├── .bumpversion.cfg ├── .github ├── stale.yml └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.rst ├── dev-requirements.txt ├── doc-requirements.txt ├── docs ├── Makefile ├── apidoc.rst ├── collected_information.rst ├── command_line.rst ├── conf.py ├── configuration.rst ├── examples.rst ├── experiment.rst ├── images │ ├── incense-artifact.png │ ├── incense-metric.png │ ├── neptune-collaboration.png │ ├── neptune-compare.png │ ├── neptune-query-api.png │ ├── omniboard-metric-graphs.png │ ├── omniboard-table.png │ ├── prophet.png │ ├── sacred_browser.png │ ├── sacredboard.png │ ├── slack_observer.png │ └── sql_schema.png ├── index.rst ├── ingredients.rst ├── internals.rst ├── logging.rst ├── make.bat ├── observers.rst ├── optional.rst ├── projects_using_sacred.rst ├── quickstart.rst ├── randomness.rst ├── settings.rst └── tensorflow.rst ├── examples ├── 01_hello_world.py ├── 02_hello_config_dict.py ├── 03_hello_config_scope.py ├── 04_captured_functions.py ├── 05_my_commands.py ├── 06_randomness.py ├── 07_magic.py ├── 08_less_magic.py ├── __init__.py ├── captured_out_filter.py ├── docker │ ├── .env │ ├── docker-compose.yml │ └── sacredboard │ │ └── Dockerfile ├── ingredient.py ├── log_example.py ├── modular.py └── named_config.py ├── pyproject.toml ├── requirements.txt ├── sacred ├── __about__.py ├── __init__.py ├── arg_parser.py ├── commandline_options.py ├── commands.py ├── config │ ├── __init__.py │ ├── captured_function.py │ ├── config_dict.py │ ├── config_files.py │ ├── config_scope.py │ ├── config_summary.py │ ├── custom_containers.py │ ├── signature.py │ └── utils.py ├── data │ └── mime.types ├── dependencies.py ├── experiment.py ├── host_info.py ├── ingredient.py ├── initialize.py ├── metrics_logger.py ├── observers │ ├── __init__.py │ ├── base.py │ ├── file_storage.py │ ├── gcs_observer.py │ ├── mongo.py │ ├── queue.py │ ├── s3_observer.py │ ├── slack.py │ ├── sql.py │ ├── sql_bases.py │ ├── telegram_obs.py │ └── tinydb_hashfs │ │ ├── __init__.py │ │ ├── bases.py │ │ └── tinydb_hashfs.py ├── optional.py ├── py.typed ├── pytee.py ├── randomness.py ├── run.py ├── serializer.py ├── settings.py ├── stdout_capturing.py ├── stflow │ ├── __init__.py │ ├── internal.py │ └── method_interception.py └── utils.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── basedir │ ├── __init__.py │ └── my_experiment.py ├── check_pre_commit.sh ├── conftest.py ├── dependency_example.py ├── donotimport.py ├── foo │ ├── __init__.py │ ├── bar.py │ └── mock_extension.py ├── test_arg_parser.py ├── test_commands.py ├── test_config │ ├── __init__.py │ ├── enclosed_config_scope.py │ ├── test_captured_functions.py │ ├── test_config_dict.py │ ├── test_config_files.py │ ├── test_config_scope.py │ ├── test_config_scope_chain.py │ ├── test_dogmatic_dict.py │ ├── test_dogmatic_list.py │ ├── test_fallback_dict.py │ ├── test_readonly_containers.py │ ├── test_signature.py │ └── test_utils.py ├── test_dependencies.py ├── test_examples.py ├── test_exceptions.py ├── test_experiment.py ├── test_host_info.py ├── test_ingredients.py ├── test_metrics_logger.py ├── test_modules.py ├── test_observers │ ├── __init__.py │ ├── failing_mongo_mock.py │ ├── test_file_storage_observer.py │ ├── test_gcs_observer.py │ ├── test_mongo_observer.py │ ├── test_mongo_option.py │ ├── test_queue_mongo_observer.py │ ├── test_queue_observer.py │ ├── test_run_observer.py │ ├── test_s3_observer.py │ ├── test_sql_observer.py │ ├── test_sql_observer_not_installed.py │ ├── test_tinydb_observer.py │ ├── test_tinydb_observer_not_installed.py │ └── test_tinydb_reader.py ├── test_optional.py ├── test_run.py ├── test_serializer.py ├── test_settings.py ├── test_stdout_capturing.py ├── test_stflow │ ├── __init__.py │ ├── test_internal.py │ └── test_method_interception.py └── test_utils.py └── tox.ini /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.8.7 3 | commit = True 4 | tag = True 5 | tag_name = {new_version} 6 | 7 | [bumpversion:file:sacred/__about__.py] 8 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 90 3 | 4 | # Number of days of inactivity before a stale issue is closed 5 | daysUntilClose: 7 6 | 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - feature 10 | - bug 11 | - documentation 12 | - high priority 13 | - in progress 14 | - question 15 | - refactoring 16 | 17 | # Label to use when marking an issue as stale 18 | staleLabel: stale 19 | 20 | # Set to true to ignore issues in a project (defaults to false) 21 | exemptProjects: true 22 | 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | 26 | # Set to true to ignore issues with an assignee (defaults to false) 27 | exemptAssignees: true 28 | 29 | # Comment to post when marking an issue as stale. Set to `false` to disable 30 | markComment: > 31 | This issue has been automatically marked as stale because it has not had 32 | recent activity. It will be closed if no further activity occurs. Thank you 33 | for your contributions. 34 | # Comment to post when closing a stale issue. Set to `false` to disable 35 | closeComment: false -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/actions/pypi-github-auto-release 2 | 3 | name: Auto-publish 4 | 5 | on: 6 | release: 7 | types: [published] 8 | 9 | jobs: 10 | # Auto-publish when version is increased 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@master 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: >- 22 | python -m pip install build --user 23 | - name: Build 24 | run: >- 25 | python -m build --sdist --wheel --outdir dist/ . 26 | - uses: pypa/gh-action-pypi-publish@release/v1 27 | with: 28 | password: ${{ secrets.PYPI_API_TOKEN }} 29 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | pytest: 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | # These tests run on all os versions 13 | env: [ 14 | {python: '3.8', tox: 'py38'}, 15 | {python: '3.9', tox: 'py39'}, 16 | {python: '3.10', tox: 'py310'}, 17 | {python: '3.11', tox: 'py311'}, 18 | ] 19 | os: [ubuntu-latest, windows-latest, macos-latest] 20 | 21 | # These tests only run on ubuntu 22 | include: 23 | - os: ubuntu-latest 24 | env: 25 | python: '3.8' 26 | tox: 'tensorflow-27' 27 | - os: ubuntu-latest 28 | env: 29 | python: '3.9' 30 | tox: 'tensorflow-28' 31 | - os: ubuntu-latest 32 | env: 33 | python: '3.10' 34 | tox: 'tensorflow-29' 35 | - os: ubuntu-latest 36 | env: 37 | python: '3.10' 38 | tox: 'tensorflow-210' 39 | - os: ubuntu-latest 40 | env: 41 | python: '3.10' 42 | tox: 'tensorflow-211' 43 | - os: ubuntu-latest 44 | env: 45 | python: '3.9' 46 | tox: 'numpy-120' 47 | - os: ubuntu-latest 48 | env: 49 | python: '3.10' 50 | tox: 'numpy-121' 51 | - os: ubuntu-latest 52 | env: 53 | python: '3.10' 54 | tox: 'numpy-123' 55 | - os: ubuntu-latest 56 | env: 57 | python: '3.11' 58 | tox: 'numpy-124' 59 | 60 | runs-on: ${{ matrix.os }} 61 | 62 | steps: 63 | - uses: actions/checkout@v3 64 | - name: Set up Python ${{ matrix.env.python }} 65 | uses: actions/setup-python@v4 66 | with: 67 | python-version: ${{ matrix.env.python }} 68 | - name: Install tox 69 | run: | 70 | python -m pip install tox 71 | - name: Test with tox against environments ${{ matrix.env.tox }} 72 | run: python -m tox -e ${{ matrix.env.tox }} 73 | test_pre_commit: 74 | runs-on: ubuntu-latest 75 | steps: 76 | - uses: actions/checkout@v3 77 | - name: Set up Python 78 | uses: actions/setup-python@v4 79 | with: 80 | python-version: 3.11 81 | - name: Test pre-commit 82 | run: | 83 | bash ./tests/check_pre_commit.sh 84 | coverage: 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@v3 88 | - name: Set up Python 89 | uses: actions/setup-python@v4 90 | with: 91 | python-version: 3.11 92 | - name: Install tox 93 | run: | 94 | python -m pip install tox 95 | - name: Run coverage tox job 96 | run: python -m tox -e coverage 97 | setup: 98 | strategy: 99 | fail-fast: false 100 | matrix: 101 | os: [ ubuntu-latest, windows-latest, macos-latest ] 102 | runs-on: ${{ matrix.os }} 103 | steps: 104 | - uses: actions/checkout@v3 105 | - name: Set up Python 106 | uses: actions/setup-python@v4 107 | with: 108 | python-version: 3.11 109 | - name: Install tox 110 | run: | 111 | python -m pip install tox 112 | - name: Run setup tox job 113 | run: python -m tox -e setup 114 | 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by http://www.gitignore.io 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | bin/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | htmlcov/ 34 | .tox/ 35 | .coverage 36 | .cache 37 | nosetests.xml 38 | coverage.xml 39 | 40 | # Translations 41 | *.mo 42 | 43 | # Mr Developer 44 | .mr.developer.cfg 45 | .project 46 | .pydevproject 47 | 48 | # Rope 49 | .ropeproject 50 | 51 | # Django stuff: 52 | *.log 53 | *.pot 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # Jupyter 59 | .ipynb_checkpoints 60 | 61 | 62 | ### PyCharm ### 63 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm 64 | 65 | ## Directory-based project format 66 | .idea/ 67 | # if you remove the above rule, at least ignore user-specific stuff: 68 | # .idea/workspace.xml 69 | # .idea/tasks.xml 70 | # and these sensitive or high-churn files: 71 | # .idea/dataSources.ids 72 | # .idea/dataSources.xml 73 | # .idea/sqlDataSources.xml 74 | # .idea/dynamic.xml 75 | 76 | ## File-based project format 77 | *.ipr 78 | *.iws 79 | *.iml 80 | 81 | ## Additional for IntelliJ 82 | out/ 83 | 84 | # generated by mpeltonen/sbt-idea plugin 85 | .idea_modules/ 86 | 87 | # generated by JIRA plugin 88 | atlassian-ide-plugin.xml 89 | 90 | # generated by Crashlytics plugin (for Android Studio and Intellij) 91 | com_crashlytics_export_strings.xml 92 | 93 | # GEdit temporary files 94 | *~ 95 | 96 | /.pytest_cache/ 97 | pip-wheel-metadata/ 98 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 22.10.0 5 | hooks: 6 | - id: black 7 | language_version: python3 8 | - repo: https://github.com/pycqa/flake8 9 | rev: 6.0.0 10 | hooks: 11 | - id: flake8 12 | exclude: ^(tests|examples|docs)/.* 13 | additional_dependencies: [pep8-naming, flake8-docstrings] 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Contributions are welcome, and they are greatly appreciated! Every little bit 6 | helps, and credit will always be given. 7 | 8 | You can contribute in many ways: 9 | 10 | Types of Contributions 11 | ---------------------- 12 | 13 | Report Bugs 14 | ~~~~~~~~~~~ 15 | 16 | Report bugs at https://github.com/IDSIA/sacred/issues. 17 | 18 | If you are reporting a bug, please include: 19 | 20 | * Any details about your local setup that might be helpful in troubleshooting. 21 | * Steps to reproduce the bug, and if possible a minimal example demonstrating the problem. 22 | 23 | Good first issue 24 | ~~~~~~~~~~~~~~~~ 25 | 26 | Look through the GitHub issues for bugs. Anything tagged with "good first issue" 27 | is a great place to get started. 28 | 29 | Fix Bugs 30 | ~~~~~~~~ 31 | 32 | Look through the GitHub issues for bugs. Anything tagged with "bug" 33 | is open to whoever wants to fix it. 34 | 35 | Implement Features 36 | ~~~~~~~~~~~~~~~~~~ 37 | 38 | Look through the GitHub issues for features. Anything tagged with "feature" 39 | is open to whoever wants to implement it. 40 | 41 | Write Documentation 42 | ~~~~~~~~~~~~~~~~~~~ 43 | 44 | Sacred could always use more documentation, whether as part of the 45 | official Sacred docs, in docstrings, or even on the web in blog posts, 46 | articles, and such. 47 | 48 | When writing docstrings, stick to the `NumPy style 49 | `_. 50 | However, prefer using Python type hints, over type annotation in the docstring. 51 | This makes your type hints useable by type checkers and IDEs. An example docstring 52 | could look like this. 53 | 54 | .. code-block :: python 55 | 56 | def add(a: int, b: int) -> int: 57 | """Add two numbers. 58 | 59 | Parameters 60 | ---------- 61 | a 62 | The first number. 63 | b 64 | The second number. 65 | 66 | Returns 67 | ------- 68 | The sum of the two numbers. 69 | """ 70 | return a + b 71 | 72 | Submit Feedback 73 | ~~~~~~~~~~~~~~~ 74 | 75 | The best way to send feedback is to file an issue at https://github.com/IDSIA/sacred/issues. 76 | 77 | If you are proposing a feature: 78 | 79 | * Explain in detail how it would work. 80 | * Keep the scope as narrow as possible, to make it easier to implement. 81 | * Remember that this is a volunteer-driven project, and that contributions 82 | are welcome :) 83 | 84 | Get Started! 85 | ------------ 86 | 87 | Ready to contribute? Here's how to set up `sacred` for 88 | local development. 89 | 90 | 1. Fork_ the `sacred` repo on GitHub. 91 | 2. Clone your fork locally:: 92 | 93 | $ git clone git@github.com:your_name_here/sacred.git 94 | 95 | 3. Create a branch for local development:: 96 | 97 | $ git checkout -b name-of-your-bugfix-or-feature 98 | 99 | 4. Create your development environment and install the pre-commit hooks:: 100 | 101 | $ # Activate your environment 102 | $ pip install -e . 103 | $ pip install -r dev-requirements.txt 104 | $ pre-commit install 105 | 106 | You can check that pre-commit works with:: 107 | 108 | $ pre-commit run --all-files 109 | 110 | if you get the error ``ModuleNotFoundError: No module named 'distutils.spawn'``, 111 | you should do the following:: 112 | 113 | $ sudo apt-get update 114 | $ sudo apt-get install python3-distutils 115 | 116 | It should solve the problem with ``distutils.spawn``. 117 | 118 | Now you can make your changes locally. 119 | 120 | 5. When you're done making changes, check that your changes pass style and unit 121 | tests, including testing other Python versions with tox:: 122 | 123 | $ tox 124 | 125 | To get tox, use ``pip install tox`` or ``pip install tox-conda``. If you have a conda distribution, you MUST use tox-conda. 126 | 127 | 6. Commit your changes and push your branch to GitHub:: 128 | 129 | $ git add . 130 | $ git commit -m "Your detailed description of your changes." 131 | $ git push origin name-of-your-bugfix-or-feature 132 | 133 | 7. Submit a pull request through the GitHub website. 134 | 135 | .. _Fork: https://github.com/IDSIA/sacred/fork 136 | 137 | Pull Request Guidelines 138 | ----------------------- 139 | 140 | Before you submit a pull request, check that it meets these guidelines: 141 | 142 | 1. Pull requests should be made on their own branch or against master. 143 | 2. The pull request should include tests. 144 | 3. If the pull request adds functionality, the docs should be updated. Put 145 | your new functionality into a function with a docstring, and add the 146 | feature to the list in README.rst. 147 | 4. The pull request should work for all Python versions listed in the ``setup.py``. 148 | Check https://travis-ci.org/IDSIA/sacred/pull_requests 149 | for active pull requests or run the ``tox`` command and make sure that the tests pass for all supported Python versions. 150 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Klaus Greff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include README.rst 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help clean clean-pyc clean-build clean-test clean-docs lint test \ 2 | test-all coverage docs release dist 3 | 4 | help: 5 | @echo "clean - remove all build, doc, test, coverage and Python artifacts" 6 | @echo "clean-build - remove build artifacts" 7 | @echo "clean-pyc - remove Python file artifacts" 8 | @echo "clean-test - remove test and coverage artifacts" 9 | @echo "clean-doc - remove documentation artifacts" 10 | @echo "lint - check style with flake8" 11 | @echo "test - run tests quickly with the default Python" 12 | @echo "test-all - run tests on every Python version with tox" 13 | @echo "coverage - check code coverage quickly with the default Python" 14 | @echo "docs - generate Sphinx HTML documentation, including API docs" 15 | @echo "release - package and upload a release" 16 | @echo "dist - package" 17 | 18 | clean: clean-build clean-pyc clean-test clean-docs 19 | 20 | clean-build: 21 | rm -fr build/ 22 | rm -fr dist/ 23 | rm -fr *.egg-info 24 | 25 | clean-pyc: 26 | find . -name '*.pyc' -exec rm -f {} + 27 | find . -name '*.pyo' -exec rm -f {} + 28 | find . -name '*~' -exec rm -f {} + 29 | find . -name '__pycache__' -exec rm -fr {} + 30 | 31 | clean-test: 32 | rm -fr .tox/ 33 | rm -f .coverage 34 | rm -fr htmlcov/ 35 | rm -fr .cache/ 36 | 37 | clean-docs: 38 | $(MAKE) -C docs clean 39 | 40 | lint: 41 | flake8 sacred 42 | 43 | test: clean-pyc clean-test 44 | py.test 45 | 46 | test-all: clean-pyc 47 | tox 48 | 49 | coverage: clean-pyc 50 | py.test --cov sacred 51 | coverage html 52 | xdg-open htmlcov/index.html 53 | 54 | docs: 55 | $(MAKE) -C docs clean 56 | $(MAKE) -C docs html 57 | xdg-open docs/_build/html/index.html 58 | 59 | release: clean 60 | python setup.py sdist upload 61 | python setup.py bdist_wheel upload 62 | 63 | dist: clean 64 | python setup.py sdist 65 | python setup.py bdist_wheel 66 | ls -l dist 67 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==7.1.2 # tests/test_utils.py depends on that pytest version is exactly 7.1.2 2 | colorama 3 | docopt-ng 4 | gitdb2 5 | GitPython 6 | hashfs 7 | jsonpickle 8 | Mako 9 | MarkupSafe 10 | mock 11 | mongomock 12 | munch>=2.5 13 | packaging 14 | pandas 15 | pbr 16 | python-dateutil 17 | pytz 18 | PyYAML 19 | scandir 20 | sentinels 21 | smmap2 22 | SQLAlchemy 23 | tinydb 24 | tinydb-serialization 25 | wrapt 26 | scikit-learn 27 | pymongo<4.9 # mongomock.gridfs.enable_gridfs_integration() is not compatible with pymongo>=4.9. https://github.com/mongomock/mongomock/issues/903 28 | py-cpuinfo 29 | boto3 30 | moto 31 | google-compute-engine 32 | google-cloud-storage 33 | pre-commit 34 | -------------------------------------------------------------------------------- /doc-requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | doc8 -------------------------------------------------------------------------------- /docs/apidoc.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ***************** 3 | This is a construction site... 4 | 5 | Experiment 6 | ========== 7 | 8 | .. note:: 9 | 10 | Experiment inherits from Ingredient_, so all methods from there also 11 | available in the Experiment. 12 | 13 | .. autoclass:: sacred.Experiment 14 | :members: 15 | :inherited-members: 16 | :special-members: __init__ 17 | 18 | Ingredient 19 | ========== 20 | 21 | .. autoclass:: sacred.Ingredient 22 | :members: 23 | :special-members: __init__ 24 | 25 | 26 | .. _api_run: 27 | 28 | The Run Object 29 | ============== 30 | The Run object can be accessed from python after the run is finished: 31 | ``run = ex.run()`` or during a run using the ``_run`` 32 | :ref:`special value ` in a 33 | :ref:`captured function `. 34 | 35 | .. autoclass:: sacred.run.Run 36 | :members: 37 | :undoc-members: 38 | :special-members: __call__ 39 | 40 | ConfigScope 41 | =========== 42 | .. autoclass:: sacred.config.config_scope.ConfigScope 43 | :members: 44 | :undoc-members: 45 | 46 | ConfigDict 47 | ========== 48 | .. autoclass:: sacred.config.config_dict.ConfigDict 49 | :members: 50 | :undoc-members: 51 | 52 | Observers 53 | ========= 54 | 55 | .. autoclass:: sacred.observers.RunObserver 56 | :members: 57 | :undoc-members: 58 | 59 | .. autoclass:: sacred.observers.MongoObserver 60 | :members: 61 | :undoc-members: 62 | 63 | Host Info 64 | ========= 65 | 66 | .. automodule:: sacred.host_info 67 | :members: 68 | 69 | 70 | Custom Exceptions 71 | ================= 72 | 73 | .. autoclass:: sacred.utils.SacredInterrupt 74 | :members: 75 | 76 | .. autoclass:: sacred.utils.TimeoutInterrupt 77 | :members: 78 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ******** 3 | You can find these examples in the examples directory (surprise!) of the 4 | Sacred sources or in the 5 | `Github Repository `_. 6 | Look at them for the sourcecode, it is an important part of the examples. 7 | It can also be very helpful to run them yourself and play with the command-line 8 | interface. 9 | 10 | The following is just their documentation from their docstring which you can 11 | also get by running them with the ``-h``, ``--help`` or ``help`` flags. 12 | 13 | Hello World 14 | =========== 15 | `examples/01_hello_world.py `_ 16 | 17 | .. automodule:: examples.01_hello_world 18 | 19 | Hello Config Dict 20 | ================= 21 | `examples/02_hello_config_dict.py `_ 22 | 23 | .. automodule:: examples.02_hello_config_dict 24 | 25 | Hello Config Scope 26 | ================== 27 | `examples/03_hello_config_scope.py `_ 28 | 29 | .. automodule:: examples.03_hello_config_scope 30 | 31 | Captured Functions 32 | ================== 33 | `examples/04_captured_functions.py `_ 34 | 35 | .. automodule:: examples.04_captured_functions 36 | 37 | My Commands 38 | =========== 39 | `examples/05_my_commands.py `_ 40 | 41 | .. automodule:: examples.05_my_commands 42 | 43 | Randomness 44 | ========== 45 | `examples/06_randomness.py `_ 46 | 47 | .. automodule:: examples.06_randomness 48 | 49 | 50 | Less magic 51 | ========== 52 | If you are new to Sacred, you might be surprised by the amount of new idioms it 53 | introduces compared to standard Python. But don't worry, you don't have to use any of the 54 | magic if you don't want to and still benefit from the excellent tracking capabilities. 55 | 56 | `examples/07_magic.py `_ 57 | shows a standard machine learning task, that uses a lot of possible Sacred idioms: 58 | 59 | * configuration definition through local variables 60 | * parameter injection through captured functions 61 | * command line interface integration through the ``ex.automain`` decorator 62 | 63 | `examples/08_less_magic.py `_ 64 | shows the same task without any of those idioms. The recipe for replacing Sacred magic with 65 | standard Python is simple. 66 | 67 | * define your configuration in a dictionary, alternatively you can use an external ``JSON`` or ``YAML`` file 68 | * avoid the ``ex.capture`` decorator. Instead only pass ``_config`` to the main function 69 | and access all parameters explicitly through the configuration dictionary 70 | * just use ``ex.main`` instead of ``ex.automain`` and call ``ex.run()`` 71 | explicitly. This avoids the parsing of command line parameters you did not define yourself. 72 | 73 | While we believe that using sacred idioms makes things easier by hard-wiring parameters 74 | and giving you a flexible command line interface, we do not enforce its usage 75 | if you feel more comfortable with classical Python. At its core Sacred is about 76 | tracking computatonal experiments, not about any particular coding style. 77 | 78 | 79 | .. _docker_setup: 80 | 81 | Docker Setup 82 | ============ 83 | `examples/docker `_ 84 | 85 | To use Sacred to its full potential you probably want to use it together with 86 | MongoDB and dashboards like `Omniboard `_ that have been developed for it. 87 | To ease getting started with these services you find an exemplary ``docker-compose`` configuration in 88 | `examples/docker `_. After installing 89 | `Docker Engine `_ and `Docker Compose `_ 90 | (only necessary for Linux) go to the directory and run:: 91 | 92 | docker-compose up 93 | 94 | 95 | This will pull the necessary containers from the internet and build them. This may take several 96 | minutes. 97 | Afterwards mongoDB should be up and running. ``mongo-express``, an admin interface for MongoDB, should now 98 | be available on port ``8081``, accessible by the user and password set in the ``.env`` file 99 | (``ME_CONFIG_BASICAUTH_USERNAME`` and ``ME_CONFIG_BASICAUTH_PASSWORD``). 100 | ``Sacredboard ``should be available on port ``5000``. ``Omniboard`` should be 101 | available on port ``9000``. They will both listen to the the database name set 102 | in the ``.env`` file (``MONGO_DATABASE``) which will allow the boards to listen to the appropriated 103 | mongo database name set when creating the MongoObserver with the ``db_name`` arg. 104 | All services will by default only be exposed to ``localhost``. If you want 105 | to expose them on all interfaces, e.g. for the use on a server, you need to change the port mappings 106 | in ``docker-compose.yml`` from ``127.0.0.1:XXXX:XXXX`` to ``XXXX:XXXX``. However, in this case you should 107 | change the authentication information in ``.env`` to something more secure. 108 | -------------------------------------------------------------------------------- /docs/images/incense-artifact.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/incense-artifact.png -------------------------------------------------------------------------------- /docs/images/incense-metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/incense-metric.png -------------------------------------------------------------------------------- /docs/images/neptune-collaboration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/neptune-collaboration.png -------------------------------------------------------------------------------- /docs/images/neptune-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/neptune-compare.png -------------------------------------------------------------------------------- /docs/images/neptune-query-api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/neptune-query-api.png -------------------------------------------------------------------------------- /docs/images/omniboard-metric-graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/omniboard-metric-graphs.png -------------------------------------------------------------------------------- /docs/images/omniboard-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/omniboard-table.png -------------------------------------------------------------------------------- /docs/images/prophet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/prophet.png -------------------------------------------------------------------------------- /docs/images/sacred_browser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/sacred_browser.png -------------------------------------------------------------------------------- /docs/images/sacredboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/sacredboard.png -------------------------------------------------------------------------------- /docs/images/slack_observer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/slack_observer.png -------------------------------------------------------------------------------- /docs/images/sql_schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/docs/images/sql_schema.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Sacred's documentation! 2 | ================================== 3 | | *Every experiment is sacred* 4 | | *Every experiment is great* 5 | | *If an experiment is wasted* 6 | | *God gets quite irate* 7 | 8 | Sacred is a tool to configure, organize, log and reproduce computational 9 | experiments. It is designed to introduce only minimal overhead, while 10 | encouraging modularity and configurability of experiments. 11 | 12 | The ability to conveniently make experiments configurable is at the heart of 13 | Sacred. If the parameters of an experiment are exposed in this way, it 14 | will help you to: 15 | 16 | - keep track of all the parameters of your experiment 17 | - easily run your experiment for different settings 18 | - save configurations for individual runs in files or a database 19 | - reproduce your results 20 | 21 | In Sacred we achieve this through the following main mechanisms: 22 | 23 | 1. *Config Scopes* are functions with a ``@ex.config`` decorator, that turn 24 | all local variables into configuration entries. This helps to set up your 25 | configuration really easily. 26 | 2. Those entries can then be used in *captured functions* via *dependency 27 | injection*. That way the system takes care of passing parameters around 28 | for you, which makes using your config values really easy. 29 | 3. The *command-line interface* can be used to change the parameters, which 30 | makes it really easy to run your experiment with modified parameters. 31 | 4. Observers log every information about your experiment and the 32 | configuration you used, and saves them for example to a Database. 33 | This helps to keep track of all your experiments. 34 | 5. Automatic seeding helps controlling the randomness in your experiments, 35 | such that they stay reproducible. 36 | 37 | 38 | 39 | Contents 40 | ======== 41 | 42 | .. toctree:: 43 | :maxdepth: 1 44 | 45 | quickstart 46 | experiment 47 | configuration 48 | command_line 49 | collected_information 50 | observers 51 | randomness 52 | logging 53 | ingredients 54 | optional 55 | settings 56 | examples 57 | projects_using_sacred 58 | tensorflow 59 | apidoc 60 | internals 61 | 62 | 63 | Index 64 | ===== 65 | 66 | :ref:`genindex` 67 | 68 | -------------------------------------------------------------------------------- /docs/internals.rst: -------------------------------------------------------------------------------- 1 | Internals of Sacred 2 | ******************* 3 | This section is meant as a reference for Sacred developers. 4 | It should give a high-level description of some of the more intricate 5 | internals of Sacred. 6 | 7 | 8 | Configuration Process 9 | ===================== 10 | The configuration process is executed when an experiment is started, and 11 | determines the final configuration that should be used for the run: 12 | 13 | #. Determine the order for running the ingredients 14 | 15 | - topological 16 | - in the order they where added 17 | 18 | #. For each ingredient do: 19 | 20 | - gather all config updates that apply (needs ``config_updates``) 21 | - gather all named configs to use (needs ``named_configs``) 22 | - gather all fallbacks that apply from subrunners (needs ``subrunners.config``) 23 | - make the fallbacks read-only 24 | - run all named configs and use the results as additional config updates, 25 | but with lower priority than the global ones. (needs ``named_configs``, ``config_updates``) 26 | - run all normal configs 27 | - update the global ``config`` 28 | - run the config hook 29 | - update the global ``config_updates`` 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/logging.rst: -------------------------------------------------------------------------------- 1 | Logging 2 | ******* 3 | Sacred used the python `logging `_ 4 | module to log some basic information about the execution. It also makes it easy 5 | for you to integrate that logging with your code. 6 | 7 | .. _log_levels: 8 | 9 | Adjusting Log-Levels from command line 10 | ====================================== 11 | If you run the hello_world example you will see the following output:: 12 | 13 | >> python hello_world.py 14 | INFO - hello_world - Running command 'main' 15 | INFO - hello_world - Started 16 | Hello world! 17 | INFO - hello_world - Completed after 0:00:00 18 | 19 | The lines starting with ``INFO`` are logging outputs. They can be suppressed by 20 | adjusting the loglevel. This can be done via the command-line like with the 21 | ``--loglevel`` (``-l`` for short) option:: 22 | 23 | >> python hello_world -l ERROR 24 | Hello world! 25 | 26 | The specified level can be either a string or an integer: 27 | 28 | +----------+---------------+ 29 | | Level | Numeric value | 30 | +==========+===============+ 31 | | CRITICAL | 50 | 32 | +----------+---------------+ 33 | | ERROR | 40 | 34 | +----------+---------------+ 35 | | WARNING | 30 | 36 | +----------+---------------+ 37 | | INFO | 20 | 38 | +----------+---------------+ 39 | | DEBUG | 10 | 40 | +----------+---------------+ 41 | | NOTSET | 0 | 42 | +----------+---------------+ 43 | 44 | Adjusting Log-Levels from python 45 | ================================ 46 | 47 | If you prefer, you can also adjust the logging level from python when 48 | running an experiment by passing the long version of the log level 49 | command line parameter as an option, as follows: 50 | 51 | .. code-block:: python 52 | 53 | ex.run(options={'--loglevel': 'ERROR'}) 54 | 55 | Note that this can only be done when using ``Experiment.run``, not when using 56 | ``Experiment.main`` or ``Experiment.automain``. 57 | 58 | Integrate Logging Into Your Experiment 59 | ====================================== 60 | If you want to make use of the logging mechanism for your own experiments the 61 | easiest way is to use the special ``_log`` argument in your captured functions: 62 | 63 | .. code-block:: python 64 | 65 | @ex.capture 66 | def some_function(_log): 67 | _log.warning('My warning message!') 68 | 69 | This will by default print a line like this:: 70 | 71 | WARNING - some_function - My warning message! 72 | 73 | The ``_log`` is a standard 74 | `Logger object `_ 75 | for your function, as a child logger of the experiments main logger. 76 | So it allows calls to ``debug``, ``info``, ``warning``, ``error``, ``critical`` 77 | and some more. Check out the documentation to see what you can do with them. 78 | 79 | Customize the Logger 80 | ==================== 81 | It is easy to customize the logging behaviour of your experiment by just 82 | providing a custom 83 | `Logger object `_ 84 | to your experiment: 85 | 86 | .. code-block:: python 87 | 88 | import logging 89 | logger = logging.getLogger('my_custom_logger') 90 | ## configure your logger here 91 | ex.logger = logger 92 | 93 | The custom logger will be used to generate all the loggers for all 94 | captured functions. This way you can use all the features of the 95 | `logging `_ package. See the 96 | ``examples/log_example.py`` file for an example of this. 97 | 98 | 99 | -------------------------------------------------------------------------------- /docs/optional.rst: -------------------------------------------------------------------------------- 1 | Optional Features 2 | ***************** 3 | 4 | Sacred offers a set of specialized features which are kept optional in order to 5 | keep the list of requirements small. 6 | This page provides a short description of these optional features. 7 | 8 | Git Integration 9 | =============== 10 | If the experiment sources are maintained in a git repository, then Sacred can 11 | extract information about the current state of the repository. More 12 | specifically it will collect the following information, which is stored by the 13 | observers as part of the experiment info: 14 | 15 | * **url:** The url of the origin repository 16 | * **commit:** The SHA256 hash of the current commit 17 | * **dirty:** A boolean indicating if the repository is dirty, i.e. has 18 | uncommitted changes. 19 | 20 | This can be especially useful together with the :ref:`cmdline_enforce_clean` 21 | (``-e / --enforce_clean``) commandline option. If this flag is used, the 22 | experiment immediately fails with an error if started on a dirty repository. 23 | 24 | .. note:: 25 | Git integration can be disabled with ``save_git_info`` flag in the 26 | ``Experiment`` or ``Ingredient`` constructor. 27 | 28 | 29 | Optional Observers 30 | ================== 31 | 32 | MongoDB 33 | ------- 34 | An observer which stores run information in a MongoDB. For more information see 35 | :ref:`mongo_observer`. 36 | 37 | .. note:: 38 | Requires the `pymongo `_ package. 39 | Install with ``pip install pymongo``. 40 | 41 | TinyDB 42 | ------ 43 | An observer which stores run information in a tinyDB. It can be seen as a local 44 | alternative for the MongoDB Observer. For more information see 45 | :ref:`tinydb_observer`. 46 | 47 | .. note:: 48 | Requires the 49 | `tinydb `_, 50 | `tinydb-serialization `_, 51 | and `hashfs `_ packages. 52 | Install with ``pip install tinydb tinydb-serialization hashfs``. 53 | 54 | SQL 55 | --- 56 | An observer that stores run information in a SQL database. For more information 57 | see :ref:`sql_observer` 58 | 59 | .. note:: 60 | Requires the `sqlalchemy `_ package. 61 | Install with ``pip install sqlalchemy``. 62 | 63 | Template Rendering 64 | ------------------ 65 | The :ref:`file_observer` supports automatic report generation using the 66 | `mako `_ package. 67 | 68 | .. note:: 69 | Requires the `mako `_ package. 70 | Install with ``pip install mako``. 71 | 72 | 73 | Numpy and Pandas Integration 74 | ============================ 75 | If ``numpy`` or ``pandas`` are installed Sacred will automatically take care of 76 | a set of type conversions and other details to make working with these packages 77 | as smooth as possible. Normally you won't need to know about any details. But 78 | for some cases it might be useful to know what is happening. So here is a list 79 | of what Sacred will do: 80 | 81 | * automatically set the global numpy random seed (``numpy.random.seed()``). 82 | * if ``numpy`` is installed the :ref:`special value ` ``_rnd`` will be a 83 | ``numpy.random.RandomState`` instead of ``random.Random``. 84 | * because of these two points having numpy installed actually changes the way 85 | randomness is handled. Therefore ``numpy`` is then automatically added to 86 | the dependencies of the experiment, irrespective of its usage in the code. 87 | * ignore typechanges in the configuration from ``numpy`` types to normal 88 | types, such as ``numpy.float32`` to ``float``. 89 | * convert basic numpy types in the configuration to normal types if possible. 90 | This includes converting ``numpy.array`` to ``list``. 91 | * convert ``numpy.array``, ``pandas.Series``, ``pandas.DataFrame`` and 92 | ``pandas.Panel`` to json before storing them in the MongoDB. This includes 93 | instances in the :ref:`info_dict`. 94 | 95 | YAML Format for Configurations 96 | ============================== 97 | If the `PyYAML `_ package is installed Sacred automatically 98 | supports using config files in the yaml format (see :ref:`config_files`). 99 | 100 | .. note:: 101 | Requires the `PyYAML `_ package. 102 | Install with ``pip install PyYAML``. 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /docs/projects_using_sacred.rst: -------------------------------------------------------------------------------- 1 | Projects using Sacred 2 | ********************* 3 | 4 | This is a curated list of projects that make use of Sacred. The list can include 5 | code from research projects or show how to integrate sacred with another library. 6 | 7 | 8 | `Sacred_HyperOpt `_ 9 | ====================================================================== 10 | An example for integrating a general machine learning training script with Sacred 11 | and `HyperOpt `_ (Distributed Asynchronous Hyperparameter Optimization). -------------------------------------------------------------------------------- /docs/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ********** 3 | 4 | Installation 5 | ============ 6 | You can get Sacred directly from pypi like this:: 7 | 8 | pip install sacred 9 | 10 | But you can of course also clone the git repo and install it from there:: 11 | 12 | git clone https://github.com/IDSIA/sacred.git 13 | cd sacred 14 | [sudo] python setup.py install 15 | 16 | Hello World 17 | =========== 18 | Let's jump right into it. This is a minimal experiment using Sacred: 19 | 20 | .. code-block:: python 21 | 22 | from sacred import Experiment 23 | 24 | ex = Experiment() 25 | 26 | @ex.automain 27 | def my_main(): 28 | print('Hello world!') 29 | 30 | We did three things here: 31 | - import ``Experiment`` from ``sacred`` 32 | - create an experiment instance ``ex`` 33 | - decorate the function that we want to run with ``@ex.automain`` 34 | 35 | This experiment can be run from the command-line, and this is what we get:: 36 | 37 | > python h01_hello_world.py 38 | INFO - 01_hello_world - Running command 'my_main' 39 | INFO - 01_hello_world - Started 40 | Hello world! 41 | INFO - 01_hello_world - Completed after 0:00:00 42 | 43 | 44 | This experiment already has a full command-line interface, that we could use 45 | to control the logging level or to automatically save information about the run 46 | in a database. But all of that is of limited use for an experiment without 47 | configurations. 48 | 49 | Our First Configuration 50 | ======================= 51 | 52 | So let us add some configuration to our program: 53 | 54 | .. code-block:: python 55 | 56 | from sacred import Experiment 57 | 58 | ex = Experiment('hello_config') 59 | 60 | @ex.config 61 | def my_config(): 62 | recipient = "world" 63 | message = "Hello %s!" % recipient 64 | 65 | @ex.automain 66 | def my_main(message): 67 | print(message) 68 | 69 | If we run this the output will look precisely as before, but there is a lot 70 | going on already, so lets look at what we did: 71 | 72 | - add the ``my_config`` function and decorate it with ``@ex.config``. 73 | - within that function define the variable ``message`` 74 | - add the ``message`` parameter to the function ``main`` and use it instead of "Hello world!" 75 | 76 | When we run this experiment, Sacred will run the ``my_config`` function and 77 | put all variables from its local scope into the configuration of our experiment. 78 | All the variables defined there can then be used in the ``main`` function. We can see 79 | this happening by asking the command-line interface to print the configuration 80 | for us:: 81 | 82 | > python hello_config.py print_config 83 | INFO - hello_config - Running command 'print_config' 84 | INFO - hello_config - started 85 | Configuration: 86 | message = 'Hello world!' 87 | recipient = 'world' 88 | seed = 746486301 89 | INFO - hello_config - finished after 0:00:00. 90 | 91 | Notice how Sacred picked up the ``message`` and the ``recipient`` variables. 92 | It also added a ``seed`` to our configuration, but we are going to ignore that 93 | for now. 94 | 95 | Now that our experiment has a configuration we can change it from the 96 | :doc:`command_line`:: 97 | 98 | > python hello_config.py with recipient="that is cool" 99 | INFO - hello_config - Running command 'my_main' 100 | INFO - hello_config - started 101 | Hello that is cool! 102 | INFO - hello_config - finished after 0:00:00. 103 | 104 | Notice how changing the ``recipient`` also changed the message. This should give 105 | you a glimpse of the power of Sacred. But there is a lot more to it, so keep reading :). 106 | -------------------------------------------------------------------------------- /docs/randomness.rst: -------------------------------------------------------------------------------- 1 | Controlling Randomness 2 | ********************** 3 | Many experiments rely on some form of randomness. Controlling this randomness is 4 | key to ensure reproducibility of the results. This typically happens by manually 5 | seeding the *Pseudo Random Number Generator (PRNG)*. Sacred can help you manage 6 | this error-prone procedure. 7 | 8 | Automatic Seed 9 | ============== 10 | Sacred auto-generates a seed for each run as part of the configuration (You 11 | might have noticed it, when printing the configuration of an experiment). 12 | This seed has a different value everytime the experiment is run and is stored 13 | as part part of the configuration. You can easily set it by:: 14 | 15 | >>./experiment.py with seed=123 16 | 17 | This root-seed is the central place to control randomness, because internally 18 | all other seeds and PRNGs depend on it in a deterministic way. 19 | 20 | Global Seeds 21 | ============ 22 | Upon starting the experiment, sacred automatically sets the global seed of 23 | ``random`` and (if installed) ``numpy.random`` (which is with v1.19 mark as legacy), 24 | ``tensorflow.set_random_seed``, ``pytorch.manual_seed`` to the auto-generated 25 | root-seed of the experiment. This means that even if you don't take any further 26 | steps, at least the randomness stemming from those two libraries is properly seeded. 27 | 28 | If you rely on any other library that you want to seed globally you should do 29 | so manually first thing inside your main function. For this you can either take 30 | the argument ``seed`` (the root-seed), or ``_seed`` (a seed generated for this 31 | call of the main function). In this case it doesn't really matter. 32 | 33 | Special Arguments 34 | ================= 35 | To generate random numbers that are controlled by the root-seed Sacred provides 36 | two special arguments: ``_rnd`` and ``_seed``. 37 | You can just accept them as a parameters in any captured function: 38 | 39 | .. code-block:: python 40 | 41 | @ex.capture 42 | def do_random_stuff(_rnd, _seed): 43 | print(_seed) 44 | print(_rnd.randint(1, 100)) 45 | 46 | ``_seed`` is an integer that is different every time the function is called. 47 | Likewise ``_rnd`` is a PRNG that you can directly use to generate random numbers. 48 | 49 | .. note:: 50 | If ``numpy`` is installed ``_rnd`` will either be a `numpy.random.Generator `_ object or a `numpy.random.RandomState `_ object. Default behavior is dependen on the numpy version, i. e. with version `v1.19` `numpy.random.RandomState` is marked as legacy. To use the legacy numpy random API regardless of the numpy version set `NUMPY_RANDOM_LEGACY_API` to `True`. 51 | Otherwise it will be `random.Random `_ object. 52 | 53 | All ``_seed`` and ``_rnd`` instances depend deterministically on the root-seed 54 | so they can be controlled centrally. 55 | 56 | Resilience to Change 57 | ==================== 58 | The way Sacred generates these seeds and PRNGs actually offers some amount of 59 | resilience to changes in your experiment or your program flow. So suppose for 60 | example you have an experiment that has two methods that use randomness: 61 | ``A`` and ``B``. You want to run and compare two variants of that experiment: 62 | 63 | 1. Only call ``B``. 64 | 2. First call ``A`` and then ``B``. 65 | 66 | If you use just a single global PRNG that would mean that for a fixed seed the 67 | call to ``B`` gives different results for the two variants, because the call to 68 | ``A`` changed the state of the global PRNG. 69 | 70 | Sacred generates these seeds and PRNGS in a hierarchical way. That makes the 71 | calls to ``A`` and ``B`` independent from one another. So ``B`` would give the 72 | same results in both cases. 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /docs/settings.rst: -------------------------------------------------------------------------------- 1 | .. _settings: 2 | 3 | Settings 4 | ******** 5 | 6 | Some of Sacred's general behaviour is configurable via ``sacred.SETTINGS``. 7 | Its entries can be set simply by importing and modifying it using dict or attribute notation: 8 | 9 | .. code-block:: python 10 | 11 | from sacred import SETTINGS 12 | SETTINGS['HOST_INFO']['INCLUDE_GPU_INFO'] = False 13 | SETTINGS.HOST_INFO.INCLUDE_GPU_INFO = False # equivalent 14 | 15 | Settings 16 | ======== 17 | Here is a brief list of all currently available options. 18 | 19 | 20 | * ``CAPTURE_MODE`` *(default: 'fd' (linux/osx) or 'sys' (windows))* 21 | configure how stdout/stderr are captured. ['no', 'sys', 'fd'] 22 | * ``DEFAULT_BEAT_INTERVAL`` *(default: 10.0)* Configures the default beat interval 23 | * ``CONFIG`` 24 | 25 | * ``ENFORCE_KEYS_MONGO_COMPATIBLE`` *(default: True)* 26 | Make sure all config keys are compatible with MongoDB. 27 | * ``ENFORCE_KEYS_JSONPICKLE_COMPATIBLE`` *(default: True)* 28 | Make sure all config keys are serializable with jsonpickle. 29 | IMPORTANT: Only deactivate if you know what you're doing. 30 | * ``ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS`` *(default: False)* 31 | Make sure all config keys are valid python identifiers. 32 | * ``ENFORCE_STRING_KEYS`` *(default: False)* 33 | Make sure all config keys are strings. 34 | * ``ENFORCE_KEYS_NO_EQUALS`` *(default: True)* 35 | Make sure no config key contains an equals sign. 36 | * ``IGNORED_COMMENTS`` *(default: ['^pylint:', '^noinspection'])* 37 | List of regex patterns to filter out certain IDE or linter directives 38 | from in-line comments in the documentation. 39 | * ``READ_ONLY_CONFIG`` *(default: True)* 40 | Make the configuration read-only inside of captured functions. This 41 | only works to a limited extend because custom types cannot be 42 | controlled. 43 | 44 | * ``HOST_INFO`` 45 | 46 | * ``INCLUDE_GPU_INFO`` *(default: True)* 47 | Try to collect information about GPUs using the nvidia-smi tool. 48 | Deactivating this can cut the start-up time of a Sacred run by about 1 sec. 49 | * ``INCLUDE_CPU_INFO`` *(default: True)* 50 | Try to collect information about the CPU using py-cpuinfo. 51 | Deactivating this can cut the start-up time of a Sacred run by about 3 sec. 52 | * ``CAPTURED_ENV`` *(default: [])* 53 | List of ENVIRONMENT variable names to store in the host-info. 54 | 55 | 56 | * ``COMMAND_LINE`` 57 | 58 | * ``STRICT_PARSING`` *(default: False)* 59 | Disallow string fallback if parsing a value from command-line failed. 60 | This enforces the usage of quotes in the command-line. Note that this can 61 | be very tedious since bash removes one set of quotes, such that double 62 | quotes will be needed. 63 | -------------------------------------------------------------------------------- /docs/tensorflow.rst: -------------------------------------------------------------------------------- 1 | Integration with Tensorflow 2 | *************************** 3 | 4 | Sacred provides ways to interact with the Tensorflow_ library. 5 | The goal is to provide an API that would allow tracking certain 6 | information about how Tensorflow is being used with Sacred. 7 | The collected data are stored in ``experiment.info["tensorflow"]`` 8 | where they can be accessed by various :doc:`observers `. 9 | 10 | Storing Tensorflow Logs (FileWriter) 11 | ------------------------------------ 12 | 13 | To store the location of summaries produced by Tensorflow 14 | (created by ``tensorflow.summary.FileWriter``) into the experiment record 15 | specified by the ``ex`` argument, use the ``sacred.stflow.LogFileWriter(ex)`` 16 | decorator or context manager. 17 | Whenever a new ``FileWriter`` instantiation is detected in a scope of the 18 | decorator or the context manager, the path of the log is 19 | copied to the experiment record exactly as passed to the FileWriter. 20 | 21 | The location(s) can be then found under ``info["tensorflow"]["logdirs"]`` 22 | of the experiment. 23 | 24 | **Important**: The experiment must be in the *RUNNING* state before calling 25 | the decorated method or entering the context. 26 | 27 | 28 | Example Usage As a Decorator 29 | ............................ 30 | 31 | ``LogFileWriter(ex)`` as a decorator can be used either on a function or 32 | on a class method. 33 | 34 | .. code-block:: python 35 | 36 | from sacred.stflow import LogFileWriter 37 | from sacred import Experiment 38 | import tensorflow as tf 39 | 40 | ex = Experiment("my experiment") 41 | 42 | @ex.automain 43 | @LogFileWriter(ex) 44 | def run_experiment(_run): 45 | with tf.Session() as s: 46 | swr = tf.summary.FileWriter("/tmp/1", s.graph) 47 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] 48 | swr2 = tf.summary.FileWriter("./test", s.graph) 49 | #_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 50 | 51 | 52 | 53 | Example Usage As a Context Manager 54 | .................................. 55 | 56 | There is a context manager available to catch the paths 57 | in a smaller portion of code. 58 | 59 | .. code-block:: python 60 | 61 | ex = Experiment("my experiment") 62 | def run_experiment(_run): 63 | with tf.Session() as s: 64 | with LogFileWriter(ex): 65 | swr = tf.summary.FileWriter("/tmp/1", s.graph) 66 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] 67 | swr3 = tf.summary.FileWriter("./test", s.graph) 68 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 69 | # This is called outside the scope and won't be captured 70 | swr3 = tf.summary.FileWriter("./nothing", s.graph) 71 | # Nothing has changed: 72 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 73 | .. _Tensorflow: http://www.tensorflow.org/ -------------------------------------------------------------------------------- /examples/01_hello_world.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This is a minimal example of a Sacred experiment. 5 | 6 | Not much to see here. But it comes with a command-line interface and can be 7 | called like this:: 8 | 9 | $ ./01_hello_world.py 10 | WARNING - 01_hello_world - No observers have been added to this run 11 | INFO - 01_hello_world - Running command 'main' 12 | INFO - 01_hello_world - Started 13 | Hello world! 14 | INFO - 01_hello_world - Completed after 0:00:00 15 | 16 | As you can see it prints 'Hello world!' as expected, but there is also some 17 | additional logging. The log-level can be controlled using the ``-l`` argument:: 18 | 19 | $ ./01_hello_world.py -l WARNING 20 | WARNING - 01_hello_world - No observers have been added to this run 21 | Hello world! 22 | 23 | If you want to learn more about the command-line interface try 24 | ``help`` or ``-h``. 25 | """ 26 | 27 | from sacred import Experiment 28 | 29 | # Create an Experiment instance 30 | ex = Experiment() 31 | 32 | 33 | # This function should be executed so we are decorating it with @ex.automain 34 | @ex.automain 35 | def main(): 36 | print("Hello world!") 37 | -------------------------------------------------------------------------------- /examples/02_hello_config_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ A configurable Hello World "experiment". 4 | In this example we configure the message using a dictionary with 5 | ``ex.add_config`` 6 | 7 | You can run it like this:: 8 | 9 | $ ./02_hello_config_dict.py 10 | WARNING - 02_hello_config_dict - No observers have been added to this run 11 | INFO - 02_hello_config_dict - Running command 'main' 12 | INFO - 02_hello_config_dict - Started 13 | Hello world! 14 | INFO - 02_hello_config_dict - Completed after 0:00:00 15 | 16 | The message can also easily be changed using the ``with`` command-line 17 | argument:: 18 | 19 | $ ./02_hello_config_dict.py with message='Ciao world!' 20 | WARNING - 02_hello_config_dict - No observers have been added to this run 21 | INFO - 02_hello_config_dict - Running command 'main' 22 | INFO - 02_hello_config_dict - Started 23 | Ciao world! 24 | INFO - 02_hello_config_dict - Completed after 0:00:00 25 | """ 26 | 27 | from sacred import Experiment 28 | 29 | ex = Experiment() 30 | 31 | # We add message to the configuration of the experiment here 32 | ex.add_config({"message": "Hello world!"}) 33 | # Equivalent: 34 | # ex.add_config( 35 | # message="Hello world!" 36 | # ) 37 | 38 | 39 | # notice how we can access the message here by taking it as an argument 40 | @ex.automain 41 | def main(message): 42 | print(message) 43 | -------------------------------------------------------------------------------- /examples/03_hello_config_scope.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A configurable Hello World "experiment". 5 | In this example we configure the message using Sacreds special ``ConfigScope``. 6 | 7 | As with hello_config_dict you can run it like this:: 8 | 9 | $ ./03_hello_config_scope.py 10 | WARNING - hello_cs - No observers have been added to this run 11 | INFO - hello_cs - Running command 'main' 12 | INFO - hello_cs - Started 13 | Hello world! 14 | INFO - hello_cs - Completed after 0:00:00 15 | 16 | The message can also easily be changed using the ``with`` command-line 17 | argument:: 18 | 19 | $ ./03_hello_config_scope.py with message='Ciao world!' 20 | WARNING - hello_cs - No observers have been added to this run 21 | INFO - hello_cs - Running command 'main' 22 | INFO - hello_cs - Started 23 | Ciao world! 24 | INFO - hello_cs - Completed after 0:00:00 25 | 26 | 27 | But because we are using a ``ConfigScope`` that constructs the message from a 28 | recipient we can also just modify that:: 29 | 30 | $ ./03_hello_config_scope.py with recipient='Bob' 31 | WARNING - hello_cs - No observers have been added to this run 32 | INFO - hello_cs - Running command 'main' 33 | INFO - hello_cs - Started 34 | Hello Bob! 35 | INFO - hello_cs - Completed after 0:00:00 36 | """ 37 | 38 | from sacred import Experiment 39 | 40 | ex = Experiment("hello_cs") # here we name the experiment explicitly 41 | 42 | 43 | # A ConfigScope is a function like this decorated with @ex.config 44 | # All local variables of this function will be put into the configuration 45 | @ex.config 46 | def cfg(_log): 47 | # The recipient of the greeting 48 | recipient = "world" 49 | 50 | # The message used for greeting 51 | message = "Hello {}!".format(recipient) 52 | 53 | 54 | # again we can access the message here by taking it as an argument 55 | @ex.automain 56 | def main(message): 57 | print(message) 58 | -------------------------------------------------------------------------------- /examples/04_captured_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | In this example the use of captured functions is demonstrated. Like the 5 | main function, they have access to the configuration parameters by just 6 | accepting them as arguments. 7 | 8 | When calling a captured function we do not need to specify the parameters that 9 | we want to be taken from the configuration. They will automatically be filled 10 | by Sacred. But we can always override that by passing them in explicitly. 11 | 12 | When run, this example will output the following:: 13 | 14 | $ ./04_captured_functions.py -l WARNING 15 | WARNING - captured_functions - No observers have been added to this run 16 | This is printed by function foo. 17 | This is printed by function bar. 18 | Overriding the default message for foo. 19 | 20 | """ 21 | 22 | from sacred import Experiment 23 | 24 | ex = Experiment("captured_functions") 25 | 26 | 27 | @ex.config 28 | def cfg(): 29 | message = "This is printed by function {}." 30 | 31 | 32 | # Captured functions have access to all the configuration parameters 33 | @ex.capture 34 | def foo(message): 35 | print(message.format("foo")) 36 | 37 | 38 | @ex.capture 39 | def bar(message): 40 | print(message.format("bar")) 41 | 42 | 43 | @ex.automain 44 | def main(): 45 | foo() # Notice that we do not pass message here 46 | bar() # or here 47 | # But we can if we feel like it... 48 | foo("Overriding the default message for {}.") 49 | -------------------------------------------------------------------------------- /examples/05_my_commands.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This experiment showcases the concept of commands in Sacred. 5 | By just using the ``@ex.command`` decorator we can add additional commands to 6 | the command-line interface of the experiment:: 7 | 8 | $ ./05_my_commands.py greet 9 | WARNING - my_commands - No observers have been added to this run 10 | INFO - my_commands - Running command 'greet' 11 | INFO - my_commands - Started 12 | Hello John! Nice to greet you! 13 | INFO - my_commands - Completed after 0:00:00 14 | 15 | :: 16 | 17 | $ ./05_my_commands.py shout 18 | WARNING - my_commands - No observers have been added to this run 19 | INFO - my_commands - Running command 'shout' 20 | INFO - my_commands - Started 21 | WHAZZZUUUUUUUUUUP!!!???? 22 | INFO - my_commands - Completed after 0:00:00 23 | 24 | Of course we can also use ``with`` and other flags with those commands:: 25 | 26 | $ ./05_my_commands.py greet with name='Jane' -l WARNING 27 | WARNING - my_commands - No observers have been added to this run 28 | Hello Jane! Nice to greet you! 29 | 30 | In fact, the main function is also just a command:: 31 | 32 | $ ./05_my_commands.py main 33 | WARNING - my_commands - No observers have been added to this run 34 | INFO - my_commands - Running command 'main' 35 | INFO - my_commands - Started 36 | This is just the main command. Try greet or shout. 37 | INFO - my_commands - Completed after 0:00:00 38 | 39 | Commands also appear in the help text, and you can get additional information 40 | about all commands using ``./05_my_commands.py help [command]``. 41 | """ 42 | 43 | from sacred import Experiment 44 | 45 | ex = Experiment("my_commands") 46 | 47 | 48 | @ex.config 49 | def cfg(): 50 | name = "John" 51 | 52 | 53 | @ex.command 54 | def greet(name): 55 | """ 56 | Print a nice greet message. 57 | 58 | Uses the name from config. 59 | """ 60 | print("Hello {}! Nice to greet you!".format(name)) 61 | 62 | 63 | @ex.command 64 | def shout(): 65 | """ 66 | Shout slang question for "what is up?" 67 | """ 68 | print("WHAZZZUUUUUUUUUUP!!!????") 69 | 70 | 71 | @ex.automain 72 | def main(): 73 | print("This is just the main command. Try greet or shout.") 74 | -------------------------------------------------------------------------------- /examples/06_randomness.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This example showcases the randomness features of Sacred. 5 | 6 | Sacred generates a random global seed for every experiment, that you can 7 | find in the configuration. It will be different every time you run the 8 | experiment. 9 | 10 | Based on this global seed it will generate the special parameters ``_seed`` and 11 | ``_rnd`` for each captured function. Every time you call such a function the 12 | ``_seed`` will be different and ``_rnd`` will be differently seeded random 13 | state. But their values depend deterministically on the global seed and on how 14 | often the function has been called. 15 | 16 | Here are a couple of things you should try: 17 | 18 | - run the experiment a couple of times and notice how the results are 19 | different every time 20 | 21 | - run the experiment a couple of times with a fixed seed. 22 | Notice that the results are the same:: 23 | 24 | :$ ./06_randomness.py with seed=12345 -l WARNING 25 | [57] 26 | [28] 27 | 695891797 28 | [82] 29 | 30 | - run the experiment with a fixed seed and vary the numbers parameter. 31 | Notice that all the results stay the same except for the added numbers. 32 | This demonstrates that all the calls to one function are in fact 33 | independent from each other:: 34 | 35 | :$ ./06_randomness.py with seed=12345 numbers=3 -l WARNING 36 | [57, 79, 86] 37 | [28, 90, 92] 38 | 695891797 39 | [82, 9, 3] 40 | 41 | - run the experiment with a fixed seed and set the reverse parameter to true. 42 | Notice how the results are the same, but in slightly different order. 43 | This shows that calls to different functions do not interfere with one 44 | another:: 45 | 46 | :$ ./06_randomness.py with seed=12345 reverse=True numbers=3 -l WARNING 47 | 695891797 48 | [57, 79, 86] 49 | [28, 90, 92] 50 | [82, 9, 3] 51 | 52 | """ 53 | 54 | from sacred import Experiment 55 | 56 | ex = Experiment("randomness") 57 | 58 | 59 | @ex.config 60 | def cfg(): 61 | reverse = False 62 | numbers = 1 63 | 64 | 65 | @ex.capture 66 | def do_random_stuff(numbers, _rnd): 67 | print([_rnd.randint(1, 100) for _ in range(numbers)]) 68 | 69 | 70 | @ex.capture 71 | def do_more_random_stuff(_seed): 72 | print(_seed) 73 | 74 | 75 | @ex.automain 76 | def run(reverse): 77 | if reverse: 78 | do_more_random_stuff() 79 | do_random_stuff() 80 | do_random_stuff() 81 | else: 82 | do_random_stuff() 83 | do_random_stuff() 84 | do_more_random_stuff() 85 | 86 | do_random_stuff() 87 | -------------------------------------------------------------------------------- /examples/07_magic.py: -------------------------------------------------------------------------------- 1 | """A standard machine learning task using sacred's magic.""" 2 | from sacred import Experiment 3 | from sacred.observers import FileStorageObserver 4 | from sklearn import svm, datasets, model_selection 5 | 6 | ex = Experiment("svm") 7 | 8 | ex.observers.append(FileStorageObserver("my_runs")) 9 | 10 | 11 | @ex.config # Configuration is defined through local variables. 12 | def cfg(): 13 | C = 1.0 14 | gamma = 0.7 15 | kernel = "rbf" 16 | seed = 42 17 | 18 | 19 | @ex.capture 20 | def get_model(C, gamma, kernel): 21 | return svm.SVC(C=C, kernel=kernel, gamma=gamma) 22 | 23 | 24 | @ex.automain # Using automain to enable command line integration. 25 | def run(): 26 | X, y = datasets.load_breast_cancer(return_X_y=True) 27 | X_train, X_test, y_train, y_test = model_selection.train_test_split( 28 | X, y, test_size=0.2 29 | ) 30 | clf = get_model() # Parameters are injected automatically. 31 | clf.fit(X_train, y_train) 32 | return clf.score(X_test, y_test) 33 | -------------------------------------------------------------------------------- /examples/08_less_magic.py: -------------------------------------------------------------------------------- 1 | """A standard machine learning task without much sacred magic.""" 2 | from sacred import Experiment 3 | from sacred.observers import FileStorageObserver 4 | from sklearn import svm, datasets, model_selection 5 | 6 | ex = Experiment("svm") 7 | 8 | ex.observers.append(FileStorageObserver("my_runs")) 9 | ex.add_config( 10 | { # Configuration is explicitly defined as dictionary. 11 | "C": 1.0, 12 | "gamma": 0.7, 13 | "kernel": "rbf", 14 | "seed": 42, 15 | } 16 | ) 17 | 18 | 19 | def get_model(C, gamma, kernel): 20 | return svm.SVC(C=C, kernel=kernel, gamma=gamma) 21 | 22 | 23 | @ex.main # Using main, command-line arguments will not be interpreted in any special way. 24 | def run(_config): 25 | X, y = datasets.load_breast_cancer(return_X_y=True) 26 | X_train, X_test, y_train, y_test = model_selection.train_test_split( 27 | X, y, test_size=0.2 28 | ) 29 | clf = get_model( 30 | _config["C"], _config["gamma"], _config["kernel"] 31 | ) # Parameters are passed explicitly. 32 | clf.fit(X_train, y_train) 33 | return clf.score(X_test, y_test) 34 | 35 | 36 | if __name__ == "__main__": 37 | ex.run() 38 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | -------------------------------------------------------------------------------- /examples/captured_out_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This example shows how to apply a filter function to the captured output 5 | of a run. This is often useful when using progress bars or similar in the text 6 | UI and you don't want to store formatting characters like backspaces and 7 | linefeeds in the database. 8 | """ 9 | 10 | import sys 11 | import time 12 | 13 | from sacred import Experiment 14 | from sacred.utils import apply_backspaces_and_linefeeds 15 | 16 | ex = Experiment("progress") 17 | 18 | # try commenting out the line below to see the difference in captured output 19 | ex.captured_out_filter = apply_backspaces_and_linefeeds 20 | 21 | 22 | def write_and_flush(*args): 23 | for arg in args: 24 | sys.stdout.write(arg) 25 | sys.stdout.flush() 26 | 27 | 28 | class ProgressMonitor: 29 | def __init__(self, count): 30 | self.count, self.progress = count, 0 31 | 32 | def show(self, n=1): 33 | self.progress += n 34 | text = "Completed {}/{} tasks".format(self.progress, self.count) 35 | write_and_flush("\b" * 80, "\r", text) 36 | 37 | def done(self): 38 | write_and_flush("\n") 39 | 40 | 41 | def progress(items): 42 | p = ProgressMonitor(len(items)) 43 | for item in items: 44 | yield item 45 | p.show() 46 | p.done() 47 | 48 | 49 | @ex.main 50 | def main(): 51 | for item in progress(range(100)): 52 | time.sleep(0.05) 53 | 54 | 55 | if __name__ == "__main__": 56 | run = ex.run_commandline() 57 | print("=" * 80) 58 | print("Captured output: ", repr(run.captured_out)) 59 | -------------------------------------------------------------------------------- /examples/docker/.env: -------------------------------------------------------------------------------- 1 | MONGO_INITDB_ROOT_USERNAME=mongo_user 2 | MONGO_INITDB_ROOT_PASSWORD=mongo_password 3 | ME_CONFIG_BASICAUTH_USERNAME=mongo_express_user 4 | ME_CONFIG_BASICAUTH_PASSWORD=mongo_express_pw 5 | MONGO_DATABASE=sacred -------------------------------------------------------------------------------- /examples/docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.6' 2 | 3 | services: 4 | 5 | mongo: 6 | image: mongo 7 | ports: 8 | - 127.0.0.1:27017:27017 9 | restart: unless-stopped 10 | env_file: .env 11 | volumes: 12 | - mongodb_data:/data/db 13 | - mongodb_config:/data/configdb 14 | 15 | mongo-express: 16 | image: mongo-express 17 | ports: 18 | - 127.0.0.1:8081:8081 19 | restart: unless-stopped 20 | env_file: .env 21 | environment: 22 | ME_CONFIG_MONGODB_ADMINUSERNAME: $MONGO_INITDB_ROOT_USERNAME 23 | ME_CONFIG_MONGODB_ADMINPASSWORD: $MONGO_INITDB_ROOT_PASSWORD 24 | ME_CONFIG_MONGODB_SERVER: mongo 25 | depends_on: 26 | - mongo 27 | 28 | omniboard: 29 | image: vivekratnavel/omniboard:latest 30 | command: ["--mu", "mongodb://$MONGO_INITDB_ROOT_USERNAME:$MONGO_INITDB_ROOT_PASSWORD@mongo:27017/$MONGO_DATABASE?authSource=admin"] 31 | ports: 32 | - 127.0.0.1:9000:9000 33 | restart: unless-stopped 34 | env_file: .env 35 | depends_on: 36 | - mongo 37 | 38 | sacredboard: 39 | build: ./sacredboard 40 | ports: 41 | - 127.0.0.1:5000:5000 42 | restart: unless-stopped 43 | env_file: .env 44 | depends_on: 45 | - mongo 46 | 47 | volumes: 48 | mongodb_data: 49 | mongodb_config: 50 | -------------------------------------------------------------------------------- /examples/docker/sacredboard/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6-jessie 2 | 3 | RUN apt update \ 4 | && pip install https://github.com/chovanecm/sacredboard/archive/develop.zip \ 5 | && rm -rf /var/lib/apt/lists/* 6 | 7 | ENTRYPOINT sacredboard -mu mongodb://$MONGO_INITDB_ROOT_USERNAME:$MONGO_INITDB_ROOT_PASSWORD@mongo:27017/?authMechanism=SCRAM-SHA-1 $MONGO_DATABASE 8 | -------------------------------------------------------------------------------- /examples/ingredient.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred import Ingredient, Experiment 5 | 6 | # ================== Dataset Ingredient ======================================= 7 | # could be in a separate file 8 | 9 | data_ingredient = Ingredient("dataset") 10 | 11 | 12 | @data_ingredient.config 13 | def cfg1(): 14 | filename = "my_dataset.npy" # dataset filename 15 | normalize = True # normalize dataset 16 | 17 | 18 | @data_ingredient.capture 19 | def load_data(filename, normalize): 20 | print("loading dataset from '{}'".format(filename)) 21 | if normalize: 22 | print("normalizing dataset") 23 | return 1 24 | return 42 25 | 26 | 27 | @data_ingredient.command 28 | def stats(filename, foo=12): 29 | print('Statistics for dataset "{}":'.format(filename)) 30 | print("mean = 42.23") 31 | print("foo=", foo) 32 | 33 | 34 | # ================== Experiment =============================================== 35 | 36 | 37 | @data_ingredient.config 38 | def cfg2(): 39 | filename = "foo.npy" 40 | 41 | 42 | # add the Ingredient while creating the experiment 43 | ex = Experiment("my_experiment", ingredients=[data_ingredient]) 44 | 45 | 46 | @ex.config 47 | def cfg3(): 48 | a = 12 49 | b = 42 50 | 51 | 52 | @ex.named_config 53 | def fbb(): 54 | a = 22 55 | dataset = {"filename": "AwwwJiss.py"} 56 | 57 | 58 | @ex.automain 59 | def run(): 60 | data = load_data() # just use the function 61 | print("data={}".format(data)) 62 | -------------------------------------------------------------------------------- /examples/log_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ An example showcasing the logging system of Sacred.""" 4 | 5 | import logging 6 | from sacred import Experiment 7 | 8 | ex = Experiment("log_example") 9 | 10 | # set up a custom logger 11 | logger = logging.getLogger("mylogger") 12 | logger.handlers = [] 13 | ch = logging.StreamHandler() 14 | formatter = logging.Formatter('[%(levelname).1s] %(name)s >> "%(message)s"') 15 | ch.setFormatter(formatter) 16 | logger.addHandler(ch) 17 | logger.setLevel("INFO") 18 | 19 | # attach it to the experiment 20 | ex.logger = logger 21 | 22 | 23 | @ex.config 24 | def cfg(): 25 | number = 2 26 | got_gizmo = False 27 | 28 | 29 | @ex.capture 30 | def transmogrify(got_gizmo, number, _log): 31 | if got_gizmo: 32 | _log.debug("Got gizmo. Performing transmogrification...") 33 | return number * 42 34 | else: 35 | _log.warning("No gizmo. Can't transmogrify!") 36 | return 0 37 | 38 | 39 | @ex.automain 40 | def main(number, _log): 41 | _log.info("Attempting to transmogrify %d...", number) 42 | result = transmogrify() 43 | _log.info("Transmogrification complete: %d", result) 44 | return result 45 | -------------------------------------------------------------------------------- /examples/modular.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This is a very basic example of how to use Sacred. 5 | """ 6 | 7 | from sacred import Experiment, Ingredient 8 | 9 | # ============== Ingredient 0: settings ================= 10 | s = Ingredient("settings") 11 | 12 | 13 | @s.config 14 | def cfg1(): 15 | verbose = True 16 | 17 | 18 | # ============== Ingredient 1: dataset.paths ================= 19 | data_paths = Ingredient("dataset.paths", ingredients=[s]) 20 | 21 | 22 | @data_paths.config 23 | def cfg2(settings): 24 | v = not settings["verbose"] 25 | base = "/home/sacred/" 26 | 27 | 28 | # ============== Ingredient 2: dataset ======================= 29 | data = Ingredient("dataset", ingredients=[data_paths, s]) 30 | 31 | 32 | @data.config 33 | def cfg3(paths): 34 | basepath = paths["base"] + "datasets/" 35 | filename = "foo.hdf5" 36 | 37 | 38 | @data.capture 39 | def foo(basepath, filename, paths, settings): 40 | print(paths) 41 | print(settings) 42 | return basepath + filename 43 | 44 | 45 | # ============== Experiment ============================== 46 | ex = Experiment("modular_example", ingredients=[data, data_paths]) 47 | 48 | 49 | @ex.config 50 | def cfg(dataset): 51 | a = 10 52 | b = 17 53 | c = a + b 54 | out_base = dataset["paths"]["base"] + "outputs/" 55 | out_filename = dataset["filename"].replace(".hdf5", ".out") 56 | 57 | 58 | @ex.automain 59 | def main(a, b, c, out_base, out_filename, dataset): 60 | print("a =", a) 61 | print("b =", b) 62 | print("c =", c) 63 | print("out_base =", out_base, out_filename) 64 | # print("dataset", dataset) 65 | # print("dataset.paths", dataset['paths']) 66 | print("foo()", foo()) 67 | -------------------------------------------------------------------------------- /examples/named_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ A very configurable Hello World. Yay! """ 4 | 5 | from sacred import Experiment 6 | 7 | ex = Experiment("hello_config") 8 | 9 | 10 | @ex.named_config 11 | def rude(): 12 | """A rude named config""" 13 | recipient = "bastard" 14 | message = "Fuck off you {}!".format(recipient) 15 | 16 | 17 | @ex.config 18 | def cfg(): 19 | recipient = "world" 20 | message = "Hello {}!".format(recipient) 21 | 22 | 23 | @ex.automain 24 | def main(message): 25 | print(__name__) 26 | print(message) 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ['py38', 'py39', 'py310', 'py311'] 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | ( 6 | /( 7 | \.eggs # exclude a few common directories in the 8 | | \.git # root of the project 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ) 19 | ''' 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt-ng>=0.9, <1.0 2 | jsonpickle>=2.2.0 3 | munch>=2.5, <5.0 4 | wrapt>=1.0, <2.0 5 | py-cpuinfo>=4.0 6 | colorama>=0.4 7 | packaging>=18.0 8 | GitPython 9 | -------------------------------------------------------------------------------- /sacred/__about__.py: -------------------------------------------------------------------------------- 1 | """Contains meta-information about the Sacred package. 2 | 3 | It is kept simple and separate from the main module, because this information 4 | is also read by the setup.py. And during installation the sacred module cannot 5 | yet be imported. 6 | """ 7 | 8 | __all__ = ("__version__", "__author__", "__author_email__", "__url__") 9 | 10 | __version__ = "0.8.7" 11 | 12 | __author__ = "Klaus Greff" 13 | __author_email__ = "klaus.greff@startmail.com" 14 | 15 | __url__ = "https://github.com/IDSIA/sacred" 16 | -------------------------------------------------------------------------------- /sacred/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | The main module of sacred. 5 | 6 | It provides access to the two main classes Experiment and Ingredient. 7 | """ 8 | 9 | from sacred.__about__ import __version__, __author__, __author_email__, __url__ 10 | from sacred.settings import SETTINGS 11 | from sacred.experiment import Experiment 12 | from sacred.ingredient import Ingredient 13 | from sacred import observers 14 | from sacred.host_info import host_info_getter, host_info_gatherer 15 | from sacred.commandline_options import cli_option 16 | 17 | 18 | __all__ = ( 19 | "Experiment", 20 | "Ingredient", 21 | "observers", 22 | "host_info_getter", 23 | "__version__", 24 | "__author__", 25 | "__author_email__", 26 | "__url__", 27 | "SETTINGS", 28 | "host_info_gatherer", 29 | "cli_option", 30 | ) 31 | -------------------------------------------------------------------------------- /sacred/config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred.config.config_dict import ConfigDict 5 | from sacred.config.config_scope import ConfigScope 6 | from sacred.config.config_files import load_config_file, save_config_file 7 | from sacred.config.captured_function import create_captured_function 8 | from sacred.config.utils import chain_evaluate_config_scopes, dogmatize, undogmatize 9 | 10 | __all__ = ( 11 | "ConfigDict", 12 | "ConfigScope", 13 | "load_config_file", 14 | "save_config_file", 15 | "create_captured_function", 16 | "chain_evaluate_config_scopes", 17 | "dogmatize", 18 | "undogmatize", 19 | ) 20 | -------------------------------------------------------------------------------- /sacred/config/captured_function.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import time 5 | from datetime import timedelta 6 | 7 | import wrapt 8 | from sacred.config.custom_containers import fallback_dict 9 | from sacred.config.signature import Signature 10 | from sacred.randomness import create_rnd, get_seed 11 | from sacred.utils import ConfigError 12 | 13 | 14 | def create_captured_function(function, prefix=None): 15 | sig = Signature(function) 16 | function.signature = sig 17 | function.uses_randomness = "_seed" in sig.arguments or "_rnd" in sig.arguments 18 | function.logger = None 19 | function.config = {} 20 | function.rnd = None 21 | function.run = None 22 | function.prefix = prefix 23 | return captured_function(function) 24 | 25 | 26 | @wrapt.decorator 27 | def captured_function(wrapped, instance, args, kwargs): 28 | options = fallback_dict( 29 | wrapped.config, _config=wrapped.config, _log=wrapped.logger, _run=wrapped.run 30 | ) 31 | if wrapped.uses_randomness: # only generate _seed and _rnd if needed 32 | options["_seed"] = get_seed(wrapped.rnd) 33 | options["_rnd"] = create_rnd(options["_seed"]) 34 | 35 | bound = instance is not None 36 | args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options, bound) 37 | if wrapped.logger is not None: 38 | wrapped.logger.debug("Started") 39 | start_time = time.time() 40 | # =================== run actual function ================================= 41 | with ConfigError.track(wrapped.config, wrapped.prefix): 42 | result = wrapped(*args, **kwargs) 43 | # ========================================================================= 44 | if wrapped.logger is not None: 45 | stop_time = time.time() 46 | elapsed_time = timedelta(seconds=round(stop_time - start_time)) 47 | wrapped.logger.debug("Finished after %s.", elapsed_time) 48 | 49 | return result 50 | -------------------------------------------------------------------------------- /sacred/config/config_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred.config.config_summary import ConfigSummary 5 | from sacred.config.utils import ( 6 | dogmatize, 7 | normalize_or_die, 8 | undogmatize, 9 | recursive_fill_in, 10 | ) 11 | 12 | 13 | class ConfigDict: 14 | def __init__(self, d): 15 | self._conf = normalize_or_die(d) 16 | 17 | def __call__(self, fixed=None, preset=None, fallback=None): 18 | result = dogmatize(fixed or {}) 19 | recursive_fill_in(result, self._conf) 20 | recursive_fill_in(result, preset or {}) 21 | added = result.revelation() 22 | config_summary = ConfigSummary(added, result.modified, result.typechanges) 23 | config_summary.update(undogmatize(result)) 24 | return config_summary 25 | -------------------------------------------------------------------------------- /sacred/config/config_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import pickle 6 | 7 | import json 8 | 9 | import sacred.optional as opt 10 | from sacred.serializer import flatten, restore 11 | 12 | __all__ = ("load_config_file", "save_config_file") 13 | 14 | 15 | class Handler: 16 | def __init__(self, load, dump, mode): 17 | self.load = load 18 | self.dump = dump 19 | self.mode = mode 20 | 21 | 22 | HANDLER_BY_EXT = { 23 | ".json": Handler( 24 | lambda fp: restore(json.load(fp)), 25 | lambda obj, fp: json.dump(flatten(obj), fp, sort_keys=True, indent=2), 26 | "", 27 | ), 28 | ".pickle": Handler(pickle.load, pickle.dump, "b"), 29 | } 30 | 31 | yaml_extensions = (".yaml", ".yml") 32 | if opt.has_yaml: 33 | 34 | def load_yaml(filename): 35 | return opt.yaml.load(filename, Loader=opt.yaml.FullLoader) 36 | 37 | yaml_handler = Handler(load_yaml, opt.yaml.dump, "") 38 | 39 | for extension in yaml_extensions: 40 | HANDLER_BY_EXT[extension] = yaml_handler 41 | 42 | 43 | def get_handler(filename): 44 | _, extension = os.path.splitext(filename) 45 | if extension in yaml_extensions and not opt.has_yaml: 46 | raise KeyError( 47 | 'Configuration file "{}" cannot be loaded as ' 48 | "you do not have PyYAML installed.".format(filename) 49 | ) 50 | try: 51 | return HANDLER_BY_EXT[extension] 52 | except KeyError as e: 53 | raise ValueError( 54 | 'Configuration file "{}" has invalid or unsupported extension ' 55 | '"{}".'.format(filename, extension) 56 | ) from e 57 | 58 | 59 | def load_config_file(filename): 60 | handler = get_handler(filename) 61 | with open(filename, "r" + handler.mode) as f: 62 | return handler.load(f) 63 | 64 | 65 | def save_config_file(config, filename): 66 | handler = get_handler(filename) 67 | with open(filename, "w" + handler.mode) as f: 68 | handler.dump(config, f) 69 | -------------------------------------------------------------------------------- /sacred/config/config_summary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred.utils import iter_prefixes, join_paths 5 | 6 | 7 | class ConfigSummary(dict): 8 | def __init__( 9 | self, added=(), modified=(), typechanged=(), ignored_fallbacks=(), docs=() 10 | ): 11 | super().__init__() 12 | self.added = set(added) 13 | self.modified = set(modified) # TODO: test for this member 14 | self.typechanged = dict(typechanged) 15 | self.ignored_fallbacks = set(ignored_fallbacks) # TODO: test 16 | self.docs = dict(docs) 17 | self.ensure_coherence() 18 | 19 | def update_from(self, config_mod, path=""): 20 | added = config_mod.added 21 | updated = config_mod.modified 22 | typechanged = config_mod.typechanged 23 | self.added &= {join_paths(path, a) for a in added} 24 | self.modified |= {join_paths(path, u) for u in updated} 25 | self.typechanged.update( 26 | {join_paths(path, k): v for k, v in typechanged.items()} 27 | ) 28 | self.ensure_coherence() 29 | for k, v in config_mod.docs.items(): 30 | if not self.docs.get(k, ""): 31 | self.docs[k] = v 32 | 33 | def update_add(self, config_mod, path=""): 34 | added = config_mod.added 35 | updated = config_mod.modified 36 | typechanged = config_mod.typechanged 37 | self.added |= {join_paths(path, a) for a in added} 38 | self.modified |= {join_paths(path, u) for u in updated} 39 | self.typechanged.update( 40 | {join_paths(path, k): v for k, v in typechanged.items()} 41 | ) 42 | self.docs.update( 43 | { 44 | join_paths(path, k): v 45 | for k, v in config_mod.docs.items() 46 | if path == "" or k != "seed" 47 | } 48 | ) 49 | self.ensure_coherence() 50 | 51 | def ensure_coherence(self): 52 | # make sure parent paths show up as updated appropriately 53 | self.modified |= {p for a in self.added for p in iter_prefixes(a)} 54 | self.modified |= {p for u in self.modified for p in iter_prefixes(u)} 55 | self.modified |= {p for t in self.typechanged for p in iter_prefixes(t)} 56 | 57 | # make sure there is no overlap 58 | self.added -= set(self.typechanged.keys()) 59 | self.modified -= set(self.typechanged.keys()) 60 | self.modified -= self.added 61 | -------------------------------------------------------------------------------- /sacred/config/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import jsonpickle.tags 5 | 6 | from sacred import SETTINGS 7 | import sacred.optional as opt 8 | from sacred.config.custom_containers import DogmaticDict, DogmaticList 9 | from sacred.utils import PYTHON_IDENTIFIER 10 | 11 | 12 | def assert_is_valid_key(key): 13 | """ 14 | Raise KeyError if a given config key violates any requirements. 15 | 16 | The requirements are the following and can be individually deactivated 17 | in ``sacred.SETTINGS.CONFIG_KEYS``: 18 | * ENFORCE_MONGO_COMPATIBLE (default: True): 19 | make sure the keys don't contain a '.' or start with a '$' 20 | * ENFORCE_JSONPICKLE_COMPATIBLE (default: True): 21 | make sure the keys do not contain any reserved jsonpickle tags 22 | This is very important. Only deactivate if you know what you are doing. 23 | * ENFORCE_STRING (default: False): 24 | make sure all keys are string. 25 | * ENFORCE_VALID_PYTHON_IDENTIFIER (default: False): 26 | make sure all keys are valid python identifiers. 27 | 28 | Parameters 29 | ---------- 30 | key: 31 | The key that should be checked 32 | 33 | Raises 34 | ------ 35 | KeyError: 36 | if the key violates any requirements 37 | 38 | """ 39 | if SETTINGS.CONFIG.ENFORCE_KEYS_MONGO_COMPATIBLE and ( 40 | isinstance(key, str) and ("." in key or key[0] == "$") 41 | ): 42 | raise KeyError( 43 | 'Invalid key "{}". Config-keys cannot ' 44 | 'contain "." or start with "$"'.format(key) 45 | ) 46 | 47 | if ( 48 | SETTINGS.CONFIG.ENFORCE_KEYS_JSONPICKLE_COMPATIBLE 49 | and isinstance(key, str) 50 | and (key in jsonpickle.tags.RESERVED or key.startswith("json://")) 51 | ): 52 | raise KeyError( 53 | 'Invalid key "{}". Config-keys cannot be one of the' 54 | "reserved jsonpickle tags: {}".format(key, jsonpickle.tags.RESERVED) 55 | ) 56 | 57 | if SETTINGS.CONFIG.ENFORCE_STRING_KEYS and (not isinstance(key, str)): 58 | raise KeyError( 59 | 'Invalid key "{}". Config-keys have to be strings, ' 60 | "but was {}".format(key, type(key)) 61 | ) 62 | 63 | if SETTINGS.CONFIG.ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS and ( 64 | isinstance(key, str) and not PYTHON_IDENTIFIER.match(key) 65 | ): 66 | raise KeyError('Key "{}" is not a valid python identifier'.format(key)) 67 | 68 | if SETTINGS.CONFIG.ENFORCE_KEYS_NO_EQUALS and (isinstance(key, str) and "=" in key): 69 | raise KeyError( 70 | 'Invalid key "{}". Config keys may not contain an' 71 | 'equals sign ("=").'.format("=") 72 | ) 73 | 74 | 75 | def normalize_numpy(obj): 76 | if opt.has_numpy and isinstance(obj, opt.np.generic): 77 | try: 78 | return obj.item() 79 | except ValueError: 80 | pass 81 | return obj 82 | 83 | 84 | def normalize_or_die(obj): 85 | if isinstance(obj, dict): 86 | res = dict() 87 | for key, value in obj.items(): 88 | assert_is_valid_key(key) 89 | res[key] = normalize_or_die(value) 90 | return res 91 | elif isinstance(obj, (list, tuple)): 92 | return list([normalize_or_die(value) for value in obj]) 93 | return normalize_numpy(obj) 94 | 95 | 96 | def recursive_fill_in(config, preset): 97 | for key in preset: 98 | if key not in config: 99 | config[key] = preset[key] 100 | elif isinstance(config[key], dict) and isinstance(preset[key], dict): 101 | recursive_fill_in(config[key], preset[key]) 102 | 103 | 104 | def chain_evaluate_config_scopes(config_scopes, fixed=None, preset=None, fallback=None): 105 | fixed = fixed or {} 106 | fallback = fallback or {} 107 | final_config = dict(preset or {}) 108 | config_summaries = [] 109 | for config in config_scopes: 110 | cfg = config(fixed=fixed, preset=final_config, fallback=fallback) 111 | config_summaries.append(cfg) 112 | final_config.update(cfg) 113 | 114 | if not config_scopes: 115 | final_config.update(fixed) 116 | 117 | return undogmatize(final_config), config_summaries 118 | 119 | 120 | def dogmatize(obj): 121 | if isinstance(obj, dict): 122 | return DogmaticDict({key: dogmatize(val) for key, val in obj.items()}) 123 | elif isinstance(obj, list): 124 | return DogmaticList([dogmatize(value) for value in obj]) 125 | elif isinstance(obj, tuple): 126 | return tuple(dogmatize(value) for value in obj) 127 | else: 128 | return obj 129 | 130 | 131 | def undogmatize(obj): 132 | if isinstance(obj, DogmaticDict): 133 | return dict({key: undogmatize(value) for key, value in obj.items()}) 134 | elif isinstance(obj, DogmaticList): 135 | return list([undogmatize(value) for value in obj]) 136 | elif isinstance(obj, tuple): 137 | return tuple(undogmatize(value) for value in obj) 138 | else: 139 | return obj 140 | -------------------------------------------------------------------------------- /sacred/metrics_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import datetime 4 | import sacred.optional as opt 5 | 6 | from queue import Queue, Empty 7 | 8 | 9 | class MetricsLogger: 10 | """MetricsLogger collects metrics measured during experiments. 11 | 12 | MetricsLogger is the (only) part of the Metrics API. 13 | An instance of the class should be created for the Run class, such that the 14 | log_scalar_metric method is accessible from running experiments using 15 | _run.metrics.log_scalar_metric. 16 | """ 17 | 18 | def __init__(self): 19 | # Create a message queue that remembers 20 | # calls of the log_scalar_metric 21 | self._logged_metrics = Queue() 22 | self._metric_step_counter = {} 23 | """Remembers the last number of each metric.""" 24 | 25 | def log_scalar_metric(self, metric_name, value, step=None): 26 | """ 27 | Add a new measurement. 28 | 29 | The measurement will be processed by the MongoDB observer 30 | during a heartbeat event. 31 | Other observers are not yet supported. 32 | 33 | :param metric_name: The name of the metric, e.g. training.loss. 34 | :param value: The measured value. 35 | :param step: The step number (integer), e.g. the iteration number 36 | If not specified, an internal counter for each metric 37 | is used, incremented by one. 38 | """ 39 | if opt.has_numpy: 40 | np = opt.np 41 | if isinstance(value, np.generic): 42 | value = value.item() 43 | if isinstance(step, np.generic): 44 | step = step.item() 45 | if step is None: 46 | step = self._metric_step_counter.get(metric_name, -1) + 1 47 | self._logged_metrics.put( 48 | ScalarMetricLogEntry(metric_name, step, datetime.datetime.utcnow(), value) 49 | ) 50 | self._metric_step_counter[metric_name] = step 51 | 52 | def get_last_metrics(self): 53 | """Read all measurement events since last call of the method. 54 | 55 | :return List[ScalarMetricLogEntry] 56 | """ 57 | read_up_to = self._logged_metrics.qsize() 58 | messages = [] 59 | for i in range(read_up_to): 60 | try: 61 | messages.append(self._logged_metrics.get_nowait()) 62 | except Empty: 63 | pass 64 | return messages 65 | 66 | 67 | class ScalarMetricLogEntry: 68 | """Container for measurements of scalar metrics. 69 | 70 | There is exactly one ScalarMetricLogEntry per logged scalar metric value. 71 | """ 72 | 73 | def __init__(self, name, step, timestamp, value): 74 | self.name = name 75 | self.step = step 76 | self.timestamp = timestamp 77 | self.value = value 78 | 79 | 80 | def linearize_metrics(logged_metrics): 81 | """ 82 | Group metrics by name. 83 | 84 | Takes a list of individual measurements, possibly belonging 85 | to different metrics and groups them by name. 86 | 87 | :param logged_metrics: A list of ScalarMetricLogEntries 88 | :return: Measured values grouped by the metric name: 89 | {"metric_name1": {"steps": [0,1,2], "values": [4, 5, 6], 90 | "timestamps": [datetime, datetime, datetime]}, 91 | "metric_name2": {...}} 92 | """ 93 | metrics_by_name = {} 94 | for metric_entry in logged_metrics: 95 | if metric_entry.name not in metrics_by_name: 96 | metrics_by_name[metric_entry.name] = { 97 | "steps": [], 98 | "values": [], 99 | "timestamps": [], 100 | "name": metric_entry.name, 101 | } 102 | metrics_by_name[metric_entry.name]["steps"].append(metric_entry.step) 103 | metrics_by_name[metric_entry.name]["values"].append(metric_entry.value) 104 | metrics_by_name[metric_entry.name]["timestamps"].append(metric_entry.timestamp) 105 | return metrics_by_name 106 | -------------------------------------------------------------------------------- /sacred/observers/__init__.py: -------------------------------------------------------------------------------- 1 | from sacred.observers.base import RunObserver 2 | from sacred.observers.file_storage import FileStorageObserver 3 | from sacred.observers.mongo import MongoObserver, QueuedMongoObserver 4 | from sacred.observers.sql import SqlObserver 5 | from sacred.observers.tinydb_hashfs import TinyDbObserver, TinyDbReader 6 | from sacred.observers.slack import SlackObserver 7 | from sacred.observers.telegram_obs import TelegramObserver 8 | from sacred.observers.s3_observer import S3Observer 9 | from sacred.observers.queue import QueueObserver 10 | from sacred.observers.gcs_observer import GoogleCloudStorageObserver 11 | 12 | 13 | __all__ = ( 14 | "FileStorageObserver", 15 | "RunObserver", 16 | "MongoObserver", 17 | "QueuedMongoObserver", 18 | "SqlObserver", 19 | "TinyDbObserver", 20 | "TinyDbReader", 21 | "SlackObserver", 22 | "TelegramObserver", 23 | "S3Observer", 24 | "QueueObserver", 25 | "GoogleCloudStorageObserver", 26 | ) 27 | -------------------------------------------------------------------------------- /sacred/observers/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | __all__ = ("RunObserver", "td_format") 5 | 6 | 7 | class RunObserver: 8 | """Defines the interface for all run observers.""" 9 | 10 | priority = 0 11 | 12 | def queued_event( 13 | self, ex_info, command, host_info, queue_time, config, meta_info, _id 14 | ): 15 | pass 16 | 17 | def started_event( 18 | self, ex_info, command, host_info, start_time, config, meta_info, _id 19 | ): 20 | pass 21 | 22 | def heartbeat_event(self, info, captured_out, beat_time, result): 23 | pass 24 | 25 | def completed_event(self, stop_time, result): 26 | pass 27 | 28 | def interrupted_event(self, interrupt_time, status): 29 | pass 30 | 31 | def failed_event(self, fail_time, fail_trace): 32 | pass 33 | 34 | def resource_event(self, filename): 35 | pass 36 | 37 | def artifact_event(self, name, filename, metadata=None, content_type=None): 38 | pass 39 | 40 | def log_metrics(self, metrics_by_name, info): 41 | pass 42 | 43 | def join(self): 44 | pass 45 | 46 | 47 | # http://stackoverflow.com/questions/538666/python-format-timedelta-to-string 48 | def td_format(td_object): 49 | seconds = int(td_object.total_seconds()) 50 | if seconds == 0: 51 | return "less than a second" 52 | 53 | periods = [ 54 | ("year", 60 * 60 * 24 * 365), 55 | ("month", 60 * 60 * 24 * 30), 56 | ("day", 60 * 60 * 24), 57 | ("hour", 60 * 60), 58 | ("minute", 60), 59 | ("second", 1), 60 | ] 61 | 62 | strings = [] 63 | for period_name, period_seconds in periods: 64 | if seconds >= period_seconds: 65 | period_value, seconds = divmod(seconds, period_seconds) 66 | if period_value == 1: 67 | strings.append("%s %s" % (period_value, period_name)) 68 | else: 69 | strings.append("%s %ss" % (period_value, period_name)) 70 | 71 | return ", ".join(strings) 72 | -------------------------------------------------------------------------------- /sacred/observers/queue.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from queue import Queue 3 | from sacred.observers.base import RunObserver 4 | from sacred.utils import IntervalTimer 5 | import traceback 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | WrappedEvent = namedtuple("WrappedEvent", "name args kwargs") 11 | 12 | 13 | class QueueObserver(RunObserver): 14 | """Wraps any observer and puts processing of events in the background. 15 | 16 | If the covered observer fails to process an event, the queue observer 17 | will retry until it works. This is useful for observers that rely on 18 | external services like databases that might become temporarily 19 | unavailable. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | covered_observer: RunObserver, 25 | interval: float = 20.0, 26 | retry_interval: float = 10.0, 27 | ): 28 | """Initialize QueueObserver. 29 | 30 | Parameters 31 | ---------- 32 | covered_observer 33 | The real observer that is being wrapped. 34 | interval 35 | The interval in seconds at which the background thread is woken up to process new events. 36 | retry_interval 37 | The interval in seconds to wait if an event failed to be processed. 38 | """ 39 | self._covered_observer = covered_observer 40 | self._retry_interval = retry_interval 41 | self._interval = interval 42 | self._queue = None 43 | self._worker = None 44 | self._stop_worker_event = None 45 | logger.debug("just testing") 46 | 47 | def queued_event(self, *args, **kwargs): 48 | self._queue.put(WrappedEvent("queued_event", args, kwargs)) 49 | 50 | def started_event(self, *args, **kwargs): 51 | self._queue = Queue() 52 | self._stop_worker_event, self._worker = IntervalTimer.create( 53 | self._run, interval=self._interval 54 | ) 55 | self._worker.start() 56 | 57 | # Putting the started event on the queue makes no sense 58 | # as it is required for initialization of the covered observer. 59 | return self._covered_observer.started_event(*args, **kwargs) 60 | 61 | def heartbeat_event(self, *args, **kwargs): 62 | self._queue.put(WrappedEvent("heartbeat_event", args, kwargs)) 63 | 64 | def completed_event(self, *args, **kwargs): 65 | self._queue.put(WrappedEvent("completed_event", args, kwargs)) 66 | self.join() 67 | 68 | def interrupted_event(self, *args, **kwargs): 69 | self._queue.put(WrappedEvent("interrupted_event", args, kwargs)) 70 | self.join() 71 | 72 | def failed_event(self, *args, **kwargs): 73 | self._queue.put(WrappedEvent("failed_event", args, kwargs)) 74 | self.join() 75 | 76 | def resource_event(self, *args, **kwargs): 77 | self._queue.put(WrappedEvent("resource_event", args, kwargs)) 78 | 79 | def artifact_event(self, *args, **kwargs): 80 | self._queue.put(WrappedEvent("artifact_event", args, kwargs)) 81 | 82 | def log_metrics(self, metrics_by_name, info): 83 | for metric_name, metric_values in metrics_by_name.items(): 84 | self._queue.put( 85 | WrappedEvent("log_metrics", [metric_name, metric_values, info], {}) 86 | ) 87 | 88 | def _run(self): 89 | """Empty the queue every interval.""" 90 | while not self._queue.empty(): 91 | try: 92 | event = self._queue.get() 93 | except IndexError: 94 | # Currently there is no event on the queue so 95 | # just go back to sleep. 96 | pass 97 | else: 98 | try: 99 | method = getattr(self._covered_observer, event.name) 100 | except NameError: 101 | # The covered observer does not implement an event handler 102 | # for the event, so just discard the message. 103 | self._queue.task_done() 104 | else: 105 | while True: 106 | try: 107 | method(*event.args, **event.kwargs) 108 | except: 109 | # Something went wrong during the processing of 110 | # the event so wait for some time and 111 | # then try again. 112 | logger.debug( 113 | "Error while processing event. Trying again.\n{}".format( 114 | traceback.format_exc() 115 | ) 116 | ) 117 | # logging.debug(f"""Error while processing event. Trying again. 118 | # {traceback.format_exc()}""") 119 | 120 | self._stop_worker_event.wait(self._retry_interval) 121 | continue 122 | else: 123 | self._queue.task_done() 124 | break 125 | 126 | def join(self): 127 | if self._queue is not None: 128 | self._queue.join() 129 | self._stop_worker_event.set() 130 | self._worker.join(timeout=10) 131 | 132 | def __getattr__(self, item): 133 | return getattr(self._covered_observer, item) 134 | 135 | def __eq__(self, other): 136 | return self._covered_observer == other 137 | -------------------------------------------------------------------------------- /sacred/observers/slack.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred.observers.base import RunObserver, td_format 5 | from sacred.config.config_files import load_config_file 6 | import json 7 | 8 | 9 | DEFAULT_SLACK_PRIORITY = 10 10 | 11 | 12 | class SlackObserver(RunObserver): 13 | """Sends a message to Slack upon completion/failing of an experiment.""" 14 | 15 | @classmethod 16 | def from_config(cls, filename): 17 | """ 18 | Create a SlackObserver from a given configuration file. 19 | 20 | The file can be in any format supported by Sacred 21 | (.json, .pickle, [.yaml]). 22 | It has to specify a ``webhook_url`` and can optionally set 23 | ``bot_name``, ``icon``, ``completed_text``, ``interrupted_text``, and 24 | ``failed_text``. 25 | """ 26 | return cls(**load_config_file(filename)) 27 | 28 | def __init__( 29 | self, 30 | webhook_url, 31 | bot_name="sacred-bot", 32 | icon=":angel:", 33 | priority=DEFAULT_SLACK_PRIORITY, 34 | completed_text=None, 35 | interrupted_text=None, 36 | failed_text=None, 37 | ): 38 | self.webhook_url = webhook_url 39 | self.bot_name = bot_name 40 | self.icon = icon 41 | self.completed_text = completed_text or ( 42 | ":white_check_mark: *{experiment[name]}* " 43 | "completed after _{elapsed_time}_ with result=`{result}`" 44 | ) 45 | self.interrupted_text = interrupted_text or ( 46 | ":warning: *{experiment[name]}* " "interrupted after _{elapsed_time}_" 47 | ) 48 | self.failed_text = failed_text or ( 49 | ":x: *{experiment[name]}* failed after " "_{elapsed_time}_ with `{error}`" 50 | ) 51 | self.run = None 52 | self.priority = priority 53 | 54 | def started_event( 55 | self, ex_info, command, host_info, start_time, config, meta_info, _id 56 | ): 57 | self.run = { 58 | "_id": _id, 59 | "config": config, 60 | "start_time": start_time, 61 | "experiment": ex_info, 62 | "command": command, 63 | "host_info": host_info, 64 | } 65 | 66 | def get_completed_text(self): 67 | return self.completed_text.format(**self.run) 68 | 69 | def get_interrupted_text(self): 70 | return self.interrupted_text.format(**self.run) 71 | 72 | def get_failed_text(self): 73 | return self.failed_text.format(**self.run) 74 | 75 | def completed_event(self, stop_time, result): 76 | import requests 77 | 78 | if self.completed_text is None: 79 | return 80 | 81 | self.run["result"] = result 82 | self.run["stop_time"] = stop_time 83 | self.run["elapsed_time"] = td_format(stop_time - self.run["start_time"]) 84 | 85 | data = { 86 | "username": self.bot_name, 87 | "icon_emoji": self.icon, 88 | "text": self.get_completed_text(), 89 | } 90 | headers = {"Content-type": "application/json", "Accept": "text/plain"} 91 | requests.post(self.webhook_url, data=json.dumps(data), headers=headers) 92 | 93 | def interrupted_event(self, interrupt_time, status): 94 | import requests 95 | 96 | if self.interrupted_text is None: 97 | return 98 | 99 | self.run["status"] = status 100 | self.run["interrupt_time"] = interrupt_time 101 | self.run["elapsed_time"] = td_format(interrupt_time - self.run["start_time"]) 102 | 103 | data = { 104 | "username": self.bot_name, 105 | "icon_emoji": self.icon, 106 | "text": self.get_interrupted_text(), 107 | } 108 | headers = {"Content-type": "application/json", "Accept": "text/plain"} 109 | requests.post(self.webhook_url, data=json.dumps(data), headers=headers) 110 | 111 | def failed_event(self, fail_time, fail_trace): 112 | import requests 113 | 114 | if self.failed_text is None: 115 | return 116 | 117 | self.run["fail_trace"] = fail_trace 118 | self.run["error"] = fail_trace[-1].strip() 119 | self.run["fail_time"] = fail_time 120 | self.run["elapsed_time"] = td_format(fail_time - self.run["start_time"]) 121 | 122 | data = { 123 | "username": self.bot_name, 124 | "icon_emoji": self.icon, 125 | "text": self.get_failed_text(), 126 | } 127 | headers = {"Content-type": "application/json", "Accept": "text/plain"} 128 | requests.post(self.webhook_url, data=json.dumps(data), headers=headers) 129 | -------------------------------------------------------------------------------- /sacred/observers/sql.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import json 5 | from threading import Lock 6 | import warnings 7 | 8 | from sacred.commandline_options import cli_option 9 | from sacred.observers.base import RunObserver 10 | from sacred.serializer import flatten 11 | 12 | DEFAULT_SQL_PRIORITY = 40 13 | 14 | 15 | # ############################# Observer #################################### # 16 | 17 | 18 | class SqlObserver(RunObserver): 19 | @classmethod 20 | def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY): 21 | warnings.warn( 22 | "SqlObserver.create(...) is deprecated. Please use" 23 | " SqlObserver(...) instead.", 24 | DeprecationWarning, 25 | ) 26 | return cls(url, echo, priority) 27 | 28 | def __init__(self, url, echo=False, priority=DEFAULT_SQL_PRIORITY): 29 | from sqlalchemy.orm import sessionmaker, scoped_session 30 | import sqlalchemy as sa 31 | 32 | engine = sa.create_engine(url, echo=echo) 33 | session_factory = sessionmaker(bind=engine) 34 | # make session thread-local to avoid problems with sqlite (see #275) 35 | session = scoped_session(session_factory) 36 | self.engine = engine 37 | self.session = session 38 | self.priority = priority 39 | self.run = None 40 | self.lock = Lock() 41 | 42 | @classmethod 43 | def create_from(cls, engine, session, priority=DEFAULT_SQL_PRIORITY): 44 | """Instantiate a SqlObserver with an existing engine and session.""" 45 | self = cls.__new__(cls) # skip __init__ call 46 | self.engine = engine 47 | self.session = session 48 | self.priority = priority 49 | self.run = None 50 | self.lock = Lock() 51 | return self 52 | 53 | def started_event( 54 | self, ex_info, command, host_info, start_time, config, meta_info, _id 55 | ): 56 | return self._add_event( 57 | ex_info, 58 | command, 59 | host_info, 60 | config, 61 | meta_info, 62 | _id, 63 | "RUNNING", 64 | start_time=start_time, 65 | ) 66 | 67 | def queued_event( 68 | self, ex_info, command, host_info, queue_time, config, meta_info, _id 69 | ): 70 | return self._add_event( 71 | ex_info, command, host_info, config, meta_info, _id, "QUEUED" 72 | ) 73 | 74 | def _add_event( 75 | self, ex_info, command, host_info, config, meta_info, _id, status, **kwargs 76 | ): 77 | from .sql_bases import Base, Experiment, Host, Run 78 | 79 | Base.metadata.create_all(self.engine) 80 | sql_exp = Experiment.get_or_create(ex_info, self.session) 81 | sql_host = Host.get_or_create(host_info, self.session) 82 | if _id is None: 83 | i = self.session.query(Run).order_by(Run.id.desc()).first() 84 | _id = 0 if i is None else i.id + 1 85 | 86 | self.run = Run( 87 | run_id=str(_id), 88 | config=json.dumps(flatten(config)), 89 | command=command, 90 | priority=meta_info.get("priority", 0), 91 | comment=meta_info.get("comment", ""), 92 | experiment=sql_exp, 93 | host=sql_host, 94 | status=status, 95 | **kwargs, 96 | ) 97 | self.session.add(self.run) 98 | self.save() 99 | return _id or self.run.run_id 100 | 101 | def heartbeat_event(self, info, captured_out, beat_time, result): 102 | self.run.info = json.dumps(flatten(info)) 103 | self.run.captured_out = captured_out 104 | self.run.heartbeat = beat_time 105 | self.run.result = result 106 | self.save() 107 | 108 | def completed_event(self, stop_time, result): 109 | self.run.stop_time = stop_time 110 | self.run.result = result 111 | self.run.status = "COMPLETED" 112 | self.save() 113 | 114 | def interrupted_event(self, interrupt_time, status): 115 | self.run.stop_time = interrupt_time 116 | self.run.status = status 117 | self.save() 118 | 119 | def failed_event(self, fail_time, fail_trace): 120 | self.run.stop_time = fail_time 121 | self.run.fail_trace = "\n".join(fail_trace) 122 | self.run.status = "FAILED" 123 | self.save() 124 | 125 | def resource_event(self, filename): 126 | from .sql_bases import Resource 127 | 128 | res = Resource.get_or_create(filename, self.session) 129 | self.run.resources.append(res) 130 | self.save() 131 | 132 | def artifact_event(self, name, filename, metadata=None, content_type=None): 133 | from .sql_bases import Artifact 134 | 135 | a = Artifact.create(name, filename) 136 | self.run.artifacts.append(a) 137 | self.save() 138 | 139 | def save(self): 140 | with self.lock: 141 | self.session.commit() 142 | 143 | def query(self, _id): 144 | from .sql_bases import Run 145 | 146 | run = self.session.query(Run).filter_by(id=_id).first() 147 | return run.to_json() 148 | 149 | def __eq__(self, other): 150 | if isinstance(other, SqlObserver): 151 | # fixme: this will probably fail to detect two equivalent engines 152 | return self.engine == other.engine and self.session == other.session 153 | return False 154 | 155 | 156 | # ######################## Commandline Option ############################### # 157 | 158 | 159 | @cli_option("-s", "--sql") 160 | def sql_option(args, run): 161 | """Add a SQL Observer to the experiment. 162 | 163 | The typical form is: dialect://username:password@host:port/database 164 | """ 165 | run.observers.append(SqlObserver(args)) 166 | -------------------------------------------------------------------------------- /sacred/observers/tinydb_hashfs/__init__.py: -------------------------------------------------------------------------------- 1 | from .tinydb_hashfs import TinyDbReader, TinyDbObserver, tiny_db_option 2 | 3 | __all__ = ["TinyDbObserver", "TinyDbReader", "tiny_db_option"] 4 | -------------------------------------------------------------------------------- /sacred/observers/tinydb_hashfs/bases.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import json 3 | import os 4 | from io import BufferedReader, FileIO 5 | from pathlib import Path 6 | from typing import Tuple 7 | 8 | from hashfs import HashFS 9 | from tinydb import TinyDB 10 | from tinydb_serialization import Serializer, SerializationMiddleware 11 | 12 | import sacred.optional as opt 13 | 14 | # Set data type values for abstract properties in Serializers 15 | series_type = opt.pandas.Series if opt.has_pandas else None 16 | dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None 17 | ndarray_type = opt.np.ndarray if opt.has_numpy else None 18 | 19 | 20 | class BufferedReaderWrapper(BufferedReader): 21 | """Custom wrapper to allow for copying of file handle. 22 | 23 | tinydb_serialisation currently does a deepcopy on all the content of the 24 | dictionary before serialisation. By default, file handles are not 25 | copiable so this wrapper is necessary to create a duplicate of the 26 | file handle passes in. 27 | 28 | Note that the file passed in will therefor remain open as the copy is the 29 | one that gets closed. 30 | """ 31 | 32 | def __init__(self, f_obj): 33 | f_obj = FileIO(f_obj.name) 34 | super().__init__(f_obj) 35 | 36 | def __copy__(self): 37 | f = open(self.name, self.mode) 38 | return BufferedReaderWrapper(f) 39 | 40 | def __deepcopy__(self, memo): 41 | f = open(self.name, self.mode) 42 | return BufferedReaderWrapper(f) 43 | 44 | 45 | class DateTimeSerializer(Serializer): 46 | OBJ_CLASS = dt.datetime # The class this serializer handles 47 | 48 | def encode(self, obj): 49 | return obj.strftime("%Y-%m-%dT%H:%M:%S.%f") 50 | 51 | def decode(self, s): 52 | return dt.datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%f") 53 | 54 | 55 | class NdArraySerializer(Serializer): 56 | OBJ_CLASS = ndarray_type 57 | 58 | def encode(self, obj): 59 | return json.dumps(obj.tolist(), check_circular=True) 60 | 61 | def decode(self, s): 62 | return opt.np.array(json.loads(s)) 63 | 64 | 65 | class DataFrameSerializer(Serializer): 66 | OBJ_CLASS = dataframe_type 67 | 68 | def encode(self, obj): 69 | return obj.to_json() 70 | 71 | def decode(self, s): 72 | return opt.pandas.read_json(s) 73 | 74 | 75 | class SeriesSerializer(Serializer): 76 | OBJ_CLASS = series_type 77 | 78 | def encode(self, obj): 79 | return obj.to_json() 80 | 81 | def decode(self, s): 82 | return opt.pandas.read_json(s, typ="series") 83 | 84 | 85 | class FileSerializer(Serializer): 86 | OBJ_CLASS = BufferedReaderWrapper 87 | 88 | def __init__(self, fs): 89 | self.fs = fs 90 | 91 | def encode(self, obj): 92 | address = self.fs.put(obj) 93 | return json.dumps(address.id) 94 | 95 | def decode(self, s): 96 | id_ = json.loads(s) 97 | file_reader = self.fs.open(id_) 98 | file_reader = BufferedReaderWrapper(file_reader) 99 | file_reader.hash = id_ 100 | return file_reader 101 | 102 | 103 | def get_db_file_manager(root_dir) -> Tuple[TinyDB, HashFS]: 104 | root_dir = Path(root_dir) 105 | fs = HashFS(root_dir / "hashfs", depth=3, width=2, algorithm="md5") 106 | 107 | # Setup Serialisation object for non list/dict objects 108 | serialization_store = SerializationMiddleware() 109 | serialization_store.register_serializer(DateTimeSerializer(), "TinyDate") 110 | serialization_store.register_serializer(FileSerializer(fs), "TinyFile") 111 | 112 | if opt.has_numpy: 113 | serialization_store.register_serializer(NdArraySerializer(), "TinyArray") 114 | if opt.has_pandas: 115 | serialization_store.register_serializer(DataFrameSerializer(), "TinyDataFrame") 116 | serialization_store.register_serializer(SeriesSerializer(), "TinySeries") 117 | 118 | db = TinyDB(os.path.join(root_dir, "metadata.json"), storage=serialization_store) 119 | return db, fs 120 | -------------------------------------------------------------------------------- /sacred/optional.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import importlib 5 | from sacred.utils import modules_exist 6 | from sacred.utils import get_package_version, parse_version 7 | 8 | 9 | def optional_import(*package_names): 10 | try: 11 | packages = [importlib.import_module(pn) for pn in package_names] 12 | return True, packages[0] 13 | except ImportError: 14 | return False, None 15 | 16 | 17 | def get_tensorflow(): 18 | # Ensures backward and forward compatibility with TensorFlow 1 and 2. 19 | if get_package_version("tensorflow") < parse_version("1.13.1"): 20 | import warnings 21 | 22 | warnings.warn( 23 | "Use of TensorFlow 1.12 and older is deprecated. " 24 | "Use Tensorflow 1.13 or newer instead.", 25 | DeprecationWarning, 26 | ) 27 | import tensorflow as tf 28 | else: 29 | import tensorflow.compat.v1 as tf 30 | return tf 31 | 32 | 33 | # Get libc in a cross-platform way and use it to also flush the c stdio buffers 34 | # credit to J.F. Sebastians SO answer from here: 35 | # http://stackoverflow.com/a/22434262/1388435 36 | try: 37 | import ctypes 38 | from ctypes.util import find_library 39 | except ImportError: 40 | libc = None 41 | else: 42 | try: 43 | libc = ctypes.cdll.msvcrt # Windows 44 | except (OSError, AttributeError): 45 | libc = ctypes.cdll.LoadLibrary(find_library("c")) 46 | 47 | 48 | has_numpy, np = optional_import("numpy") 49 | has_yaml, yaml = optional_import("yaml") 50 | has_pandas, pandas = optional_import("pandas") 51 | 52 | has_sqlalchemy = modules_exist("sqlalchemy") 53 | has_mako = modules_exist("mako") 54 | has_tinydb = modules_exist("tinydb", "tinydb_serialization", "hashfs") 55 | has_tensorflow = modules_exist("tensorflow") 56 | -------------------------------------------------------------------------------- /sacred/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/sacred/py.typed -------------------------------------------------------------------------------- /sacred/pytee.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | if __name__ == "__main__": 5 | import sys 6 | 7 | buffer = " " 8 | while len(buffer): 9 | buffer = sys.stdin.read() 10 | sys.stdout.write(buffer) 11 | sys.stderr.write(buffer) 12 | -------------------------------------------------------------------------------- /sacred/randomness.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import random 5 | 6 | import sacred.optional as opt 7 | from sacred.settings import SETTINGS 8 | from sacred.utils import module_is_in_cache 9 | 10 | SEEDRANGE = (1, int(1e9)) 11 | 12 | 13 | def get_seed(rnd=None): 14 | if rnd is None: 15 | return random.randint(*SEEDRANGE) 16 | else: 17 | try: 18 | return rnd.integers(*SEEDRANGE, dtype=int) 19 | except Exception: 20 | return rnd.randint(*SEEDRANGE) 21 | 22 | 23 | def create_rnd(seed): 24 | assert isinstance(seed, int), "Seed has to be integer but was {} {}".format( 25 | repr(seed), type(seed) 26 | ) 27 | if opt.has_numpy: 28 | if SETTINGS.CONFIG.NUMPY_RANDOM_LEGACY_API: 29 | return opt.np.random.RandomState(seed) 30 | else: 31 | return opt.np.random.default_rng(seed) 32 | else: 33 | return random.Random(seed) 34 | 35 | 36 | def set_global_seed(seed): 37 | random.seed(seed) 38 | if opt.has_numpy: 39 | opt.np.random.seed(seed) 40 | if module_is_in_cache("tensorflow"): 41 | tf = opt.get_tensorflow() 42 | tf.set_random_seed(seed) 43 | if module_is_in_cache("torch"): 44 | import torch 45 | 46 | torch.manual_seed(seed) 47 | if torch.cuda.is_available(): 48 | torch.cuda.manual_seed_all(seed) 49 | -------------------------------------------------------------------------------- /sacred/serializer.py: -------------------------------------------------------------------------------- 1 | import jsonpickle 2 | import json as _json 3 | from sacred import optional as opt 4 | 5 | json = jsonpickle 6 | 7 | 8 | __all__ = ("flatten", "restore") 9 | 10 | 11 | if opt.has_numpy: 12 | import jsonpickle.ext.numpy as jsonpickle_numpy 13 | 14 | np = opt.np 15 | 16 | jsonpickle_numpy.register_handlers() 17 | 18 | if opt.has_pandas: 19 | import jsonpickle.ext.pandas as jsonpickle_pandas 20 | 21 | jsonpickle_pandas.register_handlers() 22 | 23 | 24 | jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) 25 | jsonpickle.set_encoder_options("demjson", compactly=False) 26 | 27 | 28 | def flatten(obj): 29 | return _json.loads(json.encode(obj, keys=True)) 30 | 31 | 32 | def restore(flat): 33 | return json.decode(_json.dumps(flat), keys=True, on_missing="error") 34 | -------------------------------------------------------------------------------- /sacred/settings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import platform 5 | from sacred.utils import SacredError 6 | import sacred.optional as opt 7 | from munch import Munch 8 | from packaging import version 9 | 10 | __all__ = ("SETTINGS", "SettingError") 11 | 12 | 13 | class SettingError(SacredError): 14 | """Error for invalid settings.""" 15 | 16 | 17 | class FrozenKeyMunch(Munch): 18 | __frozen_keys = False 19 | 20 | def freeze_keys(self): 21 | if self.__frozen_keys: 22 | return 23 | self.__frozen_keys = True 24 | for v in self.values(): 25 | if isinstance(v, FrozenKeyMunch): 26 | v.freeze_keys() 27 | 28 | def _check_can_set(self, key, value): 29 | if not self.__frozen_keys: 30 | return 31 | 32 | # Don't allow unknown keys 33 | if key not in self: 34 | raise SettingError( 35 | f"Unknown setting: {key}. Possible keys are: " f"{list(self.keys())}" 36 | ) 37 | 38 | # Don't allow setting keys that represent nested settings 39 | if isinstance(self[key], Munch) and not isinstance(value, Munch): 40 | # We don't want to overwrite a munch mapping. This is the easiest 41 | # solution and closest to the original implementation where setting 42 | # a setting with a dict would likely at some point cause an 43 | # exception 44 | raise SettingError( 45 | f"Can't set this setting ({key}) to a non-munch value " 46 | f"{value}, it is a nested setting!" 47 | ) 48 | 49 | def __setitem__(self, key, value): 50 | self._check_can_set(key, value) 51 | super().__setitem__(key, value) 52 | 53 | def __setattr__(self, key, value): 54 | self._check_can_set(key, value) 55 | super().__setattr__(key, value) 56 | 57 | def __deepcopy__(self, memodict=None): 58 | obj = self.__class__.fromDict(self.toDict()) 59 | if self.__frozen_keys: 60 | obj.freeze_keys() 61 | return obj 62 | 63 | 64 | SETTINGS = FrozenKeyMunch.fromDict( 65 | { 66 | "CONFIG": { 67 | # make sure all config keys are compatible with MongoDB 68 | "ENFORCE_KEYS_MONGO_COMPATIBLE": True, 69 | # make sure all config keys are serializable with jsonpickle 70 | # THIS IS IMPORTANT. Only deactivate if you know what you're doing. 71 | "ENFORCE_KEYS_JSONPICKLE_COMPATIBLE": True, 72 | # make sure all config keys are valid python identifiers 73 | "ENFORCE_VALID_PYTHON_IDENTIFIER_KEYS": False, 74 | # make sure all config keys are strings 75 | "ENFORCE_STRING_KEYS": False, 76 | # make sure no config key contains an equals sign 77 | "ENFORCE_KEYS_NO_EQUALS": True, 78 | # if true, all dicts and lists in the configuration of a captured 79 | # function are replaced with a read-only container that raises an 80 | # Exception if it is attempted to write to those containers 81 | "READ_ONLY_CONFIG": True, 82 | # regex patterns to filter out certain IDE or linter directives 83 | # from inline comments in the documentation 84 | "IGNORED_COMMENTS": ["^pylint:", "^noinspection"], 85 | # if true uses the numpy legacy API, i.e. _rnd in captured functions is 86 | # a numpy.random.RandomState rather than numpy.random.Generator. 87 | # numpy.random.RandomState became legacy with numpy v1.19. 88 | "NUMPY_RANDOM_LEGACY_API": version.parse(opt.np.__version__) 89 | < version.parse("1.19") 90 | if opt.has_numpy 91 | else False, 92 | }, 93 | "HOST_INFO": { 94 | # Collect information about GPUs using the nvidia-smi tool 95 | "INCLUDE_GPU_INFO": True, 96 | # Collect information about CPUs using py-cpuinfo 97 | "INCLUDE_CPU_INFO": True, 98 | # List of ENVIRONMENT variables to store in host-info 99 | "CAPTURED_ENV": [], 100 | }, 101 | "COMMAND_LINE": { 102 | # disallow string fallback, if parsing a value from command-line 103 | # failed 104 | "STRICT_PARSING": False, 105 | # show command line options that are disabled (e.g. unmet 106 | # dependencies) 107 | "SHOW_DISABLED_OPTIONS": True, 108 | }, 109 | # configure how stdout/stderr are captured. ['no', 'sys', 'fd'] 110 | "CAPTURE_MODE": "sys" if platform.system() == "Windows" else "fd", 111 | # configure how dependencies are discovered. [none, imported, sys, pkg] 112 | "DISCOVER_DEPENDENCIES": "imported", 113 | # configure how source-files are discovered. [none, imported, sys, dir] 114 | "DISCOVER_SOURCES": "imported", 115 | # Configure the default beat interval, in seconds 116 | "DEFAULT_BEAT_INTERVAL": 10.0, 117 | }, 118 | ) 119 | SETTINGS.freeze_keys() 120 | -------------------------------------------------------------------------------- /sacred/stdout_capturing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import sys 6 | import subprocess 7 | import warnings 8 | from io import StringIO 9 | from contextlib import contextmanager 10 | import wrapt 11 | from sacred.optional import libc 12 | from tempfile import NamedTemporaryFile 13 | from sacred.settings import SETTINGS 14 | 15 | 16 | def flush(): 17 | """Try to flush all stdio buffers, both from python and from C.""" 18 | try: 19 | sys.stdout.flush() 20 | sys.stderr.flush() 21 | except (AttributeError, ValueError, OSError): 22 | pass # unsupported 23 | try: 24 | libc.fflush(None) 25 | except (AttributeError, ValueError, OSError): 26 | pass # unsupported 27 | 28 | 29 | def get_stdcapturer(mode=None): 30 | mode = mode if mode is not None else SETTINGS.CAPTURE_MODE 31 | capture_options = {"no": no_tee, "fd": tee_output_fd, "sys": tee_output_python} 32 | if mode not in capture_options: 33 | raise KeyError( 34 | "Unknown capture mode '{}'. Available options are {}".format( 35 | mode, sorted(capture_options.keys()) 36 | ) 37 | ) 38 | return mode, capture_options[mode] 39 | 40 | 41 | class TeeingStreamProxy(wrapt.ObjectProxy): 42 | """A wrapper around stdout or stderr that duplicates all output to out.""" 43 | 44 | def __init__(self, wrapped, out): 45 | super().__init__(wrapped) 46 | self._self_out = out 47 | 48 | def write(self, data): 49 | self.__wrapped__.write(data) 50 | self._self_out.write(data) 51 | 52 | def flush(self): 53 | self.__wrapped__.flush() 54 | self._self_out.flush() 55 | 56 | 57 | class CapturedStdout: 58 | def __init__(self, buffer): 59 | self.buffer = buffer 60 | self.read_position = 0 61 | self.final = None 62 | 63 | @property 64 | def closed(self): 65 | return self.buffer.closed 66 | 67 | def flush(self): 68 | return self.buffer.flush() 69 | 70 | def get(self): 71 | if self.final is None: 72 | self.buffer.seek(self.read_position) 73 | value = self.buffer.read() 74 | self.read_position = self.buffer.tell() 75 | return value 76 | else: 77 | value = self.final 78 | self.final = None 79 | return value 80 | 81 | def finalize(self): 82 | self.flush() 83 | self.final = self.get() 84 | self.buffer.close() 85 | 86 | 87 | @contextmanager 88 | def no_tee(): 89 | out = CapturedStdout(StringIO()) 90 | try: 91 | yield out 92 | finally: 93 | out.finalize() 94 | 95 | 96 | @contextmanager 97 | def tee_output_python(): 98 | """Duplicate sys.stdout and sys.stderr to new StringIO.""" 99 | buffer = StringIO() 100 | out = CapturedStdout(buffer) 101 | orig_stdout, orig_stderr = sys.stdout, sys.stderr 102 | flush() 103 | sys.stdout = TeeingStreamProxy(sys.stdout, buffer) 104 | sys.stderr = TeeingStreamProxy(sys.stderr, buffer) 105 | try: 106 | yield out 107 | finally: 108 | flush() 109 | out.finalize() 110 | sys.stdout, sys.stderr = orig_stdout, orig_stderr 111 | 112 | 113 | # Duplicate stdout and stderr to a file. Inspired by: 114 | # http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ 115 | # http://stackoverflow.com/a/651718/1388435 116 | # http://stackoverflow.com/a/22434262/1388435 117 | @contextmanager 118 | def tee_output_fd(): 119 | """Duplicate stdout and stderr to a file on the file descriptor level.""" 120 | with NamedTemporaryFile(mode="w+", newline="") as target: 121 | original_stdout_fd = 1 122 | original_stderr_fd = 2 123 | target_fd = target.fileno() 124 | 125 | # Save a copy of the original stdout and stderr file descriptors) 126 | saved_stdout_fd = os.dup(original_stdout_fd) 127 | saved_stderr_fd = os.dup(original_stderr_fd) 128 | 129 | try: 130 | # start_new_session=True to move process to a new process group 131 | # this is done to avoid receiving KeyboardInterrupts (see #149) 132 | tee_stdout = subprocess.Popen( 133 | ["tee", "-a", target.name], 134 | start_new_session=True, 135 | stdin=subprocess.PIPE, 136 | stdout=1, 137 | ) 138 | tee_stderr = subprocess.Popen( 139 | ["tee", "-a", target.name], 140 | start_new_session=True, 141 | stdin=subprocess.PIPE, 142 | stdout=2, 143 | ) 144 | except (FileNotFoundError, OSError, AttributeError): 145 | # No tee found in this operating system. Trying to use a python 146 | # implementation of tee. However this is slow and error-prone. 147 | tee_stdout = subprocess.Popen( 148 | [sys.executable, "-m", "sacred.pytee"], 149 | stdin=subprocess.PIPE, 150 | stderr=target_fd, 151 | ) 152 | tee_stderr = subprocess.Popen( 153 | [sys.executable, "-m", "sacred.pytee"], 154 | stdin=subprocess.PIPE, 155 | stdout=target_fd, 156 | ) 157 | 158 | flush() 159 | os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd) 160 | os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd) 161 | out = CapturedStdout(target) 162 | 163 | try: 164 | yield out # let the caller do their printing 165 | finally: 166 | flush() 167 | 168 | # then redirect stdout back to the saved fd 169 | tee_stdout.stdin.close() 170 | tee_stderr.stdin.close() 171 | 172 | # restore original fds 173 | os.dup2(saved_stdout_fd, original_stdout_fd) 174 | os.dup2(saved_stderr_fd, original_stderr_fd) 175 | 176 | try: 177 | tee_stdout.wait(timeout=1) 178 | except subprocess.TimeoutExpired: 179 | warnings.warn("tee_stdout.wait timeout. Forcibly terminating.") 180 | tee_stdout.terminate() 181 | 182 | try: 183 | tee_stderr.wait(timeout=1) 184 | except subprocess.TimeoutExpired: 185 | warnings.warn("tee_stderr.wait timeout. Forcibly terminating.") 186 | tee_stderr.terminate() 187 | 188 | os.close(saved_stdout_fd) 189 | os.close(saved_stderr_fd) 190 | out.finalize() 191 | -------------------------------------------------------------------------------- /sacred/stflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .method_interception import LogFileWriter 2 | 3 | __all__ = ("LogFileWriter",) 4 | -------------------------------------------------------------------------------- /sacred/stflow/internal.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | class ContextMethodDecorator: 5 | """A helper ContextManager decorating a method with a custom function.""" 6 | 7 | def __init__(self, classx, method_name, decorator_func): 8 | """ 9 | Create a new context manager decorating a function within its scope. 10 | 11 | This is a helper Context Manager that decorates a method of a class 12 | with a custom function. 13 | The decoration is only valid within the scope. 14 | :param classx: A class (object) 15 | :param method_name A string name of the method to be decorated 16 | :param decorator_func: The decorator function is responsible 17 | for calling the original method. 18 | The signature should be: func(instance, original_method, 19 | original_args, original_kwargs) 20 | when called, instance refers to an instance of classx and the 21 | original_method refers to the original method object which can be 22 | called. 23 | args and kwargs are arguments passed to the method 24 | 25 | """ 26 | self.method_name = method_name 27 | self.decorator_func = decorator_func 28 | self.classx = classx 29 | self.patched_by_me = False 30 | 31 | def __enter__(self): 32 | 33 | self.original_method = getattr(self.classx, self.method_name) 34 | if not hasattr( 35 | self.original_method, "sacred_patched%s" % self.__class__.__name__ 36 | ): 37 | 38 | @functools.wraps(self.original_method) 39 | def decorated(instance, *args, **kwargs): 40 | return self.decorator_func(instance, self.original_method, args, kwargs) 41 | 42 | setattr(self.classx, self.method_name, decorated) 43 | setattr(decorated, "sacred_patched%s" % self.__class__.__name__, True) 44 | self.patched_by_me = True 45 | 46 | def __exit__(self, type, value, traceback): 47 | if self.patched_by_me: 48 | # Restore original function 49 | setattr(self.classx, self.method_name, self.original_method) 50 | -------------------------------------------------------------------------------- /sacred/stflow/method_interception.py: -------------------------------------------------------------------------------- 1 | from contextlib import ContextDecorator 2 | from .internal import ContextMethodDecorator 3 | import sacred.optional as opt 4 | 5 | 6 | if opt.has_tensorflow: 7 | tf = opt.get_tensorflow() 8 | else: 9 | tf = None 10 | 11 | 12 | class LogFileWriter(ContextDecorator, ContextMethodDecorator): 13 | """ 14 | Intercept ``logdir`` each time a new ``FileWriter`` instance is created. 15 | 16 | :param experiment: Tensorflow experiment. 17 | 18 | The state of the experiment must be running when entering the annotated 19 | function / the context manager. 20 | 21 | When creating ``FileWriters`` in Tensorflow, you might want to 22 | store the path to the produced log files in the sacred database. 23 | 24 | In the scope of ``LogFileWriter``, the corresponding log directory path 25 | is appended to a list in experiment.info["tensorflow"]["logdirs"]. 26 | 27 | ``LogFileWriter`` can be used both as a context manager or as 28 | an annotation (decorator) on a function. 29 | 30 | 31 | Example usage as decorator:: 32 | 33 | ex = Experiment("my experiment") 34 | @LogFileWriter(ex) 35 | def run_experiment(_run): 36 | with tf.Session() as s: 37 | swr = tf.summary.FileWriter("/tmp/1", s.graph) 38 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] 39 | swr2 tf.summary.FileWriter("./test", s.graph) 40 | #_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 41 | 42 | 43 | Example usage as context manager:: 44 | 45 | ex = Experiment("my experiment") 46 | def run_experiment(_run): 47 | with tf.Session() as s: 48 | with LogFileWriter(ex): 49 | swr = tf.summary.FileWriter("/tmp/1", s.graph) 50 | # _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] 51 | swr3 = tf.summary.FileWriter("./test", s.graph) 52 | #_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 53 | # This is called outside the scope and won't be captured 54 | swr3 = tf.summary.FileWriter("./nothing", s.graph) 55 | # Nothing has changed: 56 | #_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] 57 | 58 | """ 59 | 60 | def __init__(self, experiment): 61 | self.experiment = experiment 62 | 63 | def log_writer_decorator( 64 | instance, original_method, original_args, original_kwargs 65 | ): 66 | result = original_method(instance, *original_args, **original_kwargs) 67 | if "logdir" in original_kwargs: 68 | logdir = original_kwargs["logdir"] 69 | else: 70 | logdir = original_args[0] 71 | self.experiment.info.setdefault("tensorflow", {}).setdefault( 72 | "logdirs", [] 73 | ).append(logdir) 74 | return result 75 | 76 | ContextMethodDecorator.__init__( 77 | self, tf.summary.FileWriter, "__init__", log_writer_decorator 78 | ) 79 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | [tool:pytest] 4 | pep8ignore = 5 | tests/test_config/test_config_scope.py ALL # Lots of errors that need to be there for testing 6 | build/* ALL 7 | dist/* ALL 8 | sacred.egg-info/* ALL 9 | [flake8] 10 | ignore = D100,D101,D102,D103,D104,D105,D203,D401,F821,E722,E203,E501,N818, 11 | # flake8 default ignores: 12 | E121,E123,E126,E226,E24,E704,W503,W504 13 | max-complexity = 10 14 | docstring-convention = numpy 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | import os 5 | 6 | classifiers = """ 7 | Development Status :: 5 - Production/Stable 8 | Intended Audience :: Science/Research 9 | Natural Language :: English 10 | Operating System :: OS Independent 11 | Programming Language :: Python :: 3.8 12 | Programming Language :: Python :: 3.9 13 | Programming Language :: Python :: 3.10 14 | Programming Language :: Python :: 3.11 15 | Topic :: Utilities 16 | Topic :: Scientific/Engineering 17 | Topic :: Scientific/Engineering :: Artificial Intelligence 18 | Topic :: Software Development :: Libraries :: Python Modules 19 | License :: OSI Approved :: MIT License 20 | """ 21 | 22 | try: 23 | from sacred import __about__ 24 | 25 | about = __about__.__dict__ 26 | except ImportError: 27 | # installing - dependencies are not there yet 28 | # Manually extract the __about__ 29 | about = dict() 30 | exec(open("sacred/__about__.py").read(), about) 31 | 32 | 33 | setup( 34 | name="sacred", 35 | version=about["__version__"], 36 | author=about["__author__"], 37 | author_email=about["__author_email__"], 38 | url=about["__url__"], 39 | packages=find_packages(include=["sacred", "sacred.*"]), 40 | package_data={"sacred": [os.path.join("data", "*"), "py.typed"]}, 41 | scripts=[], 42 | python_requires=">=3.8", 43 | install_requires=Path("requirements.txt").read_text().splitlines(), 44 | tests_require=["mock>=3.0, <5.0", "pytest==7.1.2"], 45 | classifiers=list(filter(None, classifiers.split("\n"))), 46 | description="Facilitates automated and reproducible experimental research", 47 | long_description=Path("README.rst").read_text(encoding="utf-8"), 48 | ) 49 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | -------------------------------------------------------------------------------- /tests/basedir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/tests/basedir/__init__.py -------------------------------------------------------------------------------- /tests/basedir/my_experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A file for testing the gathering of sources when a custom base directory is set. 5 | """ 6 | 7 | from tests.foo import bar 8 | 9 | 10 | def some_func(): 11 | pass 12 | -------------------------------------------------------------------------------- /tests/check_pre_commit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | pip install -e . 6 | pip install pre-commit 7 | pre-commit run --all-files 8 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import pytest 4 | import tempfile 5 | import hashlib 6 | import os.path 7 | import re 8 | import shlex 9 | import sys 10 | import warnings 11 | from importlib import reload 12 | 13 | from sacred.settings import SETTINGS 14 | 15 | EXAMPLES_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") 16 | BLOCK_START = re.compile(r"^\s\s+\$.*$", flags=re.MULTILINE) 17 | 18 | 19 | def get_calls_from_doc(doc): 20 | """ 21 | Parses a docstring looking for indented blocks that start with $. 22 | It returns the first lines as call and the rest of the blocks as outputs. 23 | """ 24 | if doc is None: 25 | return [] 26 | calls = [] 27 | outputs = [] 28 | out = [] 29 | block_indent = 2 30 | for line in doc.split("\n"): 31 | if BLOCK_START.match(line): 32 | block_indent = line.find("$") 33 | calls.append(shlex.split(line[block_indent + 1 :])) 34 | out = [] 35 | outputs.append(out) 36 | elif line.startswith(" " * block_indent): 37 | out.append(line[block_indent:]) 38 | else: 39 | out = [] 40 | 41 | return zip(calls, outputs) 42 | 43 | 44 | def pytest_generate_tests(metafunc): 45 | # collects all examples and parses their docstring for calls + outputs 46 | # it then parametrizes the function with 'example_test' 47 | if "example_test" in metafunc.fixturenames: 48 | examples = [ 49 | os.path.splitext(f)[0] 50 | for f in os.listdir(EXAMPLES_PATH) 51 | if os.path.isfile(os.path.join(EXAMPLES_PATH, f)) 52 | and f.endswith(".py") 53 | and f != "__init__.py" 54 | and re.match(r"^\d", f) 55 | ] 56 | 57 | sys.path.append(EXAMPLES_PATH) 58 | example_tests = [] 59 | example_ids = [] 60 | for example_name in sorted(examples): 61 | try: 62 | example = __import__(example_name) 63 | except ModuleNotFoundError: 64 | warnings.warn( 65 | "could not import {name}, skips during test.".format( 66 | name=example_name 67 | ) 68 | ) 69 | continue 70 | calls_outs = get_calls_from_doc(example.__doc__) 71 | for i, (call, out) in enumerate(calls_outs): 72 | example = reload(example) 73 | example_tests.append((example.ex, call, out)) 74 | example_ids.append("{}_{}".format(example_name, i)) 75 | metafunc.parametrize("example_test", example_tests, ids=example_ids) 76 | 77 | 78 | def pytest_addoption(parser): 79 | parser.addoption( 80 | "--sqlalchemy-connect-url", 81 | action="store", 82 | default="sqlite://", 83 | help="Name of the database to connect to", 84 | ) 85 | 86 | 87 | @pytest.fixture 88 | def tmpfile(): 89 | # NOTE: instead of using a with block and delete=True we are creating and 90 | # manually deleting the file, such that we can close it before running the 91 | # tests. This is necessary since on Windows we can not open the same file 92 | # twice, so for the FileStorageObserver to read it, we need to close it. 93 | f = tempfile.NamedTemporaryFile(suffix=".py", delete=False) 94 | 95 | f.content = "import sacred\n" 96 | f.write(f.content.encode()) 97 | f.flush() 98 | f.seek(0) 99 | f.md5sum = hashlib.md5(f.read()).hexdigest() 100 | 101 | f.close() 102 | 103 | yield f 104 | 105 | os.remove(f.name) 106 | 107 | 108 | # Deactivate GPU and CPU info to speed up tests 109 | SETTINGS.HOST_INFO.INCLUDE_GPU_INFO = False 110 | SETTINGS.HOST_INFO.INCLUDE_CPU_INFO = False 111 | -------------------------------------------------------------------------------- /tests/dependency_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A file for testing the gathering of sources and dependency by test_dependencies 5 | """ 6 | 7 | import mock 8 | import pytest 9 | 10 | from tests.foo import bar, mock_extension 11 | 12 | 13 | # Actually this would not work :( 14 | # import tests.foo.bar 15 | 16 | 17 | def some_func(): 18 | pass 19 | 20 | 21 | ignore_this = 17 22 | -------------------------------------------------------------------------------- /tests/donotimport.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | raise RuntimeError("Do NOT import this file!") 5 | -------------------------------------------------------------------------------- /tests/foo/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A local package used to test the gathering of sources by test_dependencies. 5 | """ 6 | -------------------------------------------------------------------------------- /tests/foo/bar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | A local imported module for the dependency example. Used to test the gathering 5 | of sources. 6 | """ 7 | 8 | 9 | def test_func(): 10 | pass 11 | -------------------------------------------------------------------------------- /tests/foo/mock_extension.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | This source should not be gathered. It is a regression test for an 5 | exception that happened if a custom pybind11 extension was present. 6 | """ 7 | 8 | 9 | __file__ = None 10 | -------------------------------------------------------------------------------- /tests/test_arg_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | import shlex 7 | from docopt import docopt 8 | 9 | from sacred.arg_parser import _convert_value, get_config_updates, format_usage 10 | from sacred.experiment import gather_command_line_options 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "argv,expected", 15 | [ 16 | ("", {}), 17 | ("run", {"COMMAND": "run"}), 18 | ("with 1 2", {"with": True, "UPDATE": ["1", "2"]}), 19 | ("evaluate", {"COMMAND": "evaluate"}), 20 | ("help", {"help": True}), 21 | ("help evaluate", {"help": True, "COMMAND": "evaluate"}), 22 | ("-h", {"--help": True}), 23 | ("--help", {"--help": True}), 24 | ("-m foo", {"--mongo_db": "foo"}), 25 | ("--mongo_db=bar", {"--mongo_db": "bar"}), 26 | ("-l 10", {"--loglevel": "10"}), 27 | ("--loglevel=30", {"--loglevel": "30"}), 28 | ("--force", {"--force": True}), 29 | ( 30 | "run with a=17 b=1 -m localhost:22222", 31 | { 32 | "COMMAND": "run", 33 | "with": True, 34 | "UPDATE": ["a=17", "b=1"], 35 | "--mongo_db": "localhost:22222", 36 | }, 37 | ), 38 | ( 39 | "evaluate with a=18 b=2 -l30", 40 | { 41 | "COMMAND": "evaluate", 42 | "with": True, 43 | "UPDATE": ["a=18", "b=2"], 44 | "--loglevel": "30", 45 | }, 46 | ), 47 | ("--id=1", {"--id": "1"}), 48 | ], 49 | ) 50 | def test_parse_individual_arguments(argv, expected): 51 | options = gather_command_line_options() 52 | usage = format_usage("test.py", "", {}, options) 53 | argv = shlex.split(argv) 54 | plain = docopt(usage, [], default_help=False) 55 | args = docopt(usage, argv, default_help=False) 56 | plain.update(expected) 57 | assert args == plain 58 | 59 | 60 | @pytest.mark.parametrize( 61 | "update,expected", 62 | [ 63 | (None, {}), 64 | (["a=5"], {"a": 5}), 65 | (["foo.bar=6"], {"foo": {"bar": 6}}), 66 | (["a=9", "b=0"], {"a": 9, "b": 0}), 67 | (["hello='world'"], {"hello": "world"}), 68 | (['hello="world"'], {"hello": "world"}), 69 | (["f=23.5"], {"f": 23.5}), 70 | (["n=None"], {"n": None}), 71 | (["t=True"], {"t": True}), 72 | (["f=False"], {"f": False}), 73 | ], 74 | ) 75 | def test_get_config_updates(update, expected): 76 | assert get_config_updates(update) == (expected, []) 77 | 78 | 79 | @pytest.mark.parametrize( 80 | "value,expected", 81 | [ 82 | ("None", None), 83 | ("True", True), 84 | ("False", False), 85 | ("246", 246), 86 | ("1.0", 1.0), 87 | ("1.", 1.0), 88 | (".1", 0.1), 89 | ("1e3", 1e3), 90 | ("-.4e-12", -0.4e-12), 91 | ("-.4e-12", -0.4e-12), 92 | ("[1,2,3]", [1, 2, 3]), 93 | ("[1.,.1]", [1.0, 0.1]), 94 | ("[True, False]", [True, False]), 95 | ("[None, None]", [None, None]), 96 | ("[1.0,2.0,3.0]", [1.0, 2.0, 3.0]), 97 | ('{"a":1}', {"a": 1}), 98 | ('{"foo":1, "bar":2.0}', {"foo": 1, "bar": 2.0}), 99 | ('{"a":1., "b":.2}', {"a": 1.0, "b": 0.2}), 100 | ('{"a":True, "b":False}', {"a": True, "b": False}), 101 | ('{"a":None}', {"a": None}), 102 | ( 103 | '{"a":[1, 2.0, True, None], "b":"foo"}', 104 | {"a": [1, 2.0, True, None], "b": "foo"}, 105 | ), 106 | ("bob", "bob"), 107 | ('"hello world"', "hello world"), 108 | ("'hello world'", "hello world"), 109 | ], 110 | ) 111 | def test_convert_value(value, expected): 112 | assert _convert_value(value) == expected 113 | -------------------------------------------------------------------------------- /tests/test_config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | -------------------------------------------------------------------------------- /tests/test_config/enclosed_config_scope.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | from sacred.config.config_scope import ConfigScope 6 | 7 | SIX = 6 8 | 9 | 10 | @ConfigScope 11 | def cfg(): 12 | answer = 7 * SIX 13 | 14 | 15 | @ConfigScope 16 | def cfg2(): 17 | answer = 6 * SEVEN 18 | -------------------------------------------------------------------------------- /tests/test_config/test_captured_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import datetime 5 | import mock 6 | import random 7 | import sacred.optional as opt 8 | from sacred.config.captured_function import create_captured_function 9 | from sacred.settings import SETTINGS 10 | 11 | 12 | def test_create_captured_function(): 13 | def foo(): 14 | """my docstring""" 15 | return 42 16 | 17 | cf = create_captured_function(foo) 18 | 19 | assert cf.__name__ == "foo" 20 | assert cf.__doc__ == "my docstring" 21 | assert cf.prefix is None 22 | assert cf.config == {} 23 | assert not cf.uses_randomness 24 | assert callable(cf) 25 | 26 | 27 | def test_call_captured_function(): 28 | def foo(a, b, c, d=4, e=5, f=6): 29 | return a, b, c, d, e, f 30 | 31 | cf = create_captured_function(foo) 32 | cf.logger = mock.MagicMock() 33 | cf.config = {"a": 11, "b": 12, "d": 14} 34 | 35 | assert cf(21, c=23, f=26) == (21, 12, 23, 14, 5, 26) 36 | cf.logger.debug.assert_has_calls( 37 | [mock.call("Started"), mock.call("Finished after %s.", datetime.timedelta(0))] 38 | ) 39 | 40 | 41 | def test_captured_function_randomness(): 42 | def foo(_rnd, _seed): 43 | try: 44 | return _rnd.integers(0, 1000), _seed 45 | except Exception: 46 | return _rnd.randint(0, 1000), _seed 47 | 48 | cf = create_captured_function(foo) 49 | assert cf.uses_randomness 50 | cf.logger = mock.MagicMock() 51 | cf.rnd = random.Random(1234) 52 | 53 | nr1, seed1 = cf() 54 | nr2, seed2 = cf() 55 | assert nr1 != nr2 56 | assert seed1 != seed2 57 | 58 | cf.rnd = random.Random(1234) 59 | 60 | assert cf() == (nr1, seed1) 61 | assert cf() == (nr2, seed2) 62 | 63 | 64 | def test_captured_function_numpy_randomness(): 65 | def foo(_rnd, _seed): 66 | return _rnd, _seed 67 | 68 | cf = create_captured_function(foo) 69 | assert cf.uses_randomness 70 | cf.logger = mock.MagicMock() 71 | cf.rnd = random.Random(1234) 72 | 73 | SETTINGS.CONFIG.NUMPY_RANDOM_LEGACY_API = False 74 | rnd, seed = cf() 75 | if opt.has_numpy: 76 | assert type(rnd) == opt.np.random.Generator 77 | 78 | SETTINGS.CONFIG.NUMPY_RANDOM_LEGACY_API = True 79 | rnd, seed = cf() 80 | assert type(rnd) == opt.np.random.RandomState 81 | else: 82 | assert type(rnd) == random.Random 83 | 84 | 85 | def test_captured_function_magic_logger_argument(): 86 | def foo(_log): 87 | return _log 88 | 89 | cf = create_captured_function(foo) 90 | cf.logger = mock.MagicMock() 91 | 92 | assert cf() == cf.logger 93 | 94 | 95 | def test_captured_function_magic_config_argument(): 96 | def foo(_config): 97 | return _config 98 | 99 | cf = create_captured_function(foo) 100 | cf.logger = mock.MagicMock() 101 | cf.config = {"a": 2, "b": 2} 102 | 103 | assert cf() == cf.config 104 | 105 | 106 | def test_captured_function_magic_run_argument(): 107 | def foo(_run): 108 | return _run 109 | 110 | cf = create_captured_function(foo) 111 | cf.logger = mock.MagicMock() 112 | cf.run = mock.MagicMock() 113 | 114 | assert cf() == cf.run 115 | 116 | 117 | def test_captured_function_call_doesnt_modify_kwargs(): 118 | def foo(a, _log): 119 | if _log is not None: 120 | return a 121 | 122 | cf = create_captured_function(foo) 123 | cf.logger = mock.MagicMock() 124 | cf.run = mock.MagicMock() 125 | 126 | d = {"a": 7} 127 | assert cf(**d) == 7 128 | assert d == {"a": 7} 129 | -------------------------------------------------------------------------------- /tests/test_config/test_config_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | import sacred.optional as opt 7 | from sacred.config import ConfigDict 8 | from sacred.config.custom_containers import DogmaticDict, DogmaticList 9 | 10 | 11 | @pytest.fixture 12 | def conf_dict(): 13 | cfg = ConfigDict( 14 | { 15 | "a": 1, 16 | "b": 2.0, 17 | "c": True, 18 | "d": "string", 19 | "e": [1, 2, 3], 20 | "f": {"a": "b", "c": "d"}, 21 | } 22 | ) 23 | return cfg 24 | 25 | 26 | def test_config_dict_returns_dict(conf_dict): 27 | assert isinstance(conf_dict(), dict) 28 | 29 | 30 | def test_config_dict_result_contains_keys(conf_dict): 31 | cfg = conf_dict() 32 | assert set(cfg.keys()) == {"a", "b", "c", "d", "e", "f"} 33 | assert cfg["a"] == 1 34 | assert cfg["b"] == 2.0 35 | assert cfg["c"] 36 | assert cfg["d"] == "string" 37 | assert cfg["e"] == [1, 2, 3] 38 | assert cfg["f"] == {"a": "b", "c": "d"} 39 | 40 | 41 | def test_fixing_values(conf_dict): 42 | assert conf_dict({"a": 100})["a"] == 100 43 | 44 | 45 | @pytest.mark.parametrize("key", ["$f", "contains.dot", "py/tuple", "json://1"]) 46 | def test_config_dict_raises_on_invalid_keys(key): 47 | with pytest.raises(KeyError): 48 | ConfigDict({key: True}) 49 | 50 | 51 | @pytest.mark.parametrize("value", [lambda x: x, pytest, test_fixing_values]) 52 | def test_config_dict_accepts_special_types(value): 53 | assert ConfigDict({"special": value})()["special"] == value 54 | 55 | 56 | def test_fixing_nested_dicts(conf_dict): 57 | cfg = conf_dict({"f": {"c": "t"}}) 58 | assert cfg["f"]["a"] == "b" 59 | assert cfg["f"]["c"] == "t" 60 | 61 | 62 | def test_adding_values(conf_dict): 63 | cfg = conf_dict({"g": 23, "h": {"i": 10}}) 64 | assert cfg["g"] == 23 65 | assert cfg["h"] == {"i": 10} 66 | assert cfg.added == {"g", "h", "h.i"} 67 | 68 | 69 | def test_typechange(conf_dict): 70 | cfg = conf_dict({"a": "bar", "b": "foo", "c": 1}) 71 | assert cfg.typechanged == { 72 | "a": (int, type("bar")), 73 | "b": (float, type("foo")), 74 | "c": (bool, int), 75 | } 76 | 77 | 78 | def test_nested_typechange(conf_dict): 79 | cfg = conf_dict({"f": {"a": 10}}) 80 | assert cfg.typechanged == {"f.a": (type("a"), int)} 81 | 82 | 83 | def is_dogmatic(a): 84 | if isinstance(a, (DogmaticDict, DogmaticList)): 85 | return True 86 | elif isinstance(a, dict): 87 | return any(is_dogmatic(v) for v in a.values()) 88 | elif isinstance(a, (list, tuple)): 89 | return any(is_dogmatic(v) for v in a) 90 | 91 | 92 | def test_result_of_conf_dict_is_not_dogmatic(conf_dict): 93 | cfg = conf_dict({"e": [1, 1, 1]}) 94 | assert not is_dogmatic(cfg) 95 | 96 | 97 | @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") 98 | def test_conf_scope_handles_numpy_bools(): 99 | cfg = ConfigDict({"a": opt.np.bool_(1)}) 100 | assert "a" in cfg() 101 | assert cfg()["a"] 102 | 103 | 104 | def test_conf_scope_contains_presets(): 105 | conf_dict = ConfigDict({"answer": 42}) 106 | cfg = conf_dict(preset={"a": 21, "unrelated": True}) 107 | assert set(cfg.keys()) == {"a", "answer", "unrelated"} 108 | assert cfg["a"] == 21 109 | assert cfg["answer"] == 42 110 | assert cfg["unrelated"] is True 111 | 112 | 113 | def test_conf_scope_does_not_contain_fallback(): 114 | config_dict = ConfigDict({"answer": 42}) 115 | 116 | cfg = config_dict(fallback={"a": 21, "b": 10}) 117 | 118 | assert set(cfg.keys()) == {"answer"} 119 | 120 | 121 | def test_fixed_subentry_of_preset(): 122 | config_dict = ConfigDict({}) 123 | 124 | cfg = config_dict(preset={"d": {"a": 1, "b": 2}}, fixed={"d": {"a": 10}}) 125 | 126 | assert set(cfg.keys()) == {"d"} 127 | assert set(cfg["d"].keys()) == {"a", "b"} 128 | assert cfg["d"]["a"] == 10 129 | assert cfg["d"]["b"] == 2 130 | 131 | 132 | def test_add_config_dict_sequential(): 133 | # https://github.com/IDSIA/sacred/issues/409 134 | 135 | adict = ConfigDict(dict(dictnest2={"key_1": "value_1", "key_2": "value_2"})) 136 | 137 | bdict = ConfigDict( 138 | dict( 139 | dictnest2={"key_2": "update_value_2", "key_3": "value3", "key_4": "value4"} 140 | ) 141 | ) 142 | 143 | final_config = bdict(preset=adict()) 144 | assert final_config == { 145 | "dictnest2": { 146 | "key_1": "value_1", 147 | "key_2": "update_value_2", 148 | "key_3": "value3", 149 | "key_4": "value4", 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /tests/test_config/test_config_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import os 6 | import re 7 | import tempfile 8 | import pytest 9 | 10 | from sacred.config.config_files import HANDLER_BY_EXT, load_config_file 11 | 12 | data = {"foo": 42, "baz": [1, 0.2, "bar", True, {"some_number": -12, "simon": "hugo"}]} 13 | 14 | 15 | @pytest.mark.parametrize("handler", HANDLER_BY_EXT.values()) 16 | def test_save_and_load(handler): 17 | with tempfile.TemporaryFile("w+" + handler.mode) as f: 18 | handler.dump(data, f) 19 | f.seek(0) # simulates closing and reopening 20 | d = handler.load(f) 21 | assert d == data 22 | 23 | 24 | @pytest.mark.parametrize("ext, handler", HANDLER_BY_EXT.items()) 25 | def test_load_config_file(ext, handler): 26 | handle, f_name = tempfile.mkstemp(suffix=ext) 27 | f = os.fdopen(handle, "w" + handler.mode) 28 | handler.dump(data, f) 29 | f.close() 30 | d = load_config_file(f_name) 31 | assert d == data 32 | os.remove(f_name) 33 | 34 | 35 | def test_load_config_file_exception_msg_invalid_ext(): 36 | handle, f_name = tempfile.mkstemp(suffix=".invalid") 37 | f = os.fdopen(handle, "w") # necessary for windows 38 | f.close() 39 | try: 40 | exception_msg = re.compile( 41 | 'Configuration file ".*.invalid" has invalid or ' 42 | 'unsupported extension ".invalid".' 43 | ) 44 | with pytest.raises(ValueError) as excinfo: 45 | load_config_file(f_name) 46 | assert exception_msg.match(excinfo.value.args[0]) 47 | finally: 48 | os.remove(f_name) 49 | -------------------------------------------------------------------------------- /tests/test_config/test_config_scope_chain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | from sacred.config import ConfigScope, ConfigDict, chain_evaluate_config_scopes 7 | 8 | 9 | def test_chained_config_scopes_contain_combined_keys(): 10 | @ConfigScope 11 | def cfg1(): 12 | a = 10 13 | 14 | @ConfigScope 15 | def cfg2(): 16 | b = 20 17 | 18 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) 19 | assert set(final_cfg.keys()) == {"a", "b"} 20 | assert final_cfg["a"] == 10 21 | assert final_cfg["b"] == 20 22 | 23 | 24 | def test_chained_config_scopes_can_access_previous_keys(): 25 | @ConfigScope 26 | def cfg1(): 27 | a = 10 28 | 29 | @ConfigScope 30 | def cfg2(a): 31 | b = 2 * a 32 | 33 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) 34 | assert set(final_cfg.keys()) == {"a", "b"} 35 | assert final_cfg["a"] == 10 36 | 37 | 38 | def test_chained_config_scopes_can_modify_previous_keys(): 39 | @ConfigScope 40 | def cfg1(): 41 | a = 10 42 | b = 20 43 | 44 | @ConfigScope 45 | def cfg2(a): 46 | a *= 2 47 | b = 22 48 | 49 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2]) 50 | assert set(final_cfg.keys()) == {"a", "b"} 51 | assert final_cfg["a"] == 20 52 | assert final_cfg["b"] == 22 53 | 54 | 55 | def test_chained_config_scopes_raise_for_undeclared_previous_keys(): 56 | @ConfigScope 57 | def cfg1(): 58 | a = 10 59 | 60 | @ConfigScope 61 | def cfg2(): 62 | b = a * 2 63 | 64 | with pytest.raises(NameError): 65 | chain_evaluate_config_scopes([cfg1, cfg2]) 66 | 67 | 68 | def test_chained_config_scopes_cannot_modify_fixed(): 69 | @ConfigScope 70 | def cfg1(): 71 | c = 10 72 | a = c * 2 73 | 74 | @ConfigScope 75 | def cfg2(c): 76 | b = 4 * c 77 | c *= 3 78 | 79 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], fixed={"c": 5}) 80 | assert set(final_cfg.keys()) == {"a", "b", "c"} 81 | assert final_cfg["a"] == 10 82 | assert final_cfg["b"] == 20 83 | assert final_cfg["c"] == 5 84 | 85 | 86 | def test_chained_config_scopes_can_access_preset(): 87 | @ConfigScope 88 | def cfg1(c): 89 | a = 10 + c 90 | 91 | @ConfigScope 92 | def cfg2(a, c): 93 | b = a * 2 + c 94 | 95 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], preset={"c": 32}) 96 | assert set(final_cfg.keys()) == {"a", "b", "c"} 97 | assert final_cfg["a"] == 42 98 | assert final_cfg["b"] == 116 99 | assert final_cfg["c"] == 32 100 | 101 | 102 | def test_chained_config_scopes_can_access_fallback(): 103 | @ConfigScope 104 | def cfg1(c): 105 | a = 10 + c 106 | 107 | @ConfigScope 108 | def cfg2(a, c): 109 | b = a * 2 + c 110 | 111 | final_cfg, summary = chain_evaluate_config_scopes([cfg1, cfg2], fallback={"c": 32}) 112 | assert set(final_cfg.keys()) == {"a", "b"} 113 | assert final_cfg["a"] == 42 114 | assert final_cfg["b"] == 116 115 | 116 | 117 | def test_chained_config_scopes_fix_subentries(): 118 | @ConfigScope 119 | def cfg1(): 120 | d = {"a": 10, "b": 20} 121 | 122 | @ConfigScope 123 | def cfg2(): 124 | pass 125 | 126 | final_cfg, summary = chain_evaluate_config_scopes( 127 | [cfg1, cfg2], fixed={"d": {"a": 0}} 128 | ) 129 | assert set(final_cfg["d"].keys()) == {"a", "b"} 130 | assert final_cfg["d"]["a"] == 0 131 | assert final_cfg["d"]["b"] == 20 132 | 133 | 134 | def test_empty_chain_contains_preset_and_fixed(): 135 | final_cfg, summary = chain_evaluate_config_scopes( 136 | [], fixed={"a": 0}, preset={"a": 1, "b": 2} 137 | ) 138 | assert set(final_cfg.keys()) == {"a", "b"} 139 | assert final_cfg["a"] == 0 140 | assert final_cfg["b"] == 2 141 | 142 | 143 | def test_add_config_dict_sequential(): 144 | # https://github.com/IDSIA/sacred/issues/409 145 | @ConfigScope 146 | def cfg1(): 147 | dictnest2 = {"key_1": "value_1", "key_2": "value_2"} 148 | 149 | cfg1dict = ConfigDict(cfg1()) 150 | 151 | @ConfigScope 152 | def cfg2(): 153 | dictnest2 = {"key_2": "update_value_2", "key_3": "value3", "key_4": "value4"} 154 | 155 | cfg2dict = ConfigDict(cfg2()) 156 | final_config_scope, _ = chain_evaluate_config_scopes([cfg1, cfg2]) 157 | assert final_config_scope == { 158 | "dictnest2": { 159 | "key_1": "value_1", 160 | "key_2": "update_value_2", 161 | "key_3": "value3", 162 | "key_4": "value4", 163 | } 164 | } 165 | 166 | final_config_dict, _ = chain_evaluate_config_scopes([cfg1dict, cfg2dict]) 167 | assert final_config_dict == final_config_scope 168 | -------------------------------------------------------------------------------- /tests/test_config/test_dogmatic_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | from sacred.config.custom_containers import DogmaticDict 7 | 8 | 9 | def test_isinstance_of_dict(): 10 | assert isinstance(DogmaticDict(), dict) 11 | 12 | 13 | def test_dict_interface_initialized_empty(): 14 | d = DogmaticDict() 15 | assert d == {} 16 | assert set(d.keys()) == set() 17 | assert set(d.values()) == set() 18 | assert set(d.items()) == set() 19 | 20 | 21 | def test_dict_interface_set_item(): 22 | d = DogmaticDict() 23 | d["a"] = 12 24 | d["b"] = "foo" 25 | assert "a" in d 26 | assert "b" in d 27 | 28 | assert d["a"] == 12 29 | assert d["b"] == "foo" 30 | 31 | assert set(d.keys()) == {"a", "b"} 32 | assert set(d.values()) == {12, "foo"} 33 | assert set(d.items()) == {("a", 12), ("b", "foo")} 34 | 35 | 36 | def test_dict_interface_del_item(): 37 | d = DogmaticDict() 38 | d["a"] = 12 39 | del d["a"] 40 | assert "a" not in d 41 | 42 | 43 | def test_dict_interface_update_with_dict(): 44 | d = DogmaticDict() 45 | d["a"] = 12 46 | d["b"] = "foo" 47 | 48 | d.update({"a": 1, "c": 2}) 49 | assert d["a"] == 1 50 | assert d["b"] == "foo" 51 | assert d["c"] == 2 52 | 53 | 54 | def test_dict_interface_update_with_kwargs(): 55 | d = DogmaticDict() 56 | d["a"] = 12 57 | d["b"] = "foo" 58 | d.update(a=2, b=3) 59 | assert d["a"] == 2 60 | assert d["b"] == 3 61 | 62 | 63 | def test_dict_interface_update_with_list_of_items(): 64 | d = DogmaticDict() 65 | d["a"] = 12 66 | d["b"] = "foo" 67 | d.update([("b", 9), ("c", 7)]) 68 | assert d["a"] == 12 69 | assert d["b"] == 9 70 | assert d["c"] == 7 71 | 72 | 73 | def test_fixed_value_not_initialized(): 74 | d = DogmaticDict({"a": 7}) 75 | assert "a" not in d 76 | 77 | 78 | def test_fixed_value_fixed(): 79 | d = DogmaticDict({"a": 7}) 80 | d["a"] = 8 81 | assert d["a"] == 7 82 | 83 | del d["a"] 84 | assert "a" in d 85 | assert d["a"] == 7 86 | 87 | d.update([("a", 9), ("b", 12)]) 88 | assert d["a"] == 7 89 | 90 | d.update({"a": 9, "b": 12}) 91 | assert d["a"] == 7 92 | 93 | d.update(a=10, b=13) 94 | assert d["a"] == 7 95 | 96 | 97 | def test_revelation(): 98 | d = DogmaticDict({"a": 7, "b": 12}) 99 | d["b"] = 23 100 | assert "a" not in d 101 | m = d.revelation() 102 | assert set(m) == {"a"} 103 | assert "a" in d 104 | 105 | 106 | def test_fallback(): 107 | d = DogmaticDict(fallback={"a": 23}) 108 | assert "a" in d 109 | assert d["a"] == 23 110 | assert d.get("a") == 23 111 | 112 | d = DogmaticDict() 113 | d.fallback = {"a": 23} 114 | assert "a" in d 115 | assert d["a"] == 23 116 | assert d.get("a") == 23 117 | 118 | 119 | def test_fallback_not_iterated(): 120 | d = DogmaticDict(fallback={"a": 23}) 121 | d["b"] = 1234 122 | assert list(d.keys()) == ["b"] 123 | assert list(d.values()) == [1234] 124 | assert list(d.items()) == [("b", 1234)] 125 | 126 | 127 | def test_overwrite_fallback(): 128 | d = DogmaticDict(fallback={"a": 23}) 129 | d["a"] = 0 130 | assert d["a"] == 0 131 | assert list(d.keys()) == ["a"] 132 | assert list(d.values()) == [0] 133 | assert list(d.items()) == [("a", 0)] 134 | 135 | 136 | def test_fixed_has_precedence_over_fallback(): 137 | d = DogmaticDict(fixed={"a": 0}, fallback={"a": 23}) 138 | assert d["a"] == 0 139 | 140 | 141 | def test_nested_fixed_merges_with_fallback(): 142 | d = DogmaticDict(fixed={"foo": {"bar": 20}}, fallback={"foo": {"bar": 10, "c": 5}}) 143 | assert d["foo"]["bar"] == 20 144 | assert d["foo"]["c"] == 5 145 | 146 | 147 | def test_nested_fixed_with_fallback_madness(): 148 | d = DogmaticDict(fixed={"foo": {"bar": 20}}, fallback={"foo": {"bar": 10, "c": 5}}) 149 | d["foo"] = {"bar": 30, "a": 1} 150 | assert d["foo"]["bar"] == 20 151 | assert d["foo"]["a"] == 1 152 | assert d["foo"]["c"] == 5 153 | -------------------------------------------------------------------------------- /tests/test_config/test_dogmatic_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | from sacred.config.custom_containers import DogmaticDict, DogmaticList 7 | 8 | 9 | def test_isinstance_of_list(): 10 | assert isinstance(DogmaticList(), list) 11 | 12 | 13 | def test_init(): 14 | l = DogmaticList() 15 | assert l == [] 16 | 17 | l2 = DogmaticList([2, 3, 1]) 18 | assert l2 == [2, 3, 1] 19 | 20 | 21 | def test_append(): 22 | l = DogmaticList([1, 2]) 23 | l.append(3) 24 | l.append(4) 25 | assert l == [1, 2] 26 | 27 | 28 | def test_extend(): 29 | l = DogmaticList([1, 2]) 30 | l.extend([3, 4]) 31 | assert l == [1, 2] 32 | 33 | 34 | def test_insert(): 35 | l = DogmaticList([1, 2]) 36 | l.insert(1, 17) 37 | assert l == [1, 2] 38 | 39 | 40 | def test_pop(): 41 | l = DogmaticList([1, 2, 3]) 42 | with pytest.raises(TypeError): 43 | l.pop() 44 | assert l == [1, 2, 3] 45 | 46 | 47 | def test_sort(): 48 | l = DogmaticList([3, 1, 2]) 49 | l.sort() 50 | assert l == [3, 1, 2] 51 | 52 | 53 | def test_reverse(): 54 | l = DogmaticList([1, 2, 3]) 55 | l.reverse() 56 | assert l == [1, 2, 3] 57 | 58 | 59 | def test_setitem(): 60 | l = DogmaticList([1, 2, 3]) 61 | l[1] = 23 62 | assert l == [1, 2, 3] 63 | 64 | 65 | def test_setslice(): 66 | l = DogmaticList([1, 2, 3]) 67 | l[1:3] = [4, 5] 68 | assert l == [1, 2, 3] 69 | 70 | 71 | def test_delitem(): 72 | l = DogmaticList([1, 2, 3]) 73 | del l[1] 74 | assert l == [1, 2, 3] 75 | 76 | 77 | def test_delslice(): 78 | l = DogmaticList([1, 2, 3]) 79 | del l[1:] 80 | assert l == [1, 2, 3] 81 | 82 | 83 | def test_iadd(): 84 | l = DogmaticList([1, 2]) 85 | l += [3, 4] 86 | assert l == [1, 2] 87 | 88 | 89 | def test_imul(): 90 | l = DogmaticList([1, 2]) 91 | l *= 4 92 | assert l == [1, 2] 93 | 94 | 95 | def test_list_interface_getitem(): 96 | l = DogmaticList([0, 1, 2]) 97 | assert l[0] == 0 98 | assert l[1] == 1 99 | assert l[2] == 2 100 | 101 | assert l[-1] == 2 102 | assert l[-2] == 1 103 | assert l[-3] == 0 104 | 105 | 106 | def test_list_interface_len(): 107 | l = DogmaticList() 108 | assert len(l) == 0 109 | l = DogmaticList([0, 1, 2]) 110 | assert len(l) == 3 111 | 112 | 113 | def test_list_interface_count(): 114 | l = DogmaticList([1, 2, 4, 4, 5]) 115 | assert l.count(1) == 1 116 | assert l.count(3) == 0 117 | assert l.count(4) == 2 118 | 119 | 120 | def test_list_interface_index(): 121 | l = DogmaticList([1, 2, 4, 4, 5]) 122 | assert l.index(1) == 0 123 | assert l.index(4) == 2 124 | assert l.index(5) == 4 125 | with pytest.raises(ValueError): 126 | l.index(3) 127 | 128 | 129 | def test_empty_revelation(): 130 | l = DogmaticList([1, 2, 3]) 131 | assert l.revelation() == set() 132 | 133 | 134 | def test_nested_dict_revelation(): 135 | d1 = DogmaticDict({"a": 7, "b": 12}) 136 | d2 = DogmaticDict({"c": 7}) 137 | l = DogmaticList([d1, 2, d2]) 138 | # assert l.revelation() == {'0.a', '0.b', '2.c'} 139 | l.revelation() 140 | assert "a" in l[0] 141 | assert "b" in l[0] 142 | assert "c" in l[2] 143 | -------------------------------------------------------------------------------- /tests/test_config/test_fallback_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import pytest 5 | from sacred.config.custom_containers import fallback_dict 6 | 7 | 8 | @pytest.fixture 9 | def fbdict(): 10 | return fallback_dict({"fall1": 7, "fall3": True}) 11 | 12 | 13 | def test_is_dictionary(fbdict): 14 | assert isinstance(fbdict, dict) 15 | 16 | 17 | def test_getitem(fbdict): 18 | assert "foo" not in fbdict 19 | fbdict["foo"] = 23 20 | assert "foo" in fbdict 21 | assert fbdict["foo"] == 23 22 | 23 | 24 | def test_fallback(fbdict): 25 | assert "fall1" in fbdict 26 | assert fbdict["fall1"] == 7 27 | fbdict["fall1"] = 8 28 | assert fbdict["fall1"] == 8 29 | 30 | 31 | def test_get(fbdict): 32 | fbdict["a"] = "b" 33 | assert fbdict.get("a", 18) == "b" 34 | assert fbdict.get("fall1", 18) == 7 35 | assert fbdict.get("notexisting", 18) == 18 36 | assert fbdict.get("fall3", 18) is True 37 | -------------------------------------------------------------------------------- /tests/test_config/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import pytest 5 | from sacred import optional as opt 6 | from sacred.config.utils import normalize_or_die 7 | 8 | 9 | @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") 10 | @pytest.mark.parametrize( 11 | "typename", 12 | [ 13 | "bool_", 14 | "int_", 15 | "intc", 16 | "intp", 17 | "int8", 18 | "int16", 19 | "int32", 20 | "int64", 21 | "uint8", 22 | "uint16", 23 | "uint32", 24 | "uint64", 25 | "float16", 26 | "float32", 27 | "float64", 28 | ], 29 | ) 30 | def test_normalize_or_die_for_numpy_datatypes(typename): 31 | dtype = getattr(opt.np, typename) 32 | assert normalize_or_die(dtype(7.0)) 33 | 34 | 35 | @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") 36 | @pytest.mark.parametrize( 37 | "typename", 38 | [ 39 | "bool_", 40 | "int_", 41 | "intc", 42 | "intp", 43 | "int8", 44 | "int16", 45 | "int32", 46 | "int64", 47 | "uint8", 48 | "uint16", 49 | "uint32", 50 | "uint64", 51 | "float16", 52 | "float32", 53 | "float64", 54 | ], 55 | ) 56 | def test_normalize_or_die_for_numpy_arrays(typename): 57 | np = opt.np 58 | dtype = getattr(np, typename) 59 | a = np.array([0, 1, 2], dtype=dtype) 60 | b = normalize_or_die(a) 61 | assert len(b) == 3 62 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import logging 5 | 6 | 7 | # example_test will be parametrized by the test generation hook in conftest.py 8 | def test_example(capsys, example_test): 9 | # pytest adds a `LogCaptureHandler` to the root logger. This conflicts 10 | # with the sacred logger setup, so remove it 11 | logging.root.handlers = [] 12 | ex, call, out = example_test 13 | ex.run_commandline(call) 14 | captured_out, captured_err = capsys.readouterr() 15 | print(captured_out) 16 | print(captured_err) 17 | captured_out = captured_out.split("\n") 18 | captured_err = captured_err.split("\n") 19 | for out_line in out: 20 | assert out_line in [captured_out[0], captured_err[0]] 21 | if out_line == captured_out[0]: 22 | captured_out.pop(0) 23 | else: 24 | captured_err.pop(0) 25 | assert captured_out == [""] 26 | assert captured_err == [""] 27 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import re 5 | 6 | from sacred import Ingredient, Experiment 7 | from sacred.utils import ( 8 | CircularDependencyError, 9 | ConfigAddedError, 10 | MissingConfigError, 11 | NamedConfigNotFoundError, 12 | format_filtered_stacktrace, 13 | format_sacred_error, 14 | SacredError, 15 | ) 16 | 17 | """Global Docstring""" 18 | 19 | import pytest 20 | 21 | 22 | def test_circular_dependency_raises(): 23 | # create experiment with circular dependency 24 | ing = Ingredient("ing") 25 | ex = Experiment("exp", ingredients=[ing]) 26 | ex.main(lambda: None) 27 | ing.ingredients.append(ex) 28 | 29 | # run and see if it raises 30 | with pytest.raises(CircularDependencyError, match="exp->ing->exp"): 31 | ex.run() 32 | 33 | 34 | def test_config_added_raises(): 35 | ex = Experiment("exp") 36 | ex.main(lambda: None) 37 | 38 | with pytest.raises( 39 | ConfigAddedError, 40 | match=r"Added new config entry that is not used anywhere.*\n" 41 | r"\s*Conflicting configuration values:\n" 42 | r"\s*a=42", 43 | ): 44 | ex.run(config_updates={"a": 42}) 45 | 46 | 47 | def test_missing_config_raises(): 48 | ex = Experiment("exp") 49 | ex.main(lambda a: None) 50 | with pytest.raises(MissingConfigError): 51 | ex.run() 52 | 53 | 54 | def test_named_config_not_found_raises(): 55 | ex = Experiment("exp") 56 | ex.main(lambda: None) 57 | with pytest.raises( 58 | NamedConfigNotFoundError, 59 | match='Named config not found: "not_there". ' "Available config values are:", 60 | ): 61 | ex.run(named_configs=("not_there",)) 62 | 63 | 64 | def test_format_filtered_stacktrace_true(): 65 | ex = Experiment("exp") 66 | 67 | @ex.capture 68 | def f(): 69 | raise Exception() 70 | 71 | try: 72 | f() 73 | except Exception: 74 | st = format_filtered_stacktrace(filter_traceback="default") 75 | assert "captured_function" not in st 76 | assert "WITHOUT Sacred internals" in st 77 | 78 | try: 79 | f() 80 | except Exception: 81 | st = format_filtered_stacktrace(filter_traceback="always") 82 | assert "captured_function" not in st 83 | assert "WITHOUT Sacred internals" in st 84 | 85 | 86 | def test_format_filtered_stacktrace_false(): 87 | ex = Experiment("exp") 88 | 89 | @ex.capture 90 | def f(): 91 | raise Exception() 92 | 93 | try: 94 | f() 95 | except: 96 | st = format_filtered_stacktrace(filter_traceback="never") 97 | assert "captured_function" in st 98 | 99 | 100 | @pytest.mark.parametrize( 101 | "print_traceback,filter_traceback,print_usage,expected", 102 | [ 103 | (False, "never", False, ".*SacredError: message"), 104 | ( 105 | True, 106 | "never", 107 | False, 108 | r"Traceback \(most recent call last\):\n*" 109 | r'\s*File ".*", line \d*, in ' 110 | r"test_format_sacred_error\n*" 111 | r".*\n*" 112 | r".*SacredError: message", 113 | ), 114 | (False, "default", False, r".*SacredError: message"), 115 | (False, "always", False, r".*SacredError: message"), 116 | (False, "never", True, r"usage\n.*SacredError: message"), 117 | ( 118 | True, 119 | "default", 120 | False, 121 | r"Traceback \(most recent calls WITHOUT " 122 | r"Sacred internals\):\n*" 123 | r"(\n|.)*" 124 | r".*SacredError: message", 125 | ), 126 | ( 127 | True, 128 | "always", 129 | False, 130 | r"Traceback \(most recent calls WITHOUT " 131 | r"Sacred internals\):\n*" 132 | r"(\n|.)*" 133 | r".*SacredError: message", 134 | ), 135 | ], 136 | ) 137 | def test_format_sacred_error(print_traceback, filter_traceback, print_usage, expected): 138 | try: 139 | raise SacredError("message", print_traceback, filter_traceback, print_usage) 140 | except SacredError as e: 141 | st = format_sacred_error(e, "usage") 142 | assert re.match(expected, st, re.MULTILINE) 143 | 144 | 145 | def test_chained_error(): 146 | try: 147 | try: 148 | print(1 / 0) 149 | except Exception as e: 150 | raise SacredError("Something bad happened") from e 151 | except SacredError as e: 152 | st = format_sacred_error(e, "usage") 153 | assert re.match( 154 | r"Traceback \(most recent calls WITHOUT Sacred internals\):\s+File " 155 | + r"\"[^\"]+?test_exceptions.py\", line \d+, in test_chained_error\s+" 156 | + r"print\(1 / 0\)\n(\s+~~\^~~\s+)?ZeroDivisionError: division by zero\n\n" 157 | + r"The above exception was the direct cause of the following exception:\n" 158 | + r"\nTraceback \(most recent calls WITHOUT Sacred internals\):\s+File " 159 | + r"\"[^\"]+?test_exceptions.py\", line \d+, in test_chained_error\s+raise " 160 | + r"SacredError\(\"Something bad happened\"\) from e\nsacred.utils." 161 | + r"SacredError: Something bad happened\n", 162 | st, 163 | re.MULTILINE, 164 | ) 165 | -------------------------------------------------------------------------------- /tests/test_host_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import pytest 5 | 6 | from sacred.host_info import get_host_info, host_info_getter, host_info_gatherers 7 | 8 | 9 | def test_get_host_info(monkeypatch: pytest.MonkeyPatch): 10 | with monkeypatch.context() as cntx: 11 | cntx.setattr("sacred.settings.SETTINGS.HOST_INFO.INCLUDE_CPU_INFO", True) 12 | host_info = get_host_info() 13 | assert isinstance(host_info["hostname"], str) 14 | assert isinstance(host_info["cpu"], str) 15 | assert host_info["cpu"] != "Unknown" 16 | assert isinstance(host_info["os"], (tuple, list)) 17 | assert isinstance(host_info["python_version"], str) 18 | 19 | 20 | def test_host_info_decorator(): 21 | try: 22 | assert "greeting" not in host_info_gatherers 23 | 24 | @host_info_getter 25 | def greeting(): 26 | return "hello" 27 | 28 | assert "greeting" in host_info_gatherers 29 | assert host_info_gatherers["greeting"] == greeting 30 | assert get_host_info()["greeting"] == "hello" 31 | finally: 32 | del host_info_gatherers["greeting"] 33 | 34 | 35 | def test_host_info_decorator_with_name(): 36 | try: 37 | assert "foo" not in host_info_gatherers 38 | 39 | @host_info_getter(name="foo") 40 | def greeting(): 41 | return "hello" 42 | 43 | assert "foo" in host_info_gatherers 44 | assert "greeting" not in host_info_gatherers 45 | assert host_info_gatherers["foo"] == greeting 46 | assert get_host_info()["foo"] == "hello" 47 | finally: 48 | del host_info_gatherers["foo"] 49 | 50 | 51 | def test_host_info_decorator_depreciation_warning(): 52 | try: 53 | assert "foo" not in host_info_gatherers 54 | 55 | with pytest.warns(DeprecationWarning): 56 | 57 | @host_info_getter(name="foo") 58 | def greeting(): 59 | return "hello" 60 | 61 | finally: 62 | del host_info_gatherers["foo"] 63 | -------------------------------------------------------------------------------- /tests/test_metrics_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import datetime 5 | import pytest 6 | from sacred import Experiment 7 | from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics 8 | 9 | 10 | @pytest.fixture() 11 | def ex(): 12 | return Experiment("Test experiment") 13 | 14 | 15 | def test_log_scalar_metric_with_run(ex): 16 | START = 10 17 | END = 100 18 | STEP_SIZE = 5 19 | messages = {} 20 | 21 | @ex.main 22 | def main_function(_run): 23 | # First, make sure the queue is empty: 24 | assert len(ex.current_run._metrics.get_last_metrics()) == 0 25 | for i in range(START, END, STEP_SIZE): 26 | val = i * i 27 | _run.log_scalar("training.loss", val, i) 28 | messages["messages"] = ex.current_run._metrics.get_last_metrics() 29 | """Calling get_last_metrics clears the metrics logger internal queue. 30 | If we don't call it here, it would be called during Sacred heartbeat 31 | event after the run finishes, and the data we want to test would 32 | be lost.""" 33 | 34 | ex.run() 35 | assert ex.current_run is not None 36 | messages = messages["messages"] 37 | assert len(messages) == (END - START) / STEP_SIZE 38 | for i in range(len(messages) - 1): 39 | assert messages[i].step < messages[i + 1].step 40 | assert messages[i].step == START + i * STEP_SIZE 41 | assert messages[i].timestamp <= messages[i + 1].timestamp 42 | 43 | 44 | def test_log_scalar_metric_with_ex(ex): 45 | messages = {} 46 | START = 10 47 | END = 100 48 | STEP_SIZE = 5 49 | 50 | @ex.main 51 | def main_function(_run): 52 | for i in range(START, END, STEP_SIZE): 53 | val = i * i 54 | ex.log_scalar("training.loss", val, i) 55 | messages["messages"] = ex.current_run._metrics.get_last_metrics() 56 | 57 | ex.run() 58 | assert ex.current_run is not None 59 | messages = messages["messages"] 60 | assert len(messages) == (END - START) / STEP_SIZE 61 | for i in range(len(messages) - 1): 62 | assert messages[i].step < messages[i + 1].step 63 | assert messages[i].step == START + i * STEP_SIZE 64 | assert messages[i].timestamp <= messages[i + 1].timestamp 65 | 66 | 67 | def test_log_scalar_metric_with_implicit_step(ex): 68 | messages = {} 69 | 70 | @ex.main 71 | def main_function(_run): 72 | for i in range(10): 73 | val = i * i 74 | ex.log_scalar("training.loss", val) 75 | messages["messages"] = ex.current_run._metrics.get_last_metrics() 76 | 77 | ex.run() 78 | assert ex.current_run is not None 79 | messages = messages["messages"] 80 | assert len(messages) == 10 81 | for i in range(len(messages) - 1): 82 | assert messages[i].step < messages[i + 1].step 83 | assert messages[i].step == i 84 | assert messages[i].timestamp <= messages[i + 1].timestamp 85 | 86 | 87 | def test_log_scalar_metrics_with_implicit_step(ex): 88 | messages = {} 89 | 90 | @ex.main 91 | def main_function(_run): 92 | for i in range(10): 93 | val = i * i 94 | ex.log_scalar("training.loss", val) 95 | ex.log_scalar("training.accuracy", val + 1) 96 | messages["messages"] = ex.current_run._metrics.get_last_metrics() 97 | 98 | ex.run() 99 | assert ex.current_run is not None 100 | messages = messages["messages"] 101 | tr_loss_messages = [m for m in messages if m.name == "training.loss"] 102 | tr_acc_messages = [m for m in messages if m.name == "training.accuracy"] 103 | 104 | assert len(tr_loss_messages) == 10 105 | # both should have 10 records 106 | assert len(tr_acc_messages) == len(tr_loss_messages) 107 | for i in range(len(tr_loss_messages) - 1): 108 | assert tr_loss_messages[i].step < tr_loss_messages[i + 1].step 109 | assert tr_loss_messages[i].step == i 110 | assert tr_loss_messages[i].timestamp <= tr_loss_messages[i + 1].timestamp 111 | 112 | assert tr_acc_messages[i].step < tr_acc_messages[i + 1].step 113 | assert tr_acc_messages[i].step == i 114 | assert tr_acc_messages[i].timestamp <= tr_acc_messages[i + 1].timestamp 115 | 116 | 117 | def test_linearize_metrics(): 118 | entries = [ 119 | ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 100), 120 | ScalarMetricLogEntry("training.accuracy", 5, datetime.datetime.utcnow(), 50), 121 | ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 200), 122 | ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), 123 | ScalarMetricLogEntry("training.accuracy", 15, datetime.datetime.utcnow(), 150), 124 | ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), 125 | ] 126 | linearized = linearize_metrics(entries) 127 | assert type(linearized) == dict 128 | assert len(linearized.keys()) == 2 129 | assert "training.loss" in linearized 130 | assert "training.accuracy" in linearized 131 | assert len(linearized["training.loss"]["steps"]) == 2 132 | assert len(linearized["training.loss"]["values"]) == 2 133 | assert len(linearized["training.loss"]["timestamps"]) == 2 134 | assert len(linearized["training.accuracy"]["steps"]) == 4 135 | assert len(linearized["training.accuracy"]["values"]) == 4 136 | assert len(linearized["training.accuracy"]["timestamps"]) == 4 137 | assert linearized["training.accuracy"]["steps"] == [5, 10, 15, 30] 138 | assert linearized["training.accuracy"]["values"] == [50, 100, 150, 300] 139 | assert linearized["training.loss"]["steps"] == [10, 20] 140 | assert linearized["training.loss"]["values"] == [100, 200] 141 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from sacred.config.config_scope import ConfigScope 5 | from sacred.experiment import Experiment, Ingredient 6 | 7 | 8 | def test_ingredient_config(): 9 | m = Ingredient("somemod") 10 | 11 | @m.config 12 | def cfg(): 13 | a = 5 14 | b = "foo" 15 | 16 | assert len(m.configurations) == 1 17 | cfg = m.configurations[0] 18 | assert isinstance(cfg, ConfigScope) 19 | assert cfg() == {"a": 5, "b": "foo"} 20 | 21 | 22 | def test_ingredient_captured_functions(): 23 | m = Ingredient("somemod") 24 | 25 | @m.capture 26 | def get_answer(b): 27 | return b 28 | 29 | assert len(m.captured_functions) == 1 30 | f = m.captured_functions[0] 31 | assert f == get_answer 32 | 33 | 34 | def test_ingredient_command(): 35 | m = Ingredient("somemod") 36 | 37 | m.add_config(a=42, b="foo{}") 38 | 39 | @m.command 40 | def transmogrify(a, b): 41 | return b.format(a) 42 | 43 | assert "transmogrify" in m.commands 44 | assert m.commands["transmogrify"] == transmogrify 45 | ex = Experiment("foo", ingredients=[m]) 46 | 47 | assert ex.run("somemod.transmogrify").result == "foo42" 48 | 49 | 50 | # ############# Experiment #################################################### 51 | 52 | 53 | def test_experiment_run(): 54 | ex = Experiment("some_experiment") 55 | 56 | @ex.main 57 | def main(): 58 | return 12 59 | 60 | assert ex.run().result == 12 61 | 62 | 63 | def test_experiment_run_access_subingredient(): 64 | somemod = Ingredient("somemod") 65 | 66 | @somemod.config 67 | def cfg(): 68 | a = 5 69 | b = "foo" 70 | 71 | ex = Experiment("some_experiment", ingredients=[somemod]) 72 | 73 | @ex.main 74 | def main(somemod): 75 | return somemod 76 | 77 | r = ex.run().result 78 | assert r["a"] == 5 79 | assert r["b"] == "foo" 80 | 81 | 82 | def test_experiment_run_subingredient_function(): 83 | somemod = Ingredient("somemod") 84 | 85 | @somemod.config 86 | def cfg(): 87 | a = 5 88 | b = "foo" 89 | 90 | @somemod.capture 91 | def get_answer(b): 92 | return b 93 | 94 | ex = Experiment("some_experiment", ingredients=[somemod]) 95 | 96 | @ex.main 97 | def main(): 98 | return get_answer() 99 | 100 | assert ex.run().result == "foo" 101 | 102 | 103 | def test_experiment_named_config_subingredient(): 104 | somemod = Ingredient("somemod") 105 | 106 | @somemod.config 107 | def sub_cfg(): 108 | a = 15 109 | 110 | @somemod.capture 111 | def get_answer(a): 112 | return a 113 | 114 | @somemod.named_config 115 | def nsubcfg(): 116 | a = 16 117 | 118 | ex = Experiment("some_experiment", ingredients=[somemod]) 119 | 120 | @ex.config 121 | def cfg(): 122 | a = 1 123 | 124 | @ex.named_config 125 | def ncfg(): 126 | a = 2 127 | somemod = {"a": 25} 128 | 129 | @ex.main 130 | def main(a): 131 | return a, get_answer() 132 | 133 | assert ex.run().result == (1, 15) 134 | assert ex.run(named_configs=["somemod.nsubcfg"]).result == (1, 16) 135 | assert ex.run(named_configs=["ncfg"]).result == (2, 25) 136 | assert ex.run(named_configs=["ncfg", "somemod.nsubcfg"]).result == (2, 16) 137 | assert ex.run(named_configs=["somemod.nsubcfg", "ncfg"]).result == (2, 25) 138 | 139 | 140 | def test_experiment_named_config_subingredient_overwrite(): 141 | somemod = Ingredient("somemod") 142 | 143 | @somemod.capture 144 | def get_answer(a): 145 | return a 146 | 147 | ex = Experiment("some_experiment", ingredients=[somemod]) 148 | 149 | @ex.named_config 150 | def ncfg(): 151 | somemod = {"a": 1} 152 | 153 | @ex.main 154 | def main(): 155 | return get_answer() 156 | 157 | assert ex.run(named_configs=["ncfg"]).result == 1 158 | assert ex.run(config_updates={"somemod": {"a": 2}}).result == 2 159 | assert ( 160 | ex.run(named_configs=["ncfg"], config_updates={"somemod": {"a": 2}}).result == 2 161 | ) 162 | 163 | 164 | def test_experiment_double_named_config(): 165 | ex = Experiment() 166 | 167 | @ex.config 168 | def config(): 169 | a = 0 170 | d = {"e": 0, "f": 0} 171 | 172 | @ex.named_config 173 | def A(): 174 | a = 2 175 | d = {"e": 2, "f": 2} 176 | 177 | @ex.named_config 178 | def B(): 179 | d = {"f": -1} 180 | 181 | @ex.main 182 | def run(a, d): 183 | return a, d["e"], d["f"] 184 | 185 | assert ex.run().result == (0, 0, 0) 186 | assert ex.run(named_configs=["A"]).result == (2, 2, 2) 187 | assert ex.run(named_configs=["B"]).result == (0, 0, -1) 188 | assert ex.run(named_configs=["A", "B"]).result == (2, 2, -1) 189 | assert ex.run(named_configs=["B", "A"]).result == (2, 2, 2) 190 | 191 | 192 | def test_double_nested_config(): 193 | sub_sub_ing = Ingredient("sub_sub_ing") 194 | sub_ing = Ingredient("sub_ing", [sub_sub_ing]) 195 | ing = Ingredient("ing", [sub_ing]) 196 | ex = Experiment("ex", [ing]) 197 | 198 | @ex.config 199 | def config(): 200 | a = 1 201 | seed = 42 202 | 203 | @ing.config 204 | def config(): 205 | b = 1 206 | 207 | @sub_ing.config 208 | def config(): 209 | c = 2 210 | 211 | @sub_sub_ing.config 212 | def config(): 213 | d = 3 214 | 215 | @sub_sub_ing.capture 216 | def sub_sub_ing_main(_config): 217 | assert _config == {"d": 3}, _config 218 | 219 | @sub_ing.capture 220 | def sub_ing_main(_config): 221 | assert _config == {"c": 2, "sub_sub_ing": {"d": 3}}, _config 222 | 223 | @ing.capture 224 | def ing_main(_config): 225 | assert _config == { 226 | "b": 1, 227 | "sub_sub_ing": {"d": 3}, 228 | "sub_ing": {"c": 2}, 229 | }, _config 230 | 231 | @ex.main 232 | def main(_config): 233 | assert _config == { 234 | "a": 1, 235 | "sub_sub_ing": {"d": 3}, 236 | "sub_ing": {"c": 2}, 237 | "ing": {"b": 1}, 238 | "seed": 42, 239 | }, _config 240 | 241 | ing_main() 242 | sub_ing_main() 243 | sub_sub_ing_main() 244 | 245 | ex.run() 246 | -------------------------------------------------------------------------------- /tests/test_observers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | -------------------------------------------------------------------------------- /tests/test_observers/test_mongo_option.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | 5 | import pytest 6 | 7 | pymongo = pytest.importorskip("pymongo") 8 | 9 | from sacred.observers.mongo import parse_mongo_db_arg, DEFAULT_MONGO_PRIORITY 10 | 11 | 12 | def test_parse_mongo_db_arg(): 13 | assert parse_mongo_db_arg("foo") == {"db_name": "foo"} 14 | 15 | 16 | def test_parse_mongo_db_arg_collection(): 17 | kwargs = parse_mongo_db_arg("foo.bar") 18 | assert kwargs == {"db_name": "foo", "collection": "bar"} 19 | 20 | 21 | def test_parse_mongo_db_arg_hostname(): 22 | assert parse_mongo_db_arg("localhost:28017") == {"url": "localhost:28017"} 23 | 24 | assert parse_mongo_db_arg("www.mymongo.db:28017") == {"url": "www.mymongo.db:28017"} 25 | 26 | assert parse_mongo_db_arg("123.45.67.89:27017") == {"url": "123.45.67.89:27017"} 27 | 28 | 29 | def test_parse_mongo_db_arg_hostname_dbname(): 30 | assert parse_mongo_db_arg("localhost:28017:foo") == { 31 | "url": "localhost:28017", 32 | "db_name": "foo", 33 | } 34 | 35 | assert parse_mongo_db_arg("www.mymongo.db:28017:bar") == { 36 | "url": "www.mymongo.db:28017", 37 | "db_name": "bar", 38 | } 39 | 40 | assert parse_mongo_db_arg("123.45.67.89:27017:baz") == { 41 | "url": "123.45.67.89:27017", 42 | "db_name": "baz", 43 | } 44 | 45 | 46 | def test_parse_mongo_db_arg_hostname_dbname_collection_name(): 47 | assert parse_mongo_db_arg("localhost:28017:foo.bar") == { 48 | "url": "localhost:28017", 49 | "db_name": "foo", 50 | "collection": "bar", 51 | } 52 | 53 | assert parse_mongo_db_arg("www.mymongo.db:28017:bar.baz") == { 54 | "url": "www.mymongo.db:28017", 55 | "db_name": "bar", 56 | "collection": "baz", 57 | } 58 | 59 | assert parse_mongo_db_arg("123.45.67.89:27017:baz.foo") == { 60 | "url": "123.45.67.89:27017", 61 | "db_name": "baz", 62 | "collection": "foo", 63 | } 64 | 65 | 66 | def test_parse_mongo_db_arg_priority(): 67 | assert parse_mongo_db_arg("localhost:28017:foo.bar!17") == { 68 | "url": "localhost:28017", 69 | "db_name": "foo", 70 | "collection": "bar", 71 | "priority": 17, 72 | } 73 | 74 | assert parse_mongo_db_arg("www.mymongo.db:28017:bar.baz!2") == { 75 | "url": "www.mymongo.db:28017", 76 | "db_name": "bar", 77 | "collection": "baz", 78 | "priority": 2, 79 | } 80 | 81 | assert parse_mongo_db_arg("123.45.67.89:27017:baz.foo!-123") == { 82 | "url": "123.45.67.89:27017", 83 | "db_name": "baz", 84 | "collection": "foo", 85 | "priority": -123, 86 | } 87 | -------------------------------------------------------------------------------- /tests/test_observers/test_queue_observer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from sacred.observers.queue import QueueObserver 4 | from sacred import Experiment 5 | import mock 6 | import pytest 7 | 8 | 9 | @pytest.fixture 10 | def queue_observer(): 11 | return QueueObserver(mock.MagicMock(), interval=0.01, retry_interval=0.01) 12 | 13 | 14 | def test_started_event(queue_observer): 15 | queue_observer.started_event("args", kwds="kwargs") 16 | assert queue_observer._worker.is_alive() 17 | queue_observer.join() 18 | assert queue_observer._covered_observer.method_calls[0][0] == "started_event" 19 | assert queue_observer._covered_observer.method_calls[0][1] == ("args",) 20 | assert queue_observer._covered_observer.method_calls[0][2] == {"kwds": "kwargs"} 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "event_name", ["heartbeat_event", "resource_event", "artifact_event"] 25 | ) 26 | def test_non_terminal_generic_events(queue_observer, event_name): 27 | queue_observer.started_event() 28 | getattr(queue_observer, event_name)("args", kwds="kwargs") 29 | queue_observer.join() 30 | assert queue_observer._covered_observer.method_calls[1][0] == event_name 31 | assert queue_observer._covered_observer.method_calls[1][1] == ("args",) 32 | assert queue_observer._covered_observer.method_calls[1][2] == {"kwds": "kwargs"} 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "event_name", ["completed_event", "interrupted_event", "failed_event"] 37 | ) 38 | def test_terminal_generic_events(queue_observer, event_name): 39 | queue_observer.started_event() 40 | getattr(queue_observer, event_name)("args", kwds="kwargs") 41 | assert queue_observer._covered_observer.method_calls[1][0] == event_name 42 | assert queue_observer._covered_observer.method_calls[1][1] == ("args",) 43 | assert queue_observer._covered_observer.method_calls[1][2] == {"kwds": "kwargs"} 44 | assert not queue_observer._worker.is_alive() 45 | 46 | 47 | def test_log_metrics(queue_observer): 48 | queue_observer.started_event() 49 | first = ("a", [1]) 50 | second = ("b", [2]) 51 | queue_observer.log_metrics(OrderedDict([first, second]), "info") 52 | queue_observer.join() 53 | assert queue_observer._covered_observer.method_calls[1][0] == "log_metrics" 54 | assert queue_observer._covered_observer.method_calls[1][1] == ( 55 | first[0], 56 | first[1], 57 | "info", 58 | ) 59 | assert queue_observer._covered_observer.method_calls[1][2] == {} 60 | assert queue_observer._covered_observer.method_calls[2][0] == "log_metrics" 61 | assert queue_observer._covered_observer.method_calls[2][1] == ( 62 | second[0], 63 | second[1], 64 | "info", 65 | ) 66 | assert queue_observer._covered_observer.method_calls[2][2] == {} 67 | 68 | 69 | def test_run_waits_for_running_queue_observer(): 70 | 71 | queue_observer_with_long_interval = QueueObserver( 72 | mock.MagicMock(), interval=1, retry_interval=0.01 73 | ) 74 | 75 | ex = Experiment("ator3000") 76 | ex.observers.append(queue_observer_with_long_interval) 77 | 78 | @ex.main 79 | def main(): 80 | print("do nothing") 81 | 82 | ex.run() 83 | assert ( 84 | queue_observer_with_long_interval._covered_observer.method_calls[-1][0] 85 | == "completed_event" 86 | ) 87 | 88 | 89 | def test_run_waits_for_running_queue_observer_after_failure(): 90 | 91 | queue_observer_with_long_interval = QueueObserver( 92 | mock.MagicMock(), interval=1, retry_interval=0.01 93 | ) 94 | 95 | ex = Experiment("ator3000") 96 | ex.observers.append(queue_observer_with_long_interval) 97 | 98 | @ex.main 99 | def main(): 100 | raise Exception("fatal error") 101 | 102 | try: 103 | ex.run() 104 | except: 105 | pass 106 | 107 | assert ( 108 | queue_observer_with_long_interval._covered_observer.method_calls[-1][0] 109 | == "failed_event" 110 | ) 111 | -------------------------------------------------------------------------------- /tests/test_observers/test_run_observer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | from datetime import datetime 5 | from sacred.observers.base import RunObserver 6 | 7 | 8 | def test_run_observer(): 9 | # basically to silence coverage 10 | r = RunObserver() 11 | assert ( 12 | r.started_event({}, "run", {}, datetime.utcnow(), {}, "comment", None) is None 13 | ) 14 | assert r.heartbeat_event({}, "", datetime.utcnow(), "result") is None 15 | assert r.completed_event(datetime.utcnow(), 123) is None 16 | assert r.interrupted_event(datetime.utcnow(), "INTERRUPTED") is None 17 | assert r.failed_event(datetime.utcnow(), "trace") is None 18 | assert r.artifact_event("foo", "foo.txt") is None 19 | assert r.resource_event("foo.txt") is None 20 | -------------------------------------------------------------------------------- /tests/test_observers/test_sql_observer_not_installed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sacred.optional import has_sqlalchemy 3 | from sacred.observers import SqlObserver 4 | from sacred import Experiment 5 | 6 | 7 | @pytest.fixture 8 | def ex(): 9 | return Experiment("ator3000") 10 | 11 | 12 | @pytest.mark.skipif(has_sqlalchemy, reason="We are testing the import error.") 13 | def test_importerror_sql(ex): 14 | with pytest.raises(ImportError): 15 | ex.observers.append(SqlObserver("some_uri")) 16 | 17 | @ex.config 18 | def cfg(): 19 | a = {"b": 1} 20 | 21 | @ex.main 22 | def foo(a): 23 | return a["b"] 24 | 25 | ex.run() 26 | -------------------------------------------------------------------------------- /tests/test_observers/test_tinydb_observer_not_installed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sacred.optional import has_tinydb 3 | from sacred.observers import TinyDbObserver 4 | from sacred import Experiment 5 | 6 | 7 | @pytest.fixture 8 | def ex(): 9 | return Experiment("ator3000") 10 | 11 | 12 | @pytest.mark.skipif(has_tinydb, reason="We are testing the import error.") 13 | def test_importerror_sql(ex): 14 | with pytest.raises(ImportError): 15 | ex.observers.append(TinyDbObserver.create("some_uri")) 16 | 17 | @ex.config 18 | def cfg(): 19 | a = {"b": 1} 20 | 21 | @ex.main 22 | def foo(a): 23 | return a["b"] 24 | 25 | ex.run() 26 | -------------------------------------------------------------------------------- /tests/test_optional.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import pytest 5 | from sacred.optional import optional_import, get_tensorflow, modules_exist 6 | 7 | 8 | def test_optional_import(): 9 | has_pytest, pyt = optional_import("pytest") 10 | assert has_pytest 11 | assert pyt == pytest 12 | 13 | 14 | def test_optional_import_nonexisting(): 15 | has_nonex, nonex = optional_import("clearlynonexistingpackage") 16 | assert not has_nonex 17 | assert nonex is None 18 | 19 | 20 | def test_get_tensorflow(): 21 | """Test that get_tensorflow() runs without error.""" 22 | pytest.importorskip("tensorflow") 23 | get_tensorflow() 24 | 25 | 26 | def test_module_exists_for_tensorflow(): 27 | """Check that module_exist returns true if tf is there.""" 28 | pytest.importorskip("tensorflow") 29 | assert modules_exist("tensorflow") 30 | -------------------------------------------------------------------------------- /tests/test_serializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import pytest 5 | 6 | from sacred.serializer import flatten, restore 7 | import sacred.optional as opt 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "obj", 12 | [ 13 | 12, 14 | 3.14, 15 | "mystring", 16 | "αβγδ", 17 | [1, 2.0, "3", [4]], 18 | {"foo": "bar", "answer": 42}, 19 | None, 20 | True, 21 | ], 22 | ) 23 | def test_flatten_on_json_is_noop(obj): 24 | assert flatten(obj) == obj 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "obj", 29 | [ 30 | 12, 31 | 3.14, 32 | "mystring", 33 | "αβγδ", 34 | [1, 2.0, "3", [4]], 35 | {"foo": "bar", "answer": 42}, 36 | None, 37 | True, 38 | ], 39 | ) 40 | def test_restore_on_json_is_noop(obj): 41 | assert flatten(obj) == obj 42 | 43 | 44 | def test_serialize_non_str_keys(): 45 | d = {1: "one", 2: "two"} 46 | assert restore(flatten(d)) == d 47 | 48 | 49 | @pytest.mark.skipif(not opt.has_numpy, reason="requires numpy") 50 | def test_serialize_numpy_arrays(): 51 | a = opt.np.array([[1, 2, 3], [4, 5, 6]], dtype=opt.np.float32) 52 | b = restore(flatten(a)) 53 | assert opt.np.all(b == a) 54 | assert b.dtype == a.dtype 55 | assert b.shape == a.shape 56 | 57 | 58 | def test_serialize_tuples(): 59 | t = (1, "two") 60 | assert restore(flatten(t)) == t 61 | assert isinstance(restore(flatten(t)), tuple) 62 | 63 | 64 | @pytest.mark.skipif(not opt.has_pandas, reason="requires pandas") 65 | def test_serialize_pandas_dataframes(): 66 | pd, np = opt.pandas, opt.np 67 | df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=list("ABCD")) 68 | b = restore(flatten(df)) 69 | assert np.all(df == b) 70 | assert np.all(df.dtypes == b.dtypes) 71 | -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from sacred import SETTINGS as DEFAULT_SETTINGS 3 | 4 | import pytest 5 | 6 | from sacred.settings import SettingError 7 | 8 | 9 | def test_access_invalid_setting(): 10 | SETTINGS = copy.deepcopy(DEFAULT_SETTINGS) 11 | with pytest.raises(SettingError): 12 | SETTINGS.INVALID_SETTING = "invalid" 13 | with pytest.raises(SettingError): 14 | SETTINGS["INVALID_SETTING"] = "invalid" 15 | 16 | 17 | def test_overwrite_collection(): 18 | SETTINGS = copy.deepcopy(DEFAULT_SETTINGS) 19 | with pytest.raises(SettingError): 20 | SETTINGS.CONFIG = "invalid" 21 | with pytest.raises(SettingError): 22 | SETTINGS["CONFIG"] = "invalid" 23 | 24 | 25 | def test_access_valid_setting(): 26 | SETTINGS = copy.deepcopy(DEFAULT_SETTINGS) 27 | SETTINGS.CONFIG.READ_ONLY_CONFIG = True 28 | assert SETTINGS.CONFIG.READ_ONLY_CONFIG 29 | -------------------------------------------------------------------------------- /tests/test_stdout_capturing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import sys 6 | import pytest 7 | from sacred.stdout_capturing import get_stdcapturer 8 | from sacred.optional import libc 9 | 10 | 11 | def test_python_tee_output(capsys): 12 | expected_lines = {"captured stdout", "captured stderr"} 13 | 14 | capture_mode, capture_stdout = get_stdcapturer("sys") 15 | with capsys.disabled(): 16 | print("before (stdout)") 17 | print("before (stderr)") 18 | with capture_stdout() as out: 19 | print("captured stdout") 20 | print("captured stderr") 21 | output = out.get() 22 | 23 | print("after (stdout)") 24 | print("after (stderr)") 25 | 26 | assert set(output.strip().split("\n")) == expected_lines 27 | 28 | 29 | @pytest.mark.skipif(sys.platform.startswith("win"), reason="does not run on windows") 30 | def test_fd_tee_output(capsys): 31 | expected_lines = { 32 | "captured stdout", 33 | "captured stderr", 34 | "stdout from C", 35 | "and this is from echo", 36 | "keep\rcarriage\rreturns", 37 | } 38 | 39 | capture_mode, capture_stdout = get_stdcapturer("fd") 40 | output = "" 41 | with capsys.disabled(): 42 | print("before (stdout)") 43 | print("before (stderr)") 44 | with capture_stdout() as out: 45 | print("captured stdout") 46 | print("captured stderr", file=sys.stderr) 47 | print("keep\rcarriage\rreturns") 48 | output += out.get() 49 | libc.puts(b"stdout from C") 50 | libc.fflush(None) 51 | os.system("echo and this is from echo") 52 | output += out.get() 53 | 54 | output += out.get() 55 | 56 | print("after (stdout)") 57 | print("after (stderr)") 58 | 59 | assert set(output.strip().split("\n")) == expected_lines 60 | -------------------------------------------------------------------------------- /tests/test_stflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDSIA/sacred/e4e5fd8c6c0a6f365ae9fe670fb691510f93be4c/tests/test_stflow/__init__.py -------------------------------------------------------------------------------- /tests/test_stflow/test_internal.py: -------------------------------------------------------------------------------- 1 | from sacred.stflow.internal import ContextMethodDecorator 2 | 3 | 4 | def test_context_method_decorator(): 5 | """ 6 | Ensure that ContextMethodDecorator can intercept method calls. 7 | """ 8 | 9 | class FooClass: 10 | def __init__(self, x): 11 | self.x = x 12 | 13 | def do_foo(self, y, z): 14 | print("foo") 15 | print(y) 16 | print(z) 17 | return y * self.x + z 18 | 19 | def decorate_three_times(instance, original_method, original_args, original_kwargs): 20 | print("three_times") 21 | print(original_args) 22 | print(original_kwargs) 23 | return original_method(instance, *original_args, **original_kwargs) * 3 24 | 25 | with ContextMethodDecorator(FooClass, "do_foo", decorate_three_times): 26 | foo = FooClass(10) 27 | assert foo.do_foo(5, 6) == (5 * 10 + 6) * 3 28 | assert foo.do_foo(5, z=6) == (5 * 10 + 6) * 3 29 | assert foo.do_foo(y=5, z=6) == (5 * 10 + 6) * 3 30 | assert foo.do_foo(5, 6) == (5 * 10 + 6) 31 | assert foo.do_foo(5, z=6) == (5 * 10 + 6) 32 | assert foo.do_foo(y=5, z=6) == (5 * 10 + 6) 33 | 34 | def decorate_three_times_with_exception( 35 | instance, original_method, original_args, original_kwargs 36 | ): 37 | raise RuntimeError("This should be caught") 38 | 39 | exception = False 40 | try: 41 | with ContextMethodDecorator( 42 | FooClass, "do_foo", decorate_three_times_with_exception 43 | ): 44 | foo = FooClass(10) 45 | this_should_raise_exception = foo.do_foo(5, 6) 46 | except RuntimeError: 47 | exception = True 48 | assert foo.do_foo(5, 6) == (5 * 10 + 6) 49 | assert exception is True 50 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = py{38,39,310,311}, setup, numpy-{120,121,123,200}, tensorflow-{212,216} 8 | 9 | [testenv] 10 | deps = 11 | -rdev-requirements.txt 12 | commands = 13 | pytest \ 14 | {posargs} # substitute with tox' positional arguments 15 | 16 | [testenv:numpy-120] 17 | basepython = python 18 | deps = 19 | -rdev-requirements.txt 20 | numpy~=1.20.0 21 | commands = 22 | pytest tests/test_config {posargs} 23 | 24 | [testenv:numpy-121] 25 | basepython = python 26 | deps = 27 | -rdev-requirements.txt 28 | numpy~=1.21.0 29 | commands = 30 | pytest tests/test_config {posargs} 31 | 32 | [testenv:numpy-122] 33 | basepython = python 34 | deps = 35 | -rdev-requirements.txt 36 | numpy~=1.22.0 37 | commands = 38 | pytest tests/test_config {posargs} 39 | 40 | [testenv:numpy-123] 41 | basepython = python 42 | deps = 43 | -rdev-requirements.txt 44 | numpy~=1.23.0 45 | commands = 46 | pytest tests/test_config {posargs} 47 | 48 | [testenv:numpy-124] 49 | basepython = python 50 | deps = 51 | -rdev-requirements.txt 52 | numpy~=1.24.0rc1 53 | commands = 54 | pytest tests/test_config {posargs} 55 | 56 | [testenv:numpy-200] 57 | basepython = python 58 | deps = 59 | -rdev-requirements.txt 60 | numpy~=2.0.0 61 | commands = 62 | pytest tests/test_config {posargs} 63 | 64 | [testenv:tensorflow-212] 65 | basepython = python 66 | deps = 67 | -rdev-requirements.txt 68 | numpy<2.0.0 69 | tensorflow~=2.12.0 70 | commands = 71 | pytest tests/test_stflow tests/test_optional.py \ 72 | {posargs} 73 | 74 | 75 | [testenv:tensorflow-216] 76 | basepython = python 77 | deps = 78 | -rdev-requirements.txt 79 | tensorflow~=2.16.0 80 | commands = 81 | pytest tests/test_stflow tests/test_optional.py \ 82 | {posargs} 83 | 84 | 85 | [testenv:setup] 86 | basepython = python 87 | deps = 88 | pytest==7.1.2 89 | mock 90 | commands = 91 | pytest {posargs} 92 | 93 | [testenv:coverage] 94 | passenv = TRAVIS, TRAVIS_* 95 | basepython = python 96 | deps = 97 | -rdev-requirements.txt 98 | pytest-cov 99 | coveralls 100 | 101 | commands = 102 | pytest \ 103 | --cov sacred \ 104 | {posargs} 105 | - coveralls 106 | --------------------------------------------------------------------------------