├── .gitignore ├── LICENSE ├── README.md ├── config.json.example ├── config.json.example_all ├── config_reader.py ├── database_helper.py ├── db_connect.py ├── direct_subset.py ├── mysql_database_creator.py ├── mysql_database_helper.py ├── psql_database_creator.py ├── psql_database_helper.py ├── requirements.txt ├── result_tabulator.py ├── subset.py ├── subset_utils.py └── topo_orderer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | __pycache__/ 3 | .vscode/ 4 | SQL/ 5 | config.json 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019, Tonic AI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Condenser 2 | 3 | Condenser is a config-driven database subsetting tool for Postgres and MySQL. 4 | 5 | Subsetting data is the process of taking a representative sample of your data in a manner that preserves the integrity of your database, e.g., give me 5% of my users. If you do this naively, e.g., just grab 5% of all the tables in your database, most likely, your database will break foreign key constraints. At best, you’ll end up with a statistically non-representative data sample. 6 | 7 | One common use-case is to scale down a production database to a more reasonable size so that it can be used in staging, test, and development environments. This can be done to save costs and, when used in tandem with PII removal, can be quite powerful as a productivity enhancer. Another example is copying specific rows from one database and placing them into another while maintaining referential integrity. 8 | 9 | You can find more details about how we built this [here](https://www.tonic.ai/blog/condenser-a-database-subsetting-tool) and [here](https://www.tonic.ai/blog/condenser-v2/). 10 | 11 | ## Need to Subset a Large Database? 12 | 13 | Our open-source tool can subset databases up to 10GB, but it will struggle with larger databases. Our premium database subsetter can, among other things (graphical UI, job scheduling, fancy algorithms), subset multi-TB databases with ease. If you're interested find us at [hello@tonic.ai](mailto:hello@tonic.ai). 14 | 15 | # Installation 16 | 17 | Five steps to install, assuming Python 3.5+: 18 | 19 | 1. Download the required Python modules. You can use [`pip`](https://pypi.org/project/pip/) for easy installation. The required modules are `toposort`, `psycopg2-binary`, and `mysql-connector-python`. 20 | ``` 21 | $ pip install toposort 22 | $ pip install psycopg2-binary 23 | $ pip install mysql-connector-python 24 | ``` 25 | 2. Install Postgres and/or MySQL database tools. For Postgres we need `pg_dump` and `psql` tools; they need to be on your `$PATH` or point to them with `$POSTGRES_PATH`. For MySQL we need `mysqldump` and `mysql`, they can be on your `$PATH` or point to them with `$MYSQL_PATH`. 26 | 3. Download this repo. You can clone the repo or Download it as a zip. Scroll up, it's the green button that says "Clone or download". 27 | 4. Setup your configuration and save it in `config.json`. The provided `config.json.example` has the skeleton of what you need to provide: source and destination database connection details, as well as subsetting goals in `initial_targets`. Here's an example that will collect 10% of a table named `public.target_table`. 28 | ``` 29 | "initial_targets": [ 30 | { 31 | "table": "public.target_table", 32 | "percent": 10 33 | } 34 | ] 35 | ``` 36 | There may be more required configuration depending on your database, but simple databases should be easy. See the Config section for more details, and `config.json.example_all` for all of the options in a single config file. 37 | 38 | 5. Run! `$ python direct_subset.py` 39 | 40 | # Config 41 | 42 | Configuration must exist in `config.json`. There is an example configuration provided in `example-config.json`. Most of the configuration is straightforward: source and destination DB connection details and subsetting settings. There are three fields that desire some additional attention. 43 | 44 | The first is `initial_targets`. This is where you tell the subsetter to begin the subset. You can specify any number of tables as an initial target, and provide either a percent goal (e.g. 5% of the `users` table) or a WHERE clause. 45 | 46 | Next is `dependency_breaks`. The best way to get a full understanding of this is to read our [blog post](https://www.tonic.ai/blog/condenser-a-database-subsetting-tool). But if you want a TLDR, it's this: The subsetting tool cannot operate on databases with cycles in their foreign key relationships. (Example: Table `events` references `users`, which references `company`, which references `events`, a cycle exists if you think about the foreign keys as a directed graph.) If your database has a foreign key cycle (any many do), have no fear! This field lets you tell the subsetter to ignore certain foreign keys, and essentially remove the cycle. You'll have to know a bit about your database to use this field effectively. The tool will warn you if you have a cycle that you haven't broken. 47 | 48 | The last is `fk_augmentation`. Databases frequently have foreign keys that are not codified as constraints on the database, these are implicit foreign keys. For a subsetter to create useful subsets if needs to know about this implicit constraints. This field lets you essentially add foreign keys to the subsetter that the DB doesn't have listed as a constraint. 49 | 50 | Below we describe the use of all configuration parameters, but the best place to start for the exact format is `example-config.json`. 51 | 52 | `db_type`: The type of the databse to subset. Valid values are `"postgres"` or `"mysql"`. 53 | 54 | `source_db_connection_info`: Source database connection details. These are recorded as a JSON object with the fields `user_name`, `host`, `db_name`, `ssl_mode`, `password` (optional), and `post`. If `password` is omitted, then you will be prompted for a password. See `example-config.json` for details. 55 | 56 | `destination_db_connection_info`: Destination database connection details. Same fields as `source_db_connection_info`. 57 | 58 | `initial_targets`: JSON array of JSON objects. The inner object must contain a `target` field, which is a target table, and either a `where` field or a `percent` field. The `where` field is used to specify a WHERE clause for the subsetting. The `percent` field indicates we want a specific percentage of the target table; it is equivalent to `"where": "random() < /100.0"`. 59 | 60 | `passthrough_tables`: Tables that will be copied to destination database in whole. The value is a JSON array of strings, of the form `"."` for Postgres and `".
"` for MySQL. 61 | 62 | `excluded_tables`: Tables that will be excluded from the subset. The table will exist in the output, but contain no rows. The value is a JSON array of strings, of the form `".
"` for Postgres and `".
"` for MySQL. 63 | 64 | `upstream_filters`: Additional filtering to be applied to tables during upstream subsetting. Upstream subsetting happens when a row is imported, and there are rows with foreign keys to that row. The subsetter then greedily grabs as many rows from the database as it can, based on the rows already imported. If you don't want such greedy behavior you can impose additional filters with this option. This is an advanced feature, you probably won't need for your first subsets. The value is a JSON array of JSON objects. See `example-config.json` for details. 65 | 66 | `fk_augmentation`: Additional foreign keys that, while not represented as constraints in the database, are logically present in the data. Foreign keys listed in `fk_augmentation` are unioned with the foreign keys provided by constraints in the database. `fk_augmentation` is useful when there are foreign keys existing in the data, but not represented in the database. The value is a JSON array of JSON objects. See `example-config.json` for details. 67 | 68 | `dependency_breaks`: An array containing JSON objects with *"fk_table"* and *"target_table"* fields of table relationships to be ignored in order to break cycles 69 | 70 | `keep_disconnected_tables`: If `true` tables that the subset target(s) don't reach, when following foreign keys, will be copied 100% over. If it's `false` then their schema will be copied but the table contents will be empty. Put more mathematically, the tables and foreign keys create a graph (where tables are nodes and foreign keys are directed edges) disconnected tables are the tables in components that don't contain any targets. This setting decides how to import those tables. 71 | 72 | `max_rows_per_table`: This is interpreted as a limit on all of the tables to be copied. Useful if you have some very large tables that you want a sampling from. For an unlimited dataset (recommended) set this parameter to `ALL`. 73 | 74 | `pre_constraint_sql`: An array of SQL commands that will be issued on the destination database after subsetting is complete, but before the database constraints have been applied. Useful to perform tasks that will clean up any data that would otherwise violate the database constraints. `post_subset_sql` is the preferred option for any general purpose queries. 75 | 76 | `post_subset_sql`: An array of SQL commands that will be issued on the destination database after subsetting is complete, and after the database constraints have been applied. Useful to perform additional adhoc tasks after subsetting. 77 | 78 | # Running 79 | 80 | Almost all the configuration is in the `config.json` file, so running is as simple as 81 | 82 | ``` 83 | $ python direct_subset.py 84 | ``` 85 | 86 | Two commandline arguements are supported: 87 | 88 | `-v`: Verbose output. Useful for performance debugging. Lists almost every query made, and it's speed. 89 | 90 | `--no-constraints`: For Postgres this will not add constraints found in the source database to the destination database. This option has no effect for MySQL. 91 | 92 | # Requirements 93 | 94 | Reference the requirements.txt file for a list of required python packages. Also, please note that Python 3.5+ is required. 95 | -------------------------------------------------------------------------------- /config.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "initial_targets": [ 3 | { 4 | "table": "public.target_table", 5 | "percent": 10 6 | } 7 | ], 8 | "db_type": "postgres", 9 | "source_db_connection_info": { 10 | "user_name": "user", 11 | "host": "host.host.com", 12 | "db_name": "source_db", 13 | "port": 5432 14 | }, 15 | "destination_db_connection_info": { 16 | "user_name": "user", 17 | "host": "host.host.com", 18 | "db_name": "destination_db", 19 | "port": 5432 20 | }, 21 | "keep_disconnected_tables": false, 22 | "excluded_tables": [ ], 23 | "passthrough_tables": [ ], 24 | "dependency_breaks": [ ], 25 | "fk_augmentation": [ ], 26 | "upstream_filters": [ ] 27 | } 28 | -------------------------------------------------------------------------------- /config.json.example_all: -------------------------------------------------------------------------------- 1 | { 2 | "initial_targets": [ 3 | { 4 | "table": "public.target_table", 5 | "percent": 10 6 | }, 7 | { 8 | "table": "public.users", 9 | "where": "split_part(email, '@', 2) = 'hotmail.com'" 10 | } 11 | ], 12 | "db_type": "postgres", 13 | "source_db_connection_info": { 14 | "user_name": "user", 15 | "host": "host.host.com", 16 | "db_name": "source_db", 17 | "port": 5432 18 | }, 19 | "destination_db_connection_info": { 20 | "user_name": "user", 21 | "host": "host.host.com", 22 | "db_name": "destination_db", 23 | "password": "if you don't include a password option, you will be prompted every time", 24 | "port": 5432 25 | }, 26 | "keep_disconnected_tables": false, 27 | "upstream_filters": [ 28 | { 29 | "table": "public.an_upstream_table", 30 | "condition": "timestamp > '01-01-2001'" 31 | }, 32 | { 33 | "column": "condition_applied_to_any_table_with_this_column", 34 | "condition": "condition_applied_to_any_table_with_this_column > 42" 35 | } 36 | ], 37 | "max_rows_per_table": 100000, 38 | "excluded_tables": [ 39 | "public.table_to_ignore", "public.spatial_ref_sys" 40 | ], 41 | "passthrough_tables": [ 42 | "public.table_of_settings", "public.table_of_constants" 43 | ], 44 | "dependency_breaks": [ 45 | {"fk_table": "schema2.table2", "target_table": "schema3.table3"} 46 | ], 47 | "fk_augmentation": [ 48 | { 49 | "fk_table": "public.fk_table", 50 | "fk_columns": ["user_id"], 51 | "target_table": "public.user", 52 | "target_columns": ["id"] 53 | } 54 | ], 55 | "pre_constraint_sql":["DELETE FROM table where non_nullable_column IS NULL"], 56 | "post_subset_sql": ["UPDATE a_table SET a_column = 'value'"] 57 | } 58 | -------------------------------------------------------------------------------- /config_reader.py: -------------------------------------------------------------------------------- 1 | import json, sys, collections 2 | 3 | _config = None 4 | 5 | def initialize(file_like = None): 6 | global _config 7 | if _config != None: 8 | print('WARNING: Attempted to initialize configuration twice.', file=sys.stderr) 9 | 10 | if not file_like: 11 | with open('config.json', 'r') as fp: 12 | _config = json.load(fp) 13 | else: 14 | _config = json.load(file_like) 15 | 16 | if "desired_result" in _config: 17 | raise ValueError("desired_result is a key in the old config spec. Check the README.md and example-config.json for the latest configuration parameters.") 18 | 19 | DependencyBreak = collections.namedtuple('DependencyBreak', ['fk_table', 'target_table']) 20 | def get_dependency_breaks(): 21 | return set([DependencyBreak(b['fk_table'], b['target_table']) for b in _config['dependency_breaks']]) 22 | 23 | def get_preserve_fk_opportunistically(): 24 | return set([DependencyBreak(b['fk_table'], b['target_table']) for b in _config['dependency_breaks'] if 'perserve_fk_opportunistically' in b and b['perserve_fk_opportunistically']]) 25 | 26 | def get_initial_targets(): 27 | return _config['initial_targets'] 28 | 29 | def get_initial_target_tables(): 30 | return [target["table"] for target in _config['initial_targets']] 31 | 32 | def keep_disconnected_tables(): 33 | return 'keep_disconnected_tables' in _config and bool(_config['keep_disconnected_tables']) 34 | 35 | def get_db_type(): 36 | return _config['db_type'] 37 | 38 | def get_source_db_connection_info(): 39 | return _config['source_db_connection_info'] 40 | 41 | def get_destination_db_connection_info(): 42 | return _config['destination_db_connection_info'] 43 | 44 | def get_excluded_tables(): 45 | return list(_config['excluded_tables']) 46 | 47 | def get_passthrough_tables(): 48 | return list(_config['passthrough_tables']) 49 | 50 | def get_fk_augmentation(): 51 | return list(map(__convert_tonic_format, _config['fk_augmentation'])) 52 | 53 | def get_upstream_filters(): 54 | return _config["upstream_filters"] 55 | 56 | def get_pre_constraint_sql(): 57 | return _config["pre_constraint_sql"] if "pre_constraint_sql" in _config else [] 58 | 59 | def get_post_subset_sql(): 60 | return _config["post_subset_sql"] if "post_subset_sql" in _config else [] 61 | 62 | def get_max_rows_per_table(): 63 | return _config["max_rows_per_table"] if "max_rows_per_table" in _config else None 64 | 65 | def __convert_tonic_format(obj): 66 | if "fk_schema" in obj: 67 | return { 68 | "fk_table": obj["fk_schema"] + "." + obj["fk_table"], 69 | "fk_columns": obj["fk_columns"], 70 | "target_table": obj["target_schema"] + "." + obj["target_table"], 71 | "target_columns": obj["target_columns"], 72 | } 73 | else: 74 | return obj 75 | 76 | def verbose_logging(): 77 | return '-v' in sys.argv 78 | -------------------------------------------------------------------------------- /database_helper.py: -------------------------------------------------------------------------------- 1 | import config_reader 2 | 3 | def get_specific_helper(): 4 | if config_reader.get_db_type() == 'postgres': 5 | import psql_database_helper 6 | return psql_database_helper 7 | else: 8 | import mysql_database_helper 9 | return mysql_database_helper 10 | -------------------------------------------------------------------------------- /db_connect.py: -------------------------------------------------------------------------------- 1 | import config_reader 2 | import psycopg2, mysql.connector 3 | import os, pathlib, re, urllib, subprocess, os.path, json, getpass, time, sys, datetime 4 | 5 | class DbConnect: 6 | 7 | def __init__(self, db_type, connection_info): 8 | requiredKeys = [ 9 | 'user_name', 10 | 'host', 11 | 'db_name', 12 | 'port' 13 | ] 14 | 15 | for r in requiredKeys: 16 | if r not in connection_info.keys(): 17 | raise Exception('Missing required key in database connection info: ' + r) 18 | if 'password' not in connection_info.keys(): 19 | connection_info['password'] = getpass.getpass('Enter password for {0} on host {1}: '.format(connection_info['user_name'], connection_info['host'])) 20 | 21 | self.user = connection_info['user_name'] 22 | self.password = connection_info['password'] 23 | self.host = connection_info['host'] 24 | self.port = connection_info['port'] 25 | self.db_name = connection_info['db_name'] 26 | self.ssl_mode = connection_info['ssl_mode'] if 'ssl_mode' in connection_info else None 27 | self.__db_type = db_type.lower() 28 | 29 | def get_db_connection(self, read_repeatable=False): 30 | 31 | if self.__db_type == 'postgres': 32 | return PsqlConnection(self, read_repeatable) 33 | elif self.__db_type == 'mysql': 34 | return MySqlConnection(self, read_repeatable) 35 | else: 36 | raise ValueError('unknown db_type ' + self.__db_type) 37 | 38 | class DbConnection: 39 | def __init__(self, connection): 40 | self.connection = connection 41 | 42 | def commit(self): 43 | self.connection.commit() 44 | 45 | def close(self): 46 | self.connection.close() 47 | 48 | 49 | class LoggingCursor: 50 | def __init__(self, cursor): 51 | self.inner_cursor = cursor 52 | 53 | def execute(self, query): 54 | start_time = time.time() 55 | if config_reader.verbose_logging(): 56 | print('Beginning query @ {}:\n\t{}'.format(str(datetime.datetime.now()), query)) 57 | sys.stdout.flush() 58 | retval = self.inner_cursor.execute(query) 59 | if config_reader.verbose_logging(): 60 | print('\tQuery completed in {}s'.format(time.time() - start_time)) 61 | sys.stdout.flush() 62 | return retval 63 | 64 | def __getattr__(self, name): 65 | return self.inner_cursor.__getattribute__(name) 66 | 67 | def __exit__(self, a, b, c): 68 | return self.inner_cursor.__exit__(a, b, c) 69 | 70 | def __enter__(self): 71 | return LoggingCursor(self.inner_cursor.__enter__()) 72 | 73 | # small wrapper to the connection class that gives us a common interface to the cursor() 74 | # method across MySQL and Postgres. This one is for Postgres 75 | class PsqlConnection(DbConnection): 76 | def __init__(self, connect, read_repeatable): 77 | connection_string = 'dbname=\'{0}\' user=\'{1}\' password=\'{2}\' host={3} port={4}'.format(connect.db_name, connect.user, connect.password, connect.host, connect.port) 78 | 79 | if connect.ssl_mode : 80 | connection_string = connection_string + ' sslmode={0}'.format(connect.ssl_mode) 81 | 82 | DbConnection.__init__(self, psycopg2.connect(connection_string)) 83 | if read_repeatable: 84 | self.connection.isolation_level = psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ 85 | 86 | def cursor(self, name=None, withhold=False): 87 | return LoggingCursor(self.connection.cursor(name=name, withhold=withhold)) 88 | 89 | 90 | # small wrapper to the connection class that gives us a common interface to the cursor() 91 | # method across MySQL and Postgres. This one is for MySQL 92 | class MySqlConnection(DbConnection): 93 | def __init__(self, connect, read_repeatable): 94 | DbConnection.__init__(self, mysql.connector.connect(host=connect.host, port=connect.port, user=connect.user, password=connect.password, database=connect.db_name)) 95 | 96 | self.db_name = connect.db_name 97 | 98 | if read_repeatable: 99 | self.connection.start_transaction(isolation_level='REPEATABLE READ') 100 | 101 | def cursor(self, name=None, withhold=False): 102 | return LoggingCursor(self.connection.cursor()) 103 | -------------------------------------------------------------------------------- /direct_subset.py: -------------------------------------------------------------------------------- 1 | import uuid, sys 2 | import config_reader, result_tabulator 3 | import time 4 | from subset import Subset 5 | from psql_database_creator import PsqlDatabaseCreator 6 | from mysql_database_creator import MySqlDatabaseCreator 7 | from db_connect import DbConnect 8 | from subset_utils import print_progress 9 | import database_helper 10 | 11 | def db_creator(db_type, source, dest): 12 | if db_type == 'postgres': 13 | return PsqlDatabaseCreator(source, dest, False) 14 | elif db_type == 'mysql': 15 | return MySqlDatabaseCreator(source, dest) 16 | else: 17 | raise ValueError('unknown db_type ' + db_type) 18 | 19 | 20 | if __name__ == '__main__': 21 | if "--stdin" in sys.argv: 22 | config_reader.initialize(sys.stdin) 23 | else: 24 | config_reader.initialize() 25 | 26 | db_type = config_reader.get_db_type() 27 | source_dbc = DbConnect(db_type, config_reader.get_source_db_connection_info()) 28 | destination_dbc = DbConnect(db_type, config_reader.get_destination_db_connection_info()) 29 | 30 | database = db_creator(db_type, source_dbc, destination_dbc) 31 | database.teardown() 32 | database.create() 33 | 34 | # Get list of tables to operate on 35 | db_helper = database_helper.get_specific_helper() 36 | all_tables = db_helper.list_all_tables(source_dbc) 37 | all_tables = [x for x in all_tables if x not in config_reader.get_excluded_tables()] 38 | 39 | subsetter = Subset(source_dbc, destination_dbc, all_tables) 40 | 41 | try: 42 | subsetter.prep_temp_dbs() 43 | subsetter.run_middle_out() 44 | 45 | print("Beginning pre constraint SQL calls") 46 | start_time = time.time() 47 | for idx, sql in enumerate(config_reader.get_pre_constraint_sql()): 48 | print_progress(sql, idx+1, len(config_reader.get_pre_constraint_sql())) 49 | db_helper.run_query(sql, destination_dbc.get_db_connection()) 50 | print("Completed pre constraint SQL calls in {}s".format(time.time()-start_time)) 51 | 52 | 53 | print("Adding database constraints") 54 | if "--no-constraints" not in sys.argv: 55 | database.add_constraints() 56 | 57 | print("Beginning post subset SQL calls") 58 | start_time = time.time() 59 | for idx, sql in enumerate(config_reader.get_post_subset_sql()): 60 | print_progress(sql, idx+1, len(config_reader.get_post_subset_sql())) 61 | db_helper.run_query(sql, destination_dbc.get_db_connection()) 62 | print("Completed post subset SQL calls in {}s".format(time.time()-start_time)) 63 | 64 | result_tabulator.tabulate(source_dbc, destination_dbc, all_tables) 65 | finally: 66 | subsetter.unprep_temp_dbs() 67 | 68 | 69 | -------------------------------------------------------------------------------- /mysql_database_creator.py: -------------------------------------------------------------------------------- 1 | import os, urllib, subprocess, io 2 | 3 | class MySqlDatabaseCreator: 4 | def __init__(self, source_connect, destination_connect): 5 | self.__source_connect = source_connect 6 | self.__destination_connect = destination_connect 7 | 8 | 9 | def create(self): 10 | cur_path = os.getcwd() 11 | 12 | mysql_bin_path = get_mysql_bin_path() 13 | if mysql_bin_path != '': 14 | os.chdir(mysql_bin_path) 15 | 16 | ca = connection_args(self.__source_connect) 17 | args = ['mysqldump', '--no-data', '--routines'] + ca + [self.__source_connect.db_name] 18 | result = subprocess.run(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE) 19 | if result.returncode != 0: 20 | raise Exception('Capturing schema failed. Details:\n{}'.format(result.stderr)) 21 | commands_to_create_schema = result.stdout 22 | 23 | ca = connection_args(self.__destination_connect) 24 | args = ['mysql'] + ca + ['-e', 'CREATE DATABASE ' + self.__destination_connect.db_name] 25 | result = subprocess.run(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE) 26 | if result.returncode != 0: 27 | raise Exception('Creating destination database failed. Details:\n{}'.format(result.stderr)) 28 | 29 | args = ['mysql', '-D', self.__destination_connect.db_name] + ca 30 | result = subprocess.run(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE, input=commands_to_create_schema) 31 | if result.returncode != 0: 32 | raise Exception('Creating destination schema. Details:\n{}'.format(result.stderr)) 33 | 34 | 35 | os.chdir(cur_path) 36 | 37 | def teardown(self): 38 | self.run_query_on_destination('DROP DATABASE IF EXISTS ' + self.__destination_connect.db_name + ';') 39 | 40 | def add_constraints(self): 41 | # no-op for mysql 42 | pass 43 | 44 | def run_query_on_destination(self, command): 45 | cur_path = os.getcwd() 46 | mysql_bin_path = get_mysql_bin_path() 47 | if mysql_bin_path != '': 48 | os.chdir(mysql_bin_path) 49 | 50 | ca = connection_args(self.__destination_connect) 51 | args = ['mysql'] + ca + ['-e', command] 52 | result = subprocess.run(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE) 53 | os.chdir(cur_path) 54 | if result.returncode != 0: 55 | raise Exception('Failed to run command \'{}\'. Details:\n{}'.format(command, result.stderr)) 56 | 57 | def get_mysql_bin_path(): 58 | if 'MYSQL_PATH' in os.environ: 59 | mysql_bin_path = os.environ['MYSQL_PATH'] 60 | else: 61 | mysql_bin_path = '' 62 | err = os.system('"' + os.path.join(mysql_bin_path, 'mysqldump') + '"' + ' --help > ' + os.devnull) 63 | if err != 0: 64 | raise Exception("Couldn't find MySQL utilities, consider specifying MYSQL_PATH environment variable if MySQL isn't " + 65 | "in your PATH.") 66 | return mysql_bin_path 67 | 68 | def connection_args(connect): 69 | host_arg = '--host={}'.format(connect.host) 70 | port_arg = '--port={}'.format(connect.port) 71 | user_arg = '--user={}'.format(connect.user) 72 | password_arg = '--password={}'.format(connect.password) 73 | return [host_arg, port_arg, user_arg, password_arg] 74 | 75 | 76 | # This is just for unit testing the creation and tear down processes 77 | if __name__ == '__main__': 78 | import config_reader, db_connect 79 | config_reader.initialize() 80 | src_connect = db_connect.DbConnect(config_reader.get_source_db_connection_info(), 'mysql') 81 | dest_connect = db_connect.DbConnect(config_reader.get_destination_db_connection_info(), 'mysql') 82 | msdbc = MySqlDatabaseCreator(src_connect, dest_connect) 83 | msdbc.teardown() 84 | msdbc.create() 85 | -------------------------------------------------------------------------------- /mysql_database_helper.py: -------------------------------------------------------------------------------- 1 | import os, uuid, csv 2 | import config_reader 3 | from pathlib import Path 4 | from subset_utils import columns_joined, columns_tupled, quoter, schema_name, table_name, fully_qualified_table, redact_relationships 5 | 6 | system_schemas_str = ','.join(['\'' + schema + '\'' for schema in ['information_schema', 'performance_schema', 'sys', 'mysql', 'innodb','tmp']]) 7 | temp_db = 'tonic_subset_temp_db_398dhjr23' 8 | 9 | def prep_temp_dbs(source_conn, destination_conn): 10 | run_query('DROP DATABASE IF EXISTS ' + temp_db, source_conn) 11 | run_query('DROP DATABASE IF EXISTS ' + temp_db, destination_conn) 12 | run_query('CREATE DATABASE IF NOT EXISTS ' + temp_db, source_conn) 13 | run_query('CREATE DATABASE IF NOT EXISTS ' + temp_db, destination_conn) 14 | 15 | def unprep_temp_dbs(source_conn, destination_conn): 16 | run_query('DROP DATABASE IF EXISTS ' + temp_db, source_conn) 17 | run_query('DROP DATABASE IF EXISTS ' + temp_db, destination_conn) 18 | 19 | def turn_off_constraints(connection): 20 | cur = connection.cursor() 21 | try: 22 | cur.execute('SET UNIQUE_CHECKS=0, FOREIGN_KEY_CHECKS=0;') 23 | finally: 24 | cur.close() 25 | 26 | def copy_rows(source, destination, query, destination_table): 27 | cursor = source.cursor() 28 | 29 | try: 30 | cursor.execute(query) 31 | fetch_row_count = 1000 32 | while True: 33 | rows = cursor.fetchmany(fetch_row_count) 34 | if len(rows) == 0: 35 | break 36 | 37 | template = ','.join(['%s']*len(rows[0])) 38 | destination_cursor = destination.cursor() 39 | insert_query = 'INSERT INTO {} VALUES ({})'.format(fully_qualified_table(destination_table), template) 40 | destination_cursor.executemany(insert_query, rows) 41 | 42 | destination_cursor.close() 43 | destination.commit() 44 | 45 | if len(rows) < fetch_row_count: 46 | # necessary because mysql doesn't behave if you fetchmany after the last row 47 | break 48 | except Exception as e: 49 | if hasattr(e, 'msg') and e.msg.startswith('Table') and e.msg.endswith('doesn\'t exist'): 50 | raise ValueError('Your database has foreign keys to another database. This is not currently supported.') 51 | else: 52 | raise e 53 | finally: 54 | cursor.close() 55 | 56 | def create_id_temp_table(conn, number_of_columns): 57 | temp_table = temp_db + '.' + str(uuid.uuid4()) 58 | cursor = conn.cursor() 59 | column_defs = ',\n'.join([' col' + str(aye) + ' text' for aye in range(number_of_columns)]) 60 | q = 'CREATE TABLE {} (\n {} \n)'.format(fully_qualified_table(temp_table), column_defs) 61 | cursor.execute(q) 62 | cursor.close() 63 | return temp_table 64 | 65 | def copy_to_temp_table(conn, query, target_table, pk_columns = None): 66 | cur = conn.cursor() 67 | temp_table = fully_qualified_table(source_db_temp_table(target_table)) 68 | try: 69 | cur.execute('CREATE TABLE IF NOT EXISTS ' + temp_table + ' AS ' + query + ' LIMIT 0') 70 | if pk_columns: 71 | query = query + ' WHERE {} NOT IN (SELECT {} FROM {})'.format(columns_tupled(pk_columns), columns_joined(pk_columns), temp_table) 72 | cur.execute('INSERT INTO ' + temp_table + ' ' + query) 73 | conn.commit() 74 | finally: 75 | cur.close() 76 | 77 | def clean_temp_table_cells(fk_table, fk_columns, target_table, target_columns, conn): 78 | fk_alias = 'tonic_subset_398dhjr23_fk' 79 | target_alias = 'tonic_subset_398dhjr23_target' 80 | 81 | fk_table = fully_qualified_table(source_db_temp_table(fk_table)) 82 | target_table = fully_qualified_table(source_db_temp_table(target_table)) 83 | assignment_list = ','.join(['{}.{} = NULL'.format(fk_alias, quoter(c)) for c in fk_columns]) 84 | column_matching = ' AND '.join(['{}.{} = {}.{}'.format(fk_alias, quoter(fc), target_alias, quoter(tc)) for fc, tc in zip(fk_columns, target_columns)]) 85 | target_columns_null = ' AND '.join(['{}.{} IS NULL'.format(target_alias, quoter(tc)) for tc in target_columns] 86 | + ['{}.{} IS NOT NULL'.format(fk_alias, quoter(c)) for c in fk_columns]) 87 | q = 'UPDATE {} {} LEFT JOIN {} {} ON {} SET {} WHERE {}'.format(fk_table, fk_alias, target_table, target_alias, column_matching, assignment_list, target_columns_null) 88 | run_query(q, conn) 89 | 90 | def source_db_temp_table(target_table): 91 | return temp_db + '.' + schema_name(target_table) + '_' + table_name(target_table) 92 | 93 | def get_redacted_table_references(table_name, tables, conn): 94 | relationships = get_unredacted_fk_relationships(tables, conn) 95 | redacted = redact_relationships(relationships) 96 | return [r for r in redacted if r['target_table']==table_name] 97 | 98 | def get_unredacted_fk_relationships(tables, conn): 99 | cur = conn.cursor() 100 | 101 | q = ''' 102 | SELECT 103 | concat(table_schema, '.', table_name) AS fk_table, 104 | group_concat(column_name ORDER BY ordinal_position) AS fk_column, 105 | concat(referenced_table_schema, '.', referenced_table_name) AS pk_name, 106 | group_concat(referenced_column_name ORDER BY ordinal_position) AS pk_name 107 | FROM 108 | information_schema.key_column_usage 109 | WHERE 110 | referenced_table_schema NOT IN ({}) 111 | GROUP BY 1, 3, constraint_schema, constraint_name; 112 | '''.format(system_schemas_str) 113 | 114 | cur.execute(q) 115 | 116 | relationships = list() 117 | 118 | for row in cur.fetchall(): 119 | d = dict() 120 | d['fk_table'] = row[0] 121 | d['fk_columns'] = row[1].split(',') 122 | d['target_table'] = row[2] 123 | d['target_columns'] = row[3].split(',') 124 | 125 | if d['fk_table'] in tables and d['target_table'] in tables: 126 | relationships.append( d ) 127 | cur.close() 128 | 129 | for augment in config_reader.get_fk_augmentation(): 130 | not_present = True 131 | for r in relationships: 132 | not_present = not_present and not all([r[key] == augment[key] for key in r.keys()]) 133 | if not not_present: 134 | break 135 | 136 | if augment['fk_table'] in tables and augment['target_table'] in tables and not_present: 137 | relationships.append(augment) 138 | 139 | return relationships 140 | 141 | def run_query(query, conn, commit=True): 142 | cur = conn.cursor() 143 | try: 144 | cur.execute(query) 145 | if commit: 146 | conn.commit() 147 | finally: 148 | cur.close() 149 | 150 | def get_table_count_estimate(table_name, schema, conn): 151 | cur = conn.cursor() 152 | try: 153 | cur.execute('SELECT table_rows AS count FROM information_schema.tables WHERE table_schema=\'{}\' AND table_name=\'{}\''.format(schema, table_name)) 154 | return cur.fetchone()[0] 155 | finally: 156 | cur.close() 157 | 158 | def get_table_columns(table, schema, conn): 159 | cur = conn.cursor() 160 | try: 161 | cur.execute('SELECT column_name FROM information_schema.columns WHERE table_schema = \'{}\' AND table_name = \'{}\' ORDER BY ordinal_position'.format(schema, table)) 162 | return [r[0] for r in cur.fetchall()] 163 | finally: 164 | cur.close() 165 | 166 | def list_all_tables(db_connect): 167 | conn = db_connect.get_db_connection() 168 | cur = conn.cursor() 169 | config_reader.get_source_db_connection_info() 170 | try: 171 | cur.execute('''SELECT 172 | concat(concat(table_schema,'.'),table_name) 173 | FROM 174 | information_schema.tables 175 | WHERE 176 | table_schema = '{}' AND table_type = 'BASE TABLE';'''.format(db_connect.db_name)) 177 | return [r[0] for r in cur.fetchall()] 178 | finally: 179 | cur.close() 180 | 181 | def truncate_table(target_table, conn): 182 | cur = conn.cursor() 183 | try: 184 | cur.execute("TRUNCATE TABLE {}".format(target_table)) 185 | conn.commit() 186 | finally: 187 | cur.close() 188 | -------------------------------------------------------------------------------- /psql_database_creator.py: -------------------------------------------------------------------------------- 1 | import os, urllib, subprocess 2 | from db_connect import DbConnect 3 | import database_helper 4 | 5 | class PsqlDatabaseCreator: 6 | def __init__(self, source_dbc, destination_dbc, use_existing_dump = False): 7 | self.destination_dbc = destination_dbc 8 | self.source_dbc = source_dbc 9 | self.__source_db_connection = source_dbc.get_db_connection() 10 | 11 | self.use_existing_dump = use_existing_dump 12 | 13 | self.output_path = os.path.join(os.getcwd(),'SQL') 14 | if not os.path.isdir(self.output_path): 15 | os.mkdir(self.output_path) 16 | 17 | self.add_constraint_output_path = os.path.join(os.getcwd(), 'SQL', 'add_constraint_output.txt') 18 | self.add_constraint_error_path = os.path.join(os.getcwd(), 'SQL', 'add_constraint_error.txt') 19 | 20 | if os.path.exists(self.add_constraint_output_path): 21 | os.remove(self.add_constraint_output_path) 22 | if os.path.exists(self.add_constraint_error_path): 23 | os.remove(self.add_constraint_error_path) 24 | 25 | 26 | self.create_output_path = os.path.join(os.getcwd(), 'SQL', 'create_output.txt') 27 | self.create_error_path = os.path.join(os.getcwd(), 'SQL', 'create_error.txt') 28 | 29 | if os.path.exists(self.create_output_path): 30 | os.remove(self.create_output_path) 31 | if os.path.exists(self.create_error_path): 32 | os.remove(self.create_error_path) 33 | 34 | def create(self): 35 | 36 | if self.use_existing_dump == True: 37 | pass 38 | else: 39 | cur_path = os.getcwd() 40 | 41 | pg_dump_path = get_pg_bin_path() 42 | if pg_dump_path != '': 43 | os.chdir(pg_dump_path) 44 | 45 | connection = '--dbname=postgresql://{0}@{2}:{3}/{4}?{1}'.format(self.source_dbc.user, urllib.parse.urlencode({'password': self.source_dbc.password}), self.source_dbc.host, self.source_dbc.port, self.source_dbc.db_name) 46 | 47 | result = subprocess.run(['pg_dump', connection, '--schema-only', '--no-owner', '--no-privileges', '--section=pre-data'] 48 | , stdout = subprocess.PIPE, stderr = subprocess.PIPE) 49 | if result.returncode != 0 or contains_errors(result.stderr): 50 | raise Exception('Captuing pre-data schema failed. Details:\n{}'.format(result.stderr)) 51 | os.chdir(cur_path) 52 | 53 | pre_data_sql = self.__filter_commands(result.stdout.decode('utf-8')) 54 | self.run_psql(pre_data_sql) 55 | 56 | def teardown(self): 57 | user_schemas = database_helper.get_specific_helper().list_all_user_schemas(self.__source_db_connection) 58 | 59 | if len(user_schemas) == 0: 60 | raise Exception("Couldn't find any non system schemas.") 61 | 62 | drop_statements = ["DROP SCHEMA IF EXISTS \"{}\" CASCADE".format(s) for s in user_schemas if s != 'public'] 63 | 64 | q = ';'.join(drop_statements) 65 | q += ";DROP SCHEMA IF EXISTS public CASCADE;CREATE SCHEMA IF NOT EXISTS public;" 66 | 67 | self.run_query(q) 68 | 69 | 70 | def add_constraints(self): 71 | if self.use_existing_dump == True: 72 | pass 73 | else: 74 | cur_path = os.getcwd() 75 | 76 | pg_dump_path = get_pg_bin_path() 77 | if pg_dump_path != '': 78 | os.chdir(pg_dump_path) 79 | connection = '--dbname=postgresql://{0}@{2}:{3}/{4}?{1}'.format(self.source_dbc.user, urllib.parse.urlencode({'password': self.source_dbc.password}), self.source_dbc.host, self.source_dbc.port, self.source_dbc.db_name) 80 | result = subprocess.run(['pg_dump', connection, '--schema-only', '--no-owner', '--no-privileges', '--section=post-data'] 81 | , stderr = subprocess.PIPE, stdout = subprocess.PIPE) 82 | if result.returncode != 0 or contains_errors(result.stderr): 83 | raise Exception('Captuing post-data schema failed. Details:\n{}'.format(result.stderr)) 84 | 85 | os.chdir(cur_path) 86 | 87 | self.run_psql(result.stdout.decode('utf-8')) 88 | 89 | def __filter_commands(self, input): 90 | 91 | input = input.split('\n') 92 | filtered_key_words = [ 93 | 'COMMENT ON CONSTRAINT', 94 | 'COMMENT ON EXTENSION' 95 | ] 96 | 97 | retval = [] 98 | for line in input: 99 | l = line.rstrip() 100 | filtered = False 101 | for key in filtered_key_words: 102 | if l.startswith(key): 103 | filtered = True 104 | 105 | if not filtered: 106 | retval.append(l) 107 | 108 | return '\n'.join(retval) 109 | 110 | def run_query(self, query): 111 | 112 | pg_dump_path = get_pg_bin_path() 113 | cur_path = os.getcwd() 114 | 115 | if(pg_dump_path != ''): 116 | os.chdir(pg_dump_path) 117 | 118 | connection_info = self.destination_dbc 119 | connection_string = '--dbname=postgresql://{0}@{2}:{3}/{4}?{1}'.format( 120 | connection_info.user, urllib.parse.urlencode({'password': connection_info.password}), connection_info.host, 121 | connection_info.port, connection_info.db_name) 122 | 123 | 124 | result = subprocess.run(['psql', connection_string, '-c {0}'.format(query)], stderr = subprocess.PIPE, stdout = subprocess.DEVNULL) 125 | if result.returncode != 0 or contains_errors(result.stderr): 126 | raise Exception('Running query: "{}" failed. Details:\n{}'.format(query, result.stderr)) 127 | 128 | os.chdir(cur_path) 129 | 130 | def run_psql(self, queries): 131 | 132 | pg_dump_path = get_pg_bin_path() 133 | cur_path = os.getcwd() 134 | 135 | if(pg_dump_path != ''): 136 | os.chdir(pg_dump_path) 137 | 138 | connect = self.destination_dbc 139 | connection_string = '--dbname=postgresql://{0}@{2}:{3}/{4}?{1}'.format( 140 | connect.user, urllib.parse.urlencode({'password': connect.password}), connect.host, 141 | connect.port, connect.db_name) 142 | 143 | input = queries.encode('utf-8') 144 | result = subprocess.run(['psql', connection_string], stderr = subprocess.PIPE, input = input, stdout= subprocess.DEVNULL) 145 | if result.returncode != 0 or contains_errors(result.stderr): 146 | raise Exception('Creating schema failed. Details:\n{}'.format(result.stderr)) 147 | 148 | os.chdir(cur_path) 149 | 150 | def get_pg_bin_path(): 151 | if 'POSTGRES_PATH' in os.environ: 152 | pg_dump_path = os.environ['POSTGRES_PATH'] 153 | else: 154 | pg_dump_path = '' 155 | err = os.system('"' + os.path.join(pg_dump_path, 'pg_dump') + '"' + ' --help > ' + os.devnull) 156 | if err != 0: 157 | raise Exception("Couldn't find Postgres utilities, consider specifying POSTGRES_PATH environment variable if Postgres isn't " + 158 | "in your PATH.") 159 | return pg_dump_path 160 | 161 | def contains_errors(stderr): 162 | msgs = stderr.decode('utf-8') 163 | return any(filter(lambda msg: msg.strip().startswith('ERROR'), msgs.split('\n'))) 164 | -------------------------------------------------------------------------------- /psql_database_helper.py: -------------------------------------------------------------------------------- 1 | import os, uuid, csv 2 | import config_reader 3 | from pathlib import Path 4 | from psycopg2.extras import execute_values, register_default_json, register_default_jsonb 5 | from subset_utils import columns_joined, columns_tupled, schema_name, table_name, fully_qualified_table, redact_relationships, quoter 6 | 7 | register_default_json(loads=lambda x: str(x)) 8 | register_default_jsonb(loads=lambda x: str(x)) 9 | 10 | def prep_temp_dbs(_, __): 11 | pass 12 | 13 | def unprep_temp_dbs(_, __): 14 | pass 15 | 16 | def turn_off_constraints(connection): 17 | # can't be done in postgres 18 | pass 19 | 20 | def copy_rows(source, destination, query, destination_table): 21 | datatypes = get_table_datatypes(table_name(destination_table), schema_name(destination_table), destination) 22 | 23 | non_generated_columns = [(dt[0], dt[1]) for i, dt in enumerate(datatypes) if dt[2] != 's'] 24 | generated_columns_positions = [i for i, dt in enumerate(datatypes) if 's' in dt[2]] 25 | always_generated_id = any([dt[3] == 'a' for dt in datatypes]) 26 | 27 | def template_piece(dt): 28 | if dt == '_json': 29 | return '%s::json[]' 30 | elif dt == '_jsonb': 31 | return '%s::jsonb[]' 32 | else: 33 | return '%s' 34 | 35 | template = '(' + ','.join([template_piece(dt[1]) for dt in non_generated_columns]) + ')' 36 | columns = '("' + '","'.join([dt[0] for dt in non_generated_columns]) + '")' 37 | 38 | cursor_name='table_cursor_'+str(uuid.uuid4()).replace('-','') 39 | cursor = source.cursor(name=cursor_name) 40 | cursor.execute(query) 41 | 42 | fetch_row_count = 100000 43 | while True: 44 | rows = cursor.fetchmany(fetch_row_count) 45 | if len(rows) == 0: 46 | break 47 | 48 | # using the inner_cursor means we don't log all the noise 49 | destination_cursor = destination.cursor().inner_cursor 50 | 51 | insert_query = 'INSERT INTO {} {} VALUES %s'.format(fully_qualified_table(destination_table), columns) 52 | if (always_generated_id): 53 | insert_query = 'INSERT INTO {} {} OVERRIDING SYSTEM VALUE VALUES %s'.format(fully_qualified_table(destination_table), columns) 54 | 55 | updated_rows = [tuple(val for i, val in enumerate(row) if i not in generated_columns_positions) for row in rows] 56 | 57 | execute_values(destination_cursor, insert_query, updated_rows, template) 58 | 59 | destination_cursor.close() 60 | 61 | cursor.close() 62 | destination.commit() 63 | 64 | def source_db_temp_table(target_table): 65 | return 'tonic_subset_' + schema_name(target_table) + '_' + table_name(target_table) 66 | 67 | def create_id_temp_table(conn, number_of_columns): 68 | table_name = 'tonic_subset_' + str(uuid.uuid4()) 69 | cursor = conn.cursor() 70 | column_defs = ',\n'.join([' col' + str(aye) + ' varchar' for aye in range(number_of_columns)]) 71 | q = 'CREATE TEMPORARY TABLE "{}" (\n {} \n)'.format(table_name, column_defs) 72 | cursor.execute(q) 73 | cursor.close() 74 | return table_name 75 | 76 | def copy_to_temp_table(conn, query, target_table, pk_columns = None): 77 | temp_table = fully_qualified_table(source_db_temp_table(target_table)) 78 | with conn.cursor() as cur: 79 | cur.execute('CREATE TEMPORARY TABLE IF NOT EXISTS ' + temp_table + ' AS ' + query + ' LIMIT 0') 80 | if pk_columns: 81 | query = query + ' WHERE {} NOT IN (SELECT {} FROM {})'.format(columns_tupled(pk_columns), columns_joined(pk_columns), temp_table) 82 | cur.execute('INSERT INTO ' + temp_table + ' ' + query) 83 | conn.commit() 84 | 85 | def clean_temp_table_cells(fk_table, fk_columns, target_table, target_columns, conn): 86 | fk_alias = 'tonic_subset_398dhjr23_fk' 87 | target_alias = 'tonic_subset_398dhjr23_target' 88 | 89 | fk_table = fully_qualified_table(source_db_temp_table(fk_table)) 90 | target_table = fully_qualified_table(source_db_temp_table(target_table)) 91 | assignment_list = ','.join(['{} = NULL'.format(quoter(c)) for c in fk_columns]) 92 | column_matching = ' AND '.join(['{}.{} = {}.{}'.format(fk_alias, quoter(fc), target_alias, quoter(tc)) for fc, tc in zip(fk_columns, target_columns)]) 93 | q = 'UPDATE {} {} SET {} WHERE NOT EXISTS (SELECT 1 FROM {} {} WHERE {})'.format(fk_table, fk_alias, assignment_list, target_table, target_alias, column_matching) 94 | run_query(q, conn) 95 | 96 | def get_redacted_table_references(table_name, tables, conn): 97 | relationships = get_unredacted_fk_relationships(tables, conn) 98 | redacted = redact_relationships(relationships) 99 | return [r for r in redacted if r['target_table']==table_name] 100 | 101 | def get_unredacted_fk_relationships(tables, conn): 102 | cur = conn.cursor() 103 | 104 | q = ''' 105 | SELECT fk_nsp.nspname || '.' || fk_table AS fk_table, array_agg(fk_att.attname ORDER BY fk_att.attnum) AS fk_columns, tar_nsp.nspname || '.' || target_table AS target_table, array_agg(tar_att.attname ORDER BY fk_att.attnum) AS target_columns 106 | FROM ( 107 | SELECT 108 | fk.oid AS fk_table_id, 109 | fk.relnamespace AS fk_schema_id, 110 | fk.relname AS fk_table, 111 | unnest(con.conkey) as fk_column_id, 112 | 113 | tar.oid AS target_table_id, 114 | tar.relnamespace AS target_schema_id, 115 | tar.relname AS target_table, 116 | unnest(con.confkey) as target_column_id, 117 | 118 | con.connamespace AS constraint_nsp, 119 | con.conname AS constraint_name 120 | 121 | FROM pg_constraint con 122 | JOIN pg_class fk ON con.conrelid = fk.oid 123 | JOIN pg_class tar ON con.confrelid = tar.oid 124 | WHERE con.contype = 'f' 125 | ) sub 126 | JOIN pg_attribute fk_att ON fk_att.attrelid = fk_table_id AND fk_att.attnum = fk_column_id 127 | JOIN pg_attribute tar_att ON tar_att.attrelid = target_table_id AND tar_att.attnum = target_column_id 128 | JOIN pg_namespace fk_nsp ON fk_schema_id = fk_nsp.oid 129 | JOIN pg_namespace tar_nsp ON target_schema_id = tar_nsp.oid 130 | GROUP BY 1, 3, sub.constraint_nsp, sub.constraint_name; 131 | ''' 132 | 133 | cur.execute(q) 134 | 135 | relationships = list() 136 | 137 | for row in cur.fetchall(): 138 | d = dict() 139 | d['fk_table'] = row[0] 140 | d['fk_columns'] = row[1] 141 | d['target_table'] = row[2] 142 | d['target_columns'] = row[3] 143 | 144 | if d['fk_table'] in tables and d['target_table'] in tables: 145 | relationships.append( d ) 146 | cur.close() 147 | 148 | for augment in config_reader.get_fk_augmentation(): 149 | not_present = True 150 | for r in relationships: 151 | not_present = not_present and not all([r[key] == augment[key] for key in r.keys()]) 152 | if not not_present: 153 | break 154 | 155 | if augment['fk_table'] in tables and augment['target_table'] in tables and not_present: 156 | relationships.append(augment) 157 | 158 | return relationships 159 | 160 | def run_query(query, conn, commit=True): 161 | with conn.cursor() as cur: 162 | cur.execute(query) 163 | if commit: 164 | conn.commit() 165 | 166 | def get_table_count_estimate(table_name, schema, conn): 167 | with conn.cursor() as cur: 168 | cur.execute('SELECT reltuples::BIGINT AS count FROM pg_class WHERE oid=\'"{}"."{}"\'::regclass'.format(schema, table_name)) 169 | return cur.fetchone()[0] 170 | 171 | def get_table_columns(table, schema, conn): 172 | with conn.cursor() as cur: 173 | cur.execute('SELECT attname FROM pg_attribute WHERE attrelid=\'"{}"."{}"\'::regclass AND attnum > 0 AND NOT attisdropped ORDER BY attnum;'.format(schema, table)) 174 | return [r[0] for r in cur.fetchall()] 175 | 176 | def list_all_user_schemas(conn): 177 | with conn.cursor() as cur: 178 | cur.execute("SELECT nspname FROM pg_catalog.pg_namespace WHERE nspname NOT LIKE 'pg\_%' and nspname != 'information_schema';") 179 | return [r[0] for r in cur.fetchall()] 180 | 181 | def list_all_tables(db_connect): 182 | conn = db_connect.get_db_connection() 183 | with conn.cursor() as cur: 184 | cur.execute("""SELECT concat(concat(nsp.nspname,'.'),cls.relname) 185 | FROM pg_class cls 186 | JOIN pg_namespace nsp ON nsp.oid = cls.relnamespace 187 | WHERE nsp.nspname NOT IN ('information_schema', 'pg_catalog') AND cls.relkind = 'r';""") 188 | return [r[0] for r in cur.fetchall()] 189 | 190 | def get_table_datatypes(table, schema, conn): 191 | if not schema: 192 | table_clause = "cl.relname = '{}'".format(table) 193 | else: 194 | table_clause = "cl.relname = '{}' AND ns.nspname = '{}'".format(table, schema) 195 | with conn.cursor() as cur: 196 | cur.execute("""SELECT att.attname, ty.typname, att.attgenerated, att.attidentity 197 | FROM pg_attribute att 198 | JOIN pg_class cl ON cl.oid = att.attrelid 199 | JOIN pg_type ty ON ty.oid = att.atttypid 200 | JOIN pg_namespace ns ON ns.oid = cl.relnamespace 201 | WHERE {} AND att.attnum > 0 AND 202 | NOT att.attisdropped 203 | ORDER BY att.attnum; 204 | """.format(table_clause)) 205 | 206 | return [(r[0], r[1], r[2], r[3]) for r in cur.fetchall()] 207 | 208 | def truncate_table(target_table, conn): 209 | with conn.cursor() as cur: 210 | cur.execute("TRUNCATE TABLE {}".format(target_table)) 211 | conn.commit() 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | toposort 2 | psycopg2-binary 3 | mysql-connector-python 4 | -------------------------------------------------------------------------------- /result_tabulator.py: -------------------------------------------------------------------------------- 1 | import config_reader 2 | import database_helper 3 | from db_connect import MySqlConnection 4 | 5 | 6 | def tabulate(source_dbc, destination_dbc, tables): 7 | #tabulate 8 | row_counts = list() 9 | source_conn = source_dbc.get_db_connection() 10 | dest_conn = destination_dbc.get_db_connection() 11 | db_helper = database_helper.get_specific_helper() 12 | try: 13 | for table in tables: 14 | o = db_helper.get_table_count_estimate(table_name(table), schema_name(table), source_conn) 15 | dest_schema_name = dest_conn.db_name if isinstance(dest_conn, MySqlConnection) else schema_name(table) 16 | n = db_helper.get_table_count_estimate(table_name(table), dest_schema_name, dest_conn) 17 | row_counts.append((table,o,n)) 18 | finally: 19 | source_conn.close() 20 | dest_conn.close() 21 | 22 | print('\n'.join(['{}, {}, {}, {}'.format(x[0], x[1], x[2], x[2]/x[1] if x[1] > 0 else 0) for x in row_counts])) 23 | 24 | 25 | def schema_name(table): 26 | return table.split('.')[0] 27 | 28 | def table_name(table): 29 | return table.split('.')[1] 30 | -------------------------------------------------------------------------------- /subset.py: -------------------------------------------------------------------------------- 1 | from topo_orderer import get_topological_order_by_tables 2 | from subset_utils import UnionFind, schema_name, table_name, find, compute_disconnected_tables, compute_downstream_tables, compute_upstream_tables, columns_joined, columns_tupled, columns_to_copy, quoter, fully_qualified_table, print_progress, mysql_db_name_hack, upstream_filter_match, redact_relationships 3 | import database_helper 4 | import config_reader 5 | import shutil, os, uuid, time, itertools 6 | 7 | # 8 | # A QUICK NOTE ON DEFINITIONS: 9 | # 10 | # Foreign key relationships form a graph. We make sure all subsetting happens on DAGs. 11 | # Nodes in the DAG are tables, and FKs point from the table with a FK column to the table 12 | # with the PK column. In other words, tables with FKs are upstream of tables with PKs. 13 | # 14 | # Sometimes we'll refer to tables as downstream or 'target' tables, because they are 15 | # targeted by foreign keys. We will also use upstream or 'fk' tables, because they 16 | # have foreign keys. 17 | # 18 | # Generally speaking, tables downstream of other tables have their membership defined 19 | # by the requirements of their upstream tables. And tables upstream can be more flexible 20 | # about their membership vis-a-vis the downstream tables (i.e. upstream tables can decide 21 | # to include more or less). 22 | # 23 | 24 | class Subset: 25 | 26 | def __init__(self, source_dbc, destination_dbc, all_tables, clean_previous = True): 27 | self.__source_dbc = source_dbc 28 | self.__destination_dbc = destination_dbc 29 | 30 | self.__source_conn = source_dbc.get_db_connection(read_repeatable=True) 31 | self.__destination_conn = destination_dbc.get_db_connection() 32 | 33 | self.__all_tables = all_tables 34 | 35 | self.__db_helper = database_helper.get_specific_helper() 36 | 37 | self.__db_helper.turn_off_constraints(self.__destination_conn) 38 | 39 | 40 | def run_middle_out(self): 41 | passthrough_tables = self.__get_passthrough_tables() 42 | relationships = self.__db_helper.get_unredacted_fk_relationships(self.__all_tables, self.__source_conn) 43 | disconnected_tables = compute_disconnected_tables(config_reader.get_initial_target_tables(), passthrough_tables, self.__all_tables, relationships) 44 | connected_tables = [table for table in self.__all_tables if table not in disconnected_tables] 45 | order = get_topological_order_by_tables(relationships, connected_tables) 46 | order = list(order) 47 | 48 | # start by subsetting the direct targets 49 | print('Beginning subsetting with these direct targets: ' + str(config_reader.get_initial_target_tables())) 50 | start_time = time.time() 51 | processed_tables = set() 52 | for idx, target in enumerate(config_reader.get_initial_targets()): 53 | print_progress(target, idx+1, len(config_reader.get_initial_targets())) 54 | self.__subset_direct(target, relationships) 55 | processed_tables.add(target['table']) 56 | print('Direct target tables completed in {}s'.format(time.time()-start_time)) 57 | 58 | # greedily grab rows with foreign keys to rows in the target strata 59 | upstream_tables = compute_upstream_tables(config_reader.get_initial_target_tables(), order) 60 | print('Beginning greedy upstream subsetting with these tables: ' + str(upstream_tables)) 61 | start_time = time.time() 62 | for idx, t in enumerate(upstream_tables): 63 | print_progress(t, idx+1, len(upstream_tables)) 64 | data_added = self.__subset_upstream(t, processed_tables, relationships) 65 | if data_added: 66 | processed_tables.add(t) 67 | print('Greedy subsettings completed in {}s'.format(time.time()-start_time)) 68 | 69 | # process pass-through tables, you need this before subset_downstream, so you can get all required downstream rows 70 | print('Beginning pass-through tables: ' + str(passthrough_tables)) 71 | start_time = time.time() 72 | for idx, t in enumerate(passthrough_tables): 73 | print_progress(t, idx+1, len(passthrough_tables)) 74 | q = 'SELECT * FROM {}'.format(fully_qualified_table(t)) 75 | if config_reader.get_max_rows_per_table() is not None: 76 | q += 'LIMIT {}'.format(config_reader.get_max_rows_per_table()) 77 | self.__db_helper.copy_rows(self.__source_conn, self.__destination_conn, q, mysql_db_name_hack(t, self.__destination_conn)) 78 | print('Pass-through completed in {}s'.format(time.time()-start_time)) 79 | 80 | # use subset_downstream to get all supporting rows according to existing needs 81 | downstream_tables = compute_downstream_tables(passthrough_tables, disconnected_tables, order) 82 | print('Beginning downstream subsetting with these tables: ' + str(downstream_tables)) 83 | start_time = time.time() 84 | for idx, t in enumerate(downstream_tables): 85 | print_progress(t, idx+1, len(downstream_tables)) 86 | self.subset_downstream(t, relationships) 87 | print('Downstream subsetting completed in {}s'.format(time.time()-start_time)) 88 | 89 | if config_reader.keep_disconnected_tables(): 90 | # get all the data for tables in disconnected components (i.e. pass those tables through) 91 | print('Beginning disconnected tables: ' + str(disconnected_tables)) 92 | start_time = time.time() 93 | for idx, t in enumerate(disconnected_tables): 94 | print_progress(t, idx+1, len(disconnected_tables)) 95 | q = 'SELECT * FROM {}'.format(fully_qualified_table(t)) 96 | self.__db_helper.copy_rows(self.__source_conn, self.__destination_conn, q, mysql_db_name_hack(t, self.__destination_conn)) 97 | print('Disconnected tables completed in {}s'.format(time.time()-start_time)) 98 | 99 | def prep_temp_dbs(self): 100 | self.__db_helper.prep_temp_dbs(self.__source_conn, self.__destination_conn) 101 | 102 | def unprep_temp_dbs(self): 103 | self.__db_helper.unprep_temp_dbs(self.__source_conn, self.__destination_conn) 104 | 105 | def __subset_direct(self, target, relationships): 106 | t = target['table'] 107 | columns_query = columns_to_copy(t, relationships, self.__source_conn) 108 | if 'where' in target: 109 | q = 'SELECT {} FROM {} WHERE {}'.format(columns_query, fully_qualified_table(t), target['where']) 110 | elif 'percent' in target: 111 | if config_reader.get_db_type() == 'postgres': 112 | q = 'SELECT {} FROM {} WHERE random() < {}'.format(columns_query, fully_qualified_table(t), float(target['percent'])/100) 113 | else: 114 | q = 'SELECT {} FROM {} WHERE rand() < {}'.format(columns_query, fully_qualified_table(t), float(target['percent'])/100) 115 | else: 116 | raise ValueError('target table {} had no \'where\' or \'percent\' term defined, check your configuration.'.format(t)) 117 | self.__db_helper.copy_rows(self.__source_conn, self.__destination_conn, q, mysql_db_name_hack(t, self.__destination_conn)) 118 | 119 | 120 | def __subset_upstream(self, target, processed_tables, relationships): 121 | 122 | redacted_relationships = redact_relationships(relationships) 123 | relevant_key_constraints = list(filter(lambda r: r['target_table'] in processed_tables and r['fk_table'] == target, redacted_relationships)) 124 | # this table isn't referenced by anything we've already processed, so let's leave it empty 125 | # OR 126 | # table was already added, this only happens if the upstream table was also a direct target 127 | if len(relevant_key_constraints) == 0 or target in processed_tables: 128 | return False 129 | 130 | temp_target_name = 'subset_temp_' + table_name(target) 131 | 132 | try: 133 | # copy the whole table 134 | columns_query = columns_to_copy(target, relationships, self.__source_conn) 135 | self.__db_helper.run_query('CREATE TEMPORARY TABLE {} AS SELECT * FROM {} LIMIT 0'.format(quoter(temp_target_name), fully_qualified_table(mysql_db_name_hack(target, self.__destination_conn))), self.__destination_conn) 136 | query = 'SELECT {} FROM {}'.format(columns_query, fully_qualified_table(target)) 137 | self.__db_helper.copy_rows(self.__source_conn, self.__destination_conn, query, temp_target_name) 138 | 139 | # filter it down in the target database 140 | table_columns = self.__db_helper.get_table_columns(table_name(target), schema_name(target), self.__source_conn) 141 | clauses = ['{} IN (SELECT {} FROM {})'.format(columns_tupled(kc['fk_columns']), columns_joined(kc['target_columns']), fully_qualified_table(mysql_db_name_hack(kc['target_table'], self.__destination_conn))) for kc in relevant_key_constraints] 142 | clauses.extend(upstream_filter_match(target, table_columns)) 143 | 144 | select_query = 'SELECT * FROM {} WHERE TRUE AND {}'.format(quoter(temp_target_name), ' AND '.join(clauses)) 145 | if config_reader.get_max_rows_per_table() is not None: 146 | select_query += " LIMIT {}".format(config_reader.get_max_rows_per_table()) 147 | insert_query = 'INSERT INTO {} {}'.format(fully_qualified_table(mysql_db_name_hack(target, self.__destination_conn)), select_query) 148 | self.__db_helper.run_query(insert_query, self.__destination_conn) 149 | self.__destination_conn.commit() 150 | 151 | finally: 152 | # delete temporary table 153 | mysql_temporary = 'TEMPORARY' if config_reader.get_db_type() == 'mysql' else '' 154 | self.__db_helper.run_query('DROP {} TABLE IF EXISTS {}'.format(mysql_temporary, quoter(temp_target_name)), self.__destination_conn) 155 | 156 | return True 157 | 158 | 159 | def __get_passthrough_tables(self): 160 | passthrough_tables = config_reader.get_passthrough_tables() 161 | return list(set(passthrough_tables)) 162 | 163 | # Table A -> Table B and Table A has the column b_id. So we SELECT b_id from table_a from our destination 164 | # database. And we take those b_ids and run `select * from table b where id in (those list of ids)` then insert 165 | # that result set into table b of the destination database 166 | def subset_downstream(self, table, relationships): 167 | referencing_tables = self.__db_helper.get_redacted_table_references(table, self.__all_tables, self.__source_conn) 168 | 169 | if len(referencing_tables) > 0: 170 | pk_columns = referencing_tables[0]['target_columns'] 171 | else: 172 | return 173 | 174 | temp_table = self.__db_helper.create_id_temp_table(self.__destination_conn, len(pk_columns)) 175 | 176 | for r in referencing_tables: 177 | fk_table = r['fk_table'] 178 | fk_columns = r['fk_columns'] 179 | 180 | q='SELECT {} FROM {} WHERE {} NOT IN (SELECT {} FROM {})'.format(columns_joined(fk_columns), fully_qualified_table(mysql_db_name_hack(fk_table, self.__destination_conn)), columns_tupled(fk_columns), columns_joined(pk_columns), fully_qualified_table(mysql_db_name_hack(table, self.__destination_conn))) 181 | self.__db_helper.copy_rows(self.__destination_conn, self.__destination_conn, q, temp_table) 182 | 183 | columns_query = columns_to_copy(table, relationships, self.__source_conn) 184 | 185 | cursor_name='table_cursor_'+str(uuid.uuid4()).replace('-','') 186 | cursor = self.__destination_conn.cursor(name=cursor_name, withhold=True) 187 | cursor_query ='SELECT DISTINCT * FROM {}'.format(fully_qualified_table(temp_table)) 188 | cursor.execute(cursor_query) 189 | fetch_row_count = 100000 190 | while True: 191 | rows = cursor.fetchmany(fetch_row_count) 192 | if len(rows) == 0: 193 | break 194 | 195 | ids = ['('+','.join(['\'' + str(c) + '\'' for c in row])+')' for row in rows if all([c is not None for c in row])] 196 | 197 | if len(ids) == 0: 198 | break 199 | 200 | ids_to_query = ','.join(ids) 201 | q = 'SELECT {} FROM {} WHERE {} IN ({})'.format(columns_query, fully_qualified_table(table), columns_tupled(pk_columns), ids_to_query) 202 | self.__db_helper.copy_rows(self.__source_conn, self.__destination_conn, q, mysql_db_name_hack(table, self.__destination_conn)) 203 | 204 | cursor.close() 205 | -------------------------------------------------------------------------------- /subset_utils.py: -------------------------------------------------------------------------------- 1 | import config_reader 2 | import database_helper 3 | from db_connect import MySqlConnection 4 | 5 | # this function generally copies all columns as is, but if the table has been selected as 6 | # breaking a dependency cycle, then it will insert NULLs instead of that table's foreign keys 7 | # to the downstream dependency that breaks the cycle 8 | def columns_to_copy(table, relationships, conn): 9 | target_breaks = set() 10 | opportunists = config_reader.get_preserve_fk_opportunistically() 11 | for dep_break in config_reader.get_dependency_breaks(): 12 | if dep_break.fk_table == table and dep_break not in opportunists: 13 | target_breaks.add(dep_break.target_table) 14 | 15 | columns_to_null = set() 16 | for rel in relationships: 17 | if rel['fk_table'] == table and rel['target_table'] in target_breaks: 18 | columns_to_null.update(rel['fk_columns']) 19 | 20 | columns = database_helper.get_specific_helper().get_table_columns(table_name(table), schema_name(table), conn) 21 | return ','.join(['{}.{}'.format(quoter(table_name(table)), quoter(c)) if c not in columns_to_null else 'NULL as {}'.format(quoter(c)) for c in columns]) 22 | 23 | def upstream_filter_match(target, table_columns): 24 | retval = [] 25 | filters = config_reader.get_upstream_filters() 26 | for filter in filters: 27 | if "table" in filter and target == filter["table"]: 28 | retval.append(filter["condition"]) 29 | if "column" in filter and filter["column"] in table_columns: 30 | retval.append(filter["condition"]) 31 | return retval 32 | 33 | def redact_relationships(relationships): 34 | breaks = config_reader.get_dependency_breaks() 35 | retval = [r for r in relationships if (r['fk_table'], r['target_table']) not in breaks] 36 | return retval 37 | 38 | def find(f, seq): 39 | """Return first item in sequence where f(item) == True.""" 40 | for item in seq: 41 | if f(item): 42 | return item 43 | 44 | def compute_upstream_tables(target_tables, order): 45 | upstream_tables = [] 46 | in_upstream = False 47 | for strata in order: 48 | if in_upstream: 49 | upstream_tables.extend(strata) 50 | if any([tt in strata for tt in target_tables]): 51 | in_upstream = True 52 | return upstream_tables 53 | 54 | def compute_downstream_tables(passthrough_tables, disconnected_tables, order): 55 | downstream_tables = [] 56 | for strata in order: 57 | downstream_tables.extend(strata) 58 | downstream_tables = list(reversed(list(filter(lambda table: table not in passthrough_tables and table not in disconnected_tables, downstream_tables)))) 59 | return downstream_tables 60 | 61 | def compute_disconnected_tables(target_tables, passthrough_tables, all_tables, relationships): 62 | uf = UnionFind() 63 | for t in all_tables: 64 | uf.make_set(t) 65 | for rel in relationships: 66 | uf.link(rel['fk_table'], rel['target_table']) 67 | 68 | connected_components = set([uf.find(tt) for tt in target_tables]) 69 | connected_components.update([uf.find(pt) for pt in passthrough_tables]) 70 | return [t for t in all_tables if uf.find(t) not in connected_components] 71 | 72 | def fully_qualified_table(table): 73 | if '.' in table: 74 | return quoter(schema_name(table)) + '.' + quoter(table_name(table)) 75 | else: 76 | return quoter(table_name(table)) 77 | 78 | def schema_name(table): 79 | return table.split('.')[0] if '.' in table else None 80 | 81 | def table_name(table): 82 | split = table.split('.') 83 | return split[1] if len(split) > 1 else split[0] 84 | 85 | def columns_tupled(columns): 86 | return '(' + ','.join([quoter(c) for c in columns]) + ')' 87 | 88 | def columns_joined(columns): 89 | return ','.join([quoter(c) for c in columns]) 90 | 91 | def quoter(id): 92 | q = '"' if config_reader.get_db_type() == 'postgres' else '`' 93 | return q + id + q 94 | 95 | def print_progress(target, idx, count): 96 | print('Processing {} of {}: {}'.format(idx, count, target)) 97 | 98 | class UnionFind: 99 | 100 | def __init__(self): 101 | self.elementsToId = dict() 102 | self.elements = [] 103 | self.roots = [] 104 | self.ranks = [] 105 | 106 | def __len__(self): 107 | return len(self.roots) 108 | 109 | def make_set(self, elem): 110 | self.id_of(elem) 111 | 112 | def find(self, elem): 113 | x = self.elementsToId[elem] 114 | if x == None: 115 | return None 116 | 117 | rootId = self.find_internal(x) 118 | return self.elements[rootId] 119 | 120 | def find_internal(self, x): 121 | x0 = x 122 | while self.roots[x] != x: 123 | x = self.roots[x] 124 | 125 | while self.roots[x0] != x: 126 | y = self.roots[x0] 127 | self.roots[x0] = x 128 | x0 = y 129 | 130 | return x 131 | 132 | def id_of(self, elem): 133 | if elem not in self.elementsToId: 134 | idx = len(self.roots) 135 | self.elements.append(elem) 136 | self.elementsToId[elem] = idx 137 | self.roots.append(idx) 138 | self.ranks.append(0) 139 | 140 | return self.elementsToId[elem] 141 | 142 | def link(self, elem1, elem2): 143 | x = self.id_of(elem1) 144 | y = self.id_of(elem2) 145 | 146 | xr = self.find_internal(x) 147 | yr = self.find_internal(y) 148 | if xr == yr: 149 | return 150 | 151 | xd = self.ranks[xr] 152 | yd = self.ranks[yr] 153 | if xd < yd: 154 | self.roots[xr] = yr 155 | elif yd < xd: 156 | self.roots[yr] = xr 157 | else: 158 | self.roots[yr] = xr 159 | self.ranks[xr] = self.ranks[xr] + 1 160 | 161 | def members_of(self, elem): 162 | id = self.elementsToId[elem] 163 | if id is None: 164 | raise ValueError("tried calling membersOf on an unknown element") 165 | 166 | elemRoot = self.find_internal(id) 167 | retval = [] 168 | for idx in range(len(self.elements)): 169 | otherRoot = self.find_internal(idx) 170 | if elemRoot == otherRoot: 171 | retval.append(self.elements[idx]) 172 | 173 | return retval 174 | 175 | def mysql_db_name_hack(target, conn): 176 | if not isinstance(conn, MySqlConnection) or '.' not in target: 177 | return target 178 | else: 179 | return conn.db_name + '.' + table_name(target) 180 | -------------------------------------------------------------------------------- /topo_orderer.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort, toposort_flatten 2 | import config_reader 3 | 4 | def get_topological_order_by_tables(relationships, tables): 5 | topsort_input = __prepare_topsort_input(relationships, tables) 6 | return list(toposort(topsort_input)) 7 | 8 | def __prepare_topsort_input(relationships, tables): 9 | dep_breaks = config_reader.get_dependency_breaks() 10 | deps = dict() 11 | for r in relationships: 12 | p =r['fk_table'] 13 | c =r['target_table'] 14 | 15 | #break circ dependency 16 | dep_break_found = False 17 | for dep_break in dep_breaks: 18 | if p == dep_break.fk_table and c == dep_break.target_table: 19 | dep_break_found = True 20 | break 21 | 22 | if dep_break_found == True: 23 | continue 24 | 25 | # toposort ignores self circularities for some reason, but we cannot 26 | if p == c: 27 | raise ValueError('Circular dependency, {} depends on itself!'.format(p)) 28 | 29 | if tables is not None and len(tables) > 0 and (p not in tables or c not in tables): 30 | continue 31 | 32 | if p in deps: 33 | deps[p].add(c) 34 | else: 35 | deps[p] = set() 36 | deps[p].add(c) 37 | 38 | return deps 39 | --------------------------------------------------------------------------------