├── .gitignore ├── README.md ├── airflow-dags ├── README.md ├── images │ └── load_weather_data_into_dbsql.png └── load_weather_data_into_dbsql.py ├── azure-mi-in-py-sdk ├── README.md ├── azure-mi-in-py-sdk.py ├── poetry.lock └── pyproject.toml ├── clojure-spark-demo ├── .gitignore ├── LICENSE ├── README.md ├── project.clj └── src │ └── clojure_spark_demo │ └── core.clj ├── database-diagram-builder ├── README.md ├── database-diagram-builder-notebook.py └── database-diagram-builder-standalone.py ├── dbconnect-maven ├── README.md ├── pom.xml └── src │ └── main │ └── scala │ └── net │ └── alexott │ └── demos │ └── databricks │ └── SimpleSQL.scala ├── dbconnect-package-versions-check ├── README.md └── package_versions_check.py ├── dbconnect-sbt ├── README.md ├── build.sbt ├── project │ └── build.properties └── src │ └── main │ └── scala │ └── net │ └── alexott │ └── demos │ └── databricks │ └── SimpleSQL.scala ├── dbsql-with-aad-token ├── README.md ├── dbsql-with-aad-token-mi.py ├── dbsql-with-aad-token-spn.py ├── dbsql-with-aad-token-user.py └── requirements.txt ├── dbutils-in-jar ├── README.md ├── pom.xml └── src │ └── main │ └── scala │ └── net │ └── alexott │ └── demos │ └── DbutilsDemo.scala ├── dbutils-snippets └── repos.py ├── deactivate-activate-users-sps ├── README.md └── deactivate-activate-users-sps.py ├── guardrails-checks └── check_guardrails.py ├── init-scripts ├── README.md └── install-ssl-certificates.sh ├── ip-access-list-analyzer └── README.md ├── java-sdk-example ├── pom.xml └── src │ └── main │ └── java │ └── net │ └── alexott │ └── demos │ └── M2MAzureAuth.java ├── kafka-eventhubs-aad-auth ├── .editorconfig ├── LICENSE ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── net │ └── alexott │ └── demos │ ├── eventhubs_aad │ ├── ServicePrincipalAuthBase.java │ └── ServicePrincipalCredentialsAuth.java │ └── kafka_aad │ ├── KafkaOAuthBearerTokenImp.java │ ├── ServicePrincipalAuthBase.java │ └── ServicePrincipalCredentialsAuth.java ├── oss-jdbc-aad-token ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── net │ └── alexott │ └── demos │ ├── OssJDBCAadTokenUser.java │ ├── OssJDBCOAuthAzureServicePrincipal.java │ └── OssJDBCOAuthDatabricksServicePrincipal.java ├── pause-unpause-jobs ├── README.md └── pause_unpause_jobs.py ├── pyspark-snippets ├── README.md ├── conftest.py ├── pyspark_helpers │ ├── columns_helpers.py │ ├── data_skipping.py │ ├── delta_helpers.py │ ├── drop_duplicates.py │ └── json_helpers.py ├── pytest.ini ├── requirements.txt ├── setup.py ├── tests │ └── unit │ │ ├── test_columns_helpers.py │ │ └── test_data_skipping.py └── unit-requirements.txt ├── simba-jdbc-aad-token ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── net │ └── alexott │ └── demos │ ├── SimbaJDBCAadTokenMI.java │ ├── SimbaJDBCAadTokenServicePrincipal.java │ └── SimbaJDBCAadTokenUser.java ├── simba-odbc-aad-token ├── README.md ├── odbc-aad-service-principal.py ├── odbc-aad-user.py └── requirements.txt ├── spring-jdbc-dbsql ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── net │ └── alexott │ └── demos │ └── spring_jdbc │ └── SimpleQuery.java ├── stackoverflow └── snippets.py └── streaming_query_listener ├── QueryListenerUsageExample.py └── query_listener.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | target/ 3 | \.idea/ 4 | \.vscode/ 5 | *.log 6 | *.iml 7 | /dbconnect-sbt/.sbtopts 8 | /pyspark-snippets/databricks_pyspark_helpers.egg-info/ 9 | /pyspark-snippets/dist/ 10 | __pycache__/ 11 | .DS_Store 12 | /kafka-aad-auth/dependency-reduced-pom.xml 13 | dependency-reduced-pom.xml 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains different code samples & other examples related to the Databricks platform & Spark: 2 | 3 | * [airflow-dags](airflow-dags) - Examples of Airflow DAGs for Databricks. 4 | * [database-diagram-builder](database-diagram-builder) - tool to generate UML diagram(s) for tables in Databricks/Spark database. 5 | * [dbconnect-maven](dbconnect-maven) - skeleton of the Maven project for simple Spark job, and instructions on how to run it via databricks-connect. 6 | * [dbconnect-package-versions-check](dbconnect-package-versions-check) - tool to checks compatibility of local Databricks connect environment with Databricks cluster. 7 | * [dbconnect-sbt](dbconnect-sbt) - skeleton of the SBT project for simple Spark job, and instructions on how to run it via databricks-connect. 8 | * [dbsql-with-aad-token](dbsql-with-aad-token) - example of querying data on Databricks using [python-sql-connector](https://docs.databricks.com/dev-tools/python-sql-connector.html) library. Authentication to Databricks is performed using Azure Active Directory tokens issued for Azure Service Principal. 9 | * [dbutils-in-jar](dbutils-in-jar) - example of using Databricks dbutils in JVM-based code compiled into `.jar`. 10 | * [ip-access-list-analyzer](ip-access-list-analyzer) - analyzer/fix tool for Databricks IP Access Lists. 11 | * [kafka-eventhubs-aad-auth](kafka-eventhubs-aad-auth) - library to support Azure Active Directory authentication for Spark Kafka & EventHubs connectors accessing Event Hubs. 12 | * [pyspark-snippets](pyspark-snippets) - functions that simplify development of PySpark code for Databricks 13 | * [simba-jdbc-aad-token](simba-jdbc-aad-token) - example of querying data on Databricks using JDBC protocol. Authentication to Databricks is performed using Azure Active Directory tokens issued for Azure Service Principal. 14 | * [spring-jdbc-dbsql](spring-jdbc-dbsql) - example of querying Databricks via JDBC using Spring JDBC. 15 | 16 | You can also find more examples of Spark code in the other repositories: 17 | * [spark-playground](https://github.com/alexott/spark-playground) - general Spark code 18 | * [cassandra-dse-playground](https://github.com/alexott/cassandra-dse-playground) - Spark code specific for use with Cassandra/DataStax Enterprise (DSE) 19 | 20 | 21 | -------------------------------------------------------------------------------- /airflow-dags/README.md: -------------------------------------------------------------------------------- 1 | # Airflow DAGs 2 | 3 | This directory contains examples of Airflow DAGs that use [apache-airflow-providers-databricks](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/index.html). 4 | 5 | ## Loading weather data into Databricks SQL Endpoint 6 | 7 | [load_weather_data_into_dbsql.py](load_weather_data_into_dbsql.py) is an example of the Airflow DAG that loads weather data for some cities into a Delta table using Databricks SQL Endpoint. DAG consists of following steps: 8 | 9 | * `create_table` - creates a Delta table if it doesn't exist using `DatabricksSqlOperator` 10 | * `get_weather_data` - fetch weather data using calls to REST API and saves data to a local disk using `PythonOperator` 11 | * `upload_weather_data` - uploads data from local disk to Azure Blob Storage using `LocalFilesystemToWasbOperator` 12 | * `import_weather_data` - imports uploaded data with `COPY INTO` SQL command executed via `DatabricksCopyIntoOperator`. 13 | 14 | ![DAG in Ariflow UI](images/load_weather_data_into_dbsql.png) 15 | 16 | To make it working in your environment you need to change following constants: 17 | 18 | * `WASBS_CONN_ID` - name of a [Azure Blob Storage connection](https://airflow.apache.org/docs/apache-airflow-providers-microsoft-azure/stable/connections/wasb.html). 19 | * `DATABRICKS_SQL_ENDPOINT_NAME` - name of a [Databricks SQL endpoint](https://docs.databricks.com/sql/admin/sql-endpoints.html) that will be used for creation of the table and importing of data. 20 | * `DATABRICKS_CONN_ID` - name of a [Databricks connection](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/connections/databricks.html) that will be used for authentication to Databricks workspace. 21 | * `DESTINATION_TABLE_NAME` - name of Delta table that will be created & loaded with data. 22 | * `LANDING_LOCATION_PREFIX` - name of directory inside the ADLS container. 23 | * `ADLS_CONTAINER_NAME` - name of ADLS container. 24 | * `ADLS_STORAGE_NAME` - name of ADLS storage account (without `.dfs.core.windows.net`). 25 | 26 | -------------------------------------------------------------------------------- /airflow-dags/images/load_weather_data_into_dbsql.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexott/databricks-playground/3bb1b2fb5e21730704b6c905f78324026ee1cfed/airflow-dags/images/load_weather_data_into_dbsql.png -------------------------------------------------------------------------------- /airflow-dags/load_weather_data_into_dbsql.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | from datetime import date 4 | 5 | import requests 6 | from airflow import DAG 7 | from airflow.operators.python import PythonOperator 8 | from airflow.providers.databricks.operators.databricks_sql import ( 9 | DatabricksCopyIntoOperator, 10 | DatabricksSqlOperator, 11 | ) 12 | from airflow.providers.microsoft.azure.transfers.local_to_wasb import LocalFilesystemToWasbOperator 13 | from airflow.utils.dates import days_ago 14 | 15 | # constants to change 16 | WASBS_CONN_ID = 'wasbs-test' 17 | DATABRICKS_SQL_ENDPOINT_NAME = "Airflow Test" 18 | DATABRICKS_CONN_ID = "my-shard-pat" 19 | DESTINATION_TABLE_NAME = "default.airflow_weather" 20 | LANDING_LOCATION_PREFIX = "airflow/landing" 21 | ADLS_CONTAINER_NAME = "test" 22 | ADLS_STORAGE_NAME = "test" 23 | 24 | # 25 | cities = ["london", "berlin", "amsterdam"] 26 | BASE_API_URL = "https://weatherdbi.herokuapp.com/data/weather/" 27 | ADLS_LANDING_LOCATION = f"abfss://{ADLS_CONTAINER_NAME}@{ADLS_STORAGE_NAME}.dfs.core.windows.net/" \ 28 | f"{LANDING_LOCATION_PREFIX}/ " 29 | 30 | 31 | airflow_default_args = { 32 | "owner": "airflow", 33 | "databricks_conn_id": DATABRICKS_CONN_ID, 34 | "sql_endpoint_name": DATABRICKS_SQL_ENDPOINT_NAME, 35 | } 36 | 37 | 38 | def import_city_data(city: str): 39 | def convert_percents_string(s: str) -> Optional[float]: 40 | if s and s.endswith('%'): 41 | return float(s[:-1]) / 100 42 | return None 43 | 44 | url = BASE_API_URL + city.lower() 45 | req = requests.get(url) 46 | data_json = req.json() 47 | data = {'date': str(date.today()), 48 | 'region': data_json['region'], 49 | 'temperature': data_json['currentConditions']['temp']['c'], 50 | 'precipitation': convert_percents_string(data_json['currentConditions']['precip']), 51 | 'humidity': convert_percents_string(data_json['currentConditions']['humidity']), 52 | 'wind': data_json['currentConditions']['wind']['km'], 53 | 'condition': data_json['currentConditions']['comment'], 54 | 'next_days': [{'day': d['day'], 'condition': d['comment'], 55 | 'max_temp': d['max_temp']['c'], 'min_temp': d['min_temp']['c']} 56 | for d in data_json.get('next_days', [])], 57 | } 58 | return data 59 | 60 | 61 | def get_weather_data(output_path: str): 62 | data = [json.dumps(import_city_data(city)) + "\n" for city in cities] 63 | with open(output_path, "w") as f: 64 | f.writelines(data) 65 | 66 | 67 | with DAG( 68 | "load_weather_into_dbsql", 69 | start_date=days_ago(0), 70 | schedule_interval="@daily", 71 | default_args=airflow_default_args, 72 | catchup=False, 73 | ) as dag: 74 | schema = "date date, condition STRING, humidity double, precipitation double, " \ 75 | "region STRING, temperature long, wind long, " \ 76 | "next_days ARRAY>" 78 | 79 | create_table = DatabricksSqlOperator( 80 | task_id="create_table", 81 | sql=[f"create table if not exists {DESTINATION_TABLE_NAME}({schema}) using delta"], 82 | ) 83 | 84 | get_weather_data = PythonOperator(task_id="get_weather_data", 85 | python_callable=get_weather_data, 86 | op_kwargs={ 87 | "output_path": "/tmp/{{next_ds}}.json" 88 | }, 89 | ) 90 | 91 | copy_data_to_adls = LocalFilesystemToWasbOperator( 92 | task_id='upload_weather_data', 93 | wasb_conn_id=WASBS_CONN_ID, 94 | file_path="/tmp/{{next_ds}}.json", 95 | container_name='test', 96 | blob_name= LANDING_LOCATION_PREFIX + "/{{next_ds}}.json", 97 | load_options={"overwrite": True,}, 98 | ) 99 | 100 | import_weather_data = DatabricksCopyIntoOperator( 101 | task_id="import_weather_data", 102 | expression_list="date::date, * except(date)", 103 | table_name=DESTINATION_TABLE_NAME, 104 | file_format="JSON", 105 | file_location=ADLS_LANDING_LOCATION, 106 | files=["{{next_ds}}.json"], 107 | #validate=True, # this requires Preview channel 108 | force_copy=True, 109 | ) 110 | 111 | (create_table >> get_weather_data >> copy_data_to_adls >> import_weather_data) 112 | -------------------------------------------------------------------------------- /azure-mi-in-py-sdk/README.md: -------------------------------------------------------------------------------- 1 | # An example of using Azure Managed Identity with Databricks Python SDK 2 | 3 | This directory contains an example of using Azure Managed Identity with Databricks Python SDK. 4 | This is done by implementing a custom credential strategy that wraps the `azure-identity` package. 5 | 6 | The code is in the `azure-mi-in-py-sdk.py` file - it instantiates a Databricks workspace client with 7 | the custom credential strategy and lists clusters in the workspace. To run it you need to install 8 | `databricks-sdk` and `azure-identity` packages (this could be done using the Poetry tool). 9 | 10 | Managed identity (system or user-assigned) needs to be added into Databricks workspace. The URL 11 | of the Databricks workspace is specified via `DATABRICKS_HOST` environment variable or directly in 12 | the code (line 32). It will be using default managed identity to generate token, but you can 13 | specify a different managed identity by setting `AZURE_CLIENT_ID` environment variable. 14 | -------------------------------------------------------------------------------- /azure-mi-in-py-sdk/azure-mi-in-py-sdk.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import os 3 | 4 | from azure.identity import ManagedIdentityCredential 5 | from databricks.sdk import WorkspaceClient 6 | from databricks.sdk.config import Config 7 | from databricks.sdk.credentials_provider import CredentialsProvider, \ 8 | CredentialsStrategy 9 | 10 | 11 | class AzureIdentityCredentialsStrategy(CredentialsStrategy): 12 | def auth_type(self) -> str: 13 | return 'azure-mi' 14 | 15 | def __init__(self, client_id: str = None): 16 | self.client_id = client_id 17 | 18 | def __call__(self, cfg: 'Config') -> CredentialsProvider: 19 | if self.client_id: 20 | mi_credential = ManagedIdentityCredential(client_id=self.client_id) 21 | else: 22 | mi_credential = ManagedIdentityCredential() 23 | 24 | def inner() -> Dict[str, str]: 25 | token = mi_credential.get_token("2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default") 26 | return {'Authorization': f'Bearer {token.token}'} 27 | 28 | return inner 29 | 30 | 31 | client_id = os.getenv( "AZURE_CLIENT_ID") 32 | host = os.getenv( "DATABRICKS_HOST") or "https://adb-....17.azuredatabricks.net" 33 | 34 | wc = WorkspaceClient(host=host, 35 | credentials_strategy=AzureIdentityCredentialsStrategy(client_id=client_id)) 36 | 37 | for cluster in wc.clusters.list(): 38 | print(cluster.cluster_id, cluster.cluster_name) -------------------------------------------------------------------------------- /azure-mi-in-py-sdk/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "azure-mi-in-py-sdk" 3 | version = "0.0.1" 4 | description = "" 5 | authors = ["Alex Ott "] 6 | license = "Databricks License" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | databricks-sdk = "^0.44.1" 12 | azure-identity = "^1.21.0" 13 | 14 | [tool.poetry.group.test.dependencies] 15 | 16 | [build-system] 17 | requires = ["poetry-core"] 18 | build-backend = "poetry.core.masonry.api" 19 | -------------------------------------------------------------------------------- /clojure-spark-demo/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /classes 3 | /checkouts 4 | profiles.clj 5 | pom.xml 6 | pom.xml.asc 7 | *.jar 8 | *.class 9 | /.lein-* 10 | /.nrepl-port 11 | /.prepl-port 12 | .hgignore 13 | .hg/ 14 | -------------------------------------------------------------------------------- /clojure-spark-demo/LICENSE: -------------------------------------------------------------------------------- 1 | Eclipse Public License - v 2.0 2 | 3 | THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE 4 | PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION 5 | OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 6 | 7 | 1. DEFINITIONS 8 | 9 | "Contribution" means: 10 | 11 | a) in the case of the initial Contributor, the initial content 12 | Distributed under this Agreement, and 13 | 14 | b) in the case of each subsequent Contributor: 15 | i) changes to the Program, and 16 | ii) additions to the Program; 17 | where such changes and/or additions to the Program originate from 18 | and are Distributed by that particular Contributor. A Contribution 19 | "originates" from a Contributor if it was added to the Program by 20 | such Contributor itself or anyone acting on such Contributor's behalf. 21 | Contributions do not include changes or additions to the Program that 22 | are not Modified Works. 23 | 24 | "Contributor" means any person or entity that Distributes the Program. 25 | 26 | "Licensed Patents" mean patent claims licensable by a Contributor which 27 | are necessarily infringed by the use or sale of its Contribution alone 28 | or when combined with the Program. 29 | 30 | "Program" means the Contributions Distributed in accordance with this 31 | Agreement. 32 | 33 | "Recipient" means anyone who receives the Program under this Agreement 34 | or any Secondary License (as applicable), including Contributors. 35 | 36 | "Derivative Works" shall mean any work, whether in Source Code or other 37 | form, that is based on (or derived from) the Program and for which the 38 | editorial revisions, annotations, elaborations, or other modifications 39 | represent, as a whole, an original work of authorship. 40 | 41 | "Modified Works" shall mean any work in Source Code or other form that 42 | results from an addition to, deletion from, or modification of the 43 | contents of the Program, including, for purposes of clarity any new file 44 | in Source Code form that contains any contents of the Program. Modified 45 | Works shall not include works that contain only declarations, 46 | interfaces, types, classes, structures, or files of the Program solely 47 | in each case in order to link to, bind by name, or subclass the Program 48 | or Modified Works thereof. 49 | 50 | "Distribute" means the acts of a) distributing or b) making available 51 | in any manner that enables the transfer of a copy. 52 | 53 | "Source Code" means the form of a Program preferred for making 54 | modifications, including but not limited to software source code, 55 | documentation source, and configuration files. 56 | 57 | "Secondary License" means either the GNU General Public License, 58 | Version 2.0, or any later versions of that license, including any 59 | exceptions or additional permissions as identified by the initial 60 | Contributor. 61 | 62 | 2. GRANT OF RIGHTS 63 | 64 | a) Subject to the terms of this Agreement, each Contributor hereby 65 | grants Recipient a non-exclusive, worldwide, royalty-free copyright 66 | license to reproduce, prepare Derivative Works of, publicly display, 67 | publicly perform, Distribute and sublicense the Contribution of such 68 | Contributor, if any, and such Derivative Works. 69 | 70 | b) Subject to the terms of this Agreement, each Contributor hereby 71 | grants Recipient a non-exclusive, worldwide, royalty-free patent 72 | license under Licensed Patents to make, use, sell, offer to sell, 73 | import and otherwise transfer the Contribution of such Contributor, 74 | if any, in Source Code or other form. This patent license shall 75 | apply to the combination of the Contribution and the Program if, at 76 | the time the Contribution is added by the Contributor, such addition 77 | of the Contribution causes such combination to be covered by the 78 | Licensed Patents. The patent license shall not apply to any other 79 | combinations which include the Contribution. No hardware per se is 80 | licensed hereunder. 81 | 82 | c) Recipient understands that although each Contributor grants the 83 | licenses to its Contributions set forth herein, no assurances are 84 | provided by any Contributor that the Program does not infringe the 85 | patent or other intellectual property rights of any other entity. 86 | Each Contributor disclaims any liability to Recipient for claims 87 | brought by any other entity based on infringement of intellectual 88 | property rights or otherwise. As a condition to exercising the 89 | rights and licenses granted hereunder, each Recipient hereby 90 | assumes sole responsibility to secure any other intellectual 91 | property rights needed, if any. For example, if a third party 92 | patent license is required to allow Recipient to Distribute the 93 | Program, it is Recipient's responsibility to acquire that license 94 | before distributing the Program. 95 | 96 | d) Each Contributor represents that to its knowledge it has 97 | sufficient copyright rights in its Contribution, if any, to grant 98 | the copyright license set forth in this Agreement. 99 | 100 | e) Notwithstanding the terms of any Secondary License, no 101 | Contributor makes additional grants to any Recipient (other than 102 | those set forth in this Agreement) as a result of such Recipient's 103 | receipt of the Program under the terms of a Secondary License 104 | (if permitted under the terms of Section 3). 105 | 106 | 3. REQUIREMENTS 107 | 108 | 3.1 If a Contributor Distributes the Program in any form, then: 109 | 110 | a) the Program must also be made available as Source Code, in 111 | accordance with section 3.2, and the Contributor must accompany 112 | the Program with a statement that the Source Code for the Program 113 | is available under this Agreement, and informs Recipients how to 114 | obtain it in a reasonable manner on or through a medium customarily 115 | used for software exchange; and 116 | 117 | b) the Contributor may Distribute the Program under a license 118 | different than this Agreement, provided that such license: 119 | i) effectively disclaims on behalf of all other Contributors all 120 | warranties and conditions, express and implied, including 121 | warranties or conditions of title and non-infringement, and 122 | implied warranties or conditions of merchantability and fitness 123 | for a particular purpose; 124 | 125 | ii) effectively excludes on behalf of all other Contributors all 126 | liability for damages, including direct, indirect, special, 127 | incidental and consequential damages, such as lost profits; 128 | 129 | iii) does not attempt to limit or alter the recipients' rights 130 | in the Source Code under section 3.2; and 131 | 132 | iv) requires any subsequent distribution of the Program by any 133 | party to be under a license that satisfies the requirements 134 | of this section 3. 135 | 136 | 3.2 When the Program is Distributed as Source Code: 137 | 138 | a) it must be made available under this Agreement, or if the 139 | Program (i) is combined with other material in a separate file or 140 | files made available under a Secondary License, and (ii) the initial 141 | Contributor attached to the Source Code the notice described in 142 | Exhibit A of this Agreement, then the Program may be made available 143 | under the terms of such Secondary Licenses, and 144 | 145 | b) a copy of this Agreement must be included with each copy of 146 | the Program. 147 | 148 | 3.3 Contributors may not remove or alter any copyright, patent, 149 | trademark, attribution notices, disclaimers of warranty, or limitations 150 | of liability ("notices") contained within the Program from any copy of 151 | the Program which they Distribute, provided that Contributors may add 152 | their own appropriate notices. 153 | 154 | 4. COMMERCIAL DISTRIBUTION 155 | 156 | Commercial distributors of software may accept certain responsibilities 157 | with respect to end users, business partners and the like. While this 158 | license is intended to facilitate the commercial use of the Program, 159 | the Contributor who includes the Program in a commercial product 160 | offering should do so in a manner which does not create potential 161 | liability for other Contributors. Therefore, if a Contributor includes 162 | the Program in a commercial product offering, such Contributor 163 | ("Commercial Contributor") hereby agrees to defend and indemnify every 164 | other Contributor ("Indemnified Contributor") against any losses, 165 | damages and costs (collectively "Losses") arising from claims, lawsuits 166 | and other legal actions brought by a third party against the Indemnified 167 | Contributor to the extent caused by the acts or omissions of such 168 | Commercial Contributor in connection with its distribution of the Program 169 | in a commercial product offering. The obligations in this section do not 170 | apply to any claims or Losses relating to any actual or alleged 171 | intellectual property infringement. In order to qualify, an Indemnified 172 | Contributor must: a) promptly notify the Commercial Contributor in 173 | writing of such claim, and b) allow the Commercial Contributor to control, 174 | and cooperate with the Commercial Contributor in, the defense and any 175 | related settlement negotiations. The Indemnified Contributor may 176 | participate in any such claim at its own expense. 177 | 178 | For example, a Contributor might include the Program in a commercial 179 | product offering, Product X. That Contributor is then a Commercial 180 | Contributor. If that Commercial Contributor then makes performance 181 | claims, or offers warranties related to Product X, those performance 182 | claims and warranties are such Commercial Contributor's responsibility 183 | alone. Under this section, the Commercial Contributor would have to 184 | defend claims against the other Contributors related to those performance 185 | claims and warranties, and if a court requires any other Contributor to 186 | pay any damages as a result, the Commercial Contributor must pay 187 | those damages. 188 | 189 | 5. NO WARRANTY 190 | 191 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT 192 | PERMITTED BY APPLICABLE LAW, THE PROGRAM IS PROVIDED ON AN "AS IS" 193 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR 194 | IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF 195 | TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR 196 | PURPOSE. Each Recipient is solely responsible for determining the 197 | appropriateness of using and distributing the Program and assumes all 198 | risks associated with its exercise of rights under this Agreement, 199 | including but not limited to the risks and costs of program errors, 200 | compliance with applicable laws, damage to or loss of data, programs 201 | or equipment, and unavailability or interruption of operations. 202 | 203 | 6. DISCLAIMER OF LIABILITY 204 | 205 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT 206 | PERMITTED BY APPLICABLE LAW, NEITHER RECIPIENT NOR ANY CONTRIBUTORS 207 | SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 208 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST 209 | PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 210 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 211 | ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE 212 | EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE 213 | POSSIBILITY OF SUCH DAMAGES. 214 | 215 | 7. GENERAL 216 | 217 | If any provision of this Agreement is invalid or unenforceable under 218 | applicable law, it shall not affect the validity or enforceability of 219 | the remainder of the terms of this Agreement, and without further 220 | action by the parties hereto, such provision shall be reformed to the 221 | minimum extent necessary to make such provision valid and enforceable. 222 | 223 | If Recipient institutes patent litigation against any entity 224 | (including a cross-claim or counterclaim in a lawsuit) alleging that the 225 | Program itself (excluding combinations of the Program with other software 226 | or hardware) infringes such Recipient's patent(s), then such Recipient's 227 | rights granted under Section 2(b) shall terminate as of the date such 228 | litigation is filed. 229 | 230 | All Recipient's rights under this Agreement shall terminate if it 231 | fails to comply with any of the material terms or conditions of this 232 | Agreement and does not cure such failure in a reasonable period of 233 | time after becoming aware of such noncompliance. If all Recipient's 234 | rights under this Agreement terminate, Recipient agrees to cease use 235 | and distribution of the Program as soon as reasonably practicable. 236 | However, Recipient's obligations under this Agreement and any licenses 237 | granted by Recipient relating to the Program shall continue and survive. 238 | 239 | Everyone is permitted to copy and distribute copies of this Agreement, 240 | but in order to avoid inconsistency the Agreement is copyrighted and 241 | may only be modified in the following manner. The Agreement Steward 242 | reserves the right to publish new versions (including revisions) of 243 | this Agreement from time to time. No one other than the Agreement 244 | Steward has the right to modify this Agreement. The Eclipse Foundation 245 | is the initial Agreement Steward. The Eclipse Foundation may assign the 246 | responsibility to serve as the Agreement Steward to a suitable separate 247 | entity. Each new version of the Agreement will be given a distinguishing 248 | version number. The Program (including Contributions) may always be 249 | Distributed subject to the version of the Agreement under which it was 250 | received. In addition, after a new version of the Agreement is published, 251 | Contributor may elect to Distribute the Program (including its 252 | Contributions) under the new version. 253 | 254 | Except as expressly stated in Sections 2(a) and 2(b) above, Recipient 255 | receives no rights or licenses to the intellectual property of any 256 | Contributor under this Agreement, whether expressly, by implication, 257 | estoppel or otherwise. All rights in the Program not expressly granted 258 | under this Agreement are reserved. Nothing in this Agreement is intended 259 | to be enforceable by any entity that is not a Contributor or Recipient. 260 | No third-party beneficiary rights are created under this Agreement. 261 | 262 | Exhibit A - Form of Secondary Licenses Notice 263 | 264 | "This Source Code may also be made available under the following 265 | Secondary Licenses when the conditions for such availability set forth 266 | in the Eclipse Public License, v. 2.0 are satisfied: GNU General Public 267 | License as published by the Free Software Foundation, either version 2 268 | of the License, or (at your option) any later version, with the GNU 269 | Classpath Exception which is available at 270 | https://www.gnu.org/software/classpath/license.html." 271 | 272 | Simply including a copy of this Agreement, including this Exhibit A 273 | is not sufficient to license the Source Code under Secondary Licenses. 274 | 275 | If it is not possible or desirable to put the notice in a particular 276 | file, then You may include the notice in a location (such as a LICENSE 277 | file in a relevant directory) where a recipient would be likely to 278 | look for such a notice. 279 | 280 | You may add additional accurate notices of copyright ownership. 281 | -------------------------------------------------------------------------------- /clojure-spark-demo/README.md: -------------------------------------------------------------------------------- 1 | # clojure-spark-demo 2 | 3 | A small demo of using Clojure on Databricks using [Geni](https://github.com/zero-one-group/geni) library. 4 | 5 | ## Usage 6 | 7 | Build the uberjar, copy it to DBFS and then create a Jar task that uses that jar. Main class name is `clojure_spark_demo.core`. 8 | 9 | ## License 10 | 11 | Copyright © 2022 Alex Ott 12 | 13 | This program and the accompanying materials are made available under the 14 | terms of the Eclipse Public License 2.0 which is available at 15 | http://www.eclipse.org/legal/epl-2.0. 16 | 17 | This Source Code may also be made available under the following Secondary 18 | Licenses when the conditions for such availability set forth in the Eclipse 19 | Public License, v. 2.0 are satisfied: GNU General Public License as published by 20 | the Free Software Foundation, either version 2 of the License, or (at your 21 | option) any later version, with the GNU Classpath Exception which is available 22 | at https://www.gnu.org/software/classpath/license.html. 23 | -------------------------------------------------------------------------------- /clojure-spark-demo/project.clj: -------------------------------------------------------------------------------- 1 | (defproject clojure-spark-demo "0.1.0-SNAPSHOT" 2 | :description "Example of Clojure project for Databricks" 3 | :license {:name "EPL-2.0 OR GPL-2.0-or-later WITH Classpath-exception-2.0" 4 | :url "https://www.eclipse.org/legal/epl-2.0/"} 5 | :dependencies [[org.clojure/clojure "1.10.3"] 6 | [zero.one/geni "0.0.40" 7 | :exclusions [commons-codec reply nrepl org.nrepl/incomplete]];; 8 | [org.apache.spark/spark-sql_2.12 "3.2.1" :scope "provided"] 9 | ;;[org.apache.spark/spark-streaming_2.12 "3.2.1" :scope "provided"] 10 | [org.apache.spark/spark-mllib_2.12 "3.2.1" :scope "provided"] 11 | ;;[com.github.fommil.netlib/all "1.1.2" :extension "pom"] 12 | [org.apache.arrow/arrow-memory-netty "2.0.0" :scope "provided"] 13 | [org.apache.arrow/arrow-memory-core "2.0.0" :scope "provided"] 14 | [org.apache.arrow/arrow-vector "2.0.0" :scope "provided" 15 | :exclusions [commons-codec com.fasterxml.jackson.core/jackson-databind]] 16 | [com.fasterxml.jackson.core/jackson-databind "2.12.3" :scope "provided"] 17 | [com.fasterxml.jackson.core/jackson-core "2.12.3" :scope "provided"] 18 | [zero.one/fxl "0.0.6" :scope "provided"] 19 | ] 20 | :aot :all 21 | :main clojure-spark-demo.core 22 | :repl-options {:init-ns clojure-spark-demo.core}) 23 | -------------------------------------------------------------------------------- /clojure-spark-demo/src/clojure_spark_demo/core.clj: -------------------------------------------------------------------------------- 1 | (ns clojure-spark-demo.core 2 | (:require [zero-one.geni.core :as g]) 3 | (:gen-class) 4 | ) 5 | 6 | 7 | (defn -main [& args] 8 | (let* [source (g/read-csv! "dbfs:/databricks-datasets/flights/departuredelays.csv" 9 | {:header "true" :infer-schema "true"})] 10 | (println "Source schema") 11 | (g/print-schema source) 12 | (println "rows count: " (g/count source)) 13 | (println "going to write data to a table") 14 | (-> source 15 | (g/select :date :delay :distance :origin :destination) 16 | (g/write-table! "default.clj_test" {:mode "overwrite"}) 17 | ) 18 | ) 19 | ) 20 | -------------------------------------------------------------------------------- /database-diagram-builder/README.md: -------------------------------------------------------------------------------- 1 | ## Generates UML diagram for specified Spark databases 2 | 3 | Generates [PlantUML](https://plantuml.com/) diagram for all, or selected databases registered in Databricks/Spark. Generated UML diagram then could be converted into PDF/SVG/PNG or other formats using `plantuml` command-line tool. 4 | 5 | There are two variants of the code: 6 | 1. [database-diagram-builder-notebook.py](database-diagram-builder-notebook.py) - Databricks notebook that accepts parameters via widgets, and besides generation of PlantUML source code, can also generate PDF/SVG/PNG representation. 7 | 1. [database-diagram-builder-standalone.py](database-diagram-builder-standalone.py) - for use with Databricks connect or OSS Spark. Only generates PlantUML source code. 8 | -------------------------------------------------------------------------------- /database-diagram-builder/database-diagram-builder-notebook.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %md This notebook contains code that allows to generate Plant UML schema from the Databricks database(s). The generated file could be transformed into desired format by selecting necessary formats in "Generate image" dropdown, or by executing `%sh java -jar plantuml.jar -t `. You can upgrade version of PlantUML by modifying `plantum_download_link` variable below. 3 | 4 | # COMMAND ---------- 5 | 6 | plantum_download_link = "https://github.com/plantuml/plantuml/releases/download/v1.2022.2/plantuml-pdf-1.2022.2.jar" 7 | 8 | # COMMAND ---------- 9 | 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.utils import AnalysisException 12 | from pyspark.sql.types import * 13 | import sys 14 | import requests 15 | import os 16 | import tempfile 17 | 18 | # COMMAND ---------- 19 | 20 | # dbutils.widgets.removeAll() 21 | dbutils.widgets.text("name_offset", "3", "Offset for names") 22 | dbutils.widgets.text("pad_offset", "2", "Offset for padding") 23 | dbutils.widgets.dropdown("include_temp", "No", ["Yes", "No"], "Include temp tables/views") 24 | dbutils.widgets.multiselect("generate_image", "", ["", "PNG", "SVG", "PDF"], "Generate image") 25 | dbutils.widgets.text("output_path", "/tmp/db_schema.puml", "Output path on DBFS") 26 | dbutils.widgets.text("databases", "", "Databases to process (comma-separated, empty - all)") 27 | #dbutils.widgets.text("", "", "") 28 | 29 | # COMMAND ---------- 30 | 31 | name_offset = int(dbutils.widgets.get("name_offset")) 32 | pad_offset = int(dbutils.widgets.get("pad_offset")) 33 | include_temp = dbutils.widgets.get("include_temp").lower() == "yes" 34 | enable_hive_support = False 35 | generate_image = dbutils.widgets.get("generate_image").lower() 36 | output_path = dbutils.widgets.get("output_path") 37 | t = dbutils.widgets.get("databases") 38 | if len(t.strip()) == 0: 39 | databases = [db[0] for db in spark.sql("show databases").collect()] 40 | else: 41 | databases = [d.strip() for d in t.split(",")] 42 | # dbutils.widgets.get("") 43 | 44 | # COMMAND ---------- 45 | 46 | def is_struct_type(typ: object) -> bool: 47 | return isinstance(typ, dict) and typ['type'] == 'struct' 48 | 49 | 50 | def maybe_get_field(obj: object, name: str, default: object) -> object: 51 | if isinstance(obj, dict): 52 | return obj.get(name, default) 53 | return default 54 | 55 | # COMMAND ---------- 56 | 57 | def format_structfield(type_val: object, padding: int, isNullable: bool = False) -> str: 58 | if isinstance(type_val, str): 59 | type_string = type_val 60 | elif isinstance(type_val, dict): 61 | sub_type = type_val['type'] 62 | if sub_type == 'array': 63 | type_string = "array< " 64 | element_type = type_val['elementType'] 65 | is_struct = is_struct_type(element_type) 66 | if is_struct: 67 | type_string += "\n" 68 | padding += pad_offset 69 | type_string += (" " * padding) 70 | 71 | type_string += format_structfield(element_type, padding, type_val.get('containsNull', False)) 72 | if is_struct: 73 | type_string += "\n" 74 | padding -= pad_offset 75 | type_string += (" " * (padding - pad_offset)) 76 | 77 | type_string += " >" 78 | elif sub_type == 'map': 79 | # TODO: fix it - need to find the example of output 80 | type_string = "map< " 81 | element_type = type_val['keyType'] 82 | is_struct = is_struct_type(element_type) 83 | if is_struct: 84 | type_string += "\n" 85 | type_string += (" " * padding) 86 | padding += pad_offset 87 | 88 | type_string += format_structfield(element_type, padding) 89 | if is_struct: 90 | padding -= pad_offset 91 | type_string += ", " 92 | 93 | element_type = type_val['valueType'] 94 | is_struct = is_struct_type(element_type) 95 | if is_struct: 96 | type_string += "\n" 97 | type_string += (" " * padding) 98 | padding += pad_offset 99 | type_string += format_structfield(element_type, padding, type_val.get('valueContainsNull', False)) 100 | if is_struct: 101 | type_string += "\n" 102 | padding -= pad_offset 103 | type_string += (" " * (padding - pad_offset)) 104 | 105 | type_string += " >" 106 | elif sub_type == 'struct': 107 | pad_str = (" " * (padding + pad_offset)) 108 | type_string = "struct<\n" 109 | for field in type_val['fields']: 110 | fname = field['name'] 111 | type_string += pad_str + fname + " : " 112 | type_string += format_structfield(field['type'], padding + len(fname) + name_offset + pad_offset, 113 | field.get('nullable', False)) 114 | type_string += "\n" 115 | 116 | type_string += (" " * padding) + ">" 117 | else: 118 | raise Exception(f'Unknown subtype: {sub_type}') 119 | else: 120 | raise Exception(f'Unknown type: {type_val}') 121 | 122 | if isNullable: 123 | type_string += ' ?' 124 | return type_string 125 | 126 | # COMMAND ---------- 127 | 128 | def format_type_name(col_name: str, typ: StructField, isNullable: bool = False, 129 | isPartition: bool = False, isBucket: bool = False) -> str: 130 | current_pad = len(col_name) + name_offset 131 | jsn = typ.jsonValue() 132 | type_string = format_structfield(jsn['type'], current_pad, isNullable) 133 | if isPartition: 134 | type_string += " (pk)" 135 | if isBucket: 136 | type_string += " (bk)" 137 | return type_string.replace('\n', '\\n') 138 | 139 | # COMMAND ---------- 140 | 141 | def generate_plantuml_schema(spark: SparkSession, databases: list, file_name: str): 142 | with open(file_name, "w") as f: 143 | f.write("\n".join(["@startuml", "skinparam packageStyle rectangle", 144 | "hide circle", "hide empty methods", 145 | "skinparam defaultFontName Courier", "", ""])) 146 | 147 | for database_name in databases[:3]: 148 | print(f"processing database {database_name}") 149 | f.write(f'package "{database_name}" {{\n') 150 | tables = spark.sql(f"show tables in `{database_name}`") 151 | # TODO: allow to pass additional mapping between table and partition keys in it that aren't defined explicitly 152 | partition_keys = {} 153 | columns_mapping = {} 154 | for tbl in tables.collect(): 155 | table_name = tbl["tableName"] 156 | db = tbl["database"] 157 | # TODO: we can try to parallelize this by running in the thread pool 158 | if include_temp or not tbl["isTemporary"]: # include only not temporary tables 159 | lines = [] 160 | try: 161 | tmp_txt = "" 162 | if tbl["isTemporary"]: 163 | tmp_txt = "(temp)" 164 | lines.append(f'class {table_name} {tmp_txt} {{') 165 | cols = spark.catalog.listColumns(table_name, dbName=db) 166 | # TODO: find the column with the longest name, and use it as offset for all? 167 | # Pad actual column name to that length 168 | column_names = [] 169 | columns = [] 170 | for cl in cols: 171 | col_name = cl.name 172 | column_names.append(col_name) 173 | schema = spark.createDataFrame([], cl.dataType).schema[0] 174 | is_partition = cl.isPartition 175 | if is_partition: 176 | if col_name in partition_keys: 177 | partition_keys[col_name].add(table_name) 178 | else: 179 | partition_keys[col_name] = {table_name} 180 | type_string = format_type_name(col_name, schema, cl.nullable, 181 | is_partition, cl.isBucket) 182 | columns.append({'name': col_name, 'is_pk': is_partition, 'type': type_string}) 183 | 184 | columns.sort(key=lambda col: (not col['is_pk'], col['name'].lower())) 185 | for col in columns: 186 | lines.append(f'{{field}} {col["name"]} : {col["type"]}') 187 | 188 | lines.append('}\n') 189 | f.write("\n".join(lines)) 190 | columns_mapping[table_name] = column_names 191 | except AnalysisException as ex: 192 | print(f"Error when trying to describe {tbl.database}.{table_name}: {ex}") 193 | 194 | links = set() 195 | for table_name, columns in columns_mapping.items(): 196 | for col in columns: 197 | for pkey_table in partition_keys.get(col, []): 198 | if table_name != pkey_table: 199 | links.add(f'{table_name} *.. {pkey_table}: {col}\n') 200 | 201 | for link in links: 202 | f.write(link) 203 | 204 | f.write("}\n\n") 205 | 206 | f.write("@enduml\n") 207 | 208 | # COMMAND ---------- 209 | 210 | fd, tmp_name = tempfile.mkstemp() 211 | os.close(fd) 212 | generate_plantuml_schema(spark, databases, tmp_name) 213 | dbutils.fs.cp(f"file:{tmp_name}", output_path) 214 | 215 | # COMMAND ---------- 216 | 217 | if generate_image != "": 218 | os.system("apt-get -f install -y graphviz") 219 | plantuml_path = "/tmp/plantuml.jar" 220 | if not os.path.exists(plantuml_path): 221 | r = requests.get(plantum_download_link) 222 | with open(plantuml_path , 'wb') as file: 223 | file.write(r.content) 224 | for image_type in generate_image.split(","): 225 | os.system(f"""java -jar /tmp/plantuml.jar -t{image_type} {tmp_name}""") 226 | dbutils.fs.cp(f"file:{tmp_name}.{image_type}", output_path + "." + image_type) 227 | -------------------------------------------------------------------------------- /database-diagram-builder/database-diagram-builder-standalone.py: -------------------------------------------------------------------------------- 1 | # This script generates PlantUML diagram for tables visible to Spark. 2 | # The diagram is stored in the db_schema.puml file, so just run 3 | # 'java -jar plantuml.jar db_schema.puml' to get PNG file 4 | 5 | from pyspark.sql import SparkSession 6 | from pyspark.sql.utils import AnalysisException 7 | from pyspark.sql.types import * 8 | import sys 9 | 10 | name_offset = 3 11 | pad_offset = 2 12 | 13 | # TODO: allow to specify via command-line 14 | include_temp = False 15 | enable_hive_support = False 16 | 17 | def is_struct_type(typ: object) -> bool: 18 | return isinstance(typ, dict) and typ['type'] == 'struct' 19 | 20 | 21 | def maybe_get_field(obj: object, name: str, default: object) -> object: 22 | if isinstance(obj, dict): 23 | return obj.get(name, default) 24 | return default 25 | 26 | 27 | def format_structfield(type_val: object, padding: int, isNullable: bool = False) -> str: 28 | if isinstance(type_val, str): 29 | type_string = type_val 30 | elif isinstance(type_val, dict): 31 | sub_type = type_val['type'] 32 | if sub_type == 'array': 33 | type_string = "array< " 34 | element_type = type_val['elementType'] 35 | is_struct = is_struct_type(element_type) 36 | if is_struct: 37 | type_string += "\n" 38 | padding += pad_offset 39 | type_string += (" " * padding) 40 | 41 | type_string += format_structfield(element_type, padding, type_val.get('containsNull', False)) 42 | if is_struct: 43 | type_string += "\n" 44 | padding -= pad_offset 45 | type_string += (" " * (padding - pad_offset)) 46 | 47 | type_string += " >" 48 | elif sub_type == 'map': 49 | # TODO: fix it - need to find the example of output 50 | type_string = "map< " 51 | element_type = type_val['keyType'] 52 | is_struct = is_struct_type(element_type) 53 | if is_struct: 54 | type_string += "\n" 55 | type_string += (" " * padding) 56 | padding += pad_offset 57 | 58 | type_string += format_structfield(element_type, padding) 59 | if is_struct: 60 | padding -= pad_offset 61 | type_string += ", " 62 | 63 | element_type = type_val['valueType'] 64 | is_struct = is_struct_type(element_type) 65 | if is_struct: 66 | type_string += "\n" 67 | type_string += (" " * padding) 68 | padding += pad_offset 69 | type_string += format_structfield(element_type, padding, type_val.get('valueContainsNull', False)) 70 | if is_struct: 71 | type_string += "\n" 72 | padding -= pad_offset 73 | type_string += (" " * (padding - pad_offset)) 74 | 75 | type_string += " >" 76 | elif sub_type == 'struct': 77 | pad_str = (" " * (padding + pad_offset)) 78 | type_string = "struct<\n" 79 | for field in type_val['fields']: 80 | fname = field['name'] 81 | type_string += pad_str + fname + " : " 82 | type_string += format_structfield(field['type'], padding + len(fname) + name_offset + pad_offset, 83 | field.get('nullable', False)) 84 | type_string += "\n" 85 | 86 | type_string += (" " * padding) + ">" 87 | else: 88 | raise Exception(f'Unknown subtype: {sub_type}') 89 | else: 90 | raise Exception(f'Unknown type: {type_val}') 91 | 92 | if isNullable: 93 | type_string += ' ?' 94 | return type_string 95 | 96 | 97 | def format_type_name(col_name: str, typ: StructField, isNullable: bool = False, 98 | isPartition: bool = False, isBucket: bool = False) -> str: 99 | current_pad = len(col_name) + name_offset 100 | jsn = typ.jsonValue() 101 | type_string = format_structfield(jsn['type'], current_pad, isNullable) 102 | if isPartition: 103 | type_string += " (pk)" 104 | if isBucket: 105 | type_string += " (bk)" 106 | return type_string.replace('\n', '\\n') 107 | 108 | 109 | def generate_plantuml_schema(spark: SparkSession, databases: list, file_name: str): 110 | with open(file_name, "w") as f: 111 | f.write("\n".join(["@startuml", "skinparam packageStyle rectangle", 112 | "hide circle", "hide empty methods", 113 | "skinparam defaultFontName Courier", "", ""])) 114 | 115 | for database_name in databases[:3]: 116 | print(f"processing database {database_name}") 117 | f.write(f'package "{database_name}" {{\n') 118 | tables = spark.sql(f"show tables in `{database_name}`") 119 | # TODO: allow to pass additional mapping between table and partition keys in it that aren't defined explicitly 120 | partition_keys = {} 121 | columns_mapping = {} 122 | for tbl in tables.collect(): 123 | table_name = tbl["tableName"] 124 | db = tbl["database"] 125 | # TODO: we can try to parallelize this by running in the thread pool 126 | if include_temp or not tbl["isTemporary"]: # include only not temporary tables 127 | lines = [] 128 | try: 129 | tmp_txt = "" 130 | if tbl["isTemporary"]: 131 | tmp_txt = "(temp)" 132 | lines.append(f'class {table_name} {tmp_txt} {{') 133 | cols = spark.catalog.listColumns(table_name, dbName=db) 134 | # TODO: find the column with the longest name, and use it as offset for all? 135 | # Pad actual column name to that length 136 | column_names = [] 137 | columns = [] 138 | for cl in cols: 139 | col_name = cl.name 140 | column_names.append(col_name) 141 | schema = spark.createDataFrame([], cl.dataType).schema[0] 142 | is_partition = cl.isPartition 143 | if is_partition: 144 | if col_name in partition_keys: 145 | partition_keys[col_name].add(table_name) 146 | else: 147 | partition_keys[col_name] = {table_name} 148 | type_string = format_type_name(col_name, schema, cl.nullable, 149 | is_partition, cl.isBucket) 150 | columns.append({'name': col_name, 'is_pk': is_partition, 'type': type_string}) 151 | 152 | columns.sort(key=lambda col: (not col['is_pk'], col['name'].lower())) 153 | for col in columns: 154 | lines.append(f'{{field}} {col["name"]} : {col["type"]}') 155 | 156 | lines.append('}\n') 157 | f.write("\n".join(lines)) 158 | columns_mapping[table_name] = column_names 159 | except AnalysisException as ex: 160 | print(f"Error when trying to describe {tbl.database}.{table_name}: {ex}") 161 | 162 | links = set() 163 | for table_name, columns in columns_mapping.items(): 164 | for col in columns: 165 | for pkey_table in partition_keys.get(col, []): 166 | if table_name != pkey_table: 167 | links.add(f'{table_name} *.. {pkey_table}: {col}\n') 168 | 169 | for link in links: 170 | f.write(link) 171 | 172 | f.write("}\n\n") 173 | 174 | f.write("@enduml\n") 175 | 176 | 177 | if __name__ == '__main__': 178 | # Variables 179 | # list of databases/namespaces to analyze. Could be empty, then all existing databases/namespaces will be processed 180 | # put databases/namespace to handle 181 | databases = [x for x in sys.argv if len(x) > 0 and not x.endswith(".py")] 182 | # change this if you want to include temporary tables as well 183 | 184 | # implementation 185 | builder = SparkSession.builder.appName("Database Schema Generator") 186 | if enable_hive_support: 187 | builder.enableHiveSupport() 188 | spark = builder.getOrCreate() 189 | 190 | # if databases aren't specified, then fetch list from the Spark 191 | if len(databases) == 0: 192 | databases = [db[0] for db in spark.sql("show databases").collect()] 193 | 194 | generate_plantuml_schema(spark, databases, "db_schema.puml") 195 | -------------------------------------------------------------------------------- /dbconnect-maven/README.md: -------------------------------------------------------------------------------- 1 | This is an example of how Maven could be used to build Spark project that later will be executed via [databricks-connect](https://docs.databricks.com/dev-tools/databricks-connect.html). 2 | 3 | ## Setup 4 | 5 | To build the project following setup steps are necessary: 6 | 7 | * Install Databricks Connect & Databricks CLI (better into a separate virtual Python/Conda environment - in this example, with name `test`. For DBR 6.x we must to use Python 3.7!): 8 | 9 | ```sh 10 | conda create --name test python=3.8 11 | conda activate test 12 | pip install -U databricks-connect==7.3 databricks-cli 13 | export DBCONNECT_JARS=$(databricks-connect get-jar-dir) 14 | export SPARK_HOME=$(dirname $DBCONNECT_JARS) 15 | ``` 16 | 17 | * Configure Databricks Connect as described in [documentation](https://docs.databricks.com/dev-tools/databricks-connect.html) - it's required if you want to run the artifact from your machine 18 | * (optional) Configure Databricks CLI as described in [documentation](https://docs.databricks.com/dev-tools/cli/index.html) 19 | 20 | ## Build 21 | 22 | This project is configured to be compiled with OSS Spark 3.0.1 that is compatible with DBR 7.x. If you want to adapt it to another Databricks Runtime (DBR) version, then adjust properties `scala.version`, `spark.version`, and `spark.scala.version` to match Spark & Scala versions used in given DBR. 23 | 24 | After that just execute `mvn clean package` to build the project. 25 | 26 | ## Executing from command-line 27 | 28 | To run the built project from the command line we need to build the project with `mvn package` & execute following command: 29 | 30 | ```sh 31 | $SPARK_HOME/bin/spark-submit --class net.alexott.demos.databricks.SimpleSQL target/dbconnect-maven-demo-0.0.1-jar-with-dependencies.jar 32 | ``` 33 | 34 | In some cases you need to set environment variable `DEBUG_IGNORE_VERSION_MISMATCH` to value `1` because `spark-submit` checks if the databricks-connect checks for its version & fails if it doesn't match to the cluster's version. And it could be a problem as databricks-connect isn't published for every DBR version. You just need to make sure that you're using databricks-connect 6.x with DBR 6.x, and similarly for 7.x versions. 35 | 36 | ## Executing from Intellij Idea 37 | 38 | The Ultimate version of Intellij Idea has [built-in support for submitting Spark jobs via spark-submit](https://www.jetbrains.com/help/idea/big-data-tools-spark-submit.html), so refer to documentation on how to configure it. 39 | 40 | For Community Edition of Idea, we'll need to force the explicit submission of the Spark job by using the `SparkSubmit` class. To configure this, go to the "Run" > "Edit Configurations..." and change settings as following: 41 | 42 | * Main class: `org.apache.spark.deploy.SparkSubmit` 43 | * VM options: `-cp $Classpath$:$SPARK_HOME` 44 | * Program arguments: `--class net.alexott.demos.databricks.SimpleSQL target/dbconnect-maven-demo-0.0.1-jar-with-dependencies.jar`. The `.jar` file should be built before execution, so it makes sense to hook `mvn package` into the "Before launch" configuration - select "Run Maven Goal" and put `package` there 45 | * Environment variables: `SPARK_HOME=...` (put the value of `SPARK_HOME` defined above), and maybe `DEBUG_IGNORE_VERSION_MISMATCH=1` to allow to run on "incompatible" clusters. 46 | 47 | After that you can execute Spark job directly from Idea. 48 | -------------------------------------------------------------------------------- /dbconnect-maven/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos.databricks 6 | dbconnect-maven-demo 7 | 0.0.1 8 | jar 9 | 10 | 11 | UTF-8 12 | 2.12.10 13 | 3.0.1 14 | 2.12 15 | 1.8 16 | 17 | 18 | 19 | 20 | Spark-Packages 21 | https://dl.bintray.com/spark-packages/maven/ 22 | 23 | 24 | 25 | 26 | 27 | org.apache.spark 28 | spark-sql_${spark.scala.version} 29 | ${spark.version} 30 | provided 31 | 32 | 33 | 34 | 35 | 36 | 37 | maven-compiler-plugin 38 | 3.8.1 39 | 40 | ${java.version} 41 | ${java.version} 42 | true 43 | 44 | 45 | 46 | net.alchim31.maven 47 | scala-maven-plugin 48 | 4.4.0 49 | 50 | 51 | process-sources 52 | 53 | compile 54 | testCompile 55 | 56 | 57 | ${scala.version} 58 | 59 | 62 | 63 | 64 | 65 | 66 | org.apache.maven.plugins 67 | maven-assembly-plugin 68 | 3.2.0 69 | 70 | 71 | jar-with-dependencies 72 | 73 | 74 | 75 | 76 | package 77 | 78 | single 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /dbconnect-maven/src/main/scala/net/alexott/demos/databricks/SimpleSQL.scala: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.databricks 2 | 3 | import org.apache.spark.sql.SparkSession 4 | import org.apache.spark.sql.functions._ 5 | 6 | object SimpleSQL { 7 | def main(args: Array[String]): Unit = { 8 | val spark = SparkSession.builder() 9 | .appName("Simple SQL Test").getOrCreate() 10 | 11 | val df = spark.range(100000).withColumn("randc", rand()) 12 | df.createOrReplaceTempView("my_sql_test") 13 | df.printSchema() 14 | spark.sql("select * from my_sql_test where randc > 0.9 limit 10").show(false) 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /dbconnect-package-versions-check/README.md: -------------------------------------------------------------------------------- 1 | ## Checks compatibility of local Databricks connect environment with Databricks cluster 2 | 3 | [`package_versions_check.py`](package_versions_check.py) - checks the versions of Python packages installed on the driver node & one of the executors. I was using that to debug the "strange error" when using [dbconnect](https://docs.databricks.com/dev-tools/databricks-connect.html) when local environment had different versions of the packages than were used in the Databricks ML Runtime. 4 | -------------------------------------------------------------------------------- /dbconnect-package-versions-check/package_versions_check.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.types import StringType, MapType 2 | from pyspark.sql.functions import udf, col 3 | from pyspark.sql import SparkSession 4 | 5 | import pkg_resources 6 | 7 | def get_pkgs(dummy): 8 | pkgs = dict([(i.key, i.version) for i in pkg_resources.working_set]) 9 | return pkgs 10 | 11 | spark = SparkSession.builder.appName("Packages Checker").getOrCreate() 12 | 13 | pudf = udf(get_pkgs, MapType(StringType(), StringType())) 14 | 15 | remote_pkgs = spark.range(1, 2).select(pudf(col("id"))).first()[0] 16 | 17 | local_pkgs = dict([(i.key, i.version) for i in pkg_resources.working_set]) 18 | 19 | ignored_pkgs = {'setuptools', 'pip', 'wheel'} 20 | # set to True to output the libraries that are installed on executors, but not on driver 21 | report_not_installed = False 22 | # TODO: instead of direct printing, collect into array of dicts, and output at the end 23 | for pkg, ver in remote_pkgs.items(): 24 | if pkg in ignored_pkgs: 25 | continue 26 | loc_ver = local_pkgs.get(pkg) 27 | if loc_ver is not None: 28 | if ver != loc_ver: 29 | print(f"Version mismatch for packate {pkg}! Remote: {ver}, Local: {loc_ver}") 30 | elif report_not_installed: 31 | print(f"Package {pkg} is not installed locally. Remote version is {ver}") 32 | -------------------------------------------------------------------------------- /dbconnect-sbt/README.md: -------------------------------------------------------------------------------- 1 | This is an example of how SBT could be used to build project that uses jars from the [databricks-connect](https://docs.databricks.com/dev-tools/databricks-connect.html). Sometimes this is required because Databricks runtime (DBR) sometimes has more functionality than open source spark. But in the most case it's enough to build a project using OSS Spark dependencies, and declare them as provided (shown as example in the `build.sbt`, as commented out code). 2 | 3 | ## Setup 4 | 5 | To build the project following setup steps are necessary: 6 | 7 | * Install Databricks Connect & Databricks CLI (better into a separate virtual Python/Conda environment - in this example, with name `test`. For DBR 6.x we must to use Python 3.7!): 8 | 9 | ```sh 10 | conda create --name test python=3.8 11 | conda activate test 12 | pip install -U databricks-connect==7.3 databricks-cli 13 | export DBCONNECT_JARS=$(databricks-connect get-jar-dir) 14 | export SPARK_HOME=$(dirname $DBCONNECT_JARS) 15 | ``` 16 | 17 | * Configure Databricks Connect as described in [documentation](https://docs.databricks.com/dev-tools/databricks-connect.html) - it's required if you want to run the artifact from your machine 18 | * (optional) Configure Databricks CLI as described in [documentation](https://docs.databricks.com/dev-tools/cli/index.html) 19 | 20 | ## Build 21 | 22 | To be able to compile & package this project you need to do following: 23 | 24 | * (optional) Setting the location of the DB jars. It could be configured by multiple ways, in order of execution: 25 | 1. SBT will try to use the path from the environment variable `DBCONNECT_JARS` set above. 26 | 1. try to get path from the `DbConnectJars` system property - that's useful when you're running the project from the IDE. It could be set in the `.sbtopts` file, for example, as `-DDbConnectJars=....` (execute in the root directory of the project following shell command: `echo "-DDbConnectJars=$(databricks-connect get-jar-dir)" |tee -a .sbtopts` to add this definition to the `.sbtopts` file) 27 | 1. try to execute `databricks-connect get-jar-dir` if the `databricks-connect` is in the `PATH` 28 | 1. take path from the `jarsPathManual` that is defined in the `build.sbt` - open it in the editor, and set this variable to the path to jars obtained via `databricks-connect get-jar-dir` 29 | * Execute `sbt clean package` to build the project 30 | 31 | ## Executing from command-line 32 | 33 | To run the built project from the command line we need to build the project with `sbt package` & execute following command: 34 | 35 | ```sh 36 | $SPARK_HOME/bin/spark-submit --class net.alexott.demos.databricks.SimpleSQL target/scala-2.12/dbconnect-sbt_2.12-0.1.jar 37 | ``` 38 | 39 | In some cases you need to set environment variable `DEBUG_IGNORE_VERSION_MISMATCH` to value `1` because `spark-submit` checks if the databricks-connect checks for its version & fails if it doesn't match to the cluster's version. And it could be a problem as databricks-connect isn't published for every DBR version. You just need to make sure that you're using databricks-connect 6.x with DBR 6.x, and similarly for 7.x versions. 40 | 41 | ## Executing from Intellij Idea 42 | 43 | The Ultimate version of Intellij Idea has [built-in support for submitting Spark jobs via spark-submit](https://www.jetbrains.com/help/idea/big-data-tools-spark-submit.html), so refer to documentation on how to configure it. 44 | 45 | For Community Edition of Idea, we'll need to force the explicit submission of the Spark job by using the `SparkSubmit` class. To configure this, go to the "Run" > "Edit Configurations..." and change settings as following: 46 | 47 | * Main class: `org.apache.spark.deploy.SparkSubmit` 48 | * VM options: `-cp $Classpath$:$SPARK_HOME` 49 | * Program arguments: `--class net.alexott.demos.databricks.SimpleSQL target/scala-2.12/dbconnect-sbt_2.12-0.1.jar` - the `.jar` file should be built before execution, so it makes sense to hook `sbt package` into the "Before launch" configuration 50 | * Environment variables: `SPARK_HOME=...` (put the value of `SPARK_HOME` defined above), and maybe `DEBUG_IGNORE_VERSION_MISMATCH=1` to allow to run on "incompatible" clusters. 51 | 52 | After that you can execute Spark job directly from Idea. 53 | -------------------------------------------------------------------------------- /dbconnect-sbt/build.sbt: -------------------------------------------------------------------------------- 1 | name := "dbconnect-sbt" 2 | 3 | organization := "net.alexott.demos.databricks" 4 | 5 | version := "0.1" 6 | 7 | scalaVersion := "2.12.12" 8 | scalacOptions ++= Seq("-Xmax-classfile-name", "78") 9 | 10 | // To use jars from the databricks-connect, comment out if it's not required 11 | //// default value 12 | val jarsPathManual = "" 13 | unmanagedBase := { 14 | import java.nio.file.{Files, Paths} 15 | import scala.sys.process._ 16 | 17 | val jarsPathEnv = System.getenv("DBCONNECT_JARS") 18 | if (jarsPathEnv != null && Files.isDirectory(Paths.get(jarsPathEnv))) { 19 | // println("We have path from the environment variable! " + jarsPathEnv) 20 | new java.io.File(jarsPathEnv) 21 | } else { 22 | val paramPathEnv = System.getProperty("DbConnectJars") 23 | if (paramPathEnv != null && Files.isDirectory(Paths.get(paramPathEnv))) { 24 | // println("We have path from the system parameter! " + paramPathEnv) 25 | new java.io.File(paramPathEnv) 26 | } else { 27 | val dbConenctPath: String = try { 28 | Seq("databricks-connect", "get-jar-dir").!!.trim 29 | } catch { 30 | case e: Exception => 31 | // println(s"Exception running databricks-connect: ${e.getMessage}") 32 | "" 33 | } 34 | if (!dbConenctPath.isEmpty && Files.isDirectory(Paths.get(dbConenctPath))) { 35 | // println("We have path from the databricks-connect! " + dbConenctPath) 36 | new java.io.File(dbConenctPath) 37 | } else if (Files.isDirectory(Paths.get(jarsPathManual))) { 38 | // println("We have path from the manual path! " + jarsPathManual) 39 | new java.io.File(jarsPathManual) 40 | } else { 41 | throw new RuntimeException("Can't find DB jars required for build! Set DBCONNECT_JARS environment variable, set -DDbConnectJars=path in .sbtopts, activate conda environment, or set the 'jarsPathManual' variable") 42 | } 43 | } 44 | } 45 | } 46 | 47 | // Example how to use OSS Spark dependencies, to use, comment out the 'unmanagedBase' piece and uncomment next lines 48 | // val sparkVersion = "3.0.1" 49 | // libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % Provided 50 | // libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % Provided 51 | // libraryDependencies += "org.apache.spark" %% "spark-hive" % sparkVersion % Provided 52 | -------------------------------------------------------------------------------- /dbconnect-sbt/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.4.0 2 | -------------------------------------------------------------------------------- /dbconnect-sbt/src/main/scala/net/alexott/demos/databricks/SimpleSQL.scala: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.databricks 2 | 3 | import org.apache.spark.sql.SparkSession 4 | import org.apache.spark.sql.functions._ 5 | 6 | object SimpleSQL { 7 | def main(args: Array[String]): Unit = { 8 | val spark = SparkSession.builder() 9 | .appName("Simple SQL Test").getOrCreate() 10 | 11 | val df = spark.range(100000).withColumn("randc", rand()) 12 | df.createOrReplaceTempView("my_sql_test") 13 | df.printSchema() 14 | spark.sql("select * from my_sql_test where randc > 0.9 limit 10").show(false) 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /dbsql-with-aad-token/README.md: -------------------------------------------------------------------------------- 1 | # Querying Databricks SQL or cluster via Databricks SQL Python connector using AAD token 2 | 3 | This directory contain demonstration of how to authenticate to Databricks using Azure Active Directory token when using [python-sql-connector](https://docs.databricks.com/dev-tools/python-sql-connector.html) library. There are three implementations: 4 | 1. `dbsql-with-aad-token-spn.py` using service principal and execute query against Databricks SQL Endpoint or Databricks Cluster. 5 | 2. `dbsql-with-aad-token-user.py` using user AAD interactive authentication. 6 | 2. `dbsql-with-aad-token-mi.py` using Azure Managed Identity associated with VM or AKS cluster. 7 | 8 | ## Installation 9 | 10 | Install all necessary dependencies: 11 | 12 | * `databricks-sql-connector` 13 | 14 | by executing: 15 | 16 | ```sh 17 | pip install -U -r requirements.txt 18 | ``` 19 | 20 | ## Modify the script(s) 21 | 22 | You need to modify scripts and change following variables: 23 | 24 | * `host_name` - set to host name of the Databricks workspace (without `https://`) 25 | * `http_path` - obtain HTTP Path parameter of Databricks SQL Endpoint or Databricks Cluster as per [documentation](https://docs.databricks.com/dev-tools/python-sql-connector.html#get-started). 26 | * `query` - what query should be executed 27 | 28 | ### Set authentication parameters for dbsql-with-aad-token-spn.py 29 | 30 | Authentication parameters of service principal could be set in the code directly (not the best way), or obtained from following environment variables: 31 | 32 | * `ARM_TENANT_ID` - tenant ID in Azure Active Directory 33 | * `ARM_CLIENT_ID` - client ID of application in Azure Active Directory 34 | * `ARM_CLIENT_SECRET` - secret for AAD application 35 | 36 | ### Set authentication parameters for dbsql-with-aad-token-user.py 37 | 38 | Authentication parameters of service principal could be set in the code directly (not the best way), or obtained from following environment variable: 39 | 40 | * `ARM_CLIENT_ID` - client ID of application in Azure Active Directory that has user impersonation permission for Azure Databricks 41 | 42 | ### Set authentication parameters for dbsql-with-aad-token-mi.py 43 | 44 | By default, the script will use default managed identity associated with the Azure VM or AKS. If you want to authenticate using a specific user-assigned managed identity, then set following environment variable. 45 | 46 | * `ARM_CLIENT_ID` - client ID of user-assigned managed identity associated with VM or AKS. 47 | 48 | ## Execute script(s) 49 | 50 | Just run: 51 | 52 | ```sh 53 | python dbsql-with-aad-token-spn.py 54 | ``` 55 | 56 | or 57 | 58 | ```sh 59 | python dbsql-with-aad-token-user.py 60 | ``` 61 | 62 | or (only from VM with MI) 63 | 64 | ```sh 65 | python dbsql-with-aad-token-mi.py 66 | ``` 67 | 68 | and it will print result of query execution. 69 | -------------------------------------------------------------------------------- /dbsql-with-aad-token/dbsql-with-aad-token-mi.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from databricks import sql 4 | import requests 5 | import time 6 | import os 7 | 8 | from databricks.sql.auth.authenticators import AuthProvider 9 | 10 | # Variables to fill 11 | # Host name without https:// 12 | host_name = "adb-.....17.azuredatabricks.net" 13 | # Path obtained as per instructions https://docs.databricks.com/dev-tools/python-sql-connector.html#get-started 14 | http_path = "/sql/1.0/warehouses/951d1b041fc6c792" 15 | # Your query to execute 16 | query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()" 17 | 18 | TOKEN_REFRESH_LEAD_TIME = 120 19 | DEFAULT_DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" 20 | AZURE_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/" 21 | AZURE_METADATA_SERVICE_TOKEN_URL = "http://169.254.169.254/metadata/identity/oauth2/token" 22 | 23 | 24 | # Heavily based on the Apache Airflow implementation that is originally written by me 25 | # https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token 26 | class AzureMIAuthProvider(AuthProvider): 27 | def __init__(self, uami_application_id: Optional[str] = None, 28 | databricks_resource_id: Optional[str] = None): 29 | """ 30 | 31 | :param uami_application_id: optional client ID of User Assigned Managed Identity 32 | :param databricks_resource_id: optional Azure Databricks resource ID 33 | """ 34 | self.__uami_application_id = uami_application_id 35 | self.__databricks_resource_id = databricks_resource_id 36 | self.__oauth_tokens = {} 37 | 38 | @staticmethod 39 | def _is_oauth_token_valid(token: dict, time_key="expires_on") -> bool: 40 | if "access_token" not in token or token.get("token_type", 41 | "") != "Bearer" or time_key not in token: 42 | raise Exception(f"Can't get necessary data from OAuth token: {token}") 43 | return int(token[time_key]) > (int(time.time()) + TOKEN_REFRESH_LEAD_TIME) 44 | 45 | def _get_aad_token(self, resource: str) -> str: 46 | aad_token = self.__oauth_tokens.get(resource) 47 | if aad_token and self._is_oauth_token_valid(aad_token): 48 | return aad_token["access_token"] 49 | params = { 50 | "api-version": "2018-02-01", 51 | "resource": resource, 52 | } 53 | if self.__uami_application_id: 54 | params["client_id"] = self.__uami_application_id 55 | resp = requests.get(AZURE_METADATA_SERVICE_TOKEN_URL, params=params, 56 | headers={"Metadata": "true"}, timeout=2, 57 | ) 58 | resp.raise_for_status() 59 | jsn = resp.json() 60 | 61 | self._is_oauth_token_valid(jsn) 62 | self.__oauth_tokens[resource] = jsn 63 | return jsn["access_token"] 64 | 65 | def __call__(self, *args, **kwargs): 66 | request_headers = {} 67 | if self.__databricks_resource_id: 68 | mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) 69 | request_headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.__databricks_resource_id 70 | request_headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token 71 | 72 | request_headers["Authorization"] = f"Bearer {self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)}" 73 | # print(request_headers) 74 | return request_headers 75 | 76 | 77 | creds = AzureMIAuthProvider(uami_application_id=os.getenv("ARM_CLIENT_ID"), 78 | databricks_resource_id=os.getenv("DATABRICKS_AZURE_RESOURCE_ID")) 79 | 80 | with sql.connect(server_hostname=host_name, http_path=http_path, 81 | credentials_provider=lambda: creds) as connection: 82 | cursor = connection.cursor() 83 | cursor.execute(query) 84 | result = cursor.fetchall() 85 | 86 | for row in result: 87 | print(row) 88 | 89 | cursor.close() 90 | connection.close() 91 | -------------------------------------------------------------------------------- /dbsql-with-aad-token/dbsql-with-aad-token-spn.py: -------------------------------------------------------------------------------- 1 | from databricks import sql 2 | from databricks.sdk.core import azure_service_principal, Config 3 | import os 4 | 5 | def get_env_or_raise(name: str) -> str: 6 | value = os.environ[name] 7 | if not value: 8 | raise Exception(f"Environment variable {name} isn't set or empty") 9 | 10 | return value 11 | 12 | # Variables to fill 13 | # Host name without https:// 14 | host_name="adb-xxx.17.azuredatabricks.net" 15 | # Path obtained as per instructions https://docs.databricks.com/dev-tools/python-sql-connector.html#get-started 16 | http_path="/sql/1.0/warehouses/xxx" 17 | # Your query to execute 18 | query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()" 19 | 20 | # Explicit initialization of the authenticator 21 | tenant_id = get_env_or_raise("ARM_TENANT_ID") 22 | client_id = get_env_or_raise("ARM_CLIENT_ID") 23 | client_secret = get_env_or_raise("ARM_CLIENT_SECRET") 24 | 25 | creds = azure_service_principal(Config(azure_client_secret=client_secret, azure_client_id=client_id, 26 | azure_tenant_id=tenant_id, host=f"https://{host_name}")) 27 | 28 | # Implicit initialization - you need to specify corresponding environment variables 29 | # creds = azure_service_principal(Config()) 30 | 31 | with sql.connect(server_hostname=host_name, http_path=http_path, 32 | credentials_provider=lambda: creds) as connection: 33 | cursor = connection.cursor() 34 | cursor.execute(query) 35 | result = cursor.fetchall() 36 | 37 | for row in result: 38 | print(row) 39 | 40 | cursor.close() 41 | connection.close() 42 | -------------------------------------------------------------------------------- /dbsql-with-aad-token/dbsql-with-aad-token-user.py: -------------------------------------------------------------------------------- 1 | from databricks import sql 2 | import os 3 | 4 | 5 | def get_env_or_raise(name: str) -> str: 6 | value = os.environ[name] 7 | if not value: 8 | raise Exception(f"Environment variable {name} isn't set or empty") 9 | 10 | return value 11 | 12 | 13 | # Variables to fill 14 | # Host name without https:// 15 | host_name = "adb-xxxx.17.azuredatabricks.net" 16 | # Path obtained as per instructions https://docs.databricks.com/dev-tools/python-sql-connector.html#get-started 17 | http_path = "/sql/1.0/warehouses/xxx" 18 | # Your query to execute 19 | query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()" 20 | 21 | # Explicit initialization of the authenticator 22 | client_id = get_env_or_raise("ARM_CLIENT_ID") 23 | 24 | with sql.connect(server_hostname=host_name, http_path=http_path, 25 | auth_type="databricks-oauth", 26 | oauth_client_id=client_id 27 | ) as connection: 28 | cursor = connection.cursor() 29 | cursor.execute(query) 30 | result = cursor.fetchall() 31 | 32 | for row in result: 33 | print(row) 34 | 35 | cursor.close() 36 | connection.close() 37 | -------------------------------------------------------------------------------- /dbsql-with-aad-token/requirements.txt: -------------------------------------------------------------------------------- 1 | databricks-sql-connector 2 | -------------------------------------------------------------------------------- /dbutils-in-jar/README.md: -------------------------------------------------------------------------------- 1 | # Example of using Databricks dbutils in JVM-based code compiled into .jar 2 | 3 | This directory shows how to use Databricks dbutils from the code compiled into & deployed as `.jar` file. It shows use of the `dbutils.secrets.get` and `dbutils.fs.ls` by listing files on ADLS using the service principal & SP secret is retrieved from secret scope (in my case, it was based on Azure KeyVault). 4 | 5 | The main caveat is that when building a fat jar, the `dbutils` and Spark dependencies shouldn't be included into the resulting `.jar` file - it's done by marking these dependencies as `provided` in the `pom.xml`. 6 | 7 | ## Setup 8 | 9 | 1. Change secret scope & secret names in line 15 of `src/main/scala/net/alexott/demos/DbutilsDemo.scala`, update other variables to point to your ADLS account. 10 | 1. Compile using `mvn package`. 11 | 1. Copy `target/dbutils-in-jar-0.0.1-jar-with-dependencies.jar` to DBFS or cloud storage. 12 | 1. Create a Databricks job with "Jar task", specify `net.alexott.demos.DbutilsDemo` as main class, use "No isolation shared" Access Mode in cluster configuration as we're using direct access to ADLS. 13 | 1. Run job - it should produce a list of files/directories in your ADLS account. 14 | -------------------------------------------------------------------------------- /dbutils-in-jar/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos.databricks 6 | dbutils-in-jar 7 | 0.0.1 8 | jar 9 | 10 | dbutils-in-jar 11 | 12 | 13 | UTF-8 14 | 2.12.12 15 | 3.3.2 16 | 2.12 17 | 1.8 18 | 19 | 20 | 21 | 22 | org.apache.spark 23 | spark-sql_${spark.scala.version} 24 | ${spark.version} 25 | provided 26 | 27 | 28 | com.databricks 29 | dbutils-api_${spark.scala.version} 30 | 0.0.6 31 | provided 32 | 33 | 34 | 35 | 36 | 37 | 38 | maven-compiler-plugin 39 | 3.8.1 40 | 41 | ${java.version} 42 | ${java.version} 43 | true 44 | 45 | 46 | 47 | net.alchim31.maven 48 | scala-maven-plugin 49 | 4.3.0 50 | 51 | 52 | process-sources 53 | 54 | compile 55 | testCompile 56 | 57 | 58 | ${scala.version} 59 | 60 | 63 | 64 | 65 | 66 | 67 | org.apache.maven.plugins 68 | maven-assembly-plugin 69 | 3.2.0 70 | 71 | 72 | jar-with-dependencies 73 | 74 | 75 | 76 | 77 | package 78 | 79 | single 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /dbutils-in-jar/src/main/scala/net/alexott/demos/DbutilsDemo.scala: -------------------------------------------------------------------------------- 1 | package net.alexott.demos 2 | 3 | import com.databricks.dbutils_v1.DBUtilsHolder.dbutils 4 | import org.apache.spark.sql.SparkSession 5 | 6 | object DbutilsDemo { 7 | def main(args: Array[String]): Unit = { 8 | val spark = SparkSession.builder().getOrCreate() 9 | 10 | // test that Spark works correctly... 11 | val df = spark.range(100) 12 | df.show(5) 13 | 14 | // get AAD service principal secret from the secret scope 15 | val clientSecret = dbutils.secrets.get("test", "test") 16 | val aadTenantId = "..." 17 | val clientID = "" 18 | // setup authentication parameters 19 | spark.conf.set("fs.azure.account.auth.type", "OAuth") 20 | spark.conf.set("fs.azure.account.oauth.provider.type", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider") 21 | spark.conf.set("fs.azure.account.oauth2.client.endpoint", s"https://login.microsoftonline.com/$aadTenantId/oauth2/token") 22 | spark.conf.set("fs.azure.account.oauth2.client.id", clientID) 23 | spark.conf.set("fs.azure.account.oauth2.client.secret", clientSecret) 24 | 25 | // list files on ADLS 26 | dbutils.fs.ls("abfss://test@test.dfs.core.windows.net/").foreach(println) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /dbutils-snippets/repos.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import sys 3 | 4 | 5 | def add_repos_paths(paths_in_repos: Optional[List[str]]=None): 6 | """Automatically add Repos to sys.path for use with files in repos. 7 | 8 | :param paths_in_repos: optional list of subdirectories inside the repository (relative paths!) 9 | """ 10 | np = dbutils.notebook.entry_point.getDbutils().notebook().getContext().extraContext().get("notebook_path") 11 | if np is None: 12 | return 13 | np = np.get() 14 | splits = np.split('/') 15 | if splits[1] == 'Repos': 16 | repo_root = '/Workspace' + '/'.join(splits[:4]) 17 | if repo_root not in sys.path: 18 | sys.path(repo_root) 19 | 20 | if paths_in_repos: 21 | for p in paths_in_repos: 22 | tp = repo_root + '/' + p 23 | if tp not in sys.path: 24 | sys.path.append(tp) 25 | 26 | -------------------------------------------------------------------------------- /deactivate-activate-users-sps/README.md: -------------------------------------------------------------------------------- 1 | # Tool to bulk deactivating/reactivating Databricks users and service principals inside the workspace 2 | 3 | This directory contains the `deactivate-activate-users-sps.py` script that allows to deactivate all non-admin users and service principals in the Databricks workspace, and store the list in the config file. And then it's possible to reactivate them based on the stored data. 4 | 5 | ## Installation 6 | 7 | You need to have [Databricks SDK for Python](https://pypi.org/project/databricks-sdk/) installed to run this tool. Do 8 | 9 | ```sh 10 | pip install databricks-sdk 11 | ``` 12 | 13 | to install SDK 14 | 15 | 16 | ## Running 17 | 18 | You must configure environment variables to perform [authentication](https://pypi.org/project/databricks-sdk/#authentication) to a Databricks workspace where work will be done. 19 | 20 | Use `-h` command-line option to get help on supported commands: 21 | 22 | ``` 23 | usage: deactivate-activate-users-sps.py [-h] [--file FILE] [--debug] [--verbose] {scan,deactivate,reactivate} 24 | 25 | Deactivate or reactivate Databricks users and service principals 26 | 27 | positional arguments: 28 | {scan,deactivate,reactivate} 29 | 30 | options: 31 | -h, --help show this help message and exit 32 | --file FILE File to store deactivated users/SPs (default: deactivated_users.json) 33 | --debug Enable debug output 34 | --verbose Enable verbose output 35 | ``` 36 | 37 | There are three commands supported 38 | 39 | * `scan` just prints a list of active users and service principals that aren't part of `admins` group. 40 | * `deactivate` finds all active non-admin users and service principals and deactivate them. Information about deactivated users/SPs is stored in the file that then will be used with `reactivate` command. 41 | * `reactivate` reads the previously generated file and reactivates users and service principals. 42 | 43 | 44 | Optionally, you can use `--file` command line option to change the file name where we store data about deactivated users/SPs (default file name is `deactivated_users.json`). 45 | 46 | 47 | 48 | ## TODOs 49 | 50 | 51 | - \[ \] Better error handling. 52 | -------------------------------------------------------------------------------- /deactivate-activate-users-sps/deactivate-activate-users-sps.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from databricks.sdk import WorkspaceClient 4 | import argparse 5 | import json 6 | import os 7 | 8 | from databricks.sdk.service.iam import Patch, PatchSchema, PatchOp 9 | 10 | default_file_name = "deactivated_users.json" 11 | default_schemas = [PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP] 12 | 13 | def find_non_admin_users_sps(wc: WorkspaceClient): 14 | admin_users = set() 15 | admin_sps = set() 16 | res = list(wc.groups.list(attributes="members", filter="displayName eq \"admins\"")) 17 | if not res: 18 | raise Exception("Can't find 'admins' group") 19 | for member in res[0].members: 20 | if member.ref.startswith("Users/"): 21 | admin_users.add(member.value) 22 | elif member.ref.startswith("ServicePrincipals/"): 23 | admin_sps.add(member.value) 24 | 25 | users_to_deactivate = [] 26 | attrs = "id,displayName,userName,active" 27 | all_active_users = wc.users.list(attributes=attrs, filter="active eq true") 28 | for user in all_active_users: 29 | if user.id not in admin_users: 30 | users_to_deactivate.append({"id": user.id, "display_name": user.display_name, "user_name": user.user_name}) 31 | 32 | sps_to_deactivate = [] 33 | all_active_sps = wc.service_principals.list(attributes=attrs, filter="active eq true") 34 | for sp in all_active_sps: 35 | if sp.id not in admin_sps: 36 | sps_to_deactivate.append({"id": sp.id, "display_name": sp.display_name, 37 | "application_id": sp.application_id}) 38 | 39 | return {"users": users_to_deactivate, "service_principals": sps_to_deactivate} 40 | 41 | 42 | def make_patch_op(active: bool): 43 | return [Patch(op=PatchOp.REPLACE, path="active", value=[ 44 | { 45 | "value": active 46 | } 47 | ])] 48 | 49 | 50 | def switch_user_active_status(wc: WorkspaceClient, user_id: str, active: bool): 51 | wc.users.patch(user_id, schemas=default_schemas, operations=make_patch_op(active)) 52 | 53 | 54 | def switch_sp_active_status(wc: WorkspaceClient, sp_id: str, active: bool): 55 | wc.service_principals.patch(sp_id, schemas=default_schemas, operations=make_patch_op(active)) 56 | 57 | 58 | def scan_identities_to_deactivate(wc: WorkspaceClient): 59 | identities = find_non_admin_users_sps(wc) 60 | users = identities['users'] 61 | sps = identities['service_principals'] 62 | print(f"Found {len(users)} users and {len(sps)} service principals.") 63 | print("Users:") 64 | for user in users: 65 | print(f"\tSCIM ID: {user['id']}, Display Name: {user['display_name']}, User Name: {user['user_name']}") 66 | print("\nService Principals:") 67 | for sp in sps: 68 | print(f"\tSCIM ID: {sp['id']}, Display Name: {sp['display_name']}, Application ID: {sp['application_id']}") 69 | 70 | 71 | def deactivate_users_sps(wc: WorkspaceClient, fname: str, verbose: bool): 72 | print("Deactivating users and service principals...") 73 | identities = find_non_admin_users_sps(wc) 74 | users = identities['users'] 75 | sps = identities['service_principals'] 76 | print(f"Found {len(users)} users and {len(sps)} service principals.") 77 | with open(args.file, "w") as f: 78 | json.dump(identities, f, indent=2) 79 | 80 | switch_users_status(wc, users, False, verbose) 81 | switch_sps_status(wc, sps, False, verbose) 82 | 83 | print(f"All users & SPs are deactivated. The list of identities is saved in {fname}") 84 | 85 | 86 | def switch_sps_status(wc: WorkspaceClient, sps: list, active: bool, verbose: bool): 87 | for sp in sps: 88 | switch_sp_active_status(wc, sp['id'], active) 89 | if verbose: 90 | status = "Reactivated" if active else "Deactivated" 91 | print( 92 | f"{status} service principal '{sp['display_name']} ({sp['application_id']})' with SCIM ID {sp['id']}") 93 | 94 | 95 | def switch_users_status(wc: WorkspaceClient, users: list, active: bool, verbose: bool): 96 | for user in users: 97 | switch_user_active_status(wc, user['id'], active) 98 | if verbose: 99 | status = "Reactivated" if active else "Deactivated" 100 | print(f"{status} user '{user['display_name']} ({user['user_name']})' with SCIM ID {user['id']}") 101 | 102 | 103 | def reactivate_users_sps(wc: WorkspaceClient, fname: str, verbose: bool): 104 | print(f"Activating users/SPs from {fname}") 105 | with open(args.file, "r") as f: 106 | identities = json.load(f) 107 | 108 | users = identities['users'] 109 | sps = identities['service_principals'] 110 | print(f"Loaded {len(users)} users and {len(sps)} service principals.") 111 | switch_users_status(wc, users, True, verbose) 112 | switch_sps_status(wc, sps, True, verbose) 113 | print(f"Users/SPs are reactivated. Removing {fname}...") 114 | os.remove(fname) 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser(description='Deactivate or reactivate Databricks users and service principals') 119 | supported_commands = ['scan', 'deactivate', 'reactivate'] 120 | parser.add_argument('command', choices=supported_commands) 121 | parser.add_argument('--file', default=default_file_name, 122 | help=f'File to store deactivated users/SPs (default: {default_file_name})') 123 | parser.add_argument('--debug', action='store_true', help='Enable debug output') 124 | parser.add_argument('--verbose', action='store_true', help='Enable verbose output') 125 | args = parser.parse_args() 126 | if args.command not in supported_commands: 127 | print(f"Unknown command {args.command}") 128 | exit(1) 129 | if args.debug: 130 | logging.basicConfig(level=logging.DEBUG) 131 | 132 | client = WorkspaceClient() 133 | if args.command == 'scan': 134 | scan_identities_to_deactivate(client) 135 | elif args.command == 'deactivate': 136 | deactivate_users_sps(client, args.file, args.verbose) 137 | elif args.command == 'reactivate': 138 | reactivate_users_sps(client, args.file, args.verbose) 139 | 140 | 141 | -------------------------------------------------------------------------------- /guardrails-checks/check_guardrails.py: -------------------------------------------------------------------------------- 1 | from databricks.sdk import WorkspaceClient 2 | 3 | workspace_client = WorkspaceClient() 4 | 5 | 6 | for se in workspace_client.serving_endpoints.list(): 7 | # print(f"Analyzing {se.name}") 8 | should_do_check = False 9 | if se.config.served_entities: 10 | for entity in se.config.served_entities: 11 | if entity.external_model or entity.foundation_model or entity.entity_name.startswith("system.ai."): 12 | should_do_check = True 13 | break 14 | 15 | if not should_do_check: 16 | continue 17 | found_problems = [] 18 | # print(f"Checking {se.name}") 19 | if not se.ai_gateway: 20 | found_problems.append("No AI Gateway found") 21 | else: 22 | if not se.ai_gateway.guardrails: 23 | found_problems.append("No guardrails found") 24 | else: 25 | if not se.ai_gateway.guardrails.input: 26 | found_problems.append("No input guardrails found") 27 | if not se.ai_gateway.guardrails.output: 28 | found_problems.append("No output guardrails found") 29 | 30 | if not se.ai_gateway.inference_table_config or not se.ai_gateway.inference_table_config.enabled: 31 | found_problems.append("Inference table is not enabled") 32 | 33 | if found_problems: 34 | print(f"Found problems for {se.name}: {found_problems}") 35 | # else: 36 | # print(f"No problems found for {se.name}") 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /init-scripts/README.md: -------------------------------------------------------------------------------- 1 | # Global & cluster init scripts for Databricks 2 | 3 | This directory contains a collection of [global & cluster init scripts for Databricks](https://docs.databricks.com/clusters/init-scripts.html). 4 | 5 | 6 | ## Global init scripts 7 | 8 | Many of these scripts could be used as cluster-level init scripts as well, but their main goal to configure all clusters in a workspace. To use, just add their content via [Databricks Admin Console](https://docs.databricks.com/clusters/init-scripts.html#global-init-scripts), via [Terraform](https://registry.terraform.io/providers/databrickslabs/databricks/latest/docs/resources/global_init_script), or [REST API](https://docs.databricks.com/dev-tools/api/latest/global-init-scripts.html). 9 | 10 | * [`install-ssl-certificates.sh`](install-ssl-certificates.sh) - allows to install custom CA SSL certificates into all certificates chains - Linux, JVM and Python's certifi package. This is required when organization uses an organization-specific SSL certificate authority (CA), and want to avoid getting SSL certificate validation errors when accessing internal resources signed by that CA. To use this script, upload to DBFS SSL CA certificate(s) in PEM format, and update list of paths on line 11 (`declare -a certs`). Please note that you need to use `/dbfs/` instead of `dbfs:/`. 11 | 12 | 13 | 14 | ## Cluster-level init scripts 15 | -------------------------------------------------------------------------------- /init-scripts/install-ssl-certificates.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # File: install-ssl-certificates.sh 4 | # Author: Alex Ott 5 | # Created: Tuesday, September 28 2021 6 | # 7 | 8 | #set -x 9 | 10 | # Update this list of certificates 11 | declare -a certs=("/dbfs/tmp/myCA.pem" "/dbfs/tmp/myCA2.pem") 12 | 13 | mkdir -p /usr/share/ca-certificates/extra 14 | CERTIFI_HOME="$(python -m certifi 2>/dev/null)" 15 | J_HOME="$(dirname $(realpath $(which java)))/.." 16 | 17 | for cert in ${certs[@]}; do 18 | BNAME="$(basename $cert)" 19 | echo "cert=$cert BNAME=$BNAME" 20 | cp $cert /usr/share/ca-certificates/extra/$BNAME 21 | echo "extra/$BNAME" >> /etc/ca-certificates.conf 22 | if [ -n "$CERTIFI_HOME" ]; then 23 | cat $cert >> $CERTIFI_HOME 24 | fi 25 | keytool -importcert -keystore ${J_HOME}/lib/security/cacerts -file $cert -alias $(basename $cert .pem) -storepass changeit -noprompt 26 | done 27 | 28 | update-ca-certificates 29 | 30 | #keytool -list -keystore ${J_HOME}/lib/security/cacerts -storepass changeit 31 | -------------------------------------------------------------------------------- /ip-access-list-analyzer/README.md: -------------------------------------------------------------------------------- 1 | # Analyzer/fix tool for Databricks IP Access Lists 2 | 3 | This tool was moved to [Databricks Labs Sandbox](https://github.com/databrickslabs/sandbox/tree/main/ip_access_list_analyzer) project. 4 | -------------------------------------------------------------------------------- /java-sdk-example/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos 6 | java-sdk-example 7 | 0.0.1 8 | jar 9 | 10 | 11 | UTF-8 12 | 11 13 | 11 14 | 11 15 | 16 | 17 | 18 | 19 | 20 | com.databricks 21 | databricks-sdk-java 22 | 0.38.0 23 | 24 | 25 | 26 | 27 | 28 | 29 | maven-compiler-plugin 30 | 3.8.1 31 | 32 | ${java.version} 33 | ${java.version} 34 | true 35 | 36 | 37 | 38 | org.apache.maven.plugins 39 | maven-assembly-plugin 40 | 3.2.0 41 | 42 | 43 | jar-with-dependencies 44 | 45 | 46 | 47 | 48 | package 49 | 50 | single 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /java-sdk-example/src/main/java/net/alexott/demos/M2MAzureAuth.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import com.databricks.sdk.WorkspaceClient; 4 | import com.databricks.sdk.core.DatabricksConfig; 5 | import com.databricks.sdk.service.compute.ClusterDetails; 6 | import com.databricks.sdk.service.compute.ClusterSource; 7 | import com.databricks.sdk.service.compute.ListClustersRequest; 8 | import com.databricks.sdk.service.compute.ListClustersFilterBy; 9 | 10 | import java.util.Arrays; 11 | 12 | class M2MAzureAuth { 13 | public static void main(String[] args) { 14 | // implicit configuration - via environment variables 15 | // https://github.com/databricks/databricks-sdk-java?tab=readme-ov-file#azure-native-authentication 16 | WorkspaceClient client = null; 17 | if (System.getenv("IMPLICIT_AUTH") != null) { 18 | client = new WorkspaceClient(); 19 | } else { 20 | DatabricksConfig config = new DatabricksConfig().setAuthType("azure-client-secret") 21 | .setAzureClientId(System.getenv("DATABRICKS_CLIENT_ID")) 22 | .setAzureClientSecret(System.getenv("DATABRICKS_CLIENT_SECRET")) 23 | .setAzureTenantId(System.getenv("DATABRICKS_TENANT_ID")) 24 | .setHost(System.getenv("DATABRICKS_HOST")); 25 | client = new WorkspaceClient(config); 26 | } 27 | System.out.println("Client: " + client.config()); 28 | for (ClusterDetails c : client.clusters().list(new ListClustersRequest().setFilterBy( 29 | new ListClustersFilterBy().setClusterSources(Arrays.asList(ClusterSource.UI, ClusterSource.API))))) { 30 | System.out.println(c.getClusterName()); 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | charset = utf-8 7 | 8 | [*.java] 9 | indent_style = space 10 | indent_size = 4 11 | trim_trailing_whitespace = true 12 | max_line_length = 120 13 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | AAD Auth for Spark Kafka connector 4 | 5 | Copyright (c) Alex Ott. All rights reserved. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/README.md: -------------------------------------------------------------------------------- 1 | # Azure Active Directory authentication for Spark Kafka & EventHubs connectors accessing Event Hubs 2 | 3 | This directory contains a library for using Azure Active Directory tokens to connect to the Azure Event Hubs using the Spark Kafka connector via Kafka protocol or [Spark EventHubs connector](https://github.com/Azure/azure-event-hubs-spark/). 4 | 5 | ## Using Spark Kafka connector on Databricks Runtime version 12.2+ and Delta Live Tables 6 | 7 | Since DBR 12.2, the Apache Kafka client that is used by Databricks Runtime is upgraded to versions that are directly supporting OAuth/OIDC flows, so there is no need to build this library - you just need to correctly configure Kafka consumer (see [blog post](https://alexott.blogspot.com/2023/10/delta-live-tables-recipes-consuming.html)), like this (please note that UC Shared and UC DLT require DBR 14.x): 8 | 9 | ```py 10 | topic = "" 11 | eh_namespace_name = "" 12 | eh_server = f"{eh_namespace_name}.servicebus.windows.net" 13 | 14 | # Data for service principal are stored in the secret scope 15 | tenant_id = dbutils.secrets.get("scope", "tenant_id") 16 | client_id = dbutils.secrets.get("scope", "sp-id") 17 | client_secret = dbutils.secrets.get("scope", "sp-secret") 18 | # Generate 19 | sasl_config = f'kafkashaded.org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required clientId="{client_id}" clientSecret="{client_secret}" scope="https://{eh_server}/.default" ssl.protocol="SSL";' 20 | 21 | # Create Kafka options dictionary 22 | callback_class = "kafkashaded.org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler" 23 | kafka_options = { 24 | "kafka.bootstrap.servers": f"{eh_server}:9093", 25 | "subscribe": topic, 26 | "startingOffsets": "earliest", 27 | "kafka.security.protocol": "SASL_SSL", 28 | "kafka.sasl.mechanism": "OAUTHBEARER", 29 | "kafka.sasl.jaas.config": sasl_config, 30 | "kafka.sasl.oauthbearer.token.endpoint.url": f"https://login.microsoft.com/{tenant_id}/oauth2/v2.0/token", 31 | "kafka.sasl.login.callback.handler.class": callback_class, 32 | } 33 | 34 | df = spark.readStream.format("kafka").options(**kafka_options).load() 35 | ``` 36 | 37 | ## Build the library for DBR runtime versions below 12.2 38 | 39 | You need to select profile to build this project: 40 | 41 | | Profile name | Supported DBR versions | Status | 42 | |--------------|------------------------------|------------| 43 | |`dbr-10.4` | DBR 10.4, DBR 11.3, DBR 12.2 | tested | 44 | |`dbr-9.1` | DBR 9.1 LTS | not tested | 45 | |`dbr-7.3` | DBR 7.3 LTS | not tested | 46 | 47 | 48 | and then use it to build package using Maven: 49 | 50 | ```sh 51 | mvn clean package -P profile-name 52 | ``` 53 | 54 | that will generate a jar file with name `kafka-eventhubs-aad-auth--.jar` in the `target` directory. Add this library to a Databricks cluster or Databricks job. You also need to add `com.microsoft.azure:msal4j:1.10.1` library to a cluster/job, as it isn't shipped together with Databricks runtimes. If you use this library to access via EventHubs protocol, you need to add corresponding library using following coordinates: `com.microsoft.azure:azure-eventhubs-spark_2.12:2.3.22`. 55 | 56 | ## Configure Service Principal 57 | 58 | Right now this library supports following authentication methods: 59 | 60 | 1. using Service Principal client ID & secret 61 | 1. (coming soon) using Service Principal certificate 62 | 63 | Service principal should have following permissions on the specific topic: 64 | 65 | * `Azure Event Hubs Data Receiver` - for consuming from Event Hubs 66 | * `Azure Event Hubs Data Sender` - for writing to Event Hubs 67 | 68 | ## Use with Spark Kafka connector 69 | 70 | All authentication methods should provide following options: 71 | 72 | * `kafka.aad_tenant_id` - required Azure Active Directory Tenant ID 73 | * `kafka.aad_client_id` - Service Principal Client ID (also called "Application ID", **don't use Object ID!**). 74 | * `kafka.aad_authority_endpoint` - optional host name of Azure AD authentication endpoint if you're using special [Azure Cloud (GovCloud, China, Germany)](https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints). The value must contain a protocol and end with `/`. For example: `https://login.microsoftonline.de/`. Default value is `https://login.microsoftonline.com/`. 75 | * `kafka.security.protocol` should be set to `SASL_SSL` 76 | * `kafka.sasl.mechanism` should be set to `OAUTHBEARER` 77 | * `kafka.sasl.jaas.config` should be set to `kafkashaded.org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;` (trailing `;` is required!) 78 | 79 | 80 | ### Authenticating using Service Principal client ID & secret 81 | 82 | We need to provide following options: 83 | 84 | * `kafka.aad_client_secret` - Service Principal Client Secret (it's recommended to store it in the Azure KeyVault-baked secret scope and retrieve it via `dbutils.secrets.get`). 85 | * `kafka.sasl.login.callback.handler.class` should be set to the name of the class: `net.alexott.demos.kafka_aad.ServicePrincipalCredentialsAuth` 86 | 87 | Example: 88 | 89 | ```python 90 | sasl_config = "kafkashaded.org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;" 91 | topic = "..." 92 | tenant_id = "..." 93 | client_id = "..." 94 | client_secret = dbutils.secrets.get("secret_scope", "sp-secret") 95 | 96 | kafka_options = { 97 | "kafka.bootstrap.servers": ".servicebus.windows.net:9093", 98 | "subscribe": topic, 99 | "failOnDataLoss": "false", 100 | "startingOffsets": "earliest", 101 | "kafka.security.protocol": "SASL_SSL", 102 | "kafka.sasl.mechanism": "OAUTHBEARER", 103 | "kafka.sasl.jaas.config": sasl_config, 104 | "kafka.sasl.login.callback.handler.class": "net.alexott.demos.kafka_aad.ServicePrincipalCredentialsAuth", 105 | "kafka.aad_tenant_id": tenant_id, 106 | "kafka.aad_client_id": client_id, 107 | "kafka.aad_client_secret": client_secret, 108 | } 109 | 110 | df = ( 111 | spark 112 | .readStream 113 | .format("kafka") 114 | .options(**kafka_options) 115 | .load() 116 | ) 117 | ``` 118 | 119 | ## Use with EventHubs 120 | 121 | All authentication methods should provide following options: 122 | 123 | * `aad_tenant_id` - required Azure Active Directory Tenant ID 124 | * `aad_client_id` - Service Principal Client ID (also called "Application ID", **don't use Object ID!**). 125 | * `aad_authority_endpoint` - optional host name of Azure AD authentication endpoint if you're using special [Azure Cloud (GovCloud, China, Germany)](https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints). The value must contain a protocol and end with `/`. For example: `https://login.microsoftonline.de/`. Default value is `https://login.microsoftonline.com/`. 126 | * `eventhubs.useAadAuth` - should be set to `True` 127 | * `eventhubs.AadAuthCallbackParams` - is the JSON-encoded dictionary with parameters that will be passed to callback handler's class constructor. 128 | * `eventhubs.aadAuthCallback` - is the class name of specific implementation 129 | 130 | ### Authenticating using Service Principal client ID & secret 131 | 132 | We need to provide additional options: 133 | 134 | * `aad_client_secret` - Service Principal Client Secret (it's recommended to store it in the Azure KeyVault-baked secret scope and retrieve it via `dbutils.secrets.get`). 135 | * `eventhubs.aadAuthCallback` should be set to the name of the class: `net.alexott.demos.eventhubs_aad.ServicePrincipalCredentialsAuth` 136 | 137 | This authentication method needs following parameters in the parameters dictionary: 138 | 139 | * `aad_tenant_id` (described above) 140 | * `aad_client_id` (described above) 141 | * `aad_client_secret` - Service Principal Client Secret (it's recommended to store it in the Azure KeyVault-baked secret scope and retrieve it via `dbutils.secrets.get`). 142 | 143 | Example: 144 | 145 | ```python 146 | topic = "..." 147 | tenant_id = "..." 148 | client_id = "..." 149 | client_secret = dbutils.secrets.get("secret_scope", "sp-secret") 150 | ehs_ns_name = "eventhubs-namespace-name" 151 | callback_class_name = "net.alexott.demos.eventhubs_aad.ServicePrincipalCredentialsAuth" 152 | # Instead of `servicebus.windows.net` there could be regional endpoints 153 | ehs_endpoint = f"sb://{ehs_ns_name}.servicebus.windows.net" 154 | 155 | # EventHubs connection string. 156 | connectionString = f"Endpoint=sb://{ehs_ns_name}.servicebus.windows.net;EntityPath={topic}" 157 | 158 | # Parameters that will be passed to the callback function 159 | params = { 160 | "aad_tenant_id": tenant_id, 161 | "aad_client_id": client_id, 162 | "aad_client_secret": client_secret, 163 | } 164 | 165 | # Spark EventHubs options 166 | ehConf = { 167 | 'eventhubs.connectionString': sc._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(connectionString), 168 | 'eventhubs.useAadAuth': True, 169 | 'eventhubs.aadAuthCallback': callback_class_name, 170 | 'eventhubs.AadAuthCallbackParams': json.dumps(params), 171 | } 172 | 173 | df = spark.readStream.format("eventhubs").options(**ehConf).load() 174 | ``` 175 | 176 | 177 | ## Limitations 178 | 179 | * We currently support only one bootstrap server for Kafka 180 | 181 | 182 | ## Troubleshooting 183 | 184 | ### TopicAuthorizationException: Not authorized to access topics 185 | 186 | Check that service principal has correct permission for all Event Hubs topics as described above 187 | 188 | ### No OAuth Bearer tokens in Subject's private credentials 189 | 190 | Usually is caused the incorrect configuration, check Spark logs to see actual exception. 191 | 192 | 193 | ## Project Support 194 | 195 | Please note that this project is provided for your exploration only, and are not formally supported by Databricks with Service Level Agreements (SLAs). They are provided AS-IS and we do not make any guarantees of any kind. Please do not submit a support ticket relating to any issues arising from the use of this project. 196 | 197 | Any issues discovered through the use of this project should be filed as GitHub Issues on the Repo. They will be reviewed as time permits, but there are no formal SLAs for support. 198 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos 6 | kafka-eventhubs-aad-auth 7 | 0.0.1 8 | jar 9 | 10 | kafka-aad-auth 11 | http://maven.apache.org 12 | 13 | 14 | UTF-8 15 | 1.8 16 | 17 | 18 | 19 | 20 | com.microsoft.azure 21 | msal4j 22 | 1.10.1 23 | 24 | 25 | com.microsoft.azure 26 | azure-eventhubs-spark_2.12 27 | 2.3.22 28 | provided 29 | 30 | 31 | 32 | 33 | 34 | dbr-10.4 35 | 36 | dbr_10.4 37 | 38 | 39 | 40 | org.apache.kafka 41 | kafka-clients 42 | 2.8.0 43 | provided 44 | 45 | 46 | 47 | 48 | dbr-7.3 49 | 50 | dbr_7.3 51 | 52 | 53 | 54 | org.apache.kafka 55 | kafka-clients 56 | 2.4.1 57 | provided 58 | 59 | 60 | 61 | 62 | dbr-9.1 63 | 64 | dbr_9.1 65 | 66 | 67 | 68 | org.apache.kafka 69 | kafka-clients 70 | 2.6.0 71 | provided 72 | 73 | 74 | 75 | 76 | 77 | 78 | ${project.artifactId}-${project.version}-${dbrver} 79 | 80 | 81 | maven-compiler-plugin 82 | 3.8.1 83 | 84 | ${java.version} 85 | ${java.version} 86 | true 87 | 88 | 89 | 90 | org.apache.maven.plugins 91 | maven-resources-plugin 92 | 3.0.2 93 | 94 | UTF-8 95 | 96 | 97 | 98 | org.apache.maven.plugins 99 | maven-shade-plugin 100 | 3.2.0 101 | 102 | 103 | package 104 | 105 | shade 106 | 107 | 108 | 109 | 110 | com.fasterxml.jackson.core:* 111 | com.nimbusds:* 112 | com.microsoft.azure:* 113 | net.minidev:* 114 | com.github.stephenc.jcip:* 115 | org.slf4j:* 116 | 117 | 118 | 119 | 120 | *:* 121 | 122 | META-INF/*.SF 123 | META-INF/*.DSA 124 | META-INF/*.RSA 125 | org/objectweb/** 126 | 127 | 128 | 129 | 130 | 131 | org.apache.kafka 132 | kafkashaded.org.apache.kafka 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/src/main/java/net/alexott/demos/eventhubs_aad/ServicePrincipalAuthBase.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.eventhubs_aad; 2 | 3 | import com.microsoft.aad.msal4j.ClientCredentialParameters; 4 | import com.microsoft.aad.msal4j.ConfidentialClientApplication; 5 | import com.microsoft.aad.msal4j.IAuthenticationResult; 6 | import org.apache.kafka.common.protocol.types.Field; 7 | import org.apache.spark.eventhubs.utils.AadAuthenticationCallback; 8 | import org.slf4j.Logger; 9 | import org.slf4j.LoggerFactory; 10 | 11 | import java.io.IOException; 12 | import java.net.MalformedURLException; 13 | import java.util.Collections; 14 | 15 | import scala.collection.immutable.Map; 16 | import java.util.concurrent.CompletableFuture; 17 | import scala.collection.JavaConverters; 18 | 19 | abstract class ServicePrincipalAuthBase implements AadAuthenticationCallback { 20 | protected volatile ConfidentialClientApplication msalClient; 21 | protected final String authority; // tenant ID 22 | protected String authEndpoint; 23 | protected final String clientId; 24 | 25 | private static final String AAD_TENANT_ID_KEY = "aad_tenant_id"; 26 | private static final String AAD_AUTHORITY_ENDPOINT_KEY = "aad_authority_endpoint"; 27 | private static final String AAD_CLIENT_ID_KEY = "aad_client_id"; 28 | private static final Logger logger = LoggerFactory.getLogger(ServicePrincipalAuthBase.class); 29 | 30 | ServicePrincipalAuthBase(Map params) { 31 | java.util.Map m = JavaConverters.mapAsJavaMapConverter(params).asJava(); 32 | // String mapStr = m.entrySet().stream() 33 | // .map(e -> "'" + e.getKey() + "'='" + e.getValue() + "'") 34 | // .collect(Collectors.joining(", ")); 35 | // logger.info("params: {}", mapStr); 36 | this.authority = m.get(AAD_TENANT_ID_KEY); 37 | this.clientId = m.get(AAD_CLIENT_ID_KEY); 38 | this.authEndpoint = m.getOrDefault(AAD_AUTHORITY_ENDPOINT_KEY, "https://login.microsoftonline.com/"); 39 | if (!this.authEndpoint.endsWith("/")) { 40 | this.authEndpoint += "/"; 41 | } 42 | this.authEndpoint += this.authority; 43 | } 44 | 45 | /** 46 | * Creates an instance of the client that will be used to obtain AAD tokens 47 | * 48 | * @return client that will be used to obtain AAD tokens 49 | * @throws MalformedURLException when authEndpoint URL is malformed 50 | */ 51 | abstract ConfidentialClientApplication getClient() throws MalformedURLException; 52 | 53 | @Override 54 | public String authority() { 55 | return authority; 56 | } 57 | 58 | @Override 59 | public CompletableFuture acquireToken(String audience, String authority, Object state) { 60 | try { 61 | if (this.msalClient == null) { 62 | synchronized (this) { 63 | if (this.msalClient == null) { 64 | this.msalClient = getClient(); 65 | if (this.msalClient == null) { 66 | throw new IOException("Can't create MSAL client"); 67 | } 68 | } 69 | } 70 | } 71 | ClientCredentialParameters ccParams = ClientCredentialParameters.builder( 72 | Collections.singleton(audience + ".default")).build(); 73 | return this.msalClient.acquireToken(ccParams).thenApply(IAuthenticationResult::accessToken); 74 | } catch (IOException ex) { 75 | CompletableFuture failed = new CompletableFuture<>(); 76 | failed.completeExceptionally(ex); 77 | return failed; 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/src/main/java/net/alexott/demos/eventhubs_aad/ServicePrincipalCredentialsAuth.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.eventhubs_aad; 2 | 3 | import com.microsoft.aad.msal4j.ClientCredentialFactory; 4 | import com.microsoft.aad.msal4j.ConfidentialClientApplication; 5 | import com.microsoft.aad.msal4j.IClientCredential; 6 | 7 | 8 | import java.net.MalformedURLException; 9 | import scala.collection.immutable.Map; 10 | 11 | public class ServicePrincipalCredentialsAuth extends ServicePrincipalAuthBase { 12 | private final String clientSecret; 13 | 14 | private static final String AAD_CLIENT_SECRET_KEY = "aad_client_secret"; 15 | 16 | public ServicePrincipalCredentialsAuth(Map params) { 17 | super(params); 18 | clientSecret = params.get(AAD_CLIENT_SECRET_KEY).get(); 19 | } 20 | 21 | @Override 22 | ConfidentialClientApplication getClient() throws MalformedURLException { 23 | IClientCredential credential = ClientCredentialFactory.createFromSecret(this.clientSecret); 24 | return ConfidentialClientApplication.builder(this.clientId, credential) 25 | .authority(this.authEndpoint) 26 | .build(); 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/src/main/java/net/alexott/demos/kafka_aad/KafkaOAuthBearerTokenImp.java: -------------------------------------------------------------------------------- 1 | //Copyright (c) Microsoft Corporation. All rights reserved. 2 | //Licensed under the MIT License. 3 | 4 | package net.alexott.demos.kafka_aad; 5 | 6 | import java.util.Date; 7 | import java.util.Set; 8 | 9 | import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; 10 | 11 | public class KafkaOAuthBearerTokenImp implements OAuthBearerToken 12 | { 13 | String token; 14 | long lifetimeMs; 15 | 16 | public KafkaOAuthBearerTokenImp(final String token, Date expiresOn) { 17 | this.token = token; 18 | this.lifetimeMs = expiresOn.getTime(); 19 | } 20 | 21 | @Override 22 | public String value() { 23 | return this.token; 24 | } 25 | 26 | @Override 27 | public Set scope() { 28 | return null; 29 | } 30 | 31 | @Override 32 | public long lifetimeMs() { 33 | return this.lifetimeMs; 34 | } 35 | 36 | @Override 37 | public String principalName() { 38 | return null; 39 | } 40 | 41 | @Override 42 | public Long startTimeMs() { 43 | return null; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/src/main/java/net/alexott/demos/kafka_aad/ServicePrincipalAuthBase.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.kafka_aad; 2 | 3 | import com.microsoft.aad.msal4j.ClientCredentialParameters; 4 | import com.microsoft.aad.msal4j.ConfidentialClientApplication; 5 | import com.microsoft.aad.msal4j.IAuthenticationResult; 6 | import org.apache.kafka.clients.producer.ProducerConfig; 7 | import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; 8 | import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; 9 | import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | 13 | import javax.security.auth.callback.Callback; 14 | import javax.security.auth.callback.UnsupportedCallbackException; 15 | import java.io.IOException; 16 | import java.net.MalformedURLException; 17 | import java.net.URL; 18 | import java.util.Collections; 19 | import java.util.Map; 20 | import java.util.concurrent.ExecutionException; 21 | import java.util.concurrent.TimeoutException; 22 | 23 | /** 24 | * Base class for Service Principal based authentication 25 | */ 26 | abstract public class ServicePrincipalAuthBase implements AuthenticateCallbackHandler { 27 | protected volatile ConfidentialClientApplication msalClient; 28 | protected ClientCredentialParameters msalParameters; 29 | protected String authEndpoint; 30 | protected String clientId; 31 | 32 | private static final String AAD_TENANT_ID_KEY = "aad_tenant_id"; 33 | private static final String AAD_AUTHORITY_ENDPOINT_KEY = "aad_authority_endpoint"; 34 | private static final String AAD_CLIENT_ID_KEY = "aad_client_id"; 35 | private static final Logger logger = LoggerFactory.getLogger(ServicePrincipalAuthBase.class); 36 | 37 | /** 38 | * Creates an instance of the client that will be used to obtain AAD tokens 39 | * 40 | * @return client that will be used to obtain AAD tokens 41 | * @throws MalformedURLException when authEndpoint URL is malformed 42 | */ 43 | abstract ConfidentialClientApplication getClient() throws MalformedURLException; 44 | 45 | /** 46 | * Extracts common configuration properties, such as, AAD Tenant ID 47 | * 48 | * @param configs Kafka configuration parameters 49 | */ 50 | protected void configureCommon(Map configs) { 51 | this.msalParameters = getMsalParameters(configs); 52 | 53 | String tenantId = configs.get(AAD_TENANT_ID_KEY).toString(); 54 | clientId = configs.get(AAD_CLIENT_ID_KEY).toString(); 55 | Object ob = configs.get(AAD_AUTHORITY_ENDPOINT_KEY); 56 | if (ob == null) { 57 | authEndpoint = "https://login.microsoftonline.com/"; 58 | } else { 59 | authEndpoint = ob.toString(); 60 | } 61 | authEndpoint += tenantId; 62 | } 63 | 64 | /** 65 | * Creates an instance of the parameters for MSAL. Currently only consists of the scope 66 | * 67 | * @param configs Kafka configs map 68 | * @return MSAL properties 69 | */ 70 | static ClientCredentialParameters getMsalParameters(Map configs) { 71 | String bootstrapServer = configs.get(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG).toString(); 72 | bootstrapServer = bootstrapServer.substring(1, bootstrapServer.length() - 2); 73 | bootstrapServer = bootstrapServer.substring(0, bootstrapServer.indexOf(':')); 74 | try { 75 | URL url = new URL("https", bootstrapServer, "/.default"); 76 | return ClientCredentialParameters.builder(Collections.singleton(url.toString())).build(); 77 | } catch (MalformedURLException e) { 78 | logger.info("Exception building scope: ", e); 79 | } 80 | return null; 81 | } 82 | 83 | @Override 84 | public void close() { 85 | } 86 | 87 | @Override 88 | public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { 89 | for (Callback callback : callbacks) { 90 | if (callback instanceof OAuthBearerTokenCallback) { 91 | try { 92 | if (this.msalClient == null) { 93 | synchronized (this) { 94 | if (this.msalClient == null) { 95 | this.msalClient = getClient(); 96 | if (this.msalClient == null) { 97 | throw new IOException("Can't create MSAL client"); 98 | } 99 | } 100 | } 101 | } 102 | IAuthenticationResult authResult = this.msalClient.acquireToken(this.msalParameters).get(); 103 | // logger.debug("Token acquired: {}", authResult.accessToken()); 104 | OAuthBearerToken token = new KafkaOAuthBearerTokenImp(authResult.accessToken(), authResult.expiresOnDate()); 105 | OAuthBearerTokenCallback oauthCallback = (OAuthBearerTokenCallback) callback; 106 | oauthCallback.token(token); 107 | } catch (InterruptedException | ExecutionException e) { 108 | e.printStackTrace(); 109 | } 110 | } else { 111 | throw new UnsupportedCallbackException(callback); 112 | } 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /kafka-eventhubs-aad-auth/src/main/java/net/alexott/demos/kafka_aad/ServicePrincipalCredentialsAuth.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.kafka_aad; 2 | 3 | import com.microsoft.aad.msal4j.ClientCredentialFactory; 4 | import com.microsoft.aad.msal4j.ConfidentialClientApplication; 5 | import com.microsoft.aad.msal4j.IClientCredential; 6 | 7 | import javax.security.auth.login.AppConfigurationEntry; 8 | import java.net.MalformedURLException; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * Class for generation of AAD tokens based on the Service Principal Client ID & Secret 14 | */ 15 | public class ServicePrincipalCredentialsAuth extends ServicePrincipalAuthBase { 16 | private String clientSecret; 17 | 18 | private static final String AAD_CLIENT_SECRET_KEY = "aad_client_secret"; 19 | 20 | @Override 21 | public void configure(Map configs, String mechanism, List jaasConfigEntries) { 22 | configureCommon(configs); 23 | clientSecret = configs.get(AAD_CLIENT_SECRET_KEY).toString(); 24 | } 25 | 26 | @Override 27 | ConfidentialClientApplication getClient() throws MalformedURLException { 28 | IClientCredential credential = ClientCredentialFactory.createFromSecret(this.clientSecret); 29 | return ConfidentialClientApplication.builder(this.clientId, credential) 30 | .authority(this.authEndpoint) 31 | .build(); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /oss-jdbc-aad-token/README.md: -------------------------------------------------------------------------------- 1 | # Querying Databricks via JDBC with OAuth tokens 2 | 3 | This directory contains the example of querying data via Databricks SQL Endpoint or Databricks Cluster using JDBC protocol and Databricks OSS JDBC driver. The JDBC driver supports OAuth authentication. Following examples demonstrate that: 4 | 5 | * `OssJDBCAadTokenUser` - authentication to Databricks is performed using Azure Active Directory tokens issued for the current user. 6 | * `OssJDBCOAuthAzureServicePrincipal` - authentication to Databricks is performed using Azure Active Directory tokens issued for Azure Service Principal. 7 | * `OssJDBCOAuthDatabricksServicePrincipal` - authentication to Databricks is performed using OAuth tokens issued for [Databricks Service Principal](https://learn.microsoft.com/en-us/azure/databricks/admin/users-groups/service-principals). 8 | 9 | ## Adjusting parameters 10 | 11 | Right now many parameters are hard-coded, so you need to change them before running. OSS JDBC driver can also pickup [authentication parameters defined as environment variables](https://learn.microsoft.com/en-us/azure/databricks/dev-tools/auth/). 12 | 13 | ### For `OssJDBCOAuthAzureServicePrincipal` 14 | 15 | You need to update source code and adjust following parameters: 16 | 17 | * `query` - what query should be executed 18 | * `clientId` - client ID of application in Azure Active Directory 19 | * `clientSecret` - secret for AAD application (ideally should be taken from KeyVault); 20 | * `host` - the host portion of the Databricks workspace (obtained from SQL Warehouse configuraiton) 21 | * `httpPath` - the HTTP Path of the SQL Warehouse (obtained from SQL Warehouse configuraiton) 22 | 23 | ### For `OssJDBCOAuthDatabricksServicePrincipal` 24 | 25 | You need to update source code and adjust following parameters: 26 | 27 | * `query` - what query should be executed 28 | * `clientId` - client ID of Databricks service principal 29 | * `clientSecret` - secret for Databricks service principal 30 | * `host` - the host portion of the Databricks workspace (obtained from SQL Warehouse configuraiton) 31 | * `httpPath` - the HTTP Path of the SQL Warehouse (obtained from SQL Warehouse configuraiton) 32 | 33 | ### For `OssJDBCAadTokenUser` 34 | 35 | * `query` - what query should be executed 36 | * `host` - the host portion of the Databricks workspace (obtained from SQL Warehouse configuraiton) 37 | * `httpPath` - the HTTP Path of the SQL Warehouse (obtained from SQL Warehouse configuraiton) 38 | * `oauthClientId` - (optional) if you don't have Azure application with name `databricks-sql-jdbc`, then set it to the Application ID of the Azure application that will be used for authentication. 39 | 40 | ## Build & run 41 | 42 | Just execute `mvn package` to build the code, and then you can execute resulting uber jar: 43 | 44 | ```sh 45 | java -cp target/oss-jdbc-aad-token-0.0.2-jar-with-dependencies.jar \ 46 | net.alexott.demos.OssJDBCOAuthAzureServicePrincipal 47 | ``` 48 | 49 | Or the code could be executed from an IDE. 50 | -------------------------------------------------------------------------------- /oss-jdbc-aad-token/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos 6 | oss-jdbc-aad-token 7 | 0.0.2 8 | jar 9 | 10 | 11 | UTF-8 12 | 11 13 | 11 14 | 11 15 | 16 | 17 | 18 | 19 | 20 | com.databricks 21 | databricks-jdbc 22 | 0.9.8-oss 23 | 24 | 25 | 26 | 27 | 28 | 29 | maven-compiler-plugin 30 | 3.8.1 31 | 32 | ${java.version} 33 | ${java.version} 34 | true 35 | 36 | 37 | 38 | org.apache.maven.plugins 39 | maven-assembly-plugin 40 | 3.2.0 41 | 42 | 43 | jar-with-dependencies 44 | 45 | 46 | 47 | 48 | package 49 | 50 | single 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /oss-jdbc-aad-token/src/main/java/net/alexott/demos/OssJDBCAadTokenUser.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class OssJDBCAadTokenUser { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String host = "adb-.....17.azuredatabricks.net"; 15 | String httpPath = "/sql/1.0/warehouses/...."; 16 | String oauthClientId = ""; 17 | 18 | boolean enableTokenCache = System.getProperty("enableTokenCache", "false") 19 | .equalsIgnoreCase("true"); 20 | String tokenCachePassPhrase = "1234567"; 21 | String jdbcString = String.format("jdbc:databricks://%s:443;httpPath=%s;AuthMech=11;Auth_Flow=2", 22 | host, httpPath); 23 | if (!oauthClientId.isEmpty()) { 24 | jdbcString = String.format("%s;OAuth2ClientId=%s", jdbcString, oauthClientId); 25 | } 26 | if (enableTokenCache) { 27 | jdbcString = String.format("%s;TokenCachePassPhrase=%s;EnableTokenCache=1", 28 | jdbcString, tokenCachePassPhrase); 29 | } else { 30 | jdbcString += ";EnableTokenCache=0"; 31 | } 32 | 33 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 34 | 35 | Class.forName(JDBC_DRIVER); 36 | System.out.println("Getting connection"); 37 | try (Connection conn = DriverManager.getConnection(jdbcString); 38 | Statement stmt = conn.createStatement()) { 39 | System.out.println("Going to execute query"); 40 | try (ResultSet rs = stmt.executeQuery(query)) { 41 | System.out.println("Query is executed"); 42 | ResultSetMetaData md = rs.getMetaData(); 43 | String[] columns = new String[md.getColumnCount()]; 44 | for (int i = 0; i < columns.length; i++) { 45 | columns[i] = md.getColumnName(i + 1); 46 | } 47 | while (rs.next()) { 48 | System.out.print("Row " + rs.getRow() + "=["); 49 | for (int i = 0; i < columns.length; i++) { 50 | if (i != 0) { 51 | System.out.print(", "); 52 | } 53 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 54 | } 55 | System.out.println(")]"); 56 | } 57 | } 58 | } 59 | System.exit(0); 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /oss-jdbc-aad-token/src/main/java/net/alexott/demos/OssJDBCOAuthAzureServicePrincipal.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class OssJDBCOAuthAzureServicePrincipal { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String clientId = "..."; 15 | String clientSecret = "..."; 16 | String host = "adb-....azuredatabricks.net"; 17 | String httpPath = "/sql/1.0/warehouses/..."; 18 | String jdbcString = String.format("jdbc:databricks://%s:443/default;transportMode=http;ssl=1;httpPath=%s;AuthMech=11;Auth_Flow=1;OAuth2ClientId=%s;OAuth2Secret=%s", 19 | host, httpPath, clientId, clientSecret); 20 | 21 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 22 | 23 | Class.forName(JDBC_DRIVER); 24 | System.out.println("Getting JDBC connection"); 25 | try (Connection conn = DriverManager.getConnection(jdbcString); 26 | Statement stmt = conn.createStatement()) { 27 | System.out.println("Going to execute query"); 28 | try (ResultSet rs = stmt.executeQuery(query)) { 29 | System.out.println("Query is executed"); 30 | ResultSetMetaData md = rs.getMetaData(); 31 | String[] columns = new String[md.getColumnCount()]; 32 | for (int i = 0; i < columns.length; i++) { 33 | columns[i] = md.getColumnName(i + 1); 34 | } 35 | while (rs.next()) { 36 | System.out.print("Row " + rs.getRow() + "=["); 37 | for (int i = 0; i < columns.length; i++) { 38 | if (i != 0) { 39 | System.out.print(", "); 40 | } 41 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 42 | } 43 | System.out.println(")]"); 44 | } 45 | } 46 | } 47 | System.exit(0); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /oss-jdbc-aad-token/src/main/java/net/alexott/demos/OssJDBCOAuthDatabricksServicePrincipal.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class OssJDBCOAuthDatabricksServicePrincipal { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String clientId = "..."; 15 | String clientSecret = "..."; 16 | String host = "adb-...azuredatabricks.net"; 17 | String httpPath = "/sql/1.0/warehouses/...."; 18 | String jdbcString = String.format("jdbc:databricks://%s:443/default;transportMode=http;ssl=1;httpPath=%s;AuthMech=11;Auth_Flow=1;OAuth2ClientId=%s;OAuth2Secret=%s", 19 | host, httpPath, clientId, clientSecret); 20 | 21 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 22 | 23 | Class.forName(JDBC_DRIVER); 24 | System.out.println("Getting JDBC connection"); 25 | try (Connection conn = DriverManager.getConnection(jdbcString); 26 | Statement stmt = conn.createStatement()) { 27 | System.out.println("Going to execute query"); 28 | try (ResultSet rs = stmt.executeQuery(query)) { 29 | System.out.println("Query is executed"); 30 | ResultSetMetaData md = rs.getMetaData(); 31 | String[] columns = new String[md.getColumnCount()]; 32 | for (int i = 0; i < columns.length; i++) { 33 | columns[i] = md.getColumnName(i + 1); 34 | } 35 | while (rs.next()) { 36 | System.out.print("Row " + rs.getRow() + "=["); 37 | for (int i = 0; i < columns.length; i++) { 38 | if (i != 0) { 39 | System.out.print(", "); 40 | } 41 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 42 | } 43 | System.out.println(")]"); 44 | } 45 | } 46 | } 47 | System.exit(0); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /pause-unpause-jobs/README.md: -------------------------------------------------------------------------------- 1 | # Tool to bulk pausing/unpausing of Databricks Workflows 2 | 3 | This directory contains the `pause_unpause_jobs.py` script that allows to pause all Workflows in the Databricks workspace, and store the list in the config file. And then it's possible to unpause them based on the stored data. 4 | 5 | ## Installation 6 | 7 | You need to have [Databricks SDK for Python](https://pypi.org/project/databricks-sdk/) installed to run this tool. Do 8 | 9 | ```sh 10 | pip install databricks-sdk 11 | ``` 12 | 13 | to install SDK 14 | 15 | 16 | ## Running 17 | 18 | You must configure environment variables to perform [authentication](https://pypi.org/project/databricks-sdk/#authentication) to a Databricks workspace where work will be done. 19 | 20 | Use `-h` command-line option to get help on supported commands: 21 | 22 | ``` 23 | usage: pause_unpause_jobs.py [-h] [--file FILE] {scan,pause,unpause} 24 | 25 | Pause or unpause Databricks jobs with schedules or triggers 26 | 27 | positional arguments: 28 | {scan,pause,unpause} 29 | 30 | options: 31 | -h, --help show this help message and exit 32 | --file FILE File to store paused jobs (default: paused_jobs.json) 33 | ``` 34 | 35 | There are three commands supported 36 | 37 | * `scan` just prints a list jobs that have a trigger or a schedule in the `UNPAUSED` status. 38 | * `pause` finds all workflows with `UNPAUSED` status and pause them. Information about paused tasks is stored in the file that then will be used with `unpause` command. 39 | * `unpause` reads the previously generated file and unpause paused Workflows. 40 | 41 | 42 | Optionally, you can use `--file` command line option to change the file name where we store data about paused tasks (default file name is `paused_jobs.json`). 43 | 44 | 45 | 46 | ## TODOs 47 | 48 | 49 | - \[ \] Better error handling - for example, if a specific Databricks workflow is deleted. 50 | -------------------------------------------------------------------------------- /pause-unpause-jobs/pause_unpause_jobs.py: -------------------------------------------------------------------------------- 1 | from databricks.sdk import WorkspaceClient 2 | from databricks.sdk.service.jobs import PauseStatus, JobSettings, CronSchedule, TriggerSettings 3 | import argparse 4 | import json 5 | import os 6 | 7 | default_file_name = "paused_jobs.json" 8 | 9 | 10 | def find_jobs_with_schedule(wc: WorkspaceClient) -> dict[int, dict]: 11 | jobs_with_schedule_or_trigger = {} 12 | jobs = wc.jobs.list() 13 | for job in jobs: 14 | if job.settings.schedule and job.settings.schedule.pause_status == PauseStatus.UNPAUSED: 15 | # print(f"Job {job.settings.name} has a schedule. {job.settings.schedule.pause_status}") 16 | jobs_with_schedule_or_trigger[job.job_id] = { 17 | "name": job.settings.name, 18 | "type": "schedule", 19 | "data": job.settings.schedule.as_dict() 20 | } 21 | elif job.settings.trigger and job.settings.trigger.pause_status == PauseStatus.UNPAUSED: 22 | # print(f"Job {job.settings.name} has a trigger: {job.settings.trigger.pause_status}") 23 | jobs_with_schedule_or_trigger[job.job_id] = { 24 | "name": job.settings.name, 25 | "type": "trigger", 26 | "data": job.settings.trigger.as_dict() 27 | } 28 | 29 | return jobs_with_schedule_or_trigger 30 | 31 | 32 | def switch_pause_status(wc: WorkspaceClient, jid:int, job_config: dict, pause: bool): 33 | new_settings = JobSettings() 34 | pause_status = PauseStatus.PAUSED if pause else PauseStatus.UNPAUSED 35 | if job_config['type'] == 'schedule': 36 | new_settings.schedule = CronSchedule.from_dict(job_config['data']) 37 | new_settings.schedule.pause_status = pause_status 38 | elif job_config['type'] == 'trigger': 39 | new_settings.trigger = TriggerSettings.from_dict(job_config['data']) 40 | new_settings.trigger.pause_status = pause_status 41 | else: 42 | print(f"Unknown job type {job_config['type']} for job {jid}") 43 | return 44 | wc.jobs.update(jid, new_settings=new_settings) 45 | 46 | 47 | if __name__ == '__main__': 48 | # parse arguments for the script that supports three commands: list, pause, unpause 49 | parser = argparse.ArgumentParser(description='Pause or unpause Databricks jobs with schedules or triggers') 50 | supported_commands = ['scan', 'pause', 'unpause'] 51 | parser.add_argument('command', choices=supported_commands) 52 | parser.add_argument('--file', default=default_file_name, 53 | help=f'File to store paused jobs (default: {default_file_name})') 54 | args = parser.parse_args() 55 | if args.command not in supported_commands: 56 | print(f"Unknown command {args.command}") 57 | exit(1) 58 | wc = WorkspaceClient() 59 | if args.command == 'scan': 60 | jobs = find_jobs_with_schedule(wc) 61 | if jobs: 62 | print("Jobs with schedules or triggers:") 63 | for k, v in jobs.items(): 64 | print(f"Job ID: {k}, Name: {v['name']}, Type: {v['type']}") 65 | elif args.command == 'pause': 66 | print("Pausing jobs...") 67 | jobs = find_jobs_with_schedule(wc) 68 | print(f"Found {len(jobs)} jobs with schedules or triggers. Pausing them...") 69 | for job_id, job in jobs.items(): 70 | print(f"Pausing job '{job['name']}' with ID {job_id}") 71 | switch_pause_status(wc, job_id, job, True) 72 | with open(args.file, "w") as f: 73 | f.write(json.dumps(jobs)) 74 | print(f"Jobs are paused. The list of paused jobs is saved in {args.file}") 75 | elif args.command == 'unpause': 76 | print(f"Unpausing jobs from {args.file}") 77 | with open(args.file, "r") as f: 78 | jobs = json.load(f) 79 | for job_id, job in jobs.items(): 80 | print(f"Unpausing job '{job['name']}' with ID {job_id}") 81 | switch_pause_status(wc, int(job_id), job, False) 82 | print(f"Jobs are unpaused. Removing {args.file}...") 83 | os.remove(args.file) 84 | 85 | 86 | -------------------------------------------------------------------------------- /pyspark-snippets/README.md: -------------------------------------------------------------------------------- 1 | # pyspark-snippets 2 | 3 | This directory contains a number of functions that simplify development of PySpark code for Databricks. 4 | 5 | 6 | ## Building 7 | 8 | ```sh 9 | python setup.py clean --all && python setup.py bdist_wheel --universal 10 | ``` 11 | 12 | ## Testing 13 | 14 | You need to install packages that are necessary for execution of tests: 15 | 16 | ``` 17 | pip install -U -r unit-requirements.txt 18 | ``` 19 | 20 | ### Unit testing 21 | 22 | Just execute: 23 | 24 | ``` 25 | pytest tests/unit 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /pyspark-snippets/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexott/databricks-playground/3bb1b2fb5e21730704b6c905f78324026ee1cfed/pyspark-snippets/conftest.py -------------------------------------------------------------------------------- /pyspark-snippets/pyspark_helpers/columns_helpers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional 2 | 3 | import pyspark.sql.functions as F 4 | from pyspark.sql import DataFrame, Column 5 | 6 | 7 | def columns_except(df: DataFrame, ex: Optional[List[str]] = None, as_column: bool = False) -> List[Union[str, Column]]: 8 | """ 9 | Returns a new list of columns without specified columns 10 | :param df: dataframe 11 | :param ex: columns to exclude 12 | :param as_column: if we should return list of columns instead of list of strings 13 | :return: new list of columns 14 | """ 15 | if ex is None: 16 | ex = [] 17 | 18 | return [F.col(cl) if as_column else cl 19 | for cl in df.columns if cl not in ex] 20 | 21 | 22 | def dataframe_except_columns(df: DataFrame, ex: Optional[List[str]] = None) -> DataFrame: 23 | """ 24 | Creates a new dataframe without specified columns 25 | :param df: dataframe 26 | :param ex: columns to exclude 27 | :return: new dataframe 28 | """ 29 | return df.select(*columns_except(df, ex, as_column=True)) 30 | 31 | 32 | def add_missing_columns(df1: DataFrame, df2: DataFrame) -> DataFrame: 33 | """ 34 | Adds to first dataframe columns from the second dataframe that don't exist in first one. 35 | Columns get null values casted to respective data types 36 | :param df1: first dataframe 37 | :param df2: second dataframe 38 | :return: new dataframe with added columns 39 | """ 40 | additional_cols = [F.lit(None).cast(field.dataType).alias(field.name) 41 | for field in df2.schema.fields if field.name not in df1.columns] 42 | return df1.select("*", *additional_cols) 43 | -------------------------------------------------------------------------------- /pyspark-snippets/pyspark_helpers/data_skipping.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import pyspark.sql.types as T 4 | from pyspark.sql import DataFrame, SparkSession 5 | 6 | 7 | def _is_indexable_column(typ) -> bool: 8 | """Returns true if the column is indexable for Delta Data skipping 9 | :param typ: Spark Type object 10 | :return: true if column is indexable 11 | """ 12 | return ( 13 | typ == T.IntegerType() 14 | or typ == T.TimestampType() 15 | or typ == T.DateType() 16 | or typ == T.DoubleType() 17 | or typ == T.FloatType() 18 | or typ == T.LongType() 19 | ) 20 | 21 | 22 | def reorder_columns(df: DataFrame, first_columns: Optional[List[str]] = None, 23 | partition_columns: Optional[List[str]] = None, 24 | automatic_indexing: bool = True): 25 | """Reorders columns of the dataframe to make them indexable for Delta Data Skipping. Besides the 26 | columns specified by ``first_columns`` parameter, all time & numeric columns are moved forward. 27 | On Databricks, it also sets the ``spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols`` 28 | configuration, so when you're writing data first time, it will set ``delta.dataSkippingNumIndexedCols`` 29 | property on the table. For OSS, this property needs to be set explicitly. 30 | 31 | :param df: dataframe to process 32 | :param first_columns: list of additional columns that needs to be moved first 33 | :param partition_columns: list of columns that will be used for partitioning 34 | :param automatic_indexing: flag defining if we should automatically select columns for indexing ( 35 | :return: modified dataframe 36 | """ 37 | if first_columns is None: 38 | first_columns = [] 39 | if partition_columns is None: 40 | partition_columns = [] 41 | not_first_cols = [ 42 | field for field in df.schema.fields if field.name not in first_columns 43 | ] 44 | if automatic_indexing: 45 | indexable_cols = [ 46 | field.name for field in not_first_cols if _is_indexable_column(field.dataType) 47 | ] 48 | non_indexable_cols = [ 49 | field.name for field in not_first_cols if not _is_indexable_column(field.dataType) 50 | ] 51 | else: 52 | indexable_cols = [] 53 | non_indexable_cols = [field.name for field in not_first_cols] 54 | 55 | # Correct number of columns to index if column(s) is used for partitioning 56 | cols_len = len(first_columns + indexable_cols) - \ 57 | len(set(partition_columns).intersection(set(first_columns + indexable_cols))) 58 | # TODO: think how this will be handled when doing multiple reorders inside the pipeline 59 | SparkSession.getActiveSession().conf.set( 60 | "spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols", str(cols_len), 61 | ) 62 | return df.select(*first_columns, *indexable_cols, *non_indexable_cols) 63 | -------------------------------------------------------------------------------- /pyspark-snippets/pyspark_helpers/delta_helpers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | 3 | from delta.tables import DeltaTable 4 | from pyspark.sql import DataFrame, SparkSession 5 | 6 | 7 | def merge_into_or_create( 8 | df: DataFrame, 9 | primary_key_columns: List[str], 10 | path: str = "", 11 | table_name: str = "", 12 | partitionby: Optional[List[str]] = None, 13 | opts: Optional[Dict[str, Any]] = None, 14 | ): 15 | """Performs merging of data into given Delta table, inserting the new data & updating existing data. 16 | Creates a new table if it doesn't exist 17 | 18 | :param df: new dataframe 19 | :param primary_key_columns: required list of the column names that are used as primary key of table 20 | :param path: optional path to table (required if table_name isn't specified) 21 | :param table_name: optional name of the table (required if path isn't specified) 22 | :param partitionby: optional list of columns to partition by 23 | :param opts: optional dictionary with options for creation of Delta table 24 | :return: nothing 25 | """ 26 | # print(f"Performing merge for {path=} or {table_name=}") 27 | if opts is None: 28 | opts = {} 29 | if partitionby is None: 30 | partitionby = [] 31 | df = df.dropDuplicates(primary_key_columns) 32 | if path == "" and table_name == "": 33 | raise Exception( 34 | "At least one parameter, 'path' or 'table_name' must be specified" 35 | ) 36 | if not df._jdf.isEmpty(): 37 | try: 38 | spark = SparkSession.getActiveSession() 39 | if table_name != "": 40 | tbl = DeltaTable.forName(spark, table_name) 41 | else: 42 | tbl = DeltaTable.forPath(spark, path) 43 | dname = "dests" 44 | uname = "updates" 45 | merge_cond = " and ".join( 46 | [f"{dname}.{col} <=> {uname}.{col}" for col in primary_key_columns] 47 | ) 48 | tbl.alias(dname).merge( 49 | df.alias(uname), merge_cond 50 | ).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute() 51 | # except AnalysisException as ex: # this happens when table doesn't exist 52 | except: # this happens when table doesn't exist 53 | # print(f"Delta table ({path=}, {table_name=}) doesn't exist, writing all data as new table...") 54 | if table_name != "": 55 | if path != "": 56 | opts["path"] = path 57 | df.write.format("delta").partitionBy(partitionby).options( 58 | **opts 59 | ).saveAsTable(table_name) 60 | else: 61 | df.write.format("delta").partitionBy(partitionby).options(**opts).save( 62 | path 63 | ) 64 | -------------------------------------------------------------------------------- /pyspark-snippets/pyspark_helpers/drop_duplicates.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | 3 | from delta.tables import DeltaTable 4 | from pyspark.sql import DataFrame, SparkSession 5 | 6 | 7 | def drop_duplicates_with_merge( 8 | df: DataFrame, 9 | primary_key_columns: List[str], 10 | path: str = "", 11 | table_name: str = "", 12 | partitionby: Optional[List[str]] = None, 13 | opts: Optional[Dict[str, Any]] = None, 14 | additional_merge_cond: Optional[str] = None, 15 | ): 16 | """Performs removal of duplicates using the Delta MERGE operation. If table doesn't exist, 17 | it's created by writing the dataframe into a specified location. This function is primarily 18 | designed for use in ``foreachBatch``. 19 | 20 | :param df: new dataframe 21 | :param primary_key_columns: required list of the column names that are used for detection of duplicates 22 | :param path: optional path to table (required if table_name isn't specified) 23 | :param table_name: optional name of the table (required if path isn't specified) 24 | :param partitionby: optional list of columns to partition by 25 | :param opts: optional dictionary with options for creation of Delta table 26 | :param additional_merge_cond: additional merge condition appended to the generated condition 27 | using ``AND``. The destination columns are prefixed as ``dest`` & new data as ``update``. 28 | :return: nothing 29 | """ 30 | # print(f"Performing merge for {path=} or {table_name=}") 31 | if opts is None: 32 | opts = {} 33 | if partitionby is None: 34 | partitionby = [] 35 | df = df.dropDuplicates(primary_key_columns) 36 | if path == "" and table_name == "": 37 | raise Exception( 38 | "At least one parameter, 'path' or 'table_name' must be specified" 39 | ) 40 | if not df._jdf.isEmpty(): 41 | try: 42 | spark = SparkSession.getActiveSession() 43 | if table_name != "": 44 | tbl = DeltaTable.forName(spark, table_name) 45 | else: 46 | tbl = DeltaTable.forPath(spark, path) 47 | dname = "dest" 48 | uname = "update" 49 | merge_cond = " and ".join( 50 | [f"{dname}.{col} <=> {uname}.{col}" for col in primary_key_columns] 51 | ) 52 | if additional_merge_cond: 53 | merge_cond = merge_cond + " AND " + additional_merge_cond 54 | tbl.alias(dname).merge( 55 | df.alias(uname), merge_cond 56 | ).whenNotMatchedInsertAll().execute() 57 | # except AnalysisException as ex: # this happens when table doesn't exist 58 | except: # this happens when table doesn't exist 59 | # print(f"Delta table ({path=}, {table_name=}) doesn't exist, writing all data as new table...") 60 | if table_name != "": 61 | if path != "": 62 | opts["path"] = path 63 | df.write.format("delta").partitionBy(partitionby).options( 64 | **opts 65 | ).saveAsTable(table_name) 66 | else: 67 | df.write.format("delta").partitionBy(partitionby).options(**opts).save( 68 | path 69 | ) 70 | 71 | 72 | def drop_duplicates_builtin( 73 | df: DataFrame, 74 | primary_key_columns: List[str], 75 | watermark_column: Optional[str] = None, 76 | watermark_time: Optional[str] = None, 77 | ): 78 | """Performs deletion of duplicates on the given dataframe using the `.dropDuplicates` function. 79 | :param df: dataframe to process 80 | :param primary_key_columns: required list of the column names that are used for detection of duplicates 81 | :param watermark_column: optional column name that will be used for watermark in the streaming mode. 82 | :param watermark_time: 83 | :return: modified dataframe 84 | """ 85 | if "watermark_column" and "watermark_time": 86 | if not df.isStreaming: 87 | raise Exception("Can't set watermark on the non-streaming dataframe") 88 | df = df.withWatermark(watermark_column, watermark_time) 89 | return df.dropDuplicates(primary_key_columns) 90 | -------------------------------------------------------------------------------- /pyspark-snippets/pyspark_helpers/json_helpers.py: -------------------------------------------------------------------------------- 1 | import pyspark.sql.functions as F 2 | from pyspark.sql import SparkSession 3 | 4 | 5 | def schema_of_json(line: str) -> str: 6 | """ 7 | Returns a Spark DDL string 8 | :param line: string with JSON data 9 | :return: Spark DDL string 10 | """ 11 | df = SparkSession.getActiveSession().range(1).select(F.schema_of_json(F.lit(line))) 12 | return df.collect()[0][0] 13 | -------------------------------------------------------------------------------- /pyspark-snippets/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -s -p no:warnings 3 | log_cli = 1 4 | log_cli_level = INFO 5 | log_cli_format = [pytest][%(asctime)s][%(levelname)s][%(module)s][%(funcName)s] %(message)s 6 | log_cli_date_format = %Y-%m-%d %H:%M:%S 7 | log_level = INFO 8 | spark_options = 9 | spark.sql.catalogImplementation: in-memory -------------------------------------------------------------------------------- /pyspark-snippets/requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark==3.2.1 2 | delta-spark==1.2.0 3 | -------------------------------------------------------------------------------- /pyspark-snippets/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | __version__ = "0.0.1" 4 | 5 | INSTALL_REQUIRES=[] 6 | with open("requirements.txt", "r") as file: 7 | for l in file: 8 | l = l.strip() 9 | if l != "" and l[0] != "#": 10 | INSTALL_REQUIRES.append(l) 11 | 12 | setup( 13 | name="databricks-pyspark-helpers", 14 | packages=find_packages(exclude=["tests", "tests.*"]), 15 | setup_requires=["wheel"], 16 | install_requires=INSTALL_REQUIRES, 17 | version=__version__, 18 | description="Useful PySpark/Databricks functions", 19 | author="Alex Ott " 20 | ) 21 | -------------------------------------------------------------------------------- /pyspark-snippets/tests/unit/test_columns_helpers.py: -------------------------------------------------------------------------------- 1 | from chispa.dataframe_comparer import * 2 | from pyspark.sql import SparkSession 3 | 4 | from pyspark_helpers.columns_helpers import * 5 | 6 | 7 | def test_columns_except(spark_session: SparkSession): 8 | original_df = spark_session.createDataFrame( 9 | [[1, 2, 3, 4]], schema="col1 int, col2 int, col3 int, col4 int") 10 | new_cols = columns_except(original_df, ["col2", "col4"]) 11 | assert new_cols == ["col1", "col3"] 12 | 13 | 14 | def test_dataframe_except_columns(spark_session: SparkSession): 15 | original_df = spark_session.createDataFrame( 16 | [[1, 2, 3, 4]], schema="col1 int, col2 int, col3 int, col4 int") 17 | new_df = dataframe_except_columns(original_df, ["col2", "col4"]) 18 | expected_df = spark_session.createDataFrame([[1, 3]], schema="col1 int, col3 int") 19 | assert_df_equality(new_df, expected_df, ignore_nullable=True) 20 | 21 | 22 | def test_add_missing_columns(spark_session: SparkSession): 23 | df1 = spark_session.createDataFrame([[1, 2]], schema="col1 int, col2 int") 24 | df2 = spark_session.createDataFrame([[1, "2", 3.0]], schema="col1 int, col4 string, col5 double") 25 | new_df = add_missing_columns(df1, df2) 26 | expected_df = spark_session.createDataFrame([[1, 2, None, None]], 27 | schema="col1 int, col2 int, col4 string, col5 double") 28 | assert_df_equality(new_df, expected_df, ignore_nullable=True) 29 | -------------------------------------------------------------------------------- /pyspark-snippets/tests/unit/test_data_skipping.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import SparkSession 2 | 3 | from pyspark_helpers.data_skipping import reorder_columns 4 | 5 | 6 | def test_data_skipping1(spark_session: SparkSession): 7 | original_df = spark_session.createDataFrame( 8 | [], schema="col1 string, col2 int, col3 string, col4 date") 9 | processed_df = reorder_columns(original_df) 10 | assert processed_df.columns == ["col2", "col4", "col1", "col3"] 11 | assert spark_session.conf.get( 12 | "spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols") == "2" 13 | 14 | 15 | def test_data_skipping2(spark_session: SparkSession): 16 | original_df = spark_session.createDataFrame( 17 | [], schema="col1 string, col2 int, col3 string, col4 date") 18 | processed_df = reorder_columns(original_df, first_columns=["col3"]) 19 | assert processed_df.columns == ["col3", "col2", "col4", "col1"] 20 | assert spark_session.conf.get( 21 | "spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols") == "3" 22 | 23 | 24 | def test_data_skipping3(spark_session: SparkSession): 25 | original_df = spark_session.createDataFrame( 26 | [], schema="col1 string, col2 int, col3 string, col4 date") 27 | processed_df = reorder_columns(original_df, partition_columns=["col4"]) 28 | assert processed_df.columns == ["col2", "col4", "col1", "col3"] 29 | assert spark_session.conf.get( 30 | "spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols") == "1" 31 | 32 | 33 | def test_data_skipping_no_automatic(spark_session: SparkSession): 34 | original_df = spark_session.createDataFrame( 35 | [], schema="col1 string, col2 int, col3 string, col4 date") 36 | processed_df = reorder_columns(original_df, first_columns=["col3", "col1"], 37 | automatic_indexing=False) 38 | assert processed_df.columns == ["col3", "col1", "col2", "col4"] 39 | assert spark_session.conf.get( 40 | "spark.databricks.delta.properties.defaults.dataSkippingNumIndexedCols") == "2" 41 | 42 | -------------------------------------------------------------------------------- /pyspark-snippets/unit-requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | wheel 3 | pytest 4 | pytest-cov 5 | pytest-spark 6 | chispa 7 | -------------------------------------------------------------------------------- /simba-jdbc-aad-token/README.md: -------------------------------------------------------------------------------- 1 | # Querying Databricks via JDBC with AAD token 2 | 3 | This directory contains the example of querying data via Databricks SQL Endpoint or Databricks Cluster using JDBC protocol. Since version 2.6.36 the JDBC driver supports OAuth authentication. Following examples demonstrate that: 4 | 5 | * `SimbaJDBCAadTokenUser` - authentication to Databricks is performed using Azure Active Directory tokens issued for the current user. 6 | * `SimbaJDBCAadTokenServicePrincipal` - authentication to Databricks is performed using Azure Active Directory tokens issued for Azure Service Principal. 7 | 8 | ## Adjusting parameters 9 | 10 | Right now many parameters are hard-coded, so you need to change them before running. 11 | 12 | ### For `SimbaJDBCAadTokenServicePrincipal` 13 | 14 | You need to update source code and adjust following parameters: 15 | 16 | * `query` - what query should be executed 17 | * `clientId` - client ID of application in Azure Active Directory 18 | * `clientSecret` - secret for AAD application (ideally should be taken from KeyVault); 19 | * `host` - the host portion of the Databricks workspace (obtained from SQL Warehouse configuraiton) 20 | * `httpPath` - the HTTP Path of the SQL Warehouse (obtained from SQL Warehouse configuraiton) 21 | 22 | ### For `SimbaJDBCAadTokenUser` 23 | 24 | * `query` - what query should be executed 25 | * `host` - the host portion of the Databricks workspace (obtained from SQL Warehouse configuraiton) 26 | * `httpPath` - the HTTP Path of the SQL Warehouse (obtained from SQL Warehouse configuraiton) 27 | * `oauthClientId` - (optional) if you don't have Azure application with name `databricks-sql-jdbc`, then set it to the Application ID of the Azure application that will be used for authentication. 28 | 29 | ## Build & run 30 | 31 | Just execute `mvn package` to build the code, and then you can execute resulting uber jar: 32 | 33 | ```sh 34 | java -cp target/simba-jdbc-aad-token-0.0.2-jar-with-dependencies.jar \ 35 | net.alexott.demos.SimbaJDBCAadTokenServicePrincipal 36 | ``` 37 | 38 | Or the code could be executed from an IDE. 39 | -------------------------------------------------------------------------------- /simba-jdbc-aad-token/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | net.alexott.demos 6 | simba-jdbc-aad-token 7 | 0.0.2 8 | jar 9 | 10 | 11 | UTF-8 12 | 1.8 13 | 14 | 15 | 16 | 17 | 18 | com.databricks 19 | databricks-jdbc 20 | 2.7.1 21 | 22 | 23 | 24 | 25 | 26 | 27 | maven-compiler-plugin 28 | 3.8.1 29 | 30 | ${java.version} 31 | ${java.version} 32 | true 33 | 34 | 35 | 36 | org.apache.maven.plugins 37 | maven-assembly-plugin 38 | 3.2.0 39 | 40 | 41 | jar-with-dependencies 42 | 43 | 44 | 45 | 46 | package 47 | 48 | single 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /simba-jdbc-aad-token/src/main/java/net/alexott/demos/SimbaJDBCAadTokenMI.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class SimbaJDBCAadTokenMI { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String host = "adb-.....17.azuredatabricks.net"; 15 | String httpPath = "/sql/1.0/warehouses/...."; 16 | String oauthClientId = System.getenv("OAUTH_CLIENT_ID"); 17 | 18 | String jdbcString = String.format("jdbc:databricks://%s:443;httpPath=%s;AuthMech=11;Auth_Flow=3", 19 | host, httpPath); 20 | if (oauthClientId != null && !oauthClientId.isEmpty()) { 21 | jdbcString = String.format("%s;OAuth2ClientId=%s", jdbcString, oauthClientId); 22 | } 23 | System.out.println("jdbcString=" + jdbcString); 24 | 25 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 26 | 27 | Class.forName(JDBC_DRIVER); 28 | System.out.println("Getting connection"); 29 | try (Connection conn = DriverManager.getConnection(jdbcString); 30 | Statement stmt = conn.createStatement()) { 31 | System.out.println("Going to execute query"); 32 | try (ResultSet rs = stmt.executeQuery(query)) { 33 | System.out.println("Query is executed"); 34 | ResultSetMetaData md = rs.getMetaData(); 35 | String[] columns = new String[md.getColumnCount()]; 36 | for (int i = 0; i < columns.length; i++) { 37 | columns[i] = md.getColumnName(i + 1); 38 | } 39 | while (rs.next()) { 40 | System.out.print("Row " + rs.getRow() + "=["); 41 | for (int i = 0; i < columns.length; i++) { 42 | if (i != 0) { 43 | System.out.print(", "); 44 | } 45 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 46 | } 47 | System.out.println(")]"); 48 | } 49 | } 50 | } 51 | System.exit(0); 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /simba-jdbc-aad-token/src/main/java/net/alexott/demos/SimbaJDBCAadTokenServicePrincipal.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class SimbaJDBCAadTokenServicePrincipal { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String clientId = "..."; 15 | String clientSecret = "..."; 16 | String host = "adb-.....17.azuredatabricks.net"; 17 | String httpPath = "/sql/1.0/warehouses/...."; 18 | String jdbcString = String.format("jdbc:databricks://%s:443/default;transportMode=http;ssl=1;httpPath=%s;AuthMech=11;Auth_Flow=1;OAuth2ClientId=%s;OAuth2Secret=%s", 19 | host, httpPath, clientId, clientSecret); 20 | 21 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 22 | 23 | Class.forName(JDBC_DRIVER); 24 | System.out.println("Getting JDBC connection"); 25 | try (Connection conn = DriverManager.getConnection(jdbcString); 26 | Statement stmt = conn.createStatement()) { 27 | System.out.println("Going to execute query"); 28 | try (ResultSet rs = stmt.executeQuery(query)) { 29 | System.out.println("Query is executed"); 30 | ResultSetMetaData md = rs.getMetaData(); 31 | String[] columns = new String[md.getColumnCount()]; 32 | for (int i = 0; i < columns.length; i++) { 33 | columns[i] = md.getColumnName(i + 1); 34 | } 35 | while (rs.next()) { 36 | System.out.print("Row " + rs.getRow() + "=["); 37 | for (int i = 0; i < columns.length; i++) { 38 | if (i != 0) { 39 | System.out.print(", "); 40 | } 41 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 42 | } 43 | System.out.println(")]"); 44 | } 45 | } 46 | } 47 | System.exit(0); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /simba-jdbc-aad-token/src/main/java/net/alexott/demos/SimbaJDBCAadTokenUser.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos; 2 | 3 | import java.sql.Connection; 4 | import java.sql.DriverManager; 5 | import java.sql.ResultSet; 6 | import java.sql.ResultSetMetaData; 7 | import java.sql.Statement; 8 | 9 | class SimbaJDBCAadTokenUser { 10 | 11 | public static void main(String[] args) throws Exception { 12 | // Variables to set 13 | String query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()"; 14 | String host = "adb-.....17.azuredatabricks.net"; 15 | String httpPath = "/sql/1.0/warehouses/...."; 16 | String oauthClientId = ""; 17 | 18 | boolean enableTokenCache = System.getProperty("enableTokenCache", "false") 19 | .equalsIgnoreCase("true"); 20 | String tokenCachePassPhrase = "1234567"; 21 | String jdbcString = String.format("jdbc:databricks://%s:443;httpPath=%s;AuthMech=11;Auth_Flow=2", 22 | host, httpPath); 23 | if (!oauthClientId.isEmpty()) { 24 | jdbcString = String.format("%s;OAuth2ClientId=%s", jdbcString, oauthClientId); 25 | } 26 | if (enableTokenCache) { 27 | jdbcString = String.format("%s;TokenCachePassPhrase=%s;EnableTokenCache=1", 28 | jdbcString, tokenCachePassPhrase); 29 | } else { 30 | jdbcString += ";EnableTokenCache=0"; 31 | } 32 | 33 | String JDBC_DRIVER = "com.databricks.client.jdbc.Driver"; 34 | 35 | Class.forName(JDBC_DRIVER); 36 | System.out.println("Getting connection"); 37 | try (Connection conn = DriverManager.getConnection(jdbcString); 38 | Statement stmt = conn.createStatement()) { 39 | System.out.println("Going to execute query"); 40 | try (ResultSet rs = stmt.executeQuery(query)) { 41 | System.out.println("Query is executed"); 42 | ResultSetMetaData md = rs.getMetaData(); 43 | String[] columns = new String[md.getColumnCount()]; 44 | for (int i = 0; i < columns.length; i++) { 45 | columns[i] = md.getColumnName(i + 1); 46 | } 47 | while (rs.next()) { 48 | System.out.print("Row " + rs.getRow() + "=["); 49 | for (int i = 0; i < columns.length; i++) { 50 | if (i != 0) { 51 | System.out.print(", "); 52 | } 53 | System.out.print(columns[i] + "='" + rs.getObject(i + 1) + "'"); 54 | } 55 | System.out.println(")]"); 56 | } 57 | } 58 | } 59 | System.exit(0); 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /simba-odbc-aad-token/README.md: -------------------------------------------------------------------------------- 1 | # Querying Databricks SQL or cluster via Databricks ODBC driver using AAD token 2 | 3 | This directory contain demonstration of how to authenticate to Databricks using Azure Active Directory token when using [Databricks ODBC driver](https://www.databricks.com/spark/odbc-drivers-download) via `pyodbc` library. There are two implementations: 4 | 1. `odbc-aad-service-principal.py` using service principal and execute query against Databricks SQL Endpoint or Databricks Cluster. 5 | 2. `odbc-aad-user.py` using user AAD interactive authentication. 6 | 7 | ## Installation 8 | 9 | Install all necessary dependencies: 10 | 11 | * `pyodbc` 12 | 13 | by executing: 14 | 15 | ```sh 16 | pip install -U -r requirements.txt 17 | ``` 18 | 19 | You also need to install ODBC driver as outlined in the [documentation](https://learn.microsoft.com/en-us/azure/databricks/integrations/jdbc-odbc-bi). 20 | 21 | ## Modify the script(s) 22 | 23 | You need to modify scripts and change following variables: 24 | 25 | * `host` - set to host name of the Databricks workspace (without `https://`) 26 | * `http_path` - obtain HTTP Path parameter of Databricks SQL Endpoint or Databricks Cluster as per [documentation](https://docs.databricks.com/dev-tools/python-sql-connector.html#get-started). 27 | * `driver` - location of ODBC driver library (i.e., `/Library/simba/spark/lib/libsparkodbc_sb64-universal.dylib` on Mac OS). 28 | * `query` - what query should be executed 29 | 30 | ### Set authentication parameters for odbc-aad-service-principal.py 31 | 32 | Authentication parameters of service principal could be set in the code directly (not the best way), or obtained from following environment variables: 33 | 34 | * `ARM_CLIENT_ID` - client ID of application in Azure Active Directory 35 | * `ARM_CLIENT_SECRET` - secret for AAD application 36 | 37 | ## Execute script(s) 38 | 39 | Just run: 40 | 41 | ```sh 42 | python odbc-aad-service-principal.py 43 | ``` 44 | 45 | or 46 | 47 | ```sh 48 | python odbc-aad-user.py 49 | ``` 50 | 51 | and it will print result of query execution. 52 | -------------------------------------------------------------------------------- /simba-odbc-aad-token/odbc-aad-service-principal.py: -------------------------------------------------------------------------------- 1 | import pyodbc 2 | import os 3 | 4 | 5 | def get_env_or_raise(name: str) -> str: 6 | value = os.environ[name] 7 | if not value: 8 | raise Exception(f"Environment variable {name} isn't set or empty") 9 | 10 | return value 11 | 12 | 13 | # Update parameters 14 | host="adb-....17.azuredatabricks.net" 15 | http_path = "/sql/1.0/warehouses/..." 16 | driver = "/Library/simba/spark/lib/libsparkodbc_sb64-universal.dylib" 17 | client_id = get_env_or_raise("ARM_CLIENT_ID") 18 | client_secret = get_env_or_raise("ARM_CLIENT_SECRET") 19 | 20 | # 21 | url = f"Driver={driver};Host={host};Port=443;ThriftTransport=2;SSL=1;AuthMech=11;Auth_Flow=1;HTTPPath={http_path};IgnoreTransactions=true;Auth_Client_ID={client_id};Auth_Client_Secret={client_secret}" 22 | db_conn = pyodbc.connect(url) 23 | 24 | query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()" 25 | cursor = db_conn.cursor().execute(query) 26 | columns = [column[0] for column in cursor.description] 27 | results = [dict(zip(columns, row)) for row in cursor] 28 | print(results) 29 | -------------------------------------------------------------------------------- /simba-odbc-aad-token/odbc-aad-user.py: -------------------------------------------------------------------------------- 1 | import pyodbc 2 | 3 | # Update parameters 4 | host="adb-....17.azuredatabricks.net" 5 | http_path = "/sql/1.0/warehouses/...." 6 | driver = "/Library/simba/spark/lib/libsparkodbc_sb64-universal.dylib" 7 | query = "select 42, current_timestamp(), current_catalog(), current_database(), current_user()" 8 | 9 | # 10 | url = f"Driver={driver};Host={host};Port=443;ThriftTransport=2;SSL=1;AuthMech=11;Auth_Flow=2;HTTPPath={http_path};IgnoreTransactions=true;PWD=1234567" 11 | db_conn = pyodbc.connect(url) 12 | 13 | cursor = db_conn.cursor().execute(query) 14 | columns = [column[0] for column in cursor.description] 15 | results = [dict(zip(columns, row)) for row in cursor] 16 | print(results) 17 | -------------------------------------------------------------------------------- /simba-odbc-aad-token/requirements.txt: -------------------------------------------------------------------------------- 1 | pyodbc 2 | -------------------------------------------------------------------------------- /spring-jdbc-dbsql/README.md: -------------------------------------------------------------------------------- 1 | # Querying Databricks via JDBC using Spring JDBC 2 | 3 | This directory contains the example of querying data via Databricks SQL Endpoint or Databricks Cluster using JDBC protocol. Authentication to Databricks is performed using Databricks personal access token. 4 | 5 | 6 | ## Adjusting parameters 7 | 8 | You need to update source code and adjust following parameters: 9 | 10 | * `query` - what query should be executed 11 | * `host` - hostname of Databricks workspace obtained as per [documentation](https://docs.databricks.com/integrations/bi/jdbc-odbc-bi.html#jdbc-configuration-and-connection-parameters), 12 | * `httpPath` - HTTP Path of Databricks cluster or SQL Endpoint 13 | * `token` - personal access token 14 | 15 | ## Build & run 16 | 17 | Just execute `mvn package` to build the code, and then you can execute resulting uber jar: 18 | 19 | ```sh 20 | java -cp target/pring-jdbc-dbsql-0.0.1-jar-with-dependencies.jar \ 21 | net.alexott.demos.spring_jdbc.SimpleQuery 22 | ``` 23 | 24 | Or the code could be executed from an IDE. 25 | -------------------------------------------------------------------------------- /spring-jdbc-dbsql/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 4.0.0 5 | net.alexott.demos 6 | spring-jdbc-dbsql 7 | 0.0.1 8 | jar 9 | 10 | UTF-8 11 | 1.8 12 | 5.3.16 13 | 14 | 15 | 16 | 17 | com.databricks 18 | databricks-jdbc 19 | 2.6.25-1 20 | 21 | 22 | org.springframework 23 | spring-jdbc 24 | 4.3.19.RELEASE 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | maven-compiler-plugin 33 | 3.8.1 34 | 35 | ${java.version} 36 | ${java.version} 37 | true 38 | 39 | 40 | 41 | org.apache.maven.plugins 42 | maven-assembly-plugin 43 | 3.2.0 44 | 45 | 46 | jar-with-dependencies 47 | 48 | 49 | 50 | 51 | package 52 | 53 | single 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /spring-jdbc-dbsql/src/main/java/net/alexott/demos/spring_jdbc/SimpleQuery.java: -------------------------------------------------------------------------------- 1 | package net.alexott.demos.spring_jdbc; 2 | 3 | import org.springframework.jdbc.core.JdbcTemplate; 4 | import org.springframework.jdbc.datasource.SimpleDriverDataSource; 5 | import com.databricks.client.jdbc.Driver; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | public class SimpleQuery { 11 | public static void main(String[] args) { 12 | // Variables to update 13 | String host = ""; 14 | String httpPath = ""; 15 | String token = ""; 16 | String query = "select * from default.table"; 17 | 18 | String jdbcUrl = "jdbc:databricks://" + host + ":443/default;transportMode=http;ssl=1;httpPath=" + 19 | httpPath + ";AuthMech=3;UID=token;PWD=" + token; 20 | 21 | // define data source 22 | SimpleDriverDataSource ds = new SimpleDriverDataSource(); 23 | ds.setDriver(new Driver()); 24 | ds.setUrl(jdbcUrl); 25 | JdbcTemplate jdbcTemplate = new JdbcTemplate(ds); 26 | 27 | // query data 28 | List> data = jdbcTemplate.queryForList(query); 29 | int cnt = 0; 30 | for (Map row: data) { 31 | System.out.format("Row(%5d)[", cnt+1); 32 | int i=0; 33 | for (Map.Entry entry: row.entrySet()) { 34 | if (i > 0) { 35 | System.out.print(", "); 36 | } 37 | System.out.print(entry.getKey()+"='" + entry.getValue() + "'"); 38 | i++; 39 | } 40 | System.out.println("]"); 41 | cnt++; 42 | } 43 | System.out.format("There are %d rows in the table\n", cnt); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /stackoverflow/snippets.py: -------------------------------------------------------------------------------- 1 | # Requires com.databricks:spark-xml_2.12:0.14.0 2 | 3 | # TODO: normalize column names - remove `_`, switch to snake case 4 | 5 | posts_schema = "`_AcceptedAnswerId` BIGINT, `_AnswerCount` int,`_Body` STRING,`_ClosedDate` TIMESTAMP,`_CommentCount` int,`_CommunityOwnedDate` TIMESTAMP,`_ContentLicense` STRING,`_CreationDate` TIMESTAMP,`_FavoriteCount` int,`_Id` BIGINT,`_LastActivityDate` TIMESTAMP,`_LastEditDate` TIMESTAMP,`_LastEditorDisplayName` STRING,`_LastEditorUserId` BIGINT,`_OwnerDisplayName` STRING,`_OwnerUserId` BIGINT,`_ParentId` BIGINT,`_PostTypeId` BIGINT,`_Score` int,`_Tags` STRING,`_Title` STRING,`_ViewCount` BIGINT" 6 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "posts").schema(posts_schema).load("/Users/ott/tmp/stackoverflow-vi/Posts.xml") 7 | 8 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "badges").load("/Users/ott/tmp/stackoverflow-vi/Badges.xml") 9 | badges_schema = "`_Class` BIGINT,`_Date` TIMESTAMP,`_Id` BIGINT,`_Name` STRING,`_TagBased` BOOLEAN,`_UserId` BIGINT" 10 | 11 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "comments").load("/Users/ott/tmp/stackoverflow-vi/Comments.xml") 12 | comments_schema = "`_ContentLicense` STRING,`_CreationDate` TIMESTAMP,`_Id` BIGINT,`_PostId` BIGINT,`_Score` BIGINT,`_Text` STRING,`_UserDisplayName` STRING,`_UserId` BIGINT" 13 | 14 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "posthistory").load("/Users/ott/tmp/stackoverflow-vi/PostHistory.xml") 15 | post_histroy_schema = "`_Comment` STRING,`_ContentLicense` STRING,`_CreationDate` TIMESTAMP,`_Id` BIGINT,`_PostHistoryTypeId` BIGINT,`_PostId` BIGINT,`_RevisionGUID` STRING,`_Text` STRING,`_UserDisplayName` STRING,`_UserId` BIGINT" 16 | 17 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "postlinks").load("/Users/ott/tmp/stackoverflow-vi/PostLinks.xml") 18 | post_links_schema = "`_CreationDate` TIMESTAMP,`_Id` BIGINT,`_LinkTypeId` BIGINT,`_PostId` BIGINT,`_RelatedPostId` BIGINT" 19 | 20 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "tags").load("/Users/ott/tmp/stackoverflow-vi/Tags.xml") 21 | tags_schema = "`_Count` BIGINT,`_ExcerptPostId` BIGINT,`_Id` BIGINT,`_TagName` STRING,`_WikiPostId` BIGINT" 22 | 23 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "users").load("/Users/ott/tmp/stackoverflow-vi/Users.xml") 24 | users_schema = "`_AboutMe` STRING,`_AccountId` BIGINT,`_CreationDate` TIMESTAMP,`_DisplayName` STRING,`_DownVotes` BIGINT,`_Id` BIGINT,`_LastAccessDate` TIMESTAMP,`_Location` STRING,`_ProfileImageUrl` STRING,`_Reputation` BIGINT,`_UpVotes` BIGINT,`_Views` BIGINT,`_WebsiteUrl` STRING" 25 | 26 | df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "votes").load("/Users/ott/tmp/stackoverflow-vi/Votes.xml") 27 | votes_schema = "`_BountyAmount` int,`_CreationDate` TIMESTAMP,`_Id` BIGINT,`_PostId` BIGINT,`_UserId` BIGINT,`_VoteTypeId` BIGINT" 28 | 29 | #df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "posts").load("/Users/ott/tmp/stackoverflow-vi/") 30 | 31 | #df = spark.read.format("xml").option("rowTag", "row").option("rootTag", "posts").load("/Users/ott/tmp/stackoverflow-vi/") 32 | 33 | -------------------------------------------------------------------------------- /streaming_query_listener/QueryListenerUsageExample.py: -------------------------------------------------------------------------------- 1 | # Databricks notebook source 2 | # MAGIC %load_ext autoreload 3 | # MAGIC %autoreload 2 4 | 5 | # COMMAND ---------- 6 | 7 | from query_listener import UcVolumeWriteListener 8 | 9 | # COMMAND ---------- 10 | 11 | listener = UcVolumeWriteListener("/Volumes/main/tmp/tmp/stream1", "stream1") 12 | spark.streams.addListener(listener) 13 | read_df = spark.readStream.format("rate").load() 14 | query1 = ( 15 | read_df 16 | .writeStream 17 | .format("memory") 18 | .queryName("stream1") 19 | .trigger(processingTime="10 seconds") 20 | .start() 21 | ) 22 | 23 | # COMMAND ---------- 24 | 25 | query1.stop() -------------------------------------------------------------------------------- /streaming_query_listener/query_listener.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime, timezone, date 3 | from typing import Any 4 | import os 5 | import uuid 6 | 7 | from click.core import batch 8 | from pyspark.sql.streaming import StreamingQueryListener 9 | 10 | from databricks.sdk import WorkspaceClient 11 | from pyspark.sql.streaming.listener import QueryProgressEvent, QueryStartedEvent, \ 12 | QueryTerminatedEvent, QueryIdleEvent 13 | 14 | 15 | class AdditionalJsonEncoder(json.JSONEncoder): 16 | def default(self, o): 17 | if isinstance(o, datetime) or isinstance(o, date): 18 | return o.isoformat() 19 | if isinstance(o, uuid.UUID): 20 | return o.hex 21 | 22 | return json.JSONEncoder.default(self, o) 23 | 24 | 25 | class UcVolumeWriteListener(StreamingQueryListener): 26 | def __init__(self, uc_volume_dir: str, file_name_prefix: str, write_interval_sec: int = 60): 27 | """Initializes the UcVolumeWriteListener. 28 | This listener writes query events to a specified UC Volume directory in JSON format. 29 | 30 | Args: 31 | uc_volume_dir (str): Path to UC Volume directory to write to `/Volumes////` 32 | The directory will be created if it does not exist. 33 | file_name_prefix (str): Prefix for the output file names. 34 | write_interval_sec (int, optional): Interval in seconds for writing to the volume. Defaults to 60. 35 | """ 36 | super().__init__() 37 | self.uc_volume_dir = uc_volume_dir 38 | self.stream_name = file_name_prefix 39 | self.base_record: dict[str, Any] = { 40 | "rowsProcessed": 0, 41 | "duration": 0, 42 | "timestamp": None, 43 | "stream_name": self.stream_name, 44 | } 45 | # TODO: add handling of error situations when we can't write to the volume? 46 | os.makedirs(uc_volume_dir, exist_ok=True) 47 | self.last_update = datetime.now(timezone.utc) 48 | self.buffer = [] 49 | self.write_interval_sec = write_interval_sec 50 | 51 | def _push_to_volume(self, data: dict, event, force_write: bool = False): 52 | # TODO: do buffering, and write to the volume in batches 53 | ts = datetime.now(timezone.utc) 54 | if "timestamp" not in data: 55 | data["timestamp"] = ts.isoformat() 56 | data["report_timestamp"] = ts.isoformat() 57 | data["runId"] = event.runId 58 | data["query_id"] = event.id 59 | data["stream_name"] = self.stream_name 60 | self.buffer.append(data) 61 | # maybe write the data... 62 | time_diff = (ts - self.last_update).total_seconds() 63 | if force_write or (time_diff >= self.write_interval_sec): 64 | path = os.path.join(self.uc_volume_dir, f"{ts.strftime('%Y-%m-%d-%H-%M-%S.%fZ')}.json") 65 | with open(path, "w") as f: 66 | for i in self.buffer: 67 | t = json.dumps(i, cls=AdditionalJsonEncoder) 68 | f.write(t+"\n") 69 | self.last_update = ts 70 | self.buffer = [] 71 | 72 | def onQueryStarted(self, event: QueryStartedEvent): 73 | r = self.base_record.copy() 74 | r["status"] = "started" 75 | r["timestamp"] = event.timestamp 76 | r["query_name"] = event.name 77 | print("Query started", r) 78 | self._push_to_volume(r, event) 79 | 80 | def onQueryProgress(self, event: QueryProgressEvent): 81 | progress = event.progress 82 | r = self.base_record.copy() 83 | r["status"] = "running" 84 | r["rowsProcessed"] = progress.numInputRows 85 | r["durationMs"] = progress.batchDuration 86 | r["batchId"] = progress.batchId 87 | r["numInputRows"] = progress.numInputRows 88 | r["inputRowsPerSecond"] = progress.inputRowsPerSecond 89 | r["timestamp"] = progress.timestamp 90 | r["query_name"] = progress.name 91 | print("Query progressed", r) 92 | 93 | self._push_to_volume(r, progress) 94 | 95 | def onQueryIdle(self, event: QueryIdleEvent): 96 | r = self.base_record.copy() 97 | r["status"] = "idle" 98 | r["timestamp"] = event.timestamp 99 | print("Query idle", r) 100 | self._push_to_volume(r, event) 101 | 102 | def onQueryTerminated(self, event: QueryTerminatedEvent): 103 | r = self.base_record.copy() 104 | if event.exception: 105 | r["status"] = "failed" 106 | r["exception"] = event.exception 107 | print( 108 | f"Query terminated with exception: id={event.id}, exception={event.exception}" 109 | ) 110 | else: 111 | r["status"] = "succeeded" 112 | 113 | print("Query termindated", r) 114 | self._push_to_volume(r, event, force_write=True) 115 | --------------------------------------------------------------------------------