├── .gitignore ├── LICENSE ├── README.md ├── docs ├── artificial_data_admin_guide.md ├── artificial_data_user_notice.md └── build_docs │ ├── README.md │ ├── create_user_docs_template.py │ └── make_create_user_docs.py ├── environment.yml ├── notebooks ├── admin │ └── collab │ │ ├── artificial_hes │ │ ├── 1_scrape_metadata.py │ │ ├── 2_review_metadata.py │ │ ├── 3_release_metadata.py │ │ ├── 4_generate_artificial_data.py │ │ ├── 5_qa_artificial_data.py │ │ └── 6_release_artificial_data.py │ │ ├── example_dotenv.py │ │ ├── job_utils.py │ │ └── meta_review │ │ ├── assertions.py │ │ ├── assertions_tests.py │ │ ├── field_definitions.py │ │ ├── regex_patterns.py │ │ └── regex_patterns_tests.py └── user │ └── collab │ └── artificial_data_user_notice.py ├── projects ├── artificial_hes │ ├── cp_config.py │ ├── init_schemas.py │ ├── notebooks │ │ ├── code_promotion_paths.py │ │ └── common │ │ │ ├── coerce_schema.py │ │ │ ├── common_exports.py │ │ │ ├── spark_helpers.py │ │ │ ├── table_helpers.py │ │ │ └── widget_utils.py │ ├── run_notebooks.py │ ├── run_tests.py │ ├── schemas │ │ └── create_user_docs.py │ └── tests │ │ └── placeholder.py ├── artificial_hes_meta │ ├── cp_config.py │ ├── init_schemas.py │ ├── notebooks │ │ ├── code_promotion_paths.py │ │ └── common │ │ │ ├── coerce_schema.py │ │ │ ├── common_exports.py │ │ │ ├── spark_helpers.py │ │ │ ├── table_helpers.py │ │ │ └── widget_utils.py │ ├── run_notebooks.py │ ├── run_tests.py │ ├── schemas │ │ └── placeholder.py │ └── tests │ │ └── placeholder.py ├── iuod_artificial_data_admin │ ├── cp_config.py │ ├── init_schemas.py │ ├── notebooks │ │ ├── code_promotion_paths.py │ │ ├── code_promotion_versions.py │ │ └── widget_utils.py │ ├── run_notebooks.py │ ├── run_tests.py │ ├── schemas │ │ └── placeholder.py │ └── tests │ │ └── placeholder.py └── iuod_artificial_data_generator │ ├── cp_config.py │ ├── init_schemas.py │ ├── notebooks │ ├── admin │ │ └── driver.py │ ├── common │ │ ├── coerce_schema.py │ │ ├── common_exports.py │ │ ├── spark_helpers.py │ │ ├── table_helpers.py │ │ └── widget_utils.py │ ├── dataset_definitions │ │ ├── hes │ │ │ ├── hes_ae_schema.py │ │ │ ├── hes_ae_tables.py │ │ │ ├── hes_apc_schema.py │ │ │ ├── hes_apc_tables.py │ │ │ ├── hes_op_schema.py │ │ │ ├── hes_op_tables.py │ │ │ ├── hes_patient_table.py │ │ │ ├── hes_schemas.py │ │ │ └── hes_tables.py │ │ ├── relational.py │ │ └── relational_helpers.py │ ├── dependencies │ │ ├── random.py │ │ ├── rstr.py │ │ ├── spark_rstr.py │ │ └── tests │ │ │ ├── test_rstr.py │ │ │ └── test_spark_rstr.py │ ├── generator_pipelines │ │ ├── artificial_hes │ │ │ ├── coerce_hes_schema.py │ │ │ ├── demographic_field_generators.py │ │ │ ├── driver.py │ │ │ └── field_definitions │ │ │ │ ├── derivations │ │ │ │ ├── age_fields.py │ │ │ │ ├── coded_field_lists.py │ │ │ │ ├── coded_fields.py │ │ │ │ ├── derivations_helpers.py │ │ │ │ ├── sequential_field_lists.py │ │ │ │ └── sequential_fields.py │ │ │ │ ├── id_field_lists.py │ │ │ │ ├── id_field_regex_patterns.py │ │ │ │ └── id_fields.py │ │ ├── driver_imports.py │ │ └── driver_parameters.py │ ├── generator_stages │ │ ├── derivations.py │ │ ├── field_generators.py │ │ ├── relationship_generator.py │ │ └── sampling │ │ │ ├── cdf_sampling.py │ │ │ └── field_definitions.py │ ├── scraper_pipelines │ │ ├── driver_imports.py │ │ └── hes │ │ │ ├── constants │ │ │ ├── disclosure_control_parameters.py │ │ │ ├── excluded_fields.py │ │ │ └── type_overrides.py │ │ │ ├── demographic_aggregation_parameters.py │ │ │ ├── demographic_field_summarisers.py │ │ │ └── driver.py │ └── scraper_stages │ │ ├── aggregation │ │ ├── aggregation_parameters.py │ │ ├── field_summarisers.py │ │ ├── relationship_summariser.py │ │ └── summariser_factory.py │ │ ├── ingestion.py │ │ ├── preprocessing │ │ └── meta_type_classifier.py │ │ ├── schemas │ │ ├── long_schema.py │ │ └── meta_schema.py │ │ ├── scrape_metadata.py │ │ └── wide_to_long.py │ ├── run_notebooks.py │ ├── run_tests.py │ ├── schemas │ ├── create_meta_table.py │ └── uplifts │ │ ├── artificial_hes_schema_coercion.py │ │ ├── artificial_hes_schema_coercion_tests.py │ │ ├── open_data_metadata_uplift.py │ │ └── open_data_metadata_uplift_tests.py │ └── tests │ ├── admin_tests │ └── placeholder_tests.py │ ├── common_tests │ ├── run_tests.py │ └── test_spark_helpers.py │ ├── dataset_definitions_tests │ ├── placeholder_tests.py │ └── run_tests.py │ ├── generator_pipelines_tests │ ├── artificial_hes_tests │ │ ├── demographic_field_generators_tests.py │ │ ├── field_definitions_tests │ │ │ ├── derivations_tests │ │ │ │ ├── age_fields_tests.py │ │ │ │ ├── coded_fields_tests.py │ │ │ │ ├── derivations_helpers_tests.py │ │ │ │ └── sequential_fields_tests.py │ │ │ ├── id_fields_tests.py │ │ │ └── imports.py │ │ └── run_tests.py │ └── run_tests.py │ ├── generator_stages_tests │ ├── field_generators_tests.py │ ├── relationship_generator_tests.py │ ├── run_tests.py │ └── sampling_tests │ │ └── field_definitions_tests.py │ ├── scraper_pipelines_tests │ ├── hes │ │ └── demographic_field_summarisers_tests.py │ └── run_tests.py │ ├── scraper_stages_tests │ ├── aggregation_tests │ │ ├── field_summarisers_tests.py │ │ └── relationship_summariser_tests.py │ ├── preprocessing_tests │ │ └── test_meta_type_classifier.py │ └── run_tests.py │ └── test_helpers.py └── utils ├── export.ps1 ├── import.ps1 └── list_releases.ps1 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | # Environments 5 | *.bak 6 | *.env 7 | .venv 8 | 9 | .vscode 10 | 11 | # Notebooks containing credentials 12 | dotenv.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NHS Digital 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/artificial_data_admin_guide.md: -------------------------------------------------------------------------------- 1 | ## Guide for artificial data admins 2 | This section is aimed at administrators of artificial data, working in an environment where they can trigger the different stages of the process. 3 | 4 | ### How do I extract metadata? 5 | To extract metadata for a dataset, run the `run_notebooks.py` notebook in the corresponding `artificial_{dataset_name}_meta` (for example `artificial_hes_meta/run_notebooks.py`). 6 | 7 | By default, this will read data from a database named according to the `{dataset_name}` parameter in the template string above and write to a database called `artificial_{dataset_name}_meta`. These can be changed via the notebook widgets. 8 | 9 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the respective project, as this will have the privileges to access the live data. 10 | Approved human users will only have access to access aggregates for review, not to the underlying record-level data. 11 | 12 | ### What happens after metadata is extracted? 13 | Once the metadata has been extracted, it should be manually reviewed by a member of staff working in a secure environment to ensure no personally identifiable information (PII) is disclosed. This should be signed off by a senior member of staff. 14 | 15 | At NHS Digital we have a checklist that was approved by the Statistical Disclosure Control Panel, chaired by the Chief Statistician. 16 | 17 | Once we have checked the metadata and signed it off we move it into a database which inaccessible to the metadata scraper and so is completely isolated from the database containing the real data. 18 | This is done by executing the `run_notebooks.py` notebook in the `iuod_artificial_data_admin` project. 19 | 20 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the `iuod_artificial_data_admin` project, as this will have the privileges to read from / write to the appropriate databases. 21 | 22 | ### How do I generate artificial data? 23 | To generate data and run postprocessing for a dataset, run the `run_notebooks.py` notebook in the `iuod_artificial_data_generator` project with the name of the dataset to generate artificial data for entered accordingly. 24 | For example for artificial HES data set the 'artificial_dataset' parameter to 'hes'. 25 | 26 | By default, this process will read metadata from and write artificial data to a database called `iuod_artificial_data_generator`, but this parameter can be changed via the notebook widgets. 27 | 28 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the `iuod_artificial_data_generator` project, as this will have the privileges to access the aggregated data. -------------------------------------------------------------------------------- /docs/artificial_data_user_notice.md: -------------------------------------------------------------------------------- 1 |

Notice For Artificial Data Users

2 | 3 |

What is artificial data?

4 | 5 |

Artificial data is an anonymous representation of real data

6 | 13 | 14 |

How is it generated?

15 | 16 | There are three stages involved in generating the artificial data: 17 | 18 |
    19 |
  1. The Metadata Scraper: extracts anonymised, high-level aggregates from real data at a national level. At this stage key identifiers (such as patient ID) are removed and small number suppression is applied in order to prevent reidentification at a later stage.
  2. 20 |
  3. The Data Generator: samples from the aggregates generated by the Metadata Scraper on a field-by-field basis and puts the sampled values together to create artificial records.
  4. 21 |
  5. Postprocessing: using the output of the Data Generator, dataset-specific tweaks are applied to make the data appear more realistic (such as swapping randomly generated birth and death dates to ensure sensible ordering). This also includes adding in randomly generated identifying fields (such ‘patient’ ID) which were removed at the Metadata Scraper stage.
  6. 22 |
23 | 24 |

What is it used for?

25 | 26 | The purpose of artificial data is twofold. 27 | 28 |

1. Artificial data enables faster onboarding for new data projects

29 | 30 | Users can begin work on a project in advance of access to real data being granted. There are multiple use cases for the artificial data. 31 | 35 | 36 |

2. Artificial data minimises the amount of personal data being processed

37 | 38 | The activities mentioned above can all be completed without handling personal data, improving the ability of NHS Digital to protect patient privacy by minimising access to sensitive data. 39 | 40 | 41 |

What are its limitations?

42 | 43 | Artificial data is not real data, and is not intended to represent something real or link to real records. Artificial records are not based on any specific records found in the original data, only on high-level, anonymised aggregates. As outlined above, it is intended to improve efficiency and protect patient data. 44 | 45 | It is crucial to note that artificial data is not synthetic data. 46 | 47 | Synthetic data is generated using sophisticated methods to create realistic records. Usually synthetic data aims to enable the building of statistical models or gaining of insights that transfer onto real data. This is not be possible with artificial data. The downside of synthetic data is that it is associated with non-negligible risks to patient privacy through reidentification. It is not possible to reidentify individuals using artificial data. 48 | 49 | 50 |

Additional Information

51 | 52 |

Support from senior leadership

53 | This work has undergone due process to fully assess any potential risks to patient privacy and has been approved by senior leadership within NHS Digital, including: the Senior Information Risk Owner (SIRO); the Caldicott Guardian; the Data Protection Officer (DPO); the Executive Director of Data and Analytics Services; and the IAOs for the datasets represented by the Artificial Data assets. A full DPIA has been completed and is available upon request. 54 | 55 | For further details, please get in touch via the mailbox linked below. 56 | 57 |

Contact

