├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── Changelog.md ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── lint.sh ├── pre-commit ├── pyproject.toml ├── setup.cfg ├── setup.py ├── shub_workflow ├── __init__.py ├── base.py ├── contrib │ ├── __init__.py │ ├── hubstorage.py │ ├── sentry.py │ └── slack.py ├── crawl.py ├── deliver │ ├── __init__.py │ └── base.py ├── graph │ ├── __init__.py │ ├── task.py │ └── utils.py ├── py.typed ├── script.py └── utils │ ├── __init__.py │ ├── alert_sender.py │ ├── clone_job.py │ ├── contexts.py │ ├── dupefilter.py │ ├── futils.py │ ├── gcstorage.py │ ├── monitor.py │ ├── scanjobs.py │ ├── sesemail.py │ └── watchdog.py └── tests ├── __init__.py ├── test_base_manager.py ├── test_crawl_manager.py ├── test_futils.py ├── test_graph_manager.py └── test_typing.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | pull_request: 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11"] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install pipenv 29 | pipenv install --dev --deploy --system 30 | - name: Lint 31 | run: | 32 | ./lint.sh 33 | - name: Test with pytest 34 | run: | 35 | pytest 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | .pytest_cache/ 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | 60 | # PyBuilder 61 | target/ 62 | 63 | -------------------------------------------------------------------------------- /Changelog.md: -------------------------------------------------------------------------------- 1 | Mayor versions changes 2 | ====================== 3 | 4 | W.X.Y.Z 5 | 6 | W - Package version 7 | X - Major version. Introduces backward incompatibilities. 8 | Y - Minor version. Introduces new features. 9 | Z - Micro version. Fixes and improvements. Sometimes incremental new sub features on a new feature added within same minor version 10 | 11 | Fixes and performance improvements are done 12 | along all the way and are not indicated there. 13 | 14 | In general you can find most incompatibilities by generous using of typing hints and mypy in your project. 15 | 16 | 1.9 (From July 2022 to October 2022) 17 | ------------------------------------ 18 | 19 | Backward incompatibility issues: 20 | 21 | - Python older than 3.8 is not supported anymore 22 | - name attribute is now required for every subclass of WorkFlowManager (either as hardcoded attribute, or passed via command line required arguments) 23 | - The Delivery script has been refactored. The old code is deprecated. It will be removed in future versions. 24 | 25 | Minor version changes: 26 | 27 | - 1.9.0 crawlmanager `bad_outcome_hook()` 28 | - 1.9.1 crawlmanager ability to resume 29 | - 1.9.2 implicit crawlmanager resumability via flow id (before a command line option was required) 30 | - 1.9.3 performance improvements 31 | - 1.9.4 name attribute logic fixes 32 | - 1.9.5 delivery script refactor, old delivery code deprecated. 33 | - 1.9.6 diverse fixes and improvements in base script, added `get_jobs_with_tags()` method. 34 | - 1.9.7 refactors and new s3 helper method 35 | 36 | 1.10 (From October 2022 to April 2023) 37 | -------------------------------------- 38 | 39 | Backward incompatibility issues: 40 | 41 | - Backward incompatibilities may come from the massive introduction of typing hints and fixes in types consistencies, specially when you override some methods. 42 | Many of the incompatibities you may find can be seen in advance by using mypy in your project and using typing abundantly in the classes that 43 | use shub-workflow. 44 | 45 | Minor version changes: 46 | 47 | - 1.10.1 continuation of typing hints massive adoption 48 | - 1.10.2 continuation of typing hints massive adoptio 49 | - 1.10.3 new BaseLoopScript class 50 | - 1.10.4 `script_args()` context manager 51 | - 1.10.5 some refactor and improvements in new delivery script. 52 | - 1.10.6 crawl manager new async schedule mode 53 | - 1.10.7 crawl manager extension of async scheduling 54 | - 1.10.8 max running time for all scripts 55 | - 1.10.9 performance improvements of resume feature 56 | - 1.10.10 new method for async tagging of jobs 57 | - 1.10.11 async tagging on delivery script 58 | - 1.10.12 introduction of script stats 59 | - 1.10.13 crawlmanager stats, delivery stats, AWS email utils. 60 | - 1.10.14 graph manager ability to resume 61 | - 1.10.15 graph manager `bad_outcome_hook()` 62 | - 1.10.16 max running time feature on delivery script 63 | - 1.10.17 spider loader object in base script 64 | - 1.10.18 default implementation of `bad_outcome_hook()` on generator crawl manager 65 | - 1.10.19 some refactoring of base script. 66 | - 1.10.20 base script: resolve project id before parsing options 67 | - 1.10.21 minor improvement in base script spider loader object 68 | - 1.10.22 crawlmanager `finished_ok_hook()` method 69 | - 1.10.23 generator crawlmanager that handle multiple spiders in same process (not good experimient, will probably be deprecated in future) 70 | 71 | 1.11 (From April 2023 to June 2023) 72 | ----------------------------------- 73 | 74 | Backward incompatibility issues: 75 | 76 | - Definitively removed old legacy delivery class 77 | - Dupefilter classes moved from delivery folder into utils 78 | 79 | 80 | Minor version changes: 81 | 82 | - 1.11.1 replaced `bloom_filter` dependency by `bloom_filter2` 83 | - 1.11.2 typing improvements 84 | - 1.11.3 method for removing jobs tags 85 | - 1.11.4 method for reading log level from kumo and use on script main() function. 86 | - 1.11.5 added methods for working with Google Cloud Storage, with same interface than already existin ones for AWS S3 87 | - 1.11.6 generator crawlmanager: method for computing max jobs per spider 88 | - 1.11.7 GCS additions 89 | - 1.11.8 typing hint updates 90 | - 1.11.9 added support for python 3.11 91 | 92 | 1.12 (From June 2023 to September 2023) 93 | --------------------------------------- 94 | 95 | Backward incompatibility issues: 96 | 97 | - make installation of s3 and gcs depedencies optional (with shub-workflow[with-s3-tools] and shub-workflow[with-gcs-tools]) 98 | 99 | Minor version changes: 100 | - 1.12.1 more typing addition and improvements, and related refactoring 101 | - 1.12.2 generator crawlmanager: some methods for conditioning scheduling of new spider jobs 102 | - 1.12.3 reimplementation of `upload_file` s3 util using boto3, for improved performance 103 | - 1.12.4 improvements in job clonner class 104 | 105 | 1.13 (From September 2023 to June 2024) 106 | --------------------------------------- 107 | 108 | Backward incompatibility issues: 109 | 110 | - Some changes in delivery script interface 111 | 112 | Minor version changes: 113 | - 1.13.0 new configuration watchdog script 114 | - 1.13.1 generator crawlmanager: added method for determining retry parameters when a job is retried from default `bad_outcome_hook()` method 115 | - 1.13.2 generator crawlmanager: additional featuring on multiple spiders handling 116 | - 1.13.3 configurable retry logging via environment variables 117 | - 1.13.4 base script: some handlers for scheduling methods 118 | - 1.13.5 additions in GCS utils 119 | - 1.13.6 additions in GCS utils 120 | - 1.13.7 base script: print stats on close 121 | - 1.13.8 avoid multiple warnings on `kumo_settings()` function. Additions in GCS utils 122 | 123 | 1.14 (From June 2024 to present) 124 | -------------------------------- 125 | 126 | Backward incompatibility issues: 127 | 128 | - Moved filesystem utils (S3 and GCS utils) from deliver folder into utils folder 129 | 130 | Minor version changes: 131 | 132 | - 1.14.0 generator crawl manager: acquire all jobs if crawlmanager doesn't have a flow id. 133 | Added `get_canonical_spidername()` and `get_project_running_spiders()` helper methods on base script. 134 | Added fshelper base script attribute for readily access to this helper tool 135 | - 1.14.1 Added method for getting alive real time settings from ScrapyCloud 136 | - 1.14.2 Added BaseMonitor class that is able to monitor aggregated stats on entire workflow jobs. 137 | - 1.14.3 Mixin for provision of Sentry alert capabilities in monitor. 138 | - 1.14.4 Monitor ratios 139 | - 1.14.5 Finished job metadata hook 140 | - 1.14.6 Allow to use project entry keyword in scrapinghub.yml as alternative to project numeric id, when passing command line --project-id option. 141 | - 1.14.7 Mixin for provision of Slack alert capabilities in monitor. 142 | - 1.14.8 Allow to load settings from SC when running script on local environment. 143 | - 1.14.9 Added stats aggregation capabilities to crawlmanager 144 | - 1.14.10 AlertSender class to allow both slack and sentry alerts combined. 145 | - 1.14.11 Created SlackSender class for easier reusage of slack messaging 146 | - 1.14.12 SlackSender: allow to send attachments 147 | - 1.14.13 Extended monitor to be able to generate reports 148 | - 1.14.14 Monitor improvements: target_spider_stats can be a regex, and allow to print report tables suitable for copy/paste over spreadsheet 149 | - 1.14.15 Allow to tune how finished and running jobs are acquired, via a couple of new attributes 150 | - 1.14.16 added mixin for allowing scripts to issue items on SC 151 | - 1.14.17 fshelper: allow to get GCS object for operations like metadata set/read 152 | - 1.14.18 monitor: allow to scan multiple SC projects 153 | - 1.14.19 monitor: compute response status ratios 154 | - 1.14.20 Added tool for finding/analyzing jobs stats/logs/args/items 155 | - 1.14.21 Allow to pre set command line programs 156 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Scrapinghub 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of exporters nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | pyyaml = "==6.0" 8 | scrapinghub = {version = ">=2.4.0", extras = ["msgpack"]} 9 | jinja2 = ">=2.7.3" 10 | sqlitedict = "==2.1.0" 11 | s3fs = ">=0.4.0" 12 | google-cloud-storage = "==1.38.0" 13 | bloom-filter2 = "2.0.0" 14 | black = "==24.3.0" 15 | tenacity = "*" 16 | typing-extensions = "*" 17 | collection-scanner = "*" 18 | scrapy = "*" 19 | boto3 = ">=1.9.92" 20 | requests = "<=2.32.3" 21 | spidermon = "*" 22 | sentry-sdk = "*" 23 | slack_sdk = "*" 24 | prettytable = "*" 25 | timelength = "*" 26 | 27 | [dev-packages] 28 | flake8 = "*" 29 | black = "==24.3.0" 30 | mypy = ">=1.4.1" 31 | pylint = "*" 32 | pytest = "*" 33 | twine = "*" 34 | ipython = "<=8.18.1" 35 | grip = "*" 36 | types-requests = "*" 37 | types-setuptools = "*" 38 | types-pyyaml = "*" 39 | flake8-import-order = "*" 40 | pytest-cov = "*" 41 | types-dateparser = "*" 42 | 43 | [requires] 44 | python_version = "3.10" 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A set of tools for controlling processing workflow with spiders and script running in scrapinghub ScrapyCloud. 2 | 3 | # Installation 4 | 5 | ``` 6 | pip install shub-workflow 7 | ``` 8 | 9 | If you want to support s3 tools: 10 | 11 | ``` 12 | pip install shub-workflow[with-s3-tools] 13 | ``` 14 | 15 | For google cloud storage tools support: 16 | 17 | ``` 18 | pip install shub-workflow[with-gcs-tools] 19 | ``` 20 | 21 | # Usage 22 | 23 | Check [Project Wiki](https://github.com/scrapinghub/shub-workflow/wiki) for documentation. You can also see code tests for lots of examples of usage. 24 | 25 | # Note 26 | 27 | The requirements for this library are defined in setup.py as usual. The Pipfile files in the repository don't define dependencies. It is only used 28 | for setting up a development environment for shub-workflow library development and testing. 29 | 30 | 31 | # For developers 32 | 33 | For installing a development environment for shub-workflow, the package comes with Pipfile and Pipfile.lock files. So, clone or fork the repository and do: 34 | 35 | ``` 36 | > pipenv install --dev 37 | > cp pre-commit .git/hooks/ 38 | ``` 39 | 40 | for installing the environment, and: 41 | 42 | ``` 43 | > pipenv shell 44 | ``` 45 | 46 | for initiating it. 47 | 48 | There is a script, lint.sh, that you can run everytime you need from the repo root folder, but it is also executed each time you do `git commit` (provided 49 | you installed the pre-commit hook during the installation step described above). It checks code pep8 and typing integrity, via flake8 and mypy. 50 | 51 | ``` 52 | > ./lint.sh 53 | ``` 54 | -------------------------------------------------------------------------------- /lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | result=0 3 | flake8 shub_workflow/ tests/ --application-import-names=shub_workflow --import-order-style=pep8 4 | result=$(($result | $?)) 5 | mypy --ignore-missing-imports --disable-error-code=method-assign --check-untyped-defs shub_workflow/ tests/ 6 | result=$(($result | $?)) 7 | exit $result 8 | -------------------------------------------------------------------------------- /pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | BASE=$(git rev-parse --show-toplevel) 3 | cd $BASE 4 | ./lint.sh 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503 3 | max-line-length = 120 4 | # max-complexity = 18 5 | select = B,C,E,F,W,T4,B9,I100 6 | 7 | [tool:pytest] 8 | addopts = --doctest-modules --cov 9 | 10 | [coverage:run] 11 | branch = true 12 | source = 13 | midjourney 14 | omit = 15 | tests/* 16 | docs/* 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Automatically created by: shub deploy 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name="shub-workflow", 7 | version="1.14.21.16", 8 | description="Workflow manager for Zyte ScrapyCloud tasks.", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | license="BSD", 12 | url="https://github.com/scrapinghub/shub-workflow", 13 | maintainer="Scrapinghub", 14 | packages=find_packages(), 15 | install_requires=( 16 | "pyyaml>=3.12", 17 | "scrapinghub[msgpack]>=2.3.1", 18 | "jinja2>=2.7.3", 19 | "sqlitedict==2.1.0", 20 | "boto3>=1.9.92", 21 | "bloom-filter2==2.0.0", 22 | "collection-scanner", 23 | "tenacity", 24 | "typing-extensions", 25 | "scrapy", 26 | "prettytable", 27 | "jmespath", 28 | "timelength", 29 | ), 30 | extras_require = { 31 | "with-s3-tools": ["s3fs>=0.4.0"], 32 | "with-gcs-tools": ["google-cloud-storage>=1.38.0"], 33 | }, 34 | scripts=[], 35 | classifiers=[ 36 | "Development Status :: 5 - Production/Stable", 37 | "Intended Audience :: Developers", 38 | "License :: OSI Approved :: BSD License", 39 | "Operating System :: OS Independent", 40 | "Programming Language :: Python :: 3.9", 41 | "Programming Language :: Python :: 3.10", 42 | "Programming Language :: Python :: 3.11", 43 | ], 44 | package_data={"shub_workflow": ["py.typed"]}, 45 | ) 46 | -------------------------------------------------------------------------------- /shub_workflow/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.14.21.16" 2 | -------------------------------------------------------------------------------- /shub_workflow/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements common methods for any workflow manager 3 | """ 4 | import abc 5 | import time 6 | import logging 7 | from uuid import uuid4 8 | from argparse import Namespace 9 | from collections import defaultdict 10 | from typing import Optional, Generator, Protocol, List, Union, Dict, Tuple 11 | 12 | from .script import BaseLoopScript, JobKey, JobDict, Outcome, BaseLoopScriptProtocol 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class WorkFlowManagerProtocol(BaseLoopScriptProtocol, Protocol): 19 | @abc.abstractmethod 20 | def get_owned_jobs(self, project_id: Optional[int] = None, **kwargs) -> Generator[JobDict, None, None]: 21 | ... 22 | 23 | @abc.abstractmethod 24 | def get_finished_owned_jobs(self, project_id: Optional[int] = None, **kwargs) -> Generator[JobDict, None, None]: 25 | ... 26 | 27 | @property 28 | @abc.abstractmethod 29 | def max_running_jobs(self) -> int: 30 | ... 31 | 32 | 33 | class CachedFinishedJobsMixin(WorkFlowManagerProtocol): 34 | def __init__(self): 35 | super().__init__() 36 | self.__finished_cache: Dict[JobKey, Outcome] = {} 37 | self.__update_finished_cache_called: Dict[int, bool] = defaultdict(bool) 38 | 39 | def update_finished_cache(self, project_id: int): 40 | if not self.__update_finished_cache_called[project_id]: 41 | logger.info("Initiating finished cache update.") 42 | for job in self.get_owned_jobs(project_id, state=["finished"], meta=["close_reason"]): 43 | if job["key"] in self.__finished_cache: 44 | break 45 | self.__finished_cache[job["key"]] = Outcome(job["close_reason"]) 46 | logger.info(f"Finished jobs cache length: {len(self.__finished_cache)}") 47 | self.__update_finished_cache_called[project_id] = True 48 | 49 | def get_finished_owned_jobs(self, project_id: Optional[int] = None, **kwargs) -> Generator[JobDict, None, None]: 50 | kwargs.setdefault("meta", []).append("close_reason") 51 | # only use for updating finished cache if no filter/limit imposed. 52 | update_finished_cache = not set(kwargs.keys()).intersection(["count", "has_tag", "lacks_tag"]) 53 | finished_cache = [] 54 | for job in super().get_finished_owned_jobs(project_id, **kwargs): # type: ignore 55 | if update_finished_cache: 56 | finished_cache.append((job["key"], Outcome(job["close_reason"]))) 57 | yield job 58 | logger.info(f"Preread {len(finished_cache)} finished jobs.") 59 | while finished_cache: 60 | key, close_reason = finished_cache.pop() 61 | self.__finished_cache[key] = close_reason 62 | logger.info(f"Finished jobs cache length: {len(self.__finished_cache)}") 63 | 64 | def is_finished(self, jobkey: JobKey) -> Optional[Outcome]: 65 | project_id = int(jobkey.split("/", 1)[0]) 66 | self.update_finished_cache(project_id) 67 | return self.__finished_cache.get(jobkey) 68 | 69 | def base_loop_tasks(self): 70 | for project_id in self.__update_finished_cache_called.keys(): 71 | self.__update_finished_cache_called[project_id] = False 72 | 73 | 74 | class WorkFlowManager(BaseLoopScript, WorkFlowManagerProtocol): 75 | 76 | # --max-running-job command line option overrides it 77 | default_max_jobs: int = 1000 78 | 79 | flow_id_required = True 80 | # if True, acquire all jobs regardless flow id 81 | # use with caution: if there are many finished jobs it may acquire thousands of them. 82 | # so you may probably want to use with dont_acquire_finished_jobs = True 83 | acquire_all_jobs = False 84 | # if True, don't acquire finished jobs 85 | dont_acquire_finished_jobs = False 86 | 87 | base_failed_outcomes: Tuple[str, ...] = ( 88 | "failed", 89 | "killed by oom", 90 | "cancelled", 91 | "cancel_timeout", 92 | "memusage_exceeded", 93 | "diskusage_exceeded", 94 | "cancelled (stalled)", 95 | ) 96 | 97 | def __init__(self): 98 | self.failed_outcomes = list(self.base_failed_outcomes) 99 | self.is_resumed = False 100 | self.__autogenerated_flow_id = False 101 | super().__init__() 102 | 103 | def set_flow_id_name(self, args): 104 | super().set_flow_id_name(args) 105 | if not self.name: 106 | self.argparser.error("Manager name not set.") 107 | 108 | def get_owned_jobs(self, project_id: Optional[int] = None, **kwargs) -> Generator[JobDict, None, None]: 109 | assert self.flow_id, "This job doesn't have a flow id." 110 | assert self.name, "This job doesn't have a name." 111 | assert "has_tag" not in kwargs, "Filtering by flow id requires no extra has_tag." 112 | assert "state" in kwargs, "'state' parameter must be provided." 113 | if not self.acquire_all_jobs: 114 | kwargs["has_tag"] = [f"FLOW_ID={self.flow_id}"] 115 | parent_tag = f"PARENT_NAME={self.name}" 116 | meta = kwargs.get("meta") or [] 117 | if "tags" not in meta: 118 | meta.append("tags") 119 | kwargs["meta"] = meta 120 | for job in self.get_jobs(project_id, **kwargs): 121 | if parent_tag in job["tags"]: 122 | yield job 123 | 124 | def generate_flow_id(self) -> str: 125 | self.__autogenerated_flow_id = True 126 | return str(uuid4()) 127 | 128 | @property 129 | def max_running_jobs(self) -> int: 130 | return self.args.max_running_jobs 131 | 132 | def add_argparser_options(self): 133 | super().add_argparser_options() 134 | if not self.name: 135 | self.argparser.add_argument("name", help="Script name.") 136 | self.argparser.add_argument( 137 | "--max-running-jobs", 138 | type=int, 139 | default=self.default_max_jobs, 140 | help="If given, don't allow more than the given jobs running at once.\ 141 | Default: %(default)s", 142 | ) 143 | 144 | def parse_args(self) -> Namespace: 145 | args = super().parse_args() 146 | if not self.name: 147 | self.name = args.name 148 | return args 149 | 150 | def wait_for( 151 | self, 152 | jobs_keys: Union[JobKey, List[JobKey]], 153 | interval: int = 60, 154 | timeout: float = float("inf"), 155 | heartbeat: Optional[int] = None, 156 | ): 157 | """Waits until all given jobs are not running anymore or until the 158 | timeout is reached, if a heartbeat is given it'll log an entry every 159 | heartbeat seconds (considering the interval), otherwise it'll log an 160 | entry every interval seconds. 161 | """ 162 | if isinstance(jobs_keys, str): 163 | jobs_keys = [jobs_keys] 164 | if heartbeat is None or heartbeat < interval: 165 | heartbeat = interval 166 | still_running = dict((key, True) for key in jobs_keys) 167 | time_waited, next_heartbeat = 0, heartbeat 168 | while any(still_running.values()) and time_waited < timeout: 169 | time.sleep(interval) 170 | time_waited += interval 171 | for key in (k for k, v in still_running.items() if v): 172 | if self.is_running(key): 173 | if time_waited >= next_heartbeat: 174 | next_heartbeat += heartbeat 175 | logger.info(f"{key} still running") 176 | break 177 | still_running[key] = False 178 | 179 | def _check_resume_workflow(self): 180 | for job in self.get_jobs(state=["finished"], meta=["tags"], has_tag=[f"NAME={self.name}"]): 181 | if self.get_keyvalue_job_tag("FLOW_ID", job["tags"]) == self.flow_id: 182 | inherited_tags = [] 183 | for tag in job["tags"]: 184 | if len(tag.split("=")) == 2: 185 | inherited_tags.append(tag) 186 | self.add_job_tags(tags=inherited_tags) 187 | self.is_resumed = True 188 | break 189 | 190 | def get_finished_owned_jobs(self, project_id: Optional[int] = None, **kwargs) -> Generator[JobDict, None, None]: 191 | if not self.dont_acquire_finished_jobs: 192 | kwargs.pop("state", None) 193 | for job in self.get_owned_jobs(project_id, state=["finished"], **kwargs): 194 | yield job 195 | 196 | def resume_workflow(self): 197 | rcount = 0 198 | for job in self.get_owned_jobs(state=["running", "pending"], meta=["spider_args", "job_cmd", "tags", "spider"]): 199 | self.resume_running_job_hook(job) 200 | logger.info(f"Found running job {job['key']}") 201 | rcount += 1 202 | if rcount > 0: 203 | logger.info(f"Found a total of {rcount} running children jobs.") 204 | 205 | fcount = 0 206 | logger.info("Searching finished children jobs...") 207 | for job in self.get_finished_owned_jobs(meta=["spider_args", "job_cmd", "tags", "spider", "close_reason"]): 208 | self.resume_finished_job_hook(job) 209 | fcount += 1 210 | if fcount > 0: 211 | logger.info(f"Found a total of {fcount} finished children jobs.") 212 | 213 | def resume_running_job_hook(self, job: JobDict): 214 | pass 215 | 216 | def resume_finished_job_hook(self, job: JobDict): 217 | pass 218 | 219 | def _on_start(self): 220 | if not self.__autogenerated_flow_id or self.acquire_all_jobs: 221 | self._check_resume_workflow() 222 | if self.is_resumed or self.acquire_all_jobs: 223 | self.resume_workflow() 224 | super()._on_start() 225 | -------------------------------------------------------------------------------- /shub_workflow/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrapinghub/shub-workflow/bfce391181ecb99b8fcb7593278cef5af1e38719/shub_workflow/contrib/__init__.py -------------------------------------------------------------------------------- /shub_workflow/contrib/hubstorage.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pprint import pformat 3 | 4 | from scrapy.signals import item_scraped 5 | 6 | from shub_workflow.script import BaseScriptProtocol 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | class ItemHSIssuerMixin(BaseScriptProtocol): 13 | """ 14 | A class for allowing to issue items on hubstorage, so a script running on SC can return items as a spider. 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__() 19 | try: 20 | from sh_scrapy.extension import HubstorageExtension 21 | 22 | class _HubstorageExtension(HubstorageExtension): 23 | def item_scraped(slf, item, spider): 24 | try: 25 | return super().item_scraped(item, spider) 26 | except RuntimeError: 27 | self.log_item(item, spider) 28 | 29 | self.hextension = _HubstorageExtension.from_crawler(self._pseudo_crawler) 30 | except ImportError: 31 | self._pseudo_crawler.signals.connect(self.log_item, item_scraped) 32 | 33 | def log_item(self, item, spider, **kwargs): 34 | LOGGER.info(pformat(item)) 35 | 36 | def hs_issue_item(self, item): 37 | self._pseudo_crawler.signals.send_catch_log_deferred(item_scraped, dont_log=True, item=item, spider=self) 38 | -------------------------------------------------------------------------------- /shub_workflow/contrib/sentry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pprint import pformat 3 | from typing import Dict, Any 4 | 5 | from shub_workflow.utils import resolve_shub_jobkey 6 | from shub_workflow.utils.alert_sender import AlertSenderMixin 7 | 8 | 9 | LOG = logging.getLogger(__name__) 10 | 11 | 12 | class SentryMixin(AlertSenderMixin): 13 | """ 14 | A class for adding sentry alert capabilities to a shub_workflow class. 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__() 19 | try: 20 | from spidermon.contrib.actions.sentry import SendSentryMessage 21 | except ImportError: 22 | raise ImportError("spidermon[sentry-sdk] is required for using SentryMixin.") 23 | self.sentry_handler = SendSentryMessage( 24 | fake=self.project_settings.getbool("SPIDERMON_SENTRY_FAKE"), 25 | sentry_dsn=self.project_settings.get("SPIDERMON_SENTRY_DSN"), 26 | sentry_log_level=self.project_settings.get("SPIDERMON_SENTRY_LOG_LEVEL"), 27 | project_name=self.project_settings.get("SPIDERMON_SENTRY_PROJECT_NAME"), 28 | environment=self.project_settings.get("SPIDERMON_SENTRY_ENVIRONMENT_TYPE"), 29 | ) 30 | self.register_sender_method(self.send_sentry_message) 31 | 32 | def send_sentry_message(self): 33 | if self.messages: 34 | message: Dict[str, Any] = dict() 35 | title = f"{self.sentry_handler.project_name} | {self.sentry_handler.environment} | {self.args.subject}" 36 | message["title"] = title 37 | message["failure_reasons"] = "/n".join(self.messages) 38 | job_key = resolve_shub_jobkey() 39 | if job_key: 40 | message["job_link"] = f"https://app.zyte.com/p/{job_key}" 41 | if self.sentry_handler.fake: 42 | message["failure_reasons"] = self.messages 43 | LOG.info(pformat(message)) 44 | else: 45 | self.sentry_handler.send_message(message) 46 | -------------------------------------------------------------------------------- /shub_workflow/contrib/slack.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pprint import pformat 3 | from typing import Any, Dict, List 4 | 5 | from shub_workflow.utils import resolve_shub_jobkey 6 | from shub_workflow.utils.alert_sender import AlertSenderMixin 7 | 8 | LOG = logging.getLogger(__name__) 9 | 10 | 11 | class SlackSender: 12 | def __init__(self, project_settings): 13 | super().__init__() 14 | try: 15 | from spidermon.contrib.actions.slack import SlackMessageManager 16 | except ImportError: 17 | raise ImportError("spidermon[slack-sdk] is required for using SlackSender") 18 | self.slack_handler = SlackMessageManager( 19 | fake=project_settings.getbool("SPIDERMON_SLACK_FAKE"), 20 | sender_token=project_settings.get("SPIDERMON_SLACK_SENDER_TOKEN"), 21 | sender_name=project_settings.get("SPIDERMON_SLACK_SENDER_NAME"), 22 | ) 23 | self.recipients = project_settings.get("SPIDERMON_SLACK_RECIPIENTS") 24 | 25 | def send_slack_messages(self, messages: List[str], subject: str, attachments=None): 26 | message: Dict[str, Any] = dict() 27 | title = f"{self.slack_handler.sender_name} | {subject}" 28 | message["title"] = title 29 | message["failure_reasons"] = "\n".join(messages) 30 | job_key = resolve_shub_jobkey() 31 | if job_key: 32 | message["job_link"] = f"https://app.zyte.com/p/{job_key}" 33 | text = ( 34 | f"{title}\n\n" 35 | "Alert Reasons:\n" 36 | f"{message['failure_reasons']}\n\n" 37 | f"Job Link: {message.get('job_link', 'N/A')}\n" 38 | "-------------------------------------" 39 | ) 40 | if self.slack_handler.fake: 41 | message["failure_reasons"] = messages 42 | LOG.info(pformat(message)) 43 | else: 44 | self.slack_handler.send_message(self.recipients, text, attachments=attachments) 45 | 46 | 47 | class SlackMixin(AlertSenderMixin): 48 | """ 49 | A class for adding slack alert capabilities to a shub_workflow class. 50 | """ 51 | 52 | def __init__(self): 53 | super().__init__() 54 | self.register_sender_method(self.send_slack_queued_messages) 55 | self.sender = SlackSender(self.project_settings) 56 | 57 | def send_slack_queued_messages(self): 58 | if self.messages: 59 | self.sender.send_slack_messages(self.messages, self.args.subject) 60 | -------------------------------------------------------------------------------- /shub_workflow/crawl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base script class for spiders crawl managers. 3 | """ 4 | import abc 5 | import json 6 | import asyncio 7 | import logging 8 | from copy import deepcopy 9 | from argparse import Namespace 10 | from typing import Optional, List, Tuple, Dict, NewType, cast, Generator, Any, AsyncGenerator, Set, Protocol, TypedDict 11 | 12 | from bloom_filter2 import BloomFilter 13 | 14 | from shub_workflow.script import ( 15 | JobKey, 16 | JobDict, 17 | JobMeta, 18 | SpiderName, 19 | Outcome, 20 | BaseLoopScriptAsyncMixin, 21 | ) 22 | from shub_workflow.base import WorkFlowManager 23 | from shub_workflow.utils import hashstr 24 | from shub_workflow.utils.dupefilter import DupesFilterProtocol 25 | from shub_workflow.utils.monitor import SpiderStatsAggregatorMixin 26 | 27 | 28 | _LOG = logging.getLogger(__name__) 29 | 30 | 31 | SpiderArgs = NewType("SpiderArgs", Dict[str, str]) 32 | ScheduleArgs = NewType("ScheduleArgs", Dict[str, Any]) 33 | 34 | 35 | class JobParams(TypedDict, total=False): 36 | units: int 37 | tags: List[str] 38 | job_settings: Dict[str, str] 39 | project_id: int 40 | spider_args: Dict[str, str] 41 | 42 | 43 | class FullJobParams(JobParams, total=False): 44 | spider: SpiderName 45 | 46 | 47 | class StopRetry(Exception): 48 | pass 49 | 50 | 51 | def get_jobseq(tags: List[str]) -> Tuple[int, int]: 52 | """ 53 | returns tuple (job sequence, retry index) 54 | """ 55 | for tag in tags: 56 | if tag.startswith("JOBSEQ="): 57 | tag = tag.replace("JOBSEQ=", "") 58 | jobseq, *rep = tag.split(".r") 59 | rep.append("0") 60 | return int(jobseq), int(rep[0]) 61 | return 0, 0 62 | 63 | 64 | class CrawlManagerProtocol(Protocol): 65 | _running_job_keys: Dict[JobKey, Tuple[SpiderName, JobParams]] 66 | 67 | @abc.abstractmethod 68 | def check_running_jobs(self) -> Dict[JobKey, Outcome]: 69 | ... 70 | 71 | @abc.abstractmethod 72 | def schedule_spider_with_jobargs( 73 | self, 74 | job_args_override: Optional[JobParams] = None, 75 | spider: Optional[SpiderName] = None, 76 | ) -> Optional[JobKey]: 77 | ... 78 | 79 | @abc.abstractmethod 80 | def get_job_settings(self, override: Optional[Dict[str, str]] = None) -> Dict[str, str]: 81 | ... 82 | 83 | 84 | class CrawlManager(SpiderStatsAggregatorMixin, WorkFlowManager, CrawlManagerProtocol): 85 | """ 86 | Schedules a single spider job. If loop mode is enabled, it will shutdown only after the scheduled spider 87 | finished. Close reason of the manager will be inherited from spider one. 88 | """ 89 | 90 | spider: Optional[SpiderName] = None 91 | 92 | def __init__(self): 93 | super().__init__() 94 | 95 | # running jobs represents the state of a crawl manager. 96 | # a dict job key : (spider, job_args_override) 97 | self._running_job_keys: Dict[JobKey, Tuple[SpiderName, JobParams]] = {} 98 | 99 | @property 100 | def description(self): 101 | return self.__doc__ 102 | 103 | def add_argparser_options(self): 104 | super().add_argparser_options() 105 | if self.spider is None: 106 | self.argparser.add_argument("spider", help="Spider name") 107 | self.argparser.add_argument("--spider-args", help="Spider arguments dict in json format", default="{}") 108 | self.argparser.add_argument("--job-settings", help="Job settings dict in json format", default="{}") 109 | self.argparser.add_argument("--units", help="Set default number of ScrapyCloud units for each job", type=int) 110 | 111 | def parse_args(self) -> Namespace: 112 | args = super().parse_args() 113 | if self.spider is None: 114 | self.spider = args.spider 115 | return args 116 | 117 | def get_job_settings(self, override: Optional[Dict[str, str]] = None) -> Dict[str, str]: 118 | job_settings = json.loads(self.args.job_settings) 119 | if override: 120 | job_settings.update(override) 121 | return job_settings 122 | 123 | def schedule_spider_with_jobargs( 124 | self, 125 | job_args_override: Optional[JobParams] = None, 126 | spider: Optional[SpiderName] = None, 127 | ) -> Optional[JobKey]: 128 | job_args_override = job_args_override or {} 129 | spider = spider or self.spider 130 | if spider is not None: 131 | schedule_args: ScheduleArgs = json.loads(self.args.spider_args) 132 | spider_args = job_args_override.get("spider_args") or {} 133 | schedule_args.update(job_args_override) 134 | schedule_args.pop("spider_args", None) 135 | schedule_args.update(spider_args) 136 | 137 | job_settings_override = schedule_args.get("job_settings", None) 138 | schedule_args["job_settings"] = self.get_job_settings(job_settings_override) 139 | schedule_args["units"] = schedule_args.get("units", self.args.units) 140 | jobkey = self.schedule_spider(spider, **schedule_args) 141 | if jobkey is not None: 142 | self._running_job_keys[jobkey] = spider, job_args_override 143 | self.stats.inc_value("crawlmanager/scheduled_jobs") 144 | return jobkey 145 | return None 146 | 147 | def check_running_jobs(self) -> Dict[JobKey, Outcome]: 148 | outcomes = {} 149 | running_job_keys = list(self._running_job_keys) 150 | while running_job_keys: 151 | jobkey = running_job_keys.pop() 152 | if (outcome := self.is_finished(jobkey)) is not None: 153 | 154 | metadata = self.get_job_metadata(jobkey) 155 | self.finished_metadata_hook(jobkey, metadata) 156 | 157 | spider, job_args_override = self._running_job_keys.pop(jobkey) 158 | if outcome in self.failed_outcomes: 159 | _LOG.warning(f"Job {jobkey} finished with outcome {outcome}.") 160 | if job_args_override is not None: 161 | job_args_override = job_args_override.copy() 162 | self.bad_outcome_hook(spider, outcome, job_args_override, jobkey) 163 | else: 164 | self.finished_ok_hook(spider, outcome, job_args_override, jobkey) 165 | outcomes[jobkey] = outcome 166 | 167 | scrapystats = self.get_metadata_key(metadata, "scrapystats") 168 | self.aggregate_spider_stats(JobDict(spider=spider, key=jobkey, scrapystats=scrapystats)) 169 | _LOG.info(f"There are {len(self._running_job_keys)} jobs still running.") 170 | 171 | return outcomes 172 | 173 | def finished_metadata_hook(self, jobkey: JobKey, metadata: JobMeta): 174 | """ 175 | allow to add some reaction on each finished job, based solely on its metadata. 176 | Use self.get_metadata_key(metadata, ) in order to get metadata with handled retries. 177 | """ 178 | 179 | def bad_outcome_hook(self, spider: SpiderName, outcome: Outcome, job_args_override: JobParams, jobkey: JobKey): 180 | if self.get_close_reason() is None: 181 | self.set_close_reason(outcome) 182 | 183 | def finished_ok_hook(self, spider: SpiderName, outcome: Outcome, job_args_override: JobParams, jobkey: JobKey): 184 | pass 185 | 186 | def workflow_loop(self) -> bool: 187 | outcomes = self.check_running_jobs() 188 | if outcomes: 189 | return False 190 | if not self._running_job_keys: 191 | self.schedule_spider_with_jobargs() 192 | return True 193 | 194 | def resume_running_job_hook(self, job: JobDict): 195 | key = job["key"] 196 | spider_args = job.get("spider_args", {}).copy() 197 | job_args_override = JobParams( 198 | { 199 | "tags": job["tags"], 200 | "spider_args": spider_args, 201 | } 202 | ) 203 | self._running_job_keys[key] = job["spider"], job_args_override 204 | _LOG.info(f"added running job {key}") 205 | 206 | def on_close(self): 207 | job = self.get_job() 208 | if job: 209 | close_reason = self.get_close_reason() 210 | self.finish(close_reason=close_reason) 211 | 212 | 213 | class PeriodicCrawlManager(CrawlManager): 214 | """ 215 | Schedule a spider periodically, waiting for the previous job to finish before scheduling it again with same 216 | parameters. Don't forget to set loop mode. 217 | """ 218 | 219 | def bad_outcome_hook(self, spider: str, outcome: Outcome, job_args_override: JobParams, jobkey: JobKey): 220 | pass 221 | 222 | def workflow_loop(self) -> bool: 223 | self.check_running_jobs() 224 | if not self._running_job_keys: 225 | self.schedule_spider_with_jobargs() 226 | return True 227 | 228 | def on_close(self): 229 | pass 230 | 231 | 232 | class GeneratorCrawlManagerProtocol(CrawlManagerProtocol, Protocol): 233 | 234 | _jobuids: DupesFilterProtocol 235 | spider: Optional[SpiderName] 236 | 237 | @abc.abstractmethod 238 | def _workflow_step_gen(self, max_next_params: int) -> Generator[Tuple[str, Optional[JobKey]], None, None]: 239 | ... 240 | 241 | @abc.abstractmethod 242 | def get_max_next_params(self) -> int: 243 | ... 244 | 245 | 246 | class GeneratorCrawlManager(CrawlManager, GeneratorCrawlManagerProtocol): 247 | """ 248 | Schedule a spider periodically, each time with different parameters yielded by a generator, until stop. 249 | Number of simultaneos spider jobs will be limited by max running jobs (see WorkFlowManager). 250 | Instructions: 251 | - Override set_parameters_gen() method. It must implement a generator of dictionaries, each one being 252 | the spider arguments (argument name: argument value) passed to the spider on each successive job. 253 | - Don't forget to set loop mode and max jobs (which defaults to infinite). 254 | """ 255 | 256 | MAX_RETRIES = 0 257 | 258 | def __init__(self): 259 | super().__init__() 260 | self.__parameters_gen: Generator[ScheduleArgs, None, None] = self.set_parameters_gen() 261 | self.__additional_jobs: List[FullJobParams] = [] 262 | self.__delayed_jobs: List[FullJobParams] = [] 263 | self.__next_job_seq = 1 264 | self._jobuids = self.create_dupe_filter() 265 | 266 | def get_delayed_jobs(self) -> List[FullJobParams]: 267 | return deepcopy(self.__delayed_jobs) 268 | 269 | def add_delayed_jobs(self, params: FullJobParams): 270 | self.__delayed_jobs.append(params) 271 | 272 | @classmethod 273 | def create_dupe_filter(cls) -> DupesFilterProtocol: 274 | return BloomFilter(max_elements=1e6, error_rate=1e-6) 275 | 276 | def spider_running_count(self, spider: SpiderName) -> int: 277 | count = 0 278 | for spidername, _ in self._running_job_keys.values(): 279 | if spidername == spider: 280 | count += 1 281 | return count 282 | 283 | def bad_outcome_hook(self, spider: SpiderName, outcome: Outcome, job_args_override: JobParams, jobkey: JobKey): 284 | if outcome == "cancelled": 285 | return 286 | if self.MAX_RETRIES == 0: 287 | return 288 | spider_args = job_args_override.setdefault("spider_args", {}) 289 | tags = job_args_override["tags"] 290 | retries = get_jobseq(tags)[1] 291 | if retries < self.MAX_RETRIES: 292 | for t in list(tags): 293 | if t.startswith("RETRIED_FROM="): 294 | tags.remove(t) 295 | break 296 | tags.append(f"RETRIED_FROM={jobkey}") 297 | spider_args["retry_num"] = str(retries + 1) 298 | try: 299 | retry_override = self.get_retry_override(spider, outcome, job_args_override, jobkey) 300 | for rtag in retry_override.pop("tags", []): 301 | if rtag not in tags: 302 | tags.append(rtag) 303 | job_args_override.setdefault("spider_args", {}).update(retry_override.pop("spider_args", {})) 304 | job_args_override.setdefault("job_settings", {}).update(retry_override.pop("job_settings", {})) 305 | job_args_override.update(retry_override) 306 | except StopRetry as e: 307 | _LOG.info(f"Job {jobkey} failed with reason '{outcome}'. Will not be retried. Reason: {e}") 308 | else: 309 | self.add_job(spider, job_args_override) 310 | _LOG.info( 311 | f"Job {jobkey} failed with reason '{outcome}'. Retrying ({retries + 1} of {self.MAX_RETRIES})." 312 | ) 313 | 314 | def get_retry_override( 315 | self, spider: SpiderName, outcome: Outcome, job_args_override: JobParams, jobkey: JobKey 316 | ) -> JobParams: 317 | return JobParams({}) 318 | 319 | def add_job(self, spider: SpiderName, job_args_override: JobParams): 320 | params = cast(FullJobParams, (job_args_override or {}).copy()) 321 | params["spider"] = spider 322 | self.__additional_jobs.append(params) 323 | 324 | @abc.abstractmethod 325 | def set_parameters_gen(self) -> Generator[ScheduleArgs, None, None]: 326 | ... 327 | 328 | def __add_jobseq_tag(self, params: FullJobParams): 329 | tags = params.setdefault("tags", []) 330 | jobseq_tag = None 331 | for tag in tags: 332 | if tag.startswith("JOBSEQ="): 333 | jobseq_tag = tag 334 | break 335 | if jobseq_tag is None: 336 | jobseq_tag = f"JOBSEQ={self.__next_job_seq:010d}" 337 | self.__next_job_seq += 1 338 | else: 339 | tags.remove(jobseq_tag) 340 | jobseq, repn = get_jobseq([jobseq_tag]) 341 | jobseq_tag = f"JOBSEQ={jobseq:010d}.r{repn + 1}" 342 | tags.append(jobseq_tag) 343 | 344 | def get_running_spiders(self) -> Set[SpiderName]: 345 | return set(sp for sp, _ in self._running_job_keys.values()) 346 | 347 | def get_delayed_spiders(self) -> Set[SpiderName]: 348 | return set(np["spider"] for np in self.__delayed_jobs) 349 | 350 | def spider_delayed_count(self, spider: SpiderName) -> int: 351 | count = 0 352 | for np in self.__delayed_jobs: 353 | if np["spider"] == spider: 354 | count += 1 355 | return count 356 | 357 | def can_schedule_job_with_params(self, params: FullJobParams) -> bool: 358 | """ 359 | If returns False, delay the scheduling of the params. 360 | """ 361 | return True 362 | 363 | def _fulljobparams_from_spiderargs(self, schedule_args: ScheduleArgs) -> FullJobParams: 364 | spider = schedule_args.get("spider", self.spider) 365 | spider_args = deepcopy(schedule_args) 366 | for key in tuple(FullJobParams.__annotations__): 367 | spider_args.pop(key, None) 368 | result = FullJobParams( 369 | { 370 | "spider": spider, 371 | "spider_args": spider_args, 372 | } 373 | ) 374 | if "units" in schedule_args: 375 | result["units"] = schedule_args["units"] 376 | if "tags" in schedule_args: 377 | result["tags"] = schedule_args["tags"] 378 | if "job_settings" in schedule_args: 379 | result["job_settings"] = schedule_args["job_settings"] 380 | if "project_id" in schedule_args: 381 | result["project_id"] = schedule_args["project_id"] 382 | return result 383 | 384 | def _workflow_step_gen(self, max_next_params: int) -> Generator[Tuple[str, Optional[JobKey]], None, None]: 385 | new_params: List[FullJobParams] = [] 386 | next_params: Optional[FullJobParams] 387 | 388 | while len(new_params) < max_next_params: 389 | next_params = None 390 | for idx, np in enumerate(self.__delayed_jobs): 391 | if self.can_schedule_job_with_params(np): 392 | next_params = np 393 | self.__delayed_jobs = self.__delayed_jobs[:idx] + self.__delayed_jobs[idx + 1:] 394 | break 395 | else: 396 | for idx, np in enumerate(self.__additional_jobs): 397 | if self.can_schedule_job_with_params(np): 398 | next_params = np 399 | self.__additional_jobs = self.__additional_jobs[:idx] + self.__additional_jobs[idx + 1:] 400 | break 401 | else: 402 | try: 403 | np = self._fulljobparams_from_spiderargs(next(self.__parameters_gen)) 404 | except StopIteration: 405 | break 406 | else: 407 | spider = np.get("spider", self.spider) 408 | assert spider, f"No spider set for parameters {np}" 409 | np["spider"] = spider 410 | if self.can_schedule_job_with_params(np): 411 | next_params = np 412 | else: 413 | self.__delayed_jobs.append(np) 414 | continue 415 | 416 | if next_params is None: 417 | break 418 | spider_args = next_params.get("spider_args") or {} 419 | jobuid = self.get_job_unique_id({"spider": next_params["spider"], "spider_args": spider_args}) 420 | if jobuid in self._jobuids: 421 | _LOG.warning(f"Job with parameters {next_params} was already scheduled. Skipped.") 422 | continue 423 | new_params.append(next_params) 424 | 425 | for next_params in new_params: 426 | spider = next_params.pop("spider") 427 | self.__add_jobseq_tag(next_params) 428 | yield jobuid, self.schedule_spider_with_jobargs(next_params, spider) 429 | 430 | def get_max_next_params(self) -> int: 431 | return self.max_running_jobs - len(self._running_job_keys) 432 | 433 | def workflow_loop(self) -> bool: 434 | self.check_running_jobs() 435 | max_next_params = self.get_max_next_params() 436 | retval = False 437 | for jobuid, jobid in self._workflow_step_gen(max_next_params): 438 | retval = True 439 | if jobid is not None: 440 | self._jobuids.add(jobuid) 441 | return retval or bool(self._running_job_keys) 442 | 443 | @staticmethod 444 | def get_job_unique_id(job: JobDict) -> str: 445 | jdict = job.get("spider_args", {}).copy() 446 | jdict["spider"] = job["spider"] 447 | for k, v in jdict.items(): 448 | jdict[k] = str(v) 449 | jid = json.dumps(jdict, sort_keys=True) 450 | return hashstr(jid) 451 | 452 | def resume_running_job_hook(self, job: JobDict): 453 | super().resume_running_job_hook(job) 454 | jobuid = self.get_job_unique_id(job) 455 | self._jobuids.add(jobuid) 456 | self.__next_job_seq = max(self.__next_job_seq, get_jobseq(job["tags"])[0] + 1) 457 | 458 | def resume_finished_job_hook(self, job: JobDict): 459 | jobuid = self.get_job_unique_id(job) 460 | if jobuid not in self._jobuids: 461 | self._jobuids.add(jobuid) 462 | self.__next_job_seq = max(self.__next_job_seq, get_jobseq(job["tags"])[0] + 1) 463 | 464 | def resume_workflow(self): 465 | super().resume_workflow() 466 | _LOG.info(f"Next job sequence number: {self.__next_job_seq}") 467 | 468 | def on_close(self): 469 | self._jobuids.close() 470 | 471 | 472 | class AsyncSchedulerCrawlManagerMixin(BaseLoopScriptAsyncMixin, GeneratorCrawlManagerProtocol): 473 | async def _async_workflow_step_gen(self, max_next_params: int) -> AsyncGenerator[Tuple[str, JobKey], None]: 474 | jobuids_cors = list(self._workflow_step_gen(max_next_params)) 475 | if jobuids_cors: 476 | jobuids, cors = zip(*jobuids_cors) 477 | results: List[JobKey] = await asyncio.gather(*cors) 478 | for jobuid, jobid in zip(jobuids, results): 479 | yield jobuid, jobid 480 | 481 | async def workflow_loop(self) -> bool: # type: ignore 482 | self.check_running_jobs() 483 | max_next_params = self.get_max_next_params() 484 | retval = False 485 | async for jobuid, jobid in self._async_workflow_step_gen(max_next_params): 486 | retval = True 487 | if jobid is not None: 488 | self._jobuids.add(jobuid) 489 | return retval or bool(self._running_job_keys) 490 | 491 | async def schedule_spider_with_jobargs( # type: ignore 492 | self, 493 | job_args_override: Optional[JobParams] = None, 494 | spider: Optional[SpiderName] = None, 495 | ) -> Optional[JobKey]: 496 | job_args_override = job_args_override or {} 497 | spider = spider or self.spider 498 | assert spider is not None 499 | schedule_args: ScheduleArgs = json.loads(self.args.spider_args) 500 | spider_args = job_args_override.get("spider_args") or {} 501 | schedule_args.update(job_args_override) 502 | schedule_args.pop("spider_args", None) 503 | schedule_args.update(spider_args) 504 | 505 | job_settings_override = schedule_args.get("job_settings", None) 506 | schedule_args["job_settings"] = self.get_job_settings(job_settings_override) 507 | schedule_args["units"] = schedule_args.get("units", self.args.units) 508 | jobkey = await self.async_schedule_spider(spider, **schedule_args) 509 | if jobkey is not None: 510 | self._running_job_keys[jobkey] = spider, job_args_override 511 | self.stats.inc_value("crawlmanager/scheduled_jobs") 512 | return jobkey 513 | -------------------------------------------------------------------------------- /shub_workflow/deliver/__init__.py: -------------------------------------------------------------------------------- 1 | from shub_workflow.deliver.base import BaseDeliverScript 2 | 3 | __all__ = ("BaseDeliverScript",) 4 | -------------------------------------------------------------------------------- /shub_workflow/deliver/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import time 3 | import asyncio 4 | import logging 5 | from collections import defaultdict 6 | from typing import Generator, List, Tuple, Protocol, Union, Type, Dict 7 | 8 | from scrapinghub.client.jobs import Job 9 | from scrapy import Item 10 | 11 | from shub_workflow.script import BaseLoopScript, BaseScriptProtocol 12 | from shub_workflow.utils.dupefilter import SqliteDictDupesFilter, DupesFilterProtocol 13 | 14 | _LOG = logging.getLogger(__name__) 15 | 16 | 17 | class DeliverScriptProtocol(BaseScriptProtocol, Protocol): 18 | 19 | DELIVERED_TAG: str 20 | SCRAPERNAME_NARGS: Union[str, int] 21 | 22 | @abc.abstractmethod 23 | def get_delivery_spider_jobs( 24 | self, 25 | scrapername: str, 26 | target_tags: List[str], 27 | only_finished: bool = True, 28 | ) -> Generator[Job, None, None]: 29 | ... 30 | 31 | @abc.abstractmethod 32 | def has_delivery_running_spider_jobs(self, scrapername: str, target_tags: List[str]) -> bool: 33 | ... 34 | 35 | @abc.abstractmethod 36 | def on_item(self, item: Item, scrapername: str): 37 | ... 38 | 39 | 40 | class BaseDeliverScript(BaseLoopScript, DeliverScriptProtocol): 41 | 42 | DELIVERED_TAG = "delivered" 43 | SCRAPERNAME_NARGS: Union[str, int] = "+" 44 | 45 | # print log every given items processed 46 | LOG_EVERY = 1000 47 | 48 | # minimal run time in order to ensure target jobs started to be scheduled 49 | MIN_RUN_TIME = 30 50 | 51 | # define here the fields used to deduplicate items. All them compose the dedupe key. 52 | # target item values must be strings. 53 | # for changing behavior, override is_seen_item() 54 | DEDUPE_KEY_BY_FIELDS: Tuple[str, ...] = () 55 | 56 | MAX_PROCESSED_ITEMS = float("inf") 57 | 58 | SEEN_ITEMS_CLASS: Type[DupesFilterProtocol] = SqliteDictDupesFilter 59 | strict_max_time = False 60 | 61 | def __init__(self): 62 | super().__init__() 63 | self._all_jobs_to_tag = [] 64 | self.total_items_count = 0 65 | self.total_dupe_filtered_items_count = 0 66 | self.seen_items: DupesFilterProtocol = self.SEEN_ITEMS_CLASS() 67 | self.seen_fields: Dict[str, int] = defaultdict(int) 68 | self.start_time = time.time() 69 | 70 | def add_argparser_options(self): 71 | super().add_argparser_options() 72 | self.argparser.add_argument("scrapername", help="Target scraper names", nargs=self.SCRAPERNAME_NARGS) 73 | self.argparser.add_argument( 74 | "--test-mode", 75 | action="store_true", 76 | help="Run in test mode (performs all processes, but doesn't upload files nor consumes jobs)", 77 | ) 78 | 79 | def get_target_tags(self) -> List[str]: 80 | """ 81 | Return here additional tags (aside from FLOW_ID, which is automatically handled) 82 | that jobs must have in order to be included in delivery. 83 | """ 84 | return [] 85 | 86 | def get_delivery_spider_jobs( 87 | self, scrapername: str, target_tags: List[str], only_finished: bool = True 88 | ) -> Generator[Job, None, None]: 89 | """ 90 | Select which jobs will be processed by the current deliver script instance. 91 | """ 92 | if self.flow_id: 93 | flow_id_tag = [f"FLOW_ID={self.flow_id}"] 94 | target_tags = flow_id_tag + target_tags 95 | state = ["finished"] 96 | if not only_finished: 97 | state.append("running") 98 | yield from self.get_jobs_with_tags(scrapername, target_tags, state=state, lacks_tag=[self.DELIVERED_TAG]) 99 | 100 | def has_delivery_running_spider_jobs(self, scrapername: str, target_tags: List[str]) -> bool: 101 | """ 102 | While this method returns True, the delivery script continue looping (provided loop_mode is set) 103 | """ 104 | if self.flow_id: 105 | flow_id_tag = [f"FLOW_ID={self.flow_id}"] 106 | target_tags = flow_id_tag + target_tags 107 | for sj in self.get_jobs_with_tags(scrapername, target_tags, state=["running", "pending"]): 108 | if sj.key not in self._all_jobs_to_tag: 109 | return True 110 | return False 111 | 112 | def process_spider_jobs(self, scrapername: str, only_finished: bool = True) -> bool: 113 | if self.total_items_count >= self.MAX_PROCESSED_ITEMS: 114 | return False 115 | if self.strict_max_time and self.is_max_time_ran_out(): 116 | return False 117 | target_tags = self.get_target_tags() 118 | for sj in self.get_delivery_spider_jobs(scrapername, target_tags, only_finished): 119 | if sj.key in self._all_jobs_to_tag: 120 | continue 121 | self.process_job_items(scrapername, sj) 122 | if not self.args.test_mode: 123 | self._all_jobs_to_tag.append(sj.key) 124 | if self.total_items_count >= self.MAX_PROCESSED_ITEMS: 125 | return False 126 | if self.strict_max_time and self.is_max_time_ran_out(): 127 | return False 128 | return True 129 | 130 | def get_item_unique_key(self, item: Item) -> str: 131 | assert all(isinstance(item[f], str) for f in self.DEDUPE_KEY_BY_FIELDS) 132 | key = tuple(item[f] for f in self.DEDUPE_KEY_BY_FIELDS) 133 | return ",".join(key) 134 | 135 | def is_seen_item(self, item: Item) -> bool: 136 | key = self.get_item_unique_key(item) 137 | return key != "" and key in self.seen_items 138 | 139 | def add_seen_item(self, item: Item): 140 | key = self.get_item_unique_key(item) 141 | if key: 142 | self.seen_items.add(key) 143 | 144 | def process_job_items(self, scrapername: str, spider_job: Job): 145 | idx = -1 146 | try: 147 | for item in spider_job.items.iter(): 148 | if self.is_seen_item(item): 149 | self.total_dupe_filtered_items_count += 1 150 | else: 151 | self.on_item(item, scrapername) 152 | self.add_seen_item(item) 153 | for key, value in item.items(): 154 | if value: 155 | self.seen_fields[key] += 1 156 | self.total_items_count += 1 157 | if self.total_items_count % self.LOG_EVERY == 0: 158 | _LOG.info(f"Processed {self.total_items_count} items.") 159 | idx += 1 160 | except UnicodeDecodeError as e: 161 | _LOG.error(f"Exception while decoding item {spider_job.key}/{idx + 1}: {e}") 162 | 163 | def workflow_loop(self) -> bool: 164 | for scrapername in self.args.scrapername: 165 | _LOG.info(f"Processing spider {scrapername}") 166 | if not self.process_spider_jobs(scrapername): 167 | return False 168 | if self.loop_mode: 169 | target_tags = self.get_target_tags() 170 | if self.has_delivery_running_spider_jobs(scrapername, target_tags): 171 | return True 172 | if time.time() - self.start_time < self.MIN_RUN_TIME: 173 | return True 174 | return False 175 | 176 | def close_files(self): 177 | pass 178 | 179 | def on_close(self): 180 | if self.loop_mode: 181 | for scrapername in self.args.scrapername: 182 | _LOG.info(f"Processing remaining spider {scrapername}") 183 | self.process_spider_jobs(scrapername, only_finished=False) 184 | self.close_files() 185 | jobs_count = len(self._all_jobs_to_tag) 186 | _LOG.info(f"Processed a total of {jobs_count} jobs.") 187 | _LOG.info(f"Processed a total of {self.total_items_count} items.") 188 | if self.DEDUPE_KEY_BY_FIELDS: 189 | _LOG.info(f"A total of {self.total_dupe_filtered_items_count} items were duplicated.") 190 | asyncio.run(self._tag_all()) 191 | if hasattr(self.seen_items, "close"): 192 | self.seen_items.close() 193 | for key, value in self.seen_fields.items(): 194 | self.stats.set_value(f"delivery/fields/count/{key}", value) 195 | self.stats.set_value("delivery/items/count", self.total_items_count) 196 | self.stats.set_value("delivery/items/duplicated", self.total_dupe_filtered_items_count) 197 | self.stats.set_value("delivery/jobs/count", jobs_count) 198 | 199 | async def _tag_all(self): 200 | while self._all_jobs_to_tag: 201 | to_tag, self._all_jobs_to_tag = self._all_jobs_to_tag[:1000], self._all_jobs_to_tag[1000:] 202 | cors = [self.async_add_job_tags(jkey, tags=[self.DELIVERED_TAG]) for jkey in to_tag] 203 | await asyncio.gather(*cors) 204 | _LOG.info("Marked %d jobs as delivered", len(to_tag)) 205 | -------------------------------------------------------------------------------- /shub_workflow/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Meta manager. Defines complex workflow in terms of lower level managers 3 | 4 | For usage example see tests 5 | 6 | """ 7 | import re 8 | import logging 9 | from time import time 10 | from argparse import Namespace 11 | from typing import NewType, Dict, List, Optional, Set, Tuple, DefaultDict 12 | 13 | from collections import defaultdict, OrderedDict 14 | from copy import copy, deepcopy 15 | 16 | from typing_extensions import TypedDict, NotRequired 17 | import yaml 18 | 19 | try: 20 | from yaml import CLoader 21 | 22 | OLD_YAML = True 23 | except ImportError: 24 | from yaml import Loader 25 | 26 | OLD_YAML = False 27 | 28 | 29 | from shub_workflow.script import JobKey, JobDict, Outcome 30 | from shub_workflow.base import WorkFlowManager 31 | from shub_workflow.graph.task import ( 32 | JobGraphDict, 33 | TaskId, 34 | BaseTask, 35 | Resource, 36 | ResourceAmmount, 37 | OnFinishKey, 38 | OnFinishTarget, 39 | ) 40 | 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | _STARTING_JOB_RE = re.compile("--starting-job(?:=(.+))?") 46 | JobsGraphs = NewType("JobsGraphs", Dict[TaskId, JobGraphDict]) 47 | 48 | 49 | class PendingJobDict(TypedDict): 50 | wait_for: Set[TaskId] 51 | is_retry: bool 52 | wait_time: Optional[int] 53 | origin: NotRequired[TaskId] 54 | 55 | 56 | class GraphManager(WorkFlowManager): 57 | 58 | jobs_graph: JobsGraphs = JobsGraphs({}) 59 | 60 | def __init__(self): 61 | # Ensure jobs are traversed in the same order as they went pending. 62 | self.__pending_jobs: Dict[TaskId, PendingJobDict] = OrderedDict() 63 | self.__running_jobs: Dict[TaskId, JobKey] = OrderedDict() 64 | self.__completed_jobs: Dict[TaskId, Tuple[JobKey, Outcome]] = dict() 65 | self._available_resources: Dict[Resource, ResourceAmmount] = {} # map resource : ammount 66 | self._acquired_resources: DefaultDict[Resource, List[Tuple[TaskId, ResourceAmmount]]] = defaultdict(list) 67 | self.__tasks: Dict[TaskId, BaseTask] = {} 68 | super(GraphManager, self).__init__() 69 | self.__start_time: DefaultDict[TaskId, float] = defaultdict(time) 70 | self.__starting_jobs: List[TaskId] = self.args.starting_job 71 | for task in self.configure_workflow() or (): 72 | if self.args.root_jobs: 73 | self.__starting_jobs.append(task.task_id) 74 | self._add_task(task) 75 | 76 | @property 77 | def description(self): 78 | return f"Workflow manager for {self.name!r}" 79 | 80 | def _add_task(self, task: BaseTask): 81 | assert task.task_id not in self.jobs_graph, ( 82 | "Workflow inconsistency detected: task %s referenced twice." % task.task_id 83 | ) 84 | self.jobs_graph[task.task_id] = task.as_jobgraph_dict() 85 | self.__tasks[task.task_id] = task 86 | for ntask in task.get_next_tasks(): 87 | self._add_task(ntask) 88 | 89 | def get_task(self, task_id: TaskId) -> BaseTask: 90 | return self.__tasks[task_id] 91 | 92 | def configure_workflow(self) -> Tuple[BaseTask, ...]: 93 | raise NotImplementedError("configure_workflow() method need to be implemented.") 94 | 95 | def on_start(self): 96 | if not self.jobs_graph: 97 | self.argparser.error("Jobs graph configuration is empty.") 98 | if not self.__starting_jobs: 99 | self.argparser.error("You must provide either --starting-job or --root-jobs.") 100 | self._fill_available_resources() 101 | self._setup_starting_jobs() 102 | self.workflow_loop_enabled = True 103 | logger.info("Starting '%s' workflow", self.name) 104 | 105 | def _setup_starting_jobs(self) -> None: 106 | for taskid in self.__starting_jobs: 107 | wait_for: List[TaskId] = self.get_jobdict(taskid).get("wait_for", []) 108 | self._add_pending_job(taskid, wait_for=tuple(wait_for)) 109 | 110 | initial_pending_jobs: Dict[TaskId, PendingJobDict] = OrderedDict() 111 | while initial_pending_jobs != self.__pending_jobs: 112 | initial_pending_jobs = deepcopy(self.__pending_jobs) 113 | for taskid in list(self.__pending_jobs.keys()): 114 | if taskid in self.__completed_jobs: 115 | jobid, outcome = self.__completed_jobs[taskid] 116 | self.__pending_jobs.pop(taskid) 117 | self._check_completed_job(taskid, jobid, outcome) 118 | elif taskid in self.__running_jobs: 119 | self.__pending_jobs.pop(taskid) 120 | 121 | def _fill_available_resources(self): 122 | """ 123 | Ensure there are enough starting resources in order every job 124 | can run at some point 125 | """ 126 | for job in self.jobs_graph.keys(): 127 | for required_resources in self.__tasks[job].get_required_resources(): 128 | for resource, req_amount in required_resources.items(): 129 | old_amount = self._available_resources.get(resource, 0) 130 | if old_amount < req_amount: 131 | logger.info( 132 | "Increasing available resources count for %r" 133 | " from %r to %r. Old value was not enough" 134 | " for job %r to run.", 135 | resource, 136 | old_amount, 137 | req_amount, 138 | job, 139 | ) 140 | self._available_resources[resource] = req_amount 141 | 142 | def get_jobdict(self, job: TaskId, pop=False) -> JobGraphDict: 143 | if job not in self.jobs_graph: 144 | self.argparser.error("Invalid job: %s. Available jobs: %s" % (job, repr(self.jobs_graph.keys()))) 145 | if pop: 146 | return self.jobs_graph.pop(job) 147 | return self.jobs_graph[job] 148 | 149 | def _add_pending_job(self, job: TaskId, wait_for=(), is_retry=False): 150 | self._maybe_add_on_finish_default(job, is_retry) 151 | if job in self.args.skip_job: 152 | return 153 | if job in self.__tasks: 154 | task = self.__tasks[job] 155 | parallelization = task.get_parallel_jobs() 156 | else: 157 | task_id = self.get_jobdict(job).get("origin", job) 158 | task = self.__tasks[task_id] 159 | parallelization = 1 160 | if parallelization == 1: 161 | self.__pending_jobs[job] = { 162 | "wait_for": set(wait_for), 163 | "is_retry": is_retry, 164 | "wait_time": task.wait_time, 165 | } 166 | else: 167 | # Split parallelized task into N parallel jobs. 168 | basejobconf = self.get_jobdict(job, pop=True) 169 | for i in range(parallelization): 170 | subtask_id = TaskId("%s.%i" % (job, i)) 171 | subtask_conf = deepcopy(basejobconf) 172 | subtask_conf["origin"] = job 173 | subtask_conf["index"] = i 174 | 175 | for _, nextjobs in subtask_conf["on_finish"].items(): 176 | if i != 0: # only job 0 will conserve finish targets 177 | for nextjob in copy(nextjobs): 178 | if nextjob != "retry": 179 | if nextjob in self.jobs_graph: 180 | self.get_jobdict(nextjob)["wait_for"].append(subtask_id) 181 | if nextjob in self.__pending_jobs: 182 | self.__pending_jobs[nextjob]["wait_for"].add(subtask_id) 183 | else: 184 | for i in range(parallelization): 185 | nextjobp = TaskId("%s.%i" % (job, i)) 186 | self.get_jobdict(nextjobp)["wait_for"].append(subtask_id) 187 | if nextjobp in self.__pending_jobs: 188 | self.__pending_jobs[nextjobp]["wait_for"].add(subtask_id) 189 | nextjobs.remove(nextjob) 190 | self.jobs_graph[subtask_id] = subtask_conf 191 | self.__pending_jobs[subtask_id] = { 192 | "wait_for": set(wait_for), 193 | "is_retry": is_retry, 194 | "origin": job, 195 | "wait_time": task.wait_time, 196 | } 197 | for other, oconf in self.jobs_graph.items(): 198 | if job in oconf.get("wait_for", []): 199 | oconf["wait_for"].remove(job) 200 | if other in self.__pending_jobs: 201 | self.__pending_jobs[other]["wait_for"].discard(job) 202 | for i in range(parallelization): 203 | subtask_id = TaskId("%s.%i" % (job, i)) 204 | oconf["wait_for"].append(subtask_id) 205 | if other in self.__pending_jobs: 206 | self.__pending_jobs[other]["wait_for"].add(subtask_id) 207 | 208 | def add_argparser_options(self): 209 | super(GraphManager, self).add_argparser_options() 210 | self.argparser.add_argument("--jobs-graph", help="Define jobs graph_dict on command line", default="{}") 211 | self.argparser.add_argument( 212 | "--starting-job", 213 | "-s", 214 | action="append", 215 | default=[], 216 | help="Set starting jobs. Can be given multiple times.", 217 | ) 218 | self.argparser.add_argument( 219 | "--only-starting-jobs", 220 | action="store_true", 221 | help="If given, only run the starting jobs (don't follow on finish next jobs)", 222 | ) 223 | self.argparser.add_argument( 224 | "--comment", 225 | help="Can be used for differentiate command line and avoid scheduling " 226 | "fail when a graph manager job is scheduled when another one with same option " 227 | "signature is running. Doesn't do anything else.", 228 | ) 229 | self.argparser.add_argument( 230 | "--skip-job", 231 | default=[], 232 | action="append", 233 | help="Skip given job. Can be given multiple times. Also next jobs for the skipped" "one will be skipped.", 234 | ) 235 | self.argparser.add_argument("--root-jobs", action="store_true", help="Set root jobs as starting jobs.") 236 | 237 | def parse_args(self) -> Namespace: 238 | args = super(GraphManager, self).parse_args() 239 | if OLD_YAML: 240 | self.jobs_graph = yaml.load(args.jobs_graph, Loader=CLoader) or deepcopy(self.jobs_graph) 241 | else: 242 | self.jobs_graph = yaml.load(args.jobs_graph, Loader=Loader) or deepcopy(self.jobs_graph) 243 | 244 | if args.starting_job and args.root_jobs: 245 | self.argparser.error("You can't provide both --starting-job and --root-jobs.") 246 | return args 247 | 248 | def workflow_loop(self) -> bool: 249 | logger.debug("Pending jobs: %r", self.__pending_jobs) 250 | logger.debug("Running jobs: %r", self.__running_jobs) 251 | logger.debug("Available resources: %r", self._available_resources) 252 | logger.debug("Acquired resources: %r", self._acquired_resources) 253 | self.check_running_jobs() 254 | if self.__pending_jobs: 255 | self.run_pending_jobs() 256 | elif not self.__running_jobs: 257 | return False 258 | return True 259 | 260 | def run_job(self, job: TaskId, is_retry=False) -> Optional[JobKey]: 261 | task = self.__tasks.get(job) 262 | if task is not None: 263 | return task.run(self, is_retry) 264 | 265 | jobconf = self.get_jobdict(job) 266 | task = self.__tasks.get(jobconf["origin"]) 267 | if task is not None: 268 | idx = jobconf["index"] 269 | return task.run(self, is_retry, index=idx) 270 | raise RuntimeError(f"Failed to run task {job}") 271 | 272 | def _must_wait_time(self, job: TaskId) -> bool: 273 | status = self.__pending_jobs[job] 274 | if status["wait_time"] is not None: 275 | wait_time = status["wait_time"] - time() + self.__start_time[job] 276 | if wait_time > 0: 277 | logger.info("Job %s must wait %d seconds for running", job, wait_time) 278 | return True 279 | return False 280 | 281 | def run_pending_jobs(self): 282 | """Try running pending jobs. 283 | 284 | Normally, only jobs that have no outstanding dependencies are started. 285 | 286 | If all pending jobs have outstanding dependencies, try to start one job 287 | ignoring unknown tasks, i.e. those that are not currently pending. 288 | 289 | If none of the pending jobs cannot be started either way, it means 290 | there's a dependency cycle, in this case an error is raised. 291 | 292 | """ 293 | 294 | # Normal mode: start jobs without dependencies. 295 | for task_id in sorted(self.__pending_jobs.keys()): 296 | if len(self.__running_jobs) >= self.max_running_jobs: 297 | break 298 | status = self.__pending_jobs[task_id] 299 | 300 | job_can_run = ( 301 | not status["wait_for"] and not self._must_wait_time(task_id) and self._try_acquire_resources(task_id) 302 | ) 303 | if job_can_run: 304 | try: 305 | jobid = self.run_job(task_id, status["is_retry"]) 306 | assert jobid is not None, f"Failed to run task {task_id}" 307 | except Exception: 308 | self._release_resources(task_id) 309 | raise 310 | self.__pending_jobs.pop(task_id) 311 | self.__running_jobs[task_id] = jobid 312 | 313 | if ( 314 | not self.__pending_jobs 315 | or self.__running_jobs 316 | or any(status["wait_time"] is not None for status in self.__pending_jobs.values()) 317 | ): 318 | return 319 | 320 | # At this point, there are pending jobs, but none were started because 321 | # of dependencies, try "skip unknown deps" mode: start one job that 322 | # only has "unseen" dependencies to try to break the "stalemate." 323 | origin_job = None 324 | for task_id in sorted(self.__pending_jobs.keys()): 325 | if len(self.__running_jobs) >= self.max_running_jobs: 326 | break 327 | status = self.__pending_jobs[task_id] 328 | job_can_run = ( 329 | all(w not in self.__pending_jobs for w in status["wait_for"]) 330 | and (not origin_job or status.get("origin") == origin_job) 331 | and self._try_acquire_resources(task_id) 332 | ) 333 | origin_job = status.get("origin") 334 | if job_can_run: 335 | try: 336 | jobid = self.run_job(task_id, status["is_retry"]) 337 | assert jobid is not None, f"Failed to run task {task_id}" 338 | except Exception: 339 | self._release_resources(task_id) 340 | raise 341 | self.__pending_jobs.pop(task_id) 342 | self.__running_jobs[task_id] = jobid 343 | if not origin_job and self.__running_jobs: 344 | return 345 | 346 | if self.__running_jobs: 347 | return 348 | 349 | # Nothing helped, all pending jobs wait for each other somehow. 350 | raise RuntimeError( 351 | "Job dependency cycle detected: %s" 352 | % ", ".join( 353 | "%s waits for %s" % (task_id, sorted(self.__pending_jobs[task_id]["wait_for"])) 354 | for task_id in sorted(self.__pending_jobs.keys()) 355 | ) 356 | ) 357 | 358 | def get_running_jobid(self, task_id: TaskId) -> JobKey: 359 | return self.__running_jobs[task_id] 360 | 361 | def handle_retry(self, job: TaskId, outcome: str) -> bool: 362 | jobconf = self.get_jobdict(job) 363 | retries = jobconf.get("retries", 0) 364 | if retries > 0: 365 | self._add_pending_job(job, is_retry=True) 366 | self.__completed_jobs.pop(job, None) 367 | jobconf["retries"] -= 1 368 | logger.warning( 369 | "Will retry job %s (outcome: %s, number of retries left: %s)", 370 | job, 371 | outcome, 372 | jobconf["retries"], 373 | ) 374 | return True 375 | logger.warning("No more retries for failed job %s (outcome: %s)", job, outcome) 376 | return False 377 | 378 | def check_running_jobs(self) -> None: 379 | for task_id, jobid in list(self.__running_jobs.items()): 380 | outcome = self.is_finished(jobid) 381 | if outcome is not None: 382 | self._check_completed_job(task_id, jobid, outcome) 383 | self.__running_jobs.pop(task_id) 384 | else: 385 | logger.info("Job %s (%s) still running", task_id, jobid) 386 | 387 | def _check_completed_job(self, task_id: TaskId, jobid: JobKey, outcome: Outcome): 388 | will_retry = False 389 | logger.info('Job "%s/%s" (%s) finished', self.name, task_id, jobid) 390 | for nextjob in self._get_next_jobs(task_id, outcome): 391 | if nextjob == "retry": 392 | will_retry = self.handle_retry(task_id, outcome) 393 | if not will_retry: 394 | self.bad_outcome_hook(task_id, jobid, outcome) 395 | elif nextjob in self.__pending_jobs: 396 | logger.error("Job %s already pending", nextjob) 397 | else: 398 | wait_for = self.get_jobdict(nextjob).get("wait_for", []) 399 | self._add_pending_job(nextjob, wait_for) 400 | self._release_resources(task_id) 401 | if not will_retry: 402 | self.__completed_jobs[task_id] = jobid, outcome 403 | for st in self.__pending_jobs.values(): 404 | st["wait_for"].discard(task_id) 405 | for conf in self.jobs_graph.values(): 406 | if task_id in conf.get("wait_for", []): 407 | conf["wait_for"].remove(task_id) 408 | 409 | def bad_outcome_hook(self, task_id: TaskId, jobid: JobKey, outcome: Outcome): 410 | pass 411 | 412 | def _try_acquire_resources(self, job: TaskId) -> bool: 413 | result = True 414 | task_id = self.get_jobdict(job).get("origin", job) 415 | for required_resources in self.__tasks[task_id].get_required_resources(partial=True): 416 | for resource, req_amount in required_resources.items(): 417 | if self._available_resources[resource] < req_amount: 418 | result = False 419 | break 420 | else: 421 | for resource, req_amount in required_resources.items(): 422 | self._available_resources[resource] -= req_amount 423 | self._acquired_resources[resource].append((job, req_amount)) 424 | return True 425 | return result 426 | 427 | def _release_resources(self, job: TaskId): 428 | for res, acquired in self._acquired_resources.items(): 429 | for rjob, res_amount in acquired: 430 | if rjob == job: 431 | self._available_resources[res] += res_amount 432 | self._acquired_resources[res].remove((rjob, res_amount)) 433 | 434 | def _maybe_add_on_finish_default(self, job: TaskId, is_retry: bool) -> Dict[OnFinishKey, OnFinishTarget]: 435 | on_finish = self.get_jobdict(job)["on_finish"] 436 | task = self.__tasks.get(job) 437 | if task is not None and not task.is_locked: 438 | task.get_start_callback()(self, is_retry) 439 | task.set_is_locked() 440 | for t in task.get_next_tasks(): 441 | on_finish["default"].append(t.task_id) 442 | if t.task_id not in self.jobs_graph: 443 | self._add_task(t) 444 | 445 | return on_finish 446 | 447 | def _get_next_jobs(self, job: TaskId, outcome: Outcome) -> OnFinishTarget: 448 | if self.args.only_starting_jobs: 449 | return [] 450 | on_finish = self.get_jobdict(job)["on_finish"] 451 | if outcome in on_finish: 452 | nextjobs = on_finish[outcome] 453 | elif outcome in self.failed_outcomes: 454 | nextjobs = on_finish.get("failed", []) 455 | else: 456 | nextjobs = on_finish["default"] 457 | 458 | return nextjobs 459 | 460 | @property 461 | def pending_jobs(self) -> Dict[TaskId, PendingJobDict]: 462 | return self.__pending_jobs 463 | 464 | def get_job_taskid(self, job: JobDict) -> Optional[TaskId]: 465 | task_id = self.get_keyvalue_job_tag("TASK_ID", job["tags"]) 466 | if task_id is not None: 467 | return TaskId(task_id) 468 | return None 469 | 470 | def resume_running_job_hook(self, job: JobDict): 471 | task_id = self.get_job_taskid(job) 472 | if task_id is not None: 473 | self.__running_jobs[task_id] = job["key"] 474 | origin_task_id = TaskId(task_id.rsplit(".", 1)[0]) 475 | self.__tasks[origin_task_id].append_jobid(job["key"]) 476 | 477 | def resume_finished_job_hook(self, job: JobDict): 478 | task_id = self.get_job_taskid(job) 479 | if task_id is not None: 480 | self.__completed_jobs[task_id] = job["key"], Outcome(job["close_reason"]) 481 | origin_task_id = TaskId(task_id.rsplit(".", 1)[0]) 482 | self.__tasks[origin_task_id].append_jobid(job["key"]) 483 | -------------------------------------------------------------------------------- /shub_workflow/graph/task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shlex 3 | import abc 4 | from fractions import Fraction 5 | from typing import NewType, List, Dict, Optional, Union, Literal, Callable, Protocol 6 | from typing_extensions import TypedDict, NotRequired 7 | 8 | from jinja2 import Template 9 | 10 | from shub_workflow.script import JobKey 11 | from shub_workflow.base import WorkFlowManagerProtocol, Outcome 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | Resource = NewType("Resource", str) 17 | ResourceAmmount = Union[int, Fraction] 18 | ResourcesDict = NewType("ResourcesDict", Dict[Resource, ResourceAmmount]) 19 | TaskId = NewType("TaskId", str) 20 | OnFinishKey = Union[Outcome, Literal["default", "failed"]] 21 | OnFinishTarget = List[Union[Literal["retry"], TaskId]] 22 | 23 | 24 | class GraphManagerProtocol(WorkFlowManagerProtocol, Protocol): 25 | @abc.abstractmethod 26 | def get_task(self, task_id: TaskId) -> "BaseTask": 27 | ... 28 | 29 | 30 | class JobGraphDict(TypedDict): 31 | tags: Optional[List[str]] 32 | units: Optional[int] 33 | on_finish: Dict[OnFinishKey, OnFinishTarget] 34 | wait_for: List[TaskId] 35 | retries: int 36 | wait_time: Optional[int] 37 | project_id: Optional[int] 38 | 39 | command: NotRequired[List[str]] 40 | init_args: NotRequired[List[str]] 41 | retry_args: NotRequired[List[str]] 42 | 43 | origin: NotRequired[TaskId] 44 | index: NotRequired[int] 45 | 46 | spider: NotRequired[str] 47 | spider_args: NotRequired[Dict[str, str]] 48 | 49 | 50 | class BaseTask(abc.ABC): 51 | def __init__( 52 | self, 53 | task_id: TaskId, 54 | tags: Optional[List[str]] = None, 55 | units: Optional[int] = None, 56 | retries: int = 1, 57 | project_id: Optional[int] = None, 58 | wait_time: Optional[int] = None, 59 | on_finish: Optional[Dict[OnFinishKey, OnFinishTarget]] = None, 60 | ): 61 | assert task_id != "retry", "Reserved word 'retry' can't be used as task id" 62 | self.task_id = task_id 63 | self.tags = tags 64 | self.units = units 65 | self.retries = retries 66 | self.project_id = project_id 67 | self.wait_time = wait_time 68 | self.on_finish: Dict[OnFinishKey, OnFinishTarget] = on_finish or {} 69 | 70 | self.__is_locked: bool = False 71 | self.__next_tasks: List[BaseTask] = [] 72 | self.__wait_for: List[BaseTask] = [] 73 | self.__required_resources: List[ResourcesDict] = [] 74 | self.__start_callback: Callable[[GraphManagerProtocol, bool], None] 75 | 76 | self.__job_ids: List[JobKey] = [] 77 | 78 | self.set_start_callback(self._default_start_callback) 79 | 80 | def set_is_locked(self): 81 | self.__is_locked = True 82 | 83 | @property 84 | def is_locked(self) -> bool: 85 | return self.__is_locked 86 | 87 | def append_jobid(self, jobid: JobKey): 88 | self.__job_ids.append(jobid) 89 | 90 | def get_scheduled_jobs(self) -> List[JobKey]: 91 | """ 92 | - Returns the task job ids 93 | """ 94 | return self.__job_ids 95 | 96 | def add_next_task(self, task: "BaseTask"): 97 | assert not self.__is_locked, "You can't alter a locked job." 98 | self.__next_tasks.append(task) 99 | 100 | def add_wait_for(self, task: "BaseTask"): 101 | assert not self.__is_locked, "You can't alter a locked job." 102 | self.__wait_for.append(task) 103 | 104 | def add_required_resources(self, resources_dict: ResourcesDict): 105 | assert not self.__is_locked, "You can't alter a locked job." 106 | self.__required_resources.append(resources_dict) 107 | 108 | def get_next_tasks(self) -> List["BaseTask"]: 109 | return self.__next_tasks 110 | 111 | def get_required_resources(self, partial: bool = False) -> List[ResourcesDict]: 112 | """ 113 | If partial is True, return required resources for each splitted job. 114 | Otherwise return the resouces required for the full task. 115 | """ 116 | if not partial: 117 | return self.__required_resources 118 | required_resources = [] 119 | parallelization = self.get_parallel_jobs() 120 | for reqset in self.__required_resources: 121 | reqres: ResourcesDict = ResourcesDict({}) 122 | for resource, req_amount in reqset.items(): 123 | # Split required resource into N parts. There are two 124 | # ideas behind this: 125 | # 126 | # - if the job in whole requires some resources, each of 127 | # its N parts should be using 1/N of that resource 128 | # 129 | # - in most common scenario when 1 unit of something is 130 | # required, allocating 1/N of it means that when we start 131 | # one unit job, we can start another unit job to allocate 132 | # 2/N, but not a completely different job (as it would 133 | # consume (1 + 1/N) of the resource. 134 | # 135 | # Use fraction to avoid any floating point quirks. 136 | reqres[resource] = Fraction(req_amount, parallelization) 137 | required_resources.append(reqres) 138 | return required_resources 139 | 140 | def get_wait_for(self) -> List["BaseTask"]: 141 | return self.__wait_for 142 | 143 | def as_jobgraph_dict(self) -> JobGraphDict: 144 | jdict: JobGraphDict = { 145 | "tags": self.tags, 146 | "units": self.units, 147 | "on_finish": self.on_finish, 148 | "wait_for": [t.task_id for t in self.get_wait_for()], 149 | "retries": self.retries, 150 | "project_id": self.project_id, 151 | "wait_time": self.wait_time, 152 | } 153 | self.on_finish["default"] = [] 154 | if self.retries > 0: 155 | self.on_finish["failed"] = ["retry"] 156 | 157 | return jdict 158 | 159 | def set_start_callback(self, func: Callable[[GraphManagerProtocol, bool], None]): 160 | self.__start_callback = func 161 | 162 | def get_start_callback(self) -> Callable[[GraphManagerProtocol, bool], None]: 163 | assert self.__start_callback is not None, "Start callback not initialized." 164 | return self.__start_callback 165 | 166 | def _default_start_callback(self, manager: GraphManagerProtocol, is_retry: bool): 167 | pass 168 | 169 | @abc.abstractmethod 170 | def run(self, manager: GraphManagerProtocol, is_retry=False, index: Optional[int] = None) -> Optional[JobKey]: 171 | ... 172 | 173 | @abc.abstractmethod 174 | def get_parallel_jobs(self): 175 | """ 176 | Returns total number of parallel jobs that this task will consist on. 177 | """ 178 | 179 | 180 | class Task(BaseTask): 181 | def __init__( 182 | self, 183 | task_id, 184 | command, 185 | init_args=None, 186 | retry_args=None, 187 | tags=None, 188 | units=None, 189 | retries=1, 190 | project_id=None, 191 | wait_time=None, 192 | on_finish=None, 193 | ): 194 | """ 195 | id - String. identifies the task. 196 | command - String. script name or jinja2 template string. 197 | init_args - List of strings. Arguments and options to add to the command. 198 | retry_args - List of strings. If given and job is retries, use this list of arguments instead the ones 199 | specified in init_args. 200 | tags - List of strings. tags to add to the scheduled job. 201 | units - Int. units to use for this scheduled job. 202 | retries - Int. Max number of retries in case job failed. 203 | project_id - Int. Run task in given project. If not given, just run in the actual project. 204 | wait_time - Int. Don't run the task before the given number of seconds after job goes to pending status. 205 | """ 206 | super().__init__(task_id, tags, units, retries, project_id, wait_time, on_finish) 207 | assert "." not in self.task_id, ". is not allowed in task name." 208 | self.command = command 209 | self.init_args = init_args or [] 210 | self.retry_args = retry_args or [] 211 | self.__template = Template(self.command) 212 | 213 | def as_jobgraph_dict(self) -> JobGraphDict: 214 | jdict = super().as_jobgraph_dict() 215 | jdict.update({"command": self.get_commands(), "init_args": self.init_args, "retry_args": self.retry_args}) 216 | return jdict 217 | 218 | def get_commands(self) -> List[str]: 219 | return self.__template.render().splitlines() 220 | 221 | def get_command(self, index: int = 0) -> List[str]: 222 | return shlex.split(self.get_commands()[index]) 223 | 224 | def get_parallel_jobs(self) -> int: 225 | """ 226 | Returns total number of parallel jobs that this task will consist on. 227 | """ 228 | return len(self.get_commands()) 229 | 230 | def run( 231 | self, manager: GraphManagerProtocol, is_retry: bool = False, index: Optional[int] = None 232 | ) -> Optional[JobKey]: 233 | command = self.get_command(index or 0) 234 | tags = [] 235 | if self.tags is not None: 236 | tags.extend(self.tags) 237 | if index is None: 238 | jobname = f"{manager.name}/{self.task_id}" 239 | tags.append(f"TASK_ID={self.task_id}") 240 | else: 241 | jobname = f"{manager.name}/{self.task_id}.{index}" 242 | tags.append(f"TASK_ID={self.task_id}.{index}") 243 | if is_retry: 244 | logger.info('Will retry task "%s"', jobname) 245 | else: 246 | logger.info('Will start task "%s"', jobname) 247 | if is_retry: 248 | retry_args = self.retry_args or self.init_args 249 | cmd = command + retry_args 250 | else: 251 | cmd = command + self.init_args 252 | 253 | jobid = manager.schedule_script(cmd, tags=tags, units=self.units, project_id=self.project_id) 254 | if jobid is not None: 255 | logger.info('Scheduled task "%s" (%s)', jobname, jobid) 256 | self.append_jobid(jobid) 257 | return jobid 258 | return None 259 | 260 | 261 | class SpiderTask(BaseTask): 262 | """ 263 | A simple spider task. 264 | """ 265 | 266 | def __init__( 267 | self, 268 | task_id, 269 | spider, 270 | tags=None, 271 | units=None, 272 | retries=1, 273 | wait_time=None, 274 | on_finish=None, 275 | job_settings=None, 276 | **spider_args, 277 | ): 278 | super().__init__(task_id, tags, units, retries, None, wait_time, on_finish) 279 | self.spider = spider 280 | self.__spider_args = spider_args 281 | self.__job_settings = job_settings 282 | 283 | def get_spider_args(self): 284 | spider_args = self.__spider_args 285 | if self.__job_settings is not None: 286 | spider_args.update({"job_settings": self.__job_settings}) 287 | return spider_args 288 | 289 | def as_jobgraph_dict(self) -> JobGraphDict: 290 | jdict = super().as_jobgraph_dict() 291 | jdict.update({"spider": self.spider, "spider_args": self.get_spider_args()}) 292 | return jdict 293 | 294 | def get_parallel_jobs(self): 295 | return 1 296 | 297 | def run(self, manager: GraphManagerProtocol, is_retry=False, index: Optional[int] = None) -> Optional[JobKey]: 298 | assert index is None, "Spider Task don't support parallelization." 299 | jobname = f"{manager.name}/{self.task_id}" 300 | if is_retry: 301 | logger.info('Will retry spider "%s"', jobname) 302 | else: 303 | logger.info('Will start spider "%s"', jobname) 304 | tags = [] 305 | if self.tags is not None: 306 | tags.extend(self.tags) 307 | tags.append(f"TASK_ID={self.task_id}") 308 | jobid = manager.schedule_spider( 309 | self.spider, tags=tags, units=self.units, project_id=self.project_id, **self.get_spider_args() 310 | ) 311 | if jobid is not None: 312 | logger.info('Scheduled spider "%s" (%s)', jobname, jobid) 313 | self.append_jobid(jobid) 314 | return jobid 315 | return None 316 | -------------------------------------------------------------------------------- /shub_workflow/graph/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Tuple, Optional, cast, Literal 3 | 4 | from shub_workflow.base import WorkFlowManager 5 | from shub_workflow.script import JobKey 6 | 7 | _SCHEDULED_RE = re.compile(r"Scheduled (?:(task|spider) \"(.+?)\" \()?.*?(\d+/\d+/\d+)\)?", re.I) 8 | 9 | 10 | def _search_scheduled_line(txt: str) -> Optional[Tuple[str, str, str]]: 11 | """ 12 | graph manager scheduled task 13 | >>> _search_scheduled_line('Scheduled task "totalwine/productsJob" (168012/20/62)') 14 | ('task', 'totalwine/productsJob', '168012/20/62') 15 | 16 | graph manager scheduled spider 17 | >>> _search_scheduled_line('Scheduled spider "totalwine/storesJob" (168012/27/2)') 18 | ('spider', 'totalwine/storesJob', '168012/27/2') 19 | """ 20 | m = _SCHEDULED_RE.search(txt) 21 | if m is not None: 22 | return cast(Tuple[Literal["task", "spider"], str, JobKey], m.groups()) 23 | return None 24 | 25 | 26 | def get_scheduled_jobs_specs(manager: WorkFlowManager, job_ids: List[JobKey]) -> List[Tuple[str, str, str]]: 27 | """ 28 | Return the jobs specs of the jobs scheduled by the jobs identified 29 | by given job_ids 30 | 31 | Each job spec is a 3-element tuple which contains, in order: 32 | - the kind of task job (spider/task) 33 | - the complete id name of the task job 34 | - the job id of the of the task job 35 | """ 36 | scheduled_jobs = [] 37 | for jobid in job_ids: 38 | project_id = jobid.split("/")[0] 39 | job = manager.get_project(project_id).jobs.get(jobid) 40 | for logline in job.logs.iter(): 41 | if "message" not in logline: 42 | continue 43 | m = _search_scheduled_line(logline["message"]) 44 | if m is not None: 45 | scheduled_jobs.append(m) 46 | return scheduled_jobs 47 | -------------------------------------------------------------------------------- /shub_workflow/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrapinghub/shub-workflow/bfce391181ecb99b8fcb7593278cef5af1e38719/shub_workflow/py.typed -------------------------------------------------------------------------------- /shub_workflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import os 4 | import hashlib 5 | from typing import Optional 6 | 7 | from scrapy.settings import BaseSettings 8 | from tenacity import retry, retry_if_exception_type, before_sleep_log, stop_after_attempt, wait_fixed 9 | from scrapinghub.client.exceptions import ServerError 10 | from requests.exceptions import ReadTimeout, ConnectionError, HTTPError 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def hashstr(text: str) -> str: 17 | u = hashlib.sha1() 18 | u.update(text.encode("utf8")) 19 | return u.hexdigest() 20 | 21 | 22 | def resolve_shub_jobkey() -> Optional[str]: 23 | return os.environ.get("SHUB_JOBKEY") 24 | 25 | 26 | def resolve_project_id(project_id=None) -> int: 27 | """ 28 | Gets project id from following sources in following order of precedence: 29 | - default parameter values 30 | - environment variables 31 | - scrapinghub.yml file 32 | 33 | in order to allow to use codes that needs HS or dash API, 34 | either locally or from scrapinghub, correctly configured 35 | """ 36 | if project_id: 37 | try: 38 | return int(project_id) 39 | except ValueError: 40 | pass 41 | else: 42 | # read from environment only if not explicitly provided 43 | if os.environ.get("PROJECT_ID") is not None: 44 | return int(os.environ["PROJECT_ID"]) 45 | 46 | # for ScrapyCloud jobs: 47 | jobkey = resolve_shub_jobkey() 48 | if jobkey: 49 | return int(jobkey.split("/")[0]) 50 | 51 | # read from scrapinghub.yml 52 | try: 53 | from shub.config import load_shub_config 54 | 55 | cfg = load_shub_config() 56 | try: 57 | project_id = project_id or "default" 58 | return int(cfg.get_project_id(project_id)) 59 | except Exception: 60 | logger.error(f"Project entry '{project_id}' not found in scrapinghub.yml.") 61 | except ImportError: 62 | logger.error("Install shub package if want to access scrapinghub.yml.") 63 | except TypeError: 64 | logger.error("Default project entry not available in scrapinghub.yml.") 65 | 66 | raise ValueError( 67 | "No default project id found. Use either PROJECT_ID env. variable or set 'default' entry in scrapinghub.yml, " 68 | "or use --project-id with a project numeric id or an existing entry in scrapinghub.yml." 69 | ) 70 | 71 | 72 | MINS_IN_A_DAY = 24 * 60 73 | ONE_MIN_IN_S = 60 74 | 75 | 76 | DASH_RETRY_MAX = int(os.environ.get("DASH_RETRY_MAX", MINS_IN_A_DAY)) 77 | DASH_RETRY_WAIT_SECS = int(os.environ.get("DASH_RETRY_WAIT_SECS", ONE_MIN_IN_S)) 78 | DASH_RETRY_LOGGING_LEVEL = os.environ.get("DASH_RETRY_LOGGING_LEVEL", "ERROR") 79 | 80 | 81 | dash_retry_decorator = retry( 82 | # ServerError is the only ScrapinghubAPIError that should be retried. Don't capture ScrapinghubAPIError here 83 | retry=retry_if_exception_type((ServerError, ReadTimeout, ConnectionError, HTTPError)), 84 | before_sleep=before_sleep_log(logger, getattr(logging, DASH_RETRY_LOGGING_LEVEL)), 85 | reraise=True, 86 | stop=stop_after_attempt(DASH_RETRY_MAX), 87 | wait=wait_fixed(DASH_RETRY_WAIT_SECS), 88 | ) 89 | 90 | _settings_warning_issued = False 91 | 92 | 93 | def kumo_settings(): 94 | global _settings_warning_issued 95 | settings = {} 96 | shub_job_data = json.loads(os.environ.get("SHUB_SETTINGS", "{}")) 97 | if shub_job_data: 98 | settings.update(shub_job_data["project_settings"]) 99 | settings.update(shub_job_data["spider_settings"]) 100 | elif not _settings_warning_issued: 101 | logger.warning("Couldn't find Dash project settings, probably not running in Dash") 102 | _settings_warning_issued = True 103 | return settings 104 | 105 | 106 | def get_project_settings() -> BaseSettings: 107 | from scrapy.utils.project import get_project_settings as scrapy_get_project_settings # pylint: disable=import-error 108 | 109 | settings = scrapy_get_project_settings() 110 | settings.setdict(kumo_settings(), priority="project") 111 | try: 112 | # test sh_scrapy is available 113 | import sh_scrapy # noqa: F401 114 | 115 | settings.setdict({"STATS_CLASS": "sh_scrapy.stats.HubStorageStatsCollector"}, priority="cmdline") 116 | except ImportError: 117 | pass 118 | return settings 119 | 120 | 121 | def get_kumo_loglevel(default="INFO"): 122 | loglevel = kumo_settings().get("LOG_LEVEL", default) 123 | return getattr(logging, loglevel) 124 | -------------------------------------------------------------------------------- /shub_workflow/utils/alert_sender.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Callable 3 | 4 | from shub_workflow.script import BaseScript 5 | 6 | LOG = logging.getLogger(__name__) 7 | 8 | 9 | class AlertSenderMixin(BaseScript): 10 | """ 11 | A class for adding slack alert capabilities to a shub_workflow class. 12 | """ 13 | 14 | default_subject = "No Subject" 15 | 16 | def __init__(self): 17 | self.messages: List[str] = [] 18 | self.registered_senders: List[Callable[[], None]] = [] 19 | super().__init__() 20 | 21 | def add_argparser_options(self): 22 | super().add_argparser_options() 23 | self.argparser.add_argument("--subject", help="Set alert message subject.", default=self.default_subject) 24 | 25 | def append_message(self, message: str): 26 | self.messages.append(message) 27 | 28 | def register_sender_method(self, sender: Callable[[], None]): 29 | self.registered_senders.append(sender) 30 | 31 | def send_messages(self): 32 | for sender in self.registered_senders: 33 | try: 34 | sender() 35 | except Exception as e: 36 | LOG.error(repr(e)) 37 | -------------------------------------------------------------------------------- /shub_workflow/utils/clone_job.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility for cloning ScrapyCloud jobs 3 | Features tagging of cloned from/to jobs (both source and destination) and avoids to clone source jobs already cloned. 4 | By default cloned jobs are scheduled in the same project as source job. If --project-id is given, target project 5 | is overriden. 6 | """ 7 | import logging 8 | from typing import Optional, List 9 | 10 | from scrapinghub.client.jobs import Job 11 | from scrapinghub import DuplicateJobError 12 | 13 | from shub_workflow.script import BaseScript, JobKey 14 | from shub_workflow.utils import dash_retry_decorator 15 | 16 | 17 | _LOG = logging.getLogger(__name__) 18 | 19 | 20 | def _transform_cmd(job_cmd): 21 | if isinstance(job_cmd, list): 22 | return " ".join(["'%s'" % cmd for cmd in job_cmd[1:]]) 23 | 24 | return job_cmd 25 | 26 | 27 | _COPIED_FROM_META = { 28 | "job_cmd": ("cmd_args", _transform_cmd), 29 | "units": (None, None), 30 | "spider_args": ("job_args", None), 31 | "tags": ("add_tag", None), 32 | "job_settings": (None, None), 33 | } 34 | 35 | 36 | class BaseClonner(BaseScript): 37 | 38 | MAX_CLONES = 10 39 | 40 | def is_cloned(self, jobkey: JobKey): 41 | for tag in self.get_job_tags(jobkey): 42 | if tag.startswith("ClonedTo="): 43 | _LOG.warning(f"Job {jobkey} already cloned. Skipped.") 44 | return True 45 | return False 46 | 47 | def job_params_hook(self, job_params): 48 | pass 49 | 50 | def clone_job( 51 | self, job_key: JobKey, units: Optional[int] = None, extra_tags: Optional[List[str]] = None 52 | ) -> Optional[Job]: 53 | extra_tags = extra_tags or [] 54 | job = self.get_job(job_key) 55 | 56 | spider = self.get_metadata_key(job.metadata, "spider") 57 | 58 | job_params = dict() 59 | for key, (target_key, _) in _COPIED_FROM_META.items(): 60 | 61 | if target_key is None: 62 | target_key = key 63 | 64 | job_params[target_key] = self.get_metadata_key(job.metadata, key) 65 | 66 | clone_number = 0 67 | add_tag = job_params.setdefault("add_tag", []) 68 | 69 | add_tag = list(filter(lambda x: not x.startswith("ClonedFrom="), add_tag)) 70 | add_tag.append(f"ClonedFrom={job_key}") 71 | 72 | for tag in add_tag: 73 | if tag.startswith("CloneNumber="): 74 | clone_number = int(tag.replace("CloneNumber=", "")) 75 | break 76 | 77 | clone_number += 1 78 | 79 | if clone_number >= self.MAX_CLONES: 80 | _LOG.warning(f"Already reached max clones allowed for job {job_key}.") 81 | return None 82 | 83 | add_tag = list(filter(lambda x: not x.startswith("CloneNumber="), add_tag)) 84 | add_tag.append(f"CloneNumber={clone_number}") 85 | 86 | add_tag.extend(extra_tags) 87 | job_params["add_tag"] = add_tag 88 | if units is not None: 89 | job_params["units"] = units 90 | 91 | self.job_params_hook(job_params) 92 | 93 | for key, (target_key, transform) in _COPIED_FROM_META.items(): 94 | 95 | target_key = target_key or key 96 | 97 | if transform is None: 98 | 99 | def transform(x): 100 | return x 101 | 102 | job_params[target_key] = transform(job_params[target_key]) 103 | 104 | project_id, _, _ = job_key.split("/") 105 | project = self.get_project(self.project_id or project_id) 106 | new_job = self.schedule_generic(project, spider, **job_params) 107 | if new_job is not None: 108 | _LOG.info("Cloned %s to %s", job_key, new_job.key) 109 | jobtags = self.get_metadata_key(job.metadata, "tags") 110 | jobtags.append(f"ClonedTo={new_job.key}") 111 | self._update_metadata(job.metadata, {"tags": jobtags}) 112 | 113 | return new_job 114 | 115 | @dash_retry_decorator 116 | def schedule_generic(self, project, spider, **job_params) -> Optional[Job]: 117 | try: 118 | return project.jobs.run(spider, **job_params) 119 | except DuplicateJobError as e: 120 | _LOG.error(str(e)) 121 | return None 122 | 123 | 124 | class CloneJobScript(BaseClonner): 125 | 126 | flow_id_required = False 127 | 128 | @property 129 | def description(self): 130 | return __doc__ 131 | 132 | def parse_project_id(self, args): 133 | project_id = super().parse_project_id(args) 134 | if project_id: 135 | return project_id 136 | if args.key: 137 | return args.key[0].split("/")[0] 138 | 139 | def add_argparser_options(self): 140 | super().add_argparser_options() 141 | self.argparser.add_argument( 142 | "key", 143 | type=str, 144 | nargs="+", 145 | default=[], 146 | help="Target job key. Can be given multiple times. All must be in same project.", 147 | ) 148 | self.argparser.add_argument("--units", help="Set number of units. Default is the same as cloned job.", type=int) 149 | 150 | def run(self): 151 | keys = list(filter(lambda x: not self.is_cloned(x), self.args.key)) 152 | for job_key in keys: 153 | try: 154 | self.clone_job(job_key, self.args.units, self.args.children_tag) 155 | except Exception as e: 156 | _LOG.error("Could not restart job %s: %s", job_key, e) 157 | 158 | 159 | if __name__ == "__main__": 160 | logging.basicConfig( 161 | format="%(asctime)s %(name)s [%(levelname)s]: %(message)s", 162 | level=logging.DEBUG, 163 | ) 164 | script = CloneJobScript() 165 | script.run() 166 | -------------------------------------------------------------------------------- /shub_workflow/utils/contexts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from contextlib import contextmanager 3 | 4 | 5 | @contextmanager 6 | def script_args(argv): 7 | """Capture cmdline args, for easily test scripts""" 8 | old_argv = sys.argv 9 | try: 10 | sys.argv = ["fake_script.py"] + argv 11 | yield sys.argv 12 | finally: 13 | sys.argv = old_argv 14 | -------------------------------------------------------------------------------- /shub_workflow/utils/dupefilter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import abc 4 | import tempfile 5 | from typing import Union 6 | from typing import Container 7 | 8 | from typing_extensions import Protocol 9 | from sqlitedict import SqliteDict 10 | 11 | 12 | class DupesFilterProtocol(Protocol, Container[str]): 13 | @abc.abstractmethod 14 | def add(self, elem: str): 15 | ... 16 | 17 | @abc.abstractmethod 18 | def __contains__(self, element: object) -> bool: 19 | ... 20 | 21 | @abc.abstractmethod 22 | def close(self): 23 | ... 24 | 25 | 26 | class SqliteDictDupesFilter: 27 | def __init__(self): 28 | """ 29 | SqlteDict based dupes filter 30 | """ 31 | self.dupes_db_file = tempfile.mktemp() 32 | self.__filter: Union[SqliteDict, None] = None 33 | 34 | def __create_db(self): 35 | self.__filter = SqliteDict(self.dupes_db_file, flag="n", autocommit=True) 36 | 37 | def __contains__(self, element: object) -> bool: 38 | if self.__filter is None: 39 | self.__create_db() 40 | assert self.__filter is not None 41 | return element in self.__filter 42 | 43 | def add(self, element: str): 44 | if self.__filter is None: 45 | self.__create_db() 46 | assert self.__filter is not None 47 | self.__filter[element] = "-" 48 | 49 | def close(self): 50 | if self.__filter is not None: 51 | try: 52 | self.__filter.close() 53 | os.remove(self.dupes_db_file) 54 | except Exception: 55 | pass 56 | -------------------------------------------------------------------------------- /shub_workflow/utils/futils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | import uuid 4 | from typing import List, Generator, Callable 5 | from operator import itemgetter 6 | from glob import iglob 7 | from os import listdir, remove, environ, makedirs 8 | from os.path import exists as os_exists, dirname, getctime, basename 9 | from shutil import copyfile 10 | from datetime import datetime, timedelta, timezone 11 | 12 | try: 13 | from s3fs import S3FileSystem as OriginalS3FileSystem 14 | import boto3 15 | 16 | s3_enabled = True 17 | except ImportError: 18 | s3_enabled = False 19 | 20 | 21 | try: 22 | from shub_workflow.utils import gcstorage 23 | 24 | gcs_enabled = True 25 | except ImportError: 26 | gcs_enabled = False 27 | 28 | 29 | S3_PATH_RE = re.compile("s3://(.+?)/(.+)") 30 | _S3_ATTRIBUTE = "s3://" 31 | _GS_ATTRIBUTE = "gs://" 32 | BUFFER = 1024 * 1024 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def just_log_exception(exception): 39 | logger.error(str(exception)) 40 | for etype in (KeyboardInterrupt, SystemExit, ImportError): 41 | if isinstance(exception, etype): 42 | return False 43 | return True # retries any other exception 44 | 45 | 46 | if s3_enabled: 47 | 48 | class S3FileSystem(OriginalS3FileSystem): 49 | read_timeout = 120 50 | connect_timeout = 60 51 | 52 | 53 | def s3_path(path, is_folder=False): 54 | if not path: 55 | return "" 56 | path = path.strip() 57 | if is_folder and not path.endswith("/"): 58 | path = path + "/" 59 | return path[len(_S3_ATTRIBUTE):] 60 | 61 | 62 | def s3_credentials(key, secret, token, region=None): 63 | creds = dict( 64 | key=environ.get("AWS_ACCESS_KEY_ID", key), secret=environ.get("AWS_SECRET_ACCESS_KEY", secret), token=token 65 | ) 66 | 67 | result = { 68 | **creds, 69 | "config_kwargs": {"retries": {"max_attempts": 20}}, 70 | } 71 | region = region or environ.get("AWS_REGION") 72 | if region is not None: 73 | result.update({"client_kwargs": {"region_name": region}}) 74 | return result 75 | 76 | 77 | def check_s3_path(path): 78 | if path.startswith(_S3_ATTRIBUTE): 79 | if s3_enabled: 80 | return True 81 | raise ModuleNotFoundError( 82 | "S3 dependencies are not installed. Install shubw-workflow as shub-workflow[with-s3-tools]" 83 | ) 84 | return False 85 | 86 | 87 | def check_gcs_path(path): 88 | if path.startswith(_GS_ATTRIBUTE): 89 | if gcs_enabled: 90 | return True 91 | raise ModuleNotFoundError( 92 | "GCS dependencies are not installed. Install shubw-workflow as shub-workflow[with-gcs-tools]" 93 | ) 94 | return False 95 | 96 | 97 | def get_s3_bucket_keyname(s3path, aws_key, aws_secret, aws_token, **kwargs): 98 | region = kwargs.pop("region", None) 99 | region_name = kwargs.pop("region_name", region) 100 | botocore_session = kwargs.pop("botocore_session", None) 101 | profile_name = kwargs.pop("profile_name", None) 102 | 103 | session = boto3.Session( 104 | aws_access_key_id=aws_key, 105 | aws_secret_access_key=aws_secret, 106 | aws_session_token=aws_token, 107 | region_name=region_name, 108 | botocore_session=botocore_session, 109 | profile_name=profile_name, 110 | ) 111 | s3 = session.resource("s3") 112 | m = S3_PATH_RE.match(s3path) 113 | if m: 114 | bucket_name, path = m.groups() 115 | bucket = s3.Bucket(bucket_name) 116 | else: 117 | raise ValueError(f"Bad s3 path specification: {s3path}") 118 | return bucket, path, kwargs.pop("op_kwargs", {}) 119 | 120 | 121 | def get_file(path, *args, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 122 | op_kwargs = kwargs.pop("op_kwargs", {}) 123 | region = kwargs.pop("region", None) 124 | if check_s3_path(path): 125 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 126 | if "ACL" in op_kwargs: 127 | op_kwargs["acl"] = op_kwargs.pop("ACL") 128 | fp = fs.open(s3_path(path), *args, **op_kwargs) 129 | try: 130 | fp.name = path 131 | except Exception: 132 | pass 133 | elif check_gcs_path(path): 134 | fp = gcstorage.get_file(path, *args, **kwargs) 135 | else: 136 | fp = open(path, *args, **kwargs) 137 | 138 | return fp 139 | 140 | 141 | def get_object(path, *args, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 142 | fp = None 143 | if check_gcs_path(path): 144 | fp = gcstorage.get_object(path) 145 | return fp 146 | 147 | 148 | DOWNLOAD_CHUNK_SIZE = 500 * 1024 * 1024 149 | UPLOAD_CHUNK_SIZE = 100 * 1024 * 1024 150 | 151 | 152 | def download_file(path, dest=None, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 153 | if dest is None: 154 | dest = basename(path) 155 | if check_s3_path(path): 156 | with get_file(path, "rb", aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs) as r, open( 157 | dest, "wb" 158 | ) as w: 159 | total = 0 160 | while True: 161 | size = w.write(r.read(DOWNLOAD_CHUNK_SIZE)) 162 | total += size 163 | logger.info(f"Downloaded {total} bytes from {path}") 164 | if size < DOWNLOAD_CHUNK_SIZE: 165 | break 166 | elif check_gcs_path(path): 167 | gcstorage.download_file(path, dest) 168 | else: 169 | raise ValueError(f"Not supported file system fpr path {path}") 170 | 171 | 172 | def upload_file_obj(robj, dest, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 173 | if check_s3_path(dest): 174 | bucket, keyname, op_kwargs = get_s3_bucket_keyname(dest, aws_key, aws_secret, aws_token, **kwargs) 175 | bucket.upload_fileobj(robj, keyname, ExtraArgs=op_kwargs) 176 | logger.info(f"Uploaded file obj to {dest}.") 177 | else: 178 | raise ValueError("Not a supported cloud.") 179 | 180 | 181 | def upload_file(path, dest, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 182 | if dest.endswith("/"): 183 | dest = dest + basename(path) 184 | if check_s3_path(dest): 185 | bucket, keyname, op_kwargs = get_s3_bucket_keyname(dest, aws_key, aws_secret, aws_token, **kwargs) 186 | bucket.upload_file(path, keyname, ExtraArgs=op_kwargs) 187 | logger.info(f"Uploaded {path} to {dest}.") 188 | elif check_gcs_path(dest): 189 | gcstorage.upload_file(path, dest) 190 | else: 191 | raise ValueError("Not a supported cloud.") 192 | 193 | 194 | def get_glob(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> List[str]: 195 | region = kwargs.pop("region", None) 196 | if check_s3_path(path): 197 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 198 | fp = [_S3_ATTRIBUTE + p for p in fs.glob(s3_path(path))] 199 | else: 200 | fp = list(iglob(path)) 201 | 202 | return fp 203 | 204 | 205 | def cp_file(src_path, dest_path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 206 | if dest_path.endswith("/"): 207 | dest_path = dest_path + basename(src_path) 208 | region = kwargs.pop("region", None) 209 | if check_s3_path(src_path) and check_s3_path(dest_path): 210 | op_kwargs = kwargs.pop("op_kwargs", {}) 211 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 212 | fs.copy(s3_path(src_path), s3_path(dest_path), **op_kwargs) 213 | elif check_s3_path(src_path): 214 | download_file(src_path, dest_path, aws_key, aws_secret, aws_token, **kwargs) 215 | elif check_s3_path(dest_path): 216 | upload_file(src_path, dest_path, aws_key, aws_secret, aws_token, **kwargs) 217 | elif check_gcs_path(src_path) and check_gcs_path(dest_path): 218 | gcstorage.cp_file(src_path, dest_path) 219 | elif check_gcs_path(src_path): 220 | gcstorage.download_file(src_path, dest_path) 221 | elif check_gcs_path(dest_path): 222 | gcstorage.upload_file(src_path, dest_path) 223 | else: 224 | dname = dirname(dest_path) 225 | if dname: 226 | makedirs(dname, exist_ok=True) 227 | copyfile(src_path, dest_path) 228 | 229 | 230 | def mv_file(src_path, dest_path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 231 | if dest_path.endswith("/"): 232 | dest_path = dest_path + basename(src_path) 233 | if check_s3_path(src_path) and check_s3_path(dest_path): 234 | cp_file(src_path, dest_path, aws_key, aws_secret, aws_token, **kwargs) 235 | rm_file(src_path, aws_key, aws_secret, aws_token, **kwargs) 236 | elif check_s3_path(src_path): 237 | download_file(src_path, dest_path, aws_key, aws_secret, aws_token, **kwargs) 238 | rm_file(src_path, aws_key, aws_secret, aws_token, **kwargs) 239 | elif check_s3_path(dest_path): 240 | upload_file(src_path, dest_path, aws_key, aws_secret, aws_token, **kwargs) 241 | rm_file(src_path) 242 | elif check_gcs_path(src_path) and check_gcs_path(dest_path): 243 | gcstorage.mv_file(src_path, dest_path) 244 | elif check_gcs_path(src_path): 245 | gcstorage.download_file(src_path, dest_path) 246 | gcstorage.rm_file(src_path) 247 | elif check_gcs_path(dest_path): 248 | gcstorage.upload_file(src_path, dest_path) 249 | rm_file(src_path) 250 | else: 251 | dname = dirname(dest_path) 252 | if dname: 253 | makedirs(dname, exist_ok=True) 254 | cp_file(src_path, dest_path) 255 | rm_file(src_path) 256 | 257 | 258 | def rm_file(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 259 | region = kwargs.pop("region", None) 260 | if check_s3_path(path): 261 | op_kwargs = kwargs.pop("op_kwargs", {}) 262 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 263 | fs.rm(s3_path(path), **op_kwargs) 264 | elif check_gcs_path(path): 265 | gcstorage.rm_file(path) 266 | else: 267 | remove(path) 268 | 269 | 270 | def list_path(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> Generator[str, None, None]: 271 | """ 272 | More efficient boto3 based path listing, that accepts prefix 273 | """ 274 | if check_s3_path(path): 275 | bucket, path, op_kwargs = get_s3_bucket_keyname(path, aws_key, aws_secret, aws_token, **kwargs) 276 | for result in bucket.objects.filter(Prefix=path): 277 | yield f"s3://{bucket.name}/{result.key}" 278 | elif check_gcs_path(path): 279 | yield from gcstorage.list_path(path) 280 | else: 281 | listing = [] 282 | if path.strip().endswith("/"): 283 | path = path[:-1] 284 | try: 285 | listing = [f"{path}/{name}" for name in listdir(path)] 286 | except FileNotFoundError: 287 | pass 288 | yield from listing 289 | 290 | 291 | def list_folder(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> List[str]: 292 | if check_s3_path(path): 293 | region = kwargs.pop("region", None) 294 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 295 | 296 | try: 297 | path = s3_path(path, is_folder=True) 298 | listing = [f"s3://{name}" for name in fs.ls(path) if name != path] 299 | except FileNotFoundError: 300 | listing = [] 301 | elif check_gcs_path(path): 302 | listing = list(gcstorage.list_folder(path)) 303 | else: 304 | if path.strip().endswith("/"): 305 | path = path[:-1] 306 | try: 307 | listing = [f"{path}/{name}" for name in listdir(path)] 308 | except FileNotFoundError: 309 | listing = [] 310 | 311 | return listing 312 | 313 | 314 | def list_folder_in_ts_order( 315 | input_folder: str, aws_key=None, aws_secret=None, aws_token=None, **kwargs 316 | ) -> Generator[str, None, None]: 317 | results = [] 318 | for input_file in list_folder(input_folder, aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs): 319 | if check_s3_path(input_folder): 320 | with get_file(input_file, "rb", aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs) as f: 321 | results.append((input_file, f.info()["LastModified"])) 322 | else: 323 | results.append((input_file, getctime(input_file))) 324 | for n, _ in sorted(results, key=itemgetter(1)): 325 | yield n 326 | 327 | 328 | def list_folder_files_recursive( 329 | folder: str, aws_key=None, aws_secret=None, aws_token=None, **kwargs 330 | ) -> Generator[str, None, None]: 331 | for f in list_folder(folder, aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs): 332 | if f == folder: 333 | yield f 334 | else: 335 | yield from list_folder_files_recursive( 336 | f, aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs 337 | ) 338 | 339 | 340 | def s3_folder_size(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 341 | region = kwargs.pop("region", None) 342 | if check_s3_path(path): 343 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 344 | return sum(fs.du(s3_path(path, is_folder=True), deep=True).values()) 345 | 346 | 347 | def exists(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 348 | region = kwargs.pop("region", None) 349 | if check_s3_path(path): 350 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 351 | return fs.exists(path) 352 | if check_gcs_path(path): 353 | return gcstorage.exists(path) 354 | return os_exists(path) 355 | 356 | 357 | def empty_folder(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> List[str]: 358 | region = kwargs.pop("region", None) 359 | removed_files = [] 360 | if check_s3_path(path): 361 | fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs) 362 | for s3file in list_folder(path, aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token): 363 | try: 364 | fs.rm(s3file) 365 | removed_files.append(s3file) 366 | except FileNotFoundError: 367 | pass 368 | else: 369 | if path.strip().endswith("/"): 370 | path = path[:-1] 371 | for fsfile in listdir(path): 372 | remove(f"{path}/{fsfile}") 373 | removed_files.append(f"{path}/{fsfile}") 374 | 375 | return removed_files 376 | 377 | 378 | def touch(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs): 379 | if not exists(path): 380 | if not check_s3_path(path): 381 | dname = dirname(path) 382 | makedirs(dname, exist_ok=True) 383 | with get_file(path, "w", aws_key=aws_key, aws_secret=aws_secret, aws_token=aws_token, **kwargs) as f: 384 | f.write("") 385 | else: 386 | raise ValueError(f"File {path} already exists") 387 | 388 | 389 | class S3SessionFactory: 390 | 391 | """ 392 | Generate s3 temporal session credentials in cases where that is required 393 | """ 394 | 395 | def __init__(self, aws_key, aws_secret, aws_role, aws_external_id, expiration_margin_minutes=10): 396 | self.aws_role = aws_role 397 | self.aws_key = aws_key 398 | self.aws_secret = aws_secret 399 | self.aws_external_id = aws_external_id 400 | self.expiration_margin_minutes = expiration_margin_minutes 401 | self.credentials_expiration = None 402 | self.sts_client = boto3.client("sts", aws_access_key_id=aws_key, aws_secret_access_key=aws_secret) 403 | self.full_credentials = {} 404 | self.s3_client = None 405 | 406 | def get_credentials(self): 407 | """ 408 | Get a new set of credentials if the existing ones are close to expire, 409 | update the S3 client. 410 | """ 411 | if self.credentials_expiration is None or datetime.now(timezone.utc) > self.credentials_expiration - timedelta( 412 | minutes=self.expiration_margin_minutes 413 | ): 414 | session_name = str(uuid.uuid4()) 415 | self.full_credentials = self.sts_client.assume_role( 416 | RoleArn=self.aws_role, 417 | RoleSessionName=session_name, 418 | ExternalId=self.aws_external_id, 419 | )["Credentials"] 420 | self.credentials_expiration = self.full_credentials["Expiration"] 421 | self.s3_client = boto3.session.Session( 422 | aws_access_key_id=self.full_credentials["AccessKeyId"], 423 | aws_secret_access_key=self.full_credentials["SecretAccessKey"], 424 | aws_session_token=self.full_credentials["SessionToken"], 425 | ).client("s3") 426 | return { 427 | "aws_key": self.full_credentials["AccessKeyId"], 428 | "aws_secret": self.full_credentials["SecretAccessKey"], 429 | "aws_token": self.full_credentials["SessionToken"], 430 | } 431 | 432 | 433 | class S3Helper: 434 | 435 | cp_file: Callable 436 | download_file: Callable 437 | empty_folder: Callable 438 | exists: Callable 439 | get_file: Callable 440 | get_glob: Callable 441 | get_object: Callable 442 | list_folder: Callable 443 | list_folder_files_recursive: Callable 444 | list_folder_in_ts_order: Callable 445 | list_path: Callable 446 | mv_file: Callable 447 | rm_file: Callable 448 | s3_folder_size: Callable 449 | touch: Callable 450 | upload_file: Callable 451 | upload_file_obj: Callable 452 | 453 | def __init__( 454 | self, 455 | aws_key, 456 | aws_secret, 457 | aws_role=None, 458 | aws_external_id=None, 459 | expiration_margin_minutes=10, 460 | op_kwargs_by_method_name=None, 461 | ): 462 | self.s3_session_factory = None 463 | self.credentials = {"aws_key": aws_key, "aws_secret": aws_secret} 464 | if aws_role is not None and aws_external_id is not None: 465 | self.s3_session_factory = S3SessionFactory( 466 | aws_key, aws_secret, aws_role, aws_external_id, expiration_margin_minutes 467 | ) 468 | 469 | op_kwargs_by_method_name = op_kwargs_by_method_name or {} 470 | for method_name, _type in S3Helper.__annotations__.items(): 471 | if not hasattr(self, method_name) and _type is Callable: 472 | method = globals()[method_name] 473 | self._wrap_method(method, **op_kwargs_by_method_name.get(method_name, {})) 474 | 475 | def _wrap_method(self, method, **op_kwargs): 476 | def _method(*args, **kwargs): 477 | if self.s3_session_factory is not None: 478 | self.credentials = self.s3_session_factory.get_credentials() 479 | if op_kwargs: 480 | kwargs.update({"op_kwargs": op_kwargs}) 481 | kwargs.update(self.credentials) 482 | return method(*args, **kwargs) 483 | 484 | setattr(self, method.__name__, _method) 485 | 486 | 487 | class FSHelper(S3Helper): 488 | def __init__(self, aws_key=None, aws_secret=None, **kwargs): 489 | super().__init__(aws_key, aws_secret, **kwargs) 490 | -------------------------------------------------------------------------------- /shub_workflow/utils/gcstorage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import logging 4 | from typing import Generator, List 5 | 6 | from pkg_resources import resource_filename 7 | 8 | from google.cloud import storage 9 | 10 | 11 | _GS_FOLDER_RE = re.compile(r"gs://([-\w]+)/(.*)$") 12 | 13 | _LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | def get_credfile_path(module, resource, check_exists=True): 17 | credfile = resource_filename(module, resource) 18 | if not check_exists or os.path.exists(credfile): 19 | return credfile 20 | 21 | 22 | def set_credential_file_environ(module, resource, check_exists=True): 23 | credfile = get_credfile_path(module, resource, check_exists) 24 | 25 | assert credfile, "Google application credentials file does not exist." 26 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credfile 27 | 28 | 29 | def upload_file(src_path: str, dest_path: str): 30 | storage_client = storage.Client() 31 | m = _GS_FOLDER_RE.match(dest_path) 32 | if m is None: 33 | raise ValueError(f"Invalid destination {dest_path}") 34 | bucket_name, destination_blob_name = m.groups() 35 | bucket = storage_client.bucket(bucket_name) 36 | blob = bucket.blob(destination_blob_name) 37 | blob.upload_from_filename(src_path, retry=storage.retry.DEFAULT_RETRY) 38 | _LOGGER.info(f"File {src_path} uploaded to {dest_path}.") 39 | 40 | 41 | def list_path(path: str) -> Generator[str, None, None]: 42 | storage_client = storage.Client() 43 | m = _GS_FOLDER_RE.match(path) 44 | if m is None: 45 | raise ValueError(f"Invalid path {path} for GCS.") 46 | bucket_name, blob_prefix = m.groups() 47 | bucket = storage_client.bucket(bucket_name) 48 | 49 | for blob in bucket.list_blobs(prefix=blob_prefix, retry=storage.retry.DEFAULT_RETRY): 50 | yield f"gs://{bucket_name}/{blob.name}" 51 | 52 | 53 | def list_folder(path: str) -> List[str]: 54 | storage_client = storage.Client() 55 | m = _GS_FOLDER_RE.match(path) 56 | if m is None: 57 | raise ValueError(f"Invalid path {path} for GCS.") 58 | bucket_name, blob_prefix = m.groups() 59 | bucket = storage_client.bucket(bucket_name) 60 | 61 | start_offset = "" 62 | new_results = True 63 | result = [] 64 | while new_results: 65 | new_results = False 66 | for blob in bucket.list_blobs(prefix=blob_prefix, retry=storage.retry.DEFAULT_RETRY, start_offset=start_offset): 67 | new_results = True 68 | if blob_prefix: 69 | suffix = blob.name.split(blob_prefix, 1)[1] 70 | else: 71 | suffix = blob.name 72 | if "/" not in suffix.lstrip("/"): 73 | result.append(f"gs://{bucket_name}/{blob.name}") 74 | else: 75 | folder = (blob_prefix + suffix.split("/")[0]).lstrip("/") 76 | result.append(f"gs://{bucket_name}/{folder}/") 77 | start_offset = folder + "0" 78 | break 79 | else: 80 | break 81 | return result 82 | 83 | 84 | def download_file(path: str, dest: str): 85 | storage_client = storage.Client() 86 | m = _GS_FOLDER_RE.match(path) 87 | if m: 88 | bucket_name, blob_name = m.groups() 89 | else: 90 | raise ValueError(f"Invalid path {path} for GCS.") 91 | bucket = storage_client.bucket(bucket_name) 92 | blob = bucket.blob(blob_name) 93 | with open(dest, "wb") as w: 94 | blob.download_to_file(w, retry=storage.retry.DEFAULT_RETRY) 95 | _LOGGER.info(f"File {path} downloaded to {dest}.") 96 | 97 | 98 | def rm_file(path: str): 99 | storage_client = storage.Client() 100 | m = _GS_FOLDER_RE.match(path) 101 | if m: 102 | bucket_name, blob_name = m.groups() 103 | else: 104 | raise ValueError(f"Invalid path {path} for GCS.") 105 | bucket = storage_client.bucket(bucket_name) 106 | blob = bucket.blob(blob_name) 107 | blob.delete(retry=storage.retry.DEFAULT_RETRY) 108 | _LOGGER.info(f"Deleted {path}.") 109 | 110 | 111 | def mv_file(src: str, dest: str): 112 | storage_client = storage.Client() 113 | m = _GS_FOLDER_RE.match(src) 114 | assert m is not None, "Source must be in the format gs:///" 115 | src_bucket_name, src_blob_name = m.groups() 116 | 117 | m = _GS_FOLDER_RE.match(dest) 118 | assert m is not None, "Destination must be in the format gs:///" 119 | dest_bucket_name, dest_blob_name = m.groups() 120 | 121 | assert src_bucket_name == dest_bucket_name, "Source and destination bucket must be the same." 122 | 123 | bucket = storage_client.bucket(src_bucket_name) 124 | bucket.rename_blob(bucket.blob(src_blob_name), dest_blob_name) 125 | _LOGGER.info(f"Moved {src} to {dest}") 126 | 127 | 128 | def cp_file(src: str, dest: str): 129 | storage_client = storage.Client() 130 | m = _GS_FOLDER_RE.match(src) 131 | assert m is not None, "Source must be in the format gs:///" 132 | src_bucket_name, src_blob_name = m.groups() 133 | 134 | m = _GS_FOLDER_RE.match(dest) 135 | assert m is not None, "Destination must be in the format gs:///" 136 | dest_bucket_name, dest_blob_name = m.groups() 137 | 138 | assert src_bucket_name == dest_bucket_name, "Source and destination bucket must be the same." 139 | 140 | bucket = storage_client.bucket(src_bucket_name) 141 | dest_bucket = storage_client.bucket(dest_bucket_name) 142 | bucket.copy_blob(bucket.blob(src_blob_name), dest_bucket, dest_blob_name) 143 | _LOGGER.info(f"Copied {src} to {dest}") 144 | 145 | 146 | def exists(src: str) -> bool: 147 | storage_client = storage.Client() 148 | m = _GS_FOLDER_RE.match(src) 149 | assert m is not None, "Source must be in the format gs:///" 150 | src_bucket_name, src_blob_name = m.groups() 151 | 152 | bucket = storage_client.bucket(src_bucket_name) 153 | return bucket.blob(src_blob_name).exists() 154 | 155 | 156 | def get_object(src: str): 157 | storage_client = storage.Client() 158 | m = _GS_FOLDER_RE.match(src) 159 | assert m is not None, "Source must be in the format gs:///" 160 | src_bucket_name, src_blob_name = m.groups() 161 | 162 | bucket = storage_client.bucket(src_bucket_name) 163 | return bucket.blob(src_blob_name) 164 | 165 | 166 | def get_file(src: str, *args, **kwargs): 167 | return get_object(src).open(*args, **kwargs) 168 | -------------------------------------------------------------------------------- /shub_workflow/utils/monitor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import abc 3 | import time 4 | import logging 5 | import inspect 6 | import csv 7 | from io import StringIO 8 | from typing import Dict, Type, Tuple, Optional, Protocol, Union 9 | from datetime import timedelta, datetime 10 | from collections import Counter 11 | 12 | import dateparser 13 | from scrapy import Spider 14 | from prettytable import PrettyTable 15 | 16 | from shub_workflow.script import BaseScript, BaseScriptProtocol, SpiderName, JobDict 17 | from shub_workflow.utils.alert_sender import AlertSenderMixin 18 | from shub_workflow.contrib.slack import SlackSender 19 | 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | def _get_number(txt: str) -> Optional[int]: 24 | try: 25 | return int(txt) 26 | except ValueError: 27 | return None 28 | 29 | 30 | class BaseMonitorProtocol(BaseScriptProtocol, Protocol): 31 | 32 | @abc.abstractmethod 33 | def close(self): 34 | pass 35 | 36 | 37 | BASE_TARGET_SPIDER_STATS = ( 38 | "downloader/response_status_count/", 39 | "downloader/response_count", 40 | "item_scraped_count", 41 | "spider_exceptions/", 42 | "scrapy-zyte-api/429", 43 | "zyte_api_proxy/response/status/429", 44 | ) 45 | 46 | 47 | RESPONSE_STATUS_COUNT_RE = re.compile(r"downloader/response_status_count/(\d+)/(.+)") 48 | 49 | 50 | class SpiderStatsAggregatorMixin(BaseScriptProtocol): 51 | # stats aggregated from spiders. A tuple of stats regex prefixes. 52 | target_spider_stats: Tuple[str, ...] = () 53 | stats_only_total = True 54 | 55 | def aggregate_spider_stats(self, jobdict: JobDict, stats_added_prefix: str = ""): 56 | canonical = self.get_canonical_spidername(jobdict["spider"]) 57 | for statkey in jobdict.get("scrapystats") or {}: 58 | for statnameprefix in self.target_spider_stats + BASE_TARGET_SPIDER_STATS: 59 | if re.match(statnameprefix, statkey) is not None: 60 | value = jobdict["scrapystats"][statkey] 61 | if stats_added_prefix != canonical: 62 | if not self.stats_only_total: 63 | self.stats.inc_value(f"{stats_added_prefix}/{statkey}/{canonical}".strip("/"), value) 64 | self.stats.inc_value(f"{stats_added_prefix}/{statkey}/total".strip("/"), value) 65 | else: 66 | self.stats.inc_value(f"{stats_added_prefix}/{statkey}".strip("/"), value) 67 | 68 | 69 | class BaseMonitor(AlertSenderMixin, SpiderStatsAggregatorMixin, BaseScript, BaseMonitorProtocol): 70 | 71 | # a map from spiders classes to check, to a stats prefix to identify the aggregated stats. 72 | target_spider_classes: Dict[Type[Spider], str] = {Spider: ""} 73 | 74 | # - a map from script name into a tuple of 2-elem tuples (aggregating stat regex, aggregated stat prefix) 75 | # - the aggregating stat regex is used to match stat on target script 76 | # - the aggregated stat prefix is used to generate the monitor stat. The original stat name is appended to 77 | # the prefix. 78 | # - if a group is present in the regex, its value is used as suffix of the generate stat, instead of 79 | # the complete original stat name. 80 | target_script_stats: Dict[str, Tuple[Tuple[str, str], ...]] = {} 81 | 82 | # - a map from script name into a tuple of 2-elem tuples (aggregating log regex, aggregated stat name) 83 | # - the aggregating log regex must match log lines in the target script job, with a first group to extract a number 84 | # from it. If not a group number is extracted, the match alone aggregates 1. 85 | # - the aggregated stat name is the name where to aggregate the number extracted by the regex, plus a second group, 86 | # if exists. 87 | 88 | # - a map from script name into a tuple of 2-elem tuples (aggregating log regex, aggregated stat name) 89 | # - the aggregating log regex must match log lines in the target script job, with a group to extract a number 90 | # from it. If not a group number is extracted, the match alone aggregates 1. 91 | # - the final aggregated stat name is the specified aggregated stat name, plus a second non numeric group in the 92 | # match, if exists. 93 | target_script_logs: Dict[str, Tuple[Tuple[str, str], ...]] = {} 94 | 95 | # Define here ratios computations. Each 3-tuple contains: 96 | # - a string or regex that matches the numerator stat. You can use a regex group here for computing multiple ratios 97 | # - a string or regex that matches the denominator stat. You can use a regex group for matching numerator group and 98 | # have a different denominator for each one. If no regex group in denominator, there will be a single 99 | # denominator. 100 | # - the target stat. If there are numerator groups, this will be the target stat prefix. 101 | # Ratios computing are performed after the stats_postprocessing() method is called. It is a stat postprocessing 102 | # itself. 103 | stats_ratios: Tuple[Tuple[str, str, str], ...] = () 104 | 105 | # A tuple of 2-elem tuples each one with a stat regex and the name of the monitor instance method that will receive: 106 | # - start and end limit of the window (epoch seconds) 107 | # - the value of the stat 108 | # - one extra argument per regex group, if any. 109 | # Useful for adding monitor alerts or any kind of reaction according to stat value. 110 | stats_hooks: Tuple[Tuple[str, str], ...] = () 111 | 112 | # if True, only generate the totals and not per crawler stats. This has effect on default stats only, not the 113 | # custom ones added by the developer. 114 | stats_only_total = False 115 | 116 | # A tuple of string tuples for generating a report table. 117 | # The first line is the header. Following ones are the rows. Ensure that all tuples has the same length. 118 | report_table: Tuple[Tuple[str, ...], ...] = () 119 | 120 | # additional projects in multiproject projects 121 | additional_projects: Tuple[int, ...] = () 122 | 123 | def add_argparser_options(self): 124 | super().add_argparser_options() 125 | self.argparser.add_argument( 126 | "--period", "-p", type=int, default=86400, help="Time window period in seconds. Default: %(default)s" 127 | ) 128 | self.argparser.add_argument( 129 | "--end-time", 130 | "-e", 131 | type=str, 132 | help="""End side of the time window. By default it is just now. Format: any string 133 | that can be recognized by dateparser.""", 134 | ) 135 | self.argparser.add_argument( 136 | "--start-time", 137 | "-s", 138 | type=str, 139 | help=""" 140 | Starting side of the time window. By default it is the end side minus the period. Format: 141 | any string that can be recognized by dateparser. 142 | """, 143 | ) 144 | self.argparser.add_argument( 145 | "--generate-report", 146 | action="store_true", 147 | help=("Generate report table. See report_table attribute and generate_report() method."), 148 | ) 149 | self.argparser.add_argument( 150 | "--report-format", 151 | choices=("csv", "pretty", "pretty_with_borders", "pretty_with_tabs"), 152 | default="pretty", 153 | help="'pretty_with_tabs' is suitable for easy copy and paste over a spreadsheet.", 154 | ) 155 | self.argparser.add_argument( 156 | "--slack-report", action="store_true", help="Send report to slack. By default it just prints it in the log." 157 | ) 158 | 159 | def stats_postprocessing(self, start_limit, end_limit): 160 | """ 161 | Apply here any additional code for post processing stats, generate derivated stats, reports, etc. 162 | """ 163 | 164 | def generate_report(self): 165 | """ 166 | Facilitates the generation of printed/slack reports. 167 | """ 168 | header: Tuple[str, ...] = self.report_table[0] 169 | rows: Tuple[Tuple[str, ...], ...] = self.report_table[1:] 170 | 171 | table = PrettyTable( 172 | field_names=header, 173 | border=self.args.report_format == "pretty_with_borders", 174 | preserve_internal_border=self.args.report_format == "pretty_with_tabs", 175 | ) 176 | for row in rows: 177 | table.add_row(list(row)) 178 | 179 | if self.args.report_format == "csv": 180 | fp = StringIO() 181 | w = csv.writer(fp) 182 | w.writerow(table.field_names) 183 | w.writerows(table.rows) 184 | fp.seek(0) 185 | table_text = fp.read() 186 | fp.close() 187 | elif self.args.report_format == "pretty_with_tabs": 188 | table_text = str(table).replace(" | ", " \t ") 189 | else: 190 | table_text = str(table) 191 | 192 | if self.args.slack_report: 193 | table_text = "```" + table_text + "```" 194 | sender = SlackSender(self.project_settings) 195 | sender.send_slack_messages( 196 | [table_text], 197 | self.args.subject or "Stats Report", 198 | ) 199 | # avoid slack rate limit error when sending the rates alert to slack 200 | time.sleep(30) 201 | else: 202 | print(table_text) 203 | 204 | def run(self): 205 | end_limit = time.time() 206 | if self.args.end_time is not None and (dt := dateparser.parse(self.args.end_time)) is not None: 207 | end_limit = dt.timestamp() 208 | start_limit = end_limit - self.args.period 209 | if self.args.start_time and (dt := dateparser.parse(self.args.start_time)) is not None: 210 | start_limit = dt.timestamp() 211 | else: 212 | LOG.info(f"Period set: {timedelta(seconds=self.args.period)}") 213 | LOG.info(f"Start time: {str(datetime.fromtimestamp(start_limit))}") 214 | LOG.info(f"End time: {str(datetime.fromtimestamp(end_limit))}") 215 | 216 | for attr in dir(self): 217 | if attr.startswith("check_"): 218 | method = getattr(self, attr) 219 | if inspect.ismethod(method): 220 | check_name = attr.replace("check_", "") 221 | LOG.info(f"Checking {check_name}...") 222 | method(start_limit, end_limit) 223 | 224 | self.stats_postprocessing(start_limit, end_limit) 225 | self.run_stats_ratios() 226 | if self.args.generate_report: 227 | self.generate_report() 228 | self.run_stats_hooks(start_limit, end_limit) 229 | self.upload_stats() 230 | self.print_stats() 231 | self.close() 232 | 233 | def run_stats_ratios(self): 234 | """ 235 | Generate new ratio stats, based on definitions in stats_ratios attribute. 236 | """ 237 | 238 | stats_ratios = self.stats_ratios 239 | 240 | # append response status ratios 241 | for prefix in self.target_spider_classes.values(): 242 | if prefix: 243 | prefix = prefix + "/" 244 | for stat, value in list(self.stats.get_stats().items()): 245 | if (m := RESPONSE_STATUS_COUNT_RE.search(stat)) is not None and stat.startswith(prefix): 246 | status_code, spider = m.groups() 247 | stats_ratios += (( 248 | f"{prefix}downloader/response_status_count/{status_code}/{spider}", 249 | f"{prefix}downloader/response_count/{spider}", 250 | f"{prefix}downloader/response_count/rate/{status_code}/{spider}" 251 | ),) 252 | 253 | for num_regex, den_regex, target_prefix in stats_ratios: 254 | numerators: Dict[str, int] = Counter() 255 | denominators: Dict[str, int] = Counter() 256 | numerator_has_groups = False 257 | denominator_has_groups = False 258 | for stat, value in list(self.stats.get_stats().items()): 259 | if m := re.search(num_regex, stat): 260 | if m.groups(): 261 | source = m.groups()[0] 262 | numerators[source] += value 263 | numerator_has_groups = True 264 | else: 265 | numerators[stat] += value 266 | if m := re.search(den_regex, stat): 267 | if m.groups(): 268 | source = m.groups()[0] 269 | denominators[source] += value 270 | denominator_has_groups = True 271 | else: 272 | denominators[stat] += value 273 | for source, numer in numerators.items(): 274 | denominator = 0 275 | if denominator_has_groups: 276 | denominator = denominators.pop(source, 0) 277 | elif denominators: 278 | denominator = list(denominators.values())[0] 279 | denominators = {} 280 | if denominator > 0: 281 | target_stat = target_prefix 282 | if numerator_has_groups: 283 | target_stat += "/" + source 284 | self.stats.set_value(target_stat, round(numer / denominator, 4)) 285 | for source, denom in denominators.items(): 286 | target_stat = target_prefix 287 | if numerator_has_groups: 288 | target_stat += "/" + source 289 | self.stats.set_value(target_stat, round(0, 4)) 290 | 291 | def run_stats_hooks(self, start_limit, end_limit): 292 | for stat, val in self.stats.get_stats().items(): 293 | for stat_re, hook_name in self.stats_hooks: 294 | if (m := re.search(stat_re, stat)) is not None: 295 | hook = getattr(self, hook_name) 296 | hook(start_limit, end_limit, self.stats.get_value(stat), *m.groups()) 297 | 298 | def _get_stats_prefix_from_spider_class(self, spiderclass: Type[Spider]) -> str: 299 | for cls, prefix in self.target_spider_classes.items(): 300 | if issubclass(spiderclass, cls): 301 | return prefix 302 | return "" 303 | 304 | def spider_job_hook(self, jobdict: JobDict): 305 | """ 306 | This is called for every spider job retrieved, from the spiders declared on target_spider_classes, 307 | so additional per job customization can be added. 308 | """ 309 | 310 | def get_jobs_in_window(self, start_limit: int, end_limit: Union[int, float, None], **kwargs): 311 | end_limit = end_limit or float("inf") 312 | for project_id in (self.project_id,) + self.additional_projects: 313 | for jobdict in self.get_jobs(project_id=project_id, **kwargs): 314 | if "finished_time" in jobdict and jobdict["finished_time"] / 1000 < start_limit: 315 | break 316 | if "finished_time" in jobdict and jobdict["finished_time"] / 1000 > end_limit: 317 | continue 318 | yield jobdict 319 | 320 | def check_spiders(self, start_limit, end_limit): 321 | 322 | spiders: Dict[SpiderName, Type[Spider]] = {} 323 | for s in self.spider_loader.list(): 324 | subclass = self.spider_loader.load(s) 325 | if issubclass(subclass, tuple(self.target_spider_classes.keys())): 326 | spiders[SpiderName(s)] = subclass 327 | 328 | for jobcount, jobdict in enumerate( 329 | self.get_jobs_in_window( 330 | start_limit, 331 | end_limit, 332 | state=["finished"], 333 | meta=["spider", "finished_time", "scrapystats", "spider_args", "close_reason", "tags"], 334 | has_tag=[f"FLOW_ID={self.flow_id}"] if self.flow_id is not None else None, 335 | endts=int(end_limit * 1000), 336 | ), 337 | start=1, 338 | ): 339 | if jobdict["spider"] in spiders: 340 | self.spider_job_hook(jobdict) 341 | stats_added_prefix = self._get_stats_prefix_from_spider_class(spiders[jobdict["spider"]]) 342 | self.aggregate_spider_stats(jobdict, stats_added_prefix) 343 | if jobcount % 1000 == 0: 344 | LOG.info(f"Read {jobcount} jobs") 345 | 346 | def script_job_hook(self, jobdict: JobDict): 347 | """ 348 | This is called for every script job retrieved, from the scripts declared on target_script_stats, 349 | so additional per job customization can be added 350 | """ 351 | 352 | def check_scripts_stats(self, start_limit, end_limit): 353 | for script, regexes in self.target_script_stats.items(): 354 | plural = script.replace("py:", "").replace(".py", "") + "s" 355 | LOG.info(f"Checking {plural} stats ...") 356 | for jobdict in self.get_jobs_in_window( 357 | start_limit, 358 | end_limit, 359 | spider=script, 360 | state=["finished"], 361 | meta=["finished_time", "scrapystats", "close_reason", "job_cmd", "tags"], 362 | endts=int(end_limit * 1000), 363 | has_tag=[f"FLOW_ID={self.flow_id}"] if self.flow_id is not None else None, 364 | ): 365 | self.script_job_hook(jobdict) 366 | for key, val in jobdict.get("scrapystats", {}).items(): 367 | for regex, prefix in regexes: 368 | if (m := re.search(regex, key)) is not None: 369 | aggregated_stat_name = m.groups()[0] if m.groups() else m.group() 370 | self.stats.inc_value(f"{prefix}/{aggregated_stat_name}".strip("/"), val) 371 | if prefix: 372 | self.stats.inc_value(f"{prefix}/total".strip("/"), val) 373 | 374 | def check_script_logs(self, start_limit, end_limit): 375 | for script, regexes in self.target_script_logs.items(): 376 | plural = script.replace("py:", "").replace(".py", "") + "s" 377 | LOG.info(f"Checking {plural} logs ...") 378 | for jobdict in self.get_jobs_in_window( 379 | start_limit, 380 | None, 381 | spider=script, 382 | state=["running", "finished"], 383 | meta=["state", "finished_time"], 384 | has_tag=[f"FLOW_ID={self.flow_id}"] if self.flow_id is not None else None, 385 | ): 386 | if jobdict["state"] == "running" or jobdict["finished_time"] / 1000 > start_limit: 387 | job = self.get_job(jobdict["key"]) 388 | for logline in job.logs.iter(): 389 | if logline["time"] / 1000 < start_limit: 390 | continue 391 | if logline["time"] / 1000 > end_limit: 392 | break 393 | for regex, stat in regexes: 394 | if (m := re.search(regex, logline["message"])) is not None: 395 | stat_suffix = "" 396 | if gr := m.groups(): 397 | val = _get_number(gr[0]) 398 | if val is None: 399 | if len(gr) > 1: 400 | val = _get_number(gr[1]) 401 | stat_suffix = gr[0] 402 | if val is None: 403 | val = 1 404 | else: 405 | stat_suffix = gr[1] if len(gr) > 0 else "" 406 | else: 407 | val = 1 408 | self.stats.inc_value(stat, val) 409 | if stat_suffix: 410 | self.stats.inc_value(stat + f"/{stat_suffix}", val) 411 | 412 | def close(self): 413 | self.send_messages() 414 | -------------------------------------------------------------------------------- /shub_workflow/utils/sesemail.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from email.message import Message 4 | from email.mime.base import MIMEBase 5 | from email.mime.multipart import MIMEMultipart 6 | from email.mime.application import MIMEApplication 7 | from email.mime.text import MIMEText 8 | from email.mime.image import MIMEImage 9 | from typing import List, Optional, Dict 10 | 11 | import boto3 12 | from botocore.client import Config 13 | 14 | from shub_workflow.script import BaseScriptProtocol 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class SESHelper: 20 | 21 | DEFAULT_SES_REGION = "us-east-1" 22 | 23 | DEFAULT_EMAIL_SUBJECT_PREFIX = "[Zyte]" 24 | DEFAULT_EMAIL_SUBJECT = "Notification from Zyte" 25 | DEFAULT_FROM_ADDR: str 26 | 27 | def __init__(self, aws_key: str, aws_secret: str, aws_region: Optional[str] = None): 28 | self.aws_key = aws_key 29 | self.aws_secret = aws_secret 30 | assert self.aws_key and self.aws_secret, "SES Credentials not set." 31 | self.aws_region = aws_region or self.DEFAULT_SES_REGION 32 | 33 | def send_ses_email( 34 | self, 35 | from_addr: str, 36 | to_addrs: List[str], 37 | msg: Message, 38 | region: str = DEFAULT_SES_REGION, 39 | cc_addrs: Optional[List[str]] = None, 40 | bcc_addrs: Optional[List[str]] = None, 41 | reply_to: Optional[str] = None, 42 | ) -> Dict: 43 | 44 | config = Config(connect_timeout=60, retries={"max_attempts": 20}) 45 | client = boto3.client( 46 | "ses", region, aws_access_key_id=self.aws_key, aws_secret_access_key=self.aws_secret, config=config 47 | ) 48 | logger.info("Sending mail as %s to: %s", from_addr, to_addrs) 49 | msg["From"] = from_addr or self.DEFAULT_FROM_ADDR 50 | msg["To"] = ",".join(to_addrs) 51 | if cc_addrs: 52 | msg["cc"] = ",".join(cc_addrs) 53 | logger.info(f"CC to {cc_addrs}") 54 | if bcc_addrs: 55 | logger.info(f"BCC to {bcc_addrs}") 56 | if reply_to: 57 | msg["Reply-To"] = reply_to 58 | 59 | destinations = to_addrs + (cc_addrs or []) + (bcc_addrs or []) 60 | response = client.send_raw_email( 61 | Source=from_addr, Destinations=destinations, RawMessage={"Data": msg.as_string()} 62 | ) 63 | return response 64 | 65 | def build_email_message( 66 | self, 67 | body: str, 68 | image_attachments: Optional[List[str]] = None, 69 | text_attachments: Optional[List[str]] = None, 70 | other_attachments: Optional[List[Message]] = None, 71 | subject_prefix: Optional[str] = None, 72 | subject: Optional[str] = None, 73 | ) -> Message: 74 | subject_prefix = (subject_prefix or self.DEFAULT_EMAIL_SUBJECT_PREFIX).strip() 75 | subject_header = "" 76 | if subject_prefix: 77 | subject_header += subject_prefix + " " 78 | subject_header += subject or self.DEFAULT_EMAIL_SUBJECT 79 | 80 | msg = MIMEMultipart() 81 | msg["Subject"] = subject_header 82 | 83 | msg.attach(MIMEText(body)) 84 | 85 | for imgpath in image_attachments or []: 86 | imgatt = MIMEImage(open(imgpath, "rb").read()) 87 | imgatt.add_header("Content-Disposition", "attachment", filename=os.path.basename(imgpath)) 88 | msg.attach(imgatt) 89 | 90 | for path in text_attachments or []: 91 | textatt: MIMEBase 92 | if path.endswith(".gz"): 93 | textatt = MIMEApplication(open(path, "rb").read(), "gzip") 94 | else: 95 | textatt = MIMEText(open(path, "r").read()) 96 | textatt.add_header("Content-Disposition", "attachment", filename=os.path.basename(path)) 97 | msg.attach(textatt) 98 | 99 | for att in other_attachments or []: 100 | msg.attach(att) 101 | 102 | return msg 103 | 104 | 105 | class SESMailSenderMixin(BaseScriptProtocol): 106 | """Use this mixin for enabling ses email sending capabilities on your script class""" 107 | 108 | def __init__(self): 109 | self.notification_emails: List[str] = [] 110 | self.cc_emails: List[str] = [] 111 | self.bcc_emails: List[str] = [] 112 | super().__init__() 113 | self.seshelper = None 114 | try: 115 | self.seshelper = SESHelper( 116 | self.project_settings["AWS_EMAIL_ACCESS_KEY"], self.project_settings["AWS_EMAIL_SECRET_KEY"] 117 | ) 118 | except AssertionError: 119 | logger.warning("No SES credentials set. No mails will be sent.") 120 | 121 | def send_ses_email( 122 | self, 123 | body: str, 124 | subject: str, 125 | text_attachments=None, 126 | image_attachments=None, 127 | ): 128 | if self.notification_emails and self.seshelper is not None: 129 | msg = self.seshelper.build_email_message( 130 | body, 131 | text_attachments=text_attachments, 132 | image_attachments=image_attachments, 133 | subject=subject, 134 | ) 135 | self.seshelper.send_ses_email( 136 | "noreply@zyte.com", self.notification_emails, msg, cc_addrs=self.cc_emails, bcc_addrs=self.bcc_emails 137 | ) 138 | -------------------------------------------------------------------------------- /shub_workflow/utils/watchdog.py: -------------------------------------------------------------------------------- 1 | """ 2 | Watchdog script. 3 | 4 | It checks for failed scripts, and send alerts in case issues are detected. Optionally it can clone 5 | failed scripts, typically workflow managers and other standalone scripts. Scripts handled by workflow managers 6 | should not be retried here. 7 | 8 | Spiders are handled by crawl managers, so they are not handled here. 9 | 10 | Script configurable attributes: 11 | 12 | MONITORED_SCRIPTS - A tuple containing all scripts to check (each one in the format "py:scriptname.py") 13 | CLONE_SCRIPTS - A tuple containing all scripts that must be cloned in case of failed. This tuple must be 14 | a subset of MONITORED_SCRIPTS. 15 | 16 | 17 | """ 18 | 19 | import time 20 | import abc 21 | import logging 22 | from typing import List, Tuple 23 | 24 | from shub_workflow.script import JobKey, Outcome, JobDict 25 | from shub_workflow.utils.clone_job import BaseClonner 26 | 27 | 28 | WATCHDOG_CHECKED_TAG = "WATCHDOG_CHECKED=True" 29 | LOGGER = logging.getLogger(__name__) 30 | LOGGER.setLevel(logging.INFO) 31 | 32 | 33 | class WatchdogBaseScript(BaseClonner): 34 | 35 | MONITORED_SCRIPTS: Tuple[str, ...] = () 36 | CLONE_SCRIPTS: Tuple[str, ...] = () 37 | DEFAULT_SPIDER_MAX_RUNNING_TIME_SECS: int = 3600 * 24 * 365 38 | CHECK_RUNNING_SPIDERS = False 39 | 40 | def __init__(self) -> None: 41 | super().__init__() 42 | self.failed_jobs: List[Tuple[str, JobKey, Outcome]] = [] 43 | self.__notification_lines: List[str] = [] 44 | 45 | def add_argparser_options(self): 46 | super().add_argparser_options() 47 | self.argparser.add_argument("period", help="How much go to past to check (in hours)", type=int) 48 | 49 | def append_notification_line(self, line: str): 50 | self.__notification_lines.append(line) 51 | 52 | def get_notification_lines(self) -> List[str]: 53 | return self.__notification_lines 54 | 55 | def run(self) -> None: 56 | self.check_failed_scripts() 57 | if self.CHECK_RUNNING_SPIDERS: 58 | self.check_running_spider_jobs() 59 | self.close() 60 | 61 | def check_failed_scripts(self) -> None: 62 | now = time.time() 63 | for script in self.MONITORED_SCRIPTS: 64 | count = 0 65 | for job in self.get_jobs( 66 | spider=script, 67 | state=["finished"], 68 | meta=["finished_time", "close_reason"], 69 | lacks_tag=WATCHDOG_CHECKED_TAG, 70 | ): 71 | age = (now - job["finished_time"] / 1000) / 3600 72 | if age > self.args.period: 73 | break 74 | count += 1 75 | if job["close_reason"] != "finished": 76 | msg = ( 77 | f"Failed task: {script} (job: https://app.zyte.com/p/{job['key']}) " 78 | f"(Reason: {job['close_reason']})" 79 | ) 80 | LOGGER.error(msg) 81 | new_job = None 82 | if script in self.CLONE_SCRIPTS: 83 | new_job = self.clone_job(job["key"]) 84 | if new_job is None: 85 | self.append_notification_line(msg) 86 | self.add_job_tags(job["key"], tags=[WATCHDOG_CHECKED_TAG]) 87 | LOGGER.info(f"Checked {count} {script} jobs") 88 | 89 | def get_spider_job_max_running_time(self, job: JobDict) -> int: 90 | """ 91 | Return max running job time in seconds 92 | """ 93 | return self.DEFAULT_SPIDER_MAX_RUNNING_TIME_SECS 94 | 95 | def check_running_spider_jobs(self) -> None: 96 | now = time.time() 97 | for job in self.get_jobs( 98 | state=["running"], meta=["spider", "spider_args", "running_time", "job_settings"] 99 | ): 100 | if job["spider"].startswith("py:"): 101 | continue 102 | max_seconds = self.get_spider_job_max_running_time(job) 103 | if max_seconds: 104 | running_time = now - job["running_time"] / 1000 105 | if running_time > int(max_seconds) * 1.1: 106 | self.finish(job["key"], close_reason="cancelled (watchdog)") 107 | msg = f"Cancelled job https://app.zyte.com/p/{job['key']} (running for {int(running_time)} seconds)" 108 | LOGGER.warning(msg) 109 | self.append_notification_line(msg) 110 | 111 | def close(self) -> None: 112 | if self.__notification_lines: 113 | self.send_alert() 114 | 115 | @abc.abstractmethod 116 | def send_alert(self): 117 | ... 118 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrapinghub/shub-workflow/bfce391181ecb99b8fcb7593278cef5af1e38719/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_base_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import StringIO 3 | from unittest import TestCase 4 | from unittest.mock import patch 5 | 6 | from shub_workflow.base import WorkFlowManager 7 | from shub_workflow.utils.contexts import script_args 8 | 9 | 10 | @patch("shub_workflow.base.WorkFlowManager.get_job_tags") 11 | @patch("shub_workflow.base.WorkFlowManager._update_metadata") 12 | @patch("shub_workflow.script.BaseScript.get_sc_project_settings", new=lambda _: {}) 13 | class WorkFlowManagerTest(TestCase): 14 | def setUp(self): 15 | os.environ["SH_APIKEY"] = "ffff" 16 | os.environ["PROJECT_ID"] = "999" 17 | 18 | @patch("sys.stderr", new_callable=StringIO) 19 | def test_name_required_not_set(self, mocked_stderr, mocked_update_metadata, mocked_get_job_tags): 20 | class TestManager(WorkFlowManager): 21 | def workflow_loop(self): 22 | return True 23 | 24 | mocked_get_job_tags.side_effect = [[], []] 25 | 26 | with script_args([]): 27 | with self.assertRaises(SystemExit): 28 | TestManager() 29 | self.assertTrue("the following arguments are required: name" in mocked_stderr.getvalue()) 30 | 31 | def test_name_required_set(self, mocked_update_metadata, mocked_get_job_tags): 32 | class TestManager(WorkFlowManager): 33 | def workflow_loop(self): 34 | return True 35 | 36 | mocked_get_job_tags.side_effect = [[], []] 37 | 38 | with script_args(["my_fantasy_name"]): 39 | manager = TestManager() 40 | self.assertEqual(manager.name, "my_fantasy_name") 41 | self.assertEqual(manager.project_id, 999) 42 | self.assertEqual(manager.get_project().key, '999') 43 | 44 | @patch("shub_workflow.base.WorkFlowManager._check_resume_workflow") 45 | def test_check_resume_workflow_not_called( 46 | self, mocked_check_resume_workflow, mocked_update_metadata, mocked_get_job_tags 47 | ): 48 | class TestManager(WorkFlowManager): 49 | def workflow_loop(self): 50 | return True 51 | 52 | mocked_get_job_tags.side_effect = [[], []] 53 | 54 | with script_args(["my_fantasy_name"]): 55 | manager = TestManager() 56 | self.assertEqual(manager.name, "my_fantasy_name") 57 | 58 | manager._on_start() 59 | self.assertFalse(mocked_check_resume_workflow.called) 60 | 61 | @patch("shub_workflow.base.WorkFlowManager._check_resume_workflow") 62 | def test_check_resume_workflow_called( 63 | self, mocked_check_resume_workflow, mocked_update_metadata, mocked_get_job_tags 64 | ): 65 | class TestManager(WorkFlowManager): 66 | def workflow_loop(self): 67 | return True 68 | 69 | mocked_get_job_tags.side_effect = [[], []] 70 | 71 | with script_args(["my_fantasy_name", "--flow-id=3456"]): 72 | manager = TestManager() 73 | self.assertEqual(manager.name, "my_fantasy_name") 74 | 75 | manager._on_start() 76 | self.assertTrue(mocked_check_resume_workflow.called) 77 | 78 | def test_project_id_override(self, mocked_update_metadata, mocked_get_job_tags): 79 | class TestManager(WorkFlowManager): 80 | def workflow_loop(self): 81 | return True 82 | 83 | mocked_get_job_tags.side_effect = [[], []] 84 | 85 | with script_args(["my_fantasy_name", "--project-id=888"]): 86 | manager = TestManager() 87 | self.assertEqual(manager.project_id, 888) 88 | self.assertEqual(manager.get_project().key, '888') 89 | -------------------------------------------------------------------------------- /tests/test_crawl_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import TestCase 3 | from unittest.mock import patch 4 | 5 | from shub_workflow.crawl import CrawlManager, PeriodicCrawlManager, GeneratorCrawlManager 6 | from shub_workflow.utils.contexts import script_args 7 | from shub_workflow.script import SpiderName, Outcome 8 | 9 | 10 | class TestManager(CrawlManager): 11 | 12 | name = "test" 13 | 14 | 15 | class PeriodicTestManager(PeriodicCrawlManager): 16 | 17 | name = "test" 18 | 19 | 20 | class ListTestManager(GeneratorCrawlManager): 21 | 22 | name = "test" 23 | default_max_jobs = 2 24 | 25 | def set_parameters_gen(self): 26 | parameters_list = [ 27 | {"argA": "valA"}, 28 | {"argA": "valB"}, 29 | {"argB": "valC"}, 30 | ] 31 | for args in parameters_list: 32 | yield args 33 | 34 | def bad_outcome_hook(self, spider, outcome, job_args_override, jobkey): 35 | spider_args = job_args_override.setdefault("spider_args", {}) 36 | if "argR" not in spider_args: 37 | spider_args.update({"argR": "valR"}) 38 | self.add_job(spider, job_args_override) 39 | elif "argS" not in spider_args: 40 | spider_args.update({"argS": "valS"}) 41 | self.add_job(spider, job_args_override) 42 | 43 | 44 | class TestManagerWithSpider(CrawlManager): 45 | 46 | name = "test" 47 | spider = SpiderName("myimplicitspider") 48 | 49 | 50 | @patch("shub_workflow.script.BaseScript.get_jobs") 51 | @patch("shub_workflow.script.BaseScript.add_job_tags") 52 | @patch("shub_workflow.script.BaseScript.get_sc_project_settings", new=lambda _: {}) 53 | @patch("shub_workflow.script.BaseScript.get_metadata_key", new=lambda s, m, k: {}) 54 | class CrawlManagerTest(TestCase): 55 | def setUp(self): 56 | os.environ["SH_APIKEY"] = "ffff" 57 | os.environ["PROJECT_ID"] = "999" 58 | 59 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 60 | def test_schedule_spider(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 61 | 62 | with script_args(["myspider"]): 63 | manager = TestManager() 64 | 65 | mocked_super_schedule_spider.side_effect = ["999/1/1"] 66 | manager._on_start() 67 | 68 | # first loop: schedule spider 69 | result = next(manager._run_loops()) 70 | 71 | self.assertTrue(result) 72 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 73 | mocked_super_schedule_spider.assert_any_call("myspider", units=None, job_settings={}) 74 | 75 | # second loop: spider still running. Continue. 76 | manager.is_finished = lambda jobkey: None 77 | result = next(manager._run_loops()) 78 | self.assertTrue(result) 79 | 80 | # third loop: spider is finished. Stop. 81 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/1" else None 82 | mocked_super_schedule_spider.reset_mock() 83 | result = next(manager._run_loops()) 84 | 85 | self.assertFalse(result) 86 | self.assertFalse(mocked_super_schedule_spider.called) 87 | 88 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 89 | def test_schedule_implicit_spider(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 90 | 91 | with script_args([]): 92 | manager = TestManagerWithSpider() 93 | 94 | mocked_super_schedule_spider.side_effect = ["999/1/1"] 95 | manager._on_start() 96 | 97 | # first loop: schedule spider 98 | result = next(manager._run_loops()) 99 | 100 | self.assertTrue(result) 101 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 102 | mocked_super_schedule_spider.assert_any_call("myimplicitspider", units=None, job_settings={}) 103 | 104 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 105 | def test_schedule_spider_bad_outcome(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 106 | 107 | with script_args(["myspider"]): 108 | manager = TestManager() 109 | 110 | mocked_super_schedule_spider.side_effect = ["999/1/1"] 111 | manager._on_start() 112 | 113 | # first loop: schedule spider 114 | result = next(manager._run_loops()) 115 | 116 | self.assertTrue(result) 117 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 118 | mocked_super_schedule_spider.assert_any_call("myspider", units=None, job_settings={}) 119 | 120 | # second loop: spider still running. Continue. 121 | manager.is_finished = lambda jobkey: None 122 | result = next(manager._run_loops()) 123 | self.assertTrue(result) 124 | 125 | # third loop: spider is cancelled. Stop. Manager must be closed with cancelled close reason 126 | manager.is_finished = lambda jobkey: Outcome("cancelled") if jobkey == "999/1/1" else None 127 | mocked_super_schedule_spider.reset_mock() 128 | result = next(manager._run_loops()) 129 | 130 | self.assertFalse(result) 131 | self.assertFalse(mocked_super_schedule_spider.called) 132 | 133 | self.assertEqual(manager.get_close_reason(), "cancelled") 134 | 135 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 136 | def test_schedule_spider_with_resume(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 137 | with script_args(["myspider", "--flow-id=3a20"]): 138 | manager = TestManager() 139 | 140 | mocked_get_jobs_side_effect = [ 141 | # the resumed job 142 | [{"tags": ["FLOW_ID=3a20", "NAME=test", "OTHER=other"], "key": "999/10/1"}], 143 | # the owned running job 144 | [{"spider": "myspider", "key": "999/1/1", "tags": ["FLOW_ID=3a20", "PARENT_NAME=test"]}], 145 | # the owned finished jobs 146 | [], 147 | ] 148 | mocked_get_jobs.side_effect = mocked_get_jobs_side_effect 149 | manager._on_start() 150 | self.assertTrue(manager.is_resumed) 151 | self.assertEqual(len(manager._running_job_keys), 1) 152 | 153 | for v in manager._running_job_keys.values(): 154 | self.assertEqual(set(v[1].keys()), {"spider_args", "tags"}) 155 | 156 | self.assertEqual(mocked_get_jobs.call_count, len(mocked_get_jobs_side_effect)) 157 | mocked_add_job_tags.assert_any_call(tags=["FLOW_ID=3a20", "NAME=test", "OTHER=other"]) 158 | 159 | # first loop: spider still running in workflow. Continue. 160 | manager.is_finished = lambda jobkey: None 161 | result = next(manager._run_loops()) 162 | self.assertTrue(result) 163 | 164 | # second loop: spider is finished. Stop. 165 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/1" else None 166 | result = next(manager._run_loops()) 167 | 168 | self.assertFalse(result) 169 | self.assertFalse(mocked_super_schedule_spider.called) 170 | 171 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 172 | def test_schedule_spider_with_resume_not_found( 173 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 174 | ): 175 | with script_args(["myspider", "--flow-id=3a20"]): 176 | manager = TestManager() 177 | 178 | mocked_get_jobs_side_effect = [ 179 | # the not resumed job (different flow id) 180 | [{"tags": ["FLOW_ID=3344", "NAME=othertest"], "key": "999/10/1"}], 181 | ] 182 | mocked_get_jobs.side_effect = mocked_get_jobs_side_effect 183 | manager._on_start() 184 | self.assertFalse(manager.is_resumed) 185 | self.assertEqual(len(manager._running_job_keys), 0) 186 | self.assertEqual(mocked_get_jobs.call_count, len(mocked_get_jobs_side_effect)) 187 | 188 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 189 | def test_schedule_spider_with_resume_not_owned( 190 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 191 | ): 192 | with script_args(["myspider", "--flow-id=3a20"]): 193 | manager = TestManager() 194 | 195 | mocked_get_jobs.side_effect = [ 196 | # the resumed job 197 | [{"tags": ["FLOW_ID=3a20", "NAME=test"], "key": "999/10/1"}], 198 | # the not owned job 199 | [{"key": "999/1/1", "tags": ["FLOW_ID=3a20", "PARENT_NAME=testa"]}], 200 | # owned finished jobs 201 | [], 202 | ] 203 | manager._on_start() 204 | self.assertTrue(manager.is_resumed) 205 | self.assertEqual(len(manager._running_job_keys), 0) 206 | 207 | # first loop: no spider, schedule one. 208 | manager.is_finished = lambda jobkey: None 209 | mocked_super_schedule_spider.side_effect = ["999/1/2"] 210 | result = next(manager._run_loops()) 211 | self.assertTrue(result) 212 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 213 | 214 | # second loop: spider is finished. Stop. 215 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/2" else None 216 | result = next(manager._run_loops()) 217 | 218 | self.assertFalse(result) 219 | 220 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 221 | def test_schedule_spider_periodic(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 222 | with script_args(["myspider"]): 223 | manager = PeriodicTestManager() 224 | 225 | mocked_super_schedule_spider.side_effect = ["999/1/1"] 226 | manager._on_start() 227 | 228 | # first loop: schedule spider 229 | result = next(manager._run_loops()) 230 | 231 | self.assertTrue(result) 232 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 233 | mocked_super_schedule_spider.assert_any_call("myspider", units=None, job_settings={}) 234 | 235 | # second loop: spider still running. Continue. 236 | manager.is_finished = lambda jobkey: None 237 | result = next(manager._run_loops()) 238 | self.assertTrue(result) 239 | 240 | # third loop: spider is finished. Schedule again. 241 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/1" else None 242 | mocked_super_schedule_spider.reset_mock() 243 | mocked_super_schedule_spider.side_effect = ["999/1/2"] 244 | result = next(manager._run_loops()) 245 | 246 | self.assertTrue(result) 247 | mocked_super_schedule_spider.assert_any_call("myspider", units=None, job_settings={}) 248 | 249 | # four loop: spider is cancelled. Schedule again. 250 | manager.is_finished = lambda jobkey: Outcome("cancelled") if jobkey == "999/1/2" else None 251 | mocked_super_schedule_spider.reset_mock() 252 | mocked_super_schedule_spider.side_effect = ["999/1/3"] 253 | result = next(manager._run_loops()) 254 | 255 | self.assertTrue(result) 256 | mocked_super_schedule_spider.assert_any_call("myspider", units=None, job_settings={}) 257 | 258 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 259 | def test_schedule_spider_list_bad_outcome_hook( 260 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 261 | ): 262 | with script_args(["myspider"]): 263 | manager = ListTestManager() 264 | 265 | mocked_super_schedule_spider.side_effect = ["999/1/1", "999/1/2", "999/1/3", "999/1/4", "999/1/5"] 266 | manager._on_start() 267 | 268 | # first loop: schedule spider with first set of arguments 269 | result = next(manager._run_loops()) 270 | 271 | self.assertTrue(result) 272 | self.assertEqual(mocked_super_schedule_spider.call_count, 2) 273 | mocked_super_schedule_spider.assert_any_call( 274 | "myspider", units=None, argA="valA", tags=["JOBSEQ=0000000001"], job_settings={} 275 | ) 276 | mocked_super_schedule_spider.assert_any_call( 277 | "myspider", units=None, argA="valB", tags=["JOBSEQ=0000000002"], job_settings={} 278 | ) 279 | 280 | # second loop: still no job finished. Wait for a free slot 281 | manager.is_finished = lambda jobkey: None 282 | result = next(manager._run_loops()) 283 | self.assertTrue(result) 284 | self.assertEqual(mocked_super_schedule_spider.call_count, 2) 285 | 286 | # third loop: finish one job. We can schedule last one with third set of arguments 287 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/1" else None 288 | result = next(manager._run_loops()) 289 | self.assertTrue(result) 290 | self.assertEqual(mocked_super_schedule_spider.call_count, 3) 291 | mocked_super_schedule_spider.assert_any_call( 292 | "myspider", units=None, argB="valC", tags=["JOBSEQ=0000000003"], job_settings={} 293 | ) 294 | 295 | # fourth loop: waiting jobs to finish 296 | result = next(manager._run_loops()) 297 | self.assertTrue(result) 298 | self.assertEqual(mocked_super_schedule_spider.call_count, 3) 299 | 300 | # fifth loop: second job finished with failed outcome. Retry according 301 | # to bad outcome hook 302 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") if jobkey == "999/1/2" else None 303 | result = next(manager._run_loops()) 304 | self.assertTrue(result) 305 | self.assertEqual(mocked_super_schedule_spider.call_count, 4) 306 | mocked_super_schedule_spider.assert_any_call( 307 | "myspider", units=None, argA="valB", tags=["JOBSEQ=0000000002.r1"], argR="valR", job_settings={} 308 | ) 309 | 310 | # sixth loop: third job finished. 311 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/3" else None 312 | result = next(manager._run_loops()) 313 | self.assertTrue(result) 314 | self.assertEqual(mocked_super_schedule_spider.call_count, 4) 315 | 316 | # seventh loop: retried job failed again. Retry with new argument. 317 | manager.is_finished = lambda jobkey: Outcome("memusage_exceeded") if jobkey == "999/1/4" else None 318 | result = next(manager._run_loops()) 319 | self.assertTrue(result) 320 | self.assertEqual(mocked_super_schedule_spider.call_count, 5) 321 | mocked_super_schedule_spider.assert_any_call( 322 | "myspider", units=None, argA="valB", tags=["JOBSEQ=0000000002.r2"], argR="valR", job_settings={} 323 | ) 324 | 325 | # eighth loop: retried job finished. Exit. 326 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/1/5" else None 327 | result = next(manager._run_loops()) 328 | self.assertFalse(result) 329 | self.assertEqual(mocked_super_schedule_spider.call_count, 5) 330 | 331 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 332 | def test_schedule_spider_list_explicit_spider( 333 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 334 | ): 335 | class _ListTestManager(GeneratorCrawlManager): 336 | 337 | name = "test" 338 | default_max_jobs = 2 339 | spider = SpiderName("myspider") 340 | 341 | def set_parameters_gen(self): 342 | parameters_list = [ 343 | {"argA": "valA"}, 344 | {"argA": "valB", "spider": "myspidertwo"}, 345 | ] 346 | for args in parameters_list: 347 | yield args 348 | 349 | with script_args([]): 350 | manager = _ListTestManager() 351 | 352 | mocked_super_schedule_spider.side_effect = ["999/1/1", "999/1/2"] 353 | manager._on_start() 354 | 355 | # first loop: schedule spider with first set of arguments 356 | result = next(manager._run_loops()) 357 | 358 | self.assertTrue(result) 359 | self.assertEqual(mocked_super_schedule_spider.call_count, 2) 360 | mocked_super_schedule_spider.assert_any_call( 361 | "myspider", units=None, argA="valA", tags=["JOBSEQ=0000000001"], job_settings={} 362 | ) 363 | mocked_super_schedule_spider.assert_any_call( 364 | "myspidertwo", units=None, argA="valB", tags=["JOBSEQ=0000000002"], job_settings={} 365 | ) 366 | 367 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 368 | def test_schedule_spider_list_scrapy_cloud_params( 369 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 370 | ): 371 | class _ListTestManager(GeneratorCrawlManager): 372 | 373 | name = "test" 374 | default_max_jobs = 2 375 | spider = SpiderName("myspider") 376 | 377 | def set_parameters_gen(self): 378 | parameters_list = [ 379 | { 380 | "argA": "valA", 381 | "units": 2, 382 | "tags": ["CHECKED"], 383 | "project_id": 999, 384 | "job_settings": {"CONCURRENT_REQUESTS": 2}, 385 | }, 386 | ] 387 | for args in parameters_list: 388 | yield args 389 | 390 | with script_args([]): 391 | manager = _ListTestManager() 392 | 393 | mocked_super_schedule_spider.side_effect = ["999/1/1", "999/1/2"] 394 | manager._on_start() 395 | 396 | # first loop: schedule spider with first set of arguments 397 | result = next(manager._run_loops()) 398 | 399 | self.assertTrue(result) 400 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 401 | mocked_super_schedule_spider.assert_any_call( 402 | "myspider", 403 | units=2, 404 | argA="valA", 405 | job_settings={"CONCURRENT_REQUESTS": 2}, 406 | project_id=999, 407 | tags=["CHECKED", "JOBSEQ=0000000001"], 408 | ) 409 | jobid = GeneratorCrawlManager.get_job_unique_id( 410 | {"spider": SpiderName("myspider"), "spider_args": {"argA": "valA"}} 411 | ) 412 | self.assertTrue(jobid in manager._jobuids) 413 | 414 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 415 | def test_schedule_spider_list_with_resume(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 416 | class _ListTestManager(GeneratorCrawlManager): 417 | 418 | name = "test" 419 | default_max_jobs = 2 420 | spider = SpiderName("myspider") 421 | 422 | def set_parameters_gen(self): 423 | parameters_list = [ 424 | {"argA": "valA"}, 425 | {"argA": "valB", "spider": "myspidertwo"}, 426 | ] 427 | for args in parameters_list: 428 | yield args 429 | 430 | with script_args(["--flow-id=3a20"]): 431 | manager = _ListTestManager() 432 | 433 | mocked_get_jobs.side_effect = [ 434 | # the resumed job 435 | [{"tags": ["FLOW_ID=3a20", "NAME=test"], "key": "999/10/1"}], 436 | # running spiders 437 | [], 438 | # finished spiders 439 | [ 440 | { 441 | "spider": "myspider", 442 | "key": "999/1/1", 443 | "tags": ["FLOW_ID=3a20", "PARENT_NAME=test", "JOBSEQ=0000000001"], 444 | "spider_args": {"argA": "valA"}, 445 | } 446 | ], 447 | ] 448 | mocked_super_schedule_spider.side_effect = ["999/2/1"] 449 | manager._on_start() 450 | self.assertTrue(manager.is_resumed) 451 | self.assertEqual(len(manager._running_job_keys), 0) 452 | 453 | # first loop: only second task will be scheduled. First one already completed before resuming. 454 | manager.is_finished = lambda jobkey: None 455 | result = next(manager._run_loops()) 456 | self.assertTrue(result) 457 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 458 | mocked_super_schedule_spider.assert_any_call( 459 | "myspidertwo", argA="valB", tags=["JOBSEQ=0000000002"], job_settings={}, units=None 460 | ) 461 | 462 | # second loop: finished second spider. Finish execution 463 | manager.is_finished = lambda jobkey: Outcome("finished") if jobkey == "999/2/1" else None 464 | result = next(manager._run_loops()) 465 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 466 | self.assertFalse(result) 467 | 468 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 469 | def test_schedule_spider_list_resumed_running_job_with_bad_outcome( 470 | self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs 471 | ): 472 | class _ListTestManager(ListTestManager): 473 | 474 | name = "test" 475 | default_max_jobs = 2 476 | spider = SpiderName("myspider") 477 | 478 | def set_parameters_gen(self): 479 | parameters_list = [ 480 | {"argA": "valA"}, 481 | ] 482 | for args in parameters_list: 483 | yield args 484 | 485 | with script_args(["--flow-id=3a20"]): 486 | manager = _ListTestManager() 487 | 488 | mocked_get_jobs.side_effect = [ 489 | # the resumed job 490 | [{"tags": ["FLOW_ID=3a20", "NAME=test"], "key": "999/10/1"}], 491 | # running spiders 492 | [ 493 | { 494 | "spider": "myspider", 495 | "key": "999/1/1", 496 | "tags": ["FLOW_ID=3a20", "PARENT_NAME=test", "JOBSEQ=0000000001"], 497 | "spider_args": {"argA": "valA"}, 498 | } 499 | ], 500 | # finished spiders 501 | [], 502 | ] 503 | mocked_super_schedule_spider.side_effect = ["999/2/1"] 504 | manager._on_start() 505 | self.assertTrue(manager.is_resumed) 506 | self.assertEqual(len(manager._running_job_keys), 1) 507 | 508 | for v in manager._running_job_keys.values(): 509 | self.assertEqual(set(v[1].keys()), {"spider_args", "tags"}) 510 | 511 | # first loop: acquire running job. 512 | manager.is_finished = lambda jobkey: None 513 | result = next(manager._run_loops()) 514 | self.assertTrue(result) 515 | self.assertEqual(mocked_super_schedule_spider.call_count, 0) 516 | 517 | # second loop: second job finished with failed outcome. Retry according 518 | # to bad outcome hook 519 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") 520 | mocked_super_schedule_spider.side_effect = ["999/2/2"] 521 | result = next(manager._run_loops()) 522 | self.assertTrue(result) 523 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 524 | mocked_super_schedule_spider.assert_any_call( 525 | "myspider", 526 | units=None, 527 | argA="valA", 528 | tags=["FLOW_ID=3a20", "PARENT_NAME=test", "JOBSEQ=0000000001.r1"], 529 | argR="valR", 530 | job_settings={}, 531 | ) 532 | 533 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 534 | def test_default_bad_outcome_no_retry(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 535 | class _ListTestManager(GeneratorCrawlManager): 536 | 537 | name = "test" 538 | default_max_jobs = 1 539 | spider = SpiderName("myspider") 540 | 541 | def set_parameters_gen(self): 542 | parameters_list = [ 543 | {"argA": "valA"}, 544 | ] 545 | for args in parameters_list: 546 | yield args 547 | 548 | with script_args([]): 549 | manager = _ListTestManager() 550 | 551 | mocked_super_schedule_spider.side_effect = ["999/1/1"] 552 | manager._on_start() 553 | 554 | # first loop: schedule spider 555 | result = next(manager._run_loops()) 556 | 557 | self.assertTrue(result) 558 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 559 | mocked_super_schedule_spider.assert_any_call( 560 | "myspider", units=None, argA="valA", tags=["JOBSEQ=0000000001"], job_settings={} 561 | ) 562 | 563 | # second loop: finish job with bad outcome, but there is no retry. Stop. 564 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") if jobkey == "999/1/1" else None 565 | result = next(manager._run_loops()) 566 | self.assertFalse(result) 567 | self.assertEqual(mocked_super_schedule_spider.call_count, 1) 568 | 569 | @patch("shub_workflow.crawl.WorkFlowManager.schedule_spider") 570 | def test_default_bad_outcome_with_retries(self, mocked_super_schedule_spider, mocked_add_job_tags, mocked_get_jobs): 571 | class _ListTestManager(GeneratorCrawlManager): 572 | 573 | name = "test" 574 | default_max_jobs = 2 575 | spider = SpiderName("myspider") 576 | 577 | MAX_RETRIES = 2 578 | 579 | def set_parameters_gen(self): 580 | parameters_list = [ 581 | {"argA": "valA"}, 582 | {"argB": "valB"}, 583 | ] 584 | for args in parameters_list: 585 | yield args 586 | 587 | with script_args([]): 588 | manager = _ListTestManager() 589 | 590 | mocked_super_schedule_spider.side_effect = ["999/1/1", "999/2/1"] 591 | manager._on_start() 592 | 593 | # first loop: schedule spiders 594 | result = next(manager._run_loops()) 595 | 596 | self.assertTrue(result) 597 | self.assertEqual(mocked_super_schedule_spider.call_count, 2) 598 | mocked_super_schedule_spider.assert_any_call( 599 | "myspider", units=None, argA="valA", tags=["JOBSEQ=0000000001"], job_settings={} 600 | ) 601 | 602 | mocked_super_schedule_spider.assert_any_call( 603 | "myspider", units=None, argB="valB", tags=["JOBSEQ=0000000002"], job_settings={} 604 | ) 605 | 606 | # second loop: finish first job with bad outcome, retry 1. 607 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") if jobkey == "999/1/1" else None 608 | mocked_super_schedule_spider.side_effect = ["999/1/2"] 609 | result = next(manager._run_loops()) 610 | self.assertTrue(result) 611 | self.assertEqual(mocked_super_schedule_spider.call_count, 3) 612 | mocked_super_schedule_spider.assert_any_call( 613 | "myspider", units=None, argA="valA", tags=["RETRIED_FROM=999/1/1", "JOBSEQ=0000000001.r1"], job_settings={} 614 | ) 615 | 616 | # second loop: second job finishes with "cancelled", don't retry it. 617 | manager.is_finished = lambda jobkey: Outcome("cancelled") if jobkey == "999/2/1" else None 618 | result = next(manager._run_loops()) 619 | self.assertTrue(result) 620 | self.assertEqual(mocked_super_schedule_spider.call_count, 3) 621 | 622 | # third loop: first job finishes again with abnormal reason, retry it. 623 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") if jobkey == "999/1/2" else None 624 | mocked_super_schedule_spider.side_effect = ["999/1/3"] 625 | result = next(manager._run_loops()) 626 | self.assertTrue(result) 627 | self.assertEqual(mocked_super_schedule_spider.call_count, 4) 628 | mocked_super_schedule_spider.assert_any_call( 629 | "myspider", units=None, argA="valA", tags=["RETRIED_FROM=999/1/2", "JOBSEQ=0000000001.r2"], job_settings={} 630 | ) 631 | 632 | # fourth loop: first job finishes again with abnormal reason, but max retries reached. Stop. 633 | manager.is_finished = lambda jobkey: Outcome("cancelled (stalled)") if jobkey == "999/1/3" else None 634 | result = next(manager._run_loops()) 635 | self.assertFalse(result) 636 | self.assertEqual(mocked_super_schedule_spider.call_count, 4) 637 | 638 | manager._close() 639 | -------------------------------------------------------------------------------- /tests/test_futils.py: -------------------------------------------------------------------------------- 1 | from tempfile import mktemp 2 | from unittest import TestCase 3 | 4 | from shub_workflow.utils.futils import FSHelper 5 | 6 | 7 | class FUtilsTest(TestCase): 8 | def test_fshelper(self) -> None: 9 | helper = FSHelper() 10 | test_file_1 = mktemp() 11 | test_file_2 = mktemp() 12 | helper.touch(test_file_1) 13 | helper.mv_file(test_file_1, test_file_2) 14 | self.assertTrue(helper.exists(test_file_2)) 15 | helper.rm_file(test_file_2) 16 | -------------------------------------------------------------------------------- /tests/test_typing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | from unittest import TestCase 4 | from unittest.mock import patch 5 | 6 | from shub_workflow.script import BaseScript, BaseScriptProtocol 7 | from shub_workflow.base import WorkFlowManager, WorkFlowManagerProtocol 8 | from shub_workflow.graph import GraphManager 9 | from shub_workflow.graph.task import Task 10 | from shub_workflow.crawl import GeneratorCrawlManager, AsyncSchedulerCrawlManagerMixin 11 | from shub_workflow.deliver import BaseDeliverScript 12 | from shub_workflow.utils.contexts import script_args 13 | 14 | 15 | class MyDeliverScript(BaseDeliverScript): 16 | SCRAPERNAME_NARGS = 1 17 | 18 | 19 | class MyMixin(BaseScriptProtocol): 20 | def __init__(self): 21 | super().__init__() 22 | self.target = 0 23 | 24 | def append_flow_tag(self, tag: str): 25 | self.target = 1 26 | 27 | def parse_args(self) -> Namespace: 28 | args = super().parse_args() # type: ignore 29 | if hasattr(args, "exec_id"): 30 | self.append_flow_tag(f"EXEC_ID={args.exec_id}") 31 | return args 32 | 33 | 34 | class MyWFMixin(WorkFlowManagerProtocol): 35 | def __init__(self): 36 | super().__init__() 37 | self.target = 0 38 | 39 | def append_flow_tag(self, tag: str): 40 | self.target = 1 41 | 42 | 43 | class MyScript(MyMixin, BaseScript): 44 | def run(self): 45 | pass 46 | 47 | 48 | class MyWorkFlowManager(MyMixin, WorkFlowManager): 49 | def workflow_loop(self) -> bool: 50 | return False 51 | 52 | 53 | class MyWorkFlowManagerTwo(MyWFMixin, WorkFlowManager): 54 | def workflow_loop(self) -> bool: 55 | return False 56 | 57 | 58 | class MyGraphManager(MyMixin, GraphManager): 59 | def configure_workflow(self): 60 | jobA = Task("jobA", "py:command.py") 61 | return (jobA,) 62 | 63 | 64 | class MyAsyncCrawlManager(AsyncSchedulerCrawlManagerMixin, GeneratorCrawlManager): # type: ignore 65 | def set_parameters_gen(self): 66 | yield from () 67 | 68 | 69 | @patch("shub_workflow.script.BaseScript.get_sc_project_settings", new=lambda _: {}) 70 | class TypingTest(TestCase): 71 | def setUp(self): 72 | os.environ["SH_APIKEY"] = "ffff" 73 | os.environ["PROJECT_ID"] = "999" 74 | 75 | @patch("shub_workflow.script.BaseScript.add_job_tags") 76 | def test_script_instantiation(self, mocked_add_job_tags): 77 | with script_args([]): 78 | script = MyScript() 79 | script.append_flow_tag("mytag") 80 | # self.assertEqual(mocked_add_job_tags.call_count, 0) 81 | self.assertEqual(script.target, 1) 82 | 83 | @patch("shub_workflow.script.BaseScript.add_job_tags") 84 | def test_workflow_manager_instantiation(self, mocked_add_job_tags): 85 | with script_args(["myname"]): 86 | manager = MyWorkFlowManager() 87 | manager.append_flow_tag("mytag") 88 | # self.assertEqual(mocked_add_job_tags.call_count, 0) 89 | self.assertEqual(manager.target, 1) 90 | 91 | @patch("shub_workflow.script.BaseScript.add_job_tags") 92 | def test_workflow_manager_instantiation_two(self, mocked_add_job_tags): 93 | with script_args(["myname"]): 94 | manager = MyWorkFlowManagerTwo() 95 | manager.append_flow_tag("mytag") 96 | # self.assertEqual(mocked_add_job_tags.call_count, 0) 97 | self.assertEqual(manager.target, 1) 98 | 99 | @patch("shub_workflow.script.BaseScript.add_job_tags") 100 | def test_graph_manager_instantiation(self, mocked_add_job_tags): 101 | with script_args(["myname"]): 102 | manager = MyGraphManager() 103 | manager.append_flow_tag("mytag") 104 | # self.assertEqual(mocked_add_job_tags.call_count, 0) 105 | self.assertEqual(manager.target, 1) 106 | 107 | def test_asyncrawlmanager(self): 108 | with script_args(["myname", "myspider"]): 109 | MyAsyncCrawlManager() 110 | --------------------------------------------------------------------------------