├── .gitignore
├── LICENSE
├── README.md
├── docs
├── artificial_data_admin_guide.md
├── artificial_data_user_notice.md
└── build_docs
│ ├── README.md
│ ├── create_user_docs_template.py
│ └── make_create_user_docs.py
├── environment.yml
├── notebooks
├── admin
│ └── collab
│ │ ├── artificial_hes
│ │ ├── 1_scrape_metadata.py
│ │ ├── 2_review_metadata.py
│ │ ├── 3_release_metadata.py
│ │ ├── 4_generate_artificial_data.py
│ │ ├── 5_qa_artificial_data.py
│ │ └── 6_release_artificial_data.py
│ │ ├── example_dotenv.py
│ │ ├── job_utils.py
│ │ └── meta_review
│ │ ├── assertions.py
│ │ ├── assertions_tests.py
│ │ ├── field_definitions.py
│ │ ├── regex_patterns.py
│ │ └── regex_patterns_tests.py
└── user
│ └── collab
│ └── artificial_data_user_notice.py
├── projects
├── artificial_hes
│ ├── cp_config.py
│ ├── init_schemas.py
│ ├── notebooks
│ │ ├── code_promotion_paths.py
│ │ └── common
│ │ │ ├── coerce_schema.py
│ │ │ ├── common_exports.py
│ │ │ ├── spark_helpers.py
│ │ │ ├── table_helpers.py
│ │ │ └── widget_utils.py
│ ├── run_notebooks.py
│ ├── run_tests.py
│ ├── schemas
│ │ └── create_user_docs.py
│ └── tests
│ │ └── placeholder.py
├── artificial_hes_meta
│ ├── cp_config.py
│ ├── init_schemas.py
│ ├── notebooks
│ │ ├── code_promotion_paths.py
│ │ └── common
│ │ │ ├── coerce_schema.py
│ │ │ ├── common_exports.py
│ │ │ ├── spark_helpers.py
│ │ │ ├── table_helpers.py
│ │ │ └── widget_utils.py
│ ├── run_notebooks.py
│ ├── run_tests.py
│ ├── schemas
│ │ └── placeholder.py
│ └── tests
│ │ └── placeholder.py
├── iuod_artificial_data_admin
│ ├── cp_config.py
│ ├── init_schemas.py
│ ├── notebooks
│ │ ├── code_promotion_paths.py
│ │ ├── code_promotion_versions.py
│ │ └── widget_utils.py
│ ├── run_notebooks.py
│ ├── run_tests.py
│ ├── schemas
│ │ └── placeholder.py
│ └── tests
│ │ └── placeholder.py
└── iuod_artificial_data_generator
│ ├── cp_config.py
│ ├── init_schemas.py
│ ├── notebooks
│ ├── admin
│ │ └── driver.py
│ ├── common
│ │ ├── coerce_schema.py
│ │ ├── common_exports.py
│ │ ├── spark_helpers.py
│ │ ├── table_helpers.py
│ │ └── widget_utils.py
│ ├── dataset_definitions
│ │ ├── hes
│ │ │ ├── hes_ae_schema.py
│ │ │ ├── hes_ae_tables.py
│ │ │ ├── hes_apc_schema.py
│ │ │ ├── hes_apc_tables.py
│ │ │ ├── hes_op_schema.py
│ │ │ ├── hes_op_tables.py
│ │ │ ├── hes_patient_table.py
│ │ │ ├── hes_schemas.py
│ │ │ └── hes_tables.py
│ │ ├── relational.py
│ │ └── relational_helpers.py
│ ├── dependencies
│ │ ├── random.py
│ │ ├── rstr.py
│ │ ├── spark_rstr.py
│ │ └── tests
│ │ │ ├── test_rstr.py
│ │ │ └── test_spark_rstr.py
│ ├── generator_pipelines
│ │ ├── artificial_hes
│ │ │ ├── coerce_hes_schema.py
│ │ │ ├── demographic_field_generators.py
│ │ │ ├── driver.py
│ │ │ └── field_definitions
│ │ │ │ ├── derivations
│ │ │ │ ├── age_fields.py
│ │ │ │ ├── coded_field_lists.py
│ │ │ │ ├── coded_fields.py
│ │ │ │ ├── derivations_helpers.py
│ │ │ │ ├── sequential_field_lists.py
│ │ │ │ └── sequential_fields.py
│ │ │ │ ├── id_field_lists.py
│ │ │ │ ├── id_field_regex_patterns.py
│ │ │ │ └── id_fields.py
│ │ ├── driver_imports.py
│ │ └── driver_parameters.py
│ ├── generator_stages
│ │ ├── derivations.py
│ │ ├── field_generators.py
│ │ ├── relationship_generator.py
│ │ └── sampling
│ │ │ ├── cdf_sampling.py
│ │ │ └── field_definitions.py
│ ├── scraper_pipelines
│ │ ├── driver_imports.py
│ │ └── hes
│ │ │ ├── constants
│ │ │ ├── disclosure_control_parameters.py
│ │ │ ├── excluded_fields.py
│ │ │ └── type_overrides.py
│ │ │ ├── demographic_aggregation_parameters.py
│ │ │ ├── demographic_field_summarisers.py
│ │ │ └── driver.py
│ └── scraper_stages
│ │ ├── aggregation
│ │ ├── aggregation_parameters.py
│ │ ├── field_summarisers.py
│ │ ├── relationship_summariser.py
│ │ └── summariser_factory.py
│ │ ├── ingestion.py
│ │ ├── preprocessing
│ │ └── meta_type_classifier.py
│ │ ├── schemas
│ │ ├── long_schema.py
│ │ └── meta_schema.py
│ │ ├── scrape_metadata.py
│ │ └── wide_to_long.py
│ ├── run_notebooks.py
│ ├── run_tests.py
│ ├── schemas
│ ├── create_meta_table.py
│ └── uplifts
│ │ ├── artificial_hes_schema_coercion.py
│ │ ├── artificial_hes_schema_coercion_tests.py
│ │ ├── open_data_metadata_uplift.py
│ │ └── open_data_metadata_uplift_tests.py
│ └── tests
│ ├── admin_tests
│ └── placeholder_tests.py
│ ├── common_tests
│ ├── run_tests.py
│ └── test_spark_helpers.py
│ ├── dataset_definitions_tests
│ ├── placeholder_tests.py
│ └── run_tests.py
│ ├── generator_pipelines_tests
│ ├── artificial_hes_tests
│ │ ├── demographic_field_generators_tests.py
│ │ ├── field_definitions_tests
│ │ │ ├── derivations_tests
│ │ │ │ ├── age_fields_tests.py
│ │ │ │ ├── coded_fields_tests.py
│ │ │ │ ├── derivations_helpers_tests.py
│ │ │ │ └── sequential_fields_tests.py
│ │ │ ├── id_fields_tests.py
│ │ │ └── imports.py
│ │ └── run_tests.py
│ └── run_tests.py
│ ├── generator_stages_tests
│ ├── field_generators_tests.py
│ ├── relationship_generator_tests.py
│ ├── run_tests.py
│ └── sampling_tests
│ │ └── field_definitions_tests.py
│ ├── scraper_pipelines_tests
│ ├── hes
│ │ └── demographic_field_summarisers_tests.py
│ └── run_tests.py
│ ├── scraper_stages_tests
│ ├── aggregation_tests
│ │ ├── field_summarisers_tests.py
│ │ └── relationship_summariser_tests.py
│ ├── preprocessing_tests
│ │ └── test_meta_type_classifier.py
│ └── run_tests.py
│ └── test_helpers.py
└── utils
├── export.ps1
├── import.ps1
└── list_releases.ps1
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 |
4 | # Environments
5 | *.bak
6 | *.env
7 | .venv
8 |
9 | .vscode
10 |
11 | # Notebooks containing credentials
12 | dotenv.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 NHS Digital
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/artificial_data_admin_guide.md:
--------------------------------------------------------------------------------
1 | ## Guide for artificial data admins
2 | This section is aimed at administrators of artificial data, working in an environment where they can trigger the different stages of the process.
3 |
4 | ### How do I extract metadata?
5 | To extract metadata for a dataset, run the `run_notebooks.py` notebook in the corresponding `artificial_{dataset_name}_meta` (for example `artificial_hes_meta/run_notebooks.py`).
6 |
7 | By default, this will read data from a database named according to the `{dataset_name}` parameter in the template string above and write to a database called `artificial_{dataset_name}_meta`. These can be changed via the notebook widgets.
8 |
9 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the respective project, as this will have the privileges to access the live data.
10 | Approved human users will only have access to access aggregates for review, not to the underlying record-level data.
11 |
12 | ### What happens after metadata is extracted?
13 | Once the metadata has been extracted, it should be manually reviewed by a member of staff working in a secure environment to ensure no personally identifiable information (PII) is disclosed. This should be signed off by a senior member of staff.
14 |
15 | At NHS Digital we have a checklist that was approved by the Statistical Disclosure Control Panel, chaired by the Chief Statistician.
16 |
17 | Once we have checked the metadata and signed it off we move it into a database which inaccessible to the metadata scraper and so is completely isolated from the database containing the real data.
18 | This is done by executing the `run_notebooks.py` notebook in the `iuod_artificial_data_admin` project.
19 |
20 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the `iuod_artificial_data_admin` project, as this will have the privileges to read from / write to the appropriate databases.
21 |
22 | ### How do I generate artificial data?
23 | To generate data and run postprocessing for a dataset, run the `run_notebooks.py` notebook in the `iuod_artificial_data_generator` project with the name of the dataset to generate artificial data for entered accordingly.
24 | For example for artificial HES data set the 'artificial_dataset' parameter to 'hes'.
25 |
26 | By default, this process will read metadata from and write artificial data to a database called `iuod_artificial_data_generator`, but this parameter can be changed via the notebook widgets.
27 |
28 | When running in production on live data, this should be done by triggering the `run_notebooks` job for the `iuod_artificial_data_generator` project, as this will have the privileges to access the aggregated data.
--------------------------------------------------------------------------------
/docs/artificial_data_user_notice.md:
--------------------------------------------------------------------------------
1 |
Notice For Artificial Data Users
2 |
3 | What is artificial data?
4 |
5 | Artificial data is an anonymous representation of real data
6 |
7 | - Artificial data provides an anonymous representation of some of the properties of real datasets.
8 | - Artificial data preserves the formatting and structure of the original dataset, but may otherwise be unrealistic.
9 | - Artificial data reproduces some of the statistical properties and content complexity of fields in the real data, while excluding cross-dependencies between fields to prevent risks of reidentification.
10 | - Artificial data is completely isolated from any record-level data.
11 | - It is not possible to use artificial data to reidentify individuals, gain insights, or build statistical models that would transfer onto real data.
12 |
13 |
14 | How is it generated?
15 |
16 | There are three stages involved in generating the artificial data:
17 |
18 |
19 | - The Metadata Scraper: extracts anonymised, high-level aggregates from real data at a national level. At this stage key identifiers (such as patient ID) are removed and small number suppression is applied in order to prevent reidentification at a later stage.
20 | - The Data Generator: samples from the aggregates generated by the Metadata Scraper on a field-by-field basis and puts the sampled values together to create artificial records.
21 | - Postprocessing: using the output of the Data Generator, dataset-specific tweaks are applied to make the data appear more realistic (such as swapping randomly generated birth and death dates to ensure sensible ordering). This also includes adding in randomly generated identifying fields (such ‘patient’ ID) which were removed at the Metadata Scraper stage.
22 |
23 |
24 | What is it used for?
25 |
26 | The purpose of artificial data is twofold.
27 |
28 | 1. Artificial data enables faster onboarding for new data projects
29 |
30 | Users can begin work on a project in advance of access to real data being granted. There are multiple use cases for the artificial data.
31 |
32 | - For TRE/DAE users who have submitted (or intend to submit) an access request for a given dataset: artificial data can give a feel for the format and layout of the real data prior to accessing it. It also allows users to create and test pipelines before accessing real data, or perhaps without ever accessing real data at all.
33 | - For TRE/DAE users who are unsure which datasets are relevant to their project: artificial data allows users to understand which data sets would be useful for them prior to applying for access. Artificial data can complement technical information found in the Data Dictionary for a given dataset, and give users a feel for how they would work with that dataset.
34 |
35 |
36 | 2. Artificial data minimises the amount of personal data being processed
37 |
38 | The activities mentioned above can all be completed without handling personal data, improving the ability of NHS Digital to protect patient privacy by minimising access to sensitive data.
39 |
40 |
41 | What are its limitations?
42 |
43 | Artificial data is not real data, and is not intended to represent something real or link to real records. Artificial records are not based on any specific records found in the original data, only on high-level, anonymised aggregates. As outlined above, it is intended to improve efficiency and protect patient data.
44 |
45 | It is crucial to note that artificial data is not synthetic data.
46 |
47 | Synthetic data is generated using sophisticated methods to create realistic records. Usually synthetic data aims to enable the building of statistical models or gaining of insights that transfer onto real data. This is not be possible with artificial data. The downside of synthetic data is that it is associated with non-negligible risks to patient privacy through reidentification. It is not possible to reidentify individuals using artificial data.
48 |
49 |
50 | Additional Information
51 |
52 | Support from senior leadership
53 | This work has undergone due process to fully assess any potential risks to patient privacy and has been approved by senior leadership within NHS Digital, including: the Senior Information Risk Owner (SIRO); the Caldicott Guardian; the Data Protection Officer (DPO); the Executive Director of Data and Analytics Services; and the IAOs for the datasets represented by the Artificial Data assets. A full DPIA has been completed and is available upon request.
54 |
55 | For further details, please get in touch via the mailbox linked below.
56 |
57 | Contact
58 | For further details, please get in touch via: nhsdigital.artificialdata@nhs.net
59 |
--------------------------------------------------------------------------------
/docs/build_docs/README.md:
--------------------------------------------------------------------------------
1 |
2 | # How do I add user documentation to the production DAE?
3 | There are 2 steps to building the documentation and adding it to DAE:
4 | 1. (Desktop Python) Run the python file `notebooks/user/build_docs/make_create_user_docs.py`: the full documentation exists in the `docs` folder in the top-level of the repo; this step takes this documentation and puts the contents into a file that can be run on DAE to make it readable by users.
5 | 1. (DAE Prod) Copy the notebook `notebooks/user/collab/create_user_docs.py` created in step 1 into DAE Prod and run it: this will add the documentation to a user-facing table in DAE Prod.
6 |
--------------------------------------------------------------------------------
/docs/build_docs/create_user_docs_template.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import functions as F
3 |
4 | # COMMAND ----------
5 |
6 | dbutils.widgets.text("db", "", "0.1 Project Database")
7 |
8 | # COMMAND ----------
9 |
10 | # Config
11 | database_name = dbutils.widgets.get("db")
12 | table_name = "user_docs"
13 | table_path = f"{database_name}.{table_name}"
14 |
15 | # Check database exists
16 | if spark.sql(f"SHOW DATABASES LIKE '{database_name}'").first() is None:
17 | # Database not found - exit
18 | dbutils.notebook.exit({})
19 |
20 | # Template variables replaced during build
21 | user_notice_file_name = "artificial_data_user_notice"
22 |
23 | # COMMAND ----------
24 |
25 | user_notice_html = """{{artificial_data_user_notice.md}}"""
26 |
27 | # COMMAND ----------
28 |
29 | # Create and upload the docs
30 | docs_data = [
31 | [user_notice_file_name, user_notice_html]
32 | ]
33 | docs_schema = "file_name: string, content_html: string"
34 | docs_df = spark.createDataFrame(docs_data, docs_schema)
35 | (
36 | docs_df.write
37 | .format("delta")
38 | .mode("overwrite")
39 | .saveAsTable(table_path)
40 | )
41 |
42 | # Make sure users can select from the table but not overwrite
43 | if os.getenv("env", "ref") == "ref":
44 | owner = "data-managers"
45 | else:
46 | owner = "admin"
47 |
48 | spark.sql(f"ALTER TABLE {table_path} OWNER TO `{owner}`")
49 |
50 | # Check the uploads
51 | for file_name, content_html in docs_data:
52 | result_content_html = (
53 | spark.table(table_path)
54 | .where(F.col("file_name") == file_name)
55 | .first()
56 | .content_html
57 | )
58 | assert result_content_html == content_html
59 |
--------------------------------------------------------------------------------
/docs/build_docs/make_create_user_docs.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | # Projects requiring docs
4 | projects = ["artificial_hes"]
5 |
6 | # Constant paths
7 | this_path = pathlib.Path(__file__)
8 | template_path = this_path.parent / "create_user_docs_template.py"
9 | docs_path = this_path.parent.parent
10 | projects_root = this_path.parent.parent.parent / "projects"
11 |
12 | def replace_template_variable(template: str, token_value: str, new_value: str) -> str:
13 | filled_template = template.replace("".join(["{{", token_value, "}}"]), new_value)
14 | return filled_template
15 |
16 | for project_name in projects:
17 | project_path = projects_root / project_name
18 | notebook_path = project_path / "schemas" / "create_user_docs.py"
19 |
20 | # Read the template content
21 | with template_path.open("r") as template_file:
22 | notebook_content = template_file.read()
23 |
24 | # Populate placeholders in the template
25 | for doc_file_path in docs_path.glob("**/*.md"):
26 | with doc_file_path.open("r") as doc_file:
27 | doc_file_content = doc_file.read()
28 |
29 | notebook_content = replace_template_variable(notebook_content, doc_file_path.name, doc_file_content)
30 | notebook_content = replace_template_variable(notebook_content, "database_name", project_name)
31 |
32 | # Write the full content
33 | with notebook_path.open("w+") as notebook_file:
34 | notebook_file.write(notebook_content)
35 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NHSDigital/artificial-data-generator/47a406cad27125f9499a1af10ba2afac3df0c554/environment.yml
--------------------------------------------------------------------------------
/notebooks/admin/collab/artificial_hes/1_scrape_metadata.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md # Scrape HES metadata
3 | # MAGIC Run the `cp_run_notebooks_artificial_hes_meta` job with user-entered config.
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %run ../job_utils
13 |
14 | # COMMAND ----------
15 |
16 | # MAGIC %md
17 | # MAGIC ## Main
18 |
19 | # COMMAND ----------
20 |
21 | USERNAME = ""
22 | assert USERNAME, "Please provide a username for getting the API token"
23 |
24 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60)
25 | configure_databricks_cli(token)
26 |
27 | # COMMAND ----------
28 |
29 | run_job_async("cp_run_notebooks_artificial_hes_meta", sleep_seconds=0, max_retries=1)
30 |
31 | # COMMAND ----------
32 |
33 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/artificial_hes/3_release_metadata.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md # Release hes metadata
3 | # MAGIC Run the `cp_run_notebooks_iuod_artificial_data_admin` to move HES metadata into the `iuod_artificial_data_generator` database.
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %run ../job_utils
13 |
14 | # COMMAND ----------
15 |
16 | # MAGIC %md ## Config
17 |
18 | # COMMAND ----------
19 |
20 | USERNAME = ""
21 | assert USERNAME, "Please provide a username for getting the API token"
22 |
23 | job_params = {
24 | "safety_dt": "",
25 | "source_database_name": "artificial_hes_meta",
26 | "table_name": "artificial_hes_meta",
27 | "reviewer_name": "",
28 | "reviewer_email": "",
29 | "reviewed_date": "",
30 | "approver_name": "",
31 | "approver_email": "",
32 | "approved_date": "",
33 | }
34 |
35 | for param_name, param_value in job_params.items():
36 | assert param_value, f"Please provide a value for parameter '{job_params}'"
37 |
38 | # COMMAND ----------
39 |
40 | # MAGIC %md
41 | # MAGIC ## Main
42 |
43 | # COMMAND ----------
44 |
45 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60)
46 | configure_databricks_cli(token)
47 |
48 | # COMMAND ----------
49 |
50 | run_job_async("cp_run_notebooks_iuod_artificial_data_admin", **job_params, sleep_seconds=30, max_retries=20)
51 |
52 | # COMMAND ----------
53 |
54 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/artificial_hes/4_generate_artificial_data.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md # Generate artificial HES data
3 | # MAGIC Run the`cp_run_notebooks_iuod_artificial_data_generator` job for the `artificial_hes` data asset.
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %run ../job_utils
13 |
14 | # COMMAND ----------
15 |
16 | import json
17 | import pyspark.sql.functions as F
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md ## Config
22 |
23 | # COMMAND ----------
24 |
25 | USERNAME = ""
26 | assert USERNAME, "Please provide a username for getting the API token"
27 |
28 | parameters_json = json.dumps(dict(n_patients=1000000))
29 | artificial_dataset = "artificial_hes"
30 |
31 | # COMMAND ----------
32 |
33 | # MAGIC %md
34 | # MAGIC ## Main
35 |
36 | # COMMAND ----------
37 |
38 | token = dbutils.notebook.run(f"/Users/{USERNAME}/token", 60)
39 | configure_databricks_cli(token)
40 |
41 | # COMMAND ----------
42 |
43 | run_job_async(
44 | "cp_run_notebooks_iuod_artificial_data_generator",
45 | artificial_dataset=artificial_dataset,
46 | parameters_json=parameters_json,
47 | sleep_seconds=0,
48 | max_retries=1
49 | )
--------------------------------------------------------------------------------
/notebooks/admin/collab/artificial_hes/5_qa_artificial_data.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # TODO: expand the notebook below with checks implemented in Ref
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run /Users/admin/releases/code-promotion/iuod_artificial_data_generator/4332+20220930151507.gite6c853018/notebooks/dataset_definitions/hes/hes_schemas
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run /Users/admin/releases/code-promotion/iuod_artificial_data_generator/4332+20220930151507.gite6c853018/notebooks/common/spark_helpers
11 |
12 | # COMMAND ----------
13 |
14 | import pyspark.sql.functions as F
15 | import pyspark.sql.types as T
16 |
17 |
18 | database_name = "artificial_hes"
19 | table_name_pattern = r"artificial_hes_(op|apc|ae)_\d{4}"
20 |
21 | # COMMAND ----------
22 |
23 | tables = (
24 | spark.sql(f"show tables from {database_name}")
25 | .filter(F.col("tableName").rlike(table_name_pattern))
26 | .select("tableName")
27 | .collect()
28 | )
29 |
30 | for table in tables:
31 | artificial_df = spark.table(f"{database_name}.{table.tableName}")
32 |
33 | # Check non-zero rows
34 | row_count = artificial_df.count()
35 | assert row_count > 100, "Found fewer rows than expected"
36 | print(table.tableName, f"{row_count:,}")
37 |
38 | # Compare schemas
39 | if "ae" in table.tableName:
40 | hes_schema = get_hes_ae_schema()
41 | elif "op" in table.tableName:
42 | hes_schema = get_hes_op_schema()
43 | elif "apc" in table.tableName:
44 | hes_schema = get_hes_apc_schema()
45 | else:
46 | assert False, f"Table {table.table_name} does not belong to HES dataset (apc, op, ae)"
47 |
48 | # Sort for comparison between fields
49 | hes_schema = T.StructType(sorted(hes_schema.fields, key=lambda f: f.name))
50 | artificial_schema = artificial_df.select(*sorted(artificial_df.columns)).schema
51 |
52 | # TODO:
53 | # Will the artificial_hes pipeline fail when it tries to update tha AE tables?
54 | # Do we need to explicitly drop columns via init_schemas?
55 | assert datatypes_equal(artificial_schema, hes_schema), "Schema not equal to expected schema"
56 |
57 | # if len(hes_schema) != len(artificial_schema):
58 | # # TODO: handle this properly?
59 | # print(
60 | # [f for f in artificial_schema.fields if f.name not in hes_schema.fieldNames()],
61 | # [f for f in hes_schema.fields if f.name not in artificial_schema.fieldNames()]
62 | # )
63 | # else:
64 |
65 | # TODO: Check values?
66 |
67 | # COMMAND ----------
68 |
69 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/artificial_hes/6_release_artificial_data.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/example_dotenv.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Example dotenv.py
4 |
5 | # COMMAND ----------
6 |
7 | API_PREFIX = "my_api_host_prefix"
8 | API_SUFFIX = "my_api_host_suffix.com/api/2.0"
9 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/job_utils.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./dotenv.py
3 |
4 | # COMMAND ----------
5 |
6 | import json
7 | import os
8 | import time
9 | from typing import Dict, Tuple, Any
10 |
11 | # From dotenv.py
12 | API_PREFIX = API_PREFIX
13 | API_SUFFIX = API_SUFFIX
14 |
15 | DATABRICKS_ENV = os.getenv("env", "dev")
16 | HTTPS = "https://"
17 | PENDING_STATE = "PENDING"
18 | RUNNING_STATE = "RUNNING"
19 |
20 |
21 | def get_databricks_api_host() -> str:
22 | """ Get URL for the Databricks API 2.0 for the current environment.
23 |
24 | NOTE: This is specific to NHS Digital infrastructure.
25 |
26 | Returns
27 | -------
28 | str
29 | API base URL for the current environment
30 |
31 | """
32 | host_segments = [f"https://{API_PREFIX}"]
33 |
34 | if DATABRICKS_ENV != "prod":
35 | host_segments.append(DATABRICKS_ENV)
36 |
37 | host_segments.append(API_SUFFIX)
38 | host = ".".join(host_segments)
39 |
40 | return host
41 |
42 |
43 | def configure_databricks_cli(token: str) -> None:
44 | """ Configure the Databricks CLI with the hostname and token.
45 |
46 | Parameters
47 | ----------
48 | token: str
49 | User generated access token
50 |
51 | """
52 | os.environ["DATABRICKS_HOST"] = get_databricks_api_host()
53 | os.environ["DATABRICKS_TOKEN"] = token
54 |
55 |
56 | def get_job_name_id_map() -> Dict[str, str]:
57 | result = os.popen("databricks jobs list").read()
58 | result = result.split("\n")
59 | result = filter(None, result)
60 | job_name_id_map = dict(map(lambda s: s.split()[::-1], result))
61 | return job_name_id_map
62 |
63 |
64 | def get_job_run_state(run_id: str) -> Dict[str, str]:
65 | get_run_command = f"databricks runs get --run-id {run_id}"
66 | result_json = os.popen(get_run_command).read()
67 | result = json.loads(result_json)
68 | state = result.get("state", {})
69 | return state
70 |
71 |
72 | def wait_for_job_result(run_id: str, sleep_seconds: int=5, max_retries: int=60) -> str:
73 | print(f"Checking state of run with run_id '{run_id}'")
74 |
75 | for i in range(max_retries):
76 | state = get_job_run_state(run_id)
77 | life_cycle_state = state.get("life_cycle_state", PENDING_STATE)
78 | result_state = state.get("result_state", PENDING_STATE)
79 |
80 | if life_cycle_state in (PENDING_STATE, RUNNING_STATE):
81 | print(f"Run with run_id '{run_id}' incomplete with life_cycle_state '{life_cycle_state}', checking again in {sleep_seconds} seconds")
82 | time.sleep(sleep_seconds)
83 | else:
84 | print(f"Run with run_id '{run_id}' completed with result_state '{result_state}'")
85 | return result_state
86 |
87 | else:
88 | print(f"Run with run_id '{run_id}' incomplete with life_cycle_state '{life_cycle_state}' but maximum number of retries ({max_retries}) has been exceeded")
89 | return result_state
90 |
91 |
92 | def run_job_with_notebook_params(job_id: str, notebook_params: str) -> str:
93 | # Check if there is already a job running
94 | print(f"Checking state of last run for job with job_id '{job_id}'")
95 | list_runs_command = f"databricks runs list --job-id {job_id}"
96 | result = os.popen(list_runs_command).read()
97 |
98 | if result == "\n":
99 | # No runs yet
100 | last_run_id = None
101 | last_run_life_cycle_state = None
102 | else:
103 | last_run_id = result.split("\n")[0].split()[0]
104 | last_run_state = get_job_run_state(last_run_id)
105 | last_run_life_cycle_state = last_run_state.get("life_cycle_state", PENDING_STATE)
106 | print(f"Last run for job with job_id '{job_id}' has life_cycle_state '{last_run_life_cycle_state}': run_id '{last_run_id}'")
107 |
108 | if last_run_life_cycle_state in (PENDING_STATE, RUNNING_STATE):
109 | # Job still pending/running
110 | return {"run_id": last_run_id}
111 | else:
112 | # Start a new job
113 | run_job_command = f"databricks jobs run-now --job-id {job_id} --notebook-params '{notebook_params}'"
114 | result_json = os.popen(run_job_command).read()
115 | result = json.loads(result_json)
116 | print(f"Started new job run for job_id '{job_id}' with run_id '{result['run_id']}'")
117 | return result
118 |
119 |
120 | def run_job(job_name: str, **job_kwargs) -> str:
121 | job_name_id_map = get_job_name_id_map()
122 | job_id = job_name_id_map[job_name]
123 | print(f"Found job_id for job '{job_name}': '{job_id}'")
124 |
125 | # Notebook arguments
126 | notebook_params = json.dumps(job_kwargs)
127 |
128 | return run_job_with_notebook_params(job_id, notebook_params)
129 |
130 |
131 | def run_job_async(*run_args: Tuple[Any], sleep_seconds: int=30, max_retries: int=20, **run_kwargs) -> Dict[str, str]:
132 | """ Run a job and keep checking status until it finishes or until a maximum number of retries
133 |
134 | Parameters
135 | ----------
136 | *run_args: Any
137 | Passed through to run_job as arguments
138 | sleep_seconds: int (default=30)
139 | Number of seconds to sleep after each retry
140 | max_retries: int (default=20)
141 | Number of times to retry checking if the run has finished
142 | **run_kwargs: Any
143 | Passed through to run_job as key-word arguments
144 |
145 | Returns
146 | -------
147 | Dict[str, str]
148 | Result of the job run with the result state
149 |
150 | """
151 | result = run_job(*run_args, **run_kwargs)
152 | run_id = result["run_id"]
153 | result_state = wait_for_job_result(run_id, sleep_seconds=sleep_seconds, max_retries=max_retries)
154 | return {**result, "result_state": result_state}
155 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/meta_review/assertions.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import DataFrame, Column, functions as F
3 | from nhsdccl.util import DFCheckpoint
4 |
5 |
6 | class StatisticalDisclosureControlException(Exception):
7 | def __init__(self, message="Statistical disclosure control check failed!"):
8 | return super().__init__(message)
9 |
10 |
11 | def modulo(col_name: str, x: int) -> Column:
12 | """ Spark SQL modulo expression as a PySpark Column.
13 | """
14 | return F.expr(f"{col_name} % {x}")
15 |
16 |
17 | def assert_statistical_disclosure_controls(
18 | meta_df: DataFrame,
19 | freq_col: Column,
20 | threshold: int = 100,
21 | rounding: int = 100
22 | ) -> None:
23 | """
24 | Check that the statistical disclosure controls are correctly applied to the metadata.
25 | By default the threshold and rounding on frequencies is set to 100 to avoid accidental disclosures.
26 |
27 | Parameters
28 | ----------
29 | meta_df: DataFrame
30 | Metadata
31 | freq_col: Column
32 | Column of `meta_df` holding the frequencies
33 | threshold: int (default = 100)
34 | Minimum frequency for a value to appear in the metadata
35 | rounding: int (default = 100)
36 | Frequencies should be rounded to this nearest value
37 |
38 | Returns
39 | -------
40 | None
41 |
42 | Raises
43 | ------
44 | StatisticalDisclosureControlException
45 | Raised if any frequencies do not obey the disclosure controls
46 |
47 | """
48 | # Filters for rows that break the SDCs
49 | frequency_below_threshold = F.col("FREQUENCY") < F.lit(threshold)
50 | frequency_not_divisible_by_rounding = modulo("FREQUENCY", rounding) != F.lit(0)
51 |
52 | # Find rows that break the SDCs
53 | sdc_check_df = (
54 | meta_df
55 | .withColumn("FREQUENCY", freq_col)
56 | .filter(frequency_below_threshold | frequency_not_divisible_by_rounding)
57 | )
58 |
59 | # Check that there are no rows breaking the SDCs
60 | with DFCheckpoint(sdc_check_df) as df_cp:
61 | try:
62 | assert df_cp.df.first() is None
63 |
64 | except AssertionError as e:
65 | print("Frequency check failed for the following rows: ")
66 | df_cp.df.show()
67 | raise StatisticalDisclosureControlException
68 |
69 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/meta_review/assertions_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./assertions
3 |
4 | # COMMAND ----------
5 |
6 | from unittest import TestCase
7 |
8 | test_data = [
9 | ["table1", "FIELD1", "bad_value1", 4],
10 | ["table1", "FIELD1", "bad_value2", 7],
11 | ["table1", "FIELD1", "good_value", 15],
12 | ]
13 | test_schema = "TABLE_NAME: string, FIELD_NAME: string, VALUE: string, FREQUENCY: integer"
14 | test_df = spark.createDataFrame(test_data, test_schema)
15 | test_value_col = F.col("VALUE")
16 | test_freq_col = F.col("FREQUENCY")
17 |
18 | with TestCase().assertRaises(StatisticalDisclosureControlException):
19 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "bad_value1"), test_freq_col, threshold=5, rounding=5)
20 |
21 | with TestCase().assertRaises(StatisticalDisclosureControlException):
22 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "bad_value2"), test_freq_col, threshold=5, rounding=5)
23 |
24 | assert_statistical_disclosure_controls(test_df.where(F.col("VALUE") == "good_value"), test_freq_col, threshold=5, rounding=5)
--------------------------------------------------------------------------------
/notebooks/admin/collab/meta_review/field_definitions.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import functions as F
3 |
4 |
5 | table_name_col = F.col("TABLE_NAME")
6 | field_name_col = F.col("FIELD_NAME")
7 |
8 | is_categorical_col = F.col("VALUE_TYPE") == "CATEGORICAL"
9 | is_discrete_col = F.col("VALUE_TYPE") == "DISCRETE"
10 | is_date_col = F.col("VALUE_TYPE") == "DATE"
11 | is_continuous_col = F.col("VALUE_TYPE") == "CONTINUOUS"
12 |
13 | is_demographic_categorical_col = F.col("VALUE_TYPE") == "DEMOGRAPHIC_CATEGORICAL"
14 | is_demographic_date_col = F.col("VALUE_TYPE") == "DEMOGRAPHIC_DATE"
15 |
16 | is_relationship_col = F.col("VALUE_TYPE") == "RELATIONSHIP"
--------------------------------------------------------------------------------
/notebooks/admin/collab/meta_review/regex_patterns.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import Union
3 | from pyspark.sql import DataFrame, Column, functions as F
4 |
5 |
6 | PHONE_NUMBER_PATTERN = r"^(((\+44\s?\d{4}|\(?0\d{4}\)?)\s?\d{3}\s?\d{3})|((\+44\s?\d{3}|\(?0\d{3}\)?)\s?\d{3}\s?\d{4})|((\+44\s?\d{2}|\(?0\d{2}\)?)\s?\d{4}\s?\d{4}))(\s?\#(\d{4}|\d{3}))?$"
7 | EMAIL_ADDRESS_PATTERN = r"([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+"
8 | POSTCODE_PATTERN = r"([Gg][Ii][Rr] 0[Aa]{2})|((([A-Za-z][0-9]{1,2})|(([A-Za-z][A-Ha-hJ-Yj-y][0-9]{1,2})|(([A-Za-z][0-9][A-Za-z])|([A-Za-z][A-Ha-hJ-Yj-y][0-9][A-Za-z]?))))\s?[0-9][A-Za-z]{2})"
9 | GT_6_ALPHA_CHARS_PATTERN = r"^([A-Za-z]{6,})$"
10 |
11 |
12 | def filter_matches(df: DataFrame, col: Union[Column, str], pattern: str) -> DataFrame:
13 | """
14 | Filter rows of a DataFrame based on a regex pattern match within a specified column.
15 |
16 | Parameters
17 | ----------
18 | df : pyspark.sql.DataFrame
19 | DataFrame to filter
20 | col : Union[pyspark.sql.Column, str]
21 | Column to look for matches within
22 | pattern : str
23 | Pattern to be matched
24 |
25 | Returns
26 | -------
27 | pyspark.sql.DataFrame
28 | Filtered DataFrame based on pattern matches
29 |
30 | """
31 | extracted_value = F.regexp_extract(col, pattern, 0)
32 | is_match = F.length(extracted_value) > 0
33 | match_df = df.filter(is_match)
34 | return match_df
35 |
--------------------------------------------------------------------------------
/notebooks/admin/collab/meta_review/regex_patterns_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./regex_patterns
3 |
4 | # COMMAND ----------
5 |
6 | test_df = spark.createDataFrame(
7 | [
8 | ["hello@gmail.com"],
9 | ["hello.world@nhs.net"],
10 | ["some other string"]
11 | ],
12 | "VALUE: string"
13 | )
14 | assert filter_matches(test_df, "VALUE", EMAIL_ADDRESS_PATTERN).count() == 2
15 |
16 | # COMMAND ----------
17 |
18 | test_df = spark.createDataFrame(
19 | [
20 | ["07123456789"],
21 | ["07 123 456 789"],
22 | ["+447123456789"]
23 | ],
24 | "VALUE: string"
25 | )
26 | assert filter_matches(test_df, "VALUE", PHONE_NUMBER_PATTERN).count() == 2 # Need to strip spaces
27 |
28 | # COMMAND ----------
29 |
30 | test_df = spark.createDataFrame(
31 | [
32 | ["LS1 4AP"],
33 | ["LS14AP"],
34 | ["Hello"]
35 | ],
36 | "VALUE: string"
37 | )
38 | assert filter_matches(test_df, "VALUE", POSTCODE_PATTERN).count() == 2
39 |
40 | # COMMAND ----------
41 |
42 | test_df = spark.createDataFrame(
43 | [
44 | ["ABCdef"],
45 | ["ghiJK"]
46 | ],
47 | "VALUE: string"
48 | )
49 | assert filter_matches(test_df, "VALUE", GT_6_ALPHA_CHARS_PATTERN).count() == 1
--------------------------------------------------------------------------------
/notebooks/user/collab/artificial_data_user_notice.py:
--------------------------------------------------------------------------------
1 | from pyspark.sql import functions as F
2 |
3 | # Config
4 | database_name = "artificial_hes"
5 | table_name = "user_docs"
6 | table_path = f"{database_name}.{table_name}"
7 |
8 | # Get the content
9 | content_html = (
10 | spark.table(table_path)
11 | .where(F.col("file_name") == "artificial_data_user_notice")
12 | .first()
13 | .content_html
14 | )
15 |
16 | displayHTML(content_html)
17 |
--------------------------------------------------------------------------------
/projects/artificial_hes/cp_config.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Code Promotion Job Configuration
4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks.
5 | # MAGIC Only simple setting of variables is allowed in this notebook.
6 |
7 | # COMMAND ----------
8 |
9 | # DBTITLE 1,Global settings
10 | # MAGIC %md
11 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3).
12 | # MAGIC This applies to all jobs created
13 |
14 | # COMMAND ----------
15 |
16 | spark_version = "6.6.x-scala2.11"
17 |
18 | # COMMAND ----------
19 |
20 | # DBTITLE 1,init_schemas
21 | # MAGIC %md
22 | # MAGIC Available parameters:
23 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
24 |
25 | # COMMAND ----------
26 |
27 | init_schemas = {
28 | "retain_cluster": False
29 | }
30 |
31 | # COMMAND ----------
32 |
33 | # DBTITLE 1,run_notebooks
34 | # MAGIC %md
35 | # MAGIC Available parameters:
36 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time.
37 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job.
38 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax.
39 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
40 |
41 | # COMMAND ----------
42 |
43 | run_notebooks = {
44 | "concurrency": 1,
45 | "extra_parameters": {
46 | "included_tables": "",
47 | "excluded_tables": "",
48 | },
49 | "retain_cluster": False
50 | }
51 |
52 | # COMMAND ----------
53 |
54 | # DBTITLE 1,tool_config
55 | # MAGIC %md
56 | # MAGIC Available parameters:
57 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
58 |
59 | # COMMAND ----------
60 |
61 | # Example:
62 | # tool_config = {
63 | # "retain_cluster": True
64 | # }
65 | #tool_config = {
66 | #}
--------------------------------------------------------------------------------
/projects/artificial_hes/init_schemas.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md ## Artificial HES: Init Schemas
3 | # MAGIC This is the init_schemas notebook for the artificial_hes Code Promotion project
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Code Promotion Project Paths
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ./notebooks/code_promotion_paths
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md
22 | # MAGIC ### Imports
23 |
24 | # COMMAND ----------
25 |
26 | # MAGIC %run ./notebooks/common/widget_utils
27 |
28 | # COMMAND ----------
29 |
30 | import pprint
31 |
32 | # COMMAND ----------
33 |
34 | # MAGIC %md
35 | # MAGIC ### Widgets
36 |
37 | # COMMAND ----------
38 |
39 | if check_databricks():
40 | dbutils.widgets.text("notebook_root", "artificial_hes/dev", "0.1 Notebook Root")
41 | dbutils.widgets.text("db", "artificial_hes", "0.2 Project Database")
42 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database")
43 | else:
44 | # Only make widgets in databricks
45 | pass
46 |
47 | # COMMAND ----------
48 |
49 | # MAGIC %md
50 | # MAGIC ## Main
51 | # MAGIC Run schema notebooks
52 |
53 | # COMMAND ----------
54 |
55 | notebook_root = get_required_argument("notebook_root")
56 | *_, project_name, project_version = notebook_root.split("/")
57 |
58 | cp_project_params = {}
59 | cp_project_params["project_name"] = project_name
60 | cp_project_params["project_version"] = project_version
61 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version)
62 |
63 | # COMMAND ----------
64 |
65 | print(f"Running with context:")
66 | pprint.pprint(cp_project_params)
67 |
68 | # COMMAND ----------
69 |
70 | database_name = get_required_argument("db")
71 | adg_schemas_path = cp_project_params["adg_project_path"] / "schemas"
72 | artificial_hes_schema_coercion_path = str(adg_schemas_path / "uplifts" / "artificial_hes_schema_coercion")
73 |
74 | notebook_params = [
75 | {
76 | "path": "./schemas/create_user_docs",
77 | "timeout_seconds": 0,
78 | "arguments": {"database_name": database_name}
79 | },
80 | {
81 | "path": artificial_hes_schema_coercion_path,
82 | "timeout_seconds": 0,
83 | "arguments": {"database_name": database_name}
84 | },
85 | ]
86 |
87 | # COMMAND ----------
88 |
89 | for params in notebook_params:
90 | dbutils.notebook.run(**params)
--------------------------------------------------------------------------------
/projects/artificial_hes/notebooks/code_promotion_paths.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import pathlib
3 |
4 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6"
5 | ADG_PROJECT_NAME = "iuod_artificial_data_generator"
6 | CODE_PROMOTION_RELEASES_PATH = pathlib.Path("/Users/admin/releases/code-promotion")
7 |
8 |
9 | def get_adg_env(project_version: str) -> str:
10 | if project_version == "dev":
11 | return "dev"
12 | elif project_version == "staging":
13 | return "staging"
14 | else:
15 | return "release"
16 |
17 |
18 | def get_adg_project_path(project_version: str) -> pathlib.Path:
19 | adg_env = get_adg_env(project_version)
20 |
21 | if adg_env == "release":
22 | adg_project_path = CODE_PROMOTION_RELEASES_PATH / ADG_PROJECT_NAME / ADG_PROJECT_VERSION
23 | elif adg_env == "staging":
24 | root_path = pathlib.Path("/staging")
25 | adg_project_path = root_path / ADG_PROJECT_NAME
26 | elif adg_env == "dev":
27 | root_path = pathlib.Path("../")
28 | adg_project_path = root_path / ADG_PROJECT_NAME
29 |
30 | return adg_project_path
--------------------------------------------------------------------------------
/projects/artificial_hes/notebooks/common/coerce_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import DataFrame, types as T, functions as F
3 |
4 |
5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame:
6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis).
7 |
8 | Fields that exist in the target schema but which are not present in the original dataframe
9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised
10 |
11 | Fields that exist in the original dataframe but not the target schema will be excluded
12 | from the result
13 |
14 | Args:
15 | df (pyspark.sql.DataFrame): Dataframe to transform
16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result
17 |
18 | Raises:
19 | ValueError: If a non-nullable field exists in the target schema which is not in
20 | the original dataframe
21 |
22 | Returns:
23 | DataFrame: transformed dataframe
24 | """
25 | selected_fields = []
26 |
27 | for field in target_schema.fields:
28 | if field.name in df.columns:
29 | value_col = F.col(field.name)
30 | else:
31 | if field.nullable:
32 | value_col = F.lit(None)
33 | else:
34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!")
35 |
36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name))
37 |
38 | return df.select(*selected_fields)
--------------------------------------------------------------------------------
/projects/artificial_hes/notebooks/common/common_exports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./table_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./widget_utils
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./spark_helpers
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./coerce_schema
--------------------------------------------------------------------------------
/projects/artificial_hes/notebooks/common/widget_utils.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import os
3 |
4 | # COMMAND ----------
5 |
6 | def check_databricks() -> bool:
7 | """Check whether the code is running in the databricks environment
8 |
9 | Returns:
10 | bool: indicates whether the current environment is databricks
11 | """
12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3"
13 |
14 |
15 | # COMMAND ----------
16 |
17 | def get_required_argument(name: str) -> str:
18 | """Get the value of a dbutils widget, raising an exception if
19 | the value is not provided.
20 |
21 | Args:
22 | name (str): name of the widget to get the value of
23 |
24 | Raises:
25 | NotImplementedError: running this function outside databricks
26 | ValueError: if a value is not provided for the widget
27 |
28 | Returns:
29 | str: value of the widget
30 | """
31 | if check_databricks():
32 | value = dbutils.widgets.get(name)
33 | else:
34 | # TODO
35 | raise NotImplementedError("No way to provide arguments outside of Databricks")
36 |
37 | try:
38 | assert value
39 | return value
40 | except AssertionError:
41 | raise ValueError(f"Argument '{name}' must be provided")
42 |
--------------------------------------------------------------------------------
/projects/artificial_hes/run_notebooks.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md ## Artificial HES: Run Notebooks
3 | # MAGIC
4 | # MAGIC This notebook copies tables containing artificial HES data from `iuod_artificial_data_generator` into a target database (by default this is the CP project database, namely `artificial_hes`).
5 |
6 | # COMMAND ----------
7 |
8 | # MAGIC %md
9 | # MAGIC ## Setup
10 |
11 | # COMMAND ----------
12 |
13 | # MAGIC %md
14 | # MAGIC ### Imports
15 |
16 | # COMMAND ----------
17 |
18 | # MAGIC %run ./notebooks/common/widget_utils
19 |
20 | # COMMAND ----------
21 |
22 | # MAGIC %run ./notebooks/common/table_helpers
23 |
24 | # COMMAND ----------
25 |
26 | from pprint import pprint
27 | from functools import reduce
28 |
29 | from pyspark.sql import DataFrame, functions as F
30 |
31 | # COMMAND ----------
32 |
33 | # MAGIC %md
34 | # MAGIC ### Widgets
35 |
36 | # COMMAND ----------
37 |
38 | if check_databricks():
39 | dbutils.widgets.removeAll()
40 | dbutils.widgets.text("notebook_root", "artificial_hes/dev", "0.1 Notebook Root")
41 | dbutils.widgets.text("db", "artificial_hes", "0.2 Project Database")
42 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database")
43 | dbutils.widgets.text("included_tables", "", "1.1 Tables to copy (comma separated)")
44 | dbutils.widgets.text("excluded_tables", "", "1.2 Tables to exclude from copy (comma separated)")
45 | else:
46 | # Only make widgets in databricks
47 | pass
48 |
49 | # COMMAND ----------
50 |
51 | # MAGIC %md
52 | # MAGIC ## Main
53 |
54 | # COMMAND ----------
55 |
56 | source_database = get_required_argument("iuod_artificial_data_generator")
57 | target_database = get_required_argument("db")
58 |
59 | # COMMAND ----------
60 |
61 | # Parse the inclusions / exclusions
62 |
63 | included_tables = list(filter(lambda x: x, dbutils.widgets.get("included_tables").split(",")))
64 | excluded_tables = list(filter(lambda x: x, dbutils.widgets.get("excluded_tables").split(",")))
65 |
66 | tables_to_clone_df = (
67 | spark.sql(rf"SHOW TABLES IN {source_database}")
68 | .filter(F.col("tableName").rlike(r"artificial_hes_(apc|ae|op)_\d{4}"))
69 | .filter(~F.col("tableName").isin(excluded_tables))
70 | )
71 |
72 | if included_tables:
73 | tables_to_clone_df = tables_to_clone_df.filter(F.col("tableName").isin(included_tables))
74 |
75 | tables_to_clone = [row.tableName for row in tables_to_clone_df.collect()]
76 |
77 | print("The following tables will be cloned: ")
78 | pprint(tables_to_clone)
79 |
80 | print("\nThe following tables will be excluded: ")
81 | pprint(excluded_tables)
82 |
83 | # COMMAND ----------
84 |
85 | # NOTE: delta history logic commented out. We do want this but we will need to figure out how to ensure
86 | # table permissions work across different environments. Currently the artificial data table gets created
87 | # and the owner is set to admin or data-managers (on ref), but this then breaks the CP jobs permissions
88 | # because it no longer owns the table so can't call 'DESCRIBE HISTORY'. Ticket added to the backlog to
89 | # revisit this.
90 |
91 | # # Record the history of delta tables for users - only table owners can see the history otherwise!
92 | # history_cols = [
93 | # F.col("version").alias("VERSION"),
94 | # F.col("timestamp").alias("TIMESTAMP"),
95 | # ]
96 | # history_dfs = []
97 |
98 | for table_name in tables_to_clone:
99 | print(f"Cloning table `{table_name}` from database `{source_database}` to database `{target_database}`")
100 | source_df = spark.table(f"{source_database}.{table_name}")
101 | create_table(spark, source_df, target_database, table_name, mode="overwrite", overwriteSchema="true")
102 |
103 | # history_dfs.append(
104 | # spark.sql(f"DESCRIBE HISTORY {target_database}.{table_name}")
105 | # .select(
106 | # F.lit(table_name).alias("TABLE_NAME"),
107 | # *history_cols
108 | # )
109 | # )
110 |
111 | # history_df = reduce(DataFrame.union, history_dfs)
112 | # create_table(spark, history_df, target_database, "delta_history", mode="overwrite", overwriteSchema="true")
113 |
114 | # COMMAND ----------
115 |
116 |
--------------------------------------------------------------------------------
/projects/artificial_hes/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/artificial_hes/tests/placeholder.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/cp_config.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Code Promotion Job Configuration
4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks.
5 | # MAGIC Only simple setting of variables is allowed in this notebook.
6 |
7 | # COMMAND ----------
8 |
9 | import os
10 |
11 | # COMMAND ----------
12 |
13 | # DBTITLE 1,Global settings
14 | # MAGIC %md
15 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3).
16 | # MAGIC This applies to all jobs created
17 |
18 | # COMMAND ----------
19 |
20 | spark_version = "6.6.x-scala2.11"
21 |
22 | # COMMAND ----------
23 |
24 | # DBTITLE 1,init_schemas
25 | # MAGIC %md
26 | # MAGIC Available parameters:
27 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
28 |
29 | # COMMAND ----------
30 |
31 | # Example:
32 | init_schemas = {
33 | "retain_cluster": False
34 | }
35 |
36 | # COMMAND ----------
37 |
38 | # DBTITLE 1,run_notebooks
39 | # MAGIC %md
40 | # MAGIC Available parameters:
41 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time.
42 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job.
43 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax.
44 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
45 |
46 | # COMMAND ----------
47 |
48 | # Example:
49 | run_notebooks = {
50 | "concurrency": 1,
51 | "extra_parameters": {},
52 | "retain_cluster": False,
53 | }
54 |
55 | if os.getenv("env", "ref") == "prod":
56 | run_notebooks = {
57 | **run_notebooks,
58 | "num_workers": 8,
59 | "instance_type": "i3.2xlarge",
60 | }
61 |
62 | # COMMAND ----------
63 |
64 | # DBTITLE 1,tool_config
65 | # MAGIC %md
66 | # MAGIC Available parameters:
67 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
68 |
69 | # COMMAND ----------
70 |
71 | # Example:
72 | # tool_config = {
73 | # "retain_cluster": True
74 | # }
75 | #tool_config = {
76 | #}
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/init_schemas.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Artificial HES Meta: Init Schemas
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Code Promotion Project Paths
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ./notebooks/code_promotion_paths
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md
22 | # MAGIC ### Imports
23 |
24 | # COMMAND ----------
25 |
26 | # MAGIC %run ./notebooks/common/widget_utils
27 |
28 | # COMMAND ----------
29 |
30 | import pprint
31 |
32 | # COMMAND ----------
33 |
34 | # MAGIC %md
35 | # MAGIC ### Widgets
36 |
37 | # COMMAND ----------
38 |
39 | if check_databricks():
40 | dbutils.widgets.removeAll()
41 | dbutils.widgets.text("notebook_root", "artificial_hes_meta/dev", "0.1 Notebook Root")
42 | dbutils.widgets.text("db", "artificial_hes_meta", "0.2 Project Database")
43 | else:
44 | # Only make widgets in databricks
45 | pass
46 |
47 | # COMMAND ----------
48 |
49 | # MAGIC %md ## Main
50 |
51 | # COMMAND ----------
52 |
53 | notebook_root = get_required_argument("notebook_root")
54 | *_, project_name, project_version = notebook_root.split("/")
55 |
56 | cp_project_params = {}
57 | cp_project_params["project_name"] = project_name
58 | cp_project_params["project_version"] = project_version
59 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version)
60 |
61 | # COMMAND ----------
62 |
63 | print(f"Running with context:")
64 | pprint.pprint(cp_project_params)
65 |
66 | # COMMAND ----------
67 |
68 | database_name = get_required_argument("db")
69 | adg_schemas_path = cp_project_params["adg_project_path"] / "schemas"
70 | open_data_metadata_uplift_path = str(adg_schemas_path / "uplifts" / "open_data_metadata_uplift")
71 |
72 | notebook_params = [
73 | {
74 | "path": open_data_metadata_uplift_path,
75 | "timeout_seconds": 0,
76 | "arguments": {
77 | "database_name": database_name,
78 | "table_name": "artificial_hes_meta",
79 | }
80 | },
81 | ]
82 |
83 | # COMMAND ----------
84 |
85 | for params in notebook_params:
86 | dbutils.notebook.run(**params)
87 |
88 | # COMMAND ----------
89 |
90 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/notebooks/code_promotion_paths.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import pathlib
3 |
4 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6"
5 | ADG_PROJECT_NAME = "iuod_artificial_data_generator"
6 | CODE_PROMOTION_RELEASES_PATH = pathlib.Path("/Users/admin/releases/code-promotion")
7 |
8 |
9 | def get_adg_env(project_version: str) -> str:
10 | if project_version == "dev":
11 | return "dev"
12 | elif project_version == "staging":
13 | return "staging"
14 | else:
15 | return "release"
16 |
17 |
18 | def get_adg_project_path(project_version: str) -> pathlib.Path:
19 | adg_env = get_adg_env(project_version)
20 |
21 | if adg_env == "release":
22 | adg_project_path = CODE_PROMOTION_RELEASES_PATH / ADG_PROJECT_NAME / ADG_PROJECT_VERSION
23 | elif adg_env == "staging":
24 | root_path = pathlib.Path("/staging")
25 | adg_project_path = root_path / ADG_PROJECT_NAME
26 | elif adg_env == "dev":
27 | root_path = pathlib.Path("../")
28 | adg_project_path = root_path / ADG_PROJECT_NAME
29 |
30 | return adg_project_path
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/notebooks/common/coerce_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import DataFrame, types as T, functions as F
3 |
4 |
5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame:
6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis).
7 |
8 | Fields that exist in the target schema but which are not present in the original dataframe
9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised
10 |
11 | Fields that exist in the original dataframe but not the target schema will be excluded
12 | from the result
13 |
14 | Args:
15 | df (pyspark.sql.DataFrame): Dataframe to transform
16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result
17 |
18 | Raises:
19 | ValueError: If a non-nullable field exists in the target schema which is not in
20 | the original dataframe
21 |
22 | Returns:
23 | DataFrame: transformed dataframe
24 | """
25 | selected_fields = []
26 |
27 | for field in target_schema.fields:
28 | if field.name in df.columns:
29 | value_col = F.col(field.name)
30 | else:
31 | if field.nullable:
32 | value_col = F.lit(None)
33 | else:
34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!")
35 |
36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name))
37 |
38 | return df.select(*selected_fields)
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/notebooks/common/common_exports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./table_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./widget_utils
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./spark_helpers
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./coerce_schema
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/notebooks/common/widget_utils.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import os
3 |
4 | # COMMAND ----------
5 |
6 | def check_databricks() -> bool:
7 | """Check whether the code is running in the databricks environment
8 |
9 | Returns:
10 | bool: indicates whether the current environment is databricks
11 | """
12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3"
13 |
14 |
15 | # COMMAND ----------
16 |
17 | def get_required_argument(name: str) -> str:
18 | """Get the value of a dbutils widget, raising an exception if
19 | the value is not provided.
20 |
21 | Args:
22 | name (str): name of the widget to get the value of
23 |
24 | Raises:
25 | NotImplementedError: running this function outside databricks
26 | ValueError: if a value is not provided for the widget
27 |
28 | Returns:
29 | str: value of the widget
30 | """
31 | if check_databricks():
32 | value = dbutils.widgets.get(name)
33 | else:
34 | # TODO
35 | raise NotImplementedError("No way to provide arguments outside of Databricks")
36 |
37 | try:
38 | assert value
39 | return value
40 | except AssertionError:
41 | raise ValueError(f"Argument '{name}' must be provided")
42 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/run_notebooks.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Artificial HES Meta: Run Notebooks
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Code Promotion Project Paths
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ./notebooks/code_promotion_paths
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md
22 | # MAGIC ### Imports
23 |
24 | # COMMAND ----------
25 |
26 | # MAGIC %run ./notebooks/common/widget_utils
27 |
28 | # COMMAND ----------
29 |
30 | import pprint
31 | import json
32 |
33 | # COMMAND ----------
34 |
35 | # MAGIC %md
36 | # MAGIC ### Widgets
37 |
38 | # COMMAND ----------
39 |
40 | if check_databricks():
41 | dbutils.widgets.removeAll()
42 | dbutils.widgets.text("notebook_root", "artificial_hes_meta/dev", "0.1 Notebook Root")
43 | dbutils.widgets.text("db", "artificial_hes_meta", "0.2 Project Database")
44 | dbutils.widgets.text("hes", "hes", "0.3 HES Database")
45 | else:
46 | # Only make widgets in databricks
47 | pass
48 |
49 | # COMMAND ----------
50 |
51 | # MAGIC %md ## Main
52 |
53 | # COMMAND ----------
54 |
55 | notebook_root = get_required_argument("notebook_root")
56 | *_, project_name, project_version = notebook_root.split("/")
57 |
58 | cp_project_params = {}
59 | cp_project_params["project_name"] = project_name
60 | cp_project_params["project_version"] = project_version
61 | cp_project_params["adg_project_path"] = get_adg_project_path(project_version)
62 |
63 | # COMMAND ----------
64 |
65 | print(f"Running with context:")
66 | pprint.pprint(cp_project_params)
67 |
68 | # COMMAND ----------
69 |
70 | params = {
71 | "path": str(cp_project_params["adg_project_path"] / "notebooks" / "scraper_pipelines" / "hes" / "driver"),
72 | "timeout_seconds": 0,
73 | "arguments": {
74 | "hes_database": get_required_argument("hes"),
75 | "meta_database": get_required_argument("db"),
76 | }
77 | }
78 |
79 | dbutils.notebook.run(**params)
80 |
81 | # COMMAND ----------
82 |
83 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/schemas/placeholder.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/artificial_hes_meta/tests/placeholder.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/cp_config.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Code Promotion Job Configuration
4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks.
5 | # MAGIC Only simple setting of variables is allowed in this notebook.
6 |
7 | # COMMAND ----------
8 |
9 | # DBTITLE 1,Global settings
10 | # MAGIC %md
11 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3).
12 | # MAGIC This applies to all jobs created
13 |
14 | # COMMAND ----------
15 |
16 | spark_version = "6.6.x-scala2.11"
17 | #spark_version = "9.1.x-scala2.12"
18 |
19 |
20 | # COMMAND ----------
21 |
22 | # DBTITLE 1,init_schemas
23 | # MAGIC %md
24 | # MAGIC Available parameters:
25 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
26 |
27 | # COMMAND ----------
28 |
29 | init_schemas = {
30 | "retain_cluster": False
31 | }
32 |
33 | # COMMAND ----------
34 |
35 | # DBTITLE 1,run_notebooks
36 | # MAGIC %md
37 | # MAGIC Available parameters:
38 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time.
39 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job.
40 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax.
41 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
42 |
43 | # COMMAND ----------
44 |
45 | # Example:
46 | run_notebooks = {
47 | "concurrency": 1,
48 | "extra_parameters": {
49 | "safety_dt": "",
50 | "source_database_name": "",
51 | "table_name": "",
52 | "reviewer_name": "",
53 | "reviewer_email": "",
54 | "reviewed_date": "",
55 | "approver_name": "",
56 | "approver_email": "",
57 | "approved_date": "",
58 | },
59 | "retain_cluster": False
60 | }
61 |
62 | # COMMAND ----------
63 |
64 | # DBTITLE 1,tool_config
65 | # MAGIC %md
66 | # MAGIC Available parameters:
67 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
68 |
69 | # COMMAND ----------
70 |
71 | # Example:
72 | # tool_config = {
73 | # "retain_cluster": True
74 | # }
75 | tool_config = {
76 | }
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/init_schemas.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/notebooks/code_promotion_paths.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import pathlib
3 |
4 | # COMMAND ----------
5 |
6 | def get_adg_project_path(project_env: str, project_version: str) -> pathlib.Path:
7 | adg_project_name = "iuod_artificial_data_generator"
8 |
9 | if project_env == "release":
10 | root_path = pathlib.Path("/Users/admin/releases/code-promotion")
11 | adg_project_path = root_path / adg_project_name / project_version
12 | elif project_env == "staging":
13 | root_path = pathlib.Path("/staging")
14 | adg_project_path = root_path / adg_project_name
15 | elif project_env == "dev":
16 | root_path = pathlib.Path("../")
17 | adg_project_path = root_path / adg_project_name
18 |
19 | return adg_project_path
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/notebooks/code_promotion_versions.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | ADG_PROJECT_VERSION = "4497+20221019131315.git06db5e6f6"
3 | print(f"iuod_artificial_data_generator: {ADG_PROJECT_VERSION}")
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/notebooks/widget_utils.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import os
3 |
4 | # COMMAND ----------
5 |
6 | def check_databricks() -> bool:
7 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3"
8 |
9 |
10 | # COMMAND ----------
11 |
12 | def get_required_argument(name: str) -> str:
13 | if check_databricks():
14 | value = dbutils.widgets.get(name)
15 | else:
16 | # TODO
17 | raise NotImplementedError("No way to provide arguments outside of Databricks")
18 |
19 | try:
20 | assert value
21 | return value
22 | except AssertionError:
23 | raise ValueError(f"Argument '{name}' must be provided")
24 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/run_notebooks.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC ## Run Notebooks
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Code Promotion Version
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ./notebooks/code_promotion_versions
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md
22 | # MAGIC ### Imports
23 |
24 | # COMMAND ----------
25 |
26 | # MAGIC %run ./notebooks/code_promotion_paths
27 |
28 | # COMMAND ----------
29 |
30 | # MAGIC %run ./notebooks/widget_utils
31 |
32 | # COMMAND ----------
33 |
34 | import json
35 | from typing import Dict
36 |
37 | # COMMAND ----------
38 |
39 | # MAGIC %md
40 | # MAGIC ### Widgets
41 |
42 | # COMMAND ----------
43 |
44 | if check_databricks():
45 | # TODO: update these!
46 | dbutils.widgets.removeAll()
47 | dbutils.widgets.text("notebook_root", "iuod_artificial_data_admin/dev", "0.1 Notebook Root")
48 | dbutils.widgets.text("db", "iuod_artificial_data_admin", "0.2 Project Database")
49 | dbutils.widgets.text("iuod_artificial_data_generator", "iuod_artificial_data_generator", "0.3 ADG Project Database")
50 | dbutils.widgets.text("artificial_hes_meta", "artificial_hes_meta", "0.4 HES Meta Project Database")
51 | dbutils.widgets.text("safety_dt", "", "1.1 Safety Timestamp (UTC)")
52 | dbutils.widgets.text("source_database_name", "", "2.1 Source Database Name")
53 | dbutils.widgets.text("table_name", "", "2.2 Source / Target Table Name")
54 | dbutils.widgets.text("reviewer_name", "", "3.1 Reviewer Name")
55 | dbutils.widgets.text("reviewer_email", "", "3.2 Reviewer Email")
56 | dbutils.widgets.text("reviewed_date", "", "3.3 Reviewed Date")
57 | dbutils.widgets.text("approver_name", "", "3.4 Approver Name")
58 | dbutils.widgets.text("approver_email", "", "3.5 Approver Email")
59 | dbutils.widgets.text("approved_date", "", "3.6 Approved Date")
60 | else:
61 | # Only make widgets in databricks
62 | pass
63 |
64 | # COMMAND ----------
65 |
66 | # MAGIC %md
67 | # MAGIC ## Functions
68 |
69 | # COMMAND ----------
70 |
71 | def get_adg_env(project_version: str) -> str:
72 | if project_version == "dev":
73 | return "dev"
74 | elif project_version == "staging":
75 | return "staging"
76 | else:
77 | return "release"
78 |
79 | # COMMAND ----------
80 |
81 | def get_driver_arguments() -> Dict[str, str]:
82 | source_database_name = get_required_argument(get_required_argument("source_database_name")) # Must be one of the databases in the project permissions scope
83 | return {
84 | "safety_dt": get_required_argument("safety_dt"),
85 | "config_json": json.dumps(dict(
86 | source_database_name = source_database_name,
87 | target_database_name = get_required_argument("iuod_artificial_data_generator"),
88 | table_name = get_required_argument("table_name"),
89 | reviewer_name = get_required_argument("reviewer_name"),
90 | reviewer_email = get_required_argument("reviewer_email"),
91 | reviewed_date = get_required_argument("reviewed_date"),
92 | approver_name = get_required_argument("approver_name"),
93 | approver_email = get_required_argument("approver_email"),
94 | approved_date = get_required_argument("approved_date"),
95 | releases_database_name = get_required_argument("db"),
96 | ))
97 | }
98 |
99 | # COMMAND ----------
100 |
101 | # MAGIC %md
102 | # MAGIC ## Main
103 |
104 | # COMMAND ----------
105 |
106 | # TODO: wrap in try catch
107 | driver_args = get_driver_arguments() # All args must be provided to proceed!
108 |
109 | # COMMAND ----------
110 |
111 | notebook_response = {}
112 |
113 | # COMMAND ----------
114 |
115 | notebook_root = get_required_argument("notebook_root")
116 | *_, project_name, project_version = notebook_root.split("/")
117 |
118 | adg_env = get_adg_env(project_version)
119 | adg_project_path = get_adg_project_path(adg_env, ADG_PROJECT_VERSION)
120 |
121 | # COMMAND ----------
122 |
123 | driver_params = {
124 | "path": str(adg_project_path / "notebooks" / "admin" / "driver"),
125 | "timeout_seconds": 0,
126 | "arguments": get_driver_arguments()
127 | }
128 | driver_response_json = dbutils.notebook.run(**driver_params)
129 |
130 | notebook_response["status"] = "success"
131 | notebook_response["message"] = ""
132 | notebook_response["children"] = [{**driver_params, "response": json.loads(driver_response_json)}]
133 |
134 | # COMMAND ----------
135 |
136 | dbutils.notebook.exit(json.dumps(notebook_response))
137 |
138 | # COMMAND ----------
139 |
140 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/schemas/placeholder.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_admin/tests/placeholder.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/cp_config.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Code Promotion Job Configuration
4 | # MAGIC ### This notebook sets some parameters for Databricks jobs wrapping the top-level entry point notebooks.
5 | # MAGIC Only simple setting of variables is allowed in this notebook.
6 |
7 | # COMMAND ----------
8 |
9 | import os
10 |
11 | # COMMAND ----------
12 |
13 | # DBTITLE 1,Global settings
14 | # MAGIC %md
15 | # MAGIC spark_version can be either "6.6.x-scala2.11" (spark 2) or "9.1.x-scala2.12" (spark 3).
16 | # MAGIC This applies to all jobs created
17 |
18 | # COMMAND ----------
19 |
20 | spark_version = "6.6.x-scala2.11"
21 |
22 | # COMMAND ----------
23 |
24 | # DBTITLE 1,init_schemas
25 | # MAGIC %md
26 | # MAGIC Available parameters:
27 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
28 |
29 | # COMMAND ----------
30 |
31 | init_schemas = {
32 | "retain_cluster": False
33 | }
34 |
35 | # COMMAND ----------
36 |
37 | # DBTITLE 1,run_notebooks
38 | # MAGIC %md
39 | # MAGIC Available parameters:
40 | # MAGIC - **concurrency**: Integer between 1 and 10. Allows you to run multiple *run_notebooks* jobs at the same time.
41 | # MAGIC - **extra_parameters**: Dictionary(String, String) that maps *parameter names* to *default values*. These parameters are added to the list of parameters for the job.
42 | # MAGIC - **schedule**: A quartz cron syntax on when to run - see https://www.freeformatter.com/cron-expression-generator-quartz.html for cron syntax.
43 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
44 |
45 | # COMMAND ----------
46 |
47 | run_notebooks = {
48 | "concurrency": 1,
49 | "extra_parameters": {
50 | "artificial_dataset": "",
51 | "parameters_json": "{}",
52 | },
53 | "retain_cluster": False,
54 | }
55 |
56 | if os.getenv("env", "ref") == "prod":
57 | run_notebooks = {
58 | **run_notebooks,
59 | "num_workers": 8,
60 | "instance_type": "i3.2xlarge",
61 | }
62 |
63 | # COMMAND ----------
64 |
65 | # DBTITLE 1,tool_config
66 | # MAGIC %md
67 | # MAGIC Available parameters:
68 | # MAGIC - **retain_cluster**: boolean flag to indicate if existing cluster definition for the job is retained
69 |
70 | # COMMAND ----------
71 |
72 | # Example:
73 | # tool_config = {
74 | # "retain_cluster": True
75 | # }
76 | #tool_config = {
77 | #}
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/init_schemas.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Artificial Data Generator: Init Schemas
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Imports
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ./notebooks/common/widget_utils
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %md
22 | # MAGIC ### Widgets
23 |
24 | # COMMAND ----------
25 |
26 | if check_databricks():
27 | dbutils.widgets.removeAll()
28 | dbutils.widgets.text("db", "iuod_artificial_data_generator", "0.1 Database")
29 | dbutils.widgets.text("notebook_root", "iuod_artificial_data_generator/dev", "0.2 Notebook Root")
30 | else:
31 | # Only make widgets in databricks
32 | pass
33 |
34 | # COMMAND ----------
35 |
36 | # MAGIC %md
37 | # MAGIC ## Main
38 |
39 | # COMMAND ----------
40 |
41 | database_name = get_required_argument("db")
42 | create_meta_notebook_path = "./schemas/create_meta_table"
43 | open_data_metadata_uplift_path = "./schemas/uplifts/open_data_metadata_uplift"
44 |
45 | notebook_params = [
46 | {
47 | "path": open_data_metadata_uplift_path,
48 | "timeout_seconds": 0,
49 | "arguments": {
50 | "database_name": database_name,
51 | "table_name": "artificial_hes_meta",
52 | }
53 | },
54 | {
55 | "path": create_meta_notebook_path,
56 | "timeout_seconds": 300,
57 | "arguments": {
58 | "database_name": database_name,
59 | "table_name": "artificial_hes_meta",
60 | }
61 | },
62 | ]
63 |
64 | # COMMAND ----------
65 |
66 | for params in notebook_params:
67 | dbutils.notebook.run(**params)
68 |
69 | # COMMAND ----------
70 |
71 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/admin/driver.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC ## Admin Driver
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Imports
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ../common/table_helpers
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %run ../common/widget_utils
22 |
23 | # COMMAND ----------
24 |
25 | import datetime
26 | from datetime import datetime
27 | import json
28 | from typing import Iterable, Dict
29 | from pyspark.sql import DataFrame, functions as F, Column
30 |
31 | # COMMAND ----------
32 |
33 | # MAGIC %md
34 | # MAGIC ### Widgets
35 |
36 | # COMMAND ----------
37 |
38 | if check_databricks():
39 | dbutils.widgets.removeAll()
40 | dbutils.widgets.text("safety_dt", "", "0 Safety Timestamp (UTC)")
41 | dbutils.widgets.text("config_json", "", "1 Config JSON")
42 | else:
43 | pass
44 |
45 | # COMMAND ----------
46 |
47 | # MAGIC %md
48 | # MAGIC ## Functions
49 |
50 | # COMMAND ----------
51 |
52 | def get_release_data(config: Dict[str, str]) -> DataFrame:
53 | reviewed_date = datetime.strptime(config["reviewed_date"], "%Y-%m-%d")
54 | approved_date = datetime.strptime(config["approved_date"], "%Y-%m-%d")
55 | release_cols = [
56 | F.current_timestamp() .alias("RELEASE_TIMESTAMP_UTC"),
57 | F.lit(config["source_database_name"]).alias("SOURCE_DATABASE_NAME"),
58 | F.lit(config["target_database_name"]).alias("TARGET_DATABASE_NAME"),
59 | F.lit(config["table_name"]) .alias("TABLE_NAME"),
60 | F.lit(config["reviewer_name"]) .alias("REVIEWER_NAME"),
61 | F.lit(config["reviewer_email"]) .alias("REVIEWER_EMAIL"),
62 | F.lit(reviewed_date).cast("date") .alias("REVIEWED_DATE"),
63 | F.lit(config["approver_name"]) .alias("APPROVER_NAME"),
64 | F.lit(config["approver_email"]) .alias("APPROVER_EMAIL"),
65 | F.lit(approved_date).cast("date") .alias("APPROVED_DATE"),
66 | ]
67 |
68 | return spark.range(1).select(*release_cols)
69 |
70 | # COMMAND ----------
71 |
72 | def check_safe() -> bool:
73 | safety_window_seconds = 600
74 | current_dt = datetime.now()
75 | safety_dt = datetime.strptime(get_required_argument("safety_dt"), "%Y-%m-%d %H:%M")
76 | safe = abs(current_dt - safety_dt).total_seconds() < safety_window_seconds
77 | return safe
78 |
79 | # COMMAND ----------
80 |
81 | # MAGIC %md
82 | # MAGIC ## Main
83 |
84 | # COMMAND ----------
85 |
86 | config = json.loads(get_required_argument("config_json"))
87 | source_table_path = f"{config['source_database_name']}.{config['table_name']}"
88 | target_table_path = f"{config['target_database_name']}.{config['table_name']}"
89 |
90 | # COMMAND ----------
91 |
92 | notebook_response = {}
93 |
94 | if check_safe():
95 | print("Metadata marked as safe")
96 |
97 | print(f"Creating release log entry in table '{config['releases_database_name']}.releases'")
98 |
99 | release_df = get_release_data(config) # Must be populated before release
100 | create_table(spark, release_df, config['releases_database_name'], "releases", format="delta", mode="append")
101 |
102 | print(f"Publishing from '{source_table_path}' to '{target_table_path}'")
103 |
104 | meta_df = spark.table(source_table_path)
105 |
106 | # Insert the rows into the target table
107 | # Table should already have been created by the owner)
108 | insert_into(spark, meta_df, *target_table_path.split("."), overwrite=True)
109 | spark.sql(f"OPTIMIZE {target_table_path}")
110 |
111 | print("Done!")
112 |
113 | notebook_response["status"] = "success"
114 | notebook_response["message"] = ""
115 | else:
116 | notebook_response["status"] = "failed"
117 | notebook_response["message"] = "Failed safety check on user-entered timestamp"
118 |
119 | # COMMAND ----------
120 |
121 | dbutils.notebook.exit(json.dumps(notebook_response))
122 |
123 | # COMMAND ----------
124 |
125 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/common/coerce_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import DataFrame, types as T, functions as F
3 |
4 |
5 | def coerce_schema(df: DataFrame, target_schema: T.StructType) -> DataFrame:
6 | """Transform the fields of a dataframe to a target schema (on a field-by-field basis).
7 |
8 | Fields that exist in the target schema but which are not present in the original dataframe
9 | will be added as nulls if the field is nullable. Otherwise a ValueError is raised
10 |
11 | Fields that exist in the original dataframe but not the target schema will be excluded
12 | from the result
13 |
14 | Args:
15 | df (pyspark.sql.DataFrame): Dataframe to transform
16 | target_schema (pyspark.sql.types.StructType): Intended schema of the result
17 |
18 | Raises:
19 | ValueError: If a non-nullable field exists in the target schema which is not in
20 | the original dataframe
21 |
22 | Returns:
23 | DataFrame: transformed dataframe
24 | """
25 | selected_fields = []
26 |
27 | for field in target_schema.fields:
28 | if field.name in df.columns:
29 | value_col = F.col(field.name)
30 | else:
31 | if field.nullable:
32 | value_col = F.lit(None)
33 | else:
34 | raise ValueError(f"Field with name `{field.name}` is not nullable in target schema, but no column exists on provided dataframe!")
35 |
36 | selected_fields.append(value_col.cast(field.dataType).alias(field.name))
37 |
38 | return df.select(*selected_fields)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/common/common_exports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./table_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./widget_utils
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./spark_helpers
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./coerce_schema
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/common/widget_utils.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import os
3 |
4 | # COMMAND ----------
5 |
6 | def check_databricks() -> bool:
7 | """Check whether the code is running in the databricks environment
8 |
9 | Returns:
10 | bool: indicates whether the current environment is databricks
11 | """
12 | return os.environ.get("VIRTUAL_ENV") == "/databricks/python3"
13 |
14 |
15 | # COMMAND ----------
16 |
17 | def get_required_argument(name: str) -> str:
18 | """Get the value of a dbutils widget, raising an exception if
19 | the value is not provided.
20 |
21 | Args:
22 | name (str): name of the widget to get the value of
23 |
24 | Raises:
25 | NotImplementedError: running this function outside databricks
26 | ValueError: if a value is not provided for the widget
27 |
28 | Returns:
29 | str: value of the widget
30 | """
31 | if check_databricks():
32 | value = dbutils.widgets.get(name)
33 | else:
34 | # TODO
35 | raise NotImplementedError("No way to provide arguments outside of Databricks")
36 |
37 | try:
38 | assert value
39 | return value
40 | except AssertionError:
41 | raise ValueError(f"Argument '{name}' must be provided")
42 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_ae_tables.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import List
3 |
4 |
5 | HES_AE_PRIMARY_KEY = "AEKEY"
6 | HES_AE_TABLE_NAMES = [
7 | "hes_ae_2021",
8 | "hes_ae_1920",
9 | "hes_ae_1819",
10 | "hes_ae_1718",
11 | "hes_ae_1617",
12 | "hes_ae_1516",
13 | "hes_ae_1415",
14 | "hes_ae_1314",
15 | "hes_ae_1213",
16 | "hes_ae_1112",
17 | "hes_ae_1011",
18 | "hes_ae_0910",
19 | "hes_ae_0809",
20 | "hes_ae_0708",
21 | ]
22 |
23 |
24 | def _get_hes_ae_key_fields(patient_table: Table) -> List[Field]:
25 | return [
26 | Field(HES_AE_PRIMARY_KEY, str, primary=True),
27 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]),
28 | ]
29 |
30 |
31 | def get_hes_ae_tables(patient_table: Table) -> List[Table]:
32 | hes_ae_key_fields = _get_hes_ae_key_fields(patient_table)
33 | hes_ae_tables = []
34 |
35 | for table_name in HES_AE_TABLE_NAMES:
36 | hes_ae_tables.append(Table(table_name, *hes_ae_key_fields))
37 |
38 | return hes_ae_tables
39 |
40 |
41 | HES_AE_TABLES = get_hes_ae_tables(HES_PATIENTS_SYMBOLIC_TABLE)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_apc_tables.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import List
3 |
4 |
5 | HES_APC_PRIMARY_KEY = "EPIKEY"
6 | HES_APC_TABLE_NAMES = [
7 | "hes_apc_2021",
8 | "hes_apc_1920",
9 | "hes_apc_1819",
10 | "hes_apc_1718",
11 | "hes_apc_1617",
12 | "hes_apc_1516",
13 | "hes_apc_1415",
14 | "hes_apc_1314",
15 | "hes_apc_1213",
16 | "hes_apc_1112",
17 | "hes_apc_1011",
18 | "hes_apc_0910",
19 | "hes_apc_0809",
20 | "hes_apc_0708",
21 | "hes_apc_0607",
22 | "hes_apc_0506",
23 | "hes_apc_0405",
24 | "hes_apc_0304",
25 | "hes_apc_0203",
26 | "hes_apc_0102",
27 | ]
28 |
29 |
30 | def _get_hes_apc_key_fields(patient_table: Table) -> List[Field]:
31 | return [
32 | Field(HES_APC_PRIMARY_KEY, str, primary=True),
33 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]),
34 | ]
35 |
36 |
37 | def get_hes_apc_tables(patient_table: Table) -> List[Table]:
38 | hes_apc_key_fields = _get_hes_apc_key_fields(patient_table)
39 | hes_apc_tables = []
40 |
41 | for table_name in HES_APC_TABLE_NAMES:
42 | hes_apc_tables.append(Table(table_name, *hes_apc_key_fields))
43 |
44 | return hes_apc_tables
45 |
46 |
47 | HES_APC_TABLES = get_hes_apc_tables(HES_PATIENTS_SYMBOLIC_TABLE)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_op_tables.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import List
3 |
4 |
5 | HES_OP_PRIMARY_KEY = "ATTENDKEY"
6 | HES_OP_TABLE_NAMES = [
7 | "hes_op_2021",
8 | "hes_op_1920",
9 | "hes_op_1819",
10 | "hes_op_1718",
11 | "hes_op_1617",
12 | "hes_op_1516",
13 | "hes_op_1415",
14 | "hes_op_1314",
15 | "hes_op_1213",
16 | "hes_op_1112",
17 | "hes_op_1011",
18 | "hes_op_0910",
19 | "hes_op_0809",
20 | "hes_op_0708",
21 | "hes_op_0607",
22 | "hes_op_0506",
23 | "hes_op_0405",
24 | "hes_op_0304",
25 | ]
26 |
27 |
28 | def _get_hes_op_key_fields(patient_table: Table) -> List[Field]:
29 | return [
30 | Field(HES_OP_PRIMARY_KEY, str, primary=True),
31 | Field(HES_PATIENT_KEY, str, foreign=patient_table[HES_PATIENT_KEY]),
32 | ]
33 |
34 |
35 | def get_hes_op_tables(patient_table: Table) -> List[Table]:
36 | hes_op_key_fields = _get_hes_op_key_fields(patient_table)
37 | hes_op_tables = []
38 |
39 | for table_name in HES_OP_TABLE_NAMES:
40 | hes_op_tables.append(Table(table_name, *hes_op_key_fields))
41 |
42 | return hes_op_tables
43 |
44 |
45 | HES_OP_TABLES = get_hes_op_tables(HES_PATIENTS_SYMBOLIC_TABLE)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_patient_table.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../relational
3 |
4 | # COMMAND ----------
5 |
6 | # from dsp.common.relational import Table, Field
7 |
8 | HES_PATIENT_KEY = "PSEUDO_HESID"
9 |
10 |
11 | def get_hes_patients_symbolic_table() -> Table:
12 | """
13 | Note
14 | ----
15 | This function returns an object that symbolises the fact that there is
16 | no 'Master' table in HES. Patient IDs are duplicated across events, but in
17 | principle they are unique for each individual and the distinct values
18 | could be extracted into a table which represents the unique patients within HES.
19 |
20 | We don't actually need to create the table, just to be aware
21 | that the primary key on this symbolic table is the foreign key on the
22 | actual hes tables used to specify the patient to which an episode relates
23 | """
24 | fields = [
25 | Field(HES_PATIENT_KEY, str, primary=True),
26 | Field("MYDOB", str),
27 | Field("SEX", str),
28 | Field("ETHNOS", str),
29 | ]
30 | return Table("hes_patients_symbolic", *fields)
31 |
32 |
33 | HES_PATIENTS_SYMBOLIC_TABLE = get_hes_patients_symbolic_table()
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_schemas.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./hes_ae_schema
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./hes_apc_schema
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./hes_op_schema
11 |
12 | # COMMAND ----------
13 |
14 | from pyspark.sql import types as T
15 |
16 | # COMMAND ----------
17 |
18 | def get_hes_schema(hes_dataset: str) -> T.StructType:
19 | if hes_dataset == "hes_ae":
20 | return get_hes_ae_schema()
21 | elif hes_dataset == "hes_op":
22 | return get_hes_op_schema()
23 | elif hes_dataset == "hes_apc":
24 | return get_hes_apc_schema()
25 | else:
26 | raise NotImplementedError("Valid options for hes_dataset are: hes_apc, hes_op, hes_ae")
27 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/hes/hes_tables.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./hes_patient_table
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./hes_ae_tables
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./hes_apc_tables
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./hes_op_tables
15 |
16 | # COMMAND ----------
17 |
18 | HES_TABLES = {
19 | "hes_ae": HES_AE_TABLES,
20 | "hes_apc": HES_APC_TABLES,
21 | "hes_op": HES_OP_TABLES,
22 | }
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dataset_definitions/relational_helpers.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./relational
3 |
4 | # COMMAND ----------
5 |
6 | from typing import Generator
7 |
8 | # region Prevent downstream linting highlights
9 | Table = Table
10 | Field = Field
11 | # endregion
12 |
13 |
14 | def get_parent_table(table: Table):
15 | for field in table.fields:
16 | if field.foreign:
17 | return field.foreign.table
18 | else:
19 | return None
20 |
21 |
22 | def get_primary_key(table: Table) -> Field:
23 | for field in table.fields:
24 | if field.primary:
25 | return field
26 | else:
27 | return None
28 |
29 |
30 | def get_foreign_keys(table: Table) -> Generator[Field, None, None]:
31 | return filter(lambda field: field.foreign is not None, table.fields)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/dependencies/random.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from __future__ import annotations
3 | from typing import Union, Iterable
4 | import itertools
5 | import random
6 |
7 | import numpy as np
8 | from sklearn.neighbors import KernelDensity
9 |
10 | from pyspark.sql import Column, functions as F, types as T
11 |
12 | # COMMAND ----------
13 |
14 | def weighted_choice_column(weighted_values: dict[Union[str, int], float], seed: int=None) -> Column:
15 | """Create a DataFrame column by sampling from weighted values
16 |
17 | Parameters
18 | ----------
19 | weighted_values: dict[Union[str, int], float]
20 | Values to sample mapped to probabilities/frequencies.
21 | seed: int (default=None)
22 | Random seed
23 |
24 | Returns
25 | -------
26 | pyspark.sql.Column
27 | Randomly chosen values
28 |
29 | Examples
30 | --------
31 | >>> # df is a pyspark DataFrame
32 | >>> # Add a column with a random choice between 1 & 2
33 | >>> df = df.withColumn('choice', weighted_choice_column({1: 0.5, 2: 0.5}))
34 | | ... | choice |
35 | | ... | 1 |
36 | | ... | 2 |
37 | | ... | ... |
38 | | ... | 2 |
39 | """
40 | values, weights = zip(*weighted_values.items())
41 | weights = list(map(float, weights))
42 |
43 | rv = sum(weights) * (F.lit(1.0) - F.rand(seed=seed))
44 | cum_weights = map(F.lit, itertools.accumulate(weights))
45 |
46 | # Choose the values based on the weights
47 | # Note: instead of using array, removing null values and picking the 0th element, you could use F.coalesce. However,
48 | # this was found to give skewed results with the resulting random values (why?), so the implementation here is a workaround
49 | choices_col = F.array(*itertools.starmap(lambda cum_weight, value: F.when(rv < cum_weight, F.lit(value)), zip(cum_weights, values)))
50 | choices_col = F.array_except(choices_col, F.array(F.lit(None)))[0]
51 |
52 | return choices_col.alias(f'weighted_choice({weighted_values}, seed={seed})')
53 |
54 |
55 | def uniform_choice_column(values: list[Union[int, str, float]], seed: int=None) -> Column:
56 | """Create a Column by sampling from values with equal probabilities.
57 |
58 | Parameters
59 | ----------
60 | values: list[Union[str, int, float]]
61 | Values to sample
62 | seed: int
63 | Random seed
64 |
65 | Returns
66 | -------
67 | pyspark.sql.Column
68 | Randomly chosen values
69 |
70 | Examples
71 | --------
72 | >>> # df is a pyspark DataFrame
73 | >>> # Add a column with a random choice between 1 & 2
74 | >>> df = df.withColumn('choice', uniform_choice_column([1, 2]))
75 | | ... | choice |
76 | | ... | 1 |
77 | | ... | 2 |
78 | | ... | ... |
79 | | ... | 2 |
80 | """
81 | weighted_values = dict(zip(values, itertools.repeat(1)))
82 | return weighted_choice_column(weighted_values).alias(f'uniform_choice({values}, seed={seed})')
83 |
84 |
85 | def random_integer_column(
86 | start: Union[int, str, Column],
87 | stop: Union[int, str, Column],
88 | seed: int=None
89 | ) -> Column:
90 | """Create a column of random integers within a given range.
91 |
92 | Parameters
93 | ----------
94 | start: Union[int, str, Column]
95 | Start of range (inclusive). If start is an int then it is used as a literal value column.
96 | If it is a string it is taken to refer to a column name. Else it is used directly
97 | as a column expression.
98 | stop: Union[int, str, Column]
99 | End of range (inclusive). Similar interpretation to start.
100 | seed: int (default=None)
101 | Seed for the random number generator
102 |
103 | Returns
104 | -------
105 | pyspark.sql.Column
106 | Random numbers in the given range
107 | """
108 | alias = f'random_integer({start}, {stop}, seed={seed})'
109 |
110 | if isinstance(start, int):
111 | start = F.lit(start)
112 | elif isinstance(start, str):
113 | start = F.col(start)
114 |
115 | if isinstance(stop, int):
116 | stop = F.lit(stop)
117 | elif isinstance(stop, str):
118 | stop = F.col(stop)
119 |
120 | # Extend limits to ensure boundaries are equally likely
121 | stop = stop + F.lit(0.999999)
122 |
123 | return F.floor(start + ((stop - start) * F.rand(seed=seed))).cast(T.IntegerType()).alias(alias)
124 |
125 | # COMMAND ----------
126 |
127 | class SparkRandom(object):
128 | """Provides interface to methods analogous to those provided by Python's
129 | built-in random module but for operating on columns of a PySpark DataFrame.
130 | """
131 | def __init__(self, seed=None):
132 | self._seed = seed
133 |
134 | def seed(self, x):
135 | self._seed = x
136 |
137 | def choice(self, values: list[Union[str, int]]) -> Column:
138 | return uniform_choice_column(list(values), seed=self._seed)
139 |
140 | def choices(self, population: list[Union[str, int]], weights: list[Union[int, float]]) -> Column:
141 | return weighted_choice_column(dict(*zip(population, weights)), seed=self._seed)
142 |
143 | def randint(self, start: Union[int, str, Column], stop: Union[int, str, Column]) -> Column:
144 | return random_integer_column(start, stop, seed=self._seed)
145 |
146 | def shuffle(self, col: Union[str, Column]) -> Column:
147 | return F.shuffle(col)
148 |
149 |
150 | spark_random = SparkRandom()
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/coerce_hes_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../dataset_definitions/hes/hes_schemas
3 |
4 | # COMMAND ----------
5 |
6 | from pyspark.sql import DataFrame, functions as F
7 |
8 | # region Prevent downstream linting highlights
9 | get_hes_schema = get_hes_schema
10 | # endregion
11 |
12 | def coerce_hes_schema(df: DataFrame, hes_dataset: str) -> DataFrame:
13 | """Transform the schema of a dataframe to match the schema of a given HES dataset
14 |
15 | Args:
16 | df (DataFrame): DataFrame to transform
17 | hes_dataset (str): HES dataset with the target schema. Either hes_apc, hes_op or hes_ae
18 |
19 | Returns:
20 | DataFrame: New DataFrame transformed to the target schema
21 | """
22 | current_fields = {f.name: f for f in df.schema.fields}
23 | coerced_fields = []
24 |
25 | for field in get_hes_schema(hes_dataset).fields:
26 | target_datatype = field.dataType
27 |
28 | if field.name not in current_fields:
29 | print(f"Warning: field `{field.name}` missing from schema! Adding field and filling with nulls")
30 | value_col = F.lit(None)
31 | else:
32 | value_col = F.col(field.name)
33 |
34 | coerced_fields.append(value_col.cast(target_datatype).alias(field.name))
35 |
36 | coerced_df = df.select(*coerced_fields)
37 |
38 | return coerced_df
39 |
40 | # COMMAND ----------
41 |
42 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/driver.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../dataset_definitions/hes/hes_tables
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./coerce_hes_schema
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./field_definitions/derivations/derivations_helpers
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./field_definitions/id_fields
15 |
16 | # COMMAND ----------
17 |
18 | # MAGIC %run ./demographic_field_generators
19 |
20 | # COMMAND ----------
21 |
22 | # MAGIC %run ../driver_imports
23 |
24 | # COMMAND ----------
25 |
26 | import json
27 | from pprint import pprint
28 | from pyspark.sql import DataFrame, functions as F
29 |
30 | # Prevent spark from auto broadcasting the metadata
31 | spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
32 |
33 | ARTIFICIAL_TABLE_NAME_TEMPLATE = "artificial_{table_name}"
34 |
35 | # COMMAND ----------
36 |
37 | dbutils.widgets.removeAll()
38 |
39 | # COMMAND ----------
40 |
41 | dbutils.widgets.text("meta_database", "iuod_artificial_data_generator", "0.1 Meta DB")
42 | dbutils.widgets.text("artificial_database", "iuod_artificial_data_generator", "0.2 Artificial DB")
43 | dbutils.widgets.text("parameters_json", json.dumps({"n_patients": 100}), "1.1 Parameters JSON")
44 |
45 | # COMMAND ----------
46 |
47 | # Config
48 | meta_database = get_required_argument("meta_database")
49 | artificial_database = get_required_argument("artificial_database")
50 |
51 | notebook_parameters = json.loads(get_required_argument("parameters_json"))
52 | n_patients = notebook_parameters['n_patients']
53 |
54 | print("Running with parameters: ")
55 | pprint({
56 | "n_patients": n_patients,
57 | "artificial_database": artificial_database,
58 | "meta_database": meta_database,
59 | })
60 |
61 | # COMMAND ----------
62 |
63 | variable_fields = ["FIELD_NAME"]
64 | index_fields = ["ARTIFICIAL_DEMOGRAPHIC_ID", "ARTIFICIAL_EPISODE_ID"]
65 |
66 |
67 | def generate_artificial_hes_table(meta_df: DataFrame, hes_dataset: str, hes_table: Table, target_database: str) -> bool:
68 | # Create base dataframe for sampling onto
69 | base_demographic_df = (
70 | spark.range(n_patients)
71 | .select(F.col("id").alias(index_fields[0]))
72 | )
73 |
74 | # Generate ids for sampling based on the relationships between patients and episodes
75 | base_demographic_episode_df = (
76 | relationship_generator(filtered_meta_df, base_demographic_df, variable_fields)
77 | .select(
78 | index_fields[0],
79 | F.col("VALUE_NUMERIC").alias(index_fields[1]),
80 | )
81 | )
82 |
83 | # Generate
84 | artificial_demographic_df = demographic_field_generator(
85 | filtered_meta_df,
86 | base_demographic_episode_df,
87 | "FIELD_NAME",
88 | index_fields,
89 | )
90 | artificial_episode_df = field_generator(filtered_meta_df, base_demographic_episode_df, "FIELD_NAME")
91 | artificial_df = artificial_demographic_df.join(artificial_episode_df, how="outer", on=index_fields)
92 |
93 | # Postprocessing
94 | artificial_df = with_hes_id_fields(artificial_df, hes_dataset, hes_table, index_fields)
95 | derived_fields = get_derivations(hes_dataset)
96 | artificial_df = with_derived_fields(artificial_df, *derived_fields)
97 | artificial_df = coerce_hes_schema(artificial_df, hes_dataset)
98 |
99 | # Output
100 | success = create_table(
101 | spark,
102 | artificial_df,
103 | database_name=target_database,
104 | table_name=ARTIFICIAL_TABLE_NAME_TEMPLATE.format(table_name=hes_table.name),
105 | format="delta",
106 | mode="overwrite",
107 | overwriteSchema="true", # This doesn't work on ACL clusters!
108 | )
109 |
110 | return success
111 |
112 | # COMMAND ----------
113 |
114 | meta_table_fullname = f"{meta_database}.artificial_hes_meta"
115 | meta_df = spark.table(meta_table_fullname)
116 |
117 | # Generate artificial data for each table
118 | for hes_dataset, hes_dataset_tables in HES_TABLES.items():
119 | print(f"\nGenerating artificial data for dataset `{hes_dataset}`")
120 |
121 | for hes_table in hes_dataset_tables:
122 | print(f"\tGenerating artificial data representing table `{hes_table.name}` into database `{artificial_database}`")
123 |
124 | filtered_meta_df = (
125 | meta_df
126 | .filter(F.col("TABLE_NAME") == hes_table.name)
127 | .drop("TABLE_NAME")
128 | )
129 |
130 | if filtered_meta_df.first() is None:
131 | print(f"\t\tNo metadata found for table '{hes_table.name}': artificial data table will be empty!")
132 |
133 | success = generate_artificial_hes_table(filtered_meta_df, hes_dataset, hes_table, artificial_database)
134 |
135 |
136 | # COMMAND ----------
137 |
138 | notebook_response = {}
139 |
140 | # TODO: work into for loop above (e.g. try/except for each table and compile list of fails / successes)
141 | if success:
142 | notebook_response["status"] = "success"
143 | notebook_response["message"] = ""
144 |
145 | else:
146 | notebook_response["status"] = "failed"
147 | notebook_response["message"] = "Failed to write test data"
148 |
149 | dbutils.notebook.exit(json.dumps(notebook_response))
150 |
151 | # COMMAND ----------
152 |
153 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/age_fields.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import Union
3 | from pyspark.sql import Column, functions as F
4 |
5 | AGE_IN_DAYS_CUTOFF = 120 * 365
6 |
7 | mydob_date = F.to_date("MYDOB", "MMyyyy")
8 |
9 |
10 | def get_fractional_age_field(ref_date: Union[str, Column], dob: Union[str, Column]) -> Column:
11 | age_in_days = F.datediff(ref_date, dob)
12 | return (
13 | F.when(age_in_days.isNull(), F.expr("null"))
14 | .when(age_in_days < 1, 0.002)
15 | .when(age_in_days <= 6, 0.010)
16 | .when(age_in_days <= 28, 0.048)
17 | .when(age_in_days <= 90, 0.167)
18 | .when(age_in_days <= 181, 0.375)
19 | .when(age_in_days <= 272, 0.625)
20 | .when(age_in_days <= 365 - 1, 0.875)
21 | .when(age_in_days >= AGE_IN_DAYS_CUTOFF, 120)
22 | .otherwise(age_in_days / 365)
23 | .cast("double")
24 | )
25 |
26 |
27 | def get_categorized_age_field(ref_date: Union[str, Column], dob: Union[str, Column]) -> Column:
28 | age_in_days = F.datediff(ref_date, dob)
29 | return (
30 | F.when(age_in_days.isNull(), F.expr("null"))
31 | .when(age_in_days < 1, 7001)
32 | .when(age_in_days <= 6, 7002)
33 | .when(age_in_days <= 28, 7003)
34 | .when(age_in_days <= 90, 7004)
35 | .when(age_in_days <= 181, 7005)
36 | .when(age_in_days <= 272, 7006)
37 | .when(age_in_days <= 365-1, 7007)
38 | .when(age_in_days >= AGE_IN_DAYS_CUTOFF, 120)
39 | .otherwise(F.round(age_in_days / 365, 0))
40 | .cast("integer")
41 | )
42 |
43 |
44 | def get_fractional_from_categorized_age_field(categorized_age: str) -> Column:
45 | return (
46 | F.when(F.col(categorized_age).isNull(), F.expr("null"))
47 | .when(F.col(categorized_age) == 7001, 0.002)
48 | .when(F.col(categorized_age) == 7002, 0.010)
49 | .when(F.col(categorized_age) == 7003, 0.048)
50 | .when(F.col(categorized_age) == 7004, 0.167)
51 | .when(F.col(categorized_age) == 7005, 0.375)
52 | .when(F.col(categorized_age) == 7006, 0.625)
53 | .when(F.col(categorized_age) == 7007, 0.875)
54 | .otherwise(F.col(categorized_age))
55 | .cast("double")
56 | )
57 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/coded_field_lists.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | opertn_4_fields = [
3 | "OPERTN_4_01",
4 | "OPERTN_4_02",
5 | "OPERTN_4_03",
6 | "OPERTN_4_04",
7 | "OPERTN_4_05",
8 | "OPERTN_4_06",
9 | "OPERTN_4_07",
10 | "OPERTN_4_08",
11 | "OPERTN_4_09",
12 | "OPERTN_4_10",
13 | "OPERTN_4_11",
14 | "OPERTN_4_12",
15 | "OPERTN_4_13",
16 | "OPERTN_4_14",
17 | "OPERTN_4_15",
18 | "OPERTN_4_16",
19 | "OPERTN_4_17",
20 | "OPERTN_4_18",
21 | "OPERTN_4_19",
22 | "OPERTN_4_20",
23 | "OPERTN_4_21",
24 | "OPERTN_4_22",
25 | "OPERTN_4_23",
26 | "OPERTN_4_24",
27 | ]
28 |
29 | diag_4_fields = [
30 | "DIAG_4_01",
31 | "DIAG_4_02",
32 | "DIAG_4_03",
33 | "DIAG_4_04",
34 | "DIAG_4_05",
35 | "DIAG_4_06",
36 | "DIAG_4_07",
37 | "DIAG_4_08",
38 | "DIAG_4_09",
39 | "DIAG_4_10",
40 | "DIAG_4_11",
41 | "DIAG_4_12",
42 | "DIAG_4_13",
43 | "DIAG_4_14",
44 | "DIAG_4_15",
45 | "DIAG_4_16",
46 | "DIAG_4_17",
47 | "DIAG_4_18",
48 | "DIAG_4_19",
49 | "DIAG_4_20",
50 | ]
51 |
52 | diag3_fields = [
53 | "DIAG3_01",
54 | "DIAG3_02",
55 | "DIAG3_03",
56 | "DIAG3_04",
57 | "DIAG3_05",
58 | "DIAG3_06",
59 | "DIAG3_07",
60 | "DIAG3_08",
61 | "DIAG3_09",
62 | "DIAG3_10",
63 | "DIAG3_11",
64 | "DIAG3_12",
65 | ]
66 |
67 | treat3_fields = [
68 | "TREAT3_01",
69 | "TREAT3_02",
70 | "TREAT3_03",
71 | "TREAT3_04",
72 | "TREAT3_05",
73 | "TREAT3_06",
74 | "TREAT3_07",
75 | "TREAT3_08",
76 | "TREAT3_09",
77 | "TREAT3_10",
78 | "TREAT3_11",
79 | "TREAT3_12",
80 | ]
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/coded_fields.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./coded_field_lists
3 |
4 | # COMMAND ----------
5 |
6 | from typing import Iterable, Union, List, Tuple
7 | from pyspark.sql import Column, functions as F
8 |
9 |
10 | def _concat_coded_fields(*fields: Union[Iterable[str], Iterable[Column]]) -> Column:
11 | coalesced_fields = [F.coalesce(field, F.lit('')) for field in fields]
12 | concat_field_raw = F.concat_ws(",", *coalesced_fields)
13 | concat_field = F.regexp_replace(concat_field_raw, ',*$', '')
14 | return concat_field
15 |
16 |
17 | def _count_coded_fields(*fields: Iterable[str]) -> Column:
18 | count_field_raw = F.expr(f"cardinality(filter(array({','.join(fields)}), x -> x is not null AND x not in ('&', '-')))")
19 | count_field = (
20 | F.when(count_field_raw > 24, 24)
21 | .when(count_field_raw > 0, count_field_raw)
22 | .otherwise(None)
23 | )
24 | return count_field
25 |
26 |
27 | def _get_opertn_4_fields() -> List[str]:
28 | return opertn_4_fields # Global constant from coded_field_lists
29 |
30 |
31 | def get_opertn_34_fields() -> List[Tuple[str, Column]]:
32 | opertn_4_fields = _get_opertn_4_fields()
33 | opertn_3_fields = [(f"OPERTN_3_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(opertn_4_fields)]
34 | opertn_3_concat = ("OPERTN_3_CONCAT", _concat_coded_fields(*next(zip(*opertn_3_fields)))) # Extract just the column defintions
35 | opertn_4_concat = ("OPERTN_4_CONCAT", _concat_coded_fields(*opertn_4_fields))
36 | opertn_count = ("OPERTN_COUNT", _count_coded_fields(*opertn_4_fields))
37 | return [
38 | *opertn_3_fields,
39 | opertn_3_concat,
40 | opertn_4_concat,
41 | opertn_count,
42 | ]
43 |
44 |
45 | def _get_diag_4_fields(limit: int=20) -> List[str]:
46 | # NOTE: OP and APC have a different number of DIAG fields
47 | return diag_4_fields[:limit] # Global constant from coded_field_lists
48 |
49 |
50 | def get_diag_34_fields(limit: int=20) -> List[Tuple[str, Column]]:
51 | diag_4_fields = _get_diag_4_fields(limit=limit)
52 | diag_3_fields = [(f"DIAG_3_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(diag_4_fields)]
53 | diag_3_concat = ("DIAG_3_CONCAT", _concat_coded_fields(*next(zip(*diag_3_fields)))) # TODO: replace null 0th element with "R69" before concat
54 | diag_4_concat = ("DIAG_4_CONCAT", _concat_coded_fields(*diag_4_fields)) # TODO: replace null 0th element with "R69X" before concat
55 | diag_count = ("DIAG_COUNT", _count_coded_fields(*diag_4_fields))
56 | return [
57 | *diag_3_fields,
58 | diag_3_concat,
59 | diag_4_concat,
60 | diag_count,
61 | ]
62 |
63 | # COMMAND ----------
64 |
65 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/derivations_helpers.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./coded_fields
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./sequential_fields
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./age_fields
11 |
12 | # COMMAND ----------
13 |
14 | from typing import List, Tuple
15 | from pyspark.sql import Column
16 |
17 |
18 | def _get_apc_derivations() -> List[Tuple[str, Column]]:
19 | startage_calc = ("STARTAGE_CALC", get_fractional_age_field("ADMIDATE", mydob_date))
20 | startage = ("STARTAGE", get_categorized_age_field("ADMIDATE", mydob_date))
21 | opertn_34_fields = get_opertn_34_fields()
22 | diag_34_fields = get_diag_34_fields()
23 | apc_sequential_fields = get_apc_sequential_fields()
24 | return [
25 | startage_calc,
26 | startage,
27 | *opertn_34_fields,
28 | *diag_34_fields,
29 | *apc_sequential_fields,
30 | ]
31 |
32 |
33 | def _get_op_derivations() -> List[Tuple[str, Column]]:
34 | apptage_calc = ("APPTAGE_CALC", get_fractional_from_categorized_age_field("APPTAGE"))
35 | opertn_34_fields = get_opertn_34_fields()
36 | diag_34_fields = get_diag_34_fields(limit=12) # OP only has 12 DIAG fields
37 | op_sequential_fields = get_op_sequential_fields()
38 | return [
39 | apptage_calc,
40 | *opertn_34_fields,
41 | *diag_34_fields,
42 | *op_sequential_fields,
43 | ]
44 |
45 |
46 | def _get_ae_derivations() -> List[Tuple[str, Column]]:
47 | arrivalage_calc = ("ARRIVALAGE_CALC", get_fractional_from_categorized_age_field("ARRIVALAGE"))
48 | diag2_fields = [(f"DIAG2_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(diag3_fields)]
49 | treat2_fields = [(f"TREAT2_{i+1:02}", F.substring(x, 0, 3)) for i, x in enumerate(treat3_fields)]
50 | ae_sequential_fields = get_ae_sequential_fields()
51 | return [
52 | *diag2_fields,
53 | *treat2_fields,
54 | arrivalage_calc,
55 | *ae_sequential_fields,
56 | ]
57 |
58 |
59 | def get_derivations(dataset: str) -> List[Tuple[str, Column]]:
60 | derivation_functions = {
61 | "hes_ae": _get_ae_derivations,
62 | "hes_apc": _get_apc_derivations,
63 | "hes_op": _get_op_derivations,
64 | }
65 | _get_derivations = derivation_functions.get(dataset, lambda: [])
66 | derived_fields = _get_derivations()
67 | return derived_fields
68 |
69 |
70 | # COMMAND ----------
71 |
72 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_field_lists.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # NOTE: "RTTPERSTART" & "RTTPEREND" appear in data dictionary but not APC, OP, AE tables for DAE version of the HES asset
3 | APC_SEQUENTIAL_DATE_FIELDS = ["ELECDATE", "ADMIDATE", "EPISTART", "EPIEND", "DISDATE"] # TODO: confirm the ordering
4 | AE_SEQUENTIAL_DATE_FIELDS = []
5 | OP_SEQUENTIAL_DATE_FIELDS = []
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_fields.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./sequential_field_lists
3 |
4 | # COMMAND ----------
5 |
6 | from typing import Iterable, List, Tuple
7 | from pyspark.sql import Column, functions as F
8 |
9 |
10 | def sort_fields_l2r(*fields: Iterable[str]) -> List[Tuple[str, Column]]: # TODO: pull this out into common functions
11 | sorted_array = F.array_sort(F.array(*fields))
12 | sorted_value_fields = [(x, F.element_at(sorted_array, i+1)) for i, x in enumerate(fields)]
13 | return sorted_value_fields
14 |
15 |
16 | def get_apc_sequential_fields() -> List[Tuple[str, Column]]:
17 | sorted_fields = sort_fields_l2r(*APC_SEQUENTIAL_DATE_FIELDS)
18 | return sorted_fields
19 |
20 |
21 | def get_op_sequential_fields() -> List[Tuple[str, Column]]:
22 | sorted_fields = sort_fields_l2r(*OP_SEQUENTIAL_DATE_FIELDS)
23 | return sorted_fields
24 |
25 |
26 | def get_ae_sequential_fields() -> List[Tuple[str, Column]]:
27 | sorted_fields = sort_fields_l2r(*AE_SEQUENTIAL_DATE_FIELDS)
28 | return sorted_fields
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_field_lists.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | APC_ID_FIELDS = [
3 | "SUSRECID",
4 | "PSEUDO_HESID",
5 | "AEKEY",
6 | "EPIKEY",
7 | ]
8 |
9 | OP_ID_FIELDS = [
10 | "PSEUDO_HESID",
11 | "ATTENDID",
12 | "ATTENDKEY",
13 | "PREFERER",
14 | ]
15 |
16 | AE_ID_FIELDS = [
17 | "AEKEY",
18 | "PSEUDO_HESID",
19 | "EPIKEY"
20 | ]
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_field_regex_patterns.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | ID_FIELD_PATTERNS = dict(
3 | PSEUDO_HESID = r"TEST[0-9a-zA-Z]{28}", # 32an - first 4 chars = TEST to ensure no overlap with real IDs
4 | SUSRECID = r"\d{14}",
5 | AEKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
6 | EPIKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
7 | ATTENDNO = r"[0-9a-zA-Z]{12}",
8 | ATTENDKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
9 | ATTENDID = r"[0-9a-zA-Z]{12}",
10 | PREFERER = r"[0-9a-zA-Z]{16}", # What about nulls (&) / invalids (99)
11 | )
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/artificial_hes/field_definitions/id_fields.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../../dependencies/spark_rstr
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../../dataset_definitions/relational_helpers
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./id_field_regex_patterns
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./id_field_lists
15 |
16 | # COMMAND ----------
17 |
18 | from typing import Iterable, Dict, List
19 |
20 | from pyspark.sql import DataFrame, Column, functions as F, Window as W
21 |
22 | # region Prevent downstream linter highlights
23 | get_parent_table = get_parent_table
24 | SparkRstr = SparkRstr
25 | ID_FIELD_PATTERNS = ID_FIELD_PATTERNS
26 | APC_ID_FIELDS = APC_ID_FIELDS
27 | AE_ID_FIELDS = AE_ID_FIELDS
28 | OP_ID_FIELDS = OP_ID_FIELDS
29 | Table = Table
30 | # endregion
31 |
32 |
33 | def _get_dataset_id_fields(field_names: Iterable[str]) -> Dict[str, Column]:
34 | rstr = SparkRstr()
35 | dataset_id_fields = set(field_names).intersection(ID_FIELD_PATTERNS.keys())
36 | return {field: rstr.xeger(ID_FIELD_PATTERNS[field]).alias(field) for field in dataset_id_fields}
37 |
38 |
39 | def _get_apc_id_fields() -> Column:
40 | return _get_dataset_id_fields(APC_ID_FIELDS)
41 |
42 |
43 | def _get_ae_id_fields() -> Column:
44 | return _get_dataset_id_fields(AE_ID_FIELDS)
45 |
46 |
47 | def _get_op_id_fields() -> Column:
48 | return _get_dataset_id_fields(OP_ID_FIELDS)
49 |
50 |
51 | def get_id_fields(dataset: str) -> List[Column]:
52 | id_fields_functions = {
53 | "hes_ae": _get_ae_id_fields,
54 | "hes_apc": _get_apc_id_fields,
55 | "hes_op": _get_op_id_fields,
56 | }
57 | _get_id_fields = id_fields_functions.get(dataset, lambda: [])
58 | id_fields = _get_id_fields()
59 | return id_fields
60 |
61 |
62 | def with_hes_id_fields(df: DataFrame, dataset: str, table: Table, index_fields: Iterable[str]) -> DataFrame:
63 | demographic_table = get_parent_table(table)
64 | demographic_field_names = [f.name for f in demographic_table.fields]
65 |
66 | # Split demographic and episode id fields so that demographic ids can be carried over
67 | id_fields = get_id_fields(dataset)
68 | demographic_id_fields = [id_fields.pop(field_name) for field_name in demographic_field_names if field_name in id_fields]
69 | episode_id_fields = [field_value_col for _, field_value_col in id_fields.items()]
70 |
71 | demographic_id_window = W.partitionBy(index_fields[0]).orderBy(index_fields[1])
72 |
73 | df_with_id_fields = (
74 | df
75 | .withColumn("DEMOGRAPHIC_IDS", F.first(F.struct(*demographic_id_fields)).over(demographic_id_window))
76 | .select(*df.columns, "DEMOGRAPHIC_IDS.*", *episode_id_fields)
77 | .drop(*index_fields)
78 | )
79 |
80 | return df_with_id_fields
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/driver_imports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../common/common_exports
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../dataset_definitions/relational_helpers
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../generator_stages/field_generators
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ../generator_stages/relationship_generator
15 |
16 | # COMMAND ----------
17 |
18 | # MAGIC %run ../generator_stages/derivations
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_pipelines/driver_parameters.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | import json
3 | from pprint import pprint
4 |
5 | # Prevent spark from auto broadcasting the metadata
6 | spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
7 |
8 | ARTIFICIAL_TABLE_NAME_TEMPLATE = "artificial_{table_name}"
9 |
10 | # COMMAND ----------
11 |
12 | dbutils.widgets.removeAll()
13 |
14 | # COMMAND ----------
15 |
16 | dbutils.widgets.text("meta_database", "iuod_artificial_data_generator", "0.1 Meta DB")
17 | dbutils.widgets.text("artificial_database", "iuod_artificial_data_generator", "0.2 Artificial DB")
18 | dbutils.widgets.text("parameters_json", "", "1.1 Parameters JSON")
19 |
20 | # COMMAND ----------
21 |
22 | # Config
23 | notebook_parameters = json.loads(get_required_argument("parameters_json"))
24 | n_patients = notebook_parameters['n_patients']
25 | artificial_database = get_required_argument("artificial_database")
26 | meta_database = get_required_argument("meta_database")
27 |
28 | print("Running with parameters: ")
29 | pprint({
30 | "n_patients": n_patients,
31 | "artificial_database": artificial_database,
32 | "meta_database": meta_database,
33 | })
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_stages/derivations.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import Iterable, Tuple
3 | from pyspark.sql import DataFrame, Column
4 |
5 |
6 | def with_derived_fields(df: DataFrame, *derived_fields: Iterable[Tuple[str, Column]]) -> DataFrame:
7 | """Add columns to a dataframe, with values possibly derived from existing columns.
8 |
9 | Args:
10 | df (DataFrame): DataFrame to add fields to
11 | *derived_fields (Iterable[Tuple[str, Column]]): Pairs of field_name, column_spec defining the columns to add
12 |
13 | Returns:
14 | DataFrame: New DataFrame with additional derived fields
15 |
16 | Notes:
17 | If it is not possible to add a particular column due to an exception, then a printout will appear
18 | and the column will be skipped. The resulting dataframe will not have this column added to it.
19 | """
20 | for field_name, col_spec in derived_fields:
21 | try:
22 | df = df.withColumn(field_name, col_spec)
23 | except Exception as e:
24 | # TODO: replace generic Exception
25 | print(f"Failed to derive field `{field_name}`: {e}")
26 | continue
27 |
28 | return df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_stages/relationship_generator.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./field_generators
3 |
4 | # COMMAND ----------
5 |
6 | from typing import Iterable
7 | from pyspark.sql import DataFrame, functions as F
8 |
9 | # COMMAND ----------
10 |
11 | def relationship_generator(meta_df: DataFrame, base_df: DataFrame, variable_fields: Iterable[str]) -> DataFrame:
12 | """Generate a dataframe of IDs according to the cardinal relationships between fields described by the metadata.
13 |
14 | Args:
15 | meta_df (pyspark.sql.DataFrame): Contains relationship metadata. This is a frequency distribution of
16 | cardinality, weight pairs - the weight for a given cardinality describes how likely it is the relationship
17 | between fields has that cardinality in the real data
18 | base_df (pyspark.sql.DataFrame): DataFrame containing a single column of IDs
19 | variable_fields (Iterable[str]): Name of fields in meta_df describing the variables to which the values relate
20 |
21 | Returns:
22 | pyspark.sql.DataFrame: Pairs of values for ID fields with the cardinal relationships described by the metadata.
23 | The generated IDs are stored in a column called 'VALUE_NUMERIC'
24 |
25 | """
26 | _relationship_generator = create_frequency_field_generator("RELATIONSHIP", "VALUE_NUMERIC")
27 | generated_df = _relationship_generator(meta_df, base_df, variable_fields)
28 |
29 | # Assign array of ids with length based on sampled value for the mapping
30 | # ids are created in chunks to avoid out of memory exceptions for large mapping values
31 | chunk_size = 100
32 | step_expr = f"int(least(greatest(1, VALUE_NUMERIC), {chunk_size}))"
33 | # Use the output ratio threshold to place a cap on the maximum cardinality that will be generated
34 | # If this cap is exceeded then the executors will fall over!
35 | output_ratio_threshold = spark.conf.get("spark.databricks.queryWatchdog.outputRatioThreshold")
36 | max_lower_bound_expr = f"int(least(coalesce(VALUE_NUMERIC, 0), {output_ratio_threshold}))"
37 |
38 | # Define the lower bounds for the id ranges
39 | id_lower_bound_col = F.expr(f"sequence(0, {max_lower_bound_expr}, {step_expr})")
40 |
41 | # Zip the bounds to get pairs of (lower, upper) for the ranges to generate
42 | id_bounds_col = F.expr("transform(slice(ID_LOWER_BOUNDS, 1, size(ID_LOWER_BOUNDS)-1), (x, i) -> struct(x AS LOWER, ID_LOWER_BOUNDS[i+1] AS UPPER))")
43 |
44 | generated_df = (
45 | generated_df
46 | .withColumn("ID_LOWER_BOUNDS", id_lower_bound_col)
47 | .withColumn("ID_BOUNDS", id_bounds_col)
48 | .select("*", F.expr("inline(ID_BOUNDS)"))
49 | .withColumn("VALUE_NUMERIC", F.expr("explode(sequence(LOWER, UPPER-1, 1))"))
50 | )
51 |
52 | return generated_df
53 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_stages/sampling/cdf_sampling.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./field_definitions
3 |
4 | # COMMAND ----------
5 |
6 | from typing import Iterable
7 | from pyspark.sql import DataFrame, functions as F
8 |
9 | # COMMAND ----------
10 |
11 | def sample_cdf(
12 | meta_df: DataFrame,
13 | base_df: DataFrame,
14 | variable_fields: Iterable[str],
15 | value_field: str,
16 | frequency_field: str,
17 | bin_size: float=0.05
18 | ) -> DataFrame:
19 | """Generate data by sampling from a cumulative distribution over each field in the metadata.
20 |
21 | Args:
22 | meta_df (pyspark.sql.DataFrame): Contains metadata
23 | base_df (pyspark.sql.DataFrame): Contains rows to generate samples for
24 | variable_fields (Iterable[str]): Field labels, where each field has its own CDF to sample
25 | value_field (str): Values in the distribution (e.g. categories, bins etc)
26 | frequency_field (str): Frequencies to sum to compute the CDF
27 | bin_size (float (default = 0.05)): Bin size in the range_join, used for performance tuning
28 |
29 | Returns:
30 | pyspark.sql.DataFrame: Generated data
31 | """
32 | cumsum_col = cumsum_over_window(frequency_field, variable_fields, (frequency_field, value_field))
33 | total_col = total_over_window(frequency_field, variable_fields)
34 | cumdist_df = (
35 | meta_df
36 | .withColumn("CUMSUM", cumsum_col)
37 | .withColumn("TOTAL", total_col)
38 | .withColumn("CDF_BIN", get_cdf_bin(frequency_field=frequency_field))
39 | )
40 |
41 | # Select the sampled values from the cumulative distribution
42 | # Generate a random number and use to do a range join on the cumulative distribution
43 | # The sampled value is the one where the random number falls within the bin of the CDF corresponding
44 | # to that value
45 | generated_df = (
46 | base_df
47 | .withColumn("RAND", F.rand())
48 | .alias("left")
49 | .hint("range_join", bin_size)
50 | .join(
51 | cumdist_df.alias("right"),
52 | how="inner",
53 | on=[
54 | *[F.col(f"left.{field}") == F.col(f"right.{field}") for field in variable_fields],
55 | F.col("right.CDF_BIN.LOWER") <= F.col("left.RAND"),
56 | F.col("left.RAND") < F.col("right.CDF_BIN.UPPER"),
57 | ]
58 | )
59 | )
60 |
61 | # Clean up
62 | ignore_fields = ["CUMSUM", "TOTAL", "CDF_BIN", "RAND", *[f"right.{field}" for field in variable_fields]]
63 | for field in ignore_fields:
64 | generated_df = generated_df.drop(F.col(field))
65 |
66 | return generated_df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/generator_stages/sampling/field_definitions.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import Iterable, Union
3 | from pyspark.sql import Column, functions as F, Window as W
4 |
5 |
6 | def cumsum_over_window(
7 | sum_col: Column,
8 | partition_cols: Union[Iterable[str], Iterable[Column]],
9 | order_cols: Union[Iterable[str], Iterable[Column]],
10 | ) -> Column:
11 | """Calculate the cumulative sum of a column over a window.
12 |
13 | Args:
14 | sum_col (pyspark.sql.Column): Column to sum over the partitions
15 | partition_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Partitions to sum within
16 | order_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Ordering for the cumulative sum
17 |
18 | Returns:
19 | pyspark.sql.Column: Column of within-partition cumulative sums
20 | """
21 | w = (
22 | W.partitionBy(*partition_cols)
23 | .orderBy(*order_cols) # Ordering by value will be slow!
24 | .rangeBetween(W.unboundedPreceding, 0)
25 | )
26 | return F.sum(sum_col).over(w)
27 |
28 |
29 | def total_over_window(sum_col: Column, partition_cols: Union[Iterable[str], Iterable[Column]]) -> Column:
30 | """Calculate the total sum of a column over a window.
31 |
32 | Parameters:
33 | sum_col (pyspark.sql.Column): Column to sum over the partitions
34 | partition_cols (Union[Iterable[str], Iterable[pyspark.sql.Column]]): Partitions to sum within
35 |
36 | Returns:
37 | pyspark.sql.Column: Column of within-partition sums
38 |
39 | """
40 | w = (
41 | W.partitionBy(*partition_cols)
42 | .rangeBetween(W.unboundedPreceding, W.unboundedFollowing)
43 | )
44 | return F.sum(sum_col).over(w)
45 |
46 |
47 | def get_cdf_bin(
48 | frequency_field: str="WEIGHT",
49 | cumsum_field: str="CUMSUM",
50 | total_field: str="TOTAL"
51 | ) -> Column:
52 | lower_bound_col = (F.col(cumsum_field) - F.col(frequency_field)) / F.col(total_field)
53 | upper_bound_col = F.col(cumsum_field) / F.col(total_field)
54 | cdf_bin_col = F.struct(
55 | lower_bound_col.alias("LOWER"),
56 | upper_bound_col.alias("UPPER")
57 | )
58 | return cdf_bin_col
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/driver_imports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../common/common_exports
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../scraper_stages/scrape_metadata
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../dataset_definitions/relational_helpers
11 |
12 | # COMMAND ----------
13 |
14 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/constants/disclosure_control_parameters.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | HES_FREQUENCY_ROUNDING = 5
3 | HES_MIN_FREQUENCY = 5
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/constants/excluded_fields.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | ID_FIELDS = [
3 | "PSEUDO_HESID",
4 | "SUSRECID",
5 | "AEKEY",
6 | "EPIKEY",
7 | "ATTENDNO",
8 | "ATTENDKEY",
9 | "ATTENDID",
10 | "PREFERER",
11 | ]
12 |
13 | OTHER_SENSITIVE_FIELDS = []
14 |
15 | DERIVED_FIELDS = [
16 | *[f"OPERTN_3_{i+1:02}" for i in range(24)], # Derived from OPERTN_4 (APC & OP)
17 | *[f"DIAG_3_{i+1:02}" for i in range(20)], # Derived from DIAG_4 (APC & OP)
18 | *[f"DIAG2_{i+1:02}" for i in range(12)], # Derived from DIAG3 (AE)
19 | *[f"TREAT2_{i+1:02}" for i in range(12)], # Derived from TREAT3 (AE)
20 | "OPERTN_3_CONCAT",
21 | "OPERTN_4_CONCAT",
22 | "OPERTN_COUNT",
23 | "DIAG_3_CONCAT",
24 | "DIAG_4_CONCAT",
25 | "DIAG_COUNT",
26 | "STARTAGE_CALC", # Derived from MYDOB and ADMIDATE (APC)
27 | "STARTAGE", # Derived from MYDOB and ADMIDATE (APC)
28 | "APPTAGE_CALC", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! Derive from APPTAGE (OP)
29 | # "APPTAGE", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! (OP)
30 | "ARRIVALAGE_CALC", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! Derive from ARRIVALAGE (AE)
31 | # "ARRIVALAGE", # Can't be derived from MYDOB because this isn't in HES Non Sensitive! (AE)
32 | ]
33 |
34 | EXCLUDED_FIELDS = [
35 | *ID_FIELDS,
36 | *OTHER_SENSITIVE_FIELDS,
37 | *DERIVED_FIELDS,
38 | ]
39 |
40 | # COMMAND ----------
41 |
42 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/demographic_aggregation_parameters.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../scraper_stages/aggregation/aggregation_parameters
3 |
4 | # COMMAND ----------
5 |
6 | from pyspark.sql import functions as F
7 |
8 | # region Prevent downstream linting highlights
9 | AggregationParametersType = AggregationParametersType
10 | get_frequency_aggregation_params = get_frequency_aggregation_params
11 | BASE_GROUPING_FIELDS = BASE_GROUPING_FIELDS
12 | weight_col = weight_col
13 | # endregion
14 |
15 |
16 | def get_demographic_categorical_aggregation_params(
17 | patient_index_field: str,
18 | frequency_rounding: int,
19 | min_frequency: int
20 | ) -> AggregationParametersType:
21 | """Get parameters defining the aggregations required to summarise 'DEMOGRAPHIC_CATEGORICAL'
22 | fields in long-formatted data. The result is a frequency distribution over weighted vectors,
23 | where the inner weights describe the proportion of events that a patient has recorded a given value
24 | and the outer weights indicate the number of patients with that pattern of inner weights across
25 | the dataset.
26 |
27 | Note: this logic is hes-specific.
28 |
29 | Args:
30 | patient_index_field (str): Field that indexes patients
31 | frequency_rounding (int): Rounding to apply to frequencies
32 | min_frequency (int): Cutoff below which frequencies are removed
33 |
34 | Returns:
35 | AggregationParametersType: Parameters defining the aggregations
36 | """
37 | # Aggregation 1: count total rows and distinct values per distinct value per patient
38 | params1 = dict(
39 | grouping_fields=[
40 | *BASE_GROUPING_FIELDS,
41 | patient_index_field,
42 | ],
43 | agg_cols=[
44 | F.collect_list(F.coalesce("VALUE_STRING", F.lit("null"))).alias("VALUE_STRING_ARRAY"), # Need nulls to be strings for aggregation
45 | weight_col,
46 | ],
47 | derivations=[
48 | F.array_sort(F.array_distinct("VALUE_STRING_ARRAY")).alias("DISTINCT_VALUE_STRING_ARRAY"),
49 | F.expr("transform(DISTINCT_VALUE_STRING_ARRAY, x -> round(size(filter(VALUE_STRING_ARRAY, y -> y == x)) / WEIGHT, 1))").alias("WEIGHT_ARRAY"), # Normalised weight per value
50 | F.expr("transform(DISTINCT_VALUE_STRING_ARRAY, x -> CASE WHEN x != 'null' THEN x END)").alias("VALUE_STRING_ARRAY"), # Replace nulls within the array of values
51 | ]
52 | )
53 |
54 | # Aggregation 2: count patients per distinct combinations of values and weights
55 | params2, = get_frequency_aggregation_params(
56 | "VALUE_STRING_ARRAY",
57 | frequency_rounding=frequency_rounding,
58 | min_frequency=min_frequency
59 | )
60 | params2["grouping_fields"].append("WEIGHT_ARRAY")
61 |
62 | return [
63 | params1,
64 | params2,
65 | ]
66 |
67 |
68 | def get_demographic_date_aggregation_params(
69 | patient_index_field: str,
70 | frequency_rounding: int,
71 | min_frequency: int
72 | ) -> AggregationParametersType:
73 | """Get parameters defining the aggregations required to summarise 'DEMOGRAPHIC_DATE'
74 | fields in long-formatted data. The result is a frequency distribution over dates
75 | with an associated incidence, which indicates the likelihood that a patient is measured with
76 | a value different from a given value. E.g. a patient's true DOB is 01/02/2022 but it
77 | could be recorded as 02/01/2022 or 01/02/2021.
78 |
79 | Note: this logic is hes-specific.
80 |
81 | Args:
82 | patient_index_field (str): Field that indexes patients
83 | frequency_rounding (int): Rounding to apply to frequencies
84 | min_frequency (int): Cutoff below which frequencies are removed
85 |
86 | Returns:
87 | AggregationParametersType: Parameters defining the aggregations
88 | """
89 | # Aggregation 1: frequency of each value per patient
90 | params1 = dict(
91 | grouping_fields=[
92 | *BASE_GROUPING_FIELDS,
93 | patient_index_field,
94 | "VALUE_STRING", # Demographic dates (e.g. MYDOB) are strings with date-like information
95 | ],
96 | agg_cols=[weight_col],
97 | )
98 |
99 | # Aggregation 2: collect values and weights per per patient
100 | # The 'VALUE_NUMERIC' column for demographic dates contains the incidence which is
101 | # defined as the mean proportion of patients that have the given 'VALUE_STRING' as the
102 | # mode value and where the specific record is associated with a that mode value
103 | # (i.e. the value is recorded 'correct')
104 | params2 = dict(
105 | grouping_fields = [
106 | *BASE_GROUPING_FIELDS,
107 | patient_index_field,
108 | ],
109 | agg_cols = [
110 | F.collect_list("VALUE_STRING").alias("VALUE_STRING_ARRAY"), # All values
111 | F.collect_list("WEIGHT").alias("WEIGHT_ARRAY"), # All weights
112 | F.sum("WEIGHT").alias("TOTAL_WEIGHT"), # Total weight for the patient
113 | F.max("WEIGHT").alias("MAX_WEIGHT"), # Largest weight for a value
114 | ],
115 | derivations = [
116 | F.expr("int(array_position(WEIGHT_ARRAY, MAX_WEIGHT))").alias("MAX_WEIGHT_INDEX"),
117 | F.expr("element_at(VALUE_STRING_ARRAY, MAX_WEIGHT_INDEX)").alias("VALUE_STRING"),
118 | (F.col("MAX_WEIGHT") / F.col("TOTAL_WEIGHT")).alias("VALUE_NUMERIC"), # Proportion of records for a patient equal to the mode value
119 | ],
120 | )
121 |
122 | # Aggregation 3: count patients with each mode value and compute the average proportion of episodes where the value differs from the mode
123 | params3, = get_frequency_aggregation_params(
124 | "VALUE_STRING",
125 | frequency_rounding=frequency_rounding,
126 | min_frequency=min_frequency
127 | )
128 | params3["agg_cols"].append(F.mean("VALUE_NUMERIC").alias("VALUE_NUMERIC")) # Average deviation from the mode value
129 |
130 | return [
131 | params1,
132 | params2,
133 | params3,
134 | ]
135 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/demographic_field_summarisers.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../dataset_definitions/hes/hes_patient_table
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../scraper_stages/aggregation/summariser_factory
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./demographic_aggregation_parameters
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./constants/disclosure_control_parameters
15 |
16 | # COMMAND ----------
17 |
18 | from typing import List, Callable
19 | from pyspark.sql import DataFrame
20 |
21 | # region Prevent downstream linting highlights
22 | HES_FREQUENCY_ROUNDING = HES_FREQUENCY_ROUNDING
23 | HES_MIN_FREQUENCY = HES_MIN_FREQUENCY
24 | HES_PATIENT_KEY = HES_PATIENT_KEY
25 | get_demographic_categorical_aggregation_params = get_demographic_categorical_aggregation_params
26 | get_demographic_date_aggregation_params = get_demographic_date_aggregation_params
27 | create_summariser = create_summariser
28 | # endregion
29 |
30 |
31 | def demographic_categorical_summariser(long_df: DataFrame) -> DataFrame:
32 | """Summarise fields in a long-formatted dataframe which have VALUE_TYPE='DEMOGRAPHIC_CATEGORICAL'
33 | These are fields such as sex or ethnicity, which in principle have a single underlying value for a given
34 | patient, but which in practice can have multiple values recorded across different visits. We want to
35 | sample in such a way that the patterns of 'errors' that we would see in real data are represented in
36 | artificial data.
37 |
38 | Note: this logic is hes-specific.
39 |
40 | Args:
41 | long_df (DataFrame): Long-formatted data to summarise
42 |
43 | Returns:
44 | DataFrame: Summarised data (long-formatted)
45 | """
46 | aggregation_params = get_demographic_categorical_aggregation_params(
47 | HES_PATIENT_KEY,
48 | frequency_rounding=HES_FREQUENCY_ROUNDING,
49 | min_frequency=HES_MIN_FREQUENCY
50 | )
51 | summariser = create_summariser("DEMOGRAPHIC_CATEGORICAL", aggregation_params)
52 | return summariser(long_df)
53 |
54 |
55 | def demographic_date_summariser(long_df: DataFrame) -> DataFrame:
56 | """Summarise fields in a long-formatted dataframe which have VALUE_TYPE='DEMOGRAPHIC_DATE'
57 | These are fields such as date of birth, which in principle have a single underlying value for a given
58 | patient, but which in practice can have multiple values recorded across different visits. We want to
59 | sample in such a way that the patterns of 'errors' that we would see in real data are represented in
60 | artificial data.
61 |
62 | Note: this logic is hes-specific.
63 |
64 | Args:
65 | long_df (DataFrame): Long-formatted data to summarise
66 |
67 | Returns:
68 | DataFrame: Summarised data (long-formatted)
69 | """
70 | aggregation_params = get_demographic_date_aggregation_params(
71 | HES_PATIENT_KEY,
72 | frequency_rounding=HES_FREQUENCY_ROUNDING,
73 | min_frequency=HES_MIN_FREQUENCY
74 | )
75 | summariser = create_summariser("DEMOGRAPHIC_DATE", aggregation_params)
76 | return summariser(long_df)
77 |
78 |
79 | def get_demographic_summarisers() -> List[Callable[[DataFrame], DataFrame]]:
80 | return [
81 | demographic_categorical_summariser,
82 | demographic_date_summariser,
83 | ]
84 |
85 | # COMMAND ----------
86 |
87 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_pipelines/hes/driver.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../dataset_definitions/hes/hes_tables
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../driver_imports
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./constants/type_overrides
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./constants/excluded_fields
15 |
16 | # COMMAND ----------
17 |
18 | # MAGIC %run ./constants/disclosure_control_parameters
19 |
20 | # COMMAND ----------
21 |
22 | # MAGIC %run ./demographic_field_summarisers
23 |
24 | # COMMAND ----------
25 |
26 | from pprint import pprint
27 | import json
28 | from typing import List, Dict
29 |
30 | from pyspark.sql import functions as F
31 |
32 |
33 | # region Prevent downstream linting highlights
34 | spark = spark
35 | dbutils = dbutils
36 | get_required_argument = get_required_argument
37 | Table = Table
38 | EXCLUDED_FIELDS = EXCLUDED_FIELDS
39 | create_table = create_table
40 | insert_into = insert_into
41 | get_type_overrides = get_type_overrides
42 | table_exists = table_exists
43 | wide_to_long = wide_to_long
44 | get_foreign_keys = get_foreign_keys
45 | scrape_fields = scrape_fields
46 | HES_TABLES = HES_TABLES
47 | HES_FREQUENCY_ROUNDING = HES_FREQUENCY_ROUNDING
48 | HES_MIN_FREQUENCY = HES_MIN_FREQUENCY
49 | get_demographic_summarisers = get_demographic_summarisers
50 | scrape_relationships = scrape_relationships
51 | ingest = ingest
52 | # endregion
53 |
54 | # COMMAND ----------
55 |
56 | dbutils.widgets.removeAll()
57 |
58 | # COMMAND ----------
59 |
60 | dbutils.widgets.text("hes_database", "hes", "0.1 HES DB")
61 | dbutils.widgets.text("meta_database", "artificial_hes_meta", "0.2 Meta DB")
62 |
63 | # COMMAND ----------
64 |
65 | # Config
66 | hes_database = get_required_argument("hes_database")
67 | meta_database = get_required_argument("meta_database")
68 |
69 | print("Running with parameters: ")
70 | pprint({
71 | "hes_database": hes_database,
72 | "meta_database": meta_database,
73 | })
74 |
75 | # COMMAND ----------
76 |
77 | def scrape_hes_metadata(
78 | hes_table: Table,
79 | source_database: str,
80 | target_database: str,
81 | type_overrides: List[Dict[str, str]],
82 | overwrite=False
83 | ) -> bool:
84 | hes_df = ingest(source_database, hes_table.qualifier)
85 |
86 | # TODO: explain this
87 | hes_table_foreign_key = next(get_foreign_keys(hes_table)).name # PSEUDO_HESID
88 | included_fields = set(hes_df.columns).difference(EXCLUDED_FIELDS).union([hes_table_foreign_key])
89 | hes_included_df = hes_df.select(*included_fields)
90 |
91 | field_meta_df = scrape_fields(
92 | hes_included_df,
93 | type_overrides,
94 | frequency_rounding=HES_FREQUENCY_ROUNDING,
95 | min_frequency=HES_MIN_FREQUENCY,
96 | dataset_summarisers=get_demographic_summarisers(),
97 | unpivoted_fields=[hes_table_foreign_key]
98 | )
99 |
100 | # Check exclusions
101 | assert field_meta_df.filter(F.col("FIELD_NAME").isin(EXCLUDED_FIELDS)).first() is None, "Scraped fields which should have been excluded!"
102 |
103 | relational_meta_df = scrape_relationships(hes_df, hes_table)
104 |
105 | meta_df = field_meta_df.unionByName(relational_meta_df)
106 |
107 | # Write output
108 | output_args = (spark, meta_df, target_database, "artificial_hes_meta")
109 | if overwrite:
110 | success = create_table(
111 | *output_args,
112 | format="delta",
113 | mode="overwrite",
114 | overwriteSchema="true", # This doesn't work on ACL clusters!
115 | partitionBy=["TABLE_NAME", "FIELD_NAME"]
116 | )
117 | else:
118 | success = insert_into(*output_args, overwrite=False)
119 |
120 | return success
121 |
122 |
123 | # COMMAND ----------
124 |
125 | type_overrides = get_type_overrides()
126 | overwrite = True # Switch to False after first overwrite so that table gets appended
127 |
128 | # Generate artificial data for each table
129 | for hes_dataset, hes_dataset_tables in HES_TABLES.items():
130 | print(f"\nExtracting metadata for dataset '{hes_dataset}'")
131 |
132 | for hes_table in hes_dataset_tables:
133 | if not table_exists(spark, hes_database, hes_table.name):
134 | print(f"\tTable `{hes_database}.{hes_table.name}` does not exist: skipping")
135 | continue
136 | else:
137 | print(f"\tExtracting metadata for table `{hes_database}.{hes_table.name}`")
138 |
139 | success = scrape_hes_metadata(hes_table, hes_database, meta_database, type_overrides, overwrite=overwrite)
140 |
141 | # Switch to prevent overwriting current outputs
142 | overwrite = False
143 |
144 | spark.sql(f"OPTIMIZE {meta_database}.artificial_hes_meta").show()
145 |
146 | # COMMAND ----------
147 |
148 | notebook_response = {}
149 |
150 | if success:
151 | notebook_response["status"] = "success"
152 | notebook_response["message"] = ""
153 |
154 | else:
155 | notebook_response["status"] = "failed"
156 | notebook_response["message"] = "Failed to write metadata"
157 |
158 | dbutils.notebook.exit(json.dumps(notebook_response))
159 |
160 | # COMMAND ----------
161 |
162 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/aggregation_parameters.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from typing import Dict, List, Union
3 | from pyspark.sql import functions as F, Column
4 |
5 |
6 | BASE_GROUPING_FIELDS = ["TABLE_NAME", "FIELD_NAME", "VALUE_TYPE"]
7 | weight_col = F.count(F.lit(1)).cast("double").alias("WEIGHT")
8 | AggregationParametersType = List[Dict[str, Union[List[str], List[Column]]]]
9 |
10 |
11 | def round_to_nearest_interval(col: Column, interval: int=5) -> Column:
12 | """Build a new pyspark column by rounding a given column to a specified interval.
13 |
14 | Args:
15 | col (Column): Original column to be rounded
16 | interval (int, optional): Interval to round to. Defaults to 5.
17 |
18 | Returns:
19 | Column: Rounded column
20 | """
21 | rounded_col = (F.round(col / interval, 0) * interval)
22 | return rounded_col
23 |
24 |
25 | def get_frequency_aggregation_params(
26 | value_field: str,
27 | frequency_rounding: int,
28 | min_frequency: int
29 | ) -> AggregationParametersType:
30 | """Get the parameters required to aggregate frequency distributions and apply
31 | disclosure controls to the resulting frequencies.
32 |
33 | Args:
34 | value_field (str): Values to calculate the frequency of
35 | frequency_rounding (int): Rounding applied to frequencies
36 | min_frequency (int): Cutoff below which frequencies / values are excluded
37 |
38 | Returns:
39 | AggregationParametersType: Parameters to do the aggregations
40 | """
41 |
42 | # Need to make sure that filtering is based on actual weights, rather than rounded weights
43 | # Using rounded weights could let potentially disclosive values through!
44 | weight_above_cutoff_col = weight_col >= min_frequency
45 | rounded_weight_col = round_to_nearest_interval(weight_col, interval=frequency_rounding).alias("WEIGHT")
46 |
47 | return [
48 | dict(
49 | grouping_fields=[
50 | *BASE_GROUPING_FIELDS,
51 | value_field,
52 | ],
53 | agg_cols=[rounded_weight_col],
54 | filters=[weight_above_cutoff_col], # NOTE: if min_frequency = 1, any rows where the frequency is small enough to be rounded to 0 will be removed
55 | )
56 | ]
57 |
58 |
59 | def get_categorical_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType:
60 | """Get the parameters required to build the frequency distributions for categorical variables and apply
61 | disclosure controls to the resulting frequencies.
62 |
63 | Args:
64 | frequency_rounding (int): Rounding applied to frequencies
65 | min_frequency (int): Cutoff below which frequencies / values are excluded
66 |
67 | Returns:
68 | AggregationParametersType: Parameters to do the aggregations
69 | """
70 | return get_frequency_aggregation_params("VALUE_STRING", frequency_rounding, min_frequency)
71 |
72 |
73 | def get_discrete_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType:
74 | """Get the parameters required to build the frequency distributions for discrete variables and apply
75 | disclosure controls to the resulting frequencies.
76 |
77 | Args:
78 | frequency_rounding (int): Rounding applied to frequencies
79 | min_frequency (int): Cutoff below which frequencies / values are excluded
80 |
81 | Returns:
82 | AggregationParametersType: Parameters to do the aggregations
83 | """
84 | return get_frequency_aggregation_params("VALUE_NUMERIC", frequency_rounding, min_frequency)
85 |
86 |
87 | def get_date_aggregation_params(frequency_rounding: int, min_frequency: int) -> AggregationParametersType:
88 | """Get the parameters required to build the frequency distributions for date variables and apply
89 | disclosure controls to the resulting frequencies.
90 |
91 | Args:
92 | frequency_rounding (int): Rounding applied to frequencies
93 | min_frequency (int): Cutoff below which frequencies / values are excluded
94 |
95 | Returns:
96 | AggregationParametersType: Parameters to do the aggregations
97 | """
98 | return get_frequency_aggregation_params("VALUE_DATE", frequency_rounding, min_frequency)
99 |
100 |
101 | def get_continuous_aggregation_params():
102 | """Get the parameters required to calculate the distribution of percentiles for continuous variables
103 | for aggregation.
104 |
105 | Returns:
106 | AggregationParametersType: Parameters to do the aggregations
107 | """
108 | # TODO: remove nulls from percentiles
109 | # TODO: incidence (also need to put in generator function)
110 | return [
111 | dict(
112 | grouping_fields=BASE_GROUPING_FIELDS,
113 | agg_cols=[
114 | F.array([F.expr(f"approx_percentile(VALUE_NUMERIC, {i / 100})") for i in range(2, 100)]).alias("VALUE_NUMERIC_ARRAY"),
115 | ]
116 | )
117 | ]
118 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/field_summarisers.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./summariser_factory
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./aggregation_parameters
7 |
8 | # COMMAND ----------
9 |
10 | from functools import reduce
11 | from typing import Optional, Iterable, Callable
12 |
13 | from pyspark.sql import functions as F, DataFrame
14 |
15 |
16 | # region Prevent downstream linting highlights
17 | create_summariser = create_summariser
18 | get_categorical_aggregation_params = get_categorical_aggregation_params
19 | get_date_aggregation_params = get_date_aggregation_params
20 | get_discrete_aggregation_params = get_discrete_aggregation_params
21 | get_continuous_aggregation_params = get_continuous_aggregation_params
22 | # endregion
23 |
24 |
25 | def field_summariser(
26 | long_df: DataFrame,
27 | frequency_rounding: int,
28 | min_frequency: int,
29 | dataset_summarisers: Optional[Iterable[Callable]]=None
30 | ) -> DataFrame:
31 | """Summarise fields in a long-formatted dataframe with value type 'CATEGORICAL',
32 | 'DATE', 'DISCRETE' or 'CONTINUOUS'.
33 |
34 | Args:
35 | long_df (pyspark.sql.DataFrame): Long-formatted data to summarise
36 | frequency_rounding (int): Rounding applied to frequencies
37 | min_frequency (int): Cutoff below which frequencies are removed.
38 | dataset_summarisers (Optional[Iterable[Callable]], optional): Additional summariser
39 | functions for dataset-specific operations. Defaults to None.
40 |
41 | Returns:
42 | pyspark.sql.DataFrame: Summarised data (long-formatted)
43 | """
44 | disclosure_control_params = (frequency_rounding, min_frequency)
45 |
46 | if dataset_summarisers is None:
47 | dataset_summarisers = []
48 |
49 | summarisers = [
50 | create_summariser("CATEGORICAL", get_categorical_aggregation_params(*disclosure_control_params)),
51 | create_summariser("DATE", get_date_aggregation_params(*disclosure_control_params)),
52 | create_summariser("DISCRETE", get_discrete_aggregation_params(*disclosure_control_params)),
53 | create_summariser("CONTINUOUS", get_continuous_aggregation_params()),
54 | *dataset_summarisers,
55 | ]
56 |
57 | meta_dfs = map(lambda summariser: summariser(long_df), summarisers)
58 | meta_df = reduce(DataFrame.union, meta_dfs)
59 |
60 | return meta_df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/relationship_summariser.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../wide_to_long
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./summariser_factory
7 |
8 | # COMMAND ----------
9 |
10 | from pyspark.sql import DataFrame, functions as F
11 |
12 | # region Prevent downstream linting highlights
13 | BASE_GROUPING_FIELDS = BASE_GROUPING_FIELDS
14 | AggregationParametersType = AggregationParametersType
15 | weight_col = weight_col
16 | create_summariser = create_summariser
17 | # endregion
18 |
19 |
20 | def get_relationship_aggregation_params(pk_field: str, fk_field: str) -> AggregationParametersType:
21 | """Get the parameters required to build the frequency distributions of cardinal relationships
22 | between a primary key and foreign key
23 |
24 | Args:
25 | pk_field (str): Name of the primary key field
26 | fk_field (str): Name of the foreign key field
27 |
28 | Returns:
29 | AggregationParametersType: Parameters to do the aggregations
30 | """
31 | # Count the number of 1-to-N mappings (where N >= 0) between foreign and primary keys
32 | # Null primary keys contribute towards 1-to-0 counts
33 | # Null foreign keys are ignored
34 | cardinality_col = F.approx_count_distinct(pk_field)
35 | return [
36 | dict(
37 | grouping_fields=[*BASE_GROUPING_FIELDS, fk_field],
38 | agg_cols=[cardinality_col.alias("VALUE_NUMERIC")],
39 | filters=[~F.isnull(fk_field)] # TODO: Handle nulls separately
40 | ),
41 | dict(
42 | grouping_fields=[*BASE_GROUPING_FIELDS, "VALUE_NUMERIC"],
43 | agg_cols=[weight_col],
44 | ),
45 | ]
46 |
47 |
48 | def summarise_relationships(wide_df: DataFrame, pk_field: str, fk_field: str) -> DataFrame:
49 | """Summarise the frequency distributions of cardinal relationships between primary key
50 | and foreign key fields in wide-formatted data.
51 |
52 | Args:
53 | wide_df (DataFrame): DataFrame holding the relationships to summarise
54 | pk_field (str): Name of the primary key field
55 | fk_field (str): Name of the foreign key field
56 |
57 | Returns:
58 | DataFrame: Summarised relationship data
59 | """
60 | wide_df = (
61 | wide_df
62 | .withColumn("FIELD_NAME", F.lit(pk_field))
63 | .withColumn("VALUE_TYPE", F.lit("RELATIONSHIP"))
64 | .select(*BASE_GROUPING_FIELDS, pk_field, fk_field)
65 | )
66 |
67 | aggregation_params = get_relationship_aggregation_params(pk_field, fk_field)
68 | summariser = create_summariser("RELATIONSHIP", aggregation_params)
69 | meta_df = (
70 | summariser(wide_df)
71 | .withColumn("VALUE_STRING", F.lit(fk_field))
72 | )
73 |
74 | return meta_df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/aggregation/summariser_factory.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./aggregation_parameters
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../common/coerce_schema
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../schemas/meta_schema
11 |
12 | # COMMAND ----------
13 |
14 | from functools import reduce
15 | from typing import Iterable, Callable, Dict, Union
16 |
17 | from pyspark.sql import functions as F, DataFrame, Column
18 |
19 |
20 | # region Prevent downstream linting highlights
21 | coerce_schema = coerce_schema
22 | AggregationParametersType = AggregationParametersType
23 | get_meta_schema = get_meta_schema
24 | # endregion
25 |
26 |
27 | def df_aggregation_reducer(
28 | df: DataFrame,
29 | aggregation_params: AggregationParametersType
30 | ) -> DataFrame:
31 | """Apply a sequence of aggregations, filters and derivations to a dataframe
32 |
33 | Args:
34 | df (DataFrame): DataFrame to aggregate
35 | aggregation_params (AggregationParametersType): Defines the aggregations.
36 | This is an iterable of keyed objects, where each element has the keys
37 | 'grouping_fields', 'agg_cols', 'derivations' and 'filters': the first 2
38 | are used in the DataFrame.groupBy(...).agg(...); the 3rd defines columns
39 | to add to the dataframe after the aggregation; and the 4th is a set
40 | of filter columns to apply to the output
41 |
42 | Returns:
43 | DataFrame: Aggregated data
44 | """
45 | def _reducer(acc, x):
46 | acc = (
47 | acc
48 | .groupBy(*x["grouping_fields"])
49 | .agg(*x["agg_cols"])
50 | )
51 |
52 | for derived_col in x.get("derivations", []):
53 | col_name = acc.select(derived_col).columns[0]
54 | acc = acc.withColumn(col_name, derived_col)
55 |
56 | for condition_col in x.get("filters", []):
57 | acc = acc.filter(condition_col)
58 |
59 | return acc
60 |
61 | agg_df = reduce(_reducer, aggregation_params, df)
62 |
63 | return agg_df
64 |
65 |
66 | def create_summariser(
67 | value_type: str,
68 | aggregation_params: AggregationParametersType
69 | ) -> Callable[[DataFrame], DataFrame]:
70 | """Create a function which summarises (long-formatted) data with a given type.
71 |
72 | Args:
73 | value_type (str): Describes type of data to summarise (e.g. 'CATEGORICAL')
74 | aggregation_params (AggregationParametersType): Defines the aggregations to summarise the data
75 |
76 | Returns:
77 | Callable[[DataFrame], DataFrame]: Summarises long-formatted data
78 | """
79 | value_type_condition_col = F.col("VALUE_TYPE") == value_type
80 |
81 | def _summariser(df: DataFrame) -> DataFrame:
82 | filtered_df = df.filter(value_type_condition_col)
83 | aggregated_df = df_aggregation_reducer(filtered_df, aggregation_params)
84 | meta_df = coerce_schema(aggregated_df, get_meta_schema())
85 | return meta_df
86 |
87 | return _summariser
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/ingestion.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../common/common_exports
3 |
4 | # COMMAND ----------
5 |
6 | from pyspark.sql import DataFrame, functions as F
7 |
8 | # COMMAND ----------
9 |
10 | def ingest(source_database: str, source_table: str) -> DataFrame:
11 | """Read the data from the source table
12 |
13 | Args:
14 | source_database (str): Database name containing source table
15 | source_table (str): Name of source table
16 |
17 | Returns:
18 | pyspark.sql.DataFrame: DataFrame with data to scrape
19 |
20 | """
21 | source_table_fullname = f"{source_database}.{source_table}"
22 | dataset_df = (
23 | spark.table(source_table_fullname)
24 | .withColumn("TABLE_NAME", F.lit(source_table))
25 | )
26 |
27 | return dataset_df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/schemas/long_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import types as T
3 |
4 |
5 | def get_long_schema() -> T.StructType:
6 | return T.StructType([
7 | T.StructField("FIELD_NAME", T.StringType(), False),
8 | T.StructField("VALUE_STRING", T.StringType(), True),
9 | T.StructField("VALUE_NUMERIC", T.DoubleType(), True),
10 | T.StructField("VALUE_DATE", T.DateType(), True),
11 | T.StructField("VALUE_TIMESTAMP", T.TimestampType(), True),
12 | ])
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/schemas/meta_schema.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | from pyspark.sql import types as T
3 |
4 |
5 | def get_meta_schema() -> T.StructType:
6 | return T.StructType([
7 | T.StructField("TABLE_NAME", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO
8 | T.StructField("FIELD_NAME", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO
9 | T.StructField("VALUE_TYPE", T.StringType(), True), # This should ideally be false but that breaks the schema uplift. Left as TODO
10 | T.StructField("VALUE_STRING", T.StringType(), True),
11 | T.StructField("VALUE_NUMERIC", T.DoubleType(), True),
12 | # T.StructField("VALUE_TIMESTAMP", T.TimestampType(), True),
13 | T.StructField("VALUE_DATE", T.DateType(), True),
14 | T.StructField("VALUE_STRING_ARRAY", T.ArrayType(T.StringType()), True),
15 | T.StructField("VALUE_NUMERIC_ARRAY", T.ArrayType(T.DoubleType()), True),
16 | # T.StructField("VALUE_TIMESTAMP_ARRAY", T.ArrayType(T.TimestampType()), True),
17 | # T.StructField("VALUE_DATE_ARRAY", T.ArrayType(T.DateType()), True),
18 | T.StructField("WEIGHT", T.DoubleType(), True),
19 | T.StructField("WEIGHT_ARRAY", T.ArrayType(T.DoubleType()), True),
20 | ])
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/notebooks/scraper_stages/scrape_metadata.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./ingestion
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ./wide_to_long
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ./schemas/meta_schema
11 |
12 | # COMMAND ----------
13 |
14 | # MAGIC %run ./aggregation/field_summarisers
15 |
16 | # COMMAND ----------
17 |
18 | # MAGIC %run ./aggregation/relationship_summariser
19 |
20 | # COMMAND ----------
21 |
22 | # MAGIC %run ../dataset_definitions/relational_helpers
23 |
24 | # COMMAND ----------
25 |
26 | from typing import Optional, Iterable, Dict, Callable
27 | from pyspark.sql import DataFrame, functions as F
28 |
29 | # region Prevent downstream linting highlights
30 | spark = spark
31 | Table = Table
32 | wide_to_long = wide_to_long
33 | field_summariser = field_summariser
34 | get_primary_key = get_primary_key
35 | get_foreign_keys = get_foreign_keys
36 | summarise_relationships = summarise_relationships
37 | get_meta_schema = get_meta_schema
38 | ingest = ingest
39 | # endregion
40 |
41 |
42 | def scrape_fields(
43 | wide_df: DataFrame,
44 | type_overrides: Iterable[Dict[str, str]],
45 | frequency_rounding: int,
46 | min_frequency: int,
47 | dataset_summarisers: Optional[Iterable[Callable]] = None,
48 | unpivoted_fields: Optional[Iterable[str]] = None
49 | )-> DataFrame:
50 | """Scrape the field metadata (i.e. frequency / percentile distributions) from wide-formatted data
51 |
52 | Args:
53 | wide_df (DataFrame): Wide-formatted data to scrape
54 | type_overrides (Iterable[Dict[str, str]]): Overrides for metadata types, used in place of automatic types
55 | frequency_rounding (int): Rounding applied to frequency-typed fields
56 | min_frequency (int): Cutoff below which frequencies / values are removed
57 | dataset_summarisers (Optional[Iterable[Callable]], optional): Dataset-specific summariser functions. Passed through
58 | to summarise_fields. Defaults to None.
59 | unpivoted_fields (Optional[Iterable[str]], optional): Fields to preserve in the original wide_df. Defaults to None.
60 |
61 | Returns:
62 | DataFrame: Field metadata
63 | """
64 |
65 | if unpivoted_fields is None:
66 | unpivoted_fields = []
67 |
68 | unpivoted_fields = list(set(["TABLE_NAME", *unpivoted_fields]))
69 |
70 | # Pivot to long format
71 | long_df = wide_to_long(wide_df, preserved_field_names=unpivoted_fields)
72 |
73 | # Associate metadata types (e.g. 'categorical')
74 | # TODO: what happens if fields don't appear in type overrides?
75 | type_overrides_df = spark.createDataFrame(type_overrides, "FIELD_NAME: string, VALUE_TYPE: string")
76 | long_df = (
77 | long_df
78 | .join(type_overrides_df.hint("broadcast"), on="FIELD_NAME", how="left")
79 | .withColumn("VALUE_TYPE", F.coalesce(F.col("VALUE_TYPE"), F.lit("CATEGORICAL"))) # Fill nulls
80 | )
81 |
82 | # Summarise values
83 | meta_df = field_summariser(
84 | long_df,
85 | frequency_rounding=frequency_rounding,
86 | min_frequency=min_frequency,
87 | dataset_summarisers=dataset_summarisers,
88 | )
89 |
90 | return meta_df
91 |
92 |
93 | def scrape_relationships(wide_df: DataFrame, table: Table) -> DataFrame:
94 | """Scrape the relationship metadata (i.e. frequency distributions of cardinal relationshios)
95 | from wide-formatted data
96 |
97 | Args:
98 | wide_df (DataFrame): Wide-formatted data to scrape
99 | table (Table): Holds the names of primary and foreign key fields
100 |
101 | Returns:
102 | DataFrame: Relationship metadata
103 | """
104 | primary_key_field = get_primary_key(table)
105 |
106 | for foreign_key_field in get_foreign_keys(table):
107 | index_fields = [foreign_key_field.name, primary_key_field.name]
108 | index_df = wide_df.select("TABLE_NAME", *index_fields)
109 | meta_df = summarise_relationships(index_df, primary_key_field.name, foreign_key_field.name)
110 | return meta_df
111 |
112 | else:
113 | empty_meta_df = spark.createDataFrame([], get_meta_schema())
114 | return empty_meta_df
115 |
116 |
117 | def scrape_metadata(
118 | database_name: str,
119 | source_table: Table,
120 | type_overrides: Iterable[Dict[str, str]],
121 | frequency_rounding: int,
122 | min_frequency: int,
123 | excluded_fields: Iterable[str],
124 | included_fields: Optional[Iterable[str]] = None,
125 | dataset_summarisers: Optional[Iterable[Callable]] = None,
126 | ) -> DataFrame:
127 | """Scrape field and relationship metadata from a table
128 |
129 | Args:
130 | database_name (str): Database containing the table to scrape
131 | source_table (Table): Table to scrape
132 | type_overrides (Iterable[Dict[str, str]]): Overrides for metadata types, used in place of automatic types
133 | frequency_rounding (int): Rounding applied to frequency-typed fields
134 | min_frequency (int): Cutoff below which frequencies / values are removed
135 | dataset_summarisers (Optional[Iterable[Callable]], optional): Dataset-specific summariser functions. Passed through
136 | to summarise_fields. Defaults to None.
137 |
138 | Returns:
139 | DataFrame: Scraped metadata
140 | """
141 | dataset_df = ingest(database_name, source_table.qualifier)
142 |
143 | field_meta_df = scrape_fields(
144 | dataset_df,
145 | type_overrides,
146 | frequency_rounding=frequency_rounding,
147 | min_frequency=min_frequency,
148 | dataset_summarisers=dataset_summarisers
149 | )
150 |
151 | relational_meta_df = scrape_relationships(dataset_df, source_table)
152 |
153 | meta_df = field_meta_df.unionByName(relational_meta_df)
154 |
155 | return meta_df
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/run_notebooks.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Artificial Data Generator: Run Notebooks
4 | # MAGIC This notebook runs the data generator pipeline for a given artificial dataset. The artificial data will be written into the database specified by the `CP Project Database` argument.
5 | # MAGIC
6 | # MAGIC **Prerequisites:** The underlying driver will ingest the metadata that was scraped from the real data and signed off against the disclosure control checklist - the metadata must exist within a database that is readable by this code promotion project (aka _iuod_artificial_data_generator_). The database containing metadata is specified via the `CP Project Database` argument.
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %md
11 | # MAGIC ## Setup
12 |
13 | # COMMAND ----------
14 |
15 | # MAGIC %md
16 | # MAGIC ### Import dependencies
17 |
18 | # COMMAND ----------
19 |
20 | # MAGIC %run ./notebooks/common/widget_utils
21 |
22 | # COMMAND ----------
23 |
24 | from pprint import pprint
25 | from pathlib import Path
26 |
27 | GENERATOR_PIPELINES_ROOT = Path("./notebooks/generator_pipelines")
28 | AVAILABLE_DRIVERS = [
29 | "artificial_hes"
30 | ]
31 |
32 | # COMMAND ----------
33 |
34 | # MAGIC %md
35 | # MAGIC ### Create widgets
36 |
37 | # COMMAND ----------
38 |
39 | if check_databricks():
40 | dbutils.widgets.removeAll()
41 | dbutils.widgets.text("db", "iuod_artificial_data_generator", "0.1 CP Project Database")
42 | dbutils.widgets.text("artificial_dataset", "", "1.1 Artificial Dataset")
43 | dbutils.widgets.text("parameters_json", "", "1.2 Parameters JSON")
44 | else:
45 | # Only make widgets in databricks
46 | pass
47 |
48 | # COMMAND ----------
49 |
50 | # MAGIC %md
51 | # MAGIC ## Driver execution
52 | # MAGIC Trigger the driver notebook for the artificial dataset specified via the arguments
53 |
54 | # COMMAND ----------
55 |
56 | artificial_dataset = get_required_argument("artificial_dataset")
57 |
58 | if artificial_dataset not in AVAILABLE_DRIVERS:
59 | message = f"Driver does not exist for dataset '{artificial_dataset}': please choose from {AVAILABLE_DRIVERS}"
60 | dbutils.notebook.exit(json.dumps({"status": "failed", "message": message}))
61 |
62 | # Path to the notebook that generates data for the given dataset
63 | driver_path = GENERATOR_PIPELINES_ROOT / artificial_dataset / "driver"
64 |
65 | # Prepare driver parameters
66 | database = get_required_argument("db")
67 | driver_args = dict(
68 | meta_database = database,
69 | artificial_database = database,
70 | parameters_json = get_required_argument("parameters_json"),
71 | )
72 |
73 | print(f"Running driver {str(driver_path)} with arguments:")
74 | pprint(driver_args)
75 |
76 | driver_response = dbutils.notebook.run(str(driver_path), 0, arguments=driver_args)
77 |
78 | dbutils.notebook.exit(driver_response)
79 |
80 | # COMMAND ----------
81 |
82 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md ### README
3 | # MAGIC This notebook will run all the tests! Each test notebook is run in turn and the results are printed out in the cell below the run. New test notebooks should be added to the runs.
4 | # MAGIC
5 | # MAGIC Test failures will be raised as an error in the final cell of this notebook (i.e. all the tests will run first, regardless of passes or failures, and then there is a final step to check how many failures there were and alert the user). This should avoid failing tests remaining undetected.
6 |
7 | # COMMAND ----------
8 |
9 | from typing import Generator, Union, Tuple, Dict
10 | from pyspark.util import Py4JJavaError
11 |
12 |
13 | def get_test_runner() -> Generator[Dict[str, Union[None, AssertionError]], Tuple[str, str], None]:
14 | errors = {}
15 |
16 | while True:
17 | test_name, test_path = yield
18 |
19 | try:
20 | dbutils.notebook.run(test_path, 0)
21 |
22 | if test_name in errors:
23 | del errors[test_name]
24 | except Py4JJavaError as e:
25 | errors[test_name] = e
26 |
27 | yield errors
28 |
29 |
30 | def run_tests(tests):
31 | test_runner = get_test_runner()
32 | errors = test_runner.send(None)
33 |
34 | for test_name, test_path in tests.items():
35 | errors = test_runner.send((test_name, test_path))
36 | next(test_runner)
37 |
38 | if errors:
39 | for test_name, e in errors.items():
40 | print(f'Test failures detected in {tests[test_name]}: {e}')
41 |
42 | # COMMAND ----------
43 |
44 | tests = {
45 | "scraper_stages": "./tests/scraper_stages_tests/run_tests",
46 | "generator_stages": "./tests/generator_stages_tests/run_tests",
47 | "scraper_pipelines": "./tests/scraper_pipelines_tests/run_tests",
48 | "generator_pipelines": "./tests/generator_pipelines_tests/run_tests",
49 | "dataset_definitions": "./tests/dataset_definitions_tests/run_tests",
50 | "common": "./tests/common_tests/run_tests",
51 | }
52 |
53 | run_tests(tests)
54 |
55 | # COMMAND ----------
56 |
57 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/schemas/create_meta_table.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Create Metadata Table
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Imports
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ../notebooks/common/widget_utils
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %run ../notebooks/common/table_helpers
22 |
23 | # COMMAND ----------
24 |
25 | # MAGIC %run ../notebooks/scraper_stages/schemas/meta_schema
26 |
27 | # COMMAND ----------
28 |
29 | # MAGIC %md
30 | # MAGIC ### Widgets
31 |
32 | # COMMAND ----------
33 |
34 | dbutils.widgets.text( "database_name", "", "1.1 Database Name")
35 | dbutils.widgets.text( "table_name", "", "1.2 Table Name")
36 | dbutils.widgets.dropdown("replace_if_exists", "false", ["true", "false"], "2.1 Replace If Exists")
37 |
38 | # COMMAND ----------
39 |
40 | # MAGIC %md
41 | # MAGIC ## Main
42 |
43 | # COMMAND ----------
44 |
45 | database_name = get_required_argument("database_name")
46 | table_name = get_required_argument("table_name")
47 | replace_if_exists = dbutils.widgets.get("replace_if_exists") == "true"
48 |
49 | # COMMAND ----------
50 |
51 | if database_exists(spark, database_name, verbose=True):
52 | if table_exists(spark, database_name, table_name):
53 | if replace_if_exists:
54 | drop_table(spark, database_name, table_name)
55 | else:
56 | dbutils.notebook.exit({"status": "skipped", "message": "Table exists but replace_if_exists=False"})
57 |
58 | create_table_from_schema(
59 | spark,
60 | database_name,
61 | table_name,
62 | get_meta_schema(),
63 | mode="overwrite",
64 | format="delta",
65 | partitionBy=["TABLE_NAME", "FIELD_NAME"]
66 | )
67 |
68 | # COMMAND ----------
69 |
70 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/schemas/uplifts/artificial_hes_schema_coercion.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Artificial HES Schema Coercion
4 | # MAGIC Coerces the schema of existing `artificial_hes` tables to match the schema specification for HES tables.
5 |
6 | # COMMAND ----------
7 |
8 | # MAGIC %md ## Setup
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %md
13 | # MAGIC ### Imports
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ../../notebooks/common/widget_utils
18 |
19 | # COMMAND ----------
20 |
21 | # MAGIC %run ../../notebooks/common/spark_helpers
22 |
23 | # COMMAND ----------
24 |
25 | # MAGIC %run ../../notebooks/common/table_helpers
26 |
27 | # COMMAND ----------
28 |
29 | # MAGIC %run ../../notebooks/common/coerce_schema
30 |
31 | # COMMAND ----------
32 |
33 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_schemas
34 |
35 | # COMMAND ----------
36 |
37 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_tables
38 |
39 | # COMMAND ----------
40 |
41 | from typing import Iterable
42 | from datetime import datetime
43 |
44 | # COMMAND ----------
45 |
46 | # MAGIC %md
47 | # MAGIC ### Widgets
48 |
49 | # COMMAND ----------
50 |
51 | dbutils.widgets.text("database_name", "artificial_hes", "0.1 Database Name")
52 |
53 | # COMMAND ----------
54 |
55 | database_name = get_required_argument("database_name")
56 |
57 | # COMMAND ----------
58 |
59 | # MAGIC %md
60 | # MAGIC ## Functions
61 |
62 | # COMMAND ----------
63 |
64 | def backup_and_uplift_schema(
65 | database_name: str,
66 | table_name: str,
67 | target_schema: T.StructType,
68 | partitionBy: Iterable[str]=[],
69 | ) -> bool:
70 | backup_version = datetime.now().strftime("%Y%m%d%H%M")
71 | backup_table_name = f"{table_name}_backup_{backup_version}"
72 |
73 | current_df = spark.table(f"{database_name}.{table_name}")
74 |
75 | if not datatypes_equal(current_df.schema, target_schema):
76 | print("Modifying schema to target")
77 |
78 | # Backup the current table
79 | create_table(
80 | spark,
81 | current_df,
82 | database_name,
83 | backup_table_name,
84 | format="delta",
85 | mode="overwrite",
86 | partitionBy=partitionBy,
87 | )
88 |
89 | # Drop the current table which has now been backed-up
90 | drop_table(spark, database_name, table_name)
91 |
92 | # Apply the schema changes
93 | uplifted_df = coerce_schema(
94 | spark.table(f"{database_name}.{backup_table_name}"),
95 | target_schema
96 | )
97 |
98 | # Write the new table
99 | create_table(
100 | spark,
101 | uplifted_df,
102 | database_name,
103 | table_name,
104 | format="delta",
105 | mode="overwrite",
106 | partitionBy=partitionBy,
107 | )
108 |
109 | # Remove the backup
110 | drop_table(spark, database_name, backup_table_name)
111 |
112 | else:
113 | print("Current schema already equal to target schema!")
114 |
115 | target_df = spark.table(f"{database_name}.{table_name}")
116 |
117 | assert datatypes_equal(target_df.schema, target_schema)
118 |
119 | return True
120 |
121 |
122 | # COMMAND ----------
123 |
124 | # MAGIC %md
125 | # MAGIC ## Main
126 |
127 | # COMMAND ----------
128 |
129 | for dataset_name, dataset_tables in HES_TABLES.items():
130 | target_schema = get_hes_schema(dataset_name)
131 |
132 | for hes_table in dataset_tables:
133 | artificial_table_name = f"artificial_{hes_table.name}"
134 |
135 | if table_exists(spark, database_name, artificial_table_name):
136 | print(f"Uplifting schema for table {database_name}.{artificial_table_name}")
137 |
138 | backup_and_uplift_schema(
139 | database_name,
140 | artificial_table_name,
141 | target_schema
142 | )
143 | else:
144 | continue
145 |
146 | # COMMAND ----------
147 |
148 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/schemas/uplifts/artificial_hes_schema_coercion_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC ## Setup
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ### Imports
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %run ../../notebooks/common/spark_helpers
13 |
14 | # COMMAND ----------
15 |
16 | # MAGIC %run ../../notebooks/common/table_helpers
17 |
18 | # COMMAND ----------
19 |
20 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_schemas
21 |
22 | # COMMAND ----------
23 |
24 | # MAGIC %run ../../notebooks/dataset_definitions/hes/hes_tables
25 |
26 | # COMMAND ----------
27 |
28 | from pyspark.sql import functions as F, types as T
29 |
30 | # COMMAND ----------
31 |
32 | # MAGIC %md
33 | # MAGIC ### Setup tables
34 |
35 | # COMMAND ----------
36 |
37 | target_schema = get_hes_ae_schema()
38 |
39 | # Before the uplift there are some extra fields in the hes ae tables
40 | # We want to test that these are removed by the uplift
41 | current_schema = T.StructType([
42 | *target_schema.fields,
43 | T.StructField("DIAG_3_CONCAT", T.StringType(), True),
44 | T.StructField("DIAG_4_CONCAT", T.StringType(), True),
45 | T.StructField("DIAG_COUNT", T.StringType(), True),
46 | T.StructField("OPERTN_3_CONCAT", T.StringType(), True),
47 | T.StructField("OPERTN_4_CONCAT", T.StringType(), True),
48 | T.StructField("OPERTN_COUNT", T.StringType(), True),
49 | ])
50 |
51 | database_name = "alistair_jones5_101351"
52 | table_name = "artificial_hes_ae_1415"
53 |
54 | # Recreate the table
55 | if table_exists(spark, database_name, table_name):
56 | drop_table(spark, database_name, table_name)
57 |
58 | create_table_from_schema(
59 | spark,
60 | database_name,
61 | table_name,
62 | current_schema,
63 | format="delta",
64 | mode="overwrite",
65 | )
66 |
67 | # COMMAND ----------
68 |
69 | # MAGIC %md
70 | # MAGIC ## Run tests
71 |
72 | # COMMAND ----------
73 |
74 | # Should be able to run twice without failure
75 | for i in range(2):
76 | dbutils.notebook.run("./artificial_hes_schema_coercion", 0, {"database_name": database_name})
77 |
78 | actual_schema = spark.table(f"{database_name}.{table_name}").schema
79 | assert datatypes_equal(target_schema, actual_schema)
80 |
81 | # COMMAND ----------
82 |
83 | # MAGIC %md
84 | # MAGIC ## Cleanup
85 |
86 | # COMMAND ----------
87 |
88 | # MAGIC %md
89 | # MAGIC ### Teardown tables
90 |
91 | # COMMAND ----------
92 |
93 | drop_table(spark, database_name, table_name)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/schemas/uplifts/open_data_metadata_uplift.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Open Data Metadata Uplift
4 | # MAGIC
5 | # MAGIC This notebook uplifts the schema of a table containing metadata to greatly improve adherance to [NHS Digital's Open Data standards](https://github.com/NHS Digitaligital/open-data-standards).
6 |
7 | # COMMAND ----------
8 |
9 | # MAGIC %md ## Setup, Imports & Widgets
10 |
11 | # COMMAND ----------
12 |
13 | # MAGIC %run ../../notebooks/common/common_exports
14 |
15 | # COMMAND ----------
16 |
17 | # MAGIC %run ../../notebooks/scraper_stages/schemas/meta_schema
18 |
19 | # COMMAND ----------
20 |
21 | dbutils.widgets.removeAll()
22 |
23 | # COMMAND ----------
24 |
25 | dbutils.widgets.text("database_name", "iuod_artificial_data_generator", "0.1 Database Name")
26 | dbutils.widgets.text("table_name", "", "0.2 Metadata Table Name")
27 |
28 | # COMMAND ----------
29 |
30 | import json
31 | from pyspark.sql import types as T
32 |
33 | # COMMAND ----------
34 |
35 | # MAGIC %md
36 | # MAGIC ## Constants
37 |
38 | # COMMAND ----------
39 |
40 | # MAGIC %md ### SQL Expressions
41 |
42 | # COMMAND ----------
43 |
44 | # Template for transforming the metadata
45 | select_open_data_metadata_template_expr = """
46 | SELECT
47 | TABLE_NAME
48 | ,FIELD_NAME
49 | ,coalesce(
50 | CASE
51 | WHEN METADATA.SUMMARY.CATEGORICAL.VALUE_TYPE == 'FLAT' THEN 'CATEGORICAL'
52 | WHEN METADATA.SUMMARY.CATEGORICAL.VALUE_TYPE == 'NESTED' THEN 'DEMOGRAPHIC_CATEGORICAL'
53 | WHEN FIELD_NAME == 'MYDOB' THEN 'DEMOGRAPHIC_DATE'
54 | END,
55 | METADATA.SUMMARY.VALUE_TYPE,
56 | METADATA.VALUE_TYPE
57 | ) AS VALUE_TYPE
58 | ,coalesce(
59 | METADATA.FIELD_TYPE,
60 | METADATA.SUMMARY.CATEGORICAL.FLAT.VALUE,
61 | METADATA.SUMMARY.DATE.FORMAT
62 | ) AS VALUE_STRING
63 | ,coalesce(
64 | METADATA.SUMMARY.DISCRETE.VALUE,
65 | METADATA.RELATIONSHIP.MAPPING
66 | ) AS VALUE_NUMERIC
67 | ,METADATA.SUMMARY.DATE.VALUE AS VALUE_DATE
68 | ,coalesce(
69 | METADATA.SUMMARY.CATEGORICAL.NESTED.VALUE.VALUE,
70 | METADATA.SUMMARY.DATE.NOISE.REPLACEMENT
71 | ) as VALUE_STRING_ARRAY
72 | ,cast(null as array) as VALUE_NUMERIC_ARRAY
73 | ,coalesce(
74 | METADATA.SUMMARY.CATEGORICAL.FLAT.WEIGHT,
75 | METADATA.SUMMARY.CATEGORICAL.NESTED.WEIGHT,
76 | METADATA.SUMMARY.DATE.WEIGHT,
77 | METADATA.SUMMARY.DISCRETE.WEIGHT,
78 | METADATA.RELATIONSHIP.FREQUENCY
79 | ) AS WEIGHT
80 | ,coalesce(
81 | METADATA.SUMMARY.CATEGORICAL.NESTED.VALUE.WEIGHT,
82 | METADATA.SUMMARY.DATE.NOISE.WEIGHT
83 | ) AS WEIGHT_ARRAY
84 | FROM {db}.{table} -- Template parameters
85 | WHERE NOT METADATA.SUMMARY.VALUE_TYPE <=> 'CONTINUOUS'
86 | UNION
87 | SELECT
88 | TABLE_NAME
89 | ,FIELD_NAME
90 | ,METADATA.SUMMARY.VALUE_TYPE
91 | ,cast(null AS string) AS VALUE_STRING
92 | ,cast(null AS double) AS VALUE_NUMERIC
93 | ,cast(null AS date) AS VALUE_DATE
94 | ,cast(null AS array) AS VALUE_STRING_ARRAY
95 | ,array_sort(collect_list(METADATA.SUMMARY.CONTINUOUS.VALUE)) AS VALUE_NUMERIC_ARRAY
96 | ,cast(null AS double) AS WEIGHT
97 | ,cast(null AS array) AS WEIGHT_ARRAY
98 | FROM {db}.{table} -- Template parameters
99 | WHERE METADATA.SUMMARY.VALUE_TYPE <=> 'CONTINUOUS'
100 | GROUP BY TABLE_NAME, FIELD_NAME, METADATA.SUMMARY.VALUE_TYPE
101 | """
102 |
103 | # COMMAND ----------
104 |
105 | # MAGIC %md
106 | # MAGIC ## Main
107 |
108 | # COMMAND ----------
109 |
110 | db = get_required_argument("database_name")
111 | target_table = get_required_argument("table_name")
112 | backup_table = f"{target_table}_v1_schema_backup"
113 |
114 | print(f"Uplifting schema of table {db}.{target_table}")
115 |
116 | # COMMAND ----------
117 |
118 | if not database_exists(spark, db, verbose=True):
119 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Database {db} not found"}))
120 |
121 | if not table_exists(spark, db, target_table):
122 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Table {db}.{target_table} does not exist"}))
123 |
124 | source_df = spark.table(f"{db}.{target_table}")
125 | source_schema = T.StructType(sorted(source_df.schema.fields, key=lambda f: f.name))
126 | target_schema = T.StructType(sorted(get_meta_schema().fields, key=lambda f: f.name))
127 |
128 | if not datatypes_equal(target_schema, source_schema):
129 | print(f"Schema of table {db}.{target_table} not equal to target schema")
130 |
131 | # Can't rename tables, so need to backup first and then drop to reuse the name
132 | print(f"Backing up table {db}.{target_table} to table {db}.{backup_table}")
133 | create_table(spark, source_df, db, backup_table, partitionBy=["TABLE_NAME", "FIELD_NAME"])
134 | drop_table(spark, db, target_table)
135 |
136 | print(f"Recreating table {db}.{target_table} from backup with target schema")
137 | select_open_data_metadata_expr = select_open_data_metadata_template_expr.format(db=db, table=backup_table)
138 | target_df = spark.sql(select_open_data_metadata_expr)
139 | create_table(spark, target_df, db, target_table, partitionBy=["TABLE_NAME", "FIELD_NAME"])
140 | dbutils.notebook.exit(json.dumps({"status": "success", "message": ""}))
141 |
142 | else:
143 | dbutils.notebook.exit(json.dumps({"status": "skipped", "message": f"Schema of table {db}.{target_table} already equal to target schema"}))
144 |
145 |
146 | # COMMAND ----------
147 |
148 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/schemas/uplifts/open_data_metadata_uplift_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC # Setup
4 |
5 | # COMMAND ----------
6 |
7 | # MAGIC %md
8 | # MAGIC ## Imports
9 |
10 | # COMMAND ----------
11 |
12 | # MAGIC %run ../../notebooks/common/table_helpers
13 |
14 | # COMMAND ----------
15 |
16 | # MAGIC %run ../../notebooks/common/spark_helpers
17 |
18 | # COMMAND ----------
19 |
20 | # MAGIC %run ../../notebooks/scraper_stages/schemas/meta_schema
21 |
22 | # COMMAND ----------
23 |
24 | import json
25 | from pyspark.sql.types import StructType
26 |
27 | database_name = "artificial_hes_meta"
28 | table_name = "open_data_metadata_uplift_tests__artificial_hes_meta"
29 | backup_table_name = f"{table_name}_v1_schema_backup"
30 |
31 | # COMMAND ----------
32 |
33 | # MAGIC %md
34 | # MAGIC ## Setup tables
35 |
36 | # COMMAND ----------
37 |
38 | v1_meta_schema_json = '{"type": "struct", "fields": [{"name": "TABLE_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "FIELD_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "METADATA", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "FIELD_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "SUMMARY", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "CATEGORICAL", "type": {"type": "struct", "fields": [{"name": "VALUE_TYPE", "type": "string", "nullable": true, "metadata": {}}, {"name": "FLAT", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "NESTED", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": {"type": "array", "elementType": {"type": "struct", "fields": [{"name": "VALUE", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "containsNull": true}, "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "DISCRETE", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "integer", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "CONTINUOUS", "type": {"type": "struct", "fields": [{"name": "PERCENTILE", "type": "double", "nullable": true, "metadata": {}}, {"name": "VALUE", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "DATE", "type": {"type": "struct", "fields": [{"name": "VALUE", "type": "date", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}, {"name": "FORMAT", "type": "string", "nullable": true, "metadata": {}}, {"name": "NOISE", "type": {"type": "array", "elementType": {"type": "struct", "fields": [{"name": "REPLACEMENT", "type": "string", "nullable": true, "metadata": {}}, {"name": "WEIGHT", "type": "double", "nullable": true, "metadata": {}}]}, "containsNull": true}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "ID", "type": "string", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}, {"name": "RELATIONSHIP", "type": {"type": "struct", "fields": [{"name": "LINKED_TABLE_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "LINKED_FIELD_NAME", "type": "string", "nullable": true, "metadata": {}}, {"name": "MAPPING", "type": "long", "nullable": true, "metadata": {}}, {"name": "FREQUENCY", "type": "double", "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}, "nullable": true, "metadata": {}}]}'
39 |
40 | v1_meta_schema = StructType.fromJson(json.loads(v1_meta_schema_json))
41 |
42 | print(f"Creating table to test uplift: {database_name}.{table_name}")
43 | create_table_from_schema(spark, database_name, table_name, v1_meta_schema)
44 |
45 | # COMMAND ----------
46 |
47 | # MAGIC %md # Run tests
48 |
49 | # COMMAND ----------
50 |
51 | # Try uplift twice: once represents fresh database, twice is to check that running it again doesn't break anything!
52 | for i in range(2):
53 | # Uplift
54 | result = dbutils.notebook.run("./open_data_metadata_uplift", 0, {"database_name": database_name, "table_name": table_name})
55 | print(f"Uplift finished with response: {result}")
56 |
57 | # Test for backup
58 | assert table_exists(spark, database_name, backup_table_name), "Backup table doesn't exist"
59 | assert datatypes_equal(spark.table(f"{database_name}.{backup_table_name}").schema, v1_meta_schema), "Backup table has incorrect schema"
60 |
61 | # Test for uplifted
62 | uplift_meta_schema = get_meta_schema()
63 | assert table_exists(spark, database_name, table_name), "Uplift table doesn't exist"
64 | assert datatypes_equal(spark.table(f"{database_name}.{table_name}").schema, uplift_meta_schema), "Uplifted table has incorrect schema"
65 |
66 |
67 | # COMMAND ----------
68 |
69 | # MAGIC %md
70 | # MAGIC # Cleanup
71 |
72 | # COMMAND ----------
73 |
74 | # MAGIC %md
75 | # MAGIC ## Teardown tables
76 |
77 | # COMMAND ----------
78 |
79 | print("Cleaning up tables")
80 | drop_table(spark, database_name, backup_table_name)
81 | drop_table(spark, database_name, table_name)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/admin_tests/placeholder_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/common_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./test_spark_helpers",
9 | ]
10 |
11 | for notebook_path in test_notebook_paths:
12 | dbutils.notebook.run(notebook_path, 0)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/dataset_definitions_tests/placeholder_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/dataset_definitions_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = []
8 |
9 | for notebook_path in test_notebook_paths:
10 | dbutils.notebook.run(notebook_path, 0)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/demographic_field_generators_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../../notebooks/common/spark_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../test_helpers
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../../../notebooks/generator_pipelines/artificial_hes/demographic_field_generators
11 |
12 | # COMMAND ----------
13 |
14 | # import demographic_field_generators
15 |
16 |
17 | from datetime import datetime
18 | from pyspark.sql import DataFrame, functions as F
19 |
20 |
21 | def tidy_demographic_meta_df_fixture() -> DataFrame:
22 | data = [
23 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["a", "b"], None, 1., [2., 3.], None],
24 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["c", "d"], None, 2., [1., 1.], None],
25 | ["field1", "DEMOGRAPHIC_CATEGORICAL", ["e"], None, 3., [1.], None],
26 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["a"], None, 3., [1.], None],
27 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["b", "c"], None, 2., [2., 3.], None],
28 | ["field2", "DEMOGRAPHIC_CATEGORICAL", ["d", "e"], None, 1., [1., 1.], None],
29 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-01", 1., None, 0.5],
30 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-02", 2., None, 1.0],
31 | ["field3", "DEMOGRAPHIC_DATE", None, "2022-01-03", 3., None, 1.0],
32 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-01", 3., None, 1.0],
33 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-02", 2., None, 0.8],
34 | ["field4", "DEMOGRAPHIC_DATE", None, "2022-01-03", 1., None, 1.0],
35 | ]
36 | schema = "FIELD_NAME: string, VALUE_TYPE: string, VALUE_STRING_ARRAY: array, VALUE_STRING: string, WEIGHT: double, WEIGHT_ARRAY: array, VALUE_NUMERIC: double"
37 | df = spark.createDataFrame(data, schema)
38 | return df
39 |
40 | # COMMAND ----------
41 |
42 | demographic_categorical_field_generator_test_suite = FunctionTestSuite()
43 |
44 |
45 | @demographic_categorical_field_generator_test_suite.add_test
46 | def test_generated_values_have_expected_frequency_distribution():
47 | n_rows = 10000
48 | n_array_values = 10
49 | base_df = (
50 | spark.range(n_rows)
51 | .withColumn("id2", F.explode(F.sequence(F.lit(1), F.lit(n_array_values), F.lit(1)))) # 10 values per unique id
52 | )
53 |
54 | meta_df = tidy_demographic_meta_df_fixture()
55 |
56 | result_df = demographic_categorical_field_generator(meta_df, base_df, ["FIELD_NAME"], ["id", "id2"])
57 |
58 | aggregated_result_df = (
59 | result_df
60 | .groupBy("FIELD_NAME", "VALUE_STRING")
61 | .agg((F.count(F.lit(1)) / (n_rows * n_array_values)).alias("FREQUENCY"))
62 | )
63 |
64 | # Need to account for the vector weights too
65 | expected_frequency_col = (
66 | F
67 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "a"), (2 / 5) * (1 / 6))
68 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "b"), (3 / 5) * (1 / 6))
69 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "c"), (1 / 2) * (2 / 6))
70 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "d"), (1 / 2) * (2 / 6))
71 | .when((F.col("FIELD_NAME") == "field1") & (F.col("VALUE_STRING") == "e"), 3 / 6)
72 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "a"), 3 / 6)
73 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "b"), (2 / 5) * (2 / 6))
74 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "c"), (3 / 5) * (2 / 6))
75 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "d"), (1 / 2) * (1 / 6))
76 | .when((F.col("FIELD_NAME") == "field2") & (F.col("VALUE_STRING") == "e"), (1 / 2) * (1 / 6))
77 | )
78 |
79 | aggregated_result_df = aggregated_result_df.withColumn("EXPECTED_FREQUENCY", expected_frequency_col)
80 | assert columns_approx_equal(aggregated_result_df, "FREQUENCY", "EXPECTED_FREQUENCY", 0.1), "Some or all frequencies did not equal the expected values"
81 |
82 |
83 | demographic_categorical_field_generator_test_suite.run()
84 |
85 |
86 | # COMMAND ----------
87 |
88 | demographic_date_field_generator_test_suite = FunctionTestSuite()
89 |
90 |
91 | @demographic_date_field_generator_test_suite.add_test
92 | def test_generated_values_have_expected_frequency_distribution():
93 | n_rows = 10000
94 | n_array_values = 100
95 | base_df = (
96 | spark.range(n_rows)
97 | .withColumn("id2", F.explode(F.sequence(F.lit(1), F.lit(n_array_values), F.lit(1)))) # n values per unique id
98 | )
99 |
100 | meta_df = tidy_demographic_meta_df_fixture()
101 |
102 | result_df = demographic_date_field_generator(meta_df, base_df, ["FIELD_NAME"], ["id", "id2"])
103 |
104 | aggregated_result_df = (
105 | result_df
106 | .groupBy("FIELD_NAME", "VALUE_STRING")
107 | .agg((F.count(F.lit(1)) / (n_rows * n_array_values)).alias("FREQUENCY"))
108 | )
109 |
110 | # Expected frequencies (approx. accounting for the noise incidence)
111 | expected_frequency_col = (
112 | F
113 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-01")), 0.5 * (1 / 6))
114 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-02")), (2 / 6) + ((2 / 5) * (1 / 24)))
115 | .when((F.col("FIELD_NAME") == "field3") & (F.col("VALUE_STRING") == F.lit("2022-01-03")), (3 / 6) + ((3 / 5) * (1 / 24)))
116 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-01")), (3 / 6) + ((3 / 5) * (1 / 30)))
117 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-02")), 0.8 * (2 / 6))
118 | .when((F.col("FIELD_NAME") == "field4") & (F.col("VALUE_STRING") == F.lit("2022-01-03")), (1 / 6) + ((2 / 5) * (1 / 30)))
119 | )
120 | aggregated_result_df = aggregated_result_df.withColumn("EXPECTED_FREQUENCY", expected_frequency_col)
121 |
122 | assert columns_approx_equal(aggregated_result_df, "FREQUENCY", "EXPECTED_FREQUENCY", 0.1), "Some or all frequencies did not equal the expected values"
123 |
124 |
125 | demographic_date_field_generator_test_suite.run()
126 |
127 | # COMMAND ----------
128 |
129 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/derivations_tests/age_fields_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../imports
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/age_fields
7 |
8 | # COMMAND ----------
9 |
10 | from pyspark.sql import functions as F
11 |
12 | # COMMAND ----------
13 |
14 | def create_test_date_df():
15 | df = spark.createDataFrame([
16 | (0, None, "2021-12-31"),
17 | (1, "2021-12-31", "2021-12-31"),
18 | (2, "2021-12-31", "2021-12-26"),
19 | (3, "2021-12-31", "2021-12-04"),
20 | (4, "2021-12-31", "2021-10-08"),
21 | (5, "2021-12-31", "2021-07-05"),
22 | (6, "2021-12-31", "2021-05-04"),
23 | (7, "2021-12-31", "2021-01-01"),
24 | (8, "2021-12-31", "1900-12-21"),
25 | (9, "2021-12-31", "2001-12-31"),
26 | ], ["id", "ADMIDATE", "DOB"]
27 | )
28 |
29 | return df
30 |
31 |
32 | # COMMAND ----------
33 |
34 | test_get_fractional_age_field = FunctionTestSuite()
35 |
36 |
37 | @test_get_fractional_age_field.add_test
38 | def test_creates_column_with_expected_values():
39 | # Inputs
40 | df = create_test_date_df()
41 |
42 | # Define expectations
43 | expected_fractional_age = (
44 | F.when(F.col("id") == 0, None)
45 | .when(F.col("id") == 1, 0.002)
46 | .when(F.col("id") == 2, 0.010)
47 | .when(F.col("id") == 3, 0.048)
48 | .when(F.col("id") == 4, 0.167)
49 | .when(F.col("id") == 5, 0.375)
50 | .when(F.col("id") == 6, 0.625)
51 | .when(F.col("id") == 7, 0.875)
52 | .when(F.col("id") == 8, 120)
53 | .when(F.col("id") == 9, 7305/365)
54 | )
55 | df = df.withColumn("EXPECTED_FRAC_AGE", expected_fractional_age)
56 |
57 | # Execute function under test
58 | fractional_age_calc = get_fractional_age_field("ADMIDATE", "DOB")
59 | df = df.withColumn("FRAC_AGE_CALC", fractional_age_calc)
60 |
61 | # Check against expectation
62 | assert columns_equal(df, "FRAC_AGE_CALC", "EXPECTED_FRAC_AGE")
63 |
64 |
65 | test_get_fractional_age_field.run()
66 |
67 | # COMMAND ----------
68 |
69 | test_get_categorized_age_field = FunctionTestSuite()
70 |
71 |
72 | @test_get_categorized_age_field.add_test
73 | def test_creates_column_with_expected_values():
74 | # Inputs
75 | df = create_test_date_df()
76 |
77 | # Define expectations
78 | expected_categorized_age = (
79 | F.when(F.col("id") == 0, None)
80 | .when(F.col("id") == 1, 7001)
81 | .when(F.col("id") == 2, 7002)
82 | .when(F.col("id") == 3, 7003)
83 | .when(F.col("id") == 4, 7004)
84 | .when(F.col("id") == 5, 7005)
85 | .when(F.col("id") == 6, 7006)
86 | .when(F.col("id") == 7, 7007)
87 | .when(F.col("id") == 8, 120)
88 | .when(F.col("id") == 9, 20)
89 | )
90 | df = df.withColumn("EXPECTED_CAT_AGE", expected_categorized_age)
91 |
92 | # Execute function under test
93 | categorized_age_calc = get_categorized_age_field("ADMIDATE", "DOB")
94 | df = df.withColumn("CAT_AGE_CALC", categorized_age_calc)
95 |
96 | # Check against expectation
97 | assert columns_equal(df, "CAT_AGE_CALC", "EXPECTED_CAT_AGE")
98 |
99 |
100 | test_get_categorized_age_field.run()
101 |
102 | # COMMAND ----------
103 |
104 | test_get_fractional_from_categorized_age_field = FunctionTestSuite()
105 |
106 |
107 | @test_get_fractional_from_categorized_age_field.add_test
108 | def test_creates_column_with_expected_values():
109 | # Inputs
110 | schema = "id: int, APPTAGE: int"
111 | df = spark.createDataFrame(
112 | [
113 | (0, None),
114 | (1, 7001),
115 | (2, 7002),
116 | (3, 7003),
117 | (4, 7004),
118 | (5, 7005),
119 | (6, 7006),
120 | (7, 7007),
121 | (8, 120),
122 | (9, 38)
123 | ],
124 | schema
125 | )
126 |
127 | # Define expectations
128 | expected_fractional_age = (
129 | F.when(F.col("id") == 0, None)
130 | .when(F.col("id") == 1, 0.002)
131 | .when(F.col("id") == 2, 0.010)
132 | .when(F.col("id") == 3, 0.048)
133 | .when(F.col("id") == 4, 0.167)
134 | .when(F.col("id") == 5, 0.375)
135 | .when(F.col("id") == 6, 0.625)
136 | .when(F.col("id") == 7, 0.875)
137 | .when(F.col("id") == 8, 120)
138 | .when(F.col("id") == 9, 38)
139 | )
140 | df = df.withColumn("EXPECTED_APPTAGE_CALC", expected_fractional_age)
141 |
142 | # Execute function under test
143 | fractional_age = get_fractional_from_categorized_age_field("APPTAGE")
144 | df = df.withColumn("APPTAGE_CALC", fractional_age)
145 |
146 | # Check against expectation
147 | assert columns_equal(df, "APPTAGE_CALC", "EXPECTED_APPTAGE_CALC"), df.show()
148 |
149 |
150 | test_get_fractional_from_categorized_age_field.run()
151 |
152 | # COMMAND ----------
153 |
154 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/derivations_tests/sequential_fields_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../imports
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/derivations/sequential_fields
7 |
8 | # COMMAND ----------
9 |
10 | from pyspark.sql import functions as F, types as T
11 | from datetime import datetime
12 | from itertools import starmap
13 |
14 | # COMMAND ----------
15 |
16 | test_sort_fields_l2r = FunctionTestSuite()
17 |
18 |
19 | @test_sort_fields_l2r.add_test
20 | def test_sorts_two_columns_correctly():
21 | # Inputs
22 | schema = "DATE_OF_BIRTH: date, DATE_OF_DEATH: date"
23 | df = spark.createDataFrame([
24 | [datetime(2021, 1, 1), datetime(2020, 1, 1)],
25 | [datetime(2019, 1, 1), datetime(2020, 1, 1)],
26 | [datetime(2019, 1, 1), datetime(2016, 1, 12)],
27 | [datetime(2019, 2, 1), datetime(2019, 1, 1)],
28 | ], schema)
29 |
30 | # Execute function under test
31 | sorted_columns = sort_fields_l2r(*df.columns)
32 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns
33 | result_df = df.select(*sorted_columns)
34 |
35 | # Define expectations
36 | expected_df = spark.createDataFrame([
37 | [datetime(2020, 1, 1), datetime(2021, 1, 1)],
38 | [datetime(2019, 1, 1), datetime(2020, 1, 1)],
39 | [datetime(2016, 1, 12), datetime(2019, 1, 1)],
40 | [datetime(2019, 1, 1), datetime(2019, 2, 1)],
41 | ], schema)
42 |
43 | # Check against expectation
44 | assert dataframes_equal(result_df, expected_df)
45 |
46 |
47 | @test_sort_fields_l2r.add_test
48 | def test_sorts_many_columns_correctly():
49 | # Inputs
50 | schema = "DATE_OF_BIRTH: date, DATE_OF_ADMISSION: date, DATE_OF_TREATMENT: date, DATE_OF_DISCHARGE: date"
51 | df = spark.createDataFrame([
52 | [datetime(2021, 1, 1), datetime(2020, 1, 1), datetime(2020, 3, 1), datetime(2020, 2, 1)],
53 | [datetime(2016, 1, 1), datetime(2018, 1, 1), datetime(2018, 3, 5), datetime(2020, 3, 4)],
54 | [datetime(2021, 1, 1), datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2020, 4, 1)],
55 | [datetime(2021, 1, 1), datetime(2020, 1, 1), datetime(2019, 3, 1), datetime(2018, 2, 1)],
56 | ], schema)
57 |
58 | # Execute function under test
59 | sorted_columns = sort_fields_l2r(*df.columns)
60 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns
61 | result_df = df.select(*sorted_columns)
62 |
63 | # Define expectations
64 | expected_df = spark.createDataFrame([
65 | [datetime(2020, 1, 1), datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2021, 1, 1)],
66 | [datetime(2016, 1, 1), datetime(2018, 1, 1), datetime(2018, 3, 5), datetime(2020, 3, 4)],
67 | [datetime(2020, 2, 1), datetime(2020, 3, 1), datetime(2020, 4, 1), datetime(2021, 1, 1)],
68 | [datetime(2018, 2, 1), datetime(2019, 3, 1), datetime(2020, 1, 1), datetime(2021, 1, 1)],
69 | ], schema)
70 |
71 | # Check against expectation
72 | assert dataframes_equal(result_df, expected_df)
73 |
74 |
75 | @test_sort_fields_l2r.add_test
76 | def test_pushes_nulls_to_right_when_sorting():
77 | # Inputs
78 | schema = "DATE_OF_BIRTH: date, DATE_OF_DEATH: date"
79 | df = spark.createDataFrame([
80 | [datetime(2021, 1, 1), None],
81 | [datetime(2019, 1, 1), datetime(2020, 1, 1)],
82 | [None, None],
83 | [None, datetime(2019, 1, 1)],
84 | ], schema)
85 |
86 | # Execute function under test
87 | sorted_columns = sort_fields_l2r(*df.columns)
88 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns
89 | result_df = df.select(*sorted_columns)
90 |
91 | # Define expectations
92 | expected_df = spark.createDataFrame([
93 | [datetime(2021, 1, 1), None],
94 | [datetime(2019, 1, 1), datetime(2020, 1, 1)],
95 | [None, None],
96 | [datetime(2019, 1, 1), None],
97 | ], schema)
98 |
99 | # Check against expectation
100 | assert dataframes_equal(result_df, expected_df)
101 |
102 |
103 | @test_sort_fields_l2r.add_test
104 | def test_does_not_affect_unspecified_columns():
105 | # Inputs
106 | schema = ["DATE_OF_BIRTH", "DATE_OF_DEATH", "ID", "NAME"]
107 | df = spark.createDataFrame([
108 | [datetime(2021, 1, 1), datetime(2020, 1, 1), 0, "name1"],
109 | [datetime(2019, 1, 1), datetime(2020, 1, 1), 1, "name1"],
110 | [datetime(2019, 1, 1), datetime(2016, 1, 12), 2, "name2"],
111 | [datetime(2019, 2, 1), datetime(2019, 1, 1), 3, "name3"],
112 | ], schema)
113 |
114 | # Execute function under test
115 | sorted_columns = sort_fields_l2r("DATE_OF_BIRTH", "DATE_OF_DEATH")
116 | sorted_columns = starmap(lambda field_name, col_spec: col_spec.alias(field_name), sorted_columns) # Assign aliases to columns
117 | result_df = df.select(*sorted_columns, "ID", "NAME")
118 |
119 | # Define expectations
120 | expected_df = spark.createDataFrame([
121 | [datetime(2020, 1, 1), datetime(2021, 1, 1), 0, "name1"],
122 | [datetime(2019, 1, 1), datetime(2020, 1, 1), 1, "name1"],
123 | [datetime(2016, 1, 12), datetime(2019, 1, 1), 2, "name2"],
124 | [datetime(2019, 1, 1), datetime(2019, 2, 1), 3, "name3"],
125 | ], schema)
126 |
127 | # Check against expectation
128 | assert dataframes_equal(result_df, expected_df)
129 |
130 |
131 | test_sort_fields_l2r.run()
132 |
133 | # COMMAND ----------
134 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/id_fields_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ./imports
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../../../notebooks/dependencies/spark_rstr
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../../../../notebooks/generator_pipelines/artificial_hes/field_definitions/id_fields
11 |
12 | # COMMAND ----------
13 |
14 | from pyspark.sql import functions as F
15 |
16 | # COMMAND ----------
17 |
18 | rstr = SparkRstr()
19 |
20 | # COMMAND ----------
21 |
22 | # ID fields:
23 |
24 | APC_ID_FIELDS = [
25 | "SUSRECID",
26 | "PSEUDO_HESID",
27 | "AEKEY",
28 | "EPIKEY",
29 | ]
30 |
31 | OP_ID_FIELDS = [
32 | "PSEUDO_HESID",
33 | "ATTENDID",
34 | "ATTENDKEY",
35 | "PREFERER",
36 | ]
37 |
38 | AE_ID_FIELDS = [
39 | "AEKEY",
40 | "PSEUDO_HESID",
41 | "EPIKEY"
42 | ]
43 |
44 | # COMMAND ----------
45 |
46 | # Patterns:
47 |
48 | ID_FIELD_PATTERNS = dict(
49 | PSEUDO_HESID = r"TEST[0-9a-zA-Z]{28}", # 32an - first 4 chars = TEST to ensure no overlap with real IDs
50 | SUSRECID = r"\d{14}",
51 | AEKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
52 | EPIKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
53 | ATTENDNO = r"[0-9a-zA-Z]{12}",
54 | ATTENDKEY = r"\d{12}", # Changes to r"\d{20}" in 2021/22
55 | ATTENDID = r"[0-9a-zA-Z]{12}",
56 | PREFERER = r"[0-9a-zA-Z]{16}", # What about nulls (&) / invalids (99)
57 | )
58 |
59 | # COMMAND ----------
60 |
61 | test_get_id_fields = FunctionTestSuite()
62 |
63 |
64 | @test_get_id_fields.add_test
65 | def test_creates_hes_ae_columns():
66 | # Inputs
67 | df = spark.range(15)
68 |
69 | # Execute function under test
70 | id_field_cols = get_id_fields("hes_ae")
71 | for field_name in id_field_cols:
72 | df = df.withColumn(field_name, id_field_cols[field_name])
73 |
74 | for field in AE_ID_FIELDS:
75 | # Define expectations
76 | extracted_field_name = f'{field}_TEST'
77 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0)
78 | df = df.withColumn(extracted_field_name, extracted_field)
79 |
80 | # Check against expectation
81 | assert columns_equal(df, field, extracted_field_name)
82 |
83 |
84 | @test_get_id_fields.add_test
85 | def test_creates_hes_apc_columns():
86 | # Inputs
87 | df = spark.range(15)
88 |
89 | # Execute function under test
90 | id_field_cols = get_id_fields("hes_apc")
91 | for field_name in id_field_cols:
92 | df = df.withColumn(field_name, id_field_cols[field_name])
93 |
94 | for field in APC_ID_FIELDS:
95 | # Define expectations
96 | extracted_field_name = f'{field}_TEST'
97 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0)
98 | df = df.withColumn(extracted_field_name, extracted_field)
99 |
100 | # Check against expectation
101 | assert columns_equal(df, field, extracted_field_name)
102 |
103 |
104 | @test_get_id_fields.add_test
105 | def test_creates_hes_op_columns():
106 | # Inputs
107 | df = spark.range(15)
108 |
109 | # Execute function under test
110 | id_field_cols = get_id_fields("hes_op")
111 | for field_name in id_field_cols:
112 | df = df.withColumn(field_name, id_field_cols[field_name])
113 |
114 | for field in OP_ID_FIELDS:
115 | # Define expectations
116 | extracted_field_name = f'{field}_TEST'
117 | extracted_field = F.regexp_extract(field, ID_FIELD_PATTERNS[field], 0)
118 | df = df.withColumn(extracted_field_name, extracted_field)
119 |
120 | # Check against expectation
121 | assert columns_equal(df, field, extracted_field_name)
122 |
123 |
124 | test_get_id_fields.run()
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/field_definitions_tests/imports.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../../test_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../../../notebooks/common/spark_helpers
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/artificial_hes_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./field_definitions_tests/derivations_tests/age_fields_tests",
9 | "./field_definitions_tests/derivations_tests/coded_fields_tests",
10 | "./field_definitions_tests/derivations_tests/derivations_helpers_tests",
11 | "./field_definitions_tests/derivations_tests/sequential_fields_tests",
12 | "./field_definitions_tests/id_fields_tests",
13 | "./demographic_field_generators_tests",
14 | ]
15 |
16 | for notebook_path in test_notebook_paths:
17 | dbutils.notebook.run(notebook_path, 0)
18 |
19 | # COMMAND ----------
20 |
21 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_pipelines_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./artificial_hes_tests/run_tests"
9 | ]
10 |
11 | for notebook_path in test_notebook_paths:
12 | dbutils.notebook.run(notebook_path, 0)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_stages_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./field_generators_tests",
9 | "./relationship_generator_tests",
10 | "./sampling_tests/field_definitions_tests",
11 | ]
12 |
13 | for notebook_path in test_notebook_paths:
14 | dbutils.notebook.run(notebook_path, 0)
15 |
16 | # COMMAND ----------
17 |
18 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/generator_stages_tests/sampling_tests/field_definitions_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %run ../../../notebooks/common/spark_helpers
3 |
4 | # COMMAND ----------
5 |
6 | # MAGIC %run ../../test_helpers
7 |
8 | # COMMAND ----------
9 |
10 | # MAGIC %run ../../../notebooks/generator_stages/sampling/field_definitions
11 |
12 | # COMMAND ----------
13 |
14 | import random
15 | from pyspark.sql import functions as F
16 |
17 | # COMMAND ----------
18 |
19 | # region total_over_window tests
20 |
21 | total_over_window_test_suite = FunctionTestSuite()
22 |
23 |
24 | @total_over_window_test_suite.add_test
25 | def test_correctly_computes_total_within_partition():
26 | input_data = [
27 | ["x", 1],
28 | ["y", 2],
29 | ["z", 3],
30 | ["x", 4],
31 | ["y", 5],
32 | ["z", 6],
33 | ["x", 7],
34 | ["y", 8],
35 | ["z", 9],
36 | ]
37 | input_schema = "FIELD_NAME: string, VALUE: integer"
38 | input_df = spark.createDataFrame(input_data, input_schema)
39 |
40 | sum_col = "VALUE"
41 | partition_cols = ["FIELD_NAME"]
42 | result_df = input_df.withColumn("TOTAL", total_over_window(sum_col, partition_cols))
43 |
44 | expected_total_col = (
45 | F
46 | .when(F.col("FIELD_NAME") == "x", 1 + 4 + 7)
47 | .when(F.col("FIELD_NAME") == "y", 2 + 5 + 8)
48 | .when(F.col("FIELD_NAME") == "z", 3 + 6 + 9)
49 | .cast("long")
50 | )
51 |
52 | result_df = result_df.withColumn("EXPECTED_TOTAL", expected_total_col)
53 |
54 | assert columns_equal(result_df, "TOTAL", "EXPECTED_TOTAL"), "Total sum did not equal expected value"
55 |
56 |
57 | total_over_window_test_suite.run()
58 |
59 | # endregion
60 |
61 | # COMMAND ----------
62 |
63 | # region cumsum_over_window tests
64 |
65 | cumsum_over_window_test_suite = FunctionTestSuite()
66 |
67 |
68 | @cumsum_over_window_test_suite.add_test
69 | def test_correctly_computes_cumsum_over_partition():
70 | # Column 1 specifies partition
71 | # Column 2 specifies ordering within the partition
72 | # Column 3 is the value to cumsum
73 | input_data = [
74 | ["x", 1, 1],
75 | ["y", 1, 2],
76 | ["z", 1, 3],
77 | ["x", 2, 4],
78 | ["y", 2, 5],
79 | ["z", 2, 6],
80 | ["x", 3, 7],
81 | ["y", 3, 8],
82 | ["z", 3, 9],
83 | ]
84 | random.shuffle(input_data)
85 | input_schema = "PARTITION_FIELD: string, ORDER_FIELD: integer, VALUE: long"
86 | input_df = spark.createDataFrame(input_data, input_schema)
87 |
88 | sum_col = "VALUE"
89 | partition_cols = "PARTITION_FIELD",
90 | order_cols = "ORDER_FIELD",
91 | result_df = input_df.withColumn("CUMSUM", cumsum_over_window(sum_col, partition_cols, order_cols))
92 |
93 | # Order expected by partition_key, expected cumulative sum value
94 | expected_cumsum_col = (
95 | F
96 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 1), 1)
97 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 2), 1 + 4)
98 | .when((F.col("PARTITION_FIELD") == "x") & (F.col("ORDER_FIELD") == 3), 1 + 4 + 7)
99 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 1), 2)
100 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 2), 2 + 5)
101 | .when((F.col("PARTITION_FIELD") == "y") & (F.col("ORDER_FIELD") == 3), 2 + 5 + 8)
102 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 1), 3)
103 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 2), 3 + 6)
104 | .when((F.col("PARTITION_FIELD") == "z") & (F.col("ORDER_FIELD") == 3), 3 + 6 + 9)
105 | .cast("long")
106 | )
107 |
108 | result_df = result_df.withColumn("EXPECTED_CUMSUM", expected_cumsum_col)
109 |
110 | assert columns_equal(result_df, "CUMSUM", "EXPECTED_CUMSUM"), "Cumulative sum did not equal expected value"
111 |
112 |
113 | cumsum_over_window_test_suite.run()
114 |
115 | # endregion
116 |
117 | # COMMAND ----------
118 |
119 |
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/scraper_pipelines_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./hes/demographic_field_summarisers_tests"
9 | ]
10 |
11 | for notebook_path in test_notebook_paths:
12 | dbutils.notebook.run(notebook_path, 0)
--------------------------------------------------------------------------------
/projects/iuod_artificial_data_generator/tests/scraper_stages_tests/run_tests.py:
--------------------------------------------------------------------------------
1 | # Databricks notebook source
2 | # MAGIC %md
3 | # MAGIC This notebook will run all the tests!
4 |
5 | # COMMAND ----------
6 |
7 | test_notebook_paths = [
8 | "./aggregation_tests/field_summarisers_tests",
9 | "./aggregation_tests/relationship_summariser_tests",
10 | "./preprocessing_tests/test_meta_type_classifier",
11 | ]
12 |
13 | for notebook_path in test_notebook_paths:
14 | dbutils.notebook.run(notebook_path, 0)
15 |
16 | # COMMAND ----------
17 |
18 |
--------------------------------------------------------------------------------
/utils/export.ps1:
--------------------------------------------------------------------------------
1 | # Export directories from Databricks
2 |
3 | $exportEnv = $(Read-Host -Prompt "Export 'dev' or 'staging' folders? (default=dev)")
4 |
5 | if ($exportEnv -eq "staging")
6 | {
7 | $workspaceRoot = "/staging"
8 | }
9 | elseif ($exportEnv -eq "dev")
10 | {
11 | $branchName = $(git rev-parse --abbrev-ref HEAD)
12 | $workspaceRoot = "/data_manager_projects/iuod_dae_test_data_generator/${branchName}"
13 | }
14 | else
15 | {
16 | "Invalid value: '$exportEnv'"
17 | exit 1;
18 | }
19 |
20 | $projects = @(
21 | "iuod_artificial_data_generator",
22 | "iuod_artificial_data_admin",
23 | "artificial_hes",
24 | "artificial_hes_meta"
25 | )
26 |
27 | git stash # Stash local changes before overwrite
28 |
29 | foreach ( $projectName in $projects )
30 | {
31 | $databricksPath = "$workspaceRoot/$projectName"
32 | $localPath = "projects/$projectName"
33 | Remove-Item $localPath -Recurse
34 | databricks workspace export_dir -o $databricksPath $localPath
35 | }
--------------------------------------------------------------------------------
/utils/import.ps1:
--------------------------------------------------------------------------------
1 | # Export directories from Databricks
2 |
3 | $importEnv = $(Read-Host -Prompt "Import to 'dev' or 'staging'? (default='dev')")
4 |
5 | if ($importEnv -eq "staging")
6 | {
7 | $workspaceRoot = "/staging"
8 | }
9 | elseif ($importEnv -eq "dev")
10 | {
11 | $branchName = $(git rev-parse --abbrev-ref HEAD)
12 | $workspaceRoot = "/data_manager_projects/iuod_dae_test_data_generator/${branchName}"
13 | }
14 | else
15 | {
16 | "Invalid value: '$importEnv'"
17 | exit 1;
18 | }
19 |
20 | $projects = @(
21 | "iuod_artificial_data_generator"
22 | "iuod_artificial_data_admin",
23 | "artificial_hes",
24 | "artificial_hes_meta"
25 | )
26 |
27 | foreach ( $projectName in $projects )
28 | {
29 | $localPath = "projects/$projectName"
30 | $databricksPath = "$workspaceRoot/$projectName"
31 | $resourceExists = !($(databricks workspace ls $databricksPath) -like '*"error_code":"RESOURCE_DOES_NOT_EXIST"*')
32 |
33 | if ($resourceExists) {
34 | "'$($databricksPath)' already exists in the Databricks workspace"
35 | $overwrite = Read-Host -Prompt "Do you wish to overwrite? (y/n)"
36 |
37 | if ($overwrite.ToLower() -eq "y")
38 | {
39 | # Require user confirmation of the project name for overwriting
40 | $confirmMessage = "Confirm the project name to overwrite ($($projectName))"
41 | $projectNameCheck = Read-Host -Prompt $confirmMessage
42 |
43 | while ($projectNameCheck -ne $projectName)
44 | {
45 | # Keep going until the names match
46 | "'$($projectNameCheck)' does not match '$($projectName)', please try again"
47 | $projectNameCheck = Read-Host -Prompt $confirmMessage
48 | }
49 |
50 | # Import the directory with overwrite
51 | databricks workspace import_dir -o $localPath $databricksPath
52 | }
53 | else {
54 | continue
55 | }
56 |
57 | }
58 | else
59 | {
60 | # Import the directory
61 | databricks workspace import_dir $localPath $databricksPath
62 | }
63 |
64 | }
--------------------------------------------------------------------------------
/utils/list_releases.ps1:
--------------------------------------------------------------------------------
1 | # List the latest code promotion project releases
2 |
3 | $workspaceRoot = "/Users/admin/releases/code-promotion"
4 | $projects = @(
5 | "iuod_artificial_data_generator"
6 | "iuod_artificial_data_admin",
7 | "artificial_hes",
8 | "artificial_hes_meta"
9 | )
10 |
11 | foreach ( $projectName in $projects )
12 | {
13 | $projectReleasesRoot = "$workspaceRoot/$projectName"
14 |
15 | # Get all the releases for the project
16 | $releases = $(databricks workspace ls $projectReleasesRoot)
17 |
18 | # Find the latest release
19 | $latestRelease = $null
20 | $latestTimestamp = [datetime]"1900-01-01"
21 | foreach ($release in $releases)
22 | {
23 | $releaseSegments = $release.Split(".")
24 | $releaseIdTimestamp = $releaseSegments[0].Split("+")
25 | $releaseId = $releaseIdTimestamp[0] # Not used
26 | $releaseTimestamp = [datetime]::ParseExact($releaseIdTimestamp[1], "yyyyMMddHHmmss", $null)
27 | $releaseHash = $releaseSegments[1].Substring(3) # Not used
28 |
29 | if ($latestTimestamp -lt $releaseTimestamp)
30 | {
31 | # Update the latest release
32 | $latestRelease = $release
33 | $latestTimestamp = $releaseTimestamp
34 | }
35 |
36 | }
37 |
38 | "Latest release for '$projectName': '$latestRelease' on $($latestTimestamp.ToString("yyyy-MM-dd 'at' HH:mm:ss"))"
39 | }
--------------------------------------------------------------------------------