├── .dockerignore ├── .gitignore ├── CONTRIBUTING.md ├── LICENCE.md ├── README.md ├── assemblyline_core ├── __init__.py ├── alerter │ ├── __init__.py │ ├── processing.py │ └── run_alerter.py ├── archiver │ ├── __init__.py │ └── run_archiver.py ├── badlist_client.py ├── dispatching │ ├── __init__.py │ ├── __main__.py │ ├── client.py │ ├── dispatcher.py │ ├── schedules.py │ └── timeout.py ├── expiry │ ├── __init__.py │ └── run_expiry.py ├── ingester │ ├── __init__.py │ ├── __main__.py │ ├── constants.py │ └── ingester.py ├── metrics │ ├── __init__.py │ ├── es_metrics.py │ ├── heartbeat_formatter.py │ ├── helper.py │ ├── metrics_server.py │ ├── run_heartbeat_manager.py │ ├── run_metrics_aggregator.py │ └── run_statistics_aggregator.py ├── plumber │ ├── __init__.py │ └── run_plumber.py ├── replay │ ├── __init__.py │ ├── client.py │ ├── creator │ │ ├── __init__.py │ │ ├── run.py │ │ └── run_worker.py │ ├── loader │ │ ├── __init__.py │ │ ├── run.py │ │ └── run_worker.py │ └── replay.py ├── safelist_client.py ├── scaler │ ├── __init__.py │ ├── collection.py │ ├── controllers │ │ ├── __init__.py │ │ ├── docker_ctl.py │ │ ├── interface.py │ │ └── kubernetes_ctl.py │ ├── run_scaler.py │ └── scaler_server.py ├── server_base.py ├── signature_client.py ├── submission_client.py ├── tasking_client.py ├── updater │ ├── __init__.py │ ├── helper.py │ └── run_updater.py ├── vacuum │ ├── __init__.py │ ├── crawler.py │ ├── department_map.py │ ├── safelist.py │ ├── stream_map.py │ └── worker.py └── workflow │ ├── __init__.py │ └── run_workflow.py ├── codecov.yml ├── deployment └── Dockerfile ├── pipelines ├── azure-tests.yaml └── config.yml ├── setup.cfg ├── setup.py └── test ├── classification.yml ├── conftest.py ├── docker-compose.yml ├── mocking ├── __init__.py └── random_service.py ├── requirements.txt ├── test_alerter.py ├── test_badlist_client.py ├── test_dispatcher.py ├── test_expiry.py ├── test_plumber.py ├── test_replay.py ├── test_safelist_client.py ├── test_scaler.py ├── test_scheduler.py ├── test_signature_client.py ├── test_simulation.py ├── test_tasking_client.py ├── test_vacuum.py ├── test_worker_ingest.py ├── test_worker_submit.py └── test_workflow.py /.dockerignore: -------------------------------------------------------------------------------- 1 | Dockerfile 2 | .idea 3 | .git 4 | 5 | pipelines 6 | venv 7 | env 8 | test 9 | tests 10 | exemples 11 | docs 12 | 13 | pip-log.txt 14 | pip-delete-this-directory.txt 15 | .tox 16 | .coverage 17 | .coverage.* 18 | .cache 19 | nosetests.xml 20 | coverage.xml 21 | *,cover 22 | *.log 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .mypy_cache 6 | 7 | # C extensions 8 | *.so 9 | 10 | # IDE files 11 | .pydevproject 12 | .python-version 13 | .idea 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | VERSION 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # PyBuilder 66 | .pybuilder/ 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # IPython 73 | profile_default/ 74 | ipython_config.py 75 | 76 | # Environments 77 | .env 78 | .venv 79 | env/ 80 | venv/ 81 | ENV/ 82 | env.bak/ 83 | venv.bak/ 84 | 85 | # Cython debug symbols 86 | cython_debug/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Assemblyline contributing guide 2 | 3 | This guide covers the basics of how to contribute to the Assemblyline project. 4 | 5 | Python code should follow the PEP8 guidelines defined here: [PEP8 Guidelines](https://www.python.org/dev/peps/pep-0008/). 6 | 7 | ## Tell us want you want to build/fix 8 | Before you start coding anything you should connect with the Assemblyline community via the [Assemblyline Discord server](https://discord.gg/GUAy9wErNu) and/or the [central Assemblyline GitHub project](https://github.com/CybercentreCanada/assemblyline/issues) to make sure no one else is working on the same thing and that whatever you are going to build still fits with the vision of the system. 9 | 10 | ## Git workflow 11 | 12 | - Clone the repo to your own account 13 | - Checkout and pull the latest commits from the master branch 14 | - Make a branch 15 | - Work in any way you like and make sure your changes actually work 16 | - When you're satisfied with your changes, create a pull requests to the main assemblyline repo 17 | 18 | #### Transfer your service repo 19 | If you've worked on a new service that you want to be included in the default service selection you'll have to transfer the repo into our control. 20 | 21 | #### You are not allowed to merge: 22 | 23 | Even if you try to merge in your pull request, you will be denied. Only a few people in our team are allowed to merge code into our repositories. 24 | 25 | We check for new pull requests every day and will merge them in once they have been approved by someone in our team. 26 | -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Crown Copyright, Government of Canada (Canadian Centre for Cyber Security / Communications Security Establishment) 4 | 5 | Copyright title to all 3rd party software distributed with Assemblyline (AL) is held by the respective copyright holders as noted in those files. Users are asked to read the 3rd Party Licenses referenced with those assets. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Discord](https://img.shields.io/badge/chat-on%20discord-7289da.svg?sanitize=true)](https://discord.gg/GUAy9wErNu) 2 | [![](https://img.shields.io/discord/908084610158714900)](https://discord.gg/GUAy9wErNu) 3 | [![Static Badge](https://img.shields.io/badge/github-assemblyline-blue?logo=github)](https://github.com/CybercentreCanada/assemblyline) 4 | [![Static Badge](https://img.shields.io/badge/github-assemblyline--core-blue?logo=github)](https://github.com/CybercentreCanada/assemblyline-core) 5 | [![GitHub Issues or Pull Requests by label](https://img.shields.io/github/issues/CybercentreCanada/assemblyline/core)](https://github.com/CybercentreCanada/assemblyline/issues?q=is:issue+is:open+label:core) 6 | [![License](https://img.shields.io/github/license/CybercentreCanada/assemblyline-core)](./LICENSE.md) 7 | 8 | # Assemblyline 4 - Core 9 | 10 | This repository provides cores services for Assemblyline 4. 11 | 12 | ## Image variants and tags 13 | 14 | | **Tag Type** | **Description** | **Example Tag** | 15 | | :----------: | :----------------------------------------------------------------------------------------------- | :------------------------: | 16 | | latest | The most recent build (can be unstable). | `latest` | 17 | | build_type | The type of build used. `dev` is the latest unstable build. `stable` is the latest stable build. | `stable` or `dev` | 18 | | series | Complete build details, including version and build type: `version.buildType`. | `4.5.stable`, `4.5.1.dev3` | 19 | 20 | ## Components 21 | 22 | ### Alerter 23 | 24 | Create alerts for the different submissions in the system. 25 | 26 | ```bash 27 | docker run --name alerter cccs/assemblyline-core python -m assemblyline_core.alerter.run_alerter 28 | ``` 29 | 30 | ### Archiver 31 | 32 | Archives submissions and their results & files into the archive. 33 | 34 | ```bash 35 | docker run --name archiver cccs/assemblyline-core python -m assemblyline_core.archiver.run_archiver 36 | ``` 37 | 38 | ### Dispatcher 39 | 40 | Route the files in the system while a submission is tacking place. Make sure all files during a submission are completed by all required services. 41 | 42 | ```bash 43 | docker run --name dispatcher cccs/assemblyline-core python -m assemblyline_core.dispatching 44 | ``` 45 | 46 | ### Expiry 47 | 48 | Delete submissions and their results when their time-to-live expires. 49 | 50 | ```bash 51 | docker run --name expiry cccs/assemblyline-core python -m assemblyline_core.expiry.run_expiry 52 | ``` 53 | 54 | ### Ingester 55 | 56 | Move ingested files from the priority queues to the processing queues. 57 | 58 | ```bash 59 | docker run --name ingester cccs/assemblyline-core python -m assemblyline_core.ingester 60 | ``` 61 | 62 | ### Metrics 63 | 64 | Generates metrics of the different components in the system. 65 | 66 | #### Heartbeat Manager 67 | 68 | ```bash 69 | docker run --name heartbeat cccs/assemblyline-core python -m assemblyline_core.metrics.run_heartbeat_manager 70 | ``` 71 | 72 | #### Metrics Aggregator 73 | 74 | ```bash 75 | docker run --name metrics cccs/assemblyline-core python -m assemblyline_core.metrics.run_metrics_aggregator 76 | ``` 77 | 78 | #### Statistics Aggregator 79 | 80 | ```bash 81 | docker run --name statistics cccs/assemblyline-core python -m assemblyline_core.metrics.run_statistics_aggregator 82 | ``` 83 | 84 | ### Scaler 85 | 86 | Spin up and down services in the system depending on the load. 87 | 88 | ```bash 89 | docker run --name scaler cccs/assemblyline-core python -m assemblyline_core.scaler.run_scaler 90 | ``` 91 | 92 | ### Updater 93 | 94 | Make sure the different services get their latest update files. 95 | 96 | ```bash 97 | docker run --name updater cccs/assemblyline-core python -m assemblyline_core.updater.run_updater 98 | ``` 99 | 100 | ### Workflow 101 | 102 | Run the different workflows in the system and apply their labels, priority and status. 103 | 104 | ```bash 105 | docker run --name workflow cccs/assemblyline-core python -m assemblyline_core.workflow.run_workflow 106 | ``` 107 | 108 | ## Documentation 109 | 110 | For more information about these Assemblyline components, follow this [overview](https://cybercentrecanada.github.io/assemblyline4_docs/overview/architecture/) of the system's architecture. 111 | 112 | --- 113 | 114 | # Assemblyline 4 - Core 115 | 116 | Ce dépôt fournit des services de base pour Assemblyline 4. 117 | 118 | ## Variantes et étiquettes d'image 119 | 120 | | **Type d'étiquette** | **Description** | **Exemple d'étiquette** | 121 | | :------------------: | :--------------------------------------------------------------------------------------------------------------- | :------------------------: | 122 | | dernière | La version la plus récente (peut être instable). | `latest` | 123 | | build_type | Le type de compilation utilisé. `dev` est la dernière version instable. `stable` est la dernière version stable. | `stable` ou `dev` | 124 | | séries | Le détail de compilation utilisé, incluant la version et le type de compilation : `version.buildType`. | `4.5.stable`, `4.5.1.dev3` | 125 | 126 | ## Composants 127 | 128 | ### Alerter 129 | 130 | Crée des alertes pour les différentes soumissions dans le système. 131 | 132 | ```bash 133 | docker run --name alerter cccs/assemblyline-core python -m assemblyline_core.alerter.run_alerter 134 | ``` 135 | 136 | ### Archiver 137 | 138 | Archivage des soumissions, de leurs résultats et des fichiers dans l'archive. 139 | 140 | ```bash 141 | docker run --name archiver cccs/assemblyline-core python -m assemblyline_core.archiver.run_archiver 142 | ``` 143 | 144 | ### Dispatcher 145 | 146 | Achemine les fichiers dans le système durant une soumission. S'assure que tous les fichiers de la soumission courante soient complétés par tous les services requis. 147 | 148 | ```bash 149 | docker run --name dispatcher cccs/assemblyline-core python -m assemblyline_core.dispatching 150 | ``` 151 | 152 | ### Expiration 153 | 154 | Supprimer les soumissions et leurs résultats à l'expiration de leur durée de vie. 155 | 156 | ```bash 157 | docker run --name expiry cccs/assemblyline-core python -m assemblyline_core.expiry.run_expiry 158 | ``` 159 | 160 | ### Ingester 161 | 162 | Déplace les fichiers ingérés des files d'attente prioritaires vers les files d'attente de traitement. 163 | 164 | ```bash 165 | docker run --name ingester cccs/assemblyline-core python -m assemblyline_core.ingester 166 | ``` 167 | 168 | ### Métriques 169 | 170 | Génère des métriques des différents composants du système. 171 | 172 | #### Heartbeat Manager 173 | 174 | ```bash 175 | docker run --name heartbeat cccs/assemblyline-core python -m assemblyline_core.metrics.run_heartbeat_manager 176 | ``` 177 | 178 | #### Agrégateur de métriques 179 | 180 | ```bash 181 | docker run --name metrics cccs/assemblyline-core python -m assemblyline_core.metrics.run_metrics_aggregator 182 | ``` 183 | 184 | ##### Agrégateur de statistiques 185 | 186 | ```bash 187 | docker run --name statistics cccs/assemblyline-core python -m assemblyline_core.metrics.run_statistics_aggregator 188 | ``` 189 | 190 | ### Scaler 191 | 192 | Augmente et diminue les services dans le système en fonction de la charge. 193 | 194 | ```bash 195 | docker run --name scaler cccs/assemblyline-core python -m assemblyline_core.scaler.run_scaler 196 | ``` 197 | 198 | ### Mise à jour 199 | 200 | Assure que les différents services reçoivent leurs derniers fichiers de mise à jour. 201 | 202 | ```bash 203 | docker run --name updater cccs/assemblyline-core python -m assemblyline_core.updater.run_updater 204 | ``` 205 | 206 | ### Workflow 207 | 208 | Exécute les différents flux de travail dans le système et appliquer leurs étiquettes, leur priorité et leur statut. 209 | 210 | ```bash 211 | docker run --name workflow cccs/assemblyline-core python -m assemblyline_core.workflow.run_workflow 212 | ``` 213 | 214 | ## Documentation 215 | 216 | Pour plus d'informations sur ces composants Assemblyline, suivez ce [overview](https://cybercentrecanada.github.io/assemblyline4_docs/overview/architecture/) de l'architecture du système. 217 | -------------------------------------------------------------------------------- /assemblyline_core/__init__.py: -------------------------------------------------------------------------------- 1 | PAUSABLE_COMPONENTS = ['ingester', 'dispatcher'] 2 | 3 | def normalize_hashlist_item(tag_type: str, tag_value: str) -> str: 4 | # Normalize tag data pertaining to domains or URIs 5 | if tag_type.endswith('.domain'): 6 | tag_value = tag_value.lower() 7 | elif tag_type.endswith('.uri'): 8 | hostname = tag_value.split('//', 1)[1].split('/', 1)[0] 9 | tag_value = tag_value.replace(hostname, hostname.lower(), 1) 10 | return tag_value 11 | -------------------------------------------------------------------------------- /assemblyline_core/alerter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/alerter/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/alerter/run_alerter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import elasticapm 4 | 5 | from assemblyline.common import forge 6 | from assemblyline.common.isotime import now 7 | from assemblyline.common.metrics import MetricsFactory 8 | from assemblyline.remote.datatypes import get_client 9 | from assemblyline.remote.datatypes.queues.named import NamedQueue 10 | from assemblyline.odm.messages.alerter_heartbeat import Metrics 11 | 12 | from assemblyline_core.alerter.processing import SubmissionNotFinalized 13 | from assemblyline_core.server_base import ServerBase 14 | 15 | ALERT_QUEUE_NAME = 'm-alert' 16 | ALERT_RETRY_QUEUE_NAME = 'm-alert-retry' 17 | MAX_RETRIES = 10 18 | SUBMISSION_RETRY_SEC = 15 19 | 20 | 21 | class Alerter(ServerBase): 22 | def __init__(self): 23 | super().__init__('assemblyline.alerter') 24 | # Publish counters to the metrics sink. 25 | self.counter = MetricsFactory('alerter', Metrics) 26 | self.datastore = forge.get_datastore(self.config) 27 | self.persistent_redis = get_client( 28 | host=self.config.core.redis.persistent.host, 29 | port=self.config.core.redis.persistent.port, 30 | private=False, 31 | ) 32 | self.process_alert_message = forge.get_process_alert_message() 33 | self.running = False 34 | self.next_retry_available = 0 35 | 36 | self.alert_queue: NamedQueue[dict] = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis) 37 | self.alert_retry_queue: NamedQueue[dict] = NamedQueue(ALERT_RETRY_QUEUE_NAME, self.persistent_redis) 38 | if self.config.core.metrics.apm_server.server_url is not None: 39 | self.log.info(f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}") 40 | elasticapm.instrument() 41 | self.apm_client = forge.get_apm_client("alerter") 42 | else: 43 | self.apm_client = None 44 | 45 | def stop(self): 46 | if self.counter: 47 | self.counter.stop() 48 | 49 | if self.apm_client: 50 | elasticapm.uninstrument() 51 | super().stop() 52 | 53 | def run_once(self): 54 | # Check if there is a due alert in the retry queue 55 | alert = None 56 | if self.next_retry_available < now(): 57 | # Check if the next alert in the queue has a wait_until in the past (<) 58 | alert = self.alert_retry_queue.peek_next() 59 | if alert and alert.get('wait_until', 0) < now(): 60 | # Double check after popping it to be sure we got the alert we expected 61 | # if it is in the future (>) put it back in the queue 62 | alert = self.alert_retry_queue.pop(blocking=False) 63 | if alert and alert.get('wait_until', 0) > now(): 64 | self.alert_retry_queue.push(alert) 65 | alert = None 66 | elif alert: 67 | # If we have peeked an alert and it isn't time for it to be processed 68 | # yet, wait until then before trying to peek the retry queue again 69 | self.next_retry_available = alert.get('wait_until', 0) 70 | alert = None 71 | 72 | # If we haven't gotten an alert from retry queue, pop on the main queue 73 | if not alert: 74 | alert = self.alert_queue.pop(timeout=1) 75 | 76 | # If there is no alert bail out 77 | if not alert: 78 | return 79 | 80 | # Start of process alert transaction 81 | if self.apm_client: 82 | self.apm_client.begin_transaction('Process alert message') 83 | 84 | self.counter.increment('received') 85 | try: 86 | alert_type = self.process_alert_message(self.counter, self.datastore, self.log, alert) 87 | 88 | # End of process alert transaction (success) 89 | if self.apm_client: 90 | self.apm_client.end_transaction(alert_type, 'success') 91 | 92 | return alert_type 93 | except SubmissionNotFinalized as error: 94 | self.counter.increment('wait') 95 | self.log.error(str(error)) 96 | 97 | # Wait a bit for the submission to complete 98 | alert['wait_until'] = now(SUBMISSION_RETRY_SEC) 99 | self.alert_retry_queue.push(alert) 100 | 101 | # End of process alert transaction (wait) 102 | if self.apm_client: 103 | self.apm_client.end_transaction('unknown', 'wait') 104 | 105 | return 'wait' 106 | except Exception: # pylint: disable=W0703 107 | retries = alert['alert_retries'] = alert.get('alert_retries', 0) + 1 108 | self.counter.increment('error') 109 | if retries > MAX_RETRIES: 110 | self.log.exception(f'Max retries exceeded for: {alert}') 111 | else: 112 | self.alert_retry_queue.push(alert) 113 | self.log.exception(f'Unhandled exception processing: {alert}') 114 | 115 | # End of process alert transaction (failure) 116 | if self.apm_client: 117 | self.apm_client.end_transaction('unknown', 'exception') 118 | 119 | return 'exception' 120 | 121 | def try_run(self): 122 | while self.running: 123 | self.heartbeat() 124 | self.run_once() 125 | 126 | 127 | if __name__ == "__main__": 128 | with Alerter() as alerter: 129 | alerter.serve_forever() 130 | -------------------------------------------------------------------------------- /assemblyline_core/archiver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/archiver/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/archiver/run_archiver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import elasticapm 3 | import os 4 | import tempfile 5 | 6 | from assemblyline.common import forge 7 | from assemblyline.common.archiving import ARCHIVE_QUEUE_NAME 8 | from assemblyline.common.metrics import MetricsFactory 9 | from assemblyline.datastore.collection import ESCollection, Index 10 | from assemblyline.datastore.exceptions import VersionConflictException 11 | from assemblyline.odm.messages.archive_heartbeat import Metrics 12 | from assemblyline.remote.datatypes import get_client 13 | from assemblyline.remote.datatypes.queues.named import NamedQueue 14 | 15 | from assemblyline_core.server_base import ServerBase 16 | 17 | 18 | class SubmissionNotFound(Exception): 19 | pass 20 | 21 | 22 | class Archiver(ServerBase): 23 | def __init__(self): 24 | super().__init__('assemblyline.archiver') 25 | self.apm_client = None 26 | self.counter = None 27 | 28 | if self.config.datastore.archive.enabled: 29 | # Publish counters to the metrics sink. 30 | self.counter = MetricsFactory('archiver', Metrics) 31 | self.datastore = forge.get_datastore(self.config, archive_access=True) 32 | self.filestore = forge.get_filestore(config=self.config) 33 | self.archivestore = forge.get_archivestore(config=self.config) 34 | self.persistent_redis = get_client( 35 | host=self.config.core.redis.persistent.host, 36 | port=self.config.core.redis.persistent.port, 37 | private=False, 38 | ) 39 | 40 | self.archive_queue: NamedQueue[dict] = NamedQueue(ARCHIVE_QUEUE_NAME, self.persistent_redis) 41 | if self.config.core.metrics.apm_server.server_url is not None: 42 | self.log.info(f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}") 43 | elasticapm.instrument() 44 | self.apm_client = forge.get_apm_client("archiver") 45 | else: 46 | self.log.warning("Archive is not enabled in the config, no need to run archiver.") 47 | exit() 48 | 49 | def stop(self): 50 | if self.counter: 51 | self.counter.stop() 52 | 53 | if self.apm_client: 54 | elasticapm.uninstrument() 55 | super().stop() 56 | 57 | def run_once(self): 58 | message = self.archive_queue.pop(timeout=1) 59 | 60 | # If there is no alert bail out 61 | if not message: 62 | return 63 | else: 64 | try: 65 | if len(message) == 3: 66 | archive_type, type_id, delete_after = message 67 | metadata = {} 68 | use_alternate_dtl = False 69 | elif len(message) == 4: 70 | archive_type, type_id, delete_after, metadata = message 71 | use_alternate_dtl = False 72 | else: 73 | archive_type, type_id, delete_after, metadata, use_alternate_dtl = message 74 | 75 | self.counter.increment('received') 76 | except Exception: 77 | self.log.error(f"Invalid message received: {message}") 78 | return 79 | 80 | # Start of process alert transaction 81 | if self.apm_client: 82 | self.apm_client.begin_transaction('Process archive message') 83 | 84 | try: 85 | if archive_type == "submission": 86 | self.counter.increment('submission') 87 | # Load submission 88 | while True: 89 | try: 90 | submission, version = self.datastore.submission.get_if_exists(type_id, version=True) 91 | 92 | # If we have metadata passed in the message, we need to apply it before archiving the submission 93 | if metadata and self.config.submission.metadata.archive: 94 | submission.metadata.update({k: v for k, v in metadata.items() 95 | if k not in submission.metadata}) 96 | self.datastore.submission.save(type_id, submission, version=version) 97 | 98 | break 99 | except VersionConflictException as vce: 100 | self.log.info(f"Retrying saving metadata due to version conflict: {str(vce)}") 101 | 102 | if not submission: 103 | raise SubmissionNotFound(type_id) 104 | 105 | self.datastore.submission.archive(type_id, delete_after=delete_after, 106 | use_alternate_dtl=use_alternate_dtl) 107 | if not delete_after: 108 | self.datastore.submission.update(type_id, [(ESCollection.UPDATE_SET, 'archived', True)], 109 | index_type=Index.HOT) 110 | 111 | # Gather list of files and archives them 112 | files = {(f.sha256, False) for f in submission.files} 113 | files.update(self.datastore.get_file_list_from_keys(submission.results)) 114 | for sha256, supplementary in files: 115 | self.counter.increment('file') 116 | 117 | # Get the tags for this file 118 | tags = self.datastore.get_tag_list_from_keys( 119 | [r for r in submission.results if r.startswith(sha256)]) 120 | attributions = {x['value'] for x in tags if x['type'].startswith('attribution.')} 121 | techniques = {x['type'].rsplit('.', 1)[1] for x in tags if x['type'].startswith('technique.')} 122 | infos = {'ioc' for x in tags if x['type'] in self.config.submission.tag_types.ioc} 123 | infos = infos.union({'password' for x in tags if x['type'] == 'info.password'}) 124 | 125 | # Create the archive file 126 | self.datastore.file.archive(sha256, delete_after=delete_after, 127 | allow_missing=True, use_alternate_dtl=use_alternate_dtl) 128 | 129 | # Auto-Labelling 130 | operations = [] 131 | 132 | # Create default labels 133 | operations += [(self.datastore.file.UPDATE_APPEND_IF_MISSING, 'labels', x) for x in attributions] 134 | operations += [(self.datastore.file.UPDATE_APPEND_IF_MISSING, 'labels', x) for x in techniques] 135 | operations += [(self.datastore.file.UPDATE_APPEND_IF_MISSING, 'labels', x) for x in infos] 136 | 137 | # Create type specific labels 138 | operations += [ 139 | (self.datastore.file.UPDATE_APPEND_IF_MISSING, 'label_categories.attribution', x) 140 | for x in attributions] 141 | operations += [ 142 | (self.datastore.file.UPDATE_APPEND_IF_MISSING, 'label_categories.technique', x) 143 | for x in techniques] 144 | operations += [ 145 | (self.datastore.file.UPDATE_APPEND_IF_MISSING, 'label_categories.info', x) 146 | for x in infos] 147 | 148 | # Set the is_supplementary property 149 | operations += [(self.datastore.file.UPDATE_SET, 'is_supplementary', supplementary)] 150 | 151 | # Apply auto-created labels 152 | self.datastore.file.update(sha256, operations=operations, index_type=Index.ARCHIVE) 153 | self.datastore.file.update(sha256, operations=operations, index_type=Index.HOT) 154 | 155 | if self.filestore != self.archivestore: 156 | with tempfile.NamedTemporaryFile() as buf: 157 | try: 158 | self.filestore.download(sha256, buf.name) 159 | if os.path.getsize(buf.name): 160 | self.archivestore.upload(buf.name, sha256) 161 | except Exception as e: 162 | self.log.error( 163 | f"Could not copy file {sha256} from the filestore to the archivestore. ({e})") 164 | 165 | # Archive associated results (Skip emptys) 166 | for r in submission.results: 167 | if not r.endswith(".e"): 168 | self.counter.increment('result') 169 | self.datastore.result.archive(r, delete_after=delete_after, 170 | allow_missing=True, use_alternate_dtl=use_alternate_dtl) 171 | 172 | # End of process alert transaction (success) 173 | self.log.info(f"Successfully archived submission '{type_id}'.") 174 | if self.apm_client: 175 | self.apm_client.end_transaction(archive_type, 'success') 176 | 177 | # Invalid archiving type 178 | else: 179 | self.counter.increment('invalid') 180 | self.log.warning(f"'{archive_type}' is not a valid archive type.") 181 | # End of process alert transaction (success) 182 | if self.apm_client: 183 | self.apm_client.end_transaction(archive_type, 'invalid') 184 | 185 | except SubmissionNotFound: 186 | self.counter.increment('not_found') 187 | self.log.warning(f"Could not archive {archive_type} '{type_id}'. It was not found in the system.") 188 | # End of process alert transaction (failure) 189 | if self.apm_client: 190 | self.apm_client.end_transaction(archive_type, 'not_found') 191 | 192 | except Exception: # pylint: disable=W0703 193 | self.counter.increment('exception') 194 | self.log.exception(f'Unhandled exception processing {archive_type} ID: {type_id}') 195 | 196 | # End of process alert transaction (failure) 197 | if self.apm_client: 198 | self.apm_client.end_transaction(archive_type, 'exception') 199 | 200 | def try_run(self): 201 | while self.running: 202 | self.heartbeat() 203 | self.run_once() 204 | 205 | 206 | if __name__ == "__main__": 207 | with Archiver() as archiver: 208 | archiver.serve_forever() 209 | -------------------------------------------------------------------------------- /assemblyline_core/badlist_client.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | 4 | from assemblyline.common import forge 5 | from assemblyline.common.chunk import chunk 6 | from assemblyline.common.isotime import now_as_iso 7 | from assemblyline.datastore.helper import AssemblylineDatastore 8 | from assemblyline.odm.models.user import ROLES 9 | from assemblyline.remote.datatypes.lock import Lock 10 | 11 | from assemblyline_core import normalize_hashlist_item 12 | 13 | CHUNK_SIZE = 1000 14 | CLASSIFICATION = forge.get_classification() 15 | 16 | 17 | class InvalidBadhash(Exception): 18 | pass 19 | 20 | 21 | # Badlist class 22 | class BadlistClient: 23 | """A helper class to simplify badlisting for privileged services and service-server.""" 24 | 25 | def __init__(self, datastore: AssemblylineDatastore = None, config=None): 26 | self.log = logging.getLogger('assemblyline.badlist_client') 27 | self.config = config or forge.CachedObject(forge.get_config) 28 | self.datastore = datastore or forge.get_datastore(self.config) 29 | 30 | def _preprocess_object(self, data: dict) -> str: 31 | # Remove any null classification that might already be set 32 | if 'classification' in data and not data.get('classification'): 33 | data.pop('classification') 34 | 35 | # Set defaults 36 | data.setdefault('classification', CLASSIFICATION.UNRESTRICTED) 37 | data.setdefault('hashes', {}) 38 | data.setdefault('expiry_ts', None) 39 | if data['type'] == 'tag': 40 | # Remove file related fields 41 | data.pop('file', None) 42 | data.pop('hashes', None) 43 | 44 | tag_data = data.get('tag', None) 45 | if tag_data is None or 'type' not in tag_data or 'value' not in tag_data: 46 | raise ValueError("Tag data not found") 47 | 48 | # Normalize tag data before further processing 49 | tag_data['value'] = normalize_hashlist_item(tag_data['type'], tag_data['value']) 50 | 51 | hashed_value = f"{tag_data['type']}: {tag_data['value']}".encode('utf8') 52 | data['hashes'] = { 53 | 'sha256': hashlib.sha256(hashed_value).hexdigest() 54 | } 55 | 56 | elif data['type'] == 'file': 57 | data.pop('tag', None) 58 | data.setdefault('file', {}) 59 | 60 | # Ensure expiry_ts is set on tag-related items 61 | dtl = data.pop('dtl', None) or self.config.core.expiry.badlisted_tag_dtl 62 | if dtl: 63 | data['expiry_ts'] = now_as_iso(dtl * 24 * 3600) 64 | 65 | # Set last updated 66 | data['added'] = data['updated'] = now_as_iso() 67 | 68 | # Find the best hash to use for the key 69 | for hash_key in ['sha256', 'sha1', 'md5', 'tlsh', 'ssdeep']: 70 | qhash = data['hashes'].get(hash_key, None) 71 | if qhash: 72 | break 73 | 74 | # Validate hash length 75 | if not qhash: 76 | raise ValueError("No valid hash found") 77 | 78 | return qhash 79 | 80 | def add_update(self, badlist_object: dict, user: dict = None): 81 | qhash = self._preprocess_object(badlist_object) 82 | 83 | # Validate sources 84 | src_map = {} 85 | for src in badlist_object['sources']: 86 | if user: 87 | if src['type'] == 'user': 88 | if src['name'] != user['uname']: 89 | raise ValueError(f"You cannot add a source for another user. {src['name']} != {user['uname']}") 90 | else: 91 | if ROLES.signature_import not in user['roles']: 92 | raise PermissionError("You do not have sufficient priviledges to add an external source.") 93 | 94 | # Find the highest classification of all sources 95 | badlist_object['classification'] = CLASSIFICATION.max_classification( 96 | badlist_object['classification'], src.get('classification', None)) 97 | 98 | src_map[src['name']] = src 99 | 100 | with Lock(f'add_or_update-badlist-{qhash}', 30): 101 | old = self.datastore.badlist.get_if_exists(qhash, as_obj=False) 102 | if old: 103 | # Save data to the DB 104 | self.datastore.badlist.save(qhash, BadlistClient._merge_hashes(badlist_object, old)) 105 | return qhash, "update" 106 | else: 107 | try: 108 | badlist_object['sources'] = list(src_map.values()) 109 | self.datastore.badlist.save(qhash, badlist_object) 110 | return qhash, "add" 111 | except Exception as e: 112 | return ValueError(f"Invalid data provided: {str(e)}") 113 | 114 | def add_update_many(self, list_of_badlist_objects: list): 115 | if not isinstance(list_of_badlist_objects, list): 116 | raise ValueError("Could not get the list of hashes") 117 | 118 | new_data = {} 119 | for badlist_object in list_of_badlist_objects: 120 | qhash = self._preprocess_object(badlist_object) 121 | new_data[qhash] = badlist_object 122 | 123 | # Get already existing hashes 124 | old_data = self.datastore.badlist.multiget(list(new_data.keys()), as_dictionary=True, as_obj=False, 125 | error_on_missing=False) 126 | 127 | # Test signature names 128 | plan = self.datastore.badlist.get_bulk_plan() 129 | for key, val in new_data.items(): 130 | # Use maximum classification 131 | old_val = old_data.get(key, {'classification': CLASSIFICATION.UNRESTRICTED, 'attribution': {}, 132 | 'hashes': {}, 'sources': [], 'type': val['type']}) 133 | 134 | # Add upsert operation 135 | plan.add_upsert_operation(key, BadlistClient._merge_hashes(val, old_val)) 136 | 137 | if not plan.empty: 138 | # Execute plan 139 | res = self.datastore.badlist.bulk(plan) 140 | return {"success": len(res['items']), "errors": res['errors']} 141 | 142 | return {"success": 0, "errors": []} 143 | 144 | def exists(self, qhash): 145 | return self.datastore.badlist.get_if_exists(qhash, as_obj=False) 146 | 147 | def exists_tags(self, tag_map): 148 | lookup_keys = [] 149 | for tag_type, tag_values in tag_map.items(): 150 | for tag_value in tag_values: 151 | lookup_keys.append(hashlib.sha256(f"{tag_type}: {normalize_hashlist_item(tag_type, tag_value)}".encode('utf8')).hexdigest()) 152 | 153 | # Elasticsearch's result window can't be more than 10000 rows 154 | # we will query for matches in chunks 155 | results = [] 156 | for key_chunk in chunk(lookup_keys, CHUNK_SIZE): 157 | results += self.datastore.badlist.search("*", fl="*", rows=CHUNK_SIZE, 158 | as_obj=False, key_space=key_chunk)['items'] 159 | 160 | return results 161 | 162 | def find_similar_tlsh(self, tlsh): 163 | return self.datastore.badlist.search(f"hashes.tlsh:{tlsh}", fl="*", as_obj=False)['items'] 164 | 165 | def find_similar_ssdeep(self, ssdeep): 166 | try: 167 | _, long, _ = ssdeep.replace('/', '\\/').split(":") 168 | return self.datastore.badlist.search(f"hashes.ssdeep:{long}~", fl="*", as_obj=False)['items'] 169 | except ValueError: 170 | self.log.warning(f'This is not a valid SSDeep hash: {ssdeep}') 171 | return [] 172 | 173 | @staticmethod 174 | def _merge_hashes(new, old): 175 | # Account for the possibility of merging with null types 176 | if not (new or old): 177 | # Both are null 178 | raise ValueError("New and old are both null") 179 | elif not (new and old): 180 | # Only one is null, in which case return the other 181 | return new or old 182 | 183 | try: 184 | # Check if hash types match 185 | if new['type'] != old['type']: 186 | raise InvalidBadhash(f"Bad hash type mismatch: {new['type']} != {old['type']}") 187 | 188 | # Use the new classification but we will recompute it later anyway 189 | old['classification'] = new['classification'] 190 | 191 | # Update updated time 192 | old['updated'] = new.get('updated', now_as_iso()) 193 | 194 | # Update hashes 195 | old['hashes'].update({k: v for k, v in new['hashes'].items() if v}) 196 | 197 | # Merge attributions 198 | if not old['attribution']: 199 | old['attribution'] = new.get('attribution', None) 200 | elif new.get('attribution', None): 201 | for key in ["actor", 'campaign', 'category', 'exploit', 'implant', 'family', 'network']: 202 | old_value = old['attribution'].get(key, []) or [] 203 | new_value = new['attribution'].get(key, []) or [] 204 | old['attribution'][key] = list(set(old_value + new_value)) or None 205 | 206 | if old['attribution'] is not None: 207 | old['attribution'] = {key: value for key, value in old['attribution'].items() if value} 208 | 209 | # Update type specific info 210 | if old['type'] == 'file': 211 | old.setdefault('file', {}) 212 | new_names = new.get('file', {}).pop('name', []) 213 | if 'name' in old['file']: 214 | for name in new_names: 215 | if name not in old['file']['name']: 216 | old['file']['name'].append(name) 217 | elif new_names: 218 | old['file']['name'] = new_names 219 | old['file'].update({k: v for k, v in new.get('file', {}).items() if v}) 220 | elif old['type'] == 'tag': 221 | old['tag'] = new['tag'] 222 | 223 | # Merge sources 224 | src_map = {x['name']: x for x in new['sources']} 225 | if not src_map: 226 | raise InvalidBadhash("No valid source found") 227 | 228 | old_src_map = {x['name']: x for x in old['sources']} 229 | for name, src in src_map.items(): 230 | if name not in old_src_map: 231 | old_src_map[name] = src 232 | else: 233 | old_src = old_src_map[name] 234 | if old_src['type'] != src['type']: 235 | raise InvalidBadhash(f"Source {name} has a type conflict: {old_src['type']} != {src['type']}") 236 | 237 | for reason in src['reason']: 238 | if reason not in old_src['reason']: 239 | old_src['reason'].append(reason) 240 | old_src['classification'] = src.get('classification', old_src['classification']) 241 | old['sources'] = list(old_src_map.values()) 242 | 243 | # Calculate the new classification 244 | for src in old['sources']: 245 | old['classification'] = CLASSIFICATION.max_classification( 246 | old['classification'], src.get('classification', None)) 247 | 248 | # Set the expiry 249 | old['expiry_ts'] = new.get('expiry_ts', None) 250 | return old 251 | except Exception as e: 252 | raise InvalidBadhash(f"Invalid data provided: {str(e)}") 253 | -------------------------------------------------------------------------------- /assemblyline_core/dispatching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/dispatching/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/dispatching/__main__.py: -------------------------------------------------------------------------------- 1 | from assemblyline_core.dispatching.dispatcher import Dispatcher 2 | 3 | 4 | with Dispatcher() as server: 5 | server.serve_forever() 6 | -------------------------------------------------------------------------------- /assemblyline_core/dispatching/schedules.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Dict, Optional, Set, cast 3 | 4 | import logging 5 | import os 6 | import re 7 | 8 | from assemblyline.common.forge import CachedObject, get_classification 9 | from assemblyline.datastore.helper import AssemblylineDatastore 10 | from assemblyline.odm.models.config import Config 11 | from assemblyline.odm.models.service import Service 12 | from assemblyline.odm.models.submission import Submission 13 | from assemblyline_core.server_base import get_service_stage_hash, ServiceStage 14 | 15 | 16 | # If you are doing development and you want the system to route jobs ignoring the service setup/teardown 17 | # set an environment variable SKIP_SERVICE_SETUP to true for all dispatcher containers 18 | SKIP_SERVICE_SETUP = os.environ.get('SKIP_SERVICE_SETUP', 'false').lower() in ['true', '1'] 19 | 20 | Classification = get_classification() 21 | 22 | 23 | class Scheduler: 24 | """This object encapsulates building the schedule for a given file type for a submission.""" 25 | 26 | def __init__(self, datastore: AssemblylineDatastore, config: Config, redis): 27 | self.datastore = datastore 28 | self.config = config 29 | self._services: Dict[str, Service] = {} 30 | self.services = cast(Dict[str, Service], CachedObject(self._get_services)) 31 | self.service_stage = get_service_stage_hash(redis) 32 | self.c12n_services: Dict[str, Set[str]] = {} 33 | 34 | def build_schedule(self, submission: Submission, file_type: str, file_depth: int = 0, 35 | runtime_excluded: Optional[list[str]] = None, 36 | submitter_c12n: Optional[str] = Classification.UNRESTRICTED) -> list[dict[str, Service]]: 37 | # Get the set of all services currently enabled on the system 38 | all_services = dict(self.services) 39 | 40 | # Retrieve a list of services that the classfication group is allowed to submit to 41 | if submitter_c12n is None: 42 | accessible = set(all_services.keys()) 43 | else: 44 | accessible = self.get_accessible_services(submitter_c12n) 45 | 46 | # Load the selected and excluded services by category 47 | excluded = self.expand_categories(submission.params.services.excluded) 48 | runtime_excluded = self.expand_categories(runtime_excluded or []) 49 | if not submission.params.services.selected: 50 | selected = [s for s in all_services.keys()] 51 | else: 52 | selected = self.expand_categories(submission.params.services.selected) 53 | 54 | if submission.params.services.rescan: 55 | selected.extend(self.expand_categories(submission.params.services.rescan)) 56 | 57 | # If we enable service safelisting, the Safelist service shouldn't run on extracted files unless: 58 | # - We're enforcing use of the Safelist service (we always want to run the Safelist service) 59 | # - We're running submission with Deep Scanning 60 | # - We want to Ignore Filtering (perform as much unfiltered analysis as possible) 61 | if "Safelist" in selected and file_depth and self.config.services.safelist.enabled and \ 62 | not self.config.services.safelist.enforce_safelist_service \ 63 | and not (submission.params.deep_scan or submission.params.ignore_filtering): 64 | # Alter schedule to remove Safelist, if scheduled to run 65 | selected.remove("Safelist") 66 | 67 | # Add all selected, accepted, and not rejected services to the schedule 68 | schedule: list[dict[str, Service]] = [{} for _ in self.config.services.stages] 69 | services = list(set(selected).intersection(accessible) - set(excluded) - set(runtime_excluded)) 70 | selected = [] 71 | skipped = [] 72 | for name in services: 73 | service = all_services.get(name, None) 74 | 75 | if not service: 76 | skipped.append(name) 77 | logging.warning(f"Service configuration not found: {name}") 78 | continue 79 | 80 | accepted = not service.accepts or re.match(service.accepts, file_type) 81 | rejected = bool(service.rejects) and re.match(service.rejects, file_type) 82 | 83 | if accepted and not rejected: 84 | schedule[self.stage_index(service.stage)][name] = service 85 | selected.append(name) 86 | else: 87 | skipped.append(name) 88 | 89 | return schedule 90 | 91 | def expand_categories(self, services: list[str]) -> list[str]: 92 | """Expands the names of service categories found in the list of services. 93 | 94 | Args: 95 | services (list): List of service category or service names. 96 | """ 97 | if services is None: 98 | return [] 99 | 100 | services = list(services) 101 | categories = self.categories() 102 | 103 | found_services = [] 104 | seen_categories: set[str] = set() 105 | while services: 106 | name = services.pop() 107 | 108 | # If we found a new category mix in it's content 109 | if name in categories: 110 | if name not in seen_categories: 111 | # Add all of the items in this group to the list of 112 | # things that we need to evaluate, and mark this 113 | # group as having been seen. 114 | services.extend(categories[name]) 115 | seen_categories.add(name) 116 | continue 117 | 118 | # If it isn't a category, its a service 119 | found_services.append(name) 120 | 121 | # Use set to remove duplicates, set is more efficient in batches 122 | return list(set(found_services)) 123 | 124 | def categories(self) -> Dict[str, list[str]]: 125 | all_categories: dict[str, list[str]] = {} 126 | for service in self.services.values(): 127 | try: 128 | all_categories[service.category].append(service.name) 129 | except KeyError: 130 | all_categories[service.category] = [service.name] 131 | return all_categories 132 | 133 | def get_accessible_services(self, user_c12n: str) -> Set[str]: 134 | if not self.c12n_services.get(user_c12n): 135 | # Cache services that are accessible to a classification group 136 | self.c12n_services[user_c12n] = {_ for _, service in dict(self.services).items() 137 | if Classification.is_accessible(user_c12n, service.classification)} 138 | 139 | return self.c12n_services[user_c12n] 140 | 141 | def stage_index(self, stage): 142 | return self.config.services.stages.index(stage) 143 | 144 | def _get_services(self): 145 | old, self._services = self._services, {} 146 | stages = self.service_stage.items() 147 | services: list[Service] = self.datastore.list_all_services(full=True) 148 | for service in services: 149 | if service.enabled: 150 | # Determine if this is a service we would wait for the first update run for 151 | # Assume it is set to running so that in the case of a redis failure we fail 152 | # on the side of waiting for the update and processing more, rather than skipping 153 | wait_for = service.update_config and (service.update_config.wait_for_update and not SKIP_SERVICE_SETUP) 154 | # This is a service that we wait for, and is new, so check if it has finished its update setup 155 | if wait_for and service.name not in old: 156 | if stages.get(service.name, ServiceStage.Running) == ServiceStage.Running: 157 | self._services[service.name] = service 158 | else: 159 | self._services[service.name] = service 160 | return self._services 161 | -------------------------------------------------------------------------------- /assemblyline_core/dispatching/timeout.py: -------------------------------------------------------------------------------- 1 | """ 2 | A data structure encapsulating the timeout logic for the dispatcher. 3 | """ 4 | from __future__ import annotations 5 | import queue 6 | import time 7 | from queue import PriorityQueue 8 | from dataclasses import dataclass, field 9 | from typing import TypeVar, Generic, Hashable 10 | 11 | KeyType = TypeVar('KeyType', bound=Hashable) 12 | DataType = TypeVar('DataType') 13 | 14 | 15 | @dataclass(order=True) 16 | class TimeoutItem(Generic[KeyType, DataType]): 17 | expiry: float 18 | key: KeyType = field(compare=False) 19 | data: DataType = field(compare=False) 20 | 21 | 22 | class TimeoutTable(Generic[KeyType, DataType]): 23 | def __init__(self): 24 | self.timeout_queue: PriorityQueue[TimeoutItem] = PriorityQueue() 25 | self.event_data: dict[KeyType, TimeoutItem] = {} 26 | 27 | def set(self, key: KeyType, timeout: float, data: DataType): 28 | # If a timeout is set repeatedly with the same key, only the last one will count 29 | # even though we aren't removing the old ones from the queue. When the items are 30 | # popped from the queue they 31 | entry = TimeoutItem(time.time() + timeout, key, data) 32 | self.event_data[key] = entry 33 | self.timeout_queue.put(entry) 34 | 35 | def clear(self, key: KeyType): 36 | self.event_data.pop(key, None) 37 | 38 | def __contains__(self, item): 39 | return item in self.event_data 40 | 41 | def timeouts(self) -> dict[KeyType, DataType]: 42 | found = {} 43 | try: 44 | now = time.time() 45 | 46 | # Loop until we hit an entry that is active, and non expired 47 | current: TimeoutItem = self.timeout_queue.get_nowait() 48 | while current.expiry <= now or self.event_data.get(current.key, None) != current: 49 | if self.event_data.get(current.key, None) == current: 50 | self.event_data.pop(current.key) 51 | found[current.key] = current.data 52 | current = self.timeout_queue.get_nowait() 53 | 54 | # If we exit the loop, the last item was valid still, put it back 55 | self.timeout_queue.put(current) 56 | 57 | except queue.Empty: 58 | pass 59 | return found 60 | -------------------------------------------------------------------------------- /assemblyline_core/expiry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/expiry/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/ingester/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import INGEST_QUEUE_NAME, drop_chance 2 | -------------------------------------------------------------------------------- /assemblyline_core/ingester/__main__.py: -------------------------------------------------------------------------------- 1 | from assemblyline_core.ingester.ingester import Ingester 2 | 3 | 4 | with Ingester() as server: 5 | server.serve_forever() 6 | -------------------------------------------------------------------------------- /assemblyline_core/ingester/constants.py: -------------------------------------------------------------------------------- 1 | from math import tanh 2 | 3 | COMPLETE_QUEUE_NAME = 'm-complete' 4 | INGEST_QUEUE_NAME = 'm-ingest' 5 | 6 | 7 | def drop_chance(length, maximum): 8 | return max(0, tanh(float(length - maximum) / maximum * 2.0)) 9 | -------------------------------------------------------------------------------- /assemblyline_core/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/metrics/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/metrics/helper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import elasticsearch 3 | 4 | from os import environ 5 | 6 | from assemblyline.datastore.exceptions import ILMException 7 | 8 | MAX_RETRY_BACKOFF = 10 9 | 10 | 11 | def ilm_policy_exists(es, name): 12 | try: 13 | es.ilm.get_lifecycle(name=name) 14 | return True 15 | except elasticsearch.NotFoundError: 16 | return False 17 | 18 | 19 | def create_ilm_policy(es, name, archive_config): 20 | data_base = { 21 | "phases": { 22 | "hot": { 23 | "min_age": "0ms", 24 | "actions": { 25 | "set_priority": { 26 | "priority": 100 27 | }, 28 | "rollover": { 29 | "max_age": f"{archive_config['warm']}{archive_config['unit']}" 30 | } 31 | } 32 | }, 33 | "warm": { 34 | "actions": { 35 | "set_priority": { 36 | "priority": 50 37 | } 38 | } 39 | }, 40 | "cold": { 41 | "min_age": f"{archive_config['cold']}{archive_config['unit']}", 42 | "actions": { 43 | "set_priority": { 44 | "priority": 20 45 | } 46 | } 47 | } 48 | } 49 | } 50 | 51 | if archive_config['delete']: 52 | data_base['phases']['delete'] = { 53 | "min_age": f"{archive_config['delete']}{archive_config['unit']}", 54 | "actions": { 55 | "delete": {} 56 | } 57 | } 58 | 59 | try: 60 | es.ilm.put_lifecycle(name=name, policy=data_base) 61 | except elasticsearch.ApiError: 62 | raise ILMException(f"ERROR: Failed to create ILM policy: {name}") 63 | 64 | 65 | def ensure_indexes(log, es, config, indexes, datastream_enabled=False): 66 | for index_type in indexes: 67 | try: 68 | # Get the shard/replica configuration for metric-based indices from the environment 69 | # Otherwise default to values for a single-node Elastic cluster 70 | replicas = environ.get(f"ELASTIC_{index_type.upper()}_METRICS_REPLICAS", environ.get('ELASTIC_DEFAULT_METRICS_REPLICAS', 0)) 71 | shards = environ.get(f"ELASTIC_{index_type.upper()}_METRICS_SHARDS", environ.get('ELASTIC_DEFAULT_METRICS_SHARDS', 1)) 72 | 73 | index = f"al_metrics_{index_type}_ds" if datastream_enabled else f"al_metrics_{index_type}" 74 | policy = f"{index}_policy" 75 | while True: 76 | try: 77 | while not ilm_policy_exists(es, policy): 78 | log.debug(f"ILM Policy {policy.upper()} does not exists. Creating it now...") 79 | create_ilm_policy(es, policy, config.as_primitives()) 80 | break 81 | except ILMException as e: 82 | log.warning(str(e)) 83 | time.sleep(1) 84 | pass 85 | 86 | if not with_retries(log, es.indices.exists_template, name=index): 87 | log.debug(f"Index template {index.upper()} does not exists. Creating it now...") 88 | 89 | template_body = { 90 | "settings": { 91 | "index.lifecycle.name": policy, 92 | "index.codec": "best_compression", 93 | "index.number_of_replicas": replicas, 94 | "index.number_of_shards": shards 95 | } 96 | } 97 | put_template_func = None 98 | # Check if datastream is enabled 99 | if datastream_enabled: 100 | put_template_func = es.indices.put_index_template 101 | component_name = f"{index}-settings" 102 | component_body = {"template": template_body} 103 | if not es.cluster.exists_component_template(name=component_name): 104 | try: 105 | # Create component template 106 | with_retries(log, es.cluster.put_component_template, 107 | name=component_name, body=component_body) 108 | except elasticsearch.exceptions.RequestError as e: 109 | if "resource_already_exists_exception" not in str(e): 110 | raise 111 | log.warning(f"Tried to create a component template that already exists: {index.upper()}") 112 | template_body = { 113 | "index_patterns": f"{index}*", 114 | "composed_of": [component_name], 115 | "data_stream": {}, 116 | "priority": 10 117 | } 118 | 119 | # Legacy template 120 | else: 121 | put_template_func = es.indices.put_template 122 | template_body["order"] = 1 123 | template_body["index_patterns"] = [f"{index}-*"] 124 | template_body["settings"]["index.lifecycle.rollover_alias"] = index 125 | 126 | try: 127 | with_retries(log, put_template_func, name=index, body=template_body) 128 | except elasticsearch.exceptions.RequestError as e: 129 | if "resource_already_exists_exception" not in str(e): 130 | raise 131 | log.warning(f"Tried to create an index template that already exists: {index.upper()}") 132 | 133 | if not with_retries(log, es.indices.exists_alias, name=index) and not datastream_enabled: 134 | log.debug(f"Index alias {index.upper()} does not exists. Creating it now...") 135 | 136 | index_body = {"aliases": {index: {"is_write_index": True}}} 137 | 138 | while True: 139 | try: 140 | with_retries(log, es.indices.create, index=f"{index}-000001", body=index_body) 141 | break 142 | except elasticsearch.exceptions.RequestError as e: 143 | if "resource_already_exists_exception" in str(e): 144 | log.warning(f"Tried to create an index template that " 145 | f"already exists: {index.upper()}-000001") 146 | break 147 | elif "invalid_alias_name_exception" in str(e): 148 | with_retries(log, es.indices.delete, index=index) 149 | log.warning(str(e)) 150 | time.sleep(1) 151 | else: 152 | raise 153 | 154 | except Exception as e: 155 | log.exception(e) 156 | 157 | 158 | def with_retries(log, func, *args, **kwargs): 159 | retries = 0 160 | updated = 0 161 | deleted = 0 162 | while True: 163 | try: 164 | ret_val = func(*args, **kwargs) 165 | 166 | if retries: 167 | log.info('Reconnected to elasticsearch!') 168 | 169 | if updated: 170 | ret_val['updated'] += updated 171 | 172 | if deleted: 173 | ret_val['deleted'] += deleted 174 | 175 | return ret_val 176 | 177 | except elasticsearch.exceptions.NotFoundError: 178 | raise 179 | 180 | except elasticsearch.exceptions.ConflictError as ce: 181 | updated += ce.info.get('updated', 0) 182 | deleted += ce.info.get('deleted', 0) 183 | 184 | time.sleep(min(retries, MAX_RETRY_BACKOFF)) 185 | retries += 1 186 | 187 | except (elasticsearch.exceptions.ConnectionError, 188 | elasticsearch.exceptions.ConnectionTimeout, 189 | elasticsearch.exceptions.AuthenticationException): 190 | log.warning("No connection to Elasticsearch, retrying...") 191 | time.sleep(min(retries, MAX_RETRY_BACKOFF)) 192 | retries += 1 193 | 194 | except elasticsearch.exceptions.TransportError as e: 195 | err_code, msg, cause = e.args 196 | if err_code == 503 or err_code == '503': 197 | log.warning("Looks like index is not ready yet, retrying...") 198 | time.sleep(min(retries, MAX_RETRY_BACKOFF)) 199 | retries += 1 200 | elif err_code == 429 or err_code == '429': 201 | log.warning("Elasticsearch is too busy to perform the requested task, " 202 | "we will wait a bit and retry...") 203 | time.sleep(min(retries, MAX_RETRY_BACKOFF)) 204 | retries += 1 205 | 206 | else: 207 | raise 208 | -------------------------------------------------------------------------------- /assemblyline_core/metrics/run_heartbeat_manager.py: -------------------------------------------------------------------------------- 1 | 2 | from assemblyline_core.metrics.metrics_server import HeartbeatManager 3 | from assemblyline.common import forge 4 | 5 | if __name__ == '__main__': 6 | config = forge.get_config() 7 | with HeartbeatManager(config=config) as metricsd: 8 | metricsd.serve_forever() 9 | -------------------------------------------------------------------------------- /assemblyline_core/metrics/run_metrics_aggregator.py: -------------------------------------------------------------------------------- 1 | 2 | from assemblyline_core.metrics.metrics_server import MetricsServer 3 | from assemblyline.common import forge 4 | 5 | if __name__ == '__main__': 6 | config = forge.get_config() 7 | with MetricsServer(config=config) as metricsd: 8 | metricsd.serve_forever() 9 | -------------------------------------------------------------------------------- /assemblyline_core/metrics/run_statistics_aggregator.py: -------------------------------------------------------------------------------- 1 | 2 | from assemblyline_core.metrics.metrics_server import StatisticsAggregator 3 | from assemblyline.common import forge 4 | 5 | if __name__ == '__main__': 6 | config = forge.get_config() 7 | with StatisticsAggregator(config=config) as sigAggregator: 8 | sigAggregator.serve_forever() 9 | -------------------------------------------------------------------------------- /assemblyline_core/plumber/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/plumber/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/replay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/replay/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/replay/creator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/replay/creator/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/replay/creator/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from assemblyline_core.replay.client import APIClient, DirectClient 4 | from assemblyline_core.replay.replay import ReplayBase, INPUT_TYPES 5 | 6 | class ReplayCreator(ReplayBase): 7 | def __init__(self): 8 | super().__init__("assemblyline.replay_creator") 9 | 10 | if not self.replay_config.creator.alert_input.enabled and \ 11 | not self.replay_config.creator.submission_input.enabled: 12 | return 13 | 14 | # Create cache directory 15 | os.makedirs(self.replay_config.creator.working_directory, exist_ok=True) 16 | 17 | # Load client 18 | client_config = {f'{input_type}_fqs': getattr(self.replay_config.creator, f'{input_type}_input').filter_queries 19 | for input_type in INPUT_TYPES} 20 | client_config['lookback_time'] = self.replay_config.creator.lookback_time 21 | 22 | if self.replay_config.creator.client.type == 'direct': 23 | self.log.info("Using direct database access client") 24 | self.client = DirectClient(self.log, **client_config) 25 | elif self.replay_config.creator.client.type == 'api': 26 | self.log.info(f"Using API access client to ({self.replay_config.creator.client.options.host})") 27 | client_config.update(self.replay_config.creator.client.options.as_primitives()) 28 | self.client = APIClient(self.log, **client_config) 29 | else: 30 | raise ValueError(f'Invalid client type ({self.replay_config.creator.client.type}). ' 31 | 'Must be either \'api\' or \'direct\'.') 32 | 33 | def try_run(self): 34 | threads = {} 35 | for input_type in INPUT_TYPES: 36 | if getattr(self.replay_config.creator, f'{input_type}_input').enabled: 37 | threads[f'Load {input_type.capitalize()}s'] = getattr(self.client, f'setup_{input_type}_input_queue') 38 | 39 | if threads: 40 | self.maintain_threads(threads) 41 | else: 42 | self.log.warning("There are no configured input, terminating") 43 | self.main_loop_exit.set() 44 | self.stop() 45 | 46 | 47 | if __name__ == '__main__': 48 | with ReplayCreator() as replay: 49 | replay.serve_forever() 50 | -------------------------------------------------------------------------------- /assemblyline_core/replay/creator/run_worker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from cart import pack_stream 5 | from io import BytesIO 6 | 7 | from assemblyline.filestore import FileStore 8 | from assemblyline.common.isotime import now_as_iso 9 | from assemblyline_core.replay.client import APIClient, DirectClient 10 | from assemblyline_core.replay.replay import ReplayBase, INPUT_TYPES 11 | 12 | REPLAY_BATCH_SIZE = int(os.environ.get("REPLAY_BATCH_SIZE", "1000")) 13 | 14 | 15 | class ReplayCreatorWorker(ReplayBase): 16 | def __init__(self): 17 | super().__init__("assemblyline.replay_creator.worker") 18 | 19 | if not self.replay_config.creator.alert_input.enabled and \ 20 | not self.replay_config.creator.submission_input.enabled: 21 | return 22 | 23 | # Initialize filestore object 24 | self.filestore = FileStore(self.replay_config.creator.output_filestore) 25 | 26 | # Create cache directory 27 | os.makedirs(self.replay_config.creator.working_directory, exist_ok=True) 28 | 29 | # Load client 30 | client_config = dict(lookback_time=self.replay_config.creator.lookback_time, 31 | alert_fqs=self.replay_config.creator.alert_input.filter_queries, 32 | submission_fqs=self.replay_config.creator.submission_input.filter_queries) 33 | 34 | if self.replay_config.creator.client.type == 'direct': 35 | self.log.info("Using direct database access client") 36 | self.client = DirectClient(self.log, **client_config) 37 | elif self.replay_config.creator.client.type == 'api': 38 | self.log.info(f"Using API access client to ({self.replay_config.creator.client.options.host})") 39 | client_config.update(self.replay_config.creator.client.options.as_primitives()) 40 | self.client = APIClient(self.log, **client_config) 41 | else: 42 | raise ValueError(f'Invalid client type ({self.replay_config.creator.client.type}). ' 43 | 'Must be either \'api\' or \'direct\'.') 44 | 45 | def process_alert(self, once=False): 46 | while self.running: 47 | # Process alerts found 48 | alert = self.client.get_next_alert() 49 | if alert: 50 | self.log.info(f"Processing alert: {alert['alert_id']}") 51 | 52 | # Make sure directories exists 53 | os.makedirs(self.replay_config.creator.working_directory, exist_ok=True) 54 | 55 | # Create the bundle 56 | bundle_path = os.path.join(self.replay_config.creator.working_directory, 57 | f"alert_{alert['alert_id']}.al_bundle") 58 | self.client.create_alert_bundle(alert['alert_id'], bundle_path) 59 | 60 | # Move the bundle 61 | self.filestore.upload(bundle_path, f"alert_{alert['alert_id']}.al_bundle") 62 | 63 | # Remove temp file 64 | if os.path.exists(bundle_path): 65 | os.unlink(bundle_path) 66 | 67 | # Set alert state done 68 | self.client.set_single_alert_complete(alert['alert_id']) 69 | 70 | if once: 71 | break 72 | 73 | def process_submission(self, once=False): 74 | while self.running: 75 | # Process submissions found 76 | submission = self.client.get_next_submission() 77 | if submission: 78 | self.log.info(f"Processing submission: {submission['sid']}") 79 | 80 | # Make sure directories exists 81 | os.makedirs(self.replay_config.creator.working_directory, exist_ok=True) 82 | 83 | # Create the bundle 84 | bundle_path = os.path.join(self.replay_config.creator.working_directory, 85 | f"submission_{submission['sid']}.al_bundle") 86 | self.client.create_submission_bundle(submission['sid'], bundle_path) 87 | 88 | # Move the bundle 89 | self.filestore.upload(bundle_path, f"submission_{submission['sid']}.al_bundle") 90 | 91 | # Remove temp file 92 | if os.path.exists(bundle_path): 93 | os.unlink(bundle_path) 94 | 95 | # Set submission state done 96 | self.client.set_single_submission_complete(submission['sid']) 97 | 98 | if once: 99 | break 100 | 101 | def _process_json_exports(self, collection, id_field, date_field, once=False): 102 | # Keep track of the last record exported to update checkpoint 103 | last_obj = None 104 | 105 | # Collection of records to be exported to a single JSON file per batch size 106 | batch = [] 107 | 108 | def upload_batch(): 109 | # Make sure directories exists 110 | os.makedirs(self.replay_config.creator.working_directory, exist_ok=True) 111 | 112 | # Create the JSON 113 | json_fn = f"{collection}_{now_as_iso()}.al_json.cart" 114 | json_path = os.path.join(self.replay_config.creator.working_directory, json_fn) 115 | with open(json_path, "wb") as fp: 116 | pack_stream(BytesIO(json.dumps(batch).encode()), fp) 117 | 118 | # Move the JSON 119 | self.filestore.upload(json_path, json_fn) 120 | 121 | # Remove temp file 122 | if os.path.exists(json_path): 123 | os.unlink(json_path) 124 | 125 | while self.running: 126 | # Process items found 127 | obj = getattr(self.client, f"get_next_{collection}")() 128 | if obj: 129 | obj_id = obj[id_field] 130 | self.log.info(f"Processing {collection}: {obj_id}") 131 | batch.append(obj) 132 | 133 | if len(batch) == REPLAY_BATCH_SIZE: 134 | upload_batch() 135 | elif last_obj: 136 | # Update the checkpoint based on the last item processed before nothing 137 | self.client._put_checkpoint(collection, last_obj[date_field]) 138 | 139 | # Check if there's anything that hasn't been exported before the queue went silent 140 | if batch: 141 | upload_batch() 142 | 143 | # Always keep track of the last object processed for later 144 | last_obj = obj 145 | 146 | if once: 147 | upload_batch() 148 | break 149 | 150 | def process_badlist(self, once=False): 151 | self._process_json_exports("badlist", "id", "updated", once) 152 | 153 | def process_safelist(self, once=False): 154 | self._process_json_exports("safelist", "id", "updated", once) 155 | 156 | def process_signature(self, once=False): 157 | self._process_json_exports("signature", "id", "last_modified", once) 158 | 159 | def process_workflow(self, once=False): 160 | self._process_json_exports("workflow", "id", "last_edit", once) 161 | 162 | def try_run(self): 163 | threads = {} 164 | for input_type in INPUT_TYPES: 165 | input_config = getattr(self.replay_config.creator, f"{input_type}_input") 166 | if input_config.enabled: 167 | for ii in range(input_config.threads): 168 | threads[f"{input_type.capitalize()} process thread #{ii}"] = getattr(self, f"process_{input_type}") 169 | if threads: 170 | self.maintain_threads(threads) 171 | else: 172 | self.log.warning("There are no configured input, terminating") 173 | self.main_loop_exit.set() 174 | self.stop() 175 | 176 | 177 | if __name__ == '__main__': 178 | with ReplayCreatorWorker() as replay: 179 | replay.serve_forever() 180 | -------------------------------------------------------------------------------- /assemblyline_core/replay/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/replay/loader/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/replay/loader/run.py: -------------------------------------------------------------------------------- 1 | import shelve 2 | import os 3 | 4 | from datetime import datetime, timedelta 5 | 6 | from assemblyline_core.replay.client import APIClient, DirectClient 7 | from assemblyline_core.replay.replay import ReplayBase 8 | 9 | 10 | class ReplayLoader(ReplayBase): 11 | def __init__(self): 12 | super().__init__("assemblyline.replay_loader") 13 | 14 | # Make sure all directories exist 15 | os.makedirs(self.replay_config.loader.working_directory, exist_ok=True) 16 | os.makedirs(self.replay_config.loader.input_directory, exist_ok=True) 17 | os.makedirs(self.replay_config.loader.failed_directory, exist_ok=True) 18 | 19 | # Create/Load the cache 20 | self.cache = shelve.open(os.path.join(self.replay_config.loader.working_directory, 'loader_cache.db')) 21 | self.last_sync_check = datetime.now() 22 | if 'files' not in self.cache: 23 | self.cache['files'] = set() 24 | 25 | # Load client 26 | if self.replay_config.loader.client.type == 'direct': 27 | self.log.info("Using direct database access client") 28 | self.client = DirectClient(self.log) 29 | elif self.replay_config.loader.client.type == 'api': 30 | self.log.info(f"Using API access client to ({self.replay_config.loader.client.options.host})") 31 | self.client = APIClient(self.log, **self.replay_config.loader.client.options.as_primitives()) 32 | else: 33 | raise ValueError(f'Invalid client type ({self.replay_config.loader.client.type}). ' 34 | 'Must be either \'api\' or \'direct\'.') 35 | 36 | def load_files(self, once=False): 37 | while self.running: 38 | new_files = False 39 | new_cache = set() 40 | 41 | # Check the datastore periodically for the last Replay bundle that was imported 42 | if datetime.now() > self.last_sync_check + timedelta(seconds=self.replay_config.loader.sync_check_interval): 43 | 44 | if not self.client.query_alerts( 45 | query=f"metadata.bundle.loaded:[now-{self.replay_config.loader.sync_check_interval}s TO now]", 46 | track_total_hits=True): 47 | self.log.warning("Haven't received a new bundle since the last check!") 48 | self.last_sync_check = datetime.now() 49 | 50 | for root, _, files in os.walk(self.replay_config.loader.input_directory, topdown=False): 51 | for name in files: 52 | # Unexpected files that could be the result of external transfer mechanisms 53 | if name.startswith('.') or not (name.endswith('.al_bundle') or \ 54 | name.endswith('.al_json') or \ 55 | name.endswith('.al_json.cart')): 56 | continue 57 | 58 | file_path = os.path.join(root, name) 59 | 60 | # Cache file 61 | new_cache.add(file_path) 62 | 63 | if file_path not in self.cache['files']: 64 | self.log.info(f'Queueing file: {file_path}') 65 | self.client.put_file(file_path) 66 | new_files = True 67 | 68 | # Cache file 69 | self.cache['files'].add(file_path) 70 | 71 | # Cleanup cache 72 | self.cache['files'] = new_cache 73 | 74 | if once: 75 | break 76 | 77 | if not new_files: 78 | self.sleep(5) 79 | 80 | def try_run(self): 81 | threads = { 82 | # Pull in completed submissions 83 | 'File loader': self.load_files 84 | } 85 | 86 | self.maintain_threads(threads) 87 | 88 | def stop(self): 89 | self.cache.close() 90 | return super().stop() 91 | 92 | 93 | if __name__ == '__main__': 94 | with ReplayLoader() as replay: 95 | replay.serve_forever() 96 | -------------------------------------------------------------------------------- /assemblyline_core/replay/loader/run_worker.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | from cart import unpack_file 5 | 6 | from assemblyline_core.replay.client import APIClient, DirectClient 7 | from assemblyline_core.replay.replay import ReplayBase 8 | 9 | 10 | class ReplayLoaderWorker(ReplayBase): 11 | def __init__(self): 12 | super().__init__("assemblyline.replay_loader.worker") 13 | 14 | # Load client 15 | if self.replay_config.loader.client.type == 'direct': 16 | self.log.info("Using direct database access client") 17 | self.client = DirectClient(self.log) 18 | elif self.replay_config.loader.client.type == 'api': 19 | self.log.info(f"Using API access client to ({self.replay_config.loader.client.options.host})") 20 | self.client = APIClient(self.log, **self.replay_config.loader.client.options.as_primitives()) 21 | else: 22 | raise ValueError(f'Invalid client type ({self.replay_config.loader.client.type}). ' 23 | 'Must be either \'api\' or \'direct\'.') 24 | 25 | def process_file(self, once=False): 26 | while self.running: 27 | file_path = self.client.get_next_file() 28 | 29 | if file_path: 30 | self.log.info(f"Processing file: {file_path}") 31 | try: 32 | if file_path.endswith(".al_bundle"): 33 | self.client.load_bundle(file_path, 34 | min_classification=self.replay_config.loader.min_classification, 35 | rescan_services=self.replay_config.loader.rescan) 36 | elif file_path.endswith(".al_json"): 37 | self.client.load_json(file_path) 38 | 39 | elif file_path.endswith(".al_json.cart"): 40 | cart_path = file_path 41 | file_path = file_path[:-5] 42 | unpack_file(cart_path, file_path) 43 | self.client.load_json(file_path) 44 | os.unlink(cart_path) 45 | 46 | if os.path.exists(file_path): 47 | os.unlink(file_path) 48 | except OSError as e: 49 | # Critical exception occurred 50 | if 'Stale file handle' in str(e): 51 | # Terminate on stale file handle from NFS mount 52 | self.log.warning("Stale file handle detected. Terminating..") 53 | self.stop() 54 | elif 'Invalid cross-device link' in str(e): 55 | # Terminate on NFS-related error 56 | self.log.warning("'Invalid cross-device link' exception detected. Terminating..") 57 | self.stop() 58 | except Exception: 59 | # Make sure failed directory exists 60 | os.makedirs(self.replay_config.loader.failed_directory, exist_ok=True) 61 | 62 | self.log.error(f"Failed to load the bundle file {file_path}, moving it to the failed directory.") 63 | failed_path = os.path.join(self.replay_config.loader.failed_directory, os.path.basename(file_path)) 64 | shutil.move(file_path, failed_path) 65 | 66 | if once: 67 | break 68 | 69 | def try_run(self): 70 | threads = {} 71 | 72 | for ii in range(self.replay_config.loader.input_threads): 73 | threads[f'File processor #{ii}'] = self.process_file 74 | 75 | self.maintain_threads(threads) 76 | 77 | def stop(self): 78 | return super().stop() 79 | 80 | 81 | if __name__ == '__main__': 82 | with ReplayLoaderWorker() as replay: 83 | replay.serve_forever() 84 | -------------------------------------------------------------------------------- /assemblyline_core/replay/replay.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import threading 4 | import yaml 5 | 6 | from typing import Callable 7 | 8 | from assemblyline.common.dict_utils import recursive_update 9 | from assemblyline.common.forge import env_substitute 10 | from assemblyline.odm.models.replay import ReplayConfig 11 | from assemblyline_core.server_base import ServerBase 12 | 13 | CONFIG_PATH = os.environ.get('REPLAY_CONFIG_PATH', '/etc/assemblyline/replay.yml') 14 | INPUT_TYPES = ['alert', 'badlist', 'safelist', 'signature', 'submission', 'workflow'] 15 | 16 | 17 | class ReplayBase(ServerBase): 18 | def __init__(self, component_name): 19 | super().__init__(component_name) 20 | 21 | # Load updated values 22 | if os.path.exists(CONFIG_PATH): 23 | with open(CONFIG_PATH) as yml_fh: 24 | self.replay_config = ReplayConfig(recursive_update(ReplayConfig().as_primitives(), 25 | yaml.safe_load(env_substitute(yml_fh.read())))) 26 | else: 27 | self.replay_config = ReplayConfig() 28 | 29 | # Thread events related to exiting 30 | self.main_loop_exit = threading.Event() 31 | 32 | def stop(self): 33 | super().stop() 34 | self.main_loop_exit.wait(30) 35 | 36 | def sleep(self, timeout: float): 37 | self.stopping.wait(timeout) 38 | return self.running 39 | 40 | def log_crashes(self, fn): 41 | @functools.wraps(fn) 42 | def with_logs(*args, **kwargs): 43 | # noinspection PyBroadException 44 | try: 45 | fn(*args, **kwargs) 46 | except Exception: 47 | self.log.exception(f'Crash in dispatcher: {fn.__name__}') 48 | return with_logs 49 | 50 | def maintain_threads(self, expected_threads: dict[str, Callable[..., None]]): 51 | expected_threads = {name: self.log_crashes(start) for name, start in expected_threads.items()} 52 | threads: dict[str, threading.Thread] = {} 53 | 54 | # Run as long as we need to 55 | while self.running: 56 | # Check for any crashed threads 57 | for name, thread in list(threads.items()): 58 | if not thread.is_alive(): 59 | self.log.warning(f'Restarting thread: {name}') 60 | threads.pop(name) 61 | 62 | # Start any missing threads 63 | for name, function in expected_threads.items(): 64 | if name not in threads: 65 | self.log.info(f'Starting thread: {name}') 66 | threads[name] = thread = threading.Thread(target=function, name=name) 67 | thread.start() 68 | 69 | # Take a break before doing it again 70 | super().heartbeat() 71 | self.sleep(2) 72 | 73 | for _t in threads.values(): 74 | _t.join() 75 | 76 | self.main_loop_exit.set() 77 | -------------------------------------------------------------------------------- /assemblyline_core/scaler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/scaler/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/scaler/collection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for collecting metric data and feeding it into the orchestration framework. 3 | 4 | The task of this module is to take data available from metrics across multiple hosts 5 | providing the same service and combine them into general statistics about the service. 6 | """ 7 | import time 8 | from typing import Dict 9 | from collections import namedtuple 10 | 11 | Row = namedtuple('Row', ['timestamp', 'busy', 'throughput']) 12 | 13 | 14 | class Collection: 15 | def __init__(self, period, ttl=None): 16 | """ 17 | A buffer for metrics data from multiple instances of multiple services. 18 | 19 | :param period: Expected seconds between updates 20 | :param ttl: Seconds before a message is dropped from the buffer 21 | """ 22 | self.period: float = period 23 | self.ttl: float = ttl or (period * 1.5) 24 | self.services: Dict[str, Dict[str, Row]] = {} 25 | 26 | def update(self, service, host, busy_seconds, throughput): 27 | # Load the sequence of data points that 28 | try: 29 | hosts = self.services[service] 30 | except KeyError: 31 | hosts = self.services[service] = {} 32 | 33 | # Add the new data 34 | hosts[host] = Row(time.time(), busy_seconds, throughput) 35 | 36 | def read(self, service): 37 | now = time.time() 38 | 39 | # Load the last messages from this service 40 | try: 41 | hosts = self.services[service] 42 | except KeyError: 43 | return None 44 | 45 | # Flush out stale messages 46 | expired = [_h for _h, _v in hosts.items() if now - _v.timestamp > self.ttl] 47 | for host_name in expired: 48 | hosts.pop(host_name, None) 49 | 50 | # If flushing got rid of all of our messages our state is 'offline' 51 | if not hosts: 52 | return None 53 | 54 | # 55 | return { 56 | 'instances': len(hosts), 57 | 'duty_cycle': sum(_v.busy for _v in hosts.values())/(len(hosts) * self.period), 58 | } 59 | -------------------------------------------------------------------------------- /assemblyline_core/scaler/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from .docker_ctl import DockerController 2 | from .kubernetes_ctl import KubernetesController 3 | -------------------------------------------------------------------------------- /assemblyline_core/scaler/controllers/interface.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Optional 3 | 4 | if TYPE_CHECKING: 5 | from assemblyline_core.scaler.scaler_server import ServiceProfile 6 | 7 | 8 | class ServiceControlError(RuntimeError): 9 | def __init__(self, message, service_name): 10 | super().__init__(message) 11 | self.service_name = service_name 12 | 13 | 14 | class ControllerInterface: 15 | def add_profile(self, profile, scale=0): 16 | """Tell the controller about a service profile it needs to manage.""" 17 | raise NotImplementedError() 18 | 19 | def memory_info(self): 20 | """Return free and total memory in the system.""" 21 | raise NotImplementedError() 22 | 23 | def cpu_info(self): 24 | """Return free and total memory in the system.""" 25 | raise NotImplementedError() 26 | 27 | def free_cpu(self) -> float: 28 | """Number of cores available for reservation.""" 29 | return self.cpu_info()[0] 30 | 31 | def free_memory(self) -> float: 32 | """Megabytes of RAM that has not been reserved.""" 33 | return self.memory_info()[0] 34 | 35 | def get_target(self, service_name): 36 | """Get the target for running instances of a service.""" 37 | raise NotImplementedError() 38 | 39 | def get_targets(self): 40 | """Get the target for running instances of all services.""" 41 | raise NotImplementedError() 42 | 43 | def set_target(self, service_name, target): 44 | """Set the target for running instances of a service.""" 45 | raise NotImplementedError() 46 | 47 | def restart(self, service: ServiceProfile): 48 | raise NotImplementedError() 49 | 50 | def get_running_container_names(self): 51 | raise NotImplementedError() 52 | 53 | def new_events(self): 54 | return [] 55 | 56 | def stateful_container_key(self, service_name: str, container_name: str, spec, change_key: str) -> Optional[str]: 57 | raise NotImplementedError() 58 | 59 | def start_stateful_container(self, service_name: str, container_name: str, spec, labels, change_key): 60 | raise NotImplementedError() 61 | 62 | def stop_containers(self, labels): 63 | raise NotImplementedError() 64 | 65 | def prepare_network(self, service_name, internet, dependency_internet): 66 | raise NotImplementedError() 67 | 68 | def stop(self): 69 | pass 70 | -------------------------------------------------------------------------------- /assemblyline_core/scaler/run_scaler.py: -------------------------------------------------------------------------------- 1 | 2 | from assemblyline_core.scaler.scaler_server import ScalerServer 3 | 4 | 5 | if __name__ == '__main__': 6 | with ScalerServer() as scaler: 7 | scaler.serve_forever() 8 | -------------------------------------------------------------------------------- /assemblyline_core/signature_client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from assemblyline.common import forge 4 | from assemblyline.common.isotime import iso_to_epoch, now_as_iso 5 | from assemblyline.common.memory_zip import InMemoryZip 6 | from assemblyline.datastore.helper import AssemblylineDatastore 7 | from assemblyline.odm.messages.changes import Operation 8 | from assemblyline.odm.models.service import SIGNATURE_DELIMITERS 9 | from assemblyline.odm.models.signature import DEPLOYED_STATUSES, STALE_STATUSES, DRAFT_STATUSES 10 | 11 | 12 | DEFAULT_DELIMITER = "\n\n" 13 | CLASSIFICATION = forge.get_classification() 14 | 15 | 16 | # Signature class 17 | class SignatureClient: 18 | """A helper class to simplify signature management for privileged services and service-server.""" 19 | 20 | def __init__(self, datastore: AssemblylineDatastore = None, config=None, classification_replace_map={}): 21 | self.log = logging.getLogger('assemblyline.signature_client') 22 | self.config = config or forge.CachedObject(forge.get_config) 23 | self.datastore = datastore or forge.get_datastore(self.config) 24 | self.service_list = forge.CachedObject(self.datastore.list_all_services, kwargs=dict(as_obj=False, full=True)) 25 | self.delimiters = forge.CachedObject(self._get_signature_delimiters) 26 | self.classification_replace_map = classification_replace_map 27 | 28 | def _get_signature_delimiters(self): 29 | signature_delimiters = {} 30 | for service in self.service_list: 31 | if service.get("update_config", {}).get("generates_signatures", False): 32 | signature_delimiters[service['name'].lower()] = self._get_signature_delimiter(service['update_config']) 33 | return signature_delimiters 34 | 35 | def _get_signature_delimiter(self, update_config): 36 | delimiter_type = update_config['signature_delimiter'] 37 | if delimiter_type == 'custom': 38 | delimiter = update_config['custom_delimiter'].encode().decode('unicode-escape') 39 | else: 40 | delimiter = SIGNATURE_DELIMITERS.get(delimiter_type, '\n\n') 41 | return {'type': delimiter_type, 'delimiter': delimiter} 42 | 43 | def _update_classification(self, signature): 44 | classification = signature['classification'] 45 | # Update classification of signatures based on rewrite definition 46 | for term, replacement in self.classification_replace_map.items(): 47 | if replacement.startswith('_'): 48 | # Replace with known field in Signature model 49 | # Otherwise replace with literal 50 | if signature.get(replacement[1:]): 51 | replacement = signature[replacement[1:]] 52 | 53 | classification = classification.replace(term, replacement) 54 | 55 | # Save the (possibly) updated classfication 56 | signature['classification'] = classification 57 | 58 | 59 | def add_update(self, data, dedup_name=True): 60 | if data.get('type', None) is None or data['name'] is None or data['data'] is None: 61 | raise ValueError("Signature id, name, type and data are mandatory fields.") 62 | 63 | # Compute signature ID if missing 64 | data['signature_id'] = data.get('signature_id', data['name']) 65 | 66 | key = f"{data['type']}_{data['source']}_{data['signature_id']}" 67 | 68 | # Test signature name 69 | if dedup_name: 70 | check_name_query = f"name:\"{data['name']}\" " \ 71 | f"AND type:\"{data['type']}\" " \ 72 | f"AND source:\"{data['source']}\" " \ 73 | f"AND NOT id:\"{key}\"" 74 | other = self.datastore.signature.search(check_name_query, fl='id', rows='0') 75 | if other['total'] > 0: 76 | raise ValueError("A signature with that name already exists") 77 | 78 | old = self.datastore.signature.get(key, as_obj=False) 79 | op = Operation.Modified if old else Operation.Added 80 | if old: 81 | if old['data'] == data['data']: 82 | return True, key, None 83 | 84 | # Ensure that the last state change, if any, was made by a user and not a system account. 85 | user_modified_last_state = old['state_change_user'] not in ['update_service_account', None] 86 | 87 | # If rule state is moving to an active state but was disabled by a user before: 88 | # Keep original inactive state, a user changed the state for a reason 89 | if user_modified_last_state and data['status'] == 'DEPLOYED' and data['status'] != old['status']: 90 | data['status'] = old['status'] 91 | 92 | # Preserve last state change 93 | data['state_change_date'] = old['state_change_date'] 94 | data['state_change_user'] = old['state_change_user'] 95 | 96 | # Preserve signature stats 97 | data['stats'] = old['stats'] 98 | 99 | self._update_classification(data) 100 | 101 | # Save the signature 102 | success = self.datastore.signature.save(key, data) 103 | return success, key, op 104 | 105 | def add_update_many(self, source, sig_type, data, dedup_name=True): 106 | if source is None or sig_type is None or not isinstance(data, list): 107 | raise ValueError("Source, source type and data are mandatory fields.") 108 | 109 | # Test signature names 110 | names_map = {x['name']: f"{x['type']}_{x['source']}_{x.get('signature_id', x['name'])}" for x in data} 111 | 112 | skip_list = [] 113 | if dedup_name: 114 | for item in self.datastore.signature.stream_search(f"type: \"{sig_type}\" AND source:\"{source}\"", 115 | fl="id,name", as_obj=False, item_buffer_size=1000): 116 | lookup_id = names_map.get(item['name'], None) 117 | if lookup_id and lookup_id != item['id']: 118 | skip_list.append(lookup_id) 119 | 120 | if skip_list: 121 | data = [ 122 | x for x in data 123 | if f"{x['type']}_{x['source']}_{x.get('signature_id', x['name'])}" not in skip_list] 124 | 125 | old_data = self.datastore.signature.multiget(list(names_map.values()), as_dictionary=True, as_obj=False, 126 | error_on_missing=False) 127 | 128 | plan = self.datastore.signature.get_bulk_plan() 129 | for rule in data: 130 | key = f"{rule['type']}_{rule['source']}_{rule.get('signature_id', rule['name'])}" 131 | if key in old_data: 132 | # Ensure that the last state change, if any, was made by a user and not a system account. 133 | user_modified_last_state = old_data[key]['state_change_user'] not in ['update_service_account', None] 134 | 135 | # If rule state is moving to an active state but was disabled by a user before: 136 | # Keep original inactive state, a user changed the state for a reason 137 | if user_modified_last_state and rule['status'] == 'DEPLOYED' and rule['status'] != old_data[key][ 138 | 'status']: 139 | rule['status'] = old_data[key]['status'] 140 | 141 | # Preserve last state change 142 | rule['state_change_date'] = old_data[key]['state_change_date'] 143 | rule['state_change_user'] = old_data[key]['state_change_user'] 144 | 145 | # Preserve signature stats 146 | rule['stats'] = old_data[key]['stats'] 147 | 148 | self._update_classification(rule) 149 | 150 | plan.add_upsert_operation(key, rule) 151 | 152 | if not plan.empty: 153 | res = self.datastore.signature.bulk(plan) 154 | return {"success": len(res['items']), "errors": res['errors'], "skipped": skip_list} 155 | 156 | return {"success": 0, "errors": [], "skipped": skip_list} 157 | 158 | def change_status(self, signature_id, status, user={}): 159 | possible_statuses = DEPLOYED_STATUSES + DRAFT_STATUSES 160 | if status not in possible_statuses: 161 | raise ValueError(f"You cannot apply the status {status} on yara rules.") 162 | 163 | data = self.datastore.signature.get(signature_id, as_obj=False) 164 | if data: 165 | if user and not CLASSIFICATION.is_accessible(user['classification'], 166 | data.get('classification', CLASSIFICATION.UNRESTRICTED)): 167 | raise PermissionError("You are not allowed change status on this signature") 168 | 169 | if data['status'] in STALE_STATUSES and status not in DRAFT_STATUSES: 170 | raise ValueError(f"Only action available while signature in {data['status']} " 171 | f"status is to change signature to a DRAFT status. ({', '.join(DRAFT_STATUSES)})") 172 | 173 | if data['status'] in DEPLOYED_STATUSES and status in DRAFT_STATUSES: 174 | raise ValueError(f"You cannot change the status of signature {signature_id} from " 175 | f"{data['status']} to {status}.") 176 | 177 | today = now_as_iso() 178 | uname = user.get('uname') 179 | 180 | if status not in ['DISABLED', 'INVALID', 'TESTING']: 181 | query = f"status:{status} AND signature_id:{data['signature_id']} AND NOT id:{signature_id}" 182 | others_operations = [ 183 | ('SET', 'last_modified', today), 184 | ('SET', 'state_change_date', today), 185 | ('SET', 'state_change_user', uname), 186 | ('SET', 'status', 'DISABLED') 187 | ] 188 | self.datastore.signature.update_by_query(query, others_operations) 189 | 190 | operations = [ 191 | ('SET', 'last_modified', today), 192 | ('SET', 'state_change_date', today), 193 | ('SET', 'state_change_user', uname), 194 | ('SET', 'status', status) 195 | ] 196 | 197 | return self.datastore.signature.update(signature_id, operations), data 198 | raise FileNotFoundError(f"Signature not found. ({signature_id})") 199 | 200 | def download(self, query=None, access=None) -> bytes: 201 | if not query: 202 | query = "*" 203 | 204 | output_files = {} 205 | 206 | signature_list = sorted( 207 | self.datastore.signature.stream_search( 208 | query, fl="signature_id,type,source,data,order", access_control=access, as_obj=False, 209 | item_buffer_size=1000), 210 | key=lambda x: x['order']) 211 | 212 | for sig in signature_list: 213 | out_fname = f"{sig['type']}/{sig['source']}" 214 | if self.delimiters.get(sig['type'], {}).get('type', None) == 'file': 215 | out_fname = f"{out_fname}/{sig['signature_id']}" 216 | output_files.setdefault(out_fname, []) 217 | output_files[out_fname].append(sig['data']) 218 | 219 | output_zip = InMemoryZip() 220 | for fname, data in output_files.items(): 221 | separator = self.delimiters.get(fname.split('/')[0], {}).get('delimiter', DEFAULT_DELIMITER) 222 | output_zip.append(fname, separator.join(data)) 223 | 224 | return output_zip.read() 225 | 226 | def update_available(self, since='', sig_type='*'): 227 | since = since or '1970-01-01T00:00:00.000000Z' 228 | last_update = iso_to_epoch(since) 229 | last_modified = iso_to_epoch(self.datastore.get_signature_last_modified(sig_type)) 230 | 231 | return last_modified > last_update 232 | -------------------------------------------------------------------------------- /assemblyline_core/updater/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/updater/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/vacuum/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/vacuum/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/vacuum/crawler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import concurrent.futures 3 | import json 4 | import logging 5 | import os 6 | import io 7 | import signal 8 | from time import sleep, time 9 | from typing import TYPE_CHECKING 10 | 11 | from assemblyline.common.log import init_logging 12 | from assemblyline.common.forge import get_config 13 | from assemblyline.odm.models.config import Config 14 | from assemblyline.remote.datatypes.queues.named import NamedQueue 15 | from assemblyline.remote.datatypes import get_client as get_redis_client 16 | 17 | from multiprocessing import Event 18 | 19 | if TYPE_CHECKING: 20 | from redis import Redis 21 | 22 | 23 | DEL_TIME: int = 60 * 60 # An hour 24 | MAX_QUEUE_LENGTH = 100000 25 | VACUUM_BUFFER_NAME = 'vacuum-file-buffer' 26 | 27 | stop_event = Event() 28 | 29 | 30 | # noinspection PyUnusedLocal 31 | def sigterm_handler(_signum=0, _frame=None): 32 | stop_event.set() 33 | 34 | 35 | _last_heartbeat: float = 0 36 | 37 | 38 | def heartbeat(config: Config): 39 | global _last_heartbeat 40 | if _last_heartbeat + 3 < time(): 41 | with io.open(config.logging.heartbeat_file, 'ab'): 42 | os.utime(config.logging.heartbeat_file) 43 | _last_heartbeat = time() 44 | 45 | 46 | logger = logging.getLogger('assemblyline.vacuum') 47 | 48 | 49 | def main(): 50 | config = get_config() 51 | signal.signal(signal.SIGTERM, sigterm_handler) 52 | 53 | # Initialize logging 54 | init_logging('assemblyline.vacuum') 55 | logger.info('Vacuum starting up...') 56 | 57 | # Initialize cache 58 | logger.info("Connect to redis...") 59 | redis = get_redis_client(config.core.redis.nonpersistent.host, config.core.redis.nonpersistent.port, False) 60 | run(config, redis) 61 | 62 | 63 | def run(config: Config, redis: Redis): 64 | vacuum_config = config.core.vacuum 65 | 66 | # connect to workers 67 | logger.info("Connect to work queue...") 68 | queue = NamedQueue(VACUUM_BUFFER_NAME, redis) 69 | 70 | logger.info("Load cache...") 71 | files_list_cache = os.path.join(vacuum_config.list_cache_directory, 'visited.json') 72 | try: 73 | with open(files_list_cache, 'r') as handle: 74 | previous_iteration_files: set[str] = set(json.load(handle)) 75 | except (OSError, json.JSONDecodeError): 76 | previous_iteration_files = set() 77 | 78 | # Make sure we can access the cache file 79 | with open(files_list_cache, 'w'): 80 | pass 81 | 82 | this_iteration_files: list[str] = [] 83 | length = queue.length() 84 | 85 | # Make sure some input is configured 86 | if not vacuum_config.data_directories: 87 | logger.error("No input directory configured.") 88 | return 89 | 90 | logger.info("Starting main loop...") 91 | while not stop_event.is_set(): 92 | heartbeat(config) 93 | remove_dir_list = [] 94 | futures: list[concurrent.futures.Future] = [] 95 | with concurrent.futures.ThreadPoolExecutor(20) as pool: 96 | for data_directory in vacuum_config.data_directories: 97 | for root, dirs, files in os.walk(data_directory): 98 | heartbeat(config) 99 | while len(futures) > 50: 100 | futures = [f for f in futures if not f.done()] 101 | heartbeat(config) 102 | sleep(0.1) 103 | 104 | if length > MAX_QUEUE_LENGTH: 105 | while len(futures) > 0: 106 | futures = [f for f in futures if not f.done()] 107 | heartbeat(config) 108 | sleep(0.1) 109 | length = queue.length() 110 | 111 | while length > MAX_QUEUE_LENGTH: 112 | logger.warning("Backlog full") 113 | length = queue.length() 114 | for _ in range(120): 115 | heartbeat(config) 116 | sleep(1) 117 | if stop_event.is_set(): 118 | break 119 | if stop_event.is_set(): 120 | break 121 | 122 | if stop_event.is_set(): 123 | break 124 | 125 | if not dirs and not files: 126 | if root == data_directory: 127 | continue 128 | 129 | cur_time = time() 130 | dir_time = os.lstat(root).st_mtime 131 | if (cur_time - dir_time) > DEL_TIME: 132 | logger.debug('Directory %s marked for removal.' % root) 133 | remove_dir_list.append(root) 134 | else: 135 | logger.debug(f'Directory {root} empty but not old enough. ' 136 | f'[{int(cur_time - dir_time)}/{DEL_TIME}]') 137 | continue 138 | 139 | if files: 140 | new_file_list = [os.path.join(root, f) for f in files 141 | if not f.startswith(".") and not f.endswith('.bad')] 142 | new_files = set(new_file_list) - previous_iteration_files 143 | 144 | if new_files: 145 | futures.append(pool.submit(queue.push, *new_files)) 146 | # queue.push(*new_files) 147 | length += len(new_files) 148 | this_iteration_files.extend(new_files) 149 | 150 | with open(files_list_cache, 'w') as handle: 151 | json.dump(this_iteration_files, handle) 152 | 153 | previous_iteration_files = set(this_iteration_files) 154 | this_iteration_files = [] 155 | 156 | for d in remove_dir_list: 157 | logger.debug("Removing empty directory: %s" % d) 158 | # noinspection PyBroadException 159 | try: 160 | os.rmdir(d) 161 | except Exception: 162 | pass 163 | 164 | if not stop_event.is_set(): 165 | sleep(5) 166 | 167 | logger.info('Good bye!') 168 | 169 | 170 | if __name__ == '__main__': 171 | main() 172 | -------------------------------------------------------------------------------- /assemblyline_core/vacuum/department_map.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import time 4 | import logging 5 | import threading 6 | from typing import Optional 7 | 8 | from assemblyline.common.iprange import RangeTable 9 | 10 | import requests 11 | 12 | 13 | logger = logging.getLogger("assemblyline.vacuum.department_map") 14 | 15 | 16 | class DepartmentMap: 17 | UPDATE_INTERVAL = 60 * 60 18 | 19 | @staticmethod 20 | @functools.cache 21 | def load(url: Optional[str], init: Optional[str]): 22 | return DepartmentMap(url, init) 23 | 24 | def __init__(self, url: Optional[str], init: Optional[str]): 25 | self.url = url 26 | self.init_data = init 27 | self.lock = threading.Lock() 28 | self.table = RangeTable() 29 | self.update_time = 0 30 | self._load_department_map() 31 | 32 | def _load_department_map(self): 33 | # Don't load more than once every 5 seconds 34 | if time.time() - self.update_time < 5: 35 | return 36 | 37 | with self.lock: 38 | # Recheck in case it was updated while waiting for lock 39 | if time.time() - self.update_time < 5: 40 | return 41 | 42 | table = RangeTable() 43 | 44 | try: 45 | if self.init_data: 46 | for row in json.loads(self.init_data): 47 | if ':' not in row['LOWER'] and ':' not in row['UPPER']: 48 | # print(row["LOWER"], row['UPPER'], row['LABEL']) 49 | table.add_range(row['LOWER'], row['UPPER'], row['LABEL']) 50 | except Exception: 51 | logger.exception("Error parsing department_map_init") 52 | 53 | try: 54 | if self.url: 55 | res = requests.get(self.url, verify=False) 56 | res.raise_for_status() 57 | 58 | for row in res.json(): 59 | if ':' not in row['LOWER'] and ':' not in row['UPPER']: 60 | # print(row["LOWER"], row['UPPER'], row['LABEL']) 61 | table.add_range(row['LOWER'], row['UPPER'], row['LABEL']) 62 | except Exception: 63 | logger.exception("Error parsing department_map_url") 64 | 65 | self.table = table 66 | self.update_time = time.time() 67 | 68 | def _refresh_department_map(self): 69 | if time.time() - self.update_time > self.UPDATE_INTERVAL: 70 | self._load_department_map() 71 | 72 | def __getitem__(self, ip) -> Optional[str]: 73 | self._refresh_department_map() 74 | try: 75 | return self.table[ip] 76 | except KeyError: 77 | self._load_department_map() 78 | try: 79 | return self.table[ip] 80 | except KeyError: 81 | return None 82 | 83 | 84 | if __name__ == '__main__': 85 | departments = DepartmentMap('', None) 86 | print(departments['48.49.39.100']) 87 | -------------------------------------------------------------------------------- /assemblyline_core/vacuum/safelist.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Tuple, Dict 3 | 4 | from assemblyline.odm.models.config import VacuumSafelistItem 5 | 6 | _safelist = [ 7 | { 8 | 'name': 'ibiblio.org', 9 | 'conditions': { 10 | 'url': r'^mirrors\.ibiblio\.org/' 11 | } 12 | }, 13 | { 14 | 'name': 'Symantec Updates', 15 | 'conditions': { 16 | 'url': r'[^/]*\.?liveupdate\.symantecliveupdate\.com(?:\:[0-9]{2,5})?/' 17 | } 18 | }, 19 | { 20 | 'name': 'Google Earth', 21 | 'conditions': { 22 | 'url': r'[^/]*\.?google\.com/mw-earth-vectordb/' 23 | } 24 | }, 25 | # Examples 26 | # download.windowsupdate.com/c/msdownload/update/software/defu/2015/01/am_delta_74634f1206094529d2f336f24da8429a7d4ebec0.exe 27 | # download.windowsupdate.com/c/msdownload/update/software/defu/2015/01/am_delta_74634f1206094529d2f336f24da8429a7d4ebec0.exe 28 | # au.v4.download.windowsupdate.com/d/msdownload/update/software/defu/2015/01/am_delta_b0023f90835cd814b953c331f93776f3108936b9.exe 29 | { 30 | 'name': 'Microsoft Windows Updates', 31 | 'conditions': { 32 | 'url': r'[^/]*\.windowsupdate\.com/' 33 | } 34 | }, 35 | { 36 | 'name': 'Microsoft Windows Updates', 37 | 'conditions': { 38 | 'domain': r'[^/]*\.windowsupdate\.com' 39 | } 40 | }, 41 | { 42 | 'name': 'Microsoft Package Distribution', 43 | 'conditions': { 44 | 'domain': r'[^/]*\.?delivery\.mp\.microsoft\.com' 45 | } 46 | }, 47 | ] 48 | 49 | _operators = { 50 | 'in': lambda args: lambda x: x in args, 51 | 'not in': lambda args: lambda x: x not in args, 52 | 'regexp': lambda args: re.compile(*args).match, 53 | } 54 | 55 | 56 | def _transform(condition): 57 | if isinstance(condition, str): 58 | args = [condition] 59 | func = 'regexp' 60 | else: 61 | args = list(condition[1:]) 62 | func = condition[0] 63 | 64 | return _operators[func](args) 65 | 66 | 67 | def _matches(data, sigs): 68 | cache = {} 69 | unknown = 0 70 | for sig in sigs: 71 | result = _match(cache, data, sig) 72 | if result: 73 | name = sig.get('name', None) 74 | if not name: 75 | unknown += 1 76 | name = "unknown%d" % unknown 77 | yield name, result 78 | return 79 | 80 | 81 | def _match(cache, data, sig): 82 | summary = {} 83 | results = [ 84 | _call(cache, data, f, k) for k, f in sig['conditions'].items() 85 | ] 86 | if all(results): 87 | [summary.update(r) for r in results] # pylint: disable=W0106 88 | return summary 89 | 90 | 91 | # noinspection PyBroadException 92 | def _call(cache, data, func, key): 93 | try: 94 | value = cache.get(key, None) 95 | if not value: 96 | cache[key] = value = data.get(key) 97 | if not callable(func): 98 | func = _transform(func) 99 | return {key: value} if func(value) else {} 100 | except Exception: 101 | return {} 102 | 103 | 104 | class VacuumSafelist: 105 | def __init__(self, data: list[VacuumSafelistItem]): 106 | self._safelist = _safelist 107 | self._safelist.extend([ 108 | row.as_primitives() if isinstance(row, VacuumSafelistItem) else row 109 | for row in data 110 | ]) 111 | VacuumSafelist.optimize(self._safelist) 112 | 113 | def drop(self, data: Dict) -> Tuple[str, Dict]: 114 | return next(_matches(data, self._safelist), ("", {})) 115 | 116 | @staticmethod 117 | def optimize(signatures): 118 | for sig in signatures: 119 | conditions = sig.get('conditions') 120 | for k, v in conditions.items(): 121 | conditions[k] = _transform(v) 122 | 123 | -------------------------------------------------------------------------------- /assemblyline_core/vacuum/stream_map.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | import functools 4 | import logging 5 | from typing import Optional 6 | import json 7 | from collections import namedtuple 8 | 9 | import requests 10 | 11 | 12 | logger = logging.getLogger("assemblyline.vacuum.stream_map") 13 | 14 | Stream = namedtuple('Stream', [ 15 | 'id', 16 | 'name', 17 | 'description', 18 | 'zone_id', 19 | 'classification' 20 | ]) 21 | 22 | 23 | class StreamMap: 24 | UPDATE_INTERVAL = 60 * 15 25 | 26 | @staticmethod 27 | @functools.cache 28 | def load(url: Optional[str], init: Optional[str]): 29 | return StreamMap(url, init) 30 | 31 | def __init__(self, url: Optional[str], init: Optional[str]): 32 | self.url = url 33 | self.init_data = init 34 | self.lock = threading.Lock() 35 | self.table: dict[int, Stream] = {} 36 | self.update_time = 0 37 | self._load_stream_map() 38 | 39 | def _load_stream_map(self): 40 | # Don't load more than once every 5 seconds 41 | if time.time() - self.update_time < 5: 42 | return 43 | 44 | with self.lock: 45 | # Recheck in case it was updated while waiting for lock 46 | if time.time() - self.update_time < 5: 47 | return 48 | 49 | table = {} 50 | try: 51 | if self.init_data: 52 | for stream in json.loads(self.init_data): 53 | stream = Stream( 54 | id=int(stream['STREAM_ID']), 55 | name=stream['STREAM_NAME'], 56 | description=stream['STREAM_DESCRIPTION'], 57 | zone_id=stream['ZONE'], 58 | classification=f"{stream.get('LEVEL', 'PB')}//{stream.get('CAVEAT', 'CND')}" 59 | ) 60 | table[stream.id] = stream 61 | 62 | except Exception: 63 | logger.exception("Error parsing stream_map_init data") 64 | 65 | try: 66 | if self.url: 67 | res = requests.get(self.url, verify=False) 68 | res.raise_for_status() 69 | 70 | for stream in res.json()['data']: 71 | stream = Stream( 72 | id=int(stream['STREAM_ID']), 73 | name=stream['STREAM_NAME'], 74 | description=stream['STREAM_DESCRIPTION'], 75 | zone_id=stream['ZONE'], 76 | classification=f"{stream.get('LEVEL', 'PB')}//{stream.get('CAVEAT', 'CND')}" 77 | ) 78 | table[stream.id] = stream 79 | except Exception: 80 | logger.exception("Error parsing stream_map_url data") 81 | 82 | self.table = table 83 | self.update_time = time.time() 84 | 85 | def _refresh_stream_map(self): 86 | if time.time() - self.update_time > self.UPDATE_INTERVAL: 87 | self._load_stream_map() 88 | 89 | def __getitem__(self, stream_id: int) -> Optional[Stream]: 90 | self._refresh_stream_map() 91 | try: 92 | return self.table[stream_id] 93 | except KeyError: 94 | self._load_stream_map() 95 | return self.table.get(stream_id) 96 | 97 | 98 | if __name__ == '__main__': 99 | streams = StreamMap('', None) 100 | print(streams[10]) 101 | print(streams[10000000]) 102 | -------------------------------------------------------------------------------- /assemblyline_core/workflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CybercentreCanada/assemblyline-core/ae4bb2ae1f3da2218e7fc681405b57ed44f744ea/assemblyline_core/workflow/__init__.py -------------------------------------------------------------------------------- /assemblyline_core/workflow/run_workflow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import elasticapm 4 | import time 5 | 6 | from assemblyline_core.server_base import ServerBase 7 | from assemblyline.common import forge 8 | from assemblyline.common.isotime import now_as_iso 9 | from assemblyline.common.str_utils import safe_str 10 | 11 | from assemblyline.datastore.exceptions import SearchException 12 | from assemblyline.odm.models.alert import Event 13 | from assemblyline.odm.models.workflow import Workflow 14 | 15 | 16 | class WorkflowManager(ServerBase): 17 | def __init__(self): 18 | super().__init__('assemblyline.workflow') 19 | 20 | self.config = forge.get_config() 21 | self.datastore = forge.get_datastore(self.config) 22 | self.start_ts = f"{self.datastore.ds.now}/{self.datastore.ds.day}-1{self.datastore.ds.day}" 23 | 24 | if self.config.core.metrics.apm_server.server_url is not None: 25 | self.log.info(f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}") 26 | elasticapm.instrument() 27 | self.apm_client = forge.get_apm_client("workflow") 28 | else: 29 | self.apm_client = None 30 | 31 | def stop(self): 32 | if self.apm_client: 33 | elasticapm.uninstrument() 34 | super().stop() 35 | 36 | def get_last_reporting_ts(self, p_start_ts): 37 | # Start of transaction 38 | if self.apm_client: 39 | self.apm_client.begin_transaction("Get last reporting timestamp") 40 | 41 | self.log.info(f"Finding reporting timestamp for the last alert since {p_start_ts}...") 42 | result = None 43 | while result is None: 44 | try: 45 | result = self.datastore.alert.search(f"reporting_ts:[{p_start_ts} TO *]", 46 | sort='reporting_ts desc', rows=1, fl='reporting_ts', as_obj=False) 47 | except SearchException as e: 48 | self.log.warning(f"Failed to load last reported alert from the datastore, retrying... :: {e}") 49 | continue 50 | 51 | items = result.get('items', [{}]) or [{}] 52 | 53 | ret_val = items[0].get("reporting_ts", p_start_ts) 54 | 55 | # End of transaction 56 | if self.apm_client: 57 | elasticapm.label(start_ts=p_start_ts, reporting_ts=ret_val) 58 | self.apm_client.end_transaction('get_last_reporting_ts', 'new_ts' if ret_val != p_start_ts else 'same_ts') 59 | 60 | return ret_val 61 | 62 | def try_run(self, run_once=False): 63 | self.datastore.alert.commit() 64 | while self.running: 65 | self.heartbeat() 66 | end_ts = self.get_last_reporting_ts(self.start_ts) 67 | if self.start_ts != end_ts: 68 | # Start of transaction 69 | if self.apm_client: 70 | self.apm_client.begin_transaction("Load workflows") 71 | 72 | workflow_queries = [Workflow({ 73 | 'status': "TRIAGE", 74 | 'name': "Triage all with no status", 75 | 'creator': "SYSTEM", 76 | 'edited_by': "SYSTEM", 77 | 'query': "NOT status:*", 78 | 'workflow_id': "DEFAULT" 79 | })] 80 | 81 | try: 82 | for item in self.datastore.workflow.stream_search("status:MALICIOUS"): 83 | workflow_queries.append(item) 84 | 85 | for item in self.datastore.workflow.stream_search("status:NON-MALICIOUS"): 86 | workflow_queries.append(item) 87 | 88 | for item in self.datastore.workflow.stream_search("status:ASSESS"): 89 | workflow_queries.append(item) 90 | 91 | for item in self.datastore.workflow.stream_search('-status:["" TO *]'): 92 | workflow_queries.append(item) 93 | except SearchException as e: 94 | self.log.warning(f"Failed to load workflows from the datastore, retrying... :: {e}") 95 | 96 | # End of transaction 97 | if self.apm_client: 98 | elasticapm.label(number_of_workflows=len(workflow_queries)) 99 | self.apm_client.end_transaction('loading_workflows', 'search_exception') 100 | continue 101 | 102 | # End of transaction 103 | if self.apm_client: 104 | elasticapm.label(number_of_workflows=len(workflow_queries)) 105 | self.apm_client.end_transaction('loading_workflows', 'success') 106 | 107 | for workflow in workflow_queries: 108 | # Only action workflow if it's enabled 109 | if not workflow.enabled: 110 | continue 111 | 112 | # Trigger a heartbeat to let the system know the workflow manager is still alive between tasks 113 | self.heartbeat() 114 | 115 | # Start of transaction 116 | if self.apm_client: 117 | self.apm_client.begin_transaction("Execute workflows") 118 | elasticapm.label(query=workflow.query, 119 | labels=workflow.labels, 120 | status=workflow.status, 121 | priority=workflow.priority, 122 | user=workflow.creator) 123 | 124 | self.log.info(f'Executing workflow filter: {workflow.name}') 125 | labels = workflow.labels or [] 126 | status = workflow.status or None 127 | priority = workflow.priority or None 128 | 129 | if not status and not labels and not priority: 130 | # End of transaction 131 | if self.apm_client: 132 | self.apm_client.end_transaction(workflow.name, 'no_action') 133 | continue 134 | 135 | fq = [f"reporting_ts:[{self.start_ts} TO {end_ts}]", "NOT extended_scan:submitted"] 136 | 137 | event_data = Event({'entity_type': 'workflow', 138 | 'entity_id': workflow.workflow_id, 139 | 'entity_name': workflow.name}) 140 | operations = [] 141 | fq_items = [] 142 | if labels: 143 | operations.extend([(self.datastore.alert.UPDATE_APPEND_IF_MISSING, 'label', lbl) 144 | for lbl in labels]) 145 | for label in labels: 146 | fq_items.append(f'label:"{label}"') 147 | event_data.labels = labels 148 | if priority: 149 | operations.append((self.datastore.alert.UPDATE_SET, 'priority', priority)) 150 | fq_items.append("priority:*") 151 | event_data.priority = priority 152 | if status: 153 | operations.append((self.datastore.alert.UPDATE_SET, 'status', status)) 154 | fq_items.append("(status:MALICIOUS OR status:NON-MALICIOUS OR status:ASSESS)") 155 | event_data.status = status 156 | 157 | fq.append(f"NOT ({' AND '.join(fq_items)})") 158 | # Add event to alert's audit history 159 | operations.append((self.datastore.alert.UPDATE_APPEND, 'events', event_data)) 160 | 161 | try: 162 | count = self.datastore.alert.update_by_query(workflow.query, operations, filters=fq) 163 | if self.apm_client: 164 | elasticapm.label(affected_alerts=count) 165 | 166 | if count: 167 | self.log.info(f"{count} Alert(s) were affected by this filter.") 168 | if workflow.workflow_id != "DEFAULT": 169 | seen = now_as_iso() 170 | operations = [ 171 | (self.datastore.workflow.UPDATE_INC, 'hit_count', count), 172 | (self.datastore.workflow.UPDATE_SET, 'last_seen', seen), 173 | ] 174 | if not workflow.first_seen: 175 | # Set first seen for workflow if not set 176 | operations.append((self.datastore.workflow.UPDATE_SET, 'first_seen', seen)) 177 | self.datastore.workflow.update(workflow.workflow_id, operations) 178 | 179 | except SearchException: 180 | self.log.warning(f"Invalid query '{safe_str(workflow.query or '')}' in workflow " 181 | f"'{workflow.name or 'unknown'}' by '{workflow.creator or 'unknown'}'") 182 | 183 | # End of transaction 184 | if self.apm_client: 185 | self.apm_client.end_transaction(workflow.name, 'search_exception') 186 | 187 | continue 188 | 189 | # End of transaction 190 | if self.apm_client: 191 | self.apm_client.end_transaction(workflow.name, 'success') 192 | 193 | # Marking all alerts for the time period as their workflow completed 194 | # Start of transaction 195 | if self.apm_client: 196 | self.apm_client.begin_transaction("Mark alerts complete") 197 | 198 | self.log.info(f'Marking all alerts between {self.start_ts} and {end_ts} as workflow completed...') 199 | wc_query = f"reporting_ts:[{self.start_ts} TO {end_ts}]" 200 | wc_operations = [(self.datastore.alert.UPDATE_SET, 'workflows_completed', True)] 201 | try: 202 | wc_count = self.datastore.alert.update_by_query(wc_query, wc_operations) 203 | if self.apm_client: 204 | elasticapm.label(affected_alerts=wc_count) 205 | 206 | if wc_count: 207 | self.log.info(f"{count} Alert(s) workflows marked as completed.") 208 | 209 | # End of transaction 210 | if self.apm_client: 211 | self.apm_client.end_transaction("workflows_completed", 'success') 212 | 213 | except SearchException as e: 214 | self.log.warning(f"Failed to update alerts workflows_completed field. [{str(e)}]") 215 | 216 | # End of transaction 217 | if self.apm_client: 218 | self.apm_client.end_transaction("workflows_completed", 'search_exception') 219 | 220 | else: 221 | self.log.info("Skipping all workflows since there where no new alerts in the specified time period.") 222 | 223 | if run_once: 224 | break 225 | time.sleep(30) 226 | self.start_ts = end_ts 227 | 228 | 229 | if __name__ == "__main__": 230 | with WorkflowManager() as wm: 231 | wm.serve_forever() 232 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 0.5% -------------------------------------------------------------------------------- /deployment/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG branch=latest 2 | ARG base=cccs/assemblyline 3 | FROM $base:$branch 4 | ARG version 5 | 6 | # Install assemblyline base (setup.py is just a file we know exists so the command 7 | # won't fail if dist isn't there. The dist* copies in any dist directory only if it exists.) 8 | COPY setup.py dist* dist/ 9 | RUN pip install --no-cache-dir -f dist/ --user assemblyline-core==$version && rm -rf ~/.cache/pip 10 | -------------------------------------------------------------------------------- /pipelines/azure-tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | trigger: ["*"] 4 | pr: ["*"] 5 | 6 | pool: 7 | vmImage: "ubuntu-latest" 8 | 9 | variables: 10 | # Try to checkout the matching branch, if the command fails, don't care. 11 | BRANCH_NAME: $[coalesce(variables['System.PullRequest.SourceBranch'], variables['System.PullRequest.TargetBranch'], replace(variables['Build.SourceBranch'], 'refs/heads/', ''))] 12 | 13 | resources: 14 | containers: 15 | - container: redis 16 | image: redis 17 | ports: 18 | - 6379:6379 19 | - container: elasticsearch 20 | image: docker.elastic.co/elasticsearch/elasticsearch:8.10.2 21 | env: 22 | xpack.security.enabled: true 23 | discovery.type: single-node 24 | ES_JAVA_OPTS: "-Xms256m -Xmx512m" 25 | ELASTIC_PASSWORD: devpass 26 | ports: 27 | - 9200:9200 28 | repositories: 29 | - repository: assemblyline-base 30 | type: github 31 | endpoint: github-repo-sa 32 | name: CybercentreCanada/assemblyline-base 33 | 34 | jobs: 35 | - job: run_test 36 | strategy: 37 | matrix: 38 | Python3_9: 39 | python.version: "3.9" 40 | Python3_10: 41 | python.version: "3.10" 42 | Python3_11: 43 | python.version: "3.11" 44 | Python3_12: 45 | python.version: "3.12" 46 | timeoutInMinutes: 10 47 | services: 48 | elasticsearch: elasticsearch 49 | redis: redis 50 | 51 | steps: 52 | - task: UsePythonVersion@0 53 | displayName: Set python version 54 | inputs: 55 | versionSpec: "$(python.version)" 56 | - checkout: self 57 | - checkout: assemblyline-base 58 | - script: | 59 | sudo apt-get update 60 | sudo apt-get install -y build-essential libffi-dev libfuzzy-dev python3-dev git 61 | sudo mkdir -p /etc/assemblyline/ 62 | sudo mkdir -p /var/cache/assemblyline/ 63 | sudo cp pipelines/config.yml /etc/assemblyline 64 | sudo chmod a+rw /var/cache/assemblyline/ 65 | sudo env "PATH=$PATH" "PIP_USE_PEP517=true" python -m pip install --no-cache-dir -U pip cython setuptools wheel 66 | workingDirectory: $(Pipeline.Workspace)/s/assemblyline-core 67 | displayName: Setup Environment 68 | - script: | 69 | set -xv 70 | git checkout -b $BRANCH_NAME -t origin/$BRANCH_NAME || true 71 | git status 72 | sudo env "PATH=$PATH" "PIP_USE_PEP517=true" python -m pip install --no-cache-dir -e . 73 | displayName: Install assemblyline 74 | workingDirectory: $(Pipeline.Workspace)/s/assemblyline-base 75 | - script: | 76 | sudo env "PATH=$PATH" "PIP_USE_PEP517=true" python -m pip install --no-cache-dir -e .[test] 77 | displayName: Install assemblyline_core 78 | workingDirectory: $(Pipeline.Workspace)/s/assemblyline-core 79 | - script: sudo env "PATH=$PATH" python -m pytest -x -rsx -vv 80 | displayName: Test 81 | workingDirectory: $(Pipeline.Workspace)/s/assemblyline-core 82 | -------------------------------------------------------------------------------- /pipelines/config.yml: -------------------------------------------------------------------------------- 1 | filestore: 2 | cache: 3 | - file:///var/cache/assemblyline/ 4 | storage: 5 | - file:///var/cache/assemblyline/ 6 | archive: 7 | - file:///var/cache/assemblyline/ 8 | core: 9 | redis: 10 | nonpersistent: 11 | host: localhost 12 | port: 6379 13 | persistent: 14 | host: localhost 15 | port: 6379 16 | datastore: 17 | ilm: 18 | enabled: true 19 | hosts: ["http://elastic:devpass@localhost:9200"] 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = test 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from setuptools import setup, find_packages 5 | 6 | # Try to load the version from a datafile in the package 7 | package_version = "4.0.0.dev0" 8 | package_version_path = os.path.join(os.path.dirname(__file__), 'assemblyline_core', 'VERSION') 9 | if os.path.exists(package_version_path): 10 | with open(package_version_path) as package_version_file: 11 | package_version = package_version_file.read().strip() 12 | 13 | # read the contents of your README file 14 | this_directory = os.path.abspath(os.path.dirname(__file__)) 15 | with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: 16 | long_description = f.read() 17 | 18 | setup( 19 | name="assemblyline-core", 20 | version=package_version, 21 | description="Assemblyline 4 - Core components", 22 | long_description=long_description, 23 | long_description_content_type='text/markdown', 24 | url="https://github.com/CybercentreCanada/assemblyline-core/", 25 | author="CCCS Assemblyline development team", 26 | author_email="assemblyline@cyber.gc.ca", 27 | license="MIT", 28 | classifiers=[ 29 | 'Development Status :: 5 - Production/Stable', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Software Development :: Libraries', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Programming Language :: Python :: 3.9', 36 | 'Programming Language :: Python :: 3.10', 37 | 'Programming Language :: Python :: 3.11', 38 | 'Programming Language :: Python :: 3.12', 39 | ], 40 | keywords="assemblyline automated malware analysis gc canada cse-cst cse cst cyber cccs", 41 | packages=find_packages(exclude=['deployment/*', 'test/*']), 42 | install_requires=[ 43 | 'assemblyline', 44 | 'docker', 45 | 'kubernetes', 46 | ], 47 | extras_require={ 48 | 'test': [ 49 | 'pytest', 50 | 'assemblyline_client' 51 | ] 52 | }, 53 | tests_require=[ 54 | 'pytest', 55 | ], 56 | package_data={ 57 | '': ["*classification.yml", "*.magic", "VERSION"] 58 | } 59 | ) 60 | -------------------------------------------------------------------------------- /test/classification.yml: -------------------------------------------------------------------------------- 1 | 2 | enforce: true -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytest configuration file, setup global pytest fixtures and functions here. 3 | """ 4 | import os 5 | import pytest 6 | 7 | from redis.exceptions import ConnectionError 8 | 9 | from assemblyline.common import forge 10 | from assemblyline.datastore.helper import AssemblylineDatastore 11 | from assemblyline.datastore.store import ESStore 12 | from assemblyline.odm.models.config import Config 13 | 14 | original_classification = forge.get_classification 15 | 16 | 17 | def test_classification(yml_config=None): 18 | path = os.path.join(os.path.dirname(__file__), 'classification.yml') 19 | return original_classification(path) 20 | 21 | 22 | forge.get_classification = test_classification 23 | 24 | 25 | original_skip = pytest.skip 26 | 27 | # Check if we are in an unattended build environment where skips won't be noticed 28 | IN_CI_ENVIRONMENT = any(indicator in os.environ for indicator in 29 | ['CI', 'BITBUCKET_BUILD_NUMBER', 'AGENT_JOBSTATUS']) 30 | 31 | 32 | def skip_or_fail(message): 33 | """Skip or fail the current test, based on the environment""" 34 | if IN_CI_ENVIRONMENT: 35 | pytest.fail(message) 36 | else: 37 | original_skip(message) 38 | 39 | 40 | # Replace the built in skip function with our own 41 | pytest.skip = skip_or_fail 42 | 43 | 44 | @pytest.fixture(scope='session') 45 | def config(): 46 | config = forge.get_config() 47 | config.logging.log_level = 'INFO' 48 | config.logging.log_as_json = False 49 | config.core.metrics.apm_server.server_url = None 50 | config.core.metrics.export_interval = 1 51 | config.datastore.archive.enabled = True 52 | return config 53 | 54 | 55 | @pytest.fixture(scope='module') 56 | def datastore_connection(config: Config): 57 | store = ESStore(config.datastore.hosts) 58 | ret_val = store.ping() 59 | if not ret_val: 60 | pytest.skip("Could not connect to datastore") 61 | return AssemblylineDatastore(store) 62 | 63 | 64 | @pytest.fixture(scope='module') 65 | def clean_datastore(datastore_connection: AssemblylineDatastore): 66 | for name in datastore_connection.ds.get_models(): 67 | datastore_connection.get_collection(name).wipe() 68 | return datastore_connection 69 | 70 | 71 | @pytest.fixture(scope='function') 72 | def function_clean_datastore(datastore_connection: AssemblylineDatastore): 73 | for name in datastore_connection.ds.get_models(): 74 | datastore_connection.get_collection(name).wipe() 75 | return datastore_connection 76 | 77 | 78 | @pytest.fixture(scope='module') 79 | def redis_connection(): 80 | from assemblyline.remote.datatypes import get_client 81 | c = get_client(None, None, False) 82 | try: 83 | ret_val = c.ping() 84 | if ret_val: 85 | return c 86 | except ConnectionError: 87 | pass 88 | 89 | return pytest.skip("Connection to the Redis server failed. This test cannot be performed...") 90 | 91 | 92 | @pytest.fixture(scope='function') 93 | def clean_redis(redis_connection): 94 | try: 95 | redis_connection.flushdb() 96 | yield redis_connection 97 | finally: 98 | redis_connection.flushdb() 99 | 100 | 101 | @pytest.fixture(scope='module') 102 | def filestore(config): 103 | try: 104 | return forge.get_filestore(config, connection_attempts=1) 105 | except ConnectionError as err: 106 | pytest.skip(str(err)) 107 | -------------------------------------------------------------------------------- /test/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | elasticsearch: 5 | image: docker.elastic.co/elasticsearch/elasticsearch:8.10.2 6 | environment: 7 | - xpack.security.enabled=true 8 | - discovery.type=single-node 9 | - logger.level=WARN 10 | - "ES_JAVA_OPTS=-Xms512m -Xmx512m" 11 | - ELASTIC_PASSWORD=devpass 12 | ports: 13 | - "9200:9200" 14 | 15 | minio: 16 | image: minio/minio 17 | environment: 18 | MINIO_ROOT_USER: al_storage_key 19 | MINIO_ROOT_PASSWORD: Ch@ngeTh!sPa33w0rd 20 | ports: 21 | - "9000:9000" 22 | command: server /data 23 | 24 | redis: 25 | image: redis 26 | ports: 27 | - "6379:6379" 28 | - "6380:6379" 29 | -------------------------------------------------------------------------------- /test/mocking/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class MockFactory: 4 | def __init__(self, mock_type): 5 | self.type = mock_type 6 | self.mocks = {} 7 | 8 | def __call__(self, name, *args): 9 | if name not in self.mocks: 10 | self.mocks[name] = self.type(name, *args) 11 | return self.mocks[name] 12 | 13 | def __getitem__(self, name): 14 | return self.mocks[name] 15 | 16 | def __len__(self): 17 | return len(self.mocks) 18 | 19 | def flush(self): 20 | self.mocks.clear() 21 | 22 | 23 | class TrueCountTimes: 24 | """A helper object that replaces a boolean. 25 | 26 | After being read a fixed number of times this object switches to false. 27 | """ 28 | def __init__(self, count): 29 | self.counter = count 30 | 31 | def __bool__(self): 32 | self.counter -= 1 33 | return self.counter >= 0 34 | 35 | 36 | class ToggleTrue: 37 | """A helper object that replaces a boolean. 38 | 39 | After every read the value switches from true to false. First call is true. 40 | """ 41 | def __init__(self): 42 | self.next = True 43 | 44 | def __bool__(self): 45 | self.next = not self.next 46 | return not self.next 47 | -------------------------------------------------------------------------------- /test/mocking/random_service.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | from assemblyline.common.forge import CachedObject 5 | 6 | from assemblyline.common.constants import SERVICE_STATE_HASH, ServiceStatus 7 | from assemblyline.common import forge 8 | from assemblyline.common.uid import get_random_id 9 | from assemblyline.remote.datatypes.hash import ExpiringHash 10 | from assemblyline_core.dispatching.client import DispatchClient 11 | from assemblyline_core.server_base import ServerBase 12 | from assemblyline.common.isotime import now_as_iso 13 | from assemblyline.common.metrics import MetricsFactory 14 | from assemblyline.odm.messages.task import Task as ServiceTask 15 | from assemblyline.odm.messages.service_heartbeat import Metrics 16 | from assemblyline.odm.models.error import Error 17 | from assemblyline.odm.models.file import File 18 | from assemblyline.odm.models.result import Result 19 | from assemblyline.odm.randomizer import random_model_obj, random_minimal_obj 20 | from assemblyline.remote.datatypes.queues.priority import select 21 | 22 | 23 | class RandomService(ServerBase): 24 | """Replaces everything past the dispatcher. 25 | 26 | Including service API, in the future probably include that in this test. 27 | """ 28 | 29 | def __init__(self, datastore=None, filestore=None): 30 | super().__init__('assemblyline.randomservice') 31 | self.config = forge.get_config() 32 | self.datastore = datastore or forge.get_datastore() 33 | self.filestore = filestore or forge.get_filestore() 34 | self.client_id = get_random_id() 35 | self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60) 36 | 37 | self.counters = {n: MetricsFactory('service', Metrics, name=n, config=self.config) 38 | for n in self.datastore.service_delta.keys()} 39 | self.queues = [forge.get_service_queue(name) for name in self.datastore.service_delta.keys()] 40 | self.dispatch_client = DispatchClient(self.datastore) 41 | self.service_info = CachedObject(self.datastore.list_all_services, kwargs={'as_obj': False}) 42 | 43 | def run(self): 44 | self.log.info("Random service result generator ready!") 45 | self.log.info("Monitoring queues:") 46 | for q in self.queues: 47 | self.log.info(f"\t{q.name}") 48 | 49 | self.log.info("Waiting for messages...") 50 | while self.running: 51 | # Reset Idle flags 52 | for s in self.service_info: 53 | if s['enabled']: 54 | self.service_state_hash.set(f"{self.client_id}_{s['name']}", 55 | (s['name'], ServiceStatus.Idle, time.time() + 30 + 5)) 56 | 57 | message = select(*self.queues, timeout=1) 58 | if not message: 59 | continue 60 | 61 | if self.config.submission.dtl: 62 | expiry_ts = now_as_iso(self.config.submission.dtl * 24 * 60 * 60) 63 | else: 64 | expiry_ts = None 65 | queue, msg = message 66 | task = ServiceTask(msg) 67 | 68 | if not self.dispatch_client.running_tasks.add(task.key(), task.as_primitives()): 69 | continue 70 | 71 | # Set service busy flag 72 | self.service_state_hash.set(f"{self.client_id}_{task.service_name}", 73 | (task.service_name, ServiceStatus.Running, time.time() + 30 + 5)) 74 | 75 | # METRICS 76 | self.counters[task.service_name].increment('execute') 77 | # METRICS (not caching here so always miss) 78 | self.counters[task.service_name].increment('cache_miss') 79 | 80 | self.log.info(f"\tQueue {queue} received a new task for sid {task.sid}.") 81 | action = random.randint(1, 10) 82 | if action >= 2: 83 | if action > 8: 84 | result = random_minimal_obj(Result) 85 | else: 86 | result = random_model_obj(Result) 87 | result.sha256 = task.fileinfo.sha256 88 | result.response.service_name = task.service_name 89 | result.archive_ts = None 90 | result.expiry_ts = expiry_ts 91 | result.response.extracted = result.response.extracted[task.depth+2:] 92 | result.response.supplementary = result.response.supplementary[task.depth+2:] 93 | result_key = Result.help_build_key(sha256=task.fileinfo.sha256, 94 | service_name=task.service_name, 95 | service_version='0', 96 | is_empty=result.is_empty()) 97 | 98 | self.log.info(f"\t\tA result was generated for this task: {result_key}") 99 | 100 | new_files = result.response.extracted + result.response.supplementary 101 | for f in new_files: 102 | if not self.datastore.file.get(f.sha256): 103 | random_file = random_model_obj(File) 104 | random_file.archive_ts = None 105 | random_file.expiry_ts = expiry_ts 106 | random_file.sha256 = f.sha256 107 | self.datastore.file.save(f.sha256, random_file) 108 | if not self.filestore.exists(f.sha256): 109 | self.filestore.put(f.sha256, f.sha256) 110 | 111 | time.sleep(random.randint(0, 2)) 112 | 113 | self.dispatch_client.service_finished(task.sid, result_key, result) 114 | 115 | # METRICS 116 | if result.result.score > 0: 117 | self.counters[task.service_name].increment('scored') 118 | else: 119 | self.counters[task.service_name].increment('not_scored') 120 | 121 | else: 122 | error = random_model_obj(Error) 123 | error.archive_ts = None 124 | error.expiry_ts = expiry_ts 125 | error.sha256 = task.fileinfo.sha256 126 | error.response.service_name = task.service_name 127 | error.type = random.choice(["EXCEPTION", "SERVICE DOWN", "SERVICE BUSY"]) 128 | 129 | error_key = error.build_key('0') 130 | 131 | self.log.info(f"\t\tA {error.response.status}:{error.type} " 132 | f"error was generated for this task: {error_key}") 133 | 134 | self.dispatch_client.service_failed(task.sid, error_key, error) 135 | 136 | # METRICS 137 | if error.response.status == "FAIL_RECOVERABLE": 138 | self.counters[task.service_name].increment('fail_recoverable') 139 | else: 140 | self.counters[task.service_name].increment('fail_nonrecoverable') 141 | 142 | 143 | if __name__ == "__main__": 144 | RandomService().serve_forever() 145 | -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | # Only used in test files 2 | pytest 3 | assemblyline_client 4 | -------------------------------------------------------------------------------- /test/test_alerter.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | import uuid 4 | 5 | import pytest 6 | 7 | from assemblyline_core.alerter.run_alerter import Alerter 8 | from assemblyline_core.ingester.ingester import IngestTask 9 | from assemblyline.common import forge 10 | from assemblyline.common.uid import get_random_id 11 | from assemblyline.odm.models.submission import Submission 12 | from assemblyline.odm.models.tagging import Tagging 13 | from assemblyline.odm.random_data import wipe_submissions, create_submission 14 | from assemblyline.odm.randomizer import random_model_obj, get_random_tags 15 | from assemblyline.remote.datatypes import get_client 16 | from assemblyline.remote.datatypes.queues.named import NamedQueue 17 | 18 | NUM_SUBMISSIONS = 3 19 | all_submissions = [] 20 | 21 | 22 | @pytest.fixture(scope='module') 23 | def fs(config): 24 | return forge.get_filestore(config) 25 | 26 | 27 | def recursive_extend(d, u): 28 | for k, v in u.items(): 29 | if isinstance(v, collections.abc.Mapping): 30 | d[k] = recursive_extend(d.get(k, {}), v) 31 | else: 32 | if k not in d: 33 | d[k] = [] 34 | d[k].extend(v) 35 | 36 | return d 37 | 38 | 39 | @pytest.fixture(scope="module") 40 | def datastore(request, datastore_connection, fs): 41 | for _ in range(NUM_SUBMISSIONS): 42 | all_submissions.append(create_submission(datastore_connection, fs)) 43 | 44 | try: 45 | yield datastore_connection 46 | finally: 47 | wipe_submissions(datastore_connection, fs) 48 | datastore_connection.alert.wipe() 49 | 50 | 51 | def test_create_single_alert(config, datastore): 52 | persistent_redis = get_client( 53 | host=config.core.redis.persistent.host, 54 | port=config.core.redis.persistent.port, 55 | private=False, 56 | ) 57 | alerter = Alerter() 58 | # Swap our alerter onto a private queue so our test doesn't get intercepted 59 | alerter.alert_queue = alert_queue = NamedQueue(uuid.uuid4().hex, persistent_redis) 60 | 61 | # Get a random submission 62 | submission = random.choice(all_submissions) 63 | all_submissions.remove(submission) 64 | 65 | # Generate a task for the submission 66 | ingest_msg = random_model_obj(IngestTask) 67 | ingest_msg.submission.sid = submission.sid 68 | ingest_msg.submission.metadata = submission.metadata 69 | ingest_msg.submission.params = submission.params 70 | ingest_msg.submission.files = submission.files 71 | 72 | alert_queue.push(ingest_msg.as_primitives()) 73 | alert_type = alerter.run_once() 74 | assert alert_type == 'create' 75 | datastore.alert.commit() 76 | 77 | res = datastore.alert.search("id:*", as_obj=False) 78 | assert res['total'] == 1 79 | 80 | alert = datastore.alert.get(res['items'][0]['alert_id']) 81 | assert alert.sid == submission.sid 82 | 83 | 84 | def test_update_single_alert(config, datastore, delete_original=False): 85 | persistent_redis = get_client( 86 | host=config.core.redis.persistent.host, 87 | port=config.core.redis.persistent.port, 88 | private=False, 89 | ) 90 | alerter = Alerter() 91 | # Swap our alerter onto a private queue so our test doesn't get intercepted 92 | alerter.alert_queue = alert_queue = NamedQueue(uuid.uuid4().hex, persistent_redis) 93 | 94 | # Get a random submission 95 | submission = random.choice(all_submissions) 96 | all_submissions.remove(submission) 97 | 98 | # Generate a task for the submission 99 | ingest_msg = random_model_obj(IngestTask) 100 | ingest_msg.submission.sid = submission.sid 101 | ingest_msg.submission.metadata = submission.metadata 102 | ingest_msg.submission.params = submission.params 103 | ingest_msg.submission.files = submission.files 104 | 105 | alert_queue.push(ingest_msg.as_primitives()) 106 | alert_type = alerter.run_once() 107 | assert alert_type == 'create' 108 | datastore.alert.commit() 109 | 110 | original_alert = datastore.alert.get(datastore.alert.search(f"sid:{submission.sid}", fl="id", 111 | as_obj=False)['items'][0]['id']) 112 | assert original_alert is not None 113 | 114 | # Generate a children task 115 | child_submission = Submission(submission.as_primitives()) 116 | child_submission.sid = get_random_id() 117 | child_submission.params.psid = submission.sid 118 | 119 | # Alter the result of one of the services 120 | r = None 121 | while r is None: 122 | r = datastore.result.get(random.choice(child_submission.results)) 123 | 124 | for s in r.result.sections: 125 | old_tags = s.tags.as_primitives(strip_null=True) 126 | s.tags = Tagging(recursive_extend(old_tags, get_random_tags())) 127 | 128 | datastore.result.save(r.build_key(), r) 129 | datastore.result.commit() 130 | 131 | datastore.submission.save(child_submission.sid, child_submission) 132 | datastore.submission.commit() 133 | 134 | child_ingest_msg = random_model_obj(IngestTask) 135 | child_ingest_msg.submission.sid = child_submission.sid 136 | child_ingest_msg.submission.metadata = child_submission.metadata 137 | child_ingest_msg.submission.params = child_submission.params 138 | child_ingest_msg.submission.files = child_submission.files 139 | child_ingest_msg.submission.time = ingest_msg.submission.time 140 | child_ingest_msg.ingest_id = ingest_msg.ingest_id 141 | 142 | alert_type_assertion = 'update' 143 | 144 | if delete_original: 145 | datastore.alert.delete(original_alert['alert_id']) 146 | alert_type_assertion = 'create' 147 | 148 | alert_queue.push(child_ingest_msg.as_primitives()) 149 | alert_type = alerter.run_once() 150 | assert alert_type == alert_type_assertion 151 | 152 | datastore.alert.commit() 153 | 154 | updated_alert = datastore.alert.get(datastore.alert.search(f"sid:{child_submission.sid}", 155 | fl="id", as_obj=False)['items'][0]['id']) 156 | assert updated_alert is not None 157 | 158 | assert updated_alert != original_alert 159 | 160 | 161 | def test_update_expired_alert(config, datastore): 162 | # If we're attempting to update an alert that has either expired or was removed from the 'alert' collection 163 | # Alert should be created in it's stead 164 | test_update_single_alert(config, datastore, delete_original=True) 165 | -------------------------------------------------------------------------------- /test/test_badlist_client.py: -------------------------------------------------------------------------------- 1 | 2 | import hashlib 3 | import random 4 | import time 5 | from copy import deepcopy 6 | 7 | import pytest 8 | from assemblyline.common.forge import get_classification 9 | from assemblyline.common.isotime import iso_to_epoch 10 | from assemblyline.odm.random_data import ( 11 | create_badlists, 12 | create_users, 13 | wipe_badlist, 14 | wipe_users, 15 | ) 16 | from assemblyline.odm.randomizer import get_random_hash 17 | from assemblyline_core.badlist_client import BadlistClient, InvalidBadhash 18 | 19 | CLASSIFICATION = get_classification() 20 | 21 | add_hash_file = "10" + get_random_hash(62) 22 | add_error_hash = "11" + get_random_hash(62) 23 | update_hash = "12" + get_random_hash(62) 24 | update_conflict_hash = "13" + get_random_hash(62) 25 | source_hash = "14" + get_random_hash(62) 26 | 27 | BAD_SOURCE = { 28 | "classification": CLASSIFICATION.UNRESTRICTED, 29 | "name": "BAD", 30 | "reason": [ 31 | "2nd stage for implant BAD", 32 | "Used by actor BLAH!" 33 | ], 34 | "type": "external"} 35 | 36 | BAD2_SOURCE = { 37 | "classification": CLASSIFICATION.UNRESTRICTED, 38 | "name": "BAD2", 39 | "reason": [ 40 | "Use for phishing" 41 | ], 42 | "type": "external"} 43 | 44 | ADMIN_SOURCE = { 45 | "classification": CLASSIFICATION.UNRESTRICTED, 46 | "name": "admin", 47 | "reason": [ 48 | "It's denifitely bad", 49 | ], 50 | "type": "user"} 51 | 52 | USER_SOURCE = { 53 | "classification": CLASSIFICATION.UNRESTRICTED, 54 | "name": "user", 55 | "reason": [ 56 | "I just feel like it!", 57 | "I just feel like it!", 58 | ], 59 | "type": "user"} 60 | 61 | 62 | @pytest.fixture(scope="module") 63 | def client(datastore_connection): 64 | try: 65 | create_users(datastore_connection) 66 | create_badlists(datastore_connection) 67 | yield BadlistClient(datastore_connection) 68 | finally: 69 | wipe_users(datastore_connection) 70 | wipe_badlist(datastore_connection) 71 | 72 | 73 | # noinspection PyUnusedLocal 74 | def test_badlist_add_file(client): 75 | # Generate a random badlist 76 | sl_data = { 77 | 'attribution': None, 78 | 'hashes': {'md5': get_random_hash(32), 79 | 'sha1': get_random_hash(40), 80 | 'sha256': add_hash_file, 81 | 'ssdeep': None, 82 | 'tlsh': None}, 83 | 'file': {'name': ['file.txt'], 84 | 'size': random.randint(128, 4096), 85 | 'type': 'document/text'}, 86 | 'sources': [BAD_SOURCE, ADMIN_SOURCE], 87 | 'type': 'file' 88 | } 89 | sl_data_original = deepcopy(sl_data) 90 | 91 | # Insert it and test return value 92 | qhash, op = client.add_update(sl_data) 93 | assert qhash == add_hash_file 94 | assert op == 'add' 95 | 96 | # Load inserted data from DB 97 | ds_sl = client.datastore.badlist.get(add_hash_file, as_obj=False) 98 | 99 | # Test dates 100 | added = ds_sl.pop('added', None) 101 | updated = ds_sl.pop('updated', None) 102 | 103 | # File item will live forever 104 | assert ds_sl.pop('expiry_ts', None) is None 105 | 106 | assert added == updated 107 | assert added is not None and updated is not None 108 | 109 | # Make sure tag is none 110 | tag = ds_sl.pop('tag', None) 111 | assert tag is None 112 | 113 | # Test classification 114 | classification = ds_sl.pop('classification', None) 115 | assert classification is not None 116 | 117 | # Test enabled 118 | enabled = ds_sl.pop('enabled', None) 119 | assert enabled 120 | 121 | # Normalize classification in sources 122 | for source in ds_sl['sources']: 123 | source['classification'] = CLASSIFICATION.normalize_classification(source['classification']) 124 | 125 | # Test rest 126 | assert ds_sl == sl_data_original 127 | 128 | 129 | def test_badlist_add_tag(client): 130 | tag_type = 'network.static.ip' 131 | tag_value = '127.0.0.1' 132 | hashed_value = f"{tag_type}: {tag_value}".encode('utf8') 133 | expected_qhash = hashlib.sha256(hashed_value).hexdigest() 134 | 135 | # Generate a random badlist 136 | sl_data = { 137 | 'attribution': { 138 | 'actor': ["SOMEONE!"], 139 | 'campaign': None, 140 | 'category': None, 141 | 'exploit': None, 142 | 'implant': None, 143 | 'family': None, 144 | 'network': None 145 | }, 146 | 'dtl': 15, 147 | 'hashes': {'sha256': expected_qhash}, 148 | 'tag': {'type': tag_type, 149 | 'value': tag_value}, 150 | 'sources': [BAD_SOURCE, ADMIN_SOURCE], 151 | 'type': 'tag' 152 | } 153 | sl_data_original = deepcopy(sl_data) 154 | 155 | # Insert it and test return value 156 | qhash, op = client.add_update(sl_data) 157 | assert qhash == expected_qhash 158 | assert op == 'add' 159 | 160 | # Load inserted data from DB 161 | ds_sl = client.datastore.badlist.get(expected_qhash, as_obj=False) 162 | 163 | # Test dates 164 | added = ds_sl.pop('added', None) 165 | updated = ds_sl.pop('updated', None) 166 | 167 | # Tag item will live up to a certain date 168 | assert ds_sl.pop('expiry_ts', None) is not None 169 | 170 | assert added == updated 171 | assert added is not None and updated is not None 172 | 173 | # Make sure file is None 174 | file = ds_sl.pop('file', {}) 175 | assert file is None 176 | 177 | # Test classification 178 | classification = ds_sl.pop('classification', None) 179 | assert classification is not None 180 | 181 | # Test enabled 182 | enabled = ds_sl.pop('enabled', None) 183 | assert enabled 184 | 185 | # Test rest, dtl should not exist anymore 186 | sl_data_original.pop('dtl', None) 187 | 188 | # Normalize classification in sources 189 | for source in ds_sl['sources']: 190 | source['classification'] = CLASSIFICATION.normalize_classification(source['classification']) 191 | 192 | for hashtype in ['md5', 'sha1', 'ssdeep', 'tlsh']: 193 | ds_sl['hashes'].pop(hashtype, None) 194 | 195 | # Test rest 196 | assert ds_sl == sl_data_original 197 | 198 | 199 | def test_badlist_add_invalid(client): 200 | # Generate a random badlist 201 | sl_data = { 202 | 'hashes': {'sha256': add_error_hash}, 203 | 'sources': [USER_SOURCE], 204 | 'type': 'file'} 205 | 206 | # Insert it and test return value 207 | with pytest.raises(ValueError) as conflict_exc: 208 | client.add_update(sl_data, user={"uname": "test"}) 209 | 210 | assert 'for another user' in conflict_exc.value.args[0] 211 | 212 | 213 | def test_badlist_update(client): 214 | # Generate a random badlist 215 | sl_data = { 216 | 'attribution': { 217 | 'actor': None, 218 | 'campaign': None, 219 | 'category': None, 220 | 'exploit': None, 221 | 'implant': ['BAD'], 222 | 'family': None, 223 | 'network': None}, 224 | 'hashes': {'md5': get_random_hash(32), 225 | 'sha1': get_random_hash(40), 226 | 'sha256': update_hash, 227 | 'ssdeep': None, 228 | 'tlsh': None}, 229 | 'file': {'name': [], 230 | 'size': random.randint(128, 4096), 231 | 'type': 'document/text'}, 232 | 'sources': [BAD_SOURCE], 233 | 'type': 'file' 234 | } 235 | sl_data_original = deepcopy(sl_data) 236 | 237 | # Insert it and test return value 238 | qhash, op = client.add_update(sl_data) 239 | assert qhash == update_hash 240 | assert op == 'add' 241 | 242 | # Load inserted data from DB 243 | ds_sl = client.datastore.badlist.get(update_hash, as_obj=False) 244 | 245 | # Normalize classification in sources 246 | for source in ds_sl['sources']: 247 | source['classification'] = CLASSIFICATION.normalize_classification(source['classification']) 248 | 249 | # Test rest 250 | assert {k: v for k, v in ds_sl.items() if k in sl_data_original} == sl_data_original 251 | 252 | u_data = { 253 | 'attribution': {'implant': ['TEST'], 'actor': ['TEST']}, 254 | 'hashes': {'sha256': update_hash, 'tlsh': 'faketlsh'}, 255 | 'sources': [USER_SOURCE], 256 | 'type': 'file' 257 | } 258 | 259 | # Insert it and test return value 260 | qhash, op = client.add_update(u_data) 261 | assert qhash == update_hash 262 | assert op == 'update' 263 | 264 | # Load inserted data from DB 265 | ds_u = client.datastore.badlist.get(update_hash, as_obj=False) 266 | 267 | # Normalize classification in sources 268 | for source in ds_u['sources']: 269 | source['classification'] = CLASSIFICATION.normalize_classification(source['classification']) 270 | 271 | assert ds_u['added'] == ds_sl['added'] 272 | assert iso_to_epoch(ds_u['updated']) > iso_to_epoch(ds_sl['updated']) 273 | assert len(ds_u['sources']) == 2 274 | assert USER_SOURCE in ds_u['sources'] 275 | assert BAD_SOURCE in ds_u['sources'] 276 | assert 'TEST' in ds_u['attribution']['implant'] 277 | assert 'BAD' in ds_u['attribution']['implant'] 278 | assert 'TEST' in ds_u['attribution']['actor'] 279 | assert 'faketlsh' in ds_u['hashes']['tlsh'] 280 | 281 | 282 | def test_badlist_update_conflict(client): 283 | # Generate a random badlist 284 | sl_data = {'hashes': {'sha256': update_conflict_hash}, 'file': {}, 'sources': [ADMIN_SOURCE], 'type': 'file'} 285 | 286 | # Insert it and test return value 287 | qhash, op = client.add_update(sl_data) 288 | assert qhash == update_conflict_hash 289 | assert op == 'add' 290 | 291 | # Insert the same source with a different type 292 | sl_data['sources'][0]['type'] = 'external' 293 | with pytest.raises(InvalidBadhash) as conflict_exc: 294 | client.add_update(sl_data) 295 | 296 | assert 'has a type conflict:' in conflict_exc.value.args[0] 297 | 298 | def test_badlist_tag_normalization(client): 299 | tag_type = 'network.static.uri' 300 | tag_value = 'https://BaD.com/About' 301 | 302 | normalized_value = 'https://bad.com/About' 303 | hashed_value = f"{tag_type}: {normalized_value}".encode('utf8') 304 | expected_qhash = hashlib.sha256(hashed_value).hexdigest() 305 | 306 | # Generate a random badlist 307 | sl_data = { 308 | 'attribution': { 309 | 'actor': ["SOMEONE!"], 310 | 'campaign': None, 311 | 'category': None, 312 | 'exploit': None, 313 | 'implant': None, 314 | 'family': None, 315 | 'network': None 316 | }, 317 | 'dtl': 15, 318 | 'tag': {'type': tag_type, 319 | 'value': tag_value}, 320 | 'sources': [BAD_SOURCE, ADMIN_SOURCE], 321 | 'type': 'tag' 322 | } 323 | 324 | client.add_update(sl_data) 325 | 326 | # Assert that item got created with the expected ID from the normalized tag value 327 | assert client.datastore.badlist.exists(expected_qhash) 328 | time.sleep(1) 329 | 330 | # Assert that the tag exists in either format (within reason) 331 | assert client.exists_tags({tag_type: [tag_value]}) 332 | assert client.exists_tags({tag_type: [normalized_value]}) 333 | -------------------------------------------------------------------------------- /test/test_expiry.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import random 4 | import concurrent.futures 5 | 6 | from assemblyline.common.isotime import now_as_iso 7 | from assemblyline.datastore.helper import AssemblylineDatastore 8 | from assemblyline.odm.randomizer import random_model_obj 9 | 10 | from assemblyline_core.expiry.run_expiry import ExpiryManager 11 | 12 | MAX_OBJECTS = 10 13 | MIN_OBJECTS = 2 14 | expiry_collections_len = {} 15 | archive_collections_len = {} 16 | 17 | 18 | def purge_data(datastore_connection: AssemblylineDatastore): 19 | for name, definition in datastore_connection.ds.get_models().items(): 20 | if hasattr(definition, 'expiry_ts'): 21 | getattr(datastore_connection, name).wipe() 22 | 23 | 24 | @pytest.fixture(scope="function") 25 | def ds_expiry(request, datastore_connection): 26 | for name, definition in datastore_connection.ds.get_models().items(): 27 | if hasattr(definition, 'expiry_ts'): 28 | collection = getattr(datastore_connection, name) 29 | collection.wipe() 30 | expiry_len = random.randint(MIN_OBJECTS, MAX_OBJECTS) 31 | for x in range(expiry_len): 32 | obj = random_model_obj(collection.model_class) 33 | if hasattr(definition, 'from_archive'): 34 | obj.from_archive = False 35 | obj.expiry_ts = now_as_iso(-10000) 36 | collection.save('longer_name'+str(x), obj) 37 | 38 | expiry_collections_len[name] = expiry_len 39 | collection.commit() 40 | 41 | request.addfinalizer(lambda: purge_data(datastore_connection)) 42 | return datastore_connection 43 | 44 | 45 | class FakeCounter(object): 46 | def __init__(self): 47 | self.counts = {} 48 | 49 | def increment(self, name, increment_by=1): 50 | if name not in self.counts: 51 | self.counts[name] = 0 52 | 53 | self.counts[name] += increment_by 54 | 55 | def get(self, name): 56 | return self.counts.get(name, 0) 57 | 58 | 59 | def test_expire_all(config, ds_expiry, filestore): 60 | expiry = ExpiryManager(config=config, datastore=ds_expiry, filestore=filestore) 61 | expiry.running = True 62 | expiry.counter = FakeCounter() 63 | with concurrent.futures.ThreadPoolExecutor(5) as pool: 64 | for collection in expiry.expirable_collections: 65 | expiry.feed_expiry_jobs(collection=collection, pool=pool, start='*', jobs=[]) 66 | 67 | for k, v in expiry_collections_len.items(): 68 | assert v == expiry.counter.get(k) 69 | collection = getattr(ds_expiry, k) 70 | collection.commit() 71 | assert collection.search("id:*")['total'] == 0 72 | -------------------------------------------------------------------------------- /test/test_plumber.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | from unittest import mock 3 | 4 | from assemblyline_core.plumber.run_plumber import Plumber 5 | from assemblyline_core.server_base import ServiceStage 6 | from mocking import TrueCountTimes 7 | from redis import Redis 8 | 9 | from assemblyline.odm.messages.task import Task 10 | from assemblyline.odm.models.service import Service 11 | from assemblyline.odm.models.user import ApiKey, User 12 | from assemblyline.odm.random_data import random_model_obj 13 | 14 | 15 | def test_expire_missing_service(): 16 | redis = mock.MagicMock(spec=Redis) 17 | redis.keys.return_value = [b'service-queue-not-service-a'] 18 | redis.zcard.return_value = 0 19 | redis_persist = mock.MagicMock(spec=Redis) 20 | datastore = mock.MagicMock() 21 | 22 | service_a = random_model_obj(Service) 23 | service_a.name = 'a' 24 | service_a.enabled = True 25 | 26 | datastore.list_all_services.return_value = [service_a] 27 | datastore.ds.ca_certs = None 28 | datastore.ds.get_hosts.return_value = ["http://localhost:9200"] 29 | 30 | plumber = Plumber(redis=redis, redis_persist=redis_persist, datastore=datastore, delay=1) 31 | plumber.get_service_stage = mock.MagicMock(return_value=ServiceStage.Running) 32 | plumber.dispatch_client = mock.MagicMock() 33 | 34 | task = random_model_obj(Task) 35 | plumber.dispatch_client.request_work.side_effect = [task, None, None] 36 | 37 | plumber.running = TrueCountTimes(count=1) 38 | plumber.service_queue_plumbing() 39 | 40 | assert plumber.dispatch_client.service_failed.call_count == 1 41 | args = plumber.dispatch_client.service_failed.call_args 42 | assert args[0][0] == task.sid 43 | 44 | 45 | def test_flush_paused_queues(): 46 | redis = mock.MagicMock(spec=Redis) 47 | redis.keys.return_value = [b'service-queue-a'] 48 | redis.zcard.return_value = 0 49 | redis_persist = mock.MagicMock(spec=Redis) 50 | datastore = mock.MagicMock() 51 | 52 | service_a = random_model_obj(Service) 53 | service_a.name = 'a' 54 | service_a.enabled = True 55 | 56 | datastore.list_all_services.return_value = [service_a] 57 | datastore.ds.ca_certs = None 58 | datastore.ds.get_hosts.return_value = ["http://localhost:9200"] 59 | 60 | plumber = Plumber(redis=redis, redis_persist=redis_persist, datastore=datastore, delay=1) 61 | plumber.get_service_stage = mock.MagicMock(return_value=ServiceStage.Running) 62 | plumber.dispatch_client = mock.MagicMock() 63 | 64 | task = random_model_obj(Task) 65 | plumber.dispatch_client.request_work.side_effect = [task, None, None] 66 | 67 | plumber.running = TrueCountTimes(count=1) 68 | plumber.service_queue_plumbing() 69 | 70 | assert plumber.dispatch_client.service_failed.call_count == 0 71 | 72 | plumber.get_service_stage = mock.MagicMock(return_value=ServiceStage.Paused) 73 | 74 | plumber.running = TrueCountTimes(count=1) 75 | plumber.service_queue_plumbing() 76 | 77 | assert plumber.dispatch_client.service_failed.call_count == 1 78 | args = plumber.dispatch_client.service_failed.call_args 79 | assert args[0][0] == task.sid 80 | 81 | 82 | def test_cleanup_old_tasks(datastore_connection): 83 | # Create a bunch of random "old" tasks and clean them up 84 | redis = mock.MagicMock(spec=Redis) 85 | redis_persist = mock.MagicMock(spec=Redis) 86 | plumber = Plumber(redis=redis, redis_persist=redis_persist, datastore=datastore_connection, delay=1) 87 | 88 | # Generate new documents in .tasks index 89 | num_old_tasks = 10 90 | [plumber.datastore.ds.client.index(index=".tasks", document={ 91 | "completed": True, 92 | "task": { 93 | "start_time_in_millis": 0 94 | } 95 | }) for _ in range(num_old_tasks)] 96 | sleep(1) 97 | 98 | # Assert that these have been indeed committed to the tasks index 99 | assert plumber.datastore.ds.client.search(index='.tasks', 100 | q="task.start_time_in_millis:0", 101 | track_total_hits=True, 102 | size=0)['hits']['total']['value'] == num_old_tasks 103 | 104 | # Run task cleanup, we should return to no more "old" completed tasks 105 | plumber.running = TrueCountTimes(count=1) 106 | plumber.cleanup_old_tasks() 107 | sleep(1) 108 | assert plumber.datastore.ds.client.search(index='.tasks', 109 | q="task.start_time_in_millis:0", 110 | track_total_hits=True, 111 | size=0)['hits']['total']['value'] == 0 112 | 113 | def test_user_setting_migrations(datastore_connection): 114 | from assemblyline.odm.models.config import SubmissionProfileParams 115 | 116 | SubmissionProfileParams.fields().keys() 117 | # Create a bunch of random "old" tasks and clean them up 118 | redis = mock.MagicMock(spec=Redis) 119 | redis_persist = mock.MagicMock(spec=Redis) 120 | plumber = Plumber(redis=redis, redis_persist=redis_persist, datastore=datastore_connection, delay=1) 121 | 122 | # Create a user with old settings (format prior to 4.6) 123 | settings = {'classification': 'TLP:CLEAR', 'deep_scan': False, 'description': '', 'download_encoding': 'cart', 'default_external_sources': ['Malware Bazaar', 'VirusTotal'], 'default_zip_password': 'zippy', 'executive_summary': False, 'expand_min_score': 500, 'generate_alert': False, 'ignore_cache': False, 'ignore_dynamic_recursion_prevention': False, 'ignore_recursion_prevention': False, 'ignore_filtering': False, 'malicious': False, 'priority': 369, 'profile': False, 'service_spec': {'AVClass': {'include_malpedia_dataset': False}}, 'services': {'selected': ['Extraction', 'ConfigExtractor', 'YARA'], 'excluded': [], 'rescan': [], 'resubmit': [], 'runtime_excluded': []}, 'submission_view': 'report', 'ttl': 0} 124 | 125 | user_account = random_model_obj(User, as_json=True) 126 | user_account['uname'] = "admin" 127 | user_account['apikeys'] = {'test': random_model_obj(ApiKey, as_json=True)} 128 | datastore_connection.ds.client.index(index="user_settings", id="admin", document=settings) 129 | datastore_connection.ds.client.index(index="user", id="admin", document=user_account) 130 | 131 | datastore_connection.user_settings.commit() 132 | datastore_connection.user.commit() 133 | 134 | # Initiate the migration 135 | plumber.user_apikey_cleanup() 136 | plumber.migrate_user_settings() 137 | 138 | # Check that the settings have been migrated 139 | migrated_settings = datastore_connection.user_settings.get("admin", as_obj=False) 140 | 141 | # Check to see if API keys for the user were transferred to the new index 142 | assert datastore_connection.apikey.search('uname:admin', rows=0)['total'] > 0 143 | 144 | # Deprecated settings should be removed 145 | assert "ignore_dynamic_recursion_prevention" not in migrated_settings 146 | 147 | # All former submission settings at the root-level should be moved to submission profiles 148 | assert all([key not in migrated_settings for key in SubmissionProfileParams.fields().keys()] ) 149 | 150 | for settings in migrated_settings['submission_profiles'].values(): 151 | assert settings['classification'] == 'TLP:C' 152 | assert settings['deep_scan'] is False 153 | assert settings['generate_alert'] is False 154 | assert settings['ignore_cache'] is False 155 | assert settings['priority'] == 369 156 | # Full service spec should be preserved in default profile (along with others by default if there's no restricted parameters) 157 | assert settings['service_spec'] == {'AVClass': {'include_malpedia_dataset': False}} 158 | assert settings['ttl'] == 0 159 | -------------------------------------------------------------------------------- /test/test_scaler.py: -------------------------------------------------------------------------------- 1 | from assemblyline.odm.models.service import DockerConfig 2 | from pytest import approx 3 | from unittest.mock import Mock, patch 4 | from assemblyline_core.scaler.collection import Collection 5 | from assemblyline_core.scaler.scaler_server import ServiceProfile 6 | 7 | mock_time = Mock() 8 | 9 | 10 | @patch('time.time', mock_time) 11 | def test_collection(): 12 | mock_time.return_value = 0 13 | collection = Collection(60, ttl=61) 14 | 15 | # Insert some sample data 16 | collection.update('service-a', 'host-a', 30, 1) 17 | collection.update('service-b', 'host-a', 60, 1) 18 | mock_time.return_value = 30 19 | collection.update('service-a', 'host-b', 30, 1) 20 | 21 | assert collection.read('service-c') is None 22 | assert collection.read('service-a')['instances'] == 2 23 | assert collection.read('service-a')['duty_cycle'] == approx(0.5) 24 | 25 | assert collection.read('service-b')['instances'] == 1 26 | assert collection.read('service-b')['duty_cycle'] == approx(1) 27 | 28 | # Move forward enough that the first two messages expire, send another message from the second 29 | # service, now both should have one active message/host 30 | mock_time.return_value = 62 31 | collection.update('service-b', 'host-a', 30, 1) 32 | 33 | assert collection.read('service-a')['instances'] == 1 34 | assert collection.read('service-a')['duty_cycle'] == approx(0.5) 35 | 36 | assert collection.read('service-b')['instances'] == 1 37 | assert collection.read('service-b')['duty_cycle'] == approx(0.5) 38 | 39 | # Move forward that the last of the original group of messages expire, but the update for the second 40 | # service is still in effect 41 | mock_time.return_value = 100 42 | 43 | assert collection.read('service-a') is None 44 | assert collection.read('service-b')['instances'] == 1 45 | assert collection.read('service-b')['duty_cycle'] == approx(0.5) 46 | 47 | 48 | def test_default_bucket_rates(): 49 | service = ServiceProfile('service', DockerConfig(dict(image='redis'))) 50 | before = service.pressure 51 | service.update(5, 1, 1, 1) 52 | print(service.pressure, before) 53 | assert service.pressure > before 54 | -------------------------------------------------------------------------------- /test/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from assemblyline.odm.models.submission import Submission 4 | from assemblyline.odm.models.config import Config, DEFAULT_CONFIG 5 | from assemblyline.odm.models.service import Service 6 | from assemblyline.odm.randomizer import random_model_obj 7 | 8 | from assemblyline_core.dispatching.dispatcher import Scheduler 9 | from assemblyline_core.server_base import get_service_stage_hash, ServiceStage 10 | 11 | 12 | @pytest.fixture(scope='module') 13 | def redis(redis_connection): 14 | redis_connection.flushdb() 15 | yield redis_connection 16 | redis_connection.flushdb() 17 | 18 | 19 | def dummy_service(name, stage, category='static', accepts='', rejects=None, docid=None, extra_data=False, monitored_keys=()): 20 | return Service({ 21 | 'name': name, 22 | 'stage': stage, 23 | 'category': category, 24 | 'accepts': accepts, 25 | 'uses_temp_submission_data': extra_data, 26 | 'uses_tags': extra_data, 27 | 'rejects': rejects, 28 | 'version': '0', 29 | 'enabled': True, 30 | 'timeout': 2, 31 | 'monitored_keys': list(monitored_keys), 32 | 'docker_config': { 33 | 'image': 'somefakedockerimage:latest' 34 | } 35 | }, docid=docid) 36 | 37 | 38 | # noinspection PyUnusedLocal,PyMethodMayBeStatic 39 | class FakeDatastore: 40 | def __init__(self): 41 | self.service = self 42 | 43 | def stream_search(self, *args, **kwargs): 44 | return [] 45 | 46 | def list_all_services(self, full=True): 47 | return { 48 | 'extract': dummy_service( 49 | name='extract', 50 | stage='pre', 51 | accepts='archive/.*', 52 | ), 53 | 'AnAV': dummy_service( 54 | name='AnAV', 55 | stage='core', 56 | category='av', 57 | accepts='.*', 58 | ), 59 | 'cuckoo': dummy_service( 60 | name='cuckoo', 61 | stage='core', 62 | category='dynamic', 63 | accepts='document/.*|executable/.*', 64 | ), 65 | 'polish': dummy_service( 66 | name='polish', 67 | stage='post', 68 | category='static', 69 | accepts='.*', 70 | ), 71 | 'not_documents': dummy_service( 72 | name='not_documents', 73 | stage='post', 74 | category='static', 75 | accepts='.*', 76 | rejects='document/*', 77 | ), 78 | 'Safelist': dummy_service( 79 | name='Safelist', 80 | stage='pre', 81 | category='static', 82 | accepts='.*', 83 | ) 84 | }.values() 85 | 86 | 87 | def submission(selected, excluded): 88 | sub = random_model_obj(Submission) 89 | sub.params.services.selected = selected 90 | sub.params.services.excluded = excluded 91 | return sub 92 | 93 | 94 | @pytest.fixture 95 | def scheduler(redis): 96 | config = Config(DEFAULT_CONFIG) 97 | config.services.stages = ['pre', 'core', 'post'] 98 | stages = get_service_stage_hash(redis) 99 | ds = FakeDatastore() 100 | for service in ds.list_all_services(): 101 | stages.set(service.name, ServiceStage.Running) 102 | return Scheduler(ds, config, redis) 103 | 104 | 105 | def test_schedule_simple(scheduler): 106 | schedule = scheduler.build_schedule(submission(['static', 'av'], ['dynamic']), 'document/word') 107 | for a, b in zip(schedule, [['Safelist'], ['AnAV'], ['polish']]): 108 | assert set(a) == set(b) 109 | 110 | 111 | def test_schedule_no_excludes(scheduler): 112 | schedule = scheduler.build_schedule(submission(['static', 'av', 'dynamic'], []), 'document/word') 113 | assert all(set(a) == set(b) for a, b in zip(schedule, [['Safelist'], ['AnAV', 'cuckoo'], ['polish']])) 114 | 115 | 116 | def test_schedule_all_defaults_word(scheduler): 117 | schedule = scheduler.build_schedule(submission([], []), 'document/word') 118 | assert all(set(a) == set(b) for a, b in zip(schedule, [['Safelist'], ['AnAV', 'cuckoo'], ['polish']])) 119 | 120 | 121 | def test_schedule_all_defaults_zip(scheduler): 122 | schedule = scheduler.build_schedule(submission([], []), 'archive/zip') 123 | assert all(set(a) == set(b) 124 | for a, b in zip(schedule, [['extract', 'Safelist'], ['AnAV'], ['polish', 'not_documents']])) 125 | 126 | 127 | def test_schedule_service_safelist(scheduler): 128 | # Safelist service should still be scheduled 129 | schedule = scheduler.build_schedule(submission(["Safelist"], []), 'document/word', file_depth=0) 130 | for a, b in zip(schedule, [["Safelist"], [], []]): 131 | assert set(a) == set(b) 132 | 133 | # Safelist service should NOT still be scheduled because we're not enforcing Safelist service by default 134 | # and deep_scan and ignore_filtering is OFF for this submission 135 | sub = submission(["Safelist"], []) 136 | sub.params.deep_scan = False 137 | sub.params.ignore_filtering = False 138 | schedule = scheduler.build_schedule(sub, 'document/word', file_depth=1) 139 | for a, b in zip(schedule, [[], [], []]): 140 | assert set(a) == set(b) 141 | 142 | # Safelist service should be scheduled because we're enabling deep_scan 143 | sub.params.deep_scan = True 144 | sub.params.ignore_filtering = False 145 | schedule = scheduler.build_schedule(sub, 'document/word', file_depth=1) 146 | for a, b in zip(schedule, [["Safelist"], [], []]): 147 | assert set(a) == set(b) 148 | 149 | # Safelist service should be scheduled because we're enabling ignore_filtering 150 | sub.params.deep_scan = False 151 | sub.params.ignore_filtering = True 152 | schedule = scheduler.build_schedule(sub, 'document/word', file_depth=1) 153 | for a, b in zip(schedule, [["Safelist"], [], []]): 154 | assert set(a) == set(b) 155 | -------------------------------------------------------------------------------- /test/test_signature_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import pytest 4 | 5 | from assemblyline.odm.models.signature import Signature 6 | from assemblyline.odm.randomizer import random_model_obj 7 | from assemblyline.odm.random_data import create_signatures, wipe_signatures, create_users, wipe_users 8 | from assemblyline_core.signature_client import SignatureClient 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def client(datastore_connection): 13 | try: 14 | create_users(datastore_connection) 15 | create_signatures(datastore_connection) 16 | yield SignatureClient(datastore_connection) 17 | finally: 18 | wipe_users(datastore_connection) 19 | wipe_signatures(datastore_connection) 20 | 21 | 22 | # noinspection PyUnusedLocal 23 | def test_add_update_signature(client): 24 | # Insert a dummy signature 25 | data = random_model_obj(Signature).as_primitives() 26 | data['status'] = "DEPLOYED" 27 | expected_key = f'{data["type"]}_{data["source"]}_{data["signature_id"]}' 28 | success, key, _ = client.add_update(data) 29 | assert success 30 | assert key == expected_key 31 | 32 | # Test the signature data 33 | client.datastore.signature.commit() 34 | added_sig = client.datastore.signature.get(key, as_obj=False) 35 | assert data == added_sig 36 | 37 | # Change the signature status as a user 38 | success, _ = client.change_status(key, "DISABLED", client.datastore.user.get('user', as_obj=False)) 39 | assert success 40 | 41 | # Update signature data as an internal component 42 | new_sig_data = "NEW SIGNATURE DATA" 43 | data['data'] = new_sig_data 44 | success, key, _ = client.add_update(data) 45 | assert success 46 | assert expected_key == key 47 | modded_sig = client.datastore.signature.get(key, as_obj=False) 48 | assert modded_sig["data"] == new_sig_data 49 | # Was state kept from user setting? 50 | assert "DISABLED" == modded_sig.pop('status') 51 | 52 | 53 | # noinspection PyUnusedLocal 54 | def test_add_update_signature_many(client): 55 | 56 | # Insert a dummy signature 57 | source = "source" 58 | s_type = "type" 59 | sig_list = [] 60 | for x in range(10): 61 | data = random_model_obj(Signature).as_primitives() 62 | data['signature_id'] = f"test_sig_{x}" 63 | data['name'] = f"sig_name_{x}" 64 | data['status'] = "DEPLOYED" 65 | data['source'] = source 66 | data['type'] = s_type 67 | sig_list.append(data) 68 | 69 | assert {'errors': False, 'success': 10, 'skipped': []} == client.add_update_many(source, s_type, sig_list) 70 | 71 | # Test the signature data 72 | client.datastore.signature.commit() 73 | data = random.choice(sig_list) 74 | key = f"{data['type']}_{data['source']}_{data['signature_id']}" 75 | added_sig = client.datastore.signature.get(key, as_obj=False) 76 | assert data == added_sig 77 | 78 | # Change the signature status 79 | success, _ = client.change_status(key, "DISABLED", client.datastore.user.get('user', as_obj=False)) 80 | assert success 81 | 82 | # Update signature data 83 | new_sig_data = "NEW SIGNATURE DATA" 84 | data['data'] = new_sig_data 85 | assert {'errors': False, 'success': 1, 'skipped': []} == client.add_update_many(source, s_type, [data]) 86 | 87 | # Test the signature data 88 | modded_sig = client.datastore.signature.get(key, as_obj=False) 89 | assert modded_sig["data"] == new_sig_data 90 | # Was state kept? 91 | assert "DISABLED" == modded_sig.pop('status') 92 | 93 | 94 | # noinspection PyUnusedLocal 95 | def test_download_signatures(client): 96 | resp = client.download() 97 | assert resp.startswith(b"PK") 98 | assert b"YAR_SAMPLE" in resp 99 | assert b"ET_SAMPLE" in resp 100 | 101 | 102 | # noinspection PyUnusedLocal 103 | def test_update_available(client): 104 | assert client.update_available() 105 | assert not client.update_available(since='2030-01-01T00:00:00.000000Z') 106 | 107 | def test_update_classification(client): 108 | sig = client.datastore.signature.search("*", rows=1, as_obj=False)['items'][0] 109 | 110 | # Update classification with literal string 111 | client.classification_replace_map = {"TLP:C": "TLP:A//TEST"} 112 | client._update_classification(sig) 113 | assert sig['classification'] == "TLP:A//TEST" 114 | 115 | # Update classification with value from another field within the signature 116 | client.classification_replace_map = {"TEST": "_source"} 117 | client._update_classification(sig) 118 | assert sig['classification'] == f"TLP:A//{sig['source']}" 119 | -------------------------------------------------------------------------------- /test/test_tasking_client.py: -------------------------------------------------------------------------------- 1 | from assemblyline_core.tasking_client import TaskingClient 2 | 3 | from assemblyline.odm.models.service import Service 4 | from assemblyline.odm.models.heuristic import Heuristic 5 | from assemblyline.odm.models.result import Result, Section, Heuristic as SectionHeuristic 6 | 7 | from assemblyline.odm.randomizer import random_minimal_obj 8 | 9 | def test_register_service(datastore_connection): 10 | client = TaskingClient(datastore_connection, register_only=True) 11 | 12 | # Test service registration 13 | service = random_minimal_obj(Service).as_primitives() 14 | heuristics = [random_minimal_obj(Heuristic).as_primitives() for _ in range(2)] 15 | service['heuristics'] = heuristics 16 | assert client.register_service(service) 17 | assert all([datastore_connection.heuristic.exists(h['heur_id']) for h in heuristics]) 18 | 19 | # Test registration with heuristics that were removed but still have related results 20 | heuristic = heuristics.pop(0) 21 | result = random_minimal_obj(Result) 22 | section = random_minimal_obj(Section) 23 | section.heuristic = SectionHeuristic(heuristic) 24 | result.result.sections = [section] 25 | datastore_connection.result.save('test_result', result) 26 | datastore_connection.result.commit() 27 | 28 | # Heuristics that were removed should still reside in the system if there are still associated data to it 29 | service['heuristics'] = heuristics 30 | assert client.register_service(service) 31 | assert datastore_connection.heuristic.exists(heuristic['heur_id']) 32 | 33 | # Test registration with removed heuristics that have no related results 34 | datastore_connection.result.delete('test_result') 35 | datastore_connection.result.commit() 36 | assert client.register_service(service) 37 | assert not datastore_connection.heuristic.exists(heuristic['heur_id']) 38 | -------------------------------------------------------------------------------- /test/test_vacuum.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import os 3 | import os.path 4 | import pathlib 5 | import uuid 6 | import json 7 | import hashlib 8 | import time 9 | import random 10 | import threading 11 | 12 | from assemblyline.odm.models.config import Config, MetadataConfig 13 | from assemblyline.remote.datatypes.queues.named import NamedQueue 14 | from assemblyline_core.vacuum import crawler, worker 15 | from assemblyline_core.ingester.constants import INGEST_QUEUE_NAME 16 | 17 | 18 | def test_crawler(config: Config, redis_connection): 19 | try: 20 | with tempfile.TemporaryDirectory() as workdir: 21 | # Configure an environment 22 | data_dir = os.path.join(workdir, 'meta') 23 | os.mkdir(data_dir) 24 | 25 | config.core.vacuum.data_directories = [data_dir] 26 | config.core.vacuum.list_cache_directory = workdir 27 | config.core.vacuum.worker_cache_directory = workdir 28 | 29 | # Put some files named the right thing 30 | file_names = [uuid.uuid4().hex + '.meta' for _ in range(100)] 31 | for name in file_names: 32 | pathlib.Path(data_dir, name).touch() 33 | 34 | # Kick off a crawler thread 35 | threading.Thread(target=crawler.run, args=[config, redis_connection], daemon=True).start() 36 | 37 | # Check that they have been picked up 38 | queue = NamedQueue(crawler.VACUUM_BUFFER_NAME, redis_connection) 39 | while file_names: 40 | path = queue.pop(timeout=1) 41 | if path is None: 42 | assert False 43 | file_names = [f for f in file_names if not path.endswith(f)] 44 | finally: 45 | # shut down the crawler 46 | crawler.stop_event.set() 47 | 48 | 49 | def test_worker(config: Config, redis_connection): 50 | try: 51 | with tempfile.TemporaryDirectory() as workdir: 52 | # Configure an environment 53 | data_dir = os.path.join(workdir, 'meta') 54 | file_dir = os.path.join(workdir, 'files') 55 | os.mkdir(data_dir) 56 | os.mkdir(file_dir) 57 | 58 | config.core.vacuum.data_directories = [data_dir] 59 | config.core.vacuum.file_directories = [file_dir] 60 | config.core.vacuum.list_cache_directory = workdir 61 | config.core.vacuum.worker_cache_directory = workdir 62 | config.core.vacuum.assemblyline_user = 'service-account' 63 | config.core.vacuum.safelist = [{ 64 | 'name': 'good_streams', 65 | 'conditions': { 66 | 'stream': '10+' 67 | } 68 | }] 69 | config.core.vacuum.worker_threads = 1 70 | 71 | # Apply strict metadata validation 72 | config.submission.metadata.ingest = { 73 | config.core.vacuum.ingest_type: { 74 | 'stream': { 75 | 'validator_type': 'integer', 76 | 'required': True 77 | } 78 | } 79 | } 80 | config.submission.metadata.strict_schemes = [config.core.vacuum.ingest_type] 81 | 82 | # Place a file that will be safe 83 | first_file = os.path.join(data_dir, uuid.uuid4().hex + '.meta') 84 | with open(first_file, 'w') as temp: 85 | temp.write(json.dumps({ 86 | 'sha256': '0'*64, 87 | 'metadata': {'stream': '100'} 88 | })) 89 | 90 | # Place a file that will be ingested 91 | test_file = random.randbytes(100) 92 | sha256 = hashlib.sha256(test_file).hexdigest() 93 | name = os.path.join(file_dir, sha256[0], sha256[1], sha256[2], sha256[3]) 94 | os.makedirs(name) 95 | with open(os.path.join(name, sha256), 'bw') as temp: 96 | temp.write(test_file) 97 | 98 | second_file = os.path.join(data_dir, uuid.uuid4().hex + '.meta') 99 | with open(second_file, 'w') as temp: 100 | temp.write(json.dumps({ 101 | 'sha256': sha256, 102 | 'metadata': {'stream': '99'} 103 | })) 104 | 105 | # Start the worker 106 | threading.Thread(target=worker.run, args=[config, redis_connection, redis_connection], daemon=True).start() 107 | 108 | # Tell the worker about the files 109 | queue = NamedQueue(crawler.VACUUM_BUFFER_NAME, redis_connection) 110 | queue.push(first_file) 111 | queue.push(second_file) 112 | 113 | # Get a message from for the ingested file 114 | ingest_queue = NamedQueue(INGEST_QUEUE_NAME, redis_connection) 115 | ingested = ingest_queue.pop(timeout=20) 116 | assert ingested is not None 117 | assert ingested['files'][0]['sha256'] == sha256 118 | assert ingest_queue.length() == 0 119 | 120 | # Make sure all the meta files have been consumed 121 | rounds = 0 122 | while os.path.exists(first_file) or os.path.exists(second_file) and rounds < 10: 123 | time.sleep(0.1) 124 | rounds += 1 125 | assert not os.path.exists(first_file) 126 | assert not os.path.exists(second_file) 127 | 128 | finally: 129 | worker.stop_event.set() 130 | -------------------------------------------------------------------------------- /test/test_worker_ingest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | import time 4 | 5 | from assemblyline.datastore.helper import AssemblylineDatastore 6 | from assemblyline.odm.models.user import User 7 | from assemblyline.odm.models.file import File 8 | from assemblyline.odm.randomizer import random_minimal_obj 9 | 10 | from assemblyline_core.ingester.ingester import IngestTask, _notification_queue_prefix, Ingester 11 | 12 | from mocking import TrueCountTimes 13 | 14 | 15 | def make_message(message=None, files=None, params=None): 16 | """A helper function to fill in some fields that are largely invariant across tests.""" 17 | send = dict( 18 | # describe the file being ingested 19 | files=[{ 20 | 'sha256': '0'*64, 21 | 'size': 100, 22 | 'name': 'abc' 23 | }], 24 | metadata={}, 25 | 26 | # Information about who wants this file ingested 27 | params={ 28 | 'description': 'file abc', 29 | 'submitter': 'user', 30 | 'groups': ['users'], 31 | } 32 | ) 33 | send.update(**(message or {})) 34 | send['files'][0].update(files or {}) 35 | send['params'].update(**(params or {})) 36 | return send 37 | 38 | 39 | @pytest.fixture 40 | def ingest_harness(clean_redis, clean_datastore: AssemblylineDatastore): 41 | datastore = clean_datastore 42 | ingester = Ingester(datastore=datastore, redis=clean_redis, persistent_redis=clean_redis) 43 | ingester.running = TrueCountTimes(1) 44 | ingester.counter.increment = mock.MagicMock() 45 | ingester.submit_client.submit = mock.MagicMock() 46 | return datastore, ingester, ingester.ingest_queue 47 | 48 | 49 | def test_ingest_simple(ingest_harness): 50 | datastore, ingester, in_queue = ingest_harness 51 | 52 | user = random_minimal_obj(User) 53 | user.name = 'user' 54 | custom_user_groups = ['users', 'the_user'] 55 | user.groups = list(custom_user_groups) 56 | datastore.user.save('user', user) 57 | 58 | # Send a message with a garbled sha, this should be dropped 59 | in_queue.push(make_message( 60 | files={'sha256': '1'*10} 61 | )) 62 | 63 | # Process garbled message 64 | ingester.handle_ingest() 65 | ingester.counter.increment.assert_called_with('error') 66 | 67 | # Send a message that is fine, but has an illegal metadata field 68 | in_queue.push(make_message(dict( 69 | metadata={ 70 | 'tobig': 'a' * (ingester.config.submission.max_metadata_length + 2), 71 | 'small': '100' 72 | } 73 | ), params={'submitter': 'user', 'groups': custom_user_groups})) 74 | 75 | # Process those ok message 76 | ingester.running.counter = 1 77 | ingester.handle_ingest() 78 | 79 | # The only task that makes it through though fit these parameters 80 | task = ingester.unique_queue.pop() 81 | assert task 82 | task = IngestTask(task) 83 | assert task.submission.files[0].sha256 == '0' * 64 # Only the valid sha passed through 84 | assert 'tobig' not in task.submission.metadata # The bad metadata was stripped 85 | assert task.submission.metadata['small'] == '100' # The valid metadata is unchanged 86 | assert task.submission.params.submitter == 'user' 87 | assert task.submission.params.groups == custom_user_groups 88 | 89 | # None of the other tasks should reach the end 90 | assert ingester.unique_queue.length() == 0 91 | assert ingester.ingest_queue.length() == 0 92 | 93 | 94 | def test_ingest_stale_score_exists(ingest_harness): 95 | datastore, ingester, in_queue = ingest_harness 96 | get_if_exists = datastore.filescore.get_if_exists 97 | try: 98 | # Add a stale file score to the database for every file always 99 | from assemblyline.odm.models.filescore import FileScore 100 | datastore.filescore.get_if_exists = mock.MagicMock( 101 | return_value=FileScore(dict(psid='000', expiry_ts=0, errors=0, score=10, sid='000', time=0)) 102 | ) 103 | 104 | # Process a message that hits the stale score 105 | in_queue.push(make_message()) 106 | ingester.handle_ingest() 107 | 108 | # The stale filescore was retrieved 109 | datastore.filescore.get_if_exists.assert_called_once() 110 | 111 | # but message was ingested as a cache miss 112 | task = ingester.unique_queue.pop() 113 | assert task 114 | task = IngestTask(task) 115 | assert task.submission.files[0].sha256 == '0' * 64 116 | 117 | assert ingester.unique_queue.length() == 0 118 | assert ingester.ingest_queue.length() == 0 119 | finally: 120 | datastore.filescore.get_if_exists = get_if_exists 121 | 122 | 123 | def test_ingest_score_exists(ingest_harness): 124 | datastore, ingester, in_queue = ingest_harness 125 | get_if_exists = datastore.filescore.get_if_exists 126 | try: 127 | # Add a valid file score for all files 128 | from assemblyline.odm.models.filescore import FileScore 129 | datastore.filescore.get_if_exists = mock.MagicMock( 130 | return_value=FileScore(dict(psid='000', expiry_ts=0, errors=0, score=10, sid='000', time=time.time())) 131 | ) 132 | 133 | # Ingest a file 134 | in_queue.push(make_message()) 135 | ingester.handle_ingest() 136 | 137 | # No file has made it into the internal buffer => cache hit and drop 138 | datastore.filescore.get_if_exists.assert_called_once() 139 | ingester.counter.increment.assert_any_call('cache_hit') 140 | ingester.counter.increment.assert_any_call('duplicates') 141 | assert ingester.unique_queue.length() == 0 142 | assert ingester.ingest_queue.length() == 0 143 | finally: 144 | datastore.filescore.get_if_exists = get_if_exists 145 | 146 | 147 | def test_ingest_groups_custom(ingest_harness): 148 | datastore, ingester, in_queue = ingest_harness 149 | 150 | user = random_minimal_obj(User) 151 | user.name = 'user' 152 | custom_user_groups = ['users', 'the_user'] 153 | user.groups = list(custom_user_groups) 154 | datastore.user.save('user', user) 155 | 156 | in_queue.push(make_message(params={'submitter': 'user', 'groups': ['group_b']})) 157 | ingester.handle_ingest() 158 | 159 | task = ingester.unique_queue.pop() 160 | assert task 161 | task = IngestTask(task) 162 | assert task.submission.params.submitter == 'user' 163 | assert task.submission.params.groups == ['group_b'] 164 | 165 | 166 | def test_ingest_size_error(ingest_harness): 167 | datastore, ingester, in_queue = ingest_harness 168 | 169 | # Send a rather big file 170 | submission = make_message( 171 | files={ 172 | 'size': ingester.config.submission.max_file_size + 1, 173 | # 'ascii': 'abc' 174 | }, 175 | params={ 176 | 'ignore_size': False, 177 | 'never_drop': False 178 | } 179 | ) 180 | fo = random_minimal_obj(File) 181 | fo.sha256 = submission['files'][0]['sha256'] 182 | datastore.file.save(submission['files'][0]['sha256'], fo) 183 | submission['notification'] = {'queue': 'drop_test'} 184 | in_queue.push(submission) 185 | ingester.handle_ingest() 186 | 187 | # No files in the internal buffer 188 | assert ingester.unique_queue.length() == 0 189 | assert ingester.ingest_queue.length() == 0 190 | 191 | # A file was dropped 192 | queue_name = _notification_queue_prefix + submission['notification']['queue'] 193 | queue = ingester.notification_queues[queue_name] 194 | message = queue.pop() 195 | assert message is not None 196 | 197 | def test_ingest_always_create_submission(ingest_harness): 198 | datastore, ingester, in_queue = ingest_harness 199 | 200 | # Simulate configuration where we'll always create a submission 201 | ingester.config.core.ingester.always_create_submission = True 202 | get_if_exists = datastore.filescore.get_if_exists 203 | try: 204 | # Add a valid file score for all files 205 | from assemblyline.odm.models.filescore import FileScore 206 | from assemblyline.odm.models.submission import Submission 207 | datastore.filescore.get_if_exists = mock.MagicMock( 208 | return_value=FileScore(dict(psid='000', expiry_ts=0, errors=0, score=10, sid='001', time=time.time())) 209 | ) 210 | # Create a submission for cache hit 211 | old_sub = random_minimal_obj(Submission) 212 | old_sub.sid = '001' 213 | old_sub.params.psid = '000' 214 | old_sub = old_sub.as_primitives() 215 | datastore.submission.save('001', old_sub) 216 | 217 | # Ingest a file 218 | submission_msg = make_message(message={'sid': '002', 'metadata': {'blah': 'blah'}}) 219 | submission_msg['sid'] = '002' 220 | in_queue.push(submission_msg) 221 | ingester.handle_ingest() 222 | 223 | # No file has made it into the internal buffer => cache hit and drop 224 | datastore.filescore.get_if_exists.assert_called_once() 225 | ingester.counter.increment.assert_any_call('cache_hit') 226 | ingester.counter.increment.assert_any_call('duplicates') 227 | assert ingester.unique_queue.length() == 0 228 | assert ingester.ingest_queue.length() == 0 229 | 230 | # Check to see if new submission was created 231 | new_sub = datastore.submission.get_if_exists('002', as_obj=False) 232 | assert new_sub and new_sub['params']['psid'] == old_sub['sid'] 233 | 234 | # Check to see if certain properties are same (anything relating to analysis) 235 | assert all([old_sub.get(attr) == new_sub.get(attr) \ 236 | for attr in ['error_count', 'errors', 'file_count', 'files', 'max_score', 'results', 'state', 'verdict']]) 237 | 238 | # Check to see if certain properties are different 239 | # (anything that isn't related to analysis but can be set at submission time) 240 | assert all([old_sub.get(attr) != new_sub.get(attr) \ 241 | for attr in ['expiry_ts', 'metadata', 'params', 'times']]) 242 | 243 | # Check to see if certain properties have been nullified 244 | # (properties that are set outside of submission) 245 | assert not all([new_sub.get(attr) \ 246 | for attr in ['archived', 'archive_ts', 'to_be_deleted', 'from_archive']]) 247 | finally: 248 | datastore.filescore.get_if_exists = get_if_exists 249 | -------------------------------------------------------------------------------- /test/test_worker_submit.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | import pytest 3 | import time 4 | 5 | from assemblyline.datastore.helper import AssemblylineDatastore 6 | from assemblyline.odm.models.submission import SubmissionParams 7 | from assemblyline.odm.models.filescore import FileScore 8 | 9 | from assemblyline_core.ingester.ingester import IngestTask, _dup_prefix, Ingester 10 | 11 | from mocking import TrueCountTimes 12 | 13 | 14 | @pytest.fixture 15 | def submit_harness(clean_redis, clean_datastore: AssemblylineDatastore): 16 | """Setup a test environment just file for the ingest tests""" 17 | datastore = clean_datastore 18 | submitter = Ingester(datastore=datastore, redis=clean_redis, persistent_redis=clean_redis) 19 | submitter.running = TrueCountTimes(1) 20 | submitter.counter.increment = mock.MagicMock() 21 | submitter.submit_client.submit = mock.MagicMock() 22 | return datastore, submitter 23 | 24 | 25 | def test_submit_simple(submit_harness): 26 | datastore, submitter = submit_harness 27 | 28 | # Push a normal ingest task 29 | submitter.unique_queue.push(0, IngestTask({ 30 | 'submission': { 31 | 'params': SubmissionParams({ 32 | 'classification': 'U', 33 | 'description': 'file abc', 34 | 'services': { 35 | 'selected': [], 36 | 'excluded': [], 37 | 'resubmit': [], 38 | }, 39 | 'submitter': 'user', 40 | }), 41 | 'files': [{ 42 | 'sha256': '0' * 64, 43 | 'size': 100, 44 | 'name': 'abc', 45 | }], 46 | 'metadata': {} 47 | }, 48 | 'ingest_id': '123abc' 49 | }).as_primitives()) 50 | submitter.handle_submit() 51 | 52 | # The task has been passed to the submit tool and there are no other submissions 53 | submitter.submit_client.submit.assert_called() 54 | assert submitter.unique_queue.pop() is None 55 | 56 | 57 | def test_submit_duplicate(submit_harness): 58 | datastore, submitter = submit_harness 59 | 60 | # a normal ingest task 61 | task = IngestTask({ 62 | 'submission': { 63 | 'params': SubmissionParams({ 64 | 'classification': 'U', 65 | 'description': 'file abc', 66 | 'services': { 67 | 'selected': [], 68 | 'excluded': [], 69 | 'resubmit': [], 70 | }, 71 | 'submitter': 'user', 72 | }), 73 | 'files': [{ 74 | 'sha256': '0' * 64, 75 | 'size': 100, 76 | 'name': 'abc', 77 | }], 78 | 'metadata': {} 79 | }, 80 | 'ingest_id': 'abc123' 81 | }) 82 | # Make sure the scan key is correct, this is normally done on ingest 83 | task.submission.scan_key = task.params.create_filescore_key(task.submission.files[0].sha256, []) 84 | 85 | # Add this file to the scanning table, so it looks like it has already been submitted + ingest again 86 | submitter.scanning.add(task.submission.scan_key, task.as_primitives()) 87 | submitter.unique_queue.push(0, task.as_primitives()) 88 | 89 | submitter.handle_submit() 90 | 91 | # No tasks should be left in the queue 92 | assert submitter.unique_queue.pop() is None 93 | # The task should have been pushed to the duplicates queue 94 | assert submitter.duplicate_queue.length(_dup_prefix + task.submission.scan_key) == 1 95 | 96 | 97 | def test_existing_score(submit_harness): 98 | datastore, submitter = submit_harness 99 | get_if_exists = datastore.filescore.get_if_exists 100 | try: 101 | # Set everything to have an existing filestore 102 | datastore.filescore.get_if_exists = mock.MagicMock(return_value=FileScore( 103 | dict(psid='000', expiry_ts=0, errors=0, score=10, sid='000', time=time.time()))) 104 | 105 | # add task to internal queue 106 | submitter.unique_queue.push(0, IngestTask({ 107 | 'submission': { 108 | 'params': SubmissionParams({ 109 | 'classification': 'U', 110 | 'description': 'file abc', 111 | 'services': { 112 | 'selected': [], 113 | 'excluded': [], 114 | 'resubmit': [], 115 | }, 116 | 'submitter': 'user', 117 | }), 118 | 'files': [{ 119 | 'sha256': '0' * 64, 120 | 'size': 100, 121 | 'name': 'abc', 122 | }], 123 | 'metadata': {}, 124 | 'notification': { 125 | 'queue': 'our_queue' 126 | } 127 | }, 128 | 'ingest_id': 'abc123' 129 | }).as_primitives()) 130 | 131 | submitter.handle_submit() 132 | 133 | # No tasks should be left in the queue 134 | assert submitter.unique_queue.pop() is None 135 | # We should have received a notification about our task, since it was already 'done' 136 | assert submitter.notification_queues['nq-our_queue'].length() == 1 137 | finally: 138 | datastore.filescore.get_if_exists = get_if_exists 139 | -------------------------------------------------------------------------------- /test/test_workflow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import random 3 | from assemblyline_core.workflow.run_workflow import WorkflowManager 4 | 5 | from assemblyline.common.isotime import now_as_iso 6 | from assemblyline.odm.models.workflow import Workflow 7 | from assemblyline.odm.random_data import create_alerts, wipe_alerts, wipe_workflows 8 | from assemblyline.odm.randomizer import random_minimal_obj 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def manager(datastore_connection): 13 | try: 14 | create_alerts(datastore_connection) 15 | wipe_workflows(datastore_connection) 16 | datastore_connection.alert.update_by_query("*", [(datastore_connection.alert.UPDATE_SET, 'reporting_ts', now_as_iso())]) 17 | datastore_connection.alert.commit() 18 | yield WorkflowManager() 19 | finally: 20 | wipe_alerts(datastore_connection) 21 | 22 | def test_workflow(manager, datastore_connection): 23 | # Create workflow that targets alerts based on YARA rule association 24 | workflow = random_minimal_obj(Workflow) 25 | 26 | yara_rule = random.choice(list(datastore_connection.alert.facet("al.yara").keys())) 27 | workflow.query = f'al.yara:"{yara_rule}"' 28 | workflow.workflow_id = "AL_TEST" 29 | workflow.labels = ["AL_TEST"] 30 | workflow.priority = "LOW" 31 | workflow.status = "MALICIOUS" 32 | datastore_connection.workflow.save(workflow.workflow_id, workflow) 33 | datastore_connection.workflow.commit() 34 | 35 | # Run Workflow manager to process new workflow against existing alerts 36 | manager.running = True 37 | manager.get_last_reporting_ts = lambda x: "now/d+1d" 38 | manager.try_run(run_once=True) 39 | datastore_connection.alert.commit() 40 | 41 | # Assert that custom labels were applied to alerts 42 | assert datastore_connection.alert.search("label:AL_TEST", track_total_hits=True)['total'] 43 | 44 | # Assert that the change has been record in the alerts' event history 45 | assert datastore_connection.alert.search(f"events.entity_id:{workflow.workflow_id}", track_total_hits=True)['total'] 46 | --------------------------------------------------------------------------------