├── .cruft.json ├── .gitattributes ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── codecov.yml ├── dependabot.yml └── workflows │ ├── release.yml │ ├── tests.yml │ └── update-template.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.rst ├── CONTRIBUTING.rst ├── LICENSE ├── README.md ├── docs ├── dvc_plots_diff.png ├── studio_compare.png ├── vscode_experiments.png └── vscode_plots.png ├── examples ├── DVCLive-Evidently.ipynb ├── DVCLive-Fabric.ipynb ├── DVCLive-HuggingFace.ipynb ├── DVCLive-PyTorch-Lightning.ipynb ├── DVCLive-Quickstart.ipynb ├── DVCLive-YOLO.ipynb └── DVCLive-scikit-learn.ipynb ├── noxfile.py ├── pyproject.toml ├── src └── dvclive │ ├── __init__.py │ ├── dvc.py │ ├── env.py │ ├── error.py │ ├── fabric.py │ ├── fastai.py │ ├── huggingface.py │ ├── keras.py │ ├── lgbm.py │ ├── lightning.py │ ├── live.py │ ├── monitor_system.py │ ├── optuna.py │ ├── plots │ ├── __init__.py │ ├── base.py │ ├── custom.py │ ├── image.py │ ├── metric.py │ ├── sklearn.py │ └── utils.py │ ├── py.typed │ ├── report.py │ ├── serialize.py │ ├── studio.py │ ├── utils.py │ ├── vscode.py │ └── xgb.py └── tests ├── __init__.py ├── conftest.py ├── frameworks ├── test_fabric.py ├── test_fastai.py ├── test_huggingface.py ├── test_keras.py ├── test_lgbm.py ├── test_lightning.py ├── test_optuna.py └── test_xgboost.py ├── plots ├── test_custom.py ├── test_image.py ├── test_metric.py └── test_sklearn.py ├── test_cleanup.py ├── test_context_manager.py ├── test_dvc.py ├── test_log_artifact.py ├── test_log_metric.py ├── test_log_param.py ├── test_logging.py ├── test_make_dvcyaml.py ├── test_make_report.py ├── test_make_summary.py ├── test_monitor_system.py ├── test_post_to_studio.py ├── test_resume.py ├── test_step.py ├── test_utils.py └── test_vscode.py /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/iterative/py-template", 3 | "commit": "e4ec95f4cfd03d4af0a8604d462ee11d07d63b42", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "dvclive", 8 | "package_name": "dvclive", 9 | "friendly_name": "dvclive", 10 | "author": "Iterative", 11 | "email": "support@dvc.org", 12 | "github_user": "iterative", 13 | "version": "0.0.0", 14 | "copyright_year": "2022", 15 | "license": "Apache-2.0", 16 | "docs": "False", 17 | "short_description": "Metric logger for ML projects.", 18 | "development_status": "Development Status :: 4 - Beta", 19 | "_template": "https://github.com/iterative/py-template" 20 | } 21 | }, 22 | "directory": null 23 | } 24 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | - [ ] ❗ I have followed the [Contributing to DVCLive](https://github.com/iterative/dvclive/blob/main/CONTRIBUTING.rst) guide. 2 | 3 | - [ ] 📖 If this PR requires [documentation](https://dvc.org/doc) updates, I have created a separate PR (or issue, at least) in [dvc.org](https://github.com/iterative/dvc.org) and linked it here. 4 | 5 | Thank you for the contribution - we'll try to review it as soon as possible. 🙏 6 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # auto compares coverage to the previous base commit 6 | target: auto 7 | # adjust accordingly based on how flaky your tests are 8 | # this allows a 10% drop from the previous base commit coverage 9 | threshold: 10% 10 | # non-blocking status checks 11 | informational: false 12 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - directory: "/" 5 | package-ecosystem: "pip" 6 | schedule: 7 | interval: "weekly" 8 | labels: 9 | - "maintenance" 10 | 11 | - directory: "/" 12 | package-ecosystem: "github-actions" 13 | schedule: 14 | interval: "weekly" 15 | labels: 16 | - "maintenance" 17 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | env: 9 | FORCE_COLOR: "1" 10 | 11 | jobs: 12 | release: 13 | environment: pypi 14 | permissions: 15 | contents: read 16 | id-token: write 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Check out the repository 20 | uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | 24 | - name: Set up Python 3.10 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: '3.10' 28 | 29 | - name: Upgrade pip 30 | run: | 31 | pip install --upgrade pip 32 | pip --version 33 | 34 | - name: Install 35 | run: python -m pip install build setuptools 36 | 37 | - name: Build package 38 | run: python -m build 39 | 40 | - name: Upload package 41 | if: github.event_name == 'release' 42 | uses: pypa/gh-action-pypi-publish@release/v1 43 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | env: 10 | FORCE_COLOR: "1" 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | test_full: 18 | timeout-minutes: 30 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - name: Check out the repository 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: "3.12" 31 | 32 | - uses: astral-sh/setup-uv@v6 33 | with: 34 | enable-cache: true 35 | cache-suffix: ${{ matrix.pyv }} 36 | cache-dependency-glob: pyproject.toml 37 | 38 | - name: Full install 39 | run: uv pip install -e '.[dev]' --system 40 | 41 | - uses: actions/cache@v4 42 | with: 43 | path: ~/.cache/pre-commit/ 44 | key: pre-commit-4|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} 45 | 46 | - name: pre-commit 47 | run: uv tool run pre-commit run --show-diff-on-failure --color=always --all-files 48 | 49 | - name: mypy 50 | run: mypy 51 | 52 | - name: Run tests 53 | run: pytest -v tests --cov --cov-report=xml --cov-config=pyproject.toml 54 | 55 | - name: Upload coverage report 56 | uses: codecov/codecov-action@v5 57 | with: 58 | token: ${{ secrets.CODECOV_TOKEN }} 59 | files: coverage.xml 60 | flags: dvclive 61 | 62 | test_core: 63 | timeout-minutes: 30 64 | runs-on: ${{ matrix.os }} 65 | strategy: 66 | fail-fast: false 67 | matrix: 68 | os: [ubuntu-latest, windows-latest, macos-latest] 69 | pyv: ["3.9", "3.10", "3.11", "3.12"] 70 | 71 | steps: 72 | - name: Check out the repository 73 | uses: actions/checkout@v4 74 | with: 75 | fetch-depth: 0 76 | 77 | - name: Set up Python ${{ matrix.pyv }} 78 | uses: actions/setup-python@v5 79 | with: 80 | python-version: ${{ matrix.pyv }} 81 | cache: "pip" 82 | cache-dependency-path: setup.cfg 83 | 84 | - name: Upgrade pip 85 | run: | 86 | python -m pip install --upgrade pip wheel 87 | pip --version 88 | 89 | - name: Install core 90 | run: | 91 | pip install -e '.[tests]' 92 | 93 | - name: Run tests 94 | run: pytest -v tests --ignore=tests/frameworks 95 | -------------------------------------------------------------------------------- /.github/workflows/update-template.yaml: -------------------------------------------------------------------------------- 1 | name: Update template 2 | 3 | on: 4 | schedule: 5 | - cron: '5 1 * * *' # every day at 01:05 6 | workflow_dispatch: 7 | 8 | jobs: 9 | update: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out the repository 13 | uses: actions/checkout@v4 14 | 15 | - name: Update template 16 | uses: iterative/py-template@main 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Editors 141 | .idea 142 | .vscode 143 | 144 | .dvc/ 145 | .dvcignore 146 | src/dvclive/_dvclive_version.py 147 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-docstring-first 10 | - id: check-executables-have-shebangs 11 | - id: check-json 12 | - id: check-merge-conflict 13 | args: ["--assume-in-merge"] 14 | - id: check-toml 15 | - id: check-yaml 16 | - id: debug-statements 17 | - id: end-of-file-fixer 18 | - id: mixed-line-ending 19 | args: ["--fix=lf"] 20 | - id: sort-simple-yaml 21 | - id: trailing-whitespace 22 | - repo: https://github.com/codespell-project/codespell 23 | rev: v2.4.1 24 | hooks: 25 | - id: codespell 26 | additional_dependencies: ["tomli"] 27 | exclude: > 28 | (?x)^( 29 | .*\.ipynb 30 | )$ 31 | - repo: https://github.com/astral-sh/ruff-pre-commit 32 | rev: "v0.11.12" 33 | hooks: 34 | - id: ruff 35 | args: [--fix, --exit-non-zero-on-fix] 36 | - id: ruff-format 37 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 38 | rev: v2.14.0 39 | hooks: 40 | - id: pretty-format-toml 41 | args: [--autofix, --no-sort] 42 | - id: pretty-format-yaml 43 | args: [--autofix, --indent, '2', '--offset', '2', --preserve-quotes] 44 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.rst: -------------------------------------------------------------------------------- 1 | Contributor Covenant Code of Conduct 2 | ==================================== 3 | 4 | Our Pledge 5 | ---------- 6 | 7 | We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 8 | 9 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 10 | 11 | 12 | Our Standards 13 | ------------- 14 | 15 | Examples of behavior that contributes to a positive environment for our community include: 16 | 17 | - Demonstrating empathy and kindness toward other people 18 | - Being respectful of differing opinions, viewpoints, and experiences 19 | - Giving and gracefully accepting constructive feedback 20 | - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 21 | - Focusing on what is best not just for us as individuals, but for the overall community 22 | 23 | Examples of unacceptable behavior include: 24 | 25 | - The use of sexualized language or imagery, and sexual attention or 26 | advances of any kind 27 | - Trolling, insulting or derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or email 30 | address, without their explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | Enforcement Responsibilities 35 | ---------------------------- 36 | 37 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. 38 | 39 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. 40 | 41 | 42 | Scope 43 | ----- 44 | 45 | This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. 46 | 47 | 48 | Enforcement 49 | ----------- 50 | 51 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at support@dvc.org. All complaints will be reviewed and investigated promptly and fairly. 52 | 53 | All community leaders are obligated to respect the privacy and security of the reporter of any incident. 54 | 55 | 56 | Enforcement Guidelines 57 | ---------------------- 58 | 59 | Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: 60 | 61 | 62 | 1. Correction 63 | ~~~~~~~~~~~~~ 64 | 65 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 66 | 67 | **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. 68 | 69 | 70 | 2. Warning 71 | ~~~~~~~~~~ 72 | 73 | **Community Impact**: A violation through a single incident or series of actions. 74 | 75 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. 76 | 77 | 78 | 3. Temporary Ban 79 | ~~~~~~~~~~~~~~~~ 80 | 81 | **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. 82 | 83 | **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. 84 | 85 | 86 | 4. Permanent Ban 87 | ~~~~~~~~~~~~~~~~ 88 | 89 | **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 90 | 91 | **Consequence**: A permanent ban from any sort of public interaction within the community. 92 | 93 | 94 | Attribution 95 | ----------- 96 | 97 | This Code of Conduct is adapted from the `Contributor Covenant `__, version 2.0, 98 | available at https://www.contributor-covenant.org/version/2/0/code_of_conduct/. 99 | 100 | Community Impact Guidelines were inspired by `Mozilla’s code of conduct enforcement ladder `__. 101 | 102 | .. _homepage: https://www.contributor-covenant.org 103 | 104 | For answers to common questions about this code of conduct, see the FAQ at 105 | https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. 106 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributor Guide 2 | ================= 3 | 4 | Thank you for your interest in improving this project. 5 | This project is open-source under the `Apache 2.0 license`_ and 6 | welcomes contributions in the form of bug reports, feature requests, and pull requests. 7 | 8 | Here is a list of important resources for contributors: 9 | 10 | - `Source Code`_ 11 | - `Issue Tracker`_ 12 | - `Code of Conduct`_ 13 | 14 | .. _Apache 2.0 license: https://opensource.org/licenses/Apache-2.0 15 | .. _Source Code: https://github.com/iterative/dvclive 16 | .. _Issue Tracker: https://github.com/iterative/dvclive/issues 17 | 18 | How to report a bug 19 | ------------------- 20 | 21 | Report bugs on the `Issue Tracker`_. 22 | 23 | When filing an issue, make sure to answer these questions: 24 | 25 | - Which operating system and Python version are you using? 26 | - Which version of this project are you using? 27 | - What did you do? 28 | - What did you expect to see? 29 | - What did you see instead? 30 | 31 | The best way to get your bug fixed is to provide a test case, 32 | and/or steps to reproduce the issue. 33 | 34 | 35 | How to request a feature 36 | ------------------------ 37 | 38 | Request features on the `Issue Tracker`_. 39 | 40 | 41 | How to set up your development environment 42 | ------------------------------------------ 43 | 44 | You need Python 3.9+. 45 | 46 | - Clone the repository: 47 | 48 | .. code:: console 49 | 50 | $ git clone https://github.com/iterative/dvclive 51 | $ cd dvclive 52 | 53 | - Set up a virtual environment: 54 | 55 | .. code:: console 56 | 57 | $ python -m venv .venv 58 | $ source .venv/bin/activate 59 | 60 | Install in editable mode including development dependencies: 61 | 62 | .. code:: console 63 | 64 | $ pip install -e .[tests] 65 | 66 | If you need to test against a specific framework, you can install it separately: 67 | 68 | .. code:: console 69 | 70 | $ pip install -e .[tests,tf] 71 | $ pip install -e .[tests,optuna] 72 | 73 | How to test the project 74 | ----------------------- 75 | 76 | Run the full test suite: 77 | 78 | .. code:: console 79 | 80 | $ pytest -v tests 81 | 82 | Tests are located in the ``tests`` directory, 83 | and are written using the pytest_ testing framework. 84 | 85 | .. _pytest: https://pytest.readthedocs.io/ 86 | 87 | 88 | How to submit changes 89 | --------------------- 90 | 91 | Open a `pull request`_ to submit changes to this project. 92 | 93 | Your pull request needs to meet the following guidelines for acceptance: 94 | 95 | - The test suite must pass without errors and warnings. 96 | - Include unit tests. 97 | - If your changes add functionality, update the documentation accordingly. 98 | 99 | Feel free to submit early, though—we can always iterate on this. 100 | 101 | To run linting and code formatting checks, you can use `pre-commit`: 102 | 103 | .. code:: console 104 | 105 | $ pre-commit run --all-files 106 | 107 | It is recommended to open an issue before starting work on anything. 108 | This will allow a chance to talk it over with the owners and validate your approach. 109 | 110 | .. _pull request: https://github.com/iterative/dvclive/pulls 111 | .. github-only 112 | .. _Code of Conduct: CODE_OF_CONDUCT.rst 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DVCLive 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/dvclive.svg)](https://pypi.org/project/dvclive/) 4 | [![Status](https://img.shields.io/pypi/status/dvclive.svg)](https://pypi.org/project/dvclive/) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/dvclive)](https://pypi.org/project/dvclive) 6 | [![License](https://img.shields.io/pypi/l/dvclive)](https://opensource.org/licenses/Apache-2.0) 7 | 8 | [![Tests](https://github.com/iterative/dvclive/workflows/Tests/badge.svg?branch=main)](https://github.com/iterative/dvclive/actions?workflow=Tests) 9 | [![Codecov](https://codecov.io/gh/iterative/dvclive/branch/main/graph/badge.svg)](https://app.codecov.io/gh/iterative/dvclive) 10 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 11 | [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 12 | 13 | DVCLive is a Python library for logging machine learning metrics and other 14 | metadata in simple file formats, which is fully compatible with DVC. 15 | 16 | # [Documentation](https://dvc.org/doc/dvclive) 17 | 18 | - [Get Started](https://dvc.org/doc/start/experiments) 19 | - [How it Works](https://dvc.org/doc/dvclive/how-it-works) 20 | - [API Reference](https://dvc.org/doc/dvclive/live) 21 | - [Integrations](https://dvc.org/doc/dvclive/ml-frameworks) 22 | 23 | ______________________________________________________________________ 24 | 25 | # Quickstart 26 | 27 | | Python API Overview | PyTorch Lightning | Scikit-learn | Ultralytics YOLO v8 | 28 | |--------|--------|--------|--------| 29 | | | | | | 30 | 31 | ## Install *dvclive* 32 | 33 | ```console 34 | $ pip install dvclive 35 | ``` 36 | 37 | ## Initialize DVC Repository 38 | 39 | ```console 40 | $ git init 41 | $ dvc init 42 | $ git commit -m "DVC init" 43 | ``` 44 | 45 | ## Example code 46 | 47 | Copy the snippet below into `train.py` for a basic API usage example: 48 | 49 | ```python 50 | import time 51 | import random 52 | 53 | from dvclive import Live 54 | 55 | params = {"learning_rate": 0.002, "optimizer": "Adam", "epochs": 20} 56 | 57 | with Live() as live: 58 | 59 | # log a parameters 60 | for param in params: 61 | live.log_param(param, params[param]) 62 | 63 | # simulate training 64 | offset = random.uniform(0.2, 0.1) 65 | for epoch in range(1, params["epochs"]): 66 | fuzz = random.uniform(0.01, 0.1) 67 | accuracy = 1 - (2 ** - epoch) - fuzz - offset 68 | loss = (2 ** - epoch) + fuzz + offset 69 | 70 | # log metrics to studio 71 | live.log_metric("accuracy", accuracy) 72 | live.log_metric("loss", loss) 73 | live.next_step() 74 | time.sleep(0.2) 75 | ``` 76 | 77 | See [Integrations](https://dvc.org/doc/dvclive/ml-frameworks) for examples using 78 | DVCLive alongside different ML Frameworks. 79 | 80 | ## Running 81 | 82 | Run this a couple of times to simulate multiple experiments: 83 | 84 | ```console 85 | $ python train.py 86 | $ python train.py 87 | $ python train.py 88 | ... 89 | ``` 90 | 91 | ## Comparing 92 | 93 | DVCLive outputs can be rendered in different ways: 94 | 95 | ### DVC CLI 96 | 97 | You can use [dvc exp show](https://dvc.org/doc/command-reference/exp/show) and 98 | [dvc plots](https://dvc.org/doc/command-reference/plots) to compare and 99 | visualize metrics, parameters and plots across experiments: 100 | 101 | ```console 102 | $ dvc exp show 103 | ``` 104 | 105 | ``` 106 | ───────────────────────────────────────────────────────────────────────────────────────────────────────────── 107 | Experiment Created train.accuracy train.loss val.accuracy val.loss step epochs 108 | ───────────────────────────────────────────────────────────────────────────────────────────────────────────── 109 | workspace - 6.0109 0.23311 6.062 0.24321 6 7 110 | master 08:50 PM - - - - - - 111 | ├── 4475845 [aulic-chiv] 08:56 PM 6.0109 0.23311 6.062 0.24321 6 7 112 | ├── 7d4cef7 [yarer-tods] 08:56 PM 4.8551 0.82012 4.5555 0.033533 4 5 113 | └── d503f8e [curst-chad] 08:56 PM 4.9768 0.070585 4.0773 0.46639 4 5 114 | ───────────────────────────────────────────────────────────────────────────────────────────────────────────── 115 | ``` 116 | 117 | ```console 118 | $ dvc plots diff $(dvc exp list --names-only) --open 119 | ``` 120 | 121 | ![dvc plots diff](./docs/dvc_plots_diff.png) 122 | 123 | ### DVC Extension for VS Code 124 | 125 | Inside the 126 | [DVC Extension for VS Code](https://marketplace.visualstudio.com/items?itemName=Iterative.dvc), 127 | you can compare and visualize results using the 128 | [Experiments](https://github.com/iterative/vscode-dvc/blob/main/extension/resources/walkthrough/experiments-table.md) 129 | and 130 | [Plots](https://github.com/iterative/vscode-dvc/blob/main/extension/resources/walkthrough/plots.md) 131 | views: 132 | 133 | ![VSCode Experiments](./docs/vscode_experiments.png) 134 | 135 | ![VSCode Plots](./docs/vscode_plots.png) 136 | 137 | While experiments are running, live updates will be displayed in both views. 138 | 139 | ### DVC Studio 140 | 141 | If you push the results to [DVC Studio](https://dvc.org/doc/studio), you can 142 | compare experiments against the entire repo history: 143 | 144 | ![Studio Compare](./docs/studio_compare.png) 145 | 146 | You can enable 147 | [Studio Live Experiments](https://dvc.org/doc/studio/user-guide/projects-and-experiments/live-metrics-and-plots) 148 | to see live updates while experiments are running. 149 | 150 | ______________________________________________________________________ 151 | 152 | # Comparison to related technologies 153 | 154 | **DVCLive** is an *ML Logger*, similar to: 155 | 156 | - [MLFlow](https://mlflow.org/) 157 | - [Weights & Biases](https://wandb.ai/site) 158 | - [Neptune](https://neptune.ai/) 159 | 160 | The main differences with those *ML Loggers* are: 161 | 162 | - **DVCLive** does not **require** any additional services or servers to run. 163 | - **DVCLive** metrics, parameters, and plots are 164 | [stored as plain text files](https://dvc.org/doc/dvclive/how-it-works#directory-structure) 165 | that can be versioned by tools like Git or tracked as pointers to files in DVC 166 | storage. 167 | - **DVCLive** can save experiments or runs as 168 | [hidden Git commits](https://dvc.org/doc/dvclive/how-it-works#track-the-results). 169 | 170 | You can then use different [options](#comparing) to visualize the metrics, 171 | parameters, and plots across experiments. 172 | 173 | ______________________________________________________________________ 174 | 175 | # Contributing 176 | 177 | Contributions are very welcome. To learn more, see the 178 | [Contributor Guide](CONTRIBUTING.rst). 179 | 180 | # License 181 | 182 | Distributed under the terms of the 183 | [Apache 2.0 license](https://opensource.org/licenses/Apache-2.0), *dvclive* is 184 | free and open source software. 185 | -------------------------------------------------------------------------------- /docs/dvc_plots_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/docs/dvc_plots_diff.png -------------------------------------------------------------------------------- /docs/studio_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/docs/studio_compare.png -------------------------------------------------------------------------------- /docs/vscode_experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/docs/vscode_experiments.png -------------------------------------------------------------------------------- /docs/vscode_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/docs/vscode_plots.png -------------------------------------------------------------------------------- /examples/DVCLive-HuggingFace.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "3SJ8SY6ldmsS" 7 | }, 8 | "source": [ 9 | "### How to do Experiment tracking with DVCLive\n", 10 | "\n", 11 | "What you will learn?\n", 12 | "\n", 13 | "- Fine-tuning a model on a binary text classification task\n", 14 | "- Track machine learning experiments with DVCLive\n", 15 | "- Visualize results and create a report\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "nxiSBytidmsU" 22 | }, 23 | "source": [ 24 | "#### Setup (Install Dependencies & Setup Git)\n", 25 | "\n", 26 | "- Install accelerate , Datasets , evaluate , transformers and dvclive\n", 27 | "- Start a Git repo. Your experiments will be saved in a commit but hidden in\n", 28 | " order to not clutter your repo.\n", 29 | "- Initialize DVC\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "CLRgy2W4dmsU" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "!pip install datasets dvclive evaluate pandas 'transformers[torch]' --upgrade" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "id": "fo0sq84UdmsV" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "!git init -q\n", 52 | "!git config --local user.email \"you@example.com\"\n", 53 | "!git config --local user.name \"Your Name\"\n", 54 | "!dvc init -q\n", 55 | "!git commit -m \"DVC init\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "T5WYJ31UdmsV" 62 | }, 63 | "source": [ 64 | "### Fine-tuning a model on a text classification task\n", 65 | "\n", 66 | "#### Loading the dataset\n", 67 | "\n", 68 | "We will use the [imdb](https://huggingface.co/datasets/imdb) Large Movie Review Dataset. This is a dataset for binary\n", 69 | "sentiment classification containing a set of 25K movie reviews for training and\n", 70 | "25K for testing.\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "id": "41fP0WCbdmsV" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "from datasets import load_dataset\n", 82 | "from transformers import AutoTokenizer\n", 83 | "\n", 84 | "dataset = load_dataset(\"imdb\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "id": "V3gDKbbSdmsV" 91 | }, 92 | "source": [ 93 | "#### Preprocessing the data\n", 94 | "\n", 95 | "We use `transformers.AutoTokenizer` which transforms the inputs and put them in a format\n", 96 | "the model expects.\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "id": "uVr5lufodmsV" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n", 108 | "\n", 109 | "\n", 110 | "def tokenize_function(examples):\n", 111 | " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", 112 | "\n", 113 | "\n", 114 | "small_train_dataset = (\n", 115 | " dataset[\"train\"]\n", 116 | " .shuffle(seed=42)\n", 117 | " .select(range(2000))\n", 118 | " .map(tokenize_function, batched=True)\n", 119 | ")\n", 120 | "small_eval_dataset = (\n", 121 | " dataset[\"test\"]\n", 122 | " .shuffle(seed=42)\n", 123 | " .select(range(200))\n", 124 | " .map(tokenize_function, batched=True)\n", 125 | ")" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": { 131 | "id": "g9sELYMHdmsV" 132 | }, 133 | "source": [ 134 | "#### Define evaluation metrics\n", 135 | "\n", 136 | "f1 is a metric for combining precision and recall metrics in one unique value, so\n", 137 | "we take this criteria for evaluating the models.\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": { 144 | "id": "wmJoy5V-dmsW" 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "import numpy as np\n", 149 | "import evaluate\n", 150 | "\n", 151 | "metric = evaluate.load(\"f1\")\n", 152 | "\n", 153 | "\n", 154 | "def compute_metrics(eval_pred):\n", 155 | " logits, labels = eval_pred\n", 156 | " predictions = np.argmax(logits, axis=-1)\n", 157 | " return metric.compute(predictions=predictions, references=labels)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": { 163 | "id": "NwFntrIKdmsW" 164 | }, 165 | "source": [ 166 | "### Training and Tracking experiments with DVCLive\n", 167 | "\n", 168 | "Track experiments in DVC by changing a few lines of your Python code.\n", 169 | "Save model artifacts using `HF_DVCLIVE_LOG_MODEL=true`." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "id": "-A1oXCxE4zGi" 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "%env HF_DVCLIVE_LOG_MODEL=true" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "id": "gKKSTh0ZdmsW" 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "from transformers.integrations import DVCLiveCallback\n", 192 | "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", 193 | "\n", 194 | "model = AutoModelForSequenceClassification.from_pretrained(\n", 195 | " \"distilbert-base-cased\", num_labels=2\n", 196 | ")\n", 197 | "for param in model.base_model.parameters():\n", 198 | " param.requires_grad = False\n", 199 | "\n", 200 | "lr = 3e-4\n", 201 | "\n", 202 | "training_args = TrainingArguments(\n", 203 | " eval_strategy=\"epoch\",\n", 204 | " learning_rate=lr,\n", 205 | " logging_strategy=\"epoch\",\n", 206 | " num_train_epochs=5,\n", 207 | " output_dir=\"output\",\n", 208 | " overwrite_output_dir=True,\n", 209 | " load_best_model_at_end=True,\n", 210 | " save_strategy=\"epoch\",\n", 211 | " weight_decay=0.01,\n", 212 | ")\n", 213 | "\n", 214 | "trainer = Trainer(\n", 215 | " model=model,\n", 216 | " args=training_args,\n", 217 | " train_dataset=small_train_dataset,\n", 218 | " eval_dataset=small_eval_dataset,\n", 219 | " compute_metrics=compute_metrics,\n", 220 | ")\n", 221 | "trainer.train()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": { 227 | "id": "KKJCw0Vj6UTw" 228 | }, 229 | "source": [ 230 | "To customize tracking, include `transformers.integrations.DVCLiveCallback` in the `Trainer` callbacks and pass additional keyword arguments to `dvclive.Live`." 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "M4FKUYTi5zYQ" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "from dvclive import Live\n", 242 | "\n", 243 | "lr = 1e-4\n", 244 | "\n", 245 | "training_args = TrainingArguments(\n", 246 | " eval_strategy=\"epoch\",\n", 247 | " learning_rate=lr,\n", 248 | " logging_strategy=\"epoch\",\n", 249 | " num_train_epochs=5,\n", 250 | " output_dir=\"output\",\n", 251 | " overwrite_output_dir=True,\n", 252 | " load_best_model_at_end=True,\n", 253 | " save_strategy=\"epoch\",\n", 254 | " weight_decay=0.01,\n", 255 | ")\n", 256 | "\n", 257 | "trainer = Trainer(\n", 258 | " model=model,\n", 259 | " args=training_args,\n", 260 | " train_dataset=small_train_dataset,\n", 261 | " eval_dataset=small_eval_dataset,\n", 262 | " compute_metrics=compute_metrics,\n", 263 | " callbacks=[DVCLiveCallback(live=Live(report=\"notebook\"), log_model=True)],\n", 264 | ")\n", 265 | "trainer.train()" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": { 271 | "id": "l29wqAaDdmsW" 272 | }, 273 | "source": [ 274 | "### Comparing Experiments\n", 275 | "\n", 276 | "We create a dataframe with the experiments in order to visualize it.\n" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "id": "wwMwHvVtdmsW" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "import dvc.api\n", 288 | "import pandas as pd\n", 289 | "\n", 290 | "columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n", 291 | "\n", 292 | "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", 293 | "\n", 294 | "df.dropna(inplace=True)\n", 295 | "df.reset_index(drop=True, inplace=True)\n", 296 | "df" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "id": "TNBGUqoCdmsW" 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "!dvc plots diff $(dvc exp list --names-only)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": { 314 | "id": "sL5pH4X5dmsW" 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "from IPython.display import HTML\n", 319 | "\n", 320 | "HTML(filename=\"./dvc_plots/index.html\")" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "colab": { 326 | "provenance": [] 327 | }, 328 | "kernelspec": { 329 | "display_name": "Python 3 (ipykernel)", 330 | "language": "python", 331 | "name": "python3" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.11.7" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | -------------------------------------------------------------------------------- /examples/DVCLive-PyTorch-Lightning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "A812CVYi_B2b" 7 | }, 8 | "source": [ 9 | "\"Open" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "gPh2FiPo_B2e" 16 | }, 17 | "source": [ 18 | "# DVCLive and PyTorch Lightning" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "m0XW9Ml7_B2e" 25 | }, 26 | "source": [ 27 | "## Setup" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "id": "QivH1_cU_B2f" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "%pip install \"dvclive[lightning]\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "id": "pn_5GW1f_B2g" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "!git init -q\n", 50 | "!git config --local user.email \"you@example.com\"\n", 51 | "!git config --local user.name \"Your Name\"\n", 52 | "!dvc init -q\n", 53 | "!git commit -m \"DVC init\"" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "zC9hk7kibFTX" 60 | }, 61 | "source": [ 62 | "### Define LightningModule" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "id": "t5PxdljP_B2h" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "import lightning.pytorch as pl\n", 74 | "import torch\n", 75 | "\n", 76 | "\n", 77 | "class LitAutoEncoder(pl.LightningModule):\n", 78 | " def __init__(self, encoder_size=64, lr=1e-3): # noqa: ARG002\n", 79 | " super().__init__()\n", 80 | " self.save_hyperparameters()\n", 81 | " self.encoder = torch.nn.Sequential(\n", 82 | " torch.nn.Linear(28 * 28, encoder_size),\n", 83 | " torch.nn.ReLU(),\n", 84 | " torch.nn.Linear(encoder_size, 3),\n", 85 | " )\n", 86 | " self.decoder = torch.nn.Sequential(\n", 87 | " torch.nn.Linear(3, encoder_size),\n", 88 | " torch.nn.ReLU(),\n", 89 | " torch.nn.Linear(encoder_size, 28 * 28),\n", 90 | " )\n", 91 | "\n", 92 | " def training_step(self, batch, batch_idx): # noqa: ARG002\n", 93 | " x, y = batch\n", 94 | " x = x.view(x.size(0), -1)\n", 95 | " z = self.encoder(x)\n", 96 | " x_hat = self.decoder(z)\n", 97 | " train_mse = torch.nn.functional.mse_loss(x_hat, x)\n", 98 | " self.log(\"train_mse\", train_mse)\n", 99 | " return train_mse\n", 100 | "\n", 101 | " def validation_step(self, batch, batch_idx): # noqa: ARG002\n", 102 | " x, y = batch\n", 103 | " x = x.view(x.size(0), -1)\n", 104 | " z = self.encoder(x)\n", 105 | " x_hat = self.decoder(z)\n", 106 | " val_mse = torch.nn.functional.mse_loss(x_hat, x)\n", 107 | " self.log(\"val_mse\", val_mse)\n", 108 | " return val_mse\n", 109 | "\n", 110 | " def configure_optimizers(self):\n", 111 | " return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": { 117 | "id": "St0ElX9obqRS" 118 | }, 119 | "source": [ 120 | "### Dataset and loaders" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "id": "T5s53qgr_B2h" 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "from torchvision.datasets import MNIST\n", 132 | "from torchvision import transforms\n", 133 | "\n", 134 | "transform = transforms.ToTensor()\n", 135 | "train_set = MNIST(root=\"MNIST\", download=True, train=True, transform=transform)\n", 136 | "validation_set = MNIST(root=\"MNIST\", download=True, train=False, transform=transform)\n", 137 | "train_loader = torch.utils.data.DataLoader(train_set)\n", 138 | "validation_loader = torch.utils.data.DataLoader(validation_set)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": { 144 | "id": "ttiwwreH_B2i" 145 | }, 146 | "source": [ 147 | "# Tracking experiments with DVCLive" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "id": "sE6qj6BMoDkn" 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "from dvclive.lightning import DVCLiveLogger" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "id": "XDqNY8pL_B2i" 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "for encoder_size in (64, 128):\n", 170 | " for lr in (1e-3, 0.1):\n", 171 | " model = LitAutoEncoder(encoder_size=encoder_size, lr=lr)\n", 172 | " trainer = pl.Trainer(\n", 173 | " limit_train_batches=200,\n", 174 | " limit_val_batches=100,\n", 175 | " max_epochs=5,\n", 176 | " logger=DVCLiveLogger(log_model=True, report=\"notebook\"),\n", 177 | " )\n", 178 | " trainer.fit(model, train_loader, validation_loader)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "id": "7zEi0BXp_B2i" 185 | }, 186 | "source": [ 187 | "## Comparing results" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "1aHmLHmf_B2i" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "import dvc.api\n", 199 | "import pandas as pd\n", 200 | "\n", 201 | "columns = [\"Experiment\", \"encoder_size\", \"lr\", \"train.mse\", \"val.mse\"]\n", 202 | "\n", 203 | "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", 204 | "\n", 205 | "df.dropna(inplace=True)\n", 206 | "df.reset_index(drop=True, inplace=True)\n", 207 | "df" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "id": "db42qeHEGqTA" 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "from plotly.express import parallel_coordinates\n", 219 | "\n", 220 | "fig = parallel_coordinates(df, columns, color=\"val.mse\")\n", 221 | "fig.show()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": { 228 | "id": "3cfvi0Uk_B2j" 229 | }, 230 | "outputs": [], 231 | "source": [ 232 | "!dvc plots diff $(dvc exp list --names-only)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "id": "Zx5n2zbn_B2j" 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "from IPython.display import HTML\n", 244 | "\n", 245 | "HTML(filename=\"./dvc_plots/index.html\")" 246 | ] 247 | } 248 | ], 249 | "metadata": { 250 | "accelerator": "GPU", 251 | "colab": { 252 | "gpuType": "T4", 253 | "provenance": [], 254 | "toc_visible": true 255 | }, 256 | "kernelspec": { 257 | "display_name": "Python 3", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.9.16" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 0 275 | } 276 | -------------------------------------------------------------------------------- /examples/DVCLive-YOLO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# DVCLive and Ultralytics YOLOv8" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "%pip install dvclive ultralytics\n", 31 | "import ultralytics\n", 32 | "\n", 33 | "ultralytics.checks()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "!git init -q\n", 43 | "!git config --local user.email \"you@example.com\"\n", 44 | "!git config --local user.name \"Your Name\"\n", 45 | "!dvc init -q\n", 46 | "!git commit -m \"DVC init\"" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# Tracking experiments with DVCLive" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "If `dvclive` is installed, Ultralytics YOLO v8 will automatically use DVCLive for tracking experiments." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "!yolo train model=yolov8n.pt data=coco8.yaml epochs=5 imgsz=512\n", 70 | "!yolo train model=yolov8n.pt data=coco8.yaml epochs=5 imgsz=640\n", 71 | "!yolo train model=yolov8n.pt data=coco8.yaml epochs=10 imgsz=640\n", 72 | "!yolo train model=yolov8s.pt data=coco8.yaml epochs=10 imgsz=640\n", 73 | "!yolo train model=yolov8m.pt data=coco8.yaml epochs=10 imgsz=640" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "# Comparing results" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "import dvc.api\n", 90 | "import pandas as pd\n", 91 | "\n", 92 | "columns = [\"Experiment\", \"epochs\", \"imgsz\", \"model\", \"metrics.mAP50-95(B)\"]\n", 93 | "\n", 94 | "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", 95 | "\n", 96 | "df.dropna(inplace=True)\n", 97 | "df.reset_index(drop=True, inplace=True)\n", 98 | "df" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from plotly.express import parallel_coordinates\n", 108 | "\n", 109 | "fig = parallel_coordinates(df, columns, color=\"metrics.mAP50-95(B)\")\n", 110 | "fig.show()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "!dvc plots diff $(dvc exp list --names-only)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "from IPython.display import HTML\n", 129 | "\n", 130 | "HTML(filename=\"./dvc_plots/index.html\")" 131 | ] 132 | } 133 | ], 134 | "metadata": { 135 | "language_info": { 136 | "name": "python" 137 | }, 138 | "orig_nbformat": 4 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /examples/DVCLive-scikit-learn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# DVCLive and scikit-learn" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "vscode": { 29 | "languageId": "plaintext" 30 | } 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "!pip install dvclive scikit-learn" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "vscode": { 42 | "languageId": "plaintext" 43 | } 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "!git init -q\n", 48 | "!git config --local user.email \"you@example.com\"\n", 49 | "!git config --local user.name \"Your Name\"\n", 50 | "!dvc init -q\n", 51 | "!git commit -m \"DVC init\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "vscode": { 59 | "languageId": "plaintext" 60 | } 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "from sklearn.datasets import make_circles\n", 65 | "from sklearn.model_selection import train_test_split\n", 66 | "\n", 67 | "X, y = make_circles(noise=0.3, factor=0.5, random_state=42)\n", 68 | "\n", 69 | "X_train, X_test, y_train, y_test = train_test_split(\n", 70 | " X,\n", 71 | " y,\n", 72 | " random_state=42)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Tracking experiments with DVCLive" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "vscode": { 87 | "languageId": "plaintext" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "from dvclive import Live\n", 93 | "\n", 94 | "from sklearn.ensemble import RandomForestClassifier\n", 95 | "from sklearn.metrics import f1_score\n", 96 | "\n", 97 | "for n_estimators in (10, 50, 100):\n", 98 | "\n", 99 | " with Live() as live:\n", 100 | "\n", 101 | " live.log_param(\"n_estimators\", n_estimators)\n", 102 | "\n", 103 | " clf = RandomForestClassifier(n_estimators=n_estimators)\n", 104 | " clf.fit(X_train, y_train)\n", 105 | "\n", 106 | " y_train_pred = clf.predict(X_train)\n", 107 | "\n", 108 | " live.log_metric(\"train/f1\", f1_score(y_train, y_train_pred, average=\"weighted\"), plot=False)\n", 109 | " live.log_sklearn_plot(\n", 110 | " \"confusion_matrix\", y_train, y_train_pred, name=\"train/confusion_matrix\",\n", 111 | " title=\"Train Confusion Matrix\")\n", 112 | "\n", 113 | " y_test_pred = clf.predict(X_test)\n", 114 | "\n", 115 | " live.log_metric(\"test/f1\", f1_score(y_test, y_test_pred, average=\"weighted\"), plot=False)\n", 116 | " live.log_sklearn_plot(\n", 117 | " \"confusion_matrix\", y_test, y_test_pred, name=\"test/confusion_matrix\",\n", 118 | " title=\"Test Confusion Matrix\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Comparing results" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "vscode": { 133 | "languageId": "plaintext" 134 | } 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "import dvc.api\n", 139 | "import pandas as pd\n", 140 | "\n", 141 | "columns = [\"Experiment\", \"train.f1\", \"test.f1\", \"n_estimators\"]\n", 142 | "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", 143 | "\n", 144 | "df.dropna(inplace=True)\n", 145 | "df.reset_index(drop=True, inplace=True)\n", 146 | "df" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "vscode": { 154 | "languageId": "plaintext" 155 | } 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "!dvc plots diff $(dvc exp list --names-only)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "vscode": { 167 | "languageId": "plaintext" 168 | } 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "from IPython.display import HTML\n", 173 | "HTML(filename='./dvc_plots/index.html')" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "language_info": { 179 | "name": "python" 180 | }, 181 | "orig_nbformat": 4 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | """Automation using nox.""" 2 | 3 | import glob 4 | import os 5 | 6 | import nox 7 | 8 | nox.options.reuse_existing_virtualenvs = True 9 | nox.options.sessions = "lint", "tests" 10 | locations = "src", "tests" 11 | 12 | 13 | @nox.session(python=["3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]) 14 | def tests(session: nox.Session) -> None: 15 | session.install(".[tests]") 16 | session.run( 17 | "pytest", 18 | "--cov", 19 | "--cov-config=pyproject.toml", 20 | *session.posargs, 21 | env={"COVERAGE_FILE": f".coverage.{session.python}"}, 22 | ) 23 | 24 | 25 | @nox.session 26 | def lint(session: nox.Session) -> None: 27 | session.install("pre-commit") 28 | session.install("-e", ".[dev]") 29 | 30 | args = *(session.posargs or ("--show-diff-on-failure",)), "--all-files" 31 | session.run("pre-commit", "run", *args) 32 | session.run("python", "-m", "mypy") 33 | 34 | 35 | @nox.session 36 | def safety(session: nox.Session) -> None: 37 | """Scan dependencies for insecure packages.""" 38 | session.install(".[dev]") 39 | session.install("safety") 40 | session.run("safety", "check", "--full-report") 41 | 42 | 43 | @nox.session 44 | def build(session: nox.Session) -> None: 45 | session.install("build", "setuptools", "twine") 46 | session.run("python", "-m", "build") 47 | dists = glob.glob("dist/*") 48 | session.run("twine", "check", *dists, silent=True) 49 | 50 | 51 | @nox.session 52 | def dev(session: nox.Session) -> None: 53 | """Sets up a python development environment for the project.""" 54 | args = session.posargs or ("venv",) 55 | venv_dir = os.fsdecode(os.path.abspath(args[0])) 56 | 57 | session.log(f"Setting up virtual environment in {venv_dir}") 58 | session.install("virtualenv") 59 | session.run("virtualenv", venv_dir, silent=True) 60 | 61 | python = os.path.join(venv_dir, "bin/python") 62 | session.run(python, "-m", "pip", "install", "-e", ".[dev]", external=True) 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "setuptools_scm[toml]>=7"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | description = "Experiments logger for ML projects." 7 | name = "dvclive" 8 | readme = "README.md" 9 | keywords = [ 10 | "ai", 11 | "metrics", 12 | "collaboration", 13 | "data-science", 14 | "data-version-control", 15 | "developer-tools", 16 | "git", 17 | "machine-learning", 18 | "reproducibility" 19 | ] 20 | license = {text = "Apache License 2.0"} 21 | maintainers = [{name = "Iterative", email = "support@dvc.org"}] 22 | authors = [{name = "Iterative", email = "support@dvc.org"}] 23 | requires-python = ">=3.9" 24 | classifiers = [ 25 | "Development Status :: 4 - Beta", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12" 31 | ] 32 | dynamic = ["version"] 33 | dependencies = [ 34 | "dvc>=3.48.4", 35 | "dvc-render>=1.0.0,<2", 36 | "dvc-studio-client>=0.20,<1", 37 | "funcy", 38 | "gto", 39 | "ruamel.yaml", 40 | "scmrepo>=3,<4", 41 | "psutil", 42 | "pynvml" 43 | ] 44 | 45 | [project.optional-dependencies] 46 | image = ["numpy", "pillow"] 47 | sklearn = ["scikit-learn"] 48 | plots = ["scikit-learn", "pandas", "numpy"] 49 | markdown = ["matplotlib"] 50 | tests = [ 51 | "pytest>=7.2.0,<9.0", 52 | "pytest-sugar>=0.9.6,<2.0", 53 | "pytest-cov>=3.0.0,<7.0", 54 | "pytest-mock>=3.8.2,<4.0", 55 | "dvclive[image,plots,markdown]", 56 | "ipython", 57 | "pytest_voluptuous", 58 | "dpath", 59 | "transformers[torch]", 60 | "tf-keras" 61 | ] 62 | dev = [ 63 | "dvclive[all,tests]", 64 | "mypy==1.15.0", 65 | "types-PyYAML" 66 | ] 67 | mmcv = ["mmcv"] 68 | tf = ["tensorflow"] 69 | xgb = ["xgboost"] 70 | lgbm = ["lightgbm"] 71 | huggingface = ["transformers", "datasets"] 72 | fastai = ["fastai"] 73 | lightning = ["lightning>=2.0", "torch", "jsonargparse[signatures]>=4.26.1"] 74 | optuna = ["optuna"] 75 | all = [ 76 | "dvclive[image,mmcv,tf,xgb,lgbm,huggingface,fastai,lightning,optuna,plots,markdown]" 77 | ] 78 | 79 | [project.urls] 80 | Homepage = "https://github.com/iterative/dvclive" 81 | Documentation = "https://dvc.org/doc/dvclive" 82 | Repository = "https://github.com/iterative/dvclive" 83 | Changelog = "https://github.com/iterative/dvclive/releases" 84 | Issues = "https://github.com/iterative/dvclive/issues" 85 | 86 | [tool.setuptools] 87 | license-files = ["LICENSE"] 88 | platforms = ["any"] 89 | 90 | [tool.setuptools.packages.find] 91 | exclude = ["tests", "tests.*"] 92 | where = ["src"] 93 | namespaces = false 94 | 95 | [tool.setuptools_scm] 96 | write_to = "src/dvclive/_dvclive_version.py" 97 | 98 | [tool.pytest.ini_options] 99 | addopts = "-ra" 100 | markers = """ 101 | vscode: mark a test that verifies behavior that VS Code relies on 102 | studio: mark a test that verifies behavior that Studio relies on 103 | """ 104 | 105 | [tool.coverage.run] 106 | branch = true 107 | source = ["dvclive", "tests"] 108 | 109 | [tool.coverage.paths] 110 | source = ["src", "*/site-packages"] 111 | 112 | [tool.coverage.report] 113 | show_missing = true 114 | exclude_lines = [ 115 | "pragma: no cover", 116 | "if __name__ == .__main__.:", 117 | "if typing.TYPE_CHECKING:", 118 | "if TYPE_CHECKING:", 119 | "raise NotImplementedError", 120 | "raise AssertionError", 121 | "@overload" 122 | ] 123 | 124 | [tool.mypy] 125 | # Error output 126 | show_column_numbers = true 127 | show_error_codes = true 128 | show_error_context = true 129 | show_traceback = true 130 | pretty = true 131 | check_untyped_defs = false 132 | # Warnings 133 | warn_no_return = true 134 | warn_redundant_casts = true 135 | warn_unreachable = true 136 | ignore_missing_imports = true 137 | files = ["src", "tests"] 138 | 139 | [tool.codespell] 140 | ignore-words-list = "fpr" 141 | 142 | [tool.ruff.lint] 143 | ignore = ["N818", "UP006", "UP007", "UP035", "UP038", "B905", "PGH003", "SIM103"] 144 | select = ["F", "E", "W", "C90", "N", "UP", "YTT", "S", "BLE", "B", "A", "C4", "T10", "EXE", "ISC", "INP", "PIE", "T20", "PT", "Q", "RSE", "RET", "SLF", "SIM", "TID", "TCH", "INT", "ARG", "PGH", "PL", "TRY", "NPY", "RUF"] 145 | 146 | [tool.ruff.lint.per-file-ignores] 147 | "noxfile.py" = ["D", "PTH"] 148 | "tests/*" = ["S101", "INP001", "SLF001", "ARG001", "ARG002", "ARG005", "PLR2004", "NPY002"] 149 | 150 | [tool.ruff.lint.pylint] 151 | max-args = 10 152 | -------------------------------------------------------------------------------- /src/dvclive/__init__.py: -------------------------------------------------------------------------------- 1 | from .live import Live # noqa: F401 2 | -------------------------------------------------------------------------------- /src/dvclive/dvc.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: SLF001 2 | import copy 3 | import logging 4 | import os 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Any, List, Optional 7 | 8 | from dvclive import env 9 | from dvclive.plots import Image, Metric 10 | from dvclive.serialize import dump_yaml 11 | from dvclive.utils import StrPath, rel_path 12 | 13 | if TYPE_CHECKING: 14 | from dvc.repo import Repo 15 | from dvc.stage import Stage 16 | 17 | logger = logging.getLogger("dvclive") 18 | 19 | 20 | def _dvc_dir(dirname: StrPath) -> str: 21 | return os.path.join(dirname, ".dvc") 22 | 23 | 24 | def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]: 25 | if not root: 26 | root = os.getcwd() 27 | 28 | root = os.path.realpath(root) 29 | 30 | if not os.path.isdir(root): 31 | raise NotADirectoryError(f"'{root}'") 32 | 33 | while True: 34 | if os.path.exists(_dvc_dir(root)): 35 | return root 36 | if os.path.ismount(root): 37 | break 38 | root = os.path.dirname(root) 39 | 40 | return None 41 | 42 | 43 | def get_dvc_repo() -> Optional["Repo"]: 44 | from dvc.exceptions import NotDvcRepoError 45 | from dvc.repo import Repo 46 | from dvc.scm import Git, SCMError 47 | from scmrepo.exceptions import SCMError as GitSCMError 48 | 49 | try: 50 | return Repo() 51 | except (NotDvcRepoError, SCMError): 52 | try: 53 | return Repo.init(Git().root_dir) 54 | except GitSCMError: 55 | return None 56 | 57 | 58 | def make_dvcyaml(live) -> None: # noqa: C901 59 | dvcyaml_dir = Path(live.dvc_file).parent.absolute().as_posix() 60 | 61 | dvcyaml = {} 62 | if live._params: 63 | dvcyaml["params"] = [rel_path(live.params_file, dvcyaml_dir)] 64 | if live._metrics or live.summary: 65 | dvcyaml["metrics"] = [rel_path(live.metrics_file, dvcyaml_dir)] 66 | plots: List[Any] = [] 67 | plots_path = Path(live.plots_dir) 68 | plots_metrics_path = plots_path / Metric.subfolder 69 | if plots_metrics_path.exists(): 70 | metrics_config = {rel_path(plots_metrics_path, dvcyaml_dir): {"x": "step"}} 71 | plots.append(metrics_config) 72 | if live._images: 73 | images_path = rel_path(plots_path / Image.subfolder, dvcyaml_dir) 74 | plots.append(images_path) 75 | if live._plots: 76 | for plot in live._plots.values(): 77 | plot_path = rel_path(plot.output_path, dvcyaml_dir) 78 | plots.append({plot_path: plot.plot_config}) 79 | if plots: 80 | dvcyaml["plots"] = plots 81 | 82 | if live._artifacts: 83 | dvcyaml["artifacts"] = copy.deepcopy(live._artifacts) 84 | for artifact in dvcyaml["artifacts"].values(): # type: ignore 85 | artifact["path"] = rel_path(artifact["path"], dvcyaml_dir) 86 | 87 | if not os.path.exists(live.dvc_file): 88 | dump_yaml(dvcyaml, live.dvc_file) 89 | else: 90 | update_dvcyaml(live, dvcyaml) 91 | 92 | 93 | def update_dvcyaml(live, updates): # noqa: C901 94 | from dvc.utils.serialize import modify_yaml 95 | 96 | dvcyaml_dir = os.path.abspath(os.path.dirname(live.dvc_file)) 97 | dvclive_dir = os.path.relpath(live.dir, dvcyaml_dir) + "/" 98 | 99 | def _drop_stale_dvclive_entries(entries): 100 | non_dvclive = [] 101 | for e in entries: 102 | if isinstance(e, str): 103 | if dvclive_dir not in e: 104 | non_dvclive.append(e) 105 | elif isinstance(e, dict) and len(e) == 1: 106 | if dvclive_dir not in next(iter(e.keys())): 107 | non_dvclive.append(e) 108 | else: 109 | non_dvclive.append(e) 110 | return non_dvclive 111 | 112 | def _update_entries(old, new, key): 113 | keepers = _drop_stale_dvclive_entries(old.get(key, [])) 114 | old[key] = keepers + new.get(key, []) 115 | if not old[key]: 116 | del old[key] 117 | return old 118 | 119 | with modify_yaml(live.dvc_file) as orig: 120 | orig = _update_entries(orig, updates, "params") # noqa: PLW2901 121 | orig = _update_entries(orig, updates, "metrics") # noqa: PLW2901 122 | orig = _update_entries(orig, updates, "plots") # noqa: PLW2901 123 | old_artifacts = {} 124 | for name, meta in orig.get("artifacts", {}).items(): 125 | if dvclive_dir not in meta.get("path", dvclive_dir): 126 | old_artifacts[name] = meta 127 | orig["artifacts"] = {**old_artifacts, **updates.get("artifacts", {})} 128 | if not orig["artifacts"]: 129 | del orig["artifacts"] 130 | 131 | 132 | def get_exp_name(name, scm, baseline_rev) -> str: 133 | from dvc.exceptions import InvalidArgumentError 134 | from dvc.repo.experiments.refs import ExpRefInfo 135 | from dvc.repo.experiments.utils import ( 136 | check_ref_format, 137 | gen_random_name, 138 | get_random_exp_name, 139 | ) 140 | 141 | name = name or os.getenv(env.DVC_EXP_NAME) 142 | if name and scm and baseline_rev: 143 | ref = ExpRefInfo(baseline_sha=baseline_rev, name=name) 144 | if scm.get_ref(str(ref)): 145 | logger.warning(f"Experiment conflicts with existing experiment '{name}'.") 146 | else: 147 | try: 148 | check_ref_format(scm, ref) 149 | except InvalidArgumentError as e: 150 | logger.warning(e) 151 | else: 152 | return name 153 | if scm and baseline_rev: 154 | return get_random_exp_name(scm, baseline_rev) 155 | if name: 156 | return name 157 | return gen_random_name() 158 | 159 | 160 | def find_overlapping_stage(dvc_repo: "Repo", path: StrPath) -> Optional["Stage"]: 161 | abs_path = str(Path(path).absolute()) 162 | for stage in dvc_repo.index.stages: 163 | for out in stage.outs: 164 | if str(out.fs_path) in abs_path: 165 | return stage 166 | return None 167 | 168 | 169 | def ensure_dir_is_tracked(directory: str, dvc_repo: "Repo") -> None: 170 | from pathspec import PathSpec 171 | 172 | dir_spec = PathSpec.from_lines("gitwildmatch", [directory]) 173 | outs_spec = PathSpec.from_lines( 174 | "gitwildmatch", [str(o) for o in dvc_repo.index.outs] 175 | ) 176 | paths_to_track = [ 177 | f 178 | for f in dvc_repo.scm.untracked_files() 179 | if (dir_spec.match_file(f) and not outs_spec.match_file(f)) 180 | ] 181 | if paths_to_track: 182 | dvc_repo.scm.add(paths_to_track) 183 | -------------------------------------------------------------------------------- /src/dvclive/env.py: -------------------------------------------------------------------------------- 1 | DVCLIVE_LOGLEVEL = "DVCLIVE_LOGLEVEL" 2 | DVCLIVE_OPEN = "DVCLIVE_OPEN" 3 | DVCLIVE_RESUME = "DVCLIVE_RESUME" 4 | DVCLIVE_TEST = "DVCLIVE_TEST" 5 | DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV" 6 | DVC_EXP_NAME = "DVC_EXP_NAME" 7 | DVC_ROOT = "DVC_ROOT" 8 | -------------------------------------------------------------------------------- /src/dvclive/error.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class DvcLiveError(Exception): 5 | pass 6 | 7 | 8 | class InvalidDataTypeError(DvcLiveError): 9 | def __init__(self, name, val): 10 | self.name = name 11 | self.val = val 12 | super().__init__(f"Data '{name}' has not supported type {val}") 13 | 14 | 15 | class InvalidDvcyamlError(DvcLiveError): 16 | def __init__(self): 17 | super().__init__("`dvcyaml` path must have filename 'dvc.yaml'") 18 | 19 | 20 | class InvalidImageNameError(DvcLiveError): 21 | def __init__(self, name): 22 | self.name = name 23 | super().__init__(f"Cannot log image with name '{name}'") 24 | 25 | 26 | class InvalidPlotTypeError(DvcLiveError): 27 | def __init__(self, name): 28 | from .plots import SKLEARN_PLOTS 29 | 30 | self.name = name 31 | super().__init__( 32 | f"Plot type '{name}' is not supported." 33 | f"\nSupported types are: {list(SKLEARN_PLOTS)}" 34 | ) 35 | 36 | 37 | class InvalidParameterTypeError(DvcLiveError): 38 | def __init__(self, msg: Any): 39 | super().__init__(msg) 40 | 41 | 42 | class InvalidReportModeError(DvcLiveError): 43 | def __init__(self, val): 44 | super().__init__( 45 | f"`report` can only be `None`, `auto`, `html`, `notebook` or `md`. " 46 | f"Got {val} instead." 47 | ) 48 | -------------------------------------------------------------------------------- /src/dvclive/fabric.py: -------------------------------------------------------------------------------- 1 | # mypy: disable-error-code="no-redef" 2 | from argparse import Namespace 3 | from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union 4 | 5 | try: 6 | from lightning.fabric.loggers.logger import Logger, rank_zero_experiment 7 | from lightning.fabric.utilities.logger import ( 8 | _add_prefix, 9 | _convert_params, 10 | _sanitize_callable_params, 11 | ) 12 | from lightning.fabric.utilities.rank_zero import rank_zero_only 13 | except ImportError: 14 | from lightning_fabric.loggers.logger import Logger, rank_zero_experiment # type: ignore[assignment] 15 | from lightning_fabric.utilities.logger import ( 16 | _add_prefix, 17 | _convert_params, 18 | _sanitize_callable_params, 19 | ) 20 | from lightning_fabric.utilities.rank_zero import rank_zero_only 21 | 22 | from torch import is_tensor 23 | 24 | from dvclive.plots import Metric 25 | from dvclive.utils import standardize_metric_name 26 | 27 | if TYPE_CHECKING: 28 | from dvclive import Live 29 | 30 | 31 | class DVCLiveLogger(Logger): 32 | LOGGER_JOIN_CHAR = "/" 33 | 34 | def __init__( 35 | self, 36 | run_name: Optional[str] = None, 37 | prefix: str = "", 38 | experiment: Optional["Live"] = None, 39 | **kwargs: Any, 40 | ): 41 | super().__init__() 42 | self._version = run_name 43 | self._prefix = prefix 44 | self._experiment = experiment 45 | self._kwargs = kwargs 46 | 47 | @property 48 | def name(self) -> str: 49 | return "DvcLiveLogger" 50 | 51 | @property 52 | def version(self) -> Union[int, str]: 53 | if self._version is None: 54 | self._version = "" 55 | return self._version 56 | 57 | @property 58 | @rank_zero_experiment 59 | def experiment(self) -> "Live": 60 | if self._experiment is not None: 61 | return self._experiment 62 | 63 | assert ( # noqa: S101 64 | rank_zero_only.rank == 0 # type: ignore[attr-defined] 65 | ), "tried to init DVCLive in non global_rank=0" # type: ignore[attr-defined] 66 | 67 | from dvclive import Live 68 | 69 | self._experiment = Live(**self._kwargs) 70 | 71 | return self._experiment 72 | 73 | @rank_zero_only 74 | def log_metrics( 75 | self, 76 | metrics: Mapping[str, Union[int, float, str]], 77 | step: Optional[int] = None, 78 | sync: Optional[bool] = True, 79 | ) -> None: 80 | assert ( # noqa: S101 81 | rank_zero_only.rank == 0 # type: ignore[attr-defined] 82 | ), "experiment tried to log from global_rank != 0" 83 | 84 | if step: 85 | self.experiment.step = step 86 | else: 87 | self.experiment.step = self.experiment.step + 1 88 | 89 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) # type: ignore[assignment,arg-type] 90 | 91 | for metric_name, metric_val in metrics.items(): 92 | val = metric_val 93 | if is_tensor(val): # type: ignore[unreachable] 94 | val = val.cpu().detach().item() # type: ignore[union-attr,unreachable] 95 | name = standardize_metric_name(metric_name, __name__) 96 | if Metric.could_log(val): 97 | self.experiment.log_metric(name=name, val=val) 98 | else: 99 | raise ValueError( # noqa: TRY003 100 | f"\n you tried to log {val} which is currently not supported." 101 | "Try a scalar/tensor." 102 | ) 103 | 104 | if sync: 105 | self.experiment.sync() 106 | 107 | @rank_zero_only 108 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 109 | """Record hyperparameters. 110 | 111 | Args: 112 | params: a dictionary-like container with the hyperparameters 113 | """ 114 | params = _convert_params(params) 115 | params = _sanitize_callable_params(params) 116 | params = self._sanitize_params(params) 117 | self.experiment.log_params(params) 118 | 119 | @rank_zero_only 120 | def finalize(self, status: str) -> None: # noqa: ARG002 121 | if self._experiment is not None: 122 | self.experiment.end() 123 | 124 | @staticmethod 125 | def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: 126 | from argparse import Namespace 127 | 128 | # logging of arrays with dimension > 1 is not supported, sanitize as string 129 | params = { 130 | k: str(v) if hasattr(v, "ndim") and v.ndim > 1 else v 131 | for k, v in params.items() 132 | } 133 | 134 | # logging of argparse.Namespace is not supported, sanitize as string 135 | params = { 136 | k: str(v) if isinstance(v, Namespace) else v for k, v in params.items() 137 | } 138 | 139 | return params # noqa: RET504 140 | 141 | def __getstate__(self) -> Dict[str, Any]: 142 | state = self.__dict__.copy() 143 | state["_experiment"] = None 144 | return state 145 | -------------------------------------------------------------------------------- /src/dvclive/fastai.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Optional 3 | 4 | from fastai.callback.core import Callback 5 | 6 | from dvclive import Live 7 | from dvclive.utils import standardize_metric_name 8 | 9 | 10 | def _inside_fine_tune(): 11 | """ 12 | Hack to find out if fastai is calling `after_fit` at the end of the 13 | "freeze" stage part of `learn.fine_tune` . 14 | """ 15 | fine_tune = False 16 | fit_one_cycle = False 17 | for frame in inspect.stack(): 18 | if frame.function == "fine_tune": 19 | fine_tune = True 20 | if frame.function == "fit_one_cycle": 21 | fit_one_cycle = True 22 | if fine_tune and fit_one_cycle: 23 | return True 24 | return False 25 | 26 | 27 | class DVCLiveCallback(Callback): 28 | def __init__( 29 | self, 30 | with_opt: bool = False, 31 | live: Optional[Live] = None, 32 | **kwargs, 33 | ): 34 | super().__init__() 35 | self.with_opt = with_opt 36 | self.live = live if live is not None else Live(**kwargs) 37 | self.freeze_stage_ended = False 38 | 39 | def before_fit(self): 40 | if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): 41 | return 42 | params = { 43 | "model": type(self.learn.model).__qualname__, 44 | "batch_size": getattr(self.dls, "bs", None), 45 | "batch_per_epoch": len(getattr(self.dls, "train", [])), 46 | "frozen": bool(getattr(self.opt, "frozen_idx", -1)), 47 | "frozen_idx": getattr(self.opt, "frozen_idx", -1), 48 | "transforms": f"{getattr(self.dls, 'tfms', None)}", 49 | } 50 | self.live.log_params(params) 51 | 52 | def after_epoch(self): 53 | if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): 54 | return 55 | logged_metrics = False 56 | for key, value in zip( 57 | self.learn.recorder.metric_names, self.learn.recorder.log 58 | ): 59 | if key == "epoch": 60 | continue 61 | self.live.log_metric(standardize_metric_name(key, __name__), float(value)) 62 | logged_metrics = True 63 | 64 | # When resuming (i.e. passing `start_epoch` to learner) 65 | # fast.ai calls after_epoch but we don't want to increase the step. 66 | if logged_metrics: 67 | self.live.next_step() 68 | 69 | def after_fit(self): 70 | if hasattr(self, "lr_finder") or hasattr(self, "gather_preds"): 71 | return 72 | if _inside_fine_tune() and not self.freeze_stage_ended: 73 | self.freeze_stage_ended = True 74 | else: 75 | if hasattr(self, "save_model") and self.save_model.last_saved_path: 76 | self.live.log_artifact(str(self.save_model.last_saved_path)) 77 | self.live.end() 78 | -------------------------------------------------------------------------------- /src/dvclive/huggingface.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | import logging 3 | import os 4 | from typing import Literal, Optional, Union 5 | 6 | from transformers import ( 7 | TrainerCallback, 8 | TrainerControl, 9 | TrainerState, 10 | TrainingArguments, 11 | ) 12 | from transformers.trainer import Trainer 13 | 14 | from dvclive import Live 15 | from dvclive.utils import standardize_metric_name 16 | 17 | logger = logging.getLogger("dvclive") 18 | 19 | 20 | class DVCLiveCallback(TrainerCallback): 21 | def __init__( 22 | self, 23 | live: Optional[Live] = None, 24 | log_model: Optional[Union[Literal["all"], bool]] = None, 25 | **kwargs, 26 | ): 27 | logger.warning( 28 | "This callback is deprecated and will be removed in DVCLive 4.0" 29 | " in favor of `transformers.integrations.DVCLiveCallback`" 30 | " https://dvc.org/doc/dvclive/ml-frameworks/huggingface." 31 | ) 32 | super().__init__() 33 | self._log_model = log_model 34 | self.live = live if live is not None else Live(**kwargs) 35 | 36 | def on_train_begin( 37 | self, 38 | args: TrainingArguments, 39 | state: TrainerState, 40 | control: TrainerControl, 41 | **kwargs, 42 | ): 43 | self.live.log_params(args.to_dict()) 44 | 45 | def on_log( 46 | self, 47 | args: TrainingArguments, 48 | state: TrainerState, 49 | control: TrainerControl, 50 | **kwargs, 51 | ): 52 | logs = kwargs["logs"] 53 | for key, value in logs.items(): 54 | self.live.log_metric(standardize_metric_name(key, __name__), value) 55 | self.live.next_step() 56 | 57 | def on_save( 58 | self, 59 | args: TrainingArguments, 60 | state: TrainerState, 61 | control: TrainerControl, 62 | **kwargs, 63 | ): 64 | if self._log_model == "all" and state.is_world_process_zero: 65 | assert args.output_dir is not None # noqa: S101 66 | self.live.log_artifact(args.output_dir) 67 | 68 | def on_train_end( 69 | self, 70 | args: TrainingArguments, 71 | state: TrainerState, 72 | control: TrainerControl, 73 | **kwargs, 74 | ): 75 | if self._log_model is True and state.is_world_process_zero: 76 | fake_trainer = Trainer( 77 | args=args, 78 | model=kwargs.get("model"), 79 | tokenizer=kwargs.get("tokenizer"), 80 | eval_dataset=["fake"], 81 | ) 82 | name = "best" if args.load_best_model_at_end else "last" 83 | assert args.output_dir is not None # noqa: S101 84 | output_dir = os.path.join(args.output_dir, name) 85 | fake_trainer.save_model(output_dir) 86 | self.live.log_artifact(output_dir, name=name, type="model", copy=True) 87 | self.live.end() 88 | -------------------------------------------------------------------------------- /src/dvclive/keras.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | from typing import Dict, Optional 3 | 4 | import tensorflow as tf 5 | 6 | from dvclive import Live 7 | from dvclive.utils import standardize_metric_name 8 | 9 | 10 | class DVCLiveCallback(tf.keras.callbacks.Callback): 11 | def __init__( 12 | self, 13 | save_weights_only: bool = False, 14 | live: Optional[Live] = None, 15 | **kwargs, 16 | ): 17 | super().__init__() 18 | self.save_weights_only = save_weights_only 19 | self.live = live if live is not None else Live(**kwargs) 20 | 21 | def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): 22 | logs = logs or {} 23 | for metric, value in logs.items(): 24 | self.live.log_metric(standardize_metric_name(metric, __name__), value) 25 | self.live.next_step() 26 | 27 | def on_train_end(self, logs: Optional[Dict] = None): 28 | self.live.end() 29 | -------------------------------------------------------------------------------- /src/dvclive/lgbm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from dvclive import Live 4 | 5 | 6 | class DVCLiveCallback: 7 | def __init__(self, live: Optional[Live] = None, **kwargs): 8 | super().__init__() 9 | self.live = live if live is not None else Live(**kwargs) 10 | 11 | def __call__(self, env): 12 | multi_eval = len(env.evaluation_result_list) > 1 13 | for eval_result in env.evaluation_result_list: 14 | data_name, eval_name, result = eval_result[:3] 15 | self.live.log_metric( 16 | f"{data_name}/{eval_name}" if multi_eval else eval_name, result 17 | ) 18 | self.live.next_step() 19 | -------------------------------------------------------------------------------- /src/dvclive/lightning.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | # mypy: disable-error-code="no-redef" 3 | import inspect 4 | from pathlib import Path 5 | from typing import Dict, List, Mapping, Optional, Union 6 | 7 | from typing_extensions import override 8 | 9 | try: 10 | from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint 11 | from lightning.pytorch.loggers.logger import Logger 12 | from lightning.pytorch.loggers.utilities import _scan_checkpoints 13 | from lightning.pytorch.utilities import rank_zero_only 14 | except ImportError: 15 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint # type: ignore[assignment] 16 | from pytorch_lightning.loggers.logger import Logger # type: ignore[assignment] 17 | from pytorch_lightning.utilities import rank_zero_only 18 | 19 | try: 20 | from pytorch_lightning.utilities.logger import _scan_checkpoints 21 | except ImportError: 22 | from pytorch_lightning.loggers.utilities import _scan_checkpoints # type: ignore[assignment] 23 | 24 | 25 | from dvclive.fabric import DVCLiveLogger as FabricDVCLiveLogger 26 | 27 | 28 | def _should_sync(): 29 | """ 30 | Find out if pytorch_lightning is calling `log_metrics` from the functions 31 | where we actually want to sync. 32 | For example, prevents calling sync when external callbacks call 33 | `log_metrics` or during the multiple `update_eval_step_metrics`. 34 | """ 35 | return any( 36 | frame.function 37 | in ( 38 | "update_train_step_metrics", 39 | "update_train_epoch_metrics", 40 | "log_eval_end_metrics", 41 | ) 42 | for frame in inspect.stack() 43 | ) 44 | 45 | 46 | class DVCLiveLogger(Logger, FabricDVCLiveLogger): 47 | def __init__( 48 | self, 49 | run_name: Optional[str] = "dvclive_run", 50 | prefix="", 51 | log_model: Union[str, bool] = False, 52 | experiment=None, 53 | **kwargs, 54 | ): 55 | super().__init__( 56 | run_name=run_name, 57 | prefix=prefix, 58 | experiment=experiment, 59 | **kwargs, 60 | ) 61 | self._log_model = log_model 62 | self._logged_model_time: Dict[str, float] = {} 63 | self._checkpoint_callback: Optional[ModelCheckpoint] = None 64 | self._all_checkpoint_paths: List[str] = [] 65 | 66 | @rank_zero_only 67 | def log_metrics( 68 | self, 69 | metrics: Mapping[str, Union[int, float, str]], 70 | step: Optional[int] = None, 71 | sync: Optional[bool] = False, 72 | ) -> None: 73 | if not sync and _should_sync(): 74 | sync = True 75 | super().log_metrics(metrics, step, sync) 76 | 77 | def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: 78 | if self._log_model in [True, "all"]: 79 | self._checkpoint_callback = checkpoint_callback 80 | self._scan_checkpoints(checkpoint_callback) 81 | if self._log_model == "all" or ( 82 | self._log_model is True and checkpoint_callback.save_top_k == -1 83 | ): 84 | self._save_checkpoints(checkpoint_callback) 85 | 86 | @override 87 | @rank_zero_only 88 | def finalize(self, status: str) -> None: 89 | # Log best model. 90 | if self._checkpoint_callback: 91 | self._scan_checkpoints(self._checkpoint_callback) 92 | self._save_checkpoints(self._checkpoint_callback) 93 | best_model_path = self._checkpoint_callback.best_model_path 94 | self.experiment.log_artifact( 95 | best_model_path, name="best", type="model", copy=True 96 | ) 97 | super().finalize(status) 98 | 99 | def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: 100 | # get checkpoints to be saved with associated score 101 | checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) 102 | 103 | # update model time and append path to list of all checkpoints 104 | for t, p, _, _ in checkpoints: 105 | self._logged_model_time[p] = t 106 | self._all_checkpoint_paths.append(p) 107 | 108 | def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: 109 | # drop unused checkpoints 110 | if not self.experiment._resume and checkpoint_callback.dirpath: # noqa: SLF001 111 | for p in Path(checkpoint_callback.dirpath).iterdir(): 112 | if str(p) not in self._all_checkpoint_paths: 113 | p.unlink(missing_ok=True) 114 | 115 | # save directory 116 | self.experiment.log_artifact(checkpoint_callback.dirpath) 117 | -------------------------------------------------------------------------------- /src/dvclive/monitor_system.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict, Union, Tuple 4 | 5 | import psutil 6 | from statistics import mean 7 | from threading import Event, Thread 8 | from funcy import merge_with 9 | 10 | try: 11 | from pynvml import ( 12 | nvmlInit, 13 | nvmlDeviceGetCount, 14 | nvmlDeviceGetHandleByIndex, 15 | nvmlDeviceGetMemoryInfo, 16 | nvmlDeviceGetUtilizationRates, 17 | nvmlShutdown, 18 | NVMLError, 19 | ) 20 | 21 | GPU_AVAILABLE = True 22 | except ImportError: 23 | GPU_AVAILABLE = False 24 | 25 | logger = logging.getLogger("dvclive") 26 | GIGABYTES_DIVIDER = 1024.0**3 27 | 28 | MINIMUM_CPU_USAGE_TO_BE_ACTIVE = 20 29 | 30 | METRIC_CPU_COUNT = "system/cpu/count" 31 | METRIC_CPU_USAGE_PERCENT = "system/cpu/usage (%)" 32 | METRIC_CPU_PARALLELIZATION_PERCENT = "system/cpu/parallelization (%)" 33 | 34 | METRIC_RAM_USAGE_PERCENT = "system/ram/usage (%)" 35 | METRIC_RAM_USAGE_GB = "system/ram/usage (GB)" 36 | METRIC_RAM_TOTAL_GB = "system/ram/total (GB)" 37 | 38 | METRIC_DISK_USAGE_PERCENT = "system/disk/usage (%)" 39 | METRIC_DISK_USAGE_GB = "system/disk/usage (GB)" 40 | METRIC_DISK_TOTAL_GB = "system/disk/total (GB)" 41 | 42 | METRIC_GPU_COUNT = "system/gpu/count" 43 | METRIC_GPU_USAGE_PERCENT = "system/gpu/usage (%)" 44 | METRIC_VRAM_USAGE_PERCENT = "system/vram/usage (%)" 45 | METRIC_VRAM_USAGE_GB = "system/vram/usage (GB)" 46 | METRIC_VRAM_TOTAL_GB = "system/vram/total (GB)" 47 | 48 | 49 | class _SystemMonitor: 50 | _plot_blacklist_prefix: Tuple = ( 51 | METRIC_CPU_COUNT, 52 | METRIC_RAM_TOTAL_GB, 53 | METRIC_DISK_TOTAL_GB, 54 | METRIC_GPU_COUNT, 55 | METRIC_VRAM_TOTAL_GB, 56 | ) 57 | 58 | def __init__( 59 | self, 60 | live, 61 | interval: float, # seconds 62 | num_samples: int, 63 | directories_to_monitor: Dict[str, str], 64 | ): 65 | self._live = live 66 | self._interval = self._check_interval(interval, max_interval=0.1) 67 | self._num_samples = self._check_num_samples( 68 | num_samples, min_num_samples=1, max_num_samples=30 69 | ) 70 | self._disks_to_monitor = self._check_directories_to_monitor( 71 | directories_to_monitor 72 | ) 73 | self._warn_cpu_problem = True 74 | self._warn_gpu_problem = True 75 | self._warn_disk_doesnt_exist: Dict[str, bool] = {} 76 | 77 | self._shutdown_event = Event() 78 | Thread( 79 | target=self._monitoring_loop, 80 | ).start() 81 | 82 | def _check_interval(self, interval: float, max_interval: float) -> float: 83 | if interval > max_interval: 84 | logger.warning( 85 | f"System monitoring `interval` should be less than {max_interval} " 86 | f"seconds. Setting `interval` to {max_interval} seconds." 87 | ) 88 | return max_interval 89 | return interval 90 | 91 | def _check_num_samples( 92 | self, num_samples: int, min_num_samples: int, max_num_samples: int 93 | ) -> int: 94 | min_num_samples = 1 95 | max_num_samples = 30 96 | if not min_num_samples < num_samples < max_num_samples: 97 | num_samples = max(min(num_samples, max_num_samples), min_num_samples) 98 | logger.warning( 99 | f"System monitoring `num_samples` should be between {min_num_samples} " 100 | f"and {max_num_samples}. Setting `num_samples` to {num_samples}." 101 | ) 102 | return num_samples 103 | 104 | def _check_directories_to_monitor( 105 | self, directories_to_monitor: Dict[str, str] 106 | ) -> Dict[str, str]: 107 | disks_to_monitor = {} 108 | for disk_name, disk_path in directories_to_monitor.items(): 109 | if disk_name != os.path.normpath(disk_name): 110 | raise ValueError( # noqa: TRY003 111 | "Keys for `directories_to_monitor` should be a valid name" 112 | f", but got '{disk_name}'." 113 | ) 114 | disks_to_monitor[disk_name] = disk_path 115 | return disks_to_monitor 116 | 117 | def _monitoring_loop(self): 118 | while not self._shutdown_event.is_set(): 119 | self._metrics = {} 120 | last_metrics = {} 121 | for _ in range(self._num_samples): 122 | try: 123 | last_metrics = self._get_metrics() 124 | except psutil.Error: 125 | if self._warn_cpu_problem: 126 | logger.exception("Failed to monitor CPU metrics") 127 | self._warn_cpu_problem = False 128 | except NVMLError: 129 | if self._warn_gpu_problem: 130 | logger.exception("Failed to monitor GPU metrics") 131 | self._warn_gpu_problem = False 132 | 133 | self._metrics = merge_with(sum, self._metrics, last_metrics) 134 | self._shutdown_event.wait(self._interval) 135 | if self._shutdown_event.is_set(): 136 | break 137 | for name, values in self._metrics.items(): 138 | blacklisted = any( 139 | name.startswith(prefix) for prefix in self._plot_blacklist_prefix 140 | ) 141 | self._live.log_metric( 142 | name, 143 | values / self._num_samples, 144 | timestamp=True, 145 | plot=None if blacklisted else True, 146 | ) 147 | 148 | def _get_metrics(self) -> Dict[str, Union[float, int]]: 149 | return { 150 | **self._get_gpu_info(), 151 | **self._get_cpu_info(), 152 | **self._get_ram_info(), 153 | **self._get_disk_info(), 154 | } 155 | 156 | def _get_ram_info(self) -> Dict[str, Union[float, int]]: 157 | ram_info = psutil.virtual_memory() 158 | return { 159 | METRIC_RAM_USAGE_PERCENT: ram_info.percent, 160 | METRIC_RAM_USAGE_GB: ram_info.used / GIGABYTES_DIVIDER, 161 | METRIC_RAM_TOTAL_GB: ram_info.total / GIGABYTES_DIVIDER, 162 | } 163 | 164 | def _get_cpu_info(self) -> Dict[str, Union[float, int]]: 165 | num_cpus = psutil.cpu_count() 166 | cpus_percent = psutil.cpu_percent(percpu=True) 167 | return { 168 | METRIC_CPU_COUNT: num_cpus, 169 | METRIC_CPU_USAGE_PERCENT: mean(cpus_percent), 170 | METRIC_CPU_PARALLELIZATION_PERCENT: len( 171 | [ 172 | percent 173 | for percent in cpus_percent 174 | if percent >= MINIMUM_CPU_USAGE_TO_BE_ACTIVE 175 | ] 176 | ) 177 | * 100 178 | / num_cpus, 179 | } 180 | 181 | def _get_disk_info(self) -> Dict[str, Union[float, int]]: 182 | result = {} 183 | for disk_name, disk_path in self._disks_to_monitor.items(): 184 | try: 185 | disk_info = psutil.disk_usage(disk_path) 186 | except OSError: 187 | if self._warn_disk_doesnt_exist.get(disk_name, True): 188 | logger.warning( 189 | f"Couldn't find directory '{disk_path}', ignoring it." 190 | ) 191 | self._warn_disk_doesnt_exist[disk_name] = False 192 | continue 193 | disk_metrics = { 194 | f"{METRIC_DISK_USAGE_PERCENT}/{disk_name}": disk_info.percent, 195 | f"{METRIC_DISK_USAGE_GB}/{disk_name}": disk_info.used 196 | / GIGABYTES_DIVIDER, 197 | f"{METRIC_DISK_TOTAL_GB}/{disk_name}": disk_info.total 198 | / GIGABYTES_DIVIDER, 199 | } 200 | disk_metrics = {k.rstrip("/"): v for k, v in disk_metrics.items()} 201 | result.update(disk_metrics) 202 | return result 203 | 204 | def _get_gpu_info(self) -> Dict[str, Union[float, int]]: 205 | if not GPU_AVAILABLE: 206 | return {} 207 | 208 | nvmlInit() 209 | num_gpus = nvmlDeviceGetCount() 210 | gpu_metrics = { 211 | "system/gpu/count": num_gpus, 212 | } 213 | 214 | for gpu_idx in range(num_gpus): 215 | gpu_handle = nvmlDeviceGetHandleByIndex(gpu_idx) 216 | memory_info = nvmlDeviceGetMemoryInfo(gpu_handle) 217 | usage_info = nvmlDeviceGetUtilizationRates(gpu_handle) 218 | 219 | gpu_metrics.update( 220 | { 221 | f"{METRIC_GPU_USAGE_PERCENT}/{gpu_idx}": ( 222 | 100 * usage_info.memory / usage_info.gpu 223 | if usage_info.gpu 224 | else 0 225 | ), 226 | f"{METRIC_VRAM_USAGE_PERCENT}/{gpu_idx}": ( 227 | 100 * memory_info.used / memory_info.total 228 | ), 229 | f"{METRIC_VRAM_USAGE_GB}/{gpu_idx}": ( 230 | memory_info.used / GIGABYTES_DIVIDER 231 | ), 232 | f"{METRIC_VRAM_TOTAL_GB}/{gpu_idx}": ( 233 | memory_info.total / GIGABYTES_DIVIDER 234 | ), 235 | } 236 | ) 237 | nvmlShutdown() 238 | return gpu_metrics 239 | 240 | def end(self): 241 | self._shutdown_event.set() 242 | -------------------------------------------------------------------------------- /src/dvclive/optuna.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | from dvclive import Live 3 | 4 | 5 | class DVCLiveCallback: 6 | def __init__(self, metric_name="metric", **kwargs) -> None: 7 | kwargs["dir"] = kwargs.get("dir", "dvclive-optuna") 8 | kwargs.pop("save_dvc_exp", None) 9 | self.metric_name = metric_name 10 | self.live_kwargs = kwargs 11 | 12 | def __call__(self, study, trial) -> None: 13 | with Live(**self.live_kwargs) as live: 14 | self._log_metrics(trial.values, live) 15 | live.log_params(trial.params) 16 | 17 | def _log_metrics(self, values, live): 18 | if values is None: 19 | return 20 | 21 | if isinstance(self.metric_name, str): 22 | if len(values) > 1: 23 | # Broadcast default name for multi-objective optimization. 24 | names = [f"{self.metric_name}_{i}" for i in range(len(values))] 25 | 26 | else: 27 | names = [self.metric_name] 28 | 29 | elif len(self.metric_name) != len(values): 30 | msg = ( 31 | "Running multi-objective optimization " 32 | f"with {len(values)} objective values, " 33 | f"but {len(self.metric_name)} names specified. " 34 | "Match objective values and names," 35 | "or use default broadcasting." 36 | ) 37 | raise ValueError(msg) 38 | 39 | else: 40 | names = [*self.metric_name] 41 | 42 | metrics = dict(zip(names, values)) 43 | for k, v in metrics.items(): 44 | live.summary[k] = v 45 | -------------------------------------------------------------------------------- /src/dvclive/plots/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import CustomPlot 2 | from .image import Image 3 | from .metric import Metric 4 | from .sklearn import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc 5 | from .utils import NumpyEncoder # noqa: F401 6 | 7 | SKLEARN_PLOTS = { 8 | "calibration": Calibration, 9 | "confusion_matrix": ConfusionMatrix, 10 | "det": Det, 11 | "precision_recall": PrecisionRecall, 12 | "roc": Roc, 13 | } 14 | PLOT_TYPES = (*SKLEARN_PLOTS.values(), Metric, Image, CustomPlot) 15 | -------------------------------------------------------------------------------- /src/dvclive/plots/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from pathlib import Path 3 | 4 | 5 | class Data(abc.ABC): 6 | def __init__(self, name: str, output_folder: str) -> None: 7 | self.name = name 8 | self.output_folder: Path = Path(output_folder) / self.subfolder 9 | self._step: int = -1 10 | 11 | @property 12 | def step(self) -> int: 13 | return self._step 14 | 15 | @step.setter 16 | def step(self, val: int) -> None: 17 | self._step = val 18 | 19 | @property 20 | @abc.abstractmethod 21 | def output_path(self) -> Path: 22 | pass 23 | 24 | @property 25 | @abc.abstractmethod 26 | def subfolder(self): 27 | pass 28 | 29 | @staticmethod 30 | @abc.abstractmethod 31 | def could_log(val) -> bool: 32 | pass 33 | 34 | @abc.abstractmethod 35 | def dump(self, val, **kwargs): 36 | pass 37 | -------------------------------------------------------------------------------- /src/dvclive/plots/custom.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | 4 | from dvclive.serialize import dump_json 5 | 6 | from .base import Data 7 | 8 | 9 | class CustomPlot(Data): 10 | suffixes = (".json",) 11 | subfolder = "custom" 12 | 13 | def __init__( 14 | self, 15 | name: str, 16 | output_folder: str, 17 | x: str, 18 | y: Union[str, list[str]], 19 | template: Optional[str], 20 | title: Optional[str] = None, 21 | x_label: Optional[str] = None, 22 | y_label: Optional[str] = None, 23 | ) -> None: 24 | super().__init__(name, output_folder) 25 | self.name = self.name.replace(".json", "") 26 | if not template: 27 | template = None 28 | 29 | config = { 30 | "template": template, 31 | "x": x, 32 | "y": y, 33 | "title": title, 34 | "x_label": x_label, 35 | "y_label": y_label, 36 | } 37 | self._plot_config = {k: v for k, v in config.items() if v is not None} 38 | 39 | @property 40 | def output_path(self) -> Path: 41 | _path = Path(f"{self.output_folder / self.name}.json") 42 | _path.parent.mkdir(exist_ok=True, parents=True) 43 | return _path 44 | 45 | @staticmethod 46 | def could_log(val: object) -> bool: 47 | if isinstance(val, list) and all(isinstance(x, dict) for x in val): 48 | return True 49 | return False 50 | 51 | @property 52 | def plot_config(self): 53 | return self._plot_config 54 | 55 | def dump(self, val, **kwargs) -> None: # noqa: ARG002 56 | dump_json(val, self.output_path) 57 | -------------------------------------------------------------------------------- /src/dvclive/plots/image.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path, PurePath 2 | 3 | from dvclive.utils import isinstance_without_import 4 | 5 | from .base import Data 6 | 7 | 8 | class Image(Data): 9 | suffixes = (".jpg", ".jpeg", ".gif", ".png") 10 | subfolder = "images" 11 | 12 | @property 13 | def output_path(self) -> Path: 14 | _path = self.output_folder / self.name 15 | _path.parent.mkdir(exist_ok=True, parents=True) 16 | return _path 17 | 18 | @staticmethod 19 | def could_log(val: object) -> bool: 20 | acceptable = { 21 | ("numpy", "ndarray"), 22 | ("matplotlib.figure", "Figure"), 23 | ("PIL.Image", "Image"), 24 | } 25 | for cls in type(val).mro(): 26 | if any(isinstance_without_import(val, *cls) for cls in acceptable): 27 | return True 28 | if isinstance(val, (PurePath, str)): 29 | return True 30 | return False 31 | 32 | def dump(self, val, **kwargs) -> None: # noqa: ARG002 33 | if isinstance_without_import(val, "numpy", "ndarray"): 34 | from PIL import Image as ImagePIL 35 | 36 | ImagePIL.fromarray(val).save(self.output_path) 37 | elif isinstance_without_import(val, "matplotlib.figure", "Figure"): 38 | import matplotlib.pyplot as plt 39 | 40 | plt.savefig(self.output_path) 41 | plt.close(val) 42 | elif isinstance_without_import(val, "PIL.Image", "Image"): 43 | val.save(self.output_path) 44 | -------------------------------------------------------------------------------- /src/dvclive/plots/metric.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import time 4 | from pathlib import Path 5 | from typing import List 6 | 7 | from .base import Data 8 | from .utils import NUMPY_SCALARS 9 | 10 | 11 | class Metric(Data): 12 | suffixes = (".csv", ".tsv") 13 | subfolder = "metrics" 14 | 15 | @staticmethod 16 | def could_log(val: object) -> bool: 17 | if isinstance(val, (int, float, str)): 18 | return True 19 | if ( 20 | val.__class__.__module__ == "numpy" 21 | and val.__class__.__name__ in NUMPY_SCALARS 22 | ): 23 | return True 24 | return False 25 | 26 | @property 27 | def output_path(self) -> Path: 28 | _path = Path(f"{self.output_folder / self.name}.tsv") 29 | _path.parent.mkdir(exist_ok=True, parents=True) 30 | return _path 31 | 32 | def dump(self, val, **kwargs) -> None: 33 | row = {} 34 | if kwargs.get("timestamp", False): 35 | row["timestamp"] = int(time.time() * 1000) 36 | row["step"] = self.step 37 | row[os.path.basename(self.name)] = val 38 | 39 | existed = self.output_path.exists() 40 | with open(self.output_path, "a", encoding="utf-8", newline="") as fobj: 41 | writer = csv.DictWriter( 42 | fobj, row.keys(), delimiter="\t", lineterminator=os.linesep 43 | ) 44 | if not existed: 45 | writer.writeheader() 46 | writer.writerow(row) 47 | 48 | @property 49 | def summary_keys(self) -> List[str]: 50 | return os.path.normpath(self.name).split(os.path.sep) 51 | -------------------------------------------------------------------------------- /src/dvclive/plots/sklearn.py: -------------------------------------------------------------------------------- 1 | from dvclive.serialize import dump_json 2 | 3 | from .custom import CustomPlot 4 | 5 | 6 | class SKLearnPlot(CustomPlot): 7 | subfolder = "sklearn" 8 | 9 | @staticmethod 10 | def could_log(val: object) -> bool: 11 | if isinstance(val, tuple) and len(val) == 2: # noqa: PLR2004 12 | return True 13 | return False 14 | 15 | 16 | class Roc(SKLearnPlot): 17 | def __init__(self, name: str, output_folder: str, **plot_config) -> None: 18 | plot_config["template"] = plot_config.get("template", "simple") 19 | plot_config["title"] = plot_config.get( 20 | "title", "Receiver operating characteristic (ROC)" 21 | ) 22 | plot_config["x_label"] = plot_config.get("x_label", "False Positive Rate") 23 | plot_config["y_label"] = plot_config.get("y_label", "True Positive Rate") 24 | plot_config["x"] = "fpr" 25 | plot_config["y"] = "tpr" 26 | super().__init__(name, output_folder, **plot_config) 27 | 28 | def dump(self, val, **kwargs) -> None: 29 | from sklearn import metrics 30 | 31 | fpr, tpr, roc_thresholds = metrics.roc_curve( 32 | y_true=val[0], y_score=val[1], **kwargs 33 | ) 34 | roc = { 35 | "roc": [ 36 | {"fpr": fp, "tpr": tp, "threshold": t} 37 | for fp, tp, t in zip(fpr, tpr, roc_thresholds) 38 | ] 39 | } 40 | dump_json(roc, self.output_path) 41 | 42 | 43 | class PrecisionRecall(SKLearnPlot): 44 | def __init__(self, name: str, output_folder: str, **plot_config) -> None: 45 | plot_config["template"] = plot_config.get("template", "simple") 46 | plot_config["title"] = plot_config.get("title", "Precision-Recall Curve") 47 | plot_config["x_label"] = plot_config.get("x_label", "Recall") 48 | plot_config["y_label"] = plot_config.get("y_label", "Precision") 49 | plot_config["x"] = "recall" 50 | plot_config["y"] = "precision" 51 | super().__init__(name, output_folder, **plot_config) 52 | 53 | def dump(self, val, **kwargs) -> None: 54 | from sklearn import metrics 55 | 56 | precision, recall, prc_thresholds = metrics.precision_recall_curve( 57 | y_true=val[0], probas_pred=val[1], **kwargs 58 | ) 59 | 60 | prc = { 61 | "precision_recall": [ 62 | {"precision": p, "recall": r, "threshold": t} 63 | for p, r, t in zip(precision, recall, prc_thresholds) 64 | ] 65 | } 66 | dump_json(prc, self.output_path) 67 | 68 | 69 | class Det(SKLearnPlot): 70 | def __init__(self, name: str, output_folder: str, **plot_config) -> None: 71 | plot_config["template"] = plot_config.get("template", "simple") 72 | plot_config["title"] = plot_config.get( 73 | "title", "Detection error tradeoff (DET)" 74 | ) 75 | plot_config["x_label"] = plot_config.get("x_label", "False Positive Rate") 76 | plot_config["y_label"] = plot_config.get("y_label", "False Negative Rate") 77 | plot_config["x"] = "fpr" 78 | plot_config["y"] = "fnr" 79 | super().__init__(name, output_folder, **plot_config) 80 | 81 | def dump(self, val, **kwargs) -> None: 82 | from sklearn import metrics 83 | 84 | fpr, fnr, roc_thresholds = metrics.det_curve( 85 | y_true=val[0], y_score=val[1], **kwargs 86 | ) 87 | 88 | det = { 89 | "det": [ 90 | {"fpr": fp, "fnr": fn, "threshold": t} 91 | for fp, fn, t in zip(fpr, fnr, roc_thresholds) 92 | ] 93 | } 94 | dump_json(det, self.output_path) 95 | 96 | 97 | class ConfusionMatrix(SKLearnPlot): 98 | def __init__(self, name: str, output_folder: str, **plot_config) -> None: 99 | plot_config["template"] = ( 100 | "confusion_normalized" 101 | if plot_config.pop("normalized", None) 102 | else plot_config.get("template", "confusion") 103 | ) 104 | plot_config["title"] = plot_config.get("title", "Confusion Matrix") 105 | plot_config["x_label"] = plot_config.get("x_label", "True Label") 106 | plot_config["y_label"] = plot_config.get("y_label", "Predicted Label") 107 | plot_config["x"] = "actual" 108 | plot_config["y"] = "predicted" 109 | super().__init__(name, output_folder, **plot_config) 110 | 111 | def dump(self, val, **kwargs) -> None: # noqa: ARG002 112 | cm = [ 113 | {"actual": str(actual), "predicted": str(predicted)} 114 | for actual, predicted in zip(val[0], val[1]) 115 | ] 116 | dump_json(cm, self.output_path) 117 | 118 | 119 | class Calibration(SKLearnPlot): 120 | def __init__(self, name: str, output_folder: str, **plot_config) -> None: 121 | plot_config["template"] = plot_config.get("template", "simple") 122 | plot_config["title"] = plot_config.get("title", "Calibration Curve") 123 | plot_config["x_label"] = plot_config.get( 124 | "x_label", "Mean Predicted Probability" 125 | ) 126 | plot_config["y_label"] = plot_config.get("y_label", "Fraction of Positives") 127 | plot_config["x"] = "prob_pred" 128 | plot_config["y"] = "prob_true" 129 | super().__init__(name, output_folder, **plot_config) 130 | 131 | def dump(self, val, **kwargs) -> None: 132 | from sklearn import calibration 133 | 134 | prob_true, prob_pred = calibration.calibration_curve( 135 | y_true=val[0], y_prob=val[1], **kwargs 136 | ) 137 | 138 | _calibration = { 139 | "calibration": [ 140 | {"prob_true": pt, "prob_pred": pp} 141 | for pt, pp in zip(prob_true, prob_pred) 142 | ] 143 | } 144 | dump_json(_calibration, self.output_path) 145 | -------------------------------------------------------------------------------- /src/dvclive/plots/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | NUMPY_INTS = [ 4 | "intc", 5 | "intp", 6 | "int8", 7 | "int16", 8 | "int32", 9 | "int64", 10 | "uint8", 11 | "uint16", 12 | "uint32", 13 | "uint64", 14 | ] 15 | NUMPY_FLOATS = ["float16", "float32", "float64"] 16 | NUMPY_SCALARS = NUMPY_INTS + NUMPY_FLOATS 17 | 18 | 19 | class NumpyEncoder(json.JSONEncoder): 20 | def default(self, o): 21 | if o.__class__.__module__ == "numpy": 22 | if o.__class__.__name__ in NUMPY_INTS: 23 | return int(o) 24 | if o.__class__.__name__ in NUMPY_FLOATS: 25 | return float(o) 26 | return super().default(o) 27 | -------------------------------------------------------------------------------- /src/dvclive/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/src/dvclive/py.typed -------------------------------------------------------------------------------- /src/dvclive/report.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: SLF001 2 | import base64 3 | import json 4 | from pathlib import Path 5 | from typing import TYPE_CHECKING 6 | 7 | from dvc_render.html import render_html 8 | from dvc_render.image import ImageRenderer 9 | from dvc_render.markdown import render_markdown 10 | from dvc_render.table import TableRenderer 11 | from dvc_render.vega import VegaRenderer 12 | 13 | from dvclive.error import InvalidReportModeError 14 | from dvclive.plots import SKLEARN_PLOTS, CustomPlot, Image, Metric 15 | from dvclive.plots.sklearn import SKLearnPlot 16 | from dvclive.serialize import load_yaml 17 | from dvclive.utils import parse_tsv 18 | 19 | if TYPE_CHECKING: 20 | from dvclive import Live 21 | 22 | 23 | BLANK_NOTEBOOK_REPORT = """ 24 |
25 | DVCLive Report 26 |
27 | """ 28 | 29 | 30 | def get_scalar_renderers(metrics_path): 31 | renderers = [] 32 | for suffix in Metric.suffixes: 33 | for file in metrics_path.rglob(f"*{suffix}"): 34 | data = parse_tsv(file) 35 | for row in data: 36 | row["rev"] = "workspace" 37 | 38 | name = file.relative_to(metrics_path.parent).with_suffix("") 39 | name = name.as_posix() 40 | title = name.replace(metrics_path.name, "").strip("/") 41 | name = name.replace(metrics_path.name, "static") 42 | 43 | properties = {"x": "step", "y": file.stem, "title": title} 44 | renderers.append(VegaRenderer(data, name, **properties)) 45 | return renderers 46 | 47 | 48 | def get_image_renderers(images_folder): 49 | renderers = [] 50 | for suffix in Image.suffixes: 51 | all_images = Path(images_folder).rglob(f"*{suffix}") 52 | for file in sorted(all_images): 53 | base64_str = base64.b64encode(file.read_bytes()).decode() 54 | src = f"data:image;base64,{base64_str}" 55 | name = str(file.relative_to(images_folder)) 56 | data = [ 57 | { 58 | ImageRenderer.SRC_FIELD: src, 59 | ImageRenderer.TITLE_FIELD: name, 60 | } 61 | ] 62 | renderers.append(ImageRenderer(data, name)) 63 | return renderers 64 | 65 | 66 | def get_custom_plot_renderers(plots_folder, live): 67 | renderers = [] 68 | for suffix in CustomPlot.suffixes: 69 | for file in Path(plots_folder).rglob(f"*{suffix}"): 70 | name = file.relative_to(plots_folder).with_suffix("").as_posix() 71 | 72 | logged_plot = live._plots[name] 73 | properties = logged_plot.plot_config 74 | 75 | data = json.loads(file.read_text()) 76 | 77 | for row in data: 78 | row["rev"] = "workspace" 79 | 80 | renderers.append(VegaRenderer(data, name, **properties)) 81 | return renderers 82 | 83 | 84 | def get_sklearn_plot_renderers(plots_folder, live): 85 | renderers = [] 86 | for suffix in SKLearnPlot.suffixes: 87 | for file in Path(plots_folder).rglob(f"*{suffix}"): 88 | name = file.relative_to(plots_folder).with_suffix("").as_posix() 89 | properties = {} 90 | 91 | logged_plot = live._plots[name] 92 | for default_name, plot_class in SKLEARN_PLOTS.items(): 93 | if isinstance(logged_plot, plot_class): 94 | properties = logged_plot.plot_config 95 | data_field = default_name 96 | break 97 | 98 | data = json.loads(file.read_text()) 99 | 100 | if data_field in data: 101 | data = data[data_field] 102 | 103 | for row in data: 104 | row["rev"] = "workspace" 105 | 106 | renderers.append(VegaRenderer(data, name, **properties)) 107 | return renderers 108 | 109 | 110 | def get_metrics_renderers(dvclive_summary): 111 | metrics_path = Path(dvclive_summary) 112 | if metrics_path.exists(): 113 | return [ 114 | TableRenderer( 115 | [json.loads(metrics_path.read_text(encoding="utf-8"))], 116 | metrics_path.name, 117 | ) 118 | ] 119 | return [] 120 | 121 | 122 | def get_params_renderers(dvclive_params): 123 | params_path = Path(dvclive_params) 124 | if params_path.exists(): 125 | return [ 126 | TableRenderer( 127 | [load_yaml(params_path)], 128 | params_path.name, 129 | ) 130 | ] 131 | return [] 132 | 133 | 134 | def make_report(live: "Live"): 135 | plots_path = Path(live.plots_dir) 136 | 137 | renderers = [] 138 | renderers.extend(get_params_renderers(live.params_file)) 139 | renderers.extend(get_metrics_renderers(live.metrics_file)) 140 | renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder)) 141 | renderers.extend(get_image_renderers(plots_path / Image.subfolder)) 142 | renderers.extend( 143 | get_sklearn_plot_renderers(plots_path / SKLearnPlot.subfolder, live) 144 | ) 145 | renderers.extend(get_custom_plot_renderers(plots_path / CustomPlot.subfolder, live)) 146 | 147 | if live._report_mode == "html": 148 | render_html(renderers, live.report_file, refresh_seconds=5) 149 | elif live._report_mode == "notebook": 150 | from IPython.display import Markdown 151 | 152 | md = render_markdown(renderers) 153 | if live._report_notebook is not None: 154 | new_report = Markdown(md) # type: ignore [assignment] 155 | live._report_notebook.update(new_report) 156 | elif live._report_mode == "md": 157 | render_markdown(renderers, live.report_file) 158 | else: 159 | raise InvalidReportModeError(live._report_mode) 160 | -------------------------------------------------------------------------------- /src/dvclive/serialize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import OrderedDict 4 | 5 | from dvclive.error import DvcLiveError 6 | 7 | 8 | class YAMLError(DvcLiveError): 9 | pass 10 | 11 | 12 | class YAMLFileCorruptedError(YAMLError): 13 | def __init__(self, path): 14 | super().__init__(path, "YAML file structure is corrupted") 15 | 16 | 17 | def load_yaml(path, typ="safe"): 18 | from ruamel.yaml import YAML 19 | from ruamel.yaml import YAMLError as _YAMLError 20 | 21 | yaml = YAML(typ=typ) 22 | with open(path, encoding="utf-8") as fd: 23 | try: 24 | return yaml.load(fd.read()) 25 | except _YAMLError: 26 | raise YAMLFileCorruptedError(path) from _YAMLError 27 | 28 | 29 | def get_yaml(): 30 | from ruamel.yaml import YAML 31 | 32 | yaml = YAML() 33 | yaml.default_flow_style = False 34 | 35 | # tell Dumper to represent OrderedDict as normal dict 36 | yaml_repr_cls = yaml.Representer 37 | yaml_repr_cls.add_representer(OrderedDict, yaml_repr_cls.represent_dict) 38 | return yaml 39 | 40 | 41 | def dump_yaml(content, output_file): 42 | yaml = get_yaml() 43 | make_dir(output_file) 44 | with open(output_file, "w", encoding="utf-8") as fd: 45 | yaml.dump(content, fd) 46 | 47 | 48 | def dump_json(content, output_file, indent=4, **kwargs): 49 | make_dir(output_file) 50 | with open(output_file, "w", encoding="utf-8") as f: 51 | json.dump(content, f, indent=indent, **kwargs) 52 | f.write("\n") 53 | 54 | 55 | def make_dir(output_file): 56 | output_dir = os.path.dirname(output_file) 57 | if output_dir: 58 | os.makedirs(output_dir, exist_ok=True) 59 | -------------------------------------------------------------------------------- /src/dvclive/studio.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: SLF001 2 | from __future__ import annotations 3 | import base64 4 | import logging 5 | import math 6 | import os 7 | from pathlib import PureWindowsPath 8 | from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional 9 | 10 | from dvc.exceptions import DvcException 11 | from dvc_studio_client.config import get_studio_config 12 | from dvc_studio_client.post_live_metrics import post_live_metrics 13 | 14 | from .utils import catch_and_warn 15 | 16 | if TYPE_CHECKING: 17 | from dvclive.plots.image import Image 18 | from dvclive.live import Live 19 | from dvclive.utils import rel_path, StrPath 20 | 21 | logger = logging.getLogger("dvclive") 22 | 23 | 24 | def _cast_to_numbers(datapoints: Mapping): 25 | for datapoint in datapoints: 26 | for k, v in datapoint.items(): 27 | if k == "step": 28 | datapoint[k] = int(v) 29 | elif k == "timestamp": 30 | continue 31 | else: 32 | float_v = float(v) 33 | if math.isnan(float_v) or math.isinf(float_v): 34 | datapoint[k] = str(v) 35 | else: 36 | datapoint[k] = float_v 37 | return datapoints 38 | 39 | 40 | def _adapt_path(live: Live, name: StrPath): 41 | if live._dvc_repo is not None: 42 | name = rel_path(name, live._dvc_repo.root_dir) 43 | if os.name == "nt": 44 | name = str(PureWindowsPath(name).as_posix()) 45 | return name 46 | 47 | 48 | def _adapt_image(image_path: StrPath): 49 | with open(image_path, "rb") as fobj: 50 | return base64.b64encode(fobj.read()).decode("utf-8") 51 | 52 | 53 | def _adapt_images(live: Live, images: list[Image]): 54 | return { 55 | _adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)} 56 | for image in images 57 | if image.step > live._latest_studio_step 58 | } 59 | 60 | 61 | def _get_studio_updates(live: Live, data: dict[str, Any]): 62 | params = data["params"] 63 | plots = data["plots"] 64 | plots_start_idx = data["plots_start_idx"] 65 | metrics = data["metrics"] 66 | images = data["images"] 67 | 68 | params_file = live.params_file 69 | params_file = _adapt_path(live, params_file) 70 | params = {params_file: params} 71 | 72 | metrics_file = live.metrics_file 73 | metrics_file = _adapt_path(live, metrics_file) 74 | metrics = {metrics_file: {"data": metrics}} 75 | 76 | plots_to_send = {} 77 | for name, plot in plots.items(): 78 | path = _adapt_path(live, name) 79 | start_idx = plots_start_idx.get(name, 0) 80 | num_points_sent = live._num_points_sent_to_studio.get(name, 0) 81 | plots_to_send[path] = _cast_to_numbers(plot[num_points_sent - start_idx :]) 82 | 83 | plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()} 84 | plots_to_send.update(_adapt_images(live, images)) 85 | 86 | return metrics, params, plots_to_send 87 | 88 | 89 | def get_dvc_studio_config(live: Live): 90 | config = {} 91 | if live._dvc_repo: 92 | config = live._dvc_repo.config.get("studio") 93 | return get_studio_config(dvc_studio_config=config) 94 | 95 | 96 | def increment_num_points_sent_to_studio(live, plots_sent, data): 97 | for name, _ in data["plots"].items(): 98 | path = _adapt_path(live, name) 99 | plot = plots_sent.get(path, {}) 100 | if "data" in plot: 101 | num_points_sent = live._num_points_sent_to_studio.get(name, 0) 102 | live._num_points_sent_to_studio[name] = num_points_sent + len(plot["data"]) 103 | return live 104 | 105 | 106 | @catch_and_warn(DvcException, logger) 107 | def post_to_studio( # noqa: C901 108 | live: Live, 109 | event: Literal["start", "data", "done"], 110 | data: Optional[dict[str, Any]] = None, 111 | ): 112 | if event in live._studio_events_to_skip: 113 | return 114 | 115 | kwargs = {} 116 | if event == "start": 117 | if message := live._exp_message: 118 | kwargs["message"] = message 119 | if subdir := live._subdir: 120 | kwargs["subdir"] = subdir 121 | elif event == "data": 122 | assert data is not None # noqa: S101 123 | metrics, params, plots = _get_studio_updates(live, data) 124 | kwargs["step"] = data["step"] # type: ignore 125 | kwargs["metrics"] = metrics 126 | kwargs["params"] = params 127 | kwargs["plots"] = plots 128 | elif event == "done" and live._experiment_rev: 129 | kwargs["experiment_rev"] = live._experiment_rev 130 | 131 | response = post_live_metrics( 132 | event, 133 | live._baseline_rev, 134 | live._exp_name, # type: ignore 135 | "dvclive", 136 | dvc_studio_config=live._dvc_studio_config, 137 | studio_repo_url=live._repo_url, 138 | **kwargs, # type: ignore 139 | ) 140 | 141 | if not response: 142 | logger.warning(f"`post_to_studio` `{event}` failed.") 143 | if event == "start": 144 | live._studio_events_to_skip.add("start") 145 | live._studio_events_to_skip.add("data") 146 | live._studio_events_to_skip.add("done") 147 | elif event == "data": 148 | assert data is not None # noqa: S101 149 | live = increment_num_points_sent_to_studio(live, plots, data) 150 | live._latest_studio_step = data["step"] 151 | 152 | if event == "done": 153 | live._studio_events_to_skip.add("done") 154 | live._studio_events_to_skip.add("data") 155 | -------------------------------------------------------------------------------- /src/dvclive/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import csv 3 | import json 4 | import os 5 | import re 6 | import shutil 7 | from pathlib import Path, PurePath 8 | from platform import uname 9 | from typing import Union, List, Dict, TYPE_CHECKING 10 | import webbrowser 11 | 12 | from .error import InvalidDataTypeError 13 | 14 | if TYPE_CHECKING: 15 | import numpy as np 16 | import pandas as pd 17 | else: 18 | try: 19 | import pandas as pd 20 | except ImportError: 21 | pd = None 22 | 23 | try: 24 | import numpy as np 25 | except ImportError: 26 | np = None 27 | 28 | 29 | StrPath = Union[str, PurePath] 30 | 31 | 32 | def run_once(f): 33 | def wrapper(*args, **kwargs): 34 | if not wrapper.has_run: 35 | wrapper.has_run = True 36 | return f(*args, **kwargs) 37 | return None 38 | 39 | wrapper.has_run = False 40 | return wrapper 41 | 42 | 43 | @run_once 44 | def open_file_in_browser(file) -> bool: 45 | path = Path(file) 46 | url = str(path) if "Microsoft" in uname().release else path.resolve().as_uri() 47 | 48 | return webbrowser.open(url) 49 | 50 | 51 | def env2bool(var, undefined=False): 52 | """ 53 | undefined: return value if env var is unset 54 | """ 55 | var = os.getenv(var, None) 56 | if var is None: 57 | return undefined 58 | return bool(re.search("1|y|yes|true", var, flags=re.I)) 59 | 60 | 61 | def standardize_metric_name(metric_name: str, framework: str) -> str: 62 | """Map framework-specific format to DVCLive standard. 63 | 64 | Use `{split}/` as prefix in order to separate by subfolders. 65 | Use `{train|eval}` as split name. 66 | """ 67 | if framework == "dvclive.fastai": 68 | metric_name = metric_name.replace("train_", "train/") 69 | metric_name = metric_name.replace("valid_", "eval/") 70 | 71 | elif framework == "dvclive.huggingface": 72 | for split in ("train", "eval"): 73 | metric_name = metric_name.replace(f"{split}_", f"{split}/") 74 | 75 | elif framework == "dvclive.keras": 76 | if "val_" in metric_name: 77 | metric_name = metric_name.replace("val_", "eval/") 78 | else: 79 | metric_name = f"train/{metric_name}" 80 | 81 | elif framework in ["dvclive.lightning", "dvclive.fabric"]: 82 | parts = metric_name.split("_") 83 | split, freq, rest = None, None, None 84 | if any(parts[0].endswith(split) for split in ["train", "val", "test"]): 85 | split = parts.pop(0) 86 | # Only set freq if split was also found. 87 | # Otherwise we end up conflicting with out internal `step` property. 88 | if parts[-1] in ["step", "epoch"]: 89 | freq = parts.pop() 90 | rest = "_".join(parts) 91 | parts = [part for part in (split, freq, rest) if part] 92 | metric_name = "/".join(parts) 93 | 94 | return metric_name 95 | 96 | 97 | def parse_tsv(path): 98 | with open(path, encoding="utf-8", newline="") as fd: 99 | reader = csv.DictReader(fd, delimiter="\t") 100 | return list(reader) 101 | 102 | 103 | def parse_json(path): 104 | with open(path, encoding="utf-8") as fd: 105 | return json.load(fd) 106 | 107 | 108 | def parse_metrics(live): 109 | from .plots import Metric 110 | 111 | metrics_path = Path(live.plots_dir) / Metric.subfolder 112 | history = {} 113 | for suffix in Metric.suffixes: 114 | for scalar_file in metrics_path.rglob(f"*{suffix}"): 115 | history[str(scalar_file)] = parse_tsv(scalar_file) 116 | latest = parse_json(live.metrics_file) 117 | return history, latest 118 | 119 | 120 | def matplotlib_installed() -> bool: 121 | try: 122 | import matplotlib # noqa: F401 123 | except ImportError: 124 | return False 125 | return True 126 | 127 | 128 | def inside_colab() -> bool: 129 | try: 130 | from google import colab # noqa: F401 131 | except ImportError: 132 | return False 133 | return True 134 | 135 | 136 | def inside_notebook() -> bool: 137 | if inside_colab(): 138 | return True 139 | 140 | try: 141 | shell = get_ipython().__class__.__name__ # type: ignore[name-defined] 142 | except NameError: 143 | return False 144 | 145 | if shell == "ZMQInteractiveShell": 146 | import IPython 147 | 148 | return IPython.__version__ >= "6.0.0" 149 | return False 150 | 151 | 152 | def clean_and_copy_into(src: StrPath, dst: StrPath) -> str: 153 | Path(dst).mkdir(exist_ok=True) 154 | 155 | basename = os.path.basename(os.path.normpath(src)) 156 | dst_path = Path(os.path.join(dst, basename)) 157 | 158 | if dst_path.is_file() or dst_path.is_symlink(): 159 | dst_path.unlink() 160 | elif dst_path.is_dir(): 161 | shutil.rmtree(dst_path) 162 | 163 | if os.path.isdir(src): 164 | shutil.copytree(src, dst_path) 165 | else: 166 | shutil.copy2(src, dst_path) 167 | 168 | return str(dst_path) 169 | 170 | 171 | def isinstance_without_import(val, module, name): 172 | for cls in type(val).mro(): 173 | if (cls.__module__, cls.__name__) == (module, name): 174 | return True 175 | return False 176 | 177 | 178 | def catch_and_warn(exception, logger, on_finally=None): 179 | def decorator(func): 180 | def wrapper(*args, **kwargs): 181 | try: 182 | return func(*args, **kwargs) 183 | except exception as e: 184 | logger.warning(f"Error in {func.__name__}: {e}") 185 | finally: 186 | if on_finally is not None: 187 | on_finally() 188 | 189 | return wrapper 190 | 191 | return decorator 192 | 193 | 194 | def rel_path(path, dvc_root_path): 195 | absolute_path = Path(path).absolute() 196 | return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix()) 197 | 198 | 199 | def read_history(live, metric): 200 | from dvclive.plots.metric import Metric 201 | 202 | history, _ = parse_metrics(live) 203 | steps = [] 204 | values = [] 205 | name = os.path.join(live.plots_dir, Metric.subfolder, f"{metric}.tsv") 206 | for e in history[name]: 207 | steps.append(int(e["step"])) 208 | values.append(float(e[metric])) 209 | return steps, values 210 | 211 | 212 | def read_latest(live, metric_name): 213 | _, latest = parse_metrics(live) 214 | return latest["step"], latest[metric_name] 215 | 216 | 217 | def convert_datapoints_to_list_of_dicts( 218 | datapoints: List[Dict] | pd.DataFrame | np.ndarray, 219 | ) -> List[Dict]: 220 | """ 221 | Convert the given datapoints to a list of dictionaries. 222 | 223 | Args: 224 | datapoints: The input datapoints to be converted. 225 | 226 | Returns: 227 | A list of dictionaries representing the datapoints. 228 | 229 | Raises: 230 | TypeError: `datapoints` must be pd.DataFrame, np.ndarray, or List[Dict] 231 | """ 232 | if isinstance(datapoints, list): 233 | return datapoints 234 | 235 | if pd and isinstance(datapoints, pd.DataFrame): 236 | return datapoints.to_dict(orient="records") 237 | 238 | if np and isinstance(datapoints, np.ndarray): 239 | # This is a structured array 240 | if datapoints.dtype.names is not None: 241 | return [dict(zip(datapoints.dtype.names, row)) for row in datapoints] 242 | 243 | # This is a regular array 244 | return [dict(enumerate(row)) for row in datapoints] 245 | 246 | # Raise an error if the input is not a supported type 247 | raise InvalidDataTypeError("datapoints", type(datapoints)) 248 | -------------------------------------------------------------------------------- /src/dvclive/vscode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, Optional, Union 4 | 5 | from dvclive.dvc import _find_dvc_root 6 | from dvclive.utils import StrPath 7 | 8 | from . import env 9 | 10 | 11 | def _dvc_exps_run_dir(dirname: StrPath) -> str: 12 | return os.path.join(dirname, ".dvc", "tmp", "exps", "run") 13 | 14 | 15 | def _dvclive_only_signal_file(root_dir: StrPath) -> str: 16 | dvc_exps_run_dir = _dvc_exps_run_dir(root_dir) 17 | return os.path.join(dvc_exps_run_dir, "DVCLIVE_ONLY") 18 | 19 | 20 | def _dvclive_step_completed_signal_file(root_dir: StrPath) -> str: 21 | dvc_exps_run_dir = _dvc_exps_run_dir(root_dir) 22 | return os.path.join(dvc_exps_run_dir, "DVCLIVE_STEP_COMPLETED") 23 | 24 | 25 | def _find_non_queue_root() -> Optional[str]: 26 | return os.getenv(env.DVC_ROOT) or _find_dvc_root() 27 | 28 | 29 | def _write_file(file: str, contents: Dict[str, Union[str, int]]): 30 | import builtins 31 | 32 | with builtins.open(file, "w", encoding="utf-8") as fobj: 33 | # NOTE: force flushing/writing empty file to disk, otherwise when 34 | # run in certain contexts (pytest) file may not actually be written 35 | fobj.write(json.dumps(contents, sort_keys=True, ensure_ascii=False)) 36 | fobj.flush() 37 | os.fsync(fobj.fileno()) 38 | 39 | 40 | def mark_dvclive_step_completed(step: int) -> None: 41 | """ 42 | https://github.com/iterative/vscode-dvc/issues/4528 43 | Signal DVC VS Code extension that 44 | a step has been completed for an experiment running in the queue 45 | """ 46 | non_queue_root_dir = _find_non_queue_root() 47 | 48 | if not non_queue_root_dir: 49 | return 50 | 51 | exp_run_dir = _dvc_exps_run_dir(non_queue_root_dir) 52 | os.makedirs(exp_run_dir, exist_ok=True) 53 | 54 | signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir) 55 | 56 | _write_file(signal_file, {"pid": os.getpid(), "step": step}) 57 | 58 | 59 | def cleanup_dvclive_step_completed() -> None: 60 | non_queue_root_dir = _find_non_queue_root() 61 | 62 | if not non_queue_root_dir: 63 | return 64 | 65 | signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir) 66 | 67 | if not os.path.exists(signal_file): 68 | return 69 | 70 | os.remove(signal_file) 71 | 72 | 73 | def mark_dvclive_only_started(exp_name: str) -> None: 74 | """ 75 | Signal DVC VS Code extension that 76 | an experiment is running in the workspace. 77 | """ 78 | root_dir = _find_dvc_root() 79 | if not root_dir: 80 | return 81 | 82 | exp_run_dir = _dvc_exps_run_dir(root_dir) 83 | os.makedirs(exp_run_dir, exist_ok=True) 84 | 85 | signal_file = _dvclive_only_signal_file(root_dir) 86 | 87 | _write_file(signal_file, {"pid": os.getpid(), "exp_name": exp_name}) 88 | 89 | 90 | def mark_dvclive_only_ended() -> None: 91 | root_dir = _find_dvc_root() 92 | if not root_dir: 93 | return 94 | 95 | signal_file = _dvclive_only_signal_file(root_dir) 96 | 97 | if not os.path.exists(signal_file): 98 | return 99 | 100 | os.remove(signal_file) 101 | -------------------------------------------------------------------------------- /src/dvclive/xgb.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | from typing import Optional 3 | from warnings import warn 4 | 5 | from xgboost.callback import TrainingCallback 6 | 7 | from dvclive import Live 8 | 9 | 10 | class DVCLiveCallback(TrainingCallback): 11 | def __init__( 12 | self, 13 | metric_data: Optional[str] = None, 14 | live: Optional[Live] = None, 15 | **kwargs, 16 | ): 17 | super().__init__() 18 | if metric_data is not None: 19 | warn( 20 | "`metric_data` is deprecated and will be removed", 21 | category=DeprecationWarning, 22 | stacklevel=2, 23 | ) 24 | self._metric_data = metric_data 25 | self.live = live if live is not None else Live(**kwargs) 26 | 27 | def after_iteration(self, model, epoch, evals_log): 28 | if self._metric_data: 29 | evals_log = {"": evals_log[self._metric_data]} 30 | for subdir, data in evals_log.items(): 31 | for key, values in data.items(): 32 | self.live.log_metric(f"{subdir}/{key}" if subdir else key, values[-1]) 33 | self.live.next_step() 34 | 35 | def after_training(self, model): 36 | self.live.end() 37 | return model 38 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/dvclive/4f3050de43e5d3b10c44ecfee6527958e4f0c9ce/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: ARG002 2 | import sys 3 | 4 | import pytest 5 | from dvc_studio_client.env import DVC_STUDIO_TOKEN, DVC_STUDIO_URL, STUDIO_REPO_URL 6 | 7 | from dvclive.utils import rel_path 8 | 9 | 10 | @pytest.fixture 11 | def tmp_dir(tmp_path, monkeypatch): 12 | monkeypatch.chdir(tmp_path) 13 | return tmp_path 14 | 15 | 16 | @pytest.fixture 17 | def mocked_dvc_repo(tmp_dir, mocker): 18 | _dvc_repo = mocker.MagicMock() 19 | _dvc_repo.index.stages = [] 20 | _dvc_repo.scm.get_rev.return_value = "f" * 40 21 | _dvc_repo.scm.get_ref.return_value = None 22 | _dvc_repo.scm.no_commits = False 23 | _dvc_repo.experiments.save.return_value = "e" * 40 24 | _dvc_repo.root_dir = _dvc_repo.scm.root_dir = tmp_dir 25 | _dvc_repo.fs.relpath = rel_path 26 | _dvc_repo.config = {} 27 | mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo) 28 | return _dvc_repo 29 | 30 | 31 | @pytest.fixture 32 | def mocked_dvc_subrepo(tmp_dir, mocker, mocked_dvc_repo): 33 | mocked_dvc_repo.root_dir = tmp_dir / "subdir" 34 | return mocked_dvc_repo 35 | 36 | 37 | @pytest.fixture 38 | def dvc_repo(tmp_dir): 39 | from dvc.repo import Repo 40 | from scmrepo.git import Git 41 | 42 | Git.init(tmp_dir) 43 | repo = Repo.init(tmp_dir) 44 | repo.scm.add_commit(".", "init") 45 | return repo 46 | 47 | 48 | @pytest.fixture(autouse=True) 49 | def _capture_wrap(): 50 | # https://github.com/pytest-dev/pytest/issues/5502#issuecomment-678368525 51 | sys.stderr.close = lambda *args: None 52 | sys.stdout.close = lambda *args: None 53 | 54 | 55 | @pytest.fixture(autouse=True) 56 | def _mocked_webbrowser_open(mocker): 57 | mocker.patch("webbrowser.open") 58 | 59 | 60 | @pytest.fixture(autouse=True) 61 | def _mocked_ci(monkeypatch): 62 | monkeypatch.setenv("CI", "false") 63 | 64 | 65 | @pytest.fixture 66 | def mocked_studio_post(mocker, monkeypatch): 67 | valid_response = mocker.MagicMock() 68 | valid_response.status_code = 200 69 | mocked_post = mocker.patch("requests.post", return_value=valid_response) 70 | monkeypatch.setenv(DVC_STUDIO_URL, "https://0.0.0.0") 71 | monkeypatch.setenv(STUDIO_REPO_URL, "STUDIO_REPO_URL") 72 | monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") 73 | return mocked_post, valid_response 74 | -------------------------------------------------------------------------------- /tests/frameworks/test_fabric.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from unittest.mock import Mock 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | try: 8 | import torch 9 | from dvclive.fabric import DVCLiveLogger 10 | except ImportError: 11 | pytest.skip("skipping lightning tests", allow_module_level=True) 12 | 13 | 14 | class BoringModel(torch.nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.layer = torch.nn.Linear(32, 2, bias=False) 18 | 19 | def forward(self, x): 20 | x = self.layer(x) 21 | return torch.nn.functional.mse_loss(x, torch.ones_like(x)) 22 | 23 | 24 | @pytest.mark.parametrize("step_idx", [10, None]) 25 | def test_dvclive_log_metrics(tmp_path, mocked_dvc_repo, step_idx): 26 | logger = DVCLiveLogger(dir=tmp_path) 27 | metrics = { 28 | "float": 0.3, 29 | "int": 1, 30 | "FloatTensor": torch.tensor(0.1), 31 | "IntTensor": torch.tensor(1), 32 | } 33 | logger.log_metrics(metrics, step_idx) 34 | 35 | 36 | def test_dvclive_log_hyperparams(tmp_path, mocked_dvc_repo): 37 | logger = DVCLiveLogger(dir=tmp_path) 38 | hparams = { 39 | "float": 0.3, 40 | "int": 1, 41 | "string": "abc", 42 | "bool": True, 43 | "dict": {"a": {"b": "c"}}, 44 | "list": [1, 2, 3], 45 | "namespace": Namespace(foo=Namespace(bar="buzz")), 46 | "layer": torch.nn.BatchNorm1d, 47 | "tensor": torch.empty(2, 2, 2), 48 | "array": np.empty([2, 2, 2]), 49 | } 50 | logger.log_hyperparams(hparams) 51 | 52 | 53 | def test_dvclive_finalize(monkeypatch, tmp_path, mocked_dvc_repo): 54 | """Test that the SummaryWriter closes in finalize.""" 55 | import dvclive 56 | 57 | monkeypatch.setattr(dvclive, "Live", Mock()) 58 | logger = DVCLiveLogger(dir=tmp_path) 59 | assert logger._experiment is None 60 | logger.finalize("any") 61 | 62 | # no log calls, no experiment created -> nothing to flush 63 | logger.experiment.assert_not_called() 64 | 65 | logger = DVCLiveLogger(dir=tmp_path) 66 | logger.log_hyperparams({"flush_me": 11.1}) # trigger creation of an experiment 67 | logger.finalize("any") 68 | 69 | # finalize flushes to experiment directory 70 | logger.experiment.end.assert_called() 71 | -------------------------------------------------------------------------------- /tests/frameworks/test_fastai.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from dvclive import Live 6 | from dvclive.plots.metric import Metric 7 | 8 | try: 9 | from fastai.callback.tracker import SaveModelCallback 10 | from fastai.tabular.all import ( 11 | Categorify, 12 | Normalize, 13 | ProgressCallback, 14 | TabularDataLoaders, 15 | accuracy, 16 | tabular_learner, 17 | ) 18 | 19 | from dvclive.fastai import DVCLiveCallback 20 | except ImportError: 21 | pytest.skip("skipping fastai tests", allow_module_level=True) 22 | 23 | 24 | @pytest.fixture 25 | def data_loader(): 26 | from pandas import DataFrame 27 | 28 | d = { 29 | "x1": [1, 1, 0, 0, 1, 1, 0, 0], 30 | "x2": [1, 0, 1, 0, 1, 0, 1, 0], 31 | "y": [1, 0, 0, 1, 1, 0, 0, 1], 32 | } 33 | df = DataFrame(d) 34 | return TabularDataLoaders.from_df( 35 | df, 36 | valid_idx=[4, 5, 6, 7], 37 | batch_size=2, 38 | cont_names=["x1", "x2"], 39 | procs=[Categorify, Normalize], 40 | y_names="y", 41 | ) 42 | 43 | 44 | def test_fastai_callback(tmp_dir, data_loader, mocker): 45 | learn = tabular_learner(data_loader, metrics=accuracy) 46 | learn.remove_cb(ProgressCallback) 47 | callback = DVCLiveCallback() 48 | live = callback.live 49 | 50 | spy = mocker.spy(live, "end") 51 | learn.fit_one_cycle(2, cbs=[callback]) 52 | spy.assert_called_once() 53 | 54 | assert (tmp_dir / live.dir).exists() 55 | assert (tmp_dir / live.params_file).exists() 56 | assert (tmp_dir / live.params_file).read_text() == ( 57 | "model: TabularModel\nbatch_size: 2\nbatch_per_epoch: 2\nfrozen: false" 58 | "\nfrozen_idx: 0\ntransforms: None\n" 59 | ) 60 | 61 | metrics_path = tmp_dir / live.plots_dir / Metric.subfolder 62 | train_path = metrics_path / "train" 63 | valid_path = metrics_path / "eval" 64 | 65 | assert train_path.is_dir() 66 | assert valid_path.is_dir() 67 | assert (metrics_path / "accuracy.tsv").exists() 68 | assert not (metrics_path / "epoch.tsv").exists() 69 | 70 | 71 | def test_fastai_pass_logger(): 72 | logger = Live("train_logs") 73 | 74 | assert DVCLiveCallback().live is not logger 75 | assert DVCLiveCallback(live=logger).live is logger 76 | 77 | 78 | def test_fast_ai_resume(tmp_dir, data_loader, mocker): 79 | learn = tabular_learner(data_loader, metrics=accuracy) 80 | learn.remove_cb(ProgressCallback) 81 | callback = DVCLiveCallback() 82 | live = callback.live 83 | 84 | spy = mocker.spy(live, "next_step") 85 | end = mocker.spy(live, "end") 86 | learn.fit_one_cycle(2, cbs=[callback]) 87 | assert spy.call_count == 2 88 | assert end.call_count == 1 89 | 90 | callback = DVCLiveCallback(resume=True) 91 | live = callback.live 92 | spy = mocker.spy(live, "next_step") 93 | learn.fit_one_cycle(3, cbs=[callback], start_epoch=live.step) 94 | assert spy.call_count == 1 95 | 96 | 97 | def test_fast_ai_avoid_unnecessary_end_calls(tmp_dir, data_loader, mocker): 98 | """ 99 | `after_fit` might be called from different points and not all mean that the 100 | training has ended. 101 | """ 102 | learn = tabular_learner(data_loader, metrics=accuracy) 103 | learn.remove_cb(ProgressCallback) 104 | callback = DVCLiveCallback() 105 | live = callback.live 106 | 107 | end = mocker.spy(live, "end") 108 | after_fit = mocker.spy(callback, "after_fit") 109 | learn.fine_tune(2, cbs=[callback]) 110 | assert end.call_count == 1 111 | assert after_fit.call_count == 2 112 | 113 | 114 | def test_fastai_save_model_callback(tmp_dir, data_loader, mocker): 115 | learn = tabular_learner(data_loader, metrics=accuracy) 116 | learn.remove_cb(ProgressCallback) 117 | learn.model_dir = os.path.abspath("./") 118 | 119 | save_callback = SaveModelCallback() 120 | live_callback = DVCLiveCallback() 121 | log_artifact = mocker.patch.object(live_callback.live, "log_artifact") 122 | learn.fit_one_cycle(2, cbs=[save_callback, live_callback]) 123 | assert (tmp_dir / "model.pth").is_file() 124 | log_artifact.assert_called_with(str(save_callback.last_saved_path)) 125 | -------------------------------------------------------------------------------- /tests/frameworks/test_huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from dvclive import Live 6 | from dvclive.plots.metric import Metric 7 | from dvclive.serialize import load_yaml 8 | from dvclive.utils import parse_metrics 9 | 10 | try: 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from transformers import ( 15 | PretrainedConfig, 16 | PreTrainedModel, 17 | Trainer, 18 | TrainingArguments, 19 | ) 20 | from transformers.integrations import DVCLiveCallback as ExternalCallback 21 | 22 | from dvclive.huggingface import DVCLiveCallback as InternalCallback 23 | except ImportError: 24 | pytest.skip("skipping huggingface tests", allow_module_level=True) 25 | 26 | 27 | def compute_metrics(eval_preds): 28 | """https://github.com/iterative/dvclive/pull/321#issuecomment-1266916039""" 29 | import time 30 | 31 | time.sleep(time.get_clock_info("time").resolution) 32 | return {"foo": 1} 33 | 34 | 35 | # From transformers/tests/trainer 36 | 37 | 38 | class RegressionDataset: 39 | def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): 40 | np.random.seed(seed) 41 | self.label_names = ["labels"] if label_names is None else label_names 42 | self.length = length 43 | self.x = np.random.normal(size=(length,)).astype(np.float32) 44 | self.ys = [ 45 | a * self.x + b + np.random.normal(scale=0.1, size=(length,)) 46 | for _ in self.label_names 47 | ] 48 | self.ys = [y.astype(np.float32) for y in self.ys] 49 | 50 | def __len__(self): 51 | return self.length 52 | 53 | def __getitem__(self, i): 54 | result = {name: y[i] for name, y in zip(self.label_names, self.ys)} 55 | result["input_x"] = self.x[i] 56 | return result 57 | 58 | 59 | class RegressionModelConfig(PretrainedConfig): 60 | def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs): 61 | super().__init__(**kwargs) 62 | self.a = a 63 | self.b = b 64 | self.double_output = double_output 65 | self.random_torch = random_torch 66 | self.hidden_size = 1 67 | 68 | 69 | class RegressionPreTrainedModel(PreTrainedModel): 70 | config_class = RegressionModelConfig 71 | base_model_prefix = "regression" 72 | 73 | def __init__(self, config): 74 | super().__init__(config) 75 | self.a = nn.Parameter(torch.tensor(config.a).float()) 76 | self.b = nn.Parameter(torch.tensor(config.b).float()) 77 | self.double_output = config.double_output 78 | 79 | def forward(self, input_x, labels=None, **kwargs): 80 | y = input_x * self.a + self.b 81 | if labels is None: 82 | return (y, y) if self.double_output else (y,) 83 | loss = nn.functional.mse_loss(y, labels) 84 | return (loss, y, y) if self.double_output else (loss, y) 85 | 86 | 87 | @pytest.fixture 88 | def data(): 89 | return RegressionDataset(), RegressionDataset() 90 | 91 | 92 | @pytest.fixture 93 | def model(): 94 | config = RegressionModelConfig() 95 | return RegressionPreTrainedModel(config) 96 | 97 | 98 | @pytest.fixture 99 | def args(): 100 | return TrainingArguments( 101 | "foo", 102 | eval_strategy="epoch", 103 | num_train_epochs=2, 104 | save_strategy="epoch", 105 | report_to="none", # Disable auto-reporting to avoid duplication 106 | use_cpu=True, 107 | ) 108 | 109 | 110 | @pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) 111 | def test_huggingface_integration(tmp_dir, model, args, data, mocker, callback): 112 | trainer = Trainer( 113 | model, 114 | args, 115 | train_dataset=data[0], 116 | eval_dataset=data[1], 117 | compute_metrics=compute_metrics, 118 | ) 119 | callback = callback() 120 | spy = mocker.spy(Live, "end") 121 | trainer.add_callback(callback) 122 | trainer.train() 123 | spy.assert_called_once() 124 | 125 | live = callback.live 126 | assert os.path.exists(live.dir) 127 | 128 | logs, _ = parse_metrics(live) 129 | 130 | scalars = os.path.join(live.plots_dir, Metric.subfolder) 131 | assert os.path.join(scalars, "eval", "foo.tsv") in logs 132 | assert os.path.join(scalars, "eval", "loss.tsv") in logs 133 | assert os.path.join(scalars, "train", "loss.tsv") in logs 134 | assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3 135 | assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2 136 | 137 | params = load_yaml(live.params_file) 138 | assert params["num_train_epochs"] == 2 139 | 140 | 141 | @pytest.mark.parametrize("log_model", ["all", True, False, None]) 142 | @pytest.mark.parametrize("best", [True, False]) 143 | @pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) 144 | def test_huggingface_log_model( 145 | tmp_dir, 146 | mocked_dvc_repo, 147 | model, 148 | data, 149 | args, 150 | monkeypatch, 151 | mocker, 152 | log_model, 153 | best, 154 | callback, 155 | ): 156 | live = Live() 157 | log_artifact = mocker.patch.object(live, "log_artifact") 158 | if callback == ExternalCallback: 159 | monkeypatch.setenv("HF_DVCLIVE_LOG_MODEL", str(log_model)) 160 | live_callback = callback(live=live) 161 | else: 162 | live_callback = callback(live=live, log_model=log_model) 163 | 164 | args.load_best_model_at_end = best 165 | args.metric_for_best_model = "loss" 166 | 167 | trainer = Trainer( 168 | model, 169 | args, 170 | train_dataset=data[0], 171 | eval_dataset=data[1], 172 | compute_metrics=compute_metrics, 173 | ) 174 | trainer.add_callback(live_callback) 175 | trainer.train() 176 | 177 | expected_call_count = { 178 | "all": 2, 179 | True: 1, 180 | False: 0, 181 | None: 0, 182 | } 183 | assert log_artifact.call_count == expected_call_count[log_model] 184 | 185 | if log_model is True: 186 | name = "best" if best else "last" 187 | log_artifact.assert_called_with( 188 | os.path.join(args.output_dir, name), 189 | name=name, 190 | type="model", 191 | copy=True, 192 | ) 193 | 194 | 195 | @pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) 196 | def test_huggingface_pass_logger(callback): 197 | logger = Live("train_logs") 198 | 199 | assert callback().live is not logger 200 | assert callback(live=logger).live is logger 201 | 202 | 203 | @pytest.mark.parametrize("report_to", ["all", "dvclive", "none"]) 204 | def test_huggingface_report_to(model, report_to): 205 | args = TrainingArguments("foo", report_to=report_to) 206 | trainer = Trainer( 207 | model, 208 | args, 209 | ) 210 | live_cbs = [ 211 | cb 212 | for cb in trainer.callback_handler.callbacks 213 | if isinstance(cb, ExternalCallback) 214 | ] 215 | if report_to == "none": 216 | assert not any(live_cbs) 217 | else: 218 | assert any(live_cbs) 219 | -------------------------------------------------------------------------------- /tests/frameworks/test_keras.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from dvclive import Live 6 | from dvclive.plots.metric import Metric 7 | from dvclive.utils import parse_metrics 8 | 9 | try: 10 | from dvclive.keras import DVCLiveCallback 11 | except ImportError: 12 | pytest.skip("skipping keras tests", allow_module_level=True) 13 | 14 | 15 | @pytest.fixture 16 | def xor_model(): 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | def make(): 21 | x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) 22 | y = np.array([[0], [1], [1], [0]]) 23 | 24 | model = tf.keras.Sequential() 25 | model.add(tf.keras.layers.Dense(8, input_dim=2)) 26 | model.add(tf.keras.layers.Activation("relu")) 27 | model.add(tf.keras.layers.Dense(1)) 28 | model.add(tf.keras.layers.Activation("sigmoid")) 29 | 30 | model.compile(loss="binary_crossentropy", optimizer="sgd", metrics=["accuracy"]) 31 | 32 | return model, x, y 33 | 34 | return make 35 | 36 | 37 | def test_keras_callback(tmp_dir, xor_model, mocker): 38 | model, x, y = xor_model() 39 | 40 | callback = DVCLiveCallback() 41 | live = callback.live 42 | spy = mocker.spy(live, "end") 43 | model.fit( 44 | x, 45 | y, 46 | epochs=1, 47 | batch_size=1, 48 | validation_split=0.2, 49 | callbacks=[callback], 50 | ) 51 | spy.assert_called_once() 52 | 53 | assert os.path.exists("dvclive") 54 | logs, _ = parse_metrics(callback.live) 55 | 56 | scalars = os.path.join(callback.live.plots_dir, Metric.subfolder) 57 | assert os.path.join(scalars, "train", "accuracy.tsv") in logs 58 | assert os.path.join(scalars, "eval", "accuracy.tsv") in logs 59 | 60 | 61 | def test_keras_callback_pass_logger(): 62 | logger = Live("train_logs") 63 | 64 | assert DVCLiveCallback().live is not logger 65 | assert DVCLiveCallback(live=logger).live is logger 66 | -------------------------------------------------------------------------------- /tests/frameworks/test_lgbm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sys import platform 3 | 4 | import pytest 5 | 6 | from dvclive import Live 7 | from dvclive.utils import parse_metrics 8 | 9 | try: 10 | import lightgbm as lgbm 11 | import pandas as pd 12 | from sklearn import datasets 13 | from sklearn.model_selection import train_test_split 14 | 15 | from dvclive.lgbm import DVCLiveCallback 16 | except ImportError: 17 | pytest.skip("skipping lightgbm tests", allow_module_level=True) 18 | 19 | 20 | @pytest.fixture 21 | def model_params(): 22 | return {"objective": "multiclass", "n_estimators": 5, "seed": 0} 23 | 24 | 25 | @pytest.fixture 26 | def iris_data(): 27 | iris = datasets.load_iris() 28 | x = pd.DataFrame(iris["data"], columns=iris["feature_names"]) 29 | y = iris["target"] 30 | x_train, x_test, y_train, y_test = train_test_split( 31 | x, y, test_size=0.33, random_state=42 32 | ) 33 | return (x_train, y_train), (x_test, y_test) 34 | 35 | 36 | @pytest.mark.skipif(platform == "darwin", reason="LIBOMP Segmentation fault on MacOS") 37 | def test_lgbm_integration(tmp_dir, model_params, iris_data): 38 | model = lgbm.LGBMClassifier() 39 | model.set_params(**model_params) 40 | 41 | callback = DVCLiveCallback() 42 | model.fit( 43 | iris_data[0][0], 44 | iris_data[0][1], 45 | eval_set=(iris_data[1][0], iris_data[1][1]), 46 | eval_metric=["multi_logloss"], 47 | callbacks=[callback], 48 | ) 49 | 50 | assert os.path.exists("dvclive") 51 | 52 | logs, _ = parse_metrics(callback.live) 53 | assert "dvclive/plots/metrics/multi_logloss.tsv" in logs 54 | assert len(logs) == 1 55 | assert len(next(iter(logs.values()))) == 5 56 | 57 | 58 | @pytest.mark.skipif(platform == "darwin", reason="LIBOMP Segmentation fault on MacOS") 59 | def test_lgbm_integration_multi_eval(tmp_dir, model_params, iris_data): 60 | model = lgbm.LGBMClassifier() 61 | model.set_params(**model_params) 62 | 63 | callback = DVCLiveCallback() 64 | model.fit( 65 | iris_data[0][0], 66 | iris_data[0][1], 67 | eval_set=[ 68 | (iris_data[0][0], iris_data[0][1]), 69 | (iris_data[1][0], iris_data[1][1]), 70 | ], 71 | eval_metric=["multi_logloss"], 72 | callbacks=[callback], 73 | ) 74 | 75 | assert os.path.exists("dvclive") 76 | 77 | logs, _ = parse_metrics(callback.live) 78 | assert "dvclive/plots/metrics/training/multi_logloss.tsv" in logs 79 | assert "dvclive/plots/metrics/valid_1/multi_logloss.tsv" in logs 80 | assert len(logs) == 2 81 | assert len(next(iter(logs.values()))) == 5 82 | 83 | 84 | def test_lgbm_pass_logger(): 85 | logger = Live("train_logs") 86 | 87 | assert DVCLiveCallback().live is not logger 88 | assert DVCLiveCallback(live=logger).live is logger 89 | -------------------------------------------------------------------------------- /tests/frameworks/test_lightning.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import redirect_stdout 3 | from io import StringIO 4 | from unittest import mock 5 | 6 | import pytest 7 | import yaml 8 | 9 | from dvclive.plots.metric import Metric 10 | from dvclive.serialize import load_yaml 11 | from dvclive.utils import parse_metrics 12 | 13 | try: 14 | import torch 15 | from lightning import LightningModule 16 | from lightning.pytorch import Trainer 17 | from lightning.pytorch.callbacks import ModelCheckpoint 18 | from lightning.pytorch.cli import LightningCLI 19 | from lightning.pytorch.demos.boring_classes import BoringModel 20 | from torch import nn 21 | from torch.nn import functional as F # noqa: N812 22 | from torch.optim import SGD, Adam 23 | from torch.utils.data import DataLoader, Dataset 24 | 25 | from dvclive import Live 26 | from dvclive.lightning import DVCLiveLogger 27 | except ImportError: 28 | pytest.skip("skipping lightning tests", allow_module_level=True) 29 | 30 | 31 | class XORDataset(Dataset): 32 | def __init__(self, *args, **kwargs): 33 | self.ins = [[0, 0], [0, 1], [1, 0], [1, 1]] 34 | self.outs = [1, 0, 0, 1] 35 | 36 | def __getitem__(self, index): 37 | return torch.Tensor(self.ins[index]), torch.tensor( 38 | self.outs[index], dtype=torch.long 39 | ) 40 | 41 | def __len__(self): 42 | return len(self.ins) 43 | 44 | 45 | class LitXOR(LightningModule): 46 | def __init__( 47 | self, 48 | latent_dims=4, 49 | optim=SGD, 50 | optim_params={"lr": 0.01}, # noqa: B006 51 | input_size=[256, 256, 256], # noqa: B006 52 | ): 53 | super().__init__() 54 | 55 | self.save_hyperparameters() 56 | 57 | self.layer_1 = nn.Linear(2, latent_dims) 58 | self.layer_2 = nn.Linear(latent_dims, 2) 59 | 60 | def forward(self, *args, **kwargs): 61 | x = args[0] 62 | batch_size, _ = x.size() 63 | x = x.view(batch_size, -1) 64 | x = self.layer_1(x) 65 | x = F.relu(x) 66 | x = self.layer_2(x) 67 | return F.log_softmax(x, dim=1) 68 | 69 | def train_loader(self): 70 | dataset = XORDataset() 71 | return DataLoader(dataset, batch_size=1) 72 | 73 | def train_dataloader(self): 74 | return self.train_loader() 75 | 76 | def training_step(self, *args, **kwargs): 77 | batch = args[0] 78 | x, y = batch 79 | logits = self(x) 80 | loss = F.nll_loss(logits, y) 81 | self.log( 82 | "train_loss", 83 | loss, 84 | prog_bar=True, 85 | logger=True, 86 | on_step=True, 87 | on_epoch=True, 88 | ) 89 | return loss 90 | 91 | def configure_optimizers(self): 92 | return self.hparams.optim(self.parameters(), **self.hparams.optim_params) 93 | 94 | def predict_dataloader(self): 95 | pass 96 | 97 | def test_dataloader(self): 98 | pass 99 | 100 | def val_dataloader(self): 101 | pass 102 | 103 | 104 | def test_lightning_integration(tmp_dir, mocker): 105 | # init model 106 | model = LitXOR( 107 | latent_dims=8, optim=Adam, optim_params={"lr": 0.02}, input_size=[128, 128, 128] 108 | ) 109 | # init logger 110 | dvclive_logger = DVCLiveLogger("test_run", dir="logs") 111 | live = dvclive_logger.experiment 112 | spy = mocker.spy(live, "end") 113 | trainer = Trainer( 114 | logger=dvclive_logger, 115 | max_epochs=2, 116 | enable_checkpointing=False, 117 | log_every_n_steps=1, 118 | ) 119 | trainer.fit(model) 120 | spy.assert_called_once() 121 | 122 | assert os.path.exists("logs") 123 | assert not os.path.exists("DvcLiveLogger") 124 | 125 | scalars = os.path.join(dvclive_logger.experiment.plots_dir, Metric.subfolder) 126 | logs, _ = parse_metrics(dvclive_logger.experiment) 127 | 128 | assert len(logs) == 3 129 | assert os.path.join(scalars, "train", "epoch", "loss.tsv") in logs 130 | assert os.path.join(scalars, "train", "step", "loss.tsv") in logs 131 | assert os.path.join(scalars, "epoch.tsv") in logs 132 | 133 | params_file = dvclive_logger.experiment.params_file 134 | assert os.path.exists(params_file) 135 | assert load_yaml(params_file) == { 136 | "latent_dims": 8, 137 | "optim": "Adam", 138 | "optim_params": {"lr": 0.02}, 139 | "input_size": [128, 128, 128], 140 | } 141 | 142 | 143 | def test_lightning_default_dir(tmp_dir): 144 | model = LitXOR() 145 | # If `dir` is not provided handle it properly, use default value 146 | dvclive_logger = DVCLiveLogger("test_run") 147 | trainer = Trainer( 148 | logger=dvclive_logger, 149 | max_epochs=2, 150 | enable_checkpointing=False, 151 | log_every_n_steps=1, 152 | ) 153 | trainer.fit(model) 154 | 155 | assert os.path.exists("dvclive") 156 | 157 | 158 | def test_lightning_kwargs(tmp_dir): 159 | model = LitXOR() 160 | # Handle kwargs passed to Live. 161 | dvclive_logger = DVCLiveLogger( 162 | dir="dir", report="md", dvcyaml=False, cache_images=True 163 | ) 164 | trainer = Trainer( 165 | logger=dvclive_logger, 166 | max_epochs=2, 167 | enable_checkpointing=False, 168 | log_every_n_steps=1, 169 | ) 170 | trainer.fit(model) 171 | 172 | assert os.path.exists("dir") 173 | assert os.path.exists("dir/report.md") 174 | assert not os.path.exists("dir/dvc.yaml") 175 | assert dvclive_logger.experiment._cache_images is True 176 | 177 | 178 | @pytest.mark.parametrize("log_model", [False, True, "all"]) 179 | @pytest.mark.parametrize("save_top_k", [1, -1]) 180 | def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): 181 | model = LitXOR() 182 | dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) 183 | checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) 184 | trainer = Trainer( 185 | logger=dvclive_logger, 186 | max_epochs=2, 187 | log_every_n_steps=1, 188 | callbacks=[checkpoint], 189 | ) 190 | log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") 191 | trainer.fit(model) 192 | 193 | # Check that log_artifact is called. 194 | if log_model is False: 195 | log_artifact.assert_not_called() 196 | elif (log_model is True) and (save_top_k != -1): 197 | # called once to cache, then again to log best artifact 198 | assert log_artifact.call_count == 2 199 | else: 200 | # once per epoch plus two calls at the end (see above) 201 | assert log_artifact.call_count == 4 202 | 203 | # Check that checkpoint files does not grow with each run. 204 | num_checkpoints = len(os.listdir(tmp_dir / "model")) 205 | if log_model in [True, "all"]: 206 | trainer.fit(model) 207 | assert len(os.listdir(tmp_dir / "model")) == num_checkpoints 208 | log_artifact.assert_any_call( 209 | checkpoint.best_model_path, name="best", type="model", copy=True 210 | ) 211 | 212 | 213 | def test_lightning_steps(tmp_dir, mocker): 214 | model = LitXOR() 215 | # Handle kwargs passed to Live. 216 | dvclive_logger = DVCLiveLogger(dir="logs") 217 | live = dvclive_logger.experiment 218 | spy = mocker.spy(live, "sync") 219 | trainer = Trainer( 220 | logger=dvclive_logger, 221 | max_epochs=2, 222 | enable_checkpointing=False, 223 | # Log one time in the middle of the epoch 224 | log_every_n_steps=3, 225 | ) 226 | trainer.fit(model) 227 | 228 | history, latest = parse_metrics(dvclive_logger.experiment) 229 | assert latest["step"] == 7 230 | assert latest["epoch"] == 1 231 | 232 | scalars = os.path.join(dvclive_logger.experiment.plots_dir, Metric.subfolder) 233 | epoch_loss = history[os.path.join(scalars, "train", "epoch", "loss.tsv")] 234 | step_loss = history[os.path.join(scalars, "train", "step", "loss.tsv")] 235 | assert len(epoch_loss) == 2 236 | assert len(step_loss) == 2 237 | 238 | # call sync: 239 | # - 2x epoch end 240 | # - 2x log_every_n_steps 241 | # - 1x experiment end 242 | assert spy.call_count == 5 243 | 244 | 245 | class ValLitXOR(LitXOR): 246 | def val_loader(self): 247 | dataset = XORDataset() 248 | return DataLoader(dataset, batch_size=1) 249 | 250 | def val_dataloader(self): 251 | return self.val_loader() 252 | 253 | def training_step(self, *args, **kwargs): 254 | batch = args[0] 255 | x, y = batch 256 | logits = self(x) 257 | loss = F.nll_loss(logits, y) 258 | self.log("train_loss", loss, on_step=True) 259 | return loss 260 | 261 | def validation_step(self, *args, **kwargs): 262 | batch = args[0] 263 | x, y = batch 264 | logits = self(x) 265 | loss = F.nll_loss(logits, y) 266 | self.log("val_loss", loss, on_step=False, on_epoch=True) 267 | return loss 268 | 269 | 270 | def test_lightning_force_init(tmp_dir, mocker): 271 | """Related to https://github.com/iterative/dvclive/issues/594 272 | Don't call Live.__init__ on rank-nonzero processes. 273 | """ 274 | init = mocker.spy(Live, "__init__") 275 | DVCLiveLogger() 276 | init.assert_not_called() 277 | 278 | 279 | # LightningCLI tests 280 | # Copied from https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/tests/tests_pytorch/test_cli.py 281 | class TestModel(BoringModel): 282 | def __init__(self, foo, bar=5): 283 | super().__init__() 284 | self.foo = foo 285 | self.bar = bar 286 | 287 | 288 | def _test_logger_init_args(logger_name, init, unresolved={}): # noqa: B006 289 | cli_args = [f"--trainer.logger={logger_name}"] 290 | cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()] 291 | cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()] 292 | cli_args.append("--print_config") 293 | 294 | out = StringIO() 295 | with ( 296 | mock.patch( 297 | "sys.argv", 298 | ["any.py"] + cli_args, # noqa: RUF005 299 | ), 300 | redirect_stdout( # noqa: RUF100 301 | out 302 | ), 303 | pytest.raises(SystemExit), 304 | ): 305 | LightningCLI(TestModel, run=False) 306 | 307 | data = yaml.safe_load(out.getvalue())["trainer"]["logger"] 308 | assert {k: data["init_args"][k] for k in init} == init 309 | if unresolved: 310 | assert data["dict_kwargs"] == unresolved 311 | 312 | 313 | def test_dvclive_logger_init_args(): 314 | _test_logger_init_args( 315 | "dvclive.lightning.DVCLiveLogger", 316 | { 317 | "run_name": "test_run", # Resolve from DVCLiveLogger.__init__ 318 | "dir": "results", # Resolve from Live.__init__ 319 | }, 320 | ) 321 | -------------------------------------------------------------------------------- /tests/frameworks/test_optuna.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dvclive.serialize import load_yaml 4 | from dvclive.utils import parse_json 5 | 6 | try: 7 | import optuna 8 | 9 | from dvclive.optuna import DVCLiveCallback 10 | except ImportError: 11 | pytest.skip("skipping optuna tests", allow_module_level=True) 12 | 13 | 14 | def objective(trial): 15 | x = trial.suggest_float("x", -10, 10) 16 | return (x - 2) ** 2 17 | 18 | 19 | def test_optuna_(tmp_dir, mocked_dvc_repo): 20 | n_trials = 5 21 | metric_name = "custom_name" 22 | callback = DVCLiveCallback(metric_name=metric_name) 23 | study = optuna.create_study() 24 | 25 | study.optimize(objective, n_trials=n_trials, callbacks=[callback]) 26 | 27 | assert mocked_dvc_repo.experiments.save.call_count == n_trials 28 | 29 | metrics = parse_json("dvclive-optuna/metrics.json") 30 | assert metric_name in metrics 31 | params = load_yaml("dvclive-optuna/params.yaml") 32 | assert "x" in params 33 | 34 | assert not (tmp_dir / "dvclive-optuna" / "plots").exists() 35 | -------------------------------------------------------------------------------- /tests/frameworks/test_xgboost.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext 3 | 4 | import pytest 5 | 6 | from dvclive import Live 7 | from dvclive.plots.metric import Metric 8 | from dvclive.utils import parse_metrics 9 | 10 | try: 11 | import pandas as pd 12 | import xgboost as xgb 13 | from sklearn import datasets 14 | from sklearn.model_selection import train_test_split 15 | 16 | from dvclive.xgb import DVCLiveCallback 17 | except ImportError: 18 | pytest.skip("skipping xgboost tests", allow_module_level=True) 19 | 20 | 21 | @pytest.fixture 22 | def train_params(): 23 | return {"objective": "multi:softmax", "num_class": 3, "seed": 0} 24 | 25 | 26 | @pytest.fixture 27 | def iris_data(): 28 | iris = datasets.load_iris() 29 | x = pd.DataFrame(iris["data"], columns=iris["feature_names"]) 30 | y = iris["target"] 31 | return xgb.DMatrix(x, y) 32 | 33 | 34 | @pytest.fixture 35 | def iris_train_eval_data(): 36 | iris = datasets.load_iris() 37 | x_train, x_eval, y_train, y_eval = train_test_split( 38 | iris.data, iris.target, random_state=0 39 | ) 40 | return (xgb.DMatrix(x_train, y_train), xgb.DMatrix(x_eval, y_eval)) 41 | 42 | 43 | @pytest.mark.parametrize( 44 | ("metric_data", "subdirs", "context"), 45 | [ 46 | ( 47 | "eval", 48 | ("",), 49 | pytest.warns(DeprecationWarning, match="`metric_data`.+deprecated"), 50 | ), 51 | (None, ("train", "eval"), nullcontext()), 52 | ], 53 | ) 54 | def test_xgb_integration( 55 | tmp_dir, train_params, iris_train_eval_data, metric_data, subdirs, context, mocker 56 | ): 57 | with context: 58 | callback = DVCLiveCallback(metric_data) 59 | live = callback.live 60 | spy = mocker.spy(live, "end") 61 | data_train, data_eval = iris_train_eval_data 62 | xgb.train( 63 | train_params, 64 | data_train, 65 | callbacks=[callback], 66 | num_boost_round=5, 67 | evals=[(data_train, "train"), (data_eval, "eval")], 68 | ) 69 | spy.assert_called_once() 70 | 71 | assert os.path.exists("dvclive") 72 | 73 | logs, _ = parse_metrics(callback.live) 74 | assert len(logs) == len(subdirs) 75 | assert list(map(len, logs.values())) == [5] * len(logs) 76 | scalars = os.path.join(callback.live.plots_dir, Metric.subfolder) 77 | assert all( 78 | os.path.join(scalars, subdir, "mlogloss.tsv") in logs for subdir in subdirs 79 | ) 80 | 81 | 82 | def test_xgb_pass_logger(): 83 | logger = Live("train_logs") 84 | 85 | assert DVCLiveCallback("eval_data").live is not logger 86 | assert DVCLiveCallback("eval_data", live=logger).live is logger 87 | -------------------------------------------------------------------------------- /tests/plots/test_custom.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from dvclive import Live 4 | from dvclive.plots.custom import CustomPlot 5 | 6 | 7 | def test_log_custom_plot(tmp_dir): 8 | live = Live() 9 | out = tmp_dir / live.plots_dir / CustomPlot.subfolder 10 | 11 | datapoints = [{"x": 1, "y": 2}, {"x": 3, "y": 4}] 12 | live.log_plot( 13 | "custom_linear", 14 | datapoints, 15 | x="x", 16 | y="y", 17 | template="linear", 18 | title="custom_title", 19 | x_label="x_label", 20 | y_label="y_label", 21 | ) 22 | 23 | assert json.loads((out / "custom_linear.json").read_text()) == datapoints 24 | assert live._plots["custom_linear"].plot_config == { 25 | "template": "linear", 26 | "title": "custom_title", 27 | "x": "x", 28 | "y": "y", 29 | "x_label": "x_label", 30 | "y_label": "y_label", 31 | } 32 | 33 | 34 | def test_log_custom_plot_multi_y(tmp_dir): 35 | live = Live() 36 | out = tmp_dir / live.plots_dir / CustomPlot.subfolder 37 | 38 | datapoints = [{"x": 1, "y1": 2, "y2": 3}, {"x": 4, "y1": 5, "y2": 6}] 39 | live.log_plot( 40 | "custom_linear", 41 | datapoints, 42 | x="x", 43 | y=["y1", "y2"], 44 | template="linear", 45 | title="custom_title", 46 | x_label="x_label", 47 | y_label="y_label", 48 | ) 49 | 50 | assert json.loads((out / "custom_linear.json").read_text()) == datapoints 51 | assert live._plots["custom_linear"].plot_config == { 52 | "template": "linear", 53 | "title": "custom_title", 54 | "x": "x", 55 | "y": ["y1", "y2"], 56 | "x_label": "x_label", 57 | "y_label": "y_label", 58 | } 59 | 60 | 61 | def test_log_custom_plot_with_template_as_empty_string(tmp_dir): 62 | live = Live() 63 | out = tmp_dir / live.plots_dir / CustomPlot.subfolder 64 | 65 | datapoints = [{"x": 1, "y": 2}, {"x": 3, "y": 4}] 66 | live.log_plot( 67 | "custom_linear", 68 | datapoints, 69 | x="x", 70 | y="y", 71 | template="", 72 | title="custom_title", 73 | x_label="x_label", 74 | y_label="y_label", 75 | ) 76 | 77 | assert json.loads((out / "custom_linear.json").read_text()) == datapoints 78 | # 'template' should not be in plot_config. Default template will be assigned later. 79 | assert live._plots["custom_linear"].plot_config == { 80 | "title": "custom_title", 81 | "x": "x", 82 | "y": "y", 83 | "x_label": "x_label", 84 | "y_label": "y_label", 85 | } 86 | -------------------------------------------------------------------------------- /tests/plots/test_image.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytest 4 | from PIL import Image 5 | 6 | from dvclive import Live 7 | from dvclive.error import InvalidImageNameError 8 | from dvclive.plots import Image as LiveImage 9 | 10 | 11 | # From https://stackoverflow.com/questions/5165317/how-can-i-extend-image-class 12 | class ExtendedImage(Image.Image): 13 | def __init__(self, img): 14 | self._img = img 15 | 16 | def __getattr__(self, key): 17 | return getattr(self._img, key) 18 | 19 | 20 | def test_pil(tmp_dir): 21 | live = Live() 22 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 23 | live.log_image("image.png", img) 24 | 25 | assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() 26 | 27 | 28 | def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir): 29 | live = Live() 30 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 31 | with pytest.raises( 32 | InvalidImageNameError, match="Cannot log image with name 'whoops'" 33 | ): 34 | live.log_image("whoops", img) 35 | 36 | 37 | def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir): 38 | live = Live() 39 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 40 | 41 | # Save it first, we'll reload it and pass it's path to log_image again 42 | live.log_image("saved_with_format.png", img) 43 | 44 | # Now try saving without explicit format and check if the format is set correctly. 45 | live.log_image( 46 | "whoops", 47 | (tmp_dir / live.plots_dir / LiveImage.subfolder / "saved_with_format.png"), 48 | ) 49 | 50 | assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "whoops.png").exists() 51 | 52 | 53 | def test_invalid_extension(tmp_dir): 54 | live = Live() 55 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 56 | with pytest.raises( 57 | InvalidImageNameError, match="Cannot log image with name 'image.foo'" 58 | ): 59 | live.log_image("image.foo", img) 60 | 61 | 62 | @pytest.mark.parametrize("shape", [(10, 10), (10, 10, 3), (10, 10, 4)]) 63 | def test_numpy(tmp_dir, shape): 64 | from PIL import Image as ImagePIL 65 | 66 | live = Live() 67 | img = np.ones(shape, np.uint8) * 255 68 | live.log_image("image.png", img) 69 | 70 | img_path = tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png" 71 | assert img_path.exists() 72 | 73 | val = np.asarray(ImagePIL.open(img_path)) 74 | assert np.array_equal(val, img) 75 | 76 | 77 | def test_path(tmp_dir): 78 | import numpy as np 79 | from PIL import Image as ImagePIL 80 | 81 | live = Live() 82 | image_data = np.random.randint(0, 255, (100, 100, 3)).astype(np.uint8) 83 | pil_image = ImagePIL.fromarray(image_data) 84 | image_path = tmp_dir / "temp.png" 85 | pil_image.save(image_path) 86 | 87 | live = Live() 88 | live.log_image("foo.png", image_path) 89 | live.end() 90 | 91 | plot_file = tmp_dir / live.plots_dir / "images" / "foo.png" 92 | assert plot_file.exists() 93 | 94 | val = np.asarray(ImagePIL.open(plot_file)) 95 | assert np.array_equal(val, image_data) 96 | 97 | 98 | def test_override_on_step(tmp_dir): 99 | live = Live() 100 | 101 | zeros = np.zeros((2, 2, 3), np.uint8) 102 | live.log_image("image.png", zeros) 103 | 104 | live.next_step() 105 | 106 | ones = np.ones((2, 2, 3), np.uint8) 107 | live.log_image("image.png", ones) 108 | 109 | img_path = tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png" 110 | assert np.array_equal(np.array(Image.open(img_path)), ones) 111 | 112 | 113 | def test_cleanup(tmp_dir): 114 | live = Live() 115 | img = np.ones((10, 10, 3), np.uint8) 116 | live.log_image("image.png", img) 117 | 118 | assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() 119 | 120 | Live() 121 | 122 | assert not (tmp_dir / live.plots_dir / LiveImage.subfolder).exists() 123 | 124 | 125 | def test_custom_class(tmp_dir): 126 | live = Live() 127 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 128 | extended_img = ExtendedImage(img) 129 | live.log_image("image.png", extended_img) 130 | 131 | assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() 132 | 133 | 134 | def test_matplotlib(tmp_dir): 135 | live = Live() 136 | fig, ax = plt.subplots() 137 | ax.plot([1, 2, 3, 4]) 138 | 139 | assert plt.fignum_exists(fig.number) 140 | 141 | live.log_image("image.png", fig) 142 | 143 | assert not plt.fignum_exists(fig.number) 144 | 145 | assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() 146 | 147 | 148 | @pytest.mark.parametrize("cache", [False, True]) 149 | def test_cache_images(tmp_dir, dvc_repo, cache): 150 | live = Live(save_dvc_exp=False, cache_images=cache) 151 | img = Image.new("RGB", (10, 10), (250, 250, 250)) 152 | live.log_image("image.png", img) 153 | live.end() 154 | assert (tmp_dir / "dvclive" / "plots" / "images.dvc").exists() == cache 155 | -------------------------------------------------------------------------------- /tests/plots/test_metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from dvclive import Live 7 | from dvclive.plots.metric import Metric 8 | from dvclive.plots.utils import NUMPY_INTS, NUMPY_SCALARS 9 | from dvclive.utils import parse_tsv 10 | 11 | 12 | @pytest.mark.parametrize("dtype", NUMPY_SCALARS) 13 | def test_numpy(tmp_dir, dtype): 14 | scalar = np.random.rand(1).astype(dtype)[0] 15 | live = Live() 16 | 17 | live.log_metric("scalar", scalar) 18 | live.next_step() 19 | 20 | parsed = json.loads((tmp_dir / live.metrics_file).read_text()) 21 | assert isinstance(parsed["scalar"], int if dtype in NUMPY_INTS else float) 22 | tsv_file = tmp_dir / live.plots_dir / Metric.subfolder / "scalar.tsv" 23 | tsv_val = parse_tsv(tsv_file)[0]["scalar"] 24 | assert tsv_val == str(scalar) 25 | 26 | 27 | def test_name_with_dot(tmp_dir): 28 | """Regression test for #284""" 29 | live = Live() 30 | 31 | live.log_metric("scalar.foo.bar", 1.0) 32 | live.next_step() 33 | 34 | tsv_file = tmp_dir / live.plots_dir / Metric.subfolder / "scalar.foo.bar.tsv" 35 | assert tsv_file.exists() 36 | tsv_val = parse_tsv(tsv_file)[0]["scalar.foo.bar"] 37 | assert tsv_val == "1.0" 38 | -------------------------------------------------------------------------------- /tests/plots/test_sklearn.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: N806 2 | import json 3 | 4 | import pytest 5 | from sklearn import calibration, metrics 6 | 7 | from dvclive import Live 8 | from dvclive.plots.sklearn import SKLearnPlot 9 | 10 | 11 | @pytest.fixture 12 | def y_true_y_pred_y_score(): 13 | from sklearn.datasets import make_classification 14 | from sklearn.ensemble import RandomForestClassifier 15 | from sklearn.model_selection import train_test_split 16 | 17 | X, y = make_classification(random_state=0) 18 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 19 | clf = RandomForestClassifier(random_state=0) 20 | clf.fit(X_train, y_train) 21 | 22 | y_pred = clf.predict(X_test) 23 | y_score = clf.predict_proba(X_test)[:, 1] 24 | 25 | return y_test, y_pred, y_score 26 | 27 | 28 | def test_log_calibration_curve(tmp_dir, y_true_y_pred_y_score, mocker): 29 | live = Live() 30 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 31 | 32 | y_true, _, y_score = y_true_y_pred_y_score 33 | 34 | spy = mocker.spy(calibration, "calibration_curve") 35 | 36 | live.log_sklearn_plot("calibration", y_true, y_score) 37 | 38 | spy.assert_called_once_with(y_true, y_score) 39 | 40 | assert (out / "calibration.json").exists() 41 | 42 | 43 | def test_log_det_curve(tmp_dir, y_true_y_pred_y_score, mocker): 44 | live = Live() 45 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 46 | 47 | y_true, _, y_score = y_true_y_pred_y_score 48 | 49 | spy = mocker.spy(metrics, "det_curve") 50 | 51 | live.log_sklearn_plot("det", y_true, y_score) 52 | 53 | spy.assert_called_once_with(y_true, y_score) 54 | assert (out / "det.json").exists() 55 | 56 | 57 | def test_log_roc_curve(tmp_dir, y_true_y_pred_y_score, mocker): 58 | live = Live() 59 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 60 | 61 | y_true, _, y_score = y_true_y_pred_y_score 62 | 63 | spy = mocker.spy(metrics, "roc_curve") 64 | 65 | live.log_sklearn_plot("roc", y_true, y_score) 66 | 67 | spy.assert_called_once_with(y_true, y_score) 68 | assert (out / "roc.json").exists() 69 | 70 | 71 | def test_log_prc_curve(tmp_dir, y_true_y_pred_y_score, mocker): 72 | live = Live() 73 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 74 | 75 | y_true, _, y_score = y_true_y_pred_y_score 76 | 77 | spy = mocker.spy(metrics, "precision_recall_curve") 78 | 79 | live.log_sklearn_plot("precision_recall", y_true, y_score) 80 | 81 | spy.assert_called_once_with(y_true=y_true, probas_pred=y_score) 82 | assert (out / "precision_recall.json").exists() 83 | 84 | 85 | def test_log_confusion_matrix(tmp_dir, y_true_y_pred_y_score, mocker): 86 | live = Live() 87 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 88 | 89 | y_true, y_pred, _ = y_true_y_pred_y_score 90 | 91 | live.log_sklearn_plot("confusion_matrix", y_true, y_pred) 92 | 93 | cm = json.loads((out / "confusion_matrix.json").read_text()) 94 | 95 | assert isinstance(cm, list) 96 | assert isinstance(cm[0], dict) 97 | assert cm[0]["actual"] == str(y_true[0]) 98 | assert cm[0]["predicted"] == str(y_pred[0]) 99 | 100 | 101 | def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker): 102 | live = Live() 103 | 104 | y_true, _, y_score = y_true_y_pred_y_score 105 | 106 | spy = mocker.spy(metrics, "roc_curve") 107 | 108 | live.log_sklearn_plot("roc", y_true, y_score, drop_intermediate=True) 109 | 110 | spy.assert_called_once_with(y_true, y_score, drop_intermediate=True) 111 | 112 | 113 | def test_override_on_step(tmp_dir): 114 | live = Live() 115 | 116 | live.log_sklearn_plot("confusion_matrix", [0, 0], [0, 0]) 117 | live.next_step() 118 | live.log_sklearn_plot("confusion_matrix", [0, 0], [1, 1]) 119 | 120 | plot_path = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 121 | plot_path = plot_path / "confusion_matrix.json" 122 | 123 | assert json.loads(plot_path.read_text()) == [ 124 | {"actual": "0", "predicted": "1"}, 125 | {"actual": "0", "predicted": "1"}, 126 | ] 127 | 128 | 129 | def test_cleanup(tmp_dir, y_true_y_pred_y_score): 130 | live = Live() 131 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 132 | 133 | y_true, y_pred, _ = y_true_y_pred_y_score 134 | 135 | live.log_sklearn_plot("confusion_matrix", y_true, y_pred) 136 | 137 | assert (out / "confusion_matrix.json").exists() 138 | 139 | Live() 140 | 141 | assert not (tmp_dir / live.plots_dir / SKLearnPlot.subfolder).exists() 142 | 143 | 144 | def test_custom_name(tmp_dir, y_true_y_pred_y_score): 145 | live = Live() 146 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 147 | 148 | y_true, y_pred, _ = y_true_y_pred_y_score 149 | 150 | live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="train/cm") 151 | live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="val/cm") 152 | # ".json" should be stripped from the name 153 | live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="cm.json") 154 | 155 | assert (out / "train" / "cm.json").exists() 156 | assert (out / "val" / "cm.json").exists() 157 | assert (out / "cm.json").exists() 158 | 159 | 160 | def test_custom_title(tmp_dir, y_true_y_pred_y_score): 161 | """https://github.com/iterative/dvclive/issues/453""" 162 | live = Live() 163 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 164 | 165 | y_true, y_pred, y_score = y_true_y_pred_y_score 166 | 167 | live.log_sklearn_plot( 168 | "confusion_matrix", 169 | y_true, 170 | y_pred, 171 | name="train/cm", 172 | title="Train Confusion Matrix", 173 | ) 174 | live.log_sklearn_plot( 175 | "confusion_matrix", y_true, y_pred, name="val/cm", title="Val Confusion Matrix" 176 | ) 177 | live.log_sklearn_plot( 178 | "precision_recall", 179 | y_true, 180 | y_score, 181 | name="val/prc", 182 | title="Val Precision Recall", 183 | ) 184 | assert (out / "train" / "cm.json").exists() 185 | assert (out / "val" / "cm.json").exists() 186 | assert (out / "val" / "prc.json").exists() 187 | 188 | assert live._plots["train/cm"].plot_config["title"] == "Train Confusion Matrix" 189 | assert live._plots["val/cm"].plot_config["title"] == "Val Confusion Matrix" 190 | assert live._plots["val/prc"].plot_config["title"] == "Val Precision Recall" 191 | 192 | 193 | def test_custom_labels(tmp_dir, y_true_y_pred_y_score): 194 | """https://github.com/iterative/dvclive/issues/453""" 195 | live = Live() 196 | out = tmp_dir / live.plots_dir / SKLearnPlot.subfolder 197 | 198 | y_true, _, y_score = y_true_y_pred_y_score 199 | 200 | live.log_sklearn_plot( 201 | "precision_recall", 202 | y_true, 203 | y_score, 204 | name="val/prc", 205 | x_label="x_test", 206 | y_label="y_test", 207 | ) 208 | assert (out / "val" / "prc.json").exists() 209 | 210 | assert live._plots["val/prc"].plot_config["x_label"] == "x_test" 211 | assert live._plots["val/prc"].plot_config["y_label"] == "y_test" 212 | -------------------------------------------------------------------------------- /tests/test_cleanup.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dvclive import Live 4 | from dvclive.plots import Metric 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "html", 9 | [True, False], 10 | ) 11 | @pytest.mark.parametrize( 12 | "dvcyaml", 13 | ["dvc.yaml", "logs/dvc.yaml"], 14 | ) 15 | def test_cleanup(tmp_dir, html, dvcyaml): 16 | dvclive = Live("logs", report="html" if html else None, dvcyaml=dvcyaml) 17 | dvclive.log_metric("m1", 1) 18 | dvclive.next_step() 19 | 20 | html_path = tmp_dir / dvclive.dir / "report.html" 21 | if html: 22 | html_path.touch() 23 | 24 | (tmp_dir / "logs" / "some_user_file.txt").touch() 25 | (tmp_dir / "dvc.yaml").touch() 26 | 27 | assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() 28 | assert (tmp_dir / dvclive.metrics_file).is_file() 29 | assert (tmp_dir / dvclive.dvc_file).is_file() 30 | assert html_path.is_file() == html 31 | 32 | dvclive = Live("logs") 33 | 34 | assert (tmp_dir / "logs" / "some_user_file.txt").is_file() 35 | assert not (tmp_dir / dvclive.plots_dir / Metric.subfolder).exists() 36 | assert not (tmp_dir / dvclive.metrics_file).is_file() 37 | if dvcyaml == "dvc.yaml": 38 | assert (tmp_dir / dvcyaml).is_file() 39 | if dvcyaml == "logs/dvc.yaml": 40 | assert not (tmp_dir / dvcyaml).is_file() 41 | assert not (html_path).is_file() 42 | -------------------------------------------------------------------------------- /tests/test_context_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from dvclive import Live 4 | from dvclive.plots import Metric 5 | 6 | 7 | def test_context_manager(tmp_dir): 8 | with Live(report="html") as live: 9 | live.summary["foo"] = 1.0 10 | 11 | assert json.loads((tmp_dir / live.metrics_file).read_text()) == { 12 | # no `step` 13 | "foo": 1.0 14 | } 15 | log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" 16 | assert not log_file.exists() 17 | report_file = tmp_dir / live.report_file 18 | assert report_file.exists() 19 | 20 | 21 | def test_context_manager_skips_end_calls(tmp_dir): 22 | with Live() as live: 23 | live.summary["foo"] = 1.0 24 | live.end() 25 | assert not (tmp_dir / live.metrics_file).exists() 26 | assert (tmp_dir / live.metrics_file).exists() 27 | -------------------------------------------------------------------------------- /tests/test_dvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from dvc.exceptions import DvcException 5 | from dvc.repo import Repo 6 | from dvc.scm import NoSCM 7 | from scmrepo.git import Git 8 | 9 | from dvclive import Live 10 | from dvclive.dvc import get_dvc_repo 11 | from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT, DVCLIVE_TEST 12 | 13 | 14 | def test_get_dvc_repo(tmp_dir): 15 | assert get_dvc_repo() is None 16 | Git.init(tmp_dir) 17 | assert isinstance(get_dvc_repo(), Repo) 18 | 19 | 20 | def test_get_dvc_repo_subdir(tmp_dir): 21 | Git.init(tmp_dir) 22 | subdir = tmp_dir / "sub" 23 | subdir.mkdir() 24 | os.chdir(subdir) 25 | assert get_dvc_repo().root_dir == str(tmp_dir) 26 | 27 | 28 | @pytest.mark.parametrize("save", [True, False]) 29 | def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): 30 | live = Live(save_dvc_exp=save) 31 | live.end() 32 | assert live._baseline_rev is not None 33 | assert live._exp_name is not None 34 | if save: 35 | mocked_dvc_repo.experiments.save.assert_called_with( 36 | name=live._exp_name, 37 | include_untracked=[live.dir, "dvc.yaml"], 38 | force=True, 39 | message=None, 40 | ) 41 | else: 42 | mocked_dvc_repo.experiments.save.assert_not_called() 43 | 44 | 45 | def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch): 46 | monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo") 47 | monkeypatch.setenv(DVC_EXP_NAME, "bar") 48 | monkeypatch.setenv(DVC_ROOT, tmp_dir) 49 | 50 | live = Live() 51 | live.end() 52 | 53 | assert live._dvc_repo is None 54 | assert live._baseline_rev == "foo" 55 | assert live._exp_name == "bar" 56 | assert live._inside_dvc_exp 57 | assert live._inside_dvc_pipeline 58 | 59 | 60 | def test_exp_save_with_dvc_files(tmp_dir, mocker): 61 | dvc_repo = mocker.MagicMock() 62 | dvc_file = mocker.MagicMock() 63 | dvc_file.is_data_source = True 64 | dvc_repo.index.stages = [dvc_file] 65 | dvc_repo.scm.get_rev.return_value = "current_rev" 66 | dvc_repo.scm.get_ref.return_value = None 67 | dvc_repo.scm.no_commits = False 68 | dvc_repo.root_dir = tmp_dir 69 | dvc_repo.config = {} 70 | 71 | mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) 72 | live = Live() 73 | live.end() 74 | 75 | dvc_repo.experiments.save.assert_called_with( 76 | name=live._exp_name, 77 | include_untracked=[live.dir, "dvc.yaml"], 78 | force=True, 79 | message=None, 80 | ) 81 | 82 | 83 | def test_exp_save_dvcexception_is_ignored(tmp_dir, mocker): 84 | from dvc.exceptions import DvcException 85 | 86 | dvc_repo = mocker.MagicMock() 87 | dvc_repo.index.stages = [] 88 | dvc_repo.scm.get_rev.return_value = "current_rev" 89 | dvc_repo.scm.get_ref.return_value = None 90 | dvc_repo.config = {} 91 | dvc_repo.experiments.save.side_effect = DvcException("foo") 92 | mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) 93 | 94 | with Live(): 95 | pass 96 | 97 | 98 | def test_untracked_dvclive_files_inside_dvc_exp_run_are_added( 99 | tmp_dir, mocked_dvc_repo, monkeypatch 100 | ): 101 | monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo") 102 | monkeypatch.setenv(DVC_EXP_NAME, "bar") 103 | monkeypatch.setenv(DVC_ROOT, tmp_dir) 104 | plot_file = os.path.join("dvclive", "plots", "metrics", "foo.tsv") 105 | mocked_dvc_repo.scm.untracked_files.return_value = [ 106 | "dvclive/metrics.json", 107 | plot_file, 108 | ] 109 | with Live() as live: 110 | live.log_metric("foo", 1) 111 | live.next_step() 112 | live._dvc_repo.scm.add.assert_any_call(["dvclive/metrics.json", plot_file]) 113 | live._dvc_repo.scm.add.assert_any_call(live.dvc_file) 114 | 115 | 116 | def test_dvc_outs_are_not_added(tmp_dir, mocked_dvc_repo, monkeypatch): 117 | """Regression test for https://github.com/iterative/dvclive/issues/516""" 118 | monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo") 119 | monkeypatch.setenv(DVC_EXP_NAME, "bar") 120 | monkeypatch.setenv(DVC_ROOT, tmp_dir) 121 | mocked_dvc_repo.index.outs = ["dvclive/plots"] 122 | plot_file = os.path.join("dvclive", "plots", "metrics", "foo.tsv") 123 | mocked_dvc_repo.scm.untracked_files.return_value = [ 124 | "dvclive/metrics.json", 125 | plot_file, 126 | ] 127 | 128 | with Live() as live: 129 | live.log_metric("foo", 1) 130 | live.next_step() 131 | 132 | live._dvc_repo.scm.add.assert_any_call(["dvclive/metrics.json"]) 133 | 134 | 135 | def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): 136 | monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo") 137 | monkeypatch.setenv(DVC_EXP_NAME, "bar") 138 | mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"] 139 | mocked_dvc_repo.scm.add.side_effect = DvcException("foo") 140 | 141 | with Live() as live: 142 | live.summary["foo"] = 1 143 | 144 | 145 | def test_exp_save_message(tmp_dir, mocked_dvc_repo): 146 | live = Live(exp_message="Custom message") 147 | live.end() 148 | mocked_dvc_repo.experiments.save.assert_called_with( 149 | name=live._exp_name, 150 | include_untracked=[live.dir, "dvc.yaml"], 151 | force=True, 152 | message="Custom message", 153 | ) 154 | 155 | 156 | def test_exp_save_name(tmp_dir, mocked_dvc_repo): 157 | live = Live(exp_name="custom-name") 158 | live.end() 159 | mocked_dvc_repo.experiments.save.assert_called_with( 160 | name="custom-name", 161 | include_untracked=[live.dir, "dvc.yaml"], 162 | force=True, 163 | message=None, 164 | ) 165 | 166 | 167 | def test_no_scm_repo(tmp_dir, mocker): 168 | dvc_repo = mocker.MagicMock() 169 | dvc_repo.scm = NoSCM() 170 | 171 | mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) 172 | live = Live() 173 | assert live._dvc_repo == dvc_repo 174 | 175 | live = Live() 176 | assert live._save_dvc_exp is False 177 | 178 | 179 | def test_dvc_repro(tmp_dir, monkeypatch, mocked_dvc_repo, mocked_studio_post): 180 | monkeypatch.setenv(DVC_ROOT, "root") 181 | live = Live(save_dvc_exp=True) 182 | assert live._baseline_rev is not None 183 | assert live._exp_name is not None 184 | assert not live._studio_events_to_skip 185 | assert not live._save_dvc_exp 186 | 187 | 188 | def test_get_exp_name_valid(tmp_dir, mocked_dvc_repo): 189 | live = Live(exp_name="name") 190 | assert live._exp_name == "name" 191 | 192 | 193 | def test_get_exp_name_random(tmp_dir, mocked_dvc_repo, mocker): 194 | mocker.patch( 195 | "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" 196 | ) 197 | live = Live() 198 | assert live._exp_name == "random" 199 | 200 | 201 | def test_get_exp_name_invalid(tmp_dir, mocked_dvc_repo, mocker, caplog): 202 | mocker.patch( 203 | "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" 204 | ) 205 | with caplog.at_level("WARNING"): 206 | live = Live(exp_name="invalid//name") 207 | assert live._exp_name == "random" 208 | assert caplog.text 209 | 210 | 211 | def test_get_exp_name_duplicate(tmp_dir, mocked_dvc_repo, mocker, caplog): 212 | mocker.patch( 213 | "dvc.repo.experiments.utils.get_random_exp_name", return_value="random" 214 | ) 215 | mocked_dvc_repo.scm.get_ref.return_value = "duplicate" 216 | with caplog.at_level("WARNING"): 217 | live = Live(exp_name="duplicate") 218 | assert live._exp_name == "random" 219 | msg = "Experiment conflicts with existing experiment 'duplicate'." 220 | assert msg in caplog.text 221 | 222 | 223 | def test_test_mode(tmp_dir, monkeypatch, mocked_dvc_repo): 224 | monkeypatch.setenv(DVCLIVE_TEST, "true") 225 | live = Live("dir", dvcyaml="dvc.yaml") 226 | live.make_dvcyaml() 227 | assert live._dir != "dir" 228 | assert live._dvc_file != "dvc.yaml" 229 | assert live._save_dvc_exp is False 230 | assert not os.path.exists("dir") 231 | assert not os.path.exists("dvc.yaml") 232 | -------------------------------------------------------------------------------- /tests/test_log_metric.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from dvclive import Live 8 | from dvclive.error import InvalidDataTypeError 9 | from dvclive.plots import Metric 10 | from dvclive.serialize import load_yaml 11 | from dvclive.utils import parse_metrics, parse_tsv 12 | 13 | 14 | def test_logging_no_step(tmp_dir): 15 | dvclive = Live("logs") 16 | 17 | dvclive.log_metric("m1", 1, plot=False) 18 | dvclive.make_summary() 19 | 20 | assert not (tmp_dir / "logs" / "plots" / "metrics" / "m1.tsv").is_file() 21 | assert (tmp_dir / dvclive.metrics_file).is_file() 22 | 23 | s = load_yaml(dvclive.metrics_file) 24 | assert s["m1"] == 1 25 | assert "step" not in s 26 | 27 | 28 | @pytest.mark.parametrize("path", ["logs", os.path.join("subdir", "logs")]) 29 | def test_logging_step(tmp_dir, path): 30 | dvclive = Live(path) 31 | dvclive.log_metric("m1", 1) 32 | dvclive.next_step() 33 | assert (tmp_dir / dvclive.dir).is_dir() 34 | assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() 35 | assert (tmp_dir / dvclive.metrics_file).is_file() 36 | 37 | s = load_yaml(dvclive.metrics_file) 38 | assert s["m1"] == 1 39 | assert s["step"] == 0 40 | 41 | 42 | def test_nested_logging(tmp_dir): 43 | dvclive = Live("logs") 44 | 45 | out = tmp_dir / dvclive.plots_dir / Metric.subfolder 46 | 47 | dvclive.log_metric("train/m1", 1) 48 | dvclive.log_metric("val/val_1/m1", 1) 49 | dvclive.log_metric("val/val_1/m2", 1) 50 | 51 | dvclive.next_step() 52 | 53 | assert (out / "val" / "val_1").is_dir() 54 | assert (out / "train" / "m1.tsv").is_file() 55 | assert (out / "val" / "val_1" / "m1.tsv").is_file() 56 | assert (out / "val" / "val_1" / "m2.tsv").is_file() 57 | 58 | assert "m1" in parse_tsv(out / "train" / "m1.tsv")[0] 59 | assert "m1" in parse_tsv(out / "val" / "val_1" / "m1.tsv")[0] 60 | assert "m2" in parse_tsv(out / "val" / "val_1" / "m2.tsv")[0] 61 | 62 | summary = load_yaml(dvclive.metrics_file) 63 | 64 | assert summary["train"]["m1"] == 1 65 | assert summary["val"]["val_1"]["m1"] == 1 66 | assert summary["val"]["val_1"]["m2"] == 1 67 | 68 | 69 | @pytest.mark.parametrize("timestamp", [True, False]) 70 | def test_log_metric_timestamp(tmp_dir, timestamp): 71 | live = Live() 72 | live.log_metric("foo", 1.0, timestamp=timestamp) 73 | live.next_step() 74 | 75 | history, _ = parse_metrics(live) 76 | logged = next(iter(history.values())) 77 | assert ("timestamp" in logged[0]) == timestamp 78 | 79 | 80 | @pytest.mark.parametrize("invalid_type", [{0: 1}, [0, 1], (0, 1)]) 81 | def test_invalid_metric_type(tmp_dir, invalid_type): 82 | dvclive = Live() 83 | 84 | with pytest.raises( 85 | InvalidDataTypeError, 86 | match=f"Data 'm' has not supported type {type(invalid_type)}", 87 | ): 88 | dvclive.log_metric("m", invalid_type) 89 | 90 | 91 | @pytest.mark.parametrize( 92 | ("val"), 93 | [math.inf, math.nan, np.nan, np.inf], 94 | ) 95 | def test_log_metric_inf_nan(tmp_dir, val): 96 | with Live() as live: 97 | live.log_metric("metric", val) 98 | assert live.summary["metric"] == str(val) 99 | 100 | 101 | def test_log_metic_str(tmp_dir): 102 | with Live() as live: 103 | live.log_metric("metric", "foo") 104 | assert live.summary["metric"] == "foo" 105 | -------------------------------------------------------------------------------- /tests/test_log_param.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from dvclive import Live 6 | from dvclive.error import InvalidParameterTypeError 7 | from dvclive.serialize import load_yaml 8 | 9 | 10 | def test_cleanup_params(tmp_dir): 11 | dvclive = Live("logs") 12 | dvclive.log_param("param", 42) 13 | 14 | assert os.path.isfile(dvclive.params_file) 15 | 16 | dvclive = Live("logs") 17 | assert not os.path.exists(dvclive.params_file) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | ("param_name", "param_value"), 22 | [ 23 | ("param_string", "value"), 24 | ("param_int", 42), 25 | ("param_float", 42.0), 26 | ("param_bool_true", True), 27 | ("param_bool_false", False), 28 | ("param_list", [1, 2, 3]), 29 | ( 30 | "param_dict_simple", 31 | {"str": "value", "int": 42, "bool": True, "list": [1, 2, 3]}, 32 | ), 33 | ( 34 | "param_dict_nested", 35 | { 36 | "str": "value", 37 | "int": 42, 38 | "bool": True, 39 | "list": [1, 2, 3], 40 | "dict": {"nested-str": "value", "nested-int": 42}, 41 | }, 42 | ), 43 | ], 44 | ) 45 | def test_log_param(tmp_dir, param_name, param_value): 46 | dvclive = Live() 47 | 48 | dvclive.log_param(param_name, param_value) 49 | 50 | s = load_yaml(dvclive.params_file) 51 | assert s[param_name] == param_value 52 | 53 | 54 | def test_log_params(tmp_dir): 55 | dvclive = Live() 56 | params = { 57 | "param_string": "value", 58 | "param_int": 42, 59 | "param_float": 42.0, 60 | "param_bool_true": True, 61 | "param_bool_false": False, 62 | } 63 | 64 | dvclive.log_params(params) 65 | 66 | s = load_yaml(dvclive.params_file) 67 | assert s == params 68 | 69 | 70 | @pytest.mark.parametrize("resume", [False, True]) 71 | def test_log_params_resume(tmp_dir, resume): 72 | dvclive = Live(resume=resume) 73 | dvclive.log_param("param", 42) 74 | 75 | dvclive = Live(resume=resume) 76 | assert ("param" in dvclive._params) == resume 77 | 78 | 79 | def test_log_param_custom_obj(tmp_dir): 80 | dvclive = Live("logs") 81 | 82 | class Dummy: 83 | val = 42 84 | 85 | param_value = Dummy() 86 | 87 | with pytest.raises(InvalidParameterTypeError) as excinfo: 88 | dvclive.log_param("param_complex", param_value) 89 | assert "Dummy" in excinfo.value.args[0] 90 | -------------------------------------------------------------------------------- /tests/test_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from dvclive import Live 4 | 5 | 6 | def test_logger(tmp_dir, mocker): 7 | logger = mocker.patch("dvclive.live.logger") 8 | 9 | live = Live() 10 | live.log_metric("foo", 0) 11 | logger.debug.assert_called_with("Logged foo: 0") 12 | live.next_step() 13 | logger.debug.assert_called_with("Step: 1") 14 | live.log_metric("foo", 1) 15 | live.next_step() 16 | 17 | live = Live(resume=True) 18 | logger.info.assert_called_with("Resuming from step 1") 19 | 20 | 21 | def test_suppress_dvc_logs(tmp_dir, mocked_dvc_repo): 22 | Live() 23 | assert logging.getLogger("dvc").level == 30 24 | -------------------------------------------------------------------------------- /tests/test_make_report.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from PIL import Image 4 | 5 | from dvclive import Live 6 | from dvclive.env import DVCLIVE_OPEN 7 | from dvclive.error import InvalidReportModeError 8 | from dvclive.plots import CustomPlot 9 | from dvclive.plots import Image as LiveImage 10 | from dvclive.plots import Metric 11 | from dvclive.plots.sklearn import SKLearnPlot 12 | from dvclive.report import ( 13 | get_custom_plot_renderers, 14 | get_image_renderers, 15 | get_metrics_renderers, 16 | get_params_renderers, 17 | get_scalar_renderers, 18 | get_sklearn_plot_renderers, 19 | ) 20 | 21 | 22 | @pytest.mark.parametrize("mode", ["html", "md", "notebook"]) 23 | def test_get_image_renderers(tmp_dir, mode): 24 | with Live() as live: 25 | img = Image.new("RGB", (10, 10), (255, 0, 0)) 26 | live.log_image("image.png", img) 27 | 28 | image_renderers = get_image_renderers( 29 | tmp_dir / live.plots_dir / LiveImage.subfolder 30 | ) 31 | assert len(image_renderers) == 1 32 | img = image_renderers[0].datapoints[0] 33 | assert img["src"].startswith("data:image;base64,") 34 | assert img["rev"] == "image.png" 35 | 36 | 37 | def test_get_renderers(tmp_dir, mocker): 38 | live = Live() 39 | 40 | live.log_param("string", "goo") 41 | live.log_param("number", 2) 42 | 43 | for i in range(2): 44 | live.log_metric("foo/bar", i) 45 | live.next_step() 46 | 47 | scalar_renderers = get_scalar_renderers(tmp_dir / live.plots_dir / Metric.subfolder) 48 | assert len(scalar_renderers) == 1 49 | assert scalar_renderers[0].datapoints == [ 50 | { 51 | "bar": "0", 52 | "rev": "workspace", 53 | "step": "0", 54 | }, 55 | { 56 | "bar": "1", 57 | "rev": "workspace", 58 | "step": "1", 59 | }, 60 | ] 61 | assert scalar_renderers[0].properties["y"] == "bar" 62 | assert scalar_renderers[0].properties["title"] == "foo/bar" 63 | assert scalar_renderers[0].name == "static/foo/bar" 64 | 65 | metrics_renderer = get_metrics_renderers(live.metrics_file)[0] 66 | assert metrics_renderer.datapoints == [{"step": 1, "foo": {"bar": 1}}] 67 | 68 | params_renderer = get_params_renderers(live.params_file)[0] 69 | assert params_renderer.datapoints == [{"string": "goo", "number": 2}] 70 | 71 | 72 | def test_report_init(monkeypatch, mocker): 73 | mocker.patch("dvclive.live.inside_notebook", return_value=False) 74 | live = Live(report="notebook") 75 | assert live._report_mode is None 76 | 77 | mocker.patch("dvclive.live.matplotlib_installed", return_value=False) 78 | live = Live(report="md") 79 | assert live._report_mode is None 80 | 81 | mocker.patch("dvclive.live.matplotlib_installed", return_value=True) 82 | live = Live(report="md") 83 | assert live._report_mode == "md" 84 | 85 | live = Live(report="html") 86 | assert live._report_mode == "html" 87 | 88 | with pytest.raises(InvalidReportModeError, match="Got foo instead."): 89 | Live(report="foo") 90 | 91 | 92 | @pytest.mark.parametrize("mode", ["html", "md"]) 93 | def test_make_report(tmp_dir, mode): 94 | last_report = "" 95 | live = Live(report=mode) 96 | for i in range(3): 97 | live.log_metric("foobar", i) 98 | live.log_metric("foo/bar", i) 99 | live.make_report() 100 | live.next_step() 101 | assert (tmp_dir / live.report_file).exists() 102 | current_report = (tmp_dir / live.report_file).read_text() 103 | assert last_report != current_report 104 | last_report = current_report 105 | 106 | 107 | @pytest.mark.vscode 108 | def test_make_report_open(tmp_dir, mocker, monkeypatch): 109 | mocked_open = mocker.patch("webbrowser.open") 110 | live = Live() 111 | live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) 112 | live.make_report() 113 | live.make_report() 114 | 115 | assert not mocked_open.called 116 | 117 | live = Live(report="html") 118 | live.log_metric("foo", 1) 119 | live.next_step() 120 | 121 | assert not mocked_open.called 122 | 123 | monkeypatch.setenv(DVCLIVE_OPEN, "true") 124 | 125 | live = Live(report="html") 126 | live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) 127 | live.make_report() 128 | 129 | mocked_open.assert_called_once() 130 | 131 | 132 | def test_get_plot_renderers_sklearn(tmp_dir): 133 | live = Live() 134 | 135 | for _ in range(2): 136 | live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) 137 | live.log_sklearn_plot( 138 | "confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1], name="train/cm" 139 | ) 140 | live.log_sklearn_plot("roc", [0, 0, 1, 1], [1, 0.1, 0, 1], name="roc_curve") 141 | live.log_sklearn_plot( 142 | "roc", [0, 0, 1, 1], [1, 0.1, 0, 1], name="other_roc.json" 143 | ) 144 | live.next_step() 145 | 146 | plot_renderers = get_sklearn_plot_renderers( 147 | tmp_dir / live.plots_dir / SKLearnPlot.subfolder, live 148 | ) 149 | assert len(plot_renderers) == 4 150 | plot_renderers_dict = { 151 | plot_renderer.name: plot_renderer for plot_renderer in plot_renderers 152 | } 153 | for name in ("roc_curve", "other_roc"): 154 | plot_renderer = plot_renderers_dict[name] 155 | assert plot_renderer.datapoints == [ 156 | {"fpr": 0.0, "rev": "workspace", "threshold": np.inf, "tpr": 0.0}, 157 | {"fpr": 0.5, "rev": "workspace", "threshold": 1.0, "tpr": 0.5}, 158 | {"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5}, 159 | {"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0}, 160 | ] 161 | assert plot_renderer.properties == live._plots[name].plot_config 162 | 163 | for name in ("confusion_matrix", "train/cm"): 164 | plot_renderer = plot_renderers_dict[name] 165 | assert plot_renderer.datapoints == [ 166 | {"actual": "0", "rev": "workspace", "predicted": "1"}, 167 | {"actual": "0", "rev": "workspace", "predicted": "0"}, 168 | {"actual": "1", "rev": "workspace", "predicted": "0"}, 169 | {"actual": "1", "rev": "workspace", "predicted": "1"}, 170 | ] 171 | assert plot_renderer.properties == live._plots[name].plot_config 172 | 173 | 174 | def test_get_plot_renderers_custom(tmp_dir): 175 | live = Live() 176 | 177 | datapoints = [{"x": 1, "y": 2}, {"x": 3, "y": 4}] 178 | for _ in range(2): 179 | live.log_plot("foo_default", datapoints, x="x", y="y") 180 | live.log_plot( 181 | "foo_scatter", 182 | datapoints, 183 | x="x", 184 | y="y", 185 | template="scatter", 186 | ) 187 | live.next_step() 188 | plot_renderers = get_custom_plot_renderers( 189 | tmp_dir / live.plots_dir / CustomPlot.subfolder, live 190 | ) 191 | 192 | assert len(plot_renderers) == 2 193 | plot_renderers_dict = { 194 | plot_renderer.name: plot_renderer for plot_renderer in plot_renderers 195 | } 196 | for name in ("foo_default", "foo_scatter"): 197 | plot_renderer = plot_renderers_dict[name] 198 | assert plot_renderer.datapoints == [ 199 | {"rev": "workspace", "x": 1, "y": 2}, 200 | {"rev": "workspace", "x": 3, "y": 4}, 201 | ] 202 | assert plot_renderer.properties == live._plots[name].plot_config 203 | 204 | 205 | def test_report_notebook(tmp_dir, mocker): 206 | mocker.patch("dvclive.live.inside_notebook", return_value=True) 207 | mocked_display = mocker.MagicMock() 208 | mocker.patch("IPython.display.display", return_value=mocked_display) 209 | live = Live(report="notebook") 210 | assert live._report_mode == "notebook" 211 | live.make_report() 212 | assert mocked_display.update.called 213 | -------------------------------------------------------------------------------- /tests/test_make_summary.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from dvclive import Live 4 | from dvclive.plots import Metric 5 | 6 | 7 | def test_make_summary_without_calling_log(tmp_dir): 8 | dvclive = Live() 9 | 10 | dvclive.summary["foo"] = 1.0 11 | dvclive.make_summary() 12 | 13 | assert json.loads((tmp_dir / dvclive.metrics_file).read_text()) == { 14 | # no `step` 15 | "foo": 1.0 16 | } 17 | log_file = tmp_dir / dvclive.plots_dir / Metric.subfolder / "foo.tsv" 18 | assert not log_file.exists() 19 | 20 | 21 | def test_make_summary_is_called_on_end(tmp_dir): 22 | live = Live() 23 | 24 | live.summary["foo"] = 1.0 25 | live.end() 26 | 27 | assert json.loads((tmp_dir / live.metrics_file).read_text()) == { 28 | # no `step` 29 | "foo": 1.0 30 | } 31 | log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" 32 | assert not log_file.exists() 33 | 34 | 35 | def test_make_summary_on_end_dont_increment_step(tmp_dir): 36 | with Live() as live: 37 | for i in range(2): 38 | live.log_metric("foo", i) 39 | live.next_step() 40 | 41 | assert json.loads((tmp_dir / live.metrics_file).read_text()) == { 42 | "foo": 1.0, 43 | "step": 1, 44 | } 45 | -------------------------------------------------------------------------------- /tests/test_resume.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dvclive import Live 4 | from dvclive.env import DVCLIVE_RESUME 5 | from dvclive.utils import read_history, read_latest 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ("resume", "steps", "metrics"), 10 | [(True, [0, 1, 2, 3], [0.9, 0.8, 0.7, 0.6]), (False, [0, 1], [0.7, 0.6])], 11 | ) 12 | def test_resume(tmp_dir, resume, steps, metrics): 13 | dvclive = Live("logs") 14 | 15 | for metric in [0.9, 0.8]: 16 | dvclive.log_metric("metric", metric) 17 | dvclive.next_step() 18 | dvclive.log_metric("summary", 1) 19 | dvclive.end() 20 | 21 | assert read_history(dvclive, "metric") == ([0, 1], [0.9, 0.8]) 22 | assert read_latest(dvclive, "metric") == (1, 0.8) 23 | 24 | dvclive = Live("logs", resume=resume) 25 | 26 | for new_metric in [0.7, 0.6]: 27 | dvclive.log_metric("metric", new_metric) 28 | dvclive.next_step() 29 | dvclive.end() 30 | 31 | assert read_history(dvclive, "metric") == (steps, metrics) 32 | assert read_latest(dvclive, "metric") == (steps[-1], metrics[-1]) 33 | if resume: 34 | assert dvclive.read_latest()["summary"] == 1 35 | else: 36 | assert "summary" not in dvclive.read_latest() 37 | 38 | 39 | def test_resume_on_first_init(tmp_dir): 40 | dvclive = Live(resume=True) 41 | 42 | assert dvclive._step == 0 43 | 44 | 45 | def test_resume_env_var(tmp_dir, monkeypatch): 46 | assert not Live()._resume 47 | 48 | monkeypatch.setenv(DVCLIVE_RESUME, "true") 49 | assert Live()._resume 50 | -------------------------------------------------------------------------------- /tests/test_step.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from dvclive import Live 6 | from dvclive.utils import read_history, read_latest 7 | 8 | 9 | @pytest.mark.parametrize("metric", ["m1", os.path.join("train", "m1")]) 10 | def test_allow_step_override(tmp_dir, metric): 11 | dvclive = Live("logs") 12 | 13 | dvclive.log_metric(metric, 1.0) 14 | dvclive.log_metric(metric, 2.0) 15 | 16 | 17 | def test_custom_steps(tmp_dir): 18 | dvclive = Live("logs") 19 | 20 | steps = [0, 62, 1000] 21 | metrics = [0.9, 0.8, 0.7] 22 | 23 | for step, metric in zip(steps, metrics): 24 | dvclive.step = step 25 | dvclive.log_metric("m", metric) 26 | dvclive.make_summary() 27 | 28 | assert read_history(dvclive, "m") == (steps, metrics) 29 | assert read_latest(dvclive, "m") == (steps[-1], metrics[-1]) 30 | 31 | 32 | def test_log_reset_with_set_step(tmp_dir): 33 | dvclive = Live() 34 | 35 | for i in range(3): 36 | dvclive.step = i 37 | dvclive.log_metric("train_m", 1) 38 | dvclive.make_summary() 39 | 40 | for i in range(3): 41 | dvclive.step = i 42 | dvclive.log_metric("val_m", 1) 43 | dvclive.make_summary() 44 | 45 | assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1]) 46 | assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1]) 47 | assert read_latest(dvclive, "train_m") == (2, 1) 48 | assert read_latest(dvclive, "val_m") == (2, 1) 49 | 50 | 51 | def test_get_step_resume(tmp_dir): 52 | dvclive = Live() 53 | 54 | for metric in [0.9, 0.8]: 55 | dvclive.log_metric("metric", metric) 56 | dvclive.next_step() 57 | 58 | assert dvclive.step == 2 59 | 60 | dvclive = Live(resume=True) 61 | assert dvclive.step == 2 62 | 63 | dvclive = Live(resume=False) 64 | assert dvclive.step == 0 65 | 66 | 67 | def test_get_step_custom_steps(tmp_dir): 68 | dvclive = Live() 69 | 70 | steps = [0, 62, 1000] 71 | metrics = [0.9, 0.8, 0.7] 72 | 73 | for step, metric in zip(steps, metrics): 74 | dvclive.step = step 75 | dvclive.log_metric("x", metric) 76 | assert dvclive.step == step 77 | 78 | 79 | def test_get_step_control_flow(tmp_dir): 80 | dvclive = Live() 81 | 82 | while dvclive.step < 10: 83 | dvclive.log_metric("i", dvclive.step) 84 | dvclive.next_step() 85 | 86 | steps, values = read_history(dvclive, "i") 87 | assert steps == list(range(10)) 88 | assert values == [float(x) for x in range(10)] 89 | 90 | 91 | def test_set_step_only(tmp_dir): 92 | dvclive = Live() 93 | dvclive.step = 1 94 | dvclive.end() 95 | 96 | assert dvclive.read_latest() == {"step": 1} 97 | assert not os.path.exists(os.path.join(tmp_dir, "dvclive", "plots")) 98 | 99 | 100 | def test_step_on_end(tmp_dir): 101 | dvclive = Live() 102 | for metric in range(3): 103 | dvclive.log_metric("m", metric) 104 | dvclive.next_step() 105 | dvclive.end() 106 | assert dvclive.step == metric 107 | 108 | assert dvclive.read_latest() == {"step": metric, "m": metric} 109 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts 6 | from dvclive.error import InvalidDataTypeError 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ("framework", "logged", "standardized"), 11 | [ 12 | ("dvclive.lightning", "epoch", "epoch"), 13 | ("dvclive.lightning", "train_loss", "train/loss"), 14 | ("dvclive.lightning", "train_loss_epoch", "train/epoch/loss"), 15 | ("dvclive.lightning", "train_model_error", "train/model_error"), 16 | ("dvclive.lightning", "grad_step", "grad_step"), 17 | ], 18 | ) 19 | def test_standardize_metric_name(framework, logged, standardized): 20 | assert standardize_metric_name(logged, framework) == standardized 21 | 22 | 23 | # Tests for convert_datapoints_to_list_of_dicts() 24 | @pytest.mark.parametrize( 25 | ("input_data", "expected_output"), 26 | [ 27 | ( 28 | pd.DataFrame({"A": [1, 2], "B": [3, 4]}), 29 | [{"A": 1, "B": 3}, {"A": 2, "B": 4}], 30 | ), 31 | (np.array([[1, 3], [2, 4]]), [{0: 1, 1: 3}, {0: 2, 1: 4}]), 32 | ( 33 | np.array([(1, 3), (2, 4)], dtype=[("A", "i4"), ("B", "i4")]), 34 | [{"A": 1, "B": 3}, {"A": 2, "B": 4}], 35 | ), 36 | ([{"A": 1, "B": 3}, {"A": 2, "B": 4}], [{"A": 1, "B": 3}, {"A": 2, "B": 4}]), 37 | ], 38 | ) 39 | def test_convert_datapoints_to_list_of_dicts(input_data, expected_output): 40 | assert convert_datapoints_to_list_of_dicts(input_data) == expected_output 41 | 42 | 43 | def test_unsupported_format(): 44 | with pytest.raises(InvalidDataTypeError) as exc_info: 45 | convert_datapoints_to_list_of_dicts("unsupported data format") 46 | 47 | assert "not supported type" in str(exc_info.value) 48 | -------------------------------------------------------------------------------- /tests/test_vscode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytest 5 | 6 | from dvclive import Live, env 7 | 8 | 9 | @pytest.mark.vscode 10 | @pytest.mark.parametrize("dvc_root", [True, False]) 11 | def test_vscode_dvclive_step_completed_signal_file( 12 | tmp_dir, dvc_root, mocker, monkeypatch 13 | ): 14 | signal_file = os.path.join( 15 | tmp_dir, ".dvc", "tmp", "exps", "run", "DVCLIVE_STEP_COMPLETED" 16 | ) 17 | cwd = tmp_dir 18 | test_pid = 12345 19 | 20 | if dvc_root: 21 | cwd = tmp_dir / ".dvc" / "tmp" / "exps" / "asdasasf" 22 | monkeypatch.setenv(env.DVC_ROOT, tmp_dir.as_posix()) 23 | (cwd / ".dvc").mkdir(parents=True) 24 | 25 | assert not os.path.exists(signal_file) 26 | 27 | dvc_repo = mocker.MagicMock() 28 | dvc_repo.index.stages = [] 29 | dvc_repo.config = {} 30 | dvc_repo.scm.get_rev.return_value = "current_rev" 31 | dvc_repo.scm.get_ref.return_value = None 32 | dvc_repo.scm.no_commits = False 33 | mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) 34 | mocker.patch("dvclive.live.os.getpid", return_value=test_pid) 35 | 36 | dvclive = Live(save_dvc_exp=True) 37 | assert not os.path.exists(signal_file) 38 | dvclive.next_step() 39 | assert dvclive.step == 1 40 | 41 | if dvc_root: 42 | assert os.path.exists(signal_file) 43 | with open(signal_file, encoding="utf-8") as f: 44 | assert json.load(f) == {"pid": test_pid, "step": 0} 45 | 46 | else: 47 | assert not os.path.exists(signal_file) 48 | 49 | dvclive.next_step() 50 | assert dvclive.step == 2 51 | 52 | if dvc_root: 53 | with open(signal_file, encoding="utf-8") as f: 54 | assert json.load(f) == {"pid": test_pid, "step": 1} 55 | 56 | dvclive.end() 57 | 58 | assert not os.path.exists(signal_file) 59 | 60 | 61 | @pytest.mark.vscode 62 | @pytest.mark.parametrize("dvc_root", [True, False]) 63 | def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker): 64 | signal_file = os.path.join(tmp_dir, ".dvc", "tmp", "exps", "run", "DVCLIVE_ONLY") 65 | test_pid = 12345 66 | 67 | if dvc_root: 68 | (tmp_dir / ".dvc").mkdir(parents=True) 69 | 70 | assert not os.path.exists(signal_file) 71 | 72 | dvc_repo = mocker.MagicMock() 73 | dvc_repo.index.stages = [] 74 | dvc_repo.config = {} 75 | dvc_repo.scm.get_rev.return_value = "current_rev" 76 | dvc_repo.scm.get_ref.return_value = None 77 | dvc_repo.scm.no_commits = False 78 | mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) 79 | mocker.patch("dvclive.live.os.getpid", return_value=test_pid) 80 | 81 | dvclive = Live(save_dvc_exp=True) 82 | 83 | if dvc_root: 84 | assert os.path.exists(signal_file) 85 | with open(signal_file, encoding="utf-8") as f: 86 | assert json.load(f) == {"pid": test_pid, "exp_name": dvclive._exp_name} 87 | 88 | else: 89 | assert not os.path.exists(signal_file) 90 | 91 | dvclive.end() 92 | 93 | assert not os.path.exists(signal_file) 94 | --------------------------------------------------------------------------------