58 | For further details, please get in touch via: nhsdigital.artificialdata@nhs.net 59 | -------------------------------------------------------------------------------- /docs/build_docs/README.md: -------------------------------------------------------------------------------- 1 | 2 | # How do I add user documentation to the production DAE? 3 | There are 2 steps to building the documentation and adding it to DAE: 4 | 1. (Desktop Python) Run the python file `notebooks/user/build_docs/make_create_user_docs.py`: the full documentation exists in the `docs` folder in the top-level of the repo; this step takes this documentation and puts the contents into a file that can be run on DAE to make it readable by users. 5 | 1. (DAE Prod) Copy the notebook `notebooks/user/collab/create_user_docs.py` created in step 1 into DAE Prod and run it: this will add the documentation to a user-facing table in DAE Prod. 6 | -------------------------------------------------------------------------------- /docs/build_docs/create_user_docs_template.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import functions as F 3 | 4 | # COMMAND ---------- 5 | 6 | dbutils.widgets.text("db", "", "0.1 Project Database") 7 | 8 | # COMMAND ---------- 9 | 10 | # Config 11 | database_name = dbutils.widgets.get("db") 12 | table_name = "user_docs" 13 | table_path = f"{database_name}.{table_name}" 14 | 15 | # Check database exists 16 | if spark.sql(f"SHOW DATABASES LIKE '{database_name}'").first() is None: 17 | # Database not found - exit 18 | dbutils.notebook.exit({}) 19 | 20 | # Template variables replaced during build 21 | user_notice_file_name = "artificial_data_user_notice" 22 | 23 | # COMMAND ---------- 24 | 25 | user_notice_html = """{{artificial_data_user_notice.md}}""" 26 | 27 | # COMMAND ---------- 28 | 29 | # Create and upload the docs 30 | docs_data = [ 31 | [user_notice_file_name, user_notice_html] 32 | ] 33 | docs_schema = "file_name: string, content_html: string" 34 | docs_df = spark.createDataFrame(docs_data, docs_schema) 35 | ( 36 | docs_df.write 37 | .format("delta") 38 | .mode("overwrite") 39 | .saveAsTable(table_path) 40 | ) 41 | 42 | # Make sure users can select from the table but not overwrite 43 | if os.getenv("env", "ref") == "ref": 44 | owner = "data-managers" 45 | else: 46 | owner = "admin" 47 | 48 | spark.sql(f"ALTER TABLE {table_path} OWNER TO `{owner}`") 49 | 50 | # Check the uploads 51 | for file_name, content_html in docs_data: 52 | result_content_html = ( 53 | spark.table(table_path) 54 | .where(F.col("file_name") == file_name) 55 | .first() 56 | .content_html 57 | ) 58 | assert result_content_html == content_html 59 | -------------------------------------------------------------------------------- /docs/build_docs/make_create_user_docs.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | # Projects requiring docs 4 | projects = ["artificial_hes"] 5 | 6 | # Constant paths 7 | this_path = pathlib.Path(__file__) 8 | template_path = this_path.parent / "create_user_docs_template.py" 9 | docs_path = this_path.parent.parent 10 | projects_root = this_path.parent.parent.parent / "projects" 11 | 12 | def replace_template_variable(template: str, token_value: str, new_value: str) -> str: 13 | filled_template = template.replace("".join(["{{", token_value, "}}"]), new_value) 14 | return filled_template 15 | 16 | for project_name in projects: 17 | project_path = projects_root / project_name 18 | notebook_path = project_path / "schemas" / "create_user_docs.py" 19 | 20 | # Read the template content 21 | with template_path.open("r") as template_file: 22 | notebook_content = template_file.read() 23 | 24 | # Populate placeholders in the template 25 | for doc_file_path in docs_path.glob("**/*.md"): 26 | with doc_file_path.open("r") as doc_file: 27 | doc_file_content = doc_file.read() 28 | 29 | notebook_content = replace_template_variable(notebook_content, doc_file_path.name, doc_file_content) 30 | notebook_content = replace_template_variable(notebook_content, "database_name", project_name) 31 | 32 | # Write the full content 33 | with notebook_path.open("w+") as notebook_file: 34 | notebook_file.write(notebook_content) 35 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NHSDigital/artificial-data-generator/47a406cad27125f9499a1af10ba2afac3df0c554/environment.yml -------------------------------------------------------------------------------- /notebooks/admin/collab/artificial_hes/1_scrape_metadata.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md # Scrape HES metadata 3 | # MAGIC Run the `cp_run_notebooks_artificial_hes_meta` job with user-entered config. 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %run ../job_utils 13 | 14 | # COMMAND ---------- 15 | 16 | # MAGIC %md 17 | # MAGIC ## Main 18 | 19 | # COMMAND ---------- 20 | 21 | USERNAME = "" 22 | assert USERNAME, "Please provide a username for getting the API token" 23 | 24 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60) 25 | configure_databricks_cli(token) 26 | 27 | # COMMAND ---------- 28 | 29 | run_job_async("cp_run_notebooks_artificial_hes_meta", sleep_seconds=0, max_retries=1) 30 | 31 | # COMMAND ---------- 32 | 33 | -------------------------------------------------------------------------------- /notebooks/admin/collab/artificial_hes/3_release_metadata.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md # Release hes metadata 3 | # MAGIC Run the `cp_run_notebooks_iuod_artificial_data_admin` to move HES metadata into the `iuod_artificial_data_generator` database. 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %run ../job_utils 13 | 14 | # COMMAND ---------- 15 | 16 | # MAGIC %md ## Config 17 | 18 | # COMMAND ---------- 19 | 20 | USERNAME = "" 21 | assert USERNAME, "Please provide a username for getting the API token" 22 | 23 | job_params = { 24 | "safety_dt": "", 25 | "source_database_name": "artificial_hes_meta", 26 | "table_name": "artificial_hes_meta", 27 | "reviewer_name": "", 28 | "reviewer_email": "", 29 | "reviewed_date": "", 30 | "approver_name": "", 31 | "approver_email": "", 32 | "approved_date": "", 33 | } 34 | 35 | for param_name, param_value in job_params.items(): 36 | assert param_value, f"Please provide a value for parameter '{job_params}'" 37 | 38 | # COMMAND ---------- 39 | 40 | # MAGIC %md 41 | # MAGIC ## Main 42 | 43 | # COMMAND ---------- 44 | 45 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60) 46 | configure_databricks_cli(token) 47 | 48 | # COMMAND ---------- 49 | 50 | run_job_async("cp_run_notebooks_iuod_artificial_data_admin", **job_params, sleep_seconds=30, max_retries=20) 51 | 52 | # COMMAND ---------- 53 | 54 | -------------------------------------------------------------------------------- /notebooks/admin/collab/artificial_hes/4_generate_artificial_data.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md # Generate artificial HES data 3 | # MAGIC Run the`cp_run_notebooks_iuod_artificial_data_generator` job for the `artificial_hes` data asset. 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %run ../job_utils 13 | 14 | # COMMAND ---------- 15 | 16 | import json 17 | import pyspark.sql.functions as F 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md ## Config 22 | 23 | # COMMAND ---------- 24 | 25 | USERNAME = "" 26 | assert USERNAME, "Please provide a username for getting the API token" 27 | 28 | parameters_json = json.dumps(dict(n_patients=1000000)) 29 | artificial_dataset = "artificial_hes" 30 | 31 | # COMMAND ---------- 32 | 33 | # MAGIC %md 34 | # MAGIC ## Main 35 | 36 | # COMMAND ---------- 37 | 38 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60) 39 | configure_databricks_cli(token) 40 | 41 | # COMMAND ---------- 42 | 43 | run_job_async( 44 | "cp_run_notebooks_iuod_artificial_data_generator", 45 | artificial_dataset=artificial_dataset, 46 | parameters_json=parameters_json, 47 | sleep_seconds=0, 48 | max_retries=1 49 | ) -------------------------------------------------------------------------------- /notebooks/admin/collab/artificial_hes/5_qa_artificial_data.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # TODO: expand the notebook below with checks implemented in Ref 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run /Users/admin/releases/code-promotion/iuod_artificial_data_generator/4332+20220930151507.gite6c853018/notebooks/dataset_definitions/hes/hes_schemas 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run /Users/admin/releases/code-promotion/iuod_artificial_data_generator/4332+20220930151507.gite6c853018/notebooks/common/spark_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | import pyspark.sql.functions as F 15 | import pyspark.sql.types as T 16 | 17 | 18 | database_name = "artificial_hes" 19 | table_name_pattern = r"artificial_hes_(op|apc|ae)_\d{4}" 20 | 21 | # COMMAND ---------- 22 | 23 | tables = ( 24 | spark.sql(f"show tables from {database_name}") 25 | .filter(F.col("tableName").rlike(table_name_pattern)) 26 | .select("tableName") 27 | .collect() 28 | ) 29 | 30 | for table in tables: 31 | artificial_df = spark.table(f"{database_name}.{table.tableName}") 32 | 33 | # Check non-zero rows 34 | row_count = artificial_df.count() 35 | assert row_count > 100, "Found fewer rows than expected" 36 | print(table.tableName, f"{row_count:,}") 37 | 38 | # Compare schemas 39 | if "ae" in table.tableName: 40 | hes_schema = get_hes_ae_schema() 41 | elif "op" in table.tableName: 42 | hes_schema = get_hes_op_schema() 43 | elif "apc" in table.tableName: 44 | hes_schema = get_hes_apc_schema() 45 | else: 46 | assert False, f"Table {table.table_name} does not belong to HES dataset (apc, op, ae)" 47 | 48 | # Sort for comparison between fields 49 | hes_schema = T.StructType(sorted(hes_schema.fields, key=lambda f: f.name)) 50 | artificial_schema = artificial_df.select(*sorted(artificial_df.columns)).schema 51 | 52 | # TODO: 53 | # Will the artificial_hes pipeline fail when it tries to update tha AE tables? 54 | # Do we need to explicitly drop columns via init_schemas? 55 | assert datatypes_equal(artificial_schema, hes_schema), "Schema not equal to expected schema" 56 | 57 | # if len(hes_schema) != len(artificial_schema): 58 | # # TODO: handle this properly? 59 | # print( 60 | # [f for f in artificial_schema.fields if f.name not in hes_schema.fieldNames()], 61 | # [f for f in hes_schema.fields if f.name not in artificial_schema.fieldNames()] 62 | # ) 63 | # else: 64 | 65 | # TODO: Check values? 66 | 67 | # COMMAND ---------- 68 | 69 | -------------------------------------------------------------------------------- /notebooks/admin/collab/artificial_hes/6_release_artificial_data.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /notebooks/admin/collab/example_dotenv.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Example dotenv.py 4 | 5 | # COMMAND ---------- 6 | 7 | API_PREFIX = "my_api_host_prefix" 8 | API_SUFFIX = "my_api_host_suffix.com/api/2.0" 9 | -------------------------------------------------------------------------------- /notebooks/admin/collab/job_utils.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./dotenv.py 3 | 4 | # COMMAND ---------- 5 | 6 | import json 7 | import os 8 | import time 9 | from typing import Dict, Tuple, Any 10 | 11 | # From dotenv.py 12 | API_PREFIX = API_PREFIX 13 | API_SUFFIX = API_SUFFIX 14 | 15 | DATABRICKS_ENV = os.getenv("env", "dev") 16 | HTTPS = "https://" 17 | PENDING_STATE = "PENDING" 18 | RUNNING_STATE = "RUNNING" 19 | 20 | 21 | def get_databricks_api_host() -> str: 22 | """ Get URL for the Databricks API 2.0 for the current environment. 23 | 24 | NOTE: This is specific to NHS Digital infrastructure. 25 | 26 | Returns 27 | ------- 28 | str 29 | API base URL for the current environment 30 | 31 | """ 32 | host_segments = [f"https://{API_PREFIX}"] 33 | 34 | if DATABRICKS_ENV != "prod": 35 | host_segments.append(DATABRICKS_ENV) 36 | 37 | host_segments.append(API_SUFFIX) 38 | host = ".".join(host_segments) 39 | 40 | return host 41 | 42 | 43 | def configure_databricks_cli(token: str) -> None: 44 | """ Configure the Databricks CLI with the hostname and token. 45 | 46 | Parameters 47 | ---------- 48 | token: str 49 | User generated access token 50 | 51 | """ 52 | os.environ["DATABRICKS_HOST"] = get_databricks_api_host() 53 | os.environ["DATABRICKS_TOKEN"] = token 54 | 55 | 56 | def get_job_name_id_map() -> Dict[str, str]: 57 | result = os.popen("databricks jobs list").read() 58 | result = result.split("\n") 59 | result = filter(None, result) 60 | job_name_id_map = dict(map(lambda s: s.split()[::-1], result)) 61 | return job_name_id_map 62 | 63 | 64 | def get_job_run_state(run_id: str) -> Dict[str, str]: 65 | get_run_command = f"databricks runs get --run-id {run_id}" 66 | result_json = os.popen(get_run_command).read() 67 | result = json.loads(result_json) 68 | state = result.get("state", {}) 69 | return state 70 | 71 | 72 | def wait_for_job_result(run_id: str, sleep_seconds: int=5, max_retries: int=60) -> str: 73 | print(f"Checking state of run with run_id '{run_id}'") 74 | 75 | for i in range(max_retries): 76 | state = get_job_run_state(run_id) 77 | life_cycle_state = state.get("life_cycle_state", PENDING_STATE) 78 | result_state = state.get("result_state", PENDING_STATE) 79 | 80 | if life_cycle_state in (PENDING_STATE, RUNNING_STATE): 81 | print(f"Run with run_id '{run_id}' incomplete with life_cycle_state '{life_cycle_state}', checking again in {sleep_seconds} seconds") 82 | time.sleep(sleep_seconds) 83 | else: 84 | print(f"Run with run_id '{run_id}' completed with result_state '{result_state}'") 85 | return result_state 86 | 87 | else: 88 | print(f"Run with run_id '{run_id}' incomplete with life_cycle_state '{life_cycle_state}' but maximum number of retries ({max_retries}) has been exceeded") 89 | return result_state 90 | 91 | 92 | def run_job_with_notebook_params(job_id: str, notebook_params: str) -> str: 93 | # Check if there is already a job running 94 | print(f"Checking state of last run for job with job_id '{job_id}'") 95 | list_runs_command = f"databricks runs list --job-id {job_id}" 96 | result = os.popen(list_runs_command).read() 97 | 98 | if result == "\n": 99 | # No runs yet 100 | last_run_id = None 101 | last_run_life_cycle_state = None 102 | else: 103 | last_run_id = result.split("\n")[0].split()[0] 104 | last_run_state = get_job_run_state(last_run_id) 105 | last_run_life_cycle_state = last_run_state.get("life_cycle_state", PENDING_STATE) 106 | print(f"Last run for job with job_id '{job_id}' has life_cycle_state '{last_run_life_cycle_state}': run_id '{last_run_id}'") 107 | 108 | if last_run_life_cycle_state in (PENDING_STATE, RUNNING_STATE): 109 | # Job still pending/running 110 | return {"run_id": last_run_id} 111 | else: 112 | # Start a new job 113 | run_job_command = f"databricks jobs run-now --job-id {job_id} --notebook-params '{notebook_params}'" 114 | result_json = os.popen(run_job_command).read() 115 | result = json.loads(result_json) 116 | print(f"Started new job run for job_id '{job_id}' with run_id '{result['run_id']}'") 117 | return result 118 | 119 | 120 | def run_job(job_name: str, **job_kwargs) -> str: 121 | job_name_id_map = get_job_name_id_map() 122 | job_id = job_name_id_map[job_name] 123 | print(f"Found job_id for job '{job_name}': '{job_id}'") 124 | 125 | # Notebook arguments 126 | notebook_params = json.dumps(job_kwargs) 127 | 128 | return run_job_with_notebook_params(job_id, notebook_params) 129 | 130 | 131 | def run_job_async(*run_args: Tuple[Any], sleep_seconds: int=30, max_retries: int=20, **run_kwargs) -> Dict[str, str]: 132 | """ Run a job and keep checking status until it finishes or until a maximum number of retries 133 | 134 | Parameters 135 | ---------- 136 | *run_args: Any 137 | Passed through to run_job as arguments 138 | sleep_seconds: int (default=30) 139 | Number of seconds to sleep after each retry 140 | max_retries: int (default=20) 141 | Number of times to retry checking if the run has finished 142 | **run_kwargs: Any 143 | Passed through to run_job as key-word arguments 144 | 145 | Returns 146 | ------- 147 | Dict[str, str] 148 | Result of the job run with the result state 149 | 150 | """ 151 | result = run_job(*run_args, **run_kwargs) 152 | run_id = result["run_id"] 153 | result_state = wait_for_job_result(run_id, sleep_seconds=sleep_seconds, max_retries=max_retries) 154 | return {**result, "result_state": result_state} 155 | -------------------------------------------------------------------------------- /notebooks/admin/collab/meta_review/assertions.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import DataFrame, Column, functions as F 3 | from nhsdccl.util import DFCheckpoint 4 | 5 | 6 | class StatisticalDisclosureControlException(Exception): 7 | def __init__(self, message="Statistical disclosure control check failed!"): 8 | return super().__init__(message) 9 | 10 | 11 | def modulo(col_name: str, x: int) -> Column: 12 | """ Spark SQL modulo expression as a PySpark Column. 13 | """ 14 | return F.expr(f"{col_name} % {x}") 15 | 16 | 17 | def assert_statistical_disclosure_controls( 18 | meta_df: DataFrame, 19 | freq_col: Column, 20 | threshold: int = 100, 21 | rounding: int = 100 22 | ) -> None: 23 | """ 24 | Check that the statistical disclosure controls are correctly applied to the metadata. 25 | By default the threshold and rounding on frequencies is set to 100 to avoid accidental disclosures. 26 | 27 | Parameters 28 | ---------- 29 | meta_df: DataFrame 30 | Metadata 31 | freq_col: Column 32 | Column of `meta_df` holding the frequencies 33 | threshold: int (default = 100) 34 | Minimum frequency for a value to appear in the metadata 35 | rounding: int (default = 100) 36 | Frequencies should be rounded to this nearest value 37 | 38 | Returns 39 | ------- 40 | None 41 | 42 | Raises 43 | ------ 44 | StatisticalDisclosureControlException 45 | Raised if any frequencies do not obey the disclosure controls 46 | 47 | """ 48 | # Filters for rows that break the SDCs 49 | frequency_below_threshold = F.col("FREQUENCY") < F.lit(threshold) 50 | frequency_not_divisible_by_rounding = modulo("FREQUENCY", rounding) != F.lit(0) 51 | 52 | # Find rows that break the SDCs 53 | sdc_check_df = ( 54 | meta_df 55 | .withColumn("FREQUENCY", freq_col) 56 | .filter(frequency_below_threshold | frequency_not_divisible_by_rounding) 57 | ) 58 | 59 | # Check that there are no rows breaking the SDCs 60 | with DFCheckpoint(sdc_check_df) as df_cp: 61 | try: 62 | assert df_cp.df.first() is None 63 | 64 | except AssertionError as e: 65 | print("Frequency check failed for the following rows: ") 66 | df_cp.df.show() 67 | raise StatisticalDisclosureControlException 68 | 69 | -------------------------------------------------------------------------------- /notebooks/admin/collab/meta_review/assertions_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./assertions 3 | 4 | # COMMAND ---------- 5 | 6 | from unittest import TestCase 7 | 8 | test_data = [ 9 | ["table1", "FIELD1", "bad_value1", 4], 10 | ["table1", "FIELD1", "bad_value2", 7], 11 | ["table1", "FIELD1", "good_value", 15], 12 | ] 13 | test_schema = "TABLE_NAME: string, FIELD_NAME: string, VALUE: string, FREQUENCY: integer" 14 | test_df = spark.createDataFrame(test_data, test_schema) 15 | test_value_col = F.col("VALUE") 16 | test_freq_col = F.col("FREQUENCY") 17 | 18 | with TestCase().assertRaises(StatisticalDisclosureControlException): 19 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "bad_value1"), test_freq_col, threshold=5, rounding=5) 20 | 21 | with TestCase().assertRaises(StatisticalDisclosureControlException): 22 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "bad_value2"), test_freq_col, threshold=5, rounding=5) 23 | 24 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "good_value"), test_freq_col, threshold=5, rounding=5) -------------------------------------------------------------------------------- /notebooks/admin/collab/meta_review/field_definitions.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import functions as F 3 | 4 | 5 | table_name_col = F.col("TABLE_NAME") 6 | field_name_col = F.col("FIELD_NAME") 7 | 8 | is_categorical_col = F.col("VALUE_TYPE") == "CATEGORICAL" 9 | is_discrete_col = F.col("VALUE_TYPE") == "DISCRETE" 10 | is_date_col = F.col("VALUE_TYPE") == "DATE" 11 | is_continuous_col = F.col("VALUE_TYPE") == "CONTINUOUS" 12 | 13 | is_demographic_categorical_col = F.col("VALUE_TYPE") == "DEMOGRAPHIC_CATEGORICAL" 14 | is_demographic_date_col = F.col("VALUE_TYPE") == "DEMOGRAPHIC_DATE" 15 | 16 | is_relationship_col = F.col("VALUE_TYPE") == "RELATIONSHIP" -------------------------------------------------------------------------------- /notebooks/admin/collab/meta_review/regex_patterns.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import Union 3 | from pyspark.sql import DataFrame, Column, functions as F 4 | 5 | 6 | PHONE_NUMBER_PATTERN = r"^(((\+44\s?\d{4}|\(?0\d{4}\)?)\s?\d{3}\s?\d{3})|((\+44\s?\d{3}|\(?0\d{3}\)?)\s?\d{3}\s?\d{4})|((\+44\s?\d{2}|\(?0\d{2}\)?)\s?\d{4}\s?\d{4}))(\s?\#(\d{4}|\d{3}))?$" 7 | EMAIL_ADDRESS_PATTERN = r"([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+" 8 | POSTCODE_PATTERN = r"([Gg][Ii][Rr] 0[Aa]{2})|((([A-Za-z][0-9]{1,2})|(([A-Za-z][A-Ha-hJ-Yj-y][0-9]{1,2})|(([A-Za-z][0-9][A-Za-z])|([A-Za-z][A-Ha-hJ-Yj-y][0-9][A-Za-z]?))))\s?[0-9][A-Za-z]{2})" 9 | GT_6_ALPHA_CHARS_PATTERN = r"^([A-Za-z]{6,})$" 10 | 11 | 12 | def filter_matches(df: DataFrame, col: Union[Column, str], pattern: str) -> DataFrame: 13 | """ 14 | Filter rows of a DataFrame based on a regex pattern match within a specified column. 15 | 16 | Parameters 17 | ---------- 18 | df : pyspark.sql.DataFrame 19 | DataFrame to filter 20 | col : Union[pyspark.sql.Column, str] 21 | Column to look for matches within 22 | pattern : str 23 | Pattern to be matched 24 | 25 | Returns 26 | ------- 27 | pyspark.sql.DataFrame 28 | Filtered DataFrame based on pattern matches 29 | 30 | """ 31 | extracted_value = F.regexp_extract(col, pattern, 0) 32 | is_match = F.length(extracted_value) > 0 33 | match_df = df.filter(is_match) 34 | return match_df 35 | -------------------------------------------------------------------------------- /notebooks/admin/collab/meta_review/regex_patterns_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./regex_patterns 3 | 4 | # COMMAND ---------- 5 | 6 | test_df = spark.createDataFrame( 7 | [ 8 | ["hello@gmail.com"], 9 | ["hello.world@nhs.net"], 10 | ["some other string"] 11 | ], 12 | "VALUE: string" 13 | ) 14 | assert filter_matches(test_df, "VALUE", EMAIL_ADDRESS_PATTERN).count() == 2 15 | 16 | # COMMAND ---------- 17 | 18 | test_df = spark.createDataFrame( 19 | [ 20 | ["07123456789"], 21 | ["07 123 456 789"], 22 | ["+447123456789"] 23 | ], 24 | "VALUE: string" 25 | ) 26 | assert filter_matches(test_df, "VALUE", PHONE_NUMBER_PATTERN).count() == 2 # Need to strip spaces 27 | 28 | # COMMAND ---------- 29 | 30 | test_df = spark.createDataFrame( 31 | [ 32 | ["LS1 4AP"], 33 | ["LS14AP"], 34 | ["Hello"] 35 | ], 36 | "VALUE: string" 37 | ) 38 | assert filter_matches(test_df, "VALUE", POSTCODE_PATTERN).count() == 2 39 | 40 | # COMMAND ---------- 41 | 42 | test_df = spark.createDataFrame( 43 | [ 44 | ["ABCdef"], 45 | ["ghiJK"] 46 | ], 47 | "VALUE: string" 48 | ) 49 | assert filter_matches(test_df, "VALUE", GT_6_ALPHA_CHARS_PATTERN).count() == 1 -------------------------------------------------------------------------------- /notebooks/user/collab/artificial_data_user_notice.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import functions as F 2 | 3 | # Config 4 | database_name = "artificial_hes" 5 | table_name = "user_docs" 6 | table_path = f"{database_name}.{table_name}" 7 | 8 | # Get the content 9 | content_html = ( 10 | spark.table(table_path) 11 | .where(F.col("file_name") == "artificial_data_user_notice") 12 | .first() 13 | .content_html 14 | ) 15 | 16 | displayHTML(content_html) 17 | -------------------------------------------------------------------------------- /projects/artificial_hes/cp_config.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Code Promotion Job Configuration 4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks. 5 | # MAGIC Only simple setting of variables is allowed in this notebook. 6 | 7 | # COMMAND ---------- 8 | 9 | # DBTITLE 1,Global settings 10 | # MAGIC %md 11 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3). 12 | # MAGIC This applies to all jobs created 13 | 14 | # COMMAND ---------- 15 | 16 | spark_version = "6.6.x-scala2.11" 17 | 18 | # COMMAND ---------- 19 | 20 | # DBTITLE 1,init_schemas 21 | # MAGIC %md 22 | # MAGIC Available parameters: 23 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 24 | 25 | # COMMAND ---------- 26 | 27 | init_schemas = { 28 | "retain_cluster": False 29 | } 30 | 31 | # COMMAND ---------- 32 | 33 | # DBTITLE 1,run_notebooks 34 | # MAGIC %md 35 | # MAGIC Available parameters: 36 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time. 37 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job. 38 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax. 39 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 40 | 41 | # COMMAND ---------- 42 | 43 | run_notebooks = { 44 | "concurrency": 1, 45 | "extra_parameters": { 46 | "included_tables": "", 47 | "excluded_tables": "", 48 | }, 49 | "retain_cluster": False 50 | } 51 | 52 | # COMMAND ---------- 53 | 54 | # DBTITLE 1,tool_config 55 | # MAGIC %md 56 | # MAGIC Available parameters: 57 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 58 | 59 | # COMMAND ---------- 60 | 61 | # Example: 62 | # tool_config = { 63 | # "retain_cluster": True 64 | # } 65 | #tool_config = { 66 | #} -------------------------------------------------------------------------------- /projects/artificial_hes/init_schemas.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md ## Artificial HES: Init Schemas 3 | # MAGIC This is the init_schemas notebook for the artificial_hes Code Promotion project 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Code Promotion Project Paths 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ./notebooks/code_promotion_paths 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md 22 | # MAGIC ### Imports 23 | 24 | # COMMAND ---------- 25 | 26 | # MAGIC %run ./notebooks/common/widget_utils 27 | 28 | # COMMAND ---------- 29 | 30 | import pprint 31 | 32 | # COMMAND ---------- 33 | 34 | # MAGIC %md 35 | # MAGIC ### Widgets 36 | 37 | # COMMAND ---------- 38 | 39 | if check_databricks(): 40 | dbutils.widgets.text("notebook_root", "artificial_hes/dev", "0.1 Notebook Root") 41 | dbutils.widgets.text("db", "artificial_hes", "0.2 Project Database") 42 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database") 43 | else: 44 | # Only make widgets in databricks 45 | pass 46 | 47 | # COMMAND ---------- 48 | 49 | # MAGIC %md 50 | # MAGIC ## Main 51 | # MAGIC Run schema notebooks 52 | 53 | # COMMAND ---------- 54 | 55 | notebook_root = get_required_argument("notebook_root") 56 | *_, project_name, project_version = notebook_root.split("/") 57 | 58 | cp_project_params = {} 59 | cp_project_params["project_name"] = project_name 60 | cp_project_params["project_version"] = project_version 61 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version) 62 | 63 | # COMMAND ---------- 64 | 65 | print(f"Running with context:") 66 | pprint.pprint(cp_project_params) 67 | 68 | # COMMAND ---------- 69 | 70 | database_name = get_required_argument("db") 71 | adg_schemas_path = cp_project_params["adg_project_path"] / "schemas" 72 | artificial_hes_schema_coercion_path = str(adg_schemas_path / "uplifts" / "artificial_hes_schema_coercion") 73 | 74 | notebook_params = [ 75 | { 76 | "path": "./schemas/create_user_docs", 77 | "timeout_seconds": 0, 78 | "arguments": {"database_name": database_name} 79 | }, 80 | { 81 | "path": artificial_hes_schema_coercion_path, 82 | "timeout_seconds": 0, 83 | "arguments": {"database_name": database_name} 84 | }, 85 | ] 86 | 87 | # COMMAND ---------- 88 | 89 | for params in notebook_params: 90 | dbutils.notebook.run(**params) -------------------------------------------------------------------------------- /projects/artificial_hes/notebooks/code_promotion_paths.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import pathlib 3 | 4 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6" 5 | ADG_PROJECT_NAME = "iuod_artificial_data_generator" 6 | CODE_PROMOTION_RELEASES_PATH = pathlib.Path("/Users/admin/releases/code-promotion") 7 | 8 | 9 | def get_adg_env(project_version: str) -> str: 10 | if project_version == "dev": 11 | return "dev" 12 | elif project_version == "staging": 13 | return "staging" 14 | else: 15 | return "release" 16 | 17 | 18 | def get_adg_project_path(project_version: str) -> pathlib.Path: 19 | adg_env = get_adg_env(project_version) 20 | 21 | if adg_env == "release": 22 | adg_project_path = CODE_PROMOTION_RELEASES_PATH / ADG_PROJECT_NAME / ADG_PROJECT_VERSION 23 | elif adg_env == "staging": 24 | root_path = pathlib.Path("/staging") 25 | adg_project_path = root_path / ADG_PROJECT_NAME 26 | elif adg_env == "dev": 27 | root_path = pathlib.Path("../") 28 | adg_project_path = root_path / ADG_PROJECT_NAME 29 | 30 | return adg_project_path -------------------------------------------------------------------------------- /projects/artificial_hes/notebooks/common/coerce_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import DataFrame, types as T, functions as F 3 | 4 | 5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame: 6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis). 7 | 8 | Fields that exist in the target schema but which are not present in the original dataframe 9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised 10 | 11 | Fields that exist in the original dataframe but not the target schema will be excluded 12 | from the result 13 | 14 | Args: 15 | df (pyspark.sql.DataFrame): Dataframe to transform 16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result 17 | 18 | Raises: 19 | ValueError: If a non-nullable field exists in the target schema which is not in 20 | the original dataframe 21 | 22 | Returns: 23 | DataFrame: transformed dataframe 24 | """ 25 | selected_fields = [] 26 | 27 | for field in target_schema.fields: 28 | if field.name in df.columns: 29 | value_col = F.col(field.name) 30 | else: 31 | if field.nullable: 32 | value_col = F.lit(None) 33 | else: 34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!") 35 | 36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name)) 37 | 38 | return df.select(*selected_fields) -------------------------------------------------------------------------------- /projects/artificial_hes/notebooks/common/common_exports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./table_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./widget_utils 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./spark_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./coerce_schema -------------------------------------------------------------------------------- /projects/artificial_hes/notebooks/common/widget_utils.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import os 3 | 4 | # COMMAND ---------- 5 | 6 | def check_databricks() -> bool: 7 | """Check whether the code is running in the databricks environment 8 | 9 | Returns: 10 | bool: indicates whether the current environment is databricks 11 | """ 12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3" 13 | 14 | 15 | # COMMAND ---------- 16 | 17 | def get_required_argument(name: str) -> str: 18 | """Get the value of a dbutils widget, raising an exception if 19 | the value is not provided. 20 | 21 | Args: 22 | name (str): name of the widget to get the value of 23 | 24 | Raises: 25 | NotImplementedError: running this function outside databricks 26 | ValueError: if a value is not provided for the widget 27 | 28 | Returns: 29 | str: value of the widget 30 | """ 31 | if check_databricks(): 32 | value = dbutils.widgets.get(name) 33 | else: 34 | # TODO 35 | raise NotImplementedError("No way to provide arguments outside of Databricks") 36 | 37 | try: 38 | assert value 39 | return value 40 | except AssertionError: 41 | raise ValueError(f"Argument '{name}' must be provided") 42 | -------------------------------------------------------------------------------- /projects/artificial_hes/run_notebooks.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md ## Artificial HES: Run Notebooks 3 | # MAGIC 4 | # MAGIC This notebook copies tables containing artificial HES data from `iuod_artificial_data_generator` into a target database (by default this is the CP project database, namely `artificial_hes`). 5 | 6 | # COMMAND ---------- 7 | 8 | # MAGIC %md 9 | # MAGIC ## Setup 10 | 11 | # COMMAND ---------- 12 | 13 | # MAGIC %md 14 | # MAGIC ### Imports 15 | 16 | # COMMAND ---------- 17 | 18 | # MAGIC %run ./notebooks/common/widget_utils 19 | 20 | # COMMAND ---------- 21 | 22 | # MAGIC %run ./notebooks/common/table_helpers 23 | 24 | # COMMAND ---------- 25 | 26 | from pprint import pprint 27 | from functools import reduce 28 | 29 | from pyspark.sql import DataFrame, functions as F 30 | 31 | # COMMAND ---------- 32 | 33 | # MAGIC %md 34 | # MAGIC ### Widgets 35 | 36 | # COMMAND ---------- 37 | 38 | if check_databricks(): 39 | dbutils.widgets.removeAll() 40 | dbutils.widgets.text("notebook_root", "artificial_hes/dev", "0.1 Notebook Root") 41 | dbutils.widgets.text("db", "artificial_hes", "0.2 Project Database") 42 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database") 43 | dbutils.widgets.text("included_tables", "", "1.1 Tables to copy (comma separated)") 44 | dbutils.widgets.text("excluded_tables", "", "1.2 Tables to exclude from copy (comma separated)") 45 | else: 46 | # Only make widgets in databricks 47 | pass 48 | 49 | # COMMAND ---------- 50 | 51 | # MAGIC %md 52 | # MAGIC ## Main 53 | 54 | # COMMAND ---------- 55 | 56 | source_database = get_required_argument("iuod_artificial_data_generator") 57 | target_database = get_required_argument("db") 58 | 59 | # COMMAND ---------- 60 | 61 | # Parse the inclusions / exclusions 62 | 63 | included_tables = list(filter(lambda x: x, dbutils.widgets.get("included_tables").split(","))) 64 | excluded_tables = list(filter(lambda x: x, dbutils.widgets.get("excluded_tables").split(","))) 65 | 66 | tables_to_clone_df = ( 67 | spark.sql(rf"SHOW TABLES IN {source_database}") 68 | .filter(F.col("tableName").rlike(r"artificial_hes_(apc|ae|op)_\d{4}")) 69 | .filter(~F.col("tableName").isin(excluded_tables)) 70 | ) 71 | 72 | if included_tables: 73 | tables_to_clone_df = tables_to_clone_df.filter(F.col("tableName").isin(included_tables)) 74 | 75 | tables_to_clone = [row.tableName for row in tables_to_clone_df.collect()] 76 | 77 | print("The following tables will be cloned: ") 78 | pprint(tables_to_clone) 79 | 80 | print("\nThe following tables will be excluded: ") 81 | pprint(excluded_tables) 82 | 83 | # COMMAND ---------- 84 | 85 | # NOTE: delta history logic commented out. We do want this but we will need to figure out how to ensure 86 | # table permissions work across different environments. Currently the artificial data table gets created 87 | # and the owner is set to admin or data-managers (on ref), but this then breaks the CP jobs permissions 88 | # because it no longer owns the table so can't call 'DESCRIBE HISTORY'. Ticket added to the backlog to 89 | # revisit this. 90 | 91 | # # Record the history of delta tables for users - only table owners can see the history otherwise! 92 | # history_cols = [ 93 | # F.col("version").alias("VERSION"), 94 | # F.col("timestamp").alias("TIMESTAMP"), 95 | # ] 96 | # history_dfs = [] 97 | 98 | for table_name in tables_to_clone: 99 | print(f"Cloning table `{table_name}` from database `{source_database}` to database `{target_database}`") 100 | source_df = spark.table(f"{source_database}.{table_name}") 101 | create_table(spark, source_df, target_database, table_name, mode="overwrite", overwriteSchema="true") 102 | 103 | # history_dfs.append( 104 | # spark.sql(f"DESCRIBE HISTORY {target_database}.{table_name}") 105 | # .select( 106 | # F.lit(table_name).alias("TABLE_NAME"), 107 | # *history_cols 108 | # ) 109 | # ) 110 | 111 | # history_df = reduce(DataFrame.union, history_dfs) 112 | # create_table(spark, history_df, target_database, "delta_history", mode="overwrite", overwriteSchema="true") 113 | 114 | # COMMAND ---------- 115 | 116 | -------------------------------------------------------------------------------- /projects/artificial_hes/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/artificial_hes/tests/placeholder.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/cp_config.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Code Promotion Job Configuration 4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks. 5 | # MAGIC Only simple setting of variables is allowed in this notebook. 6 | 7 | # COMMAND ---------- 8 | 9 | import os 10 | 11 | # COMMAND ---------- 12 | 13 | # DBTITLE 1,Global settings 14 | # MAGIC %md 15 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3). 16 | # MAGIC This applies to all jobs created 17 | 18 | # COMMAND ---------- 19 | 20 | spark_version = "6.6.x-scala2.11" 21 | 22 | # COMMAND ---------- 23 | 24 | # DBTITLE 1,init_schemas 25 | # MAGIC %md 26 | # MAGIC Available parameters: 27 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 28 | 29 | # COMMAND ---------- 30 | 31 | # Example: 32 | init_schemas = { 33 | "retain_cluster": False 34 | } 35 | 36 | # COMMAND ---------- 37 | 38 | # DBTITLE 1,run_notebooks 39 | # MAGIC %md 40 | # MAGIC Available parameters: 41 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time. 42 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job. 43 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax. 44 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 45 | 46 | # COMMAND ---------- 47 | 48 | # Example: 49 | run_notebooks = { 50 | "concurrency": 1, 51 | "extra_parameters": {}, 52 | "retain_cluster": False, 53 | } 54 | 55 | if os.getenv("env", "ref") == "prod": 56 | run_notebooks = { 57 | **run_notebooks, 58 | "num_workers": 8, 59 | "instance_type": "i3.2xlarge", 60 | } 61 | 62 | # COMMAND ---------- 63 | 64 | # DBTITLE 1,tool_config 65 | # MAGIC %md 66 | # MAGIC Available parameters: 67 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 68 | 69 | # COMMAND ---------- 70 | 71 | # Example: 72 | # tool_config = { 73 | # "retain_cluster": True 74 | # } 75 | #tool_config = { 76 | #} -------------------------------------------------------------------------------- /projects/artificial_hes_meta/init_schemas.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Artificial HES Meta: Init Schemas 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Code Promotion Project Paths 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ./notebooks/code_promotion_paths 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md 22 | # MAGIC ### Imports 23 | 24 | # COMMAND ---------- 25 | 26 | # MAGIC %run ./notebooks/common/widget_utils 27 | 28 | # COMMAND ---------- 29 | 30 | import pprint 31 | 32 | # COMMAND ---------- 33 | 34 | # MAGIC %md 35 | # MAGIC ### Widgets 36 | 37 | # COMMAND ---------- 38 | 39 | if check_databricks(): 40 | dbutils.widgets.removeAll() 41 | dbutils.widgets.text("notebook_root", "artificial_hes_meta/dev", "0.1 Notebook Root") 42 | dbutils.widgets.text("db", "artificial_hes_meta", "0.2 Project Database") 43 | else: 44 | # Only make widgets in databricks 45 | pass 46 | 47 | # COMMAND ---------- 48 | 49 | # MAGIC %md ## Main 50 | 51 | # COMMAND ---------- 52 | 53 | notebook_root = get_required_argument("notebook_root") 54 | *_, project_name, project_version = notebook_root.split("/") 55 | 56 | cp_project_params = {} 57 | cp_project_params["project_name"] = project_name 58 | cp_project_params["project_version"] = project_version 59 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version) 60 | 61 | # COMMAND ---------- 62 | 63 | print(f"Running with context:") 64 | pprint.pprint(cp_project_params) 65 | 66 | # COMMAND ---------- 67 | 68 | database_name = get_required_argument("db") 69 | adg_schemas_path = cp_project_params["adg_project_path"] / "schemas" 70 | open_data_metadata_uplift_path = str(adg_schemas_path / "uplifts" / "open_data_metadata_uplift") 71 | 72 | notebook_params = [ 73 | { 74 | "path": open_data_metadata_uplift_path, 75 | "timeout_seconds": 0, 76 | "arguments": { 77 | "database_name": database_name, 78 | "table_name": "artificial_hes_meta", 79 | } 80 | }, 81 | ] 82 | 83 | # COMMAND ---------- 84 | 85 | for params in notebook_params: 86 | dbutils.notebook.run(**params) 87 | 88 | # COMMAND ---------- 89 | 90 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/notebooks/code_promotion_paths.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import pathlib 3 | 4 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6" 5 | ADG_PROJECT_NAME = "iuod_artificial_data_generator" 6 | CODE_PROMOTION_RELEASES_PATH = pathlib.Path("/Users/admin/releases/code-promotion") 7 | 8 | 9 | def get_adg_env(project_version: str) -> str: 10 | if project_version == "dev": 11 | return "dev" 12 | elif project_version == "staging": 13 | return "staging" 14 | else: 15 | return "release" 16 | 17 | 18 | def get_adg_project_path(project_version: str) -> pathlib.Path: 19 | adg_env = get_adg_env(project_version) 20 | 21 | if adg_env == "release": 22 | adg_project_path = CODE_PROMOTION_RELEASES_PATH / ADG_PROJECT_NAME / ADG_PROJECT_VERSION 23 | elif adg_env == "staging": 24 | root_path = pathlib.Path("/staging") 25 | adg_project_path = root_path / ADG_PROJECT_NAME 26 | elif adg_env == "dev": 27 | root_path = pathlib.Path("../") 28 | adg_project_path = root_path / ADG_PROJECT_NAME 29 | 30 | return adg_project_path -------------------------------------------------------------------------------- /projects/artificial_hes_meta/notebooks/common/coerce_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import DataFrame, types as T, functions as F 3 | 4 | 5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame: 6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis). 7 | 8 | Fields that exist in the target schema but which are not present in the original dataframe 9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised 10 | 11 | Fields that exist in the original dataframe but not the target schema will be excluded 12 | from the result 13 | 14 | Args: 15 | df (pyspark.sql.DataFrame): Dataframe to transform 16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result 17 | 18 | Raises: 19 | ValueError: If a non-nullable field exists in the target schema which is not in 20 | the original dataframe 21 | 22 | Returns: 23 | DataFrame: transformed dataframe 24 | """ 25 | selected_fields = [] 26 | 27 | for field in target_schema.fields: 28 | if field.name in df.columns: 29 | value_col = F.col(field.name) 30 | else: 31 | if field.nullable: 32 | value_col = F.lit(None) 33 | else: 34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!") 35 | 36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name)) 37 | 38 | return df.select(*selected_fields) -------------------------------------------------------------------------------- /projects/artificial_hes_meta/notebooks/common/common_exports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./table_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./widget_utils 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./spark_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./coerce_schema -------------------------------------------------------------------------------- /projects/artificial_hes_meta/notebooks/common/widget_utils.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import os 3 | 4 | # COMMAND ---------- 5 | 6 | def check_databricks() -> bool: 7 | """Check whether the code is running in the databricks environment 8 | 9 | Returns: 10 | bool: indicates whether the current environment is databricks 11 | """ 12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3" 13 | 14 | 15 | # COMMAND ---------- 16 | 17 | def get_required_argument(name: str) -> str: 18 | """Get the value of a dbutils widget, raising an exception if 19 | the value is not provided. 20 | 21 | Args: 22 | name (str): name of the widget to get the value of 23 | 24 | Raises: 25 | NotImplementedError: running this function outside databricks 26 | ValueError: if a value is not provided for the widget 27 | 28 | Returns: 29 | str: value of the widget 30 | """ 31 | if check_databricks(): 32 | value = dbutils.widgets.get(name) 33 | else: 34 | # TODO 35 | raise NotImplementedError("No way to provide arguments outside of Databricks") 36 | 37 | try: 38 | assert value 39 | return value 40 | except AssertionError: 41 | raise ValueError(f"Argument '{name}' must be provided") 42 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/run_notebooks.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Artificial HES Meta: Run Notebooks 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Code Promotion Project Paths 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ./notebooks/code_promotion_paths 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md 22 | # MAGIC ### Imports 23 | 24 | # COMMAND ---------- 25 | 26 | # MAGIC %run ./notebooks/common/widget_utils 27 | 28 | # COMMAND ---------- 29 | 30 | import pprint 31 | import json 32 | 33 | # COMMAND ---------- 34 | 35 | # MAGIC %md 36 | # MAGIC ### Widgets 37 | 38 | # COMMAND ---------- 39 | 40 | if check_databricks(): 41 | dbutils.widgets.removeAll() 42 | dbutils.widgets.text("notebook_root", "artificial_hes_meta/dev", "0.1 Notebook Root") 43 | dbutils.widgets.text("db", "artificial_hes_meta", "0.2 Project Database") 44 | dbutils.widgets.text("hes", "hes", "0.3 HES Database") 45 | else: 46 | # Only make widgets in databricks 47 | pass 48 | 49 | # COMMAND ---------- 50 | 51 | # MAGIC %md ## Main 52 | 53 | # COMMAND ---------- 54 | 55 | notebook_root = get_required_argument("notebook_root") 56 | *_, project_name, project_version = notebook_root.split("/") 57 | 58 | cp_project_params = {} 59 | cp_project_params["project_name"] = project_name 60 | cp_project_params["project_version"] = project_version 61 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version) 62 | 63 | # COMMAND ---------- 64 | 65 | print(f"Running with context:") 66 | pprint.pprint(cp_project_params) 67 | 68 | # COMMAND ---------- 69 | 70 | params = { 71 | "path": str(cp_project_params["adg_project_path"] / "notebooks" / "scraper_pipelines" / "hes" / "driver"), 72 | "timeout_seconds": 0, 73 | "arguments": { 74 | "hes_database": get_required_argument("hes"), 75 | "meta_database": get_required_argument("db"), 76 | } 77 | } 78 | 79 | dbutils.notebook.run(**params) 80 | 81 | # COMMAND ---------- 82 | 83 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/schemas/placeholder.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/artificial_hes_meta/tests/placeholder.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/cp_config.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Code Promotion Job Configuration 4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks. 5 | # MAGIC Only simple setting of variables is allowed in this notebook. 6 | 7 | # COMMAND ---------- 8 | 9 | # DBTITLE 1,Global settings 10 | # MAGIC %md 11 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3). 12 | # MAGIC This applies to all jobs created 13 | 14 | # COMMAND ---------- 15 | 16 | spark_version = "6.6.x-scala2.11" 17 | #spark_version = "9.1.x-scala2.12" 18 | 19 | 20 | # COMMAND ---------- 21 | 22 | # DBTITLE 1,init_schemas 23 | # MAGIC %md 24 | # MAGIC Available parameters: 25 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 26 | 27 | # COMMAND ---------- 28 | 29 | init_schemas = { 30 | "retain_cluster": False 31 | } 32 | 33 | # COMMAND ---------- 34 | 35 | # DBTITLE 1,run_notebooks 36 | # MAGIC %md 37 | # MAGIC Available parameters: 38 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time. 39 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job. 40 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax. 41 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 42 | 43 | # COMMAND ---------- 44 | 45 | # Example: 46 | run_notebooks = { 47 | "concurrency": 1, 48 | "extra_parameters": { 49 | "safety_dt": "", 50 | "source_database_name": "", 51 | "table_name": "", 52 | "reviewer_name": "", 53 | "reviewer_email": "", 54 | "reviewed_date": "", 55 | "approver_name": "", 56 | "approver_email": "", 57 | "approved_date": "", 58 | }, 59 | "retain_cluster": False 60 | } 61 | 62 | # COMMAND ---------- 63 | 64 | # DBTITLE 1,tool_config 65 | # MAGIC %md 66 | # MAGIC Available parameters: 67 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 68 | 69 | # COMMAND ---------- 70 | 71 | # Example: 72 | # tool_config = { 73 | # "retain_cluster": True 74 | # } 75 | tool_config = { 76 | } -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/init_schemas.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/notebooks/code_promotion_paths.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import pathlib 3 | 4 | # COMMAND ---------- 5 | 6 | def get_adg_project_path(project_env: str, project_version: str) -> pathlib.Path: 7 | adg_project_name = "iuod_artificial_data_generator" 8 | 9 | if project_env == "release": 10 | root_path = pathlib.Path("/Users/admin/releases/code-promotion") 11 | adg_project_path = root_path / adg_project_name / project_version 12 | elif project_env == "staging": 13 | root_path = pathlib.Path("/staging") 14 | adg_project_path = root_path / adg_project_name 15 | elif project_env == "dev": 16 | root_path = pathlib.Path("../") 17 | adg_project_path = root_path / adg_project_name 18 | 19 | return adg_project_path -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/notebooks/code_promotion_versions.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6" 3 | print(f"iuod_artificial_data_generator: {ADG_PROJECT_VERSION}") -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/notebooks/widget_utils.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import os 3 | 4 | # COMMAND ---------- 5 | 6 | def check_databricks() -> bool: 7 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3" 8 | 9 | 10 | # COMMAND ---------- 11 | 12 | def get_required_argument(name: str) -> str: 13 | if check_databricks(): 14 | value = dbutils.widgets.get(name) 15 | else: 16 | # TODO 17 | raise NotImplementedError("No way to provide arguments outside of Databricks") 18 | 19 | try: 20 | assert value 21 | return value 22 | except AssertionError: 23 | raise ValueError(f"Argument '{name}' must be provided") 24 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/run_notebooks.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC ## Run Notebooks 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Code Promotion Version 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ./notebooks/code_promotion_versions 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md 22 | # MAGIC ### Imports 23 | 24 | # COMMAND ---------- 25 | 26 | # MAGIC %run ./notebooks/code_promotion_paths 27 | 28 | # COMMAND ---------- 29 | 30 | # MAGIC %run ./notebooks/widget_utils 31 | 32 | # COMMAND ---------- 33 | 34 | import json 35 | from typing import Dict 36 | 37 | # COMMAND ---------- 38 | 39 | # MAGIC %md 40 | # MAGIC ### Widgets 41 | 42 | # COMMAND ---------- 43 | 44 | if check_databricks(): 45 | # TODO: update these! 46 | dbutils.widgets.removeAll() 47 | dbutils.widgets.text("notebook_root", "iuod_artificial_data_admin/dev", "0.1 Notebook Root") 48 | dbutils.widgets.text("db", "iuod_artificial_data_admin", "0.2 Project Database") 49 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database") 50 | dbutils.widgets.text("artificial_hes_meta", "artificial_hes_meta", "0.4 HES Meta Project Database") 51 | dbutils.widgets.text("safety_dt", "", "1.1 Safety Timestamp (UTC)") 52 | dbutils.widgets.text("source_database_name", "", "2.1 Source Database Name") 53 | dbutils.widgets.text("table_name", "", "2.2 Source / Target Table Name") 54 | dbutils.widgets.text("reviewer_name", "", "3.1 Reviewer Name") 55 | dbutils.widgets.text("reviewer_email", "", "3.2 Reviewer Email") 56 | dbutils.widgets.text("reviewed_date", "", "3.3 Reviewed Date") 57 | dbutils.widgets.text("approver_name", "", "3.4 Approver Name") 58 | dbutils.widgets.text("approver_email", "", "3.5 Approver Email") 59 | dbutils.widgets.text("approved_date", "", "3.6 Approved Date") 60 | else: 61 | # Only make widgets in databricks 62 | pass 63 | 64 | # COMMAND ---------- 65 | 66 | # MAGIC %md 67 | # MAGIC ## Functions 68 | 69 | # COMMAND ---------- 70 | 71 | def get_adg_env(project_version: str) -> str: 72 | if project_version == "dev": 73 | return "dev" 74 | elif project_version == "staging": 75 | return "staging" 76 | else: 77 | return "release" 78 | 79 | # COMMAND ---------- 80 | 81 | def get_driver_arguments() -> Dict[str, str]: 82 | source_database_name = get_required_argument(get_required_argument("source_database_name")) # Must be one of the databases in the project permissions scope 83 | return { 84 | "safety_dt": get_required_argument("safety_dt"), 85 | "config_json": json.dumps(dict( 86 | source_database_name = source_database_name, 87 | target_database_name = get_required_argument("iuod_artificial_data_generator"), 88 | table_name = get_required_argument("table_name"), 89 | reviewer_name = get_required_argument("reviewer_name"), 90 | reviewer_email = get_required_argument("reviewer_email"), 91 | reviewed_date = get_required_argument("reviewed_date"), 92 | approver_name = get_required_argument("approver_name"), 93 | approver_email = get_required_argument("approver_email"), 94 | approved_date = get_required_argument("approved_date"), 95 | releases_database_name = get_required_argument("db"), 96 | )) 97 | } 98 | 99 | # COMMAND ---------- 100 | 101 | # MAGIC %md 102 | # MAGIC ## Main 103 | 104 | # COMMAND ---------- 105 | 106 | # TODO: wrap in try catch 107 | driver_args = get_driver_arguments() # All args must be provided to proceed! 108 | 109 | # COMMAND ---------- 110 | 111 | notebook_response = {} 112 | 113 | # COMMAND ---------- 114 | 115 | notebook_root = get_required_argument("notebook_root") 116 | *_, project_name, project_version = notebook_root.split("/") 117 | 118 | adg_env = get_adg_env(project_version) 119 | adg_project_path = get_adg_project_path(adg_env, ADG_PROJECT_VERSION) 120 | 121 | # COMMAND ---------- 122 | 123 | driver_params = { 124 | "path": str(adg_project_path / "notebooks" / "admin" / "driver"), 125 | "timeout_seconds": 0, 126 | "arguments": get_driver_arguments() 127 | } 128 | driver_response_json = dbutils.notebook.run(**driver_params) 129 | 130 | notebook_response["status"] = "success" 131 | notebook_response["message"] = "" 132 | notebook_response["children"] = [{**driver_params, "response": json.loads(driver_response_json)}] 133 | 134 | # COMMAND ---------- 135 | 136 | dbutils.notebook.exit(json.dumps(notebook_response)) 137 | 138 | # COMMAND ---------- 139 | 140 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/schemas/placeholder.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_admin/tests/placeholder.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/cp_config.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Code Promotion Job Configuration 4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks. 5 | # MAGIC Only simple setting of variables is allowed in this notebook. 6 | 7 | # COMMAND ---------- 8 | 9 | import os 10 | 11 | # COMMAND ---------- 12 | 13 | # DBTITLE 1,Global settings 14 | # MAGIC %md 15 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3). 16 | # MAGIC This applies to all jobs created 17 | 18 | # COMMAND ---------- 19 | 20 | spark_version = "6.6.x-scala2.11" 21 | 22 | # COMMAND ---------- 23 | 24 | # DBTITLE 1,init_schemas 25 | # MAGIC %md 26 | # MAGIC Available parameters: 27 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 28 | 29 | # COMMAND ---------- 30 | 31 | init_schemas = { 32 | "retain_cluster": False 33 | } 34 | 35 | # COMMAND ---------- 36 | 37 | # DBTITLE 1,run_notebooks 38 | # MAGIC %md 39 | # MAGIC Available parameters: 40 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time. 41 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job. 42 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax. 43 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 44 | 45 | # COMMAND ---------- 46 | 47 | run_notebooks = { 48 | "concurrency": 1, 49 | "extra_parameters": { 50 | "artificial_dataset": "", 51 | "parameters_json": "{}", 52 | }, 53 | "retain_cluster": False, 54 | } 55 | 56 | if os.getenv("env", "ref") == "prod": 57 | run_notebooks = { 58 | **run_notebooks, 59 | "num_workers": 8, 60 | "instance_type": "i3.2xlarge", 61 | } 62 | 63 | # COMMAND ---------- 64 | 65 | # DBTITLE 1,tool_config 66 | # MAGIC %md 67 | # MAGIC Available parameters: 68 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained 69 | 70 | # COMMAND ---------- 71 | 72 | # Example: 73 | # tool_config = { 74 | # "retain_cluster": True 75 | # } 76 | #tool_config = { 77 | #} -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/init_schemas.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Artificial Data Generator: Init Schemas 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Imports 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ./notebooks/common/widget_utils 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %md 22 | # MAGIC ### Widgets 23 | 24 | # COMMAND ---------- 25 | 26 | if check_databricks(): 27 | dbutils.widgets.removeAll() 28 | dbutils.widgets.text("db", "iuod_artificial_data_generator", "0.1 Database") 29 | dbutils.widgets.text("notebook_root", "iuod_artificial_data_generator/dev", "0.2 Notebook Root") 30 | else: 31 | # Only make widgets in databricks 32 | pass 33 | 34 | # COMMAND ---------- 35 | 36 | # MAGIC %md 37 | # MAGIC ## Main 38 | 39 | # COMMAND ---------- 40 | 41 | database_name = get_required_argument("db") 42 | create_meta_notebook_path = "./schemas/create_meta_table" 43 | open_data_metadata_uplift_path = "./schemas/uplifts/open_data_metadata_uplift" 44 | 45 | notebook_params = [ 46 | { 47 | "path": open_data_metadata_uplift_path, 48 | "timeout_seconds": 0, 49 | "arguments": { 50 | "database_name": database_name, 51 | "table_name": "artificial_hes_meta", 52 | } 53 | }, 54 | { 55 | "path": create_meta_notebook_path, 56 | "timeout_seconds": 300, 57 | "arguments": { 58 | "database_name": database_name, 59 | "table_name": "artificial_hes_meta", 60 | } 61 | }, 62 | ] 63 | 64 | # COMMAND ---------- 65 | 66 | for params in notebook_params: 67 | dbutils.notebook.run(**params) 68 | 69 | # COMMAND ---------- 70 | 71 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/admin/driver.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC ## Admin Driver 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Imports 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ../common/table_helpers 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %run ../common/widget_utils 22 | 23 | # COMMAND ---------- 24 | 25 | import datetime 26 | from datetime import datetime 27 | import json 28 | from typing import Iterable, Dict 29 | from pyspark.sql import DataFrame, functions as F, Column 30 | 31 | # COMMAND ---------- 32 | 33 | # MAGIC %md 34 | # MAGIC ### Widgets 35 | 36 | # COMMAND ---------- 37 | 38 | if check_databricks(): 39 | dbutils.widgets.removeAll() 40 | dbutils.widgets.text("safety_dt", "", "0 Safety Timestamp (UTC)") 41 | dbutils.widgets.text("config_json", "", "1 Config JSON") 42 | else: 43 | pass 44 | 45 | # COMMAND ---------- 46 | 47 | # MAGIC %md 48 | # MAGIC ## Functions 49 | 50 | # COMMAND ---------- 51 | 52 | def get_release_data(config: Dict[str, str]) -> DataFrame: 53 | reviewed_date = datetime.strptime(config["reviewed_date"], "%Y-%m-%d") 54 | approved_date = datetime.strptime(config["approved_date"], "%Y-%m-%d") 55 | release_cols = [ 56 | F.current_timestamp() .alias("RELEASE_TIMESTAMP_UTC"), 57 | F.lit(config["source_database_name"]).alias("SOURCE_DATABASE_NAME"), 58 | F.lit(config["target_database_name"]).alias("TARGET_DATABASE_NAME"), 59 | F.lit(config["table_name"]) .alias("TABLE_NAME"), 60 | F.lit(config["reviewer_name"]) .alias("REVIEWER_NAME"), 61 | F.lit(config["reviewer_email"]) .alias("REVIEWER_EMAIL"), 62 | F.lit(reviewed_date).cast("date") .alias("REVIEWED_DATE"), 63 | F.lit(config["approver_name"]) .alias("APPROVER_NAME"), 64 | F.lit(config["approver_email"]) .alias("APPROVER_EMAIL"), 65 | F.lit(approved_date).cast("date") .alias("APPROVED_DATE"), 66 | ] 67 | 68 | return spark.range(1).select(*release_cols) 69 | 70 | # COMMAND ---------- 71 | 72 | def check_safe() -> bool: 73 | safety_window_seconds = 600 74 | current_dt = datetime.now() 75 | safety_dt = datetime.strptime(get_required_argument("safety_dt"), "%Y-%m-%d %H:%M") 76 | safe = abs(current_dt - safety_dt).total_seconds() < safety_window_seconds 77 | return safe 78 | 79 | # COMMAND ---------- 80 | 81 | # MAGIC %md 82 | # MAGIC ## Main 83 | 84 | # COMMAND ---------- 85 | 86 | config = json.loads(get_required_argument("config_json")) 87 | source_table_path = f"{config['source_database_name']}.{config['table_name']}" 88 | target_table_path = f"{config['target_database_name']}.{config['table_name']}" 89 | 90 | # COMMAND ---------- 91 | 92 | notebook_response = {} 93 | 94 | if check_safe(): 95 | print("Metadata marked as safe") 96 | 97 | print(f"Creating release log entry in table '{config['releases_database_name']}.releases'") 98 | 99 | release_df = get_release_data(config) # Must be populated before release 100 | create_table(spark, release_df, config['releases_database_name'], "releases", format="delta", mode="append") 101 | 102 | print(f"Publishing from '{source_table_path}' to '{target_table_path}'") 103 | 104 | meta_df = spark.table(source_table_path) 105 | 106 | # Insert the rows into the target table 107 | # Table should already have been created by the owner) 108 | insert_into(spark, meta_df, *target_table_path.split("."), overwrite=True) 109 | spark.sql(f"OPTIMIZE {target_table_path}") 110 | 111 | print("Done!") 112 | 113 | notebook_response["status"] = "success" 114 | notebook_response["message"] = "" 115 | else: 116 | notebook_response["status"] = "failed" 117 | notebook_response["message"] = "Failed safety check on user-entered timestamp" 118 | 119 | # COMMAND ---------- 120 | 121 | dbutils.notebook.exit(json.dumps(notebook_response)) 122 | 123 | # COMMAND ---------- 124 | 125 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/common/coerce_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import DataFrame, types as T, functions as F 3 | 4 | 5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame: 6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis). 7 | 8 | Fields that exist in the target schema but which are not present in the original dataframe 9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised 10 | 11 | Fields that exist in the original dataframe but not the target schema will be excluded 12 | from the result 13 | 14 | Args: 15 | df (pyspark.sql.DataFrame): Dataframe to transform 16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result 17 | 18 | Raises: 19 | ValueError: If a non-nullable field exists in the target schema which is not in 20 | the original dataframe 21 | 22 | Returns: 23 | DataFrame: transformed dataframe 24 | """ 25 | selected_fields = [] 26 | 27 | for field in target_schema.fields: 28 | if field.name in df.columns: 29 | value_col = F.col(field.name) 30 | else: 31 | if field.nullable: 32 | value_col = F.lit(None) 33 | else: 34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!") 35 | 36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name)) 37 | 38 | return df.select(*selected_fields) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/common/common_exports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./table_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./widget_utils 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./spark_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./coerce_schema -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/common/widget_utils.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import os 3 | 4 | # COMMAND ---------- 5 | 6 | def check_databricks() -> bool: 7 | """Check whether the code is running in the databricks environment 8 | 9 | Returns: 10 | bool: indicates whether the current environment is databricks 11 | """ 12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3" 13 | 14 | 15 | # COMMAND ---------- 16 | 17 | def get_required_argument(name: str) -> str: 18 | """Get the value of a dbutils widget, raising an exception if 19 | the value is not provided. 20 | 21 | Args: 22 | name (str): name of the widget to get the value of 23 | 24 | Raises: 25 | NotImplementedError: running this function outside databricks 26 | ValueError: if a value is not provided for the widget 27 | 28 | Returns: 29 | str: value of the widget 30 | """ 31 | if check_databricks(): 32 | value = dbutils.widgets.get(name) 33 | else: 34 | # TODO 35 | raise NotImplementedError("No way to provide arguments outside of Databricks") 36 | 37 | try: 38 | assert value 39 | return value 40 | except AssertionError: 41 | raise ValueError(f"Argument '{name}' must be provided") 42 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_ae_tables.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import List 3 | 4 | 5 | HES_AE_PRIMARY_KEY = "AEKEY" 6 | HES_AE_TABLE_NAMES = [ 7 | "hes_ae_2021", 8 | "hes_ae_1920", 9 | "hes_ae_1819", 10 | "hes_ae_1718", 11 | "hes_ae_1617", 12 | "hes_ae_1516", 13 | "hes_ae_1415", 14 | "hes_ae_1314", 15 | "hes_ae_1213", 16 | "hes_ae_1112", 17 | "hes_ae_1011", 18 | "hes_ae_0910", 19 | "hes_ae_0809", 20 | "hes_ae_0708", 21 | ] 22 | 23 | 24 | def _get_hes_ae_key_fields(patient_table: Table) -> List[Field]: 25 | return [ 26 | Field(HES_AE_PRIMARY_KEY, str, primary=True), 27 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]), 28 | ] 29 | 30 | 31 | def get_hes_ae_tables(patient_table: Table) -> List[Table]: 32 | hes_ae_key_fields = _get_hes_ae_key_fields(patient_table) 33 | hes_ae_tables = [] 34 | 35 | for table_name in HES_AE_TABLE_NAMES: 36 | hes_ae_tables.append(Table(table_name, *hes_ae_key_fields)) 37 | 38 | return hes_ae_tables 39 | 40 | 41 | HES_AE_TABLES = get_hes_ae_tables(HES_PATIENTS_SYMBOLIC_TABLE) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_apc_tables.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import List 3 | 4 | 5 | HES_APC_PRIMARY_KEY = "EPIKEY" 6 | HES_APC_TABLE_NAMES = [ 7 | "hes_apc_2021", 8 | "hes_apc_1920", 9 | "hes_apc_1819", 10 | "hes_apc_1718", 11 | "hes_apc_1617", 12 | "hes_apc_1516", 13 | "hes_apc_1415", 14 | "hes_apc_1314", 15 | "hes_apc_1213", 16 | "hes_apc_1112", 17 | "hes_apc_1011", 18 | "hes_apc_0910", 19 | "hes_apc_0809", 20 | "hes_apc_0708", 21 | "hes_apc_0607", 22 | "hes_apc_0506", 23 | "hes_apc_0405", 24 | "hes_apc_0304", 25 | "hes_apc_0203", 26 | "hes_apc_0102", 27 | ] 28 | 29 | 30 | def _get_hes_apc_key_fields(patient_table: Table) -> List[Field]: 31 | return [ 32 | Field(HES_APC_PRIMARY_KEY, str, primary=True), 33 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]), 34 | ] 35 | 36 | 37 | def get_hes_apc_tables(patient_table: Table) -> List[Table]: 38 | hes_apc_key_fields = _get_hes_apc_key_fields(patient_table) 39 | hes_apc_tables = [] 40 | 41 | for table_name in HES_APC_TABLE_NAMES: 42 | hes_apc_tables.append(Table(table_name, *hes_apc_key_fields)) 43 | 44 | return hes_apc_tables 45 | 46 | 47 | HES_APC_TABLES = get_hes_apc_tables(HES_PATIENTS_SYMBOLIC_TABLE) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_op_tables.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import List 3 | 4 | 5 | HES_OP_PRIMARY_KEY = "ATTENDKEY" 6 | HES_OP_TABLE_NAMES = [ 7 | "hes_op_2021", 8 | "hes_op_1920", 9 | "hes_op_1819", 10 | "hes_op_1718", 11 | "hes_op_1617", 12 | "hes_op_1516", 13 | "hes_op_1415", 14 | "hes_op_1314", 15 | "hes_op_1213", 16 | "hes_op_1112", 17 | "hes_op_1011", 18 | "hes_op_0910", 19 | "hes_op_0809", 20 | "hes_op_0708", 21 | "hes_op_0607", 22 | "hes_op_0506", 23 | "hes_op_0405", 24 | "hes_op_0304", 25 | ] 26 | 27 | 28 | def _get_hes_op_key_fields(patient_table: Table) -> List[Field]: 29 | return [ 30 | Field(HES_OP_PRIMARY_KEY, str, primary=True), 31 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]), 32 | ] 33 | 34 | 35 | def get_hes_op_tables(patient_table: Table) -> List[Table]: 36 | hes_op_key_fields = _get_hes_op_key_fields(patient_table) 37 | hes_op_tables = [] 38 | 39 | for table_name in HES_OP_TABLE_NAMES: 40 | hes_op_tables.append(Table(table_name, *hes_op_key_fields)) 41 | 42 | return hes_op_tables 43 | 44 | 45 | HES_OP_TABLES = get_hes_op_tables(HES_PATIENTS_SYMBOLIC_TABLE) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_patient_table.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../relational 3 | 4 | # COMMAND ---------- 5 | 6 | # from dsp.common.relational import Table, Field 7 | 8 | HES_PATIENT_KEY = "PSEUDO_HESID" 9 | 10 | 11 | def get_hes_patients_symbolic_table() -> Table: 12 | """ 13 | Note 14 | ---- 15 | This function returns an object that symbolises the fact that there is 16 | no 'Master' table in HES. Patient IDs are duplicated across events, but in 17 | principle they are unique for each individual and the distinct values 18 | could be extracted into a table which represents the unique patients within HES. 19 | 20 | We don't actually need to create the table, just to be aware 21 | that the primary key on this symbolic table is the foreign key on the 22 | actual hes tables used to specify the patient to which an episode relates 23 | """ 24 | fields = [ 25 | Field(HES_PATIENT_KEY, str, primary=True), 26 | Field("MYDOB", str), 27 | Field("SEX", str), 28 | Field("ETHNOS", str), 29 | ] 30 | return Table("hes_patients_symbolic", *fields) 31 | 32 | 33 | HES_PATIENTS_SYMBOLIC_TABLE = get_hes_patients_symbolic_table() -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_schemas.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./hes_ae_schema 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./hes_apc_schema 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./hes_op_schema 11 | 12 | # COMMAND ---------- 13 | 14 | from pyspark.sql import types as T 15 | 16 | # COMMAND ---------- 17 | 18 | def get_hes_schema(hes_dataset: str) -> T.StructType: 19 | if hes_dataset == "hes_ae": 20 | return get_hes_ae_schema() 21 | elif hes_dataset == "hes_op": 22 | return get_hes_op_schema() 23 | elif hes_dataset == "hes_apc": 24 | return get_hes_apc_schema() 25 | else: 26 | raise NotImplementedError("Valid options for hes_dataset are: hes_apc, hes_op, hes_ae") 27 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_tables.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./hes_patient_table 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./hes_ae_tables 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./hes_apc_tables 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./hes_op_tables 15 | 16 | # COMMAND ---------- 17 | 18 | HES_TABLES = { 19 | "hes_ae": HES_AE_TABLES, 20 | "hes_apc": HES_APC_TABLES, 21 | "hes_op": HES_OP_TABLES, 22 | } -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dataset_definitions/relational_helpers.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./relational 3 | 4 | # COMMAND ---------- 5 | 6 | from typing import Generator 7 | 8 | # region Prevent downstream linting highlights 9 | Table = Table 10 | Field = Field 11 | # endregion 12 | 13 | 14 | def get_parent_table(table: Table): 15 | for field in table.fields: 16 | if field.foreign: 17 | return field.foreign.table 18 | else: 19 | return None 20 | 21 | 22 | def get_primary_key(table: Table) -> Field: 23 | for field in table.fields: 24 | if field.primary: 25 | return field 26 | else: 27 | return None 28 | 29 | 30 | def get_foreign_keys(table: Table) -> Generator[Field, None, None]: 31 | return filter(lambda field: field.foreign is not None, table.fields) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/dependencies/random.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from __future__ import annotations 3 | from typing import Union, Iterable 4 | import itertools 5 | import random 6 | 7 | import numpy as np 8 | from sklearn.neighbors import KernelDensity 9 | 10 | from pyspark.sql import Column, functions as F, types as T 11 | 12 | # COMMAND ---------- 13 | 14 | def weighted_choice_column(weighted_values: dict[Union[str, int], float], seed: int=None) -> Column: 15 | """Create a DataFrame column by sampling from weighted values 16 | 17 | Parameters 18 | ---------- 19 | weighted_values: dict[Union[str, int], float] 20 | Values to sample mapped to probabilities/frequencies. 21 | seed: int (default=None) 22 | Random seed 23 | 24 | Returns 25 | ------- 26 | pyspark.sql.Column 27 | Randomly chosen values 28 | 29 | Examples 30 | -------- 31 | >>> # df is a pyspark DataFrame 32 | >>> # Add a column with a random choice between 1 & 2 33 | >>> df = df.withColumn('choice', weighted_choice_column({1: 0.5, 2: 0.5})) 34 | | ... | choice | 35 | | ... | 1 | 36 | | ... | 2 | 37 | | ... | ... | 38 | | ... | 2 | 39 | """ 40 | values, weights = zip(*weighted_values.items()) 41 | weights = list(map(float, weights)) 42 | 43 | rv = sum(weights) * (F.lit(1.0) - F.rand(seed=seed)) 44 | cum_weights = map(F.lit, itertools.accumulate(weights)) 45 | 46 | # Choose the values based on the weights 47 | # Note: instead of using array, removing null values and picking the 0th element, you could use F.coalesce. However, 48 | # this was found to give skewed results with the resulting random values (why?), so the implementation here is a workaround 49 | choices_col = F.array(*itertools.starmap(lambda cum_weight, value: F.when(rv < cum_weight, F.lit(value)), zip(cum_weights, values))) 50 | choices_col = F.array_except(choices_col, F.array(F.lit(None)))[0] 51 | 52 | return choices_col.alias(f'weighted_choice({weighted_values}, seed={seed})') 53 | 54 | 55 | def uniform_choice_column(values: list[Union[int, str, float]], seed: int=None) -> Column: 56 | """Create a Column by sampling from values with equal probabilities. 57 | 58 | Parameters 59 | ---------- 60 | values: list[Union[str, int, float]] 61 | Values to sample 62 | seed: int 63 | Random seed 64 | 65 | Returns 66 | ------- 67 | pyspark.sql.Column 68 | Randomly chosen values 69 | 70 | Examples 71 | -------- 72 | >>> # df is a pyspark DataFrame 73 | >>> # Add a column with a random choice between 1 & 2 74 | >>> df = df.withColumn('choice', uniform_choice_column([1, 2])) 75 | | ... | choice | 76 | | ... | 1 | 77 | | ... | 2 | 78 | | ... | ... | 79 | | ... | 2 | 80 | """ 81 | weighted_values = dict(zip(values, itertools.repeat(1))) 82 | return weighted_choice_column(weighted_values).alias(f'uniform_choice({values}, seed={seed})') 83 | 84 | 85 | def random_integer_column( 86 | start: Union[int, str, Column], 87 | stop: Union[int, str, Column], 88 | seed: int=None 89 | ) -> Column: 90 | """Create a column of random integers within a given range. 91 | 92 | Parameters 93 | ---------- 94 | start: Union[int, str, Column] 95 | Start of range (inclusive). If start is an int then it is used as a literal value column. 96 | If it is a string it is taken to refer to a column name. Else it is used directly 97 | as a column expression. 98 | stop: Union[int, str, Column] 99 | End of range (inclusive). Similar interpretation to start. 100 | seed: int (default=None) 101 | Seed for the random number generator 102 | 103 | Returns 104 | ------- 105 | pyspark.sql.Column 106 | Random numbers in the given range 107 | """ 108 | alias = f'random_integer({start}, {stop}, seed={seed})' 109 | 110 | if isinstance(start, int): 111 | start = F.lit(start) 112 | elif isinstance(start, str): 113 | start = F.col(start) 114 | 115 | if isinstance(stop, int): 116 | stop = F.lit(stop) 117 | elif isinstance(stop, str): 118 | stop = F.col(stop) 119 | 120 | # Extend limits to ensure boundaries are equally likely 121 | stop = stop + F.lit(0.999999) 122 | 123 | return F.floor(start + ((stop - start) * F.rand(seed=seed))).cast(T.IntegerType()).alias(alias) 124 | 125 | # COMMAND ---------- 126 | 127 | class SparkRandom(object): 128 | """Provides interface to methods analogous to those provided by Python's 129 | built-in random module but for operating on columns of a PySpark DataFrame. 130 | """ 131 | def __init__(self, seed=None): 132 | self._seed = seed 133 | 134 | def seed(self, x): 135 | self._seed = x 136 | 137 | def choice(self, values: list[Union[str, int]]) -> Column: 138 | return uniform_choice_column(list(values), seed=self._seed) 139 | 140 | def choices(self, population: list[Union[str, int]], weights: list[Union[int, float]]) -> Column: 141 | return weighted_choice_column(dict(*zip(population, weights)), seed=self._seed) 142 | 143 | def randint(self, start: Union[int, str, Column], stop: Union[int, str, Column]) -> Column: 144 | return random_integer_column(start, stop, seed=self._seed) 145 | 146 | def shuffle(self, col: Union[str, Column]) -> Column: 147 | return F.shuffle(col) 148 | 149 | 150 | spark_random = SparkRandom() -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/coerce_hes_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../dataset_definitions/hes/hes_schemas 3 | 4 | # COMMAND ---------- 5 | 6 | from pyspark.sql import DataFrame, functions as F 7 | 8 | # region Prevent downstream linting highlights 9 | get_hes_schema = get_hes_schema 10 | # endregion 11 | 12 | def coerce_hes_schema(df: DataFrame, hes_dataset: str) -> DataFrame: 13 | """Transform the schema of a dataframe to match the schema of a given HES dataset 14 | 15 | Args: 16 | df (DataFrame): DataFrame to transform 17 | hes_dataset (str): HES dataset with the target schema. Either hes_apc, hes_op or hes_ae 18 | 19 | Returns: 20 | DataFrame: New DataFrame transformed to the target schema 21 | """ 22 | current_fields = {f.name: f for f in df.schema.fields} 23 | coerced_fields = [] 24 | 25 | for field in get_hes_schema(hes_dataset).fields: 26 | target_datatype = field.dataType 27 | 28 | if field.name not in current_fields: 29 | print(f"Warning: field `{field.name}` missing from schema! Adding field and filling with nulls") 30 | value_col = F.lit(None) 31 | else: 32 | value_col = F.col(field.name) 33 | 34 | coerced_fields.append(value_col.cast(target_datatype).alias(field.name)) 35 | 36 | coerced_df = df.select(*coerced_fields) 37 | 38 | return coerced_df 39 | 40 | # COMMAND ---------- 41 | 42 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/driver.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../dataset_definitions/hes/hes_tables 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./coerce_hes_schema 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./field_definitions/derivations/derivations_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./field_definitions/id_fields 15 | 16 | # COMMAND ---------- 17 | 18 | # MAGIC %run ./demographic_field_generators 19 | 20 | # COMMAND ---------- 21 | 22 | # MAGIC %run ../driver_imports 23 | 24 | # COMMAND ---------- 25 | 26 | import json 27 | from pprint import pprint 28 | from pyspark.sql import DataFrame, functions as F 29 | 30 | # Prevent spark from auto broadcasting the metadata 31 | spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) 32 | 33 | ARTIFICIAL_TABLE_NAME_TEMPLATE = "artificial_{table_name}" 34 | 35 | # COMMAND ---------- 36 | 37 | dbutils.widgets.removeAll() 38 | 39 | # COMMAND ---------- 40 | 41 | dbutils.widgets.text("meta_database", "iuod_artificial_data_generator", "0.1 Meta DB") 42 | dbutils.widgets.text("artificial_database", "iuod_artificial_data_generator", "0.2 Artificial DB") 43 | dbutils.widgets.text("parameters_json", json.dumps({"n_patients": 100}), "1.1 Parameters JSON") 44 | 45 | # COMMAND ---------- 46 | 47 | # Config 48 | meta_database = get_required_argument("meta_database") 49 | artificial_database = get_required_argument("artificial_database") 50 | 51 | notebook_parameters = json.loads(get_required_argument("parameters_json")) 52 | n_patients = notebook_parameters['n_patients'] 53 | 54 | print("Running with parameters: ") 55 | pprint({ 56 | "n_patients": n_patients, 57 | "artificial_database": artificial_database, 58 | "meta_database": meta_database, 59 | }) 60 | 61 | # COMMAND ---------- 62 | 63 | variable_fields = ["FIELD_NAME"] 64 | index_fields = ["ARTIFICIAL_DEMOGRAPHIC_ID", "ARTIFICIAL_EPISODE_ID"] 65 | 66 | 67 | def generate_artificial_hes_table(meta_df: DataFrame, hes_dataset: str, hes_table: Table, target_database: str) -> bool: 68 | # Create base dataframe for sampling onto 69 | base_demographic_df = ( 70 | spark.range(n_patients) 71 | .select(F.col("id").alias(index_fields[0])) 72 | ) 73 | 74 | # Generate ids for sampling based on the relationships between patients and episodes 75 | base_demographic_episode_df = ( 76 | relationship_generator(filtered_meta_df, base_demographic_df, variable_fields) 77 | .select( 78 | index_fields[0], 79 | F.col("VALUE_NUMERIC").alias(index_fields[1]), 80 | ) 81 | ) 82 | 83 | # Generate 84 | artificial_demographic_df = demographic_field_generator( 85 | filtered_meta_df, 86 | base_demographic_episode_df, 87 | "FIELD_NAME", 88 | index_fields, 89 | ) 90 | artificial_episode_df = field_generator(filtered_meta_df, base_demographic_episode_df, "FIELD_NAME") 91 | artificial_df = artificial_demographic_df.join(artificial_episode_df, how="outer", on=index_fields) 92 | 93 | # Postprocessing 94 | artificial_df = with_hes_id_fields(artificial_df, hes_dataset, hes_table, index_fields) 95 | derived_fields = get_derivations(hes_dataset) 96 | artificial_df = with_derived_fields(artificial_df, *derived_fields) 97 | artificial_df = coerce_hes_schema(artificial_df, hes_dataset) 98 | 99 | # Output 100 | success = create_table( 101 | spark, 102 | artificial_df, 103 | database_name=target_database, 104 | table_name=ARTIFICIAL_TABLE_NAME_TEMPLATE.format(table_name=hes_table.name), 105 | format="delta", 106 | mode="overwrite", 107 | overwriteSchema="true", # This doesn't work on ACL clusters! 108 | ) 109 | 110 | return success 111 | 112 | # COMMAND ---------- 113 | 114 | meta_table_fullname = f"{meta_database}.artificial_hes_meta" 115 | meta_df = spark.table(meta_table_fullname) 116 | 117 | # Generate artificial data for each table 118 | for hes_dataset, hes_dataset_tables in HES_TABLES.items(): 119 | print(f"\nGenerating artificial data for dataset `{hes_dataset}`") 120 | 121 | for hes_table in hes_dataset_tables: 122 | print(f"\tGenerating artificial data representing table `{hes_table.name}` into database `{artificial_database}`") 123 | 124 | filtered_meta_df = ( 125 | meta_df 126 | .filter(F.col("TABLE_NAME") == hes_table.name) 127 | .drop("TABLE_NAME") 128 | ) 129 | 130 | if filtered_meta_df.first() is None: 131 | print(f"\t\tNo metadata found for table '{hes_table.name}': artificial data table will be empty!") 132 | 133 | success = generate_artificial_hes_table(filtered_meta_df, hes_dataset, hes_table, artificial_database) 134 | 135 | 136 | # COMMAND ---------- 137 | 138 | notebook_response = {} 139 | 140 | # TODO: work into for loop above (e.g. try/except for each table and compile list of fails / successes) 141 | if success: 142 | notebook_response["status"] = "success" 143 | notebook_response["message"] = "" 144 | 145 | else: 146 | notebook_response["status"] = "failed" 147 | notebook_response["message"] = "Failed to write test data" 148 | 149 | dbutils.notebook.exit(json.dumps(notebook_response)) 150 | 151 | # COMMAND ---------- 152 | 153 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/age_fields.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import Union 3 | from pyspark.sql import Column, functions as F 4 | 5 | AGE_IN_DAYS_CUTOFF = 120 * 365 6 | 7 | mydob_date = F.to_date("MYDOB", "MMyyyy") 8 | 9 | 10 | def get_fractional_age_field(ref_date: Union[str, Column], dob: Union[str, Column]) -> Column: 11 | age_in_days = F.datediff(ref_date, dob) 12 | return ( 13 | F.when(age_in_days.isNull(), F.expr("null")) 14 | .when(age_in_days < 1, 0.002) 15 | .when(age_in_days <= 6, 0.010) 16 | .when(age_in_days <= 28, 0.048) 17 | .when(age_in_days <= 90, 0.167) 18 | .when(age_in_days <= 181, 0.375) 19 | .when(age_in_days <= 272, 0.625) 20 | .when(age_in_days <= 365 - 1, 0.875) 21 | .when(age_in_days >= AGE_IN_DAYS_CUTOFF, 120) 22 | .otherwise(age_in_days / 365) 23 | .cast("double") 24 | ) 25 | 26 | 27 | def get_categorized_age_field(ref_date: Union[str, Column], dob: Union[str, Column]) -> Column: 28 | age_in_days = F.datediff(ref_date, dob) 29 | return ( 30 | F.when(age_in_days.isNull(), F.expr("null")) 31 | .when(age_in_days < 1, 7001) 32 | .when(age_in_days <= 6, 7002) 33 | .when(age_in_days <= 28, 7003) 34 | .when(age_in_days <= 90, 7004) 35 | .when(age_in_days <= 181, 7005) 36 | .when(age_in_days <= 272, 7006) 37 | .when(age_in_days <= 365-1, 7007) 38 | .when(age_in_days >= AGE_IN_DAYS_CUTOFF, 120) 39 | .otherwise(F.round(age_in_days / 365, 0)) 40 | .cast("integer") 41 | ) 42 | 43 | 44 | def get_fractional_from_categorized_age_field(categorized_age: str) -> Column: 45 | return ( 46 | F.when(F.col(categorized_age).isNull(), F.expr("null")) 47 | .when(F.col(categorized_age) == 7001, 0.002) 48 | .when(F.col(categorized_age) == 7002, 0.010) 49 | .when(F.col(categorized_age) == 7003, 0.048) 50 | .when(F.col(categorized_age) == 7004, 0.167) 51 | .when(F.col(categorized_age) == 7005, 0.375) 52 | .when(F.col(categorized_age) == 7006, 0.625) 53 | .when(F.col(categorized_age) == 7007, 0.875) 54 | .otherwise(F.col(categorized_age)) 55 | .cast("double") 56 | ) 57 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/coded_field_lists.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | opertn_4_fields = [ 3 | "OPERTN_4_01", 4 | "OPERTN_4_02", 5 | "OPERTN_4_03", 6 | "OPERTN_4_04", 7 | "OPERTN_4_05", 8 | "OPERTN_4_06", 9 | "OPERTN_4_07", 10 | "OPERTN_4_08", 11 | "OPERTN_4_09", 12 | "OPERTN_4_10", 13 | "OPERTN_4_11", 14 | "OPERTN_4_12", 15 | "OPERTN_4_13", 16 | "OPERTN_4_14", 17 | "OPERTN_4_15", 18 | "OPERTN_4_16", 19 | "OPERTN_4_17", 20 | "OPERTN_4_18", 21 | "OPERTN_4_19", 22 | "OPERTN_4_20", 23 | "OPERTN_4_21", 24 | "OPERTN_4_22", 25 | "OPERTN_4_23", 26 | "OPERTN_4_24", 27 | ] 28 | 29 | diag_4_fields = [ 30 | "DIAG_4_01", 31 | "DIAG_4_02", 32 | "DIAG_4_03", 33 | "DIAG_4_04", 34 | "DIAG_4_05", 35 | "DIAG_4_06", 36 | "DIAG_4_07", 37 | "DIAG_4_08", 38 | "DIAG_4_09", 39 | "DIAG_4_10", 40 | "DIAG_4_11", 41 | "DIAG_4_12", 42 | "DIAG_4_13", 43 | "DIAG_4_14", 44 | "DIAG_4_15", 45 | "DIAG_4_16", 46 | "DIAG_4_17", 47 | "DIAG_4_18", 48 | "DIAG_4_19", 49 | "DIAG_4_20", 50 | ] 51 | 52 | diag3_fields = [ 53 | "DIAG3_01", 54 | "DIAG3_02", 55 | "DIAG3_03", 56 | "DIAG3_04", 57 | "DIAG3_05", 58 | "DIAG3_06", 59 | "DIAG3_07", 60 | "DIAG3_08", 61 | "DIAG3_09", 62 | "DIAG3_10", 63 | "DIAG3_11", 64 | "DIAG3_12", 65 | ] 66 | 67 | treat3_fields = [ 68 | "TREAT3_01", 69 | "TREAT3_02", 70 | "TREAT3_03", 71 | "TREAT3_04", 72 | "TREAT3_05", 73 | "TREAT3_06", 74 | "TREAT3_07", 75 | "TREAT3_08", 76 | "TREAT3_09", 77 | "TREAT3_10", 78 | "TREAT3_11", 79 | "TREAT3_12", 80 | ] -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/coded_fields.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./coded_field_lists 3 | 4 | # COMMAND ---------- 5 | 6 | from typing import Iterable, Union, List, Tuple 7 | from pyspark.sql import Column, functions as F 8 | 9 | 10 | def _concat_coded_fields(*fields: Union[Iterable[str], Iterable[Column]]) -> Column: 11 | coalesced_fields = [F.coalesce(field, F.lit('')) for field in fields] 12 | concat_field_raw = F.concat_ws(",", *coalesced_fields) 13 | concat_field = F.regexp_replace(concat_field_raw, ',*$', '') 14 | return concat_field 15 | 16 | 17 | def _count_coded_fields(*fields: Iterable[str]) -> Column: 18 | count_field_raw = F.expr(f"cardinality(filter(array({','.join(fields)}), x -> x is not null AND x not in ('&', '-')))") 19 | count_field = ( 20 | F.when(count_field_raw > 24, 24) 21 | .when(count_field_raw > 0, count_field_raw) 22 | .otherwise(None) 23 | ) 24 | return count_field 25 | 26 | 27 | def _get_opertn_4_fields() -> List[str]: 28 | return opertn_4_fields # Global constant from coded_field_lists 29 | 30 | 31 | def get_opertn_34_fields() -> List[Tuple[str, Column]]: 32 | opertn_4_fields = _get_opertn_4_fields() 33 | opertn_3_fields = [(f"OPERTN_3_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(opertn_4_fields)] 34 | opertn_3_concat = ("OPERTN_3_CONCAT", _concat_coded_fields(*next(zip(*opertn_3_fields)))) # Extract just the column defintions 35 | opertn_4_concat = ("OPERTN_4_CONCAT", _concat_coded_fields(*opertn_4_fields)) 36 | opertn_count = ("OPERTN_COUNT", _count_coded_fields(*opertn_4_fields)) 37 | return [ 38 | *opertn_3_fields, 39 | opertn_3_concat, 40 | opertn_4_concat, 41 | opertn_count, 42 | ] 43 | 44 | 45 | def _get_diag_4_fields(limit: int=20) -> List[str]: 46 | # NOTE: OP and APC have a different number of DIAG fields 47 | return diag_4_fields[:limit] # Global constant from coded_field_lists 48 | 49 | 50 | def get_diag_34_fields(limit: int=20) -> List[Tuple[str, Column]]: 51 | diag_4_fields = _get_diag_4_fields(limit=limit) 52 | diag_3_fields = [(f"DIAG_3_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(diag_4_fields)] 53 | diag_3_concat = ("DIAG_3_CONCAT", _concat_coded_fields(*next(zip(*diag_3_fields)))) # TODO: replace null 0th element with "R69" before concat 54 | diag_4_concat = ("DIAG_4_CONCAT", _concat_coded_fields(*diag_4_fields)) # TODO: replace null 0th element with "R69X" before concat 55 | diag_count = ("DIAG_COUNT", _count_coded_fields(*diag_4_fields)) 56 | return [ 57 | *diag_3_fields, 58 | diag_3_concat, 59 | diag_4_concat, 60 | diag_count, 61 | ] 62 | 63 | # COMMAND ---------- 64 | 65 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/derivations_helpers.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./coded_fields 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./sequential_fields 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./age_fields 11 | 12 | # COMMAND ---------- 13 | 14 | from typing import List, Tuple 15 | from pyspark.sql import Column 16 | 17 | 18 | def _get_apc_derivations() -> List[Tuple[str, Column]]: 19 | startage_calc = ("STARTAGE_CALC", get_fractional_age_field("ADMIDATE", mydob_date)) 20 | startage = ("STARTAGE", get_categorized_age_field("ADMIDATE", mydob_date)) 21 | opertn_34_fields = get_opertn_34_fields() 22 | diag_34_fields = get_diag_34_fields() 23 | apc_sequential_fields = get_apc_sequential_fields() 24 | return [ 25 | startage_calc, 26 | startage, 27 | *opertn_34_fields, 28 | *diag_34_fields, 29 | *apc_sequential_fields, 30 | ] 31 | 32 | 33 | def _get_op_derivations() -> List[Tuple[str, Column]]: 34 | apptage_calc = ("APPTAGE_CALC", get_fractional_from_categorized_age_field("APPTAGE")) 35 | opertn_34_fields = get_opertn_34_fields() 36 | diag_34_fields = get_diag_34_fields(limit=12) # OP only has 12 DIAG fields 37 | op_sequential_fields = get_op_sequential_fields() 38 | return [ 39 | apptage_calc, 40 | *opertn_34_fields, 41 | *diag_34_fields, 42 | *op_sequential_fields, 43 | ] 44 | 45 | 46 | def _get_ae_derivations() -> List[Tuple[str, Column]]: 47 | arrivalage_calc = ("ARRIVALAGE_CALC", get_fractional_from_categorized_age_field("ARRIVALAGE")) 48 | diag2_fields = [(f"DIAG2_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(diag3_fields)] 49 | treat2_fields = [(f"TREAT2_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(treat3_fields)] 50 | ae_sequential_fields = get_ae_sequential_fields() 51 | return [ 52 | *diag2_fields, 53 | *treat2_fields, 54 | arrivalage_calc, 55 | *ae_sequential_fields, 56 | ] 57 | 58 | 59 | def get_derivations(dataset: str) -> List[Tuple[str, Column]]: 60 | derivation_functions = { 61 | "hes_ae": _get_ae_derivations, 62 | "hes_apc": _get_apc_derivations, 63 | "hes_op": _get_op_derivations, 64 | } 65 | _get_derivations = derivation_functions.get(dataset, lambda: []) 66 | derived_fields = _get_derivations() 67 | return derived_fields 68 | 69 | 70 | # COMMAND ---------- 71 | 72 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_field_lists.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # NOTE: "RTTPERSTART" & "RTTPEREND" appear in data dictionary but not APC, OP, AE tables for DAE version of the HES asset 3 | APC_SEQUENTIAL_DATE_FIELDS = ["ELECDATE", "ADMIDATE", "EPISTART", "EPIEND", "DISDATE"] # TODO: confirm the ordering 4 | AE_SEQUENTIAL_DATE_FIELDS = [] 5 | OP_SEQUENTIAL_DATE_FIELDS = [] -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_fields.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./sequential_field_lists 3 | 4 | # COMMAND ---------- 5 | 6 | from typing import Iterable, List, Tuple 7 | from pyspark.sql import Column, functions as F 8 | 9 | 10 | def sort_fields_l2r(*fields: Iterable[str]) -> List[Tuple[str, Column]]: # TODO: pull this out into common functions 11 | sorted_array = F.array_sort(F.array(*fields)) 12 | sorted_value_fields = [(x, F.element_at(sorted_array, i+1)) for i, x in enumerate(fields)] 13 | return sorted_value_fields 14 | 15 | 16 | def get_apc_sequential_fields() -> List[Tuple[str, Column]]: 17 | sorted_fields = sort_fields_l2r(*APC_SEQUENTIAL_DATE_FIELDS) 18 | return sorted_fields 19 | 20 | 21 | def get_op_sequential_fields() -> List[Tuple[str, Column]]: 22 | sorted_fields = sort_fields_l2r(*OP_SEQUENTIAL_DATE_FIELDS) 23 | return sorted_fields 24 | 25 | 26 | def get_ae_sequential_fields() -> List[Tuple[str, Column]]: 27 | sorted_fields = sort_fields_l2r(*AE_SEQUENTIAL_DATE_FIELDS) 28 | return sorted_fields -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_field_lists.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | APC_ID_FIELDS = [ 3 | "SUSRECID", 4 | "PSEUDO_HESID", 5 | "AEKEY", 6 | "EPIKEY", 7 | ] 8 | 9 | OP_ID_FIELDS = [ 10 | "PSEUDO_HESID", 11 | "ATTENDID", 12 | "ATTENDKEY", 13 | "PREFERER", 14 | ] 15 | 16 | AE_ID_FIELDS = [ 17 | "AEKEY", 18 | "PSEUDO_HESID", 19 | "EPIKEY" 20 | ] -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_field_regex_patterns.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | ID_FIELD_PATTERNS = dict( 3 | PSEUDO_HESID = r"TEST[0-9a-zA-Z]{28}", # 32an - first 4 chars = TEST to ensure no overlap with real IDs 4 | SUSRECID = r"\d{14}", 5 | AEKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 6 | EPIKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 7 | ATTENDNO = r"[0-9a-zA-Z]{12}", 8 | ATTENDKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 9 | ATTENDID = r"[0-9a-zA-Z]{12}", 10 | PREFERER = r"[0-9a-zA-Z]{16}", # What about nulls (&) / invalids (99) 11 | ) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_fields.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../../dependencies/spark_rstr 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../../dataset_definitions/relational_helpers 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./id_field_regex_patterns 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./id_field_lists 15 | 16 | # COMMAND ---------- 17 | 18 | from typing import Iterable, Dict, List 19 | 20 | from pyspark.sql import DataFrame, Column, functions as F, Window as W 21 | 22 | # region Prevent downstream linter highlights 23 | get_parent_table = get_parent_table 24 | SparkRstr = SparkRstr 25 | ID_FIELD_PATTERNS = ID_FIELD_PATTERNS 26 | APC_ID_FIELDS = APC_ID_FIELDS 27 | AE_ID_FIELDS = AE_ID_FIELDS 28 | OP_ID_FIELDS = OP_ID_FIELDS 29 | Table = Table 30 | # endregion 31 | 32 | 33 | def _get_dataset_id_fields(field_names: Iterable[str]) -> Dict[str, Column]: 34 | rstr = SparkRstr() 35 | dataset_id_fields = set(field_names).intersection(ID_FIELD_PATTERNS.keys()) 36 | return {field: rstr.xeger(ID_FIELD_PATTERNS[field]).alias(field) for field in dataset_id_fields} 37 | 38 | 39 | def _get_apc_id_fields() -> Column: 40 | return _get_dataset_id_fields(APC_ID_FIELDS) 41 | 42 | 43 | def _get_ae_id_fields() -> Column: 44 | return _get_dataset_id_fields(AE_ID_FIELDS) 45 | 46 | 47 | def _get_op_id_fields() -> Column: 48 | return _get_dataset_id_fields(OP_ID_FIELDS) 49 | 50 | 51 | def get_id_fields(dataset: str) -> List[Column]: 52 | id_fields_functions = { 53 | "hes_ae": _get_ae_id_fields, 54 | "hes_apc": _get_apc_id_fields, 55 | "hes_op": _get_op_id_fields, 56 | } 57 | _get_id_fields = id_fields_functions.get(dataset, lambda: []) 58 | id_fields = _get_id_fields() 59 | return id_fields 60 | 61 | 62 | def with_hes_id_fields(df: DataFrame, dataset: str, table: Table, index_fields: Iterable[str]) -> DataFrame: 63 | demographic_table = get_parent_table(table) 64 | demographic_field_names = [f.name for f in demographic_table.fields] 65 | 66 | # Split demographic and episode id fields so that demographic ids can be carried over 67 | id_fields = get_id_fields(dataset) 68 | demographic_id_fields = [id_fields.pop(field_name) for field_name in demographic_field_names if field_name in id_fields] 69 | episode_id_fields = [field_value_col for _, field_value_col in id_fields.items()] 70 | 71 | demographic_id_window = W.partitionBy(index_fields[0]).orderBy(index_fields[1]) 72 | 73 | df_with_id_fields = ( 74 | df 75 | .withColumn("DEMOGRAPHIC_IDS", F.first(F.struct(*demographic_id_fields)).over(demographic_id_window)) 76 | .select(*df.columns, "DEMOGRAPHIC_IDS.*", *episode_id_fields) 77 | .drop(*index_fields) 78 | ) 79 | 80 | return df_with_id_fields -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/driver_imports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../common/common_exports 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../dataset_definitions/relational_helpers 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../generator_stages/field_generators 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ../generator_stages/relationship_generator 15 | 16 | # COMMAND ---------- 17 | 18 | # MAGIC %run ../generator_stages/derivations -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_pipelines/driver_parameters.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | import json 3 | from pprint import pprint 4 | 5 | # Prevent spark from auto broadcasting the metadata 6 | spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) 7 | 8 | ARTIFICIAL_TABLE_NAME_TEMPLATE = "artificial_{table_name}" 9 | 10 | # COMMAND ---------- 11 | 12 | dbutils.widgets.removeAll() 13 | 14 | # COMMAND ---------- 15 | 16 | dbutils.widgets.text("meta_database", "iuod_artificial_data_generator", "0.1 Meta DB") 17 | dbutils.widgets.text("artificial_database", "iuod_artificial_data_generator", "0.2 Artificial DB") 18 | dbutils.widgets.text("parameters_json", "", "1.1 Parameters JSON") 19 | 20 | # COMMAND ---------- 21 | 22 | # Config 23 | notebook_parameters = json.loads(get_required_argument("parameters_json")) 24 | n_patients = notebook_parameters['n_patients'] 25 | artificial_database = get_required_argument("artificial_database") 26 | meta_database = get_required_argument("meta_database") 27 | 28 | print("Running with parameters: ") 29 | pprint({ 30 | "n_patients": n_patients, 31 | "artificial_database": artificial_database, 32 | "meta_database": meta_database, 33 | }) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_stages/derivations.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import Iterable, Tuple 3 | from pyspark.sql import DataFrame, Column 4 | 5 | 6 | def with_derived_fields(df: DataFrame, *derived_fields: Iterable[Tuple[str, Column]]) -> DataFrame: 7 | """Add columns to a dataframe, with values possibly derived from existing columns. 8 | 9 | Args: 10 | df (DataFrame): DataFrame to add fields to 11 | *derived_fields (Iterable[Tuple[str, Column]]): Pairs of field_name, column_spec defining the columns to add 12 | 13 | Returns: 14 | DataFrame: New DataFrame with additional derived fields 15 | 16 | Notes: 17 | If it is not possible to add a particular column due to an exception, then a printout will appear 18 | and the column will be skipped. The resulting dataframe will not have this column added to it. 19 | """ 20 | for field_name, col_spec in derived_fields: 21 | try: 22 | df = df.withColumn(field_name, col_spec) 23 | except Exception as e: 24 | # TODO: replace generic Exception 25 | print(f"Failed to derive field `{field_name}`: {e}") 26 | continue 27 | 28 | return df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_stages/relationship_generator.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./field_generators 3 | 4 | # COMMAND ---------- 5 | 6 | from typing import Iterable 7 | from pyspark.sql import DataFrame, functions as F 8 | 9 | # COMMAND ---------- 10 | 11 | def relationship_generator(meta_df: DataFrame, base_df: DataFrame, variable_fields: Iterable[str]) -> DataFrame: 12 | """Generate a dataframe of IDs according to the cardinal relationships between fields described by the metadata. 13 | 14 | Args: 15 | meta_df (pyspark.sql.DataFrame): Contains relationship metadata. This is a frequency distribution of 16 | cardinality, weight pairs - the weight for a given cardinality describes how likely it is the relationship 17 | between fields has that cardinality in the real data 18 | base_df (pyspark.sql.DataFrame): DataFrame containing a single column of IDs 19 | variable_fields (Iterable[str]): Name of fields in meta_df describing the variables to which the values relate 20 | 21 | Returns: 22 | pyspark.sql.DataFrame: Pairs of values for ID fields with the cardinal relationships described by the metadata. 23 | The generated IDs are stored in a column called 'VALUE_NUMERIC' 24 | 25 | """ 26 | _relationship_generator = create_frequency_field_generator("RELATIONSHIP", "VALUE_NUMERIC") 27 | generated_df = _relationship_generator(meta_df, base_df, variable_fields) 28 | 29 | # Assign array of ids with length based on sampled value for the mapping 30 | # ids are created in chunks to avoid out of memory exceptions for large mapping values 31 | chunk_size = 100 32 | step_expr = f"int(least(greatest(1, VALUE_NUMERIC), {chunk_size}))" 33 | # Use the output ratio threshold to place a cap on the maximum cardinality that will be generated 34 | # If this cap is exceeded then the executors will fall over! 35 | output_ratio_threshold = spark.conf.get("spark.databricks.queryWatchdog.outputRatioThreshold") 36 | max_lower_bound_expr = f"int(least(coalesce(VALUE_NUMERIC, 0), {output_ratio_threshold}))" 37 | 38 | # Define the lower bounds for the id ranges 39 | id_lower_bound_col = F.expr(f"sequence(0, {max_lower_bound_expr}, {step_expr})") 40 | 41 | # Zip the bounds to get pairs of (lower, upper) for the ranges to generate 42 | id_bounds_col = F.expr("transform(slice(ID_LOWER_BOUNDS, 1, size(ID_LOWER_BOUNDS)-1), (x, i) -> struct(x AS LOWER, ID_LOWER_BOUNDS[i+1] AS UPPER))") 43 | 44 | generated_df = ( 45 | generated_df 46 | .withColumn("ID_LOWER_BOUNDS", id_lower_bound_col) 47 | .withColumn("ID_BOUNDS", id_bounds_col) 48 | .select("*", F.expr("inline(ID_BOUNDS)")) 49 | .withColumn("VALUE_NUMERIC", F.expr("explode(sequence(LOWER, UPPER-1, 1))")) 50 | ) 51 | 52 | return generated_df 53 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_stages/sampling/cdf_sampling.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./field_definitions 3 | 4 | # COMMAND ---------- 5 | 6 | from typing import Iterable 7 | from pyspark.sql import DataFrame, functions as F 8 | 9 | # COMMAND ---------- 10 | 11 | def sample_cdf( 12 | meta_df: DataFrame, 13 | base_df: DataFrame, 14 | variable_fields: Iterable[str], 15 | value_field: str, 16 | frequency_field: str, 17 | bin_size: float=0.05 18 | ) -> DataFrame: 19 | """Generate data by sampling from a cumulative distribution over each field in the metadata. 20 | 21 | Args: 22 | meta_df (pyspark.sql.DataFrame): Contains metadata 23 | base_df (pyspark.sql.DataFrame): Contains rows to generate samples for 24 | variable_fields (Iterable[str]): Field labels, where each field has its own CDF to sample 25 | value_field (str): Values in the distribution (e.g. categories, bins etc) 26 | frequency_field (str): Frequencies to sum to compute the CDF 27 | bin_size (float (default = 0.05)): Bin size in the range_join, used for performance tuning 28 | 29 | Returns: 30 | pyspark.sql.DataFrame: Generated data 31 | """ 32 | cumsum_col = cumsum_over_window(frequency_field, variable_fields, (frequency_field, value_field)) 33 | total_col = total_over_window(frequency_field, variable_fields) 34 | cumdist_df = ( 35 | meta_df 36 | .withColumn("CUMSUM", cumsum_col) 37 | .withColumn("TOTAL", total_col) 38 | .withColumn("CDF_BIN", get_cdf_bin(frequency_field=frequency_field)) 39 | ) 40 | 41 | # Select the sampled values from the cumulative distribution 42 | # Generate a random number and use to do a range join on the cumulative distribution 43 | # The sampled value is the one where the random number falls within the bin of the CDF corresponding 44 | # to that value 45 | generated_df = ( 46 | base_df 47 | .withColumn("RAND", F.rand()) 48 | .alias("left") 49 | .hint("range_join", bin_size) 50 | .join( 51 | cumdist_df.alias("right"), 52 | how="inner", 53 | on=[ 54 | *[F.col(f"left.{field}") == F.col(f"right.{field}") for field in variable_fields], 55 | F.col("right.CDF_BIN.LOWER") <= F.col("left.RAND"), 56 | F.col("left.RAND") < F.col("right.CDF_BIN.UPPER"), 57 | ] 58 | ) 59 | ) 60 | 61 | # Clean up 62 | ignore_fields = ["CUMSUM", "TOTAL", "CDF_BIN", "RAND", *[f"right.{field}" for field in variable_fields]] 63 | for field in ignore_fields: 64 | generated_df = generated_df.drop(F.col(field)) 65 | 66 | return generated_df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/generator_stages/sampling/field_definitions.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import Iterable, Union 3 | from pyspark.sql import Column, functions as F, Window as W 4 | 5 | 6 | def cumsum_over_window( 7 | sum_col: Column, 8 | partition_cols: Union[Iterable[str], Iterable[Column]], 9 | order_cols: Union[Iterable[str], Iterable[Column]], 10 | ) -> Column: 11 | """Calculate the cumulative sum of a column over a window. 12 | 13 | Args: 14 | sum_col (pyspark.sql.Column): Column to sum over the partitions 15 | partition_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Partitions to sum within 16 | order_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Ordering for the cumulative sum 17 | 18 | Returns: 19 | pyspark.sql.Column: Column of within-partition cumulative sums 20 | """ 21 | w = ( 22 | W.partitionBy(*partition_cols) 23 | .orderBy(*order_cols) # Ordering by value will be slow! 24 | .rangeBetween(W.unboundedPreceding, 0) 25 | ) 26 | return F.sum(sum_col).over(w) 27 | 28 | 29 | def total_over_window(sum_col: Column, partition_cols: Union[Iterable[str], Iterable[Column]]) -> Column: 30 | """Calculate the total sum of a column over a window. 31 | 32 | Parameters: 33 | sum_col (pyspark.sql.Column): Column to sum over the partitions 34 | partition_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Partitions to sum within 35 | 36 | Returns: 37 | pyspark.sql.Column: Column of within-partition sums 38 | 39 | """ 40 | w = ( 41 | W.partitionBy(*partition_cols) 42 | .rangeBetween(W.unboundedPreceding, W.unboundedFollowing) 43 | ) 44 | return F.sum(sum_col).over(w) 45 | 46 | 47 | def get_cdf_bin( 48 | frequency_field: str="WEIGHT", 49 | cumsum_field: str="CUMSUM", 50 | total_field: str="TOTAL" 51 | ) -> Column: 52 | lower_bound_col = (F.col(cumsum_field) - F.col(frequency_field)) / F.col(total_field) 53 | upper_bound_col = F.col(cumsum_field) / F.col(total_field) 54 | cdf_bin_col = F.struct( 55 | lower_bound_col.alias("LOWER"), 56 | upper_bound_col.alias("UPPER") 57 | ) 58 | return cdf_bin_col -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/driver_imports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../common/common_exports 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../scraper_stages/scrape_metadata 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../dataset_definitions/relational_helpers 11 | 12 | # COMMAND ---------- 13 | 14 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/constants/disclosure_control_parameters.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | HES_FREQUENCY_ROUNDING = 5 3 | HES_MIN_FREQUENCY = 5 -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/constants/excluded_fields.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | ID_FIELDS = [ 3 | "PSEUDO_HESID", 4 | "SUSRECID", 5 | "AEKEY", 6 | "EPIKEY", 7 | "ATTENDNO", 8 | "ATTENDKEY", 9 | "ATTENDID", 10 | "PREFERER", 11 | ] 12 | 13 | OTHER_SENSITIVE_FIELDS = [] 14 | 15 | DERIVED_FIELDS = [ 16 | *[f"OPERTN_3_{i+1:02}" for i in range(24)], # Derived from OPERTN_4 (APC & OP) 17 | *[f"DIAG_3_{i+1:02}" for i in range(20)], # Derived from DIAG_4 (APC & OP) 18 | *[f"DIAG2_{i+1:02}" for i in range(12)], # Derived from DIAG3 (AE) 19 | *[f"TREAT2_{i+1:02}" for i in range(12)], # Derived from TREAT3 (AE) 20 | "OPERTN_3_CONCAT", 21 | "OPERTN_4_CONCAT", 22 | "OPERTN_COUNT", 23 | "DIAG_3_CONCAT", 24 | "DIAG_4_CONCAT", 25 | "DIAG_COUNT", 26 | "STARTAGE_CALC", # Derived from MYDOB and ADMIDATE (APC) 27 | "STARTAGE", # Derived from MYDOB and ADMIDATE (APC) 28 | "APPTAGE_CALC", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! Derive from APPTAGE (OP) 29 | # "APPTAGE", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! (OP) 30 | "ARRIVALAGE_CALC", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! Derive from ARRIVALAGE (AE) 31 | # "ARRIVALAGE", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! (AE) 32 | ] 33 | 34 | EXCLUDED_FIELDS = [ 35 | *ID_FIELDS, 36 | *OTHER_SENSITIVE_FIELDS, 37 | *DERIVED_FIELDS, 38 | ] 39 | 40 | # COMMAND ---------- 41 | 42 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/demographic_aggregation_parameters.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../scraper_stages/aggregation/aggregation_parameters 3 | 4 | # COMMAND ---------- 5 | 6 | from pyspark.sql import functions as F 7 | 8 | # region Prevent downstream linting highlights 9 | AggregationParametersType = AggregationParametersType 10 | get_frequency_aggregation_params = get_frequency_aggregation_params 11 | BASE_GROUPING_FIELDS = BASE_GROUPING_FIELDS 12 | weight_col = weight_col 13 | # endregion 14 | 15 | 16 | def get_demographic_categorical_aggregation_params( 17 | patient_index_field: str, 18 | frequency_rounding: int, 19 | min_frequency: int 20 | ) -> AggregationParametersType: 21 | """Get parameters defining the aggregations required to summarise 'DEMOGRAPHIC_CATEGORICAL' 22 | fields in long-formatted data. The result is a frequency distribution over weighted vectors, 23 | where the inner weights describe the proportion of events that a patient has recorded a given value 24 | and the outer weights indicate the number of patients with that pattern of inner weights across 25 | the dataset. 26 | 27 | Note: this logic is hes-specific. 28 | 29 | Args: 30 | patient_index_field (str): Field that indexes patients 31 | frequency_rounding (int): Rounding to apply to frequencies 32 | min_frequency (int): Cutoff below which frequencies are removed 33 | 34 | Returns: 35 | AggregationParametersType: Parameters defining the aggregations 36 | """ 37 | # Aggregation 1: count total rows and distinct values per distinct value per patient 38 | params1 = dict( 39 | grouping_fields=[ 40 | *BASE_GROUPING_FIELDS, 41 | patient_index_field, 42 | ], 43 | agg_cols=[ 44 | F.collect_list(F.coalesce("VALUE_STRING", F.lit("null"))).alias("VALUE_STRING_ARRAY"), # Need nulls to be strings for aggregation 45 | weight_col, 46 | ], 47 | derivations=[ 48 | F.array_sort(F.array_distinct("VALUE_STRING_ARRAY")).alias("DISTINCT_VALUE_STRING_ARRAY"), 49 | F.expr("transform(DISTINCT_VALUE_STRING_ARRAY, x -> round(size(filter(VALUE_STRING_ARRAY, y -> y == x)) / WEIGHT, 1))").alias("WEIGHT_ARRAY"), # Normalised weight per value 50 | F.expr("transform(DISTINCT_VALUE_STRING_ARRAY, x -> CASE WHEN x != 'null' THEN x END)").alias("VALUE_STRING_ARRAY"), # Replace nulls within the array of values 51 | ] 52 | ) 53 | 54 | # Aggregation 2: count patients per distinct combinations of values and weights 55 | params2, = get_frequency_aggregation_params( 56 | "VALUE_STRING_ARRAY", 57 | frequency_rounding=frequency_rounding, 58 | min_frequency=min_frequency 59 | ) 60 | params2["grouping_fields"].append("WEIGHT_ARRAY") 61 | 62 | return [ 63 | params1, 64 | params2, 65 | ] 66 | 67 | 68 | def get_demographic_date_aggregation_params( 69 | patient_index_field: str, 70 | frequency_rounding: int, 71 | min_frequency: int 72 | ) -> AggregationParametersType: 73 | """Get parameters defining the aggregations required to summarise 'DEMOGRAPHIC_DATE' 74 | fields in long-formatted data. The result is a frequency distribution over dates 75 | with an associated incidence, which indicates the likelihood that a patient is measured with 76 | a value different from a given value. E.g. a patient's true DOB is 01/02/2022 but it 77 | could be recorded as 02/01/2022 or 01/02/2021. 78 | 79 | Note: this logic is hes-specific. 80 | 81 | Args: 82 | patient_index_field (str): Field that indexes patients 83 | frequency_rounding (int): Rounding to apply to frequencies 84 | min_frequency (int): Cutoff below which frequencies are removed 85 | 86 | Returns: 87 | AggregationParametersType: Parameters defining the aggregations 88 | """ 89 | # Aggregation 1: frequency of each value per patient 90 | params1 = dict( 91 | grouping_fields=[ 92 | *BASE_GROUPING_FIELDS, 93 | patient_index_field, 94 | "VALUE_STRING", # Demographic dates (e.g. MYDOB) are strings with date-like information 95 | ], 96 | agg_cols=[weight_col], 97 | ) 98 | 99 | # Aggregation 2: collect values and weights per per patient 100 | # The 'VALUE_NUMERIC' column for demographic dates contains the incidence which is 101 | # defined as the mean proportion of patients that have the given 'VALUE_STRING' as the 102 | # mode value and where the specific record is associated with a that mode value 103 | # (i.e. the value is recorded 'correct') 104 | params2 = dict( 105 | grouping_fields = [ 106 | *BASE_GROUPING_FIELDS, 107 | patient_index_field, 108 | ], 109 | agg_cols = [ 110 | F.collect_list("VALUE_STRING").alias("VALUE_STRING_ARRAY"), # All values 111 | F.collect_list("WEIGHT").alias("WEIGHT_ARRAY"), # All weights 112 | F.sum("WEIGHT").alias("TOTAL_WEIGHT"), # Total weight for the patient 113 | F.max("WEIGHT").alias("MAX_WEIGHT"), # Largest weight for a value 114 | ], 115 | derivations = [ 116 | F.expr("int(array_position(WEIGHT_ARRAY, MAX_WEIGHT))").alias("MAX_WEIGHT_INDEX"), 117 | F.expr("element_at(VALUE_STRING_ARRAY, MAX_WEIGHT_INDEX)").alias("VALUE_STRING"), 118 | (F.col("MAX_WEIGHT") / F.col("TOTAL_WEIGHT")).alias("VALUE_NUMERIC"), # Proportion of records for a patient equal to the mode value 119 | ], 120 | ) 121 | 122 | # Aggregation 3: count patients with each mode value and compute the average proportion of episodes where the value differs from the mode 123 | params3, = get_frequency_aggregation_params( 124 | "VALUE_STRING", 125 | frequency_rounding=frequency_rounding, 126 | min_frequency=min_frequency 127 | ) 128 | params3["agg_cols"].append(F.mean("VALUE_NUMERIC").alias("VALUE_NUMERIC")) # Average deviation from the mode value 129 | 130 | return [ 131 | params1, 132 | params2, 133 | params3, 134 | ] 135 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/demographic_field_summarisers.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../dataset_definitions/hes/hes_patient_table 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../scraper_stages/aggregation/summariser_factory 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./demographic_aggregation_parameters 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./constants/disclosure_control_parameters 15 | 16 | # COMMAND ---------- 17 | 18 | from typing import List, Callable 19 | from pyspark.sql import DataFrame 20 | 21 | # region Prevent downstream linting highlights 22 | HES_FREQUENCY_ROUNDING = HES_FREQUENCY_ROUNDING 23 | HES_MIN_FREQUENCY = HES_MIN_FREQUENCY 24 | HES_PATIENT_KEY = HES_PATIENT_KEY 25 | get_demographic_categorical_aggregation_params = get_demographic_categorical_aggregation_params 26 | get_demographic_date_aggregation_params = get_demographic_date_aggregation_params 27 | create_summariser = create_summariser 28 | # endregion 29 | 30 | 31 | def demographic_categorical_summariser(long_df: DataFrame) -> DataFrame: 32 | """Summarise fields in a long-formatted dataframe which have VALUE_TYPE='DEMOGRAPHIC_CATEGORICAL' 33 | These are fields such as sex or ethnicity, which in principle have a single underlying value for a given 34 | patient, but which in practice can have multiple values recorded across different visits. We want to 35 | sample in such a way that the patterns of 'errors' that we would see in real data are represented in 36 | artificial data. 37 | 38 | Note: this logic is hes-specific. 39 | 40 | Args: 41 | long_df (DataFrame): Long-formatted data to summarise 42 | 43 | Returns: 44 | DataFrame: Summarised data (long-formatted) 45 | """ 46 | aggregation_params = get_demographic_categorical_aggregation_params( 47 | HES_PATIENT_KEY, 48 | frequency_rounding=HES_FREQUENCY_ROUNDING, 49 | min_frequency=HES_MIN_FREQUENCY 50 | ) 51 | summariser = create_summariser("DEMOGRAPHIC_CATEGORICAL", aggregation_params) 52 | return summariser(long_df) 53 | 54 | 55 | def demographic_date_summariser(long_df: DataFrame) -> DataFrame: 56 | """Summarise fields in a long-formatted dataframe which have VALUE_TYPE='DEMOGRAPHIC_DATE' 57 | These are fields such as date of birth, which in principle have a single underlying value for a given 58 | patient, but which in practice can have multiple values recorded across different visits. We want to 59 | sample in such a way that the patterns of 'errors' that we would see in real data are represented in 60 | artificial data. 61 | 62 | Note: this logic is hes-specific. 63 | 64 | Args: 65 | long_df (DataFrame): Long-formatted data to summarise 66 | 67 | Returns: 68 | DataFrame: Summarised data (long-formatted) 69 | """ 70 | aggregation_params = get_demographic_date_aggregation_params( 71 | HES_PATIENT_KEY, 72 | frequency_rounding=HES_FREQUENCY_ROUNDING, 73 | min_frequency=HES_MIN_FREQUENCY 74 | ) 75 | summariser = create_summariser("DEMOGRAPHIC_DATE", aggregation_params) 76 | return summariser(long_df) 77 | 78 | 79 | def get_demographic_summarisers() -> List[Callable[[DataFrame], DataFrame]]: 80 | return [ 81 | demographic_categorical_summariser, 82 | demographic_date_summariser, 83 | ] 84 | 85 | # COMMAND ---------- 86 | 87 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/driver.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../dataset_definitions/hes/hes_tables 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../driver_imports 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./constants/type_overrides 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./constants/excluded_fields 15 | 16 | # COMMAND ---------- 17 | 18 | # MAGIC %run ./constants/disclosure_control_parameters 19 | 20 | # COMMAND ---------- 21 | 22 | # MAGIC %run ./demographic_field_summarisers 23 | 24 | # COMMAND ---------- 25 | 26 | from pprint import pprint 27 | import json 28 | from typing import List, Dict 29 | 30 | from pyspark.sql import functions as F 31 | 32 | 33 | # region Prevent downstream linting highlights 34 | spark = spark 35 | dbutils = dbutils 36 | get_required_argument = get_required_argument 37 | Table = Table 38 | EXCLUDED_FIELDS = EXCLUDED_FIELDS 39 | create_table = create_table 40 | insert_into = insert_into 41 | get_type_overrides = get_type_overrides 42 | table_exists = table_exists 43 | wide_to_long = wide_to_long 44 | get_foreign_keys = get_foreign_keys 45 | scrape_fields = scrape_fields 46 | HES_TABLES = HES_TABLES 47 | HES_FREQUENCY_ROUNDING = HES_FREQUENCY_ROUNDING 48 | HES_MIN_FREQUENCY = HES_MIN_FREQUENCY 49 | get_demographic_summarisers = get_demographic_summarisers 50 | scrape_relationships = scrape_relationships 51 | ingest = ingest 52 | # endregion 53 | 54 | # COMMAND ---------- 55 | 56 | dbutils.widgets.removeAll() 57 | 58 | # COMMAND ---------- 59 | 60 | dbutils.widgets.text("hes_database", "hes", "0.1 HES DB") 61 | dbutils.widgets.text("meta_database", "artificial_hes_meta", "0.2 Meta DB") 62 | 63 | # COMMAND ---------- 64 | 65 | # Config 66 | hes_database = get_required_argument("hes_database") 67 | meta_database = get_required_argument("meta_database") 68 | 69 | print("Running with parameters: ") 70 | pprint({ 71 | "hes_database": hes_database, 72 | "meta_database": meta_database, 73 | }) 74 | 75 | # COMMAND ---------- 76 | 77 | def scrape_hes_metadata( 78 | hes_table: Table, 79 | source_database: str, 80 | target_database: str, 81 | type_overrides: List[Dict[str, str]], 82 | overwrite=False 83 | ) -> bool: 84 | hes_df = ingest(source_database, hes_table.qualifier) 85 | 86 | # TODO: explain this 87 | hes_table_foreign_key = next(get_foreign_keys(hes_table)).name # PSEUDO_HESID 88 | included_fields = set(hes_df.columns).difference(EXCLUDED_FIELDS).union([hes_table_foreign_key]) 89 | hes_included_df = hes_df.select(*included_fields) 90 | 91 | field_meta_df = scrape_fields( 92 | hes_included_df, 93 | type_overrides, 94 | frequency_rounding=HES_FREQUENCY_ROUNDING, 95 | min_frequency=HES_MIN_FREQUENCY, 96 | dataset_summarisers=get_demographic_summarisers(), 97 | unpivoted_fields=[hes_table_foreign_key] 98 | ) 99 | 100 | # Check exclusions 101 | assert field_meta_df.filter(F.col("FIELD_NAME").isin(EXCLUDED_FIELDS)).first() is None, "Scraped fields which should have been excluded!" 102 | 103 | relational_meta_df = scrape_relationships(hes_df, hes_table) 104 | 105 | meta_df = field_meta_df.unionByName(relational_meta_df) 106 | 107 | # Write output 108 | output_args = (spark, meta_df, target_database, "artificial_hes_meta") 109 | if overwrite: 110 | success = create_table( 111 | *output_args, 112 | format="delta", 113 | mode="overwrite", 114 | overwriteSchema="true", # This doesn't work on ACL clusters! 115 | partitionBy=["TABLE_NAME", "FIELD_NAME"] 116 | ) 117 | else: 118 | success = insert_into(*output_args, overwrite=False) 119 | 120 | return success 121 | 122 | 123 | # COMMAND ---------- 124 | 125 | type_overrides = get_type_overrides() 126 | overwrite = True # Switch to False after first overwrite so that table gets appended 127 | 128 | # Generate artificial data for each table 129 | for hes_dataset, hes_dataset_tables in HES_TABLES.items(): 130 | print(f"\nExtracting metadata for dataset '{hes_dataset}'") 131 | 132 | for hes_table in hes_dataset_tables: 133 | if not table_exists(spark, hes_database, hes_table.name): 134 | print(f"\tTable `{hes_database}.{hes_table.name}` does not exist: skipping") 135 | continue 136 | else: 137 | print(f"\tExtracting metadata for table `{hes_database}.{hes_table.name}`") 138 | 139 | success = scrape_hes_metadata(hes_table, hes_database, meta_database, type_overrides, overwrite=overwrite) 140 | 141 | # Switch to prevent overwriting current outputs 142 | overwrite = False 143 | 144 | spark.sql(f"OPTIMIZE {meta_database}.artificial_hes_meta").show() 145 | 146 | # COMMAND ---------- 147 | 148 | notebook_response = {} 149 | 150 | if success: 151 | notebook_response["status"] = "success" 152 | notebook_response["message"] = "" 153 | 154 | else: 155 | notebook_response["status"] = "failed" 156 | notebook_response["message"] = "Failed to write metadata" 157 | 158 | dbutils.notebook.exit(json.dumps(notebook_response)) 159 | 160 | # COMMAND ---------- 161 | 162 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/aggregation_parameters.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from typing import Dict, List, Union 3 | from pyspark.sql import functions as F, Column 4 | 5 | 6 | BASE_GROUPING_FIELDS = ["TABLE_NAME", "FIELD_NAME", "VALUE_TYPE"] 7 | weight_col = F.count(F.lit(1)).cast("double").alias("WEIGHT") 8 | AggregationParametersType = List[Dict[str, Union[List[str], List[Column]]]] 9 | 10 | 11 | def round_to_nearest_interval(col: Column, interval: int=5) -> Column: 12 | """Build a new pyspark column by rounding a given column to a specified interval. 13 | 14 | Args: 15 | col (Column): Original column to be rounded 16 | interval (int, optional): Interval to round to. Defaults to 5. 17 | 18 | Returns: 19 | Column: Rounded column 20 | """ 21 | rounded_col = (F.round(col / interval, 0) * interval) 22 | return rounded_col 23 | 24 | 25 | def get_frequency_aggregation_params( 26 | value_field: str, 27 | frequency_rounding: int, 28 | min_frequency: int 29 | ) -> AggregationParametersType: 30 | """Get the parameters required to aggregate frequency distributions and apply 31 | disclosure controls to the resulting frequencies. 32 | 33 | Args: 34 | value_field (str): Values to calculate the frequency of 35 | frequency_rounding (int): Rounding applied to frequencies 36 | min_frequency (int): Cutoff below which frequencies / values are excluded 37 | 38 | Returns: 39 | AggregationParametersType: Parameters to do the aggregations 40 | """ 41 | 42 | # Need to make sure that filtering is based on actual weights, rather than rounded weights 43 | # Using rounded weights could let potentially disclosive values through! 44 | weight_above_cutoff_col = weight_col >= min_frequency 45 | rounded_weight_col = round_to_nearest_interval(weight_col, interval=frequency_rounding).alias("WEIGHT") 46 | 47 | return [ 48 | dict( 49 | grouping_fields=[ 50 | *BASE_GROUPING_FIELDS, 51 | value_field, 52 | ], 53 | agg_cols=[rounded_weight_col], 54 | filters=[weight_above_cutoff_col], # NOTE: if min_frequency = 1, any rows where the frequency is small enough to be rounded to 0 will be removed 55 | ) 56 | ] 57 | 58 | 59 | def get_categorical_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType: 60 | """Get the parameters required to build the frequency distributions for categorical variables and apply 61 | disclosure controls to the resulting frequencies. 62 | 63 | Args: 64 | frequency_rounding (int): Rounding applied to frequencies 65 | min_frequency (int): Cutoff below which frequencies / values are excluded 66 | 67 | Returns: 68 | AggregationParametersType: Parameters to do the aggregations 69 | """ 70 | return get_frequency_aggregation_params("VALUE_STRING", frequency_rounding, min_frequency) 71 | 72 | 73 | def get_discrete_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType: 74 | """Get the parameters required to build the frequency distributions for discrete variables and apply 75 | disclosure controls to the resulting frequencies. 76 | 77 | Args: 78 | frequency_rounding (int): Rounding applied to frequencies 79 | min_frequency (int): Cutoff below which frequencies / values are excluded 80 | 81 | Returns: 82 | AggregationParametersType: Parameters to do the aggregations 83 | """ 84 | return get_frequency_aggregation_params("VALUE_NUMERIC", frequency_rounding, min_frequency) 85 | 86 | 87 | def get_date_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType: 88 | """Get the parameters required to build the frequency distributions for date variables and apply 89 | disclosure controls to the resulting frequencies. 90 | 91 | Args: 92 | frequency_rounding (int): Rounding applied to frequencies 93 | min_frequency (int): Cutoff below which frequencies / values are excluded 94 | 95 | Returns: 96 | AggregationParametersType: Parameters to do the aggregations 97 | """ 98 | return get_frequency_aggregation_params("VALUE_DATE", frequency_rounding, min_frequency) 99 | 100 | 101 | def get_continuous_aggregation_params(): 102 | """Get the parameters required to calculate the distribution of percentiles for continuous variables 103 | for aggregation. 104 | 105 | Returns: 106 | AggregationParametersType: Parameters to do the aggregations 107 | """ 108 | # TODO: remove nulls from percentiles 109 | # TODO: incidence (also need to put in generator function) 110 | return [ 111 | dict( 112 | grouping_fields=BASE_GROUPING_FIELDS, 113 | agg_cols=[ 114 | F.array([F.expr(f"approx_percentile(VALUE_NUMERIC, {i / 100})") for i in range(2, 100)]).alias("VALUE_NUMERIC_ARRAY"), 115 | ] 116 | ) 117 | ] 118 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/field_summarisers.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./summariser_factory 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./aggregation_parameters 7 | 8 | # COMMAND ---------- 9 | 10 | from functools import reduce 11 | from typing import Optional, Iterable, Callable 12 | 13 | from pyspark.sql import functions as F, DataFrame 14 | 15 | 16 | # region Prevent downstream linting highlights 17 | create_summariser = create_summariser 18 | get_categorical_aggregation_params = get_categorical_aggregation_params 19 | get_date_aggregation_params = get_date_aggregation_params 20 | get_discrete_aggregation_params = get_discrete_aggregation_params 21 | get_continuous_aggregation_params = get_continuous_aggregation_params 22 | # endregion 23 | 24 | 25 | def field_summariser( 26 | long_df: DataFrame, 27 | frequency_rounding: int, 28 | min_frequency: int, 29 | dataset_summarisers: Optional[Iterable[Callable]]=None 30 | ) -> DataFrame: 31 | """Summarise fields in a long-formatted dataframe with value type 'CATEGORICAL', 32 | 'DATE', 'DISCRETE' or 'CONTINUOUS'. 33 | 34 | Args: 35 | long_df (pyspark.sql.DataFrame): Long-formatted data to summarise 36 | frequency_rounding (int): Rounding applied to frequencies 37 | min_frequency (int): Cutoff below which frequencies are removed. 38 | dataset_summarisers (Optional[Iterable[Callable]], optional): Additional summariser 39 | functions for dataset-specific operations. Defaults to None. 40 | 41 | Returns: 42 | pyspark.sql.DataFrame: Summarised data (long-formatted) 43 | """ 44 | disclosure_control_params = (frequency_rounding, min_frequency) 45 | 46 | if dataset_summarisers is None: 47 | dataset_summarisers = [] 48 | 49 | summarisers = [ 50 | create_summariser("CATEGORICAL", get_categorical_aggregation_params(*disclosure_control_params)), 51 | create_summariser("DATE", get_date_aggregation_params(*disclosure_control_params)), 52 | create_summariser("DISCRETE", get_discrete_aggregation_params(*disclosure_control_params)), 53 | create_summariser("CONTINUOUS", get_continuous_aggregation_params()), 54 | *dataset_summarisers, 55 | ] 56 | 57 | meta_dfs = map(lambda summariser: summariser(long_df), summarisers) 58 | meta_df = reduce(DataFrame.union, meta_dfs) 59 | 60 | return meta_df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/relationship_summariser.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../wide_to_long 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./summariser_factory 7 | 8 | # COMMAND ---------- 9 | 10 | from pyspark.sql import DataFrame, functions as F 11 | 12 | # region Prevent downstream linting highlights 13 | BASE_GROUPING_FIELDS = BASE_GROUPING_FIELDS 14 | AggregationParametersType = AggregationParametersType 15 | weight_col = weight_col 16 | create_summariser = create_summariser 17 | # endregion 18 | 19 | 20 | def get_relationship_aggregation_params(pk_field: str, fk_field: str) -> AggregationParametersType: 21 | """Get the parameters required to build the frequency distributions of cardinal relationships 22 | between a primary key and foreign key 23 | 24 | Args: 25 | pk_field (str): Name of the primary key field 26 | fk_field (str): Name of the foreign key field 27 | 28 | Returns: 29 | AggregationParametersType: Parameters to do the aggregations 30 | """ 31 | # Count the number of 1-to-N mappings (where N >= 0) between foreign and primary keys 32 | # Null primary keys contribute towards 1-to-0 counts 33 | # Null foreign keys are ignored 34 | cardinality_col = F.approx_count_distinct(pk_field) 35 | return [ 36 | dict( 37 | grouping_fields=[*BASE_GROUPING_FIELDS, fk_field], 38 | agg_cols=[cardinality_col.alias("VALUE_NUMERIC")], 39 | filters=[~F.isnull(fk_field)] # TODO: Handle nulls separately 40 | ), 41 | dict( 42 | grouping_fields=[*BASE_GROUPING_FIELDS, "VALUE_NUMERIC"], 43 | agg_cols=[weight_col], 44 | ), 45 | ] 46 | 47 | 48 | def summarise_relationships(wide_df: DataFrame, pk_field: str, fk_field: str) -> DataFrame: 49 | """Summarise the frequency distributions of cardinal relationships between primary key 50 | and foreign key fields in wide-formatted data. 51 | 52 | Args: 53 | wide_df (DataFrame): DataFrame holding the relationships to summarise 54 | pk_field (str): Name of the primary key field 55 | fk_field (str): Name of the foreign key field 56 | 57 | Returns: 58 | DataFrame: Summarised relationship data 59 | """ 60 | wide_df = ( 61 | wide_df 62 | .withColumn("FIELD_NAME", F.lit(pk_field)) 63 | .withColumn("VALUE_TYPE", F.lit("RELATIONSHIP")) 64 | .select(*BASE_GROUPING_FIELDS, pk_field, fk_field) 65 | ) 66 | 67 | aggregation_params = get_relationship_aggregation_params(pk_field, fk_field) 68 | summariser = create_summariser("RELATIONSHIP", aggregation_params) 69 | meta_df = ( 70 | summariser(wide_df) 71 | .withColumn("VALUE_STRING", F.lit(fk_field)) 72 | ) 73 | 74 | return meta_df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/summariser_factory.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./aggregation_parameters 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../common/coerce_schema 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../schemas/meta_schema 11 | 12 | # COMMAND ---------- 13 | 14 | from functools import reduce 15 | from typing import Iterable, Callable, Dict, Union 16 | 17 | from pyspark.sql import functions as F, DataFrame, Column 18 | 19 | 20 | # region Prevent downstream linting highlights 21 | coerce_schema = coerce_schema 22 | AggregationParametersType = AggregationParametersType 23 | get_meta_schema = get_meta_schema 24 | # endregion 25 | 26 | 27 | def df_aggregation_reducer( 28 | df: DataFrame, 29 | aggregation_params: AggregationParametersType 30 | ) -> DataFrame: 31 | """Apply a sequence of aggregations, filters and derivations to a dataframe 32 | 33 | Args: 34 | df (DataFrame): DataFrame to aggregate 35 | aggregation_params (AggregationParametersType): Defines the aggregations. 36 | This is an iterable of keyed objects, where each element has the keys 37 | 'grouping_fields', 'agg_cols', 'derivations' and 'filters': the first 2 38 | are used in the DataFrame.groupBy(...).agg(...); the 3rd defines columns 39 | to add to the dataframe after the aggregation; and the 4th is a set 40 | of filter columns to apply to the output 41 | 42 | Returns: 43 | DataFrame: Aggregated data 44 | """ 45 | def _reducer(acc, x): 46 | acc = ( 47 | acc 48 | .groupBy(*x["grouping_fields"]) 49 | .agg(*x["agg_cols"]) 50 | ) 51 | 52 | for derived_col in x.get("derivations", []): 53 | col_name = acc.select(derived_col).columns[0] 54 | acc = acc.withColumn(col_name, derived_col) 55 | 56 | for condition_col in x.get("filters", []): 57 | acc = acc.filter(condition_col) 58 | 59 | return acc 60 | 61 | agg_df = reduce(_reducer, aggregation_params, df) 62 | 63 | return agg_df 64 | 65 | 66 | def create_summariser( 67 | value_type: str, 68 | aggregation_params: AggregationParametersType 69 | ) -> Callable[[DataFrame], DataFrame]: 70 | """Create a function which summarises (long-formatted) data with a given type. 71 | 72 | Args: 73 | value_type (str): Describes type of data to summarise (e.g. 'CATEGORICAL') 74 | aggregation_params (AggregationParametersType): Defines the aggregations to summarise the data 75 | 76 | Returns: 77 | Callable[[DataFrame], DataFrame]: Summarises long-formatted data 78 | """ 79 | value_type_condition_col = F.col("VALUE_TYPE") == value_type 80 | 81 | def _summariser(df: DataFrame) -> DataFrame: 82 | filtered_df = df.filter(value_type_condition_col) 83 | aggregated_df = df_aggregation_reducer(filtered_df, aggregation_params) 84 | meta_df = coerce_schema(aggregated_df, get_meta_schema()) 85 | return meta_df 86 | 87 | return _summariser -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/ingestion.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../common/common_exports 3 | 4 | # COMMAND ---------- 5 | 6 | from pyspark.sql import DataFrame, functions as F 7 | 8 | # COMMAND ---------- 9 | 10 | def ingest(source_database: str, source_table: str) -> DataFrame: 11 | """Read the data from the source table 12 | 13 | Args: 14 | source_database (str): Database name containing source table 15 | source_table (str): Name of source table 16 | 17 | Returns: 18 | pyspark.sql.DataFrame: DataFrame with data to scrape 19 | 20 | """ 21 | source_table_fullname = f"{source_database}.{source_table}" 22 | dataset_df = ( 23 | spark.table(source_table_fullname) 24 | .withColumn("TABLE_NAME", F.lit(source_table)) 25 | ) 26 | 27 | return dataset_df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/schemas/long_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import types as T 3 | 4 | 5 | def get_long_schema() -> T.StructType: 6 | return T.StructType([ 7 | T.StructField("FIELD_NAME", T.StringType(), False), 8 | T.StructField("VALUE_STRING", T.StringType(), True), 9 | T.StructField("VALUE_NUMERIC", T.DoubleType(), True), 10 | T.StructField("VALUE_DATE", T.DateType(), True), 11 | T.StructField("VALUE_TIMESTAMP", T.TimestampType(), True), 12 | ]) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/schemas/meta_schema.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | from pyspark.sql import types as T 3 | 4 | 5 | def get_meta_schema() -> T.StructType: 6 | return T.StructType([ 7 | T.StructField("TABLE_NAME", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO 8 | T.StructField("FIELD_NAME", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO 9 | T.StructField("VALUE_TYPE", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO 10 | T.StructField("VALUE_STRING", T.StringType(), True), 11 | T.StructField("VALUE_NUMERIC", T.DoubleType(), True), 12 | # T.StructField("VALUE_TIMESTAMP", T.TimestampType(), True), 13 | T.StructField("VALUE_DATE", T.DateType(), True), 14 | T.StructField("VALUE_STRING_ARRAY", T.ArrayType(T.StringType()), True), 15 | T.StructField("VALUE_NUMERIC_ARRAY", T.ArrayType(T.DoubleType()), True), 16 | # T.StructField("VALUE_TIMESTAMP_ARRAY", T.ArrayType(T.TimestampType()), True), 17 | # T.StructField("VALUE_DATE_ARRAY", T.ArrayType(T.DateType()), True), 18 | T.StructField("WEIGHT", T.DoubleType(), True), 19 | T.StructField("WEIGHT_ARRAY", T.ArrayType(T.DoubleType()), True), 20 | ]) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/notebooks/scraper_stages/scrape_metadata.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./ingestion 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ./wide_to_long 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ./schemas/meta_schema 11 | 12 | # COMMAND ---------- 13 | 14 | # MAGIC %run ./aggregation/field_summarisers 15 | 16 | # COMMAND ---------- 17 | 18 | # MAGIC %run ./aggregation/relationship_summariser 19 | 20 | # COMMAND ---------- 21 | 22 | # MAGIC %run ../dataset_definitions/relational_helpers 23 | 24 | # COMMAND ---------- 25 | 26 | from typing import Optional, Iterable, Dict, Callable 27 | from pyspark.sql import DataFrame, functions as F 28 | 29 | # region Prevent downstream linting highlights 30 | spark = spark 31 | Table = Table 32 | wide_to_long = wide_to_long 33 | field_summariser = field_summariser 34 | get_primary_key = get_primary_key 35 | get_foreign_keys = get_foreign_keys 36 | summarise_relationships = summarise_relationships 37 | get_meta_schema = get_meta_schema 38 | ingest = ingest 39 | # endregion 40 | 41 | 42 | def scrape_fields( 43 | wide_df: DataFrame, 44 | type_overrides: Iterable[Dict[str, str]], 45 | frequency_rounding: int, 46 | min_frequency: int, 47 | dataset_summarisers: Optional[Iterable[Callable]] = None, 48 | unpivoted_fields: Optional[Iterable[str]] = None 49 | )-> DataFrame: 50 | """Scrape the field metadata (i.e. frequency / percentile distributions) from wide-formatted data 51 | 52 | Args: 53 | wide_df (DataFrame): Wide-formatted data to scrape 54 | type_overrides (Iterable[Dict[str, str]]): Overrides for metadata types, used in place of automatic types 55 | frequency_rounding (int): Rounding applied to frequency-typed fields 56 | min_frequency (int): Cutoff below which frequencies / values are removed 57 | dataset_summarisers (Optional[Iterable[Callable]], optional): Dataset-specific summariser functions. Passed through 58 | to summarise_fields. Defaults to None. 59 | unpivoted_fields (Optional[Iterable[str]], optional): Fields to preserve in the original wide_df. Defaults to None. 60 | 61 | Returns: 62 | DataFrame: Field metadata 63 | """ 64 | 65 | if unpivoted_fields is None: 66 | unpivoted_fields = [] 67 | 68 | unpivoted_fields = list(set(["TABLE_NAME", *unpivoted_fields])) 69 | 70 | # Pivot to long format 71 | long_df = wide_to_long(wide_df, preserved_field_names=unpivoted_fields) 72 | 73 | # Associate metadata types (e.g. 'categorical') 74 | # TODO: what happens if fields don't appear in type overrides? 75 | type_overrides_df = spark.createDataFrame(type_overrides, "FIELD_NAME: string, VALUE_TYPE: string") 76 | long_df = ( 77 | long_df 78 | .join(type_overrides_df.hint("broadcast"), on="FIELD_NAME", how="left") 79 | .withColumn("VALUE_TYPE", F.coalesce(F.col("VALUE_TYPE"), F.lit("CATEGORICAL"))) # Fill nulls 80 | ) 81 | 82 | # Summarise values 83 | meta_df = field_summariser( 84 | long_df, 85 | frequency_rounding=frequency_rounding, 86 | min_frequency=min_frequency, 87 | dataset_summarisers=dataset_summarisers, 88 | ) 89 | 90 | return meta_df 91 | 92 | 93 | def scrape_relationships(wide_df: DataFrame, table: Table) -> DataFrame: 94 | """Scrape the relationship metadata (i.e. frequency distributions of cardinal relationshios) 95 | from wide-formatted data 96 | 97 | Args: 98 | wide_df (DataFrame): Wide-formatted data to scrape 99 | table (Table): Holds the names of primary and foreign key fields 100 | 101 | Returns: 102 | DataFrame: Relationship metadata 103 | """ 104 | primary_key_field = get_primary_key(table) 105 | 106 | for foreign_key_field in get_foreign_keys(table): 107 | index_fields = [foreign_key_field.name, primary_key_field.name] 108 | index_df = wide_df.select("TABLE_NAME", *index_fields) 109 | meta_df = summarise_relationships(index_df, primary_key_field.name, foreign_key_field.name) 110 | return meta_df 111 | 112 | else: 113 | empty_meta_df = spark.createDataFrame([], get_meta_schema()) 114 | return empty_meta_df 115 | 116 | 117 | def scrape_metadata( 118 | database_name: str, 119 | source_table: Table, 120 | type_overrides: Iterable[Dict[str, str]], 121 | frequency_rounding: int, 122 | min_frequency: int, 123 | excluded_fields: Iterable[str], 124 | included_fields: Optional[Iterable[str]] = None, 125 | dataset_summarisers: Optional[Iterable[Callable]] = None, 126 | ) -> DataFrame: 127 | """Scrape field and relationship metadata from a table 128 | 129 | Args: 130 | database_name (str): Database containing the table to scrape 131 | source_table (Table): Table to scrape 132 | type_overrides (Iterable[Dict[str, str]]): Overrides for metadata types, used in place of automatic types 133 | frequency_rounding (int): Rounding applied to frequency-typed fields 134 | min_frequency (int): Cutoff below which frequencies / values are removed 135 | dataset_summarisers (Optional[Iterable[Callable]], optional): Dataset-specific summariser functions. Passed through 136 | to summarise_fields. Defaults to None. 137 | 138 | Returns: 139 | DataFrame: Scraped metadata 140 | """ 141 | dataset_df = ingest(database_name, source_table.qualifier) 142 | 143 | field_meta_df = scrape_fields( 144 | dataset_df, 145 | type_overrides, 146 | frequency_rounding=frequency_rounding, 147 | min_frequency=min_frequency, 148 | dataset_summarisers=dataset_summarisers 149 | ) 150 | 151 | relational_meta_df = scrape_relationships(dataset_df, source_table) 152 | 153 | meta_df = field_meta_df.unionByName(relational_meta_df) 154 | 155 | return meta_df -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/run_notebooks.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Artificial Data Generator: Run Notebooks 4 | # MAGIC This notebook runs the data generator pipeline for a given artificial dataset. The artificial data will be written into the database specified by the `CP Project Database` argument. 5 | # MAGIC 6 | # MAGIC **Prerequisites:** The underlying driver will ingest the metadata that was scraped from the real data and signed off against the disclosure control checklist - the metadata must exist within a database that is readable by this code promotion project (aka _iuod_artificial_data_generator_). The database containing metadata is specified via the `CP Project Database` argument. 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %md 11 | # MAGIC ## Setup 12 | 13 | # COMMAND ---------- 14 | 15 | # MAGIC %md 16 | # MAGIC ### Import dependencies 17 | 18 | # COMMAND ---------- 19 | 20 | # MAGIC %run ./notebooks/common/widget_utils 21 | 22 | # COMMAND ---------- 23 | 24 | from pprint import pprint 25 | from pathlib import Path 26 | 27 | GENERATOR_PIPELINES_ROOT = Path("./notebooks/generator_pipelines") 28 | AVAILABLE_DRIVERS = [ 29 | "artificial_hes" 30 | ] 31 | 32 | # COMMAND ---------- 33 | 34 | # MAGIC %md 35 | # MAGIC ### Create widgets 36 | 37 | # COMMAND ---------- 38 | 39 | if check_databricks(): 40 | dbutils.widgets.removeAll() 41 | dbutils.widgets.text("db", "iuod_artificial_data_generator", "0.1 CP Project Database") 42 | dbutils.widgets.text("artificial_dataset", "", "1.1 Artificial Dataset") 43 | dbutils.widgets.text("parameters_json", "", "1.2 Parameters JSON") 44 | else: 45 | # Only make widgets in databricks 46 | pass 47 | 48 | # COMMAND ---------- 49 | 50 | # MAGIC %md 51 | # MAGIC ## Driver execution 52 | # MAGIC Trigger the driver notebook for the artificial dataset specified via the arguments 53 | 54 | # COMMAND ---------- 55 | 56 | artificial_dataset = get_required_argument("artificial_dataset") 57 | 58 | if artificial_dataset not in AVAILABLE_DRIVERS: 59 | message = f"Driver does not exist for dataset '{artificial_dataset}': please choose from {AVAILABLE_DRIVERS}" 60 | dbutils.notebook.exit(json.dumps({"status": "failed", "message": message})) 61 | 62 | # Path to the notebook that generates data for the given dataset 63 | driver_path = GENERATOR_PIPELINES_ROOT / artificial_dataset / "driver" 64 | 65 | # Prepare driver parameters 66 | database = get_required_argument("db") 67 | driver_args = dict( 68 | meta_database = database, 69 | artificial_database = database, 70 | parameters_json = get_required_argument("parameters_json"), 71 | ) 72 | 73 | print(f"Running driver {str(driver_path)} with arguments:") 74 | pprint(driver_args) 75 | 76 | driver_response = dbutils.notebook.run(str(driver_path), 0, arguments=driver_args) 77 | 78 | dbutils.notebook.exit(driver_response) 79 | 80 | # COMMAND ---------- 81 | 82 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md ### README 3 | # MAGIC This notebook will run all the tests! Each test notebook is run in turn and the results are printed out in the cell below the run. New test notebooks should be added to the runs. 4 | # MAGIC 5 | # MAGIC Test failures will be raised as an error in the final cell of this notebook (i.e. all the tests will run first, regardless of passes or failures, and then there is a final step to check how many failures there were and alert the user). This should avoid failing tests remaining undetected. 6 | 7 | # COMMAND ---------- 8 | 9 | from typing import Generator, Union, Tuple, Dict 10 | from pyspark.util import Py4JJavaError 11 | 12 | 13 | def get_test_runner() -> Generator[Dict[str, Union[None, AssertionError]], Tuple[str, str], None]: 14 | errors = {} 15 | 16 | while True: 17 | test_name, test_path = yield 18 | 19 | try: 20 | dbutils.notebook.run(test_path, 0) 21 | 22 | if test_name in errors: 23 | del errors[test_name] 24 | except Py4JJavaError as e: 25 | errors[test_name] = e 26 | 27 | yield errors 28 | 29 | 30 | def run_tests(tests): 31 | test_runner = get_test_runner() 32 | errors = test_runner.send(None) 33 | 34 | for test_name, test_path in tests.items(): 35 | errors = test_runner.send((test_name, test_path)) 36 | next(test_runner) 37 | 38 | if errors: 39 | for test_name, e in errors.items(): 40 | print(f'Test failures detected in {tests[test_name]}: {e}') 41 | 42 | # COMMAND ---------- 43 | 44 | tests = { 45 | "scraper_stages": "./tests/scraper_stages_tests/run_tests", 46 | "generator_stages": "./tests/generator_stages_tests/run_tests", 47 | "scraper_pipelines": "./tests/scraper_pipelines_tests/run_tests", 48 | "generator_pipelines": "./tests/generator_pipelines_tests/run_tests", 49 | "dataset_definitions": "./tests/dataset_definitions_tests/run_tests", 50 | "common": "./tests/common_tests/run_tests", 51 | } 52 | 53 | run_tests(tests) 54 | 55 | # COMMAND ---------- 56 | 57 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/schemas/create_meta_table.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Create Metadata Table 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Imports 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ../notebooks/common/widget_utils 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %run ../notebooks/common/table_helpers 22 | 23 | # COMMAND ---------- 24 | 25 | # MAGIC %run ../notebooks/scraper_stages/schemas/meta_schema 26 | 27 | # COMMAND ---------- 28 | 29 | # MAGIC %md 30 | # MAGIC ### Widgets 31 | 32 | # COMMAND ---------- 33 | 34 | dbutils.widgets.text( "database_name", "", "1.1 Database Name") 35 | dbutils.widgets.text( "table_name", "", "1.2 Table Name") 36 | dbutils.widgets.dropdown("replace_if_exists", "false", ["true", "false"], "2.1 Replace If Exists") 37 | 38 | # COMMAND ---------- 39 | 40 | # MAGIC %md 41 | # MAGIC ## Main 42 | 43 | # COMMAND ---------- 44 | 45 | database_name = get_required_argument("database_name") 46 | table_name = get_required_argument("table_name") 47 | replace_if_exists = dbutils.widgets.get("replace_if_exists") == "true" 48 | 49 | # COMMAND ---------- 50 | 51 | if database_exists(spark, database_name, verbose=True): 52 | if table_exists(spark, database_name, table_name): 53 | if replace_if_exists: 54 | drop_table(spark, database_name, table_name) 55 | else: 56 | dbutils.notebook.exit({"status": "skipped", "message": "Table exists but replace_if_exists=False"}) 57 | 58 | create_table_from_schema( 59 | spark, 60 | database_name, 61 | table_name, 62 | get_meta_schema(), 63 | mode="overwrite", 64 | format="delta", 65 | partitionBy=["TABLE_NAME", "FIELD_NAME"] 66 | ) 67 | 68 | # COMMAND ---------- 69 | 70 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/schemas/uplifts/artificial_hes_schema_coercion.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Artificial HES Schema Coercion 4 | # MAGIC Coerces the schema of existing `artificial_hes` tables to match the schema specification for HES tables. 5 | 6 | # COMMAND ---------- 7 | 8 | # MAGIC %md ## Setup 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %md 13 | # MAGIC ### Imports 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ../../notebooks/common/widget_utils 18 | 19 | # COMMAND ---------- 20 | 21 | # MAGIC %run ../../notebooks/common/spark_helpers 22 | 23 | # COMMAND ---------- 24 | 25 | # MAGIC %run ../../notebooks/common/table_helpers 26 | 27 | # COMMAND ---------- 28 | 29 | # MAGIC %run ../../notebooks/common/coerce_schema 30 | 31 | # COMMAND ---------- 32 | 33 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_schemas 34 | 35 | # COMMAND ---------- 36 | 37 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_tables 38 | 39 | # COMMAND ---------- 40 | 41 | from typing import Iterable 42 | from datetime import datetime 43 | 44 | # COMMAND ---------- 45 | 46 | # MAGIC %md 47 | # MAGIC ### Widgets 48 | 49 | # COMMAND ---------- 50 | 51 | dbutils.widgets.text("database_name", "artificial_hes", "0.1 Database Name") 52 | 53 | # COMMAND ---------- 54 | 55 | database_name = get_required_argument("database_name") 56 | 57 | # COMMAND ---------- 58 | 59 | # MAGIC %md 60 | # MAGIC ## Functions 61 | 62 | # COMMAND ---------- 63 | 64 | def backup_and_uplift_schema( 65 | database_name: str, 66 | table_name: str, 67 | target_schema: T.StructType, 68 | partitionBy: Iterable[str]=[], 69 | ) -> bool: 70 | backup_version = datetime.now().strftime("%Y%m%d%H%M") 71 | backup_table_name = f"{table_name}_backup_{backup_version}" 72 | 73 | current_df = spark.table(f"{database_name}.{table_name}") 74 | 75 | if not datatypes_equal(current_df.schema, target_schema): 76 | print("Modifying schema to target") 77 | 78 | # Backup the current table 79 | create_table( 80 | spark, 81 | current_df, 82 | database_name, 83 | backup_table_name, 84 | format="delta", 85 | mode="overwrite", 86 | partitionBy=partitionBy, 87 | ) 88 | 89 | # Drop the current table which has now been backed-up 90 | drop_table(spark, database_name, table_name) 91 | 92 | # Apply the schema changes 93 | uplifted_df = coerce_schema( 94 | spark.table(f"{database_name}.{backup_table_name}"), 95 | target_schema 96 | ) 97 | 98 | # Write the new table 99 | create_table( 100 | spark, 101 | uplifted_df, 102 | database_name, 103 | table_name, 104 | format="delta", 105 | mode="overwrite", 106 | partitionBy=partitionBy, 107 | ) 108 | 109 | # Remove the backup 110 | drop_table(spark, database_name, backup_table_name) 111 | 112 | else: 113 | print("Current schema already equal to target schema!") 114 | 115 | target_df = spark.table(f"{database_name}.{table_name}") 116 | 117 | assert datatypes_equal(target_df.schema, target_schema) 118 | 119 | return True 120 | 121 | 122 | # COMMAND ---------- 123 | 124 | # MAGIC %md 125 | # MAGIC ## Main 126 | 127 | # COMMAND ---------- 128 | 129 | for dataset_name, dataset_tables in HES_TABLES.items(): 130 | target_schema = get_hes_schema(dataset_name) 131 | 132 | for hes_table in dataset_tables: 133 | artificial_table_name = f"artificial_{hes_table.name}" 134 | 135 | if table_exists(spark, database_name, artificial_table_name): 136 | print(f"Uplifting schema for table {database_name}.{artificial_table_name}") 137 | 138 | backup_and_uplift_schema( 139 | database_name, 140 | artificial_table_name, 141 | target_schema 142 | ) 143 | else: 144 | continue 145 | 146 | # COMMAND ---------- 147 | 148 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/schemas/uplifts/artificial_hes_schema_coercion_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC ## Setup 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ### Imports 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %run ../../notebooks/common/spark_helpers 13 | 14 | # COMMAND ---------- 15 | 16 | # MAGIC %run ../../notebooks/common/table_helpers 17 | 18 | # COMMAND ---------- 19 | 20 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_schemas 21 | 22 | # COMMAND ---------- 23 | 24 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_tables 25 | 26 | # COMMAND ---------- 27 | 28 | from pyspark.sql import functions as F, types as T 29 | 30 | # COMMAND ---------- 31 | 32 | # MAGIC %md 33 | # MAGIC ### Setup tables 34 | 35 | # COMMAND ---------- 36 | 37 | target_schema = get_hes_ae_schema() 38 | 39 | # Before the uplift there are some extra fields in the hes ae tables 40 | # We want to test that these are removed by the uplift 41 | current_schema = T.StructType([ 42 | *target_schema.fields, 43 | T.StructField("DIAG_3_CONCAT", T.StringType(), True), 44 | T.StructField("DIAG_4_CONCAT", T.StringType(), True), 45 | T.StructField("DIAG_COUNT", T.StringType(), True), 46 | T.StructField("OPERTN_3_CONCAT", T.StringType(), True), 47 | T.StructField("OPERTN_4_CONCAT", T.StringType(), True), 48 | T.StructField("OPERTN_COUNT", T.StringType(), True), 49 | ]) 50 | 51 | database_name = "alistair_jones5_101351" 52 | table_name = "artificial_hes_ae_1415" 53 | 54 | # Recreate the table 55 | if table_exists(spark, database_name, table_name): 56 | drop_table(spark, database_name, table_name) 57 | 58 | create_table_from_schema( 59 | spark, 60 | database_name, 61 | table_name, 62 | current_schema, 63 | format="delta", 64 | mode="overwrite", 65 | ) 66 | 67 | # COMMAND ---------- 68 | 69 | # MAGIC %md 70 | # MAGIC ## Run tests 71 | 72 | # COMMAND ---------- 73 | 74 | # Should be able to run twice without failure 75 | for i in range(2): 76 | dbutils.notebook.run("./artificial_hes_schema_coercion", 0, {"database_name": database_name}) 77 | 78 | actual_schema = spark.table(f"{database_name}.{table_name}").schema 79 | assert datatypes_equal(target_schema, actual_schema) 80 | 81 | # COMMAND ---------- 82 | 83 | # MAGIC %md 84 | # MAGIC ## Cleanup 85 | 86 | # COMMAND ---------- 87 | 88 | # MAGIC %md 89 | # MAGIC ### Teardown tables 90 | 91 | # COMMAND ---------- 92 | 93 | drop_table(spark, database_name, table_name) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/schemas/uplifts/open_data_metadata_uplift.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Open Data Metadata Uplift 4 | # MAGIC 5 | # MAGIC This notebook uplifts the schema of a table containing metadata to greatly improve adherance to [NHS Digital's Open Data standards](https://github.com/NHS Digitaligital/open-data-standards). 6 | 7 | # COMMAND ---------- 8 | 9 | # MAGIC %md ## Setup, Imports & Widgets 10 | 11 | # COMMAND ---------- 12 | 13 | # MAGIC %run ../../notebooks/common/common_exports 14 | 15 | # COMMAND ---------- 16 | 17 | # MAGIC %run ../../notebooks/scraper_stages/schemas/meta_schema 18 | 19 | # COMMAND ---------- 20 | 21 | dbutils.widgets.removeAll() 22 | 23 | # COMMAND ---------- 24 | 25 | dbutils.widgets.text("database_name", "iuod_artificial_data_generator", "0.1 Database Name") 26 | dbutils.widgets.text("table_name", "", "0.2 Metadata Table Name") 27 | 28 | # COMMAND ---------- 29 | 30 | import json 31 | from pyspark.sql import types as T 32 | 33 | # COMMAND ---------- 34 | 35 | # MAGIC %md 36 | # MAGIC ## Constants 37 | 38 | # COMMAND ---------- 39 | 40 | # MAGIC %md ### SQL Expressions 41 | 42 | # COMMAND ---------- 43 | 44 | # Template for transforming the metadata 45 | select_open_data_metadata_template_expr = """ 46 | SELECT 47 | TABLE_NAME 48 | ,FIELD_NAME 49 | ,coalesce( 50 | CASE 51 | WHEN METADATA.SUMMARY.CATEGORICAL.VALUE_TYPE == 'FLAT' THEN 'CATEGORICAL' 52 | WHEN METADATA.SUMMARY.CATEGORICAL.VALUE_TYPE == 'NESTED' THEN 'DEMOGRAPHIC_CATEGORICAL' 53 | WHEN FIELD_NAME == 'MYDOB' THEN 'DEMOGRAPHIC_DATE' 54 | END, 55 | METADATA.SUMMARY.VALUE_TYPE, 56 | METADATA.VALUE_TYPE 57 | ) AS VALUE_TYPE 58 | ,coalesce( 59 | METADATA.FIELD_TYPE, 60 | METADATA.SUMMARY.CATEGORICAL.FLAT.VALUE, 61 | METADATA.SUMMARY.DATE.FORMAT 62 | ) AS VALUE_STRING 63 | ,coalesce( 64 | METADATA.SUMMARY.DISCRETE.VALUE, 65 | METADATA.RELATIONSHIP.MAPPING 66 | ) AS VALUE_NUMERIC 67 | ,METADATA.SUMMARY.DATE.VALUE AS VALUE_DATE 68 | ,coalesce( 69 | METADATA.SUMMARY.CATEGORICAL.NESTED.VALUE.VALUE, 70 | METADATA.SUMMARY.DATE.NOISE.REPLACEMENT 71 | ) as VALUE_STRING_ARRAY 72 | ,cast(null as array) as VALUE_NUMERIC_ARRAY 73 | ,coalesce( 74 | METADATA.SUMMARY.CATEGORICAL.FLAT.WEIGHT, 75 | METADATA.SUMMARY.CATEGORICAL.NESTED.WEIGHT, 76 | METADATA.SUMMARY.DATE.WEIGHT, 77 | METADATA.SUMMARY.DISCRETE.WEIGHT, 78 | METADATA.RELATIONSHIP.FREQUENCY 79 | ) AS WEIGHT 80 | ,coalesce( 81 | METADATA.SUMMARY.CATEGORICAL.NESTED.VALUE.WEIGHT, 82 | METADATA.SUMMARY.DATE.NOISE.WEIGHT 83 | ) AS WEIGHT_ARRAY 84 | FROM {db}.{table} -- Template parameters 85 | WHERE NOT METADATA.SUMMARY.VALUE_TYPE <=> 'CONTINUOUS' 86 | UNION 87 | SELECT 88 | TABLE_NAME 89 | ,FIELD_NAME 90 | ,METADATA.SUMMARY.VALUE_TYPE 91 | ,cast(null AS string) AS VALUE_STRING 92 | ,cast(null AS double) AS VALUE_NUMERIC 93 | ,cast(null AS date) AS VALUE_DATE 94 | ,cast(null AS array) AS VALUE_STRING_ARRAY 95 | ,array_sort(collect_list(METADATA.SUMMARY.CONTINUOUS.VALUE)) AS VALUE_NUMERIC_ARRAY 96 | ,cast(null AS double) AS WEIGHT 97 | ,cast(null AS array) AS WEIGHT_ARRAY 98 | FROM {db}.{table} -- Template parameters 99 | WHERE METADATA.SUMMARY.VALUE_TYPE <=> 'CONTINUOUS' 100 | GROUP BY TABLE_NAME, FIELD_NAME, METADATA.SUMMARY.VALUE_TYPE 101 | """ 102 | 103 | # COMMAND ---------- 104 | 105 | # MAGIC %md 106 | # MAGIC ## Main 107 | 108 | # COMMAND ---------- 109 | 110 | db = get_required_argument("database_name") 111 | target_table = get_required_argument("table_name") 112 | backup_table = f"{target_table}_v1_schema_backup" 113 | 114 | print(f"Uplifting schema of table {db}.{target_table}") 115 | 116 | # COMMAND ---------- 117 | 118 | if not database_exists(spark, db, verbose=True): 119 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Database {db} not found"})) 120 | 121 | if not table_exists(spark, db, target_table): 122 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Table {db}.{target_table} does not exist"})) 123 | 124 | source_df = spark.table(f"{db}.{target_table}") 125 | source_schema = T.StructType(sorted(source_df.schema.fields, key=lambda f: f.name)) 126 | target_schema = T.StructType(sorted(get_meta_schema().fields, key=lambda f: f.name)) 127 | 128 | if not datatypes_equal(target_schema, source_schema): 129 | print(f"Schema of table {db}.{target_table} not equal to target schema") 130 | 131 | # Can't rename tables, so need to backup first and then drop to reuse the name 132 | print(f"Backing up table {db}.{target_table} to table {db}.{backup_table}") 133 | create_table(spark, source_df, db, backup_table, partitionBy=["TABLE_NAME", "FIELD_NAME"]) 134 | drop_table(spark, db, target_table) 135 | 136 | print(f"Recreating table {db}.{target_table} from backup with target schema") 137 | select_open_data_metadata_expr = select_open_data_metadata_template_expr.format(db=db, table=backup_table) 138 | target_df = spark.sql(select_open_data_metadata_expr) 139 | create_table(spark, target_df, db, target_table, partitionBy=["TABLE_NAME", "FIELD_NAME"]) 140 | dbutils.notebook.exit(json.dumps({"status": "success", "message": ""})) 141 | 142 | else: 143 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Schema of table {db}.{target_table} already equal to target schema"})) 144 | 145 | 146 | # COMMAND ---------- 147 | 148 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/schemas/uplifts/open_data_metadata_uplift_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC # Setup 4 | 5 | # COMMAND ---------- 6 | 7 | # MAGIC %md 8 | # MAGIC ## Imports 9 | 10 | # COMMAND ---------- 11 | 12 | # MAGIC %run ../../notebooks/common/table_helpers 13 | 14 | # COMMAND ---------- 15 | 16 | # MAGIC %run ../../notebooks/common/spark_helpers 17 | 18 | # COMMAND ---------- 19 | 20 | # MAGIC %run ../../notebooks/scraper_stages/schemas/meta_schema 21 | 22 | # COMMAND ---------- 23 | 24 | import json 25 | from pyspark.sql.types import StructType 26 | 27 | database_name = "artificial_hes_meta" 28 | table_name = "open_data_metadata_uplift_tests__artificial_hes_meta" 29 | backup_table_name = f"{table_name}_v1_schema_backup" 30 | 31 | # COMMAND ---------- 32 | 33 | # MAGIC %md 34 | # MAGIC ## Setup tables 35 | 36 | # COMMAND ---------- 37 | 38 | v1_meta_schema_json = '{"type": "struct", "fields": [{"name": "TABLE_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "FIELD_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "METADATA", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "FIELD_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "SUMMARY", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "CATEGORICAL", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "FLAT", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "NESTED", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": {"type": "array", "elementType": {"type": "struct", "fields": [{"name": "VALUE", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "containsNull": true}, "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "DISCRETE", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "integer", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "CONTINUOUS", "type": {"type": "struct", "fields": [{"name": "PERCENTILE", "type": "double", "nullable": true, "metadata": {}}, {"name": "VALUE", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "DATE", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "date", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}, {"name": "FORMAT", "type": "string", "nullable": true, "metadata": {}}, {"name": "NOISE", "type": {"type": "array", "elementType": {"type": "struct", "fields": [{"name": "REPLACEMENT", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "containsNull": true}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "ID", "type": "string", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "RELATIONSHIP", "type": {"type": "struct", "fields": [{"name": "LINKED_TABLE_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "LINKED_FIELD_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "MAPPING", "type": "long", "nullable": true, "metadata": {}}, {"name": "FREQUENCY", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}' 39 | 40 | v1_meta_schema = StructType.fromJson(json.loads(v1_meta_schema_json)) 41 | 42 | print(f"Creating table to test uplift: {database_name}.{table_name}") 43 | create_table_from_schema(spark, database_name, table_name, v1_meta_schema) 44 | 45 | # COMMAND ---------- 46 | 47 | # MAGIC %md # Run tests 48 | 49 | # COMMAND ---------- 50 | 51 | # Try uplift twice: once represents fresh database, twice is to check that running it again doesn't break anything! 52 | for i in range(2): 53 | # Uplift 54 | result = dbutils.notebook.run("./open_data_metadata_uplift", 0, {"database_name": database_name, "table_name": table_name}) 55 | print(f"Uplift finished with response: {result}") 56 | 57 | # Test for backup 58 | assert table_exists(spark, database_name, backup_table_name), "Backup table doesn't exist" 59 | assert datatypes_equal(spark.table(f"{database_name}.{backup_table_name}").schema, v1_meta_schema), "Backup table has incorrect schema" 60 | 61 | # Test for uplifted 62 | uplift_meta_schema = get_meta_schema() 63 | assert table_exists(spark, database_name, table_name), "Uplift table doesn't exist" 64 | assert datatypes_equal(spark.table(f"{database_name}.{table_name}").schema, uplift_meta_schema), "Uplifted table has incorrect schema" 65 | 66 | 67 | # COMMAND ---------- 68 | 69 | # MAGIC %md 70 | # MAGIC # Cleanup 71 | 72 | # COMMAND ---------- 73 | 74 | # MAGIC %md 75 | # MAGIC ## Teardown tables 76 | 77 | # COMMAND ---------- 78 | 79 | print("Cleaning up tables") 80 | drop_table(spark, database_name, backup_table_name) 81 | drop_table(spark, database_name, table_name) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/admin_tests/placeholder_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/common_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./test_spark_helpers", 9 | ] 10 | 11 | for notebook_path in test_notebook_paths: 12 | dbutils.notebook.run(notebook_path, 0) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/dataset_definitions_tests/placeholder_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/dataset_definitions_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [] 8 | 9 | for notebook_path in test_notebook_paths: 10 | dbutils.notebook.run(notebook_path, 0) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/demographic_field_generators_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../../notebooks/common/spark_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../test_helpers 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../../../notebooks/generator_pipelines/artificial_hes/demographic_field_generators 11 | 12 | # COMMAND ---------- 13 | 14 | # import demographic_field_generators 15 | 16 | 17 | from datetime import datetime 18 | from pyspark.sql import DataFrame, functions as F 19 | 20 | 21 | def tidy_demographic_meta_df_fixture() -> DataFrame: 22 | data = [ 23 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["a", "b"], None, 1., [2., 3.], None], 24 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["c", "d"], None, 2., [1., 1.], None], 25 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["e"], None, 3., [1.], None], 26 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["a"], None, 3., [1.], None], 27 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["b", "c"], None, 2., [2., 3.], None], 28 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["d", "e"], None, 1., [1., 1.], None], 29 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-01", 1., None, 0.5], 30 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-02", 2., None, 1.0], 31 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-03", 3., None, 1.0], 32 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-01", 3., None, 1.0], 33 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-02", 2., None, 0.8], 34 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-03", 1., None, 1.0], 35 | ] 36 | schema = "FIELD_NAME: string, VALUE_TYPE: string, VALUE_STRING_ARRAY: array, VALUE_STRING: string, WEIGHT: double, WEIGHT_ARRAY: array, VALUE_NUMERIC: double" 37 | df = spark.createDataFrame(data, schema) 38 | return df 39 | 40 | # COMMAND ---------- 41 | 42 | demographic_categorical_field_generator_test_suite = FunctionTestSuite() 43 | 44 | 45 | @demographic_categorical_field_generator_test_suite.add_test 46 | def test_generated_values_have_expected_frequency_distribution(): 47 | n_rows = 10000 48 | n_array_values = 10 49 | base_df = ( 50 | spark.range(n_rows) 51 | .withColumn("id2", F.explode(F.sequence(F.lit(1), F.lit(n_array_values), F.lit(1)))) # 10 values per unique id 52 | ) 53 | 54 | meta_df = tidy_demographic_meta_df_fixture() 55 | 56 | result_df = demographic_categorical_field_generator(meta_df, base_df, ["FIELD_NAME"], ["id", "id2"]) 57 | 58 | aggregated_result_df = ( 59 | result_df 60 | .groupBy("FIELD_NAME", "VALUE_STRING") 61 | .agg((F.count(F.lit(1)) / (n_rows * n_array_values)).alias("FREQUENCY")) 62 | ) 63 | 64 | # Need to account for the vector weights too 65 | expected_frequency_col = ( 66 | F 67 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "a"), (2 / 5) * (1 / 6)) 68 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "b"), (3 / 5) * (1 / 6)) 69 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "c"), (1 / 2) * (2 / 6)) 70 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "d"), (1 / 2) * (2 / 6)) 71 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "e"), 3 / 6) 72 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "a"), 3 / 6) 73 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "b"), (2 / 5) * (2 / 6)) 74 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "c"), (3 / 5) * (2 / 6)) 75 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "d"), (1 / 2) * (1 / 6)) 76 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "e"), (1 / 2) * (1 / 6)) 77 | ) 78 | 79 | aggregated_result_df = aggregated_result_df.withColumn("EXPECTED_FREQUENCY", expected_frequency_col) 80 | assert columns_approx_equal(aggregated_result_df, "FREQUENCY", "EXPECTED_FREQUENCY", 0.1), "Some or all frequencies did not equal the expected values" 81 | 82 | 83 | demographic_categorical_field_generator_test_suite.run() 84 | 85 | 86 | # COMMAND ---------- 87 | 88 | demographic_date_field_generator_test_suite = FunctionTestSuite() 89 | 90 | 91 | @demographic_date_field_generator_test_suite.add_test 92 | def test_generated_values_have_expected_frequency_distribution(): 93 | n_rows = 10000 94 | n_array_values = 100 95 | base_df = ( 96 | spark.range(n_rows) 97 | .withColumn("id2", F.explode(F.sequence(F.lit(1), F.lit(n_array_values), F.lit(1)))) # n values per unique id 98 | ) 99 | 100 | meta_df = tidy_demographic_meta_df_fixture() 101 | 102 | result_df = demographic_date_field_generator(meta_df, base_df, ["FIELD_NAME"], ["id", "id2"]) 103 | 104 | aggregated_result_df = ( 105 | result_df 106 | .groupBy("FIELD_NAME", "VALUE_STRING") 107 | .agg((F.count(F.lit(1)) / (n_rows * n_array_values)).alias("FREQUENCY")) 108 | ) 109 | 110 | # Expected frequencies (approx. accounting for the noise incidence) 111 | expected_frequency_col = ( 112 | F 113 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-01")), 0.5 * (1 / 6)) 114 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-02")), (2 / 6) + ((2 / 5) * (1 / 24))) 115 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-03")), (3 / 6) + ((3 / 5) * (1 / 24))) 116 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-01")), (3 / 6) + ((3 / 5) * (1 / 30))) 117 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-02")), 0.8 * (2 / 6)) 118 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-03")), (1 / 6) + ((2 / 5) * (1 / 30))) 119 | ) 120 | aggregated_result_df = aggregated_result_df.withColumn("EXPECTED_FREQUENCY", expected_frequency_col) 121 | 122 | assert columns_approx_equal(aggregated_result_df, "FREQUENCY", "EXPECTED_FREQUENCY", 0.1), "Some or all frequencies did not equal the expected values" 123 | 124 | 125 | demographic_date_field_generator_test_suite.run() 126 | 127 | # COMMAND ---------- 128 | 129 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/derivations_tests/age_fields_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../imports 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/age_fields 7 | 8 | # COMMAND ---------- 9 | 10 | from pyspark.sql import functions as F 11 | 12 | # COMMAND ---------- 13 | 14 | def create_test_date_df(): 15 | df = spark.createDataFrame([ 16 | (0, None, "2021-12-31"), 17 | (1, "2021-12-31", "2021-12-31"), 18 | (2, "2021-12-31", "2021-12-26"), 19 | (3, "2021-12-31", "2021-12-04"), 20 | (4, "2021-12-31", "2021-10-08"), 21 | (5, "2021-12-31", "2021-07-05"), 22 | (6, "2021-12-31", "2021-05-04"), 23 | (7, "2021-12-31", "2021-01-01"), 24 | (8, "2021-12-31", "1900-12-21"), 25 | (9, "2021-12-31", "2001-12-31"), 26 | ], ["id", "ADMIDATE", "DOB"] 27 | ) 28 | 29 | return df 30 | 31 | 32 | # COMMAND ---------- 33 | 34 | test_get_fractional_age_field = FunctionTestSuite() 35 | 36 | 37 | @test_get_fractional_age_field.add_test 38 | def test_creates_column_with_expected_values(): 39 | # Inputs 40 | df = create_test_date_df() 41 | 42 | # Define expectations 43 | expected_fractional_age = ( 44 | F.when(F.col("id") == 0, None) 45 | .when(F.col("id") == 1, 0.002) 46 | .when(F.col("id") == 2, 0.010) 47 | .when(F.col("id") == 3, 0.048) 48 | .when(F.col("id") == 4, 0.167) 49 | .when(F.col("id") == 5, 0.375) 50 | .when(F.col("id") == 6, 0.625) 51 | .when(F.col("id") == 7, 0.875) 52 | .when(F.col("id") == 8, 120) 53 | .when(F.col("id") == 9, 7305/365) 54 | ) 55 | df = df.withColumn("EXPECTED_FRAC_AGE", expected_fractional_age) 56 | 57 | # Execute function under test 58 | fractional_age_calc = get_fractional_age_field("ADMIDATE", "DOB") 59 | df = df.withColumn("FRAC_AGE_CALC", fractional_age_calc) 60 | 61 | # Check against expectation 62 | assert columns_equal(df, "FRAC_AGE_CALC", "EXPECTED_FRAC_AGE") 63 | 64 | 65 | test_get_fractional_age_field.run() 66 | 67 | # COMMAND ---------- 68 | 69 | test_get_categorized_age_field = FunctionTestSuite() 70 | 71 | 72 | @test_get_categorized_age_field.add_test 73 | def test_creates_column_with_expected_values(): 74 | # Inputs 75 | df = create_test_date_df() 76 | 77 | # Define expectations 78 | expected_categorized_age = ( 79 | F.when(F.col("id") == 0, None) 80 | .when(F.col("id") == 1, 7001) 81 | .when(F.col("id") == 2, 7002) 82 | .when(F.col("id") == 3, 7003) 83 | .when(F.col("id") == 4, 7004) 84 | .when(F.col("id") == 5, 7005) 85 | .when(F.col("id") == 6, 7006) 86 | .when(F.col("id") == 7, 7007) 87 | .when(F.col("id") == 8, 120) 88 | .when(F.col("id") == 9, 20) 89 | ) 90 | df = df.withColumn("EXPECTED_CAT_AGE", expected_categorized_age) 91 | 92 | # Execute function under test 93 | categorized_age_calc = get_categorized_age_field("ADMIDATE", "DOB") 94 | df = df.withColumn("CAT_AGE_CALC", categorized_age_calc) 95 | 96 | # Check against expectation 97 | assert columns_equal(df, "CAT_AGE_CALC", "EXPECTED_CAT_AGE") 98 | 99 | 100 | test_get_categorized_age_field.run() 101 | 102 | # COMMAND ---------- 103 | 104 | test_get_fractional_from_categorized_age_field = FunctionTestSuite() 105 | 106 | 107 | @test_get_fractional_from_categorized_age_field.add_test 108 | def test_creates_column_with_expected_values(): 109 | # Inputs 110 | schema = "id: int, APPTAGE: int" 111 | df = spark.createDataFrame( 112 | [ 113 | (0, None), 114 | (1, 7001), 115 | (2, 7002), 116 | (3, 7003), 117 | (4, 7004), 118 | (5, 7005), 119 | (6, 7006), 120 | (7, 7007), 121 | (8, 120), 122 | (9, 38) 123 | ], 124 | schema 125 | ) 126 | 127 | # Define expectations 128 | expected_fractional_age = ( 129 | F.when(F.col("id") == 0, None) 130 | .when(F.col("id") == 1, 0.002) 131 | .when(F.col("id") == 2, 0.010) 132 | .when(F.col("id") == 3, 0.048) 133 | .when(F.col("id") == 4, 0.167) 134 | .when(F.col("id") == 5, 0.375) 135 | .when(F.col("id") == 6, 0.625) 136 | .when(F.col("id") == 7, 0.875) 137 | .when(F.col("id") == 8, 120) 138 | .when(F.col("id") == 9, 38) 139 | ) 140 | df = df.withColumn("EXPECTED_APPTAGE_CALC", expected_fractional_age) 141 | 142 | # Execute function under test 143 | fractional_age = get_fractional_from_categorized_age_field("APPTAGE") 144 | df = df.withColumn("APPTAGE_CALC", fractional_age) 145 | 146 | # Check against expectation 147 | assert columns_equal(df, "APPTAGE_CALC", "EXPECTED_APPTAGE_CALC"), df.show() 148 | 149 | 150 | test_get_fractional_from_categorized_age_field.run() 151 | 152 | # COMMAND ---------- 153 | 154 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/derivations_tests/sequential_fields_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../imports 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_fields 7 | 8 | # COMMAND ---------- 9 | 10 | from pyspark.sql import functions as F, types as T 11 | from datetime import datetime 12 | from itertools import starmap 13 | 14 | # COMMAND ---------- 15 | 16 | test_sort_fields_l2r = FunctionTestSuite() 17 | 18 | 19 | @test_sort_fields_l2r.add_test 20 | def test_sorts_two_columns_correctly(): 21 | # Inputs 22 | schema = "DATE_OF_BIRTH: date, DATE_OF_DEATH: date" 23 | df = spark.createDataFrame([ 24 | [datetime(2021, 1, 1), datetime(2020, 1, 1)], 25 | [datetime(2019, 1, 1), datetime(2020, 1, 1)], 26 | [datetime(2019, 1, 1), datetime(2016, 1, 12)], 27 | [datetime(2019, 2, 1), datetime(2019, 1, 1)], 28 | ], schema) 29 | 30 | # Execute function under test 31 | sorted_columns = sort_fields_l2r(*df.columns) 32 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns 33 | result_df = df.select(*sorted_columns) 34 | 35 | # Define expectations 36 | expected_df = spark.createDataFrame([ 37 | [datetime(2020, 1, 1), datetime(2021, 1, 1)], 38 | [datetime(2019, 1, 1), datetime(2020, 1, 1)], 39 | [datetime(2016, 1, 12), datetime(2019, 1, 1)], 40 | [datetime(2019, 1, 1), datetime(2019, 2, 1)], 41 | ], schema) 42 | 43 | # Check against expectation 44 | assert dataframes_equal(result_df, expected_df) 45 | 46 | 47 | @test_sort_fields_l2r.add_test 48 | def test_sorts_many_columns_correctly(): 49 | # Inputs 50 | schema = "DATE_OF_BIRTH: date, DATE_OF_ADMISSION: date, DATE_OF_TREATMENT: date, DATE_OF_DISCHARGE: date" 51 | df = spark.createDataFrame([ 52 | [datetime(2021, 1, 1), datetime(2020, 1, 1), datetime(2020, 3, 1), datetime(2020, 2, 1)], 53 | [datetime(2016, 1, 1), datetime(2018, 1, 1), datetime(2018, 3, 5), datetime(2020, 3, 4)], 54 | [datetime(2021, 1, 1), datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2020, 4, 1)], 55 | [datetime(2021, 1, 1), datetime(2020, 1, 1), datetime(2019, 3, 1), datetime(2018, 2, 1)], 56 | ], schema) 57 | 58 | # Execute function under test 59 | sorted_columns = sort_fields_l2r(*df.columns) 60 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns 61 | result_df = df.select(*sorted_columns) 62 | 63 | # Define expectations 64 | expected_df = spark.createDataFrame([ 65 | [datetime(2020, 1, 1), datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2021, 1, 1)], 66 | [datetime(2016, 1, 1), datetime(2018, 1, 1), datetime(2018, 3, 5), datetime(2020, 3, 4)], 67 | [datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2020, 4, 1), datetime(2021, 1, 1)], 68 | [datetime(2018, 2, 1), datetime(2019, 3, 1), datetime(2020, 1, 1), datetime(2021, 1, 1)], 69 | ], schema) 70 | 71 | # Check against expectation 72 | assert dataframes_equal(result_df, expected_df) 73 | 74 | 75 | @test_sort_fields_l2r.add_test 76 | def test_pushes_nulls_to_right_when_sorting(): 77 | # Inputs 78 | schema = "DATE_OF_BIRTH: date, DATE_OF_DEATH: date" 79 | df = spark.createDataFrame([ 80 | [datetime(2021, 1, 1), None], 81 | [datetime(2019, 1, 1), datetime(2020, 1, 1)], 82 | [None, None], 83 | [None, datetime(2019, 1, 1)], 84 | ], schema) 85 | 86 | # Execute function under test 87 | sorted_columns = sort_fields_l2r(*df.columns) 88 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns 89 | result_df = df.select(*sorted_columns) 90 | 91 | # Define expectations 92 | expected_df = spark.createDataFrame([ 93 | [datetime(2021, 1, 1), None], 94 | [datetime(2019, 1, 1), datetime(2020, 1, 1)], 95 | [None, None], 96 | [datetime(2019, 1, 1), None], 97 | ], schema) 98 | 99 | # Check against expectation 100 | assert dataframes_equal(result_df, expected_df) 101 | 102 | 103 | @test_sort_fields_l2r.add_test 104 | def test_does_not_affect_unspecified_columns(): 105 | # Inputs 106 | schema = ["DATE_OF_BIRTH", "DATE_OF_DEATH", "ID", "NAME"] 107 | df = spark.createDataFrame([ 108 | [datetime(2021, 1, 1), datetime(2020, 1, 1), 0, "name1"], 109 | [datetime(2019, 1, 1), datetime(2020, 1, 1), 1, "name1"], 110 | [datetime(2019, 1, 1), datetime(2016, 1, 12), 2, "name2"], 111 | [datetime(2019, 2, 1), datetime(2019, 1, 1), 3, "name3"], 112 | ], schema) 113 | 114 | # Execute function under test 115 | sorted_columns = sort_fields_l2r("DATE_OF_BIRTH", "DATE_OF_DEATH") 116 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns 117 | result_df = df.select(*sorted_columns, "ID", "NAME") 118 | 119 | # Define expectations 120 | expected_df = spark.createDataFrame([ 121 | [datetime(2020, 1, 1), datetime(2021, 1, 1), 0, "name1"], 122 | [datetime(2019, 1, 1), datetime(2020, 1, 1), 1, "name1"], 123 | [datetime(2016, 1, 12), datetime(2019, 1, 1), 2, "name2"], 124 | [datetime(2019, 1, 1), datetime(2019, 2, 1), 3, "name3"], 125 | ], schema) 126 | 127 | # Check against expectation 128 | assert dataframes_equal(result_df, expected_df) 129 | 130 | 131 | test_sort_fields_l2r.run() 132 | 133 | # COMMAND ---------- 134 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/id_fields_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ./imports 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../../../notebooks/dependencies/spark_rstr 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/id_fields 11 | 12 | # COMMAND ---------- 13 | 14 | from pyspark.sql import functions as F 15 | 16 | # COMMAND ---------- 17 | 18 | rstr = SparkRstr() 19 | 20 | # COMMAND ---------- 21 | 22 | # ID fields: 23 | 24 | APC_ID_FIELDS = [ 25 | "SUSRECID", 26 | "PSEUDO_HESID", 27 | "AEKEY", 28 | "EPIKEY", 29 | ] 30 | 31 | OP_ID_FIELDS = [ 32 | "PSEUDO_HESID", 33 | "ATTENDID", 34 | "ATTENDKEY", 35 | "PREFERER", 36 | ] 37 | 38 | AE_ID_FIELDS = [ 39 | "AEKEY", 40 | "PSEUDO_HESID", 41 | "EPIKEY" 42 | ] 43 | 44 | # COMMAND ---------- 45 | 46 | # Patterns: 47 | 48 | ID_FIELD_PATTERNS = dict( 49 | PSEUDO_HESID = r"TEST[0-9a-zA-Z]{28}", # 32an - first 4 chars = TEST to ensure no overlap with real IDs 50 | SUSRECID = r"\d{14}", 51 | AEKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 52 | EPIKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 53 | ATTENDNO = r"[0-9a-zA-Z]{12}", 54 | ATTENDKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22 55 | ATTENDID = r"[0-9a-zA-Z]{12}", 56 | PREFERER = r"[0-9a-zA-Z]{16}", # What about nulls (&) / invalids (99) 57 | ) 58 | 59 | # COMMAND ---------- 60 | 61 | test_get_id_fields = FunctionTestSuite() 62 | 63 | 64 | @test_get_id_fields.add_test 65 | def test_creates_hes_ae_columns(): 66 | # Inputs 67 | df = spark.range(15) 68 | 69 | # Execute function under test 70 | id_field_cols = get_id_fields("hes_ae") 71 | for field_name in id_field_cols: 72 | df = df.withColumn(field_name, id_field_cols[field_name]) 73 | 74 | for field in AE_ID_FIELDS: 75 | # Define expectations 76 | extracted_field_name = f'{field}_TEST' 77 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0) 78 | df = df.withColumn(extracted_field_name, extracted_field) 79 | 80 | # Check against expectation 81 | assert columns_equal(df, field, extracted_field_name) 82 | 83 | 84 | @test_get_id_fields.add_test 85 | def test_creates_hes_apc_columns(): 86 | # Inputs 87 | df = spark.range(15) 88 | 89 | # Execute function under test 90 | id_field_cols = get_id_fields("hes_apc") 91 | for field_name in id_field_cols: 92 | df = df.withColumn(field_name, id_field_cols[field_name]) 93 | 94 | for field in APC_ID_FIELDS: 95 | # Define expectations 96 | extracted_field_name = f'{field}_TEST' 97 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0) 98 | df = df.withColumn(extracted_field_name, extracted_field) 99 | 100 | # Check against expectation 101 | assert columns_equal(df, field, extracted_field_name) 102 | 103 | 104 | @test_get_id_fields.add_test 105 | def test_creates_hes_op_columns(): 106 | # Inputs 107 | df = spark.range(15) 108 | 109 | # Execute function under test 110 | id_field_cols = get_id_fields("hes_op") 111 | for field_name in id_field_cols: 112 | df = df.withColumn(field_name, id_field_cols[field_name]) 113 | 114 | for field in OP_ID_FIELDS: 115 | # Define expectations 116 | extracted_field_name = f'{field}_TEST' 117 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0) 118 | df = df.withColumn(extracted_field_name, extracted_field) 119 | 120 | # Check against expectation 121 | assert columns_equal(df, field, extracted_field_name) 122 | 123 | 124 | test_get_id_fields.run() -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/imports.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../../test_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../../../notebooks/common/spark_helpers -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./field_definitions_tests/derivations_tests/age_fields_tests", 9 | "./field_definitions_tests/derivations_tests/coded_fields_tests", 10 | "./field_definitions_tests/derivations_tests/derivations_helpers_tests", 11 | "./field_definitions_tests/derivations_tests/sequential_fields_tests", 12 | "./field_definitions_tests/id_fields_tests", 13 | "./demographic_field_generators_tests", 14 | ] 15 | 16 | for notebook_path in test_notebook_paths: 17 | dbutils.notebook.run(notebook_path, 0) 18 | 19 | # COMMAND ---------- 20 | 21 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./artificial_hes_tests/run_tests" 9 | ] 10 | 11 | for notebook_path in test_notebook_paths: 12 | dbutils.notebook.run(notebook_path, 0) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_stages_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./field_generators_tests", 9 | "./relationship_generator_tests", 10 | "./sampling_tests/field_definitions_tests", 11 | ] 12 | 13 | for notebook_path in test_notebook_paths: 14 | dbutils.notebook.run(notebook_path, 0) 15 | 16 | # COMMAND ---------- 17 | 18 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/generator_stages_tests/sampling_tests/field_definitions_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %run ../../../notebooks/common/spark_helpers 3 | 4 | # COMMAND ---------- 5 | 6 | # MAGIC %run ../../test_helpers 7 | 8 | # COMMAND ---------- 9 | 10 | # MAGIC %run ../../../notebooks/generator_stages/sampling/field_definitions 11 | 12 | # COMMAND ---------- 13 | 14 | import random 15 | from pyspark.sql import functions as F 16 | 17 | # COMMAND ---------- 18 | 19 | # region total_over_window tests 20 | 21 | total_over_window_test_suite = FunctionTestSuite() 22 | 23 | 24 | @total_over_window_test_suite.add_test 25 | def test_correctly_computes_total_within_partition(): 26 | input_data = [ 27 | ["x", 1], 28 | ["y", 2], 29 | ["z", 3], 30 | ["x", 4], 31 | ["y", 5], 32 | ["z", 6], 33 | ["x", 7], 34 | ["y", 8], 35 | ["z", 9], 36 | ] 37 | input_schema = "FIELD_NAME: string, VALUE: integer" 38 | input_df = spark.createDataFrame(input_data, input_schema) 39 | 40 | sum_col = "VALUE" 41 | partition_cols = ["FIELD_NAME"] 42 | result_df = input_df.withColumn("TOTAL", total_over_window(sum_col, partition_cols)) 43 | 44 | expected_total_col = ( 45 | F 46 | .when(F.col("FIELD_NAME") == "x", 1 + 4 + 7) 47 | .when(F.col("FIELD_NAME") == "y", 2 + 5 + 8) 48 | .when(F.col("FIELD_NAME") == "z", 3 + 6 + 9) 49 | .cast("long") 50 | ) 51 | 52 | result_df = result_df.withColumn("EXPECTED_TOTAL", expected_total_col) 53 | 54 | assert columns_equal(result_df, "TOTAL", "EXPECTED_TOTAL"), "Total sum did not equal expected value" 55 | 56 | 57 | total_over_window_test_suite.run() 58 | 59 | # endregion 60 | 61 | # COMMAND ---------- 62 | 63 | # region cumsum_over_window tests 64 | 65 | cumsum_over_window_test_suite = FunctionTestSuite() 66 | 67 | 68 | @cumsum_over_window_test_suite.add_test 69 | def test_correctly_computes_cumsum_over_partition(): 70 | # Column 1 specifies partition 71 | # Column 2 specifies ordering within the partition 72 | # Column 3 is the value to cumsum 73 | input_data = [ 74 | ["x", 1, 1], 75 | ["y", 1, 2], 76 | ["z", 1, 3], 77 | ["x", 2, 4], 78 | ["y", 2, 5], 79 | ["z", 2, 6], 80 | ["x", 3, 7], 81 | ["y", 3, 8], 82 | ["z", 3, 9], 83 | ] 84 | random.shuffle(input_data) 85 | input_schema = "PARTITION_FIELD: string, ORDER_FIELD: integer, VALUE: long" 86 | input_df = spark.createDataFrame(input_data, input_schema) 87 | 88 | sum_col = "VALUE" 89 | partition_cols = "PARTITION_FIELD", 90 | order_cols = "ORDER_FIELD", 91 | result_df = input_df.withColumn("CUMSUM", cumsum_over_window(sum_col, partition_cols, order_cols)) 92 | 93 | # Order expected by partition_key, expected cumulative sum value 94 | expected_cumsum_col = ( 95 | F 96 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 1), 1) 97 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 2), 1 + 4) 98 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 3), 1 + 4 + 7) 99 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 1), 2) 100 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 2), 2 + 5) 101 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 3), 2 + 5 + 8) 102 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 1), 3) 103 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 2), 3 + 6) 104 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 3), 3 + 6 + 9) 105 | .cast("long") 106 | ) 107 | 108 | result_df = result_df.withColumn("EXPECTED_CUMSUM", expected_cumsum_col) 109 | 110 | assert columns_equal(result_df, "CUMSUM", "EXPECTED_CUMSUM"), "Cumulative sum did not equal expected value" 111 | 112 | 113 | cumsum_over_window_test_suite.run() 114 | 115 | # endregion 116 | 117 | # COMMAND ---------- 118 | 119 | -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/scraper_pipelines_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./hes/demographic_field_summarisers_tests" 9 | ] 10 | 11 | for notebook_path in test_notebook_paths: 12 | dbutils.notebook.run(notebook_path, 0) -------------------------------------------------------------------------------- /projects/iuod_artificial_data_generator/tests/scraper_stages_tests/run_tests.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md 3 | # MAGIC This notebook will run all the tests! 4 | 5 | # COMMAND ---------- 6 | 7 | test_notebook_paths = [ 8 | "./aggregation_tests/field_summarisers_tests", 9 | "./aggregation_tests/relationship_summariser_tests", 10 | "./preprocessing_tests/test_meta_type_classifier", 11 | ] 12 | 13 | for notebook_path in test_notebook_paths: 14 | dbutils.notebook.run(notebook_path, 0) 15 | 16 | # COMMAND ---------- 17 | 18 | -------------------------------------------------------------------------------- /utils/export.ps1: -------------------------------------------------------------------------------- 1 | # Export directories from Databricks 2 | 3 | $exportEnv = $(Read-Host -Prompt "Export 'dev' or 'staging' folders? (default=dev)") 4 | 5 | if ($exportEnv -eq "staging") 6 | { 7 | $workspaceRoot = "/staging" 8 | } 9 | elseif ($exportEnv -eq "dev") 10 | { 11 | $branchName = $(git rev-parse --abbrev-ref HEAD) 12 | $workspaceRoot = "/data_manager_projects/iuod_dae_test_data_generator/${branchName}" 13 | } 14 | else 15 | { 16 | "Invalid value: '$exportEnv'" 17 | exit 1; 18 | } 19 | 20 | $projects = @( 21 | "iuod_artificial_data_generator", 22 | "iuod_artificial_data_admin", 23 | "artificial_hes", 24 | "artificial_hes_meta" 25 | ) 26 | 27 | git stash # Stash local changes before overwrite 28 | 29 | foreach ( $projectName in $projects ) 30 | { 31 | $databricksPath = "$workspaceRoot/$projectName" 32 | $localPath = "projects/$projectName" 33 | Remove-Item $localPath -Recurse 34 | databricks workspace export_dir -o $databricksPath $localPath 35 | } -------------------------------------------------------------------------------- /utils/import.ps1: -------------------------------------------------------------------------------- 1 | # Export directories from Databricks 2 | 3 | $importEnv = $(Read-Host -Prompt "Import to 'dev' or 'staging'? (default='dev')") 4 | 5 | if ($importEnv -eq "staging") 6 | { 7 | $workspaceRoot = "/staging" 8 | } 9 | elseif ($importEnv -eq "dev") 10 | { 11 | $branchName = $(git rev-parse --abbrev-ref HEAD) 12 | $workspaceRoot = "/data_manager_projects/iuod_dae_test_data_generator/${branchName}" 13 | } 14 | else 15 | { 16 | "Invalid value: '$importEnv'" 17 | exit 1; 18 | } 19 | 20 | $projects = @( 21 | "iuod_artificial_data_generator" 22 | "iuod_artificial_data_admin", 23 | "artificial_hes", 24 | "artificial_hes_meta" 25 | ) 26 | 27 | foreach ( $projectName in $projects ) 28 | { 29 | $localPath = "projects/$projectName" 30 | $databricksPath = "$workspaceRoot/$projectName" 31 | $resourceExists = !($(databricks workspace ls $databricksPath) -like '*"error_code":"RESOURCE_DOES_NOT_EXIST"*') 32 | 33 | if ($resourceExists) { 34 | "'$($databricksPath)' already exists in the Databricks workspace" 35 | $overwrite = Read-Host -Prompt "Do you wish to overwrite? (y/n)" 36 | 37 | if ($overwrite.ToLower() -eq "y") 38 | { 39 | # Require user confirmation of the project name for overwriting 40 | $confirmMessage = "Confirm the project name to overwrite ($($projectName))" 41 | $projectNameCheck = Read-Host -Prompt $confirmMessage 42 | 43 | while ($projectNameCheck -ne $projectName) 44 | { 45 | # Keep going until the names match 46 | "'$($projectNameCheck)' does not match '$($projectName)', please try again" 47 | $projectNameCheck = Read-Host -Prompt $confirmMessage 48 | } 49 | 50 | # Import the directory with overwrite 51 | databricks workspace import_dir -o $localPath $databricksPath 52 | } 53 | else { 54 | continue 55 | } 56 | 57 | } 58 | else 59 | { 60 | # Import the directory 61 | databricks workspace import_dir $localPath $databricksPath 62 | } 63 | 64 | } -------------------------------------------------------------------------------- /utils/list_releases.ps1: -------------------------------------------------------------------------------- 1 | # List the latest code promotion project releases 2 | 3 | $workspaceRoot = "/Users/admin/releases/code-promotion" 4 | $projects = @( 5 | "iuod_artificial_data_generator" 6 | "iuod_artificial_data_admin", 7 | "artificial_hes", 8 | "artificial_hes_meta" 9 | ) 10 | 11 | foreach ( $projectName in $projects ) 12 | { 13 | $projectReleasesRoot = "$workspaceRoot/$projectName" 14 | 15 | # Get all the releases for the project 16 | $releases = $(databricks workspace ls $projectReleasesRoot) 17 | 18 | # Find the latest release 19 | $latestRelease = $null 20 | $latestTimestamp = [datetime]"1900-01-01" 21 | foreach ($release in $releases) 22 | { 23 | $releaseSegments = $release.Split(".") 24 | $releaseIdTimestamp = $releaseSegments[0].Split("+") 25 | $releaseId = $releaseIdTimestamp[0] # Not used 26 | $releaseTimestamp = [datetime]::ParseExact($releaseIdTimestamp[1], "yyyyMMddHHmmss", $null) 27 | $releaseHash = $releaseSegments[1].Substring(3) # Not used 28 | 29 | if ($latestTimestamp -lt $releaseTimestamp) 30 | { 31 | # Update the latest release 32 | $latestRelease = $release 33 | $latestTimestamp = $releaseTimestamp 34 | } 35 | 36 | } 37 | 38 | "Latest release for '$projectName': '$latestRelease' on $($latestTimestamp.ToString("yyyy-MM-dd 'at' HH:mm:ss"))" 39 | } --------------------------------------------------------------------------------