├── README.md ├── __init__.py ├── __main__.py ├── api_index_advise.py ├── executors ├── __init__.py ├── common.py ├── driver_executor.py └── gsql_executor.py ├── figures └── arch.jpg ├── index_advisor_workload.py ├── mcts.py ├── process_bar.py ├── sql_generator.py ├── sql_output_parser.py ├── table.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Index_advisor 2 | **Index_advisor** is a tool to recommend indexes for workload. A workload consists 3 | of a set of SQL data manipulation statements, i.e., Select, Insert, Delete and Update. 4 | First, some candidate indexes are generated based on query syntax and database 5 | statistics. Then the optimal index set is determined by estimating the cost and 6 | benefit of it for the workload. 7 | 8 | 9 | Origin Repo Link (check the update-to-date functions): https://gitee.com/opengauss/openGauss-DBMind/tree/master/dbmind/components/index_advisor 10 | 11 | Document Link: https://docs.opengauss.org/en/docs/3.1.1/docs/Developerguide/index-advisor-index-recommendation.html 12 | 13 | 14 | ![alt text](./figures/arch.jpg?raw=true) 15 | 16 | ## Citing AutoIndex 17 | 18 | ```bibTeX 19 | @inproceedings{autoindex2022, 20 | author = {Xuanhe Zhou and Luyang Liu and Wenbo Li and Lianyuan Jin and Shifu Li and Tianqing Wang and Jianhua Feng}, 21 | title = {AutoIndex: An Incremental Index Management System for Dynamic Workloads}, 22 | booktitle = {ICDE}, 23 | year = {2022}} 24 | ``` 25 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | from .index_advisor_workload import main 15 | -------------------------------------------------------------------------------- /__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import os 15 | import sys 16 | 17 | try: 18 | from dbmind.components.index_advisor import main 19 | except ImportError: 20 | libpath = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 21 | sys.path.append(libpath) 22 | from index_advisor import main 23 | 24 | main(sys.argv[1:]) 25 | -------------------------------------------------------------------------------- /api_index_advise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | import logging 14 | from contextlib import contextmanager 15 | import psycopg2 16 | 17 | from . import index_advisor_workload, process_bar 18 | from .executors.driver_executor import DriverExecutor 19 | 20 | 21 | class Executor(DriverExecutor): 22 | def __init__(self, dbname=None, user=None, password=None, host=None, port=None, schema=None): 23 | super().__init__(dbname, user, password, host, port, schema) 24 | self.conn = None 25 | self.cur = None 26 | 27 | def set_connection(self, connection): 28 | self.conn = connection 29 | self.cur = self.conn.cursor() 30 | 31 | def set_schemas(self, schemas): 32 | self.schema = ','.join(schemas) 33 | 34 | @contextmanager 35 | def session(self): 36 | yield 37 | 38 | 39 | def api_index_advise(sql_pairs, connection=None, dsn=None, 40 | schemas=("public",), improved_rate=0.5, 41 | improved_cost=None, max_index_num=None, 42 | max_index_columns=3, min_n_distinct=20, 43 | **connection_kwargs): 44 | search_path = None 45 | if min_n_distinct <= 0: 46 | raise ValueError('min_n_distinct is an invalid positive int value') 47 | process_bar.print = lambda *args, **kwargs: None 48 | # nly single thread can be used 49 | index_advisor_workload.get_workload_costs = index_advisor_workload.get_plan_cost 50 | templates = dict() 51 | executor = Executor() 52 | executor.set_schemas(schemas) 53 | if connection: 54 | cursor = connection.cursor() 55 | cursor.execute('show search_path;') 56 | search_path = cursor.fetchone()[0] 57 | connection.commit() 58 | elif dsn: 59 | connection = psycopg2.connect(dsn=dsn) 60 | else: 61 | connection = psycopg2.connect(**connection_kwargs) 62 | executor.set_connection(connection) 63 | 64 | for sql, count in sql_pairs.items(): 65 | templates[sql] = {'samples': [sql], 66 | 'cnt': count} 67 | if max_index_num: 68 | index_advisor_workload.MAX_INDEX_NUM = max_index_num 69 | if improved_cost: 70 | index_advisor_workload.MAX_BENEFIT_THRESHOLD = improved_cost 71 | detail_info, advised_indexes, _redundant_indexes = index_advisor_workload.index_advisor_workload( 72 | {'historyIndexes': {}}, executor, templates, 73 | multi_iter_mode=True, show_detail=True, 74 | n_distinct=0.02, reltuples=10, show_benefits=True, 75 | use_all_columns=True, improved_rate=improved_rate, 76 | max_index_columns=max_index_columns, max_n_distinct=1 / min_n_distinct, 77 | ) 78 | 79 | redundant_indexes = [] 80 | for index in _redundant_indexes: 81 | if index.get_is_unique() or index.is_primary_key(): 82 | continue 83 | statement = "DROP INDEX %s.%s;(%s)" % (index.get_schema(), index.get_indexname(), index.get_indexdef()) 84 | related_indexes = [] 85 | for _index in index.redundant_objs: 86 | related_indexes.append(_index.get_indexdef()) 87 | redundant_indexes.append({'redundant_index': statement, 88 | 'related_indexes': related_indexes}) 89 | cursor = connection.cursor() 90 | if search_path: 91 | cursor.execute(f'set current_schema={search_path}') 92 | connection.commit() 93 | else: 94 | cursor.close() 95 | connection.close() 96 | 97 | return advised_indexes, redundant_indexes 98 | -------------------------------------------------------------------------------- /executors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | -------------------------------------------------------------------------------- /executors/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | from abc import abstractmethod 15 | from typing import List 16 | 17 | 18 | class BaseExecutor: 19 | def __init__(self, dbname, user, password, host, port, schema, driver=None): 20 | self.dbname = dbname 21 | self.user = user 22 | self.password = password 23 | self.host = host 24 | self.port = port 25 | self.schema = schema 26 | self.driver = driver 27 | 28 | def get_schema(self): 29 | return self.schema 30 | 31 | @abstractmethod 32 | def execute_sqls(self, sqls) -> List[str]: 33 | pass 34 | 35 | @abstractmethod 36 | def session(self): 37 | pass 38 | -------------------------------------------------------------------------------- /executors/driver_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | import sys 14 | from typing import List 15 | import logging 16 | from contextlib import contextmanager 17 | 18 | import psycopg2 19 | 20 | sys.path.append('..') 21 | 22 | from .common import BaseExecutor 23 | 24 | 25 | class DriverExecutor(BaseExecutor): 26 | def __init__(self, *arg): 27 | super(DriverExecutor, self).__init__(*arg) 28 | self.conn = None 29 | self.cur = None 30 | with self.session(): 31 | pass 32 | 33 | def __init_conn_handle(self): 34 | self.conn = psycopg2.connect(dbname=self.dbname, 35 | user=self.user, 36 | password=self.password, 37 | host=self.host, 38 | port=self.port, 39 | application_name='DBMind-index-advisor') 40 | self.cur = self.conn.cursor() 41 | 42 | def __execute(self, sql): 43 | if self.cur.closed: 44 | self.__init_conn_handle() 45 | try: 46 | self.cur.execute(sql) 47 | self.conn.commit() 48 | if self.cur.rowcount == -1: 49 | return 50 | return [(self.cur.statusmessage,)] + self.cur.fetchall() 51 | except psycopg2.ProgrammingError: 52 | return [('ERROR',)] 53 | except Exception as e: 54 | logging.warning('Found %s while executing SQL statement.', e) 55 | return [('ERROR ' + str(e),)] 56 | finally: 57 | self.conn.rollback() 58 | 59 | def execute_sqls(self, sqls) -> List[str]: 60 | results = [] 61 | sqls = ['set current_schema = %s' % self.get_schema()] + sqls 62 | for sql in sqls: 63 | res = self.__execute(sql) 64 | if res: 65 | results.extend(res) 66 | return results 67 | 68 | def __close_conn(self): 69 | if self.conn and self.cur: 70 | self.cur.close() 71 | self.conn.close() 72 | 73 | @contextmanager 74 | def session(self): 75 | self.__init_conn_handle() 76 | yield 77 | self.__close_conn() 78 | -------------------------------------------------------------------------------- /executors/gsql_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import os 15 | import shlex 16 | import subprocess 17 | import sys 18 | import time 19 | from contextlib import contextmanager 20 | from typing import List, Tuple 21 | import re 22 | import tempfile 23 | 24 | from .common import BaseExecutor 25 | 26 | BLANK = ' ' 27 | 28 | 29 | def to_tuples(text): 30 | """Parse execution result by using gsql 31 | and convert to tuples.""" 32 | lines = text.splitlines() 33 | separator_location = -1 34 | for i, line in enumerate(lines): 35 | # Find separator line such as '-----+-----+------'. 36 | if re.match(r'^\s*?[-|+]+\s*$', line): 37 | separator_location = i 38 | break 39 | 40 | if separator_location < 0: 41 | return [] 42 | 43 | separator = lines[separator_location] 44 | left = 0 45 | right = len(separator) 46 | locations = list() 47 | while left < right: 48 | try: 49 | location = separator.index('+', left, right) 50 | except ValueError: 51 | break 52 | locations.append(location) 53 | left = location + 1 54 | # Record each value start location and end location. 55 | pairs = list(zip([0] + locations, locations + [right])) 56 | tuples = [] 57 | row = [] 58 | wrap_flag = False 59 | # Continue to parse each line. 60 | for line in lines[separator_location + 1:]: 61 | # Prevent from parsing bottom lines. 62 | if len(line.strip()) == 0 or re.match(r'\(\d+ rows?\)', line): 63 | continue 64 | # Parse a record to tuple. 65 | if wrap_flag: 66 | row[-1] += line[pairs[-1][0] + 1: pairs[-1][1]].strip() 67 | else: 68 | for start, end in pairs: 69 | # Increase 1 to start index to go over vertical bar (|). 70 | row.append(line[start + 1: end].strip()) 71 | 72 | if len(line) == right and re.match(r'.*\s*\+$', line): 73 | wrap_flag = True 74 | row[-1] = row[-1].strip('+').strip(BLANK) + BLANK 75 | else: 76 | tuples.append(tuple(row)) 77 | row = [] 78 | wrap_flag = False 79 | return tuples 80 | 81 | 82 | class GsqlExecutor(BaseExecutor): 83 | def __init__(self, *args): 84 | super(GsqlExecutor, self).__init__(*args) 85 | self.base_cmd = '' 86 | with self.session(): 87 | self.__check_connect() 88 | 89 | def __init_conn_handle(self): 90 | self.base_cmd = 'gsql -p ' + str(self.port) + ' -d ' + self.dbname 91 | if self.host: 92 | self.base_cmd += ' -h ' + self.host 93 | if self.user: 94 | self.base_cmd += ' -U ' + self.user 95 | if self.password: 96 | self.base_cmd += ' -W ' + shlex.quote(self.password) 97 | 98 | def __check_connect(self): 99 | cmd = self.base_cmd + ' -c \"' 100 | cmd += 'select 1;\"' 101 | proc = subprocess.Popen( 102 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) 103 | (stdout, stderr) = proc.communicate() 104 | stdout, stderr = stdout.decode(errors='ignore'), stderr.decode(errors='ignore') 105 | if 'gsql: FATAL:' in stderr or 'failed to connect' in stderr: 106 | raise ConnectionError("An error occurred while connecting to the database.\n" + 107 | "Details: " + stderr) 108 | return stdout 109 | 110 | @staticmethod 111 | def __to_tuples(sql_result: str) -> List[Tuple[str]]: 112 | is_tuple = False 113 | results = [] 114 | tmp_tuple_lines = [] 115 | for line in sql_result.strip().split('\n'): 116 | if re.match(r'^\s*?[-|+]+\s*$', line): 117 | is_tuple = True 118 | elif re.match(r'\(\d+ rows?\)', line) and is_tuple: 119 | is_tuple = False 120 | results.extend(to_tuples('\n'.join(tmp_tuple_lines))) 121 | tmp_tuple_lines = [] 122 | if is_tuple: 123 | tmp_tuple_lines.append(line) 124 | else: 125 | results.append((line,)) 126 | 127 | return results 128 | 129 | def execute_sqls(self, sqls): 130 | sqls = ['set current_schema = %s' % self.get_schema()] + sqls 131 | 132 | file1 = tempfile.NamedTemporaryFile(mode='w+', delete=True) 133 | try: 134 | for sql in sqls: 135 | if not sql.strip().endswith(';'): 136 | sql += ';' 137 | file1.file.write(sql + '\n') 138 | file1.file.flush() 139 | cmd = self.base_cmd + ' -f ' + file1.name 140 | try: 141 | ret = subprocess.check_output( 142 | shlex.split(cmd), stderr=subprocess.STDOUT) 143 | return self.__to_tuples(ret.decode(errors='ignore')) 144 | except subprocess.CalledProcessError as e: 145 | print(e.output.decode(errors='ignore'), file=sys.stderr) 146 | finally: 147 | file1.close() 148 | 149 | @contextmanager 150 | def session(self): 151 | self.__init_conn_handle() 152 | yield 153 | -------------------------------------------------------------------------------- /figures/arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuanheZhou/AutoIndex/fd7f4a9768ae8567d8259c05d8d0efd33215fe54/figures/arch.jpg -------------------------------------------------------------------------------- /index_advisor_workload.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import argparse 15 | import copy 16 | import getpass 17 | import json 18 | import os 19 | import random 20 | import re 21 | import sys 22 | import select 23 | import logging 24 | from logging.handlers import RotatingFileHandler 25 | from collections import defaultdict 26 | from functools import lru_cache 27 | from itertools import groupby, chain, combinations, permutations 28 | from typing import Tuple, List 29 | import heapq 30 | from multiprocessing import Pool 31 | 32 | import sqlparse 33 | from sql_metadata import Parser 34 | 35 | try: 36 | from .sql_output_parser import parse_single_advisor_results, parse_explain_plan, \ 37 | get_checked_indexes, parse_table_sql_results, parse_existing_indexes_results, parse_plan_cost, parse_hypo_index 38 | from .sql_generator import get_single_advisor_sql, get_index_check_sqls, get_existing_index_sql, \ 39 | get_workload_cost_sqls, get_index_setting_sqls, get_prepare_sqls, get_hypo_index_head_sqls 40 | from .executors.common import BaseExecutor 41 | from .executors.gsql_executor import GsqlExecutor 42 | from .mcts import MCTS 43 | from .table import get_table_context 44 | from .utils import match_table_name, IndexItemFactory, \ 45 | AdvisedIndex, ExistingIndex, QueryItem, WorkLoad, QueryType, IndexType, COLUMN_DELIMITER, \ 46 | lookfor_subsets_configs, has_dollar_placeholder, generate_placeholder_indexes, \ 47 | match_columns, infer_workload_benefit, UniqueList, is_multi_node, hypo_index_ctx, split_iter, \ 48 | replace_comma_with_dollar, replace_function_comma, flatten, ERROR_KEYWORD 49 | from .process_bar import bar_print, ProcessBar 50 | except ImportError: 51 | from sql_output_parser import parse_single_advisor_results, parse_explain_plan, \ 52 | get_checked_indexes, parse_table_sql_results, parse_existing_indexes_results, parse_plan_cost, parse_hypo_index 53 | from sql_generator import get_single_advisor_sql, get_index_check_sqls, get_existing_index_sql, \ 54 | get_workload_cost_sqls, get_index_setting_sqls, get_prepare_sqls, get_hypo_index_head_sqls 55 | from executors.common import BaseExecutor 56 | from executors.gsql_executor import GsqlExecutor 57 | from mcts import MCTS 58 | from table import get_table_context 59 | from utils import match_table_name, IndexItemFactory, \ 60 | AdvisedIndex, ExistingIndex, QueryItem, WorkLoad, QueryType, IndexType, COLUMN_DELIMITER, \ 61 | lookfor_subsets_configs, has_dollar_placeholder, generate_placeholder_indexes, \ 62 | match_columns, infer_workload_benefit, UniqueList, is_multi_node, hypo_index_ctx, split_iter, \ 63 | replace_comma_with_dollar, replace_function_comma, flatten, ERROR_KEYWORD 64 | from process_bar import bar_print, ProcessBar 65 | 66 | SAMPLE_NUM = 5 67 | MAX_INDEX_COLUMN_NUM = 5 68 | MAX_CANDIDATE_COLUMNS = 40 69 | MAX_INDEX_NUM = None 70 | MAX_INDEX_STORAGE = None 71 | FULL_ARRANGEMENT_THRESHOLD = 20 72 | NEGATIVE_RATIO_THRESHOLD = 0.2 73 | MAX_BENEFIT_THRESHOLD = float('inf') 74 | SHARP = '#' 75 | JSON_TYPE = False 76 | BLANK = ' ' 77 | GLOBAL_PROCESS_BAR = ProcessBar() 78 | SQL_TYPE = ['select', 'delete', 'insert', 'update'] 79 | NUMBER_SET_PATTERN = r'\((\s*(\-|\+)?\d+(\.\d+)?\s*)(,\s*(\-|\+)?\d+(\.\d+)?\s*)*[,]?\)' 80 | SQL_PATTERN = [r'([^\\])\'((\')|(.*?([^\\])\'))', # match all content in single quotes 81 | NUMBER_SET_PATTERN, # match integer set in the IN collection 82 | r'(([^<>]\s*=\s*)|([^<>]\s+))(\d+)(\.\d+)?'] # match single integer 83 | SQL_DISPLAY_PATTERN = [r'\'((\')|(.*?\'))', # match all content in single quotes 84 | NUMBER_SET_PATTERN, # match integer set in the IN collection 85 | r'([^\_\d])\d+(\.\d+)?'] # match single integer 86 | 87 | os.umask(0o0077) 88 | 89 | 90 | def path_type(path): 91 | realpath = os.path.realpath(path) 92 | if os.path.exists(realpath): 93 | return realpath 94 | raise argparse.ArgumentTypeError('%s is not a valid path.' % path) 95 | 96 | 97 | def set_logger(): 98 | logfile = 'index_advisor.log' 99 | handler = RotatingFileHandler( 100 | filename=logfile, 101 | maxBytes=100 * 1024 * 1024, 102 | backupCount=5, 103 | ) 104 | handler.setLevel(logging.INFO) 105 | handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s')) 106 | logger = logging.getLogger() 107 | logger.addHandler(handler) 108 | logger.setLevel(logging.INFO) 109 | 110 | 111 | class CheckWordValid(argparse.Action): 112 | def __call__(self, parser, namespace, values, option_string=None): 113 | ill_character = [" ", "|", ";", "&", "$", "<", ">", "`", "\\", "'", "\"", 114 | "{", "}", "(", ")", "[", "]", "~", "*", "?", "!", "\n"] 115 | if not values.strip(): 116 | return 117 | if any(ill_char in values for ill_char in ill_character): 118 | parser.error('There are illegal characters in your input.') 119 | setattr(namespace, self.dest, values) 120 | 121 | 122 | def read_input_from_pipe(): 123 | """ 124 | Read stdin input if there is "echo 'str1 str2' | python xx.py", return the input string. 125 | """ 126 | input_str = "" 127 | r_handle, _, _ = select.select([sys.stdin], [], [], 0) 128 | if not r_handle: 129 | return "" 130 | 131 | for item in r_handle: 132 | if item == sys.stdin: 133 | input_str = sys.stdin.read().strip() 134 | return input_str 135 | 136 | 137 | def get_password(): 138 | password = read_input_from_pipe() 139 | if password: 140 | logging.warning("Read password from pipe.") 141 | else: 142 | password = getpass.getpass("Password for database user:") 143 | if not password: 144 | raise ValueError('Please input the password') 145 | return password 146 | 147 | 148 | def is_valid_statement(conn, statement): 149 | """Determine if the query is correct by whether the executor throws an exception.""" 150 | queries = get_prepare_sqls(statement) 151 | res = conn.execute_sqls(queries) 152 | # Rpc executor return [] if the statement is not executed successfully. 153 | if not res: 154 | return False 155 | for _tuple in res: 156 | if isinstance(_tuple[0], str) and \ 157 | (_tuple[0].upper().startswith(ERROR_KEYWORD) or f' {ERROR_KEYWORD}: ' in _tuple[0].upper()): 158 | return False 159 | return True 160 | 161 | 162 | def get_positive_sql_count(candidate_indexes: List[AdvisedIndex], workload: WorkLoad): 163 | positive_sql_count = 0 164 | for query in workload.get_queries(): 165 | for index in candidate_indexes: 166 | if workload.is_positive_query(index, query): 167 | positive_sql_count += query.get_frequency() 168 | break 169 | return int(positive_sql_count) 170 | 171 | 172 | def print_statement(index_list: List[Tuple[str]], schema_table: str): 173 | for columns, index_type in index_list: 174 | index_name = 'idx_%s_%s%s' % (schema_table.split('.')[-1], 175 | (index_type + '_' if index_type else ''), 176 | '_'.join(columns.split(COLUMN_DELIMITER))) 177 | statement = 'CREATE INDEX %s ON %s%s%s;' % (index_name, schema_table, 178 | '(' + columns + ')', 179 | (' ' + index_type if index_type else '')) 180 | bar_print(statement) 181 | 182 | 183 | class IndexAdvisor: 184 | def __init__(self, executor: BaseExecutor, workload: WorkLoad, multi_iter_mode: bool): 185 | self.executor = executor 186 | self.workload = workload 187 | self.multi_iter_mode = multi_iter_mode 188 | 189 | self.determine_indexes = [] 190 | self.integrate_indexes = {} 191 | 192 | self.display_detail_info = {} 193 | self.index_benefits = [] 194 | self.redundant_indexes = [] 195 | 196 | def complex_index_advisor(self, candidate_indexes: List[AdvisedIndex]): 197 | atomic_config_total = generate_sorted_atomic_config(self.workload.get_queries(), candidate_indexes) 198 | same_columns_config = generate_atomic_config_containing_same_columns(candidate_indexes) 199 | for atomic_config in same_columns_config: 200 | if atomic_config not in atomic_config_total: 201 | atomic_config_total.append(atomic_config) 202 | if atomic_config_total and len(atomic_config_total[0]) != 0: 203 | raise ValueError("The empty atomic config isn't generated!") 204 | for atomic_config in GLOBAL_PROCESS_BAR.process_bar(atomic_config_total, 'Optimal indexes'): 205 | estimate_workload_cost_file(self.executor, self.workload, atomic_config) 206 | self.workload.set_index_benefit() 207 | if MAX_INDEX_STORAGE: 208 | opt_config = MCTS(self.workload, atomic_config_total, candidate_indexes, 209 | MAX_INDEX_STORAGE, MAX_INDEX_NUM) 210 | else: 211 | opt_config = greedy_determine_opt_config(self.workload, atomic_config_total, 212 | candidate_indexes) 213 | self.filter_redundant_indexes_with_diff_types(opt_config) 214 | self.filter_same_columns_indexes(opt_config, self.workload) 215 | self.display_detail_info['positive_stmt_count'] = get_positive_sql_count(opt_config, 216 | self.workload) 217 | if len(opt_config) == 0: 218 | bar_print("No optimal indexes generated!") 219 | return None 220 | return opt_config 221 | 222 | @staticmethod 223 | def filter_same_columns_indexes(opt_config, workload, rate=0.8): 224 | """If the columns in two indexes have a containment relationship, 225 | for example, index1 is table1(col1, col2), index2 is table1(col3, col1, col2), 226 | then when the gain of one index is close to the gain of both indexes as a whole, 227 | the addition of the other index obviously does not improve the gain much, 228 | and we filter it out.""" 229 | same_columns_config = generate_atomic_config_containing_same_columns(opt_config) 230 | origin_cost = workload.get_total_origin_cost() 231 | filtered_indexes = UniqueList() 232 | for short_index, long_index in same_columns_config: 233 | if workload.has_indexes((short_index, long_index)): 234 | combined_benefit = workload.get_total_index_cost((short_index, long_index)) - origin_cost 235 | elif workload.has_indexes((long_index, short_index)): 236 | combined_benefit = workload.get_total_index_cost((long_index, short_index)) - origin_cost 237 | else: 238 | continue 239 | short_index_benefit = workload.get_total_index_cost((short_index,)) - origin_cost 240 | long_index_benefit = workload.get_total_index_cost((long_index,)) - origin_cost 241 | if combined_benefit and short_index_benefit / combined_benefit > rate: 242 | filtered_indexes.append(long_index) 243 | continue 244 | if combined_benefit and long_index_benefit / combined_benefit > rate: 245 | filtered_indexes.append(short_index) 246 | for filtered_index in filtered_indexes: 247 | opt_config.remove(filtered_index) 248 | logging.info(f'filtered: {filtered_index} is removed due to similar benefits ' 249 | f'with other same column indexes') 250 | 251 | def simple_index_advisor(self, candidate_indexes: List[AdvisedIndex]): 252 | estimate_workload_cost_file(self.executor, self.workload) 253 | for index in GLOBAL_PROCESS_BAR.process_bar(candidate_indexes, 'Optimal indexes'): 254 | estimate_workload_cost_file(self.executor, self.workload, (index,)) 255 | self.workload.set_index_benefit() 256 | self.filter_redundant_indexes_with_diff_types(candidate_indexes) 257 | if not candidate_indexes: 258 | bar_print("No optimal indexes generated!") 259 | return None 260 | 261 | self.display_detail_info['positive_stmt_count'] = get_positive_sql_count(candidate_indexes, 262 | self.workload) 263 | return candidate_indexes 264 | 265 | def filter_low_benefit_index(self, opt_indexes: List[AdvisedIndex], improved_rate): 266 | index_current_storage = 0 267 | cnt = 0 268 | for key, index in enumerate(opt_indexes): 269 | sql_optimized = 0 270 | negative_sql_ratio = 0 271 | insert_queries, delete_queries, \ 272 | update_queries, select_queries, \ 273 | positive_queries, ineffective_queries, \ 274 | negative_queries = self.workload.get_index_related_queries(index) 275 | sql_num = self.workload.get_index_sql_num(index) 276 | total_benefit = 0 277 | # Calculate the average benefit of each positive SQL. 278 | for query in positive_queries: 279 | current_cost = self.workload.get_indexes_cost_of_query(query, (index,)) 280 | origin_cost = self.workload.get_origin_cost_of_query(query) 281 | sql_optimized += (1 - current_cost / origin_cost) * query.get_frequency() 282 | benefit = origin_cost - current_cost 283 | total_benefit += benefit 284 | total_queries_num = sql_num['negative'] + sql_num['ineffective'] + sql_num['positive'] 285 | if total_queries_num: 286 | negative_sql_ratio = sql_num['negative'] / total_queries_num 287 | # Filter the candidate indexes that do not meet the conditions of optimization. 288 | logging.info(f'filter low benefit index for {index}') 289 | if not positive_queries: 290 | logging.info('filtered: positive_queries not found for the index') 291 | continue 292 | if sql_optimized / sql_num['positive'] < improved_rate and total_benefit < MAX_BENEFIT_THRESHOLD: 293 | logging.info(f"filtered: improved_rate {sql_optimized / sql_num['positive']} less than {improved_rate}") 294 | continue 295 | if sql_optimized / sql_num['positive'] < \ 296 | NEGATIVE_RATIO_THRESHOLD < negative_sql_ratio: 297 | logging.info(f'filtered: improved_rate {sql_optimized / sql_num["positive"]} < ' 298 | f'negative_ratio_threshold < negative_sql_ratio {negative_sql_ratio} is not met') 299 | continue 300 | logging.info(f'{index} has benefit of {self.workload.get_index_benefit(index)}') 301 | if MAX_INDEX_STORAGE and (index_current_storage + index.get_storage()) > MAX_INDEX_STORAGE: 302 | logging.info('filtered: if add the index {index}, it reaches the max index storage.') 303 | continue 304 | if MAX_INDEX_NUM and cnt == MAX_INDEX_NUM: 305 | logging.info('filtered: reach the maximum number for the index.') 306 | break 307 | if not self.multi_iter_mode and index.benefit <= 0: 308 | logging.info('filtered: benefit not above 0 for the index.') 309 | continue 310 | index_current_storage += index.get_storage() 311 | cnt += 1 312 | self.determine_indexes.append(index) 313 | 314 | def print_benefits(self, created_indexes: List[ExistingIndex]): 315 | print_header_boundary('Index benefits') 316 | table_indexes = defaultdict(UniqueList) 317 | for index in created_indexes: 318 | table_indexes[index.get_schema_table()].append(index) 319 | total_origin_cost = self.workload.get_total_origin_cost() 320 | for i, index in enumerate(self.determine_indexes): 321 | useless_indexes = [] 322 | existing_indexes = [] 323 | improved_queries = [] 324 | indexdef = index.get_index_statement() 325 | bar_print(f'INDEX {i}: {indexdef}') 326 | workload_benefit = sum([query.get_benefit() for query in index.get_positive_queries()]) 327 | workload_improved_rate = workload_benefit / total_origin_cost 328 | bar_print('\tCost benefit for workload: %.2f' % workload_benefit) 329 | bar_print('\tCost improved rate for workload: %.2f%%' 330 | % (workload_improved_rate * 100)) 331 | 332 | # invalid indexes caused by recommended indexes 333 | source_index = index.get_source_index() 334 | if source_index and (not source_index.is_primary_key()) and (not source_index.get_is_unique()): 335 | bar_print('\tCurrently existing useless indexes:') 336 | bar_print(f'\t\t{source_index.get_indexdef()}') 337 | useless_indexes.append(source_index.get_indexdef()) 338 | 339 | # information about existing indexes 340 | created_indexes = table_indexes.get(index.get_table(), []) 341 | if created_indexes: 342 | bar_print('\tExisting indexes of this relation:') 343 | for created_index in created_indexes: 344 | bar_print(f'\t\t{created_index.get_indexdef()}') 345 | existing_indexes.append(created_index.get_indexdef()) 346 | 347 | bar_print('\tImproved query:') 348 | # get benefit rate for subsequent sorting and display 349 | query_benefit_rate = [] 350 | for query in sorted(index.get_positive_queries(), key=lambda query: -query.get_benefit()): 351 | query_origin_cost = self.workload.get_origin_cost_of_query(query) 352 | current_cost = self.workload.get_indexes_cost_of_query(query, tuple([index])) 353 | query_improved_rate = (query_origin_cost - current_cost) / current_cost 354 | query_benefit_rate.append((query, query_improved_rate)) 355 | # sort query by benefit rate 356 | for j, (query, query_improved_rate) in enumerate(sorted(query_benefit_rate, key=lambda x: -x[1])): 357 | other_related_indexes = [] 358 | bar_print(f'\t\tQuery {j}: {query.get_statement()}') 359 | query_origin_cost = self.workload.get_origin_cost_of_query(query) 360 | current_cost = self.workload.get_indexes_cost_of_query(query, tuple([index])) 361 | query_benefit = query_origin_cost - current_cost 362 | origin_plan = self.workload.get_indexes_plan_of_query(query, None) 363 | current_plan = self.workload.get_indexes_plan_of_query(query, tuple([index])) 364 | bar_print('\t\t\tCost benefit for the query: %.2f' % query_benefit) 365 | bar_print('\t\t\tCost improved rate for the query: %.2f%%' % (query_improved_rate * 100)) 366 | bar_print(f'\t\t\tQuery number: {int(query.get_frequency())}') 367 | if len(query.get_indexes()) > 1: 368 | bar_print('\t\t\tOther optimal indexes:') 369 | for temp_index in query.get_indexes(): 370 | if temp_index is index: 371 | continue 372 | bar_print(f'\t\t\t\t{temp_index.get_index_statement()}') 373 | other_related_indexes.append(temp_index.get_index_statement()) 374 | improved_queries.append({'query': query.get_statement(), 375 | 'query_benefit': query_benefit, 376 | 'query_improved_rate': query_improved_rate, 377 | 'query_count': int(query.get_frequency()), 378 | 'origin_plan': origin_plan, 379 | 'current_plan': current_plan, 380 | 'other_related_indexes': other_related_indexes 381 | }) 382 | self.index_benefits.append({'indexdef': indexdef, 383 | 'workload_benefit': workload_benefit, 384 | 'workload_improved_rate': workload_improved_rate, 385 | 'useless_indexes': useless_indexes, 386 | 'existing_indexes': existing_indexes, 387 | 'improved_queriies': improved_queries, 388 | }) 389 | 390 | def record_info(self, index: AdvisedIndex, sql_info, table_name: str, statement: str): 391 | sql_num = self.workload.get_index_sql_num(index) 392 | total_sql_num = int(sql_num['positive'] + sql_num['ineffective'] + sql_num['negative']) 393 | workload_optimized = index.benefit / self.workload.get_total_origin_cost() * 100 394 | sql_info['workloadOptimized'] = '%.2f' % \ 395 | (workload_optimized if workload_optimized > 1 else 1) 396 | sql_info['schemaName'] = index.get_table().split('.')[0] 397 | sql_info['tbName'] = table_name 398 | sql_info['columns'] = index.get_columns() 399 | sql_info['index_type'] = index.get_index_type() 400 | sql_info['statement'] = statement 401 | sql_info['storage'] = index.get_storage() 402 | sql_info['dmlCount'] = total_sql_num 403 | sql_info['selectRatio'] = 1 404 | sql_info['insertRatio'] = sql_info['deleteRatio'] = sql_info['updateRatio'] = 0 405 | if total_sql_num: 406 | sql_info['selectRatio'] = round( 407 | (sql_num['select']) * 100 / total_sql_num, 2) 408 | sql_info['insertRatio'] = round( 409 | sql_num['insert'] * 100 / total_sql_num, 2) 410 | sql_info['deleteRatio'] = round( 411 | sql_num['delete'] * 100 / total_sql_num, 2) 412 | sql_info['updateRatio'] = round( 413 | 100 - sql_info['selectRatio'] - sql_info['insertRatio'] - sql_info['deleteRatio'], 2) 414 | sql_info['associationIndex'] = index.association_indexes 415 | self.display_detail_info['recommendIndexes'].append(sql_info) 416 | 417 | def compute_index_optimization_info(self, index: AdvisedIndex, table_name: str, statement: str): 418 | sql_info = {'sqlDetails': []} 419 | insert_queries, delete_queries, update_queries, select_queries, \ 420 | positive_queries, ineffective_queries, negative_queries = \ 421 | self.workload.get_index_related_queries(index) 422 | 423 | for category, queries in zip([QueryType.INEFFECTIVE, QueryType.POSITIVE, QueryType.NEGATIVE], 424 | [ineffective_queries, positive_queries, negative_queries]): 425 | sql_count = int(sum(query.get_frequency() for query in queries)) 426 | # Record 5 ineffective or negative queries. 427 | if category in [QueryType.INEFFECTIVE, QueryType.NEGATIVE]: 428 | queries = queries[:5] 429 | for query in queries: 430 | sql_detail = {} 431 | sql_template = query.get_statement() 432 | for pattern in SQL_DISPLAY_PATTERN: 433 | sql_template = re.sub(pattern, '?', sql_template) 434 | 435 | sql_detail['sqlTemplate'] = sql_template 436 | sql_detail['sql'] = query.get_statement() 437 | sql_detail['sqlCount'] = int(round(sql_count)) 438 | 439 | if category == QueryType.POSITIVE: 440 | origin_cost = self.workload.get_origin_cost_of_query(query) 441 | current_cost = self.workload.get_indexes_cost_of_query(query, tuple([index])) 442 | sql_optimized = (origin_cost - current_cost) / current_cost * 100 443 | sql_detail['optimized'] = '%.1f' % sql_optimized 444 | sql_detail['correlationType'] = category.value 445 | sql_info['sqlDetails'].append(sql_detail) 446 | self.record_info(index, sql_info, table_name, statement) 447 | 448 | def display_advise_indexes_info(self, show_detail: bool): 449 | self.display_detail_info['workloadCount'] = int( 450 | sum(query.get_frequency() for query in self.workload.get_queries())) 451 | self.display_detail_info['recommendIndexes'] = [] 452 | logging.info('filter advised indexes by using max-index-storage and max-index-num.') 453 | for key, index in enumerate(self.determine_indexes): 454 | # display determine indexes 455 | table_name = index.get_table().split('.')[-1] 456 | statement = index.get_index_statement() 457 | bar_print(statement) 458 | if show_detail: 459 | # Record detailed SQL optimization information for each index. 460 | self.compute_index_optimization_info( 461 | index, table_name, statement) 462 | 463 | def generate_incremental_index(self, history_advise_indexes): 464 | self.integrate_indexes = copy.copy(history_advise_indexes) 465 | self.integrate_indexes['currentIndexes'] = {} 466 | for key, index in enumerate(self.determine_indexes): 467 | self.integrate_indexes['currentIndexes'][index.get_table()] = \ 468 | self.integrate_indexes['currentIndexes'].get(index.get_table(), []) 469 | self.integrate_indexes['currentIndexes'][index.get_table()].append( 470 | (index.get_columns(), index.get_index_type())) 471 | 472 | def generate_redundant_useless_indexes(self, history_invalid_indexes): 473 | created_indexes = fetch_created_indexes(self.executor) 474 | record_history_invalid_indexes(self.integrate_indexes['historyIndexes'], history_invalid_indexes, 475 | created_indexes) 476 | print_header_boundary(" Created indexes ") 477 | self.display_detail_info['createdIndexes'] = [] 478 | if not created_indexes: 479 | bar_print("No created indexes!") 480 | else: 481 | self.record_created_indexes(created_indexes) 482 | for index in created_indexes: 483 | bar_print("%s: %s;" % (index.get_schema(), index.get_indexdef())) 484 | workload_indexnames = self.workload.get_used_index_names() 485 | display_useless_redundant_indexes(created_indexes, workload_indexnames, 486 | self.display_detail_info) 487 | unused_indexes = [index for index in created_indexes if index.get_indexname() not in workload_indexnames] 488 | self.redundant_indexes = get_redundant_created_indexes(created_indexes, unused_indexes) 489 | 490 | def record_created_indexes(self, created_indexes): 491 | for index in created_indexes: 492 | index_info = {'schemaName': index.get_schema(), 'tbName': index.get_table(), 493 | 'columns': index.get_columns(), 'statement': index.get_indexdef() + ';'} 494 | self.display_detail_info['createdIndexes'].append(index_info) 495 | 496 | def display_incremental_index(self, history_invalid_indexes, 497 | workload_file_path): 498 | 499 | # Display historical effective indexes. 500 | if self.integrate_indexes['historyIndexes']: 501 | print_header_boundary(" Historical effective indexes ") 502 | for table_name, index_list in self.integrate_indexes['historyIndexes'].items(): 503 | print_statement(index_list, table_name) 504 | # Display historical invalid indexes. 505 | if history_invalid_indexes: 506 | print_header_boundary(" Historical invalid indexes ") 507 | for table_name, index_list in history_invalid_indexes.items(): 508 | print_statement(index_list, table_name) 509 | # Save integrate indexes result. 510 | if not isinstance(workload_file_path, dict): 511 | integrate_indexes_file = os.path.join(os.path.realpath(os.path.dirname(workload_file_path)), 512 | 'index_result.json') 513 | for table, indexes in self.integrate_indexes['currentIndexes'].items(): 514 | self.integrate_indexes['historyIndexes'][table] = \ 515 | self.integrate_indexes['historyIndexes'].get(table, []) 516 | self.integrate_indexes['historyIndexes'][table].extend(indexes) 517 | self.integrate_indexes['historyIndexes'][table] = \ 518 | list( 519 | set(map(tuple, (self.integrate_indexes['historyIndexes'][table])))) 520 | with open(integrate_indexes_file, 'w') as file: 521 | json.dump(self.integrate_indexes['historyIndexes'], file) 522 | 523 | @staticmethod 524 | def filter_redundant_indexes_with_diff_types(candidate_indexes: List[AdvisedIndex]): 525 | sorted_indexes = sorted(candidate_indexes, key=lambda x: (x.get_table(), x.get_columns())) 526 | for table, _index_group in groupby(sorted_indexes, key=lambda x: x.get_table()): 527 | index_group = list(_index_group) 528 | for i in range(len(index_group) - 1): 529 | cur_index = index_group[i] 530 | next_index = index_group[i + 1] 531 | if match_columns(cur_index.get_columns(), next_index.get_columns()): 532 | if cur_index.benefit == next_index.benefit: 533 | if cur_index.get_index_type() == 'global': 534 | candidate_indexes.remove(next_index) 535 | index_group[i + 1] = index_group[i] 536 | else: 537 | candidate_indexes.remove(cur_index) 538 | else: 539 | if cur_index.benefit < next_index.benefit: 540 | candidate_indexes.remove(cur_index) 541 | else: 542 | candidate_indexes.remove(next_index) 543 | index_group[i + 1] = index_group[i] 544 | 545 | 546 | def green(text): 547 | return '\033[32m%s\033[0m' % text 548 | 549 | 550 | def print_header_boundary(header): 551 | # Output a header first, which looks more beautiful. 552 | try: 553 | term_width = os.get_terminal_size().columns 554 | # Get the width of each of the two sides of the terminal. 555 | side_width = (term_width - len(header)) // 2 556 | except (AttributeError, OSError): 557 | side_width = 0 558 | title = SHARP * side_width + header + SHARP * side_width 559 | bar_print(green(title)) 560 | 561 | 562 | def load_workload(file_path): 563 | wd_dict = {} 564 | workload = [] 565 | global BLANK 566 | with open(file_path, 'r', errors='ignore') as file: 567 | raw_text = ''.join(file.readlines()) 568 | sqls = sqlparse.split(raw_text) 569 | for sql in sqls: 570 | if any(re.search(r'((\A|[\s(,])%s[\s*(])' % tp, sql.lower()) for tp in SQL_TYPE): 571 | TWO_BLANKS = BLANK * 2 572 | while TWO_BLANKS in sql: 573 | sql = sql.replace(TWO_BLANKS, BLANK) 574 | if sql.strip() not in wd_dict.keys(): 575 | wd_dict[sql.strip()] = 1 576 | else: 577 | wd_dict[sql.strip()] += 1 578 | for sql, freq in wd_dict.items(): 579 | workload.append(QueryItem(sql, freq)) 580 | 581 | return workload 582 | 583 | 584 | def get_workload_template(workload): 585 | templates = {} 586 | placeholder = r'@@@' 587 | 588 | for item in workload: 589 | sql_template = item.get_statement() 590 | for pattern in SQL_PATTERN: 591 | sql_template = re.sub(pattern, placeholder, sql_template) 592 | if sql_template not in templates: 593 | templates[sql_template] = {} 594 | templates[sql_template]['cnt'] = 0 595 | templates[sql_template]['samples'] = [] 596 | templates[sql_template]['cnt'] += item.get_frequency() 597 | # reservoir sampling 598 | statement = item.get_statement() 599 | if has_dollar_placeholder(statement): 600 | statement = replace_function_comma(statement) 601 | statement = replace_comma_with_dollar(statement) 602 | if len(templates[sql_template]['samples']) < SAMPLE_NUM: 603 | templates[sql_template]['samples'].append(statement) 604 | else: 605 | if random.randint(0, templates[sql_template]['cnt']) < SAMPLE_NUM: 606 | templates[sql_template]['samples'][random.randint(0, SAMPLE_NUM - 1)] = \ 607 | statement 608 | 609 | return templates 610 | 611 | 612 | def compress_workload(input_path): 613 | compressed_workload = [] 614 | if isinstance(input_path, dict): 615 | templates = input_path 616 | elif JSON_TYPE: 617 | with open(input_path, 'r', errors='ignore') as file: 618 | templates = json.load(file) 619 | else: 620 | workload = load_workload(input_path) 621 | templates = get_workload_template(workload) 622 | 623 | for _, elem in templates.items(): 624 | for sql in elem['samples']: 625 | compressed_workload.append( 626 | QueryItem(sql.strip(), elem['cnt'] / len(elem['samples']))) 627 | 628 | return compressed_workload 629 | 630 | 631 | def generate_single_column_indexes(advised_indexes: List[AdvisedIndex]): 632 | """ Generate single column indexes. """ 633 | single_column_indexes = [] 634 | if len(advised_indexes) == 0: 635 | return single_column_indexes 636 | 637 | for index in advised_indexes: 638 | table = index.get_table() 639 | columns = index.get_columns() 640 | index_type = index.get_index_type() 641 | for column in columns.split(COLUMN_DELIMITER): 642 | single_column_index = IndexItemFactory().get_index(table, column, index_type) 643 | if single_column_index not in single_column_indexes: 644 | single_column_indexes.append(single_column_index) 645 | return single_column_indexes 646 | 647 | 648 | def add_more_column_index(indexes, table, columns_info, single_col_info): 649 | columns, columns_index_type = columns_info 650 | single_column, single_index_type = single_col_info 651 | if columns_index_type.strip('"') != single_index_type.strip('"'): 652 | add_more_column_index(indexes, table, (columns, 'local'), 653 | (single_column, 'local')) 654 | add_more_column_index(indexes, table, (columns, 'global'), 655 | (single_column, 'global')) 656 | else: 657 | current_columns_index = IndexItemFactory().get_index(table, columns + COLUMN_DELIMITER + single_column, 658 | columns_index_type) 659 | if current_columns_index in indexes: 660 | return 661 | # To make sure global is behind local 662 | if single_index_type == 'local': 663 | global_columns_index = IndexItemFactory().get_index(table, columns + COLUMN_DELIMITER + single_column, 664 | 'global') 665 | if global_columns_index in indexes: 666 | global_pos = indexes.index(global_columns_index) 667 | indexes[global_pos] = current_columns_index 668 | current_columns_index = global_columns_index 669 | indexes.append(current_columns_index) 670 | 671 | 672 | def query_index_advise(executor, query): 673 | """ Call the single-indexes-advisor in the database. """ 674 | 675 | sql = get_single_advisor_sql(query) 676 | results = executor.execute_sqls([sql]) 677 | advised_indexes = parse_single_advisor_results(results) 678 | 679 | return advised_indexes 680 | 681 | 682 | def get_index_storage(executor, hypo_index_id): 683 | sqls = get_hypo_index_head_sqls(is_multi_node(executor)) 684 | index_size_sqls = sqls + ['select * from pg_catalog.hypopg_estimate_size(%s);' % hypo_index_id] 685 | results = executor.execute_sqls(index_size_sqls) 686 | for cur_tuple in results: 687 | if re.match(r'\d+', str(cur_tuple[0]).strip()): 688 | return float(str(cur_tuple[0]).strip()) / 1024 / 1024 689 | 690 | 691 | def update_index_storage(indexes, hypo_index_ids, executor): 692 | if indexes: 693 | for index, hypo_index_id in zip(indexes, hypo_index_ids): 694 | storage = index.get_storage() 695 | if not storage: 696 | storage = get_index_storage(executor, hypo_index_id) 697 | index.set_storage(storage) 698 | 699 | 700 | def get_plan_cost(statements, executor): 701 | plan_sqls = [] 702 | plan_sqls.extend(get_hypo_index_head_sqls(is_multi_node(executor))) 703 | for statement in statements: 704 | plan_sqls.extend(get_prepare_sqls(statement)) 705 | results = executor.execute_sqls(plan_sqls) 706 | cost, index_names_list, plans = parse_explain_plan(results, len(statements)) 707 | return cost, index_names_list, plans 708 | 709 | 710 | def get_workload_costs(statements, executor, threads=20): 711 | costs = [] 712 | index_names_list = [] 713 | plans = [] 714 | statements_blocks = split_iter(statements, threads) 715 | try: 716 | with Pool(threads) as p: 717 | results = p.starmap(get_plan_cost, [[_statements, executor] for _statements in statements_blocks]) 718 | except TypeError: 719 | results = [get_plan_cost(statements, executor)] 720 | for _costs, _index_names_list, _plans in results: 721 | costs.extend(_costs) 722 | index_names_list.extend(_index_names_list) 723 | plans.extend(_plans) 724 | return costs, index_names_list, plans 725 | 726 | 727 | def estimate_workload_cost_file(executor, workload, indexes=None): 728 | select_queries = [] 729 | select_queries_pos = [] 730 | query_costs = [0] * len(workload.get_queries()) 731 | for i, query in enumerate(workload.get_queries()): 732 | select_queries.append(query.get_statement()) 733 | select_queries_pos.append(i) 734 | with hypo_index_ctx(executor): 735 | index_setting_sqls = get_index_setting_sqls(indexes, is_multi_node(executor)) 736 | hypo_index_ids = parse_hypo_index(executor.execute_sqls(index_setting_sqls)) 737 | update_index_storage(indexes, hypo_index_ids, executor) 738 | costs, index_names, plans = get_workload_costs([query.get_statement() for query in 739 | workload.get_queries()], executor) 740 | # Update query cost for select queries and positive_pos for indexes. 741 | for cost, query_pos in zip(costs, select_queries_pos): 742 | query_costs[query_pos] = cost * workload.get_queries()[query_pos].get_frequency() 743 | workload.add_indexes(indexes, query_costs, index_names, plans) 744 | 745 | 746 | def query_index_check(executor, query, indexes, sort_by_column_no=True): 747 | """ Obtain valid indexes based on the optimizer. """ 748 | valid_indexes = [] 749 | if len(indexes) == 0: 750 | return valid_indexes, None 751 | if sort_by_column_no: 752 | # When the cost values are the same, the execution plan picks the last index created. 753 | # Sort indexes to ensure that short indexes have higher priority. 754 | indexes = sorted(indexes, key=lambda index: -len(index.get_columns())) 755 | index_check_results = executor.execute_sqls(get_index_check_sqls(query, indexes, is_multi_node(executor))) 756 | valid_indexes = get_checked_indexes(index_check_results, set(index.get_table() for index in indexes)) 757 | cost = None 758 | for res in index_check_results: 759 | if '(cost' in res[0]: 760 | cost = parse_plan_cost(res[0]) 761 | break 762 | return valid_indexes, cost 763 | 764 | 765 | def remove_unused_indexes(executor, statement, valid_indexes): 766 | """ Remove invalid indexes by creating virtual indexes in different order. """ 767 | least_indexes = valid_indexes 768 | for indexes in permutations(valid_indexes, len(valid_indexes)): 769 | cur_indexes, cost = query_index_check(executor, statement, indexes, False) 770 | if len(cur_indexes) < len(least_indexes): 771 | least_indexes = cur_indexes 772 | return least_indexes 773 | 774 | 775 | def filter_candidate_columns_by_cost(valid_indexes, statement, executor, max_candidate_columns): 776 | indexes = [] 777 | for table, index_group in groupby(valid_indexes, key=lambda x: x.get_table()): 778 | cost_index = [] 779 | index_group = list(index_group) 780 | if len(index_group) <= max_candidate_columns: 781 | indexes.extend(index_group) 782 | continue 783 | for _index in index_group: 784 | _indexes, _cost = query_index_check(executor, statement, [_index]) 785 | if _indexes: 786 | heapq.heappush(cost_index, (_cost, _indexes[0])) 787 | for _cost, _index in heapq.nsmallest(max_candidate_columns, cost_index): 788 | indexes.append(_index) 789 | return indexes 790 | 791 | 792 | def set_source_indexes(indexes, source_indexes): 793 | """Record the original index of the recommended index.""" 794 | for index in indexes: 795 | table = index.get_table() 796 | columns = index.get_columns() 797 | for source_index in source_indexes: 798 | if not source_index.get_source_index(): 799 | continue 800 | if not source_index.get_table() == table: 801 | continue 802 | if f'{columns}{COLUMN_DELIMITER}'.startswith(f'{source_index.get_columns()}{COLUMN_DELIMITER}'): 803 | index.set_source_index(source_index.get_source_index()) 804 | continue 805 | 806 | 807 | def get_valid_indexes(advised_indexes, original_base_indexes, statement, executor, **kwargs): 808 | need_check = False 809 | single_column_indexes = generate_single_column_indexes(advised_indexes) 810 | valid_indexes, cost = query_index_check(executor, statement, single_column_indexes) 811 | valid_indexes = filter_candidate_columns_by_cost(valid_indexes, statement, executor, 812 | kwargs.get('max_candidate_columns', MAX_CANDIDATE_COLUMNS)) 813 | valid_indexes, cost = query_index_check(executor, statement, valid_indexes) 814 | pre_indexes = valid_indexes[:] 815 | 816 | # Increase the number of index columns in turn and check their validity. 817 | for column_num in range(2, MAX_INDEX_COLUMN_NUM + 1): 818 | for table, index_group in groupby(valid_indexes, key=lambda x: x.get_table()): 819 | _original_base_indexes = [index for index in original_base_indexes if index.get_table() == table] 820 | for index in list(index_group) + _original_base_indexes: 821 | columns = index.get_columns() 822 | index_type = index.get_index_type() 823 | # only validate indexes with column number of column_num 824 | if index.get_columns_num() != column_num - 1: 825 | continue 826 | need_check = True 827 | for single_column_index in single_column_indexes: 828 | _table = single_column_index.get_table() 829 | if _table != table: 830 | continue 831 | single_column = single_column_index.get_columns() 832 | single_index_type = single_column_index.get_index_type() 833 | if single_column not in columns.split(COLUMN_DELIMITER): 834 | add_more_column_index(valid_indexes, table, (columns, index_type), 835 | (single_column, single_index_type)) 836 | if need_check: 837 | cur_indexes, cur_cost = query_index_check(executor, statement, valid_indexes) 838 | # If the cost reduction does not exceed 5%, return the previous indexes. 839 | if cur_cost is not None and cost / cur_cost < 1.05: 840 | set_source_indexes(pre_indexes, original_base_indexes) 841 | return pre_indexes 842 | valid_indexes = cur_indexes 843 | pre_indexes = valid_indexes[:] 844 | cost = cur_cost 845 | need_check = False 846 | else: 847 | break 848 | 849 | # filtering of functionally redundant indexes due to index order 850 | valid_indexes = remove_unused_indexes(executor, statement, valid_indexes) 851 | set_source_indexes(valid_indexes, original_base_indexes) 852 | return valid_indexes 853 | 854 | 855 | def get_redundant_created_indexes(indexes: List[ExistingIndex], unused_indexes: List[ExistingIndex]): 856 | sorted_indexes = sorted(indexes, key=lambda i: (i.get_table(), len(i.get_columns().split(COLUMN_DELIMITER)))) 857 | redundant_indexes = [] 858 | for table, index_group in groupby(sorted_indexes, key=lambda i: i.get_table()): 859 | cur_table_indexes = list(index_group) 860 | for pos, index in enumerate(cur_table_indexes[:-1]): 861 | is_redundant = False 862 | for next_index in cur_table_indexes[pos + 1:]: 863 | if match_columns(index.get_columns(), next_index.get_columns()): 864 | is_redundant = True 865 | index.redundant_objs.append(next_index) 866 | if is_redundant: 867 | redundant_indexes.append(index) 868 | remove_list = [] 869 | for pos, index in enumerate(redundant_indexes): 870 | is_redundant = False 871 | for redundant_obj in index.redundant_objs: 872 | # Redundant objects are not in the useless index set, or 873 | # both redundant objects and redundant index in the useless index must be redundant index. 874 | index_exist = redundant_obj not in unused_indexes or \ 875 | (redundant_obj in unused_indexes and index in unused_indexes) 876 | if index_exist: 877 | is_redundant = True 878 | if not is_redundant: 879 | remove_list.append(pos) 880 | for item in sorted(remove_list, reverse=True): 881 | redundant_indexes.pop(item) 882 | return redundant_indexes 883 | 884 | 885 | def record_history_invalid_indexes(history_indexes, history_invalid_indexes, indexes): 886 | for index in indexes: 887 | # Update historical indexes validity. 888 | schema_table = index.get_schema_table() 889 | cur_columns = index.get_columns() 890 | if not history_indexes.get(schema_table): 891 | continue 892 | for column in history_indexes.get(schema_table, dict()): 893 | history_index_column = list(map(str.strip, column[0].split(','))) 894 | existed_index_column = list(map(str.strip, cur_columns[0].split(','))) 895 | if len(history_index_column) > len(existed_index_column): 896 | continue 897 | if history_index_column == existed_index_column[0:len(history_index_column)]: 898 | history_indexes[schema_table].remove(column) 899 | history_invalid_indexes[schema_table] = history_invalid_indexes.get( 900 | schema_table, list()) 901 | history_invalid_indexes[schema_table].append(column) 902 | if not history_indexes[schema_table]: 903 | del history_indexes[schema_table] 904 | 905 | 906 | @lru_cache(maxsize=None) 907 | def fetch_created_indexes(executor): 908 | schemas = [elem.lower() 909 | for elem in filter(None, executor.get_schema().split(','))] 910 | created_indexes = [] 911 | for schema in schemas: 912 | sql = "select tablename from pg_tables where schemaname = '%s'" % schema 913 | res = executor.execute_sqls([sql]) 914 | if not res: 915 | continue 916 | tables = parse_table_sql_results(res) 917 | if not tables: 918 | continue 919 | sql = get_existing_index_sql(schema, tables) 920 | res = executor.execute_sqls([sql]) 921 | if not res: 922 | continue 923 | _created_indexes = parse_existing_indexes_results(res, schema) 924 | created_indexes.extend(_created_indexes) 925 | 926 | return created_indexes 927 | 928 | 929 | def print_candidate_indexes(candidate_indexes): 930 | print_header_boundary(" Generate candidate indexes ") 931 | for index in candidate_indexes: 932 | table = index.get_table() 933 | columns = index.get_columns() 934 | index_type = index.get_index_type() 935 | if index.get_index_type(): 936 | bar_print("table: ", table, "columns: ", columns, "type: ", index_type) 937 | else: 938 | bar_print("table: ", table, "columns: ", columns) 939 | if not candidate_indexes: 940 | bar_print("No candidate indexes generated!") 941 | 942 | 943 | def index_sort_func(index): 944 | """ Sort indexes function. """ 945 | if index.get_index_type() == 'global': 946 | return index.get_table(), 0, index.get_columns() 947 | else: 948 | return index.get_table(), 1, index.get_columns() 949 | 950 | 951 | def filter_redundant_indexes_with_same_type(indexes: List[AdvisedIndex]): 952 | """ Filter redundant indexes with same index_type. """ 953 | candidate_indexes = [] 954 | for table, table_group_indexes in groupby(sorted(indexes, key=lambda x: x.get_table()), 955 | key=lambda x: x.get_table()): 956 | for index_type, index_type_group_indexes in groupby( 957 | sorted(table_group_indexes, key=lambda x: x.get_index_type()), key=lambda x: x.get_index_type()): 958 | column_sorted_indexes = sorted(index_type_group_indexes, key=lambda x: x.get_columns()) 959 | for i in range(len(column_sorted_indexes) - 1): 960 | if match_columns(column_sorted_indexes[i].get_columns(), column_sorted_indexes[i + 1].get_columns()): 961 | continue 962 | else: 963 | index = column_sorted_indexes[i] 964 | candidate_indexes.append(index) 965 | candidate_indexes.append(column_sorted_indexes[-1]) 966 | candidate_indexes.sort(key=index_sort_func) 967 | 968 | return candidate_indexes 969 | 970 | 971 | def add_query_indexes(indexes: List[AdvisedIndex], queries: List[QueryItem], pos): 972 | for table, index_group in groupby(indexes, key=lambda x: x.get_table()): 973 | _indexes = sorted(list(index_group), key=lambda x: -len(x.get_columns())) 974 | for _index in _indexes: 975 | if len(queries[pos].get_indexes()) >= FULL_ARRANGEMENT_THRESHOLD: 976 | break 977 | queries[pos].append_index(_index) 978 | 979 | 980 | def generate_query_placeholder_indexes(query, executor: BaseExecutor, n_distinct=0.01, reltuples=10000, 981 | use_all_columns=False): 982 | indexes = [] 983 | if not has_dollar_placeholder(query) and not use_all_columns: 984 | return [] 985 | parser = Parser(query) 986 | try: 987 | tables = [table.lower() for table in parser.tables] 988 | columns = [] 989 | for position, _columns in parser.columns_dict.items(): 990 | if position.upper() not in ['SELECT', 'INSERT', 'UPDATE']: 991 | columns.extend(_columns) 992 | flatten_columns = UniqueList() 993 | for column in flatten(columns): 994 | flatten_columns.append(column) 995 | logging.info(f'parsing query: {query}') 996 | logging.info(f'found tables: {" ".join(tables)}, columns: {" ".join(flatten_columns)}') 997 | except (ValueError, AttributeError, KeyError) as e: 998 | logging.warning('Found %s while parsing SQL statement.', e) 999 | return [] 1000 | for table in tables: 1001 | table_indexes = [] 1002 | table_context = get_table_context(table, executor) 1003 | if not table_context or table_context.reltuples < reltuples: 1004 | logging.info(f'filtered: table_context is {table_context} and does not meet the requirements') 1005 | continue 1006 | for column in flatten_columns: 1007 | if table_context.has_column(column) and table_context.get_n_distinct(column) <= n_distinct: 1008 | table_indexes.extend(generate_placeholder_indexes(table_context, column.split('.')[-1].lower())) 1009 | # top 20 for candidate indexes 1010 | indexes.extend(sorted(table_indexes, key=lambda x: table_context.get_n_distinct(x.get_columns()))[:20]) 1011 | logging.info(f'related indexes: {indexes}') 1012 | return indexes 1013 | 1014 | 1015 | def get_original_base_indexes(original_indexes: List[ExistingIndex]) -> List[AdvisedIndex]: 1016 | original_base_indexes = [] 1017 | for index in original_indexes: 1018 | table = f'{index.get_schema()}.{index.get_table()}' 1019 | columns = index.get_columns().split(COLUMN_DELIMITER) 1020 | index_type = index.get_index_type() 1021 | columns_length = len(columns) 1022 | for _len in range(1, columns_length): 1023 | _columns = COLUMN_DELIMITER.join(columns[:_len]) 1024 | original_base_indexes.append(IndexItemFactory().get_index(table, _columns, index_type)) 1025 | all_columns_index = IndexItemFactory().get_index(table, index.get_columns(), index_type) 1026 | original_base_indexes.append(all_columns_index) 1027 | all_columns_index.set_source_index(index) 1028 | return original_base_indexes 1029 | 1030 | 1031 | def generate_candidate_indexes(workload: WorkLoad, executor: BaseExecutor, n_distinct, reltuples, use_all_columns, 1032 | **kwargs): 1033 | all_indexes = [] 1034 | with executor.session(): 1035 | # Resolve the bug that indexes extended on top of the original index will not be recommended 1036 | # by building the base index related to the original index 1037 | original_indexes = fetch_created_indexes(executor) 1038 | original_base_indexes = get_original_base_indexes(original_indexes) 1039 | for pos, query in GLOBAL_PROCESS_BAR.process_bar(list(enumerate(workload.get_queries())), 'Candidate indexes'): 1040 | advised_indexes = [] 1041 | for advised_index in generate_query_placeholder_indexes(query.get_statement(), executor, n_distinct, 1042 | reltuples, use_all_columns, 1043 | ): 1044 | if advised_index not in advised_indexes: 1045 | advised_indexes.append(advised_index) 1046 | valid_indexes = get_valid_indexes(advised_indexes, original_base_indexes, query.get_statement(), executor, 1047 | **kwargs) 1048 | logging.info(f'get valid indexes: {valid_indexes} for the query {query}') 1049 | add_query_indexes(valid_indexes, workload.get_queries(), pos) 1050 | for index in valid_indexes: 1051 | if index not in all_indexes: 1052 | all_indexes.append(index) 1053 | 1054 | # Filter redundant indexes. 1055 | candidate_indexes = filter_redundant_indexes_with_same_type(all_indexes) 1056 | 1057 | if len(candidate_indexes) == 0: 1058 | estimate_workload_cost_file(executor, workload) 1059 | 1060 | return candidate_indexes 1061 | 1062 | 1063 | def powerset(iterable): 1064 | """ powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) """ 1065 | s = list(iterable) 1066 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 1067 | 1068 | 1069 | def generate_sorted_atomic_config(queries: List[QueryItem], 1070 | candidate_indexes: List[AdvisedIndex]) -> List[Tuple[AdvisedIndex, ...]]: 1071 | atomic_config_total = [] 1072 | 1073 | for query in queries: 1074 | if len(query.get_indexes()) == 0: 1075 | continue 1076 | 1077 | indexes = [] 1078 | for i, (table, group) in enumerate(groupby(query.get_sorted_indexes(), lambda x: x.get_table())): 1079 | # The max number of table is 2. 1080 | if i > 1: 1081 | break 1082 | # The max index number for each table is 2. 1083 | indexes.extend(list(group)[:2]) 1084 | atomic_configs = powerset(indexes) 1085 | for new_config in atomic_configs: 1086 | if new_config not in atomic_config_total: 1087 | atomic_config_total.append(new_config) 1088 | # Make sure atomic_config_total contains candidate_indexes. 1089 | for index in candidate_indexes: 1090 | if (index,) not in atomic_config_total: 1091 | atomic_config_total.append((index,)) 1092 | return atomic_config_total 1093 | 1094 | 1095 | def generate_atomic_config_containing_same_columns(candidate_indexes: List[AdvisedIndex]) \ 1096 | -> List[Tuple[AdvisedIndex, AdvisedIndex]]: 1097 | atomic_configs = [] 1098 | for _, _indexes in groupby(sorted(candidate_indexes, key=lambda index: (index.get_table(), index.get_index_type())), 1099 | key=lambda index: (index.get_table(), index.get_index_type())): 1100 | _indexes = list(_indexes) 1101 | _indexes.sort(key=lambda index: len(index.get_columns().split(COLUMN_DELIMITER))) 1102 | for short_index_idx in range(len(_indexes) - 1): 1103 | short_columns = set(_indexes[short_index_idx].get_columns().split(COLUMN_DELIMITER)) 1104 | for long_index_idx in range(short_index_idx + 1, len(_indexes)): 1105 | long_columns = set(_indexes[long_index_idx].get_columns().split(COLUMN_DELIMITER)) 1106 | if not (short_columns - long_columns): 1107 | atomic_configs.append((_indexes[short_index_idx], _indexes[long_index_idx])) 1108 | 1109 | return atomic_configs 1110 | 1111 | 1112 | def display_redundant_indexes(redundant_indexes: List[ExistingIndex]): 1113 | if not redundant_indexes: 1114 | bar_print("No redundant indexes!") 1115 | # Display redundant indexes. 1116 | for index in redundant_indexes: 1117 | if index.get_is_unique() or index.is_primary_key(): 1118 | continue 1119 | statement = "DROP INDEX %s.%s;(%s)" % (index.get_schema(), index.get_indexname(), index.get_indexdef()) 1120 | bar_print(statement) 1121 | bar_print('Related indexes:') 1122 | for _index in index.redundant_objs: 1123 | _statement = "\t%s" % (_index.get_indexdef()) 1124 | bar_print(_statement) 1125 | bar_print('') 1126 | 1127 | 1128 | def record_redundant_indexes(redundant_indexes: List[ExistingIndex], detail_info): 1129 | for index in redundant_indexes: 1130 | statement = "DROP INDEX %s.%s;" % (index.get_schema(), index.get_indexname()) 1131 | existing_index = [item.get_indexname() + ':' + 1132 | item.get_columns() for item in index.redundant_objs] 1133 | redundant_index = {"schemaName": index.get_schema(), "tbName": index.get_table(), 1134 | "type": IndexType.REDUNDANT.value, 1135 | "columns": index.get_columns(), "statement": statement, 1136 | "existingIndex": existing_index} 1137 | detail_info['uselessIndexes'].append(redundant_index) 1138 | 1139 | 1140 | def display_useless_redundant_indexes(created_indexes, workload_indexnames, detail_info): 1141 | unused_indexes = [index for index in created_indexes if index.get_indexname() not in workload_indexnames] 1142 | print_header_boundary(" Current workload useless indexes ") 1143 | detail_info['uselessIndexes'] = [] 1144 | has_unused_index = False 1145 | 1146 | for cur_index in unused_indexes: 1147 | if (not cur_index.get_is_unique()) and (not cur_index.is_primary_key()): 1148 | has_unused_index = True 1149 | statement = "DROP INDEX %s;" % cur_index.get_indexname() 1150 | bar_print(statement) 1151 | useless_index = {"schemaName": cur_index.get_schema(), "tbName": cur_index.get_table(), 1152 | "type": IndexType.INVALID.value, 1153 | "columns": cur_index.get_columns(), "statement": statement} 1154 | detail_info['uselessIndexes'].append(useless_index) 1155 | 1156 | if not has_unused_index: 1157 | bar_print("No useless indexes!") 1158 | print_header_boundary(" Redundant indexes ") 1159 | redundant_indexes = get_redundant_created_indexes(created_indexes, unused_indexes) 1160 | display_redundant_indexes(redundant_indexes) 1161 | record_redundant_indexes(redundant_indexes, detail_info) 1162 | 1163 | 1164 | def greedy_determine_opt_config(workload: WorkLoad, atomic_config_total: List[Tuple[AdvisedIndex]], 1165 | candidate_indexes: List[AdvisedIndex]): 1166 | opt_config = [] 1167 | candidate_indexes_copy = candidate_indexes[:] 1168 | for i in range(len(candidate_indexes_copy)): 1169 | cur_max_benefit = 0 1170 | cur_index = None 1171 | for index in candidate_indexes_copy: 1172 | cur_config = copy.copy(opt_config) 1173 | cur_config.append(index) 1174 | cur_estimated_benefit = infer_workload_benefit(workload, cur_config, atomic_config_total) 1175 | if cur_estimated_benefit > cur_max_benefit: 1176 | cur_max_benefit = cur_estimated_benefit 1177 | cur_index = index 1178 | if cur_index: 1179 | if len(opt_config) == MAX_INDEX_NUM: 1180 | break 1181 | opt_config.append(cur_index) 1182 | candidate_indexes_copy.remove(cur_index) 1183 | else: 1184 | break 1185 | 1186 | return opt_config 1187 | 1188 | 1189 | def get_last_indexes_result(input_path): 1190 | last_indexes_result_file = os.path.join(os.path.realpath( 1191 | os.path.dirname(input_path)), 'index_result.json') 1192 | integrate_indexes = {'historyIndexes': {}} 1193 | if os.path.exists(last_indexes_result_file): 1194 | try: 1195 | with open(last_indexes_result_file, 'r', errors='ignore') as file: 1196 | integrate_indexes['historyIndexes'] = json.load(file) 1197 | except json.JSONDecodeError: 1198 | return integrate_indexes 1199 | return integrate_indexes 1200 | 1201 | 1202 | def recalculate_cost_for_opt_indexes(workload: WorkLoad, indexes: Tuple[AdvisedIndex]): 1203 | """After the recommended indexes are all built, calculate the gain of each index.""" 1204 | all_used_index_names = workload.get_workload_used_indexes(indexes) 1205 | for query, used_index_names in zip(workload.get_queries(), all_used_index_names): 1206 | cost = workload.get_indexes_cost_of_query(query, indexes) 1207 | origin_cost = workload.get_indexes_cost_of_query(query, None) 1208 | query_benefit = origin_cost - cost 1209 | query.set_benefit(query_benefit) 1210 | query.reset_opt_indexes() 1211 | if not query_benefit > 0: 1212 | continue 1213 | for index in indexes: 1214 | for index_name in used_index_names: 1215 | if index.match_index_name(index_name): 1216 | index.append_positive_query(query) 1217 | query.append_index(index) 1218 | 1219 | 1220 | def filter_no_benefit_indexes(indexes): 1221 | for index in indexes[:]: 1222 | if not index.get_positive_queries(): 1223 | indexes.remove(index) 1224 | logging.info('remove no benefit index {index}') 1225 | 1226 | 1227 | def index_advisor_workload(history_advise_indexes, executor: BaseExecutor, workload_file_path, 1228 | multi_iter_mode: bool, show_detail: bool, n_distinct: float, reltuples: int, 1229 | use_all_columns: bool, **kwargs): 1230 | queries = compress_workload(workload_file_path) 1231 | queries = [query for query in queries if is_valid_statement(executor, query.get_statement())] 1232 | workload = WorkLoad(queries) 1233 | candidate_indexes = generate_candidate_indexes(workload, executor, n_distinct, reltuples, use_all_columns, **kwargs) 1234 | print_candidate_indexes(candidate_indexes) 1235 | index_advisor = IndexAdvisor(executor, workload, multi_iter_mode) 1236 | if candidate_indexes: 1237 | print_header_boundary(" Determine optimal indexes ") 1238 | with executor.session(): 1239 | if multi_iter_mode: 1240 | opt_indexes = index_advisor.complex_index_advisor(candidate_indexes) 1241 | else: 1242 | opt_indexes = index_advisor.simple_index_advisor(candidate_indexes) 1243 | if opt_indexes: 1244 | index_advisor.filter_low_benefit_index(opt_indexes, kwargs.get('improved_rate', 0)) 1245 | if index_advisor.determine_indexes: 1246 | estimate_workload_cost_file(executor, workload, tuple(index_advisor.determine_indexes)) 1247 | recalculate_cost_for_opt_indexes(workload, tuple(index_advisor.determine_indexes)) 1248 | determine_indexes = index_advisor.determine_indexes[:] 1249 | filter_no_benefit_indexes(index_advisor.determine_indexes) 1250 | index_advisor.determine_indexes.sort(key=lambda index: -sum(query.get_benefit() 1251 | for query in index.get_positive_queries())) 1252 | workload.replace_indexes(tuple(determine_indexes), tuple(index_advisor.determine_indexes)) 1253 | 1254 | index_advisor.display_advise_indexes_info(show_detail) 1255 | created_indexes = fetch_created_indexes(executor) 1256 | if kwargs.get('show_benefits'): 1257 | index_advisor.print_benefits(created_indexes) 1258 | index_advisor.generate_incremental_index(history_advise_indexes) 1259 | history_invalid_indexes = {} 1260 | with executor.session(): 1261 | index_advisor.generate_redundant_useless_indexes(history_invalid_indexes) 1262 | index_advisor.display_incremental_index( 1263 | history_invalid_indexes, workload_file_path) 1264 | if show_detail: 1265 | print_header_boundary(" Display detail information ") 1266 | sql_info = json.dumps( 1267 | index_advisor.display_detail_info, indent=4, separators=(',', ':')) 1268 | bar_print(sql_info) 1269 | return index_advisor.display_detail_info, index_advisor.index_benefits, index_advisor.redundant_indexes 1270 | 1271 | 1272 | def check_parameter(args): 1273 | global MAX_INDEX_NUM, MAX_INDEX_STORAGE, JSON_TYPE, MAX_INDEX_COLUMN_NUM 1274 | if args.max_index_num is not None and args.max_index_num <= 0: 1275 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % 1276 | args.max_index_num) 1277 | if args.max_candidate_columns <= 0: 1278 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % 1279 | args.max_candidate_columns) 1280 | if args.max_index_columns <= 0: 1281 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % 1282 | args.max_index_columns) 1283 | if args.max_index_storage is not None and args.max_index_storage <= 0: 1284 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % 1285 | args.max_index_storage) 1286 | if args.max_n_distinct <= 0 or args.max_n_distinct > 1: 1287 | raise argparse.ArgumentTypeError( 1288 | '%s is an invalid max-n-distinct which ranges from 0 to 1' % args.max_n_distinct) 1289 | if args.min_improved_rate < 0 or args.min_improved_rate >= 1: 1290 | raise argparse.ArgumentTypeError( 1291 | '%s is an invalid min-improved-rate which must be greater than ' 1292 | 'or equal to 0 and less than 1' % args.min_improved_rate) 1293 | if args.min_reltuples <= 0: 1294 | raise argparse.ArgumentTypeError('%s is an invalid positive int value' % args.min_reltuples) 1295 | JSON_TYPE = args.json 1296 | MAX_INDEX_NUM = args.max_index_num 1297 | MAX_INDEX_STORAGE = args.max_index_storage 1298 | MAX_INDEX_COLUMN_NUM = args.max_index_columns 1299 | # Check if the password contains illegal characters. 1300 | is_legal = re.search(r'^[A-Za-z0-9~!@#$%^&*()-_=+\|\[{}\];:,<.>/?]+$', args.W) 1301 | if not is_legal: 1302 | raise ValueError("The password contains illegal characters.") 1303 | 1304 | 1305 | def main(argv): 1306 | arg_parser = argparse.ArgumentParser( 1307 | description='Generate index set for workload.') 1308 | arg_parser.add_argument("db_port", help="Port of database", type=int) 1309 | arg_parser.add_argument("database", help="Name of database", action=CheckWordValid) 1310 | arg_parser.add_argument( 1311 | "--db-host", "--h", help="Host for database", action=CheckWordValid) 1312 | arg_parser.add_argument( 1313 | "-U", "--db-user", help="Username for database log-in", action=CheckWordValid) 1314 | arg_parser.add_argument( 1315 | "file", type=path_type, help="File containing workload queries (One query per line)", action=CheckWordValid) 1316 | arg_parser.add_argument("--schema", help="Schema name for the current business data", 1317 | required=True, action=CheckWordValid) 1318 | arg_parser.add_argument( 1319 | "--max-index-num", "--max_index_num", help="Maximum number of suggested indexes", type=int) 1320 | arg_parser.add_argument("--max-index-storage", "--max_index_storage", 1321 | help="Maximum storage of suggested indexes/MB", type=int) 1322 | arg_parser.add_argument("--multi-iter-mode", "--multi_iter_mode", action='store_true', 1323 | help="Whether to use multi-iteration algorithm", default=False) 1324 | arg_parser.add_argument("--max-n-distinct", type=float, 1325 | help="Maximum n_distinct value (reciprocal of the distinct number)" 1326 | " for the index column.", 1327 | default=0.01) 1328 | arg_parser.add_argument("--min-improved-rate", type=float, 1329 | help="Minimum improved rate of the cost for the indexes", 1330 | default=0.1) 1331 | arg_parser.add_argument("--max-candidate-columns", type=int, 1332 | help='Maximum number of columns for candidate indexes', 1333 | default=MAX_CANDIDATE_COLUMNS) 1334 | arg_parser.add_argument('--max-index-columns', type=int, 1335 | help='Maximum number of columns in a joint index', 1336 | default=4) 1337 | arg_parser.add_argument("--min-reltuples", type=int, 1338 | help="Minimum reltuples value for the index column.", default=10000) 1339 | arg_parser.add_argument("--multi-node", "--multi_node", action='store_true', 1340 | help="Whether to support distributed scenarios", default=False) 1341 | arg_parser.add_argument("--json", action='store_true', 1342 | help="Whether the workload file format is json", default=False) 1343 | arg_parser.add_argument("--driver", action='store_true', 1344 | help="Whether to employ python-driver", default=False) 1345 | arg_parser.add_argument("--show-detail", "--show_detail", action='store_true', 1346 | help="Whether to show detailed sql information", default=False) 1347 | arg_parser.add_argument("--show-benefits", action='store_true', 1348 | help="Whether to show index benefits", default=False) 1349 | args = arg_parser.parse_args(argv) 1350 | 1351 | set_logger() 1352 | args.W = get_password() 1353 | check_parameter(args) 1354 | # Initialize the connection. 1355 | if args.driver: 1356 | try: 1357 | import psycopg2 1358 | try: 1359 | from .executors.driver_executor import DriverExecutor 1360 | except ImportError: 1361 | from executors.driver_executor import DriverExecutor 1362 | executor = DriverExecutor(args.database, args.db_user, args.W, args.db_host, args.db_port, args.schema) 1363 | except ImportError: 1364 | logging.warning('Python driver import failed, ' 1365 | 'the gsql mode will be selected to connect to the database.') 1366 | 1367 | executor = GsqlExecutor(args.database, args.db_user, args.W, args.db_host, args.db_port, args.schema) 1368 | args.driver = None 1369 | else: 1370 | executor = GsqlExecutor(args.database, args.db_user, args.W, args.db_host, args.db_port, args.schema) 1371 | use_all_columns = True 1372 | index_advisor_workload(get_last_indexes_result(args.file), executor, args.file, 1373 | args.multi_iter_mode, args.show_detail, args.max_n_distinct, args.min_reltuples, 1374 | use_all_columns, improved_rate=args.min_improved_rate, 1375 | max_candidate_columns=args.max_candidate_columns, show_benefits=args.show_benefits) 1376 | 1377 | 1378 | if __name__ == '__main__': 1379 | main(sys.argv[1:]) 1380 | -------------------------------------------------------------------------------- /mcts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import sys 15 | import math 16 | import random 17 | import copy 18 | from collections import defaultdict 19 | try: 20 | from utils import infer_workload_benefit 21 | except ImportError: 22 | from .utils import infer_workload_benefit 23 | 24 | TOTAL_STORAGE = 0 25 | STORAGE_THRESHOLD = 0 26 | AVAILABLE_CHOICES = None 27 | ATOMIC_CHOICES = None 28 | WORKLOAD = None 29 | MAX_INDEX_NUM = 0 30 | CANDIDATE_SUBSET = defaultdict(list) 31 | CANDIDATE_SUBSET_BENEFIT = defaultdict(list) 32 | 33 | 34 | def find_best_benefit(choice): 35 | if choice[-1] in CANDIDATE_SUBSET.keys() and set(choice[:-1]) in CANDIDATE_SUBSET[choice[-1]]: 36 | return CANDIDATE_SUBSET_BENEFIT[choice[-1]][CANDIDATE_SUBSET[choice[-1]].index(set(choice[:-1]))] 37 | 38 | total_benefit = infer_workload_benefit(WORKLOAD, choice, ATOMIC_CHOICES) 39 | 40 | CANDIDATE_SUBSET[choice[-1]].append(set(choice[:-1])) 41 | CANDIDATE_SUBSET_BENEFIT[choice[-1]].append(total_benefit) 42 | return total_benefit 43 | 44 | 45 | def get_diff(available_choices, choices): 46 | return set(available_choices).difference(set(choices)) 47 | 48 | 49 | class State(object): 50 | """ 51 | The game state of the Monte Carlo tree search, 52 | the state data recorded under a specific Node node, 53 | including the current game score, the current number of game rounds, 54 | and the execution record from the beginning to the current. 55 | 56 | It is necessary to realize whether the current state has reached the end of the game state, 57 | and support the operation of randomly fetching from the Action collection. 58 | """ 59 | 60 | def __init__(self): 61 | self.current_storage = 0.0 62 | self.current_benefit = 0.0 63 | # record the sum of choices up to the current state 64 | self.accumulation_choices = [] 65 | # record available choices of current state 66 | self.available_choices = [] 67 | self.displayable_choices = [] 68 | 69 | def reset_state(self): 70 | self.set_available_choices(set(AVAILABLE_CHOICES).difference(self.accumulation_choices)) 71 | 72 | def get_available_choices(self): 73 | return self.available_choices 74 | 75 | def set_available_choices(self, choices): 76 | self.available_choices = choices 77 | 78 | def get_current_storage(self): 79 | return self.current_storage 80 | 81 | def set_current_storage(self, value): 82 | self.current_storage = value 83 | 84 | def get_current_benefit(self): 85 | return self.current_benefit 86 | 87 | def set_current_benefit(self, value): 88 | self.current_benefit = value 89 | 90 | def get_accumulation_choices(self): 91 | return self.accumulation_choices 92 | 93 | def set_accumulation_choices(self, choices): 94 | self.accumulation_choices = choices 95 | 96 | def is_terminal(self): 97 | # The current node is a leaf node. 98 | return len(self.accumulation_choices) == MAX_INDEX_NUM 99 | 100 | def get_next_state_with_random_choice(self): 101 | # Ensure that the choices taken are not repeated. 102 | if not self.available_choices: 103 | return None 104 | random_choice = random.choice([choice for choice in self.available_choices]) 105 | self.available_choices.remove(random_choice) 106 | choice = copy.copy(self.accumulation_choices) 107 | choice.append(random_choice) 108 | benefit = find_best_benefit(choice) + self.current_benefit 109 | # If the current choice does not satisfy restrictions, then continue to get the next choice. 110 | if benefit <= self.current_benefit or \ 111 | self.current_storage + random_choice.get_storage() > STORAGE_THRESHOLD: 112 | return self.get_next_state_with_random_choice() 113 | 114 | next_state = State() 115 | # Initialize the properties of the new state. 116 | next_state.set_accumulation_choices(choice) 117 | next_state.set_current_benefit(benefit) 118 | next_state.set_current_storage(self.current_storage + random_choice.get_storage()) 119 | next_state.set_available_choices(get_diff(AVAILABLE_CHOICES, choice)) 120 | return next_state 121 | 122 | def __repr__(self): 123 | self.displayable_choices = ['{}: {}'.format(choice.get_table(), choice.get_columns()) 124 | for choice in self.accumulation_choices] 125 | return "reward: {}, storage :{}, choices: {}".format( 126 | self.current_benefit, self.current_storage, self.displayable_choices) 127 | 128 | 129 | class Node(object): 130 | """ 131 | The Node of the Monte Carlo tree search tree contains the parent node and 132 | current point information, 133 | which is used to calculate the traversal times and quality value of the UCB, 134 | and the State of the Node selected by the game. 135 | """ 136 | def __init__(self): 137 | self.visit_number = 0 138 | self.quality = 0.0 139 | 140 | self.parent = None 141 | self.children = [] 142 | self.state = None 143 | 144 | def reset_node(self): 145 | self.visit_number = 0 146 | self.quality = 0.0 147 | self.children = [] 148 | 149 | def get_parent(self): 150 | return self.parent 151 | 152 | def set_parent(self, parent): 153 | self.parent = parent 154 | 155 | def get_children(self): 156 | return self.children 157 | 158 | def expand_child(self, node): 159 | node.set_parent(self) 160 | self.children.append(node) 161 | 162 | def set_state(self, state): 163 | self.state = state 164 | 165 | def get_state(self): 166 | return self.state 167 | 168 | def get_visit_number(self): 169 | return self.visit_number 170 | 171 | def set_visit_number(self, number): 172 | self.visit_number = number 173 | 174 | def update_visit_number(self): 175 | self.visit_number += 1 176 | 177 | def get_quality_value(self): 178 | return self.quality 179 | 180 | def set_quality_value(self, value): 181 | self.quality = value 182 | 183 | def update_quality_value(self, reward): 184 | self.quality += reward 185 | 186 | def is_all_expand(self): 187 | return False if self.state.available_choices else True 188 | 189 | def __repr__(self): 190 | return "Node: {}, Q/N: {}/{}, State: {}".format( 191 | hash(self), self.quality, self.visit_number, self.state) 192 | 193 | 194 | def tree_policy(node): 195 | """ 196 | In the Selection and Expansion stages of the Monte Carlo tree search, 197 | the node that needs to be searched (such as the root node) is passed in, 198 | and the best node that needs to be expanded is returned 199 | according to the exploration/exploitation algorithm. 200 | Note that if the node is a leaf node, it will be returned directly. 201 | 202 | The basic strategy is first to find the child nodes that have not been selected 203 | and pick them randomly if there is more than one. Then, if both are selected, 204 | find the one with the largest UCB value that has weighed exploration/exploitation, 205 | and randomly choose if the UCB values are equal. 206 | """ 207 | 208 | # Check if the current node is a leaf node. 209 | while node and not node.get_state().is_terminal(): 210 | 211 | if node.is_all_expand(): 212 | if not node.children: 213 | return node 214 | node = best_child(node, True) 215 | else: 216 | # Return the new sub-node. 217 | sub_node = expand(node) 218 | if sub_node: 219 | return sub_node 220 | # When no node satisfies the condition in the remaining nodes, this state is terminal. 221 | return node 222 | 223 | # Return the leaf node. 224 | return node 225 | 226 | 227 | def default_policy(node): 228 | """ 229 | In the Simulation stage of the Monte Carlo tree search, input a node that needs to be expanded, 230 | create a new node after a random operation, and return the reward of the new node. 231 | Note that the input node should not be a child node, 232 | and there are unexecuted Actions that can be expendable. 233 | 234 | The basic strategy is to choose the Action at random. 235 | """ 236 | 237 | # Get the state of the game. 238 | current_state = copy.deepcopy(node.get_state()) 239 | current_state.set_accumulation_choices(copy.copy(node.get_state().get_accumulation_choices())) 240 | current_state.set_available_choices(copy.copy(node.get_state().get_available_choices())) 241 | 242 | # Run until the game is over. 243 | while not current_state.is_terminal(): 244 | # Pick one random action to play and get the next state. 245 | next_state = current_state.get_next_state_with_random_choice() 246 | if not next_state: 247 | break 248 | current_state = next_state 249 | 250 | final_state_reward = current_state.get_current_benefit() 251 | return final_state_reward 252 | 253 | 254 | def expand(node): 255 | """ 256 | Enter a node, expand a new node on the node, use the random method to execute the Action, 257 | and return the new node. Note that it is necessary to ensure that the newly 258 | added nodes differ from other node Action. 259 | """ 260 | 261 | new_state = node.get_state().get_next_state_with_random_choice() 262 | if not new_state: 263 | return None 264 | sub_node = Node() 265 | sub_node.set_state(new_state) 266 | node.expand_child(sub_node) 267 | 268 | return sub_node 269 | 270 | 271 | def best_child(node, is_exploration): 272 | """ 273 | Using the UCB algorithm, 274 | select the child node with the highest score after weighing the exploration and exploitation. 275 | Note that the current Q-value score with the highest score is directly chosen if it is in the prediction stage. 276 | """ 277 | 278 | best_score = -sys.maxsize 279 | best_sub_node = None 280 | 281 | # Travel all-sub nodes to find the best one. 282 | for sub_node in node.get_children(): 283 | # The children nodes of the node contain the children node whose state is empty, 284 | # this kind of node comes from the node that does not meet the conditions. 285 | if not sub_node.get_state(): 286 | continue 287 | # Explore constants. 288 | if is_exploration: 289 | C = 1 / math.sqrt(2.0) 290 | else: 291 | C = 0.0 292 | # UCB = quality / times + C * sqrt(2 * ln(total_times) / times) 293 | left = sub_node.get_quality_value() / sub_node.get_visit_number() 294 | right = 2.0 * math.log(node.get_visit_number()) / sub_node.get_visit_number() 295 | score = left + C * math.sqrt(right) 296 | # Get the maximum score while filtering nodes that do not meet the space constraints and 297 | # nodes that have no revenue 298 | if score > best_score \ 299 | and sub_node.get_state().get_current_storage() <= STORAGE_THRESHOLD \ 300 | and sub_node.get_state().get_current_benefit() > 0: 301 | best_sub_node = sub_node 302 | best_score = score 303 | 304 | return best_sub_node 305 | 306 | 307 | def backpropagate(node, reward): 308 | """ 309 | In the Backpropagation stage of the Monte Carlo tree search, 310 | input the node that needs to be expended and the reward of the newly executed Action, 311 | feed it back to the expend node and all upstream nodes, 312 | and update the corresponding data. 313 | """ 314 | 315 | # Update until the root node. 316 | while node is not None: 317 | # Update the visit number. 318 | node.update_visit_number() 319 | 320 | # Update the quality value. 321 | node.update_quality_value(reward) 322 | 323 | # Change the node to the parent node. 324 | node = node.parent 325 | 326 | 327 | def monte_carlo_tree_search(node): 328 | """ 329 | Implement the Monte Carlo tree search algorithm, pass in a root node, 330 | expand new nodes and update data according to the 331 | tree structure that has been explored before in a limited time, 332 | and then return as long as the child node with the highest exploitation. 333 | 334 | When making predictions, 335 | you only need to select the node with the largest exploitation according to the Q value, 336 | and find the next optimal node. 337 | """ 338 | 339 | computation_budget = len(AVAILABLE_CHOICES) * STORAGE_THRESHOLD / TOTAL_STORAGE * 50 340 | 341 | # Run as much as possible under the computation budget. 342 | for i in range(int(computation_budget)): 343 | # 1. find the best node to expand. 344 | expand_node = tree_policy(node) 345 | 346 | # 2. random get next action and get reward. 347 | reward = default_policy(expand_node) 348 | 349 | # 3. update all passing nodes with reward. 350 | backpropagate(expand_node, reward) 351 | 352 | # Get the best next node. 353 | best_next_node = best_child(node, False) 354 | 355 | return best_next_node 356 | 357 | 358 | def MCTS(workload_info, atomic_choices, available_choices, storage_threshold, max_index_num): 359 | global ATOMIC_CHOICES, STORAGE_THRESHOLD, WORKLOAD, \ 360 | AVAILABLE_CHOICES, MAX_INDEX_NUM, TOTAL_STORAGE 361 | WORKLOAD = workload_info 362 | AVAILABLE_CHOICES = available_choices 363 | ATOMIC_CHOICES = atomic_choices 364 | STORAGE_THRESHOLD = storage_threshold 365 | MAX_INDEX_NUM = max_index_num if max_index_num and max_index_num < len(available_choices) \ 366 | else len(available_choices) 367 | for index in available_choices: 368 | TOTAL_STORAGE += index.get_storage() 369 | if STORAGE_THRESHOLD >= TOTAL_STORAGE: 370 | return sorted(available_choices, key=lambda x: x.benefit, reverse=True)[:MAX_INDEX_NUM] 371 | # Create the initialized state and initialized node. 372 | init_state = State() 373 | init_node = Node() 374 | init_node.set_state(init_state) 375 | current_node = init_node 376 | 377 | opt_config = [] 378 | # Set the rounds to play. 379 | for i in range(len(AVAILABLE_CHOICES)): 380 | if current_node: 381 | current_node.reset_node() 382 | current_node.state.reset_state() 383 | current_node = monte_carlo_tree_search(current_node) 384 | if current_node: 385 | opt_config = current_node.state.accumulation_choices 386 | else: 387 | break 388 | return opt_config 389 | -------------------------------------------------------------------------------- /process_bar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import os 15 | import re 16 | import time 17 | 18 | 19 | try: 20 | TERMINAL_SIZE = os.get_terminal_size().columns 21 | except (AttributeError, OSError): 22 | TERMINAL_SIZE = 60 23 | 24 | 25 | class ProcessBar: 26 | LENGTH = TERMINAL_SIZE - 40 27 | 28 | def __init__(self): 29 | self.start = time.perf_counter() 30 | self.iterable = None 31 | self.title = None 32 | self.percent = 0 33 | self.process_num = 0 34 | 35 | def process_bar(self, iterable, title): 36 | self.iterable = iterable 37 | self.title = title 38 | self.percent = 0 39 | self.process_num = 0 40 | return self 41 | 42 | def __get_time(self): 43 | return time.perf_counter() - self.start 44 | 45 | def __processbar(self): 46 | bar_print(self.__output()) 47 | 48 | def __output(self): 49 | return "{}: {:^3.0f}%[{}{}]{:.2f}s".format(self.title, self.percent * 100, 50 | '>' * int(self.percent * ProcessBar.LENGTH), 51 | '*' * (ProcessBar.LENGTH - int(self.percent * ProcessBar.LENGTH)), 52 | self.__get_time()) 53 | 54 | @staticmethod 55 | def match(content): 56 | p = re.compile('[*>]+') 57 | res = p.search(str(content)) 58 | if res: 59 | return len(res.group()) == ProcessBar.LENGTH 60 | 61 | def __iter__(self): 62 | return self 63 | 64 | def __next__(self): 65 | self.process_num += 1 66 | if self.process_num > len(self.iterable): 67 | raise StopIteration 68 | self.percent = self.process_num / len(self.iterable) 69 | self.__processbar() 70 | 71 | return self.iterable[self.process_num - 1] 72 | 73 | 74 | def _print_wrap(): 75 | last_content = '' 76 | 77 | def inner_bar_print(*content): 78 | nonlocal last_content 79 | if ProcessBar.match(last_content): 80 | print(f'\x1b[1A{" " * TERMINAL_SIZE}\r') 81 | size = len(content) 82 | print('\x1b[1A' + (' '.join(['{}'] * size)).format(*content)) 83 | if ProcessBar.match(content[0]): 84 | last_content = content[0] 85 | else: 86 | print(last_content) 87 | else: 88 | if ProcessBar.match(content[0]): 89 | print(content[0]) 90 | last_content = content[0] 91 | 92 | return inner_bar_print 93 | 94 | 95 | bar_print = _print_wrap() 96 | -------------------------------------------------------------------------------- /sql_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | from itertools import count 15 | from functools import lru_cache 16 | 17 | try: 18 | from .utils import get_placeholders, has_dollar_placeholder, replace_comma_with_dollar, replace_function_comma 19 | except ImportError: 20 | from utils import get_placeholders, has_dollar_placeholder, replace_comma_with_dollar, replace_function_comma 21 | 22 | counter = count(start=0, step=1) 23 | 24 | 25 | def get_existing_index_sql(schema, tables): 26 | tables_string = ','.join(["'%s'" % table for table in tables]) 27 | # Query all table indexes information and primary key information. 28 | sql = "SELECT c.relname AS tablename, i.relname AS indexname, " \ 29 | "pg_catalog.pg_get_indexdef(i.oid) AS indexdef, p.contype AS pkey from " \ 30 | "pg_index x JOIN pg_class c ON c.oid = x.indrelid JOIN " \ 31 | "pg_class i ON i.oid = x.indexrelid LEFT JOIN pg_namespace n " \ 32 | "ON n.oid = c.relnamespace LEFT JOIN pg_constraint p ON (i.oid = p.conindid " \ 33 | "AND p.contype = 'p') WHERE (c.relkind = ANY (ARRAY['r'::\"char\", " \ 34 | "'m'::\"char\"])) AND (i.relkind = ANY (ARRAY['i'::\"char\", 'I'::\"char\"])) " \ 35 | "AND n.nspname = '%s' AND c.relname in (%s) order by c.relname;" % \ 36 | (schema, tables_string) 37 | return sql 38 | 39 | 40 | @lru_cache(maxsize=None) 41 | def get_prepare_sqls(statement): 42 | if has_dollar_placeholder(statement): 43 | statement = replace_function_comma(statement) 44 | statement = replace_comma_with_dollar(statement) 45 | prepare_id = 'prepare_' + str(next(counter)) 46 | placeholder_size = len(get_placeholders(statement)) 47 | prepare_args = '' if not placeholder_size else '(%s)' % (','.join(['NULL'] * placeholder_size)) 48 | return [f'prepare {prepare_id} as {statement}', f'explain execute {prepare_id}{prepare_args}', 49 | f'deallocate prepare {prepare_id}'] 50 | 51 | 52 | def get_workload_cost_sqls(statements, indexes, is_multi_node): 53 | sqls = [] 54 | if indexes: 55 | # Create hypo-indexes. 56 | sqls.append('SET enable_hypo_index = on;\n') 57 | for index in indexes: 58 | sqls.append("SELECT pg_catalog.hypopg_create_index('CREATE INDEX ON %s(%s) %s');" % 59 | (index.get_table(), index.get_columns(), index.get_index_type())) 60 | if is_multi_node: 61 | sqls.append('set enable_fast_query_shipping = off;') 62 | sqls.append('set enable_stream_operator = on; ') 63 | sqls.append("set explain_perf_mode = 'normal'; ") 64 | for index, statement in enumerate(statements): 65 | sqls.extend(get_prepare_sqls(statement)) 66 | return sqls 67 | 68 | 69 | def get_index_setting_sqls(indexes, is_multi_node): 70 | sqls = get_hypo_index_head_sqls(is_multi_node)[:] 71 | if indexes: 72 | # Create hypo-indexes. 73 | for index in indexes: 74 | sqls.append("SELECT pg_catalog.hypopg_create_index('CREATE INDEX ON %s(%s) %s');" % 75 | (index.get_table(), index.get_columns(), index.get_index_type())) 76 | return sqls 77 | 78 | 79 | def get_single_advisor_sql(ori_sql): 80 | advisor_sql = 'select pg_catalog.gs_index_advise(\'' 81 | for elem in ori_sql: 82 | if elem == '\'': 83 | advisor_sql += '\'' 84 | advisor_sql += elem 85 | advisor_sql += '\');' 86 | return advisor_sql 87 | 88 | 89 | @lru_cache(maxsize=None) 90 | def get_hypo_index_head_sqls(is_multi_node): 91 | sqls = ['SET enable_hypo_index = on;'] 92 | if is_multi_node: 93 | sqls.append('SET enable_fast_query_shipping = off;') 94 | sqls.append('SET enable_stream_operator = on;') 95 | sqls.append("set explain_perf_mode = 'normal'; ") 96 | return sqls 97 | 98 | 99 | def get_index_check_sqls(query, indexes, is_multi_node): 100 | sqls = get_hypo_index_head_sqls(is_multi_node)[:] 101 | for index in indexes: 102 | table = index.get_table() 103 | columns = index.get_columns() 104 | index_type = index.get_index_type() 105 | sqls.append("SELECT pg_catalog.hypopg_create_index('CREATE INDEX ON %s(%s) %s')" % 106 | (table, columns, index_type)) 107 | sqls.append('SELECT pg_catalog.hypopg_display_index()') 108 | sqls.append("SET explain_perf_mode = 'normal';") 109 | sqls.extend(get_prepare_sqls(query)) 110 | sqls.append('SELECT pg_catalog.hypopg_reset_index()') 111 | return sqls 112 | 113 | 114 | def get_table_info_sql(table, schema): 115 | return f"select reltuples, parttype from pg_class where relname ilike '{table}' and " \ 116 | f"relnamespace = (select oid from pg_namespace where nspname = '{schema}');" 117 | 118 | 119 | def get_column_info_sql(table, schema): 120 | return f"select n_distinct, attname from pg_stats where tablename ilike '{table}' " \ 121 | f"and schemaname = '{schema}';" 122 | -------------------------------------------------------------------------------- /sql_output_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import re 15 | from typing import List 16 | import logging 17 | 18 | from sqlparse.tokens import Punctuation, Keyword, Name 19 | 20 | try: 21 | from utils import match_table_name, IndexItemFactory, ExistingIndex, AdvisedIndex, get_tokens, UniqueList, \ 22 | QUERY_PLAN_SUFFIX, EXPLAIN_SUFFIX, ERROR_KEYWORD, PREPARE_KEYWORD 23 | except ImportError: 24 | from .utils import match_table_name, IndexItemFactory, ExistingIndex, AdvisedIndex, get_tokens, UniqueList, \ 25 | QUERY_PLAN_SUFFIX, EXPLAIN_SUFFIX, ERROR_KEYWORD, PREPARE_KEYWORD 26 | 27 | 28 | def __get_columns_from_indexdef(indexdef): 29 | for content in get_tokens(indexdef): 30 | if content.ttype is Punctuation and content.normalized == '(': 31 | return content.parent.value.strip()[1:-1] 32 | 33 | 34 | def __is_unique_from_indexdef(indexdef): 35 | for content in get_tokens(indexdef): 36 | if content.ttype is Keyword: 37 | return content.value.upper() == 'UNIQUE' 38 | 39 | 40 | def __get_index_type_from_indexdef(indexdef): 41 | for content in get_tokens(indexdef): 42 | if content.ttype is Name: 43 | if content.value.upper() == 'LOCAL': 44 | return 'local' 45 | elif content.value.upper() == 'GLOBAL': 46 | return 'global' 47 | 48 | 49 | def parse_existing_indexes_results(results, schema) -> List[ExistingIndex]: 50 | indexes = list() 51 | indexdef_list = [] 52 | table = index = pkey = None 53 | for cur_tuple in results: 54 | if len(cur_tuple) == 1: 55 | continue 56 | else: 57 | temptable, tempindex, indexdef, temppkey = cur_tuple 58 | if temptable and tempindex: 59 | table, index, pkey = temptable, tempindex, temppkey 60 | if indexdef.endswith('+'): 61 | if len(indexdef_list) >= 1: 62 | if indexdef.startswith('SUBPARTITION'): 63 | indexdef_list.append(' ' * 8 + indexdef.strip(' +')) 64 | else: 65 | indexdef_list.append(' ' * 4 + indexdef.strip(' +')) 66 | else: 67 | indexdef_list.append(indexdef.strip(' +')) 68 | continue 69 | elif indexdef_list and indexdef.startswith(')'): 70 | indexdef_list.append(indexdef.strip().strip('+').strip()) 71 | indexdef = '\n'.join(indexdef_list) 72 | indexdef_list = [] 73 | cur_columns = __get_columns_from_indexdef(indexdef) 74 | is_unique = __is_unique_from_indexdef(indexdef) 75 | index_type = __get_index_type_from_indexdef(indexdef) 76 | cur_index = ExistingIndex( 77 | schema, table, index, cur_columns, indexdef) 78 | if pkey: 79 | cur_index.set_is_primary_key(True) 80 | if is_unique: 81 | cur_index.set_is_unique() 82 | if index_type: 83 | cur_index.set_index_type(index_type) 84 | indexes.append(cur_index) 85 | return indexes 86 | 87 | 88 | def parse_table_sql_results(table_sql_results): 89 | tables = [] 90 | for cur_tuple in table_sql_results: 91 | text = cur_tuple[0] 92 | if 'tablename' in text or re.match(r'-+', text) or re.match(r'\(\d+ rows?\)', text) \ 93 | or text.strip().startswith('SELECT '): 94 | continue 95 | tables.append(text.strip()) 96 | return tables 97 | 98 | 99 | def parse_hypo_index(results): 100 | hypo_index_ids = [] 101 | for cur_tuple in results: 102 | text = cur_tuple[0] 103 | if 'btree' in text: 104 | hypo_index_id = text.strip().strip('()').split(',')[0] 105 | hypo_index_ids.append(hypo_index_id) 106 | return hypo_index_ids 107 | 108 | 109 | def parse_explain_plan(results, query_num): 110 | # record execution plan for each explain statement (the parameter results contain multiple explain results) 111 | plans = [] 112 | plan = [] 113 | index_names_list = [] 114 | found_plan = False 115 | plan_start = False 116 | costs = [] 117 | i = 0 118 | index_names = UniqueList() 119 | for cur_tuple in results: 120 | text = cur_tuple[0] 121 | # Save the results of the last index_names according to the EXPLAIN keyword. 122 | if QUERY_PLAN_SUFFIX in text or text == EXPLAIN_SUFFIX: 123 | index_names_list.append(index_names) 124 | index_names = UniqueList() 125 | plans.append(plan) 126 | plan = [] 127 | found_plan = True 128 | plan_start = True 129 | continue 130 | if plan_start: 131 | plan.append(cur_tuple[0]) 132 | # Consider execution errors and ensure that the cost value of an explain is counted only once. 133 | if ERROR_KEYWORD in text and 'prepared statement' not in text: 134 | if i >= query_num: 135 | logging.info(f'Cannot correct parse the explain results: {results}') 136 | raise ValueError("The size of queries is not correct!") 137 | costs.append(0) 138 | index_names_list.append(index_names) 139 | index_names = UniqueList() 140 | i += 1 141 | if found_plan and '(cost=' in text: 142 | if i >= query_num: 143 | logging.info(f'Cannot correct parse the explain results: {results}') 144 | raise ValueError("The size of queries is not correct!") 145 | query_cost = parse_plan_cost(text) 146 | costs.append(query_cost) 147 | found_plan = False 148 | i += 1 149 | if 'Index' in text and 'Scan' in text: 150 | ind1, ind2 = re.search(r'Index.*Scan(.*)on ([^\s]+)', 151 | text.strip(), re.IGNORECASE).groups() 152 | if ind1.strip(): 153 | # `Index (Only)? Scan (Backward)? using index1` 154 | if ind1.strip().split(' ')[-1] not in index_names: 155 | index_names.append(ind1.strip().split(' ')[-1]) 156 | else: 157 | index_names.append(ind2) 158 | index_names_list.append(index_names) 159 | index_names_list = index_names_list[1:] 160 | plans.append(plan) 161 | plans = plans[1:] 162 | 163 | # when a syntax error causes multiple explain queries to be run as one query 164 | while len(index_names_list) < query_num: 165 | index_names_list.append([]) 166 | plans.append([]) 167 | while i < query_num: 168 | costs.append(0) 169 | i += 1 170 | return costs, index_names_list, plans 171 | 172 | 173 | def parse_plan_cost(line): 174 | """ Parse the explain plan to get the estimated cost by database optimizer. """ 175 | cost = -1 176 | # like "Limit (cost=19932.04..19933.29 rows=100 width=17)" 177 | pattern = re.compile(r'\(cost=([^)]*)\)', re.S) 178 | matched_res = re.search(pattern, line) 179 | if matched_res and len(matched_res.group(1).split()) == 3: 180 | _cost, _rows, _width = matched_res.group(1).split() 181 | # like cost=19932.04..19933.29 182 | cost = float(_cost.split('..')[-1]) 183 | return cost 184 | 185 | 186 | def parse_single_advisor_results(results) -> List[AdvisedIndex]: 187 | indexes = [] 188 | for cur_tuple in results: 189 | res = cur_tuple[0] 190 | schema_idx = 0 191 | table_idx = 1 192 | index_type_idx = -1 193 | columns_slice = slice(2, -1) 194 | # like '(1 row)' or (2 rows) 195 | if res.strip().endswith('rows)') or res.strip().endswith(' row)'): 196 | continue 197 | # like ' (public,date_dim,d_year,global)' or ' (public,store_sales,"ss_sold_date_sk,ss_item_sk","")' 198 | if len(res) > 2 and res.strip()[0:1] == '(': 199 | items = res.strip().split(',') 200 | table = items[schema_idx][1:] + '.' + items[table_idx] 201 | columns = ','.join(items[columns_slice]).strip('\"') 202 | if columns == '': 203 | continue 204 | if items[index_type_idx].strip(') ') not in ['global', 'local']: 205 | index_type = '' 206 | else: 207 | index_type = items[index_type_idx].strip(') ') 208 | indexes.append(IndexItemFactory().get_index(table, columns, index_type)) 209 | return indexes 210 | 211 | 212 | def __add_valid_index(record, hypoid_table_column, valid_indexes: list): 213 | # like 'Index Scan using <134667>btree_global_item_i_manufact_id on item (cost=0.00..68.53 rows=16 width=59)' 214 | tokens = record.split(' ') 215 | for token in tokens: 216 | if 'btree' in token: 217 | if 'btree_global_' in token: 218 | index_type = 'global' 219 | elif 'btree_local_' in token: 220 | index_type = 'local' 221 | else: 222 | index_type = '' 223 | hypo_index_id = re.search( 224 | r'\d+', token.split('_', 1)[0]).group() 225 | table_columns = hypoid_table_column.get(hypo_index_id) 226 | if not table_columns: 227 | continue 228 | table, columns = table_columns.split(':') 229 | index = IndexItemFactory().get_index(table, columns, index_type) 230 | if index not in valid_indexes: 231 | valid_indexes.append(index) 232 | 233 | 234 | def get_checked_indexes(index_check_results, tables) -> list: 235 | valid_indexes = [] 236 | hypoid_table_column = {} 237 | hypo_index_info_length = 4 238 | btree_idx = 0 239 | index_id_idx = 1 240 | table_idx = 2 241 | columns_idx = 3 242 | for cur_tuple in index_check_results: 243 | # like '(<134672>btree_local_customer_c_customer_sk,134672,customer,"(c_customer_sk)")' 244 | text = cur_tuple[0] 245 | if text.strip().startswith('(<') and 'btree' in text: 246 | if len(text.split(',', 3)) == hypo_index_info_length: 247 | hypo_index_info = text.split(',', 3) 248 | table_name = re.search(r'btree(_global|_local|)_(.*?%s)' % hypo_index_info[table_idx], 249 | hypo_index_info[btree_idx]).group(2) 250 | match_flag, table_name = match_table_name(table_name, tables) 251 | if not match_flag: 252 | return valid_indexes 253 | hypoid_table_column[hypo_index_info[index_id_idx]] = \ 254 | table_name + ':' + hypo_index_info[columns_idx].strip('"()') 255 | 256 | if 'Index' in text and 'Scan' in text and 'btree' in text: 257 | __add_valid_index(text, hypoid_table_column, valid_indexes) 258 | return valid_indexes 259 | -------------------------------------------------------------------------------- /table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | from dataclasses import dataclass, field 15 | from typing import List 16 | from functools import lru_cache 17 | 18 | try: 19 | from .sql_generator import get_table_info_sql, get_column_info_sql 20 | from .executors.common import BaseExecutor 21 | from .utils import IndexItemFactory 22 | except ImportError: 23 | from sql_generator import get_table_info_sql, get_column_info_sql 24 | from executors.common import BaseExecutor 25 | from utils import IndexItemFactory 26 | 27 | 28 | @lru_cache(maxsize=None) 29 | def get_table_context(origin_table, executor: BaseExecutor): 30 | reltuples, parttype = None, None 31 | if '.' in origin_table: 32 | schemas, table = origin_table.split('.') 33 | else: 34 | table = origin_table 35 | schemas = executor.get_schema() 36 | for _schema in schemas.split(','): 37 | table_info_sqls = [get_table_info_sql(table, _schema)] 38 | column_info_sqls = [get_column_info_sql(table, _schema)] 39 | for _tuple in executor.execute_sqls(table_info_sqls): 40 | if len(_tuple) == 2: 41 | reltuples, parttype = _tuple 42 | reltuples = int(float(reltuples)) 43 | if not reltuples: 44 | continue 45 | is_partitioned_table = True if parttype == 'p' else False 46 | columns = [] 47 | n_distincts = [] 48 | for _tuple in executor.execute_sqls(column_info_sqls): 49 | if len(_tuple) != 2: 50 | continue 51 | n_distinct, column = _tuple 52 | if column not in columns: 53 | columns.append(column) 54 | n_distincts.append(float(n_distinct)) 55 | table_context = TableContext(_schema, table, int(reltuples), columns, n_distincts, is_partitioned_table) 56 | return table_context 57 | 58 | 59 | @dataclass(eq=False) 60 | class TableContext: 61 | schema: str 62 | table: str 63 | reltuples: int 64 | columns: List = field(default_factory=lambda: []) 65 | n_distincts: List = field(default_factory=lambda: []) 66 | is_partitioned_table: bool = field(default=False) 67 | 68 | @lru_cache(maxsize=None) 69 | def has_column(self, column): 70 | is_same_table = True 71 | if '.' in column: 72 | if column.split('.')[0].upper() != self.table.split('.')[-1].upper(): 73 | is_same_table = False 74 | column = column.split('.')[1].lower() 75 | return is_same_table and column in self.columns 76 | 77 | @lru_cache(maxsize=None) 78 | def get_n_distinct(self, column): 79 | column = column.split('.')[-1].lower() 80 | idx = self.columns.index(column) 81 | n_distinct = self.n_distincts[idx] 82 | if float(n_distinct) == float(0): 83 | return 1 84 | return 1 / (-n_distinct * self.reltuples) if n_distinct < 0 else 1 / n_distinct 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Huawei Technologies Co.,Ltd. 2 | # 3 | # openGauss is licensed under Mulan PSL v2. 4 | # You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | # You may obtain a copy of Mulan PSL v2 at: 6 | # 7 | # http://license.coscl.org.cn/MulanPSL2 8 | # 9 | # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 10 | # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 11 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 12 | # See the Mulan PSL v2 for more details. 13 | 14 | import re 15 | from collections import defaultdict 16 | from enum import Enum 17 | from functools import lru_cache 18 | from typing import List, Tuple, Sequence, Any 19 | from contextlib import contextmanager 20 | 21 | import sqlparse 22 | from sqlparse.tokens import Name 23 | from sqlparse.sql import Function, Parenthesis, IdentifierList 24 | 25 | COLUMN_DELIMITER = ', ' 26 | QUERY_PLAN_SUFFIX = 'QUERY PLAN' 27 | EXPLAIN_SUFFIX = 'EXPLAIN' 28 | ERROR_KEYWORD = 'ERROR' 29 | PREPARE_KEYWORD = 'PREPARE' 30 | 31 | 32 | class QueryType(Enum): 33 | INEFFECTIVE = 0 34 | POSITIVE = 1 35 | NEGATIVE = 2 36 | 37 | 38 | class IndexType(Enum): 39 | ADVISED = 1 40 | REDUNDANT = 2 41 | INVALID = 3 42 | 43 | 44 | def replace_function_comma(statement): 45 | """Replace the ? in function to the corresponding value to ensure that prepare execution can be executed properly""" 46 | function_value = {'count': '1', 'decode': "'1'"} 47 | new_statement = '' 48 | for token in get_tokens(statement): 49 | value = token.value 50 | if token.ttype is Name.Placeholder and token.value == '?': 51 | function_token = None 52 | if isinstance(token.parent, Parenthesis) and isinstance(token.parent.parent, Function): 53 | function_token = token.parent.parent 54 | elif isinstance(token.parent, IdentifierList) \ 55 | and isinstance(token.parent.parent, Parenthesis) \ 56 | and isinstance(token.parent.parent.parent, Function): 57 | function_token = token.parent.parent.parent 58 | if function_token: 59 | replaced_value = function_value.get(function_token.get_name().lower(), None) 60 | value = replaced_value if replaced_value else value 61 | new_statement += value 62 | return new_statement 63 | 64 | 65 | class UniqueList(list): 66 | 67 | def append(self, item: Any) -> None: 68 | if item not in self: 69 | super().append(item) 70 | 71 | def extend(self, items: Sequence[Any]) -> None: 72 | for item in items: 73 | self.append(item) 74 | 75 | 76 | class ExistingIndex: 77 | 78 | def __init__(self, schema, table, indexname, columns, indexdef): 79 | self.__schema = schema 80 | self.__table = table 81 | self.__indexname = indexname 82 | self.__columns = columns 83 | self.__indexdef = indexdef 84 | self.__primary_key = False 85 | self.__is_unique = False 86 | self.__index_type = '' 87 | self.redundant_objs = [] 88 | 89 | def set_is_unique(self): 90 | self.__is_unique = True 91 | 92 | def get_is_unique(self): 93 | return self.__is_unique 94 | 95 | def set_index_type(self, index_type): 96 | self.__index_type = index_type 97 | 98 | def get_index_type(self): 99 | return self.__index_type 100 | 101 | def get_table(self): 102 | return self.__table 103 | 104 | def get_schema(self): 105 | return self.__schema 106 | 107 | def get_indexname(self): 108 | return self.__indexname 109 | 110 | def get_columns(self): 111 | return self.__columns 112 | 113 | def get_indexdef(self): 114 | return self.__indexdef 115 | 116 | def is_primary_key(self): 117 | return self.__primary_key 118 | 119 | def set_is_primary_key(self, is_primary_key: bool): 120 | self.__primary_key = is_primary_key 121 | 122 | def get_schema_table(self): 123 | return self.__schema + '.' + self.__table 124 | 125 | def __str__(self): 126 | return f'{self.__schema}, {self.__table}, {self.__indexname}, {self.__columns}, {self.__indexdef})' 127 | 128 | def __repr__(self): 129 | return self.__str__() 130 | 131 | 132 | class AdvisedIndex: 133 | def __init__(self, tbl, cols, index_type=None): 134 | self.__table = tbl 135 | self.__columns = cols 136 | self.benefit = 0 137 | self.__storage = 0 138 | self.__index_type = index_type 139 | self.association_indexes = defaultdict(list) 140 | self.__positive_queries = [] 141 | self.__source_index = None 142 | 143 | def set_source_index(self, source_index: ExistingIndex): 144 | self.__source_index = source_index 145 | 146 | def get_source_index(self): 147 | return self.__source_index 148 | 149 | def append_positive_query(self, query): 150 | self.__positive_queries.append(query) 151 | 152 | def get_positive_queries(self): 153 | return self.__positive_queries 154 | 155 | def set_storage(self, storage): 156 | self.__storage = storage 157 | 158 | def get_storage(self): 159 | return self.__storage 160 | 161 | def get_table(self): 162 | return self.__table 163 | 164 | def get_schema(self): 165 | return self.__table.split('.')[0] 166 | 167 | def get_columns(self): 168 | return self.__columns 169 | 170 | def get_columns_num(self): 171 | return len(self.get_columns().split(COLUMN_DELIMITER)) 172 | 173 | def get_index_type(self): 174 | return self.__index_type 175 | 176 | def get_index_statement(self): 177 | table_name = self.get_table().split('.')[-1] 178 | index_name = 'idx_%s_%s%s' % (table_name, (self.get_index_type() + '_' if self.get_index_type() else ''), 179 | '_'.join(self.get_columns().split(COLUMN_DELIMITER)) 180 | ) 181 | statement = 'CREATE INDEX %s ON %s%s%s;' % (index_name, self.get_table(), 182 | '(' + self.get_columns() + ')', 183 | (' ' + self.get_index_type() if self.get_index_type() else '') 184 | ) 185 | return statement 186 | 187 | def set_association_indexes(self, association_indexes_name, association_benefit): 188 | self.association_indexes[association_indexes_name].append(association_benefit) 189 | 190 | def match_index_name(self, index_name): 191 | schema = self.get_schema() 192 | if schema == 'public': 193 | return index_name.endswith(f'btree_{self.get_index_type() + "_" if self.get_index_type() else ""}' 194 | f'{self.get_table().split(".")[-1]}_' 195 | f'{"_".join(self.get_columns().split(COLUMN_DELIMITER))}') 196 | else: 197 | return index_name.endswith(f'btree_{self.get_index_type() + "_" if self.get_index_type() else ""}' 198 | f'{self.get_table().replace(".", "_")}_' 199 | f'{"_".join(self.get_columns().split(COLUMN_DELIMITER))}') 200 | 201 | def __str__(self): 202 | return f'table: {self.__table} columns: {self.__columns} index_type: ' \ 203 | f'{self.__index_type} storage: {self.__storage}' 204 | 205 | def __repr__(self): 206 | return self.__str__() 207 | 208 | 209 | def singleton(cls): 210 | instances = {} 211 | 212 | def _singleton(*args, **kwargs): 213 | if cls not in instances: 214 | instances[cls] = cls(*args, **kwargs) 215 | return instances[cls] 216 | 217 | return _singleton 218 | 219 | 220 | @singleton 221 | class IndexItemFactory: 222 | def __init__(self): 223 | self.indexes = {} 224 | 225 | def get_index(self, tbl, cols, index_type): 226 | if COLUMN_DELIMITER not in cols: 227 | cols = cols.replace(',', COLUMN_DELIMITER) 228 | if not (tbl, cols, index_type) in self.indexes: 229 | self.indexes[(tbl, cols, index_type)] = AdvisedIndex(tbl, cols, index_type=index_type) 230 | return self.indexes[(tbl, cols, index_type)] 231 | 232 | 233 | def match_table_name(table_name, tables): 234 | for elem in tables: 235 | item_tmp = '_'.join(elem.split('.')) 236 | if table_name == item_tmp: 237 | table_name = elem 238 | break 239 | elif 'public_' + table_name == item_tmp: 240 | table_name = 'public.' + table_name 241 | break 242 | else: 243 | return False, table_name 244 | return True, table_name 245 | 246 | 247 | class QueryItem: 248 | __valid_index_list: List[AdvisedIndex] 249 | 250 | def __init__(self, sql: str, freq: float): 251 | self.__statement = sql 252 | self.__frequency = freq 253 | self.__valid_index_list = [] 254 | self.__benefit = 0 255 | 256 | def get_statement(self): 257 | return self.__statement 258 | 259 | def get_frequency(self): 260 | return self.__frequency 261 | 262 | def append_index(self, index): 263 | self.__valid_index_list.append(index) 264 | 265 | def get_indexes(self): 266 | return self.__valid_index_list 267 | 268 | def reset_opt_indexes(self): 269 | self.__valid_index_list = [] 270 | 271 | def get_sorted_indexes(self): 272 | return sorted(self.__valid_index_list, key=lambda x: (x.get_table(), x.get_columns(), x.get_index_type())) 273 | 274 | def set_benefit(self, benefit): 275 | self.__benefit = benefit 276 | 277 | def get_benefit(self): 278 | return self.__benefit 279 | 280 | def __str__(self): 281 | return f'statement: {self.get_statement()} frequency: {self.get_frequency()} ' \ 282 | f'index_list: {self.__valid_index_list} benefit: {self.__benefit}' 283 | 284 | def __repr__(self): 285 | return self.__str__() 286 | 287 | 288 | class WorkLoad: 289 | def __init__(self, queries: List[QueryItem]): 290 | self.__indexes_list = [] 291 | self.__queries = queries 292 | self.__index_names_list = [[] for _ in range(len(self.__queries))] 293 | self.__indexes_costs = [[] for _ in range(len(self.__queries))] 294 | self.__plan_list = [[] for _ in range(len(self.__queries))] 295 | 296 | def get_queries(self) -> List[QueryItem]: 297 | return self.__queries 298 | 299 | def has_indexes(self, indexes: Tuple[AdvisedIndex]): 300 | return indexes in self.__indexes_list 301 | 302 | def get_used_index_names(self): 303 | used_indexes = set() 304 | for index_names in self.get_workload_used_indexes(None): 305 | for index_name in index_names: 306 | used_indexes.add(index_name) 307 | return used_indexes 308 | 309 | @lru_cache(maxsize=None) 310 | def get_workload_used_indexes(self, indexes: (Tuple[AdvisedIndex], None)): 311 | return list([index_names[self.__indexes_list.index(indexes if indexes else None)] 312 | for index_names in self.__index_names_list]) 313 | 314 | def get_query_advised_indexes(self, indexes, query): 315 | query_idx = self.__queries.index(query) 316 | indexes_idx = self.__indexes_list.index(indexes if indexes else None) 317 | used_index_names = self.__index_names_list[indexes_idx][query_idx] 318 | used_advised_indexes = [] 319 | for index in indexes: 320 | for index_name in used_index_names: 321 | if index.match(index_name): 322 | used_advised_indexes.append(index) 323 | return used_advised_indexes 324 | 325 | def set_index_benefit(self): 326 | for indexes in self.__indexes_list: 327 | if indexes and len(indexes) == 1: 328 | indexes[0].benefit = self.get_index_benefit(indexes[0]) 329 | 330 | def replace_indexes(self, origin, new): 331 | if not new: 332 | new = None 333 | self.__indexes_list[self.__indexes_list.index(origin if origin else None)] = new 334 | 335 | @lru_cache(maxsize=None) 336 | def get_total_index_cost(self, indexes: (Tuple[AdvisedIndex], None)): 337 | return sum( 338 | query_index_cost[self.__indexes_list.index(indexes if indexes else None)] for query_index_cost in 339 | self.__indexes_costs) 340 | 341 | @lru_cache(maxsize=None) 342 | def get_total_origin_cost(self): 343 | return self.get_total_index_cost(None) 344 | 345 | @lru_cache(maxsize=None) 346 | def get_indexes_benefit(self, indexes: Tuple[AdvisedIndex]): 347 | return self.get_total_origin_cost() - self.get_total_index_cost(indexes) 348 | 349 | @lru_cache(maxsize=None) 350 | def get_index_benefit(self, index: AdvisedIndex): 351 | return self.get_indexes_benefit(tuple([index])) 352 | 353 | @lru_cache(maxsize=None) 354 | def get_indexes_cost_of_query(self, query: QueryItem, indexes: (Tuple[AdvisedIndex], None)): 355 | return self.__indexes_costs[self.__queries.index(query)][ 356 | self.__indexes_list.index(indexes if indexes else None)] 357 | 358 | @lru_cache(maxsize=None) 359 | def get_indexes_plan_of_query(self, query: QueryItem, indexes: (Tuple[AdvisedIndex], None)): 360 | return self.__plan_list[self.__queries.index(query)][ 361 | self.__indexes_list.index(indexes if indexes else None)] 362 | 363 | @lru_cache(maxsize=None) 364 | def get_origin_cost_of_query(self, query: QueryItem): 365 | return self.get_indexes_cost_of_query(query, None) 366 | 367 | @lru_cache(maxsize=None) 368 | def is_positive_query(self, index: AdvisedIndex, query: QueryItem): 369 | return self.get_origin_cost_of_query(query) > self.get_indexes_cost_of_query(query, tuple([index])) 370 | 371 | def add_indexes(self, indexes: (Tuple[AdvisedIndex], None), costs, index_names, plan_list): 372 | if not indexes: 373 | indexes = None 374 | self.__indexes_list.append(indexes) 375 | if len(costs) != len(self.__queries): 376 | raise 377 | for i, cost in enumerate(costs): 378 | self.__indexes_costs[i].append(cost) 379 | self.__index_names_list[i].append(index_names[i]) 380 | self.__plan_list[i].append(plan_list[i]) 381 | 382 | @lru_cache(maxsize=None) 383 | def get_index_related_queries(self, index: AdvisedIndex): 384 | insert_queries = [] 385 | delete_queries = [] 386 | update_queries = [] 387 | select_queries = [] 388 | positive_queries = [] 389 | ineffective_queries = [] 390 | negative_queries = [] 391 | 392 | cur_table = index.get_table() 393 | for query in self.get_queries(): 394 | if cur_table not in query.get_statement().lower() and \ 395 | not re.search(r'((\A|[\s(,])%s[\s),])' % cur_table.split('.')[-1], 396 | query.get_statement().lower()): 397 | continue 398 | 399 | if any(re.match(r'(insert\s+into\s+%s\s)' % table, query.get_statement().lower()) 400 | for table in [cur_table, cur_table.split('.')[-1]]): 401 | insert_queries.append(query) 402 | if not self.is_positive_query(index, query): 403 | negative_queries.append(query) 404 | elif any(re.match(r'(delete\s+from\s+%s\s)' % table, query.get_statement().lower()) or 405 | re.match(r'(delete\s+%s\s)' % table, query.get_statement().lower()) 406 | for table in [cur_table, cur_table.split('.')[-1]]): 407 | delete_queries.append(query) 408 | if not self.is_positive_query(index, query): 409 | negative_queries.append(query) 410 | elif any(re.match(r'(update\s+%s\s)' % table, query.get_statement().lower()) 411 | for table in [cur_table, cur_table.split('.')[-1]]): 412 | update_queries.append(query) 413 | if not self.is_positive_query(index, query): 414 | negative_queries.append(query) 415 | else: 416 | select_queries.append(query) 417 | if not self.is_positive_query(index, query): 418 | ineffective_queries.append(query) 419 | positive_queries = [query for query in insert_queries + delete_queries + update_queries + select_queries 420 | if query not in negative_queries + ineffective_queries] 421 | return insert_queries, delete_queries, update_queries, select_queries, \ 422 | positive_queries, ineffective_queries, negative_queries 423 | 424 | @lru_cache(maxsize=None) 425 | def get_index_sql_num(self, index: AdvisedIndex): 426 | insert_queries, delete_queries, update_queries, \ 427 | select_queries, positive_queries, ineffective_queries, \ 428 | negative_queries = self.get_index_related_queries(index) 429 | insert_sql_num = sum(query.get_frequency() for query in insert_queries) 430 | delete_sql_num = sum(query.get_frequency() for query in delete_queries) 431 | update_sql_num = sum(query.get_frequency() for query in update_queries) 432 | select_sql_num = sum(query.get_frequency() for query in select_queries) 433 | positive_sql_num = sum(query.get_frequency() for query in positive_queries) 434 | ineffective_sql_num = sum(query.get_frequency() for query in ineffective_queries) 435 | negative_sql_num = sum(query.get_frequency() for query in negative_queries) 436 | return {'insert': insert_sql_num, 'delete': delete_sql_num, 'update': update_sql_num, 'select': select_sql_num, 437 | 'positive': positive_sql_num, 'ineffective': ineffective_sql_num, 'negative': negative_sql_num} 438 | 439 | 440 | def get_statement_count(queries: List[QueryItem]): 441 | return int(sum(query.get_frequency() for query in queries)) 442 | 443 | 444 | def is_subset_index(indexes1: Tuple[AdvisedIndex], indexes2: Tuple[AdvisedIndex]): 445 | existing = False 446 | if len(indexes1) > len(indexes2): 447 | return existing 448 | for index1 in indexes1: 449 | existing = False 450 | for index2 in indexes2: 451 | # Example indexes1: [table1 col1 global] belong to indexes2:[table1 col1, col2 global]. 452 | if index2.get_table() == index1.get_table() \ 453 | and match_columns(index1.get_columns(), index2.get_columns()) \ 454 | and index2.get_index_type() == index1.get_index_type(): 455 | existing = True 456 | break 457 | if not existing: 458 | break 459 | return existing 460 | 461 | 462 | def lookfor_subsets_configs(config: List[AdvisedIndex], atomic_config_total: List[Tuple[AdvisedIndex]]): 463 | """ Look for the subsets of a given config in the atomic configs. """ 464 | contained_atomic_configs = [] 465 | for atomic_config in atomic_config_total: 466 | if len(atomic_config) == 1: 467 | continue 468 | if not is_subset_index(atomic_config, tuple(config)): 469 | continue 470 | # Atomic_config should contain the latest candidate_index. 471 | if not any(is_subset_index((atomic_index,), (config[-1],)) for atomic_index in atomic_config): 472 | continue 473 | # Filter redundant config in contained_atomic_configs. 474 | for contained_atomic_config in contained_atomic_configs[:]: 475 | if is_subset_index(contained_atomic_config, atomic_config): 476 | contained_atomic_configs.remove(contained_atomic_config) 477 | 478 | contained_atomic_configs.append(atomic_config) 479 | 480 | return contained_atomic_configs 481 | 482 | 483 | def match_columns(column1, column2): 484 | return re.match(column1 + ',', column2 + ',') 485 | 486 | 487 | def infer_workload_benefit(workload: WorkLoad, config: List[AdvisedIndex], 488 | atomic_config_total: List[Tuple[AdvisedIndex]]): 489 | """ Infer the total cost of queries for a config according to the cost of atomic configs. """ 490 | total_benefit = 0 491 | atomic_subsets_configs = lookfor_subsets_configs(config, atomic_config_total) 492 | is_recorded = [True] * len(atomic_subsets_configs) 493 | for query in workload.get_queries(): 494 | origin_cost_of_query = workload.get_origin_cost_of_query(query) 495 | if origin_cost_of_query == 0: 496 | continue 497 | # When there are multiple indexes, the benefit is the total benefit 498 | # of the multiple indexes minus the benefit of every single index. 499 | total_benefit += \ 500 | origin_cost_of_query - workload.get_indexes_cost_of_query(query, (config[-1],)) 501 | for k, sub_config in enumerate(atomic_subsets_configs): 502 | single_index_total_benefit = sum(origin_cost_of_query - 503 | workload.get_indexes_cost_of_query(query, (index,)) 504 | for index in sub_config) 505 | portfolio_returns = \ 506 | origin_cost_of_query \ 507 | - workload.get_indexes_cost_of_query(query, sub_config) \ 508 | - single_index_total_benefit 509 | total_benefit += portfolio_returns 510 | if portfolio_returns / origin_cost_of_query <= 0.01: 511 | continue 512 | # Record the portfolio returns of the index. 513 | association_indexes = ';'.join([str(index) for index in sub_config]) 514 | association_benefit = (query.get_statement(), portfolio_returns / origin_cost_of_query) 515 | if association_indexes not in config[-1].association_indexes: 516 | is_recorded[k] = False 517 | config[-1].set_association_indexes(association_indexes, association_benefit) 518 | continue 519 | if not is_recorded[k]: 520 | config[-1].set_association_indexes(association_indexes, association_benefit) 521 | 522 | return total_benefit 523 | 524 | 525 | @lru_cache(maxsize=None) 526 | def get_tokens(query): 527 | return list(sqlparse.parse(query)[0].flatten()) 528 | 529 | 530 | @lru_cache(maxsize=None) 531 | def has_dollar_placeholder(query): 532 | tokens = get_tokens(query) 533 | return any(item.ttype is Name.Placeholder for item in tokens) 534 | 535 | 536 | @lru_cache(maxsize=None) 537 | def get_placeholders(query): 538 | placeholders = set() 539 | for item in get_tokens(query): 540 | if item.ttype is Name.Placeholder: 541 | placeholders.add(item.value) 542 | return placeholders 543 | 544 | 545 | @lru_cache(maxsize=None) 546 | def generate_placeholder_indexes(table_cxt, column): 547 | indexes = [] 548 | schema_table = f'{table_cxt.schema}.{table_cxt.table}' 549 | if table_cxt.is_partitioned_table: 550 | indexes.append(IndexItemFactory().get_index(schema_table, column, 'global')) 551 | indexes.append(IndexItemFactory().get_index(schema_table, column, 'local')) 552 | else: 553 | indexes.append(IndexItemFactory().get_index(schema_table, column, '')) 554 | return indexes 555 | 556 | 557 | def replace_comma_with_dollar(query): 558 | """ 559 | Replacing '?' with '$+Numbers' in SQL: 560 | input: UPDATE bmsql_customer SET c_balance = c_balance + $1, c_delivery_cnt = c_delivery_cnt + ? 561 | WHERE c_w_id = $2 AND c_d_id = $3 AND c_id = $4 and c_info = ?; 562 | output: UPDATE bmsql_customer SET c_balance = c_balance + $1, c_delivery_cnt = c_delivery_cnt + $5 563 | WHERE c_w_id = $2 AND c_d_id = $3 AND c_id = $4 and c_info = $6; 564 | note: if track_stmt_parameter is off, all '?' in SQL need to be replaced 565 | """ 566 | if '?' not in query: 567 | return query 568 | max_dollar_number = 0 569 | dollar_parts = re.findall(r'(\$\d+)', query) 570 | if dollar_parts: 571 | max_dollar_number = max(int(item.strip('$')) for item in dollar_parts) 572 | while '?' in query: 573 | dollar = "$%s" % (max_dollar_number + 1) 574 | query = query.replace('?', dollar, 1) 575 | max_dollar_number += 1 576 | return query 577 | 578 | 579 | @lru_cache(maxsize=None) 580 | def is_multi_node(executor): 581 | sql = "select pg_catalog.count(*) from pgxc_node where node_type='C';" 582 | for cur_tuple in executor.execute_sqls([sql]): 583 | if str(cur_tuple[0]).isdigit(): 584 | return int(cur_tuple[0]) > 0 585 | 586 | 587 | @contextmanager 588 | def hypo_index_ctx(executor): 589 | yield 590 | executor.execute_sqls(['SELECT pg_catalog.hypopg_reset_index();']) 591 | 592 | 593 | def split_integer(m, n): 594 | quotient = int(m / n) 595 | remainder = m % n 596 | if m < n: 597 | return [1] * m 598 | if remainder > 0: 599 | return [quotient] * (n - remainder) + [quotient + 1] * remainder 600 | if remainder < 0: 601 | return [quotient - 1] * -remainder + [quotient] * (n + remainder) 602 | return [quotient] * n 603 | 604 | 605 | def split_iter(iterable, n): 606 | size_list = split_integer(len(iterable), n) 607 | index = 0 608 | res = [] 609 | for size in size_list: 610 | res.append(iterable[index:index + size]) 611 | index += size 612 | return res 613 | 614 | 615 | def flatten(iterable): 616 | for _iter in iterable: 617 | if hasattr(_iter, '__iter__') and not isinstance(_iter, str): 618 | for item in flatten(_iter): 619 | yield item 620 | else: 621 | yield _iter 622 | --------------------------------------------------------------------------------