├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── encryption_template.ipynb ├── event_hub_encryption_template_errorhandling.scala └── ml_batch_scoring └── batch_scoring_ml_models.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing 3 | 4 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 5 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 6 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 9 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 10 | provided by the bot. You will only need to do this once across all repos using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 14 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /encryption_template.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","source":["## Parameterized Job for running encryption"],"metadata":{}},{"cell_type":"markdown","source":["This notebook is a template to illustrate how encryption at scale can be parameterized so that it can be run with data factory scheduling or Databricks job scheduling. The main use is to read json param file that contain mapping between dataset name and its schema definition & encryption definition so that we can automate: \n1. Reading datasets of different source types: flat file (csv and other), streaming from eventhubs and decode from raw into structured format\n2. Automatically pick up fields to encrypt\n3. Write to a designated delta table\nNote:\n- The structurization of data sources is currently done in a simplistic way, e.g. support string, numeric and date column. For other complex types like timestamp, reformat of date etc...future effort is needed to enhance. The main focus of this is to get the structure out in string and apply encryption\n- For parallel running of multiple notebook like this where each notebook process one dataset/table, one can either define parallel execution steps in ADF or use databricks workflow to run notebooks conrrently (https://docs.databricks.com/user-guide/notebooks/notebook-workflows.html#run-multiple-notebooks-concurrently). In that case, the step to load json file should be placed at the master notebook and this notebook is used as worker notebook to run a specific dataset\n\nParameters for the job/notebook\nThese are parameters for the notebook that can be set at run time either by Devop engineer or by another master notebook/job\n1. schema_mapping_json_path: Path to the Json file that contain mapping between dataset names, fields, whether or not it needs encryption, data type. See below for an example of a json file\n2. Dataset name: name of the dataset that will be used to look up for schema and encryption mapping in the json file\n3. Ingestion type: Streaming or Batch. Streaming is used when the source is EventHub. Batch is used when raw data is copied to a landing folder in ADLS Gen 2 and this notebook picks up from there\n4. Input data folder: For batch processing from landing zone. This can be set with wildcat to copy data recursively\n5. Output data folder: In case you let the job create the delta table automatically in batch mode, set the output folder so that encrypted data is created there. For batch mode, this may be needed as there can be multiple checkpoint location before final append to the target table. In case of streaming, the final table can be used directly as Streaming supports checkpoint location\n6. Table name: output table name. Can be an existing table name\n7. Eventhub Account: the account name of the eventhub in case a streaming dataset is used\n8. Secret scope/secret: name of secret scope and key to retrieve EH's key \n9. EH topic: topic to read data from. As a deviation from original design due to the use of Delta table, it's recommended to write streaming encryption output to a delta table instead of EH's topic. The reason is EH delta table support change capture. So the next job can just subscribe to a Delta table to read new changes. This may be faster than reading from a EH's topic."],"metadata":{}},{"cell_type":"markdown","source":["{\n \"datasets\": {\n \"chicago_crimes\": {\n \"first_row_is_header\": \"true\",\n \"field_target_dtype_mapping\": {\n \"Seq\": \"IntegerType()\",\n \"ID\": \"IntegerType()\",\n \"Case_Number\": \"StringType()\",\n \"Date\": \"StringType()\",\n \"Block\": \"StringType()\",\n \"IUCR\": \"StringType()\",\n \"Primary_Type\": \"StringType()\",\n \"Description\": \"StringType()\",\n \"Location_Description\": \"StringType()\",\n \"Beat\": \"StringType()\",\n \"Arrest\": \"StringType()\",\n \"Domestic\": \"StringType()\",\n \"District\": \"StringType()\",\n \"Ward\": \"StringType()\",\n \"Community_Area\": \"StringType()\",\n \"FBI_Code\": \"StringType()\",\n \"X_Coordinate\": \"StringType()\",\n \"Y_Coordinate\": \"StringType()\",\n \"Year\": \"StringType()\",\n \"Updated_On\": \"StringType()\",\n \"Latitude\": \"StringType()\",\n \"Longitude\": \"StringType()\",\n \"Location\": \"StringType()\"\n },\n \"field_encryption_mapping\": {\n \"Description\": \"alpha\",\n \"Location_Description\": \"alpha\",\n \"Block\": \"alpha\",\n \"X_Coordinate\": \"alpha\",\n \"Y_Coordinate\": \"alpha\"\n },\n \"delimiter\": \",\",\n \"format\": \"csv\"\n },\n \"Event_json_flat\": {\n \"field_target_dtype_mapping\": {\n \"time\": \"TimestampType()\",\n \"action\": \"StringType()\"\n },\n \"field_encryption_mapping\": {\n \"time\": \"alpha\",\n \"action\": \"alpha\"\n },\n \"format\": \"json\"\n },\n \"Event_json_stream\": {\n \"field_target_dtype_mapping\": {\n \"time\": \"TimestampType()\",\n \"action\": \"StringType()\"\n },\n \"field_encryption_mapping\": {\n \"time\": \"alpha\",\n \"action\": \"alpha\"\n },\n \"format\": \"json\"\n }\n }\n}"],"metadata":{}},{"cell_type":"code","source":["//uncomment and Run this here or set it at the cluster level if you'd like to access ADLS Gen2 store directly instead of using mount point. Currently mount point of ADLS Gen 2 is not supported for Delta table\n// spark.conf.set(\"fs.azure.account.auth.type..dfs.core.windows.net\", \"OAuth\")\n// spark.conf.set(\"fs.azure.account.oauth.provider.type.adlsdatalakegen6.dfs.core.windows.net\", \"org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider\")\n// spark.conf.set(\"fs.azure.account.oauth2.client.id..dfs.core.windows.net\", \"\")\n// spark.conf.set(\"fs.azure.account.oauth2.client.secret..dfs.core.windows.net\", \"\")\n// spark.conf.set(\"fs.azure.account.oauth2.client.endpoint..dfs.core.windows.net\", \"\")\n// spark.conf.set(\"fs.azure.createRemoteFileSystemDuringInitialization\", \"true\")\n// dbutils.fs.ls(\"abfss://test@.dfs.core.windows.net/\")\n// spark.conf.set(\"fs.azure.createRemoteFileSystemDuringInitialization\", \"false\")\n"],"metadata":{},"outputs":[],"execution_count":4},{"cell_type":"code","source":["%sql drop table test_db.my_table5 "],"metadata":{},"outputs":[],"execution_count":5},{"cell_type":"code","source":["%sql \n---replace this with the name of the table you'll use if you want to test the job with existing table instead of letting the job to create the table itself\n-- CREATE TABLE iF NOT EXISTS test_db.my_table5 (action STRING, time string)\n-- USING DELTA LOCATION 'abfss://test@adlsdatalakegen6.dfs.core.windows.net/testencrypt2'"],"metadata":{},"outputs":[],"execution_count":6},{"cell_type":"code","source":["%python\n# dbutils.widgets.removeAll()\n\n"],"metadata":{},"outputs":[],"execution_count":7},{"cell_type":"code","source":["%python\nfrom pyspark.sql.types import *\nimport json\nfrom json import JSONDecoder\nfrom collections import OrderedDict\n\ndbutils.widgets.dropdown(\"ingestion_type\", \"Batch\", [\"Batch\", \"Streaming\"], \"Ingestion type\")\ndbutils.widgets.text(\"dataset_name\", \"chicago_crimes\", \"Dataset name\")\ndbutils.widgets.text(\"output_tbl\", \"output_table\", \"output delta table name \")\n\ndbutils.widgets.text(\"eh_topic\", \"kafka_spark\", \"Input eventHub topic\")\ndbutils.widgets.text(\"secret_scope\", \"encryptionkey\", \"Secret scope\")\ndbutils.widgets.text(\"secret_key\", \"evh\", \"Secret key\")\n\ndbutils.widgets.text(\"checkpointLocation\", \"dbfs:/mnt/cp/testcp\", \"Check point Location for streaming\")\n\ndbutils.widgets.text(\"input_path\", \"/FileStore/tables/Chicago_*.csv\", \"Input Data Folder\")\ndbutils.widgets.text(\"eh_account\", \"kafkaeventhub01\", \"Event Hub Account\")\ndbutils.widgets.text(\"en_schema_path\", \"/dbfs/FileStore/tables/enc_schema.json\", \"Json Encryption Schema Path\")\ndbutils.widgets.text(\"output_path\", \"/FileStore/output/output.delta\", \"Output File Path\")\n\ndataset_name = dbutils.widgets.get(\"dataset_name\")\n\ningestion_type = dbutils.widgets.get(\"ingestion_type\")\nen_schema_path =dbutils.widgets.get(\"en_schema_path\")\ninput_path = dbutils.widgets.get(\"input_path\")\ncheckpointLocation = dbutils.widgets.get(\"checkpointLocation\")\n\neh_account = dbutils.widgets.get(\"eh_account\")\noutput_path=dbutils.widgets.get(\"output_path\")\noutput_tbl=dbutils.widgets.get(\"output_tbl\")\n\neh_topic=dbutils.widgets.get(\"eh_topic\")\nsecret_scope=dbutils.widgets.get(\"secret_scope\")\nsecret_key=dbutils.widgets.get(\"secret_key\")\n\n\n#Need a method to maintain order of json fields in the mapping\njson_data=open(en_schema_path).read()\ncustomdecoder = JSONDecoder(object_pairs_hook=OrderedDict)\njsondata = customdecoder.decode(json_data)\n\nfield_encryption_mapping =jsondata.get('datasets').get(dataset_name).get('field_encryption_mapping')\ndata_format = jsondata.get('datasets').get(dataset_name).get(\"format\")\ndelimiter= jsondata.get('datasets').get(dataset_name).get(\"delimiter\")\nschema_text = jsondata.get('datasets').get(dataset_name).get('field_target_dtype_mapping')\nschema = StructType()\n\nfor (field,dtype) in schema_text.items():\n if dtype =='StringType()':\n dtype =StringType()\n elif dtype =='IntegerType()':\n dtype =IntegerType()\n elif dtype =='LongType()':\n dtype =LongType()\n elif dtype =='LongType()':\n dtype =LongType()\n elif dtype =='DateType()':\n dtype =DateType()\n elif dtype =='TimestampType()':\n dtype =TimestampType()\n schema.add(field, data_type = dtype)\n\nprint(ingestion_type)\nprint(en_schema_path)\nprint(data_format)\nprint(input_path)\nprint(eh_account)\nprint(output_path)\nprint(schema)\nprint(eh_topic)\nprint(secret_scope)\nprint(secret_key)\n"],"metadata":{},"outputs":[],"execution_count":8},{"cell_type":"code","source":["%python\n#below is to clear the path before each run, need to disable this for real run\nfrom pyspark.sql.functions import *\n\ndbutils.fs.rm(output_path, True)\ndbutils.fs.rm(checkpointLocation, True)\nspark.sql(\"drop table if exists \"+output_tbl)\n\n#Batch copy and encryption\ntemp_tbl_name = 'temp_tbl'\nencrypt_func='encrypt'\nif ingestion_type == \"Batch\":\n if data_format =='csv':\n first_row_is_header = \"true\"\n df = spark.read.format(data_format) \\\n .option(\"header\", first_row_is_header) \\\n .option(\"sep\", delimiter) \\\n .schema(schema) \\\n .load(input_path) \n else:\n df = spark.read.format(data_format) \\\n .schema(schema) \\\n .load(input_path) \nelse: #Prepare for loading eventHub info\n eh_key = dbutils.secrets.get(secret_scope,secret_key)\n BOOTSTRAP_SERVERS = eh_account+\".servicebus.windows.net:9093\"\n EH_SASL = \"kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username='$ConnectionString' password='Endpoint=sb://\"+eh_account+\".servicebus.windows.net;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=\"+eh_key+\"';\"\n GROUP_ID = \"$Default\"\n df = spark \\\n .readStream \\\n .format(\"kafka\") \\\n .option(\"kafka.bootstrap.servers\", BOOTSTRAP_SERVERS) \\\n .option(\"subscribe\", eh_topic) \\\n .option(\"kafka.sasl.mechanism\",\"PLAIN\") \\\n .option(\"kafka.security.protocol\",\"SASL_SSL\") \\\n .option(\"kafka.sasl.jaas.config\", EH_SASL ) \\\n .option(\"kafka.request.timeout.ms\", \"60000\") \\\n .option(\"kafka.session.timeout.ms\", \"60000\") \\\n .option(\"kafka.group.id\", GROUP_ID) \\\n .option(\"failOnDataLoss\", \"false\") \\\n .load() \\\n .select(col(\"timestamp\"), col(\"value\").cast(\"STRING\").alias(\"value\")) \\\n .select(col(\"timestamp\"), from_json(col(\"value\"), schema).alias(\"json\"))\n \n\ndf.registerTempTable(temp_tbl_name)\n\nsql_statement = \"Select \"\n#Where statement to remove null value object\nwhere_statement = \" where \"\nfor field in schema_text.keys():\n if field in field_encryption_mapping.keys():\n #in case of streaming above, the field need to be accessed as json.field_name\n if ingestion_type == \"Batch\":\n sql_statement = sql_statement +encrypt_func +\"(\"+field+\",'\"+ field_encryption_mapping.get(field)+\"') \"+field+\" ,\" \n where_statement = where_statement +field+ \" is not null and \"\n else:\n sql_statement = sql_statement +encrypt_func +\"(json.\"+field+\",'\"+ field_encryption_mapping.get(field)+\"') \"+field+\" ,\" \n where_statement = where_statement +\"json.\"+field+ \" is not null and \"\n \n else:\n if ingestion_type == \"Batch\":\n sql_statement = sql_statement+ field+\",\" \n else:\n sql_statement = sql_statement+\"json.\"+ field+\",\"\nsql_statement = sql_statement[:-1] + \" from \" + temp_tbl_name + where_statement[:-4]\nprint(sql_statement)\ndf = spark.sql(sql_statement)\n\n#Write out the encrypted dataset\nif ingestion_type == \"Batch\":\n df.write.format(\"delta\").saveAsTable(output_tbl, path=output_path)\nelse:#write stream to the delta table\n df.writeStream \\\n .format(\"delta\") \\\n .outputMode(\"append\") \\\n .option(\"checkpointLocation\", checkpointLocation) \\\n .table(\"output_tbl\") \n"],"metadata":{},"outputs":[],"execution_count":9},{"cell_type":"code","source":["//Show how the next job can just subscribe to delta table to read new data automatically without having to read from EH.\n// display(spark.readStream.table(\"test_db.my_table5\"))"],"metadata":{},"outputs":[],"execution_count":10},{"cell_type":"code","source":["%python\n#To simulate a stream to Kakfa for testing purpose, uncomment and run the following \n# from pyspark.sql.types import *\n# from pyspark.sql.functions import *\n\n# inputPath = \"/databricks-datasets/structured-streaming/events/\"\n\n# # Since we know the data format already, let's define the schema to speed up processing (no need for Spark to infer schema)\n# jsonSchema = StructType([ StructField(\"time\", TimestampType(), True), StructField(\"action\", StringType(), True) ])\n\n\n# # Similar to definition of staticInputDF above, just using `readStream` instead of `read`\n# streamingInputDF = (\n# spark\n# .readStream \n# .schema(jsonSchema) # Set the schema of the JSON data\n# .option(\"maxFilesPerTrigger\", 1) # Treat a sequence of files as a stream by picking one file at a time\n# .json(inputPath)\n# )\n\n\n# EH_SASL = \"kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username='$ConnectionString' password='Endpoint=sb://\"+eh_account+\".servicebus.windows.net;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=\"+eh_key+\"';\"\n# BOOTSTRAP_SERVERS = eh_account+\".servicebus.windows.net:9093\"\n\n# GROUP_ID = \"$Default\"\n# df = streamingInputDF.selectExpr(\"CAST(time AS STRING) AS key\", \"to_json(struct(*)) AS value\") \\\n# .writeStream \\\n# .format(\"kafka\") \\\n# .option(\"kafka.bootstrap.servers\", BOOTSTRAP_SERVERS) \\\n# .option(\"topic\", \"kafka_spark\") \\\n# .option(\"kafka.sasl.mechanism\",\"PLAIN\") \\\n# .option(\"kafka.security.protocol\",\"SASL_SSL\") \\\n# .option(\"kafka.sasl.jaas.config\", EH_SASL ) \\\n# .option(\"checkpointLocation\", \"dbfs:/mnt/demo/checkpoint19\") \\\n# .start() \n\n"],"metadata":{},"outputs":[],"execution_count":11}],"metadata":{"name":"encryption_template","notebookId":1113761770974841},"nbformat":4,"nbformat_minor":0} -------------------------------------------------------------------------------- /event_hub_encryption_template_errorhandling.scala: -------------------------------------------------------------------------------- 1 | // Databricks notebook source 2 | // MAGIC %md ## Parameterized Job for running encryption 3 | 4 | // COMMAND ---------- 5 | 6 | // MAGIC %md 7 | // MAGIC This notebook is a template to illustrate how encryption at scale can be parameterized so that it can be run with data factory scheduling or Databricks job scheduling. The main use is to read json param file that contain mapping between dataset name and its schema definition & encryption definition so that we can automate: 8 | // MAGIC 1. Reading datasets of different source types: flat file (csv and other), streaming from eventhubs and decode from raw into structured format 9 | // MAGIC 2. Automatically pick up fields to encrypt 10 | // MAGIC 3. Write to a designated delta table 11 | // MAGIC Note: 12 | // MAGIC - The structurization of data sources is currently done in a simplistic way, e.g. support string, numeric and date column. For other complex types like timestamp, reformat of date etc...future effort is needed to enhance. The main focus of this is to get the structure out in string and apply encryption 13 | // MAGIC - For parallel running of multiple notebook like this where each notebook process one dataset/table, one can either define parallel execution steps in ADF or use databricks workflow to run notebooks conrrently (https://docs.databricks.com/user-guide/notebooks/notebook-workflows.html#run-multiple-notebooks-concurrently). In that case, the step to load json file should be placed at the master notebook and this notebook is used as worker notebook to run a specific dataset 14 | // MAGIC 15 | // MAGIC Parameters for the job/notebook 16 | // MAGIC These are parameters for the notebook that can be set at run time either by Devop engineer or by another master notebook/job 17 | // MAGIC 1. schema_mapping_json_path: Path to the Json file that contain mapping between dataset names, fields, whether or not it needs encryption, data type. See below for an example of a json file 18 | // MAGIC 2. Dataset name: name of the dataset that will be used to look up for schema and encryption mapping in the json file 19 | // MAGIC 3. Ingestion type: Streaming or Batch. Streaming is used when the source is EventHub. Batch is used when raw data is copied to a landing folder in ADLS Gen 2 and this notebook picks up from there 20 | // MAGIC 4. Input data folder: For batch processing from landing zone. This can be set with wildcat to copy data recursively 21 | // MAGIC 5. Output data folder: In case you let the job create the delta table automatically in batch mode, set the output folder so that encrypted data is created there. For batch mode, this may be needed as there can be multiple checkpoint location before final append to the target table. In case of streaming, the final table can be used directly as Streaming supports checkpoint location 22 | // MAGIC 6. Table name: output table name. Can be an existing table name 23 | // MAGIC 7. Eventhub Account: the account name of the eventhub in case a streaming dataset is used 24 | // MAGIC 8. Secret scope/secret: name of secret scope and key to retrieve EH's key 25 | // MAGIC 9. EH topic: topic to read data from. As a deviation from original design due to the use of Delta table, it's recommended to write streaming encryption output to a delta table instead of EH's topic. The reason is EH delta table support change capture. So the next job can just subscribe to a Delta table to read new changes. This may be faster than reading from a EH's topic. 26 | 27 | // COMMAND ---------- 28 | 29 | import org.apache.spark.sql.types._ 30 | import scala.collection.JavaConversions._ 31 | 32 | // import json 33 | // from json import JSONDecoder 34 | // from collections import OrderedDict 35 | 36 | dbutils.widgets.text("schema_name", "", "Schema name") 37 | 38 | dbutils.widgets.text("eh_topic", "", "EH Raw topic") //Devops will provide 39 | dbutils.widgets.text("eh_en_topic", "", "EH Encrypted topic") //Devops will provide 40 | 41 | dbutils.widgets.text("secret_scope", "", "EH Secret scope") //Scope will be same 42 | dbutils.widgets.text("secret_key", "", "EH Raw Secret key") //Devops will provide 43 | // dbutils.widgets.text("access_policy_key", "", "EH Raw Access Policy key") //Devops will provide 44 | // dbutils.widgets.text("en_access_policy_key", "", "EH Encrypted Access Policy key") //Devops will provide 45 | 46 | dbutils.widgets.text("en_secret_key", "", "EH Encrypted Secret key")//Devops will provide 47 | dbutils.widgets.text("eh_account", "", "Event Hub Account") //Devops will provide 48 | 49 | 50 | dbutils.widgets.text("checkpointLocation", "dbfs:/mnt/cp/testcp", "Check point Location") 51 | 52 | // #Fetch data from parameters 53 | 54 | val schema_name = dbutils.widgets.get("schema_name") //Devops will provide used for both ADF and EH 55 | 56 | val ingestion_type = dbutils.widgets.get("ingestion_type") //EH =streaming or ADF = batch 57 | val checkpointLocation = dbutils.widgets.get("checkpointLocation") //where do we defined this ? 58 | 59 | //Error table path to capture bad data 60 | val error_logtbl_path="dbfs:/mnt/tbl/error_tbl" 61 | 62 | // #Getting detail for Streaming 63 | val eh_account = dbutils.widgets.get("eh_account") 64 | val eh_topic=dbutils.widgets.get("eh_topic") 65 | val eh_en_topic =dbutils.widgets.get("eh_en_topic") 66 | 67 | val secret_scope=dbutils.widgets.get("secret_scope") 68 | val secret_key=dbutils.widgets.get("secret_key") 69 | val en_secret_key=dbutils.widgets.get("en_secret_key") 70 | 71 | val eh_key = dbutils.secrets.get(secret_scope,secret_key) 72 | val access_policy_name ="SendListenSharedAccesskey" 73 | val en_access_policy_name ="EncryptedSendListenSharedAccesskey" 74 | 75 | val eh_en_key =dbutils.secrets.get(secret_scope,en_secret_key) 76 | 77 | //Look up the schema detail from 78 | val jdbc_secret_password_key ="edsschemadb-password" 79 | val jdbc_secret_username_key = "edsschemadb-username" 80 | val jdbc_secret_hostname_key= "edsschemadb-hostname" 81 | val jdbcUsername = dbutils.secrets.get(secret_scope,jdbc_secret_username_key) 82 | 83 | val jdbcPassword =dbutils.secrets.get(secret_scope,jdbc_secret_password_key) 84 | val jdbcHostname= dbutils.secrets.get(secret_scope,jdbc_secret_hostname_key) 85 | val jdbcPort = 1433 86 | val jdbcDatabase ="edsschemadb" 87 | 88 | val jdbcUrl = s"jdbc:sqlserver://${jdbcHostname}:${jdbcPort};database=${jdbcDatabase}" 89 | 90 | // Create a Properties() object to hold the parameters. 91 | import java.util.Properties 92 | val connectionProperties = new Properties() 93 | 94 | connectionProperties.put("user", s"${jdbcUsername}") 95 | connectionProperties.put("password", s"${jdbcPassword}") 96 | 97 | val driverClass = "com.microsoft.sqlserver.jdbc.SQLServerDriver" 98 | connectionProperties.setProperty("Driver", driverClass) 99 | 100 | val schema_df = spark.read.jdbc(url=jdbcUrl,table="(select schema_content from schema_registry where schema_name='"+schema_name+"') as tbl", properties=connectionProperties) 101 | 102 | val field_encryption_df = spark.read.jdbc(url=jdbcUrl, table="(select attribute_name, encr_key from sensitive_field_mst A, schema_registry B where A.schema_id = B.schema_id and B.schema_name='"+schema_name+"') as tbl", properties=connectionProperties) 103 | 104 | //HOW DO WE COME UP WITH SCHEMA BASED on JSON 105 | 106 | val schema = DataType.fromJson(schema_df.take(1)(0)(0).toString).asInstanceOf[StructType] 107 | val field_encryption_list = field_encryption_df.collectAsList() 108 | val field_encryption_map = field_encryption_list.map(_.toSeq).map({ a => a.get(0) -> a.get(1) }).toMap 109 | 110 | 111 | // print(ingestion_type) 112 | // print(eh_account) 113 | // print(eh_topic) 114 | // print(secret_scope) 115 | // print(secret_key) 116 | 117 | 118 | // COMMAND ---------- 119 | 120 | // #Below, To fix bad data, I will use something like this 121 | // #select CASE WHEN Description rlike "^[a-zA-Z ]+$" THEN Description ELSE NULL END logic_column from chicago_crimes_delta 122 | // #To identify 123 | // #This is the regex map in the form of encryption_key: regex_expression 124 | // // # regex_map ={ 125 | // // # 'name2':'^[A-Z]', 126 | // // # 'ssn':'[A-Z!@#$*<>]', 127 | // // # 'dob':'[A-Za-z]' 128 | // // # # 'govt_id':[govt_id: [!"#()$%&*+,./:<=>?@^_a-z\\[\]{|}~;] 129 | 130 | // // # } 131 | 132 | //For special characters Scala require \\\\ in the regex 133 | 134 | val regex_map =Map( 135 | "name2" ->"[^a-zA-Z\\\\d:\\\\s]", 136 | 137 | "ssn" -> "[A-Z!@#$*<>]", 138 | "dob" -> "[A-Za-z]" 139 | 140 | ) 141 | 142 | def fullFlattenSchema(schema: StructType): Seq[String] = { 143 | def helper(schema: StructType, prefix: String): Seq[String] = { 144 | val fullName: String => String = name => if (prefix.isEmpty) name else s"$prefix.$name" 145 | schema.fields.flatMap { 146 | case StructField(name, inner: StructType, _, _) => 147 | fullName(name) +: helper(inner, fullName(name)) 148 | case StructField(name, _, _, _) => Seq(fullName(name)) 149 | } 150 | } 151 | 152 | helper(schema, "") 153 | } 154 | val schema_fields = fullFlattenSchema(schema) 155 | 156 | // COMMAND ---------- 157 | 158 | // #This part is to catch all 'bad' data records and store into an error log table 159 | import org.apache.spark.sql.functions._ 160 | 161 | // 162 | // #Batch copy and encryption 163 | val encr_invalid_tbl_name = "encr_invalid_tbl" //used for processing data that passed invalidation criteria 164 | 165 | val encrypt_func="encrypt" 166 | // #Prepare for consuming from eventHub - create listener 167 | val BOOTSTRAP_SERVERS = eh_account+".servicebus.windows.net:9093" 168 | val EH_SASL = "kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username=\"$ConnectionString\" password=\"Endpoint=sb://"+eh_account+".servicebus.windows.net;SharedAccessKeyName="+access_policy_name+";SharedAccessKey="+eh_key+"\" ;" 169 | // print(EH_SASL) 170 | //print(EH_SASL2) 171 | 172 | val GROUP_ID_BAD_DATA = "$BAD_DATA" 173 | val encr_invalid_df = spark 174 | .readStream 175 | .format("kafka") 176 | .option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS) 177 | .option("subscribe", eh_topic) 178 | .option("kafka.sasl.mechanism","PLAIN") 179 | .option("kafka.security.protocol","SASL_SSL") 180 | .option("kafka.sasl.jaas.config", EH_SASL ) 181 | .option("startingOffsets", "earliest") 182 | .option("kafka.request.timeout.ms", "60000") 183 | .option("kafka.session.timeout.ms", "60000") 184 | .option("kafka.group.id", GROUP_ID_BAD_DATA) 185 | .option("failOnDataLoss", "false") 186 | .load() 187 | .select($"timestamp", $"value".cast("STRING").alias("value")) 188 | .select($"timestamp", from_json($"value", schema).as("json")) 189 | 190 | // #temporary tables to process sql stmts 191 | 192 | encr_invalid_df.createOrReplaceTempView(encr_invalid_tbl_name) 193 | 194 | // #Constructing the SQL statement for good data 195 | var encr_invalid_sql_statement:String = "" 196 | var where_statement: String =" where " 197 | 198 | for (field <- schema_fields) { 199 | if (field_encryption_map.keys.contains((field.split("\\.").last.toLowerCase()))){ //this is handling for sensitive fields 200 | val encr_key = field_encryption_map.getOrElse(field.split("\\.").last.toLowerCase(),"").toString 201 | // print(encr_key) 202 | val regex_exp = regex_map.getOrElse(encr_key,"").toString 203 | // #EH 204 | // #filter good data based on invalidation criteria provided by infosec team and encrypt it 205 | encr_invalid_sql_statement = encr_invalid_sql_statement + "json."+field+" " +field.split("\\.").last+ ";" 206 | where_statement = where_statement +"json."+field+" rlike '"+regex_exp+"' or " 207 | } 208 | else //this is handling for non-sensitive fields 209 | encr_invalid_sql_statement = encr_invalid_sql_statement+"json."+ field+";" 210 | } 211 | 212 | 213 | //print(encr_invalid_sql_statement) 214 | 215 | //This is the part to build a struct SQL statement (hierchary) so that we can convert back to hierarchical Json format 216 | var after_struct="struct(" 217 | var before_struct="struct(" 218 | var singles="" 219 | for (word <- encr_invalid_sql_statement.split(";")) { 220 | if (word.contains("after")){ 221 | after_struct=after_struct+word+"," 222 | }else if(word.contains("before")){ 223 | before_struct=before_struct+word+"," 224 | }else{ 225 | singles = singles+word+"," 226 | } 227 | } 228 | 229 | 230 | after_struct = after_struct.slice(0, after_struct.length - 1) +") as after" 231 | before_struct = before_struct.slice(0, before_struct.length - 1) +") as before," 232 | where_statement = where_statement.slice(0, where_statement.length - 4) 233 | encr_invalid_sql_statement = "Select CAST(json.current_ts AS STRING) as Key, struct("+singles+before_struct +after_struct+ ") AS value from " + encr_invalid_tbl_name + where_statement 234 | // print(encr_invalid_sql_statement) 235 | 236 | val encr_invalid_dfout = spark.sql(encr_invalid_sql_statement) 237 | // #write invalid data to a delta table 238 | // # encr_invalid_dfout.writeStream.option("checkpointLocation", checkpointLocation+"/error").option("path",error_logtbl_path).table(error_logtbl) 239 | encr_invalid_dfout.writeStream.format("delta").option("checkpointLocation", checkpointLocation+"/error").option("path",error_logtbl_path).start() 240 | 241 | // #this option is to write invalid data to a EH topic 242 | // # EH_ENC_SASL = "kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username='$ConnectionString' password='Endpoint=sb://"+eh_account+".servicebus.windows.net;SharedAccessKeyName="+en_access_policy_name+";SharedAccessKey="+eh_en_key+"';" 243 | 244 | // # encr_invalid_dfout_df =encr_invalid_dfout.writeStream \ 245 | // # .format("kafka") \ 246 | // # .option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS) \ 247 | // # .option("topic", "error_topic_here") \ 248 | // # .option("kafka.sasl.mechanism","PLAIN") \ 249 | // # .option("kafka.security.protocol","SASL_SSL") \ 250 | // # .option("kafka.sasl.jaas.config", EH_ENC_SASL ) \ 251 | // # .option("checkpointLocation", checkpointLocation) \ 252 | // # .start() 253 | 254 | 255 | // COMMAND ---------- 256 | 257 | display(spark.readStream.format("delta").load(error_logtbl_path)) 258 | //This is to display if any bad records are caught 259 | 260 | 261 | // COMMAND ---------- 262 | 263 | // #This part is to use regex_replace to replace invalid characters in columns that are supposed to be encrypted before passing to encryption function to minimize problem with encryption due to valid format. 264 | import org.apache.spark.sql.functions._ 265 | 266 | // dbutils.fs.rm(checkpointLocation, True) 267 | val encr_valid_tbl_name = "encr_valid_tbl" //#used for processing data that passed invalidation criteria 268 | 269 | val encrypt_func="encrypt" 270 | //Prepare for consuming from eventHub - create listener 271 | val BOOTSTRAP_SERVERS = eh_account+".servicebus.windows.net:9093" 272 | val EH_SASL = "kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username=\"$ConnectionString\" password=\"Endpoint=sb://"+eh_account+".servicebus.windows.net;SharedAccessKeyName="+access_policy_name+";SharedAccessKey="+eh_key+"\" ;" 273 | // print(EH_SASL) 274 | 275 | //Consumer group for taking care of valid senstive data 276 | val GROUP_ID_valid_DATA = "$GOOD_DATA" 277 | val encr_valid_df = spark 278 | .readStream 279 | .format("kafka") 280 | .option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS) 281 | .option("subscribe", eh_topic) 282 | .option("kafka.sasl.mechanism","PLAIN") 283 | .option("kafka.security.protocol","SASL_SSL") 284 | .option("kafka.sasl.jaas.config", EH_SASL ) 285 | .option("kafka.request.timeout.ms", "60000") 286 | .option("kafka.session.timeout.ms", "60000") 287 | .option("startingOffsets", "earliest") 288 | .option("kafka.group.id", GROUP_ID_valid_DATA) 289 | .option("failOnDataLoss", "false") 290 | .load() 291 | .select(col("timestamp"), col("value").cast("STRING").alias("value")) 292 | .select(col("timestamp"), from_json(col("value"), schema).alias("json")) 293 | 294 | //#temporary tables to process sql stmts 295 | encr_valid_df.createOrReplaceTempView(encr_valid_tbl_name) 296 | 297 | //#Constructing the SQL statement for good data 298 | var encr_valid_sql_statement: String = "" 299 | //# #Where statement to remove null value object 300 | //#James: not where but it should be regexp_replace((column, pattern, replacement)) to remove/replace special character before apply encrypt_func 301 | //# 302 | for (field <- schema_fields) { 303 | if (field_encryption_map.keys.contains((field.split("\\.").last.toLowerCase()))){ 304 | //#this is handling for sensitive fields 305 | // #in case of streaming above, the field need to be accessed as json.field_name 306 | // #EH 307 | //#filter good data based on validation criteria provided by infosec team and encrypt it 308 | //# encr_valid_sql_statement = encr_valid_sql_statement +"(CASE WHEN json."+field+" IS NOT NULL AND LENGTH(json."+field+") <> 0 THEN " +encrypt_func +"(json."+field+",'"+ field_encryption_map.keys.contains((field.split("\\.").last.toLowerCase()))+"') ELSE json."+field+" END) " +field.split("\\.").last+ ";" 309 | //#To apply regex to remove special characters, you can do it as follow 310 | val encr_key = field_encryption_map.getOrElse(field.split("\\.").last.toLowerCase(),"").toString 311 | val regex_exp = regex_map.getOrElse(encr_key,"").toString 312 | 313 | encr_valid_sql_statement = encr_valid_sql_statement +"(CASE WHEN json."+field+" IS NOT NULL AND LENGTH(json."+field+") <> 0 AND " +"json."+field+" !rlike '"+regex_exp+"' THEN " +encrypt_func +"(json."+field+",'"+encr_key+"') ELSE NULL END) " +field.split("\\.").last+ ";" 314 | 315 | } else { 316 | //#this is handling for non-sensitive fields 317 | encr_valid_sql_statement = encr_valid_sql_statement+"json."+ field+";" 318 | } 319 | } 320 | 321 | // print (encr_valid_sql_statement) 322 | //#This is the part to build a struct SQL statement (hierchary) so that we can convert back to hierarchical Json format 323 | var after_struct="struct(" 324 | var before_struct="struct(" 325 | var singles="" 326 | for (word <- encr_valid_sql_statement.split(";")) { 327 | if (word.contains("after")){ 328 | after_struct=after_struct+word+"," 329 | }else if(word.contains("before")){ 330 | before_struct=before_struct+word+"," 331 | }else{ 332 | singles = singles+word+"," 333 | } 334 | } 335 | 336 | 337 | after_struct = after_struct.slice(0, after_struct.length - 1) +") as after" 338 | before_struct = before_struct.slice(0, before_struct.length - 1) +") as before," 339 | 340 | encr_valid_sql_statement = "Select CAST(json.current_ts AS STRING) as Key, to_json(struct("+singles+before_struct +after_struct+ ")) AS value from " + encr_valid_tbl_name 341 | 342 | val encr_valid_dfout = spark.sql(encr_valid_sql_statement) 343 | 344 | // display(encr_valid_dfout) 345 | //#Stream this data to Databasew 346 | 347 | //# #Write out the encrypted dataset 348 | val EH_ENC_SASL = "kafkashaded.org.apache.kafka.common.security.plain.PlainLoginModule required username='$ConnectionString' password='Endpoint=sb://"+eh_account+".servicebus.windows.net;SharedAccessKeyName="+en_access_policy_name+";SharedAccessKey="+eh_en_key+"';" 349 | 350 | val encr_valid_dfout_df =encr_valid_dfout.writeStream 351 | .format("kafka") 352 | .option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS) 353 | .option("topic", eh_en_topic) 354 | .option("kafka.sasl.mechanism","PLAIN") 355 | .option("kafka.security.protocol","SASL_SSL") 356 | .option("kafka.sasl.jaas.config", EH_ENC_SASL ) 357 | .option("checkpointLocation", checkpointLocation+"/main_process") 358 | .start() -------------------------------------------------------------------------------- /ml_batch_scoring/batch_scoring_ml_models.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","source":["# Batch Scoring ML Model with Spark Pandas UDF\n\nAfter you train your ML model, how do you use it to perform batch scoring of a very large dataset? How would you do this in parallel to minimize scoring time?.\nIf you trained your model with Spark ML then this is not a problem as Spark ML model is designed to score Spark distributed data objects. However, if Spark ML is not what you used due to its limitation and your model happens to be SKlearn or a Tensorflow model or is in the form of published web service (your own model or a cognitive API) then there's no straight forward way to do this.\n \nIn this post, I'll show two examples of how batch scoring can be applied using the relatively new Pandas UDF function in Spark 2.x:\nBatch scoring from cognitive API (or your own ML model published as API)\nBatch scoring from persisted SKlearn model"],"metadata":{}},{"cell_type":"markdown","source":["## Scoring from persisted sklearn model"],"metadata":{}},{"cell_type":"code","source":["#This is the example to load a sklearn model from pkl file and use it to score mini batches of data from Spark Streaming.\n#But the dataset does not need to be streaming, it can be any Spark dataset\ndf_stream = spark.readStream.format(\"delta\").table(\"events\")\ndf_stream.withWatermark(\"aiv_epoch_start\", \"10 minutes\").registerTempTable(\"amazon_msess_events\")\nfrom pyspark.sql.types import *\nfrom pyspark.sql.functions import pandas_udf, PandasUDFType,window\nimport datetime\n\nimport pandas as pd\nimport numpy as np\nfrom sklearn.externals import joblib\nimport pandas.errors\n\njdbcHostname = 'DBhost'\njdbcUsername =''\njdbcPassword = ''\ntable = ''\n\njdbcDatabase = \"DBname\"\njdbcPort = 1433\njdbcUrl = \"jdbc:sqlserver://{0}:{1};database={2}\".format(jdbcHostname, jdbcPort, jdbcDatabase)\nprint(jdbcUrl)\nconnectionProperties = {\n \"user\" : jdbcUsername,\n \"password\" : jdbcPassword,\n \"driver\" : \"com.microsoft.sqlserver.jdbc.SQLServerDriver\"\n}\n\n\n\nreq_eval = spark.sql(\"select ng_start_time, ng_stop_time, cast(ng_stop_time as long) as start_time,cast(ng_stop_time as long) as stop_time, (cast(ng_stop_time as long) - cast(ng_start_time as long)) as duration, total_count, manifest_requests, avg_request_time,avg_tdwait_time,max_request,max_tdwait, max_mbps, min_mbps,avg_mbps,avg_rtt,bytes,total_2xx,total_3xx,total_4xx,total_5xx,td_lte_1s_count,tc_lte_1s_count,\ttd_gt_1s_lte_2s_count,\ttc_gt_1s_lte_2s_count,\ttd_gt_2s_lte_4s_count,\ttc_gt_2s_lte_4s_count,\ttd_gt_4s_lte_6s_count,\ttc_gt_4s_lte_6s_count,\ttd_gt_6s_lte_8s_count,\ttc_gt_6s_lte_8s_count,\ttd_gt_8s_lte_10s_count,\ttc_gt_8s_lte_10s_count,\ttd_gt_10s_lte_30s_count,\ttc_gt_10s_lte_30s_count,\ttd_gt_30s_lte_60s_count,\ttc_gt_30s_lte_60s_count,\ttd_gt_60s_count\ttc_gt_60s_count,asn, cc, state,cast(max_err as string) as max_err, (case when aiv_num_rebuffers >0 then 1 else 0 end) as aiv_num_rebuffers from amazon_msess_events\")\n\n\nschema = StructType([\n StructField(\"ng_start_time\", TimestampType()),\n StructField(\"ng_stop_time\", TimestampType()),\n StructField(\"cc\", StringType()),\n StructField(\"state\", StringType()),\n StructField(\"asn\", DoubleType()),\n\n StructField(\"predicted_buffering\", DoubleType()),\n StructField(\"actual_buffering\", DoubleType()),\n\n])\n\n@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)\n\n\ndef predict_buffering(panda_buffer):\n from sklearn.externals import joblib\n model = joblib.load('/dbfs/mnt/regr.joblib') \n\n #if datasize is not long enough, return default result\n #score threshold whether to score as positive\n threshold = 0.45\n states=['SC','AZ','LA','MN','NJ','DC','OR','unknown','VA','RI','KY','WY','NH','MI','NV','WI','ID','CA','CT','NE','MT','NC','VT','MD','DE','MO','IL','ME','WA','ND','MS','AL','IN','OH','TN','IA','NM','PA','SD','NY','TX','WV','GA','MA','KS','CO','FL','AK','AR','OK','UT','HI']\n countries =['cr','pr','us','vi','cl','il','ro','jp','ag','vn','pl','vg','za','sk','mu','pt','ke','ni','sg','ae','iq','hk','be','qa','bz','gb','me','ec','sa','co','tr','de','is','tt','lu','br', 'im', 'gt', 'bb', 'jo', 'es', 'hr', 'eu', 'dj', 'kr', 'it', 'uy', 'af', 'pe', 'vc', 'ar', 'sv', 'jm', 'ph', 'nl', 'bo', 'gp', 'hn', 'hu', 'ca', 'al', 'bm', 've', 'gu','ee','nz','si','gr','aw','ru','mt','th','cw','ch','ma','gh','do','lt','ht','pa','no','bg','cy','at','cz','ua','dm','mx','bs','ai','ky','fr','se','ie','dk','gd','id','bh','gg','gy','fi']\n\n errors =[412, 206, 500, 504, 502, 503, 400, 403, 404, 591, 408, 200, 260, 499]\n ori_states = panda_buffer['state'].values.tolist()\n ori_cc=panda_buffer['cc'].values.tolist()\n panda_buffer['ng_start_time'] = pd.to_datetime(panda_buffer['start_time'], unit='ms', utc=True)\n panda_buffer['ng_stop_time'] = pd.to_datetime(panda_buffer['stop_time'],unit='ms', utc=True)\n start_time = panda_buffer['ng_start_time'].values.tolist()\n stop_time = panda_buffer['ng_stop_time'].values.tolist()\n del panda_buffer['ng_start_time']\n del panda_buffer['ng_stop_time']\n del panda_buffer['start_time']\n del panda_buffer['stop_time']\n# del panda_buffer['window']\n panda_buffer['state']= panda_buffer['state'].astype('category',categories=states)\n temp = pd.get_dummies(panda_buffer['state'], prefix='state')\n panda_buffer = pd.concat([panda_buffer, temp], axis = 1)\n del panda_buffer['state'], temp\n\n panda_buffer['cc']= panda_buffer['cc'].astype('category',categories=countries)\n temp = pd.get_dummies(panda_buffer['cc'], prefix='cc')\n panda_buffer = pd.concat([panda_buffer, temp], axis = 1)\n del panda_buffer['cc'], temp\n\n panda_buffer['max_err']= panda_buffer['max_err'].astype('category',categories=errors)\n temp = pd.get_dummies(panda_buffer['max_err'], prefix='max_err')\n panda_buffer = pd.concat([panda_buffer, temp], axis = 1)\n del panda_buffer['max_err'], temp\n\n # panda_buffer['asn']= panda_buffer['asn'].astype('category',categories=asns)\n # temp = pd.get_dummies(panda_buffer['asn'], prefix='asn')\n # panda_buffer = pd.concat([panda_buffer, temp], axis = 1)\n # del panda_buffer['asn'], temp\n asn =panda_buffer['asn'].values.tolist()\n y = panda_buffer['aiv_num_rebuffers'].values.tolist()\n X_dataset= panda_buffer.copy()\n del X_dataset['aiv_num_rebuffers']\n X = X_dataset.values\n\n y_pred= model.predict(X).tolist()\n def score(y_pred, threshold=0.5):\n out = []\n for item in y_pred:\n if item[1] >=threshold:\n out.append(1)\n else:\n out.append(0)\n return out\n out_pred = model.predict_proba(X).tolist()\n y_pred = score(out_pred, threshold)\n\n\n return pd.DataFrame({'ng_start_time': start_time,'ng_stop_time': stop_time, 'cc':ori_cc,'state':ori_states,'asn':asn, 'predicted_buffering': y_pred, 'actual_buffering': y})\n\n\n\n\n\n\n\nagg= req_eval.groupby(window(\"ng_start_time\", \"5 minutes\", \"10 seconds\"))\noutput =agg.apply(predict_buffering)\n\n#here you can set the mode = \"overwrite\" (inside JDBC) if you only want to see latest data in output table. Otherwise can you set the mode ='append', then in the query of client tool, you need to select latest record sort by start_time, end_time\n#side note: if you want to ouput a regular Dataframe, not stream dataframe to SQL table, use syntax: yourDF.write.jdbc(url=jdbcUrl, table=table, mode=\"append\", properties=connectionProperties)\n#IMPORTANT: Please make sure you create output table in advance in SQL Server. There's a bug for non-numeric column to be included as colum index if you let the driver create table automatically.\noutput.registerTempTable(\"prediction_out\")\n\n "],"metadata":{},"outputs":[],"execution_count":3},{"cell_type":"markdown","source":["## Batch Scoring from ML rest API"],"metadata":{}},{"cell_type":"code","source":["\n#This Example use the Azure's Anomaly Finder API to score against Spark data in a batch fashion.\n#By doing this, we can score records in parallel for supposed to be real time API.\n\n#This function has input of a spark data frame which include mulitple variables aligned in the same timeline to test for anomalies simultaneously\n# @columns: This is the list of numeric column to detect anomaly\n# @cat_columns: this the list of categorical columns to include for query purpose\n# @Groupby_cols: the list of columns for grouping by\ndef anomaly_df_finder(df,columns, cat_columns,groupby_cols,max_ratio=0.25,sens=95, rare_as_exception =False,higher_value_better_cols=None ):\n from pyspark.sql.types import StructType,StructField,StringType,BooleanType, TimestampType,DoubleType\n from pyspark.sql.functions import pandas_udf, PandasUDFType\n import datetime\n from pyspark.sql.functions import window\n\n import pandas as pd\n import json\n import requests\n import numpy as np\n import time\n\n schema = StructType([\n StructField(\"timestamp\", TimestampType()),\n StructField(\"col_anomaly_count\", DoubleType()),\n\n ])\n for cat_column in cat_columns:\n schema.add(cat_column,StringType() )\n\n for column in columns:\n schema.add(column+'_IsAnomaly_Pos',BooleanType() )\n schema.add(column+'_IsAnomaly_Neg',BooleanType() )\n schema.add(column+'_is_anomaly',BooleanType() )\n schema.add(column+'_upper_value',DoubleType() )\n schema.add(column+'_lower_value',DoubleType() )\n schema.add(column+'_expected_value',DoubleType() )\n schema.add(column+'_value',DoubleType() )\n\n @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)\n\n\n def detect_anomaly(df):\n MaxAnomalyRatio=max_ratio\n Sensitivity=sens\n\n # columns = ['avg_rtt', 'avg_tdwait']\n output_dict = {'timestamp': df['Timestamp']}\n output_dict['col_anomaly_count']=0\n for cat_column in cat_columns:\n output_dict[cat_column]= df[cat_column]\n\n #if datasize is not long enough, return default result\n if df.shape[0] <12:\n for column in columns:\n output_dict[column+'_IsAnomaly_Pos']=rare_as_exception\n output_dict[column+'_IsAnomaly_Neg']=False\n output_dict[column+'_is_anomaly']=rare_as_exception\n output_dict[column+'_upper_value']=-999\n output_dict[column+'_lower_value']=-999\n output_dict[column+'_value']=df[column]\n output_dict[column+'_expected_value']=df[column]\n output_dict['col_anomaly_count']=output_dict['col_anomaly_count']+np.array(output_dict[column+'_IsAnomaly_Pos'])\n\n return pd.DataFrame(data=output_dict)\n\n\n\n\n endpoint = 'https://westus2.api.cognitive.microsoft.com/anomalyfinder/v2.0/timeseries/entire/detect'\n subscription_key = '' #Key for version 2\n\n def detect(endpoint, subscription_key, request_data):\n headers = {'Content-Type': 'application/json', 'Ocp-Apim-Subscription-Key': subscription_key}\n response = requests.post(endpoint, data=json.dumps(request_data), headers=headers)\n #Dealing with threshold exceeding exception, retry util we can call the api\n while response.status_code == 429:\n time.sleep(1)\n response = requests.post(endpoint, data=json.dumps(request_data), headers=headers)\n\n if response.status_code == 200:\n return json.loads(response.content.decode(\"utf-8\"))\n\n else:\n # print(response.status_code)\n raise Exception(str(response.status_code)+\":\" +response.text + json.dumps(request_data))\n\n #Loop for each column in \n# df.sort_values(by= 'Timestamp', inplace=True)\n \n for column in columns:\n df_out = df[['Timestamp', column]].copy()\n df_out[\"Value\"] = df_out[column]\n del df_out[column]\n df_out.Timestamp = pd.to_datetime(df.Timestamp, unit ='ms',utc =True)\n json_data = df_out.to_json(orient='records',date_format ='iso',date_unit ='s')\n json_loaded = json.loads(json_data)\n json_loaded = {\"Granularity\":\"minutely\", \"CustomInterval\":5,\"MaxAnomalyRatio\": MaxAnomalyRatio, \"Sensitivity\": Sensitivity, \"Series\":json_loaded }\n # json_loaded = {\"Period\": None, \"Points\":json_loaded }\n\n try:\n result = detect(endpoint, subscription_key, json_loaded)\n except Exception as e:\n output_dict[column+'_IsAnomaly_Pos']=rare_as_exception\n output_dict[column+'_IsAnomaly_Neg']=False\n output_dict[column+'_is_anomaly']=rare_as_exception\n output_dict[column+'_upper_value']=-999\n output_dict[column+'_lower_value']=-999\n output_dict[column+'_value']=df[column]\n output_dict[column+'_expected_value']=df[column]\n continue\n\n\n output_dict[column+'_IsAnomaly_Pos']=result['IsPositiveAnomaly']\n output_dict[column+'_IsAnomaly_Neg']=result['IsNegativeAnomaly']\n output_dict[column+'_is_anomaly']=result['IsAnomaly']\n if higher_value_better_cols and column in higher_value_better_cols:\n output_dict['col_anomaly_count']=output_dict['col_anomaly_count']+np.array(result['IsNegativeAnomaly'])\n else:\n output_dict['col_anomaly_count']=output_dict['col_anomaly_count']+np.array(result['IsPositiveAnomaly'])\n\n# output_dict['col_anomaly_count']=output_dict['col_anomaly_count']+np.array(result['IsAnomaly'])\n\n output_dict[column+'_upper_value']=result['UpperMargins']\n output_dict[column+'_lower_value']=result['LowerMargins']\n output_dict[column+'_value']=df[column]\n output_dict[column+'_expected_value']=result['ExpectedValues']\n# output_dict[column+'_upper_value'] = np.array(result['ExpectedValues'])+(100-Sensitivity)*np.array(result['UpperMargins'])\n# output_dict[column+'_lower_value'] = np.array(result['ExpectedValues'])-(100-Sensitivity)*np.array(result['LowerMargins']) \n output_dict[column+'_upper_value'] = np.array(result['UpperMargins'])\n output_dict[column+'_lower_value'] = np.array(result['LowerMargins']) \n\n output = pd.DataFrame(data=output_dict)\n\n return output\n agg=df.groupby(groupby_cols)\n df_output =agg.apply(detect_anomaly)\n\n return df_output"],"metadata":{},"outputs":[],"execution_count":5}],"metadata":{"name":"batch_scoring_ml_models","notebookId":3972720607427244},"nbformat":4,"nbformat_minor":0} --------------------------------------------------------------------------------