├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE ├── kepler ├── __init__.py ├── data_management │ ├── __init__.py │ ├── database_simulator.py │ ├── database_simulator_test.py │ ├── test_util.py │ ├── workload.py │ └── workload_test.py ├── database_integrations │ ├── __init__.py │ ├── model_serving │ │ ├── __init__.py │ │ ├── model_server.py │ │ ├── model_server_main.py │ │ ├── model_server_test.py │ │ ├── query_parsing_utils.py │ │ └── query_parsing_utils_test.py │ └── postgres │ │ ├── 13 │ │ ├── README │ │ └── kepler_extension.patch │ │ ├── __init__.py │ │ └── pg_hint_plan_extension_test.py ├── evaluation │ ├── __init__.py │ ├── e2e_evaluation.py │ ├── e2e_evaluation_main.py │ └── e2e_evaluation_test.py ├── examples │ ├── __init__.py │ └── active_learning_for_training_data_collection_main.py ├── inputs │ └── stack_query_templates.json ├── model_trainer │ ├── __init__.py │ ├── evaluation.py │ ├── evaluation_test.py │ ├── loss_functions.py │ ├── loss_functions_test.py │ ├── model_base.py │ ├── model_base_test.py │ ├── multihead_model.py │ ├── multihead_model_test.py │ ├── sngp_multihead_model.py │ ├── sngp_multihead_model_test.py │ ├── test_util.py │ ├── trainer.py │ ├── trainer_main.py │ ├── trainer_test.py │ ├── trainer_util.py │ └── trainer_util_test.py └── training_data_collection_pipeline │ ├── README │ ├── __init__.py │ ├── e2e_test.py │ ├── main_utils.py │ ├── parameter_generator.py │ ├── parameter_generator_main.py │ ├── parameter_generator_test.py │ ├── pg_distributed_training_execution_main.py │ ├── pg_execute_explain_tools.py │ ├── pg_execute_explain_tools_main.py │ ├── pg_execute_explain_tools_test.py │ ├── pg_execute_training_data_queries.py │ ├── pg_execute_training_data_queries_main.py │ ├── pg_execute_training_data_queries_test.py │ ├── pg_generate_plan_candidates.py │ ├── pg_generate_plan_candidates_main.py │ ├── pg_generate_plan_candidates_test.py │ ├── pg_generate_plan_costs_main.py │ ├── pg_perturb_plan_cardinalities.py │ ├── pg_perturb_plan_cardinalities_test.py │ ├── pg_plan_hint_extractor.py │ ├── pg_plan_hint_extractor_main.py │ ├── pg_plan_hint_extractor_test.py │ ├── query_plan_utils.py │ ├── query_plan_utils_test.py │ ├── query_printer_main.py │ ├── query_text_utils.py │ ├── query_text_utils_test.py │ ├── query_utils.py │ ├── query_utils_test.py │ ├── test_util.py │ └── testdata │ ├── filtered_generate_candidates_explain_plans.json │ └── generate_candidates_explain_plans.json └── requirements.txt /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Currently Kepler is primarily a research project and we are not accepting 4 | contributions. If you encounter an issue or have thoughts on our work, please 5 | reach out to the Learned Systems team at learned-systems@google.com. 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kepler 2 | 3 | Kepler is a learning-based optimizer for parameterized queries that provides: 4 | 5 | 1. faster query execution time, 6 | 7 | 2. fast query planning time, and 8 | 9 | 3. robustness, i.e. is not worse than the existing optimizer with high 10 | probability. 11 | 12 | Broadly, for a given query template, Kepler generates a set of candidate plans 13 | via Row Count Evolution (RCE), then executes those plans on a workload of query 14 | instances. This execution data is then used as a training dataset for machine 15 | learning models that predict the best plan for any given query instance. Kepler 16 | leverages Spectral-Normalized Gaussian Process (SNGP) models to produce 17 | calibrated confidence scores for its predictions, and will fall back to the 18 | built-in (Postgres) optimizer if its confidence score is not sufficiently high. 19 | 20 | For more details, see our SIGMOD 2023 paper 21 | [Kepler: Robust Learning for Parametric Query Optimization](https://dl.acm.org/doi/abs/10.1145/3588963). 22 | A brief 23 | [summary](https://www.growkudos.com/publications/10.1145%25252F3588963/reader) 24 | appears on Kudos. 25 | 26 | ## Usage 27 | 28 | Usage for the individual library components can be found in their respective 29 | comment headers. 30 | 31 | The `examples` directory source code includes a demonstration for how one could 32 | use the Kepler dataset and associated tooling like `DatabaseSimulator` for 33 | active learning research to reduce the training data collection cost of Kepler. 34 | A sample run command is provided in the 35 | [Using SPQD with DatabaseSimulator](#using-spqd-with-databasesimulator) section 36 | below. 37 | 38 | ## Dataset 39 | 40 | To benchmark Kepler, we constructed a dataset using StackExchange data, which we 41 | call the Stack Parametric Query Dataset (SPQD). SPQD is based on the original 42 | [Stack](https://rm.cab/stack) dataset, and consists of 87 unique query 43 | templates. For each template, up to 50000 synthetic query instances are 44 | provided, and a set of candidate plans (generated via RCE) are executed over 45 | this workload, yielding nearly 14 years worth of execution data. 46 | 47 | Using SPQD enables research on the machine learning aspects of the problem 48 | without requiring a live database instance or paying for the cost of additional 49 | query execution. SPQD contains the cross-product of the plan-cover candidate 50 | plans and query instances. This enables a researcher to evaluate approaches to 51 | reduce the required query instance executions to train a sufficient model by 52 | using SPQD to cost and compare counterfactuals. 53 | 54 | Of the 87 query templates in SPQD, 42 are automatically extracted from the 55 | original Stack dataset, and the remaining 45 are manually written. Full dataset 56 | details can be found in the paper. 57 | 58 | The dataset can be downloaded at 59 | [https://storage.googleapis.com/gresearch/stack_parametric_query_dataset/stack_parametric_query_dataset.zip](https://storage.googleapis.com/gresearch/stack_parametric_query_dataset/stack_parametric_query_dataset.zip) 60 | (2.7 GB). 61 | 62 | ### Dataset Structure 63 | 64 | The base directory with `LICENSE.txt` consists of the following: 65 | 66 | 1. `stack_query_templates_with_metadata.json` contains the query templates, 67 | query instance parameters, and auxiliary metadata. 68 | 69 | 2. The `training_metadata` directory contains auxiliary information about 70 | database table contents that is used in model training (e.g. vocabulary 71 | selection). 72 | 73 | 3. The `execution_data` directory contains a `hints` subdirectory for 74 | RCE-generated plans and a `results` subdirectory for execution latency data 75 | and plan-cover metadata. The `results` subdirectory comprises the outputs of 76 | our training data pipeline code and corresponds directly to the input data 77 | for model training. The utilities in the `data_management` package are the 78 | recommended interface for using this data. 79 | 80 | ## Run 81 | 82 | ### General Set Up 83 | 84 | Environment tools: 85 | 86 | * The workflow has been tested with python3.10 and requires `python3.10-venv` 87 | and `python3.10-dev`. A python version below 3.8 will definitely not work 88 | due to the tensorflow version. 89 | 90 | * The psycopg2 in `requirements.txt` requires `libpq-dev` (or the equivalent 91 | for the OS) 92 | 93 | * Some libraries require installing `build-essential`. 94 | 95 | Ubuntu-friendly command: 96 | 97 | ``` 98 | sudo apt-get install python3.10-venv python3.10-dev libpq-dev build-essential 99 | ``` 100 | 101 | The following sample commands presume `BASE` is an environment variable set to 102 | the base repository directory containing `README.md` and `requirements.txt`. 103 | 104 | ``` 105 | cd $BASE 106 | python3 -m venv venv 107 | source venv/bin/activate 108 | pip install -r requirements.txt 109 | ``` 110 | 111 | ### Using SPQD with DatabaseSimulator 112 | 113 | To run the active learning example, first download and unzip the 114 | [SPQD dataset](#dataset). The following command presumes `SPQD` is the base 115 | dataset directory containing `LICENSE.txt`. 116 | 117 | *Note: No live database connection is required.* 118 | 119 | ``` 120 | cd $BASE 121 | python -m kepler.examples.active_learning_for_training_data_collection_main --query_metadata $SPQD/stack_query_templates_with_metadata.json --vocab_data_dir $SPQD/training_metadata --execution_metadata $SPQD/execution_data/results/q31_0/execution_output/stack_q31_0_metadata.json --execution_data $SPQD/execution_data/results/q31_0/execution_output/stack_q31_0.json --query_id q31_0 122 | ``` 123 | 124 | ### Postgres Set Up 125 | 126 | Postgres set up is **not required** for machine learning research using the SPQD 127 | and the associated `DatabaseSimulator` tool from the `data_management` package. 128 | 129 | The utilities which do connect to a database were tested using a Postgres 13 130 | instance with [pg_hint_plan](https://github.com/ossc-db/pg_hint_plan/tree/PG13) 131 | for PG13 installed. The instructions below cover the installation of 132 | `pg_hint_plan` but do not cover the installation of Postgres 13. 133 | 134 | Note that installing `pg_hint_plan` may require first installing Postgres dev 135 | libraries, such as via the following: 136 | 137 | ``` 138 | sudo apt install postgresql-server-dev-13 139 | ``` 140 | 141 | After installing Postgres 13 and the matching version of `pg_hint_plan` as shown 142 | below, execute `CREATE EXTENSION pg_hint_plan` from the `psql` prompt: 143 | 144 | ``` 145 | git clone https://github.com/ossc-db/pg_hint_plan.git 146 | cd pg_hint_plan 147 | git fetch origin PG13 148 | git checkout PG13 149 | sudo make install 150 | sudo service postgresql restart 151 | ``` 152 | 153 | If you have not created a user on Postgres yet, you may need to use `sudo su - 154 | postgres` before typing `psql`. 155 | 156 | The pg_stat_statements library needs to be enabled. This is typically done by 157 | adding `pg_stat_statements` to the `shared_preload_libraries` line in 158 | `/etc/postgresql/13/main/postgresql.conf` and then restarting Postgres. The 159 | edited line may look like this: 160 | 161 | ``` 162 | shared_preload_libraries = 'pg_stat_statements' 163 | ``` 164 | 165 | At this point, one can execute training data collection queries and build models 166 | using this codebase. The remaining steps describe additional set up to repeat 167 | how this project integrated Kepler into Postgres at query time. The pg_hint_plan 168 | extension was patched to reach out to a server which hosts the models. The 169 | server checks the hint string in the query for a query id. If the server has a 170 | model matching the query id, it will produce a set of plan hints. See 171 | `kepler/database_integrations/model_serving` for an implementation of this 172 | server. 173 | 174 | To patch `pg_hint_plan` for PG13, amend paths below and run the following: 175 | 176 | ``` 177 | path_to_pg_hint_plan="$HOME/pg_hint_plan" 178 | path_to_patch="$HOME/kepler/kepler/database_integrations/postgres/13/kepler_extension.patch" 179 | cd $path_to_pg_hint_plan 180 | patch -p0 < $path_to_patch 181 | sudo make install 182 | sudo service postgresql restart 183 | ``` 184 | 185 | ### Running Binaries 186 | 187 | *Coming soon.* 188 | 189 | ### Running Tests 190 | 191 | Running tests requires a few additional steps after installing Postgres 13 and 192 | pg_hint_plan. Be sure to have executed `CREATE EXTENSION pg_hint_plan` on the 193 | database, per instructions above. Open a prompt to Postgres by typing `psql` in 194 | the shell. As before, if you have not created a user on Postgres yet, you may 195 | need to use `sudo su - postgres` before typing `psql`. Then execute the 196 | following commands: 197 | 198 | ``` 199 | CREATE USER test SUPERUSER PASSWORD 'test'; 200 | CREATE DATABASE test; 201 | ``` 202 | 203 | pytest kepler 204 | 205 | ``` 206 | cd $BASE 207 | pytest kepler 208 | ``` 209 | 210 | ## Disclaimer 211 | 212 | This is not an officially supported Google product. 213 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/kepler/502506874960a81f50420184c6ab262a8b51a39d/WORKSPACE -------------------------------------------------------------------------------- /kepler/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/data_management/test_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Shared data to simplify unit tests.""" 17 | 18 | TEST_SEED = 12345 19 | TEST_QUERY_ID = "q0_0" 20 | QUERY_METADATA = { 21 | TEST_QUERY_ID: { 22 | "query": 23 | "SELECT * FROM t as t WHERE a=@param0 AND b=@param1 AND c=@param2", 24 | "predicates": [{ 25 | "table": "t", 26 | "alias": "t", 27 | "column": "a", 28 | "operator": "=", 29 | "data_type": "text", 30 | "distinct_values": ["first"] 31 | }, { 32 | "table": "t", 33 | "alias": "t", 34 | "column": "b", 35 | "operator": "=", 36 | "data_type": "text", 37 | "distinct_values": ["a", "b"] 38 | }, { 39 | "table": "t", 40 | "alias": "t", 41 | "column": "c", 42 | "operator": "=", 43 | "data_type": "int", 44 | "min": 0, 45 | "max": 2 46 | }, { 47 | "table": "t", 48 | "alias": "t", 49 | "column": "d", 50 | "operator": "=", 51 | "data_type": "float", 52 | "preprocess_type": "to_timestamp", 53 | "mean": 0, 54 | "variance": 1 55 | }] 56 | } 57 | } 58 | QUERY_EXECUTION_METADATA = {TEST_QUERY_ID: {"plan_cover": [0, 1, 2]}} 59 | QUERY_EXECUTION_DATA = { 60 | TEST_QUERY_ID: { 61 | "first####a####2####1970-01-01": { 62 | "default": 63 | 0, 64 | "results": [[{ 65 | "duration_ms": 3 66 | }, { 67 | "duration_ms": 2 68 | }, { 69 | "duration_ms": 1 70 | }], [{ 71 | "duration_ms": 3 72 | }, { 73 | "duration_ms": 2 74 | }, { 75 | "duration_ms": 2 76 | }], [{ 77 | "duration_ms": 5 78 | }, { 79 | "duration_ms": 4 80 | }, { 81 | "duration_ms": 3 82 | }], [{ 83 | "duration_ms": 6 84 | }, { 85 | "duration_ms": 6 86 | }, { 87 | "duration_ms": 4 88 | }]] 89 | }, 90 | "first####a####1####1970-01-02": { 91 | "default": 92 | 0, 93 | "results": [[{ 94 | "duration_ms": 33 95 | }, { 96 | "duration_ms": 22 97 | }, { 98 | "duration_ms": 11 99 | }], [{ 100 | "duration_ms": 33 101 | }, { 102 | "duration_ms": 22 103 | }, { 104 | "duration_ms": 22 105 | }], [{ 106 | "duration_ms": 123 107 | }, { 108 | "duration_ms": 45 109 | }, { 110 | "duration_ms": 33 111 | }], [{ 112 | "duration_ms": 61 113 | }, { 114 | "duration_ms": 64 115 | }, { 116 | "duration_ms": 45 117 | }]] 118 | }, 119 | "first####b####0####1970-01-03": { 120 | "default": 121 | 1, 122 | "results": [[{ 123 | "duration_ms": 333 124 | }, { 125 | "duration_ms": 222 126 | }, { 127 | "duration_ms": 222 128 | }], 129 | [{ 130 | "duration_ms": 333 131 | }, { 132 | "duration_ms": 222 133 | }, { 134 | "duration_ms": 111 135 | }], 136 | [{ 137 | "duration_ms": 555 138 | }, { 139 | "duration_ms": 444 140 | }, { 141 | "duration_ms": 333 142 | }], [{ 143 | "skipped": True 144 | }]] 145 | }, 146 | "first####b####1####1970-01-04": { 147 | "default": 148 | 0, 149 | "results": [[{ 150 | "duration_ms": 3333 151 | }, { 152 | "duration_ms": 3333 153 | }, { 154 | "duration_ms": 3333 155 | }], 156 | [{ 157 | "duration_ms": 3333 158 | }, { 159 | "duration_ms": 2222 160 | }, { 161 | "duration_ms": 2222 162 | }], 163 | [{ 164 | "duration_ms": 3333 165 | }, { 166 | "duration_ms": 2222 167 | }, { 168 | "duration_ms": 1111 169 | }], [{ 170 | "skipped": True 171 | }]] 172 | }, 173 | "first####c####0####1970-01-05": { 174 | "default": 175 | 2, 176 | "results": [[{ 177 | "duration_ms": 50 178 | }, { 179 | "timed_out": 31 180 | }, { 181 | "duration_ms": 15 182 | }], [{ 183 | "timed_out": 32 184 | }, { 185 | "timed_out": 32 186 | }, { 187 | "timed_out": 32 188 | }], [{ 189 | "duration_ms": 3 190 | }, { 191 | "duration_ms": 2 192 | }, { 193 | "duration_ms": 1 194 | }], [{ 195 | "skipped": True 196 | }]] 197 | }, 198 | "first####c####0####1975-01-05": { 199 | "default": 2, 200 | "results": { 201 | "default_timed_out": 3600 202 | } 203 | } 204 | } 205 | } 206 | 207 | QUERY_EXPLAIN_DATA = { 208 | TEST_QUERY_ID: { 209 | "first####a####2####1970-01-01": { 210 | "results": [[{ 211 | "total_cost": 3.1 212 | }], [{ 213 | "total_cost": 4.1 214 | }], [{ 215 | "total_cost": 5.1 216 | }], [{ 217 | "total_cost": 6.1 218 | }]] 219 | }, 220 | "first####a####1####1970-01-02": { 221 | "results": [[{ 222 | "total_cost": 13.1 223 | }], [{ 224 | "total_cost": 14.1 225 | }], [{ 226 | "total_cost": 15.1 227 | }], [{ 228 | "total_cost": 16.1 229 | }]] 230 | }, 231 | "first####b####0####1970-01-03": { 232 | "results": [[{ 233 | "total_cost": 23.1 234 | }], [{ 235 | "total_cost": 24.1 236 | }], [{ 237 | "total_cost": 25.1 238 | }], [{ 239 | "total_cost": 26.1 240 | }]] 241 | }, 242 | "first####b####1####1970-01-04": { 243 | "results": [[{ 244 | "total_cost": 33.1 245 | }], [{ 246 | "total_cost": 34.1 247 | }], [{ 248 | "total_cost": 35.1 249 | }], [{ 250 | "total_cost": 36.1 251 | }]] 252 | }, 253 | "first####c####0####1970-01-05": { 254 | "results": [[{ 255 | "total_cost": 43.1 256 | }], [{ 257 | "total_cost": 44.1 258 | }], [{ 259 | "total_cost": 45.1 260 | }], [{ 261 | "total_cost": 46.1 262 | }]] 263 | } 264 | } 265 | } 266 | 267 | PARAMETERS_POOL = [["first", "a", "2", "1970-01-01"], 268 | ["first", "a", "1", "1970-01-02"], 269 | ["first", "b", "0", "1970-01-03"], 270 | ["first", "b", "1", "1970-01-04"], 271 | ["first", "c", "0", "1970-01-05"]] 272 | -------------------------------------------------------------------------------- /kepler/data_management/workload_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Workload.""" 17 | import copy 18 | from typing import List 19 | 20 | from kepler.data_management import test_util 21 | from kepler.data_management import workload 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | 25 | 26 | class WorkloadGenerateTest(parameterized.TestCase): 27 | 28 | def test_kepler_plan_discoverer(self): 29 | plans = workload.KeplerPlanDiscoverer( 30 | query_execution_data=test_util.QUERY_EXECUTION_DATA) 31 | self.assertEqual(plans.plan_ids, [0, 1, 2, 3]) 32 | 33 | plans = workload.KeplerPlanDiscoverer( 34 | query_execution_metadata=test_util.QUERY_EXECUTION_METADATA) 35 | self.assertEqual(plans.plan_ids, [0, 1, 2]) 36 | 37 | def test_kepler_plan_discoverer_illegal_init_calls(self): 38 | self.assertRaisesRegex( 39 | ValueError, 40 | "Exactly one of query_execution_data and query_execution_metadata must", 41 | workload.KeplerPlanDiscoverer) 42 | 43 | self.assertRaisesRegex( 44 | ValueError, 45 | "Exactly one of query_execution_data and query_execution_metadata must", 46 | workload.KeplerPlanDiscoverer, test_util.QUERY_EXECUTION_DATA, 47 | test_util.QUERY_EXECUTION_METADATA) 48 | 49 | def test_workload_generator_properties(self): 50 | workload_generator = workload.WorkloadGenerator( 51 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED) 52 | self.assertEqual(workload_generator.parameter_count, 4) 53 | self.assertLen(test_util.PARAMETERS_POOL, 54 | workload_generator.workload_pool_size) 55 | 56 | def _verify_parameters_from_pool( 57 | self, query_log: List[workload.QueryInstance]) -> None: 58 | for query_instance in query_log: 59 | self.assertEqual(query_instance.execution_frequency, 1) 60 | self.assertIn(query_instance.parameters, test_util.PARAMETERS_POOL) 61 | 62 | @parameterized.named_parameters( 63 | dict(testcase_name="small", n=1), dict(testcase_name="medium", n=3), 64 | dict(testcase_name="all", n=4)) 65 | def test_workload_generator_random_sample(self, n): 66 | """Verifies workloads generated as random samples. 67 | 68 | Each sample must be the right size, selected without replacement, and only 69 | contain members from the known population. 70 | 71 | Args: 72 | n: The number of parameters to sample. 73 | """ 74 | 75 | workload_generator = workload.WorkloadGenerator( 76 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED) 77 | 78 | random_workload = workload_generator.random_sample(n) 79 | self.assertEqual(random_workload.query_id, test_util.TEST_QUERY_ID) 80 | self.assertLen(random_workload.query_log, n) 81 | 82 | parameters_set = set( 83 | [tuple(entry.parameters) for entry in random_workload.query_log]) 84 | self.assertLen(parameters_set, len(random_workload.query_log)) 85 | 86 | self._verify_parameters_from_pool(random_workload.query_log) 87 | 88 | def test_shuffle(self): 89 | workload_generator = workload.WorkloadGenerator( 90 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED) 91 | workload_all = workload_generator.all() 92 | self.assertLen(workload_all.query_log, 5) 93 | self.assertEqual(workload_all.query_id, test_util.TEST_QUERY_ID) 94 | self._verify_parameters_from_pool(workload_all.query_log) 95 | 96 | # Shuffling changes the instance order. 97 | current_query_log = copy.copy(workload_all.query_log) 98 | workload.shuffle(workload_all, seed=test_util.TEST_SEED) 99 | self.assertLen(workload_all.query_log, 5) 100 | self.assertEqual(workload_all.query_id, test_util.TEST_QUERY_ID) 101 | self._verify_parameters_from_pool(workload_all.query_log) 102 | 103 | self.assertNotEqual(current_query_log, workload_all.query_log) 104 | 105 | # Shuffling again with the same seed provides the same instance order. 106 | workload_all_again = workload_generator.all() 107 | workload.shuffle(workload_all_again, seed=test_util.TEST_SEED) 108 | self.assertLen(workload_all_again.query_log, 5) 109 | self.assertEqual(workload_all_again.query_id, test_util.TEST_QUERY_ID) 110 | self._verify_parameters_from_pool(workload_all_again.query_log) 111 | 112 | self.assertEqual(workload_all_again.query_log, workload_all.query_log) 113 | 114 | # Shuffling with a new seed provides a new instance order. 115 | workload_all_new_seed = workload_generator.all() 116 | workload.shuffle(workload_all_new_seed, seed=test_util.TEST_SEED + 1) 117 | self.assertLen(workload_all_new_seed.query_log, 5) 118 | self.assertEqual(workload_all_new_seed.query_id, test_util.TEST_QUERY_ID) 119 | self._verify_parameters_from_pool(workload_all_new_seed.query_log) 120 | 121 | self.assertNotEqual(workload_all_new_seed.query_log, workload_all.query_log) 122 | 123 | def test_split_illegal_calls(self): 124 | workload_generator = workload.WorkloadGenerator( 125 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED) 126 | workload_all = workload_generator.all() 127 | 128 | self.assertRaisesRegex(ValueError, "Exactly one of first_half_count and ", 129 | workload.split, workload_all, None, None) 130 | self.assertRaisesRegex(ValueError, "Exactly one of first_half_count and ", 131 | workload.split, workload_all, 2, .5) 132 | 133 | self.assertRaisesRegex( 134 | ValueError, 135 | "The first_half_count must be", 136 | workload.split, 137 | workload_all, 138 | first_half_count=-1) 139 | 140 | self.assertRaisesRegex( 141 | ValueError, 142 | "The first_half_fraction must be", 143 | workload.split, 144 | workload_all, 145 | first_half_fraction=-1) 146 | self.assertRaisesRegex( 147 | ValueError, 148 | "The first_half_fraction must be", 149 | workload.split, 150 | workload_all, 151 | first_half_fraction=3) 152 | 153 | def test_split(self): 154 | workload_generator = workload.WorkloadGenerator( 155 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED) 156 | workload_all = workload_generator.all() 157 | 158 | workload_train_count, workload_test_count = workload.split( 159 | workload_all, first_half_count=3) 160 | self.assertEqual(workload_train_count.query_id, test_util.TEST_QUERY_ID) 161 | self.assertLen(workload_train_count.query_log, 3) 162 | self.assertEqual(workload_all.query_log[:3], workload_train_count.query_log) 163 | 164 | self.assertEqual(workload_test_count.query_id, test_util.TEST_QUERY_ID) 165 | self.assertLen(workload_test_count.query_log, 2) 166 | self.assertEqual(workload_all.query_log[3:], workload_test_count.query_log) 167 | 168 | workload_train_fraction, workload_test_fraction = workload.split( 169 | workload_all, first_half_fraction=.25) 170 | self.assertEqual(workload_train_fraction.query_id, test_util.TEST_QUERY_ID) 171 | self.assertLen(workload_train_fraction.query_log, 1) 172 | self.assertEqual(workload_all.query_log[:1], 173 | workload_train_fraction.query_log) 174 | 175 | self.assertEqual(workload_test_fraction.query_id, test_util.TEST_QUERY_ID) 176 | self.assertLen(workload_test_fraction.query_log, 4) 177 | self.assertEqual(workload_all.query_log[1:], 178 | workload_test_fraction.query_log) 179 | 180 | 181 | if __name__ == "__main__": 182 | absltest.main() 183 | -------------------------------------------------------------------------------- /kepler/database_integrations/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/model_server.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Server returning predicted query plan hints for provided query instances.""" 17 | 18 | import socket 19 | 20 | from typing import Any, Dict 21 | 22 | from kepler.database_integrations.model_serving import query_parsing_utils 23 | from kepler.model_trainer import model_base 24 | 25 | _READ_BYTES_COUNT = 2048 26 | 27 | 28 | class ModelServer: 29 | """Serves query plan hints for provided query instances. 30 | 31 | The server runs model inference for known query ids and returns a plan 32 | prediction if it is able to make one. The plan prediction is conveyed via a 33 | hint string. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | host: str, 39 | port: int, 40 | param_extractors: Dict[str, query_parsing_utils.ParamExtractor], 41 | plan_hints: Any, 42 | predictors: Dict[str, model_base.ModelPredictorBase], 43 | read_bytes_count: int = _READ_BYTES_COUNT, 44 | ): 45 | """Prepare a server to provide hints for queries. 46 | 47 | Args: 48 | host: The host address for the model server. 49 | port: The port for the model server to use. 50 | param_extractors: A mapping from query id to its corresponding 51 | ParamExtractor. 52 | plan_hints: A mapping from query id to pg_hint_plan hints representing the 53 | set of query plans for execution. 54 | predictors: A mapping from query id to its corresponding query plan 55 | predictor. 56 | read_bytes_count: The number of bytes to read at a time from the socket. 57 | This is exposed primarily for testing. 58 | 59 | Raises: 60 | ValueError: If predictors contains query_ids not found in plan_hints or 61 | param_extractors. 62 | """ 63 | self._host = host 64 | self._port = port 65 | self._param_extractors = param_extractors 66 | self._plan_hints = plan_hints 67 | self._predictors = predictors 68 | self._read_bytes_count = read_bytes_count 69 | 70 | extra_query_ids = self._predictors.keys() - self._param_extractors.keys() 71 | if extra_query_ids: 72 | raise ValueError( 73 | "Provided predictors contains the following query ids not found in" 74 | f" param_extractors. {extra_query_ids}" 75 | ) 76 | 77 | extra_query_ids = self._predictors.keys() - self._plan_hints.keys() 78 | if extra_query_ids: 79 | raise ValueError( 80 | "Provided predictors contains the following query ids not found in" 81 | f" plan_hints. {extra_query_ids}" 82 | ) 83 | 84 | def run(self) -> None: 85 | """Run the main server loop to await connections and handle requests.""" 86 | 87 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 88 | s.bind((self._host, self._port)) 89 | s.listen() 90 | 91 | while True: 92 | conn, _ = s.accept() 93 | with conn: 94 | while True: 95 | data = conn.recv(4) 96 | data_length = int.from_bytes(data, "little") 97 | data = b"" 98 | while data_length > len(data): 99 | data += conn.recv(self._read_bytes_count) 100 | 101 | if not data: 102 | break 103 | 104 | query = data.decode() 105 | query_id = query_parsing_utils.extract_comment_content(query) 106 | 107 | if query_id not in self._predictors: 108 | response = " " 109 | else: 110 | query_params = self._param_extractors[query_id].get_params(query) 111 | plan_ids, _ = self._predictors[query_id].predict(query_params) 112 | assert len(plan_ids) == 1 113 | plan_id = plan_ids[0] 114 | 115 | if plan_id is not None: 116 | response = query_parsing_utils.extract_comment_content( 117 | self._plan_hints[query_id][int(plan_id)]["hints"] 118 | ) 119 | else: 120 | response = " " 121 | 122 | conn.sendall(response.encode()) 123 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/model_server_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Runs a Kepler model prediction server. 17 | 18 | The ModelServer supports any predictor implementing ModelPredictorBase. This 19 | program instantiates a ModelServer using an 20 | sngp_multihead_model.SNGPMultiheadModel converted to tflite. 21 | """ 22 | 23 | import json 24 | import os 25 | 26 | from absl import app 27 | from absl import flags 28 | 29 | from kepler.database_integrations.model_serving import model_server 30 | from kepler.database_integrations.model_serving import query_parsing_utils 31 | from kepler.model_trainer import sngp_multihead_model 32 | 33 | _LOCAL_HOST = "127.0.0.1" 34 | _KEPLER_PORT = 30709 35 | _TFLITE_FILE_ENDING = ".tflite" 36 | 37 | _HOST = flags.DEFINE_string( 38 | "host", _LOCAL_HOST, "The host address for the model server." 39 | ) 40 | _PORT = flags.DEFINE_integer( 41 | "port", _KEPLER_PORT, "The port for the model server to use." 42 | ) 43 | 44 | _QUERY_TEMPLATES_FILE = flags.DEFINE_string( 45 | "query_templates_file", 46 | None, 47 | "Path to file in which query templates are stored.", 48 | ) 49 | flags.mark_flag_as_required("query_templates_file") 50 | 51 | _PLAN_HINTS_DIR = flags.DEFINE_string( 52 | "plan_hints_dir", 53 | None, 54 | ( 55 | "Directory containing plan hints files. The directory must contain a" 56 | " hint file for each query id found in _TFLITE_SNGP_MODEL_DIR. Hint" 57 | " files for additional query ids will not cause a problem." 58 | ), 59 | ) 60 | flags.mark_flag_as_required("plan_hints_dir") 61 | 62 | _PLAN_COVERS_DIR = flags.DEFINE_string( 63 | "plan_covers_dir", 64 | None, 65 | ( 66 | "Directory containing the metadata files which contain the plan_cover." 67 | " These are generated during training data collection. The directory" 68 | " must contain a metadata file for each query id found in" 69 | " _TFLITE_SNGP_MODEL_DIR. Metadata files for additional query ids will" 70 | " not cause a problem." 71 | ), 72 | ) 73 | flags.mark_flag_as_required("plan_covers_dir") 74 | 75 | _TFLITE_SNGP_MODEL_DIR = flags.DEFINE_string( 76 | "tflite_sngp_model_dir", 77 | None, 78 | ( 79 | "Directory containing tflite SNGP models to use for prediction. Assumes" 80 | " the contents of the directory are files named .tflite" 81 | ), 82 | ) 83 | flags.mark_flag_as_required("tflite_sngp_model_dir") 84 | 85 | 86 | _QUERY_METADATA_FILE = flags.DEFINE_string( 87 | "query_metadata_file", 88 | None, 89 | ( 90 | "File containing metadata describing the query predicates. This is the" 91 | " same file used in training." 92 | ), 93 | ) 94 | flags.mark_flag_as_required("query_metadata_file") 95 | 96 | _CONFIDENCE_THRESHOLD = flags.DEFINE_float( 97 | "confidence_threshold", 98 | 0.9, 99 | ( 100 | "The confidence threshold used by the SNGP models to determine whether" 101 | " to a prediction or abstain." 102 | ), 103 | ) 104 | 105 | 106 | def main(unused_argv): 107 | with open(_QUERY_TEMPLATES_FILE.value) as f: 108 | templates = json.load(f) 109 | 110 | param_extractors = {} 111 | for query_id, template_entry in templates.items(): 112 | param_extractors[query_id] = query_parsing_utils.ParamExtractor( 113 | query_template=template_entry["query"] 114 | ) 115 | 116 | plan_hints = {} 117 | for plan_hints_file in os.listdir(_PLAN_HINTS_DIR.value): 118 | with open(os.path.join(_PLAN_HINTS_DIR.value, plan_hints_file)) as f: 119 | query_plan_hints = json.load(f) 120 | 121 | query_id_overlap = set(query_plan_hints).intersection(set(plan_hints)) 122 | assert not query_id_overlap, ( 123 | "One of these query ids was observed in multiple hint files:" 124 | f" {query_id_overlap}" 125 | ) 126 | plan_hints.update(query_plan_hints) 127 | 128 | plan_covers = {} 129 | for plan_cover_file in os.listdir(_PLAN_COVERS_DIR.value): 130 | with open(os.path.join(_PLAN_COVERS_DIR.value, plan_cover_file)) as f: 131 | metadata = json.load(f) 132 | 133 | query_id_overlap = set(metadata).intersection(set(plan_covers)) 134 | assert not query_id_overlap, ( 135 | "One of these query ids was observed in multiple metadata files:" 136 | f" {query_id_overlap}" 137 | ) 138 | for query_id, entry in metadata.items(): 139 | plan_covers[query_id] = entry["plan_cover"] 140 | 141 | with open(_QUERY_METADATA_FILE.value) as f: 142 | query_metadata = json.load(f) 143 | 144 | tflite_predictor_files = os.listdir(_TFLITE_SNGP_MODEL_DIR.value) 145 | predictors = {} 146 | for tflite_predictor_file in tflite_predictor_files: 147 | query_id = tflite_predictor_file[ 148 | : tflite_predictor_file.index(_TFLITE_FILE_ENDING) 149 | ] 150 | 151 | with open( 152 | os.path.join(_TFLITE_SNGP_MODEL_DIR.value, tflite_predictor_file), "rb" 153 | ) as f: 154 | tflite_model_content = f.read() 155 | predictor = sngp_multihead_model.SNGPMultiheadModelTFLitePredictor( 156 | tflite_model_content=tflite_model_content, 157 | metadata=query_metadata[query_id], 158 | plan_cover=plan_covers[query_id], 159 | confidence_threshold=_CONFIDENCE_THRESHOLD.value, 160 | ) 161 | predictors[query_id] = predictor 162 | 163 | server = model_server.ModelServer( 164 | host=_HOST.value, 165 | port=_PORT.value, 166 | param_extractors=param_extractors, 167 | plan_hints=plan_hints, 168 | predictors=predictors, 169 | ) 170 | server.run() 171 | 172 | 173 | if __name__ == "__main__": 174 | app.run(main) 175 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/model_server_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for model_server.""" 17 | 18 | import json 19 | import multiprocessing 20 | import socket 21 | import time 22 | from typing import Any, Dict, List, Optional, Tuple 23 | 24 | import numpy as np 25 | 26 | from kepler.database_integrations.model_serving import model_server 27 | from kepler.database_integrations.model_serving import query_parsing_utils 28 | from kepler.model_trainer import model_base 29 | from kepler.training_data_collection_pipeline import query_text_utils 30 | from kepler.training_data_collection_pipeline import test_util 31 | from absl.testing import absltest 32 | from absl.testing import parameterized 33 | 34 | _LOCAL_HOST = "127.0.0.1" 35 | _KEPLER_PORT = 30709 36 | 37 | 38 | _QUERY_REQUEST = "/*+ {query_id} */ {query}" 39 | 40 | _HINTS_UNUSED = """{"hints": "/*+ unused */", "source": "unused"}""" 41 | _HINTS = """[{hints_0}, {hints_unused}, {hints_1}]""".format( 42 | hints_0=test_util.HINTS_0, 43 | hints_unused=_HINTS_UNUSED, 44 | hints_1=test_util.HINTS_1, 45 | ) 46 | 47 | _PARAMS_HINT_0 = [0, "hi"] 48 | _PARAMS_HINT_1 = [2, "hi"] 49 | _PARAMS_NO_HINT = [0, "skip"] 50 | 51 | 52 | def _send_request(query_id: str, query: str) -> str: 53 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 54 | s.connect((_LOCAL_HOST, _KEPLER_PORT)) 55 | request = _QUERY_REQUEST.format(query_id=query_id, query=query).encode() 56 | s.sendall(len(request).to_bytes(4, "little")) 57 | s.sendall(request) 58 | return s.recv(2048).decode() 59 | 60 | 61 | class ConcreteModelPredictor(model_base.ModelPredictorBase): 62 | """Make non-abstract class for testing.""" 63 | 64 | def predict( 65 | self, params: List[Any] 66 | ) -> Tuple[Optional[int], Optional[Dict[str, Any]]]: 67 | if params[1] == "skip": 68 | return np.array([None]), None 69 | 70 | return np.array([params[0]]), None 71 | 72 | 73 | class ModelServerTest(parameterized.TestCase): 74 | 75 | def setUp(self): 76 | super().setUp() 77 | self._query_template = json.loads(test_util.TEST_TEMPLATE_STRING)["query"] 78 | self._param_extractors = { 79 | test_util.TEST_QUERY_ID: query_parsing_utils.ParamExtractor( 80 | query_template=self._query_template 81 | ) 82 | } 83 | self._plan_hints = {test_util.TEST_QUERY_ID: json.loads(_HINTS)} 84 | self._predictors = {test_util.TEST_QUERY_ID: ConcreteModelPredictor()} 85 | self._server_process = None 86 | 87 | def tearDown(self): 88 | if self._server_process: 89 | self._server_process.terminate() 90 | while self._server_process.is_alive(): 91 | time.sleep(0.01) 92 | 93 | super().tearDown() 94 | 95 | def _start_server( 96 | self, 97 | read_bytes_count: int = 1024, 98 | param_extractors=None, 99 | plan_hints=None, 100 | ) -> None: 101 | param_extractors = ( 102 | param_extractors 103 | if param_extractors is not None 104 | else self._param_extractors 105 | ) 106 | plan_hints = plan_hints if plan_hints is not None else self._plan_hints 107 | 108 | server = model_server.ModelServer( 109 | host=_LOCAL_HOST, 110 | port=_KEPLER_PORT, 111 | param_extractors=param_extractors, 112 | plan_hints=plan_hints, 113 | predictors=self._predictors, 114 | read_bytes_count=read_bytes_count, 115 | ) 116 | self._server_process = multiprocessing.Process(target=server.run) 117 | self._server_process.start() 118 | 119 | # Block until the server is ready to accept connections. 120 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 121 | while True: 122 | try: 123 | s.connect((_LOCAL_HOST, _KEPLER_PORT)) 124 | return 125 | except ConnectionRefusedError: 126 | time.sleep(0.01) 127 | 128 | @parameterized.named_parameters( 129 | dict( 130 | testcase_name="read_bytes_count larger than query", 131 | read_bytes_count=1024, 132 | ), 133 | dict( 134 | testcase_name="read_bytes_count smaller than query", 135 | read_bytes_count=5, 136 | ), 137 | ) 138 | def test_read_bytes_count(self, read_bytes_count: int): 139 | self._start_server(read_bytes_count=read_bytes_count) 140 | 141 | hints = _send_request( 142 | query_id=test_util.TEST_QUERY_ID, 143 | query=query_text_utils.substitute_query_params( 144 | query=self._query_template, params=_PARAMS_HINT_0 145 | ), 146 | ) 147 | self.assertIn("MergeJoin", hints) 148 | 149 | @parameterized.named_parameters( 150 | dict( 151 | testcase_name="hint 0", 152 | params=_PARAMS_HINT_0, 153 | expected="MergeJoin(foo bar) Leading((foo bar))", 154 | ), 155 | dict( 156 | testcase_name="hint 1", 157 | params=_PARAMS_HINT_1, 158 | expected="HashJoin(foo bar) Leading((foo bar))", 159 | ), 160 | dict(testcase_name="no hints", params=_PARAMS_NO_HINT, expected=" "), 161 | ) 162 | def test_get_hints(self, params: List[Any], expected: str): 163 | self._start_server() 164 | 165 | hints = _send_request( 166 | query_id=test_util.TEST_QUERY_ID, 167 | query=query_text_utils.substitute_query_params( 168 | query=self._query_template, params=params 169 | ), 170 | ) 171 | self.assertEqual(hints, expected) 172 | 173 | @parameterized.named_parameters( 174 | dict(testcase_name="unknown id", query_id="q22"), 175 | dict(testcase_name="empty", query_id=""), 176 | ) 177 | def test_query_id_without_predictor(self, query_id: str): 178 | self._start_server() 179 | 180 | hints = _send_request( 181 | query_id=query_id, 182 | query=query_text_utils.substitute_query_params( 183 | query=self._query_template, params=_PARAMS_HINT_1 184 | ), 185 | ) 186 | self.assertEqual(hints, " ") 187 | 188 | def test_inconsistent_init(self): 189 | plan_hints = {"q22": json.loads(_HINTS)} 190 | self.assertRaisesRegex( 191 | ValueError, 192 | "not found in plan_hints", 193 | self._start_server, 194 | plan_hints=plan_hints, 195 | ) 196 | 197 | param_extractors = { 198 | "q23": query_parsing_utils.ParamExtractor( 199 | query_template=self._query_template 200 | ) 201 | } 202 | self.assertRaisesRegex( 203 | ValueError, 204 | "not found in param_extractors", 205 | self._start_server, 206 | param_extractors=param_extractors, 207 | ) 208 | 209 | 210 | if __name__ == "__main__": 211 | absltest.main() 212 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/query_parsing_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Query parsing utilties to extract attributes for the prediction server.""" 17 | 18 | import re 19 | from typing import Any, Dict, List, Optional 20 | 21 | import pglast 22 | 23 | _COMMENT_START = "/*+" 24 | _COMMENT_END = "*/" 25 | 26 | _SKIP_ATTRIBUTES = ["location"] 27 | 28 | _PARAM_REGEX = r"@param[0-9]*" 29 | _PARAM_SINGLE_QUOTE_REGEX = r"'@param[0-9]*'" 30 | _PARAM_DOUBLE_QUOTE_REGEX = r"\"@param[0-9]*\"" 31 | _PARAM_TOKEN_CHECKER = re.compile(_PARAM_REGEX) 32 | 33 | 34 | def _add_quotes_to_param(match): 35 | return f"'{match.group()}'" 36 | 37 | 38 | def _strip_quotes_from_param(match): 39 | return match.group()[1:-1] 40 | 41 | 42 | def _canonize_params(query: str) -> str: 43 | """Ensures all params are single-quoted.""" 44 | query = re.sub(_PARAM_SINGLE_QUOTE_REGEX, _strip_quotes_from_param, query) 45 | query = re.sub(_PARAM_DOUBLE_QUOTE_REGEX, _strip_quotes_from_param, query) 46 | return re.sub(_PARAM_REGEX, _add_quotes_to_param, query) 47 | 48 | 49 | def _get_nodes_flattened(query: str) -> List[Any]: 50 | """Flattens out the query nodes and values from the parsed AST.""" 51 | root = pglast.Node(pglast.parse_sql(query)) 52 | nodes_flattened = [] 53 | _get_nodes_flattened_helper(root, nodes_flattened) 54 | return nodes_flattened 55 | 56 | 57 | def _get_nodes_flattened_helper(node, nodes_flattened: List[Any]): 58 | """Executes a step in the recursive tree walk to flatten the parsed AST.""" 59 | if isinstance(node, pglast.node.Scalar): 60 | if node.parent_attribute in _SKIP_ATTRIBUTES: 61 | return 62 | 63 | nodes_flattened.append(node.value) 64 | return 65 | elif isinstance(node, pglast.node.Node): 66 | nodes_flattened.append(node.node_tag) 67 | for child in node: 68 | _get_nodes_flattened_helper(child, nodes_flattened) 69 | 70 | 71 | def _get_param_indices_map(nodes_flattened: List[Any]) -> Dict[str, int]: 72 | """Maps params to positions to extract param values from query instances.""" 73 | param_indices_map = {} 74 | for i, token in enumerate(nodes_flattened): 75 | if _PARAM_TOKEN_CHECKER.search(str(token)): 76 | param_indices_map[token] = i 77 | 78 | return param_indices_map 79 | 80 | 81 | class ParamExtractor: 82 | """Extracts param values from query instances of a given query template.""" 83 | 84 | def __init__(self, query_template: str): 85 | canonized_query_template = _canonize_params(query_template) 86 | nodes_flattened = _get_nodes_flattened(canonized_query_template) 87 | self._node_count = len(nodes_flattened) 88 | self._param_indices_map = _get_param_indices_map(nodes_flattened) 89 | 90 | # Check for parameter number issues. 91 | self._extract_params(nodes_flattened) 92 | 93 | def _extract_params(self, nodes_flattened: List[Any]) -> List[Any]: 94 | params = [] 95 | for i in range(len(self._param_indices_map)): 96 | param_key = f"@param{i}" 97 | 98 | if param_key not in self._param_indices_map: 99 | raise ValueError("Param indices are not consecutive starting with 0.") 100 | 101 | params.append(nodes_flattened[self._param_indices_map[param_key]]) 102 | return params 103 | 104 | def get_params(self, query_instance: str) -> List[Any]: 105 | """Extracts parameter binding values from a query instance. 106 | 107 | The caller is expected to cast data types as required. 108 | 109 | The current implementation presumes the limitations used in the current 110 | Kepler iteration. For example, IN lists are expected to only contain a 111 | single element. 112 | 113 | Args: 114 | query_instance: An instance with parameter values substituted into the 115 | query template used to instantiate this instance of ParamExtractor. 116 | 117 | Returns: 118 | A list of param binding values found in the query instance. The order of 119 | the values in the list matches the @param tokens from the 120 | query template in sorted order, not appearance order. 121 | """ 122 | nodes_flattened = _get_nodes_flattened(query_instance) 123 | if len(nodes_flattened) != self._node_count: 124 | raise ValueError( 125 | f"Mismatch in flattened query tree size between query instance ({len(nodes_flattened)}) and template ({self._node_count})." 126 | ) 127 | 128 | return self._extract_params(nodes_flattened) 129 | 130 | 131 | def extract_comment_content(query: str) -> Optional[str]: 132 | """Extracts the content from the query comment syntax. 133 | 134 | This function assumes a protocol where the comment is of the following exact 135 | format: 136 | /*+ */ SELECT ... 137 | 138 | Args: 139 | query: The SQL query text to parse. 140 | 141 | Returns: 142 | The comment content if found, None otherwise. 143 | """ 144 | try: 145 | return query[len(_COMMENT_START):query.index(_COMMENT_END)].strip() 146 | except ValueError: 147 | return None 148 | -------------------------------------------------------------------------------- /kepler/database_integrations/model_serving/query_parsing_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for query_parsing_utils.""" 17 | from typing import Any, List 18 | 19 | 20 | from kepler.database_integrations.model_serving import query_parsing_utils 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | 25 | class QueryParsingUtilsTest(parameterized.TestCase): 26 | 27 | def test_extract_comment_content(self): 28 | self.assertEqual( 29 | query_parsing_utils.extract_comment_content("/*+ q12_1 */"), "q12_1" 30 | ) 31 | self.assertEqual( 32 | query_parsing_utils.extract_comment_content("/*+q12_2 */"), "q12_2" 33 | ) 34 | self.assertEqual( 35 | query_parsing_utils.extract_comment_content("/*+ q12_3*/"), "q12_3" 36 | ) 37 | self.assertEqual( 38 | query_parsing_utils.extract_comment_content("/*+q12_4*/"), "q12_4" 39 | ) 40 | self.assertEqual( 41 | query_parsing_utils.extract_comment_content("/*+ hi there*/"), 42 | "hi there", 43 | ) 44 | self.assertIsNone(query_parsing_utils.extract_comment_content("hi there")) 45 | 46 | @parameterized.named_parameters( 47 | dict( 48 | testcase_name="0 params", 49 | query_template="SELECT foo FROM bar where x > 0", 50 | query_instance="SELECT foo FROM bar where x > 0", 51 | expected=[], 52 | ), 53 | dict( 54 | testcase_name="some params with hint to ignore", 55 | query_template=( 56 | "/*+ q11_0 */ SELECT foo FROM bar where x > @param0 and y =" 57 | " '@param1' and z in ('@param2')" 58 | ), 59 | query_instance=( 60 | "SELECT foo FROM bar where x > 5 and y = '2022-01-01' and z in" 61 | " ('bye')" 62 | ), 63 | expected=[5, "2022-01-01", "bye"], 64 | ), 65 | dict( 66 | testcase_name="out-of-order params", 67 | query_template=( 68 | "SELECT foo FROM bar where x > @param2 and y = '@param0' and z in" 69 | " ('@param1')" 70 | ), 71 | query_instance=( 72 | "SELECT foo FROM bar where x > 5 and y = '2022-01-01' and z in" 73 | " ('bye')" 74 | ), 75 | expected=["2022-01-01", "bye", 5], 76 | ), 77 | dict( 78 | testcase_name="double digit params", 79 | query_template=( 80 | "SELECT foo FROM bar where x > @param0 and x > @param10 and " 81 | "x > @param2 and x > @param3 and x > @param4 and x > @param5 and " 82 | "x > @param6 and x > @param7 and x > @param8 and x > @param9 and " 83 | "x > @param1" 84 | ), 85 | query_instance=( 86 | "SELECT foo FROM bar where x > 1 and x > 11 and x > 3 and x > 4 " 87 | "and x > 5 and x > 6 and x > 7 and x > 8 and x > 9 and x > 10 " 88 | "and x > 2" 89 | ), 90 | expected=list(range(1, 12)), 91 | ), 92 | ) 93 | def test_extract_params(self, query_template: str, query_instance: str, 94 | expected: List[Any]): 95 | param_extractor = query_parsing_utils.ParamExtractor(query_template) 96 | self.assertEqual(param_extractor.get_params(query_instance), expected) 97 | 98 | def test_bad_query_template(self): 99 | self.assertRaisesRegex(ValueError, "Param indices are not consecutive", 100 | query_parsing_utils.ParamExtractor, 101 | "SELECT foo from bar where x > @param1") 102 | self.assertRaisesRegex( 103 | ValueError, "Param indices are not consecutive", 104 | query_parsing_utils.ParamExtractor, 105 | "SELECT foo from bar where x > @param0 and x > @param2") 106 | self.assertRaisesRegex( 107 | ValueError, "Param indices are not consecutive starting with 0", 108 | query_parsing_utils.ParamExtractor, 109 | "SELECT foo from bar where x > @param1 and x > @param2") 110 | 111 | def test_bad_query_instance(self): 112 | param_extractor = query_parsing_utils.ParamExtractor( 113 | "SELECT foo from bar where x > @param0") 114 | self.assertRaisesRegex(ValueError, "Mismatch in flattened query tree size", 115 | param_extractor.get_params, 116 | "SELECT foo from bar where x > @param0 and 1=1") 117 | 118 | 119 | if __name__ == "__main__": 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /kepler/database_integrations/postgres/13/README: -------------------------------------------------------------------------------- 1 | # Kepler in Postgres 13 2 | 3 | We extend the pg_hint_plan extension to request hints from a Kepler server. 4 | 5 | ## Set Up 6 | git clone https://github.com/ossc-db/pg_hint_plan.git 7 | cd pg_hint_plan 8 | git fetch origin PG13 9 | git checkout PG13 10 | patch -p0 < /path/to/kepler_extension.patch 11 | sudo make install 12 | sudo service postgresql restart 13 | 14 | ## Enabling extension. 15 | Once the host is set, Postgres will expect the server to be reachable and error otherwise. 16 | 17 | LOAD 'pg_hint_plan'; 18 | 19 | SET pg_hint_plan.kepler_port TO ; 20 | SET pg_hint_plan.kepler_host TO ; 21 | 22 | ## Test Set Up 23 | The following may be required on a fresh postgres to run tests: 24 | 25 | CREATE DATABASE test; 26 | CREATE ROLE test SUPERUSER; 27 | ALTER ROLE test PASSWORD 'test'; 28 | ALTER ROLE test WITH LOGIN; -------------------------------------------------------------------------------- /kepler/database_integrations/postgres/13/kepler_extension.patch: -------------------------------------------------------------------------------- 1 | diff --git pg_hint_plan.c pg_hint_plan.c 2 | index 21e24d3..556d470 100644 3 | --- pg_hint_plan.c 4 | +++ pg_hint_plan.c 5 | @@ -8,6 +8,8 @@ 6 | *------------------------------------------------------------------------- 7 | */ 8 | #include 9 | +#include 10 | +#include 11 | 12 | #include "postgres.h" 13 | #include "access/genam.h" 14 | @@ -519,7 +521,10 @@ static int pg_hint_plan_parse_message_level = INFO; 15 | static int pg_hint_plan_debug_message_level = LOG; 16 | /* Default is off, to keep backward compatibility. */ 17 | static bool pg_hint_plan_enable_hint_table = false; 18 | +static char* pg_hint_plan_kepler_host = NULL; 19 | +static int pg_hint_plan_kepler_port = 30709; 20 | 21 | +static int kepler_fd = -1; /* Kepler server socket file descriptor */ 22 | static int plpgsql_recurse_level = 0; /* PLpgSQL recursion level */ 23 | static int recurse_level = 0; /* recursion level incl. direct SPI calls */ 24 | static int hint_inhibit_level = 0; /* Inhibit hinting if this is above 0 */ 25 | @@ -687,6 +692,30 @@ _PG_init(void) 26 | NULL, 27 | NULL); 28 | 29 | + DefineCustomStringVariable("pg_hint_plan.kepler_host", 30 | + "Kepler server host", 31 | + "Request plan hints from Kepler server at specified host", 32 | + &pg_hint_plan_kepler_host, 33 | + NULL, 34 | + PGC_USERSET, 35 | + 0, 36 | + NULL, 37 | + NULL, 38 | + NULL); 39 | + 40 | + DefineCustomIntVariable("pg_hint_plan.kepler_port", 41 | + "Kepler server port", 42 | + "Listening port for Kepler server", 43 | + &pg_hint_plan_kepler_port, 44 | + 30709, 45 | + 1, 46 | + 65536, 47 | + PGC_USERSET, 48 | + 0, 49 | + NULL, 50 | + NULL, 51 | + NULL); 52 | + 53 | /* Install hooks. */ 54 | prev_post_parse_analyze_hook = post_parse_analyze_hook; 55 | post_parse_analyze_hook = pg_hint_plan_post_parse_analyze; 56 | @@ -2003,6 +2032,122 @@ get_hints_from_comment(const char *p) 57 | return p; 58 | } 59 | 60 | +static void 61 | +terminate_kepler_connection() 62 | +{ 63 | + shutdown(kepler_fd, SHUT_RDWR); 64 | + kepler_fd = -1; 65 | +} 66 | + 67 | +static bool 68 | +create_kepler_connection() 69 | +{ 70 | + int ret = -1; 71 | + struct sockaddr_in server_addr = {0}; 72 | + 73 | + if (kepler_fd > 0) 74 | + return true; 75 | + 76 | + elog(pg_hint_plan_debug_message_level, "pg_hint_plan: Creating new connection to Kepler server"); 77 | + kepler_fd = socket(AF_INET, SOCK_STREAM, 0); 78 | + if (kepler_fd < 0) 79 | + { 80 | + elog(ERROR, "Unable to create Kepler socket"); 81 | + return false; 82 | + } 83 | + 84 | + server_addr.sin_family = AF_INET; 85 | + server_addr.sin_port = htons(pg_hint_plan_kepler_port); 86 | + inet_pton(AF_INET, pg_hint_plan_kepler_host, &server_addr.sin_addr); 87 | + ret = connect(kepler_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)); 88 | + if (ret == -1) 89 | + { 90 | + terminate_kepler_connection(); 91 | + return false; 92 | + } 93 | + 94 | + return true; 95 | +} 96 | + 97 | +/* 98 | + * Check for VariableSetStmt and VariableShowStmt. 99 | + */ 100 | +static bool 101 | +is_variable_statement(const Query *query) 102 | +{ 103 | + if (query->commandType != CMD_UTILITY) 104 | + return false; 105 | + 106 | + return (IsA((Query *)query->utilityStmt, VariableSetStmt) || 107 | + IsA((Query *)query->utilityStmt, VariableShowStmt)); 108 | +} 109 | + 110 | +/* 111 | + * Get hints from Kepler server. 112 | + */ 113 | +static const char * 114 | +get_hints_from_kepler(const char *query_str) 115 | +{ 116 | + int ret = -1; 117 | + int bytes_remaining = 0; 118 | + const char* buffer = NULL; 119 | + char* hint_str = NULL; 120 | + elog(INFO, ("get hints from Kepler")); 121 | + /* Connect to kepler socket if not already connected. */ 122 | + if (!create_kepler_connection()) 123 | + { 124 | + elog(ERROR, "Unable to create connection to Kepler server"); 125 | + return NULL; 126 | + } 127 | + 128 | + bytes_remaining = strlen(query_str); 129 | + buffer = query_str; 130 | + 131 | + elog(pg_hint_plan_debug_message_level, "pg_hint_plan: Requesting hint for: %s", query_str); 132 | + 133 | + /* First send the number of bytes in the query. */ 134 | + ret = write(kepler_fd, &bytes_remaining, sizeof(bytes_remaining)); 135 | + if (ret < 0) 136 | + { 137 | + elog(ERROR, "Unable to send request to Kepler server"); 138 | + terminate_kepler_connection(); 139 | + return NULL; 140 | + } 141 | + while (bytes_remaining) 142 | + { 143 | + ret = write(kepler_fd, buffer, bytes_remaining); 144 | + if (ret < 0) 145 | + { 146 | + elog(ERROR, "Unable to send request to Kepler server"); 147 | + terminate_kepler_connection(); 148 | + return NULL; 149 | + } 150 | + bytes_remaining -= ret; 151 | + buffer += ret; 152 | + } 153 | + 154 | + // TODO: Support dynamic length of hint strings. 155 | + hint_str = palloc(4096); 156 | + memset(hint_str, 0, 4096); 157 | + ret = read(kepler_fd, hint_str, 4096); 158 | + if (ret <= 0) 159 | + { 160 | + elog(ERROR, "Unable to receive hint from Kepler server"); 161 | + pfree(hint_str); 162 | + terminate_kepler_connection(); 163 | + return NULL; 164 | + } 165 | + 166 | + if (ret == 1) 167 | + { 168 | + elog(pg_hint_plan_debug_message_level, "pg_hint_plan: Kepler server did not give hint"); 169 | + pfree(hint_str); 170 | + return NULL; 171 | + } 172 | + 173 | + return hint_str; 174 | +} 175 | + 176 | /* 177 | * Parse hints that got, create hint struct from parse tree and parse hints. 178 | */ 179 | @@ -2879,6 +3024,43 @@ get_current_hint_string(ParseState *pstate, Query *query) 180 | snprintf(qnostr, sizeof(qnostr), "[qno=0x%x]", qno++); 181 | qno++; 182 | 183 | + /* TODO: The get_query_string() function will get called twice in the 184 | + case of pg_hint_plan_enable_hint_table and Kepler doesn't provide a 185 | + hint. */ 186 | + query_str = get_query_string(pstate, query, NULL); 187 | + 188 | + if (pg_hint_plan_kepler_host && !is_variable_statement(query)) 189 | + { 190 | + /* get hints from Kepler. */ 191 | + oldcontext = MemoryContextSwitchTo(TopMemoryContext); 192 | + current_hint_str = get_hints_from_kepler(query_str); 193 | + MemoryContextSwitchTo(oldcontext); 194 | + 195 | + if (debug_level > 1) 196 | + { 197 | + if (debug_level == 2 && query_str && debug_query_string && 198 | + strcmp(query_str, debug_query_string)) 199 | + ereport(pg_hint_plan_debug_message_level, 200 | + (errmsg("hints in Kepler=\"%s\"", 201 | + current_hint_str ? current_hint_str : "(none)"), 202 | + errhidestmt(msgqno != qno), 203 | + errhidecontext(msgqno != qno))); 204 | + else 205 | + ereport(pg_hint_plan_debug_message_level, 206 | + (errmsg("hints in Kepler=\"%s\", query=\"%s\", debug_query_string=\"%s\"", 207 | + current_hint_str ? current_hint_str : "(none)", 208 | + query_str ? query_str : "(none)", 209 | + debug_query_string ? debug_query_string : "(none)"), 210 | + errhidestmt(msgqno != qno), 211 | + errhidecontext(msgqno != qno))); 212 | + msgqno = qno; 213 | + } 214 | + 215 | + /* Return if we have hint string here */ 216 | + if (current_hint_str) 217 | + return; 218 | + } 219 | + 220 | /* search the hint table for a hint if requested */ 221 | if (pg_hint_plan_enable_hint_table) 222 | { 223 | -------------------------------------------------------------------------------- /kepler/database_integrations/postgres/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/database_integrations/postgres/pg_hint_plan_extension_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pg_hint_plan extension.""" 17 | 18 | import multiprocessing 19 | import socket 20 | import time 21 | 22 | import psycopg2 23 | 24 | from kepler.training_data_collection_pipeline import pg_plan_hint_extractor 25 | from kepler.training_data_collection_pipeline import query_utils 26 | from kepler.training_data_collection_pipeline import test_util 27 | from absl.testing import absltest 28 | from absl.testing import parameterized 29 | 30 | _LOCAL_HOST = "127.0.0.1" 31 | _KEPLER_PORT = 30709 32 | 33 | _SET_PORT = "SET pg_hint_plan.kepler_port TO {port};" 34 | _SET_HOST = "SET pg_hint_plan.kepler_host TO '127.0.0.1';" 35 | _RESET_HOST = "RESET pg_hint_plan.kepler_host;" 36 | _QUERY = "SELECT COUNT(*) FROM foo JOIN bar ON x = a JOIN baz ON k = a;" 37 | _DEFAULT_JOIN_ORDER = "Leading(((foo baz) bar))" 38 | _FORCED_JOIN_ORDER = "Leading(((bar foo) baz))" 39 | 40 | _TRIGGER_KEPLER = "hint" 41 | _TRIGGER_READ_ERROR = "read error" 42 | 43 | 44 | def _run_server(read_bytes_count: int) -> None: 45 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 46 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 47 | s.bind((_LOCAL_HOST, _KEPLER_PORT)) 48 | s.listen() 49 | 50 | b_trigger_kepler = _TRIGGER_KEPLER.encode("ascii") 51 | b_trigger_read_error = _TRIGGER_READ_ERROR.encode("ascii") 52 | while True: 53 | conn, _ = s.accept() 54 | with conn: 55 | while True: 56 | data = conn.recv(4) 57 | data_length = int.from_bytes(data, "little") 58 | data = b"" 59 | while data_length > len(data): 60 | data += conn.recv(read_bytes_count) 61 | 62 | if not data: 63 | break 64 | 65 | if b_trigger_kepler in data: 66 | response = _FORCED_JOIN_ORDER 67 | elif b_trigger_read_error in data: 68 | return 69 | else: 70 | response = " " 71 | 72 | conn.sendall(response.encode()) 73 | 74 | 75 | def _trigger_kepler(query: str) -> str: 76 | return f"/*+ {_TRIGGER_KEPLER} */ {query}" 77 | 78 | 79 | def _trigger_read_error(query: str) -> str: 80 | return f"/*+ {_TRIGGER_READ_ERROR} */ {query}" 81 | 82 | 83 | class PgHintPlanExtensionTest(parameterized.TestCase): 84 | 85 | def setUp(self): 86 | super().setUp() 87 | self._test_database = test_util.TestDatabase() 88 | self._query_manager = query_utils.QueryManager( 89 | database_configuration=query_utils.DatabaseConfiguration( 90 | dbname=self._test_database.dbname, 91 | user=test_util.USER, 92 | password=test_util.PASSWORD)) 93 | test_util.populate_database(self._query_manager) 94 | # Some tests read from pg_stats tables, which requires running ANALYZE. 95 | # Warning: Removing/adding this affects cardinality estimates. 96 | self._query_manager.run_analyze() 97 | 98 | self._query_manager.execute_and_commit(_SET_PORT.format(port=_KEPLER_PORT)) 99 | self._server_process = None 100 | 101 | def tearDown(self): 102 | if self._server_process: 103 | self._server_process.terminate() 104 | while self._server_process.is_alive(): 105 | time.sleep(0.01) 106 | 107 | self._test_database.drop() 108 | super().tearDown() 109 | 110 | def start_server(self, read_bytes_count: int) -> None: 111 | self._server_process = multiprocessing.Process( 112 | target=_run_server, args=(read_bytes_count,)) 113 | self._server_process.start() 114 | # Block until the server is ready to accept connections. 115 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 116 | while True: 117 | try: 118 | s.connect((_LOCAL_HOST, _KEPLER_PORT)) 119 | return 120 | except ConnectionRefusedError: 121 | time.sleep(.01) 122 | 123 | @parameterized.named_parameters( 124 | dict( 125 | testcase_name="read_bytes_count larger than query", 126 | read_bytes_count=1024), 127 | dict( 128 | testcase_name="read_bytes_count smaller than query", 129 | read_bytes_count=5), 130 | ) 131 | def test_pg_hint_plan_extension(self, read_bytes_count: int): 132 | """Tests receiving hints while toggling the extension on and off.""" 133 | self.start_server(read_bytes_count=read_bytes_count) 134 | self._query_manager.execute_and_commit(_SET_HOST) 135 | 136 | kepler_query = _trigger_kepler(_QUERY) 137 | # Repeat requests to ensure there's no lag in hinting due to state 138 | # mismanagement. 139 | for _ in range(3): 140 | for _ in range(3): 141 | hints = pg_plan_hint_extractor.get_single_query_hints_with_plan( 142 | query_manager=self._query_manager, query=_QUERY, params=None)[0] 143 | self.assertIn(_DEFAULT_JOIN_ORDER, hints) 144 | 145 | for _ in range(3): 146 | hints = pg_plan_hint_extractor.get_single_query_hints_with_plan( 147 | query_manager=self._query_manager, query=kepler_query, 148 | params=None)[0] 149 | self.assertIn(_FORCED_JOIN_ORDER, hints) 150 | 151 | self._query_manager.execute_and_commit(_RESET_HOST) 152 | for _ in range(3): 153 | hints = pg_plan_hint_extractor.get_single_query_hints_with_plan( 154 | query_manager=self._query_manager, query=kepler_query, 155 | params=None)[0] 156 | self.assertIn(_DEFAULT_JOIN_ORDER, hints) 157 | self._query_manager.execute_and_commit(_SET_HOST) 158 | 159 | def test_create_connection_error(self): 160 | with self.assertRaisesRegex(psycopg2.InternalError, 161 | "Unable to create connection to Kepler server"): 162 | # The COMMIT triggers the Kepler connection creation attempt. 163 | self._query_manager.execute_and_commit(_SET_HOST) 164 | 165 | # TODO(lyric): Test that we can get out of this issue by RESET-ing 166 | # pg_hint_plan.kepler_host. This works at the SQL command prompt, but not in 167 | # the tests at present due to errors around failed transaction. Debugging 168 | # required. 169 | 170 | def test_unable_to_receive_hints(self): 171 | self.start_server(read_bytes_count=1024) 172 | self._query_manager.execute_and_commit(_SET_HOST) 173 | 174 | with self.assertRaisesRegex(psycopg2.InternalError, 175 | "Unable to receive hint from Kepler server"): 176 | self._query_manager.execute(_trigger_read_error(_QUERY)) 177 | 178 | 179 | if __name__ == "__main__": 180 | absltest.main() 181 | -------------------------------------------------------------------------------- /kepler/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/evaluation/e2e_evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluates Kepler models integrated into a database system.""" 17 | 18 | 19 | from typing import Any, Callable, List, Tuple, Union 20 | 21 | 22 | from kepler.data_management import workload 23 | 24 | # TODO(b/199162711): Transition this script and downstream analysis scripts to a 25 | # structured format instead of using _NAME_DELIMITER. 26 | _NAME_DELIMITER = "####" 27 | 28 | 29 | def _get_params_as_string(params: List[Any]) -> str: 30 | return _NAME_DELIMITER.join([str(p) for p in params]) 31 | 32 | 33 | def evaluate_workload( 34 | workload_eval: workload.Workload, 35 | template: Any, 36 | iterations: int, 37 | batch_size: int, 38 | limit: Union[int, None], 39 | execute_query_fn: Callable[[str, List[Any]], Tuple[float, int]], 40 | checkpoint_results_fn: Callable[[str, Any], None], 41 | ) -> None: 42 | """Executes SQL queries generated defined by workload_eval. 43 | 44 | For each set of parameter values, we execute a query using 45 | execute_query_fn. The query execution latencies are stored in dicts and then 46 | passed to the checkpoint_results_fn function. 47 | 48 | Args: 49 | workload_eval: The workload defining which parameter bindings to evaluate. 50 | template: The templatized SQL query text. 51 | iterations: The number of times to execute the requested query. 52 | batch_size: The number of parameter_values to fully evaluate before calling 53 | checkpoint_results_fn. 54 | limit: The number of parameter_values to gather execution data for. 55 | execute_query_fn: The function to call to execute a query. The function is 56 | expected to accept the query to execute and the list of parameters. It is 57 | expected to return a tuple of the query execution latency in ms and the 58 | rows produced count. 59 | checkpoint_results_fn: A callback function called at a period defined by 60 | batch_size as well as upon completion of all executions. The function is 61 | expected to accept the query_id and a JSON object of results. 62 | """ 63 | query_id = workload_eval.query_id 64 | kepler_enabled_query_template = f"/*+ {query_id} */ {template['query']}" 65 | 66 | results = {} 67 | for query_instance in workload_eval.query_log[:limit]: 68 | params_as_string = _get_params_as_string(query_instance.parameters) 69 | 70 | execution_results = [] 71 | rows_produced = None 72 | for _ in range(iterations): 73 | latency_ms, rows = execute_query_fn( 74 | kepler_enabled_query_template, query_instance.parameters 75 | ) 76 | execution_results.append({"duration_ms": latency_ms}) 77 | 78 | if rows_produced is None: 79 | rows_produced = rows 80 | else: 81 | assert rows == rows_produced, ( 82 | f"All executions for parameter {params_as_string} should return the" 83 | f" same number of rows. Got {rows} and {rows_produced} instead." 84 | ) 85 | 86 | results[params_as_string] = { 87 | "default": 0, 88 | "results": [execution_results], 89 | "rows": rows_produced, 90 | } 91 | 92 | if len(results) % batch_size == 0: 93 | checkpoint_results_fn(query_id, {query_id: results}) 94 | 95 | checkpoint_results_fn(query_id, {query_id: results}) 96 | -------------------------------------------------------------------------------- /kepler/evaluation/e2e_evaluation_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Orchestrates evaluating Kepler models integrated into a database system. 17 | 18 | Results in JSON format are saved to the directory indicated by the output flag. 19 | """ 20 | import json 21 | import os 22 | from typing import Any, List, Optional, Tuple 23 | 24 | from absl import app 25 | from absl import flags 26 | 27 | from kepler.data_management import workload 28 | from kepler.evaluation import e2e_evaluation 29 | from kepler.training_data_collection_pipeline import query_utils 30 | 31 | _SET_MODEL_SERVER_PORT = "SET pg_hint_plan.kepler_port TO {port};" 32 | _SET_MODEL_SERVER_HOST = "SET pg_hint_plan.kepler_host TO '127.0.0.1';" 33 | 34 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 35 | flags.mark_flag_as_required("database") 36 | _USER = flags.DEFINE_string("user", None, "Database username.") 37 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 38 | 39 | _MODEL_SERVER_PORT = flags.DEFINE_integer( 40 | "model_server_port", 30709, "The port for the model server to use." 41 | ) 42 | 43 | _QUERY_TEMPLATES_FILE = flags.DEFINE_string( 44 | "query_templates_file", 45 | None, 46 | "Path to file in which query templates are stored.", 47 | ) 48 | flags.mark_flag_as_required("query_templates_file") 49 | 50 | _EXECUTION_DATA_FILE = flags.DEFINE_string( 51 | "execution_data_file", None, "Execution data file." 52 | ) 53 | flags.mark_flag_as_required("execution_data_file") 54 | 55 | _ITERATIONS = flags.DEFINE_integer( 56 | "iterations", 57 | 3, 58 | ( 59 | "The number of iterations to execute query (query plan, parameter" 60 | " binding) pairing." 61 | ), 62 | ) 63 | _BATCH_SIZE = flags.DEFINE_integer( 64 | "batch_size", 65 | 50, 66 | ( 67 | "Batch of parameters for which to collect training data before" 68 | " checkpointing." 69 | ), 70 | ) 71 | _LIMIT = flags.DEFINE_integer( 72 | "limit", 73 | None, 74 | "Limit the number of parameters per query to gather execution data for.", 75 | ) 76 | 77 | _SEED = flags.DEFINE_integer( 78 | "seed", 0, "The seed to use when shuffling the workload before splitting." 79 | ) 80 | 81 | _TRAIN_SPLIT = flags.DEFINE_float( 82 | "train_split", 83 | 0.8, 84 | "The fraction of workload query instances to place in the training split.", 85 | ) 86 | 87 | _QUERY = flags.DEFINE_string("query", None, "Specific query id to execute.") 88 | flags.mark_flag_as_required("query") 89 | 90 | _OUTPUT_DIR = flags.DEFINE_string( 91 | "output_dir", None, "Directory to store execution results." 92 | ) 93 | flags.mark_flag_as_required("output_dir") 94 | 95 | 96 | def main(unused_argv): 97 | query_manager = query_utils.QueryManager( 98 | query_utils.DatabaseConfiguration( 99 | dbname=_DATABASE.value, user=_USER.value, password=_PASSWORD.value 100 | ) 101 | ) 102 | query_utils.save_postgres_config_info(query_manager, _OUTPUT_DIR.value) 103 | query_manager.execute_and_commit( 104 | _SET_MODEL_SERVER_PORT.format(port=_MODEL_SERVER_PORT.value) 105 | ) 106 | query_manager.execute_and_commit(_SET_MODEL_SERVER_HOST) 107 | 108 | with open(_EXECUTION_DATA_FILE.value) as f: 109 | execution_data = json.load(f) 110 | workload_generator = workload.WorkloadGenerator(execution_data) 111 | full_workload = workload_generator.all() 112 | workload.shuffle(full_workload, _SEED.value) 113 | _, workload_eval = workload.split( 114 | full_workload, first_half_fraction=_TRAIN_SPLIT.value 115 | ) 116 | 117 | with open(_QUERY_TEMPLATES_FILE.value) as f: 118 | templates = json.load(f) 119 | 120 | query_id = _QUERY.value 121 | 122 | output_subdir = os.path.join(_OUTPUT_DIR.value, "execution_output") 123 | os.makedirs(output_subdir, exist_ok=True) 124 | 125 | def checkpoint_results(query_id: str, results: Any) -> None: 126 | with open( 127 | os.path.join( 128 | output_subdir, f"{_DATABASE.value}_{query_id}_e2e_evaluation.json" 129 | ), 130 | "w", 131 | ) as f: 132 | json.dump(results, f) 133 | 134 | def execute_query( 135 | query: str, params: List[Any] 136 | ) -> Tuple[Optional[float], Optional[int]]: 137 | return query_manager.execute_timed(query, params) 138 | 139 | e2e_evaluation.evaluate_workload( 140 | workload_eval=workload_eval, 141 | template=templates[query_id], 142 | iterations=_ITERATIONS.value, 143 | batch_size=_BATCH_SIZE.value, 144 | limit=_LIMIT.value, 145 | execute_query_fn=execute_query, 146 | checkpoint_results_fn=checkpoint_results, 147 | ) 148 | 149 | 150 | if __name__ == "__main__": 151 | app.run(main) 152 | -------------------------------------------------------------------------------- /kepler/evaluation/e2e_evaluation_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pg_execute_training_data_queries.py.""" 17 | 18 | import copy 19 | import dataclasses 20 | import itertools 21 | import math 22 | from typing import Any, List, Tuple 23 | 24 | 25 | from kepler.data_management import test_util 26 | from kepler.data_management import workload 27 | from kepler.evaluation import e2e_evaluation 28 | 29 | from absl.testing import absltest 30 | from absl.testing import parameterized 31 | 32 | 33 | JSON = Any 34 | 35 | 36 | @dataclasses.dataclass 37 | class QueryExecution: 38 | query: str 39 | params: List[Any] 40 | 41 | 42 | class E2EEvaluationTest(parameterized.TestCase): 43 | 44 | def setUp(self): 45 | super().setUp() 46 | self._workload_eval = workload.WorkloadGenerator( 47 | test_util.QUERY_EXECUTION_DATA, test_util.TEST_SEED 48 | ).all() 49 | self._params = [ 50 | query_instance.parameters 51 | for query_instance in self._workload_eval.query_log 52 | ] 53 | self._keys = [ 54 | "####".join([str(element) for element in param]) 55 | for param in self._params 56 | ] 57 | self._template = test_util.QUERY_METADATA[test_util.TEST_QUERY_ID] 58 | self._results = [] 59 | 60 | def checkpoint_results( 61 | query_id: str, 62 | results: Any, 63 | ) -> None: 64 | self.assertEqual(test_util.TEST_QUERY_ID, query_id) 65 | 66 | self._results.append(copy.deepcopy(results)) 67 | 68 | self._checkpoint_results = checkpoint_results 69 | 70 | self._query_executions = [] 71 | 72 | self._latency = 1 73 | 74 | def execute_query(query: str, params: List[Any]) -> Tuple[float, int]: 75 | """Mocks execution by logging the requested query and returning a latency. 76 | 77 | The query latency changes every call. 78 | 79 | Args: 80 | query: SQL query template string with 0 or more parameters provided in 81 | the form of @param#, starting with 0. 82 | params: List of parameter values to substitute into the query. All 83 | values will be cast to str. 84 | 85 | Returns: 86 | A tuple containing: 87 | 1. The execution time of the query in ms. 88 | 2. The number of rows produced by the query. 89 | """ 90 | self._query_executions.append(QueryExecution(query=query, params=params)) 91 | self._latency += 1 92 | 93 | # The 3rd element of a params is an int. Return it as the rows produced. 94 | return self._latency, params[2] 95 | 96 | self._execute_query = execute_query 97 | 98 | def _verify_query_executions( 99 | self, iterations: int, params: List[List[str]] 100 | ) -> None: 101 | for query_execution, (current_params, _) in zip( 102 | self._query_executions, itertools.product(params, range(iterations)) 103 | ): 104 | self.assertIn(f"/*+ {test_util.TEST_QUERY_ID} */", query_execution.query) 105 | self.assertEqual(query_execution.params, current_params) 106 | 107 | def _verify_results( 108 | self, 109 | iterations: int, 110 | expected_checkpoint_count: int, 111 | keys_per_checkpoint: List[List[str]], 112 | rows_per_checkpoint: List[List[int]], 113 | ) -> None: 114 | self.assertLen(self._results, expected_checkpoint_count) 115 | self.assertLen(keys_per_checkpoint, expected_checkpoint_count) 116 | self.assertLen(rows_per_checkpoint, expected_checkpoint_count) 117 | 118 | for result, current_keys, current_rows in zip( 119 | self._results, keys_per_checkpoint, rows_per_checkpoint 120 | ): 121 | result_map = result[test_util.TEST_QUERY_ID] 122 | self.assertEqual(len(result_map), len(current_keys)) 123 | 124 | for key, rows in zip(current_keys, current_rows): 125 | self.assertEqual(result_map[key]["default"], 0) 126 | self.assertEqual(result_map[key]["rows"], rows) 127 | self.assertLen(result_map[key]["results"], 1) 128 | stats = result_map[key]["results"][0] 129 | self.assertLen(stats, iterations) 130 | for i in range(len(stats) - 1): 131 | self.assertGreater( 132 | stats[i + 1]["duration_ms"], stats[i]["duration_ms"] 133 | ) 134 | 135 | @parameterized.named_parameters( 136 | dict(testcase_name="basic", iterations=1, limit=None), 137 | dict(testcase_name="iterations_3", iterations=3, limit=None), 138 | dict(testcase_name="limit_2", iterations=1, limit=2), 139 | ) 140 | def test_iterations_and_limit(self, iterations, limit): 141 | """Verifies execution combinations abide by iterations and limit. 142 | 143 | The test checks all the queries that would have been executed as well as the 144 | shape of the results. 145 | 146 | Args: 147 | iterations: The number of times to execute the requested query. 148 | limit: The number of parameter_values to gather execution data for. 149 | """ 150 | e2e_evaluation.evaluate_workload( 151 | workload_eval=self._workload_eval, 152 | template=self._template, 153 | iterations=iterations, 154 | batch_size=100, 155 | limit=limit, 156 | execute_query_fn=self._execute_query, 157 | checkpoint_results_fn=self._checkpoint_results, 158 | ) 159 | 160 | self._verify_query_executions( 161 | iterations=iterations, params=self._params[:limit] 162 | ) 163 | 164 | keys = self._keys[:limit] 165 | rows = [params[2] for params in self._params] 166 | self._verify_results( 167 | iterations=iterations, 168 | expected_checkpoint_count=1, 169 | keys_per_checkpoint=[keys], 170 | rows_per_checkpoint=[rows], 171 | ) 172 | 173 | @parameterized.named_parameters( 174 | dict(testcase_name="tiny_batch", batch_size=1), 175 | dict(testcase_name="small_batch", batch_size=2), 176 | dict(testcase_name="input_size_batch", batch_size=3), 177 | dict(testcase_name="large_batch", batch_size=1000), 178 | ) 179 | def test_batch_size(self, batch_size): 180 | """Verifies results are checkpointed in accordance with batch size. 181 | 182 | The test checks all the queries that would have been executed as well as the 183 | shape of the results. 184 | 185 | Args: 186 | batch_size: The number of parameter_values to fully evaluate before 187 | calling checkpoint_results_fn. 188 | """ 189 | e2e_evaluation.evaluate_workload( 190 | workload_eval=self._workload_eval, 191 | template=self._template, 192 | iterations=1, 193 | batch_size=batch_size, 194 | limit=None, 195 | execute_query_fn=self._execute_query, 196 | checkpoint_results_fn=self._checkpoint_results, 197 | ) 198 | 199 | self._verify_query_executions(iterations=1, params=self._params) 200 | 201 | expected_checkpoint_count = math.ceil(len(self._params) / batch_size) 202 | 203 | # Handle an edge condition where the last execution iteration will 204 | # checkpoint immediately before the post-execution checkpoint. 205 | if len(self._params) % batch_size == 0: 206 | expected_checkpoint_count += 1 207 | 208 | keys_per_checkpoint = [ 209 | self._keys[: (batch_end + 1) * batch_size] 210 | for batch_end in range(expected_checkpoint_count) 211 | ] 212 | rows = [params[2] for params in self._params] 213 | rows_per_checkpoint = [ 214 | rows[: (rows_end + 1) * batch_size] 215 | for rows_end in range(expected_checkpoint_count) 216 | ] 217 | 218 | self._verify_results( 219 | iterations=1, 220 | expected_checkpoint_count=expected_checkpoint_count, 221 | keys_per_checkpoint=keys_per_checkpoint, 222 | rows_per_checkpoint=rows_per_checkpoint, 223 | ) 224 | 225 | 226 | if __name__ == "__main__": 227 | absltest.main() 228 | -------------------------------------------------------------------------------- /kepler/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/model_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/model_trainer/loss_functions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module for Kepler loss functions.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def mse_loss(y_true: tf.Tensor, 22 | y_pred: tf.Tensor) -> tf.Tensor: 23 | """Computes mean squared error. 24 | 25 | Attempts to fit to the negative of the target to be compatible with 26 | argmax in MultiheadModel. The true target values are assumed to 27 | represent costs, so that the model learns to predict the negative cost. 28 | 29 | Args: 30 | y_true: True target costs (e.g. latencies). 31 | y_pred: Model predictions. 32 | 33 | Returns: 34 | Scalar tensor corresponding to the loss. 35 | """ 36 | return tf.math.reduce_mean(tf.square(y_true + y_pred)) 37 | 38 | 39 | def log_mse_loss(y_true: tf.Tensor, 40 | y_pred: tf.Tensor) -> tf.Tensor: 41 | """Computes mean squared error against log of true values. 42 | 43 | Like mse_loss, attempts to fit the negative cost. This loss uses the 44 | log of the cost for better stability. 45 | 46 | Args: 47 | y_true: True target costs (e.g. latencies). 48 | y_pred: Model predictions. 49 | 50 | Returns: 51 | Scalar tensor corresponding to the loss. 52 | """ 53 | return tf.math.reduce_mean(tf.square(tf.math.log(y_true) + y_pred)) 54 | -------------------------------------------------------------------------------- /kepler/model_trainer/loss_functions_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for loss_functions.""" 17 | 18 | import tensorflow as tf 19 | 20 | from kepler.model_trainer import loss_functions 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | 25 | class LossFunctionsTest(parameterized.TestCase): 26 | 27 | @parameterized.named_parameters( 28 | ("expected_zero", tf.constant([1., 2., 2.]), tf.constant([-1., -2., -2.]), 29 | tf.constant([0.])), 30 | ("batch_expected_zero", 31 | tf.constant([[1., 2., 2.], [2., 4., 4.], [3., 5., 5.]]), 32 | tf.constant([[-1., -2., -2.], [-2., -4., -4.], [-3., -5., -5.]]), 33 | tf.constant([0.])), 34 | ("same_nonzero", tf.constant([1., 2., 2.]), tf.constant([1., 2., 2.]), 35 | tf.constant([12.])) 36 | ) 37 | def test_mse_loss(self, y_true, y_pred, expected_loss): 38 | loss = loss_functions.mse_loss(y_true, y_pred) 39 | self.assertTrue(tf.math.equal(loss, expected_loss)) 40 | 41 | @parameterized.named_parameters( 42 | ("expected_zero", tf.constant([1., 1., 1.]), 43 | tf.constant([0., 0., 0.]), 44 | tf.constant([0.])), 45 | ("batch_expected_zero", 46 | tf.math.exp(tf.constant([[1., 2., 2.], [2., 4., 4.], [3., 5., 5.]])), 47 | tf.constant([[-1., -2., -2.], [-2., -4., -4.], [-3., -5., -5.]]), 48 | tf.constant([0.])), 49 | ("same_nonzero", 50 | tf.math.exp(tf.constant([1., 2., 2.])), tf.constant([1., 2., 2.]), 51 | tf.constant([12.])) 52 | ) 53 | def test_log_mse_loss(self, y_true, y_pred, expected_loss): 54 | loss = loss_functions.log_mse_loss(y_true, y_pred) 55 | self.assertTrue(tf.math.equal(loss, expected_loss)) 56 | 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /kepler/model_trainer/model_base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base functionality for Kepler models. 17 | 18 | Contains: 19 | - ModelConfig class for specifying model architectures and hyperparams. 20 | - ModelBase class as an abstract base class for training any Kepler model. 21 | - ModelPredictorBase class as an abstract base class for providing 22 | predictions using a trained model. 23 | """ 24 | import abc 25 | import dataclasses 26 | from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | 31 | from kepler.model_trainer import trainer_util 32 | 33 | JSON = Any 34 | 35 | 36 | def prepare_input_param( 37 | param: Any, predicate: Any, tflite_mode: bool 38 | ) -> np.ndarray: 39 | """Casts and preprocesses input parameter(s). 40 | 41 | Args: 42 | param: List, array, or single parameter value. 43 | predicate: Predicate metadata. 44 | tflite_mode: Whether we are doing tflite inference or not. 45 | 46 | Returns: 47 | Preprocessed parameter(s). 48 | """ 49 | np_type = trainer_util.get_np_type(predicate["data_type"]) 50 | transform = trainer_util.get_predicate_transform(predicate) 51 | if type(param) in [np.ndarray, list]: 52 | param = np.array(list(map(transform, param))) 53 | else: 54 | param = transform(param) 55 | casted_param = np.array(np.atleast_1d(param)).astype(np_type) 56 | if tflite_mode: 57 | casted_param = np.atleast_2d(casted_param) 58 | return casted_param 59 | 60 | 61 | @dataclasses.dataclass(frozen=True) 62 | class ModelConfig(): 63 | """Stores model architecture/hyperparameters. 64 | 65 | TODO(b/217974851): Options for optimizer, weight/batch norm, etc. 66 | """ 67 | layer_sizes: List[int] # All intermediate hidden layer sizes. 68 | dropout_rates: List[float] # Dropout rate for each layer. 69 | learning_rate: float 70 | activation: str 71 | loss: Any # Loss can be custom function or string. 72 | metrics: List[Any] 73 | spectral_norm_multiplier: float = 1.0 74 | num_gp_random_features: int = 128 75 | 76 | 77 | def _apply_preprocessing_layer( 78 | name: str, 79 | param: Mapping[str, Any], 80 | layer: tf.keras.Input, # pytype: disable=invalid-annotation # typed-keras 81 | preprocessing_info: Mapping[str, Any] 82 | ) -> Optional[tf.keras.layers.Layer]: # pytype: disable=invalid-annotation # typed-keras 83 | """Generates the preprocessing layer for a given input parameter. 84 | 85 | We currently support the following types of preprocessing: 86 | - Normalization to mean 0, variance 1. 87 | - Embedding of integer types. This requires specifying min/max of 88 | the integer values, since we can only embed non-negative values. 89 | - String embeddings. 90 | 91 | Args: 92 | name: Suffix for names of produced layers. 93 | param: Contains info about the parameter; a single entry in 94 | self.predicate_metadata. 95 | layer: Keras input layer for this parameter. 96 | preprocessing_info: Contains info about how to preprocess this parameter; a 97 | single entry in self.preprocessing_config. 98 | 99 | Returns: 100 | The preprocessing layer for the parameter. 101 | 102 | Raises: 103 | ValueError: If the parameter type or preprocessing type is invalid. 104 | """ 105 | data_type = param["data_type"] 106 | preprocessing_type = preprocessing_info["type"] 107 | if data_type == "float" and preprocessing_type == "std_normalization": 108 | norm = tf.keras.layers.experimental.preprocessing.Normalization( 109 | mean=preprocessing_info["mean"], 110 | variance=preprocessing_info["variance"], 111 | name=f"normalization_{name}") 112 | return norm(layer) 113 | if data_type == "int": 114 | # Shift layer to be non-negative. 115 | shifted_layer = layer - tf.constant(param["min"], dtype=tf.int64) 116 | if preprocessing_type == "embedding": 117 | embedding = tf.keras.layers.Embedding( 118 | input_dim=param["max"] - param["min"] + 1, 119 | output_dim=preprocessing_info["output_dim"], 120 | name=f"embedding_{name}")( 121 | shifted_layer) 122 | return tf.keras.layers.Flatten()(embedding) 123 | elif preprocessing_type == "one_hot": 124 | onehot_layer = tf.keras.layers.CategoryEncoding( 125 | num_tokens=param["max"] - param["min"] + 1, 126 | output_mode="one_hot", 127 | name=f"one_hot_{name}")( 128 | shifted_layer) 129 | return tf.keras.layers.Flatten()(onehot_layer) 130 | if data_type == "text": 131 | vocabulary = param["distinct_values"] 132 | num_oov_indices = preprocessing_info.get("num_oov_indices", 0) 133 | lookup_layer = tf.keras.layers.StringLookup( 134 | num_oov_indices=num_oov_indices, 135 | vocabulary=vocabulary, 136 | name=f"lookup_{name}")( 137 | layer) 138 | if preprocessing_type == "embedding": 139 | embedding_layer = tf.keras.layers.Embedding( 140 | input_dim=len(vocabulary) + num_oov_indices, 141 | output_dim=preprocessing_info["output_dim"], 142 | name=f"embedding_{name}")( 143 | lookup_layer) 144 | return tf.keras.layers.Flatten()(embedding_layer) 145 | elif preprocessing_type == "one_hot": 146 | onehot_layer = tf.keras.layers.CategoryEncoding( 147 | num_tokens=len(vocabulary) + num_oov_indices, 148 | output_mode="one_hot", 149 | name=f"one_hot_{name}")( 150 | lookup_layer) 151 | return tf.keras.layers.Flatten()(onehot_layer) 152 | 153 | raise ValueError(f"Unsupported preprocessing: parameter type: {data_type}" 154 | f" preprocessing type: {preprocessing_type}") 155 | 156 | 157 | class ModelBase(metaclass=abc.ABCMeta): 158 | """Abstract base for classes that implement a Kepler model. 159 | 160 | A Kepler model is distinguished by the parameters of its corresponding 161 | query template. This base class contains logic to handle input parameters 162 | of varying types (str, int, float), as well as multiple ways to preprocess 163 | each type, e.g. normalization and embedding. 164 | 165 | This class contains abstract methods for performing model training. 166 | """ 167 | 168 | def __init__(self, metadata: JSON, 169 | plan_ids: List[int], 170 | model_config: Optional[ModelConfig], 171 | preprocessing_config: Sequence[Mapping[str, Any]]): 172 | """Initializes base model. 173 | 174 | Args: 175 | metadata: Metadata for entire query (i.e. under query_id key). The 176 | predicates value is structured as a list, with each entry corresponding 177 | to a param. Each entry in this list is a dict containing data_type (int, 178 | float, text), and associated info, e.g. min/max ranges for ints, 179 | distinct values for embedded parameters, etc. 180 | plan_ids: List of plan ids to predict. 181 | model_config: Model configuration data. 182 | preprocessing_config: Preprocessing config for each parameter, structured 183 | as a list corresponding to the elements of metadata['predicates']. 184 | For example: 185 | [{"type": "embedding"}, 186 | {"type": "std_normalization", "mean": 0., "variance": 1.}] 187 | 188 | Raises: 189 | ValueError: If predicates metadata and preprocessing config don't have 190 | the same length. 191 | """ 192 | if len(metadata["predicates"]) != len(preprocessing_config): 193 | raise ValueError("Predicates metadata and preprocessing config have " 194 | "mismatched lengths! %d != %d" % 195 | (len(metadata["predicates"]), 196 | len(preprocessing_config))) 197 | 198 | self._predicate_metadata = metadata["predicates"] 199 | self._num_plans = len(plan_ids) 200 | self._model_index_to_plan_id = {i: plan_id for i, plan_id in 201 | enumerate(plan_ids)} 202 | 203 | self._model_config = model_config 204 | self._preprocessing_config = preprocessing_config 205 | 206 | def _construct_preprocessing_layer(self) -> tf.keras.layers.Layer: # pytype: disable=invalid-annotation # typed-keras 207 | """Constructs input layer and preprocessing layer. 208 | 209 | Returns: 210 | Concatenated preprocessing layer for the model. 211 | """ 212 | self._inputs = [ 213 | self._input_layer(p["data_type"], f"param{i}") 214 | for i, p in enumerate(self._predicate_metadata) 215 | ] 216 | 217 | to_concatenate = [] 218 | for i in range(len(self._predicate_metadata)): 219 | to_concatenate.append( 220 | _apply_preprocessing_layer(f"preprocessing_param{i}", 221 | self._predicate_metadata[i], 222 | self._inputs[i], 223 | self._preprocessing_config[i])) 224 | return tf.keras.layers.Concatenate()(to_concatenate) 225 | 226 | def _input_layer( 227 | self, 228 | data_type: str, 229 | name: Optional[str] = None 230 | ) -> tf.keras.Input: # pytype: disable=invalid-annotation # typed-keras 231 | """Return a 1d input layer with the appropriate data type.""" 232 | return tf.keras.layers.Input( 233 | shape=(1), dtype=trainer_util.get_tf_type(data_type), name=name) 234 | 235 | @abc.abstractmethod 236 | def get_model(self) -> tf.keras.Model: # pytype: disable=invalid-annotation # typed-keras 237 | """Get keras model.""" 238 | raise NotImplementedError 239 | 240 | 241 | class ModelPredictorBase(metaclass=abc.ABCMeta): 242 | """Abstract base for classes that implement inference for a Kepler model.""" 243 | 244 | @abc.abstractmethod 245 | def predict( 246 | self, params: List[Any] 247 | ) -> Tuple[np.ndarray, Optional[Dict[str, Any]]]: 248 | """Returns the predicted plan id from the model. 249 | 250 | Args: 251 | params: Inputs to model inference. 252 | 253 | Returns: 254 | A tuple containing: 255 | 1. Array of plan ids predicted to give the best latency by the model, or 256 | None if the model abstains from making a prediction due to low 257 | confidence. 258 | 2. An optional dictionary containing predictor-specific auxiliary 259 | values. 260 | """ 261 | raise NotImplementedError 262 | -------------------------------------------------------------------------------- /kepler/model_trainer/model_base_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for model_base.""" 17 | 18 | import tensorflow as tf 19 | 20 | from kepler.model_trainer import test_util 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | 25 | class ModelBaseTest(parameterized.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | # ModelBase does not use any model configs nor num_plans. 30 | self.base_model = test_util.ConcreteModelBase( 31 | test_util.TEST_METADATA_0, 32 | [], 33 | None, 34 | test_util.TEST_PREPROCESSING_CONFIG_0, 35 | ) 36 | self.model = self.base_model.get_model() 37 | 38 | def test_preprocessing_dimension(self): 39 | self.assertEqual(self.base_model.preprocessing_layer.dtype, tf.float32) 40 | self.assertEqual(int(self.base_model.preprocessing_layer.shape[1]), 142) 41 | 42 | def test_input_layers(self): 43 | # Inputs should all be dimension 1. 44 | for input_layer in self.base_model._inputs: 45 | self.assertEqual(int(input_layer.shape[1]), 1) 46 | 47 | @parameterized.named_parameters( 48 | ("param0", "embedding_preprocessing_param0", 11, 100), 49 | ("param1", "embedding_preprocessing_param1", 5, 10), 50 | ("param3", "embedding_preprocessing_param3", 33, 5) 51 | ) 52 | def test_embedding_layer_dims(self, layer_name, input_dim, output_dim): 53 | # Input dimension for string embeddings is vocab size + num oov indices. 54 | self.assertIn(layer_name, [l.name for l in self.model.layers]) 55 | layer = self.model.get_layer(layer_name) 56 | self.assertEqual(layer.input_dim, input_dim) 57 | self.assertEqual(layer.output_dim, output_dim) 58 | 59 | @parameterized.named_parameters( 60 | ("param4", "one_hot_preprocessing_param4", 21), 61 | ("param5", "one_hot_preprocessing_param5", 4), 62 | ) 63 | def test_one_hot_layer_dims(self, layer_name, num_tokens): 64 | self.assertIn(layer_name, [l.name for l in self.model.layers]) 65 | layer = self.model.get_layer(layer_name) 66 | self.assertEqual(layer.num_tokens, num_tokens) 67 | 68 | @parameterized.named_parameters( 69 | ("param0", "normalization_preprocessing_param2", 10, 5),) 70 | def test_normalization_layers(self, layer_name, mean, variance): 71 | # Tests that normalization layers have correct mean and variance. 72 | self.assertIn(layer_name, [l.name for l in self.model.layers]) 73 | layer = self.model.get_layer(layer_name) 74 | self.assertEqual(layer.mean, mean) 75 | self.assertEqual(layer.variance, variance) 76 | 77 | 78 | if __name__ == "__main__": 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /kepler/model_trainer/multihead_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Class for multihead model, for which there is an output head for each plan. 17 | 18 | In general, each output value indicates to the how good the corresponding plan 19 | is estimated to be. The predicted plan will be the plan corresponding to the 20 | head with the highest value. 21 | 22 | This model can be trained in a variety of ways. For example, in the 23 | classification setting, the outputs will correspond to the probability that 24 | that plan is the optimal plan for the inputs. One can also define 25 | latency-sensitive training losses for this model. 26 | """ 27 | 28 | from typing import Any, List, Mapping, Sequence 29 | 30 | import tensorflow as tf 31 | 32 | from kepler.model_trainer import model_base 33 | 34 | JSON = Any 35 | 36 | 37 | class MultiheadModel(model_base.ModelBase): 38 | """Model for multihead neural network model. 39 | 40 | The model configurations for this model are: 41 | - layer sizes (and number of layers) 42 | - dropout rates for each layer 43 | - optimizer learning rate 44 | - activation function 45 | - loss function. For example, with no softmax activation, we need to adjust 46 | the loss appropriately (standard categorical_crossentropy won't work). 47 | """ 48 | _inputs: List[tf.keras.Input] # pytype: disable=invalid-annotation # typed-keras 49 | 50 | def __init__(self, metadata: JSON, plan_ids: List[int], 51 | model_config: model_base.ModelConfig, 52 | preprocessing_config: Sequence[Mapping[str, Any]]): 53 | self._initialize_base(metadata, plan_ids, model_config, 54 | preprocessing_config) 55 | self._build_model() 56 | 57 | def _initialize_base(self, metadata: JSON, plan_ids: List[int], 58 | model_config: model_base.ModelConfig, 59 | preprocessing_config: Sequence[Mapping[str, Any]]): 60 | super().__init__(metadata, plan_ids, model_config, preprocessing_config) 61 | 62 | # LINT.IfChange 63 | def _build_model(self) -> None: 64 | """Constructs model via Keras Functional API.""" 65 | prev_layer = self._construct_preprocessing_layer() 66 | for i, (layer_size, dropout_rate) in enumerate( 67 | zip(self._model_config.layer_sizes, self._model_config.dropout_rates)): 68 | dense_layer = tf.keras.layers.Dense( 69 | layer_size, 70 | activation=self._model_config.activation, 71 | name=f"intermediate_dense_{i}")( 72 | prev_layer) 73 | prev_layer = tf.keras.layers.Dropout( 74 | dropout_rate, name=f"dropout_{i}")( 75 | dense_layer) 76 | 77 | result = tf.keras.layers.Dense( 78 | self._num_plans, name="output_dense")( 79 | prev_layer) 80 | model = tf.keras.models.Model(inputs=self._inputs, outputs=result) 81 | model.compile( 82 | optimizer=tf.keras.optimizers.Adam( 83 | learning_rate=self._model_config.learning_rate), 84 | loss=self._model_config.loss, 85 | metrics=self._model_config.metrics) 86 | 87 | self._model = model 88 | # LINT.ThenChange(//depot/google3/research/sir/kepler/model_trainer/sngp_multihead_model.py) 89 | 90 | def get_model(self) -> tf.keras.Model: # pytype: disable=invalid-annotation # typed-keras 91 | return self._model 92 | -------------------------------------------------------------------------------- /kepler/model_trainer/multihead_model_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for multihead model.""" 17 | 18 | import copy 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from kepler.model_trainer import multihead_model 23 | from kepler.model_trainer import sngp_multihead_model 24 | from kepler.model_trainer import test_util 25 | from kepler.model_trainer import trainer_util 26 | from absl.testing import absltest 27 | from absl.testing import parameterized 28 | 29 | 30 | _MODEL_0 = multihead_model.MultiheadModel( 31 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_0)), 32 | test_util.TEST_MODEL_CONFIG_0, test_util.TEST_PREPROCESSING_CONFIG_0) 33 | 34 | _MODEL_1 = multihead_model.MultiheadModel( 35 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_1)), 36 | test_util.TEST_MODEL_CONFIG_1, test_util.TEST_PREPROCESSING_CONFIG_0) 37 | 38 | _MODEL_2 = sngp_multihead_model.SNGPMultiheadModel( 39 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_0)), 40 | test_util.TEST_MODEL_CONFIG_0, test_util.TEST_PREPROCESSING_CONFIG_0) 41 | 42 | _MODEL_3 = sngp_multihead_model.SNGPMultiheadModel( 43 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_1)), 44 | test_util.TEST_MODEL_CONFIG_1, test_util.TEST_PREPROCESSING_CONFIG_0) 45 | 46 | 47 | class MultiheadModelTest(parameterized.TestCase): 48 | 49 | @parameterized.named_parameters( 50 | ("model0", _MODEL_0, test_util.TEST_NUM_PLANS_0), 51 | ("model1", _MODEL_1, test_util.TEST_NUM_PLANS_1), 52 | ("model2", _MODEL_2, test_util.TEST_NUM_PLANS_0), 53 | ("model3", _MODEL_3, test_util.TEST_NUM_PLANS_1) 54 | ) 55 | def test_basic_training(self, model, num_plans): 56 | # Basic check that this model can train without raising errors. 57 | x = copy.deepcopy(test_util.TEST_INPUT_PARAMS_0) 58 | trainer_util.apply_preprocessing(x, test_util.TEST_METADATA_0["predicates"]) 59 | model.get_model().fit(x, np.zeros((4, num_plans)), epochs=1) 60 | 61 | @parameterized.named_parameters( 62 | ("model0", _MODEL_0, test_util.TEST_NUM_PLANS_0), 63 | ("model1", _MODEL_1, test_util.TEST_NUM_PLANS_1) 64 | ) 65 | def test_output_dimensions(self, model, num_plans): 66 | self.assertEqual(int(model.get_model().output.shape[1]), num_plans) 67 | 68 | @parameterized.named_parameters( 69 | ("model0", _MODEL_0, tf.keras.activations.relu), 70 | ("model1", _MODEL_1, tf.keras.activations.tanh), 71 | ("model2", _MODEL_2, tf.keras.activations.relu), 72 | ("model3", _MODEL_3, tf.keras.activations.tanh) 73 | ) 74 | def test_activations(self, model, activation): 75 | for layer in model.get_model().layers: 76 | if "intermediate_dense_" in layer.name: 77 | self.assertEqual(layer.activation, activation) 78 | elif layer.name == "output_dense": 79 | self.assertEqual(layer.activation, tf.keras.activations.linear) 80 | 81 | @parameterized.named_parameters( 82 | ("model0", _MODEL_0, test_util.TEST_MODEL_CONFIG_0, 83 | test_util.TEST_NUM_PLANS_0), 84 | ("model1", _MODEL_1, test_util.TEST_MODEL_CONFIG_1, 85 | test_util.TEST_NUM_PLANS_1), 86 | ("model2", _MODEL_2, test_util.TEST_MODEL_CONFIG_0, 87 | test_util.TEST_NUM_PLANS_0), 88 | ("model3", _MODEL_3, test_util.TEST_MODEL_CONFIG_1, 89 | test_util.TEST_NUM_PLANS_1) 90 | ) 91 | def test_layer_widths(self, model, model_config, num_plans): 92 | for layer in model.get_model().layers: 93 | if "intermediate_dense_" in layer.name: 94 | idx = int(layer.name.split("_")[2]) 95 | self.assertEqual(layer.units, model_config.layer_sizes[idx]) 96 | elif layer.name == "output_dense": 97 | self.assertEqual(layer.units, num_plans) 98 | 99 | @parameterized.named_parameters( 100 | ("model0", _MODEL_0, test_util.TEST_MODEL_CONFIG_0), 101 | ("model1", _MODEL_1, test_util.TEST_MODEL_CONFIG_1), 102 | ("model2", _MODEL_2, test_util.TEST_MODEL_CONFIG_0), 103 | ("model3", _MODEL_3, test_util.TEST_MODEL_CONFIG_1) 104 | ) 105 | def test_dropout_rates(self, model, model_config): 106 | for layer in model.get_model().layers: 107 | if "dropout_" in layer.name: 108 | idx = int(layer.name.split("_")[1]) 109 | self.assertEqual(layer.rate, model_config.dropout_rates[idx]) 110 | 111 | @parameterized.named_parameters( 112 | ("model0", _MODEL_0, 1), 113 | ("model1", _MODEL_1, 2) 114 | ) 115 | def test_num_layers(self, model, target_num_layers): 116 | self.assertLen([l for l in model.get_model().layers 117 | if "intermediate_dense_" in l.name], 118 | target_num_layers) 119 | 120 | 121 | if __name__ == "__main__": 122 | absltest.main() 123 | -------------------------------------------------------------------------------- /kepler/model_trainer/sngp_multihead_model_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for SNGP multihead model.""" 17 | 18 | import copy 19 | from typing import Any, List 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from kepler.model_trainer import model_base 25 | from kepler.model_trainer import sngp_multihead_model 26 | from kepler.model_trainer import test_util 27 | from kepler.model_trainer import trainer_util 28 | from absl.testing import absltest 29 | from absl.testing import parameterized 30 | 31 | 32 | _MODEL_0 = sngp_multihead_model.SNGPMultiheadModel( 33 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_0)), 34 | test_util.TEST_MODEL_CONFIG_0, test_util.TEST_PREPROCESSING_CONFIG_0) 35 | 36 | _MODEL_1 = sngp_multihead_model.SNGPMultiheadModel( 37 | test_util.TEST_METADATA_0, list(range(test_util.TEST_NUM_PLANS_1)), 38 | test_util.TEST_MODEL_CONFIG_1, test_util.TEST_PREPROCESSING_CONFIG_0) 39 | 40 | 41 | class SNGPMultiheadModelTest(parameterized.TestCase): 42 | 43 | @parameterized.named_parameters( 44 | ("model0", _MODEL_0, test_util.TEST_NUM_PLANS_0), 45 | ("model1", _MODEL_1, test_util.TEST_NUM_PLANS_1) 46 | ) 47 | def test_output_dimensions(self, model, num_plans): 48 | self.assertEqual(int(model.get_model().output[0].shape[1]), num_plans) 49 | 50 | @parameterized.named_parameters( 51 | ("model0", _MODEL_0, 1), 52 | ("model1", _MODEL_1, 2) 53 | ) 54 | def test_num_layers(self, model, target_num_layers): 55 | self.assertLen([l for l in model.get_model().layers 56 | if "spectral_norm_" in l.name], 57 | target_num_layers) 58 | 59 | @parameterized.named_parameters( 60 | ("model0", _MODEL_0, test_util.TEST_MODEL_CONFIG_0), 61 | ("model1", _MODEL_1, test_util.TEST_MODEL_CONFIG_1) 62 | ) 63 | def test_spectral_norm_multiplier(self, model, model_config): 64 | num_spectral_norm_layers = 0 65 | for layer in model.get_model().layers: 66 | if "spectral_norm_" in layer.name: 67 | self.assertEqual(layer.norm_multiplier, 68 | model_config.spectral_norm_multiplier) 69 | num_spectral_norm_layers += 1 70 | self.assertLen(model_config.layer_sizes, num_spectral_norm_layers) 71 | 72 | @parameterized.named_parameters( 73 | ("model0", _MODEL_0, test_util.TEST_MODEL_CONFIG_0), 74 | ("model1", _MODEL_1, test_util.TEST_MODEL_CONFIG_1) 75 | ) 76 | def test_num_gp_features(self, model, model_config): 77 | contains_output_gp_layer = False 78 | for layer in model.get_model().layers: 79 | if "output_gp_layer" == layer.name: 80 | self.assertEqual(layer.num_inducing, 81 | model_config.num_gp_random_features) 82 | contains_output_gp_layer = True 83 | self.assertTrue(contains_output_gp_layer) 84 | 85 | 86 | def _get_tflite_predictor( 87 | x: np.ndarray, 88 | model: tf.keras.Model, 89 | metadata: Any, 90 | plan_cover: List[int], 91 | confidence_threshold: float, 92 | train: bool, 93 | ) -> model_base.ModelPredictorBase: 94 | if train: 95 | # Set the 4th plan as optimal for the first 2 inputs and the 2nd plan as 96 | # optional for the latter 2 inputs. 97 | y = np.zeros((4, test_util.TEST_NUM_PLANS_0)) 98 | y[0, 3] = 1 99 | y[1, 3] = 1 100 | y[2, 1] = 1 101 | y[3, 1] = 1 102 | 103 | model.fit(x, y, epochs=50) 104 | 105 | tflite_model = trainer_util.convert_to_tflite(model) 106 | return sngp_multihead_model.SNGPMultiheadModelTFLitePredictor( 107 | tflite_model_content=tflite_model, 108 | metadata=metadata, 109 | plan_cover=plan_cover, 110 | confidence_threshold=confidence_threshold, 111 | ) 112 | 113 | 114 | def _get_keras_model_predictor( 115 | x: np.ndarray, 116 | model: tf.keras.Model, 117 | metadata: Any, 118 | plan_cover: List[int], 119 | confidence_threshold: float, 120 | train: bool, 121 | ) -> model_base.ModelPredictorBase: 122 | if train: 123 | # Set the 4th plan as optimal for the first 2 inputs and the 2nd plan as 124 | # optional for the latter 2 inputs. 125 | y = np.zeros((4, test_util.TEST_NUM_PLANS_0)) 126 | y[0, 3] = 1 127 | y[1, 3] = 1 128 | y[2, 1] = 1 129 | y[3, 1] = 1 130 | 131 | model.fit(x, y, epochs=50) 132 | 133 | return sngp_multihead_model.SNGPMultiheadModelPredictor( 134 | model=model, 135 | metadata=metadata, 136 | plan_cover=plan_cover, 137 | confidence_threshold=confidence_threshold, 138 | ) 139 | 140 | 141 | class SNGPMultiheadModelPredictorTest(parameterized.TestCase): 142 | 143 | def setUp(self): 144 | super().setUp() 145 | 146 | self._model = _MODEL_0.get_model() 147 | self._x_train = copy.deepcopy(test_util.TEST_INPUT_PARAMS_0) 148 | self._x_input = test_util.TEST_INPUT_PARAMS_0 149 | self._metadata = test_util.TEST_METADATA_0 150 | self._plan_cover = [10, 11, 12, 13, 14] 151 | trainer_util.apply_preprocessing( 152 | self._x_train, self._metadata["predicates"] 153 | ) 154 | 155 | # The expected_predictions are based on the values set for y in 156 | # _get_tflite_model(). 157 | @parameterized.named_parameters( 158 | dict( 159 | testcase_name="tflite low abstention", 160 | get_predictor_fn=_get_tflite_predictor, 161 | confidence_threshold=0.6, 162 | train=True, 163 | expected_predictions=[13, 13, 11, 11], 164 | batch_inference=False, 165 | ), 166 | dict( 167 | testcase_name="tflite low abstention batch", 168 | get_predictor_fn=_get_tflite_predictor, 169 | confidence_threshold=0.6, 170 | train=True, 171 | expected_predictions=[13, 13, 11, 11], 172 | batch_inference=True, 173 | ), 174 | dict( 175 | testcase_name="tflite high abstention", 176 | get_predictor_fn=_get_tflite_predictor, 177 | confidence_threshold=0.95, 178 | train=True, 179 | expected_predictions=[None, None, None, None], 180 | batch_inference=False, 181 | ), 182 | dict( 183 | testcase_name="keras model low abstention", 184 | get_predictor_fn=_get_keras_model_predictor, 185 | confidence_threshold=0.6, 186 | train=True, 187 | expected_predictions=[13, 13, 11, 11], 188 | batch_inference=False, 189 | ), 190 | dict( 191 | testcase_name="keras model high abstention", 192 | get_predictor_fn=_get_keras_model_predictor, 193 | confidence_threshold=0.95, 194 | train=True, 195 | expected_predictions=[None, None, None, None], 196 | batch_inference=False, 197 | ), 198 | ) 199 | def test_predict( 200 | self, get_predictor_fn, confidence_threshold, train, expected_predictions, 201 | batch_inference 202 | ): 203 | predictor = get_predictor_fn( 204 | x=self._x_train, 205 | model=self._model, 206 | metadata=self._metadata, 207 | plan_cover=self._plan_cover, 208 | train=train, 209 | confidence_threshold=confidence_threshold, 210 | ) 211 | 212 | for i, params in enumerate(zip(*self._x_input)): 213 | if batch_inference: 214 | params = [[p] for p in params] 215 | prediction, auxiliary = predictor.predict(params) 216 | self.assertEqual(prediction, expected_predictions[i]) 217 | self.assertIn("confidences", auxiliary) 218 | self.assertLen(auxiliary["confidences"][0], test_util.TEST_NUM_PLANS_0) 219 | 220 | # The expected_predictions are based on the values set for y in 221 | # _get_tflite_model(). 222 | @parameterized.named_parameters( 223 | dict( 224 | testcase_name="tflite model", 225 | get_predictor_fn=_get_tflite_predictor, 226 | confidence_threshold=0.2, 227 | train=False, 228 | ), 229 | dict( 230 | testcase_name="keras model", 231 | get_predictor_fn=_get_keras_model_predictor, 232 | confidence_threshold=0.2, 233 | train=False, 234 | ), 235 | ) 236 | def test_illegal_calls(self, get_predictor_fn, confidence_threshold, train): 237 | predictor = get_predictor_fn( 238 | x=self._x_train, 239 | model=self._model, 240 | metadata=self._metadata, 241 | plan_cover=self._plan_cover, 242 | train=train, 243 | confidence_threshold=confidence_threshold, 244 | ) 245 | 246 | self.assertRaisesRegex( 247 | ValueError, 248 | "Expected 7 parameter values and received 1 instead", 249 | predictor.predict, 250 | [1], 251 | ) 252 | 253 | self.assertRaisesRegex( 254 | ValueError, 255 | "could not convert string to float", 256 | predictor.predict, 257 | ["a"] * 7, 258 | ) 259 | 260 | short_metadata = {"predicates": [1]} 261 | self.assertRaisesRegex( 262 | ValueError, 263 | "Provided 1 predicates for a model", 264 | get_predictor_fn, 265 | self._x_train, 266 | self._model, 267 | short_metadata, 268 | self._plan_cover, 269 | train, 270 | confidence_threshold, 271 | ) 272 | 273 | short_plan_cover = [1] 274 | self.assertRaisesRegex( 275 | ValueError, 276 | r"Provided plan cover size of 1", 277 | get_predictor_fn, 278 | self._x_train, 279 | self._model, 280 | self._metadata, 281 | short_plan_cover, 282 | train, 283 | confidence_threshold, 284 | ) 285 | 286 | 287 | if __name__ == "__main__": 288 | absltest.main() 289 | -------------------------------------------------------------------------------- /kepler/model_trainer/test_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utils/constants for tests.""" 17 | 18 | from typing import Any, List, Mapping, Optional, Sequence 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | from kepler.model_trainer import loss_functions 24 | from kepler.model_trainer import model_base 25 | 26 | 27 | class ConcreteModelBase(model_base.ModelBase): 28 | """Make non-abstract class for testing.""" 29 | 30 | _inputs: List[tf.keras.Input] # pytype: disable=invalid-annotation # typed-keras 31 | 32 | def __init__( 33 | self, 34 | metadata: Mapping[str, Any], 35 | plan_ids: List[int], 36 | model_config: Optional[model_base.ModelConfig], 37 | preprocessing_config: Sequence[Mapping[str, Any]], 38 | ): 39 | super().__init__(metadata, plan_ids, model_config, preprocessing_config) 40 | 41 | self.preprocessing_layer = self._construct_preprocessing_layer() 42 | self._model = tf.keras.models.Model( 43 | inputs=self._inputs, outputs=self.preprocessing_layer 44 | ) 45 | 46 | def get_model(self) -> tf.keras.Model: 47 | return self._model 48 | 49 | 50 | # Only contains a subset of predicates metadata by design. Actual metadata 51 | # has other keys, which are unused and thus omitted for the sake of brevity. 52 | TEST_METADATA_0 = { 53 | "predicates": [ 54 | { 55 | "alias": "s", 56 | "column": "name", 57 | "operator": "=", 58 | "data_type": "text", 59 | "distinct_values": ["a", "b", "c", "d", "e", "f"], 60 | }, 61 | { 62 | "alias": "t", 63 | "column": "name", 64 | "operator": "=", 65 | "data_type": "text", 66 | "distinct_values": ["aa", "bb", "cc", "dd", "ee"], 67 | }, 68 | { 69 | "alias": "q", 70 | "column": "score", 71 | "operator": ">", 72 | "data_type": "float", 73 | }, 74 | { 75 | "alias": "q", 76 | "column": "view_count", 77 | "operator": "<", 78 | "data_type": "int", 79 | "min": -5, 80 | "max": 27, 81 | }, 82 | { 83 | "alias": "q", 84 | "column": "upvotes", 85 | "operator": "<", 86 | "data_type": "int", 87 | "min": 0, 88 | "max": 20, 89 | }, 90 | { 91 | "alias": "b", 92 | "column": "name", 93 | "operator": "=", 94 | "data_type": "text", 95 | "distinct_values": ["a", "b", "c"], 96 | }, 97 | { 98 | "alias": "b", 99 | "column": "name", 100 | "operator": "=", 101 | "data_type": "float", 102 | "preprocess_type": "to_timestamp", 103 | }, 104 | ] 105 | } 106 | 107 | TEST_PREPROCESSING_CONFIG_0 = [ 108 | {"type": "embedding", "output_dim": 100, "num_oov_indices": 5}, 109 | {"type": "embedding", "output_dim": 10}, 110 | {"type": "std_normalization", "mean": 10, "variance": 5}, 111 | {"type": "embedding", "output_dim": 5}, 112 | {"type": "one_hot"}, 113 | {"type": "one_hot", "num_oov_indices": 1}, 114 | {"type": "std_normalization", "mean": 1453276800, "variance": 1e10}, 115 | ] 116 | 117 | # Corresponds to data_management test_query 0. 118 | TEST_PREPROCESSING_CONFIG_1 = [{ 119 | "type": "embedding", 120 | "output_dim": 100, 121 | "num_oov_indices": 5 122 | }, { 123 | "type": "embedding", 124 | "output_dim": 100, 125 | "num_oov_indices": 1, 126 | }, { 127 | "type": "embedding", 128 | "output_dim": 10 129 | }, { 130 | "type": "std_normalization", 131 | "mean": 0, 132 | "variance": 1 133 | }] 134 | 135 | TEST_INPUT_PARAMS_0 = [ 136 | np.array(["a", "c", "e", "z"]), 137 | np.array(["aa", "bb", "cc", "ee"]), 138 | np.ones(4), 139 | np.arange(4), 140 | np.arange(4), 141 | np.array(["a"] * 4), 142 | np.array(["2016-01-20", "2016-01-20", "2016-01-20", "2016-01-20"]), 143 | ] 144 | 145 | # Note: We would not want to actually train the model using the default BCE 146 | # loss, since we assume linear output activation. Instead, in the classification 147 | # setting one should use BCE loss with from_logits set to True. 148 | TEST_MODEL_CONFIG_0 = model_base.ModelConfig( 149 | [64], 150 | [0.1], 151 | 1e-3, 152 | "relu", 153 | tf.keras.losses.BinaryCrossentropy(from_logits=True), 154 | [], 155 | ) 156 | TEST_NUM_PLANS_0 = 5 157 | 158 | TEST_MODEL_CONFIG_1 = model_base.ModelConfig( 159 | [64, 32], [0.25, 0.2], 1e-3, "tanh", "binary_crossentropy", ["accuracy"] 160 | ) 161 | TEST_NUM_PLANS_1 = 10 162 | 163 | TEST_MODEL_CONFIG_2 = model_base.ModelConfig( 164 | [64], [0.1], 1e-3, "relu", loss_functions.mse_loss, 165 | [tf.keras.metrics.MeanAbsoluteError()]) 166 | -------------------------------------------------------------------------------- /kepler/model_trainer/trainer_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example script for training a Kepler model. 17 | 18 | This script is mostly designed for illustrative purposes, hence it does not 19 | fully parameterize things like model type, model architectures, hyperparameters, 20 | etc. We recommend using Colab for model experimentation instead. 21 | """ 22 | import json 23 | import os 24 | 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | from kepler.data_management import database_simulator 33 | from kepler.data_management import workload 34 | from kepler.model_trainer import evaluation 35 | from kepler.model_trainer import model_base 36 | from kepler.model_trainer import sngp_multihead_model 37 | from kepler.model_trainer import trainer 38 | from kepler.model_trainer import trainer_util 39 | 40 | _QUERY_METADATA = flags.DEFINE_string( 41 | "query_metadata", 42 | None, 43 | ( 44 | "File containing query metadata which includes information on parameter" 45 | " values." 46 | ), 47 | ) 48 | flags.mark_flag_as_required("query_metadata") 49 | 50 | _EXECUTION_DATA = flags.DEFINE_string( 51 | "execution_data", 52 | None, 53 | "File containing query execution data across various parameter values.", 54 | ) 55 | flags.mark_flag_as_required("execution_data") 56 | 57 | _EXECUTION_METADATA = flags.DEFINE_string( 58 | "execution_metadata", 59 | None, 60 | "File containing query execution metadata across various parameter values.", 61 | ) 62 | flags.mark_flag_as_required("execution_metadata") 63 | 64 | _PREPROCESSING_CONFIG = flags.DEFINE_string( 65 | "preprocessing_config", 66 | None, 67 | "File containing preprocessing config information.", 68 | ) 69 | 70 | _VOCAB_DATA_DIR = flags.DEFINE_string( 71 | "vocab_data_dir", None, "Folder containing distinct values for each column." 72 | ) 73 | flags.mark_flag_as_required("vocab_data_dir") 74 | 75 | _QUERY_ID = flags.DEFINE_string("query_id", None, "Name of query to train on") 76 | flags.mark_flag_as_required("query_id") 77 | _SEED = flags.DEFINE_integer( 78 | "seed", 0, "Seed used to shuffle workload before splitting." 79 | ) 80 | 81 | _TRAIN_SPLIT = flags.DEFINE_float( 82 | "train_split", 83 | 0.8, 84 | "Fraction of data to use for training." 85 | ) 86 | _NUM_EPOCHS = flags.DEFINE_integer( 87 | "num_epochs", 400, "Number of epochs to train for." 88 | ) 89 | _BATCH_SIZE = flags.DEFINE_integer("batch_size", 64, "Training minibatch size.") 90 | _VOCAB_SIZE_LIMIT = flags.DEFINE_integer( 91 | "vocab_size_limit", 200, "Maximum vocabulary size." 92 | ) 93 | 94 | 95 | def _get_distinct_values(table, column): 96 | with open( 97 | os.path.join(_VOCAB_DATA_DIR.value, f"{table}-{column}-distinct_values") 98 | ) as f: 99 | return json.load(f) 100 | 101 | 102 | def main(unused_argv): 103 | query_id = _QUERY_ID.value 104 | with open(_QUERY_METADATA.value) as json_file: 105 | query_metadata = json.load(json_file) 106 | metadata = query_metadata[query_id] 107 | 108 | with open(_EXECUTION_DATA.value) as json_file: 109 | execution_data = json.load(json_file) 110 | 111 | with open(_EXECUTION_METADATA.value) as json_file: 112 | execution_metadata = json.load(json_file) 113 | 114 | if _PREPROCESSING_CONFIG.value: 115 | with open(_PREPROCESSING_CONFIG.value) as json_file: 116 | preprocessing_config = json.load(json_file) 117 | else: 118 | preprocessing_config = trainer_util.construct_preprocessing_config( 119 | metadata) 120 | 121 | plans = workload.KeplerPlanDiscoverer( 122 | query_execution_metadata=execution_metadata) 123 | database = database_simulator.DatabaseSimulator( 124 | query_execution_data=execution_data, 125 | query_execution_metadata=execution_metadata, 126 | estimator=database_simulator.LatencyEstimator.MIN, 127 | ) 128 | client = database_simulator.DatabaseClient(database) 129 | workload_generator = workload.WorkloadGenerator(execution_data) 130 | full_workload = workload_generator.all() 131 | workload.shuffle(full_workload, seed=_SEED.value) 132 | 133 | workload_train, workload_eval = workload.split( 134 | full_workload, first_half_fraction=_TRAIN_SPLIT.value, 135 | ) 136 | 137 | # Additionally fetch all default execution times to compute near-optimality. 138 | queries_train = workload.create_query_batch(plans.plan_ids, workload_train) 139 | queries_train_with_default = workload.create_query_batch( 140 | plans.plan_ids + [None], workload_train) 141 | query_execution_train_df = client.execute_timed_batch( 142 | planned_queries=queries_train) 143 | query_execution_train_with_default_df = client.execute_timed_batch( 144 | planned_queries=queries_train_with_default) 145 | 146 | default_latencies_train = evaluation.get_default_latencies( 147 | database=database, 148 | query_workload=workload_train) 149 | default_latencies_train = np.atleast_2d(default_latencies_train).T 150 | 151 | trainer_util.add_vocabulary_to_metadata( 152 | query_execution_train_df, metadata, 153 | _get_distinct_values, _VOCAB_SIZE_LIMIT.value) 154 | 155 | logging.info( 156 | "Num queries: %d, num plans %d", 157 | len(queries_train), 158 | len(plans.plan_ids), 159 | ) 160 | 161 | loss = tf.keras.losses.BinaryCrossentropy(from_logits=True) 162 | model_config = model_base.ModelConfig( 163 | layer_sizes=[64, 64, 64], 164 | dropout_rates=[0., 0., 0.], 165 | learning_rate=3e-4, 166 | activation="relu", 167 | spectral_norm_multiplier=10.0, 168 | num_gp_random_features=128, 169 | loss=loss, 170 | metrics=[] 171 | ) 172 | 173 | model = sngp_multihead_model.SNGPMultiheadModel( 174 | metadata, plans.plan_ids, model_config, 175 | preprocessing_config) 176 | 177 | t = trainer.NearOptimalClassificationTrainer(metadata, plans, model) 178 | x, y = t.construct_training_data(query_execution_train_with_default_df, 179 | default_relative=True, 180 | near_optimal_threshold=1.1) 181 | 182 | # Construct sample weight matrix. 183 | all_latencies = np.array(query_execution_train_df["latency_ms"]).reshape( 184 | (-1, len(plans.plan_ids))) 185 | sample_weight = trainer_util.get_sample_weight(all_latencies, 186 | default_latencies_train) 187 | 188 | print("Training on %d samples" % len(y)) 189 | t.train(x, y, 190 | epochs=_NUM_EPOCHS.value, 191 | batch_size=_BATCH_SIZE.value, 192 | sample_weight=sample_weight) 193 | 194 | # Perform some basic model evaluation. 195 | model_predictor = sngp_multihead_model.SNGPMultiheadModelPredictor( 196 | model.get_model(), metadata, plan_cover=plans.plan_ids, 197 | confidence_threshold=0.9) 198 | 199 | eval_inputs = trainer_util.construct_multihead_model_inputs(workload_eval) 200 | plan_selections, _ = model_predictor.predict(eval_inputs) 201 | 202 | candidate_latencies = evaluation.get_candidate_latencies( 203 | database=database, 204 | query_workload=workload_eval, 205 | plan_selections=plan_selections) 206 | default_latencies = evaluation.get_default_latencies( 207 | database=database, 208 | query_workload=workload_eval) 209 | optimal_latencies = evaluation.get_optimal_latencies( 210 | client=client, 211 | query_workload=workload_eval, 212 | kepler_plan_discoverer=plans) 213 | 214 | evaluation.evaluate( 215 | candidate_latencies=candidate_latencies, 216 | default_latencies=default_latencies, 217 | optimal_latencies=optimal_latencies) 218 | 219 | 220 | if __name__ == "__main__": 221 | app.run(main) 222 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/README: -------------------------------------------------------------------------------- 1 | ## How to Run Unit Tests 2 | 3 | blaze test :target --noforge 4 | 5 | The unit tests currently required an externally running instance of postgres on 6 | the workstation with a database called "test" and a user called "test" with 7 | password "test". 8 | 9 | ### Install postgres on the workstation. 10 | Example: https://www.digitalocean.com/community/tutorials/how-to-install-postgresql-on-ubuntu-20-04-quickstart 11 | 12 | Once you get the database prompt in Step 2: 13 | 14 | CREATE USER test superuser password "test"; 15 | CREATE database test; 16 | 17 | You may also benefit from: 18 | 19 | CREATE USER superuser; 20 | 21 | Consider setting up a cronjob to clean up test databases. See comments in scripts/clean_up_test_databases.sh. 22 | 23 | Test connection: 24 | 25 | $ psql test -U test -h localhost --password 26 | 27 | ### Install the pg_stat_statements extension. 28 | Example: https://pganalyze.com/docs/install/01_enabling_pg_stat_statements 29 | 30 | At the database prompt: 31 | 32 | CREATE EXTENSION pg_stat_statements; 33 | 34 | Update shared_preload_libraries in the postgresql.conf. Location: 35 | 36 | $ psql -U test -c 'SHOW config_file' -h localhost 37 | 38 | Or at the database prompt: 39 | 40 | SHOW config_file; 41 | 42 | Edit the shared_preload_libraries line to included pg_stat_statements. 43 | 44 | shared_preload_libraries = 'pg_stat_statements' 45 | 46 | Restart postgres. 47 | 48 | sudo service postgresql restart 49 | 50 | ## How to Speed Up The Parameter Generator 51 | 52 | Early experiments demonstrated a query latency improvement of seconds/minutes vs 53 | hours when Vertica was used instead of Postgres to generate parameters for 54 | stack. 55 | 56 | The following code is useful to add to query_utils to use Vertica. The client 57 | library is not in third-party and Vertica is not used in any other part of the 58 | training data pipeline, so this code is not submitted with query_utils. 59 | 60 | 1. Adjust imports and pip install vertica_python if necessary. 61 | import vertica_python 62 | 63 | 2. Add a function to QueryManager to connect to Vertica using vertica_python. 64 | def connect_vertica(self, database: str, schema: str) -> None: 65 | """Connects to a Vertica database using the vertica_python client.""" 66 | 67 | self._database = database 68 | self._schema = schema 69 | conn_info = { 70 | 'host': '127.0.0.1', 71 | 'port': 5433, 72 | 'user': 'dbadmin', 73 | 'password': '', 74 | 'database': self._database, 75 | # Autogenerated session label by default. 76 | 'session_label': 'dataset', 77 | # Default throw error on invalid UTF-8 results. 78 | 'unicode_error': 'strict', 79 | # SSL is disabled by default. 80 | 'ssl': False, 81 | # Autocommit is off by default. 82 | 'autocommit': True, 83 | # Using server-side prepared statements is disabled by default. 84 | 'use_prepared_statements': False, 85 | # Connection timeout is not enabled by default. 86 | 'connection_timeout': 1500000 87 | } 88 | 89 | self._conn = vertica_python.connect(**conn_info) 90 | self._cursor = self._conn.cursor() 91 | self._cursor.execute(f'set search_path to {self._schema};') 92 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/e2e_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """End-to-end tests for the training data collection pipeline. 17 | 18 | The primary goal of this test is to ensure the outputs and inputs of each 19 | pipeline step flow seamlessly into the next step for various pipeline 20 | compositions. 21 | 22 | The correctness of the output of each library function is verified by the unit 23 | tests for that library function. 24 | """ 25 | 26 | import copy 27 | import json 28 | from typing import Any, List, Optional, Tuple 29 | 30 | from kepler.training_data_collection_pipeline import parameter_generator 31 | from kepler.training_data_collection_pipeline import pg_execute_training_data_queries 32 | from kepler.training_data_collection_pipeline import pg_generate_plan_candidates 33 | from kepler.training_data_collection_pipeline import pg_plan_hint_extractor 34 | from kepler.training_data_collection_pipeline import query_utils 35 | from kepler.training_data_collection_pipeline import test_util 36 | from absl.testing import absltest 37 | from absl.testing import parameterized 38 | 39 | # Typing aliases. 40 | JSON = Any 41 | 42 | _TEST_TEMPLATE_INTEGER = """ 43 | { 44 | "query": "SELECT j, k, a, b, d_date \\n FROM \\n baz as baz \\n JOIN \\n bar as bar \\n on j = b \\n JOIN \\n foo as foo \\n on x = j \\n WHERE foo.y = 2 \\n and baz.j = '@param0' \\n and baz.k = '@param1';", 45 | "predicates": [ 46 | { 47 | "alias": "baz", 48 | "column": "j", 49 | "operator": "=" 50 | }, 51 | { 52 | "alias": "baz", 53 | "column": "k", 54 | "operator": "=" 55 | } 56 | ] 57 | } 58 | """ 59 | 60 | _GET_QUERY_PLANS_KWARGS = { 61 | "configs": ["enable_nestloop", "enable_hashjoin"], 62 | } 63 | 64 | _GENERATE_BY_ROW_NUM_EVOLUTION_KWARGS = { 65 | "num_generations": 3, 66 | "num_mutations_per_plan": 3, 67 | "exponent_base": 10, 68 | "exponent_range": 3, 69 | "max_plans_per_generation": 3, 70 | "perturb_unit_only": False 71 | } 72 | 73 | _GENERATE_BY_EXHAUSTIVE_CARDINALITY_PERTURBATION_KWARGS = { 74 | "cardinality_multipliers": [.1, 10], 75 | } 76 | 77 | 78 | class E2ETest(parameterized.TestCase): 79 | 80 | def setUp(self): 81 | super().setUp() 82 | self._test_database = test_util.TestDatabase() 83 | self._database_configuration = query_utils.DatabaseConfiguration( 84 | dbname=self._test_database.dbname, 85 | user=test_util.USER, 86 | password=test_util.PASSWORD) 87 | self._query_manager = query_utils.QueryManager(self._database_configuration) 88 | test_util.populate_database(self._query_manager) 89 | self._parameter_generator = parameter_generator.ParameterGenerator( 90 | self._database_configuration) 91 | 92 | def tearDown(self): 93 | self._test_database.drop() 94 | super().tearDown() 95 | 96 | @parameterized.named_parameters( 97 | ("get_query_plans", pg_generate_plan_candidates.get_query_plans, 98 | _GET_QUERY_PLANS_KWARGS, 2), 99 | ("generate_by_row_num_evolution", 100 | pg_generate_plan_candidates.generate_by_row_num_evolution, 101 | _GENERATE_BY_ROW_NUM_EVOLUTION_KWARGS, 3), 102 | ("generate_by_exhaustive_cardinality_perturbations", 103 | pg_generate_plan_candidates 104 | .generate_by_exhaustive_cardinality_perturbations, 105 | _GENERATE_BY_EXHAUSTIVE_CARDINALITY_PERTURBATION_KWARGS, 2)) 106 | def test_run_pipeline(self, plan_generation_function, 107 | plan_generation_function_specific_kwargs, 108 | extracted_plan_hints_count): 109 | # Step 0: Generate parameters. 110 | template_item = parameter_generator.TemplateItem( 111 | query_id=test_util.TEST_QUERY_ID, 112 | template=json.loads(_TEST_TEMPLATE_INTEGER)) 113 | parameter_count = 3 114 | query_templates = self._parameter_generator.generate_parameters( 115 | count=parameter_count, template_item=template_item) 116 | query_metadata = query_templates[test_util.TEST_QUERY_ID] 117 | self.assertLen(query_metadata["params"], parameter_count) 118 | 119 | # Step 1: Generate plan candidates. 120 | plan_generation_function_kwargs = { 121 | "query": query_metadata["query"], 122 | "database_configuration": self._database_configuration 123 | } 124 | plan_generation_function_kwargs.update( 125 | plan_generation_function_specific_kwargs) 126 | plan_hint_extractor = pg_plan_hint_extractor.PlanHintExtractor() 127 | pg_generate_plan_candidates.execute_plan_generation( 128 | function=plan_generation_function, 129 | function_kwargs=plan_generation_function_kwargs, 130 | all_params=query_metadata["params"], 131 | plan_hint_extractor=plan_hint_extractor, 132 | chunksize=1, 133 | distributed=False) 134 | hint_counts, plan_hints, parameter_values_and_plan_indices, debug_infos = plan_hint_extractor.get_consolidated_plan_hints( 135 | ) 136 | self.assertLen(hint_counts, extracted_plan_hints_count) 137 | self.assertLen(plan_hints, extracted_plan_hints_count) 138 | self.assertLen(parameter_values_and_plan_indices, parameter_count) 139 | if plan_generation_function == pg_generate_plan_candidates.generate_by_row_num_evolution: 140 | self.assertLen(debug_infos, parameter_count) 141 | 142 | # Step 2: Verify plan hints. 143 | failure_counts = pg_plan_hint_extractor.verify_hints( 144 | query_id=test_util.TEST_QUERY_ID, 145 | query=query_metadata["query"], 146 | plan_hints=plan_hints, 147 | params_plan_indices=parameter_values_and_plan_indices, 148 | database_configuration=self._database_configuration) 149 | self.assertLen(failure_counts, 1) 150 | self.assertLen(failure_counts[test_util.TEST_QUERY_ID], 151 | extracted_plan_hints_count) 152 | 153 | # Step 3: Collect execution results. 154 | query_manager = query_utils.QueryManager(self._database_configuration) 155 | 156 | def execute_query( 157 | unused_query_manager: query_utils.QueryManager, 158 | query: str, 159 | params: List[Any], 160 | timeout_ms: Optional[int] = None, 161 | ) -> Tuple[Optional[float], Optional[int]]: 162 | del unused_query_manager 163 | return query_manager.execute_timed(query, params, timeout_ms) 164 | 165 | execute_query_results_key = "duration_ms" 166 | 167 | execution_results = {} 168 | execution_metadata = {} 169 | 170 | def checkpoint_results(query_id: str, results: Any, 171 | is_metadata: bool) -> None: 172 | self.assertEqual(test_util.TEST_QUERY_ID, query_id) 173 | 174 | if is_metadata: 175 | nonlocal execution_metadata 176 | execution_metadata = copy.deepcopy(results) 177 | else: 178 | nonlocal execution_results 179 | execution_results = copy.deepcopy(results) 180 | 181 | pg_execute_training_data_queries.execute_training_data_queries( 182 | batch_index=0, 183 | parameter_values={ 184 | test_util.TEST_QUERY_ID: parameter_values_and_plan_indices 185 | }, 186 | query_id=test_util.TEST_QUERY_ID, 187 | templates=query_templates, 188 | plan_hints={test_util.TEST_QUERY_ID: plan_hints}, 189 | iterations=3, 190 | batch_size=2, 191 | skip_indices=[], 192 | query_timeout_multiplier=3, 193 | query_timeout_min_ms=30, 194 | query_timeout_max_ms=300, 195 | execute_query_fn=execute_query, 196 | checkpoint_results_fn=checkpoint_results, 197 | results_key=execute_query_results_key, 198 | limit=None, 199 | plan_cover_num_params=2, 200 | near_optimal_threshold=1.05, 201 | num_params_threshold=0.95, 202 | query_timeout_minimum_speedup_multiplier=2, 203 | ) 204 | 205 | self.assertLen(execution_results[test_util.TEST_QUERY_ID], parameter_count) 206 | self.assertLen(execution_metadata[test_util.TEST_QUERY_ID], 1) 207 | 208 | 209 | if __name__ == "__main__": 210 | absltest.main() 211 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/main_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Main file shared functionality. 17 | 18 | Overall, we strive to keep main files simple. There are cases where multiple 19 | main files may need to do something very similar. This file contains the helper 20 | functions that are shared between main files. 21 | 22 | Please double check when adding more functionality here to see if the 23 | functionality would be better suited in one of the libraries. 24 | """ 25 | 26 | import json 27 | import os 28 | from typing import Any, Dict, List 29 | 30 | from absl import logging 31 | 32 | # Typing aliases. 33 | JSON = Any 34 | 35 | 36 | class MissingVerificationDataError(ValueError): 37 | pass 38 | 39 | 40 | def get_skip_indices( 41 | plan_hints: Dict[str, List[Dict[str, str]]], verification_file: str 42 | ) -> Dict[str, List[int]]: 43 | """Get plan indices to skip for each query based on verification failures. 44 | 45 | Args: 46 | plan_hints: Dict mapping query id to list of hints, each a dict mapping 47 | "hints" to hint string. 48 | verification_file: Path to verification failures file. 49 | 50 | Returns: 51 | Dict mapping query id to list of plan indices to skip. 52 | 53 | Raises: 54 | MissingVerificationDataError: If the verification data doesn't match 55 | the provided query or hints. 56 | """ 57 | query_id_to_skip_indices = {} 58 | if not verification_file: 59 | return query_id_to_skip_indices 60 | 61 | with open(verification_file) as f: 62 | hint_failures = json.load(f) 63 | for query_id in plan_hints: 64 | if query_id not in hint_failures: 65 | raise MissingVerificationDataError( 66 | f"Query {query_id} not in verification file!" 67 | ) 68 | query_id_to_skip_indices[query_id] = [] 69 | for i, hint in enumerate(plan_hints[query_id]): 70 | hint_str = hint["hints"] 71 | if hint_str not in hint_failures[query_id]: 72 | raise MissingVerificationDataError( 73 | f"Missing hint {hint_str} for query {query_id}" 74 | ) 75 | if hint_failures[query_id][hint_str]: 76 | query_id_to_skip_indices[query_id].append(i) 77 | 78 | return query_id_to_skip_indices 79 | 80 | 81 | def print_failure_counts(failure_counts: Dict[str, Dict[str, int]]) -> None: 82 | """Prints out failure summary.""" 83 | logging.info("Printing positive failure counts:") 84 | for query_id in sorted(failure_counts): 85 | plan_failure_count = 0 86 | for hint in sorted(failure_counts[query_id]): 87 | hint_fail_count = failure_counts[query_id][hint] 88 | if hint_fail_count > 0: 89 | plan_failure_count += 1 90 | logging.info("%s num failures for %s: %s ", query_id, hint, 91 | hint_fail_count) 92 | logging.info("Query %s failure ratio: %d/%d", query_id, plan_failure_count, 93 | len(failure_counts[query_id])) 94 | 95 | 96 | def print_hint_counts_by_source(query_id_to_counts: Dict[str, JSON]) -> None: 97 | """Prints out query hint counts summary.""" 98 | logging.info("Printing hints counts by source:") 99 | for query_id in sorted(query_id_to_counts): 100 | plan_hints = query_id_to_counts[query_id] 101 | for i, plan_hint in enumerate(sorted(plan_hints.keys())): 102 | logging.info("%s: number of suggestions for hint %s from source %s: %s", 103 | query_id, i, plan_hints[plan_hint]["source"], 104 | plan_hints[plan_hint]["count"]) 105 | 106 | 107 | class HintAccumulator: 108 | """Accumulates extracted hints across queries and saves them to file.""" 109 | 110 | def __init__(self): 111 | self.query_id_to_counts = {} 112 | self.query_id_to_plan_hints = {} 113 | self.query_id_to_params_plan_indices = {} 114 | self.query_id_to_debug_infos = {} 115 | self.combined_failure_counts = {} 116 | 117 | def save(self, output_dir: str, plans_output_file: str, 118 | verification_failures_file: str, plan_index_suffix: str) -> None: 119 | """Saves the content of the hint accumulator to a set of files. 120 | 121 | Args: 122 | output_dir: The output_dir for produced files. 123 | plans_output_file: The file to store the plans as hints. 124 | verification_failures_file: The file to save verification failures. 125 | plan_index_suffix: The suffix used for plan index file names. 126 | """ 127 | verification_file_directory = os.path.join(output_dir, "verification") 128 | os.makedirs(verification_file_directory, exist_ok=True) 129 | verification_failures_file = os.path.join(verification_file_directory, 130 | verification_failures_file) 131 | with open(verification_failures_file, "w") as outfile: 132 | json.dump(self.combined_failure_counts, outfile) 133 | 134 | with open(os.path.join(output_dir, plans_output_file), "w") as outfile: 135 | json.dump(self.query_id_to_plan_hints, outfile) 136 | 137 | with open( 138 | os.path.join(output_dir, plans_output_file[:-5] + plan_index_suffix), 139 | "w") as outfile: 140 | json.dump(self.query_id_to_params_plan_indices, outfile) 141 | 142 | with open( 143 | os.path.join(output_dir, plans_output_file[:-5] + "_debug_infos.json"), 144 | "w") as outfile: 145 | json.dump(self.query_id_to_debug_infos, outfile) 146 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/parameter_generator_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate parameter bindings for the provided templates. 17 | 18 | The parameter bindings for each template are saved to a different file as 19 | indicated by the output-related flags. 20 | """ 21 | import collections 22 | from concurrent import futures 23 | import functools 24 | import json 25 | import logging 26 | import os 27 | import time 28 | 29 | from absl import app 30 | from absl import flags 31 | 32 | from kepler.training_data_collection_pipeline import parameter_generator 33 | from kepler.training_data_collection_pipeline import query_utils 34 | 35 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 36 | flags.mark_flag_as_required("database") 37 | _USER = flags.DEFINE_string("user", None, "Database username.") 38 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 39 | _HOST = flags.DEFINE_string("host", "localhost", "Database host.") 40 | _SEED = flags.DEFINE_float("seed", 0, "Database random number seed.") 41 | 42 | _TEMPLATE_FILE = flags.DEFINE_string("template_file", None, 43 | "Parameterized query template file.") 44 | flags.mark_flag_as_required("template_file") 45 | _COUNT = flags.DEFINE_integer( 46 | "count", 1000000, "The max number of parameters to generate per query.") 47 | 48 | _COUNTS_OUTPUT_FILE = flags.DEFINE_string( 49 | "counts_output_file", None, 50 | "Output file to store the parameter counts per query.") 51 | flags.mark_flag_as_required("counts_output_file") 52 | _PARAMETERS_OUTPUT_DIR = flags.DEFINE_string( 53 | "parameters_output_dir", None, 54 | "Directory to store parameter values per query.") 55 | flags.mark_flag_as_required("parameters_output_dir") 56 | 57 | _DRY_RUN = flags.DEFINE_bool( 58 | "dry_run", False, 59 | "If true, verify that the parameter generation process works correctly " 60 | "using a single, non-random parameter value. This involves a) verifying " 61 | "that the parameter generation query can be composed from the template " 62 | "query and b) ensuring that the template query executes successfully with " 63 | "the generated parameter value.") 64 | 65 | 66 | def main(unused_argv): 67 | with open(_TEMPLATE_FILE.value) as f: 68 | templates = json.load(f) 69 | 70 | work_list = [] 71 | # Query templates that failed hint verification using the parameters from the 72 | # original Stack benchmark. That is, at least one provided hint was ignored by 73 | # the PG optimizer for at least one parameter binding. 74 | skip_list = ["q3_0", "q3_1", "q3_2"] 75 | for query_id, template in templates.items(): 76 | if query_id not in skip_list: 77 | work_list.append( 78 | parameter_generator.TemplateItem( 79 | query_id=query_id, template=template)) 80 | 81 | database_configuration = query_utils.DatabaseConfiguration( 82 | dbname=_DATABASE.value, 83 | user=_USER.value, 84 | password=_PASSWORD.value, 85 | host=_HOST.value, 86 | seed=_SEED.value) 87 | generator = parameter_generator.ParameterGenerator(database_configuration) 88 | parameter_generation_function = functools.partial( 89 | generator.generate_parameters, _COUNT.value, dry_run=_DRY_RUN.value) 90 | 91 | output_counts = collections.defaultdict(lambda: {}) 92 | # The high-latency work occurs remotely via the database executing queries to 93 | # generate parameters. The number of max workers is limited empirically to 94 | # avoid memory issues on the database side. 95 | with futures.ThreadPoolExecutor(max_workers=3) as executor: 96 | for result in executor.map(parameter_generation_function, work_list): 97 | query_id = next(iter(result)) 98 | logging.info("Finished generating for %s", query_id) 99 | with open( 100 | os.path.join(_PARAMETERS_OUTPUT_DIR.value, 101 | f"{query_id}-{len(result[query_id]['params'])}.json"), 102 | "w") as f: 103 | json.dump(result, f) 104 | 105 | output_counts[query_id] = len(result[query_id]["params"]) 106 | 107 | if _DRY_RUN.value: 108 | # Ensure that the template query executes successfully with the 109 | # generated parameter value. 110 | query_manager = query_utils.QueryManager(database_configuration) 111 | start_ms = int(time.time() * 1e3) 112 | query = result[query_id]["query"] 113 | params = result[query_id]["params"][0] 114 | results = query_manager.execute(query, params) 115 | end_ms = int(time.time() * 1e3) 116 | 117 | print(f"Query {query_id} approximate latency: {end_ms-start_ms} ms") 118 | print(f"Template: {query}") 119 | print(f"Params: {params}") 120 | 121 | # The query should return at least one result. 122 | assert results 123 | 124 | with open(_COUNTS_OUTPUT_FILE.value, "w") as f: 125 | json.dump(output_counts, f) 126 | 127 | 128 | if __name__ == "__main__": 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/pg_distributed_training_execution_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Distributed training data query execution. 17 | 18 | This script speeds up training data collection by executing queries in 19 | parallel. It is recommended to execute this script after computing the plan 20 | cover, which currently requires running the first N queries serially. 21 | """ 22 | import functools 23 | import json 24 | import multiprocessing 25 | import os 26 | from typing import Any, List, Optional, Tuple 27 | 28 | from absl import app 29 | from absl import flags 30 | 31 | from kepler.training_data_collection_pipeline import main_utils 32 | from kepler.training_data_collection_pipeline import pg_execute_training_data_queries 33 | from kepler.training_data_collection_pipeline import query_text_utils 34 | from kepler.training_data_collection_pipeline import query_utils 35 | 36 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 37 | flags.mark_flag_as_required("database") 38 | _USER = flags.DEFINE_string("user", None, "Database username.") 39 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 40 | 41 | _QUERY_TEMPLATES_FILE = flags.DEFINE_string( 42 | "query_templates_file", 43 | None, 44 | "Path to file in which query templates are stored.", 45 | ) 46 | flags.mark_flag_as_required("query_templates_file") 47 | _PARAMETER_VALUES_FILE = flags.DEFINE_string( 48 | "parameter_values_file", None, "Parameter values file." 49 | ) 50 | flags.mark_flag_as_required("parameter_values_file") 51 | _PLAN_HINTS_FILE = flags.DEFINE_string( 52 | "plan_hints_file", None, "Plan hints file." 53 | ) 54 | flags.mark_flag_as_required("plan_hints_file") 55 | 56 | _EXECUTION_METHOD = flags.DEFINE_string( 57 | "execution_method", 58 | "regular", 59 | ( 60 | "Which execution method to use: regular to simply time the latency," 61 | " explain, or explain_analyze." 62 | ), 63 | ) 64 | _ITERATIONS = flags.DEFINE_integer( 65 | "iterations", 66 | 3, 67 | ( 68 | "The number of iterations to execute query (query plan, parameter" 69 | " binding) pairing." 70 | ), 71 | ) 72 | _BATCH_SIZE = flags.DEFINE_integer( 73 | "batch_size", 74 | 10, 75 | ( 76 | "Batch of parameters for which to collect training data before" 77 | " checkpointing." 78 | ), 79 | ) 80 | 81 | _QUERY_TIMEOUT_MULTIPLIER = flags.DEFINE_integer( 82 | "query_timeout_multiplier", 83 | 5, 84 | ( 85 | "This factor is multiplied by the median execution time of the default" 86 | " query plan to provide an upper bound on the query execution time" 87 | " considered 'way too slow' during execution data collection for each" 88 | " set of parameter values. The product will be clipped to" 89 | " [query_timeout_min_ms, query_timeout_max_ms]." 90 | ), 91 | ) 92 | 93 | _QUERY_TIMEOUT_MIN_MS = flags.DEFINE_integer( 94 | "query_timeout_min_ms", 95 | 200, 96 | ( 97 | "The minimum timeout for each query execution to enable setting a low" 98 | " multiplier while balancing the risk of timeouts caused by system" 99 | " noise for very fast queries." 100 | ), 101 | ) 102 | 103 | _QUERY_TIMEOUT_MAX_MS = flags.DEFINE_integer( 104 | "query_timeout_max_ms", 105 | 60 * 1000, 106 | ( 107 | "The maximum timeout for each query execution to provide a hard-cap on" 108 | " the cost of very slow query plans." 109 | ), 110 | ) 111 | 112 | _VERIFICATION_FILE = flags.DEFINE_string( 113 | "verification_file", 114 | None, 115 | ( 116 | "File containing verification results. If specified, we will only" 117 | " execute hints for which there are no failures." 118 | ), 119 | ) 120 | 121 | _NUM_PROCESSES = flags.DEFINE_integer( 122 | "num_processes", 1, "Number of processes to distribute execution over." 123 | ) 124 | 125 | _QUERY = flags.DEFINE_string("query", None, "Specific query id to execute.") 126 | 127 | _OUTPUT_DIR = flags.DEFINE_string( 128 | "output_dir", None, "Directory to store execution results." 129 | ) 130 | flags.mark_flag_as_required("output_dir") 131 | 132 | 133 | def execute_query( 134 | query_manager: query_utils.QueryManager, 135 | query: str, 136 | params: List[Any], 137 | timeout_ms: Optional[int] = None, 138 | ) -> Tuple[Optional[float], Optional[int]]: 139 | return query_manager.execute_timed(query, params, timeout_ms) 140 | 141 | 142 | def execute_query_local( 143 | query_manager: query_utils.QueryManager, 144 | query: str, 145 | params: List[Any], 146 | timeout_ms: Optional[int] = None, 147 | ) -> Tuple[Optional[float], Optional[int]]: 148 | return query_manager.execute_timed_local(query, params, timeout_ms) 149 | 150 | 151 | def execute_explain( 152 | query_manager: query_utils.QueryManager, 153 | query: str, 154 | params: List[Any], 155 | _: Optional[int] = None, 156 | ) -> Tuple[Any, None]: 157 | return query_manager.get_query_plan(query, params), None 158 | 159 | 160 | def execute_explain_analyze( 161 | query_manager: query_utils.QueryManager, 162 | query: str, 163 | params: List[Any], 164 | _: Optional[int] = None, 165 | ) -> Tuple[Any, None]: 166 | return query_manager.get_query_plan_and_execute(query, params), None 167 | 168 | 169 | def _execute_func(args, **kwargs): 170 | batch_index, parameter_values = args 171 | return pg_execute_training_data_queries.execute_training_data_queries( 172 | batch_index, parameter_values, **kwargs 173 | ) 174 | 175 | 176 | def main(unused_argv): 177 | db_config = query_utils.DatabaseConfiguration( 178 | dbname=_DATABASE.value, user=_USER.value, password=_PASSWORD.value 179 | ) 180 | query_manager = query_utils.QueryManager(db_config) 181 | query_utils.save_postgres_config_info(query_manager, _OUTPUT_DIR.value) 182 | 183 | with open(_PLAN_HINTS_FILE.value) as f: 184 | plan_hints = json.load(f) 185 | 186 | query_id_to_skip_indices = main_utils.get_skip_indices( 187 | plan_hints, _VERIFICATION_FILE.value 188 | ) 189 | 190 | with open(_PARAMETER_VALUES_FILE.value) as f: 191 | parameter_values = json.load(f) 192 | 193 | with open(_QUERY_TEMPLATES_FILE.value) as f: 194 | templates = json.load(f) 195 | 196 | query_ids = [_QUERY.value] if _QUERY.value else plan_hints.keys() 197 | 198 | output_subdir = os.path.join(_OUTPUT_DIR.value, "execution_output") 199 | os.makedirs(output_subdir, exist_ok=True) 200 | 201 | def get_filename(query_id: str, is_metadata: bool) -> str: 202 | type_token = "_metadata" if is_metadata else "" 203 | return os.path.join( 204 | output_subdir, f"{_DATABASE.value}_{query_id}{type_token}.json" 205 | ) 206 | 207 | def checkpoint_results( 208 | query_id: str, results: Any, is_metadata: bool 209 | ) -> None: 210 | with open(get_filename(query_id, is_metadata), "w") as f: 211 | json.dump(results, f) 212 | 213 | execution_method_map = { 214 | "regular": (execute_query_local, "duration_ms"), 215 | "explain": (execute_explain, "explain_output"), 216 | "explain_analyze": (execute_explain_analyze, "explain_analyze_output"), 217 | } 218 | execution_method, results_key = execution_method_map[_EXECUTION_METHOD.value] 219 | 220 | for query_id in query_ids: 221 | # Load previous results and metadata. 222 | previous_results = {query_id: {}} 223 | previous_metadata = {query_id: {}} 224 | results_file = get_filename(query_id, False) 225 | metadata_file = get_filename(query_id, True) 226 | exists_results = os.path.exists(results_file) 227 | exists_metadata = os.path.exists(metadata_file) 228 | assert exists_metadata 229 | 230 | if exists_results: 231 | with open(results_file) as f: 232 | previous_results = json.load(f) 233 | with open(metadata_file) as f: 234 | previous_metadata = json.load(f) 235 | 236 | # Compute indices to skip. 237 | skip_indices = set(query_id_to_skip_indices.get(query_id, [])) 238 | plan_cover = previous_metadata[query_id].get("plan_cover", []) 239 | for i in range(len(plan_hints[query_id])): 240 | if i not in plan_cover: 241 | skip_indices.add(i) 242 | 243 | skip_indices = list(skip_indices) 244 | 245 | # Filter previously-executed query parameter values and split into batches. 246 | batch_size = _BATCH_SIZE.value 247 | parameter_value_batches = [] 248 | query_parameter_values = parameter_values[query_id] 249 | 250 | def filter_nonexecuted( 251 | param_set, previous_params=previous_results[query_id] 252 | ): 253 | return ( 254 | query_text_utils.get_params_as_string(param_set["params"]) 255 | not in previous_params 256 | ) 257 | 258 | query_parameter_values = list( 259 | filter(filter_nonexecuted, query_parameter_values) 260 | ) 261 | batch_indexes = list(range(0, len(query_parameter_values), batch_size)) 262 | for i in batch_indexes: 263 | parameter_value_batches.append( 264 | {query_id: query_parameter_values[i : i + batch_size]} 265 | ) 266 | 267 | results = previous_results[query_id] 268 | with multiprocessing.Pool( 269 | processes=_NUM_PROCESSES.value, 270 | initializer=pg_execute_training_data_queries.init_per_process_global_query_manager, 271 | initargs=[db_config], 272 | ) as pool: 273 | kwargs = { 274 | "query_id": query_id, 275 | "templates": templates, 276 | "plan_hints": plan_hints, 277 | "iterations": _ITERATIONS.value, 278 | "batch_size": _BATCH_SIZE.value, 279 | "skip_indices": skip_indices, 280 | "query_timeout_multiplier": _QUERY_TIMEOUT_MULTIPLIER.value, 281 | "query_timeout_min_ms": _QUERY_TIMEOUT_MIN_MS.value, 282 | "query_timeout_max_ms": _QUERY_TIMEOUT_MAX_MS.value, 283 | "execute_query_fn": execution_method, 284 | "checkpoint_results_fn": None, 285 | "results_key": results_key, 286 | "total_num_params": len(query_parameter_values), 287 | "print_skips": False, 288 | } 289 | 290 | for _, (batch_results, _) in enumerate( 291 | pool.imap( 292 | func=functools.partial(_execute_func, **kwargs), 293 | iterable=zip(batch_indexes, parameter_value_batches), 294 | chunksize=1, 295 | ) 296 | ): 297 | results.update(batch_results) 298 | checkpoint_results(query_id, {query_id: results}, False) 299 | 300 | 301 | if __name__ == "__main__": 302 | app.run(main) 303 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/pg_execute_explain_tools_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Executes the collection of EXPLAIN plans for various analyses. 17 | 18 | Results in JSON format are saved to the directory indicated by the output flag. 19 | """ 20 | import json 21 | import os 22 | 23 | from typing import Any 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | from kepler.training_data_collection_pipeline import pg_execute_explain_tools 29 | from kepler.training_data_collection_pipeline import query_utils 30 | 31 | 32 | # Typing aliases. 33 | JSON = Any 34 | 35 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 36 | flags.mark_flag_as_required("database") 37 | _USER = flags.DEFINE_string("user", None, "Database username.") 38 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 39 | _HOST = flags.DEFINE_string("host", "localhost", "Database host.") 40 | _SEED = flags.DEFINE_float("seed", 0, "Database random number seed.") 41 | 42 | _QUERY_TEMPLATES_FILE = flags.DEFINE_string( 43 | "query_templates_file", None, 44 | "Path to file in which query templates are stored.") 45 | flags.mark_flag_as_required("query_templates_file") 46 | _PARAMETER_VALUES_FILE = flags.DEFINE_string("parameter_values_file", None, 47 | "Parameter values file.") 48 | flags.mark_flag_as_required("parameter_values_file") 49 | _PLAN_HINTS_FILE = flags.DEFINE_string("plan_hints_file", None, 50 | "Plan hints file.") 51 | _EXTRACT_FUNCTION = flags.DEFINE_enum_class( 52 | "extract_function", 53 | pg_execute_explain_tools.ExplainExtractionFunction.TOTAL_COSTS.value, 54 | pg_execute_explain_tools.ExplainExtractionFunction, 55 | "Which EXPLAIN extraction function to use.") 56 | _LIMIT = flags.DEFINE_integer( 57 | "limit", 1, "Limit the number of parameters per query to use for analysis.") 58 | _CHUNKSIZE = flags.DEFINE_integer("chunksize", 100, 59 | "Multiprocessing chunksize.") 60 | 61 | _QUERY = flags.DEFINE_string("query", None, "Specific query id to analyze.") 62 | 63 | _OUTPUT_DIR = flags.DEFINE_string("output_dir", None, 64 | "Directory to store execution results.") 65 | flags.mark_flag_as_required("output_dir") 66 | 67 | 68 | def _save(data: JSON, output_subdir: str, query_id: str, 69 | filename_prefix: str) -> None: 70 | with open( 71 | os.path.join(output_subdir, 72 | f"{filename_prefix}_{_DATABASE.value}_{query_id}.json"), 73 | "w") as f: 74 | json.dump(data, f) 75 | 76 | 77 | def _calculate_plan_changing_cardinality_estimates( 78 | output_subdir: str, 79 | database_configuration: query_utils.DatabaseConfiguration, query_id: str, 80 | templates: JSON, parameter_values: JSON) -> None: 81 | """Calculates plan changing cardinalty estimates and writes out the result.""" 82 | results = pg_execute_explain_tools.calculate_plan_changing_cardinality_estimates( 83 | database_configuration=database_configuration, 84 | query_id=query_id, 85 | templates=templates, 86 | parameter_values=parameter_values, 87 | limit=_LIMIT.value, 88 | multiprocessing_chunksize=_CHUNKSIZE.value) 89 | 90 | _save( 91 | data=results, 92 | output_subdir=output_subdir, 93 | query_id=query_id, 94 | filename_prefix="plan_changing_cardinality_estimates") 95 | 96 | 97 | def _collect_explain_plan_info( 98 | output_subdir: str, 99 | database_configuration: query_utils.DatabaseConfiguration, query_id: str, 100 | templates: JSON, parameter_values: JSON, plan_hints: JSON, 101 | extract_function: pg_execute_explain_tools.ExplainExtractionFunction 102 | ) -> None: 103 | """Collects explain plan info and writes out the result.""" 104 | results = pg_execute_explain_tools.collect_explain_plan_info( 105 | database_configuration=database_configuration, 106 | query_id=query_id, 107 | templates=templates, 108 | parameter_values=parameter_values, 109 | plan_hints=plan_hints, 110 | extract_function=extract_function, 111 | limit=_LIMIT.value, 112 | multiprocessing_chunksize=_CHUNKSIZE.value) 113 | 114 | _save( 115 | data=results, 116 | output_subdir=output_subdir, 117 | query_id=query_id, 118 | filename_prefix="explain_plan_total_costs") 119 | 120 | 121 | def main(unused_argv): 122 | database_configuration = query_utils.DatabaseConfiguration( 123 | dbname=_DATABASE.value, 124 | user=_USER.value, 125 | password=_PASSWORD.value, 126 | host=_HOST.value, 127 | seed=_SEED.value) 128 | query_manager = query_utils.QueryManager(database_configuration) 129 | query_utils.save_postgres_config_info(query_manager, _OUTPUT_DIR.value) 130 | 131 | output_subdir = os.path.join(_OUTPUT_DIR.value, "explain_output") 132 | os.makedirs(output_subdir, exist_ok=True) 133 | 134 | with open(_PARAMETER_VALUES_FILE.value) as f: 135 | parameter_values = json.load(f) 136 | 137 | with open(_QUERY_TEMPLATES_FILE.value) as f: 138 | templates = json.load(f) 139 | 140 | plan_hints = None 141 | if _PLAN_HINTS_FILE.value: 142 | with open(_PLAN_HINTS_FILE.value) as f: 143 | plan_hints = json.load(f) 144 | 145 | query_ids = [_QUERY.value] if _QUERY.value else parameter_values.keys() 146 | 147 | for query_id in query_ids: 148 | if plan_hints: 149 | _collect_explain_plan_info( 150 | output_subdir=output_subdir, 151 | database_configuration=database_configuration, 152 | query_id=query_id, 153 | templates=templates, 154 | parameter_values=parameter_values, 155 | plan_hints=plan_hints, 156 | extract_function=_EXTRACT_FUNCTION.value) 157 | else: 158 | _calculate_plan_changing_cardinality_estimates( 159 | output_subdir=output_subdir, 160 | database_configuration=database_configuration, 161 | query_id=query_id, 162 | templates=templates, 163 | parameter_values=parameter_values) 164 | 165 | 166 | if __name__ == "__main__": 167 | app.run(main) 168 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/pg_generate_plan_candidates_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Script for generating a set of query plans to use for a query template. 17 | 18 | This step of the pipeline takes parameter values and query templates, and 19 | produces a set of query plans as execution options for the query template. The 20 | set of query plans are represented as pg_hint_plan hints. 21 | 22 | The script supports several methods for generating the plan candidates, listed 23 | in the _GENERATION_FUNCTION_MAP. 24 | """ 25 | import enum 26 | import json 27 | import os 28 | 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | 33 | from kepler.training_data_collection_pipeline import main_utils 34 | from kepler.training_data_collection_pipeline import pg_generate_plan_candidates 35 | from kepler.training_data_collection_pipeline import pg_plan_hint_extractor 36 | from kepler.training_data_collection_pipeline import query_utils 37 | 38 | 39 | class GenerationFunction(enum.Enum): 40 | PG_CONFIGS = "pg_configs" 41 | ROW_NUM_EVOLUTION = "row_num_evolution" 42 | EXHAUSTIVE_CARDINALITY_PERTURBATIONS = "exhaustive_cardinality_perturbations" 43 | 44 | 45 | _GENERATION_FUNCTION_MAP = { 46 | GenerationFunction.PG_CONFIGS: 47 | pg_generate_plan_candidates.get_query_plans, 48 | GenerationFunction.ROW_NUM_EVOLUTION: 49 | pg_generate_plan_candidates.generate_by_row_num_evolution, 50 | GenerationFunction.EXHAUSTIVE_CARDINALITY_PERTURBATIONS: 51 | pg_generate_plan_candidates 52 | .generate_by_exhaustive_cardinality_perturbations 53 | } 54 | 55 | 56 | def _supports_distributed_execution( 57 | generation_function: GenerationFunction) -> bool: 58 | return generation_function != GenerationFunction.EXHAUSTIVE_CARDINALITY_PERTURBATIONS 59 | 60 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 61 | flags.mark_flag_as_required("database") 62 | _USER = flags.DEFINE_string("user", None, "Database username.") 63 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 64 | _HOST = flags.DEFINE_string("host", "localhost", "Database host.") 65 | 66 | _QUERY_PARAMS_FILE = flags.DEFINE_string( 67 | "query_params_file", None, 68 | "File containing parameterized queries with list of parameter values.") 69 | flags.mark_flag_as_required("query_params_file") 70 | _PARAMS_LIMIT = flags.DEFINE_integer( 71 | "params_limit", None, 72 | "The number of parameter values to use when generating plans.") 73 | _OUTPUT_DIR = flags.DEFINE_string( 74 | "output_dir", None, 75 | "Directory in which to store query plan hints and configs.") 76 | flags.mark_flag_as_required("output_dir") 77 | 78 | _PLANS_OUTPUT_FILE = flags.DEFINE_string( 79 | "plans_output_file", None, 80 | "File to store distinct plans per query. The file name is expected to end in .json" 81 | ) 82 | flags.mark_flag_as_required("plans_output_file") 83 | _PLAN_INDEX_SUFFIX = flags.DEFINE_string( 84 | "plan_index_suffix", "_plan_index.json", 85 | "Suffix of files to store plan indices in.") 86 | _VERIFICATION_FAILURES_FILE = flags.DEFINE_string( 87 | "verification_failures_file", "verification_failures.json", 88 | "Filename of file to save verification failures.") 89 | _CHUNKSIZE = flags.DEFINE_integer( 90 | "chunksize", 100, "How many params to include in each subprocess chunk.") 91 | _KEYS_TO_REMOVE = flags.DEFINE_list( 92 | "keys_to_remove", [], 93 | ("List of keys to filter from EXPLAIN plan JSON. Good candidates include " 94 | "\"Parallel Aware\", \"Relation Name\", \"Parent Relationship\"")) 95 | 96 | _GENERATION_FUNCTION = flags.DEFINE_enum_class( 97 | "generation_function", GenerationFunction.PG_CONFIGS.value, 98 | GenerationFunction, "Which plan generation function to use.") 99 | _SOFT_TOTAL_PLANS_LIMIT = flags.DEFINE_integer( 100 | "soft_total_plans_limit", None, 101 | "Soft limit on total number of plans to produce." 102 | ) 103 | # Pg configs flags. 104 | _CONFIG_STR = flags.DEFINE_string( 105 | "configs", "", 106 | "Comma-separated string of Postgres optimizer configuration parameters to toggle off." 107 | ) 108 | # Row number evolution flags. 109 | _MAX_PLANS_PER_PARAM = flags.DEFINE_integer( 110 | "max_plans_per_param", None, 111 | "Stop evolution after this number of plans is exceeded.") 112 | _NUM_GENERATIONS = flags.DEFINE_integer( 113 | "num_generations", 3, "Number of generations of row number evolution.") 114 | _NUM_MUTATIONS_PER_PLAN = flags.DEFINE_integer( 115 | "num_mutations_per_plan", 25, "Number of random mutations for each plan.") 116 | _EXPONENT_BASE = flags.DEFINE_integer( 117 | "exponent_base", 10, "Base of exponential row number perturbations.") 118 | _EXPONENT_RANGE = flags.DEFINE_integer( 119 | "exponent_range", 3, "One-sided range of exponent of perturbations.") 120 | _MAX_PLANS_PER_GENERATION = flags.DEFINE_integer( 121 | "max_plans_per_generation", 20, 122 | "Max number of plans to mutate per generation.") 123 | _PERTURB_UNIT_ONLY = flags.DEFINE_bool( 124 | "perturb_unit_only", True, 125 | "Whether to perturb only row counts exactly equal to one." 126 | ) 127 | _MAX_PERTURBS_PER_JOIN = flags.DEFINE_integer( 128 | "max_perturbs_per_join", 1, 129 | "Limit on how many times a specific join can be perturbed." 130 | ) 131 | # Exhaustive cardinality perturbation flags. 132 | _CARDINALITY_MULTIPLIERS = flags.DEFINE_list( 133 | "cardinality_multipliers", None, 134 | "List of cardinality multipliers to apply when generating plans.") 135 | 136 | 137 | def main(unused_argv): 138 | configs = _CONFIG_STR.value.split(",") if _CONFIG_STR.value else [] 139 | 140 | with open(_QUERY_PARAMS_FILE.value) as json_file: 141 | info = json.load(json_file) 142 | 143 | hints_output_dir = os.path.join(_OUTPUT_DIR.value, _DATABASE.value) 144 | os.makedirs(hints_output_dir, exist_ok=True) 145 | 146 | database_configuration = query_utils.DatabaseConfiguration( 147 | dbname=_DATABASE.value, 148 | user=_USER.value, 149 | password=_PASSWORD.value, 150 | host=_HOST.value) 151 | query_manager = query_utils.QueryManager(database_configuration) 152 | query_utils.save_postgres_config_info(query_manager, _OUTPUT_DIR.value) 153 | 154 | hint_accumulator = main_utils.HintAccumulator() 155 | for query_id, query_metadata in info.items(): 156 | logging.info("Start: %s", query_id) 157 | 158 | output = {} 159 | output["output"] = {} 160 | 161 | function_kwargs = { 162 | "database_configuration": database_configuration, 163 | "query": query_metadata["query"], 164 | "keys_to_remove": _KEYS_TO_REMOVE.value 165 | } 166 | 167 | # Augment kwargs depending on generation function. 168 | if _GENERATION_FUNCTION.value == GenerationFunction.PG_CONFIGS: 169 | function_kwargs["configs"] = configs 170 | elif _GENERATION_FUNCTION.value == GenerationFunction.ROW_NUM_EVOLUTION: 171 | function_kwargs.update({ 172 | "max_plans": _MAX_PLANS_PER_PARAM.value, 173 | "num_generations": _NUM_GENERATIONS.value, 174 | "num_mutations_per_plan": _NUM_MUTATIONS_PER_PLAN.value, 175 | "exponent_base": _EXPONENT_BASE.value, 176 | "exponent_range": _EXPONENT_RANGE.value, 177 | "max_plans_per_generation": _MAX_PLANS_PER_GENERATION.value, 178 | "perturb_unit_only": _PERTURB_UNIT_ONLY.value, 179 | "max_perturbs_per_join": _MAX_PERTURBS_PER_JOIN.value 180 | }) 181 | elif _GENERATION_FUNCTION.value == GenerationFunction.EXHAUSTIVE_CARDINALITY_PERTURBATIONS: 182 | cardinality_multipliers = [ 183 | float(multiplier) for multiplier in _CARDINALITY_MULTIPLIERS.value 184 | ] 185 | 186 | function_kwargs.update( 187 | {"cardinality_multipliers": cardinality_multipliers}) 188 | 189 | if _PARAMS_LIMIT.value: 190 | query_metadata["params"] = query_metadata["params"][:_PARAMS_LIMIT.value] 191 | 192 | plan_hint_extractor = pg_plan_hint_extractor.PlanHintExtractor() 193 | pg_generate_plan_candidates.execute_plan_generation( 194 | _GENERATION_FUNCTION_MAP[_GENERATION_FUNCTION.value], 195 | function_kwargs, 196 | query_metadata["params"], 197 | plan_hint_extractor=plan_hint_extractor, 198 | chunksize=_CHUNKSIZE.value, 199 | distributed=_supports_distributed_execution(_GENERATION_FUNCTION.value), 200 | soft_total_plans_limit=_SOFT_TOTAL_PLANS_LIMIT.value) 201 | counts, plan_hints, params_plan_indices, debug_infos = ( 202 | plan_hint_extractor.get_consolidated_plan_hints()) 203 | 204 | hint_accumulator.query_id_to_counts[query_id] = counts 205 | hint_accumulator.query_id_to_plan_hints[query_id] = plan_hints 206 | hint_accumulator.query_id_to_params_plan_indices[ 207 | query_id] = params_plan_indices 208 | hint_accumulator.query_id_to_debug_infos[query_id] = debug_infos 209 | 210 | failure_counts = pg_plan_hint_extractor.verify_hints( 211 | query_id=query_id, 212 | query=query_metadata["query"], 213 | plan_hints=plan_hints, 214 | params_plan_indices=params_plan_indices, 215 | database_configuration=database_configuration) 216 | hint_accumulator.combined_failure_counts.update(failure_counts) 217 | 218 | main_utils.print_failure_counts(hint_accumulator.combined_failure_counts) 219 | main_utils.print_hint_counts_by_source(hint_accumulator.query_id_to_counts) 220 | 221 | hint_accumulator.save( 222 | output_dir=hints_output_dir, 223 | plans_output_file=_PLANS_OUTPUT_FILE.value, 224 | verification_failures_file=_VERIFICATION_FAILURES_FILE.value, 225 | plan_index_suffix=_PLAN_INDEX_SUFFIX.value) 226 | 227 | if __name__ == "__main__": 228 | app.run(main) 229 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/pg_generate_plan_costs_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Orchestrates training data collection for Kepler. 17 | 18 | The tool perturbs join cardinality estimates and collects the corresponding 19 | EXPLAIN plans which include the total cost. This data is intended for 20 | cardinality-estimate representation learning. 21 | 22 | Results in JSON format are saved to the directory indicated by the output flag. 23 | """ 24 | import json 25 | import os 26 | 27 | from absl import app 28 | from absl import flags 29 | 30 | from kepler.training_data_collection_pipeline import pg_perturb_plan_cardinalities 31 | from kepler.training_data_collection_pipeline import query_utils 32 | 33 | _DATABASE = flags.DEFINE_string("database", None, "Database name.") 34 | flags.mark_flag_as_required("database") 35 | _USER = flags.DEFINE_string("user", None, "Database username.") 36 | _PASSWORD = flags.DEFINE_string("password", None, "Database password.") 37 | _HOST = flags.DEFINE_string("host", "localhost", "Database host.") 38 | _SEED = flags.DEFINE_float("seed", 0, "Database random number seed.") 39 | 40 | _QUERY_TEMPLATES_FILE = flags.DEFINE_string( 41 | "query_templates_file", None, 42 | "Path to file in which query templates are stored.") 43 | flags.mark_flag_as_required("query_templates_file") 44 | _PARAMETER_VALUES_FILE = flags.DEFINE_string("parameter_values_file", None, 45 | "Parameter values file.") 46 | flags.mark_flag_as_required("parameter_values_file") 47 | _PLAN_HINTS_FILE = flags.DEFINE_string("plan_hints_file", None, 48 | "Plan hints file.") 49 | flags.mark_flag_as_required("plan_hints_file") 50 | 51 | _CARDINALITY_MULTIPLIERS = flags.DEFINE_list( 52 | "cardinality_multipliers", None, 53 | "List of cardinality multipliers to apply when generating plans.") 54 | flags.mark_flag_as_required("cardinality_multipliers") 55 | 56 | _LIMIT = flags.DEFINE_integer( 57 | "limit", 1, 58 | "Limit the number of parameters per query to gather cost estimates for.") 59 | 60 | _KEYS_TO_REMOVE = flags.DEFINE_list( 61 | "keys_to_remove", [], 62 | ("List of keys to filter from EXPLAIN plan JSON. Good candidates include " 63 | "\"Parallel Aware\", \"Relation Name\", \"Parent Relationship\"")) 64 | 65 | _QUERY = flags.DEFINE_string("query", None, "Specific query id to execute.") 66 | 67 | _OUTPUT_DIR = flags.DEFINE_string("output_dir", None, 68 | "Directory to store execution results.") 69 | flags.mark_flag_as_required("output_dir") 70 | 71 | 72 | def main(unused_argv): 73 | query_manager = query_utils.QueryManager( 74 | query_utils.DatabaseConfiguration( 75 | dbname=_DATABASE.value, user=_USER.value, password=_PASSWORD.value)) 76 | query_utils.save_postgres_config_info(query_manager, _OUTPUT_DIR.value) 77 | 78 | output_subdir = os.path.join(_OUTPUT_DIR.value, "execution_output") 79 | os.makedirs(output_subdir, exist_ok=True) 80 | 81 | with open(_PLAN_HINTS_FILE.value) as f: 82 | plan_hints = json.load(f) 83 | 84 | with open(_PARAMETER_VALUES_FILE.value) as f: 85 | parameter_values = json.load(f) 86 | 87 | with open(_QUERY_TEMPLATES_FILE.value) as f: 88 | templates = json.load(f) 89 | 90 | cardinality_multipliers = [ 91 | float(multiplier) for multiplier in _CARDINALITY_MULTIPLIERS.value 92 | ] 93 | 94 | query_ids = [_QUERY.value] if _QUERY.value else plan_hints.keys() 95 | 96 | for query_id in query_ids: 97 | results = pg_perturb_plan_cardinalities.multiplicatively_perturb_plan_cardinalities( 98 | query_manager=query_manager, 99 | query_id=query_id, 100 | templates=templates, 101 | parameter_values=parameter_values, 102 | plan_hints=plan_hints, 103 | cardinality_multipliers=cardinality_multipliers, 104 | limit=_LIMIT.value, 105 | keys_to_remove=_KEYS_TO_REMOVE.value) 106 | 107 | with open( 108 | os.path.join(output_subdir, f"{_DATABASE.value}_{query_id}.json"), 109 | "w") as f: 110 | json.dump(results, f) 111 | 112 | 113 | if __name__ == "__main__": 114 | app.run(main) 115 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/query_plan_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Query plan navigation and manipulation utilties.""" 17 | 18 | from typing import Any, Optional, Sequence 19 | 20 | JSON = Any 21 | 22 | 23 | def filter_keys(plan: JSON, 24 | keys_to_remove: Optional[Sequence[str]] = None) -> JSON: 25 | """Removes extraneous keys from plan JSON to reduce file size. 26 | 27 | Args: 28 | plan: Explain plan in JSON format. 29 | keys_to_remove: Keys to remove from the plan. 30 | 31 | Returns: 32 | A new JSON with the specified keys removed. 33 | """ 34 | if not keys_to_remove: 35 | return plan 36 | 37 | if isinstance(plan, list): 38 | return [filter_keys(p, keys_to_remove) for p in plan] 39 | elif isinstance(plan, dict): 40 | new_plan = {} 41 | for k, v in plan.items(): 42 | if k in keys_to_remove: 43 | continue 44 | if isinstance(v, (dict, list)): 45 | new_plan[k] = filter_keys(v, keys_to_remove) 46 | else: 47 | new_plan[k] = v 48 | return new_plan 49 | else: 50 | return plan 51 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/query_plan_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for query_plan_utils.""" 17 | 18 | import json 19 | import os 20 | 21 | # Elided resources import 22 | from kepler.training_data_collection_pipeline import query_plan_utils 23 | from absl.testing import absltest 24 | 25 | _TEST_DATA_DIR = "kepler/training_data_collection_pipeline/testdata" 26 | 27 | 28 | class QueryPlanUtilsTest(absltest.TestCase): 29 | 30 | def test_filter_plan_keys(self): 31 | remove_keys = [ 32 | "Parallel Aware", "Relation Name", "Startup Cost", 33 | "Parent Relationship", "Inner Unique", "Join Filter", "Filter", 34 | "Plan Rows", "Plan Width", "Total Cost" 35 | ] 36 | 37 | plan_path = os.path.join(_TEST_DATA_DIR, 38 | "generate_candidates_explain_plans.json") 39 | test_query_explain_plans = json.loads(open(plan_path, mode='rb').read()) 40 | 41 | filtered_plan_path = os.path.join( 42 | _TEST_DATA_DIR, "filtered_generate_candidates_explain_plans.json") 43 | expected_filtered_query_explain_plans = json.loads( 44 | open(filtered_plan_path, mode='rb').read()) 45 | 46 | filtered_plans = query_plan_utils.filter_keys(test_query_explain_plans, 47 | remove_keys) 48 | self.assertEqual(filtered_plans, expected_filtered_query_explain_plans) 49 | 50 | 51 | if __name__ == "__main__": 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/query_printer_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Outputs instances of query templates populated with bind parameters.""" 17 | import json 18 | 19 | from absl import app 20 | from absl import flags 21 | from tensorflow.io import gfile 22 | 23 | from kepler.training_data_collection_pipeline import query_text_utils 24 | 25 | _PARAMETER_VALUES_FILE = flags.DEFINE_string("parameter_values_file", None, 26 | "Parameter values file.") 27 | flags.mark_flag_as_required("parameter_values_file") 28 | 29 | _LIMIT = flags.DEFINE_integer("limit", 100, 30 | "Limit the number of parameters to instantiate.") 31 | 32 | _OUTPUT_FILE = flags.DEFINE_string( 33 | "output_file", None, 34 | "File to append the instantiated queries. If no file is provided, the " 35 | "instantiated queries will be printed to the console instead.") 36 | 37 | 38 | def main(unused_argv): 39 | with gfile.GFile(_PARAMETER_VALUES_FILE.value) as f: 40 | parameter_values = json.load(f) 41 | 42 | if len(parameter_values) != 1: 43 | raise ValueError( 44 | "The parameter_values_file should contain contents for a single query" 45 | f"id, but found the follow query ids instead: {parameter_values.keys()}" 46 | ) 47 | 48 | query_id = next(iter(parameter_values)) 49 | query_info = parameter_values[query_id] 50 | 51 | query_instances = [ 52 | query_text_utils.substitute_query_params(query_info["query"], params) 53 | for params in query_info["params"][:_LIMIT.value] 54 | ] 55 | query_instances = [ 56 | query_instance + 57 | ";" if query_instance.strip()[-1] != ";" else query_instance 58 | for query_instance in query_instances 59 | ] 60 | 61 | description = f"-- Queries for {query_id}" 62 | if _OUTPUT_FILE.value: 63 | with gfile.GFile(_OUTPUT_FILE.value, "a") as f: 64 | f.write(description + "\n") 65 | for query_instance in query_instances: 66 | f.write(f"{query_instance}\n\n") 67 | 68 | else: 69 | for query_instance in query_instances: 70 | print(f"{query_instance}\n") 71 | 72 | 73 | if __name__ == "__main__": 74 | app.run(main) 75 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/query_text_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Provides simple generic utils.""" 17 | 18 | from typing import Any, List, Sequence 19 | 20 | 21 | # TODO(b/199162711): Transition this script and downstream analysis scripts to a 22 | # structured format instead of using _NAME_DELIMITER. 23 | _NAME_DELIMITER = '####' 24 | 25 | 26 | def get_params_as_string(params: List[Any]) -> str: 27 | return _NAME_DELIMITER.join([str(p) for p in params]) 28 | 29 | 30 | def substitute_query_params(query: str, params: Sequence[Any]) -> str: 31 | # Replace in reverse order so that eg param1 does not interfere with 32 | # param12 via a substring match. 33 | if params: 34 | for i in range(len(params) - 1, -1, -1): 35 | query = query.replace(f'@param{i}', str(params[i])) 36 | 37 | return query 38 | 39 | 40 | def get_hinted_query(query: str, hints: str) -> str: 41 | return f'{hints} {query}' 42 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/query_text_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils.""" 17 | from typing import Any, List 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from kepler.training_data_collection_pipeline import query_text_utils 22 | 23 | 24 | class QueryManagerPostgresTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | dict( 28 | testcase_name="trivial no-op", 29 | query_template="SELECT foo FROM bar where x > 0", 30 | params=[], 31 | expected="SELECT foo FROM bar where x > 0"), 32 | dict( 33 | testcase_name="extra params no-op", 34 | query_template="SELECT foo FROM bar where x > 0", 35 | params=[1], 36 | expected="SELECT foo FROM bar where x > 0"), 37 | dict( 38 | testcase_name="insufficient params no-op", 39 | query_template="SELECT foo FROM bar where x > @param0", 40 | params=[], 41 | expected="SELECT foo FROM bar where x > @param0"), 42 | dict( 43 | testcase_name="simple substitution", 44 | query_template="SELECT foo FROM bar where x > @param0", 45 | params=[1], 46 | expected="SELECT foo FROM bar where x > 1"), 47 | # The query template with @param1 and @param10 demonstrates that the 48 | # direction of the naive param# replacement matters and needs to be done 49 | # in reverse to avoid false matches. 50 | dict( 51 | testcase_name="complex substitution", 52 | query_template=( 53 | "SELECT foo FROM bar where x > @param0 and x > @param1 and " 54 | "x > @param2 and x > @param3 and x > @param4 and x > @param5 and " 55 | "x > @param6 and x > @param7 and x > @param8 and x > @param9 and " 56 | "x > @param10"), 57 | params=list(range(1, 12)), 58 | expected=( 59 | "SELECT foo FROM bar where x > 1 and x > 2 and x > 3 and x > 4 " 60 | "and x > 5 and x > 6 and x > 7 and x > 8 and x > 9 and x > 10 " 61 | "and x > 11"))) 62 | def test_substitute_query_params(self, query_template: str, params: List[Any], 63 | expected: str): 64 | """Verifies query template parameter substitution. 65 | 66 | Args: 67 | query_template: The query template that may contain 0 to many 68 | parameterized predicates. 69 | params: The parameter values to substitute into the query. 70 | expected: The resulting query instance with the params substituted in when 71 | possible. 72 | """ 73 | self.assertEqual( 74 | query_text_utils.substitute_query_params(query_template, params), 75 | expected) 76 | 77 | def test_get_hinted_query(self): 78 | hints = "/*+ SeqScan(foo) */" 79 | query = "SELECT x from foo" 80 | self.assertEqual( 81 | query_text_utils.get_hinted_query(query=query, hints=hints), 82 | "/*+ SeqScan(foo) */ SELECT x from foo") 83 | 84 | 85 | if __name__ == "__main__": 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/test_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions to simplify unit tests. 17 | 18 | For simplicity, these tests assume a database called 'test' already exists. 19 | 20 | """ 21 | 22 | import time 23 | import psycopg2 24 | 25 | from kepler.training_data_collection_pipeline import query_utils 26 | 27 | TEST_QUERY_ID = 'q0_0' 28 | TEST_TEMPLATE_STRING = """ 29 | { 30 | "query": "SELECT x, y, a, b, d_date \\n FROM \\n foo JOIN bar on x = b WHERE \\n a < @param0 \\n and bar.c = '@param1';", 31 | "predicates": [ 32 | { 33 | "alias": "bar", 34 | "column": "c", 35 | "operator": "=" 36 | } 37 | ] 38 | } 39 | """ 40 | 41 | HINTS_0 = """{"hints": "/*+ MergeJoin(foo bar) Leading((foo bar)) */", "source": "default"}""" 42 | HINTS_1 = """{"hints": "/*+ HashJoin(foo bar) Leading((foo bar)) */", "source": "default"}""" 43 | 44 | _DBNAME = 'test' 45 | USER = 'test' 46 | PASSWORD = 'test' 47 | 48 | _SCHEMA = """ 49 | CREATE TABLE foo(x int, y int); 50 | CREATE TABLE bar(a int, b int, c varchar, d_date date, e_date timestamp, website_url varchar); 51 | CREATE TABLE baz(j int, k int, l varchar); 52 | """ 53 | 54 | _DATA_FOO = [ 55 | [1, -1], 56 | [1, 2], 57 | [1, 3], 58 | [2, 2], 59 | [2, 4], 60 | ] 61 | 62 | _DATA_BAR = [ 63 | [ 64 | 1, 1, 'alfa', '2021-09-19', '2022-02-07 14:28:59.473387-08', 65 | 'https://hello.com' 66 | ], 67 | [ 68 | 1, 2, 'bravo', '2021-09-21', '2022-02-08 14:28:59.473387-08', 69 | 'http://goodbye.org/methods' 70 | ], 71 | [ 72 | 1, 3, 'bravo', '2021-09-22', '2022-02-07 15:28:59.473387-08', 73 | 'http://www.goodnight.org' 74 | ], 75 | [ 76 | 3, 2, 'charlie', '2021-09-20', '2022-02-07 13:28:59.473387-08', 77 | 'http://www.goodmorning.com' 78 | ], 79 | [3, 2, None, None, None, None], 80 | ] 81 | 82 | _DATA_BAZ = [ 83 | [1, 3, 'single-string'], 84 | [1, 4, 'split string'], 85 | [2, 4, 'single-string'], 86 | [3, 5, 'split string'], 87 | ] 88 | 89 | _TABLE_LIST = ['foo', 'bar', 'baz'] 90 | 91 | 92 | def populate_database(query_manager: query_utils.QueryManager): 93 | """Create schema and insert data for test cases.""" 94 | query_manager.execute_and_commit(_SCHEMA) 95 | 96 | columns_list = ['x, y', 'a, b, c, d_date, e_date, website_url', 'j, k, l'] 97 | data_list = [_DATA_FOO, _DATA_BAR, _DATA_BAZ] 98 | 99 | for table, columns, data in zip(_TABLE_LIST, columns_list, data_list): 100 | for row in data: 101 | row_values_as_strings = [ 102 | f"'{str(value)}'" if value else 'NULL' for value in row 103 | ] 104 | query_manager.execute_and_commit( 105 | f"INSERT INTO {table} ({columns}) VALUES ({','.join(row_values_as_strings)})" 106 | ) 107 | 108 | 109 | class TestDatabase: 110 | """Manages the creating and dropping of a test-case-specific database.""" 111 | 112 | def __init__(self): 113 | """Creates test database. 114 | 115 | Connects to the preexisting test database to create a new database with a 116 | unique name. 117 | 118 | Attributes: 119 | dbname: The name of the newly created database. This name will be unique 120 | each time a TestDatabase is created. 121 | """ 122 | self.dbname = f'test_{time.time_ns()}' 123 | 124 | conn = psycopg2.connect( 125 | dbname=_DBNAME, user=USER, password=PASSWORD, host='localhost') 126 | # Enabling AUTOCOMMIT is required to execute CREATE DATABASE. 127 | conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) 128 | self._cursor = conn.cursor() 129 | self._cursor.execute(f'CREATE DATABASE {self.dbname};') 130 | 131 | # Complete setup of the newly created database per README instructions. 132 | query_manager = query_utils.QueryManager( 133 | query_utils.DatabaseConfiguration( 134 | dbname=self.dbname, user=USER, password=PASSWORD)) 135 | query_manager.execute_and_commit('CREATE EXTENSION pg_stat_statements;') 136 | 137 | def drop(self): 138 | self._cursor.execute('SELECT pg_terminate_backend(pg_stat_activity.pid) ' 139 | 'FROM pg_stat_activity ' 140 | f"WHERE pg_stat_activity.datname = '{self.dbname}';") 141 | self._cursor.execute(f'DROP DATABASE {self.dbname};') 142 | -------------------------------------------------------------------------------- /kepler/training_data_collection_pipeline/testdata/filtered_generate_candidates_explain_plans.json: -------------------------------------------------------------------------------- 1 | { 2 | "output": { 3 | "q0_0": [ 4 | { 5 | "params": [ 6 | "https://hello.com" 7 | ], 8 | "result": { 9 | "Plan": { 10 | "Node Type": "Hash Join", 11 | "Join Type": "Inner", 12 | "Hash Cond": "(foo.x = bar.b)", 13 | "Plans": [ 14 | { 15 | "Node Type": "Seq Scan", 16 | "Alias": "foo" 17 | }, 18 | { 19 | "Node Type": "Hash", 20 | "Plans": [ 21 | { 22 | "Node Type": "Seq Scan", 23 | "Alias": "bar" 24 | } 25 | ] 26 | } 27 | ] 28 | } 29 | }, 30 | "additional_plans": [ 31 | { 32 | "Plan": { 33 | "Node Type": "Nested Loop", 34 | "Join Type": "Inner", 35 | "Plans": [ 36 | { 37 | "Node Type": "Seq Scan", 38 | "Alias": "bar" 39 | }, 40 | { 41 | "Node Type": "Seq Scan", 42 | "Alias": "foo" 43 | } 44 | ] 45 | } 46 | }, 47 | { 48 | "Plan": { 49 | "Node Type": "Hash Join", 50 | "Join Type": "Inner", 51 | "Hash Cond": "(foo.x = bar.b)", 52 | "Plans": [ 53 | { 54 | "Node Type": "Seq Scan", 55 | "Alias": "foo" 56 | }, 57 | { 58 | "Node Type": "Hash", 59 | "Plans": [ 60 | { 61 | "Node Type": "Seq Scan", 62 | "Alias": "bar" 63 | } 64 | ] 65 | } 66 | ] 67 | } 68 | }, 69 | { 70 | "Plan": { 71 | "Node Type": "Nested Loop", 72 | "Join Type": "Inner", 73 | "Plans": [ 74 | { 75 | "Node Type": "Seq Scan", 76 | "Alias": "bar" 77 | }, 78 | { 79 | "Node Type": "Seq Scan", 80 | "Alias": "foo" 81 | } 82 | ] 83 | } 84 | } 85 | ], 86 | "sources": [ 87 | [ 88 | "enable_hashjoin" 89 | ], 90 | [ 91 | "enable_mergejoin" 92 | ], 93 | [ 94 | "enable_hashjoin", 95 | "enable_mergejoin" 96 | ] 97 | ] 98 | }, 99 | { 100 | "params": [ 101 | "http://goodbye.org/methods" 102 | ], 103 | "result": { 104 | "Plan": { 105 | "Node Type": "Hash Join", 106 | "Join Type": "Inner", 107 | "Hash Cond": "(foo.x = bar.b)", 108 | "Plans": [ 109 | { 110 | "Node Type": "Seq Scan", 111 | "Alias": "foo" 112 | }, 113 | { 114 | "Node Type": "Hash", 115 | "Plans": [ 116 | { 117 | "Node Type": "Seq Scan", 118 | "Alias": "bar" 119 | } 120 | ] 121 | } 122 | ] 123 | } 124 | }, 125 | "additional_plans": [ 126 | { 127 | "Plan": { 128 | "Node Type": "Nested Loop", 129 | "Join Type": "Inner", 130 | "Plans": [ 131 | { 132 | "Node Type": "Seq Scan", 133 | "Alias": "bar" 134 | }, 135 | { 136 | "Node Type": "Seq Scan", 137 | "Alias": "foo" 138 | } 139 | ] 140 | } 141 | }, 142 | { 143 | "Plan": { 144 | "Node Type": "Hash Join", 145 | "Join Type": "Inner", 146 | "Hash Cond": "(foo.x = bar.b)", 147 | "Plans": [ 148 | { 149 | "Node Type": "Seq Scan", 150 | "Alias": "foo" 151 | }, 152 | { 153 | "Node Type": "Hash", 154 | "Plans": [ 155 | { 156 | "Node Type": "Seq Scan", 157 | "Alias": "bar" 158 | } 159 | ] 160 | } 161 | ] 162 | } 163 | }, 164 | { 165 | "Plan": { 166 | "Node Type": "Nested Loop", 167 | "Join Type": "Inner", 168 | "Plans": [ 169 | { 170 | "Node Type": "Seq Scan", 171 | "Alias": "bar" 172 | }, 173 | { 174 | "Node Type": "Seq Scan", 175 | "Alias": "foo" 176 | } 177 | ] 178 | } 179 | } 180 | ], 181 | "sources": [ 182 | [ 183 | "enable_hashjoin" 184 | ], 185 | [ 186 | "enable_mergejoin" 187 | ], 188 | [ 189 | "enable_hashjoin", 190 | "enable_mergejoin" 191 | ] 192 | ] 193 | }, 194 | { 195 | "params": [ 196 | "http://www.goodnight.org" 197 | ], 198 | "result": { 199 | "Plan": { 200 | "Node Type": "Hash Join", 201 | "Join Type": "Inner", 202 | "Hash Cond": "(foo.x = bar.b)", 203 | "Plans": [ 204 | { 205 | "Node Type": "Seq Scan", 206 | "Alias": "foo" 207 | }, 208 | { 209 | "Node Type": "Hash", 210 | "Plans": [ 211 | { 212 | "Node Type": "Seq Scan", 213 | "Alias": "bar" 214 | } 215 | ] 216 | } 217 | ] 218 | } 219 | }, 220 | "additional_plans": [ 221 | { 222 | "Plan": { 223 | "Node Type": "Nested Loop", 224 | "Join Type": "Inner", 225 | "Plans": [ 226 | { 227 | "Node Type": "Seq Scan", 228 | "Alias": "bar" 229 | }, 230 | { 231 | "Node Type": "Seq Scan", 232 | "Alias": "foo" 233 | } 234 | ] 235 | } 236 | }, 237 | { 238 | "Plan": { 239 | "Node Type": "Hash Join", 240 | "Join Type": "Inner", 241 | "Hash Cond": "(foo.x = bar.b)", 242 | "Plans": [ 243 | { 244 | "Node Type": "Seq Scan", 245 | "Alias": "foo" 246 | }, 247 | { 248 | "Node Type": "Hash", 249 | "Plans": [ 250 | { 251 | "Node Type": "Seq Scan", 252 | "Alias": "bar" 253 | } 254 | ] 255 | } 256 | ] 257 | } 258 | }, 259 | { 260 | "Plan": { 261 | "Node Type": "Nested Loop", 262 | "Join Type": "Inner", 263 | "Plans": [ 264 | { 265 | "Node Type": "Seq Scan", 266 | "Alias": "bar" 267 | }, 268 | { 269 | "Node Type": "Seq Scan", 270 | "Alias": "foo" 271 | } 272 | ] 273 | } 274 | } 275 | ], 276 | "sources": [ 277 | [ 278 | "enable_hashjoin" 279 | ], 280 | [ 281 | "enable_mergejoin" 282 | ], 283 | [ 284 | "enable_hashjoin", 285 | "enable_mergejoin" 286 | ] 287 | ] 288 | }, 289 | { 290 | "params": [ 291 | "http://www.goodmorning.com" 292 | ], 293 | "result": { 294 | "Plan": { 295 | "Node Type": "Hash Join", 296 | "Join Type": "Inner", 297 | "Hash Cond": "(foo.x = bar.b)", 298 | "Plans": [ 299 | { 300 | "Node Type": "Seq Scan", 301 | "Alias": "foo" 302 | }, 303 | { 304 | "Node Type": "Hash", 305 | "Plans": [ 306 | { 307 | "Node Type": "Seq Scan", 308 | "Alias": "bar" 309 | } 310 | ] 311 | } 312 | ] 313 | } 314 | }, 315 | "additional_plans": [ 316 | { 317 | "Plan": { 318 | "Node Type": "Nested Loop", 319 | "Join Type": "Inner", 320 | "Plans": [ 321 | { 322 | "Node Type": "Seq Scan", 323 | "Alias": "bar" 324 | }, 325 | { 326 | "Node Type": "Seq Scan", 327 | "Alias": "foo" 328 | } 329 | ] 330 | } 331 | }, 332 | { 333 | "Plan": { 334 | "Node Type": "Hash Join", 335 | "Join Type": "Inner", 336 | "Hash Cond": "(foo.x = bar.b)", 337 | "Plans": [ 338 | { 339 | "Node Type": "Seq Scan", 340 | "Alias": "foo" 341 | }, 342 | { 343 | "Node Type": "Hash", 344 | "Plans": [ 345 | { 346 | "Node Type": "Seq Scan", 347 | "Alias": "bar" 348 | } 349 | ] 350 | } 351 | ] 352 | } 353 | }, 354 | { 355 | "Plan": { 356 | "Node Type": "Nested Loop", 357 | "Join Type": "Inner", 358 | "Plans": [ 359 | { 360 | "Node Type": "Seq Scan", 361 | "Alias": "bar" 362 | }, 363 | { 364 | "Node Type": "Seq Scan", 365 | "Alias": "foo" 366 | } 367 | ] 368 | } 369 | } 370 | ], 371 | "sources": [ 372 | [ 373 | "enable_hashjoin" 374 | ], 375 | [ 376 | "enable_mergejoin" 377 | ], 378 | [ 379 | "enable_hashjoin", 380 | "enable_mergejoin" 381 | ] 382 | ] 383 | } 384 | ] 385 | } 386 | } 387 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | numpy 3 | pandas==1.3.5 4 | pglast==3.17 5 | psycopg2 6 | pyglib 7 | pytest 8 | python-dateutil 9 | scipy 10 | tensorflow 11 | tf-models-official 12 | --------------------------------------------------------------------------------