├── example ├── __init__.py ├── run.sh ├── benchmark │ ├── __init__.py │ └── benchmarks.py ├── index │ ├── __init__.py │ └── anserini.py ├── searcher │ ├── __init__.py │ └── bm25.py ├── collection │ ├── __init__.py │ └── collections.py ├── task │ ├── __init__.py │ ├── rank.py │ └── base.py ├── worker.py └── run.py ├── requirements.txt ├── profane ├── exceptions.py ├── __init__.py ├── constants.py ├── frozendict.py ├── cli.py ├── sql.py ├── config_option.py └── base.py ├── pyproject.toml ├── .gitignore ├── .github └── workflows │ └── pythonpackage.yml ├── tests ├── test_frozendict.py ├── test_module_registry.py ├── test_config_string.py ├── test_pipeline.py ├── test_config_types.py └── test_task_pipeline.py ├── setup.py ├── README.md ├── flexible_pipeline.md └── LICENSE /example/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PYTHONPATH=.. python run.py $* 4 | -------------------------------------------------------------------------------- /example/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from profane import import_all_modules 2 | 3 | import_all_modules(__file__, __package__) 4 | -------------------------------------------------------------------------------- /example/index/__init__.py: -------------------------------------------------------------------------------- 1 | from profane import import_all_modules 2 | 3 | import_all_modules(__file__, __package__) 4 | -------------------------------------------------------------------------------- /example/searcher/__init__.py: -------------------------------------------------------------------------------- 1 | from profane import import_all_modules 2 | 3 | import_all_modules(__file__, __package__) 4 | -------------------------------------------------------------------------------- /example/collection/__init__.py: -------------------------------------------------------------------------------- 1 | from profane import import_all_modules 2 | 3 | import_all_modules(__file__, __package__) 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorama 2 | docopt 3 | hypothesis 4 | numpy>=1.17 5 | pytest 6 | PyYAML>=5 7 | sqlalchemy 8 | sqlalchemy-utils 9 | -------------------------------------------------------------------------------- /example/task/__init__.py: -------------------------------------------------------------------------------- 1 | from profane import import_all_modules 2 | 3 | from task.base import Task 4 | 5 | import_all_modules(__file__, __package__) 6 | -------------------------------------------------------------------------------- /profane/exceptions.py: -------------------------------------------------------------------------------- 1 | class PipelineConstructionError(Exception): 2 | pass 3 | 4 | 5 | class InvalidConfigError(Exception): 6 | pass 7 | 8 | 9 | class InvalidModuleError(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 130 3 | target_version = ['py37'] 4 | 5 | [tool.pytest.ini_options] 6 | pythonpath = [ 7 | "." 8 | ] 9 | 10 | [tool.pyright] 11 | include = ["profane"] 12 | exclude = ["**/__pycache__"] 13 | reportMissingImports = true 14 | reportMissingTypeStubs = false -------------------------------------------------------------------------------- /example/collection/collections.py: -------------------------------------------------------------------------------- 1 | from profane import ModuleBase 2 | 3 | 4 | class Collection(ModuleBase): 5 | module_type = "collection" 6 | 7 | 8 | @Collection.register 9 | class Robust04(Collection): 10 | module_name = "robust04" 11 | 12 | 13 | @Collection.register 14 | class MSMARCO(Collection): 15 | module_name = "MSMARCO" 16 | -------------------------------------------------------------------------------- /example/benchmark/benchmarks.py: -------------------------------------------------------------------------------- 1 | from profane import ModuleBase, Dependency, ConfigOption 2 | 3 | 4 | class Benchmark(ModuleBase): 5 | module_type = "benchmark" 6 | 7 | 8 | @Benchmark.register 9 | class WsdmBenchmark(Benchmark): 10 | module_name = "wsdm20demo" 11 | 12 | dependencies = [Dependency(key="collection", module="collection", name="robust04")] 13 | -------------------------------------------------------------------------------- /example/index/anserini.py: -------------------------------------------------------------------------------- 1 | from profane import ModuleBase, Dependency, ConfigOption 2 | 3 | 4 | class Index(ModuleBase): 5 | module_type = "index" 6 | 7 | 8 | @Index.register 9 | class Anserini(Index): 10 | module_name = "anserini" 11 | dependencies = [Dependency(key="collection", module="collection", name="MSMARCO")] 12 | config_spec = [ConfigOption(key="stemmer", default_value="porter", description="stemmer")] 13 | -------------------------------------------------------------------------------- /example/searcher/bm25.py: -------------------------------------------------------------------------------- 1 | from profane import ModuleBase, Dependency, ConfigOption 2 | 3 | 4 | class Searcher(ModuleBase): 5 | module_type = "searcher" 6 | 7 | 8 | @Searcher.register 9 | class BM25(Searcher): 10 | module_name = "BM25" 11 | dependencies = [Dependency(key="index", module="index", name="anserini")] 12 | config_spec = [ 13 | ConfigOption(key="b", default_value="0.8", description="b param", value_type="floatlist"), 14 | ConfigOption("z", default_value=1, value_type="intlist"), 15 | ] 16 | requires_random_seed = True 17 | -------------------------------------------------------------------------------- /profane/__init__.py: -------------------------------------------------------------------------------- 1 | import profane.base 2 | from profane.cli import config_list_to_dict 3 | from profane.config_option import ConfigOption 4 | from profane.exceptions import PipelineConstructionError, InvalidConfigError, InvalidModuleError 5 | from profane.frozendict import FrozenDict 6 | from profane.sql import DBManager 7 | 8 | __version__ = "0.2.4" 9 | 10 | constants = profane.base.constants 11 | module_registry = profane.base.module_registry 12 | import_all_modules = profane.base.import_all_modules 13 | Dependency = profane.base.Dependency 14 | ModuleBase = profane.base.ModuleBase 15 | -------------------------------------------------------------------------------- /example/task/rank.py: -------------------------------------------------------------------------------- 1 | from profane import Dependency 2 | 3 | from task import Task 4 | 5 | 6 | @Task.register 7 | class Rank(Task): 8 | module_name = "rank" 9 | dependencies = [ 10 | Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]), 11 | Dependency(key="searcher", module="searcher", name="BM25"), 12 | ] 13 | commands = ["run"] + Task.help_commands 14 | default_command = "run" 15 | 16 | def run(self): 17 | print("in rank.run") 18 | print("benchmark:", self.benchmark) 19 | print("searcher:", self.searcher) 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # editor tmp files 7 | *~ 8 | 9 | # DB credentials 10 | .sql_url 11 | 12 | # mypy 13 | .mypy_cache/ 14 | 15 | # notebooks 16 | .ipynb_checkpoints/ 17 | .idea/ 18 | # scripts with credentials 19 | queue.sh 20 | queue_cpu.sh 21 | 22 | .DS_Store 23 | 24 | .pytest_cache/ 25 | 26 | # ide files 27 | /.vscode 28 | /*.code-workspace 29 | 30 | # default anserini, cache, and result directories 31 | /anserini 32 | /cache 33 | /results 34 | 35 | # testing 36 | .coverage 37 | .hypothesis/ 38 | 39 | # tmp 40 | build/ 41 | dist/ 42 | profane.egg-info/ 43 | -------------------------------------------------------------------------------- /example/task/base.py: -------------------------------------------------------------------------------- 1 | from profane import ModuleBase 2 | 3 | 4 | class Task(ModuleBase): 5 | module_type = "task" 6 | commands = [] 7 | help_commands = ["describe", "print_config", "print_paths", "print_pipeline"] 8 | default_command = "describe" 9 | requires_random_seed = True 10 | 11 | def print_config(self): 12 | print("Configuration:") 13 | self.print_module_config(prefix=" ") 14 | 15 | def print_paths(self): # TODO 16 | pass 17 | 18 | def print_pipeline(self): 19 | print(f"Module graph:") 20 | self.print_module_graph(prefix=" ") 21 | 22 | def describe(self): 23 | self.print_pipeline() 24 | print("\n") 25 | self.print_config() 26 | -------------------------------------------------------------------------------- /profane/constants.py: -------------------------------------------------------------------------------- 1 | class ConstantsRegistry: 2 | """Write-once registry that keeps track of constants shared by modules. 3 | ConstantsRegistry behaves like a dict, but keys can only be assigned to once. 4 | """ 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self._d = {} 11 | 12 | def __getitem__(self, key): 13 | return self._d[key] 14 | 15 | def __setitem__(self, key, val): 16 | if key in self._d and self._d[key] != val: 17 | raise TypeError( 18 | f"ConstantsRegistry does not support re-assignment of existing entries; already contains: {key}={self._d[key]}" 19 | ) 20 | else: 21 | self._d[key] = val 22 | 23 | def __repr__(self): 24 | return repr(self._d) 25 | 26 | def __len__(self): 27 | return len(self._d) 28 | 29 | def __contains__(self, item): 30 | return item in self._d 31 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: [3.7, 3.8, 3.9, "3.10"] 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: black 21 | uses: lgeiger/black-action@master 22 | with: 23 | args: ". --check --config pyproject.toml" 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r requirements.txt 28 | - name: Test with pytest 29 | run: | 30 | pip install pytest 31 | export PYTHONPATH=${PYTHONPATH}:/home/runner/work/profane/profane/ 32 | pytest -vvv 33 | 34 | timeout-minutes: 20 35 | -------------------------------------------------------------------------------- /tests/test_frozendict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from profane import FrozenDict 4 | 5 | 6 | def test_frozen(): 7 | d = FrozenDict({1: 2, "tuple": [9, 8], "inner": {"a": "b", "more": {"c": "d"}}}) 8 | 9 | with pytest.raises(TypeError): 10 | d[1] = 3 11 | 12 | with pytest.raises(TypeError): 13 | d["inner"]["a"] = "c" 14 | 15 | with pytest.raises(TypeError): 16 | d["inner"]["more"] = "e" 17 | 18 | with pytest.raises(TypeError): 19 | d["inner"]["more"]["c"] = "e" 20 | 21 | assert isinstance(d["tuple"], tuple) 22 | assert d["tuple"] == (9, 8) 23 | 24 | 25 | def test_copy(): 26 | d = FrozenDict({1: 2, 3: 4, "inner": {5: 6, "more": {7: 8, 9: 10}}}) 27 | unfrozen = d.unfrozen_copy() 28 | unfrozen[1] = 11 29 | unfrozen["inner"][5] = 7 30 | unfrozen["inner"]["more"][9] = 12 31 | 32 | modified = FrozenDict({1: 11, 3: 4, "inner": {5: 7, "more": {7: 8, 9: 12}}}) 33 | assert FrozenDict(unfrozen) == modified 34 | -------------------------------------------------------------------------------- /tests/test_module_registry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from profane.base import ModuleBase, InvalidModuleError, ModuleRegistry 4 | 5 | 6 | def test_module_registry(): 7 | registry = ModuleRegistry() 8 | 9 | with pytest.raises(ValueError): 10 | registry.lookup("missing_type", "missing_name") 11 | 12 | class VeryIncompleteModule(ModuleBase): 13 | pass 14 | 15 | class HalfIncompleteModule(ModuleBase): 16 | module_type = "incomplete" 17 | 18 | with pytest.raises(InvalidModuleError): 19 | registry.register(VeryIncompleteModule) 20 | 21 | with pytest.raises(InvalidModuleError): 22 | registry.register(HalfIncompleteModule) 23 | 24 | class MinimalRegisterableModule(ModuleBase): 25 | module_type = "minimal" 26 | module_name = "MRM" 27 | 28 | registry.register(MinimalRegisterableModule) 29 | 30 | assert registry.lookup("minimal", "MRM") == MinimalRegisterableModule 31 | 32 | with pytest.raises(ValueError): 33 | registry.lookup("minimal", "missing") 34 | 35 | with pytest.raises(ValueError): 36 | registry.lookup("missing", "MRM") 37 | 38 | class WrongDependenciesTypeModule: 39 | module_type = "wrongdependencies" 40 | module_name = "WDTM" 41 | 42 | dependencies = {} 43 | 44 | with pytest.raises(TypeError): 45 | registry.register(WrongDependenciesTypeModule) 46 | -------------------------------------------------------------------------------- /example/worker.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | import sys 5 | import time 6 | import traceback 7 | import sqlalchemy 8 | 9 | import sql 10 | from run import prepare_task 11 | 12 | db = sql.DBManager(os.environ.get("EXAMPLE_DB")) 13 | 14 | 15 | def try_run(run): 16 | if run.status not in ["QUEUED", "FAILED"]: 17 | return 18 | 19 | db.started_event(run) 20 | 21 | try: 22 | task, func = prepare_task(run.command, run.config) 23 | func() 24 | db.completed_event(run) 25 | print("run finished") 26 | return True 27 | except (Exception, KeyboardInterrupt) as e: 28 | db.failed_event(run) 29 | 30 | print("\nERROR: failed run for id: %s" % run) 31 | print("exception {0} with arguments:\n{1!r}".format(type(e).__name__, e.args)) 32 | print(traceback.format_exc()) 33 | 34 | return False 35 | 36 | 37 | print("%s checking for work" % datetime.datetime.now()) 38 | 39 | try: 40 | db.clear_zombie_runs() 41 | 42 | run = db.get_eligible_run(max_tries=3) 43 | if run: 44 | try_run(run) 45 | 46 | print("%s done" % datetime.datetime.now()) 47 | except (sqlalchemy.exc.InvalidRequestError, sqlalchemy.exc.OperationalError) as e: 48 | if ("%s" % e).find("deadlock detected") != -1: 49 | print("got exception: %s\n" % e) 50 | print("%s deadlock detected; sleeping and exiting" % datetime.datetime.now()) 51 | time.sleep(random.randint(60, 450)) 52 | sys.exit(0) 53 | else: 54 | raise 55 | -------------------------------------------------------------------------------- /profane/frozendict.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from copy import deepcopy 3 | 4 | 5 | class FrozenDict(collections.abc.Mapping): 6 | """Based on frozen dict implementation from https://stackoverflow.com/a/2704866 by Mike Graham""" 7 | 8 | def __init__(self, *args, **kwargs): 9 | self._d = dict(*args, **kwargs) 10 | _freeze_dicts(self._d) 11 | 12 | self._hash = None 13 | 14 | def __iter__(self): 15 | return iter(self._d) 16 | 17 | def __len__(self): 18 | return len(self._d) 19 | 20 | def __str__(self): 21 | return self._d.__str__() 22 | 23 | def __repr__(self): 24 | return self._d.__repr__() 25 | 26 | def __getitem__(self, key): 27 | return self._d[key] 28 | 29 | def __eq__(self, other): 30 | if isinstance(other, dict): 31 | other = FrozenDict(other) 32 | 33 | if not isinstance(other, FrozenDict): 34 | return False 35 | 36 | return self._d == other._d 37 | 38 | def __hash__(self): 39 | if self._hash is None: 40 | self._hash = hash(frozenset(self._d.items())) 41 | return self._hash 42 | 43 | def _as_dict(self): 44 | unfrozen = deepcopy(self._d) 45 | 46 | for k in list(unfrozen.keys()): 47 | if isinstance(unfrozen[k], FrozenDict): 48 | unfrozen[k] = unfrozen[k]._as_dict() 49 | 50 | return unfrozen 51 | 52 | def unfrozen_copy(self): 53 | return self._as_dict() 54 | 55 | 56 | def _freeze_dicts(d): 57 | for k in list(d.keys()): 58 | if isinstance(d[k], dict): 59 | d[k] = FrozenDict(d[k]) 60 | elif isinstance(d[k], list): 61 | d[k] = tuple(d[k]) 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | from setuptools.command.develop import develop 4 | from setuptools.command.install import install 5 | 6 | 7 | class PostDevelopCommand(develop): 8 | """Post-installation for development mode.""" 9 | 10 | def run(self): 11 | develop.run(self) 12 | 13 | 14 | class PostInstallCommand(install): 15 | """Post-installation for installation mode.""" 16 | 17 | def run(self): 18 | install.run(self) 19 | 20 | 21 | with open("README.md", "r") as fh: 22 | long_description = fh.read() 23 | 24 | # from https://packaging.python.org/guides/single-sourcing-package-version/ 25 | def read(rel_path): 26 | here = os.path.abspath(os.path.dirname(__file__)) 27 | with open(os.path.join(here, rel_path), "rt") as fp: 28 | return fp.read() 29 | 30 | 31 | def get_version(rel_path): 32 | for line in read(rel_path).splitlines(): 33 | if line.startswith("__version__"): 34 | delim = '"' if '"' in line else "'" 35 | return line.split(delim)[1] 36 | raise RuntimeError("Unable to find version string.") 37 | 38 | 39 | setuptools.setup( 40 | name="profane", 41 | version=get_version("profane/__init__.py"), 42 | author="Andrew Yates", 43 | author_email="", 44 | description="A library for creating complex experimental pipelines", 45 | long_description=long_description, 46 | long_description_content_type="text/markdown", 47 | url="https://github.com/andrewyates/profane", 48 | packages=setuptools.find_packages(), 49 | install_requires=["colorama", "docopt", "numpy>=1.17", "PyYAML>=5", "sqlalchemy", "sqlalchemy-utils"], 50 | classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"], 51 | python_requires=">=3.6", 52 | cmdclass={"develop": PostDevelopCommand, "install": PostInstallCommand}, 53 | include_package_data=True, 54 | ) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 2 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) 3 | [![Worfklow](https://github.com/andrewyates/profane/workflows/pytest/badge.svg)](https://github.com/andrewyates/profane/actions) 4 | [![PyPI version fury.io](https://badge.fury.io/py/profane.svg)](https://pypi.python.org/pypi/profane/) 5 | 6 | 7 | # Overview 8 | *Profane* is a library for creating complex experimental pipelines. Profane pipelines are based on two key ideas: 9 | 1. An experiment is a *function of its configuration*. In other words, an experiment should be deterministic given a set of experimental parameters (random seed, specific algorithms to run, etc). 10 | 2. An experiment is described as a *DAG* representing modules' (nodes') dependencies in which the *state of a node is independent of its parent's state*. That is, a *node's operation is a function of its configuration and the configurations of its children*. This means that a node may not modify the configuration (or state) of its children (or descendants). 11 | 12 | These allow for the construction of a flexible pipeline with automatic caching. Each node's configuration can be modified to change experimental parameters, and a node's output can be safely cached in a path derived from its configuration and the configurations of its children. These nodes are called modules. 13 | 14 | This library is heavily inspired by the excellent [sacred](https://sacred.readthedocs.io/en/stable/) library. Among other differences, profane imposes a specific structure on the pipeline and leverages this to allow profane modules to be dynamically configured (which would be similar to dynamic sacred ingredients). Profane was developed based on experiences using sacred with a heavily modified pipeline initialization step. 15 | 16 | ## Example 17 | The `example/` directory contains a module graph similar to that used in Capreolus. Run it with the `run.sh` script. 18 | -------------------------------------------------------------------------------- /tests/test_config_string.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import yaml 4 | 5 | from profane.cli import config_list_to_dict 6 | 7 | 8 | def test_config_string_to_dict(): 9 | args = ["foo.bar=yes", "main=42"] 10 | assert config_list_to_dict(args) == {"foo": {"bar": "yes"}, "main": "42"} 11 | 12 | args = ["foo.bar=yes", "main=42", "foo.bar=override"] 13 | assert config_list_to_dict(args) == {"foo": {"bar": "override"}, "main": "42"} 14 | 15 | with pytest.raises(ValueError): 16 | args = ["invalid"] 17 | config_list_to_dict(args) 18 | 19 | with pytest.raises(ValueError): 20 | args = ["invalid="] 21 | config_list_to_dict(args) 22 | 23 | with pytest.raises(ValueError): 24 | args = ["invalid.=1"] 25 | config_list_to_dict(args) 26 | 27 | with pytest.raises(ValueError): 28 | args = [".invalid=1"] 29 | config_list_to_dict(args) 30 | 31 | 32 | def test_config_string_with_files_to_dict(tmpdir): 33 | mainfn = os.path.join(tmpdir, "main.txt") 34 | with open(mainfn, "wt") as f: 35 | print("main=24 # comment", file=f) 36 | print("#main=25", file=f) 37 | 38 | foofn = os.path.join(tmpdir, "foo.txt") 39 | with open(foofn, "wt") as f: 40 | print("test1=20 submod1.test1=21 ", file=f) 41 | print("submod1.submod2.test1=22", file=f) 42 | print("test3=extra", file=f) 43 | print(f"FILE={mainfn}", file=f) 44 | 45 | args = ["foo.test1=1", f"foo.file={foofn}", "main=42", f"file={mainfn}"] 46 | assert config_list_to_dict(args) == { 47 | "foo": {"test1": "20", "test3": "extra", "main": "24", "submod1": {"test1": "21", "submod2": {"test1": "22"}}}, 48 | "main": "24", 49 | } 50 | 51 | 52 | def test_config_string_with_yaml_files_to_dict(tmpdir): 53 | mainfn = os.path.join(tmpdir, "main.yaml") 54 | 55 | main_data = dict( 56 | main=24, 57 | ) 58 | with open(mainfn, "wt") as f: 59 | yaml.dump(main_data, f, default_flow_style=False) 60 | 61 | foo_data = dict(test1=20, submod1=dict(test1=21, submod2=dict(test1=22)), test3="extra", FILE=mainfn) 62 | 63 | foofn = os.path.join(tmpdir, "foo.yaml") 64 | with open(foofn, "wt") as f: 65 | yaml.dump(foo_data, f, default_flow_style=False) 66 | 67 | args = ["foo.test1=1", f"foo.file={foofn}", "main=42", f"file={mainfn}"] 68 | assert config_list_to_dict(args) == { 69 | "foo": {"test1": "20", "test3": "extra", "main": "24", "submod1": {"test1": "21", "submod2": {"test1": "22"}}}, 70 | "main": "24", 71 | } 72 | -------------------------------------------------------------------------------- /profane/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shlex import shlex 3 | import yaml 4 | import collections 5 | 6 | 7 | def config_string_to_dict(s): 8 | s = " ".join(s.split()) # remove consecutive whitespace 9 | return config_list_to_dict(s.split()) 10 | 11 | 12 | def config_list_to_dict(l): 13 | d = {} 14 | 15 | for k, v in _config_list_to_pairs(l): 16 | _dot_to_dict(d, k, v) 17 | 18 | return d 19 | 20 | 21 | def _recursive_update(ori_dict, new_dict): 22 | for k in new_dict: 23 | if k not in ori_dict: 24 | ori_dict[k] = new_dict[k] 25 | elif not isinstance(ori_dict[k], dict): 26 | ori_dict[k] = new_dict[k] 27 | else: 28 | ori_value, new_value = ori_dict[k], new_dict[k] 29 | ori_dict[k] = _recursive_update(ori_value, new_value) 30 | return ori_dict 31 | 32 | 33 | def _load_yaml(fn): 34 | with open(fn) as f: 35 | config = yaml.safe_load(f) 36 | return config 37 | 38 | 39 | def _flatten(d, parent_key="", sep="."): 40 | items = [] 41 | for k, v in d.items(): 42 | new_key = parent_key + sep + k if parent_key else k 43 | if isinstance(v, collections.abc.MutableMapping): 44 | items.extend(_flatten(v, new_key, sep=sep)) 45 | else: 46 | items.append(str(new_key) + "=" + str(v)) 47 | return list(items) 48 | 49 | 50 | def _dot_to_dict(d, k, v, DEL=""): 51 | if k.startswith(".") or k.endswith("."): 52 | raise ValueError(f"invalid path: {k}") 53 | 54 | if "." in k: 55 | path = k.split(".") 56 | current_k = path[0] 57 | remaining_path = ".".join(path[1:]) 58 | 59 | d.setdefault(current_k, {}) 60 | 61 | _dot_to_dict(d[current_k], remaining_path, v, DEL=DEL + " ") 62 | elif k.lower() == "file": 63 | lst = _config_file_to_list(v) 64 | for new_k, new_v in _config_list_to_pairs(lst): 65 | _dot_to_dict(d, new_k, new_v) 66 | else: 67 | d[k] = v 68 | 69 | 70 | def _config_list_to_pairs(l): 71 | pairs = [] 72 | for kv in l: 73 | kv = kv.strip() 74 | 75 | if len(kv) == 0: 76 | continue 77 | 78 | if kv.count("=") != 1: 79 | raise ValueError(f"invalid 'key=value' pair: {kv}") 80 | 81 | k, v = kv.split("=") 82 | if len(v) == 0: 83 | raise ValueError(f"invalid 'key=value' pair: {kv}") 84 | 85 | pairs.append((k, v)) 86 | 87 | return pairs 88 | 89 | 90 | def _config_file_to_list(fn): 91 | lst = [] 92 | ext = os.path.splitext(fn)[1] 93 | if ext == ".yaml": 94 | yaml_list = _flatten(_load_yaml(fn)) 95 | lst.extend(yaml_list) 96 | else: 97 | with open(os.path.expanduser(fn), "rt") as f: 98 | for line in f: 99 | lex = shlex(line) 100 | lex.whitespace = "" 101 | kvs = "".join(list(lex)) 102 | for kv in kvs.strip().split(): 103 | lst.append(kv) 104 | 105 | return lst 106 | -------------------------------------------------------------------------------- /example/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from docopt import docopt 5 | 6 | from profane import DBManager, config_list_to_dict, constants 7 | 8 | # specify a base package that we should look for modules under (e.g., .task) 9 | # constants must be specified before importing Task (or any other modules!) 10 | constants["BASE_PACKAGE"] = "example" 11 | 12 | from task import Task 13 | 14 | 15 | def parse_task_string(s): 16 | fields = s.split(".") 17 | task = fields[0] 18 | task_cls = Task.lookup(task) 19 | 20 | if len(fields) == 2: 21 | cmd = fields[1] 22 | else: 23 | cmd = task_cls.default_command 24 | 25 | if not hasattr(task_cls, cmd): 26 | print("error: invalid command:", s) 27 | print(f"valid commands for task={task}: {sorted(task_cls.commands)}") 28 | sys.exit(2) 29 | 30 | return task, cmd 31 | 32 | 33 | def prepare_task(fullcommand, config): 34 | taskstr, commandstr = parse_task_string(fullcommand) 35 | task = Task.create(taskstr, config) 36 | task_entry_function = getattr(task, commandstr) 37 | return task, task_entry_function 38 | 39 | 40 | if __name__ == "__main__": 41 | help = """ 42 | Usage: 43 | run.py COMMAND [(with CONFIG...)] [options] 44 | run.py help [COMMAND] 45 | run.py (-h | --help) 46 | 47 | 48 | Options: 49 | -h --help Print this help message and exit. 50 | -l VALUE --loglevel=VALUE Set the log level: DEBUG, INFO, WARNING, ERROR, or CRITICAL. 51 | -p VALUE --priority=VALUE Sets the priority for a queued up experiment. No effect without -q flag. 52 | -q --queue Only queue this run, do not start it. 53 | 54 | 55 | Arguments: 56 | COMMAND Name of command to run (see below for list of commands) 57 | CONFIG Configuration assignments of the form foo.bar=17 58 | 59 | 60 | Commands: 61 | rank.run ...description here... 62 | rank.describe ...description here... 63 | """ 64 | 65 | # hack to make docopt print full help message if no arguments are give 66 | if len(sys.argv) == 1: 67 | sys.argv.append("-h") 68 | 69 | arguments = docopt(help, version="example") 70 | 71 | if arguments["--loglevel"]: 72 | loglevel = arguments["--loglevel"].upper() 73 | valid_loglevels = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") 74 | 75 | if loglevel not in valid_loglevels: 76 | print("error: log level must be one of:", ", ".join(valid_loglevels)) 77 | sys.exit(1) 78 | 79 | os.environ["EXAMPLE_LOGGING"] = loglevel 80 | 81 | # prepare task even if we're queueing, so that we validate the config 82 | config = config_list_to_dict(arguments["CONFIG"]) 83 | task, task_entry_function = prepare_task(arguments["COMMAND"], config) 84 | 85 | if arguments["--queue"]: 86 | if not arguments["--priority"]: 87 | arguments["--priority"] = 0 88 | 89 | db = DBManager(os.environ.get("EXAMPLE_DB")) 90 | db.queue_run(command=arguments["COMMAND"], config=config, priority=arguments["--priority"]) 91 | else: 92 | print(f"starting {arguments['COMMAND']} with config: {task.config}\n") 93 | task_entry_function() 94 | -------------------------------------------------------------------------------- /flexible_pipeline.md: -------------------------------------------------------------------------------- 1 | `flexible_pipeline` was a Capreolus branch that heavily modified sacred's initialization in order to support profane-style pipelines. This document briefly describes how `profane` differs from the `flexible_pipeline` approach. 2 | 3 | ## Changes from `flexible_pipeline` 4 | - `Benchmark` now depends on `Collection`, so that the collection does not need to be specified by the user separately. 5 | - `Dependency` declarations now have a `key` in addition to the `module_name`. This allows multiple dependencies on the same module (but with different config options for each). For example, an independently-configured `searcher1` and `searcher2`. 6 | - Previously, each time a dependency was declared the object was added to `provide`. (When a module requests a dependency available in `provide`, the *object* in `provide` is used directly rather than instantiating a new object.) Dependencies are no longer provided by default. When `Dependency.provide_this=True`, the dependency will be provided to all of the declaring module's descendants. For example, say we have a `Rank` task that declares a `Benchmark` dependency with `provide_this=True` followed by a `Searcher` dependency. The `Benchmark` will be instantiated first and then provided to the `Searcher` class. However, the `Benchmark` would not be provided to any of the `Rank` task's parents. 7 | - Similarly, a new `Dependency.provide_children` list allows a `Dependency` to provide some of its children to the declaring module's descendants. For example, in the previous example say the `Benchmark` class also declared `provide_children=["collection"]`. The benchmark's `Collection` object would then be passed to the `Rank` task's `Searcher` dependency. 8 | - `Task` is now a true module and can be used as a dependency like any other module. Coupled with the `provide` changes, this allows tasks to be arbitrarily combined (e.g., a `WeirdRerankTask` that depends on two `RerankTask` modules that each run on separate benchmarks). 9 | - Modules' `__init__` method now fully instantiates the module; previously, `instantiate_from_config` was required to instantiate depenencies. The `provide` argument can be used to specify existing dependency objects rather than creating them. For example, `index=Anserini(config=...); BM25(config=..., provide={"index": index})`. 10 | - In place of the `config` class method, modules declare their config by providing a `config_spec` class attribute containing `ConfigOption` objects. 11 | - `registry.all_known_modules` has been replaced with a `ModuleRegistry` class, which is instantiated at `base.module_registry`. 12 | - Modules can also be instantiated using `create` method of the module's base class (e.g., `Reranker` or `Benchmark`). By default, modules instantiated with `create` are cached based on their configs, so that identical module objects are re-used. 13 | 14 | ## DB Queue and Worker 15 | I've revived the run queuing mechanism from before WSDM. To make this work, `EXAMPLE_DB` needs to be a URL pointing to a valid Postgres DB. e.g., `EXAMPLE_DB="postgresql+psycopg2://:@/"`. 16 | - To queue runs, `run.py` accepts a `-q` option that will cause the specified run to be queued rather than run immediately. e.g., `python run.py -q rank.run with searcher.b=0.123`. 17 | - To launch one of these queued runs, run `worker.py` with no arguments. This script will 1) clear any zombie runs on the current host (i.e., runs marked as running that have non-existent PIDs), and then 2) launch any QUEUED/FAILED run that has failed less than three times. 18 | - To continuously launch available runs, `worker.py` can be put in a shell script or queued with slurm. Placing the loop inside Python isn't great, because past experiences revealed a lot of memory errors with this. Looping in Python until a new run is found would be okay though. 19 | -------------------------------------------------------------------------------- /profane/sql.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import socket 4 | 5 | from contextlib import contextmanager 6 | 7 | import sqlalchemy as sa 8 | from sqlalchemy.ext.declarative import declarative_base 9 | from sqlalchemy.orm import sessionmaker 10 | from sqlalchemy_utils import database_exists, create_database 11 | 12 | Base = declarative_base() 13 | 14 | 15 | class Run(Base): 16 | __tablename__ = "run" 17 | 18 | run_id = sa.Column(sa.Integer, primary_key=True) 19 | command = sa.Column(sa.String) 20 | config = sa.Column(sa.JSON) 21 | 22 | hostname = sa.Column(sa.String) 23 | pid = sa.Column(sa.Integer) 24 | 25 | status = sa.Column(sa.Enum("RUNNING", "COMPLETED", "INTERRUPTED", "FAILED", "QUEUED", name="statuses")) 26 | priority = sa.Column(sa.Integer) 27 | tries = sa.Column(sa.Integer, default=0) 28 | 29 | start_time = sa.Column(sa.DateTime(timezone=True)) 30 | stop_time = sa.Column(sa.DateTime(timezone=True)) 31 | queue_time = sa.Column(sa.DateTime(timezone=True)) 32 | 33 | idx1 = sa.Index("idx_status_priority_tries", status, priority, tries) 34 | 35 | 36 | class DBManager: 37 | def __init__(self, url): 38 | engine = sa.create_engine(url, pool_pre_ping=True) 39 | if not database_exists(engine.url): 40 | print("creating missing DB") 41 | create_database(engine.url) 42 | 43 | Base.metadata.create_all(engine) 44 | 45 | self.sessionmaker = sessionmaker(bind=engine) 46 | 47 | def queue_run(self, command, config, priority=0): 48 | run = Run( 49 | config=config, 50 | command=command, 51 | priority=priority, 52 | status="QUEUED", 53 | queue_time=datetime.datetime.now(datetime.timezone.utc), 54 | ) 55 | 56 | with self.session_scope() as session: 57 | session.add(run) 58 | 59 | return run.run_id 60 | 61 | def clear_zombie_runs(self): 62 | # TODO first find runs, then do for_update later when clearing them only 63 | with self.session_scope() as session: 64 | for run in ( 65 | session.query(Run) 66 | .filter(sa.and_(Run.status == "RUNNING", Run.hostname == socket.gethostname())) 67 | .with_for_update() 68 | ): 69 | if not os.path.exists(f"/proc/{run.pid}"): 70 | print(f"found zombie run_id={run.run_id} with pid: {run.pid}") 71 | run.status = "FAILED" 72 | session.add(run) 73 | 74 | def get_eligible_run(self, max_tries=3): 75 | with self.session_scope() as session: 76 | run = ( 77 | session.query(Run) 78 | .filter(sa.or_(Run.status == "QUEUED", Run.status == "FAILED")) 79 | .filter(Run.tries < max_tries) 80 | .order_by(Run.priority.desc(), sa.text("random()")) 81 | .limit(1) 82 | .with_for_update() 83 | .first() 84 | ) 85 | 86 | return run 87 | 88 | def started_event(self, run): 89 | with self.session_scope() as session: 90 | run = session.query(Run).filter(Run.run_id == run.run_id).with_for_update().one() 91 | run.start_time = datetime.datetime.now(datetime.timezone.utc) 92 | run.hostname = socket.gethostname() 93 | run.pid = os.getpid() 94 | run.status = "RUNNING" 95 | run.tries += 1 96 | 97 | session.add(run) 98 | 99 | def _ended_event(self, run, status): 100 | with self.session_scope() as session: 101 | run = session.query(Run).filter(Run.run_id == run.run_id).with_for_update().one() 102 | run.stop_time = datetime.datetime.now(datetime.timezone.utc) 103 | run.status = status 104 | 105 | session.add(run) 106 | 107 | def completed_event(self, run): 108 | return self._ended_event(run, "COMPLETED") 109 | 110 | def interrupted_event(self, run): 111 | return self._ended_event(run, "INTERRUPTED") 112 | 113 | def failed_event(self, run): 114 | return self._ended_event(run, "FAILED") 115 | 116 | # context manager from SA docs 117 | # https://docs.sqlalchemy.org/en/13/orm/session_basics.html 118 | @contextmanager 119 | def session_scope(self): 120 | """Provide a transactional scope around a series of operations.""" 121 | session = self.sessionmaker() 122 | try: 123 | yield session 124 | session.commit() 125 | except: 126 | session.rollback() 127 | raise 128 | 129 | # unlike the example, we don't call session.close() since this invalidates Run objects (eg in worker.py) 130 | -------------------------------------------------------------------------------- /profane/config_option.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | from profane.exceptions import InvalidModuleError 6 | 7 | 8 | class ConfigOption: 9 | """Represents a config option required by a module. 10 | When a module is created, any unspecified config options will receive `default_value`, 11 | and all config options will be cast to value_type. The None type is considered to be a string. 12 | If one of the list types is used, the config option's value will always be provided to the module as a list. 13 | These lists can be converted to strings in list or range format when needed (see `ModuleBase._config_as_strings`). 14 | 15 | Args: 16 | key (str): a name for the config option 17 | default_value (str): the default value the config option should take 18 | description (str): a description to be shown in help messages 19 | value_type: either built-in type bool, int, float, or str; or "intlist", "floatlist", "strlist" 20 | """ 21 | 22 | def __init__(self, key, default_value, description="", value_type=None): 23 | self.key = key 24 | self.default_value = default_value 25 | self.description = description 26 | 27 | if value_type == "strlist": 28 | self.string_representation = partial(convert_list_to_string, item_type=str) 29 | elif value_type == "intlist": 30 | self.string_representation = partial(convert_list_to_string, item_type=int) 31 | elif value_type == "floatlist": 32 | self.string_representation = partial(convert_list_to_string, item_type=float) 33 | else: 34 | self.string_representation = str 35 | 36 | if value_type is None: 37 | value_type = type(self.default_value) 38 | 39 | if value_type == bool: 40 | self.type = lambda x: str(x).lower() == "true" 41 | elif value_type in [str, type(None)]: 42 | self.type = lambda x: None if str(x).lower() == "none" else str(x) 43 | elif value_type == "strlist": 44 | self.type = partial(convert_string_to_list, item_type=str) 45 | elif value_type == "intlist": 46 | self.type = partial(convert_string_to_list, item_type=int) 47 | elif value_type == "floatlist": 48 | self.type = partial(convert_string_to_list, item_type=float) 49 | elif value_type in [list, tuple]: 50 | raise InvalidModuleError( 51 | "ConfigOptions with a default_value of list must set value_type to one of: 'strlist', 'intlist', 'floatlist'" 52 | ) 53 | else: 54 | self.type = value_type 55 | 56 | 57 | def convert_string_to_list(values, item_type): 58 | """Convert a comma-seperated string '1,2,3' to a list of item_type elements.""" 59 | 60 | if isinstance(values, str): 61 | as_range = _parse_string_as_range(values, item_type) 62 | if as_range: 63 | return tuple(as_range) 64 | 65 | values = values.split(",") 66 | elif isinstance(values, (tuple, list)): 67 | pass 68 | else: 69 | values = [values] 70 | 71 | return tuple(item_type(item) for item in values) 72 | 73 | 74 | def _parse_string_as_range(s, item_type): 75 | parts = s.split(",") 76 | if len(parts) != 2: 77 | return None 78 | 79 | ends = parts[0].split("..") 80 | if len(ends) != 2: 81 | return None 82 | 83 | start, stop = ends 84 | start, stop = item_type(start), item_type(stop) 85 | step = item_type(parts[1]) 86 | 87 | if stop <= start: 88 | raise ValueError(f"invalid range: {s}") 89 | 90 | if item_type == int: 91 | return list(range(start, stop + step, step)) 92 | elif item_type == float: 93 | precision = max(_rounding_precision(x) for x in (start, stop, step)) 94 | lst = [round(item, precision) for item in np.arange(start, stop + step, step)] 95 | if lst[-1] > stop: 96 | del lst[-1] 97 | return lst 98 | 99 | raise ValueError(f"unsupported type: {item_type}") 100 | 101 | 102 | def convert_list_to_string(lst, item_type): 103 | """Convert a list to a string. 104 | Try to represent it as a range if the list has more than two elements and item_type is float or int. 105 | [1,2] -> "1,2" 106 | [1,2,3,4] -> "1..5,1" 107 | """ 108 | 109 | lst = [item_type(x) for x in lst] 110 | 111 | # check whether we can represent lst as "start..stop,step" 112 | if len(lst) > 2 and item_type in (float, int): 113 | # for floating point lists, determine the number of significant digits to keep based on the user's input 114 | # e.g., 1.01 --> 2 or 3e-05 --> 5; this is necessary to avoid floating point weirdness when adding step 115 | if item_type == int: 116 | precision = 0 117 | else: 118 | precision = max(_rounding_precision(x) for x in lst) 119 | 120 | # is the distance between successive list elements always the same as the distance between the first two elements? 121 | step = round(lst[1] - lst[0], precision) 122 | is_range = all(lst[idx + 1] == round(lst[idx] + step, precision) for idx in range(len(lst) - 1)) 123 | 124 | if is_range: 125 | start = round(lst[0], precision) 126 | stop = round(lst[-1], precision) 127 | 128 | start, stop, step = _unnecessary_floats_to_ints([start, stop, step]) 129 | return f"{start}..{stop},{step}" 130 | else: 131 | lst = _unnecessary_floats_to_ints(lst) 132 | 133 | return ",".join(str(item) for item in lst) 134 | 135 | 136 | def _rounding_precision(x): 137 | x = str(x) 138 | if len(x.split(".")) == 2: 139 | return len(x.split(".")[1]) 140 | elif len(x.split("e-")) == 2: 141 | return int(x.split("e-")[1]) 142 | 143 | raise ValueError(f"cannot parse: {x}") 144 | 145 | 146 | def _unnecessary_floats_to_ints(lst): 147 | return [int(x) if int(x) == x else x for x in lst] 148 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from profane.base import ModuleBase, PipelineConstructionError, InvalidConfigError, Dependency, module_registry 4 | from profane.config_option import ConfigOption 5 | 6 | 7 | @pytest.fixture 8 | def test_modules(): 9 | module_registry.reset() 10 | 11 | class ModuleTypeA(ModuleBase): 12 | module_type = "Atype" 13 | 14 | @ModuleTypeA.register 15 | class AParent(ModuleTypeA): 16 | module_name = "AParent" 17 | config_spec = [ConfigOption(key="key1", default_value="val1", description="test option")] 18 | dependencies = [ 19 | Dependency(key="myfoo", module="Atype", name="AFoo", default_config_overrides={"changethis": 42}), 20 | Dependency(key="bar", module="Atype", name="ABar"), 21 | ] 22 | 23 | @ModuleTypeA.register 24 | class AFoo(ModuleTypeA): 25 | module_name = "AFoo" 26 | config_spec = [ 27 | ConfigOption(key="foo1", default_value="val1", description="test option"), 28 | ConfigOption(key="changethis", default_value=0, description="something to override"), 29 | ] 30 | dependencies = [Dependency(key="myfoobar", module="Atype", name="AFooBar")] 31 | 32 | @ModuleTypeA.register 33 | class ABar(ModuleTypeA): 34 | module_name = "ABar" 35 | config_spec = [ConfigOption(key="bar1", default_value="val1", description="test option")] 36 | 37 | @ModuleTypeA.register 38 | class AFooBar(ModuleTypeA): 39 | module_name = "AFooBar" 40 | config_spec = [ConfigOption(key="foobar1", default_value="val1", description="test option")] 41 | 42 | return ModuleTypeA, AParent 43 | 44 | 45 | def test_module_creation(): 46 | class SimpleModuleType(ModuleBase): 47 | module_type = "SimpleType" 48 | 49 | @SimpleModuleType.register 50 | class ModuleA(SimpleModuleType): 51 | module_name = "A" 52 | config_spec = [ConfigOption(key="key1", default_value="val1", description="test option")] 53 | 54 | with pytest.raises(ValueError): 55 | SimpleModuleType.create(name="invalid") 56 | 57 | # check that default config options are filled 58 | mod1 = SimpleModuleType.create(name="A", config={}) 59 | assert mod1.config["key1"] == "val1" 60 | 61 | # check that the default config option was overwritten 62 | mod2 = SimpleModuleType.create(name="A", config={"key1": "val2"}) 63 | assert mod2.config["key1"] == "val2" 64 | 65 | # check that this is equivalent to calling init 66 | mod2same = ModuleA(config={"key1": "val2"}) 67 | assert mod2.__class__ == mod2same.__class__ 68 | assert mod2.config == mod2same.config 69 | assert mod2.dependencies == mod2same.dependencies 70 | 71 | # check that invalid config options raise an exception 72 | with pytest.raises(InvalidConfigError): 73 | SimpleModuleType.create(name="A", config={"invalid": "yes"}) 74 | 75 | 76 | def test_module_creation_with_dependencies(test_modules): 77 | ModuleTypeA, AParent = test_modules 78 | 79 | # override myfoo.myfoobar.foobar1 and use default values for all other options 80 | config = {"myfoo": {"myfoobar": {"foobar1": "val2"}}} 81 | mod = ModuleTypeA.create(name="AParent", config=config) 82 | 83 | # check that all dependency objects have been created 84 | myfoo = mod.myfoo 85 | myfoobar = mod.myfoo.myfoobar 86 | bar = mod.bar 87 | 88 | # check that top level module and dependencies have the correct configs 89 | correct_parent_config = { 90 | "name": "AParent", 91 | "key1": "val1", 92 | "myfoo": {"foo1": "val1", "changethis": 42, "name": "AFoo", "myfoobar": {"foobar1": "val2", "name": "AFooBar"}}, 93 | "bar": {"bar1": "val1", "name": "ABar"}, 94 | } 95 | assert mod.config == correct_parent_config 96 | assert myfoo.config == correct_parent_config["myfoo"] 97 | assert bar.config == correct_parent_config["bar"] 98 | assert myfoobar.config == correct_parent_config["myfoo"]["myfoobar"] 99 | 100 | # test that creating AParent via init behaves the same 101 | mod2 = AParent(config) 102 | assert mod2.config == correct_parent_config 103 | assert mod2.dependencies == mod.dependencies 104 | 105 | 106 | def test_module_creation_with_provided_dependencies(test_modules): 107 | ModuleTypeA, AParent = test_modules 108 | 109 | # override AFoo's 'myfoobar' dependency to be an instance of 'ABar' 110 | foo = ModuleTypeA.create(name="AFoo", config={"foo1": "provided1", "myfoobar": {"name": "ABar"}}) 111 | # check myfoobar is an 'ABar' rather than the default of 'AFooBar' 112 | assert foo.myfoobar.module_name == "ABar" 113 | 114 | provided = {"myfoo": foo} 115 | mod = ModuleTypeA.create(name="AParent", config={}, provide=provided) 116 | 117 | # check that provided module was used 118 | assert mod.myfoo == foo 119 | assert mod.config["myfoo"]["foo1"] == "provided1" 120 | 121 | # test that creating AParent via init behaves the same 122 | mod2 = AParent(provide=provided) 123 | assert mod2.config == mod.config 124 | assert mod2.dependencies == mod.dependencies 125 | 126 | 127 | def test_module_compute_config(test_modules): 128 | ModuleTypeA, AParent = test_modules 129 | 130 | default_parent_config = { 131 | "name": "AParent", 132 | "key1": "val1", 133 | "myfoo": {"foo1": "val1", "changethis": 42, "name": "AFoo", "myfoobar": {"foobar1": "val1", "name": "AFooBar"}}, 134 | "bar": {"bar1": "val1", "name": "ABar"}, 135 | } 136 | 137 | mod = AParent({"key1": "non_default_value"}) 138 | 139 | # the default config is returned, not the active config 140 | assert AParent.compute_config() == default_parent_config 141 | 142 | # the config is computed from the default and the given config 143 | modified_config = default_parent_config.copy() 144 | modified_config["myfoo"]["foo1"] = "different" 145 | assert AParent.compute_config({"myfoo": {"foo1": "different"}}) == modified_config 146 | 147 | # the config is computed based on the provided module also 148 | modified_config["bar"]["bar1"] = "providedval" 149 | abar = ModuleTypeA.create("ABar", {"bar1": "providedval"}) 150 | assert AParent.compute_config({"myfoo": {"foo1": "different"}}, provide={"bar": abar}) == modified_config 151 | -------------------------------------------------------------------------------- /tests/test_config_types.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pytest 3 | import numpy as np 4 | from hypothesis import given 5 | from hypothesis.strategies import lists, integers, floats, composite 6 | 7 | from profane.base import ModuleBase, PipelineConstructionError, InvalidConfigError, Dependency, module_registry 8 | from profane.config_option import ConfigOption, convert_string_to_list, convert_list_to_string 9 | 10 | 11 | def test_types(): 12 | module_registry.reset() 13 | 14 | class ModuleFoo(ModuleBase): 15 | module_type = "Atype" 16 | module_name = "foo" 17 | config_spec = [ 18 | ConfigOption(key="str1", default_value="foo"), 19 | ConfigOption(key="str2", default_value=9, value_type=str), 20 | ConfigOption(key="int1", default_value=2), 21 | ConfigOption(key="int2", default_value="3", value_type=int), 22 | ConfigOption(key="float1", default_value=2.2), 23 | ConfigOption(key="float2", default_value="3.3", value_type=float), 24 | ConfigOption(key="bool1", default_value=False), 25 | ConfigOption(key="bool2", default_value="false", value_type=bool), 26 | ConfigOption(key="bool3", default_value="true", value_type=bool), 27 | ConfigOption(key="strlist1", default_value=3, value_type="strlist"), 28 | ConfigOption(key="strlist2", default_value=[4, 5], value_type="strlist"), 29 | ConfigOption(key="strlist3", default_value="4,5", value_type="strlist"), 30 | ConfigOption(key="intlist1", default_value=3, value_type="intlist"), 31 | ConfigOption(key="intlist2", default_value="3", value_type="intlist"), 32 | ConfigOption(key="intlist3", default_value=(4, 5), value_type="intlist"), 33 | ConfigOption(key="intlist4", default_value="4,5", value_type="intlist"), 34 | ConfigOption(key="floatlist1", default_value=3, value_type="floatlist"), 35 | ConfigOption(key="none-or-str", default_value=None), 36 | ] 37 | 38 | foo = ModuleFoo() 39 | assert type(foo.config["str1"]) == str 40 | assert type(foo.config["str2"]) == str 41 | assert type(foo.config["int1"]) == int 42 | assert type(foo.config["int2"]) == int 43 | assert type(foo.config["float1"]) == float 44 | assert type(foo.config["float2"]) == float 45 | 46 | assert type(foo.config["none-or-str"]) == type(None) 47 | 48 | assert foo.config["bool1"] is False 49 | assert foo.config["bool2"] is False 50 | assert foo.config["bool3"] is True 51 | 52 | assert foo.config["strlist1"] == ("3",) 53 | assert foo.config["strlist2"] == ("4", "5") 54 | assert foo.config["strlist3"] == ("4", "5") 55 | assert foo.config["intlist1"] == (3,) 56 | assert foo.config["intlist2"] == (3,) 57 | assert foo.config["intlist3"] == (4, 5) 58 | assert foo.config["intlist4"] == (4, 5) 59 | assert foo.config["floatlist1"] == (3.0,) 60 | 61 | foo = ModuleFoo({"none-or-str": "str"}) 62 | assert type(foo.config["none-or-str"]) == str 63 | assert foo.config["none-or-str"] == "str" 64 | 65 | 66 | def test_convert_string_to_list(): 67 | # test typed conversions 68 | assert convert_string_to_list("1,2", int) == (1, 2) 69 | assert convert_string_to_list("1", int) == (1,) 70 | assert convert_string_to_list("1.1,1.2", float) == (1.1, 1.2) 71 | assert convert_string_to_list("1.1", float) == (1.1,) 72 | assert convert_string_to_list("1,2", str) == ("1", "2") 73 | assert convert_string_to_list("1", str) == ("1",) 74 | 75 | # test range conversions 76 | assert convert_string_to_list("1..4,1", int) == (1, 2, 3, 4) 77 | assert convert_string_to_list("1..4,0.5", float) == (1, 1.5, 2, 2.5, 3, 3.5, 4.0) 78 | assert convert_string_to_list("0.65..0.8,0.05", float) == (0.65, 0.7, 0.75, 0.80) 79 | assert convert_string_to_list("0.00001..0.00002,2e-06", float) == (1e-05, 1.2e-05, 1.4e-05, 1.6e-05, 1.8e-05, 2.0e-05) 80 | 81 | # test range checking endpoints 82 | assert convert_string_to_list("1,2,3,4,6", int) == (1, 2, 3, 4, 6) 83 | assert convert_string_to_list("0,2,3,4,5", int) == (0, 2, 3, 4, 5) 84 | 85 | with pytest.raises(ValueError): 86 | convert_string_to_list("1..4,1", str) 87 | 88 | with pytest.raises(ValueError): 89 | convert_string_to_list("3..1,1", int) 90 | 91 | 92 | def test_convert_list_to_string(): 93 | assert convert_list_to_string([1.1, 1.3, 1.5, 1.7], float) == "1.1..1.7,0.2" 94 | assert convert_list_to_string([1, 3, 5], int) == "1..5,2" 95 | 96 | assert convert_list_to_string([1, 3, 4], int) == "1,3,4" 97 | assert convert_list_to_string([1, 3, 4.9999], float) == "1,3,4.9999" 98 | assert convert_list_to_string([1.0, 3, 4.9999], float) == "1,3,4.9999" 99 | assert convert_list_to_string([1.001, 3, 4.9999], float) == "1.001,3,4.9999" 100 | 101 | assert convert_list_to_string([1, 3], int) == "1,3" 102 | assert convert_list_to_string([1.0, 3.0], float) == "1.0,3.0" 103 | 104 | assert convert_list_to_string([1], int) == "1" 105 | assert convert_list_to_string([1.0], float) == "1.0" 106 | 107 | assert convert_list_to_string(["1"], str) == "1" 108 | assert convert_list_to_string(["1", "2"], str) == "1,2" 109 | assert convert_list_to_string(["1", "2", "3", "4"], str) == "1,2,3,4" 110 | 111 | assert convert_list_to_string([1, 2, 3.0], float) == "1..3,1" 112 | assert convert_list_to_string([1.5, 2, 2.5], float) == "1.5..2.5,0.5" 113 | 114 | 115 | @composite 116 | def arithmetic_sequence(draw, dtype): 117 | if dtype == "int": 118 | start = draw(integers(min_value=0, max_value=3)) 119 | end = start + draw(integers(min_value=1, max_value=10)) 120 | step = draw(integers(min_value=1, max_value=3)) 121 | elif dtype == "float": 122 | if random.random() < 0.5: 123 | start = draw(floats(min_value=0.0, max_value=1.0)) # step=0.01 124 | end = start + draw(floats(min_value=1.0, max_value=3.0)) # , 0.01) 125 | step = draw(floats(min_value=0.01, max_value=0.51)) 126 | else: 127 | start = draw(floats(min_value=0.0, max_value=1.0)) 128 | end = start + draw(floats(min_value=0.0001, max_value=0.002)) 129 | step = draw(floats(min_value=0.0001, max_value=0.0005)) 130 | else: 131 | raise ValueError(f"Unexpected dtype {dtype}") 132 | 133 | lst = np.around(np.arange(start, end + step, step), decimals=4).tolist() 134 | assert len(lst) > 0 135 | return lst 136 | 137 | 138 | @given(lst=lists(elements=integers(min_value=0, max_value=100), min_size=1, max_size=10, unique=True)) 139 | def test_string_list_inversion_random_int(lst): 140 | lst = sorted(lst) 141 | assert tuple(lst) == convert_string_to_list(convert_list_to_string(lst, int), int) 142 | 143 | 144 | @given(lst=lists(elements=floats(min_value=0.0, max_value=5), min_size=1, max_size=10, unique=True)) 145 | def test_string_list_inversion_random_float(lst): 146 | assert tuple(lst) == convert_string_to_list(convert_list_to_string(lst, float), float) 147 | 148 | 149 | @given(arithmetic_sequence(dtype="int")) 150 | def test_string_list_inversion_arithmetic_int(lst): 151 | assert tuple(lst) == convert_string_to_list(convert_list_to_string(lst, int), int) 152 | 153 | 154 | @given(arithmetic_sequence(dtype="float")) 155 | def test_string_list_inversion_arithmetic_float(lst): 156 | assert tuple(lst) == convert_string_to_list(convert_list_to_string(lst, float), float) 157 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_task_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | # import constants 4 | from profane.base import ( 5 | ModuleBase, 6 | PipelineConstructionError, 7 | ConfigOption, 8 | InvalidConfigError, 9 | Dependency, 10 | module_registry, 11 | constants, 12 | _DEFAULT_RANDOM_SEED, 13 | ) 14 | 15 | 16 | @pytest.fixture 17 | def rank_modules(): 18 | module_registry.reset() 19 | constants.reset() 20 | 21 | class Task(ModuleBase): 22 | module_type = "task" 23 | requires_random_seed = True 24 | 25 | @Task.register 26 | class ThreeRankTask(Task): 27 | """A strange rank task that runs two searchers on benchmark #1 (via TwoRank) and the third searcher on benchmark #2""" 28 | 29 | module_name = "threerank" 30 | dependencies = [ 31 | Dependency(key="tworank", module="task", name="tworank"), 32 | Dependency(key="rank3", module="task", name="rank"), 33 | ] 34 | 35 | @Task.register 36 | class TwoRankTask(Task): 37 | """A rank tasks two runs two searchers on the same benchmark""" 38 | 39 | module_name = "tworank" 40 | dependencies = [ 41 | Dependency(key="benchmark", module="benchmark", name="rob04yang", provide_this=True, provide_children=["collection"]), 42 | Dependency(key="rank1a", module="task", name="rank"), 43 | Dependency(key="rank1b", module="task", name="rank"), 44 | ] 45 | 46 | @Task.register 47 | class RankTask(Task): 48 | module_name = "rank" 49 | dependencies = [ 50 | Dependency(key="benchmark", module="benchmark", name="rob04yang", provide_this=True, provide_children=["collection"]), 51 | Dependency(key="searcher", module="searcher", name="bm25"), 52 | ] 53 | 54 | @Task.register 55 | class RerankTask(Task): 56 | module_name = "rerank" 57 | config_spec = [ 58 | ConfigOption("fold", "s1", "fold to run"), 59 | ConfigOption( 60 | "optimize", "map", "metric to maximize on the dev set" 61 | ), # affects train() because we check to save weights 62 | ] 63 | dependencies = [ 64 | Dependency(key="benchmark", module="benchmark", name="rob04yang", provide_this=True, provide_children=["collection"]), 65 | Dependency(key="rank", module="task", name="rank"), 66 | Dependency(key="reranker", module="reranker", name="DRMM"), 67 | ] 68 | 69 | @ModuleBase.register 70 | class BenchmarkRob04(ModuleBase): 71 | module_type = "benchmark" 72 | module_name = "rob04yang" 73 | dependencies = [Dependency(key="collection", module="collection", name="robust04")] 74 | 75 | @ModuleBase.register 76 | class BenchmarkTRECDL(ModuleBase): 77 | module_type = "benchmark" 78 | module_name = "trecdl" 79 | dependencies = [Dependency(key="collection", module="collection", name="msmarco")] 80 | 81 | @ModuleBase.register 82 | class SearcherBM25(ModuleBase): 83 | module_type = "searcher" 84 | module_name = "bm25" 85 | dependencies = [Dependency(key="index", module="index", name="anserini")] 86 | config_spec = [ConfigOption(key="k1", default_value=1.0, description="k1 parameter")] 87 | # Searchers are unlikely to actually need a seed, but we require it for testing 88 | requires_random_seed = True 89 | 90 | @ModuleBase.register 91 | class IndexAnserini(ModuleBase): 92 | module_type = "index" 93 | module_name = "anserini" 94 | dependencies = [Dependency(key="collection", module="collection", name="robust04")] 95 | config_spec = [ConfigOption(key="stemmer", default_value="porter", description="stemming")] 96 | 97 | @ModuleBase.register 98 | class CollectionRobust04(ModuleBase): 99 | module_type = "collection" 100 | module_name = "robust04" 101 | 102 | @ModuleBase.register 103 | class CollectionMSMARCO(ModuleBase): 104 | module_type = "collection" 105 | module_name = "msmarco" 106 | 107 | @ModuleBase.register 108 | class ExtractorEmbedtext(ModuleBase): 109 | module_type = "extractor" 110 | module_name = "embedtext" 111 | 112 | dependencies = [ 113 | Dependency(key="index", module="index", name="anserini", default_config_overrides={"stemmer": "none"}), 114 | Dependency(key="tokenizer", module="tokenizer", name="anserini"), 115 | ] 116 | config_spec = [ 117 | ConfigOption("embeddings", "glove6b"), 118 | ConfigOption("zerounk", False), 119 | ConfigOption("calcidf", True), 120 | ConfigOption("maxqlen", 4), 121 | ConfigOption("maxdoclen", 800), 122 | ConfigOption("usecache", False), 123 | ] 124 | 125 | @ModuleBase.register 126 | class TokenizerAnserini(ModuleBase): 127 | module_type = "tokenizer" 128 | module_name = "anserini" 129 | config_spec = [ 130 | ConfigOption("keepstops", True, "keep stopwords if True"), 131 | ConfigOption("stemmer", "none", "stemmer: porter, krovetz, or none"), 132 | ] 133 | 134 | @ModuleBase.register 135 | class TrainerPytorch(ModuleBase): 136 | module_type = "trainer" 137 | module_name = "pytorch" 138 | config_spec = [ 139 | ConfigOption("batch", 32, "batch size"), 140 | ConfigOption("niters", 20), 141 | ConfigOption("itersize", 512), 142 | ConfigOption("gradacc", 1), 143 | ConfigOption("lr", 0.001), 144 | ConfigOption("softmaxloss", False), 145 | ConfigOption("fastforward", False), 146 | ConfigOption("validatefreq", 1), 147 | ConfigOption("boardname", "default"), 148 | ] 149 | config_keys_not_in_path = ["fastforward", "boardname"] 150 | 151 | @ModuleBase.register 152 | class RerankerDRMM(ModuleBase): 153 | module_type = "reranker" 154 | module_name = "DRMM" 155 | dependencies = [ 156 | Dependency(key="extractor", module="extractor", name="embedtext"), 157 | Dependency(key="trainer", module="trainer", name="pytorch"), 158 | ] 159 | config_spec = [ 160 | ConfigOption("nbins", 29, "number of bins in matching histogram"), 161 | ConfigOption("nodes", 5, "hidden layer dimension for matching network"), 162 | ConfigOption("histType", "LCH", "histogram type: CH, NH, LCH"), 163 | ConfigOption("gateType", "IDF", "term gate type: TV or IDF"), 164 | ] 165 | 166 | return [ThreeRankTask, TwoRankTask, RankTask, RerankTask] 167 | 168 | 169 | def test_creation_with_simple_provide(rank_modules): 170 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 171 | 172 | # non-default collection should be set in both benchmark's and searcher's dependencies 173 | rank = RankTask({"benchmark": {"collection": {"name": "msmarco"}}}) 174 | assert rank.benchmark.collection.module_name == "msmarco" 175 | assert rank.searcher.index.collection.module_name == "msmarco" 176 | assert rank.benchmark.collection == rank.searcher.index.collection 177 | 178 | 179 | def test_creation_with_complex_provide(rank_modules): 180 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 181 | 182 | # TwoRank task should provide same default benchmark to both Rank tasks 183 | tworank_default = TwoRankTask() 184 | assert tworank_default.rank1a.benchmark == tworank_default.rank1b.benchmark 185 | assert tworank_default.rank1a.benchmark.module_name == "rob04yang" 186 | # and should provide same default collection to rank.searcher.index 187 | assert tworank_default.rank1a.searcher.index.collection == tworank_default.rank1b.searcher.index.collection 188 | assert tworank_default.rank1a.searcher.index.collection.module_name == "robust04" 189 | # re-using the config should yield a new object with the same config 190 | assert tworank_default.config == TwoRankTask(tworank_default.config).config 191 | 192 | # TwoRank task should provide same non-default benchmark to both Rank tasks 193 | tworank_trecdl = TwoRankTask({"benchmark": {"name": "trecdl"}}) 194 | assert tworank_trecdl.rank1a.benchmark == tworank_trecdl.rank1b.benchmark 195 | assert tworank_trecdl.rank1a.benchmark.module_name == "trecdl" 196 | # and should provide same non-default collection to rank.searcher.index 197 | assert tworank_trecdl.rank1a.searcher.index.collection == tworank_trecdl.rank1b.searcher.index.collection 198 | assert tworank_trecdl.rank1a.searcher.index.collection.module_name == "msmarco" 199 | # re-using the config should yield a new object with the same config 200 | assert tworank_trecdl.config == TwoRankTask(tworank_trecdl.config).config 201 | 202 | 203 | def test_creation_with_more_complex_provide(rank_modules): 204 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 205 | 206 | # this ThreeRank should provide a TwoRank with one benchmark and a Rank with a second (independent) benchmark 207 | threerank = ThreeRankTask({"tworank": {"benchmark": {"name": "rob04yang"}}, "rank3": {"benchmark": {"name": "trecdl"}}}) 208 | assert threerank.tworank.rank1a.benchmark == threerank.tworank.rank1b.benchmark 209 | assert threerank.tworank.rank1a.searcher.index.collection == threerank.tworank.rank1b.searcher.index.collection 210 | assert threerank.tworank.benchmark.module_name == "rob04yang" 211 | assert threerank.rank3.benchmark.module_name == "trecdl" 212 | assert threerank.tworank.rank1a.searcher.index.collection.module_name == "robust04" 213 | assert threerank.rank3.searcher.index.collection.module_name == "msmarco" 214 | # re-using the config should yield a new object with the same config 215 | assert threerank.config == ThreeRankTask(threerank.config).config 216 | 217 | 218 | def test_creation_with_module_object_sharing(rank_modules): 219 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 220 | 221 | tworank_trecdl = TwoRankTask({"benchmark": {"name": "trecdl"}}, share_dependency_objects=True) 222 | # both Rank tasks should be identical and thus pointing to the same object 223 | assert tworank_trecdl.rank1a == tworank_trecdl.rank1b 224 | # however, the TwoRankTask object is not shared because .create() was not used 225 | assert tworank_trecdl != TwoRankTask(tworank_trecdl.config) 226 | 227 | # calling .create() twice returns the same object when the config is the same 228 | assert TwoRankTask.create("tworank", tworank_trecdl.config) == TwoRankTask.create("tworank", tworank_trecdl.config) 229 | # and different objects when the configs are different 230 | assert TwoRankTask.create("tworank", tworank_trecdl.config) != TwoRankTask.create("tworank") 231 | 232 | # change k1 so that Rank and Searcher objects should be different 233 | tworank_k1 = TwoRankTask( 234 | {"rank1a": {"searcher": {"k1": 0.5}}, "rank1b": {"searcher": {"k1": 1.0}}}, share_dependency_objects=True 235 | ) 236 | assert tworank_k1.rank1a.benchmark == tworank_k1.rank1b.benchmark 237 | assert tworank_k1.rank1a.searcher.index == tworank_k1.rank1b.searcher.index 238 | # but Benchmark and Index should be the same objects 239 | assert tworank_k1.rank1a != tworank_k1.rank1b 240 | assert tworank_k1.rank1a.searcher != tworank_k1.rank1b.searcher 241 | # and rank1b should be the same object as used in tworank_default 242 | tworank_default = TwoRankTask(share_dependency_objects=True) 243 | assert tworank_k1.rank1b == tworank_default.rank1b 244 | 245 | # this ThreeRank should use the same benchmark for both its TwoRank and Rank 246 | threerank_same = ThreeRankTask( 247 | {"tworank": {"benchmark": {"name": "trecdl"}}, "rank3": {"benchmark": {"name": "trecdl"}}}, share_dependency_objects=True 248 | ) 249 | assert threerank_same.tworank.benchmark == threerank_same.rank3.benchmark 250 | 251 | 252 | def test_module_path(rank_modules): 253 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 254 | 255 | rt = RankTask({"searcher": {"index": {"stemmer": "other"}}}) 256 | assert ( 257 | rt.get_module_path() 258 | == "collection-robust04/benchmark-rob04yang/collection-robust04/index-anserini_stemmer-other/searcher-bm25_k1-1.0_seed-42/task-rank_seed-42" 259 | ) 260 | assert rt.benchmark.get_module_path() == "collection-robust04/benchmark-rob04yang" 261 | assert rt.searcher.get_module_path() == "collection-robust04/index-anserini_stemmer-other/searcher-bm25_k1-1.0_seed-42" 262 | 263 | rrt = RerankTask() 264 | assert ( 265 | rrt.get_module_path() 266 | == "collection-robust04/benchmark-rob04yang/collection-robust04/benchmark-rob04yang/collection-robust04/index-anserini_stemmer-porter/searcher-bm25_k1-1.0_seed-42/task-rank_seed-42/collection-robust04/index-anserini_stemmer-None/tokenizer-anserini_keepstops-True_stemmer-None/extractor-embedtext_calcidf-True_embeddings-glove6b_maxdoclen-800_maxqlen-4_usecache-False_zerounk-False/trainer-pytorch_batch-32_gradacc-1_itersize-512_lr-0.001_niters-20_softmaxloss-False_validatefreq-1/reranker-DRMM_gateType-IDF_histType-LCH_nbins-29_nodes-5/task-rerank_fold-s1_optimize-map_seed-42" 267 | ) 268 | assert rrt.benchmark.get_module_path() == "collection-robust04/benchmark-rob04yang" 269 | assert ( 270 | rrt.rank.get_module_path() 271 | == "collection-robust04/benchmark-rob04yang/collection-robust04/index-anserini_stemmer-porter/searcher-bm25_k1-1.0_seed-42/task-rank_seed-42" 272 | ) 273 | 274 | 275 | def test_config_keys_not_in_module_path(): 276 | @ModuleBase.register 277 | class CollectionSecret(ModuleBase): 278 | module_type = "collection" 279 | module_name = "secretdocs" 280 | config_keys_not_in_path = ["path"] 281 | config_spec = [ 282 | ConfigOption(key="version", default_value="aliens", description="redacted"), 283 | ConfigOption(key="path", default_value="nicetry", description="redacted"), 284 | ] 285 | 286 | collection = CollectionSecret({"version": "illuminati"}) 287 | assert collection.get_module_path() == "collection-secretdocs_version-illuminati" 288 | 289 | 290 | def test_config_seed_propagation(rank_modules): 291 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 292 | 293 | rt = RankTask({"seed": 123, "searcher": {"index": {"stemmer": "other"}}}) 294 | assert rt.config["seed"] == 123 295 | assert rt.searcher.config["seed"] == 123 296 | 297 | 298 | def test_config_seed_nonpropagation(rank_modules): 299 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 300 | 301 | rt = RankTask({"searcher": {"seed": 123, "index": {"stemmer": "other"}}}) 302 | assert rt.config["seed"] == _DEFAULT_RANDOM_SEED 303 | assert rt.searcher.config["seed"] == _DEFAULT_RANDOM_SEED 304 | 305 | 306 | def test_prng_creation(rank_modules): 307 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 308 | 309 | rt = RankTask({"searcher": {"seed": 123, "index": {"stemmer": "other"}}}) 310 | assert hasattr(rt, "rng") 311 | assert hasattr(rt.searcher, "rng") 312 | 313 | assert not hasattr(rt.searcher.index, "rng") 314 | assert not hasattr(rt.searcher.index.collection, "rng") 315 | assert not hasattr(rt.benchmark, "rng") 316 | assert not hasattr(rt.benchmark.collection, "rng") 317 | 318 | 319 | def test_creation_with_config_string(rank_modules): 320 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 321 | 322 | rt1 = RankTask({"searcher": {"seed": 123, "index": {"stemmer": "other"}}}) 323 | rt2 = RankTask("searcher.seed=123 searcher.index.stemmer=other") 324 | rt3 = RankTask("searcher.seed=456 searcher.seed=123 searcher.index.stemmer=other") 325 | 326 | assert rt1.config == rt2.config 327 | assert rt2.config == rt3.config 328 | 329 | 330 | def test_creation_with_provide_obj(rank_modules): 331 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 332 | 333 | benchmark = module_registry.lookup("benchmark", "trecdl")() 334 | rt = RankTask("benchmark.name=rob04yang", provide=benchmark) 335 | 336 | assert rt.benchmark == benchmark 337 | 338 | 339 | def test_creation_with_provide_list(rank_modules): 340 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 341 | 342 | benchmark = module_registry.lookup("benchmark", "trecdl")() 343 | rt = RankTask("benchmark.name=rob04yang", provide=[benchmark]) 344 | 345 | assert rt.benchmark == benchmark 346 | 347 | 348 | def test_skip_config_in_module_path(rank_modules): 349 | ThreeRankTask, TwoRankTask, RankTask, RerankTask = rank_modules 350 | 351 | rerank = RerankTask() 352 | assert rerank.get_module_path().endswith("/task-rerank_fold-s1_optimize-map_seed-42") 353 | assert rerank.get_module_path(skip_config_keys="fold").endswith("/task-rerank_optimize-map_seed-42") 354 | assert rerank.get_module_path(skip_config_keys=["fold", "seed"]).endswith("/task-rerank_optimize-map") 355 | 356 | 357 | def test_registry_enumeration(rank_modules): 358 | assert module_registry.get_module_types() == [ 359 | "benchmark", 360 | "collection", 361 | "extractor", 362 | "index", 363 | "reranker", 364 | "searcher", 365 | "task", 366 | "tokenizer", 367 | "trainer", 368 | ] 369 | assert module_registry.get_module_names("benchmark") == ["rob04yang", "trecdl"] 370 | assert module_registry.get_module_names("collection") == ["msmarco", "robust04"] 371 | assert module_registry.get_module_names("index") == ["anserini"] 372 | assert module_registry.get_module_names("searcher") == ["bm25"] 373 | assert module_registry.get_module_names("task") == ["rank", "rerank", "threerank", "tworank"] 374 | -------------------------------------------------------------------------------- /profane/base.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | from glob import glob 7 | 8 | from colorama import Style, Fore 9 | 10 | from profane.cli import config_string_to_dict 11 | from profane.config_option import ConfigOption 12 | from profane.exceptions import PipelineConstructionError, InvalidConfigError, InvalidModuleError 13 | from profane.frozendict import FrozenDict 14 | import profane.constants as constants 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.addHandler(logging.NullHandler()) 18 | 19 | 20 | _DEFAULT_RANDOM_SEED = 42 21 | constants = constants.ConstantsRegistry() 22 | 23 | 24 | class ModuleRegistry: 25 | """Keeps track of modules that have been registered with `ModuleBase.register`""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.registry = {} 32 | self.shared_objects = {} 33 | 34 | def register(self, cls): 35 | """Register a class that describes itself via a `module_type` and a `module_name variable.""" 36 | 37 | if not hasattr(cls, "module_type"): 38 | raise InvalidModuleError(f"missing module_type for class: {cls}") 39 | 40 | if not hasattr(cls, "module_name"): 41 | raise InvalidModuleError(f"missing module_name for class: {cls}") 42 | 43 | if not isinstance(cls.dependencies, list): 44 | raise TypeError(f"wrong type of dependencies for class {cls}, expect list but found {type(cls.dependencies)}") 45 | 46 | module_type_registry = self.registry.setdefault(cls.module_type, {}) 47 | 48 | # do we already have a different entry for this module_type and module_name? 49 | if module_type_registry.get(cls.module_name, cls) != cls: 50 | logger.warning(f"replacing entry {module_type_registry[cls.module_name]} for {cls.module_name} with {cls}") 51 | 52 | module_type_registry[cls.module_name] = cls 53 | 54 | def lookup(self, module_type, module_name): 55 | """Return the class corresponding to a `module_type` and `module_name` pair.""" 56 | 57 | if module_type not in self.registry: 58 | raise ValueError(f"unknown module_type '{module_type}'; known types: {self.get_module_types()}") 59 | 60 | if module_name not in self.registry[module_type]: 61 | raise ValueError( 62 | f"unknown module_name '{module_name}'; known modules of type '{module_type}': {sorted(self.get_module_names(module_type))}" 63 | ) 64 | 65 | return self.registry[module_type][module_name] 66 | 67 | def get_module_types(self): 68 | return sorted(k for k in self.registry.keys() if len(self.registry[k]) > 0) 69 | 70 | def get_module_names(self, module_type): 71 | return sorted(self.registry[module_type]) 72 | 73 | def get_registered_modules(self): 74 | return [ 75 | (module_type, module_name) 76 | for module_type in self.get_module_types() 77 | for module_name in self.get_module_names(module_type) 78 | ] 79 | 80 | 81 | module_registry = ModuleRegistry() 82 | 83 | 84 | class Dependency: 85 | """Represents a dependency on another module. 86 | 87 | If name is None, the dependency must be provided by the pipeline (i.e., in `provided_modules`). 88 | Otherwise, the module class corresponding to `name` will be used. 89 | 90 | If default_config_overrides is a dict, it will be used to override the dependency's default config options. 91 | Note that user may still override these options e.g. on the command line. 92 | """ 93 | 94 | def __init__(self, key, module, name=None, default_config_overrides=None, provide_this=False, provide_children=None): 95 | try: 96 | if "BASE_PACKAGE" in constants: 97 | importlib.import_module(f"{constants['BASE_PACKAGE']}.{module}") 98 | except ModuleNotFoundError as e: 99 | pass 100 | 101 | self.key = key 102 | self.module = module 103 | self.name = name 104 | self.provide_this = provide_this 105 | 106 | if default_config_overrides is None: 107 | default_config_overrides = {} 108 | 109 | if provide_children is None: 110 | provide_children = [] 111 | 112 | self.default_config_overrides = default_config_overrides 113 | self.provide_children = provide_children 114 | 115 | def __str__(self): 116 | return f"" 117 | 118 | 119 | class ModuleBase: 120 | """Base class for profane modules. 121 | Module construction proceeds as follows: 122 | 1) Any config options not present in `config` are filled in with their default values. Config options and their defaults are specified in the `config_spec` class attribute. 123 | 2) Any dependencies declared in the `dependencies` class attribute are recursively instantiated. If the dependency object is present in `provide`, this object will be used instead of instantiating a new object for the dependency. 124 | 3) The module object's `config` variable is updated to reflect the configs of its dependencies and then frozen. 125 | 126 | After construction is complete, the module's dependencies are available as instance variables: self.`dependency key`. 127 | 128 | Args: 129 | config: dictionary containing a config to apply to this module and its dependencies 130 | provide: dictionary mapping dependency keys to module objects 131 | share_dependency_objects: if true, dependencies will be cached in the registry based on their configs and reused. See the `share_objects` argument of `ModuleBase.create`. 132 | """ 133 | 134 | config_spec = [] 135 | dependencies = [] 136 | config_keys_not_in_path = [] 137 | requires_random_seed = False 138 | 139 | @staticmethod 140 | def register(cls): 141 | module_registry.register(cls) 142 | return cls 143 | 144 | @classmethod 145 | def _validate_and_cast_config(cls, config): 146 | """Validates `config` and casts values to their correct types. 147 | Reraises an exception if any option present is not recognized or is incompatible with its type. 148 | """ 149 | 150 | options = {option.key: option for option in cls.config_spec} 151 | dependencies = set(dependency.key for dependency in cls.dependencies) 152 | 153 | for key in list(config.keys()): 154 | if key == "name": 155 | if config[key] != cls.module_name: 156 | raise InvalidConfigError(f"key name={config[key]} does not match cls.module_name={cls.module_name}") 157 | elif key == "seed": 158 | if not cls.requires_random_seed: 159 | raise InvalidConfigError(f"seed={config[key]} was provided but cls.requires_random_seed=False") 160 | if config["seed"] != constants["RANDOM_SEED"]: 161 | raise InvalidConfigError( 162 | f"seed={config[key]} does not match constants.RANDOM_SEED={constants.RANDOM_SEED}. This indicates that different seeds were configured within the same process which is not possible. Please start a new process for each different seed." 163 | ) 164 | elif key in dependencies: 165 | if isinstance(config[key], str): 166 | raise InvalidConfigError( 167 | f"invalid option: '{key}={config[key]}' ... maybe you meant: '{key}.name={config[key]}'" 168 | ) 169 | elif key not in options: 170 | raise InvalidConfigError(f"received unknown config key: {key}") 171 | else: 172 | config[key] = options[key].type(config[key]) 173 | 174 | return config 175 | 176 | @classmethod 177 | def _fill_in_default_config_options(cls, config): 178 | """Adds default values to config for any key that is not already present""" 179 | for option in cls.config_spec: 180 | if option.key not in config: 181 | config[option.key] = option.type(option.default_value) 182 | return config 183 | 184 | @classmethod 185 | def _config_values_to_strings(cls, config): 186 | """Converts config values to strings that can be shown to the user""" 187 | 188 | options = {option.key: option for option in cls.config_spec} 189 | dependencies = set(dependency.key for dependency in cls.dependencies) 190 | 191 | config_as_strings = {} 192 | for key in config: 193 | if key in dependencies: 194 | continue 195 | elif key == "name" or key == "seed": 196 | val = config[key] 197 | else: 198 | val = options[key].string_representation(config[key]) 199 | 200 | reconverted_typed_value = options[key].type(val) 201 | current_typed_value = config[key] 202 | if current_typed_value != reconverted_typed_value: 203 | raise RuntimeError( 204 | f"value changed during type conversion: '{current_typed_value}' became '{reconverted_typed_value}'" 205 | ) 206 | 207 | config_as_strings[key] = val 208 | 209 | return config_as_strings 210 | 211 | @classmethod 212 | def create(cls, name, config=None, provide=None, share_objects=True): 213 | """Creates a module by looking up a `name` in the module registry corresponding to the calling class' module type. 214 | `config` and `provide` are passed to the module's constructor. 215 | 216 | If `share_objects` is true: 217 | - any instantiated module objects will be cached in the registry based on their configs 218 | - when a module with the same config is created, the cached object is returned rather than a new instance 219 | This behavior applies to any module dependencies as well. 220 | """ 221 | 222 | module_cls = module_registry.lookup(cls.module_type, name) 223 | module_obj = module_cls(config, provide, share_dependency_objects=share_objects) 224 | 225 | if not share_objects: 226 | return module_obj 227 | 228 | if module_obj.config not in module_registry.shared_objects: 229 | module_registry.shared_objects[module_obj.config] = module_obj 230 | 231 | return module_registry.shared_objects[module_obj.config] 232 | 233 | @classmethod 234 | def lookup(cls, name): 235 | return module_registry.lookup(cls.module_type, name) 236 | 237 | @classmethod 238 | def compute_config(cls, config=None, provide=None): 239 | """Return this module class' effective config after taking the module's defaults, `config`, and `provide` into account.""" 240 | return cls(config, provide=provide, share_dependency_objects=False).config 241 | 242 | def __init__(self, config=None, provide=None, share_dependency_objects=False, build=True): 243 | # create new objects to prevent them from being shared with other class instances 244 | self._dependency_objects = {} 245 | self._provided_dependency = set() 246 | 247 | if isinstance(config, str): 248 | config = config_string_to_dict(config) 249 | 250 | if isinstance(config, FrozenDict): 251 | config = config._as_dict() 252 | 253 | if isinstance(provide, ModuleBase): 254 | provide = [provide] 255 | 256 | if isinstance(provide, (list, tuple)): 257 | provide = {module.module_type: module for module in provide} 258 | 259 | # it is important that we create a new provide object here, because _instantiate_dependencies may add entries to it. 260 | # we don't want those entries to propagate higher in the module graph. 261 | # see the test with 'threerank_separate' in test_task_pipeline.py for illustration. 262 | if not config: 263 | config = {} 264 | if not provide: 265 | provide = {} 266 | 267 | # make a copy so we don't modify the object that was passed 268 | config = config.copy() 269 | 270 | config["name"] = self.module_name 271 | self._set_random_seed(config) 272 | self.config = self._validate_and_cast_config(config) 273 | self.config = self._fill_in_default_config_options(self.config) 274 | self._config_as_strings = self._config_values_to_strings(self.config) 275 | self._instantiate_dependencies(self.config, provide, share_dependency_objects) 276 | # freeze config 277 | self.config = FrozenDict(self.config) 278 | 279 | if build and hasattr(self, "build"): 280 | self.build() 281 | 282 | def _instantiate_dependencies(self, config, provide, share_objects): 283 | dependencies = {} 284 | for dependency in self.dependencies: 285 | # if the dependency object has been provided, use it directly 286 | if dependency.key in provide: 287 | dependencies[dependency.key] = provide[dependency.key] 288 | self._provided_dependency.add(dependency.key) 289 | 290 | if dependency.key in config: 291 | logger.warning( 292 | "config['%s']='%s' is being replaced with config from provided module: %s", 293 | dependency.key, 294 | config[dependency.key], 295 | provide[dependency.key].config, 296 | ) 297 | 298 | continue 299 | 300 | # if not, we need to instantiate the dependency 301 | # apply any config overrides 302 | dependency_config = dependency.default_config_overrides.copy() 303 | 304 | # apply any config options we received 305 | for k, v in config.get(dependency.key, {}).items(): 306 | dependency_config[k] = v 307 | 308 | # identify correct class for this dependency 309 | dependency_name = dependency_config.get("name", dependency.name) 310 | if dependency_name is None: 311 | raise PipelineConstructionError(f"No name provided for dependency {dependency}") 312 | dependency_cls = module_registry.lookup(dependency.module, dependency_name) 313 | 314 | # instantiate the dependency 315 | dependencies[dependency.key] = dependency_cls.create( 316 | dependency_name, dependency_config, provide=provide, share_objects=share_objects 317 | ) 318 | 319 | # provide the dependency for later modules? 320 | if dependency.provide_this: 321 | if dependency.key in provide: 322 | raise PipelineConstructionError( 323 | f"'provide_this' flag on dependency '{dependency}' would replace existing provided module {provide[dependency.key]} with {dependencies[dependency.key]}" 324 | ) 325 | provide[dependency.key] = dependencies[dependency.key] 326 | 327 | # provide any of this dependency's children for later modules? 328 | for child_dep_key in dependency.provide_children: 329 | if child_dep_key in provide: 330 | raise PipelineConstructionError( 331 | f"'provide_children' list for dependency '{dependency}' would replace existing provided module" 332 | ) 333 | 334 | if not hasattr(dependencies[dependency.key], child_dep_key): 335 | raise PipelineConstructionError( 336 | f"'provide_children' list for dependency '{dependency}' contains key '{child_dep_key}', but the module has no such dependency" 337 | ) 338 | 339 | provide[child_dep_key] = getattr(dependencies[dependency.key], child_dep_key) 340 | 341 | # add dependency configs and objects to self 342 | for module_name, module_obj in dependencies.items(): 343 | if hasattr(self, module_name): # and getattr(self, module_name) != module_obj: 344 | raise PipelineConstructionError(f"would assign {module_obj} to self.{module_name} but it already exists") 345 | 346 | setattr(self, module_name, module_obj) 347 | self._dependency_objects[module_name] = module_obj 348 | self.config[module_name] = module_obj.config 349 | 350 | def _set_random_seed(self, config): 351 | """If this module requires a random seed, set one and initialize the RNGs. 352 | 353 | All modules must share the same seed, because they may make calls to the same RNGs (e.g., ``np.random``). 354 | However, this can lead to non-deterministic behavior and should be avoided whenever possible. 355 | Instead, modules should use their own numpy RNG at `self.rng` to avoid RNG interactions between modules.""" 356 | 357 | if not self.requires_random_seed: 358 | return 359 | 360 | # must use the same seed for all modules 361 | if "RANDOM_SEED" not in constants: 362 | constants["RANDOM_SEED"] = int(config.get("seed", _DEFAULT_RANDOM_SEED)) 363 | random.seed(constants["RANDOM_SEED"]) 364 | np.random.seed(constants["RANDOM_SEED"]) 365 | 366 | self.rng = np.random.Generator(np.random.PCG64(constants["RANDOM_SEED"])) 367 | config["seed"] = constants["RANDOM_SEED"] 368 | 369 | def get_cache_path(self, *args, **kwargs): 370 | """Return an absolute path that can be used for caching. 371 | The path is a function of the module's config and the configs of its dependencies. 372 | """ 373 | 374 | return constants["CACHE_BASE_PATH"] / self.get_module_path(*args, **kwargs) 375 | 376 | def get_module_path(self, skip_config_keys=None): 377 | """Return a relative path encoding the module's config and its dependencies""" 378 | 379 | if self.dependencies: 380 | prefix = os.path.join( 381 | *[self._dependency_objects[dependency.key].get_module_path() for dependency in self.dependencies] 382 | ) 383 | return os.path.join(prefix, self._this_module_path_only(skip_config_keys=skip_config_keys)) 384 | else: 385 | return self._this_module_path_only() 386 | 387 | def _this_module_path_only(self, skip_config_keys=None): 388 | """Return a path encoding only the module's config (and not its dependencies)""" 389 | 390 | if isinstance(skip_config_keys, str): 391 | skip_config_keys = [skip_config_keys] 392 | 393 | if skip_config_keys is None: 394 | skip_config_keys = [] 395 | 396 | module_cfg = { 397 | k: self._config_as_strings[k] 398 | for k in self.config 399 | if k not in self._dependency_objects and k not in self.config_keys_not_in_path and k not in skip_config_keys 400 | } 401 | module_name_key = self.module_type + "-" + module_cfg.pop("name") 402 | return "_".join([module_name_key] + [f"{k}-{v}" for k, v in sorted(module_cfg.items())]) 403 | 404 | def print_module_graph(self, prefix=""): 405 | childprefix = prefix + " " 406 | this = f"{self.module_type}={self.module_name}" 407 | print(prefix + this) 408 | for dependency in self.dependencies: 409 | child = self._dependency_objects[dependency.key] 410 | 411 | if dependency.key in self._provided_dependency: 412 | print(f"{childprefix}{child.module_type}={child.module_name} [provided by pipeline]") 413 | else: 414 | child.print_module_graph(prefix=childprefix) 415 | 416 | def print_module_config(self, prefix=""): 417 | lines = [] 418 | self._config_summary(lines, prefix) 419 | print("\n".join(lines)) 420 | 421 | def _config_summary(self, lines, prefix=""): 422 | options = {option.key: option for option in self.config_spec} 423 | options["name"] = ConfigOption("name", self.module_name) 424 | options["seed"] = ConfigOption("seed", _DEFAULT_RANDOM_SEED, "random seed") 425 | 426 | # show name, followed by module config, followed by dependencies 427 | order = sorted(self.config.keys(), key=lambda x: (x != "name", x in self._dependency_objects, x)) 428 | for key in order: 429 | if key in self._dependency_objects: 430 | lines.append(f"{prefix}{key}:{Style.RESET_ALL}") 431 | childprefix = prefix + " " 432 | if key in self._provided_dependency: 433 | lines.append(f"{childprefix}{Style.DIM}[provided by pipeline]{Style.RESET_ALL}") 434 | else: 435 | self._dependency_objects[key]._config_summary(lines, prefix=childprefix) 436 | else: 437 | if options[key].description: 438 | lines.append(f"{prefix}{Style.DIM}# {options[key].description}{Style.RESET_ALL}") 439 | 440 | color = "" 441 | if self.config[key] != options[key].default_value: 442 | color = Fore.GREEN 443 | lines.append(f"{color}{prefix}{key} = {self._config_as_strings[key]}{Style.RESET_ALL}") 444 | 445 | 446 | def import_all_modules(file, package): 447 | pwd = os.path.dirname(file) 448 | for fn in glob(os.path.join(pwd, "*.py")): 449 | module_name = os.path.basename(fn)[:-3] 450 | if not (module_name.startswith("__") or module_name.startswith("flycheck_") or module_name.startswith("#")): 451 | importlib.import_module(f"{package}.{module_name}") 452 | --------------------------------------------------------------------------------