├── CLAUDE.md ├── tests ├── __init__.py ├── samples │ └── json │ │ └── zeek-http.jsonl ├── test_common.py ├── README.md ├── test_restapi.py ├── test_splunk.py ├── test_mssentinel.py ├── test_mssentinel_stream_reader.py └── test_mssentinel_reader.py ├── cyber_connectors ├── __init__.py ├── common.py ├── RestApi.py ├── Splunk.py └── MsSentinel.py ├── .github ├── dependabot.yml └── workflows │ └── onpush.yml ├── pytest.ini ├── docker-compose.yaml ├── NOTICE ├── utils ├── http_server.py └── README.md ├── .gitignore ├── pyproject.toml ├── LICENSE ├── AGENTS.md └── README.md /CLAUDE.md: -------------------------------------------------------------------------------- 1 | AGENTS.md -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests for cyber-spark-data-connectors.""" 2 | -------------------------------------------------------------------------------- /cyber_connectors/__init__.py: -------------------------------------------------------------------------------- 1 | from cyber_connectors.MsSentinel import AzureMonitorDataSource, MicrosoftSentinelDataSource 2 | from cyber_connectors.RestApi import RestApiDataSource 3 | from cyber_connectors.Splunk import SplunkDataSource 4 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | ignore: 6 | - dependency-name: "grpcio*" 7 | - dependency-name: "googleapis-common-protos" 8 | schedule: 9 | interval: "weekly" 10 | day: "sunday" 11 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -s -p no:warnings 3 | log_cli = 1 4 | log_cli_level = INFO 5 | log_cli_format = [pytest][%(asctime)s][%(levelname)s][%(module)s][%(funcName)s] %(message)s 6 | log_cli_date_format = %Y-%m-%d %H:%M:%S 7 | log_level = INFO 8 | spark_options = 9 | spark.sql.catalogImplementation: in-memory 10 | spark.sql.session.timeZone: UTC 11 | spark.sql.shuffle.partitions: 1 12 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | networks: 4 | splunk: 5 | 6 | services: 7 | 8 | splunk: 9 | image: splunk/splunk:latest 10 | environment: 11 | SPLUNK_START_ARGS: "--accept-license" 12 | SPLUNK_PASSWORD: ${SPLUNK_PASSWORD} 13 | SPLUNK_HEC_TOKEN: ${SPLUNK_HEC_TOKEN} 14 | ports: 15 | - 8000:8000 16 | - 8088:8088 17 | volumes: 18 | - ${PWD}/default.yml:/tmp/defaults/default.yml 19 | networks: 20 | - splunk 21 | 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Alex Ott (alexott at gmail.com) 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /utils/http_server.py: -------------------------------------------------------------------------------- 1 | # Generate simple REST API server for testing purposes 2 | # Usage: python3 http_server.py 3 | # It should receive POST requests with JSON data and return JSON response 4 | 5 | 6 | from http.server import BaseHTTPRequestHandler, HTTPServer 7 | 8 | 9 | class MyHandler(BaseHTTPRequestHandler): 10 | def do_POST(self): 11 | content_length = int(self.headers['Content-Length']) 12 | post_data = self.rfile.read(content_length) 13 | print(f'Received POST request: {post_data}') 14 | self.send_response(200) 15 | self.end_headers() 16 | self.wfile.write(b'{"status": "ok"}') 17 | 18 | 19 | def run(server_class=HTTPServer, handler_class=MyHandler, port=8001): 20 | server_address = ('', port) 21 | httpd = server_class(server_address, handler_class) 22 | print(f'Starting httpd on port {port}') 23 | httpd.serve_forever() 24 | 25 | 26 | if __name__ == '__main__': 27 | run() 28 | -------------------------------------------------------------------------------- /tests/samples/json/zeek-http.jsonl: -------------------------------------------------------------------------------- 1 | {"ts":1362692526.939527,"uid":"Ctoe9Q2STDWcguMFZ8","id.orig_h":"141.142.228.5","id.orig_p":59856,"id.resp_h":"192.150.187.43","id.resp_p":80,"trans_depth":1,"method":"GET","host":"bro.org","uri":"/download/CHANGES.bro-aux.txt","version":"1.1","user_agent":"Wget/1.14 (darwin12.2.0)","request_body_len":0,"response_body_len":4705,"status_code":200,"status_msg":"OK","tags":[],"resp_fuids":["FMnxxt3xjVcWNS2141"],"resp_mime_types":["text/plain"]} 2 | {"ts":1445000735.104954,"uid":"CzBcXU2g1gSfhJ2N2c","id.orig_h":"10.1.9.63","id.orig_p":63526,"id.resp_h":"54.175.222.246","id.resp_p":80,"trans_depth":1,"method":"GET","host":"httpbin.org","uri":"/response-headers?Content-Type=application/octet-stream; charset=UTF-8&Content-Disposition=attachment; filename=\"test.json\"","version":"1.1","user_agent":"curl/7.45.0","request_body_len":0,"response_body_len":191,"status_code":200,"status_msg":"OK","tags":[],"resp_fuids":["FiokML36uuy5agr5x3"],"resp_filenames":["test.json"],"resp_mime_types":["text/json"]} 3 | {"ts": 1591367999.512593, "uid": "C5bLoe2Mvxqhawzqqd", "id.orig_h": "192.168.4.76", "id.orig_p": 46378, "id.resp_h": "31.3.245.133", "id.resp_p": 80, "trans_depth": 1, "method": "GET", "host": "testmyids.com", "uri": "/", "version": "1.1", "user_agent": "curl/7.47.0", "request_body_len": 0, "response_body_len": 39, "status_code": 200, "status_msg": "OK", "tags": [], "resp_fuids": ["FEEsZS1w0Z0VJIb5x4"], "resp_mime_types": ["text/plain"]} 4 | -------------------------------------------------------------------------------- /cyber_connectors/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from datetime import date, datetime 4 | 5 | from pyspark.sql.datasource import WriterCommitMessage 6 | 7 | 8 | @dataclass 9 | class SimpleCommitMessage(WriterCommitMessage): 10 | partition_id: int 11 | count: int 12 | 13 | 14 | class DateTimeJsonEncoder(json.JSONEncoder): 15 | def default(self, o): 16 | if isinstance(o, datetime) or isinstance(o, date): 17 | return o.isoformat() 18 | 19 | return json.JSONEncoder.default(self, o) 20 | 21 | 22 | def get_http_session(retry: int = 5, additional_headers: dict = None, retry_on_post: bool = False): 23 | import requests 24 | from requests.adapters import HTTPAdapter 25 | from urllib3.util import Retry 26 | 27 | session = requests.Session() 28 | if additional_headers: 29 | session.headers.update(additional_headers) 30 | 31 | if retry > 0: 32 | allowed_methods = ["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"] 33 | if retry_on_post: 34 | allowed_methods.append("POST") 35 | retry_strategy = Retry( 36 | total=retry, 37 | status_forcelist=[429, 500, 502, 503, 504], 38 | allowed_methods=allowed_methods, 39 | ) 40 | adapter = HTTPAdapter(max_retries=retry_strategy) 41 | session.mount("http://", adapter) 42 | session.mount("https://", adapter) 43 | 44 | return session 45 | -------------------------------------------------------------------------------- /.github/workflows/onpush.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | pull_request: 5 | types: [ opened, synchronize ] 6 | push: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | max-parallel: 4 14 | matrix: 15 | os: [ 'ubuntu-22.04' ] 16 | python-version: [ '3.10', '3.12' ] 17 | poetry-version: [ '1.3' ] 18 | 19 | steps: 20 | - uses: actions/checkout@v1 21 | - name: Install poetry 22 | run: pipx install poetry 23 | - uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: 'poetry' 27 | - name: Install dependencies 28 | run: poetry install 29 | - name: Run tests 30 | run: poetry run pytest --junitxml=test-unit.xml --cov=cyber_connectors --cov-report=term-missing:skip-covered --cov-report xml:cov.xml -vv tests | tee pytest-coverage.txt 31 | # - name: Publish Test Report 32 | # uses: mikepenz/action-junit-report@v3 33 | # if: always() # always run even if the previous step fails 34 | # with: 35 | # report_paths: './test-unit.xml' 36 | # - name: Pytest coverage comment 37 | # uses: MishaKav/pytest-coverage-comment@main 38 | # if: always() # always run even if the previous step fails 39 | # with: 40 | # pytest-coverage-path: ./pytest-coverage.txt 41 | # pytest-xml-coverage-path: ./cov.xml 42 | # title: Unit tests code coverage 43 | # badge-title: Coverage 44 | # hide-badge: false 45 | # hide-report: false 46 | # create-new-comment: false 47 | # hide-comment: false 48 | # report-only-changed-files: false 49 | # remove-link-from-badge: false 50 | # junitxml-path: ./test-unit.xml 51 | # junitxml-title: Unit tests summary 52 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Utilities for Cyber Spark Data Connectors 2 | 3 | This directory contains utility scripts for manual testing and development of the cyber-spark-data-connectors package. 4 | 5 | ## Contents 6 | 7 | ### `http_server.py` 8 | 9 | A simple HTTP server for manual testing of the REST API data source. 10 | 11 | **Purpose:** 12 | - Manual/integration testing of the REST API connector 13 | - Debugging HTTP payload format 14 | - Verifying data is correctly sent to REST endpoints 15 | 16 | **Usage:** 17 | 18 | 1. Start the test server: 19 | ```bash 20 | python3 utils/http_server.py 21 | ``` 22 | 23 | The server will start on port 8001 and print received POST requests to the console. 24 | 25 | 2. In a separate terminal, run your Spark code with the REST API data source: 26 | ```python 27 | from cyber_connectors import RestApiDataSource 28 | spark.dataSource.register(RestApiDataSource) 29 | 30 | df = spark.range(10) 31 | df.write.format("rest").mode("overwrite") \ 32 | .option("url", "http://localhost:8001/") \ 33 | .save() 34 | ``` 35 | 36 | 3. Observe the server output to verify the data being sent. 37 | 38 | **Example Output:** 39 | ``` 40 | Starting httpd on port 8001 41 | Received POST request: b'{"id": 0}' 42 | Received POST request: b'{"id": 1}' 43 | Received POST request: b'{"id": 2}' 44 | ... 45 | ``` 46 | 47 | **Note:** 48 | - This is for **manual testing only** and is not used by the automated unit test suite. 49 | - The automated tests mock HTTP requests for speed and reliability. 50 | - You can modify the port by editing the `port` parameter in the script or passing it as an argument. 51 | 52 | ## Adding More Utilities 53 | 54 | Feel free to add more utility scripts here for: 55 | - Mock servers for Splunk or other endpoints 56 | - Data generation tools 57 | - Testing helpers 58 | - Development utilities 59 | 60 | Keep utility scripts separate from the automated test suite in the `tests/` directory. 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # TODO: add Azure stuff as optional dependencies 2 | # https://stackoverflow.com/questions/60971502/python-poetry-how-to-install-optional-dependencies 3 | [tool.poetry] 4 | name = "cyber-spark-data-connectors" 5 | version = "0.0.5" 6 | description = "Cybersecurity-related custom data connectors for Spark (readers and writers)." 7 | authors = ["Alex Ott "] 8 | packages = [ 9 | {include = "cyber_connectors"} 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.9,<3.14" 14 | requests = "^2.32.5" 15 | azure-monitor-ingestion = "^1.1.0" 16 | azure-identity = "^1.25.1" 17 | azure-monitor-query = "^2.0.0" 18 | 19 | [tool.poetry.group.test.dependencies] 20 | pytest = "^8.4.2" 21 | pytest-cov = "^7.0.0" 22 | pytest-spark = "^0.8.0" 23 | 24 | [tool.poetry.group.dev.dependencies] 25 | ruff = "^0.14.9" 26 | pyspark = {extras = ["sql,connect"], version = "4.0.1"} 27 | grpcio = ">=1.67.0" 28 | grpcio-status = ">=1.67.0" 29 | googleapis-common-protos = ">=1.65.0" 30 | #delta-spark = "^3.2.1" 31 | chispa = "^0.11.1" 32 | mypy = "^1.19.0" 33 | types-requests = "^2.32.4.20250913" 34 | pyarrow = "^21.0.0" 35 | 36 | [build-system] 37 | requires = ["poetry-core>=1.0.0"] 38 | build-backend = "poetry.core.masonry.api" 39 | 40 | [tool.ruff] 41 | # Exclude a variety of commonly ignored directories. 42 | exclude = [ 43 | ".bzr", 44 | ".direnv", 45 | ".eggs", 46 | ".git", 47 | ".git-rewrite", 48 | ".hg", 49 | ".ipynb_checkpoints", 50 | ".mypy_cache", 51 | ".nox", 52 | ".pants.d", 53 | ".pyenv", 54 | ".pytest_cache", 55 | ".pytype", 56 | ".ruff_cache", 57 | ".svn", 58 | ".tox", 59 | ".venv", 60 | ".vscode", 61 | "__pypackages__", 62 | "_build", 63 | "buck-out", 64 | "build", 65 | "dist", 66 | "node_modules", 67 | "site-packages", 68 | "venv", 69 | "notebooks", 70 | "tests", 71 | ] 72 | 73 | # Same as Black. 74 | line-length = 120 75 | indent-width = 4 76 | 77 | # Assume Python 3.8 78 | target-version = "py38" 79 | 80 | [tool.ruff.lint] 81 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 82 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 83 | # McCabe complexity (`C901`) by default. 84 | select = [ 85 | "E", # pycodestyle errors 86 | "W", # pycodestyle warnings 87 | "F", # pyflakes 88 | "I", # isort 89 | "B", # flake8-bugbear 90 | "C4", # flake8-comprehensions 91 | "N", # PEP8 naming convetions 92 | "D" # pydocstyle 93 | ] 94 | ignore = [ 95 | "D103", # No docstring in public function 96 | "N812", # Lowercase `functions` imported as non-lowercase `F` 97 | "F405", # `assert_df_equality` may be undefined, or defined from star imports 98 | ] 99 | 100 | # Allow fix for all enabled rules (when `--fix`) is provided. 101 | fixable = ["ALL"] 102 | unfixable = [] 103 | 104 | # Allow unused variables when underscore-prefixed. 105 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 106 | 107 | [tool.ruff.format] 108 | # Like Black, use double quotes for strings. 109 | quote-style = "double" 110 | 111 | # Like Black, indent with spaces, rather than tabs. 112 | indent-style = "space" 113 | 114 | # Like Black, respect magic trailing commas. 115 | skip-magic-trailing-comma = false 116 | 117 | # Like Black, automatically detect the appropriate line ending. 118 | line-ending = "auto" 119 | 120 | # Enable auto-formatting of code examples in docstrings. Markdown, 121 | # reStructuredText code/literal blocks and doctests are all supported. 122 | # 123 | # This is currently disabled by default, but it is planned for this 124 | # to be opt-out in the future. 125 | docstring-code-format = false 126 | 127 | # Set the line length limit used when formatting code snippets in 128 | # docstrings. 129 | # 130 | # This only has an effect when the `docstring-code-format` setting is 131 | # enabled. 132 | docstring-code-line-length = "dynamic" 133 | -------------------------------------------------------------------------------- /cyber_connectors/RestApi.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator 2 | 3 | from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, DataSourceWriter, WriterCommitMessage 4 | from pyspark.sql.types import Row, StructType 5 | 6 | from cyber_connectors.common import DateTimeJsonEncoder, SimpleCommitMessage, get_http_session 7 | 8 | 9 | class RestApiDataSource(DataSource): 10 | """Data source for REST APIs. Right now supports writing to a REST API. 11 | 12 | Write options: 13 | - url: REST API URL 14 | - http_format: (optional) format of the payload (default: json) 15 | - http_method: (optional) HTTP method to use - post or put (default: post) 16 | 17 | """ 18 | 19 | @classmethod 20 | def name(cls): 21 | return "rest" 22 | 23 | # needed only for reads without schema 24 | # def schema(self): 25 | # return "name string, date string, zipcode string, state string" 26 | 27 | def streamWriter(self, schema: StructType, overwrite: bool) -> DataSourceStreamWriter: 28 | return RestApiStreamWriter(self.options) 29 | 30 | def writer(self, schema: StructType, overwrite: bool) -> DataSourceWriter: 31 | return RestApiBatchWriter(self.options) 32 | 33 | 34 | class RestApiWriter: 35 | def __init__(self, options: Dict[str, any]): 36 | self.options = options 37 | self.url = self.options.get("url") 38 | self.payload_format: str = self.options.get("http_format", "json").lower() 39 | self.http_method: str = self.options.get("http_method", "post").lower() 40 | assert self.url is not None 41 | assert self.payload_format == "json" 42 | assert self.http_method in ["post", "put"] 43 | 44 | def write(self, iterator: Iterator[Row]): 45 | """Writes the data, then returns the commit message of that partition. Library imports must be within the method.""" 46 | import json 47 | 48 | from pyspark import TaskContext 49 | 50 | additional_headers = {} 51 | if self.payload_format == "json": 52 | additional_headers.update({"Content-Type": "application/json"}) 53 | # make retry_on_post configurable 54 | s = get_http_session(additional_headers=additional_headers, retry_on_post=True) 55 | context = TaskContext.get() 56 | partition_id = context.partitionId() 57 | cnt = 0 58 | for row in iterator: 59 | cnt += 1 60 | data = "" 61 | if self.payload_format == "json": 62 | data = json.dumps(row.asDict(), cls=DateTimeJsonEncoder) 63 | if self.http_method == "post": 64 | response = s.post(self.url, data=data) 65 | elif self.http_method == "put": 66 | response = s.put(self.url, data=data) 67 | else: 68 | raise ValueError(f"Unsupported http method: {self.http_method}") 69 | print(response.status_code, response.text) 70 | 71 | return SimpleCommitMessage(partition_id=partition_id, count=cnt) 72 | 73 | 74 | class RestApiBatchWriter(RestApiWriter, DataSourceWriter): 75 | def __init__(self, options): 76 | super().__init__(options) 77 | 78 | 79 | class RestApiStreamWriter(RestApiWriter, DataSourceStreamWriter): 80 | def __init__(self, options): 81 | super().__init__(options) 82 | 83 | def commit(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 84 | """Receives a sequence of :class:`WriterCommitMessage` when all write tasks have succeeded, then decides what to do with it. 85 | In this FakeStreamWriter, the metadata of the microbatch(number of rows and partitions) is written into a JSON file inside commit(). 86 | """ 87 | {"num_partitions": len(messages), "rows": sum(m.count for m in messages)} 88 | # with open(os.path.join(self.path, f"{batchId}.json"), "a") as file: 89 | # file.write(json.dumps(status) + "\n") 90 | 91 | def abort(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 92 | """Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some other tasks have failed, then decides what to do with it. 93 | In this FakeStreamWriter, a failure message is written into a text file inside abort(). 94 | """ 95 | # with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file: 96 | # file.write(f"failed in batch {batchId}") 97 | pass 98 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | """Unit tests for common utilities.""" 2 | 3 | from datetime import date, datetime 4 | 5 | from cyber_connectors.common import DateTimeJsonEncoder, SimpleCommitMessage, get_http_session 6 | 7 | 8 | class TestSimpleCommitMessage: 9 | """Test SimpleCommitMessage dataclass.""" 10 | 11 | def test_creation(self): 12 | """Test creating a SimpleCommitMessage.""" 13 | msg = SimpleCommitMessage(partition_id=0, count=100) 14 | assert msg.partition_id == 0 15 | assert msg.count == 100 16 | 17 | def test_attributes(self): 18 | """Test that SimpleCommitMessage has correct attributes.""" 19 | msg = SimpleCommitMessage(partition_id=5, count=42) 20 | assert hasattr(msg, "partition_id") 21 | assert hasattr(msg, "count") 22 | 23 | 24 | class TestDateTimeJsonEncoder: 25 | """Test DateTimeJsonEncoder class.""" 26 | 27 | def test_encode_datetime(self): 28 | """Test encoding datetime objects.""" 29 | import json 30 | 31 | dt = datetime(2024, 1, 1, 12, 30, 45) 32 | result = json.dumps({"timestamp": dt}, cls=DateTimeJsonEncoder) 33 | assert "2024-01-01T12:30:45" in result 34 | 35 | def test_encode_date(self): 36 | """Test encoding date objects.""" 37 | import json 38 | 39 | d = date(2024, 1, 1) 40 | result = json.dumps({"date": d}, cls=DateTimeJsonEncoder) 41 | assert "2024-01-01" in result 42 | 43 | def test_encode_datetime_with_microseconds(self): 44 | """Test encoding datetime with microseconds.""" 45 | import json 46 | 47 | dt = datetime(2024, 1, 1, 12, 30, 45, 123456) 48 | result = json.dumps({"timestamp": dt}, cls=DateTimeJsonEncoder) 49 | assert "2024-01-01T12:30:45.123456" in result 50 | 51 | def test_encode_regular_types(self): 52 | """Test that regular types still work.""" 53 | import json 54 | 55 | data = { 56 | "string": "test", 57 | "number": 42, 58 | "float": 3.14, 59 | "boolean": True, 60 | "null": None, 61 | "list": [1, 2, 3], 62 | "dict": {"key": "value"}, 63 | } 64 | result = json.dumps(data, cls=DateTimeJsonEncoder) 65 | parsed = json.loads(result) 66 | assert parsed["string"] == "test" 67 | assert parsed["number"] == 42 68 | assert parsed["float"] == 3.14 69 | assert parsed["boolean"] is True 70 | assert parsed["null"] is None 71 | assert parsed["list"] == [1, 2, 3] 72 | assert parsed["dict"] == {"key": "value"} 73 | 74 | def test_encode_mixed_types(self): 75 | """Test encoding mixed types including datetime.""" 76 | import json 77 | 78 | dt = datetime(2024, 1, 1, 12, 0, 0) 79 | d = date(2024, 6, 15) 80 | data = {"timestamp": dt, "date": d, "string": "test", "number": 42} 81 | result = json.dumps(data, cls=DateTimeJsonEncoder) 82 | assert "2024-01-01T12:00:00" in result 83 | assert "2024-06-15" in result 84 | assert "test" in result 85 | assert "42" in result 86 | 87 | 88 | class TestGetHttpSession: 89 | """Test get_http_session function.""" 90 | 91 | def test_default_session(self): 92 | """Test creating a session with default parameters.""" 93 | session = get_http_session() 94 | 95 | assert session is not None 96 | assert hasattr(session, "headers") 97 | 98 | def test_session_with_headers(self): 99 | """Test creating a session with additional headers.""" 100 | headers = {"Authorization": "Bearer token123"} 101 | session = get_http_session(additional_headers=headers) 102 | 103 | assert session is not None 104 | assert session.headers.get("Authorization") == "Bearer token123" 105 | 106 | def test_session_with_retry(self): 107 | """Test creating a session with retry configuration.""" 108 | session = get_http_session(retry=3) 109 | 110 | assert session is not None 111 | 112 | def test_session_without_retry(self): 113 | """Test creating a session without retry.""" 114 | session = get_http_session(retry=0) 115 | 116 | assert session is not None 117 | 118 | def test_session_retry_on_post(self): 119 | """Test creating a session with retry on POST enabled.""" 120 | session = get_http_session(retry=5, retry_on_post=True) 121 | 122 | assert session is not None 123 | 124 | def test_session_no_retry_on_post(self): 125 | """Test creating a session with retry on POST disabled.""" 126 | session = get_http_session(retry=5, retry_on_post=False) 127 | 128 | assert session is not None 129 | 130 | def test_retry_status_codes(self): 131 | """Test that session is created successfully with retries.""" 132 | session = get_http_session(retry=3) 133 | 134 | assert session is not None 135 | 136 | def test_session_with_multiple_headers(self): 137 | """Test creating a session with multiple headers.""" 138 | headers = { 139 | "Authorization": "Bearer token123", 140 | "Content-Type": "application/json", 141 | "User-Agent": "TestClient/1.0", 142 | } 143 | session = get_http_session(additional_headers=headers) 144 | 145 | assert session is not None 146 | for key, value in headers.items(): 147 | assert session.headers.get(key) == value 148 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Unit Tests for Cyber Spark Data Connectors 2 | 3 | This directory contains comprehensive unit tests for all implemented data sources in the cyber-spark-data-connectors package. 4 | 5 | ## Test Coverage 6 | 7 | The test suite covers the following data sources: 8 | 9 | ### 1. Splunk Data Source (`test_splunk.py`) 10 | - **TestSplunkDataSource**: Tests for the main data source class 11 | - Data source name validation 12 | - Stream and batch writer creation 13 | 14 | - **TestSplunkHecWriter**: Tests for the Splunk HEC writer functionality 15 | - Initialization with required and optional parameters 16 | - Missing parameter validation 17 | - Basic write operations 18 | - Time column handling 19 | - Indexed fields support 20 | - Single event column mode 21 | - Batching behavior 22 | 23 | - **TestSplunkHecStreamWriter**: Tests for stream-specific functionality 24 | - Commit and abort methods 25 | 26 | ### 2. REST API Data Source (`test_restapi.py`) 27 | - **TestRestApiDataSource**: Tests for the main data source class 28 | - Data source name validation 29 | - Stream and batch writer creation 30 | 31 | - **TestRestApiWriter**: Tests for the REST API writer functionality 32 | - Initialization with required and optional parameters 33 | - Missing/invalid parameter validation 34 | - POST and PUT method support 35 | - JSON serialization 36 | - Multiple row handling 37 | - HTTP headers configuration 38 | 39 | - **TestRestApiStreamWriter**: Tests for stream-specific functionality 40 | - Commit and abort methods 41 | 42 | ### 3. Microsoft Sentinel / Azure Monitor Data Source (`test_mssentinel.py`) 43 | - **TestAzureMonitorDataSource**: Tests for the Azure Monitor data source 44 | - Data source name validation 45 | - Stream and batch writer creation 46 | 47 | - **TestMicrosoftSentinelDataSource**: Tests for the MS Sentinel data source 48 | - Data source name validation (alias for Azure Monitor) 49 | - Stream and batch writer creation 50 | 51 | - **TestAzureMonitorWriter**: Tests for the Azure Monitor writer functionality 52 | - Initialization with all required Azure parameters 53 | - Missing parameter validation for each required field 54 | - Custom batch size configuration 55 | - Basic write operations 56 | - DateTime field conversion 57 | - Batching behavior 58 | - Azure credential creation 59 | 60 | - **TestAzureMonitorStreamWriter**: Tests for stream-specific functionality 61 | - Commit and abort methods 62 | 63 | ### 4. Common Utilities (`test_common.py`) 64 | - **TestSimpleCommitMessage**: Tests for the commit message dataclass 65 | - Creation and attribute access 66 | 67 | - **TestDateTimeJsonEncoder**: Tests for JSON encoding utilities 68 | - DateTime and date object serialization 69 | - Microseconds handling 70 | - Regular types support 71 | - Mixed type handling 72 | 73 | - **TestGetHttpSession**: Tests for HTTP session creation 74 | - Default session configuration 75 | - Custom headers 76 | - Retry configuration 77 | - Retry on POST option 78 | 79 | ## Running the Tests 80 | 81 | ### Run all tests 82 | ```bash 83 | poetry run pytest tests/ 84 | ``` 85 | 86 | ### Run tests with verbose output 87 | ```bash 88 | poetry run pytest tests/ -v 89 | ``` 90 | 91 | ### Run tests with coverage 92 | ```bash 93 | poetry run pytest tests/ --cov=cyber_connectors --cov-report=term-missing 94 | ``` 95 | 96 | ### Run specific test file 97 | ```bash 98 | poetry run pytest tests/test_splunk.py 99 | poetry run pytest tests/test_restapi.py 100 | poetry run pytest tests/test_mssentinel.py 101 | poetry run pytest tests/test_common.py 102 | ``` 103 | 104 | ### Run specific test class or method 105 | ```bash 106 | poetry run pytest tests/test_splunk.py::TestSplunkDataSource 107 | poetry run pytest tests/test_splunk.py::TestSplunkHecWriter::test_write_basic 108 | ``` 109 | 110 | ## Test Coverage Summary 111 | 112 | As of the latest run: 113 | - **Overall Coverage**: 95% 114 | - **MsSentinel.py**: 100% 115 | - **RestApi.py**: 98% 116 | - **Splunk.py**: 89% 117 | - **common.py**: 97% 118 | - **__init__.py**: 100% 119 | 120 | Total: 66 tests, all passing 121 | 122 | ## Testing Approach 123 | 124 | The tests use the following approach: 125 | - **Unit Tests**: All tests are unit tests with mocked dependencies 126 | - **Mocking**: External dependencies (HTTP sessions, Azure clients, PySpark TaskContext) are mocked 127 | - **Fixtures**: pytest fixtures are used for common test data (options, schemas) 128 | - **Assertions**: Comprehensive assertions to validate behavior 129 | - **Edge Cases**: Tests cover both happy paths and error conditions 130 | 131 | ## Dependencies 132 | 133 | The test suite requires: 134 | - pytest 135 | - pytest-cov (for coverage reports) 136 | - pytest-spark (for Spark testing utilities) 137 | - unittest.mock (standard library) 138 | 139 | All dependencies are managed through Poetry and are installed automatically when running tests in the Poetry environment. 140 | 141 | ### pytest-spark Fixtures 142 | 143 | The project uses `pytest-spark` which provides the following fixtures if needed for future tests: 144 | - `spark_session` - A SparkSession instance (scope: session) 145 | - `spark_context` - A SparkContext instance (scope: session) 146 | 147 | **Note:** The current unit tests don't use these fixtures because they mock all Spark interactions for speed and isolation. However, these fixtures are available for integration tests if needed in the future. 148 | 149 | Example usage (if needed): 150 | ```python 151 | def test_with_real_spark(spark_session): 152 | """Example test using actual Spark session.""" 153 | df = spark_session.range(10) 154 | assert df.count() == 10 155 | ``` 156 | -------------------------------------------------------------------------------- /cyber_connectors/Splunk.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, DataSourceWriter, WriterCommitMessage 4 | from pyspark.sql.types import Row, StructType 5 | from requests import Session 6 | 7 | from cyber_connectors.common import DateTimeJsonEncoder, SimpleCommitMessage, get_http_session 8 | 9 | 10 | class SplunkDataSource(DataSource): 11 | """Data source for Splunk. Right now supports writing to Splunk HEC. 12 | 13 | Write options: 14 | - url: Splunk HEC URL 15 | - token: Splunk HEC token 16 | - time_column: (optional) column name to use as event time 17 | - batch_size: (optional) number of events to batch before sending to Splunk. (default: 50) 18 | - index: (optional) Splunk index 19 | - source: (optional) Splunk source 20 | - host: (optional) Splunk host 21 | - sourcetype: (optional) Splunk sourcetype (default: _json) 22 | - single_event_column: (optional) column name to use as the full event payload (i.e., text column) 23 | - indexed_fields: (optional) comma separated list of fields to index 24 | - remove_indexed_fields: (optional) remove indexed fields from event payload (default: false) 25 | """ 26 | 27 | @classmethod 28 | def name(cls): 29 | return "splunk" 30 | 31 | def streamWriter(self, schema: StructType, overwrite: bool): 32 | return SplunkHecStreamWriter(self.options) 33 | 34 | def writer(self, schema: StructType, overwrite: bool): 35 | return SplunkHecBatchWriter(self.options) 36 | 37 | 38 | class SplunkHecWriter: 39 | """ """ 40 | 41 | def __init__(self, options): 42 | self.options = options 43 | self.url = self.options.get("url") 44 | self.token = self.options.get("token") 45 | assert self.url is not None 46 | assert self.token is not None 47 | # extract optional parameters 48 | self.time_col = self.options.get("time_column") 49 | self.batch_size = int(self.options.get("batch_size", "50")) 50 | self.index = self.options.get("index") 51 | self.source = self.options.get("source") 52 | self.host = self.options.get("host") 53 | self.source_type = self.options.get("sourcetype", "_json") 54 | self.single_event_column = self.options.get("single_event_column") 55 | if self.single_event_column and self.source_type == "_json": 56 | self.source_type = "text" 57 | self.indexed_fields = str(self.options.get("indexed_fields", "")).split(",") 58 | self.omit_indexed_fields = self.options.get("remove_indexed_fields", False) 59 | if isinstance(self.omit_indexed_fields, str): 60 | self.omit_indexed_fields = self.omit_indexed_fields.lower() == "true" 61 | 62 | def _send_to_splunk(self, s: Session, msgs: list): 63 | if len(msgs) > 0: 64 | response = s.post(self.url, data="\n".join(msgs)) 65 | print(response.status_code, response.text) 66 | 67 | def write(self, iterator: Iterator[Row]): 68 | """Writes the data, then returns the commit message of that partition. 69 | Library imports must be within the method. 70 | """ 71 | import datetime 72 | import json 73 | 74 | from pyspark import TaskContext 75 | 76 | context = TaskContext.get() 77 | partition_id = context.partitionId() 78 | cnt = 0 79 | s = get_http_session(additional_headers={"Authorization": f"Splunk {self.token}"}, retry_on_post=True) 80 | 81 | msgs = [] 82 | for row in iterator: 83 | cnt += 1 84 | rd = row.asDict() 85 | d = {"sourcetype": self.source_type} 86 | if self.index: 87 | d["index"] = self.index 88 | if self.source: 89 | d["source"] = self.source 90 | if self.host: 91 | d["host"] = self.host 92 | if self.time_col and self.time_col in rd: 93 | tm = rd.get(self.time_col, datetime.datetime.now()) 94 | if isinstance(tm, datetime.datetime): 95 | d["time"] = tm.timestamp() 96 | elif isinstance(tm, int) or isinstance(tm, float): 97 | d["time"] = tm 98 | else: 99 | d["time"] = datetime.datetime.now().timestamp() 100 | else: 101 | d["time"] = datetime.datetime.now().timestamp() 102 | if self.single_event_column and self.single_event_column in rd: 103 | d["event"] = rd.get(self.single_event_column) 104 | elif self.indexed_fields: 105 | idx_fields = {k: rd.get(k) for k in self.indexed_fields if k in rd} 106 | if idx_fields: 107 | d["fields"] = idx_fields 108 | if self.omit_indexed_fields: 109 | ev_fields = {k: v for k, v in rd.items() if k not in self.indexed_fields} 110 | if ev_fields: 111 | d["event"] = ev_fields 112 | else: 113 | d["event"] = rd 114 | else: 115 | d["event"] = rd 116 | msgs.append(json.dumps(d, cls=DateTimeJsonEncoder)) 117 | 118 | if len(msgs) >= self.batch_size: 119 | self._send_to_splunk(s, msgs) 120 | msgs = [] 121 | 122 | self._send_to_splunk(s, msgs) 123 | 124 | return SimpleCommitMessage(partition_id=partition_id, count=cnt) 125 | 126 | 127 | class SplunkHecBatchWriter(SplunkHecWriter, DataSourceWriter): 128 | def __init__(self, options): 129 | super().__init__(options) 130 | 131 | 132 | class SplunkHecStreamWriter(SplunkHecWriter, DataSourceStreamWriter): 133 | def __init__(self, options): 134 | super().__init__(options) 135 | 136 | def commit(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 137 | """Receives a sequence of :class:`WriterCommitMessage` when all write tasks have succeeded, then decides what to do with it. 138 | In this FakeStreamWriter, the metadata of the microbatch(number of rows and partitions) is written into a JSON file inside commit(). 139 | """ 140 | # status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages)) 141 | pass 142 | 143 | def abort(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 144 | """Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some other tasks have failed, then decides what to do with it. 145 | In this FakeStreamWriter, a failure message is written into a text file inside abort(). 146 | """ 147 | pass 148 | -------------------------------------------------------------------------------- /tests/test_restapi.py: -------------------------------------------------------------------------------- 1 | """Unit tests for REST API data source.""" 2 | 3 | import json 4 | from datetime import datetime 5 | from unittest.mock import Mock, patch 6 | 7 | import pytest 8 | from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType, TimestampType 9 | 10 | from cyber_connectors.RestApi import RestApiBatchWriter, RestApiDataSource, RestApiStreamWriter 11 | 12 | 13 | @pytest.fixture 14 | def basic_options(): 15 | """Basic required options for REST API data source.""" 16 | return {"url": "http://localhost:8001/api/endpoint"} 17 | 18 | 19 | @pytest.fixture 20 | def sample_schema(): 21 | """Sample schema for testing.""" 22 | return StructType( 23 | [ 24 | StructField("id", IntegerType(), True), 25 | StructField("name", StringType(), True), 26 | StructField("timestamp", TimestampType(), True), 27 | ] 28 | ) 29 | 30 | 31 | class TestRestApiDataSource: 32 | """Test RestApiDataSource class.""" 33 | 34 | def test_name(self): 35 | """Test that data source name is 'rest'.""" 36 | assert RestApiDataSource.name() == "rest" 37 | 38 | def test_streamWriter(self, basic_options, sample_schema): 39 | """Test that streamWriter returns RestApiStreamWriter.""" 40 | ds = RestApiDataSource(options=basic_options) 41 | writer = ds.streamWriter(sample_schema, overwrite=True) 42 | assert isinstance(writer, RestApiStreamWriter) 43 | 44 | def test_writer(self, basic_options, sample_schema): 45 | """Test that writer returns RestApiBatchWriter.""" 46 | ds = RestApiDataSource(options=basic_options) 47 | writer = ds.writer(sample_schema, overwrite=True) 48 | assert isinstance(writer, RestApiBatchWriter) 49 | 50 | 51 | class TestRestApiWriter: 52 | """Test RestApiWriter functionality.""" 53 | 54 | def test_init_required_options(self, basic_options): 55 | """Test initialization with required options.""" 56 | writer = RestApiBatchWriter(basic_options) 57 | assert writer.url == "http://localhost:8001/api/endpoint" 58 | assert writer.payload_format == "json" 59 | assert writer.http_method == "post" 60 | 61 | def test_init_missing_url(self): 62 | """Test that missing URL raises assertion error.""" 63 | with pytest.raises(AssertionError): 64 | RestApiBatchWriter({}) 65 | 66 | def test_init_with_custom_format(self): 67 | """Test initialization with custom format.""" 68 | options = {"url": "http://localhost:8001", "http_format": "json"} 69 | writer = RestApiBatchWriter(options) 70 | assert writer.payload_format == "json" 71 | 72 | def test_init_with_custom_method(self): 73 | """Test initialization with custom HTTP method.""" 74 | options = {"url": "http://localhost:8001", "http_method": "put"} 75 | writer = RestApiBatchWriter(options) 76 | assert writer.http_method == "put" 77 | 78 | def test_init_invalid_format(self): 79 | """Test that invalid format raises assertion error.""" 80 | options = {"url": "http://localhost:8001", "http_format": "xml"} 81 | with pytest.raises(AssertionError): 82 | RestApiBatchWriter(options) 83 | 84 | def test_init_invalid_method(self): 85 | """Test that invalid HTTP method raises assertion error.""" 86 | options = {"url": "http://localhost:8001", "http_method": "delete"} 87 | with pytest.raises(AssertionError): 88 | RestApiBatchWriter(options) 89 | 90 | @patch("pyspark.TaskContext") 91 | @patch("cyber_connectors.RestApi.get_http_session") 92 | def test_write_basic_post(self, mock_get_session, mock_task_context, basic_options): 93 | """Test basic write functionality with POST.""" 94 | mock_context = Mock() 95 | mock_context.partitionId.return_value = 0 96 | mock_task_context.get.return_value = mock_context 97 | 98 | mock_session = Mock() 99 | mock_response = Mock() 100 | mock_response.status_code = 200 101 | mock_response.text = '{"status": "ok"}' 102 | mock_session.post.return_value = mock_response 103 | mock_get_session.return_value = mock_session 104 | 105 | writer = RestApiBatchWriter(basic_options) 106 | rows = [Row(id=1, name="test")] 107 | commit_msg = writer.write(iter(rows)) 108 | 109 | assert commit_msg.partition_id == 0 110 | assert commit_msg.count == 1 111 | assert mock_session.post.called 112 | call_args = mock_session.post.call_args 113 | assert call_args[0][0] == "http://localhost:8001/api/endpoint" 114 | 115 | @patch("pyspark.TaskContext") 116 | @patch("cyber_connectors.RestApi.get_http_session") 117 | def test_write_with_put(self, mock_get_session, mock_task_context): 118 | """Test write functionality with PUT method.""" 119 | mock_context = Mock() 120 | mock_context.partitionId.return_value = 0 121 | mock_task_context.get.return_value = mock_context 122 | 123 | mock_session = Mock() 124 | mock_response = Mock() 125 | mock_response.status_code = 200 126 | mock_response.text = '{"status": "ok"}' 127 | mock_session.put.return_value = mock_response 128 | mock_get_session.return_value = mock_session 129 | 130 | options = {"url": "http://localhost:8001", "http_method": "put"} 131 | writer = RestApiBatchWriter(options) 132 | rows = [Row(id=1, name="test")] 133 | commit_msg = writer.write(iter(rows)) 134 | 135 | assert commit_msg.count == 1 136 | assert mock_session.put.called 137 | assert not mock_session.post.called 138 | 139 | @patch("pyspark.TaskContext") 140 | @patch("cyber_connectors.RestApi.get_http_session") 141 | def test_write_json_format(self, mock_get_session, mock_task_context, basic_options): 142 | """Test that data is properly serialized to JSON.""" 143 | mock_context = Mock() 144 | mock_context.partitionId.return_value = 0 145 | mock_task_context.get.return_value = mock_context 146 | 147 | mock_session = Mock() 148 | mock_response = Mock() 149 | mock_response.status_code = 200 150 | mock_response.text = '{"status": "ok"}' 151 | mock_session.post.return_value = mock_response 152 | mock_get_session.return_value = mock_session 153 | 154 | writer = RestApiBatchWriter(basic_options) 155 | timestamp = datetime(2024, 1, 1, 12, 0, 0) 156 | rows = [Row(id=1, name="test", timestamp=timestamp)] 157 | commit_msg = writer.write(iter(rows)) 158 | 159 | assert commit_msg.count == 1 160 | call_args = mock_session.post.call_args 161 | data = call_args[1]["data"] 162 | payload = json.loads(data) 163 | assert payload["id"] == 1 164 | assert payload["name"] == "test" 165 | assert "timestamp" in payload 166 | 167 | @patch("pyspark.TaskContext") 168 | @patch("cyber_connectors.RestApi.get_http_session") 169 | def test_write_multiple_rows(self, mock_get_session, mock_task_context, basic_options): 170 | """Test writing multiple rows.""" 171 | mock_context = Mock() 172 | mock_context.partitionId.return_value = 0 173 | mock_task_context.get.return_value = mock_context 174 | 175 | mock_session = Mock() 176 | mock_response = Mock() 177 | mock_response.status_code = 200 178 | mock_response.text = '{"status": "ok"}' 179 | mock_session.post.return_value = mock_response 180 | mock_get_session.return_value = mock_session 181 | 182 | writer = RestApiBatchWriter(basic_options) 183 | rows = [Row(id=i, name=f"test{i}") for i in range(5)] 184 | commit_msg = writer.write(iter(rows)) 185 | 186 | assert commit_msg.count == 5 187 | assert mock_session.post.call_count == 5 188 | 189 | @patch("pyspark.TaskContext") 190 | @patch("cyber_connectors.RestApi.get_http_session") 191 | def test_write_content_type_header(self, mock_get_session, mock_task_context, basic_options): 192 | """Test that Content-Type header is set correctly.""" 193 | mock_context = Mock() 194 | mock_context.partitionId.return_value = 0 195 | mock_task_context.get.return_value = mock_context 196 | 197 | mock_session = Mock() 198 | mock_response = Mock() 199 | mock_response.status_code = 200 200 | mock_response.text = '{"status": "ok"}' 201 | mock_session.post.return_value = mock_response 202 | mock_get_session.return_value = mock_session 203 | 204 | writer = RestApiBatchWriter(basic_options) 205 | rows = [Row(id=1)] 206 | writer.write(iter(rows)) 207 | 208 | # Check that get_http_session was called with correct headers 209 | call_args = mock_get_session.call_args 210 | headers = call_args[1]["additional_headers"] 211 | assert headers["Content-Type"] == "application/json" 212 | 213 | 214 | class TestRestApiStreamWriter: 215 | """Test RestApiStreamWriter functionality.""" 216 | 217 | def test_commit(self, basic_options): 218 | """Test commit method.""" 219 | writer = RestApiStreamWriter(basic_options) 220 | messages = [Mock(count=10), Mock(count=20)] 221 | writer.commit(messages, batchId=1) 222 | 223 | def test_abort(self, basic_options): 224 | """Test abort method.""" 225 | writer = RestApiStreamWriter(basic_options) 226 | messages = [Mock(count=10)] 227 | writer.abort(messages, batchId=1) 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /tests/test_splunk.py: -------------------------------------------------------------------------------- 1 | """Unit tests for Splunk data source.""" 2 | 3 | import json 4 | from datetime import datetime 5 | from unittest.mock import Mock, patch 6 | 7 | import pytest 8 | from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType, TimestampType 9 | 10 | from cyber_connectors.Splunk import SplunkDataSource, SplunkHecBatchWriter, SplunkHecStreamWriter 11 | 12 | 13 | @pytest.fixture 14 | def basic_options(): 15 | """Basic required options for Splunk data source.""" 16 | return {"url": "http://localhost:8088/services/collector/event", "token": "test-token-12345"} 17 | 18 | 19 | @pytest.fixture 20 | def sample_schema(): 21 | """Sample schema for testing.""" 22 | return StructType( 23 | [ 24 | StructField("id", IntegerType(), True), 25 | StructField("name", StringType(), True), 26 | StructField("timestamp", TimestampType(), True), 27 | ] 28 | ) 29 | 30 | 31 | class TestSplunkDataSource: 32 | """Test SplunkDataSource class.""" 33 | 34 | def test_name(self): 35 | """Test that data source name is 'splunk'.""" 36 | assert SplunkDataSource.name() == "splunk" 37 | 38 | def test_streamWriter(self, basic_options, sample_schema): 39 | """Test that streamWriter returns SplunkHecStreamWriter.""" 40 | ds = SplunkDataSource(options=basic_options) 41 | writer = ds.streamWriter(sample_schema, overwrite=True) 42 | assert isinstance(writer, SplunkHecStreamWriter) 43 | 44 | def test_writer(self, basic_options, sample_schema): 45 | """Test that writer returns SplunkHecBatchWriter.""" 46 | ds = SplunkDataSource(options=basic_options) 47 | writer = ds.writer(sample_schema, overwrite=True) 48 | assert isinstance(writer, SplunkHecBatchWriter) 49 | 50 | 51 | class TestSplunkHecWriter: 52 | """Test SplunkHecWriter functionality.""" 53 | 54 | def test_init_required_options(self, basic_options): 55 | """Test initialization with required options.""" 56 | writer = SplunkHecBatchWriter(basic_options) 57 | assert writer.url == "http://localhost:8088/services/collector/event" 58 | assert writer.token == "test-token-12345" 59 | assert writer.batch_size == 50 60 | assert writer.source_type == "_json" 61 | 62 | def test_init_missing_url(self): 63 | """Test that missing URL raises assertion error.""" 64 | with pytest.raises(AssertionError): 65 | SplunkHecBatchWriter({"token": "test-token"}) 66 | 67 | def test_init_missing_token(self): 68 | """Test that missing token raises assertion error.""" 69 | with pytest.raises(AssertionError): 70 | SplunkHecBatchWriter({"url": "http://localhost:8088"}) 71 | 72 | def test_init_with_all_options(self): 73 | """Test initialization with all options.""" 74 | options = { 75 | "url": "http://localhost:8088/services/collector/event", 76 | "token": "test-token", 77 | "time_column": "ts", 78 | "batch_size": "100", 79 | "index": "main", 80 | "source": "spark", 81 | "host": "test-host", 82 | "sourcetype": "custom", 83 | "single_event_column": "message", 84 | "indexed_fields": "field1,field2", 85 | "remove_indexed_fields": "true", 86 | } 87 | writer = SplunkHecBatchWriter(options) 88 | assert writer.time_col == "ts" 89 | assert writer.batch_size == 100 90 | assert writer.index == "main" 91 | assert writer.source == "spark" 92 | assert writer.host == "test-host" 93 | assert writer.source_type == "custom" 94 | assert writer.single_event_column == "message" 95 | assert writer.indexed_fields == ["field1", "field2"] 96 | assert writer.omit_indexed_fields is True 97 | 98 | def test_single_event_column_changes_sourcetype(self): 99 | """Test that single_event_column changes sourcetype to 'text' if default.""" 100 | options = {"url": "http://localhost:8088", "token": "test-token", "single_event_column": "message"} 101 | writer = SplunkHecBatchWriter(options) 102 | assert writer.source_type == "text" 103 | 104 | @patch("pyspark.TaskContext") 105 | @patch("cyber_connectors.Splunk.get_http_session") 106 | def test_write_basic(self, mock_get_session, mock_task_context, basic_options): 107 | """Test basic write functionality.""" 108 | mock_context = Mock() 109 | mock_context.partitionId.return_value = 0 110 | mock_task_context.get.return_value = mock_context 111 | 112 | mock_session = Mock() 113 | mock_response = Mock() 114 | mock_response.status_code = 200 115 | mock_response.text = '{"text":"Success","code":0}' 116 | mock_session.post.return_value = mock_response 117 | mock_get_session.return_value = mock_session 118 | 119 | writer = SplunkHecBatchWriter(basic_options) 120 | rows = [Row(id=1, name="test")] 121 | commit_msg = writer.write(iter(rows)) 122 | 123 | assert commit_msg.partition_id == 0 124 | assert commit_msg.count == 1 125 | assert mock_session.post.called 126 | 127 | @patch("pyspark.TaskContext") 128 | @patch("cyber_connectors.Splunk.get_http_session") 129 | def test_write_with_time_column(self, mock_get_session, mock_task_context): 130 | """Test write with time_column option.""" 131 | mock_context = Mock() 132 | mock_context.partitionId.return_value = 0 133 | mock_task_context.get.return_value = mock_context 134 | 135 | mock_session = Mock() 136 | mock_response = Mock() 137 | mock_response.status_code = 200 138 | mock_response.text = '{"text":"Success","code":0}' 139 | mock_session.post.return_value = mock_response 140 | mock_get_session.return_value = mock_session 141 | 142 | options = {"url": "http://localhost:8088", "token": "test-token", "time_column": "timestamp"} 143 | writer = SplunkHecBatchWriter(options) 144 | 145 | timestamp = datetime(2024, 1, 1, 12, 0, 0) 146 | rows = [Row(id=1, timestamp=timestamp)] 147 | commit_msg = writer.write(iter(rows)) 148 | 149 | assert commit_msg.count == 1 150 | call_args = mock_session.post.call_args 151 | data = call_args[1]["data"] 152 | payload = json.loads(data) 153 | assert payload["time"] == timestamp.timestamp() 154 | 155 | @patch("pyspark.TaskContext") 156 | @patch("cyber_connectors.Splunk.get_http_session") 157 | def test_write_with_indexed_fields(self, mock_get_session, mock_task_context): 158 | """Test write with indexed_fields option.""" 159 | mock_context = Mock() 160 | mock_context.partitionId.return_value = 0 161 | mock_task_context.get.return_value = mock_context 162 | 163 | mock_session = Mock() 164 | mock_response = Mock() 165 | mock_response.status_code = 200 166 | mock_response.text = '{"text":"Success","code":0}' 167 | mock_session.post.return_value = mock_response 168 | mock_get_session.return_value = mock_session 169 | 170 | options = {"url": "http://localhost:8088", "token": "test-token", "indexed_fields": "field1,field2"} 171 | writer = SplunkHecBatchWriter(options) 172 | 173 | rows = [Row(id=1, field1="value1", field2="value2", field3="value3")] 174 | commit_msg = writer.write(iter(rows)) 175 | 176 | assert commit_msg.count == 1 177 | call_args = mock_session.post.call_args 178 | data = call_args[1]["data"] 179 | payload = json.loads(data) 180 | assert "fields" in payload 181 | assert payload["fields"]["field1"] == "value1" 182 | assert payload["fields"]["field2"] == "value2" 183 | assert "event" in payload 184 | 185 | @patch("pyspark.TaskContext") 186 | @patch("cyber_connectors.Splunk.get_http_session") 187 | def test_write_with_single_event_column(self, mock_get_session, mock_task_context): 188 | """Test write with single_event_column option.""" 189 | mock_context = Mock() 190 | mock_context.partitionId.return_value = 0 191 | mock_task_context.get.return_value = mock_context 192 | 193 | mock_session = Mock() 194 | mock_response = Mock() 195 | mock_response.status_code = 200 196 | mock_response.text = '{"text":"Success","code":0}' 197 | mock_session.post.return_value = mock_response 198 | mock_get_session.return_value = mock_session 199 | 200 | options = {"url": "http://localhost:8088", "token": "test-token", "single_event_column": "message"} 201 | writer = SplunkHecBatchWriter(options) 202 | 203 | rows = [Row(id=1, message="This is a log message")] 204 | commit_msg = writer.write(iter(rows)) 205 | 206 | assert commit_msg.count == 1 207 | call_args = mock_session.post.call_args 208 | data = call_args[1]["data"] 209 | payload = json.loads(data) 210 | assert payload["event"] == "This is a log message" 211 | assert payload["sourcetype"] == "text" 212 | 213 | @patch("pyspark.TaskContext") 214 | @patch("cyber_connectors.Splunk.get_http_session") 215 | def test_write_batching(self, mock_get_session, mock_task_context): 216 | """Test that batching works correctly.""" 217 | mock_context = Mock() 218 | mock_context.partitionId.return_value = 0 219 | mock_task_context.get.return_value = mock_context 220 | 221 | mock_session = Mock() 222 | mock_response = Mock() 223 | mock_response.status_code = 200 224 | mock_response.text = '{"text":"Success","code":0}' 225 | mock_session.post.return_value = mock_response 226 | mock_get_session.return_value = mock_session 227 | 228 | options = {"url": "http://localhost:8088", "token": "test-token", "batch_size": "2"} 229 | writer = SplunkHecBatchWriter(options) 230 | 231 | rows = [Row(id=i) for i in range(5)] 232 | commit_msg = writer.write(iter(rows)) 233 | 234 | assert commit_msg.count == 5 235 | # Should be called 3 times: 2+2+1 236 | assert mock_session.post.call_count == 3 237 | 238 | 239 | class TestSplunkHecStreamWriter: 240 | """Test SplunkHecStreamWriter functionality.""" 241 | 242 | def test_commit(self, basic_options): 243 | """Test commit method.""" 244 | writer = SplunkHecStreamWriter(basic_options) 245 | messages = [Mock(count=10), Mock(count=20)] 246 | writer.commit(messages, batchId=1) 247 | 248 | def test_abort(self, basic_options): 249 | """Test abort method.""" 250 | writer = SplunkHecStreamWriter(basic_options) 251 | messages = [Mock(count=10)] 252 | writer.abort(messages, batchId=1) 253 | -------------------------------------------------------------------------------- /tests/test_mssentinel.py: -------------------------------------------------------------------------------- 1 | """Unit tests for Microsoft Sentinel / Azure Monitor data source.""" 2 | 3 | from datetime import datetime 4 | from unittest.mock import Mock, patch 5 | 6 | import pytest 7 | from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType, TimestampType 8 | 9 | from cyber_connectors.MsSentinel import ( 10 | AzureMonitorBatchWriter, 11 | AzureMonitorDataSource, 12 | AzureMonitorStreamWriter, 13 | MicrosoftSentinelDataSource, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def basic_options(): 19 | """Basic required options for Azure Monitor data source.""" 20 | return { 21 | "dce": "https://test-dce.monitor.azure.com", 22 | "dcr_id": "dcr-test123456789", 23 | "dcs": "Custom-TestTable_CL", 24 | "tenant_id": "tenant-id-12345", 25 | "client_id": "client-id-12345", 26 | "client_secret": "client-secret-12345", 27 | } 28 | 29 | 30 | @pytest.fixture 31 | def sample_schema(): 32 | """Sample schema for testing.""" 33 | return StructType( 34 | [ 35 | StructField("id", IntegerType(), True), 36 | StructField("name", StringType(), True), 37 | StructField("timestamp", TimestampType(), True), 38 | ] 39 | ) 40 | 41 | 42 | class TestAzureMonitorDataSource: 43 | """Test AzureMonitorDataSource class.""" 44 | 45 | def test_name(self): 46 | """Test that data source name is 'azure-monitor'.""" 47 | assert AzureMonitorDataSource.name() == "azure-monitor" 48 | 49 | def test_streamWriter(self, basic_options, sample_schema): 50 | """Test that streamWriter returns AzureMonitorStreamWriter.""" 51 | ds = AzureMonitorDataSource(options=basic_options) 52 | writer = ds.streamWriter(sample_schema, overwrite=True) 53 | assert isinstance(writer, AzureMonitorStreamWriter) 54 | 55 | def test_writer(self, basic_options, sample_schema): 56 | """Test that writer returns AzureMonitorBatchWriter.""" 57 | ds = AzureMonitorDataSource(options=basic_options) 58 | writer = ds.writer(sample_schema, overwrite=True) 59 | assert isinstance(writer, AzureMonitorBatchWriter) 60 | 61 | 62 | class TestMicrosoftSentinelDataSource: 63 | """Test MicrosoftSentinelDataSource class.""" 64 | 65 | def test_name(self): 66 | """Test that data source name is 'ms-sentinel'.""" 67 | assert MicrosoftSentinelDataSource.name() == "ms-sentinel" 68 | 69 | def test_streamWriter(self, basic_options, sample_schema): 70 | """Test that streamWriter returns AzureMonitorStreamWriter.""" 71 | ds = MicrosoftSentinelDataSource(options=basic_options) 72 | writer = ds.streamWriter(sample_schema, overwrite=True) 73 | assert isinstance(writer, AzureMonitorStreamWriter) 74 | 75 | def test_writer(self, basic_options, sample_schema): 76 | """Test that writer returns AzureMonitorBatchWriter.""" 77 | ds = MicrosoftSentinelDataSource(options=basic_options) 78 | writer = ds.writer(sample_schema, overwrite=True) 79 | assert isinstance(writer, AzureMonitorBatchWriter) 80 | 81 | 82 | class TestAzureMonitorWriter: 83 | """Test AzureMonitorWriter functionality.""" 84 | 85 | def test_init_required_options(self, basic_options): 86 | """Test initialization with required options.""" 87 | writer = AzureMonitorBatchWriter(basic_options) 88 | assert writer.dce == "https://test-dce.monitor.azure.com" 89 | assert writer.dcr_id == "dcr-test123456789" 90 | assert writer.dcs == "Custom-TestTable_CL" 91 | assert writer.tenant_id == "tenant-id-12345" 92 | assert writer.client_id == "client-id-12345" 93 | assert writer.client_secret == "client-secret-12345" 94 | assert writer.batch_size == 50 95 | 96 | def test_init_missing_dce(self): 97 | """Test that missing DCE raises assertion error.""" 98 | options = { 99 | "dcr_id": "dcr-test", 100 | "dcs": "stream", 101 | "tenant_id": "tenant", 102 | "client_id": "client", 103 | "client_secret": "secret", 104 | } 105 | with pytest.raises(AssertionError): 106 | AzureMonitorBatchWriter(options) 107 | 108 | def test_init_missing_dcr_id(self): 109 | """Test that missing DCR ID raises assertion error.""" 110 | options = { 111 | "dce": "https://test.monitor.azure.com", 112 | "dcs": "stream", 113 | "tenant_id": "tenant", 114 | "client_id": "client", 115 | "client_secret": "secret", 116 | } 117 | with pytest.raises(AssertionError): 118 | AzureMonitorBatchWriter(options) 119 | 120 | def test_init_missing_dcs(self): 121 | """Test that missing DCS raises assertion error.""" 122 | options = { 123 | "dce": "https://test.monitor.azure.com", 124 | "dcr_id": "dcr-test", 125 | "tenant_id": "tenant", 126 | "client_id": "client", 127 | "client_secret": "secret", 128 | } 129 | with pytest.raises(AssertionError): 130 | AzureMonitorBatchWriter(options) 131 | 132 | def test_init_missing_tenant_id(self): 133 | """Test that missing tenant ID raises assertion error.""" 134 | options = { 135 | "dce": "https://test.monitor.azure.com", 136 | "dcr_id": "dcr-test", 137 | "dcs": "stream", 138 | "client_id": "client", 139 | "client_secret": "secret", 140 | } 141 | with pytest.raises(AssertionError): 142 | AzureMonitorBatchWriter(options) 143 | 144 | def test_init_missing_client_id(self): 145 | """Test that missing client ID raises assertion error.""" 146 | options = { 147 | "dce": "https://test.monitor.azure.com", 148 | "dcr_id": "dcr-test", 149 | "dcs": "stream", 150 | "tenant_id": "tenant", 151 | "client_secret": "secret", 152 | } 153 | with pytest.raises(AssertionError): 154 | AzureMonitorBatchWriter(options) 155 | 156 | def test_init_missing_client_secret(self): 157 | """Test that missing client secret raises assertion error.""" 158 | options = { 159 | "dce": "https://test.monitor.azure.com", 160 | "dcr_id": "dcr-test", 161 | "dcs": "stream", 162 | "tenant_id": "tenant", 163 | "client_id": "client", 164 | } 165 | with pytest.raises(AssertionError): 166 | AzureMonitorBatchWriter(options) 167 | 168 | def test_init_with_custom_batch_size(self): 169 | """Test initialization with custom batch size.""" 170 | options = { 171 | "dce": "https://test.monitor.azure.com", 172 | "dcr_id": "dcr-test", 173 | "dcs": "stream", 174 | "tenant_id": "tenant", 175 | "client_id": "client", 176 | "client_secret": "secret", 177 | "batch_size": "100", 178 | } 179 | writer = AzureMonitorBatchWriter(options) 180 | assert writer.batch_size == 100 181 | 182 | @patch("pyspark.TaskContext") 183 | def test_write_basic(self, mock_task_context, basic_options): 184 | """Test basic write functionality.""" 185 | mock_context = Mock() 186 | mock_context.partitionId.return_value = 0 187 | mock_task_context.get.return_value = mock_context 188 | 189 | with patch("azure.identity.ClientSecretCredential") as mock_credential, patch( 190 | "azure.monitor.ingestion.LogsIngestionClient" 191 | ) as mock_logs_client_class: 192 | mock_credential_instance = Mock() 193 | mock_credential.return_value = mock_credential_instance 194 | 195 | mock_logs_client = Mock() 196 | mock_logs_client.upload = Mock() 197 | mock_logs_client_class.return_value = mock_logs_client 198 | 199 | writer = AzureMonitorBatchWriter(basic_options) 200 | rows = [Row(id=1, name="test")] 201 | commit_msg = writer.write(iter(rows)) 202 | 203 | assert commit_msg.partition_id == 0 204 | assert commit_msg.count == 1 205 | assert mock_logs_client.upload.called 206 | 207 | call_args = mock_logs_client.upload.call_args 208 | assert call_args[1]["rule_id"] == "dcr-test123456789" 209 | assert call_args[1]["stream_name"] == "Custom-TestTable_CL" 210 | assert len(call_args[1]["logs"]) == 1 211 | 212 | @patch("pyspark.TaskContext") 213 | def test_write_with_datetime(self, mock_task_context, basic_options): 214 | """Test write with datetime fields (should be converted to strings).""" 215 | mock_context = Mock() 216 | mock_context.partitionId.return_value = 0 217 | mock_task_context.get.return_value = mock_context 218 | 219 | with patch("azure.identity.ClientSecretCredential") as mock_credential, patch( 220 | "azure.monitor.ingestion.LogsIngestionClient" 221 | ) as mock_logs_client_class: 222 | mock_credential_instance = Mock() 223 | mock_credential.return_value = mock_credential_instance 224 | 225 | mock_logs_client = Mock() 226 | mock_logs_client.upload = Mock() 227 | mock_logs_client_class.return_value = mock_logs_client 228 | 229 | writer = AzureMonitorBatchWriter(basic_options) 230 | timestamp = datetime(2024, 1, 1, 12, 0, 0) 231 | rows = [Row(id=1, timestamp=timestamp)] 232 | commit_msg = writer.write(iter(rows)) 233 | 234 | assert commit_msg.count == 1 235 | call_args = mock_logs_client.upload.call_args 236 | logs = call_args[1]["logs"] 237 | assert isinstance(logs[0]["timestamp"], str) 238 | 239 | @patch("pyspark.TaskContext") 240 | def test_write_batching(self, mock_task_context): 241 | """Test that batching works correctly.""" 242 | mock_context = Mock() 243 | mock_context.partitionId.return_value = 0 244 | mock_task_context.get.return_value = mock_context 245 | 246 | with patch("azure.identity.ClientSecretCredential") as mock_credential, patch( 247 | "azure.monitor.ingestion.LogsIngestionClient" 248 | ) as mock_logs_client_class: 249 | mock_credential_instance = Mock() 250 | mock_credential.return_value = mock_credential_instance 251 | 252 | mock_logs_client = Mock() 253 | mock_logs_client.upload = Mock() 254 | mock_logs_client_class.return_value = mock_logs_client 255 | 256 | options = { 257 | "dce": "https://test.monitor.azure.com", 258 | "dcr_id": "dcr-test", 259 | "dcs": "stream", 260 | "tenant_id": "tenant", 261 | "client_id": "client", 262 | "client_secret": "secret", 263 | "batch_size": "2", 264 | } 265 | writer = AzureMonitorBatchWriter(options) 266 | rows = [Row(id=i) for i in range(5)] 267 | commit_msg = writer.write(iter(rows)) 268 | 269 | assert commit_msg.count == 5 270 | # Should be called 3 times: 2+2+1 271 | assert mock_logs_client.upload.call_count == 3 272 | 273 | @patch("pyspark.TaskContext") 274 | def test_write_credential_creation(self, mock_task_context, basic_options): 275 | """Test that credentials are created correctly.""" 276 | mock_context = Mock() 277 | mock_context.partitionId.return_value = 0 278 | mock_task_context.get.return_value = mock_context 279 | 280 | with patch("azure.identity.ClientSecretCredential") as mock_credential, patch( 281 | "azure.monitor.ingestion.LogsIngestionClient" 282 | ) as mock_logs_client_class: 283 | mock_credential_instance = Mock() 284 | mock_credential.return_value = mock_credential_instance 285 | 286 | mock_logs_client = Mock() 287 | mock_logs_client.upload = Mock() 288 | mock_logs_client_class.return_value = mock_logs_client 289 | 290 | writer = AzureMonitorBatchWriter(basic_options) 291 | rows = [Row(id=1)] 292 | writer.write(iter(rows)) 293 | 294 | # Check that ClientSecretCredential was called with correct parameters 295 | mock_credential.assert_called_once_with("tenant-id-12345", "client-id-12345", "client-secret-12345") 296 | 297 | # Check that LogsIngestionClient was created with correct parameters 298 | mock_logs_client_class.assert_called_once_with( 299 | "https://test-dce.monitor.azure.com", mock_credential_instance 300 | ) 301 | 302 | 303 | class TestAzureMonitorStreamWriter: 304 | """Test AzureMonitorStreamWriter functionality.""" 305 | 306 | def test_commit(self, basic_options): 307 | """Test commit method.""" 308 | writer = AzureMonitorStreamWriter(basic_options) 309 | messages = [Mock(count=10), Mock(count=20)] 310 | writer.commit(messages, batchId=1) 311 | 312 | def test_abort(self, basic_options): 313 | """Test abort method.""" 314 | writer = AzureMonitorStreamWriter(basic_options) 315 | messages = [Mock(count=10)] 316 | writer.abort(messages, batchId=1) 317 | -------------------------------------------------------------------------------- /AGENTS.md: -------------------------------------------------------------------------------- 1 | # AGENTS.md 2 | 3 | This file provides guidance to LLM when working with code in this repository. 4 | 5 | You're experienced senior developer with background in Cybersecurity. You're developing a 6 | library of custom Python data sources to simplify work (read/write) with data in different 7 | cybersecurity-related products. You follow the architecture and development guidelines 8 | outlined below. 9 | 10 | ## Project Overview 11 | 12 | This repository contains source code of different custom Python data sources (part of 13 | Apache Spark API) related to Cybersecurity. These data sources allows to read from or 14 | write to different Cybersecurity solutions in batch and/or streaming manner. 15 | 16 | **Architecture**: Simple, flat structure with one data source per file. Each data source implements: 17 | - `DataSource` base class with `name()`, `reader`, `streamReader`, `writer()`, and `streamWriter()` methods 18 | - Separate writer classes for batch (`DataSourceWriter`, `DataSourceReader`) and streaming (`DataSourceStreamWriter`, `DataSourceStreamReader`) 19 | - Shared writer logic in a common base class (e.g., `SplunkHecWriter`) 20 | 21 | ### Documentation and examples of custom data sources using the same API 22 | 23 | There is a number of publicly available examples that demonstrate how to implement custom Python data sources: 24 | 25 | - https://github.com/databricks/tmm/tree/main/Lakeflow-OpenSkyNetwork 26 | - https://github.com/allisonwang-db/pyspark-data-sources 27 | - https://github.com/databricks-industry-solutions/python-data-sources 28 | - https://github.com/dmatrix/spark-misc/tree/main/src/py/data_source 29 | - https://github.com/huggingface/pyspark_huggingface 30 | - https://github.com/dmoore247/PythonDataSources 31 | - https://github.com/dgomez04/pyspark-hubspot 32 | - https://github.com/dgomez04/pyspark-faker 33 | - https://github.com/skyler-myers-db/activemq_pyspark_connector 34 | - https://github.com/jiteshsoni/ethereum-streaming-pipeline/blob/6e06cdea573780ba09a33a334f7f07539721b85e/ethereum_block_stream_chainstack.py 35 | - https://www.canadiandataguy.com/p/stop-waiting-for-connectors-stream 36 | - https://www.databricks.com/blog/simplify-data-ingestion-new-python-data-source-api 37 | 38 | More information could be found in the documentation: 39 | 40 | - https://docs.databricks.com/aws/en/pyspark/datasources 41 | - https://spark.apache.org/docs/latest/api/python/tutorial/sql/python_data_source.html 42 | 43 | ## Architecture Patterns 44 | 45 | ### Data Source Implementation Pattern 46 | 47 | Most of data source follows this structure (see `cyber_connectors/Splunk.py` as reference): 48 | 49 | 1. **DataSource class**: Entry point, returns appropriate writers 50 | - Implements `name()` class method (returns format name like "splunk") 51 | - Implements `writer()` for batch write operations 52 | - Implements `streamWriter()` for streaming write operations 53 | - Implements `reader()` for batch read operations 54 | - Implements `streamReader()` for streaming read operations 55 | - Implements `schema` to return a predefined schema for read operations (it could be automatically generated from the response, but it could be slower compared to predefined schema). 56 | 57 | 2. **Base Writer class**: Shared write logic for batch and streaming 58 | - Extracts and validates options in `__init__` 59 | - Implements `write(iterator)` that processes rows and returns `SimpleCommitMessage` 60 | - May batch records before sending (configurable via `batch_size` option) 61 | 62 | 3. **Batch Writer class**: Inherits from base writer + `DataSourceWriter` 63 | - No additional methods needed 64 | 65 | 4. **Stream Writer class**: Inherits from base writer + `DataSourceStreamWriter` 66 | - Implements `commit()` (handles successful batch completion) 67 | - Implements `abort()` (handles failed batch) 68 | 69 | 5. **Base Reader class**: Shared read logic for batch and streaming reads 70 | - Extracts and validates base options in `__init__` 71 | - Implements `partitions()` to distribute reads over multiple executors (if it's possible). The custom class could be used to specify partition information (it should be inherited from `InputPartition`). 72 | - Implements `read` to get data for a specific partition. 73 | 74 | 6. **Batch Reader class**: Inherits from base reader + `DataSourceReader`. 75 | - No additional methods needed 76 | 77 | 7. **Stream Reader class**: Inherits from base reader + `DataSourceStreamReader`. 78 | - `initialOffset` - returns initial offset provided during the first initialization (or inferred automatically). The offset class should implement `json` and `from_json` methods. 79 | - `latestOffset` - returns the latest available offset. 80 | 81 | ### Common Utilities (`cyber_connectors/common.py`) 82 | 83 | - `SimpleCommitMessage`: Dataclass for partition write results (partition_id, count) 84 | - `DateTimeJsonEncoder`: JSON encoder that converts datetime/date to ISO format 85 | - `get_http_session()`: Creates requests.Session with retry logic (5 retries by default, handles 429/5xx errors) 86 | 87 | ### Key Design Principles 88 | 89 | 1. **SIMPLE over CLEVER**: No abstract base classes, factory patterns, or complex inheritance 90 | 2. **EXPLICIT over IMPLICIT**: Direct implementations, no hidden abstractions 91 | 3. **FLAT over NESTED**: Single-level inheritance (DataSource → Writer → Batch/Stream) 92 | 4. **Imports inside methods**: For partition-level execution, import libraries within `write()` methods 93 | 5. **Row-by-row processing**: Iterate rows, batch them, send when buffer full 94 | 95 | ## Adding a New Data Source 96 | 97 | Follow this checklist (use existing sources as templates): 98 | 99 | 1. Create new file `cyber_connectors/YourSource.py` 100 | 2. Implement `YourSourceDataSource(DataSource)` with `name()`, `writer()`, `streamWriter()` 101 | 3. Implement base writer class with: 102 | - Options validation in `__init__` 103 | - `write(iterator)` method with write logic 104 | 4. Implement batch and stream writer classes (minimal boilerplate) 105 | 5. Implement base reader class with: 106 | - Options validation in `__init__` 107 | - `read(partition)` method with read logic 108 | - `partitions(start, end)` method to split data into partitions 109 | 6. Implement batch and stream writer classes (minimal boilerplate) 110 | 7. Add exports to `cyber_connectors/__init__.py` 111 | 8. Create test file `tests/test_yoursource.py` with unit tests 112 | 9. Update README.md with usage examples and options 113 | 114 | ### Data Source Registration 115 | 116 | Users register data sources like this: 117 | ```python 118 | from cyber_connectors import SplunkDataSource 119 | spark.dataSource.register(SplunkDataSource) 120 | 121 | # Then use with .format("splunk") 122 | df.write.format("splunk").option("url", "...").save() 123 | ``` 124 | 125 | ## Current Data Sources 126 | 127 | 1. **Splunk** (`splunk`): Write to Splunk HEC endpoint 128 | - Supports indexed fields, custom event columns, metadata 129 | - See `cyber_connectors/Splunk.py` 130 | 131 | 2. **Microsoft Sentinel** (`ms-sentinel` / `azure-monitor`): Write to and read from Azure Monitor/Sentinel 132 | - Uses Azure service principal authentication 133 | - See `cyber_connectors/MsSentinel.py` 134 | 135 | 3. **REST API** (`rest`): Generic REST API writer 136 | - Supports POST/PUT with JSON payload 137 | - See `cyber_connectors/RestApi.py` 138 | 139 | ## 🚨 SENIOR DEVELOPER GUIDELINES 🚨 140 | 141 | **CRITICAL: This project follows SIMPLE, MAINTAINABLE patterns. DO NOT over-engineer!** 142 | 143 | ### Forbidden Patterns (DO NOT ADD THESE) 144 | 145 | - ❌ **Abstract base classes** or complex inheritance hierarchies 146 | - ❌ **Factory patterns** or dependency injection containers 147 | - ❌ **Decorators for cross-cutting concerns** (logging, caching, performance monitoring) 148 | - ❌ **Complex configuration classes** with nested structures 149 | - ❌ **Async/await patterns** unless absolutely necessary 150 | - ❌ **Connection pooling** or caching layers 151 | - ❌ **Generic "framework" code** or reusable utilities 152 | - ❌ **Complex error handling systems** or custom exceptions 153 | - ❌ **Performance optimization** patterns (premature optimization) 154 | - ❌ **Enterprise patterns** like singleton, observer, strategy, etc. 155 | 156 | ### Required Patterns (ALWAYS USE THESE) 157 | - ✅ **Direct function calls** - no indirection or abstraction layers 158 | - ✅ **Simple classes** with clear, single responsibilities 159 | - ✅ **Environment variables** for configuration (no complex config objects) 160 | - ✅ **Explicit imports** - import exactly what you need 161 | - ✅ **Basic error handling** with try/catch and simple return dictionaries 162 | - ✅ **Straightforward control flow** - avoid complex conditional logic 163 | - ✅ **Standard library first** - only add dependencies when absolutely necessary 164 | 165 | ### Implementation Rules 166 | 167 | 1. **One concept per file**: Each module should have a single, clear purpose 168 | 2. **Functions over classes**: Prefer functions unless you need state management 169 | 3. **Direct SDK calls**: Call Databricks SDK directly, no wrapper layers 170 | 4. **Simple data structures**: Use dicts and lists, avoid custom data classes 171 | 5. **Basic testing**: Simple unit tests with basic mocking, no complex test frameworks 172 | 6. **Minimal dependencies**: Only add new dependencies if critically needed 173 | 174 | ### Code Review Questions 175 | 176 | Before adding any code, ask yourself: 177 | - "Is this the simplest way to solve this problem?" 178 | - "Would a new developer understand this immediately?" 179 | - "Am I adding abstraction for a real need or hypothetical flexibility?" 180 | - "Can I solve this with standard library or existing dependencies?" 181 | - "Does this follow the existing patterns in the codebase?" 182 | 183 | ## Development Commands 184 | 185 | ### Python Execution Rules 186 | 187 | **CRITICAL: Always use `poetry run` instead of direct `python`:** 188 | ```bash 189 | # ✅ CORRECT 190 | poetry run python script.py 191 | 192 | # ❌ WRONG 193 | python script.py 194 | ``` 195 | 196 | ## Development Workflow 197 | 198 | ### Package Management 199 | 200 | - **Python**: Use `poetry add/remove` for dependencies, never edit `pyproject.toml` manually 201 | - Always check if dependencies already exist before adding new ones 202 | - **Principle**: Only add dependencies if absolutely critical 203 | 204 | ### Setup 205 | ```bash 206 | # Install dependencies (first time) 207 | poetry install 208 | 209 | # Activate environment 210 | . $(poetry env info -p)/bin/activate 211 | ``` 212 | 213 | ### Testing 214 | ```bash 215 | # Run all tests 216 | poetry run pytest 217 | 218 | # Run specific test file 219 | poetry run pytest tests/test_splunk.py 220 | 221 | # Run single test 222 | poetry run pytest tests/test_splunk.py::TestSplunkDataSource::test_name 223 | 224 | # Run with verbose output 225 | poetry run pytest -v 226 | ``` 227 | 228 | ### Building 229 | ```bash 230 | # Build wheel package 231 | poetry build 232 | 233 | # Output will be in dist/ directory 234 | ``` 235 | 236 | ### Code Quality 237 | ```bash 238 | # Format and lint code (ruff) 239 | poetry run ruff check cyber_connectors/ 240 | poetry run ruff format cyber_connectors/ 241 | 242 | # Type checking 243 | poetry run mypy cyber_connectors/ 244 | ``` 245 | 246 | ## Testing Guidelines 247 | 248 | - Tests use `pytest` with `pytest-spark` for Spark session fixtures 249 | - Mock external HTTP calls using `unittest.mock.patch` 250 | - Test writer initialization, option validation, and data processing logic 251 | - See `tests/test_splunk.py` for comprehensive examples 252 | 253 | **Test structure**: 254 | - Use fixtures for common setup (`basic_options`, `sample_schema`) 255 | - Test data source name registration 256 | - Test writer instantiation (batch and streaming) 257 | - Test option validation (required vs optional parameters) 258 | - Mock HTTP responses to test write operations 259 | 260 | ## Important Notes 261 | 262 | - **Python version**: 3.10-3.13 (defined in `pyproject.toml`) 263 | - **Spark version**: 4.0.1+ required (PySpark DataSource API) 264 | - **Dependencies**: Keep minimal - only add if critically needed 265 | - **Never use direct `python` commands**: Always use `poetry run python` 266 | - **Ruff configuration**: Line length 120, enforces docstrings, isort, flake8-bugbear 267 | - **No premature optimization**: Focus on clarity over performance 268 | 269 | ## Summary: What Makes This Project "Senior Developer Approved" 270 | 271 | - **Readable**: Any developer can understand the code immediately 272 | - **Maintainable**: Simple patterns that are easy to modify 273 | - **Focused**: Each module has a single, clear responsibility 274 | - **Direct**: No unnecessary abstractions or indirection 275 | - **Practical**: Solves the specific problem without over-engineering 276 | 277 | When in doubt, choose the **simpler** solution. Your future self (and your teammates) will thank you. 278 | 279 | --- 280 | 281 | ## Important Instruction Reminders 282 | 283 | **For an agent when working on this project:** 284 | 285 | 1. **Do what has been asked; nothing more, nothing less** 286 | 2. **NEVER create files unless absolutely necessary for achieving the goal** 287 | 3. **ALWAYS prefer editing an existing file to creating a new one** 288 | 4. **NEVER proactively create documentation files (*.md) or README files** 289 | 5. **Follow the SIMPLE patterns established in this codebase** 290 | 6. **When in doubt, ask "Is this the simplest way?" before implementing** 291 | 292 | This project is intentionally simplified. **Respect that simplicity.** 293 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Custom data sources/sinks for Cybersecurity-related work 2 | 3 | > [!WARNING] 4 | > **Experimental! Work in progress** 5 | 6 | Based on [PySpark DataSource API](https://spark.apache.org/docs/preview/api/python/user_guide/sql/python_data_source.html) available with Spark 4 & [DBR 15.3+](https://docs.databricks.com/en/pyspark/datasources.html). 7 | 8 | - [Custom data sources/sinks for Cybersecurity-related work](#custom-data-sourcessinks-for-cybersecurity-related-work) 9 | - [Available data sources](#available-data-sources) 10 | - [Splunk data source](#splunk-data-source) 11 | - [Microsoft Sentinel / Azure Monitor](#microsoft-sentinel--azure-monitor) 12 | - [Authentication Requirements](#authentication-requirements) 13 | - [Writing to Microsoft Sentinel / Azure Monitor](#writing-to-microsoft-sentinel--azure-monitor) 14 | - [Reading from Microsoft Sentinel / Azure Monitor](#reading-from-microsoft-sentinel--azure-monitor) 15 | - [Simple REST API](#simple-rest-api) 16 | - [Building](#building) 17 | - [References](#references) 18 | 19 | 20 | # Available data sources 21 | 22 | > [!NOTE] 23 | > Most of these data sources/sinks are designed to work with relatively small amounts of data - alerts, etc. If you need to read or write huge amounts of data, use native export/import functionality of corresponding external system. 24 | 25 | ## Splunk data source 26 | 27 | Right now only implements writing to [Splunk](https://www.splunk.com/) - both batch & streaming. Registered data source name is `splunk`. 28 | 29 | By default, this data source will put all columns into the `event` object and send it to Splunk together with metadata (`index`, `source`, ...). This behavior could be changed by providing `single_event_column` option to specify which string column should be used as the single value of `event`. 30 | 31 | Batch usage: 32 | 33 | ```python 34 | from cyber_connectors import * 35 | spark.dataSource.register(SplunkDataSource) 36 | 37 | df = spark.range(10) 38 | df.write.format("splunk").mode("overwrite") \ 39 | .option("url", "http://localhost:8088/services/collector/event") \ 40 | .option("token", "...").save() 41 | ``` 42 | 43 | Streaming usage: 44 | 45 | ```python 46 | from cyber_connectors import * 47 | spark.dataSource.register(SplunkDataSource) 48 | 49 | dir_name = "tests/samples/json/" 50 | bdf = spark.read.format("json").load(dir_name) # to infer schema - not use in the prod! 51 | 52 | sdf = spark.readStream.format("json").schema(bdf.schema).load(dir_name) 53 | 54 | stream_options = { 55 | "url": "http://localhost:8088/services/collector/event", 56 | "token": "....", 57 | "source": "zeek", 58 | "index": "zeek", 59 | "host": "my_host", 60 | "time_column": "ts", 61 | "checkpointLocation": "/tmp/splunk-checkpoint/" 62 | } 63 | stream = sdf.writeStream.format("splunk") \ 64 | .trigger(availableNow=True) \ 65 | .options(**stream_options).start() 66 | ``` 67 | 68 | Supported options: 69 | 70 | - `url` (string, required) - URL of the Splunk HTTP Event Collector (HEC) endpoint to send data to. For example, `http://localhost:8088/collector/services/event`. 71 | - `token` (string, required) - HEC token to [authenticate to HEC endpoint](https://docs.splunk.com/Documentation/Splunk/9.3.1/Data/FormateventsforHTTPEventCollector#HTTP_authentication). 72 | - `index` (string, optional) - name of the Splunk index to send data to. If omitted, the default index configured for HEC endpoint is used. 73 | - `source` (string, optional) - the source value to assign to the event data. 74 | - `host` (string, optional) - the host value to assign to the event data. 75 | - `sourcetype` (string, optional, default: `_json`) - the sourcetype value to assign to the event data. 76 | - `single_event_column` (string, optional) - specify which string column will be used as `event` payload. Typically this is used to ingest log files content. 77 | - `time_column` (string, optional) - specify which column to use as event time value (the `time` value in Splunk payload). Supported data types: `timestamp`, `float`, `int`, `long` (`float`/`int`/`long` values are treated as seconds since epoch). If not specified, current timestamp will be used. 78 | - `indexed_fields` (string, optional) - comma-separated list of string columns to be [indexed in the ingestion time](http://docs.splunk.com/Documentation/Splunk/9.3.1/Data/IFXandHEC). 79 | - `remove_indexed_fields` (boolean, optional, default: `false`) - if indexed fields should be removed from the `event` object. 80 | - `batch_size` (int. optional, default: 50) - the size of the buffer to collect payload before sending to Splunk. 81 | 82 | ## Microsoft Sentinel / Azure Monitor 83 | 84 | This data source supports both reading from and writing to [Microsoft Sentinel](https://learn.microsoft.com/en-us/azure/sentinel/overview/) / [Azure Monitor Log Analytics](https://learn.microsoft.com/en-us/azure/azure-monitor/logs/log-analytics-overview). Registered data source names are `ms-sentinel` and `azure-monitor`. 85 | 86 | ### Authentication Requirements 87 | 88 | This connector uses Azure Service Principal Client ID/Secret for authentication. 89 | 90 | The service principal needs the following permissions: 91 | - For reading: **Log Analytics Reader** role on the Log Analytics workspace 92 | - For writing: **Monitoring Metrics Publisher** role on the DCE and DCR 93 | 94 | 95 | ### Writing to Microsoft Sentinel / Azure Monitor 96 | 97 | The integration uses [Logs Ingestion API of Azure Monitor](https://learn.microsoft.com/en-us/azure/sentinel/create-custom-connector#connect-with-the-log-ingestion-api) for writing data. 98 | 99 | To push data you need to create Data Collection Endpoint (DCE), Data Collection Rule (DCR), and create a custom table in Log Analytics workspace. See [documentation](https://learn.microsoft.com/en-us/azure/azure-monitor/logs/logs-ingestion-api-overview) for description of this process. The structure of the data in DataFrame should match the structure of the defined custom table. 100 | 101 | You need to grant correct permissions (`Monitoring Metrics Publisher`) to the service principal on the DCE and DCR. 102 | 103 | Batch write usage: 104 | 105 | ```python 106 | from cyber_connectors import * 107 | spark.dataSource.register(MicrosoftSentinelDataSource) 108 | 109 | sentinel_options = { 110 | "dce": dc_endpoint, 111 | "dcr_id": dc_rule_id, 112 | "dcs": dc_stream_name, 113 | "tenant_id": tenant_id, 114 | "client_id": client_id, 115 | "client_secret": client_secret, 116 | } 117 | 118 | df = spark.range(10) 119 | df.write.format("ms-sentinel") \ 120 | .mode("overwrite") \ 121 | .options(**sentinel_options) \ 122 | .save() 123 | ``` 124 | 125 | Streaming write usage: 126 | 127 | ```python 128 | from cyber_connectors import * 129 | spark.dataSource.register(MicrosoftSentinelDataSource) 130 | 131 | dir_name = "tests/samples/json/" 132 | bdf = spark.read.format("json").load(dir_name) # to infer schema - not use in the prod! 133 | 134 | sdf = spark.readStream.format("json").schema(bdf.schema).load(dir_name) 135 | 136 | sentinel_stream_options = { 137 | "dce": dc_endpoint, 138 | "dcr_id": dc_rule_id, 139 | "dcs": dc_stream_name, 140 | "tenant_id": tenant_id, 141 | "client_id": client_id, 142 | "client_secret": client_secret, 143 | "checkpointLocation": "/tmp/sentinel-checkpoint/" 144 | } 145 | stream = sdf.writeStream.format("ms-sentinel") \ 146 | .trigger(availableNow=True) \ 147 | .options(**sentinel_stream_options).start() 148 | ``` 149 | 150 | Supported write options: 151 | 152 | - `dce` (string, required) - URL of the Data Collection Endpoint. 153 | - `dcr_id` (string, required) - ID of Data Collection Rule. 154 | - `dcs` (string, required) - name of custom table created in the Log Analytics Workspace. 155 | - `tenant_id` (string, required) - Azure Tenant ID. 156 | - `client_id` (string, required) - Application ID (client ID) of Azure Service Principal. 157 | - `client_secret` (string, required) - Client Secret of Azure Service Principal. 158 | - `batch_size` (int. optional, default: 50) - the size of the buffer to collect payload before sending to MS Sentinel. 159 | 160 | ### Reading from Microsoft Sentinel / Azure Monitor 161 | 162 | The data source supports both batch and streaming reads from Azure Monitor / Log Analytics workspaces using KQL (Kusto Query Language) queries. If schema isn't specified with `.schema`, it will be inferred automatically. 163 | 164 | > [!NOTE] 165 | > For streaming reads of big amounts of data, it's recommended to export necessary tables to EventHubs, and consume from there. 166 | 167 | #### Batch Read 168 | 169 | Batch read usage: 170 | 171 | ```python 172 | from cyber_connectors import * 173 | spark.dataSource.register(AzureMonitorDataSource) 174 | 175 | # Option 1: Using timespan (ISO 8601 duration) 176 | read_options = { 177 | "workspace_id": "your-workspace-id", 178 | "query": "AzureActivity | where TimeGenerated > ago(1d) | take 100", 179 | "timespan": "P1D", # ISO 8601 duration: 1 day 180 | "tenant_id": tenant_id, 181 | "client_id": client_id, 182 | "client_secret": client_secret, 183 | } 184 | 185 | # Option 2: Using start_time and end_time (ISO 8601 timestamps) 186 | read_options = { 187 | "workspace_id": "your-workspace-id", 188 | "query": "AzureActivity | take 100", 189 | "start_time": "2024-01-01T00:00:00Z", 190 | "end_time": "2024-01-02T00:00:00Z", 191 | "tenant_id": tenant_id, 192 | "client_id": client_id, 193 | "client_secret": client_secret, 194 | } 195 | 196 | # Option 3: Using only start_time (end_time defaults to current time) 197 | read_options = { 198 | "workspace_id": "your-workspace-id", 199 | "query": "AzureActivity | take 100", 200 | "start_time": "2024-01-01T00:00:00Z", # Query from start_time to now 201 | "tenant_id": tenant_id, 202 | "client_id": client_id, 203 | "client_secret": client_secret, 204 | } 205 | 206 | df = spark.read.format("azure-monitor") \ 207 | .options(**read_options) \ 208 | .load() 209 | 210 | df.show() 211 | ``` 212 | 213 | Supported read options: 214 | 215 | - `workspace_id` (string, required) - Log Analytics workspace ID 216 | - `query` (string, required) - KQL query to execute 217 | - **Time range options (choose one approach):** 218 | - `timespan` (string) - Time range in ISO 8601 duration format (e.g., "P1D" = 1 day, "PT1H" = 1 hour, "P7D" = 7 days) 219 | - `start_time` (string) - Start time in ISO 8601 format (e.g., "2024-01-01T00:00:00Z"). If provided without `end_time`, queries from `start_time` to current time 220 | - `end_time` (string, optional) - End time in ISO 8601 format. Only valid when `start_time` is specified 221 | - **Note**: `timespan` and `start_time/end_time` are mutually exclusive - choose one approach 222 | - `tenant_id` (string, required) - Azure Tenant ID 223 | - `client_id` (string, required) - Application ID (client ID) of Azure Service Principal 224 | - `client_secret` (string, required) - Client Secret of Azure Service Principal 225 | - `num_partitions` (int, optional, default: 1) - Number of partitions for reading data 226 | - `inferSchema` (bool, optional, default: true) - if we do the schema inference by sampling result. 227 | 228 | **KQL Query Examples:** 229 | 230 | ```python 231 | # Get recent Azure Activity logs 232 | query = "AzureActivity | where TimeGenerated > ago(24h) | project TimeGenerated, OperationName, ResourceGroup" 233 | 234 | # Get security alerts 235 | query = "SecurityAlert | where TimeGenerated > ago(7d) | project TimeGenerated, AlertName, Severity" 236 | 237 | # Custom table query 238 | query = "MyCustomTable_CL | where TimeGenerated > ago(1h)" 239 | ``` 240 | 241 | #### Streaming Read 242 | 243 | The data source supports streaming reads from Azure Monitor / Log Analytics. The streaming reader uses time-based offsets to track progress and splits time ranges into partitions for parallel processing. 244 | 245 | Streaming read usage: 246 | 247 | ```python 248 | from cyber_connectors import * 249 | spark.dataSource.register(AzureMonitorDataSource) 250 | 251 | # Stream from a specific timestamp 252 | stream_options = { 253 | "workspace_id": "your-workspace-id", 254 | "query": "AzureActivity | project TimeGenerated, OperationName, ResourceGroup", 255 | "start_time": "2024-01-01T00:00:00Z", # Start streaming from this timestamp 256 | "tenant_id": tenant_id, 257 | "client_id": client_id, 258 | "client_secret": client_secret, 259 | "checkpointLocation": "/tmp/azure-monitor-checkpoint/", 260 | "partition_duration": "3600", # Optional: partition size in seconds (default 1 hour) 261 | } 262 | 263 | # Read stream 264 | stream_df = spark.readStream.format("azure-monitor") \ 265 | .options(**stream_options) \ 266 | .load() 267 | 268 | # Write to console or another sink 269 | query = stream_df.writeStream \ 270 | .format("console") \ 271 | .trigger(availableNow=True) \ 272 | .option("checkpointLocation", "/tmp/azure-monitor-checkpoint/") \ 273 | .start() 274 | 275 | query.awaitTermination() 276 | ``` 277 | 278 | Supported streaming read options: 279 | 280 | - `workspace_id` (string, required) - Log Analytics workspace ID 281 | - `query` (string, required) - KQL query to execute (should not include time filters - these are added automatically) 282 | - `start_time` (string, optional, default: "latest") - Start time in ISO 8601 format (e.g., "2024-01-01T00:00:00Z"). Use "latest" to start from current time 283 | - `partition_duration` (int, optional, default: 3600) - Duration in seconds for each partition (controls parallelism) 284 | - `tenant_id` (string, required) - Azure Tenant ID 285 | - `client_id` (string, required) - Application ID (client ID) of Azure Service Principal 286 | - `client_secret` (string, required) - Client Secret of Azure Service Principal 287 | - `checkpointLocation` (string, required) - Directory path for Spark streaming checkpoints 288 | 289 | **Important notes for streaming:** 290 | - The reader automatically tracks the timestamp of the last processed data in checkpoints 291 | - Time ranges are split into partitions based on `partition_duration` for parallel processing 292 | - The query should NOT include time filters (e.g., `where TimeGenerated > ago(1d)`) - the reader adds these automatically based on offsets 293 | - Use `start_time: "latest"` to begin streaming from the current time (useful for monitoring real-time data) 294 | 295 | ## Simple REST API 296 | 297 | Right now only implements writing to arbitrary REST API - both batch & streaming. Registered data source name is `rest`. 298 | 299 | Usage: 300 | 301 | ```python 302 | from cyber_connectors import * 303 | 304 | spark.dataSource.register(RestApiDataSource) 305 | 306 | df = spark.range(10) 307 | df.write.format("rest").mode("overwrite") \ 308 | .option("url", "http://localhost:8001/") \ 309 | .save() 310 | ``` 311 | 312 | Supported options: 313 | 314 | - `url` (string, required) - URL of the REST API endpoint to send data to. 315 | - `http_format` (string, optional, default: `json`) what payload format to use (right now only `json` is supported) 316 | - `http_method` (string, optional, default: `post`) what HTTP method to use (`post` or `put`). 317 | 318 | This data source could be easily used to write to Tines webhook. Just specify [Tines webhook URL](https://www.tines.com/docs/actions/types/webhook/#secrets-in-url) as `url` option: 319 | 320 | ```python 321 | df.write.format("rest").mode("overwrite") \ 322 | .option("url", "https://tenant.tines.com/webhook//") \ 323 | .save() 324 | ``` 325 | 326 | # Building 327 | 328 | This project uses [Poetry](https://python-poetry.org/) to manage dependencies and building the package. 329 | 330 | Initial setup & build: 331 | 332 | - Install Poetry 333 | - Set the Poetry environment with `poetry env use 3.10` (or higher Python version) 334 | - Activate Poetry environment with `. $(poetry env info -p)/bin/activate` 335 | - Build the wheel file with `poetry build`. Generated file will be stored in the `dist` directory. 336 | 337 | > [!CAUTION] 338 | > Right now, some dependencies aren't included into manifest, so if you will try it with OSS Spark, you will need to make sure that you have following dependencies set: `pyspark[sql]` (version `4.0.0.dev2` or higher), `grpcio` (`>=1.48,<1.57`), `grpcio-status` (`>=1.48,<1.57`), `googleapis-common-protos` (`1.56.4`). 339 | 340 | 341 | # References 342 | 343 | - Splunk: [Format events for HTTP Event Collector](https://docs.splunk.com/Documentation/Splunk/9.3.1/Data/FormateventsforHTTPEventCollector) 344 | 345 | 346 | 347 | -------------------------------------------------------------------------------- /tests/test_mssentinel_stream_reader.py: -------------------------------------------------------------------------------- 1 | """Unit tests for Azure Monitor Stream Reader.""" 2 | 3 | from datetime import datetime, timezone 4 | from unittest.mock import Mock, patch 5 | 6 | import pytest 7 | from pyspark.sql.types import StringType, StructField, StructType 8 | 9 | from cyber_connectors.MsSentinel import ( 10 | AzureMonitorDataSource, 11 | AzureMonitorOffset, 12 | AzureMonitorStreamReader, 13 | ) 14 | 15 | 16 | class TestAzureMonitorDataSourceStream: 17 | """Test data source stream reader registration.""" 18 | 19 | @pytest.fixture 20 | def stream_options(self): 21 | """Basic valid options for stream reader.""" 22 | return { 23 | "workspace_id": "test-workspace-id", 24 | "query": "AzureActivity", 25 | "start_time": "2024-01-01T00:00:00Z", 26 | "tenant_id": "test-tenant", 27 | "client_id": "test-client", 28 | "client_secret": "test-secret", 29 | } 30 | 31 | def test_stream_reader_method_exists(self, stream_options): 32 | """Test that streamReader method exists and returns reader.""" 33 | ds = AzureMonitorDataSource(options=stream_options) 34 | schema = StructType([StructField("test", StringType(), True)]) 35 | 36 | # Should not raise exception (reader creation will validate options) 37 | reader = ds.streamReader(schema) 38 | assert reader is not None 39 | assert isinstance(reader, AzureMonitorStreamReader) 40 | 41 | 42 | class TestAzureMonitorOffset: 43 | """Test AzureMonitorOffset serialization and deserialization.""" 44 | 45 | def test_offset_initialization(self): 46 | """Test offset initializes with timestamp.""" 47 | offset = AzureMonitorOffset("2024-01-01T00:00:00Z") 48 | assert offset.timestamp == "2024-01-01T00:00:00Z" 49 | 50 | def test_offset_json_serialization(self): 51 | """Test offset JSON serialization.""" 52 | offset = AzureMonitorOffset("2024-01-01T00:00:00Z") 53 | json_str = offset.json() 54 | assert json_str is not None 55 | assert "2024-01-01T00:00:00Z" in json_str 56 | assert "timestamp" in json_str 57 | 58 | def test_offset_json_deserialization(self): 59 | """Test offset JSON deserialization.""" 60 | offset = AzureMonitorOffset("2024-01-01T00:00:00Z") 61 | json_str = offset.json() 62 | restored = AzureMonitorOffset.from_json(json_str) 63 | assert restored.timestamp == offset.timestamp 64 | 65 | def test_offset_roundtrip(self): 66 | """Test offset roundtrip serialization/deserialization.""" 67 | original_timestamp = "2024-12-31T23:59:59Z" 68 | offset1 = AzureMonitorOffset(original_timestamp) 69 | json_str = offset1.json() 70 | offset2 = AzureMonitorOffset.from_json(json_str) 71 | assert offset2.timestamp == original_timestamp 72 | 73 | 74 | class TestAzureMonitorStreamReader: 75 | """Test Azure Monitor Stream Reader implementation.""" 76 | 77 | @pytest.fixture 78 | def stream_options(self): 79 | """Basic valid options for stream reader.""" 80 | return { 81 | "workspace_id": "test-workspace-id", 82 | "query": "AzureActivity", 83 | "start_time": "2024-01-01T00:00:00Z", 84 | "tenant_id": "test-tenant", 85 | "client_id": "test-client", 86 | "client_secret": "test-secret", 87 | } 88 | 89 | @pytest.fixture 90 | def basic_schema(self): 91 | """Basic schema for testing.""" 92 | return StructType( 93 | [StructField("TimeGenerated", StringType(), True), StructField("OperationName", StringType(), True)] 94 | ) 95 | 96 | def test_stream_reader_initialization(self, stream_options, basic_schema): 97 | """Test stream reader initializes with valid options.""" 98 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 99 | 100 | assert reader.workspace_id == "test-workspace-id" 101 | assert reader.query == "AzureActivity" 102 | assert reader.start_time == "2024-01-01T00:00:00Z" 103 | assert reader.tenant_id == "test-tenant" 104 | assert reader.client_id == "test-client" 105 | assert reader.client_secret == "test-secret" 106 | assert reader.partition_duration == 3600 # default 1 hour 107 | 108 | def test_stream_reader_custom_partition_duration(self, stream_options, basic_schema): 109 | """Test stream reader with custom partition_duration.""" 110 | stream_options["partition_duration"] = "1800" # 30 minutes 111 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 112 | assert reader.partition_duration == 1800 113 | 114 | def test_stream_reader_start_time_latest(self, basic_schema): 115 | """Test stream reader with start_time='latest' defaults to current time.""" 116 | options = { 117 | "workspace_id": "test-workspace-id", 118 | "query": "AzureActivity", 119 | # start_time not provided - should default to 'latest' 120 | "tenant_id": "test-tenant", 121 | "client_id": "test-client", 122 | "client_secret": "test-secret", 123 | } 124 | 125 | reader = AzureMonitorStreamReader(options, basic_schema) 126 | assert reader.start_time is not None 127 | # Should be in ISO format 128 | assert "T" in reader.start_time 129 | 130 | def test_stream_reader_start_time_latest_explicit(self, basic_schema): 131 | """Test stream reader with explicit start_time='latest'.""" 132 | options = { 133 | "workspace_id": "test-workspace-id", 134 | "query": "AzureActivity", 135 | "start_time": "latest", 136 | "tenant_id": "test-tenant", 137 | "client_id": "test-client", 138 | "client_secret": "test-secret", 139 | } 140 | 141 | reader = AzureMonitorStreamReader(options, basic_schema) 142 | assert reader.start_time is not None 143 | assert "T" in reader.start_time 144 | 145 | def test_stream_reader_invalid_start_time(self, basic_schema): 146 | """Test stream reader fails with invalid start_time format.""" 147 | options = { 148 | "workspace_id": "test-workspace-id", 149 | "query": "AzureActivity", 150 | "start_time": "invalid-timestamp", 151 | "tenant_id": "test-tenant", 152 | "client_id": "test-client", 153 | "client_secret": "test-secret", 154 | } 155 | 156 | with pytest.raises(ValueError, match="Invalid start_time format"): 157 | AzureMonitorStreamReader(options, basic_schema) 158 | 159 | def test_stream_reader_missing_workspace_id(self, basic_schema): 160 | """Test stream reader fails without workspace_id.""" 161 | options = { 162 | "query": "AzureActivity", 163 | "start_time": "2024-01-01T00:00:00Z", 164 | "tenant_id": "test-tenant", 165 | "client_id": "test-client", 166 | "client_secret": "test-secret", 167 | } 168 | 169 | with pytest.raises(AssertionError, match="workspace_id is required"): 170 | AzureMonitorStreamReader(options, basic_schema) 171 | 172 | def test_stream_reader_missing_query(self, basic_schema): 173 | """Test stream reader fails without query.""" 174 | options = { 175 | "workspace_id": "test-workspace-id", 176 | "start_time": "2024-01-01T00:00:00Z", 177 | "tenant_id": "test-tenant", 178 | "client_id": "test-client", 179 | "client_secret": "test-secret", 180 | } 181 | 182 | with pytest.raises(AssertionError, match="query is required"): 183 | AzureMonitorStreamReader(options, basic_schema) 184 | 185 | def test_stream_reader_missing_tenant_id(self, basic_schema): 186 | """Test stream reader fails without tenant_id.""" 187 | options = { 188 | "workspace_id": "test-workspace-id", 189 | "query": "AzureActivity", 190 | "start_time": "2024-01-01T00:00:00Z", 191 | "client_id": "test-client", 192 | "client_secret": "test-secret", 193 | } 194 | 195 | with pytest.raises(AssertionError, match="tenant_id is required"): 196 | AzureMonitorStreamReader(options, basic_schema) 197 | 198 | def test_stream_reader_missing_client_id(self, basic_schema): 199 | """Test stream reader fails without client_id.""" 200 | options = { 201 | "workspace_id": "test-workspace-id", 202 | "query": "AzureActivity", 203 | "start_time": "2024-01-01T00:00:00Z", 204 | "tenant_id": "test-tenant", 205 | "client_secret": "test-secret", 206 | } 207 | 208 | with pytest.raises(AssertionError, match="client_id is required"): 209 | AzureMonitorStreamReader(options, basic_schema) 210 | 211 | def test_stream_reader_missing_client_secret(self, basic_schema): 212 | """Test stream reader fails without client_secret.""" 213 | options = { 214 | "workspace_id": "test-workspace-id", 215 | "query": "AzureActivity", 216 | "start_time": "2024-01-01T00:00:00Z", 217 | "tenant_id": "test-tenant", 218 | "client_id": "test-client", 219 | } 220 | 221 | with pytest.raises(AssertionError, match="client_secret is required"): 222 | AzureMonitorStreamReader(options, basic_schema) 223 | 224 | def test_initial_offset(self, stream_options, basic_schema): 225 | """Test initial offset returns start_time as JSON string.""" 226 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 227 | offset_json = reader.initialOffset() 228 | 229 | # Should return JSON string 230 | assert isinstance(offset_json, str) 231 | assert "timestamp" in offset_json 232 | 233 | # Deserialize and verify 234 | offset = AzureMonitorOffset.from_json(offset_json) 235 | assert offset.timestamp == "2024-01-01T00:00:00Z" 236 | 237 | def test_latest_offset(self, stream_options, basic_schema): 238 | """Test latest offset returns current time as JSON string.""" 239 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 240 | offset_json = reader.latestOffset() 241 | 242 | # Should return JSON string 243 | assert isinstance(offset_json, str) 244 | assert "timestamp" in offset_json 245 | 246 | # Deserialize and verify 247 | offset = AzureMonitorOffset.from_json(offset_json) 248 | assert offset.timestamp is not None 249 | # Should be in ISO format 250 | assert "T" in offset.timestamp 251 | # Should be recent (within last minute) 252 | offset_time = datetime.fromisoformat(offset.timestamp.replace("Z", "+00:00")) 253 | now = datetime.now(timezone.utc) 254 | time_diff = (now - offset_time).total_seconds() 255 | assert time_diff < 60 # Less than 1 minute difference 256 | 257 | def test_partitions_single_partition(self, stream_options, basic_schema): 258 | """Test partitions with time range smaller than partition_duration.""" 259 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 260 | 261 | start_offset = AzureMonitorOffset("2024-01-01T00:00:00Z").json() 262 | end_offset = AzureMonitorOffset("2024-01-01T00:30:00Z").json() # 30 minutes (< 1 hour) 263 | 264 | partitions = reader.partitions(start_offset, end_offset) 265 | 266 | # Should have single partition 267 | assert len(partitions) == 1 268 | assert partitions[0].start_time.isoformat() == "2024-01-01T00:00:00+00:00" 269 | assert partitions[0].end_time.isoformat() == "2024-01-01T00:30:00+00:00" 270 | 271 | def test_partitions_multiple_partitions(self, stream_options, basic_schema): 272 | """Test partitions splits time range into multiple partitions.""" 273 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 274 | 275 | start_offset = AzureMonitorOffset("2024-01-01T00:00:00Z").json() 276 | end_offset = AzureMonitorOffset("2024-01-01T03:00:00Z").json() # 3 hours 277 | 278 | partitions = reader.partitions(start_offset, end_offset) 279 | 280 | # Should have 3 partitions (1 hour each) 281 | assert len(partitions) == 3 282 | 283 | # Verify partition time ranges 284 | # First partition: exact start and end at 1-hour boundary 285 | assert partitions[0].start_time.isoformat() == "2024-01-01T00:00:00+00:00" 286 | assert partitions[0].end_time.isoformat() == "2024-01-01T01:00:00+00:00" 287 | 288 | # Second partition: starts 1 microsecond after first ends (to avoid overlap) 289 | # End time = start_time + 1 hour 290 | assert partitions[1].start_time.isoformat() == "2024-01-01T01:00:00.000001+00:00" 291 | assert partitions[1].end_time.isoformat() == "2024-01-01T02:00:00.000001+00:00" 292 | 293 | # Third partition: starts 1 microsecond after second ends 294 | # End time = min(start_time + 1 hour, end_offset) = end_offset 295 | assert partitions[2].start_time.isoformat() == "2024-01-01T02:00:00.000002+00:00" 296 | assert partitions[2].end_time.isoformat() == "2024-01-01T03:00:00+00:00" 297 | 298 | def test_partitions_custom_duration(self, stream_options, basic_schema): 299 | """Test partitions with custom partition_duration.""" 300 | stream_options["partition_duration"] = "1800" # 30 minutes 301 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 302 | 303 | start_offset = AzureMonitorOffset("2024-01-01T00:00:00Z").json() 304 | end_offset = AzureMonitorOffset("2024-01-01T01:00:00Z").json() # 1 hour 305 | 306 | partitions = reader.partitions(start_offset, end_offset) 307 | 308 | # Should have 2 partitions (30 minutes each) 309 | assert len(partitions) == 2 310 | 311 | assert partitions[0].start_time.isoformat() == "2024-01-01T00:00:00+00:00" 312 | assert partitions[0].end_time.isoformat() == "2024-01-01T00:30:00+00:00" 313 | 314 | # Second partition starts 1 microsecond after first ends (to avoid overlap) 315 | assert partitions[1].start_time.isoformat() == "2024-01-01T00:30:00.000001+00:00" 316 | assert partitions[1].end_time.isoformat() == "2024-01-01T01:00:00+00:00" 317 | 318 | def test_partitions_partial_last_partition(self, stream_options, basic_schema): 319 | """Test partitions handles partial last partition correctly.""" 320 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 321 | 322 | start_offset = AzureMonitorOffset("2024-01-01T00:00:00Z").json() 323 | end_offset = AzureMonitorOffset("2024-01-01T02:30:00Z").json() # 2.5 hours 324 | 325 | partitions = reader.partitions(start_offset, end_offset) 326 | 327 | # Should have 3 partitions: 1h, 1h, 0.5h 328 | assert len(partitions) == 3 329 | 330 | # First partition ends at 1-hour boundary 331 | assert partitions[0].end_time.isoformat() == "2024-01-01T01:00:00+00:00" 332 | # Second partition ends at 2-hour boundary + 1 microsecond (starts 1 microsecond after first) 333 | assert partitions[1].end_time.isoformat() == "2024-01-01T02:00:00.000001+00:00" 334 | # Third partition (partial, starts 1 microsecond after second) ends at 2.5 hours (end_offset) 335 | assert partitions[2].end_time.isoformat() == "2024-01-01T02:30:00+00:00" 336 | 337 | @patch("azure.monitor.query.LogsQueryClient") 338 | @patch("azure.identity.ClientSecretCredential") 339 | def test_read_success(self, mock_credential, mock_client, stream_options, basic_schema): 340 | """Test successful streaming read.""" 341 | from azure.monitor.query import LogsQueryStatus 342 | 343 | from cyber_connectors.MsSentinel import TimeRangePartition 344 | 345 | # Create mock response 346 | mock_response = Mock() 347 | mock_response.status = LogsQueryStatus.SUCCESS 348 | 349 | mock_table = Mock() 350 | mock_table.columns = ["TimeGenerated", "OperationName"] 351 | mock_table.rows = [ 352 | ["2024-01-01T00:15:00Z", "Read"], 353 | ["2024-01-01T00:30:00Z", "Write"], 354 | ] 355 | mock_response.tables = [mock_table] 356 | 357 | # Setup mock client 358 | mock_client_instance = Mock() 359 | mock_client_instance.query_workspace.return_value = mock_response 360 | mock_client.return_value = mock_client_instance 361 | 362 | # Create reader 363 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 364 | 365 | # Create partition with time range 366 | partition = TimeRangePartition( 367 | start_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 368 | end_time=datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc), 369 | ) 370 | 371 | # Read data 372 | rows = list(reader.read(partition)) 373 | 374 | # Verify results 375 | assert len(rows) == 2 376 | assert rows[0].TimeGenerated == "2024-01-01T00:15:00Z" 377 | assert rows[0].OperationName == "Read" 378 | assert rows[1].TimeGenerated == "2024-01-01T00:30:00Z" 379 | assert rows[1].OperationName == "Write" 380 | 381 | # Verify query was called with correct parameters 382 | mock_client_instance.query_workspace.assert_called_once() 383 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 384 | assert call_kwargs["workspace_id"] == "test-workspace-id" 385 | assert call_kwargs["query"] == "AzureActivity" 386 | 387 | @patch("azure.monitor.query.LogsQueryClient") 388 | @patch("azure.identity.ClientSecretCredential") 389 | def test_read_empty_results(self, mock_credential, mock_client, stream_options, basic_schema): 390 | """Test streaming read with empty results.""" 391 | from azure.monitor.query import LogsQueryStatus 392 | 393 | from cyber_connectors.MsSentinel import TimeRangePartition 394 | 395 | # Create mock response with no rows 396 | mock_response = Mock() 397 | mock_response.status = LogsQueryStatus.SUCCESS 398 | mock_table = Mock() 399 | mock_table.columns = ["TimeGenerated", "OperationName"] 400 | mock_table.rows = [] 401 | mock_response.tables = [mock_table] 402 | 403 | mock_client_instance = Mock() 404 | mock_client_instance.query_workspace.return_value = mock_response 405 | mock_client.return_value = mock_client_instance 406 | 407 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 408 | partition = TimeRangePartition( 409 | start_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 410 | end_time=datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc), 411 | ) 412 | 413 | rows = list(reader.read(partition)) 414 | assert len(rows) == 0 415 | 416 | @patch("azure.monitor.query.LogsQueryClient") 417 | @patch("azure.identity.ClientSecretCredential") 418 | def test_read_query_failure(self, mock_credential, mock_client, stream_options, basic_schema): 419 | """Test handling of query failure in streaming.""" 420 | from azure.monitor.query import LogsQueryStatus 421 | 422 | from cyber_connectors.MsSentinel import TimeRangePartition 423 | 424 | # Create mock response with failure status 425 | mock_response = Mock() 426 | mock_response.status = LogsQueryStatus.PARTIAL 427 | 428 | mock_client_instance = Mock() 429 | mock_client_instance.query_workspace.return_value = mock_response 430 | mock_client.return_value = mock_client_instance 431 | 432 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 433 | partition = TimeRangePartition( 434 | start_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 435 | end_time=datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc), 436 | ) 437 | 438 | with pytest.raises(Exception, match="Query failed with status"): 439 | list(reader.read(partition)) 440 | 441 | def test_commit_does_nothing(self, stream_options, basic_schema): 442 | """Test commit method (should do nothing as Spark handles checkpointing).""" 443 | reader = AzureMonitorStreamReader(stream_options, basic_schema) 444 | end_offset = AzureMonitorOffset("2024-01-01T01:00:00Z").json() 445 | 446 | # Should not raise exception 447 | reader.commit(end_offset) 448 | 449 | @patch("azure.monitor.query.LogsQueryClient") 450 | @patch("azure.identity.ClientSecretCredential") 451 | def test_read_with_type_conversion(self, mock_credential, mock_client, stream_options): 452 | """Test streaming read with type conversion from schema.""" 453 | from azure.monitor.query import LogsQueryStatus 454 | from pyspark.sql.types import LongType, StructField, StructType 455 | 456 | from cyber_connectors.MsSentinel import TimeRangePartition 457 | 458 | schema = StructType([StructField("Count", LongType(), True)]) 459 | 460 | # Create mock response with string values that should be converted to int 461 | mock_response = Mock() 462 | mock_response.status = LogsQueryStatus.SUCCESS 463 | mock_table = Mock() 464 | mock_table.columns = ["Count"] 465 | mock_table.rows = [["123"], ["456"]] 466 | mock_response.tables = [mock_table] 467 | 468 | mock_client_instance = Mock() 469 | mock_client_instance.query_workspace.return_value = mock_response 470 | mock_client.return_value = mock_client_instance 471 | 472 | reader = AzureMonitorStreamReader(stream_options, schema) 473 | partition = TimeRangePartition( 474 | start_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 475 | end_time=datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc), 476 | ) 477 | 478 | rows = list(reader.read(partition)) 479 | 480 | # Verify type conversion 481 | assert len(rows) == 2 482 | assert rows[0].Count == 123 483 | assert isinstance(rows[0].Count, int) 484 | assert rows[1].Count == 456 485 | assert isinstance(rows[1].Count, int) 486 | 487 | @patch("azure.monitor.query.LogsQueryClient") 488 | @patch("azure.identity.ClientSecretCredential") 489 | def test_read_with_missing_columns(self, mock_credential, mock_client, stream_options): 490 | """Test streaming read when query results are missing schema columns.""" 491 | from azure.monitor.query import LogsQueryStatus 492 | from pyspark.sql.types import StringType, StructField, StructType 493 | 494 | from cyber_connectors.MsSentinel import TimeRangePartition 495 | 496 | # Schema expects Name and Extra columns 497 | schema = StructType([StructField("Name", StringType(), True), StructField("Extra", StringType(), True)]) 498 | 499 | # Query results only have Name (missing Extra) 500 | mock_response = Mock() 501 | mock_response.status = LogsQueryStatus.SUCCESS 502 | mock_table = Mock() 503 | mock_table.columns = ["Name"] 504 | mock_table.rows = [["Alice"], ["Bob"]] 505 | mock_response.tables = [mock_table] 506 | 507 | mock_client_instance = Mock() 508 | mock_client_instance.query_workspace.return_value = mock_response 509 | mock_client.return_value = mock_client_instance 510 | 511 | reader = AzureMonitorStreamReader(stream_options, schema) 512 | partition = TimeRangePartition( 513 | start_time=datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 514 | end_time=datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc), 515 | ) 516 | 517 | rows = list(reader.read(partition)) 518 | 519 | # Verify missing column is set to None 520 | assert len(rows) == 2 521 | assert rows[0].Name == "Alice" 522 | assert rows[0].Extra is None 523 | assert rows[1].Name == "Bob" 524 | assert rows[1].Extra is None 525 | -------------------------------------------------------------------------------- /cyber_connectors/MsSentinel.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import date, datetime 3 | 4 | from azure.monitor.ingestion import LogsIngestionClient 5 | from pyspark.sql.datasource import ( 6 | DataSource, 7 | DataSourceReader, 8 | DataSourceStreamReader, 9 | DataSourceStreamWriter, 10 | DataSourceWriter, 11 | InputPartition, 12 | WriterCommitMessage, 13 | ) 14 | from pyspark.sql.types import StructType 15 | 16 | from cyber_connectors.common import DateTimeJsonEncoder, SimpleCommitMessage 17 | 18 | 19 | def _create_azure_credential(tenant_id, client_id, client_secret): 20 | """Create Azure ClientSecretCredential for authentication. 21 | 22 | Args: 23 | tenant_id: Azure tenant ID 24 | client_id: Azure service principal client ID 25 | client_secret: Azure service principal client secret 26 | 27 | Returns: 28 | ClientSecretCredential: Authenticated credential object 29 | 30 | """ 31 | from azure.identity import ClientSecretCredential 32 | 33 | return ClientSecretCredential(tenant_id=tenant_id, client_id=client_id, client_secret=client_secret) 34 | 35 | 36 | def _parse_time_range(timespan=None, start_time=None, end_time=None): 37 | """Parse time range from timespan or start_time/end_time options. 38 | 39 | Args: 40 | timespan: ISO 8601 duration string (e.g., "P1D", "PT1H") 41 | start_time: ISO 8601 datetime string (e.g., "2024-01-01T00:00:00Z") 42 | end_time: ISO 8601 datetime string (optional, defaults to now) 43 | 44 | Returns: 45 | tuple: (start_datetime, end_datetime) as datetime objects with timezone 46 | 47 | Raises: 48 | ValueError: If timespan format is invalid 49 | Exception: If neither timespan nor start_time is provided 50 | 51 | """ 52 | import re 53 | from datetime import datetime, timedelta, timezone 54 | 55 | if timespan: 56 | # Parse ISO 8601 duration 57 | # Format: P[n]D or PT[n]H[n]M[n]S or combination P[n]DT[n]H[n]M[n]S 58 | match = re.match(r"P(?:(\d+)D)?(?:T(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?)?$", timespan) 59 | if match: 60 | days = int(match.group(1) or 0) 61 | hours = int(match.group(2) or 0) 62 | minutes = int(match.group(3) or 0) 63 | seconds = int(match.group(4) or 0) 64 | 65 | # Validate that at least one component was specified 66 | if days == 0 and hours == 0 and minutes == 0 and seconds == 0: 67 | raise ValueError(f"Invalid timespan format: {timespan} - must specify at least one duration component") 68 | 69 | delta = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) 70 | end_time_val = datetime.now(timezone.utc) 71 | start_time_val = end_time_val - delta 72 | return (start_time_val, end_time_val) 73 | else: 74 | raise ValueError(f"Invalid timespan format: {timespan}") 75 | elif start_time: 76 | start_time_val = datetime.fromisoformat(start_time.replace("Z", "+00:00")) 77 | if end_time: 78 | end_time_val = datetime.fromisoformat(end_time.replace("Z", "+00:00")) 79 | else: 80 | end_time_val = datetime.now(timezone.utc) 81 | return (start_time_val, end_time_val) 82 | else: 83 | raise Exception("Either 'timespan' or 'start_time' must be provided") 84 | 85 | 86 | def _execute_logs_query( 87 | workspace_id, 88 | query, 89 | timespan, 90 | tenant_id, 91 | client_id, 92 | client_secret, 93 | ): 94 | """Execute a KQL query against Azure Monitor Log Analytics workspace. 95 | 96 | Args: 97 | workspace_id: Log Analytics workspace ID 98 | query: KQL query to execute 99 | timespan: Time range as tuple (start_time, end_time) 100 | tenant_id: Azure tenant ID 101 | client_id: Azure service principal client ID 102 | client_secret: Azure service principal client secret 103 | 104 | Returns: 105 | Query response object from Azure Monitor 106 | 107 | Raises: 108 | Exception: If query fails 109 | 110 | """ 111 | from azure.monitor.query import LogsQueryClient, LogsQueryStatus 112 | 113 | # Create authenticated client 114 | credential = _create_azure_credential(tenant_id, client_id, client_secret) 115 | client = LogsQueryClient(credential) 116 | 117 | # Execute query 118 | response = client.query_workspace( 119 | workspace_id=workspace_id, 120 | query=query, 121 | timespan=timespan, 122 | include_statistics=False, 123 | include_visualization=False, 124 | ) 125 | 126 | if response.status != LogsQueryStatus.SUCCESS: 127 | raise Exception(f"Query failed with status: {response.status}") 128 | 129 | return response 130 | 131 | 132 | def _convert_value_to_schema_type(value, spark_type): 133 | """Convert a value to match the expected PySpark schema type. 134 | 135 | Args: 136 | value: The raw value from Azure Monitor 137 | spark_type: The expected PySpark DataType 138 | 139 | Returns: 140 | Converted value matching the schema type 141 | 142 | Raises: 143 | ValueError: If conversion fails 144 | 145 | """ 146 | from pyspark.sql.types import ( 147 | BooleanType, 148 | DateType, 149 | DoubleType, 150 | FloatType, 151 | IntegerType, 152 | LongType, 153 | StringType, 154 | TimestampType, 155 | ) 156 | 157 | # Handle None/NULL values 158 | if value is None: 159 | return None 160 | 161 | try: 162 | # String type - convert everything to string 163 | if isinstance(spark_type, StringType): 164 | return str(value) 165 | 166 | # Boolean type 167 | elif isinstance(spark_type, BooleanType): 168 | if isinstance(value, bool): 169 | return value 170 | elif isinstance(value, str): 171 | if value.lower() in ("true", "1", "yes"): 172 | return True 173 | elif value.lower() in ("false", "0", "no"): 174 | return False 175 | else: 176 | raise ValueError(f"Cannot convert string '{value}' to boolean") 177 | elif isinstance(value, (int, float)): 178 | return bool(value) 179 | else: 180 | raise ValueError(f"Cannot convert {type(value).__name__} to boolean") 181 | 182 | # Integer types 183 | elif isinstance(spark_type, (IntegerType, LongType)): 184 | if isinstance(value, bool): 185 | # Don't convert bool to int (bool is subclass of int in Python) 186 | raise ValueError("Cannot convert boolean to integer") 187 | return int(value) 188 | 189 | # Float types 190 | elif isinstance(spark_type, (FloatType, DoubleType)): 191 | if isinstance(value, bool): 192 | raise ValueError("Cannot convert boolean to float") 193 | return float(value) 194 | 195 | # Timestamp type 196 | elif isinstance(spark_type, TimestampType): 197 | if isinstance(value, datetime): 198 | return value 199 | elif isinstance(value, str): 200 | # Try parsing ISO 8601 format 201 | return datetime.fromisoformat(value.replace("Z", "+00:00")) 202 | else: 203 | raise ValueError(f"Cannot convert {type(value).__name__} to timestamp") 204 | 205 | # Date type 206 | elif isinstance(spark_type, DateType): 207 | if isinstance(value, date) and not isinstance(value, datetime): 208 | return value 209 | elif isinstance(value, datetime): 210 | return value.date() 211 | elif isinstance(value, str): 212 | # Try parsing ISO 8601 date format 213 | return datetime.fromisoformat(value.replace("Z", "+00:00")).date() 214 | else: 215 | raise ValueError(f"Cannot convert {type(value).__name__} to date") 216 | 217 | # Unsupported type - return as-is 218 | else: 219 | return value 220 | 221 | except (ValueError, TypeError) as e: 222 | raise ValueError(f"Failed to convert value '{value}' (type: {type(value).__name__}) to {spark_type}: {e}") 223 | 224 | 225 | @dataclass 226 | class TimeRangePartition(InputPartition): 227 | """Represents a time range partition for parallel query execution.""" 228 | 229 | start_time: datetime 230 | end_time: datetime 231 | 232 | 233 | class AzureMonitorDataSource(DataSource): 234 | """Data source for Azure Monitor. Supports reading from and writing to Azure Monitor. 235 | 236 | Write options: 237 | - dce: data collection endpoint URL 238 | - dcr_id: data collection rule ID 239 | - dcs: data collection stream name 240 | - tenant_id: Azure tenant ID 241 | - client_id: Azure service principal ID 242 | - client_secret: Azure service principal client secret 243 | 244 | Read options: 245 | - workspace_id: Log Analytics workspace ID 246 | - query: KQL query to execute 247 | - timespan: Time range for query in ISO 8601 duration format 248 | - tenant_id: Azure tenant ID 249 | - client_id: Azure service principal ID 250 | - client_secret: Azure service principal client secret 251 | """ 252 | 253 | @classmethod 254 | def name(cls): 255 | return "azure-monitor" 256 | 257 | def schema(self): 258 | """Return the schema for reading data. 259 | 260 | If the user doesn't provide a schema, this method infers it by executing 261 | a sample query with limit 1. Only if inferSchema is true. 262 | 263 | Returns: 264 | StructType: The schema of the data 265 | 266 | """ 267 | infer_schema = self.options.get("inferSchema", "true").lower() == "true" 268 | if infer_schema: 269 | return self._infer_read_schema() 270 | else: 271 | raise Exception("Must provide schema if inferSchema is false") 272 | 273 | def _infer_read_schema(self): 274 | """Infer schema by executing a sample query with limit 1. 275 | 276 | Returns: 277 | StructType: The inferred schema 278 | 279 | Raises: 280 | Exception: If query returns no results or fails 281 | 282 | """ 283 | from pyspark.sql.types import ( 284 | BooleanType, 285 | DateType, 286 | DoubleType, 287 | LongType, 288 | StringType, 289 | StructField, 290 | StructType, 291 | TimestampType, 292 | ) 293 | 294 | # Get read options 295 | workspace_id = self.options.get("workspace_id") 296 | query = self.options.get("query") 297 | tenant_id = self.options.get("tenant_id") 298 | client_id = self.options.get("client_id") 299 | client_secret = self.options.get("client_secret") 300 | timespan = self.options.get("timespan") 301 | start_time = self.options.get("start_time") 302 | end_time = self.options.get("end_time") 303 | 304 | # Validate required options 305 | assert workspace_id is not None, "workspace_id is required" 306 | assert query is not None, "query is required" 307 | assert tenant_id is not None, "tenant_id is required" 308 | assert client_id is not None, "client_id is required" 309 | assert client_secret is not None, "client_secret is required" 310 | 311 | # Parse time range using module-level function 312 | timespan_value = _parse_time_range(timespan=timespan, start_time=start_time, end_time=end_time) 313 | 314 | # Modify query to limit results to 1 row 315 | sample_query = query.strip() 316 | if not any(keyword in sample_query.lower() for keyword in ["| take 1", "| limit 1"]): 317 | sample_query = f"{sample_query} | take 1" 318 | 319 | # Execute sample query using module-level function 320 | response = _execute_logs_query( 321 | workspace_id=workspace_id, 322 | query=sample_query, 323 | timespan=timespan_value, 324 | tenant_id=tenant_id, 325 | client_id=client_id, 326 | client_secret=client_secret, 327 | ) 328 | 329 | # Check if we got any tables 330 | if not response.tables or len(response.tables) == 0: 331 | raise Exception("Schema inference failed: query returned no tables") 332 | 333 | table = response.tables[0] 334 | 335 | # Check if table has any columns 336 | if not table.columns or len(table.columns) == 0: 337 | raise Exception("Schema inference failed: query returned no columns") 338 | 339 | # Check if we have any rows to infer types from 340 | if not table.rows or len(table.rows) == 0: 341 | # No data to infer types from, use string type for all columns 342 | fields = [StructField(str(col), StringType(), nullable=True) for col in table.columns] 343 | return StructType(fields) 344 | 345 | # Infer schema from actual data in the first row 346 | # table.columns is always a list of strings (column names) 347 | first_row = table.rows[0] 348 | fields = [] 349 | 350 | for i, column_name in enumerate(table.columns): 351 | # Get the value from the first row to infer type 352 | value = first_row[i] if i < len(first_row) else None 353 | 354 | # Infer PySpark type from Python type 355 | if value is None: 356 | # If first value is None, default to StringType 357 | spark_type = StringType() 358 | elif isinstance(value, bool): 359 | # Check bool before int (bool is subclass of int in Python) 360 | spark_type = BooleanType() 361 | elif isinstance(value, int): 362 | spark_type = LongType() 363 | elif isinstance(value, float): 364 | spark_type = DoubleType() 365 | elif isinstance(value, datetime): 366 | spark_type = TimestampType() 367 | elif isinstance(value, date): 368 | spark_type = DateType() 369 | elif isinstance(value, str): 370 | spark_type = StringType() 371 | else: 372 | # For any other type, use StringType 373 | spark_type = StringType() 374 | 375 | fields.append(StructField(column_name, spark_type, nullable=True)) 376 | 377 | return StructType(fields) 378 | 379 | def reader(self, schema: StructType): 380 | return AzureMonitorBatchReader(self.options, schema) 381 | 382 | def streamReader(self, schema: StructType): 383 | return AzureMonitorStreamReader(self.options, schema) 384 | 385 | def streamWriter(self, schema: StructType, overwrite: bool): 386 | return AzureMonitorStreamWriter(self.options) 387 | 388 | def writer(self, schema: StructType, overwrite: bool): 389 | return AzureMonitorBatchWriter(self.options) 390 | 391 | 392 | class MicrosoftSentinelDataSource(AzureMonitorDataSource): 393 | """Same implementation as AzureMonitorDataSource, just exposed as ms-sentinel name.""" 394 | 395 | @classmethod 396 | def name(cls): 397 | return "ms-sentinel" 398 | 399 | 400 | class AzureMonitorReader: 401 | """Base reader class for Azure Monitor / Log Analytics workspaces. 402 | 403 | Shared read logic for batch and streaming reads. 404 | """ 405 | 406 | def __init__(self, options, schema: StructType): 407 | """Initialize the reader with options and schema. 408 | 409 | Args: 410 | options: Dictionary of options containing workspace_id, query, credentials 411 | schema: StructType schema (provided by DataSource.schema()) 412 | 413 | """ 414 | # Extract and validate required options 415 | self.workspace_id = options.get("workspace_id") 416 | self.query = options.get("query") 417 | self.tenant_id = options.get("tenant_id") 418 | self.client_id = options.get("client_id") 419 | self.client_secret = options.get("client_secret") 420 | 421 | # Validate required options 422 | assert self.workspace_id is not None, "workspace_id is required" 423 | assert self.query is not None, "query is required" 424 | assert self.tenant_id is not None, "tenant_id is required" 425 | assert self.client_id is not None, "client_id is required" 426 | assert self.client_secret is not None, "client_secret is required" 427 | 428 | # Store schema (provided by DataSource.schema()) 429 | self._schema = schema 430 | 431 | def read(self, partition: TimeRangePartition): 432 | """Read data for the given partition time range. 433 | 434 | Args: 435 | partition: TimeRangePartition containing start_time and end_time 436 | 437 | Yields: 438 | Row objects from the query results 439 | 440 | """ 441 | # Import inside method for partition-level execution 442 | from pyspark.sql import Row 443 | 444 | # Use partition's time range 445 | timespan_value = (partition.start_time, partition.end_time) 446 | 447 | # Execute query using module-level function 448 | response = _execute_logs_query( 449 | workspace_id=self.workspace_id, 450 | query=self.query, 451 | timespan=timespan_value, 452 | tenant_id=self.tenant_id, 453 | client_id=self.client_id, 454 | client_secret=self.client_secret, 455 | ) 456 | 457 | # Create a mapping of column names to their expected types from schema 458 | schema_field_map = {field.name: field.dataType for field in self._schema.fields} 459 | 460 | # Process all tables in response 461 | for table in response.tables: 462 | # Convert Azure Monitor rows to Spark Rows 463 | # table.columns is always a list of strings (column names) 464 | for row_idx, row_data in enumerate(table.rows): 465 | row_dict = {} 466 | 467 | # First, process columns from the query results 468 | for i, col in enumerate(table.columns): 469 | # Handle both string columns (real API) and objects with .name attribute (test mocks) 470 | column_name = str(col) if isinstance(col, str) else str(col.name) 471 | raw_value = row_data[i] 472 | 473 | # If column is in schema, convert to expected type 474 | if column_name in schema_field_map: 475 | expected_type = schema_field_map[column_name] 476 | try: 477 | converted_value = _convert_value_to_schema_type(raw_value, expected_type) 478 | row_dict[column_name] = converted_value 479 | except ValueError as e: 480 | raise ValueError(f"Row {row_idx}, column '{column_name}': {e}") 481 | # Note: columns not in schema are ignored (not included in row) 482 | 483 | # Second, add NULL values for schema columns that are not in query results 484 | for schema_column_name in schema_field_map.keys(): 485 | if schema_column_name not in row_dict: 486 | row_dict[schema_column_name] = None 487 | 488 | yield Row(**row_dict) 489 | 490 | 491 | class AzureMonitorBatchReader(AzureMonitorReader, DataSourceReader): 492 | """Batch reader for Azure Monitor / Log Analytics workspaces.""" 493 | 494 | def __init__(self, options, schema: StructType): 495 | """Initialize the batch reader with options and schema. 496 | 497 | Args: 498 | options: Dictionary of options containing workspace_id, query, time range, credentials 499 | schema: StructType schema (provided by DataSource.schema()) 500 | 501 | """ 502 | super().__init__(options, schema) 503 | 504 | # Time range options (mutually exclusive) 505 | timespan = options.get("timespan") 506 | start_time = options.get("start_time") 507 | end_time = options.get("end_time") 508 | 509 | # Optional options 510 | self.num_partitions = int(options.get("num_partitions", "1")) 511 | 512 | # Parse time range using module-level function 513 | self.start_time, self.end_time = _parse_time_range(timespan=timespan, start_time=start_time, end_time=end_time) 514 | 515 | def partitions(self): 516 | """Generate list of non-overlapping time range partitions. 517 | 518 | Returns: 519 | List of TimeRangePartition objects, each containing start_time and end_time 520 | 521 | """ 522 | # Calculate total time range duration 523 | total_duration = self.end_time - self.start_time 524 | 525 | # Split into N equal partitions 526 | partition_duration = total_duration / self.num_partitions 527 | 528 | partitions = [] 529 | for i in range(self.num_partitions): 530 | partition_start = self.start_time + (partition_duration * i) 531 | partition_end = self.start_time + (partition_duration * (i + 1)) 532 | 533 | # Ensure last partition ends exactly at end_time (avoid rounding errors) 534 | if i == self.num_partitions - 1: 535 | partition_end = self.end_time 536 | 537 | partitions.append(TimeRangePartition(partition_start, partition_end)) 538 | 539 | return partitions 540 | 541 | 542 | class AzureMonitorOffset: 543 | """Represents the offset for Azure Monitor streaming. 544 | 545 | The offset tracks the timestamp of the last processed data to enable incremental streaming. 546 | """ 547 | 548 | def __init__(self, timestamp: str): 549 | """Initialize offset with ISO 8601 timestamp. 550 | 551 | Args: 552 | timestamp: ISO 8601 formatted timestamp string (e.g., "2024-01-01T00:00:00Z") 553 | 554 | """ 555 | self.timestamp = timestamp 556 | 557 | def json(self): 558 | """Serialize offset to JSON string. 559 | 560 | Returns: 561 | JSON string representation of the offset 562 | 563 | """ 564 | import json 565 | 566 | return json.dumps({"timestamp": self.timestamp}) 567 | 568 | @staticmethod 569 | def from_json(json_str: str): 570 | """Deserialize offset from JSON string. 571 | 572 | Args: 573 | json_str: JSON string containing offset data 574 | 575 | Returns: 576 | AzureMonitorOffset instance 577 | 578 | """ 579 | import json 580 | 581 | data = json.loads(json_str) 582 | return AzureMonitorOffset(data["timestamp"]) 583 | 584 | 585 | class AzureMonitorStreamReader(AzureMonitorReader, DataSourceStreamReader): 586 | """Stream reader for Azure Monitor / Log Analytics workspaces. 587 | 588 | Implements incremental streaming by tracking time-based offsets and splitting 589 | time ranges into partitions for parallel processing. 590 | """ 591 | 592 | def __init__(self, options, schema: StructType): 593 | """Initialize the stream reader with options and schema. 594 | 595 | Args: 596 | options: Dictionary of options containing workspace_id, query, start_time, credentials 597 | schema: StructType schema (provided by DataSource.schema()) 598 | 599 | """ 600 | super().__init__(options, schema) 601 | 602 | # Stream-specific options 603 | start_time = options.get("start_time", "latest") 604 | # Support 'latest' as alias for current timestamp 605 | if start_time == "latest": 606 | from datetime import datetime, timezone 607 | 608 | self.start_time = datetime.now(timezone.utc).isoformat() 609 | else: 610 | # Validate that start_time is a valid ISO 8601 timestamp 611 | from datetime import datetime 612 | 613 | try: 614 | datetime.fromisoformat(start_time.replace("Z", "+00:00")) 615 | self.start_time = start_time 616 | except (ValueError, AttributeError) as e: 617 | raise ValueError( 618 | f"Invalid start_time format: {start_time}. Expected ISO 8601 format (e.g., '2024-01-01T00:00:00Z')" 619 | ) from e 620 | 621 | # Partition duration in seconds (default 1 hour) 622 | self.partition_duration = int(options.get("partition_duration", "3600")) 623 | 624 | def initialOffset(self): 625 | """Return the initial offset (start time). 626 | 627 | Returns: 628 | JSON string representation of AzureMonitorOffset with the configured start time 629 | 630 | """ 631 | return AzureMonitorOffset(self.start_time).json() 632 | 633 | def latestOffset(self): 634 | """Return the latest offset (current time). 635 | 636 | Returns: 637 | JSON string representation of AzureMonitorOffset with the current UTC timestamp 638 | 639 | """ 640 | from datetime import datetime, timezone 641 | 642 | current_time = datetime.now(timezone.utc).isoformat() 643 | return AzureMonitorOffset(current_time).json() 644 | 645 | def partitions(self, start, end): 646 | """Create partitions for the time range between start and end offsets. 647 | 648 | Splits the time range into fixed-duration partitions based on partition_duration. 649 | 650 | Args: 651 | start: JSON string representing AzureMonitorOffset for the start of the range 652 | end: JSON string representing AzureMonitorOffset for the end of the range 653 | 654 | Returns: 655 | List of TimeRangePartition objects 656 | 657 | """ 658 | from datetime import datetime, timedelta 659 | 660 | # Deserialize JSON strings to offset objects 661 | start_offset = AzureMonitorOffset.from_json(start) 662 | end_offset = AzureMonitorOffset.from_json(end) 663 | 664 | # Parse timestamps 665 | start_time = datetime.fromisoformat(start_offset.timestamp.replace("Z", "+00:00")) 666 | end_time = datetime.fromisoformat(end_offset.timestamp.replace("Z", "+00:00")) 667 | 668 | # Calculate total duration 669 | total_duration = (end_time - start_time).total_seconds() 670 | 671 | # If total duration is less than partition_duration, create a single partition 672 | if total_duration <= self.partition_duration: 673 | return [TimeRangePartition(start_time, end_time)] 674 | 675 | # Split into fixed-duration partitions 676 | partitions = [] 677 | current_start = start_time 678 | partition_delta = timedelta(seconds=self.partition_duration) 679 | 680 | while current_start < end_time: 681 | current_end = min(current_start + partition_delta, end_time) 682 | partitions.append(TimeRangePartition(current_start, current_end)) 683 | # Next partition starts 1 microsecond after current partition ends to avoid overlap 684 | current_start = current_end + timedelta(microseconds=1) 685 | 686 | return partitions 687 | 688 | def commit(self, end): 689 | """Called when a batch is successfully processed. 690 | 691 | Args: 692 | end: AzureMonitorOffset representing the end of the committed batch 693 | 694 | """ 695 | # Nothing special needed - Spark handles checkpointing 696 | pass 697 | 698 | 699 | # https://learn.microsoft.com/en-us/python/api/overview/azure/monitor-ingestion-readme?view=azure-python 700 | class AzureMonitorWriter: 701 | def __init__(self, options): 702 | self.options = options 703 | self.dce = self.options.get("dce") # data_collection_endpoint 704 | self.dcr_id = self.options.get("dcr_id") # data_collection_rule_id 705 | self.dcs = self.options.get("dcs") # data_collection_stream 706 | self.tenant_id = self.options.get("tenant_id") 707 | self.client_id = self.options.get("client_id") 708 | self.client_secret = self.options.get("client_secret") 709 | self.batch_size = int(self.options.get("batch_size", "50")) 710 | assert self.dce is not None 711 | assert self.dcr_id is not None 712 | assert self.dcs is not None 713 | assert self.tenant_id is not None 714 | assert self.client_id is not None 715 | assert self.client_secret is not None 716 | 717 | def _send_to_sentinel(self, s: LogsIngestionClient, msgs: list): 718 | if len(msgs) > 0: 719 | # TODO: add retries 720 | s.upload(rule_id=self.dcr_id, stream_name=self.dcs, logs=msgs) 721 | 722 | def write(self, iterator): 723 | """Writes the data, then returns the commit message of that partition. Library imports must be within the method.""" 724 | import json 725 | 726 | from azure.identity import ClientSecretCredential 727 | from azure.monitor.ingestion import LogsIngestionClient 728 | from pyspark import TaskContext 729 | # from azure.core.exceptions import HttpResponseError 730 | 731 | credential = ClientSecretCredential(self.tenant_id, self.client_id, self.client_secret) 732 | logs_client = LogsIngestionClient(self.dce, credential) 733 | 734 | msgs = [] 735 | 736 | context = TaskContext.get() 737 | partition_id = context.partitionId() 738 | cnt = 0 739 | for row in iterator: 740 | cnt += 1 741 | # Workaround to convert datetime/date to string 742 | msgs.append(json.loads(json.dumps(row.asDict(), cls=DateTimeJsonEncoder))) 743 | if len(msgs) >= self.batch_size: 744 | self._send_to_sentinel(logs_client, msgs) 745 | msgs = [] 746 | 747 | self._send_to_sentinel(logs_client, msgs) 748 | 749 | return SimpleCommitMessage(partition_id=partition_id, count=cnt) 750 | 751 | 752 | class AzureMonitorBatchWriter(AzureMonitorWriter, DataSourceWriter): 753 | def __init__(self, options): 754 | super().__init__(options) 755 | 756 | 757 | class AzureMonitorStreamWriter(AzureMonitorWriter, DataSourceStreamWriter): 758 | def __init__(self, options): 759 | super().__init__(options) 760 | 761 | def commit(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 762 | """Receives a sequence of :class:`WriterCommitMessage` when all write tasks have succeeded, then decides what to do with it. 763 | In this FakeStreamWriter, the metadata of the microbatch(number of rows and partitions) is written into a JSON file inside commit(). 764 | """ 765 | # status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages)) 766 | # with open(os.path.join(self.path, f"{batchId}.json"), "a") as file: 767 | # file.write(json.dumps(status) + "\n") 768 | pass 769 | 770 | def abort(self, messages: list[WriterCommitMessage | None], batchId: int) -> None: 771 | """Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some other tasks have failed, then decides what to do with it. 772 | In this FakeStreamWriter, a failure message is written into a text file inside abort(). 773 | """ 774 | # with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file: 775 | # file.write(f"failed in batch {batchId}") 776 | pass 777 | -------------------------------------------------------------------------------- /tests/test_mssentinel_reader.py: -------------------------------------------------------------------------------- 1 | """Unit tests for Azure Monitor Batch Reader.""" 2 | 3 | from datetime import datetime 4 | from unittest.mock import Mock, patch 5 | 6 | import pytest 7 | from pyspark.sql.types import StringType, StructField, StructType 8 | 9 | from cyber_connectors.MsSentinel import ( 10 | AzureMonitorBatchReader, 11 | AzureMonitorDataSource, 12 | MicrosoftSentinelDataSource, 13 | ) 14 | 15 | 16 | class TestAzureMonitorDataSource: 17 | """Test data source registration and reader creation.""" 18 | 19 | @pytest.fixture 20 | def basic_options(self): 21 | """Basic valid options for reader.""" 22 | return { 23 | "workspace_id": "test-workspace-id", 24 | "query": "AzureActivity | take 5", 25 | "timespan": "P1D", 26 | "tenant_id": "test-tenant", 27 | "client_id": "test-client", 28 | "client_secret": "test-secret", 29 | } 30 | 31 | def test_azure_monitor_name(self): 32 | """Test Azure Monitor data source name.""" 33 | assert AzureMonitorDataSource.name() == "azure-monitor" 34 | 35 | def test_ms_sentinel_name(self): 36 | """Test Microsoft Sentinel data source name.""" 37 | assert MicrosoftSentinelDataSource.name() == "ms-sentinel" 38 | 39 | def test_reader_method_exists(self): 40 | """Test that reader method exists and returns reader.""" 41 | ds = AzureMonitorDataSource(options={}) 42 | schema = StructType([StructField("test", StringType(), True)]) 43 | 44 | # Should not raise exception (actual reader creation will fail due to missing options) 45 | assert hasattr(ds, "reader") 46 | 47 | def test_datasource_has_schema_method(self, basic_options): 48 | """Test that data source has schema() method that infers schema.""" 49 | # Mock the Azure Monitor query client 50 | from unittest.mock import Mock, patch 51 | 52 | from azure.monitor.query import LogsQueryStatus 53 | 54 | with patch("azure.monitor.query.LogsQueryClient") as mock_client_cls, patch( 55 | "azure.identity.ClientSecretCredential" 56 | ) as mock_credential: 57 | # Create mock response - columns are strings, data types inferred from rows 58 | mock_response = Mock() 59 | mock_response.status = LogsQueryStatus.SUCCESS 60 | mock_table = Mock() 61 | mock_table.columns = ["TestCol"] 62 | mock_table.rows = [["test value"]] 63 | mock_response.tables = [mock_table] 64 | 65 | mock_client = Mock() 66 | mock_client.query_workspace.return_value = mock_response 67 | mock_client_cls.return_value = mock_client 68 | 69 | ds = AzureMonitorDataSource(options=basic_options) 70 | 71 | # Verify schema method exists and returns a schema 72 | assert hasattr(ds, "schema") 73 | returned_schema = ds.schema() 74 | assert returned_schema is not None 75 | assert isinstance(returned_schema, StructType) 76 | 77 | @patch("azure.monitor.query.LogsQueryClient") 78 | @patch("azure.identity.ClientSecretCredential") 79 | def test_schema_inference(self, mock_credential, mock_client, basic_options): 80 | """Test that schema is inferred by DataSource from actual row data.""" 81 | from datetime import datetime, timezone 82 | 83 | from azure.monitor.query import LogsQueryStatus 84 | 85 | # Create mock response for schema inference 86 | mock_response = Mock() 87 | mock_response.status = LogsQueryStatus.SUCCESS 88 | 89 | # table.columns is always a list of strings (column names) 90 | # Schema is inferred from actual data types in rows 91 | mock_table = Mock() 92 | mock_table.columns = ["TimeGenerated", "OperationName", "Count"] 93 | mock_table.rows = [ 94 | [ 95 | datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), # datetime -> TimestampType 96 | "Read", # str -> StringType 97 | 100, # int -> LongType 98 | ] 99 | ] 100 | mock_response.tables = [mock_table] 101 | 102 | # Setup mock client 103 | mock_client_instance = Mock() 104 | mock_client_instance.query_workspace.return_value = mock_response 105 | mock_client.return_value = mock_client_instance 106 | 107 | # Create data source and call schema() 108 | ds = AzureMonitorDataSource(options=basic_options) 109 | schema = ds.schema() 110 | 111 | # Verify schema was inferred 112 | assert schema is not None 113 | assert len(schema.fields) == 3 114 | 115 | # Verify field names and types 116 | from pyspark.sql.types import LongType, StringType, TimestampType 117 | 118 | assert schema.fields[0].name == "TimeGenerated" 119 | assert isinstance(schema.fields[0].dataType, TimestampType) 120 | 121 | assert schema.fields[1].name == "OperationName" 122 | assert isinstance(schema.fields[1].dataType, StringType) 123 | 124 | assert schema.fields[2].name == "Count" 125 | assert isinstance(schema.fields[2].dataType, LongType) 126 | 127 | # Verify the query was modified to include "| take 1" 128 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 129 | assert "| take 1" in call_kwargs["query"] 130 | 131 | @patch("azure.monitor.query.LogsQueryClient") 132 | @patch("azure.identity.ClientSecretCredential") 133 | def test_schema_inference_query_already_has_limit(self, mock_credential, mock_client, basic_options): 134 | """Test that schema inference doesn't add duplicate limit when query already has one.""" 135 | from azure.monitor.query import LogsQueryStatus 136 | 137 | # Modify query to already have a limit 138 | basic_options["query"] = "AzureActivity | take 1" 139 | 140 | # Create mock response - columns are strings, data types inferred from rows 141 | mock_response = Mock() 142 | mock_response.status = LogsQueryStatus.SUCCESS 143 | mock_table = Mock() 144 | mock_table.columns = ["TestCol"] 145 | mock_table.rows = [["test"]] 146 | mock_response.tables = [mock_table] 147 | 148 | # Setup mock client 149 | mock_client_instance = Mock() 150 | mock_client_instance.query_workspace.return_value = mock_response 151 | mock_client.return_value = mock_client_instance 152 | 153 | # Create data source and call schema() 154 | ds = AzureMonitorDataSource(options=basic_options) 155 | ds.schema() 156 | 157 | # Verify the query was not modified (already has limit) 158 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 159 | query_used = call_kwargs["query"] 160 | # Should only have one "| take 1" 161 | assert query_used.count("| take 1") == 1 162 | 163 | @patch("azure.monitor.query.LogsQueryClient") 164 | @patch("azure.identity.ClientSecretCredential") 165 | def test_schema_inference_no_tables(self, mock_credential, mock_client, basic_options): 166 | """Test that schema inference fails when query returns no tables.""" 167 | from azure.monitor.query import LogsQueryStatus 168 | 169 | # Create mock response with no tables 170 | mock_response = Mock() 171 | mock_response.status = LogsQueryStatus.SUCCESS 172 | mock_response.tables = [] 173 | 174 | # Setup mock client 175 | mock_client_instance = Mock() 176 | mock_client_instance.query_workspace.return_value = mock_response 177 | mock_client.return_value = mock_client_instance 178 | 179 | # Create data source and attempt to call schema() 180 | ds = AzureMonitorDataSource(options=basic_options) 181 | 182 | with pytest.raises(Exception, match="Schema inference failed: query returned no tables"): 183 | ds.schema() 184 | 185 | @patch("azure.monitor.query.LogsQueryClient") 186 | @patch("azure.identity.ClientSecretCredential") 187 | def test_schema_inference_no_columns(self, mock_credential, mock_client, basic_options): 188 | """Test that schema inference fails when query returns no columns.""" 189 | from azure.monitor.query import LogsQueryStatus 190 | 191 | # Create mock response with table but no columns 192 | mock_response = Mock() 193 | mock_response.status = LogsQueryStatus.SUCCESS 194 | mock_table = Mock() 195 | mock_table.columns = [] 196 | mock_response.tables = [mock_table] 197 | 198 | # Setup mock client 199 | mock_client_instance = Mock() 200 | mock_client_instance.query_workspace.return_value = mock_response 201 | mock_client.return_value = mock_client_instance 202 | 203 | # Create data source and attempt to call schema() 204 | ds = AzureMonitorDataSource(options=basic_options) 205 | 206 | with pytest.raises(Exception, match="Schema inference failed: query returned no columns"): 207 | ds.schema() 208 | 209 | @patch("azure.monitor.query.LogsQueryClient") 210 | @patch("azure.identity.ClientSecretCredential") 211 | def test_schema_inference_query_failure(self, mock_credential, mock_client, basic_options): 212 | """Test that schema inference fails when query fails.""" 213 | from azure.monitor.query import LogsQueryStatus 214 | 215 | # Create mock response with failure status 216 | mock_response = Mock() 217 | mock_response.status = LogsQueryStatus.PARTIAL 218 | 219 | # Setup mock client 220 | mock_client_instance = Mock() 221 | mock_client_instance.query_workspace.return_value = mock_response 222 | mock_client.return_value = mock_client_instance 223 | 224 | # Create data source and attempt to call schema() 225 | ds = AzureMonitorDataSource(options=basic_options) 226 | 227 | with pytest.raises(Exception, match="Query failed with status"): 228 | ds.schema() 229 | 230 | @patch("azure.monitor.query.LogsQueryClient") 231 | @patch("azure.identity.ClientSecretCredential") 232 | def test_schema_inference_type_mapping(self, mock_credential, mock_client, basic_options): 233 | """Test that schema inference correctly maps Python types to PySpark types.""" 234 | from datetime import datetime, timezone 235 | 236 | from azure.monitor.query import LogsQueryStatus 237 | from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType, TimestampType 238 | 239 | # Create mock response with various Python data types 240 | mock_response = Mock() 241 | mock_response.status = LogsQueryStatus.SUCCESS 242 | 243 | # Map Python types to expected PySpark types 244 | # (column_name, python_value, expected_spark_type) 245 | test_data = [ 246 | ("BoolCol", True, BooleanType), 247 | ("DateCol", datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), TimestampType), 248 | ("IntCol", 42, LongType), 249 | ("FloatCol", 3.14, DoubleType), 250 | ("StringCol", "test", StringType), 251 | ("NullCol", None, StringType), # None defaults to StringType 252 | ] 253 | 254 | mock_table = Mock() 255 | mock_table.columns = [col_name for col_name, _, _ in test_data] 256 | mock_table.rows = [[value for _, value, _ in test_data]] 257 | mock_response.tables = [mock_table] 258 | 259 | # Setup mock client 260 | mock_client_instance = Mock() 261 | mock_client_instance.query_workspace.return_value = mock_response 262 | mock_client.return_value = mock_client_instance 263 | 264 | # Create data source and call schema() 265 | ds = AzureMonitorDataSource(options=basic_options) 266 | schema = ds.schema() 267 | 268 | # Verify schema was inferred with correct types 269 | assert len(schema.fields) == len(test_data) 270 | 271 | for i, (expected_name, _, expected_type_class) in enumerate(test_data): 272 | assert schema.fields[i].name == expected_name 273 | assert isinstance(schema.fields[i].dataType, expected_type_class) 274 | 275 | 276 | class TestAzureMonitorBatchReader: 277 | """Test Azure Monitor Batch Reader implementation.""" 278 | 279 | @pytest.fixture 280 | def basic_options(self): 281 | """Basic valid options for reader.""" 282 | return { 283 | "workspace_id": "test-workspace-id", 284 | "query": "AzureActivity | take 5", 285 | "timespan": "P1D", 286 | "tenant_id": "test-tenant", 287 | "client_id": "test-client", 288 | "client_secret": "test-secret", 289 | } 290 | 291 | @pytest.fixture 292 | def basic_schema(self): 293 | """Basic schema for testing.""" 294 | return StructType( 295 | [StructField("TimeGenerated", StringType(), True), StructField("OperationName", StringType(), True)] 296 | ) 297 | 298 | def test_reader_initialization(self, basic_options, basic_schema): 299 | """Test reader initializes with valid options.""" 300 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 301 | 302 | assert reader.workspace_id == "test-workspace-id" 303 | assert reader.query == "AzureActivity | take 5" 304 | assert reader.tenant_id == "test-tenant" 305 | assert reader.client_id == "test-client" 306 | assert reader.client_secret == "test-secret" 307 | assert reader.num_partitions == 1 308 | # Verify that start_time and end_time were set (from timespan) 309 | assert isinstance(reader.start_time, datetime) 310 | assert isinstance(reader.end_time, datetime) 311 | assert reader.start_time < reader.end_time 312 | 313 | def test_reader_optional_parameters(self, basic_options, basic_schema): 314 | """Test reader handles optional parameters.""" 315 | basic_options["num_partitions"] = "3" 316 | 317 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 318 | 319 | assert reader.num_partitions == 3 320 | 321 | def test_reader_missing_workspace_id(self, basic_schema): 322 | """Test reader fails without workspace_id.""" 323 | options = { 324 | "query": "AzureActivity | take 5", 325 | "timespan": "P1D", 326 | "tenant_id": "test-tenant", 327 | "client_id": "test-client", 328 | "client_secret": "test-secret", 329 | } 330 | 331 | with pytest.raises(AssertionError, match="workspace_id is required"): 332 | AzureMonitorBatchReader(options, basic_schema) 333 | 334 | def test_reader_missing_query(self, basic_schema): 335 | """Test reader fails without query.""" 336 | options = { 337 | "workspace_id": "test-workspace-id", 338 | "timespan": "P1D", 339 | "tenant_id": "test-tenant", 340 | "client_id": "test-client", 341 | "client_secret": "test-secret", 342 | } 343 | 344 | with pytest.raises(AssertionError, match="query is required"): 345 | AzureMonitorBatchReader(options, basic_schema) 346 | 347 | def test_reader_missing_timespan(self, basic_schema): 348 | """Test reader fails without timespan (now requires either timespan or start_time).""" 349 | options = { 350 | "workspace_id": "test-workspace-id", 351 | "query": "AzureActivity | take 5", 352 | "tenant_id": "test-tenant", 353 | "client_id": "test-client", 354 | "client_secret": "test-secret", 355 | } 356 | 357 | with pytest.raises(Exception, match="Either 'timespan' or 'start_time' must be provided"): 358 | AzureMonitorBatchReader(options, basic_schema) 359 | 360 | def test_reader_missing_tenant_id(self, basic_schema): 361 | """Test reader fails without tenant_id.""" 362 | options = { 363 | "workspace_id": "test-workspace-id", 364 | "query": "AzureActivity | take 5", 365 | "timespan": "P1D", 366 | "client_id": "test-client", 367 | "client_secret": "test-secret", 368 | } 369 | 370 | with pytest.raises(AssertionError, match="tenant_id is required"): 371 | AzureMonitorBatchReader(options, basic_schema) 372 | 373 | def test_reader_missing_client_id(self, basic_schema): 374 | """Test reader fails without client_id.""" 375 | options = { 376 | "workspace_id": "test-workspace-id", 377 | "query": "AzureActivity | take 5", 378 | "timespan": "P1D", 379 | "tenant_id": "test-tenant", 380 | "client_secret": "test-secret", 381 | } 382 | 383 | with pytest.raises(AssertionError, match="client_id is required"): 384 | AzureMonitorBatchReader(options, basic_schema) 385 | 386 | def test_reader_missing_client_secret(self, basic_schema): 387 | """Test reader fails without client_secret.""" 388 | options = { 389 | "workspace_id": "test-workspace-id", 390 | "query": "AzureActivity | take 5", 391 | "timespan": "P1D", 392 | "tenant_id": "test-tenant", 393 | "client_id": "test-client", 394 | } 395 | 396 | with pytest.raises(AssertionError, match="client_secret is required"): 397 | AzureMonitorBatchReader(options, basic_schema) 398 | 399 | def test_reader_missing_time_range(self, basic_schema): 400 | """Test reader fails without timespan or start_time.""" 401 | options = { 402 | "workspace_id": "test-workspace-id", 403 | "query": "AzureActivity | take 5", 404 | "tenant_id": "test-tenant", 405 | "client_id": "test-client", 406 | "client_secret": "test-secret", 407 | } 408 | 409 | with pytest.raises(Exception, match="Either 'timespan' or 'start_time' must be provided"): 410 | AzureMonitorBatchReader(options, basic_schema) 411 | 412 | def test_reader_both_timespan_and_start_time(self, basic_schema): 413 | """Test that reader uses timespan when both are provided (timespan takes precedence).""" 414 | options = { 415 | "workspace_id": "test-workspace-id", 416 | "query": "AzureActivity | take 5", 417 | "timespan": "P1D", 418 | "start_time": "2024-01-01T00:00:00Z", 419 | "tenant_id": "test-tenant", 420 | "client_id": "test-client", 421 | "client_secret": "test-secret", 422 | } 423 | 424 | # Should succeed - timespan takes precedence over start_time in module-level function 425 | reader = AzureMonitorBatchReader(options, basic_schema) 426 | assert reader.start_time is not None 427 | assert reader.end_time is not None 428 | 429 | def test_reader_with_start_time_only(self, basic_schema): 430 | """Test reader initializes with only start_time (end_time defaults to now).""" 431 | options = { 432 | "workspace_id": "test-workspace-id", 433 | "query": "AzureActivity | take 5", 434 | "start_time": "2024-01-01T00:00:00Z", 435 | "tenant_id": "test-tenant", 436 | "client_id": "test-client", 437 | "client_secret": "test-secret", 438 | } 439 | 440 | reader = AzureMonitorBatchReader(options, basic_schema) 441 | assert isinstance(reader.start_time, datetime) 442 | assert isinstance(reader.end_time, datetime) 443 | assert reader.end_time is not None # Should be set to current time 444 | assert reader.start_time < reader.end_time 445 | 446 | def test_reader_with_start_and_end_time(self, basic_schema): 447 | """Test reader initializes with both start_time and end_time.""" 448 | options = { 449 | "workspace_id": "test-workspace-id", 450 | "query": "AzureActivity | take 5", 451 | "start_time": "2024-01-01T00:00:00Z", 452 | "end_time": "2024-01-02T00:00:00Z", 453 | "tenant_id": "test-tenant", 454 | "client_id": "test-client", 455 | "client_secret": "test-secret", 456 | } 457 | 458 | reader = AzureMonitorBatchReader(options, basic_schema) 459 | assert isinstance(reader.start_time, datetime) 460 | assert isinstance(reader.end_time, datetime) 461 | # Verify the times are approximately correct (allowing for timezone handling) 462 | assert reader.start_time.year == 2024 463 | assert reader.start_time.month == 1 464 | assert reader.start_time.day == 1 465 | assert reader.end_time.year == 2024 466 | assert reader.end_time.month == 1 467 | assert reader.end_time.day == 2 468 | 469 | def test_reader_with_end_time_only(self, basic_schema): 470 | """Test reader fails with only end_time (start_time is required).""" 471 | options = { 472 | "workspace_id": "test-workspace-id", 473 | "query": "AzureActivity | take 5", 474 | "end_time": "2024-01-02T00:00:00Z", 475 | "tenant_id": "test-tenant", 476 | "client_id": "test-client", 477 | "client_secret": "test-secret", 478 | } 479 | 480 | # Should fail because neither timespan nor start_time is provided 481 | with pytest.raises(Exception, match="Either 'timespan' or 'start_time' must be provided"): 482 | AzureMonitorBatchReader(options, basic_schema) 483 | 484 | def test_timespan_parsing_days(self): 485 | """Test timespan parsing for days using module-level function.""" 486 | from cyber_connectors.MsSentinel import _parse_time_range 487 | 488 | start, end = _parse_time_range(timespan="P1D") 489 | assert (end - start).days == 1 490 | 491 | start, end = _parse_time_range(timespan="P7D") 492 | assert (end - start).days == 7 493 | 494 | def test_timespan_parsing_hours(self): 495 | """Test timespan parsing for hours using module-level function.""" 496 | from cyber_connectors.MsSentinel import _parse_time_range 497 | 498 | start, end = _parse_time_range(timespan="PT1H") 499 | assert (end - start).total_seconds() == 3600 500 | 501 | start, end = _parse_time_range(timespan="PT24H") 502 | assert (end - start).total_seconds() == 86400 503 | 504 | def test_timespan_parsing_minutes(self): 505 | """Test timespan parsing for minutes using module-level function.""" 506 | from cyber_connectors.MsSentinel import _parse_time_range 507 | 508 | start, end = _parse_time_range(timespan="PT30M") 509 | assert (end - start).total_seconds() == 1800 510 | 511 | def test_timespan_parsing_seconds(self): 512 | """Test timespan parsing for seconds using module-level function.""" 513 | from cyber_connectors.MsSentinel import _parse_time_range 514 | 515 | start, end = _parse_time_range(timespan="PT120S") 516 | assert (end - start).total_seconds() == 120 517 | 518 | def test_timespan_parsing_combined(self): 519 | """Test timespan parsing for combined duration using module-level function.""" 520 | from cyber_connectors.MsSentinel import _parse_time_range 521 | 522 | # P1DT2H30M 523 | start, end = _parse_time_range(timespan="P1DT2H30M") 524 | total_seconds = (end - start).total_seconds() 525 | expected_seconds = 1 * 86400 + 2 * 3600 + 30 * 60 526 | assert total_seconds == expected_seconds 527 | 528 | def test_timespan_parsing_invalid(self): 529 | """Test timespan parsing with invalid format using module-level function.""" 530 | from cyber_connectors.MsSentinel import _parse_time_range 531 | 532 | with pytest.raises(ValueError, match="Invalid timespan format"): 533 | _parse_time_range(timespan="invalid") 534 | 535 | with pytest.raises(ValueError, match="Invalid timespan format"): 536 | _parse_time_range(timespan="1D") 537 | 538 | @patch("azure.monitor.query.LogsQueryClient") 539 | @patch("azure.identity.ClientSecretCredential") 540 | def test_read_success_single_partition(self, mock_credential, mock_client, basic_options, basic_schema): 541 | """Test successful data reading with single partition.""" 542 | from azure.monitor.query import LogsQueryStatus 543 | 544 | # Create mock response 545 | mock_response = Mock() 546 | mock_response.status = LogsQueryStatus.SUCCESS 547 | 548 | # Create mock table with columns and rows 549 | # table.columns is always a list of strings (column names) 550 | mock_table = Mock() 551 | mock_table.columns = ["TimeGenerated", "OperationName"] 552 | mock_table.rows = [ 553 | ["2024-01-01T00:00:00Z", "Read"], 554 | ["2024-01-01T01:00:00Z", "Write"], 555 | ["2024-01-01T02:00:00Z", "Delete"], 556 | ] 557 | mock_response.tables = [mock_table] 558 | 559 | # Setup mock client 560 | mock_client_instance = Mock() 561 | mock_client_instance.query_workspace.return_value = mock_response 562 | mock_client.return_value = mock_client_instance 563 | 564 | # Create reader and read 565 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 566 | # Use the partitions method to get the partition 567 | partitions = reader.partitions() 568 | assert len(partitions) == 1 569 | rows = list(reader.read(partitions[0])) 570 | 571 | # Verify results 572 | assert len(rows) == 3 573 | assert rows[0].TimeGenerated == "2024-01-01T00:00:00Z" 574 | assert rows[0].OperationName == "Read" 575 | assert rows[1].TimeGenerated == "2024-01-01T01:00:00Z" 576 | assert rows[1].OperationName == "Write" 577 | assert rows[2].TimeGenerated == "2024-01-01T02:00:00Z" 578 | assert rows[2].OperationName == "Delete" 579 | 580 | # Verify query_workspace was called with correct parameters 581 | mock_client_instance.query_workspace.assert_called_once() 582 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 583 | assert call_kwargs["workspace_id"] == "test-workspace-id" 584 | assert call_kwargs["query"] == "AzureActivity | take 5" 585 | 586 | @patch("azure.monitor.query.LogsQueryClient") 587 | @patch("azure.identity.ClientSecretCredential") 588 | def test_read_multiple_tables(self, mock_credential, mock_client, basic_options, basic_schema): 589 | """Test reading with multiple tables in response.""" 590 | from azure.monitor.query import LogsQueryStatus 591 | 592 | # Create mock response with two tables 593 | mock_response = Mock() 594 | mock_response.status = LogsQueryStatus.SUCCESS 595 | 596 | # First table 597 | mock_col1 = Mock() 598 | mock_col1.name = "Col1" 599 | mock_table1 = Mock() 600 | mock_table1.columns = [mock_col1] 601 | mock_table1.rows = [["value1"], ["value2"]] 602 | 603 | # Second table 604 | mock_col2 = Mock() 605 | mock_col2.name = "Col2" 606 | mock_table2 = Mock() 607 | mock_table2.columns = [mock_col2] 608 | mock_table2.rows = [["value3"]] 609 | 610 | mock_response.tables = [mock_table1, mock_table2] 611 | 612 | # Setup mock client 613 | mock_client_instance = Mock() 614 | mock_client_instance.query_workspace.return_value = mock_response 615 | mock_client.return_value = mock_client_instance 616 | 617 | # Create reader and read 618 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 619 | partitions = reader.partitions() 620 | rows = list(reader.read(partitions[0])) 621 | 622 | # Should have rows from both tables 623 | assert len(rows) == 3 624 | 625 | @patch("azure.monitor.query.LogsQueryClient") 626 | @patch("azure.identity.ClientSecretCredential") 627 | def test_read_query_failure(self, mock_credential, mock_client, basic_options, basic_schema): 628 | """Test handling of query failure.""" 629 | from azure.monitor.query import LogsQueryStatus 630 | 631 | # Create mock response with failure status 632 | mock_response = Mock() 633 | mock_response.status = LogsQueryStatus.PARTIAL 634 | 635 | # Setup mock client 636 | mock_client_instance = Mock() 637 | mock_client_instance.query_workspace.return_value = mock_response 638 | mock_client.return_value = mock_client_instance 639 | 640 | # Create reader and attempt to read 641 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 642 | partitions = reader.partitions() 643 | 644 | with pytest.raises(Exception, match="Query failed with status"): 645 | list(reader.read(partitions[0])) 646 | 647 | @patch("azure.monitor.query.LogsQueryClient") 648 | @patch("azure.identity.ClientSecretCredential") 649 | def test_read_empty_results(self, mock_credential, mock_client, basic_options, basic_schema): 650 | """Test reading with empty results.""" 651 | from azure.monitor.query import LogsQueryStatus 652 | 653 | # Create mock response with no rows 654 | mock_response = Mock() 655 | mock_response.status = LogsQueryStatus.SUCCESS 656 | mock_column = Mock() 657 | mock_column.name = "TestCol" 658 | mock_table = Mock() 659 | mock_table.columns = [mock_column] 660 | mock_table.rows = [] 661 | mock_response.tables = [mock_table] 662 | 663 | # Setup mock client 664 | mock_client_instance = Mock() 665 | mock_client_instance.query_workspace.return_value = mock_response 666 | mock_client.return_value = mock_client_instance 667 | 668 | # Create reader and read 669 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 670 | partitions = reader.partitions() 671 | rows = list(reader.read(partitions[0])) 672 | 673 | # Should have no rows 674 | assert len(rows) == 0 675 | 676 | def test_partitions_generation(self, basic_options, basic_schema): 677 | """Test that partitions method generates correct time ranges.""" 678 | basic_options["num_partitions"] = "3" 679 | basic_options["start_time"] = "2024-01-01T00:00:00Z" 680 | basic_options["end_time"] = "2024-01-01T03:00:00Z" 681 | del basic_options["timespan"] # Remove timespan to use start/end time 682 | 683 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 684 | partitions = reader.partitions() 685 | 686 | # Should have 3 partitions 687 | assert len(partitions) == 3 688 | 689 | # Each partition should be approximately 1 hour 690 | # Partition 0: 00:00 - 01:00 691 | # Partition 1: 01:00 - 02:00 692 | # Partition 2: 02:00 - 03:00 693 | assert partitions[0].start_time == reader.start_time 694 | assert partitions[2].end_time == reader.end_time 695 | 696 | # Verify partitions are contiguous and non-overlapping 697 | assert partitions[0].end_time == partitions[1].start_time 698 | assert partitions[1].end_time == partitions[2].start_time 699 | 700 | @patch("azure.monitor.query.LogsQueryClient") 701 | @patch("azure.identity.ClientSecretCredential") 702 | def test_read_multiple_partitions(self, mock_credential, mock_client, basic_options, basic_schema): 703 | """Test reading with multiple partitions (each partition queries independently).""" 704 | from azure.monitor.query import LogsQueryStatus 705 | 706 | basic_options["num_partitions"] = "3" 707 | 708 | # Setup mock client to return different results based on timespan 709 | mock_client_instance = Mock() 710 | 711 | def query_side_effect(*args, **kwargs): 712 | # Each partition will get a different response 713 | mock_response = Mock() 714 | mock_response.status = LogsQueryStatus.SUCCESS 715 | mock_column = Mock() 716 | mock_column.name = "Value" 717 | mock_table = Mock() 718 | mock_table.columns = [mock_column] 719 | # Return 3 rows per partition 720 | mock_table.rows = [["row1"], ["row2"], ["row3"]] 721 | mock_response.tables = [mock_table] 722 | return mock_response 723 | 724 | mock_client_instance.query_workspace.side_effect = query_side_effect 725 | mock_client.return_value = mock_client_instance 726 | 727 | # Create reader 728 | reader = AzureMonitorBatchReader(basic_options, basic_schema) 729 | partitions = reader.partitions() 730 | 731 | # Verify we have 3 partitions 732 | assert len(partitions) == 3 733 | 734 | # Read from each partition independently 735 | rows0 = list(reader.read(partitions[0])) 736 | rows1 = list(reader.read(partitions[1])) 737 | rows2 = list(reader.read(partitions[2])) 738 | 739 | # Each partition should have returned 3 rows 740 | assert len(rows0) == 3 741 | assert len(rows1) == 3 742 | assert len(rows2) == 3 743 | 744 | # Verify query was called 3 times (once per partition) 745 | assert mock_client_instance.query_workspace.call_count == 3 746 | 747 | # Verify each query was called with different timespan ranges 748 | call_args_list = mock_client_instance.query_workspace.call_args_list 749 | timespan0 = call_args_list[0][1]["timespan"] 750 | timespan1 = call_args_list[1][1]["timespan"] 751 | timespan2 = call_args_list[2][1]["timespan"] 752 | 753 | # Verify timespans are tuples (start, end) 754 | assert isinstance(timespan0, tuple) and len(timespan0) == 2 755 | assert isinstance(timespan1, tuple) and len(timespan1) == 2 756 | assert isinstance(timespan2, tuple) and len(timespan2) == 2 757 | 758 | # Verify partitions are non-overlapping 759 | assert timespan0[0] == partitions[0].start_time 760 | assert timespan0[1] == partitions[0].end_time 761 | assert timespan1[0] == partitions[1].start_time 762 | assert timespan1[1] == partitions[1].end_time 763 | assert timespan2[0] == partitions[2].start_time 764 | assert timespan2[1] == partitions[2].end_time 765 | 766 | @patch("azure.monitor.query.LogsQueryClient") 767 | @patch("azure.identity.ClientSecretCredential") 768 | def test_read_with_start_and_end_time(self, mock_credential, mock_client, basic_schema): 769 | """Test reading with start_time and end_time instead of timespan.""" 770 | from azure.monitor.query import LogsQueryStatus 771 | 772 | options = { 773 | "workspace_id": "test-workspace-id", 774 | "query": "AzureActivity | take 5", 775 | "start_time": "2024-01-01T00:00:00Z", 776 | "end_time": "2024-01-02T00:00:00Z", 777 | "tenant_id": "test-tenant", 778 | "client_id": "test-client", 779 | "client_secret": "test-secret", 780 | } 781 | 782 | # Create mock response 783 | mock_response = Mock() 784 | mock_response.status = LogsQueryStatus.SUCCESS 785 | mock_column = Mock() 786 | mock_column.name = "TestCol" 787 | mock_table = Mock() 788 | mock_table.columns = [mock_column] 789 | mock_table.rows = [["value1"], ["value2"]] 790 | mock_response.tables = [mock_table] 791 | 792 | # Setup mock client 793 | mock_client_instance = Mock() 794 | mock_client_instance.query_workspace.return_value = mock_response 795 | mock_client.return_value = mock_client_instance 796 | 797 | # Create reader and read 798 | reader = AzureMonitorBatchReader(options, basic_schema) 799 | partitions = reader.partitions() 800 | rows = list(reader.read(partitions[0])) 801 | 802 | # Verify results 803 | assert len(rows) == 2 804 | 805 | # Verify query_workspace was called with tuple timespan (start, end) 806 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 807 | assert "timespan" in call_kwargs 808 | timespan_arg = call_kwargs["timespan"] 809 | assert isinstance(timespan_arg, tuple) 810 | assert len(timespan_arg) == 2 811 | 812 | @patch("azure.monitor.query.LogsQueryClient") 813 | @patch("azure.identity.ClientSecretCredential") 814 | def test_read_with_start_time_only(self, mock_credential, mock_client, basic_schema): 815 | """Test reading with only start_time (end_time defaults to now).""" 816 | from azure.monitor.query import LogsQueryStatus 817 | 818 | options = { 819 | "workspace_id": "test-workspace-id", 820 | "query": "AzureActivity | take 5", 821 | "start_time": "2024-01-01T00:00:00Z", 822 | "tenant_id": "test-tenant", 823 | "client_id": "test-client", 824 | "client_secret": "test-secret", 825 | } 826 | 827 | # Create mock response 828 | mock_response = Mock() 829 | mock_response.status = LogsQueryStatus.SUCCESS 830 | mock_column = Mock() 831 | mock_column.name = "TestCol" 832 | mock_table = Mock() 833 | mock_table.columns = [mock_column] 834 | mock_table.rows = [["value1"]] 835 | mock_response.tables = [mock_table] 836 | 837 | # Setup mock client 838 | mock_client_instance = Mock() 839 | mock_client_instance.query_workspace.return_value = mock_response 840 | mock_client.return_value = mock_client_instance 841 | 842 | # Create reader and read 843 | reader = AzureMonitorBatchReader(options, basic_schema) 844 | partitions = reader.partitions() 845 | rows = list(reader.read(partitions[0])) 846 | 847 | # Verify results 848 | assert len(rows) == 1 849 | 850 | # Verify end_time was set automatically 851 | assert isinstance(reader.end_time, datetime) 852 | assert isinstance(reader.start_time, datetime) 853 | assert reader.start_time < reader.end_time 854 | 855 | # Verify query_workspace was called with tuple timespan 856 | call_kwargs = mock_client_instance.query_workspace.call_args[1] 857 | timespan_arg = call_kwargs["timespan"] 858 | assert isinstance(timespan_arg, tuple) 859 | assert len(timespan_arg) == 2 860 | 861 | @patch("azure.monitor.query.LogsQueryClient") 862 | @patch("azure.identity.ClientSecretCredential") 863 | def test_type_conversion_string_to_int(self, mock_credential, mock_client, basic_options): 864 | """Test that string values are converted to int when schema specifies LongType.""" 865 | from azure.monitor.query import LogsQueryStatus 866 | from pyspark.sql.types import LongType, StructField, StructType 867 | 868 | # Schema expects int 869 | schema = StructType([StructField("Count", LongType(), True)]) 870 | 871 | # Create mock response with string value 872 | mock_response = Mock() 873 | mock_response.status = LogsQueryStatus.SUCCESS 874 | mock_table = Mock() 875 | mock_table.columns = ["Count"] 876 | mock_table.rows = [["123"], ["456"]] # String values 877 | mock_response.tables = [mock_table] 878 | 879 | mock_client_instance = Mock() 880 | mock_client_instance.query_workspace.return_value = mock_response 881 | mock_client.return_value = mock_client_instance 882 | 883 | reader = AzureMonitorBatchReader(basic_options, schema) 884 | partitions = reader.partitions() 885 | rows = list(reader.read(partitions[0])) 886 | 887 | # Verify conversion to int 888 | assert len(rows) == 2 889 | assert rows[0].Count == 123 890 | assert isinstance(rows[0].Count, int) 891 | assert rows[1].Count == 456 892 | assert isinstance(rows[1].Count, int) 893 | 894 | @patch("azure.monitor.query.LogsQueryClient") 895 | @patch("azure.identity.ClientSecretCredential") 896 | def test_type_conversion_string_to_bool(self, mock_credential, mock_client, basic_options): 897 | """Test that string values are converted to bool when schema specifies BooleanType.""" 898 | from azure.monitor.query import LogsQueryStatus 899 | from pyspark.sql.types import BooleanType, StructField, StructType 900 | 901 | schema = StructType([StructField("IsActive", BooleanType(), True)]) 902 | 903 | mock_response = Mock() 904 | mock_response.status = LogsQueryStatus.SUCCESS 905 | mock_table = Mock() 906 | mock_table.columns = ["IsActive"] 907 | mock_table.rows = [["true"], ["false"], ["1"], ["0"]] 908 | mock_response.tables = [mock_table] 909 | 910 | mock_client_instance = Mock() 911 | mock_client_instance.query_workspace.return_value = mock_response 912 | mock_client.return_value = mock_client_instance 913 | 914 | reader = AzureMonitorBatchReader(basic_options, schema) 915 | partitions = reader.partitions() 916 | rows = list(reader.read(partitions[0])) 917 | 918 | assert len(rows) == 4 919 | assert rows[0].IsActive is True 920 | assert rows[1].IsActive is False 921 | assert rows[2].IsActive is True 922 | assert rows[3].IsActive is False 923 | 924 | @patch("azure.monitor.query.LogsQueryClient") 925 | @patch("azure.identity.ClientSecretCredential") 926 | def test_type_conversion_string_to_timestamp(self, mock_credential, mock_client, basic_options): 927 | """Test that string values are converted to datetime when schema specifies TimestampType.""" 928 | from azure.monitor.query import LogsQueryStatus 929 | from pyspark.sql.types import StructField, StructType, TimestampType 930 | 931 | schema = StructType([StructField("Timestamp", TimestampType(), True)]) 932 | 933 | mock_response = Mock() 934 | mock_response.status = LogsQueryStatus.SUCCESS 935 | mock_table = Mock() 936 | mock_table.columns = ["Timestamp"] 937 | mock_table.rows = [["2024-01-01T00:00:00Z"], ["2024-12-31T23:59:59Z"]] 938 | mock_response.tables = [mock_table] 939 | 940 | mock_client_instance = Mock() 941 | mock_client_instance.query_workspace.return_value = mock_response 942 | mock_client.return_value = mock_client_instance 943 | 944 | reader = AzureMonitorBatchReader(basic_options, schema) 945 | partitions = reader.partitions() 946 | rows = list(reader.read(partitions[0])) 947 | 948 | assert len(rows) == 2 949 | assert isinstance(rows[0].Timestamp, datetime) 950 | assert isinstance(rows[1].Timestamp, datetime) 951 | 952 | @patch("azure.monitor.query.LogsQueryClient") 953 | @patch("azure.identity.ClientSecretCredential") 954 | def test_type_conversion_invalid_raises_error(self, mock_credential, mock_client, basic_options): 955 | """Test that invalid type conversions raise ValueError with descriptive message.""" 956 | from azure.monitor.query import LogsQueryStatus 957 | from pyspark.sql.types import LongType, StructField, StructType 958 | 959 | schema = StructType([StructField("Count", LongType(), True)]) 960 | 961 | # Create mock response with non-convertible value 962 | mock_response = Mock() 963 | mock_response.status = LogsQueryStatus.SUCCESS 964 | mock_table = Mock() 965 | mock_table.columns = ["Count"] 966 | mock_table.rows = [["not-a-number"]] # Cannot convert to int 967 | mock_response.tables = [mock_table] 968 | 969 | mock_client_instance = Mock() 970 | mock_client_instance.query_workspace.return_value = mock_response 971 | mock_client.return_value = mock_client_instance 972 | 973 | reader = AzureMonitorBatchReader(basic_options, schema) 974 | partitions = reader.partitions() 975 | 976 | # Should raise ValueError with row/column info 977 | with pytest.raises(ValueError) as exc_info: 978 | list(reader.read(partitions[0])) 979 | 980 | assert "Row 0" in str(exc_info.value) 981 | assert "Count" in str(exc_info.value) 982 | 983 | @patch("azure.monitor.query.LogsQueryClient") 984 | @patch("azure.identity.ClientSecretCredential") 985 | def test_type_conversion_mixed_types(self, mock_credential, mock_client, basic_options): 986 | """Test converting multiple columns with different types.""" 987 | from azure.monitor.query import LogsQueryStatus 988 | from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType, StructField, StructType 989 | 990 | schema = StructType( 991 | [ 992 | StructField("Name", StringType(), True), 993 | StructField("Count", LongType(), True), 994 | StructField("Score", DoubleType(), True), 995 | StructField("Active", BooleanType(), True), 996 | ] 997 | ) 998 | 999 | mock_response = Mock() 1000 | mock_response.status = LogsQueryStatus.SUCCESS 1001 | mock_table = Mock() 1002 | mock_table.columns = ["Name", "Count", "Score", "Active"] 1003 | mock_table.rows = [ 1004 | ["Alice", "100", "95.5", "true"], 1005 | ["Bob", "200", "87.3", "false"], 1006 | ] 1007 | mock_response.tables = [mock_table] 1008 | 1009 | mock_client_instance = Mock() 1010 | mock_client_instance.query_workspace.return_value = mock_response 1011 | mock_client.return_value = mock_client_instance 1012 | 1013 | reader = AzureMonitorBatchReader(basic_options, schema) 1014 | partitions = reader.partitions() 1015 | rows = list(reader.read(partitions[0])) 1016 | 1017 | assert len(rows) == 2 1018 | # First row 1019 | assert rows[0].Name == "Alice" 1020 | assert rows[0].Count == 100 1021 | assert isinstance(rows[0].Count, int) 1022 | assert rows[0].Score == 95.5 1023 | assert isinstance(rows[0].Score, float) 1024 | assert rows[0].Active is True 1025 | # Second row 1026 | assert rows[1].Name == "Bob" 1027 | assert rows[1].Count == 200 1028 | assert rows[1].Score == 87.3 1029 | assert rows[1].Active is False 1030 | 1031 | @patch("azure.monitor.query.LogsQueryClient") 1032 | @patch("azure.identity.ClientSecretCredential") 1033 | def test_type_conversion_preserves_none(self, mock_credential, mock_client, basic_options): 1034 | """Test that None/NULL values are preserved regardless of schema type.""" 1035 | from azure.monitor.query import LogsQueryStatus 1036 | from pyspark.sql.types import LongType, StringType, StructField, StructType 1037 | 1038 | schema = StructType( 1039 | [ 1040 | StructField("Name", StringType(), True), 1041 | StructField("Count", LongType(), True), 1042 | ] 1043 | ) 1044 | 1045 | mock_response = Mock() 1046 | mock_response.status = LogsQueryStatus.SUCCESS 1047 | mock_table = Mock() 1048 | mock_table.columns = ["Name", "Count"] 1049 | mock_table.rows = [ 1050 | ["Alice", None], 1051 | [None, "123"], 1052 | ] 1053 | mock_response.tables = [mock_table] 1054 | 1055 | mock_client_instance = Mock() 1056 | mock_client_instance.query_workspace.return_value = mock_response 1057 | mock_client.return_value = mock_client_instance 1058 | 1059 | reader = AzureMonitorBatchReader(basic_options, schema) 1060 | partitions = reader.partitions() 1061 | rows = list(reader.read(partitions[0])) 1062 | 1063 | assert len(rows) == 2 1064 | assert rows[0].Name == "Alice" 1065 | assert rows[0].Count is None 1066 | assert rows[1].Name is None 1067 | assert rows[1].Count == 123 1068 | 1069 | @patch("azure.monitor.query.LogsQueryClient") 1070 | @patch("azure.identity.ClientSecretCredential") 1071 | def test_type_conversion_missing_columns_set_to_null(self, mock_credential, mock_client, basic_options): 1072 | """Test that columns in schema but not in results are set to NULL.""" 1073 | from azure.monitor.query import LogsQueryStatus 1074 | from pyspark.sql.types import LongType, StringType, StructField, StructType 1075 | 1076 | # Schema expects Name, Count, and Extra columns 1077 | schema = StructType( 1078 | [ 1079 | StructField("Name", StringType(), True), 1080 | StructField("Count", LongType(), True), 1081 | StructField("Extra", StringType(), True), 1082 | ] 1083 | ) 1084 | 1085 | # Query results only have Name and Count (missing Extra) 1086 | mock_response = Mock() 1087 | mock_response.status = LogsQueryStatus.SUCCESS 1088 | mock_table = Mock() 1089 | mock_table.columns = ["Name", "Count"] # Extra is missing 1090 | mock_table.rows = [ 1091 | ["Alice", "100"], 1092 | ["Bob", "200"], 1093 | ] 1094 | mock_response.tables = [mock_table] 1095 | 1096 | mock_client_instance = Mock() 1097 | mock_client_instance.query_workspace.return_value = mock_response 1098 | mock_client.return_value = mock_client_instance 1099 | 1100 | reader = AzureMonitorBatchReader(basic_options, schema) 1101 | partitions = reader.partitions() 1102 | rows = list(reader.read(partitions[0])) 1103 | 1104 | # Verify all schema columns are present, missing ones are None 1105 | assert len(rows) == 2 1106 | assert rows[0].Name == "Alice" 1107 | assert rows[0].Count == 100 1108 | assert rows[0].Extra is None # Missing in query results 1109 | assert rows[1].Name == "Bob" 1110 | assert rows[1].Count == 200 1111 | assert rows[1].Extra is None # Missing in query results 1112 | 1113 | @patch("azure.monitor.query.LogsQueryClient") 1114 | @patch("azure.identity.ClientSecretCredential") 1115 | def test_type_conversion_extra_columns_ignored(self, mock_credential, mock_client, basic_options): 1116 | """Test that columns in results but not in schema are ignored.""" 1117 | from azure.monitor.query import LogsQueryStatus 1118 | from pyspark.sql.types import StringType, StructField, StructType 1119 | 1120 | # Schema only expects Name 1121 | schema = StructType([StructField("Name", StringType(), True)]) 1122 | 1123 | # Query results have Name and Extra (Extra not in schema) 1124 | mock_response = Mock() 1125 | mock_response.status = LogsQueryStatus.SUCCESS 1126 | mock_table = Mock() 1127 | mock_table.columns = ["Name", "Extra"] 1128 | mock_table.rows = [ 1129 | ["Alice", "value1"], 1130 | ["Bob", "value2"], 1131 | ] 1132 | mock_response.tables = [mock_table] 1133 | 1134 | mock_client_instance = Mock() 1135 | mock_client_instance.query_workspace.return_value = mock_response 1136 | mock_client.return_value = mock_client_instance 1137 | 1138 | reader = AzureMonitorBatchReader(basic_options, schema) 1139 | partitions = reader.partitions() 1140 | rows = list(reader.read(partitions[0])) 1141 | 1142 | # Verify only schema columns are present 1143 | assert len(rows) == 2 1144 | assert rows[0].Name == "Alice" 1145 | assert not hasattr(rows[0], "Extra") # Extra column ignored 1146 | assert rows[1].Name == "Bob" 1147 | assert not hasattr(rows[1], "Extra") # Extra column ignored 1148 | --------------------------------------------------------------------------------