├── torcheval
├── py.typed
├── utils
│ ├── test_utils
│ │ ├── __init__.py
│ │ └── dummy_metric.py
│ └── __init__.py
├── metrics
│ ├── audio
│ │ └── __init__.py
│ ├── functional
│ │ ├── image
│ │ │ ├── __init__.py
│ │ │ └── psnr.py
│ │ ├── statistical
│ │ │ └── __init__.py
│ │ ├── regression
│ │ │ └── __init__.py
│ │ ├── aggregation
│ │ │ ├── __init__.py
│ │ │ ├── throughput.py
│ │ │ ├── sum.py
│ │ │ ├── mean.py
│ │ │ └── auc.py
│ │ ├── text
│ │ │ ├── __init__.py
│ │ │ ├── helper.py
│ │ │ ├── word_information_lost.py
│ │ │ └── word_information_preserved.py
│ │ ├── ranking
│ │ │ ├── __init__.py
│ │ │ ├── frequency.py
│ │ │ ├── num_collisions.py
│ │ │ ├── reciprocal_rank.py
│ │ │ ├── hit_rate.py
│ │ │ └── click_through_rate.py
│ │ ├── tensor_utils.py
│ │ ├── frechet.py
│ │ ├── classification
│ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── statistical
│ │ └── __init__.py
│ ├── regression
│ │ └── __init__.py
│ ├── image
│ │ └── __init__.py
│ ├── text
│ │ ├── __init__.py
│ │ └── word_error_rate.py
│ ├── aggregation
│ │ ├── __init__.py
│ │ ├── min.py
│ │ ├── max.py
│ │ ├── cov.py
│ │ ├── sum.py
│ │ ├── cat.py
│ │ └── throughput.py
│ ├── window
│ │ └── __init__.py
│ ├── ranking
│ │ ├── __init__.py
│ │ ├── hit_rate.py
│ │ └── reciprocal_rank.py
│ └── classification
│ │ └── __init__.py
├── __init__.py
└── version.py
├── requirements.txt
├── image-requirements.txt
├── docs
├── requirements.txt
├── source
│ ├── torcheval.metrics.toolkit.rst
│ ├── _static
│ │ ├── css
│ │ │ └── torcheval.css
│ │ └── js
│ │ │ └── torcheval.js
│ ├── templates
│ │ └── layout.html
│ ├── ext
│ │ └── fbcode.py
│ ├── torcheval.metrics.functional.rst
│ ├── torcheval.metrics.rst
│ └── index.rst
├── license_header.txt
├── Makefile
└── build_docs.sh
├── pyproject.toml
├── dev-requirements.txt
├── tests
└── metrics
│ ├── __init__.py
│ ├── text
│ ├── __init__.py
│ ├── test_word_information_lost.py
│ ├── test_word_error_rate.py
│ └── test_word_information_preserved.py
│ ├── ranking
│ ├── __init__.py
│ ├── test_hit_rate.py
│ ├── test_reciprocal_rank.py
│ └── test_click_through_rate.py
│ ├── aggregation
│ ├── __init__.py
│ ├── test_cov.py
│ ├── test_max.py
│ ├── test_min.py
│ ├── test_throughput.py
│ └── test_sum.py
│ ├── classification
│ └── __init__.py
│ ├── functional
│ ├── __init__.py
│ ├── image
│ │ ├── __init__.py
│ │ └── test_psnr.py
│ ├── ranking
│ │ ├── __init__.py
│ │ ├── test_num_collisions.py
│ │ ├── test_frequency.py
│ │ ├── test_click_through_rate.py
│ │ ├── test_hit_rate.py
│ │ ├── test_reciprocal_rank.py
│ │ └── test_weighted_calibration.py
│ ├── text
│ │ ├── __init__.py
│ │ ├── test_word_information_lost.py
│ │ ├── test_word_error_rate.py
│ │ └── test_word_information_preserved.py
│ ├── aggregation
│ │ ├── __init__.py
│ │ ├── test_throughput.py
│ │ ├── test_sum.py
│ │ └── test_mean.py
│ ├── regression
│ │ └── __init__.py
│ ├── statistical
│ │ └── __init__.py
│ └── classification
│ │ └── __init__.py
│ ├── regression
│ └── __init__.py
│ ├── statistical
│ └── __init__.py
│ └── image
│ ├── test_ssim.py
│ └── test_psnr.py
├── .github
├── workflows
│ ├── pre_commit.yaml
│ ├── build_docs.yaml
│ ├── release_build.yaml
│ ├── nightly_build_cpu.yaml
│ ├── release_build_docs.yaml
│ └── unit_test.yaml
├── PULL_REQUEST_TEMPLATE.md
└── ISSUE_TEMPLATE
│ ├── help-support.yml
│ ├── documentation.yml
│ ├── feature-request.yml
│ └── bug-report.yml
├── .flake8
├── .pre-commit-config.yaml
├── LICENSE
├── CONTRIBUTING.md
├── setup.py
├── examples
└── simple_example.py
├── CODE_OF_CONDUCT.md
└── .gitignore
/torcheval/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | typing_extensions
2 |
--------------------------------------------------------------------------------
/image-requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision
2 | skimage
3 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==5.0.1
2 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.usort]
2 |
3 | first_party_detection = false
4 |
5 | [tool.pytest.ini_options]
6 | markers =[
7 | "cpu_and_gpu",
8 | "gpu_only",
9 | ]
10 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pre-commit
3 | pytest
4 | pytest-timeout
5 | pytest-cov
6 | Cython>=0.28.5
7 | scikit-learn>=0.22
8 | scikit-image==0.18.3
9 | torchtnt-nightly
10 |
--------------------------------------------------------------------------------
/docs/source/torcheval.metrics.toolkit.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: torcheval.metrics.toolkit
2 |
3 | Metric Toolkit
4 | ==================
5 |
6 | .. automodule:: torcheval.metrics.toolkit
7 | :members:
8 | :undoc-members:
9 |
--------------------------------------------------------------------------------
/docs/license_header.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) Meta Platforms, Inc. and affiliates.
2 | All rights reserved.
3 |
4 | This source code is licensed under the BSD-style license found in the
5 | LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/ranking/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/regression/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/statistical/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torcheval/utils/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/aggregation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/regression/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/statistical/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/metrics/functional/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/.github/workflows/pre_commit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: actions/setup-python@v3
14 | - uses: pre-commit/action@v3.0.0
15 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | Please read through our [contribution guide](https://github.com/pytorch/torcheval/blob/main/CONTRIBUTING.md) prior to creating your pull request.
2 |
3 | Summary:
4 |
5 |
6 | Test plan:
7 |
8 |
9 | Fixes #{issue number}
10 |
11 |
--------------------------------------------------------------------------------
/torcheval/metrics/audio/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.audio.fad import FrechetAudioDistance
10 |
11 |
12 | __all__ = ["FrechetAudioDistance"]
13 |
--------------------------------------------------------------------------------
/docs/source/_static/css/torcheval.css:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | #redirect-banner {
10 | border-bottom: 1px solid #e2e2e2;
11 | }
12 | #redirect-banner > p {
13 | margin: 0.8rem;
14 | text-align: center;
15 | }
16 |
--------------------------------------------------------------------------------
/torcheval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | "A library that contains a collection of performant PyTorch model metrics"
10 |
11 | from .version import __version__
12 |
13 | __all__ = [
14 | "__version__",
15 | ]
16 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.image.psnr import peak_signal_noise_ratio
10 |
11 | __all__ = ["peak_signal_noise_ratio"]
12 | __doc_name__ = "Image Metrics"
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/help-support.yml:
--------------------------------------------------------------------------------
1 | name: 📚 Help Support
2 | description: Do you need help/support? Send us your questions.
3 |
4 | body:
5 | - type: textarea
6 | attributes:
7 | label: 📚 Question
8 | description: >
9 | Description of your question or what you need support with.
10 | validations:
11 | required: true
12 | - type: markdown
13 | attributes:
14 | value: >
15 | Thanks for contributing 🎉!
16 |
--------------------------------------------------------------------------------
/docs/source/_static/js/torcheval.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the BSD-style license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | const NETWORK_TEST_URL = 'https://staticdocs.thefacebook.com/ping';
10 | fetch(NETWORK_TEST_URL).then(() => {
11 | $("#redirect-banner").prependTo("body").show();
12 | });
13 |
--------------------------------------------------------------------------------
/docs/source/templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {%- block extrabody %}
4 | {% if not fbcode %}
5 |
11 | {% endif %}
12 | {%- endblock %}
13 |
--------------------------------------------------------------------------------
/torcheval/metrics/statistical/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
8 |
9 | from torcheval.metrics.statistical.wasserstein import Wasserstein1D
10 |
11 | __all__ = ["Wasserstein1D"]
12 | __doc_name__ = "Statistical Metrics"
13 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/statistical/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
8 |
9 | from torcheval.metrics.functional.statistical.wasserstein import wasserstein_1d
10 |
11 | __all__ = ["wasserstein_1d"]
12 | __doc_name__ = "Statistical Metrics"
13 |
--------------------------------------------------------------------------------
/torcheval/metrics/regression/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.regression.mean_squared_error import MeanSquaredError
10 | from torcheval.metrics.regression.r2_score import R2Score
11 |
12 | __all__ = ["MeanSquaredError", "R2Score"]
13 | __doc_name__ = "Regression Metrics"
14 |
--------------------------------------------------------------------------------
/torcheval/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.utils.random_data import (
10 | get_rand_data_binary,
11 | get_rand_data_binned_binary,
12 | get_rand_data_multiclass,
13 | )
14 |
15 | __all__ = [
16 | "get_rand_data_binary",
17 | "get_rand_data_binned_binary",
18 | "get_rand_data_multiclass",
19 | ]
20 |
--------------------------------------------------------------------------------
/torcheval/version.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # pyre-strict
9 |
10 | # Follows PEP-0440 version scheme guidelines
11 | # https://www.python.org/dev/peps/pep-0440/#version-scheme
12 | #
13 | # Examples:
14 | # 0.1.0.devN # Developmental release
15 | # 0.1.0aN # Alpha release
16 | # 0.1.0bN # Beta release
17 | # 0.1.0rcN # Release Candidate
18 | # 0.1.0 # Final release
19 | __version__: str = "0.0.7"
20 |
--------------------------------------------------------------------------------
/torcheval/metrics/image/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.image.fid import FrechetInceptionDistance
10 | from torcheval.metrics.image.psnr import PeakSignalNoiseRatio
11 | from torcheval.metrics.image.ssim import StructuralSimilarity
12 |
13 | __all__ = [
14 | "FrechetInceptionDistance",
15 | "PeakSignalNoiseRatio",
16 | "StructuralSimilarity",
17 | ]
18 | __doc_name__ = "Image Metrics"
19 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | # Suggested config from pytorch that we can adopt
3 | select = B,C,E,F,P,T4,W,B9
4 | max-line-length = 120
5 | # C408 ignored because we like the dict keyword argument syntax
6 | # E501 is not flexible enough, we're using B950 instead
7 | ignore =
8 | E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
9 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying
10 | # to line this up with executable bit
11 | EXE001,
12 | optional-ascii-coding = True
13 | exclude =
14 | ./.git,
15 | ./docs
16 | ./build
17 | ./scripts,
18 | ./venv,
19 | *.pyi
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.yml:
--------------------------------------------------------------------------------
1 | name: 📚 Documentation
2 | description: Report an issue related to inline documnetation
3 |
4 | body:
5 | - type: textarea
6 | attributes:
7 | label: 📚 The doc issue
8 | description: >
9 | A clear and concise description of what content is an issue.
10 | validations:
11 | required: true
12 | - type: textarea
13 | attributes:
14 | label: Suggest a potential alternative/fix
15 | description: >
16 | Tell us how we could improve the documentation in this regard.
17 | - type: markdown
18 | attributes:
19 | value: >
20 | Thanks for contributing 🎉!
21 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/regression/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from torcheval.metrics.functional.regression.mean_squared_error import (
12 | mean_squared_error,
13 | )
14 |
15 | from torcheval.metrics.functional.regression.r2_score import r2_score
16 |
17 | __all__ = ["mean_squared_error", "r2_score"]
18 | __doc_name__ = "Regression Metrics"
19 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/aggregation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.aggregation.auc import auc
10 |
11 | from torcheval.metrics.functional.aggregation.mean import mean
12 |
13 | from torcheval.metrics.functional.aggregation.sum import sum
14 |
15 | from torcheval.metrics.functional.aggregation.throughput import throughput
16 |
17 |
18 | __all__ = ["auc", "mean", "sum", "throughput"]
19 | __doc_name__ = "Aggregation Metrics"
20 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | python update_docs.py
21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
22 |
--------------------------------------------------------------------------------
/docs/build_docs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | # This script shows how to build the docs.
9 | # 1. First ensure you have the installed requirements for torcheval in `requirements.txt`
10 | # 2. Then make sure you have installed the requirements inside `docs/requirements.txt`
11 | # 3. Finally cd into docs/ and source this script. Sphinx reads through the installed module
12 | # pull docstrings, so this script just installs the current version of torcheval on your
13 | # system before it builds the docs with `make html`
14 | cd .. || exit
15 | pip install --no-build-isolation .
16 | cd docs || exit
17 | make html
18 |
--------------------------------------------------------------------------------
/torcheval/metrics/text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.text.bleu import BLEUScore
10 | from torcheval.metrics.text.perplexity import Perplexity
11 | from torcheval.metrics.text.word_error_rate import WordErrorRate
12 | from torcheval.metrics.text.word_information_lost import WordInformationLost
13 | from torcheval.metrics.text.word_information_preserved import WordInformationPreserved
14 |
15 | __all__ = [
16 | "BLEUScore",
17 | "Perplexity",
18 | "WordErrorRate",
19 | "WordInformationLost",
20 | "WordInformationPreserved",
21 | ]
22 | __doc_name__ = "Text Metrics"
23 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.aggregation.auc import AUC
10 | from torcheval.metrics.aggregation.cat import Cat
11 | from torcheval.metrics.aggregation.cov import Covariance
12 | from torcheval.metrics.aggregation.max import Max
13 | from torcheval.metrics.aggregation.mean import Mean
14 | from torcheval.metrics.aggregation.min import Min
15 | from torcheval.metrics.aggregation.sum import Sum
16 | from torcheval.metrics.aggregation.throughput import Throughput
17 |
18 | __all__ = ["AUC", "Cat", "Covariance", "Max", "Mean", "Min", "Sum", "Throughput"]
19 | __doc_name__ = "Aggregation Metrics"
20 |
--------------------------------------------------------------------------------
/torcheval/metrics/window/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.window.auroc import WindowedBinaryAUROC
10 | from torcheval.metrics.window.click_through_rate import WindowedClickThroughRate
11 | from torcheval.metrics.window.mean_squared_error import WindowedMeanSquaredError
12 | from torcheval.metrics.window.normalized_entropy import WindowedBinaryNormalizedEntropy
13 | from torcheval.metrics.window.weighted_calibration import WindowedWeightedCalibration
14 |
15 | __all__ = [
16 | "WindowedBinaryAUROC",
17 | "WindowedBinaryNormalizedEntropy",
18 | "WindowedClickThroughRate",
19 | "WindowedMeanSquaredError",
20 | "WindowedWeightedCalibration",
21 | ]
22 | __doc_name__ = "Windowed Metrics"
23 |
--------------------------------------------------------------------------------
/torcheval/metrics/ranking/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.ranking.click_through_rate import ClickThroughRate
10 | from torcheval.metrics.ranking.hit_rate import HitRate
11 | from torcheval.metrics.ranking.reciprocal_rank import ReciprocalRank
12 | from torcheval.metrics.ranking.retrieval_precision import RetrievalPrecision
13 | from torcheval.metrics.ranking.retrieval_recall import RetrievalRecall
14 | from torcheval.metrics.ranking.weighted_calibration import WeightedCalibration
15 |
16 | __all__ = [
17 | "ClickThroughRate",
18 | "HitRate",
19 | "ReciprocalRank",
20 | "RetrievalPrecision",
21 | "RetrievalRecall",
22 | "WeightedCalibration",
23 | ]
24 | __doc_name__ = "Ranking Metrics"
25 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.text.bleu import bleu_score
10 |
11 | from torcheval.metrics.functional.text.perplexity import perplexity
12 |
13 | from torcheval.metrics.functional.text.word_error_rate import word_error_rate
14 |
15 | from torcheval.metrics.functional.text.word_information_lost import (
16 | word_information_lost,
17 | )
18 |
19 | from torcheval.metrics.functional.text.word_information_preserved import (
20 | word_information_preserved,
21 | )
22 |
23 | __all__ = [
24 | "bleu_score",
25 | "perplexity",
26 | "word_error_rate",
27 | "word_information_preserved",
28 | "word_information_lost",
29 | ]
30 | __doc_name__ = "Text Metrics"
31 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.1.0
7 | hooks:
8 | - id: trailing-whitespace
9 | - id: check-ast
10 | - id: check-merge-conflict
11 | - id: check-added-large-files
12 | args: ['--maxkb=500']
13 | - id: end-of-file-fixer
14 | exclude: '.*\.rst'
15 |
16 | - repo: https://github.com/Lucas-C/pre-commit-hooks
17 | rev: v1.1.7
18 | hooks:
19 | - id: insert-license
20 | files: \.py$
21 | args:
22 | - --license-filepath
23 | - docs/license_header.txt
24 |
25 | - repo: https://github.com/pycqa/flake8
26 | rev: 6.1.0
27 | hooks:
28 | - id: flake8
29 | args:
30 | - --config=.flake8
31 |
32 | - repo: https://github.com/omnilib/ufmt
33 | rev: v2.5.1
34 | hooks:
35 | - id: ufmt
36 | additional_dependencies:
37 | - black == 24.2.0
38 | - usort == 1.0.2
39 |
--------------------------------------------------------------------------------
/docs/source/ext/fbcode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 |
10 | from docutils import nodes
11 | from sphinx.util.docutils import SphinxDirective
12 | from sphinx.util.nodes import nested_parse_with_titles
13 |
14 |
15 | class FbcodeDirective(SphinxDirective):
16 | # this enables content in the directive
17 | has_content = True
18 |
19 | def run(self):
20 | if "fbcode" not in os.getcwd():
21 | return []
22 | node = nodes.section()
23 | node.document = self.state.document
24 | nested_parse_with_titles(self.state, self.content, node)
25 | return node.children
26 |
27 |
28 | def setup(app):
29 | app.add_directive("fbcode", FbcodeDirective)
30 |
31 | return {
32 | "version": "0.1",
33 | "parallel_read_safe": True,
34 | "parallel_write_safe": True,
35 | }
36 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: 🚀 Feature request
2 | description: Submit a proposal/request for a new feature
3 |
4 | body:
5 | - type: textarea
6 | attributes:
7 | label: 🚀 The feature
8 | description: >
9 | A clear and concise description of the feature proposal
10 | validations:
11 | required: true
12 | - type: textarea
13 | attributes:
14 | label: Motivation, pitch
15 | description: >
16 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g.,
17 | *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link
18 | here too.
19 | validations:
20 | required: true
21 | - type: textarea
22 | attributes:
23 | label: Alternatives
24 | description: >
25 | A description of any alternative solutions or features you've considered, if any.
26 | - type: textarea
27 | attributes:
28 | label: Additional context
29 | description: >
30 | Add any other context or screenshots about the feature request.
31 | - type: markdown
32 | attributes:
33 | value: >
34 | Thanks for contributing 🎉!
35 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.ranking.click_through_rate import click_through_rate
10 | from torcheval.metrics.functional.ranking.frequency import frequency_at_k
11 |
12 | from torcheval.metrics.functional.ranking.hit_rate import hit_rate
13 |
14 | from torcheval.metrics.functional.ranking.num_collisions import num_collisions
15 |
16 | from torcheval.metrics.functional.ranking.reciprocal_rank import reciprocal_rank
17 |
18 | from torcheval.metrics.functional.ranking.retrieval_precision import retrieval_precision
19 |
20 | from torcheval.metrics.functional.ranking.retrieval_recall import retrieval_recall
21 |
22 | from torcheval.metrics.functional.ranking.weighted_calibration import (
23 | weighted_calibration,
24 | )
25 |
26 | __all__ = [
27 | "click_through_rate",
28 | "frequency_at_k",
29 | "hit_rate",
30 | "num_collisions",
31 | "reciprocal_rank",
32 | "weighted_calibration",
33 | "retrieval_precision",
34 | "retrieval_recall",
35 | ]
36 | __doc_name__ = "Ranking Metrics"
37 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_num_collisions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import num_collisions
13 |
14 |
15 | class TestNumCollisions(unittest.TestCase):
16 | def test_num_collisions_with_valid_input(self) -> None:
17 | input_test_1 = torch.tensor([3, 4, 2, 3])
18 | torch.testing.assert_close(
19 | num_collisions(input_test_1),
20 | torch.tensor([1, 0, 0, 1]),
21 | )
22 |
23 | input_test_2 = torch.tensor([3, 4, 1, 3, 1, 1, 5])
24 | torch.testing.assert_close(
25 | num_collisions(input_test_2),
26 | torch.tensor([1, 0, 2, 1, 2, 2, 0]),
27 | )
28 |
29 | def test_num_collisions_with_invalid_input(self) -> None:
30 | with self.assertRaisesRegex(
31 | ValueError, "input should be a one-dimensional tensor"
32 | ):
33 | num_collisions(torch.randint(10, (3, 2)))
34 |
35 | with self.assertRaisesRegex(ValueError, "input should be an integer tensor"):
36 | num_collisions(torch.rand(3))
37 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/tensor_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | def _riemann_integral(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14 | """Riemann integral approximates the area of each cell with a rectangle positioned at the egde.
15 | It is conventionally used rather than trapezoid approximation, which uses a rectangle positioned in the
16 | center"""
17 | return -torch.sum((x[1:] - x[:-1]) * y[:-1])
18 |
19 |
20 | def _create_threshold_tensor(
21 | threshold: int | list[float] | torch.Tensor,
22 | device: torch.device,
23 | ) -> torch.Tensor:
24 | """
25 | Creates a threshold tensor from an integer, a list or a tensor.
26 | If `threshold` is an integer n, returns a Tensor with values [0, 1/(n-1), 2/(n-1), ..., (n-2)/(n-1), 1].
27 | If `threshold` is a list, returns the list converted to a Tensor.
28 | Otherwise, returns the tensor itself.
29 | """
30 | if isinstance(threshold, int):
31 | threshold = torch.linspace(0, 1.0, threshold, device=device)
32 | elif isinstance(threshold, list):
33 | threshold = torch.tensor(threshold, device=device)
34 | return threshold
35 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/frequency.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def frequency_at_k(
15 | input: torch.Tensor,
16 | k: float,
17 | ) -> torch.Tensor:
18 | """
19 | Calculate the frequency given a list of frequencies and threshold k.
20 | Generate a binary list to indicate if frequencies is less than k.
21 |
22 | Args:
23 | input (Tensor): Predicted unnormalized scores (often referred to as logits).
24 | k (float): Threshold of the frequency. k should not negative value.
25 |
26 | Example:
27 | >>> import torch
28 | >>> from torcheval.metrics.functional import frequency
29 | >>> input = torch.tensor([0.3, 0.1, 0.6])
30 | >>> frequency(input, k=0.5)
31 | tensor([1.0000, 1.0000, 0.0000])
32 | """
33 | _frequency_input_check(input, k)
34 |
35 | return (input < k).float()
36 |
37 |
38 | def _frequency_input_check(input: torch.Tensor, k: float) -> None:
39 | if input.ndim != 1:
40 | raise ValueError(
41 | f"input should be a one-dimensional tensor, got shape {input.shape}."
42 | )
43 | if k < 0:
44 | raise ValueError(f"k should not be negative, got {k}.")
45 |
--------------------------------------------------------------------------------
/.github/workflows/build_docs.yaml:
--------------------------------------------------------------------------------
1 | name: Build and Update Docs
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 |
7 | # Allow one concurrent deployment
8 | concurrency:
9 | group: "pages"
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | build_docs:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Check out repo
17 | uses: actions/checkout@v2
18 | - name: Setup conda env
19 | uses: conda-incubator/setup-miniconda@v2
20 | with:
21 | miniconda-version: "latest"
22 | activate-environment: test
23 | - name: Install dependencies
24 | shell: bash -l {0}
25 | run: |
26 | set -eux
27 | conda activate test
28 | conda install pytorch cpuonly -c pytorch-nightly
29 | pip install -r requirements.txt
30 | pip install -r dev-requirements.txt
31 | python setup.py sdist bdist_wheel
32 | pip install dist/*.whl
33 | - name: Build docs
34 | shell: bash -l {0}
35 | run: |
36 | set -eux
37 | conda activate test
38 | cd docs
39 | pip install -r requirements.txt
40 | make html
41 | cd ..
42 | - name: Deploy docs to Github pages
43 | uses: JamesIves/github-pages-deploy-action@v4.4.1
44 | with:
45 | branch: gh-pages # The branch the action should deploy to.
46 | folder: docs/build/html # The folder the action should deploy.
47 | target-folder: main
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | For torcheval software
4 |
5 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without modification,
8 | are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name Meta nor the names of its contributors may be used to
18 | endorse or promote products derived from this software without specific
19 | prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/test_cov.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 | from torcheval.metrics import Covariance
12 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
13 |
14 |
15 | class TestCovariance(MetricClassTester):
16 | def _test_covariance_with_input(self, batching: list[int]) -> None:
17 | gen = torch.Generator()
18 | gen.manual_seed(3)
19 | X = torch.randn(sum(batching), 4, generator=gen)
20 | self.run_class_implementation_tests(
21 | metric=Covariance(),
22 | state_names={"n", "sum", "ss_sum"},
23 | update_kwargs={"obs": torch.split(X, batching, dim=0)},
24 | compute_result=(X.mean(dim=0), torch.cov(X.T)),
25 | num_total_updates=len(batching),
26 | min_updates_before_compute=1,
27 | num_processes=4,
28 | )
29 |
30 | def test_covariance_all_at_once(self) -> None:
31 | self._test_covariance_with_input([100, 100, 100, 100])
32 |
33 | def test_covariance_one_by_one(self) -> None:
34 | self._test_covariance_with_input(list(range(2, 22)))
35 |
36 | def test_covariance_overflow(self) -> None:
37 | cov = Covariance()
38 | s = torch.zeros(10)
39 | ss_sum = torch.ones(10, 10)
40 |
41 | cov._update(s, ss_sum, torch.iinfo(torch.uint64).max)
42 | cov._update(s, ss_sum, 1) # No overflow!
43 |
--------------------------------------------------------------------------------
/tests/metrics/functional/text/test_word_information_lost.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import word_information_lost
13 |
14 |
15 | class TestWordInformationLost(unittest.TestCase):
16 | def test_word_information_lost(self) -> None:
17 | input = ["hello world", "welcome to the facebook"]
18 | target = ["hello metaverse", "welcome to meta"]
19 | torch.testing.assert_close(
20 | word_information_lost(input, target),
21 | torch.tensor(0.7, dtype=torch.float64),
22 | )
23 |
24 | input = ["this is the prediction", "there is an other sample"]
25 | target = ["this is the reference", "there is another one"]
26 | torch.testing.assert_close(
27 | word_information_lost(input, target),
28 | torch.tensor(0.6527777, dtype=torch.float64),
29 | )
30 |
31 | def test_word_information_lost_with_invalid_input(self) -> None:
32 | with self.assertRaisesRegex(
33 | AssertionError,
34 | "Arguments must contain the same number of strings.",
35 | ):
36 | word_information_lost(
37 | ["hello metaverse", "welcome to meta"],
38 | [
39 | "welcome to meta",
40 | "this is the prediction",
41 | "there is an other sample",
42 | ],
43 | )
44 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_frequency.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import frequency_at_k
13 |
14 |
15 | class TestFrequency(unittest.TestCase):
16 | def test_frequency_with_valid_input(self) -> None:
17 | input = torch.tensor(
18 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700],
19 | )
20 |
21 | torch.testing.assert_close(
22 | frequency_at_k(input, k=0.5),
23 | torch.tensor([1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000]),
24 | )
25 | torch.testing.assert_close(
26 | frequency_at_k(input, k=0.9),
27 | torch.tensor([1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.0000, 0.0000]),
28 | )
29 | torch.testing.assert_close(
30 | frequency_at_k(input, k=0.95),
31 | torch.tensor([1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000]),
32 | )
33 | torch.testing.assert_close(
34 | frequency_at_k(input, k=1.0),
35 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]),
36 | )
37 |
38 | def test_frequency_with_invalid_input(self) -> None:
39 | with self.assertRaisesRegex(
40 | ValueError, "input should be a one-dimensional tensor"
41 | ):
42 | frequency_at_k(torch.rand(3, 2, 2), k=1)
43 | with self.assertRaisesRegex(ValueError, "k should not be negative"):
44 | frequency_at_k(torch.rand(3), k=-1)
45 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/aggregation/throughput.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def throughput(
15 | num_processed: int = 0,
16 | elapsed_time_sec: float = 0.0,
17 | ) -> torch.Tensor:
18 | """
19 | Calculate the throughput value which is the number of elements processed per second.
20 | Its class version is ``torcheval.metrics.Throughput``.
21 |
22 | Args:
23 | num_processed (int): Number of items processed.
24 | elapsed_time_sec (float): Total elapsed time in seconds to process ``num_processed`` items.
25 | Raises:
26 | ValueError:
27 | If ``num_processed`` is a negative number.
28 | If ``elapsed_time_sec`` is a non-positive number.
29 |
30 | Examples::
31 |
32 | >>> import torch
33 | >>> from torcheval.metrics.functional import throughput
34 | >>> throughput(64, 2.0)
35 | tensor(32.)
36 | """
37 | return _throughput_compute(num_processed, elapsed_time_sec)
38 |
39 |
40 | def _throughput_compute(num_processed: int, elapsed_time_sec: float) -> torch.Tensor:
41 | if num_processed < 0:
42 | raise ValueError(
43 | f"Expected num_processed to be a non-negative number, but received {num_processed}."
44 | )
45 | if elapsed_time_sec <= 0:
46 | raise ValueError(
47 | f"Expected elapsed_time_sec to be a positive number, but received {elapsed_time_sec}."
48 | )
49 | return torch.tensor(num_processed / elapsed_time_sec)
50 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/num_collisions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 |
11 |
12 | @torch.inference_mode()
13 | def num_collisions(input: torch.Tensor) -> torch.Tensor:
14 | """
15 | Compute the number of collisions given a list of input(ids).
16 |
17 | Args:
18 | input (Tensor): a tensor of input ids (num_samples, ).
19 | class probabilities of shape (num_samples, num_classes).
20 |
21 | Examples::
22 |
23 | >>> import torch
24 | >>> from torcheval.metrics.functional import num_collisions
25 | >>> input = torch.tensor([3, 4, 2, 3])
26 | >>> num_collisions(input)
27 | tensor([1, 0, 0, 1])
28 | >>> input = torch.tensor([3, 4, 1, 3, 1, 1, 5])
29 | >>> num_collisions(input)
30 | tensor([1, 0, 2, 1, 2, 2, 0])
31 | """
32 | _num_collisions_input_check(input)
33 |
34 | input_for_logits = input.view(1, -1).repeat_interleave(torch.numel(input), dim=0)
35 | num_collisions = (input_for_logits == input.view(-1, 1)).sum(
36 | dim=1, keepdim=True
37 | ) - 1
38 | return num_collisions.view(-1)
39 |
40 |
41 | def _num_collisions_input_check(input: torch.Tensor) -> None:
42 | if input.ndim != 1:
43 | raise ValueError(
44 | f"input should be a one-dimensional tensor, got shape {input.shape}."
45 | )
46 |
47 | if input.dtype not in (
48 | torch.int,
49 | torch.int8,
50 | torch.int16,
51 | torch.int32,
52 | torch.int64,
53 | ):
54 | raise ValueError(f"input should be an integer tensor, got {input.dtype}.")
55 |
--------------------------------------------------------------------------------
/tests/metrics/functional/aggregation/test_throughput.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import random
10 | import unittest
11 |
12 | import torch
13 | from torcheval.metrics.functional import throughput
14 | from torcheval.utils.test_utils.metric_class_tester import NUM_PROCESSES
15 |
16 |
17 | class TestThroughput(unittest.TestCase):
18 | def _test_throughput_with_input(
19 | self,
20 | num_processed: int,
21 | elapsed_time_sec: float,
22 | ) -> None:
23 | torch.testing.assert_close(
24 | throughput(num_processed, elapsed_time_sec),
25 | torch.tensor(num_processed / elapsed_time_sec),
26 | equal_nan=True,
27 | atol=1e-8,
28 | rtol=1e-5,
29 | )
30 |
31 | def test_throughput_base(self) -> None:
32 | num_processed = NUM_PROCESSES
33 | elapsed_time_sec = random.random() * 20
34 | self._test_throughput_with_input(num_processed, elapsed_time_sec)
35 |
36 | def test_throughput_update_input_invalid_num_processed(self) -> None:
37 | with self.assertRaisesRegex(
38 | ValueError,
39 | r"Expected num_processed to be a non-negative number, but received",
40 | ):
41 | throughput(-1, 1.0)
42 |
43 | def test_throughput_update_input_invalid_elapsed_time_sec(self) -> None:
44 | with self.assertRaisesRegex(
45 | ValueError,
46 | r"Expected elapsed_time_sec to be a positive number, but received",
47 | ):
48 | throughput(42, -5.1)
49 | with self.assertRaisesRegex(
50 | ValueError,
51 | r"Expected elapsed_time_sec to be a positive number, but received",
52 | ):
53 | throughput(42, 0.0)
54 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/test_max.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics import Max
11 | from torcheval.utils.test_utils.metric_class_tester import (
12 | BATCH_SIZE,
13 | MetricClassTester,
14 | NUM_TOTAL_UPDATES,
15 | )
16 |
17 |
18 | class TestMax(MetricClassTester):
19 | def _test_max_class_with_input(self, input_val_tensor: torch.Tensor) -> None:
20 | self.run_class_implementation_tests(
21 | metric=Max(),
22 | state_names={"max"},
23 | update_kwargs={"input": input_val_tensor},
24 | compute_result=torch.max(input_val_tensor),
25 | )
26 |
27 | def test_max_class_base(self) -> None:
28 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
29 | self._test_max_class_with_input(input_val_tensor)
30 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4)
31 | self._test_max_class_with_input(input_val_tensor)
32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4)
33 | self._test_max_class_with_input(input_val_tensor)
34 |
35 | def test_max_class_update_input_dimension_different(self) -> None:
36 | self.run_class_implementation_tests(
37 | metric=Max(),
38 | state_names={"max"},
39 | update_kwargs={
40 | "input": [
41 | torch.tensor(1.0),
42 | torch.tensor([2.0, 3.0, 5.0]),
43 | torch.tensor([-1.0, 2.0]),
44 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]),
45 | ]
46 | },
47 | compute_result=torch.tensor(6.0),
48 | num_total_updates=4,
49 | num_processes=2,
50 | )
51 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/test_min.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics import Min
11 | from torcheval.utils.test_utils.metric_class_tester import (
12 | BATCH_SIZE,
13 | MetricClassTester,
14 | NUM_TOTAL_UPDATES,
15 | )
16 |
17 |
18 | class TestMin(MetricClassTester):
19 | def _test_min_class_with_input(self, input_val_tensor: torch.Tensor) -> None:
20 | self.run_class_implementation_tests(
21 | metric=Min(),
22 | state_names={"min"},
23 | update_kwargs={"input": input_val_tensor},
24 | compute_result=torch.min(input_val_tensor),
25 | )
26 |
27 | def test_min_class_base(self) -> None:
28 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
29 | self._test_min_class_with_input(input_val_tensor)
30 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4)
31 | self._test_min_class_with_input(input_val_tensor)
32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4)
33 | self._test_min_class_with_input(input_val_tensor)
34 |
35 | def test_min_class_update_input_dimension_different(self) -> None:
36 | self.run_class_implementation_tests(
37 | metric=Min(),
38 | state_names={"min"},
39 | update_kwargs={
40 | "input": [
41 | torch.tensor(1.0),
42 | torch.tensor([2.0, 3.0, 5.0]),
43 | torch.tensor([-1.0, 2.0]),
44 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]),
45 | ]
46 | },
47 | compute_result=torch.tensor(-4.0),
48 | num_total_updates=4,
49 | num_processes=2,
50 | )
51 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/aggregation/sum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def sum(
15 | input: torch.Tensor,
16 | weight: float | torch.Tensor = 1.0,
17 | ) -> torch.Tensor:
18 | """
19 | Compute weighted sum. When weight is not provided, it calculates the unweighted sum.
20 | Its class version is ``torcheval.metrics.Sum``.
21 |
22 | Args:
23 | input (Tensor): Tensor of input values.
24 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size.
25 | Raises:
26 | ValueError: If value of weight is neither a ``float`` nor an ``int`` nor a ``torch.Tensor`` that matches the input tensor size.
27 |
28 | Examples::
29 |
30 | >>> import torch
31 | >>> from torcheval.metrics.functional import sum
32 | >>> sum(torch.tensor([2, 3]))
33 | tensor(5.)
34 | >>> sum(torch.tensor([2, 3]), torch.tensor([0.1, 0.6]))
35 | tensor(2.)
36 | >>> sum(torch.tensor([2, 3]), 0.5)
37 | tensor(2.5)
38 | >>> sum(torch.tensor([2, 3]), 2)
39 | tensor(10.)
40 | """
41 | return _sum_update(input, weight)
42 |
43 |
44 | def _sum_update(
45 | input: torch.Tensor, weight: float | int | torch.Tensor
46 | ) -> torch.Tensor:
47 | if (
48 | isinstance(weight, float)
49 | or isinstance(weight, int)
50 | or (isinstance(weight, torch.Tensor) and input.size() == weight.size())
51 | ):
52 | return (input * weight).sum()
53 | else:
54 | raise ValueError(
55 | "Weight must be either a float value or an int value or a tensor that matches the input tensor size. "
56 | f"Got {weight} instead."
57 | )
58 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/min.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.metric import Metric
17 |
18 |
19 | TMin = TypeVar("TMin")
20 |
21 |
22 | class Min(Metric[torch.Tensor]):
23 | """
24 | Calculate the minimum value of all elements in all the input tensors.
25 | Its functional version is ``torch.min(input)``.
26 |
27 | Examples::
28 |
29 | >>> import torch
30 | >>> from torcheval.metrics import Min
31 | >>> metric = Min()
32 | >>> metric.update(torch.tensor([[1, 2], [3, 4]]))
33 | >>> metric.compute()
34 | tensor(1.)
35 |
36 | >>> metric.update(torch.tensor(-1)).compute()
37 | tensor(-1.)
38 |
39 | >>> metric.reset()
40 | >>> metric.update(torch.tensor(5)).compute()
41 | tensor(5.)
42 | """
43 |
44 | def __init__(
45 | self: TMin,
46 | *,
47 | device: torch.device | None = None,
48 | ) -> None:
49 | super().__init__(device=device)
50 | self._add_state("min", torch.tensor(float("inf"), device=self.device))
51 |
52 | @torch.inference_mode()
53 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
54 | def update(self: TMin, input: torch.Tensor) -> TMin:
55 | self.min = torch.min(self.min, torch.min(input))
56 | return self
57 |
58 | @torch.inference_mode()
59 | def compute(self: TMin) -> torch.Tensor:
60 | return self.min
61 |
62 | @torch.inference_mode()
63 | def merge_state(self: TMin, metrics: Iterable[TMin]) -> TMin:
64 | for metric in metrics:
65 | self.min = torch.min(self.min, metric.min.to(self.device))
66 | return self
67 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/max.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.metric import Metric
17 |
18 |
19 | TMax = TypeVar("TMax")
20 |
21 |
22 | class Max(Metric[torch.Tensor]):
23 | """
24 | Calculate the maximum value of all elements in all the input tensors.
25 | Its functional version is ``torch.max(input)``.
26 |
27 | Examples::
28 |
29 | >>> import torch
30 | >>> from torcheval.metrics import Max
31 | >>> metric = Max()
32 | >>> metric.update(torch.tensor([[1, 2], [3, 4]]))
33 | >>> metric.compute()
34 | tensor(4.)
35 |
36 | >>> metric.update(torch.tensor(-1)).compute()
37 | tensor(4.)
38 |
39 | >>> metric.reset()
40 | >>> metric.update(torch.tensor(-1)).compute()
41 | tensor(-1.)
42 | """
43 |
44 | def __init__(
45 | self: TMax,
46 | *,
47 | device: torch.device | None = None,
48 | ) -> None:
49 | super().__init__(device=device)
50 | self._add_state("max", torch.tensor(float("-inf"), device=self.device))
51 |
52 | @torch.inference_mode()
53 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
54 | def update(self: TMax, input: torch.Tensor) -> TMax:
55 | self.max = torch.max(self.max, torch.max(input))
56 | return self
57 |
58 | @torch.inference_mode()
59 | def compute(self: TMax) -> torch.Tensor:
60 | return self.max
61 |
62 | @torch.inference_mode()
63 | def merge_state(self: TMax, metrics: Iterable[TMax]) -> TMax:
64 | for metric in metrics:
65 | self.max = torch.max(self.max, metric.max.to(self.device))
66 | return self
67 |
--------------------------------------------------------------------------------
/tests/metrics/text/test_word_information_lost.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.text import WordInformationLost
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestWordInformationLost(MetricClassTester):
15 | def test_word_information_lost(self) -> None:
16 | self.run_class_implementation_tests(
17 | metric=WordInformationLost(),
18 | state_names={"correct_total", "target_total", "preds_total"},
19 | update_kwargs={
20 | "input": [
21 | ["hello world", "welcome to the facebook"],
22 | ["hello world", "welcome to the facebook"],
23 | ["hello world", "welcome to the facebook"],
24 | ["hello world", "welcome to the facebook"],
25 | ],
26 | "target": [
27 | ["hello metaverse", "welcome to meta"],
28 | ["hello metaverse", "welcome to meta"],
29 | ["hello metaverse", "welcome to meta"],
30 | ["hello metaverse", "welcome to meta"],
31 | ],
32 | },
33 | compute_result=torch.tensor(0.7, dtype=torch.float64),
34 | num_total_updates=4,
35 | )
36 |
37 | def test_word_information_lost_with_invalid_input(self) -> None:
38 | metric = WordInformationLost()
39 |
40 | with self.assertRaisesRegex(
41 | AssertionError,
42 | "Arguments must contain the same number of strings.",
43 | ):
44 | metric.update(
45 | ["hello metaverse", "welcome to meta"],
46 | [
47 | "welcome to meta",
48 | "this is the prediction",
49 | "there is an other sample",
50 | ],
51 | )
52 |
--------------------------------------------------------------------------------
/tests/metrics/text/test_word_error_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.text import WordErrorRate
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestWordErrorRate(MetricClassTester):
15 | def test_word_error_rate_with_valid_input(self) -> None:
16 | self.run_class_implementation_tests(
17 | metric=WordErrorRate(),
18 | state_names={"errors", "total"},
19 | update_kwargs={
20 | "input": [
21 | ["hello world", "welcome to the facebook"],
22 | ["hello world", "welcome to the facebook"],
23 | ["hello world", "welcome to the facebook"],
24 | ["hello world", "welcome to the facebook"],
25 | ],
26 | "target": [
27 | ["hello metaverse", "welcome to meta"],
28 | ["hello metaverse", "welcome to meta"],
29 | ["hello metaverse", "welcome to meta"],
30 | ["hello metaverse", "welcome to meta"],
31 | ],
32 | },
33 | compute_result=torch.tensor(0.6),
34 | num_total_updates=4,
35 | )
36 |
37 | def test_word_error_rate_with_invalid_input(self) -> None:
38 | metric = WordErrorRate()
39 | with self.assertRaisesRegex(
40 | ValueError, "input and target should have the same type"
41 | ):
42 | metric.update(["hello metaverse", "welcome to meta"], "hello world")
43 |
44 | with self.assertRaisesRegex(
45 | ValueError, "input and target lists should have the same length"
46 | ):
47 | metric.update(
48 | ["hello metaverse", "welcome to meta"],
49 | [
50 | "welcome to meta",
51 | "this is the prediction",
52 | "there is an other sample",
53 | ],
54 | )
55 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to torcheval
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Development Installation
6 | To get the development installation with all the necessary dependencies for
7 | linting, testing, and building the documentation, run the following:
8 | ```bash
9 | git clone https://github.com/pytorch/torcheval
10 | cd torcheval
11 | pip install -r requirements.txt
12 | pip install -r dev-requirements.txt
13 | pip install -r docs/requirements.txt
14 | pip install --no-build-isolation -e ".[dev]"
15 | ```
16 |
17 | ## Pull Requests
18 | We actively welcome your pull requests.
19 |
20 | 1. Create your branch from `main`.
21 | 2. If you've added code that should be tested, add tests.
22 | 3. If you've changed APIs, update the documentation.
23 | - To build docs
24 | ```bash
25 | cd docs; make html
26 | ```
27 | - To view docs
28 | ```bash
29 | cd build/html; python -m http.server
30 | ```
31 | 4. Ensure the test suite passes.
32 | - To run all tests
33 | ```bash
34 | python -m pytest tests/
35 | ```
36 | - To run a single test
37 | ```bash
38 | python -m pytest -v tests/metrics/test_metric.py::MetricBaseClassTest::test_add_state_invalid
39 | ```
40 |
41 | 5. Make sure your code lints.
42 | ```bash
43 | pre-commit run --all-files
44 | ```
45 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
46 |
47 | ## Contributor License Agreement ("CLA")
48 | In order to accept your pull request, we need you to submit a CLA. You only need
49 | to do this once to work on any of Meta's open source projects.
50 |
51 | Complete your CLA here:
52 |
53 | ## Issues
54 | We use GitHub issues to track public bugs. Please ensure your description is
55 | clear and has sufficient instructions to be able to reproduce the issue.
56 |
57 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
58 | disclosure of security bugs. In those cases, please go through the process
59 | outlined on that page and do not file a public issue.
60 |
61 | ## License
62 | By contributing to torcheval, you agree that your contributions will be licensed
63 | under the LICENSE file in the root directory of this source tree.
64 |
--------------------------------------------------------------------------------
/tests/metrics/text/test_word_information_preserved.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.text import WordInformationPreserved
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestWordInformationPreserved(MetricClassTester):
15 | def test_word_information_preserved_with_valid_input(self) -> None:
16 | self.run_class_implementation_tests(
17 | metric=WordInformationPreserved(),
18 | state_names={"correct_total", "input_total", "target_total"},
19 | update_kwargs={
20 | "input": [
21 | ["hello world", "welcome to the facebook"],
22 | ["hello world", "welcome to the facebook"],
23 | ["hello world", "welcome to the facebook"],
24 | ["hello world", "welcome to the facebook"],
25 | ],
26 | "target": [
27 | ["hello metaverse", "welcome to meta"],
28 | ["hello metaverse", "welcome to meta"],
29 | ["hello metaverse", "welcome to meta"],
30 | ["hello metaverse", "welcome to meta"],
31 | ],
32 | },
33 | compute_result=torch.tensor(0.3, dtype=torch.float64),
34 | num_total_updates=4,
35 | )
36 |
37 | def test_word_information_preserved_with_invalid_input(self) -> None:
38 | metric = WordInformationPreserved()
39 | with self.assertRaisesRegex(
40 | ValueError, "input and target should have the same type"
41 | ):
42 | metric.update(["hello metaverse", "welcome to meta"], "hello world")
43 |
44 | with self.assertRaisesRegex(
45 | ValueError, "input and target lists should have the same length"
46 | ):
47 | metric.update(
48 | ["hello metaverse", "welcome to meta"],
49 | [
50 | "welcome to meta",
51 | "this is the prediction",
52 | "there is an other sample",
53 | ],
54 | )
55 |
--------------------------------------------------------------------------------
/.github/workflows/release_build.yaml:
--------------------------------------------------------------------------------
1 | name: Push Release to PyPi
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | unit_tests:
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: [3.8, 3.9, "3.10"]
12 | steps:
13 | - name: Check out repo
14 | uses: actions/checkout@v2
15 | - name: Setup conda env
16 | uses: conda-incubator/setup-miniconda@v2
17 | with:
18 | miniconda-version: "latest"
19 | activate-environment: test
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | shell: bash -l {0}
23 | run: |
24 | set -eux
25 | conda activate test
26 | conda install pytorch torchaudio torchvision cpuonly -c pytorch-nightly
27 | pip install -r requirements.txt
28 | pip install -r dev-requirements.txt
29 | python setup.py sdist bdist_wheel
30 | pip install dist/*.whl
31 | - name: Run unit tests
32 | shell: bash -l {0}
33 | run: |
34 | set -eux
35 | conda activate test
36 | pytest tests -vv
37 | # TODO figure out how to deduplicate steps
38 | upload_to_pypi:
39 | needs: unit_tests
40 | runs-on: ubuntu-latest
41 | steps:
42 | - name: Check out repo
43 | uses: actions/checkout@v2
44 | - name: Setup conda env
45 | uses: conda-incubator/setup-miniconda@v2
46 | with:
47 | miniconda-version: "latest"
48 | activate-environment: test
49 | python-version: "3.10"
50 | - name: Install dependencies
51 | shell: bash -l {0}
52 | run: |
53 | set -eux
54 | conda activate test
55 | conda install pytorch cpuonly -c pytorch-nightly
56 | pip install -r requirements.txt
57 | pip install -r dev-requirements.txt
58 | pip install --no-build-isolation -e ".[dev]"
59 | - name: Upload to PyPI
60 | shell: bash -l {0}
61 | env:
62 | PYPI_USER: ${{ secrets.PYPI_USER_RELEASE }}
63 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_RELEASE }}
64 | run: |
65 | set -eux
66 | conda activate test
67 | pip install twine
68 | python setup.py sdist bdist_wheel
69 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose
70 |
--------------------------------------------------------------------------------
/tests/metrics/functional/image/test_psnr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 |
13 | from skimage.metrics import peak_signal_noise_ratio as skimage_psnr
14 | from torcheval.metrics.functional import peak_signal_noise_ratio
15 | from torcheval.utils.test_utils.metric_class_tester import (
16 | BATCH_SIZE,
17 | IMG_CHANNELS,
18 | IMG_HEIGHT,
19 | IMG_WIDTH,
20 | )
21 |
22 |
23 | class TestPeakSignalNoiseRatio(unittest.TestCase):
24 | def test_psnr_skimage_equivelant(self) -> None:
25 | input, target = self._get_random_data_peak_signal_to_noise_ratio(
26 | BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH
27 | )
28 |
29 | input_np = input.numpy().ravel()
30 | target_np = target.numpy().ravel()
31 | skimage_result = torch.tensor(
32 | skimage_psnr(target_np, input_np), dtype=torch.float32
33 | )
34 |
35 | torch.testing.assert_close(
36 | peak_signal_noise_ratio(input, target),
37 | skimage_result,
38 | atol=1e-3,
39 | rtol=1e-3,
40 | )
41 |
42 | def test_psnr_with_invalid_input(self) -> None:
43 | input = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
44 | target = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH + 1)
45 | with self.assertRaisesRegex(
46 | ValueError,
47 | r"^The `input` and `target` must have the same shape, "
48 | + rf"got shapes torch.Size\(\[{BATCH_SIZE}, {IMG_CHANNELS}, {IMG_HEIGHT}, {IMG_WIDTH}\]\) "
49 | + rf"and torch.Size\(\[{BATCH_SIZE}, {IMG_CHANNELS}, {IMG_HEIGHT}, {IMG_WIDTH + 1}\]\).",
50 | ):
51 | peak_signal_noise_ratio(input, target)
52 |
53 | def _get_random_data_peak_signal_to_noise_ratio(
54 | self, batch_size: int, num_channels: int, height: int, width: int
55 | ) -> tuple[torch.Tensor, torch.Tensor]:
56 | input = torch.rand(
57 | size=(batch_size, num_channels, height, width),
58 | )
59 | target = torch.rand(
60 | size=(batch_size, num_channels, height, width),
61 | )
62 | return input, target
63 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_click_through_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import click_through_rate
13 |
14 |
15 | class TestClickThroughRate(unittest.TestCase):
16 | def test_click_through_rate_with_valid_input(self) -> None:
17 | input = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1])
18 | weights = torch.tensor([1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0])
19 | torch.testing.assert_close(click_through_rate(input), torch.tensor(0.5))
20 | torch.testing.assert_close(
21 | click_through_rate(input, weights), torch.tensor(0.58333334)
22 | )
23 |
24 | input = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]])
25 | weights = torch.tensor([[1.0, 2.0, 1.0, 2.0], [1.0, 2.0, 1.0, 1.0]])
26 | torch.testing.assert_close(
27 | click_through_rate(input, num_tasks=2), torch.tensor([0.5, 0.5])
28 | )
29 | torch.testing.assert_close(
30 | click_through_rate(input, weights, num_tasks=2),
31 | torch.tensor([0.66666667, 0.4]),
32 | )
33 |
34 | def test_click_through_rate_with_invalid_input(self) -> None:
35 | with self.assertRaisesRegex(
36 | ValueError,
37 | "^`input` should be a one or two dimensional tensor",
38 | ):
39 | click_through_rate(torch.rand(3, 2, 2))
40 | with self.assertRaisesRegex(
41 | ValueError,
42 | "^tensor `weights` should have the same shape as tensor `input`",
43 | ):
44 | click_through_rate(torch.rand(4, 2), torch.rand(3))
45 | with self.assertRaisesRegex(
46 | ValueError,
47 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,",
48 | ):
49 | click_through_rate(
50 | torch.tensor([[1, 1], [0, 1]]),
51 | )
52 | with self.assertRaisesRegex(
53 | ValueError,
54 | r"`num_tasks = 2`, `input`'s shape is expected to be",
55 | ):
56 | click_through_rate(
57 | torch.tensor([1, 0, 0, 1]),
58 | num_tasks=2,
59 | )
60 |
--------------------------------------------------------------------------------
/docs/source/torcheval.metrics.functional.rst:
--------------------------------------------------------------------------------
1 | Functional Metrics
2 | ==================
3 |
4 | .. automodule:: torcheval.metrics.functional
5 |
6 | Aggregation Metrics
7 | -------------------------------------------------------------------
8 |
9 | .. autosummary::
10 | :toctree: generated
11 | :nosignatures:
12 |
13 | auc
14 | mean
15 | sum
16 | throughput
17 |
18 | Classification Metrics
19 | -------------------------------------------------------------------
20 |
21 | .. autosummary::
22 | :toctree: generated
23 | :nosignatures:
24 |
25 | binary_accuracy
26 | binary_auprc
27 | binary_auroc
28 | binary_binned_auroc
29 | binary_binned_precision_recall_curve
30 | binary_confusion_matrix
31 | binary_f1_score
32 | binary_normalized_entropy
33 | binary_precision
34 | binary_precision_recall_curve
35 | binary_recall
36 | binary_recall_at_fixed_precision
37 | multiclass_accuracy
38 | multiclass_auprc
39 | multiclass_auroc
40 | multiclass_binned_auroc
41 | multiclass_binned_precision_recall_curve
42 | multiclass_confusion_matrix
43 | multiclass_f1_score
44 | multiclass_precision
45 | multiclass_precision_recall_curve
46 | multiclass_recall
47 | multilabel_accuracy
48 | multilabel_auprc
49 | multilabel_precision_recall_curve
50 | multilabel_recall_at_fixed_precision
51 | topk_multilabel_accuracy
52 |
53 | Image Metrics
54 | -------------------------------------------------------------------
55 |
56 | .. autosummary::
57 | :toctree: generated
58 | :nosignatures:
59 |
60 | peak_signal_noise_ratio
61 |
62 | Ranking Metrics
63 | -------------------------------------------------------------------
64 |
65 | .. autosummary::
66 | :toctree: generated
67 | :nosignatures:
68 |
69 | click_through_rate
70 | frequency_at_k
71 | hit_rate
72 | num_collisions
73 | reciprocal_rank
74 | weighted_calibration
75 |
76 | Regression Metrics
77 | -------------------------------------------------------------------
78 |
79 | .. autosummary::
80 | :toctree: generated
81 | :nosignatures:
82 |
83 | mean_squared_error
84 | r2_score
85 |
86 | Text Metrics
87 | -------------------------------------------------------------------
88 |
89 | .. autosummary::
90 | :toctree: generated
91 | :nosignatures:
92 |
93 | bleu_score
94 | perplexity
95 | word_error_rate
96 | word_information_preserved
97 | word_information_lost
98 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/aggregation/mean.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def mean(
15 | input: torch.Tensor,
16 | weight: float | int | torch.Tensor = 1.0,
17 | ) -> torch.Tensor:
18 | """
19 | Compute weighted mean. When weight is not provided, it calculates the unweighted mean.
20 | Its class version is ``torcheval.metrics.Mean``.
21 |
22 | weighted_mean = sum(weight * input) / sum(weight)
23 |
24 | Args:
25 | input (Tensor): Tensor of input values.
26 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size.
27 | Raises:
28 | ValueError: If value of weight is neither a ``float`` nor a ``int`` nor a ``torch.Tensor`` that matches the input tensor size.
29 |
30 | Examples::
31 |
32 | >>> import torch
33 | >>> from torcheval.metrics.functional import mean
34 | >>> mean(torch.tensor([2, 3]))
35 | tensor(2.5)
36 | >>> mean(torch.tensor([2, 3]), torch.tensor([0.2, 0.8]))
37 | tensor(2.8)
38 | >>> mean(torch.tensor([2, 3]), 0.5)
39 | tensor(2.5)
40 | >>> mean(torch.tensor([2, 3]), 1)
41 | tensor(2.5)
42 | """
43 | return _mean_compute(input, weight)
44 |
45 |
46 | def _mean_update(
47 | input: torch.Tensor, weight: float | int | torch.Tensor
48 | ) -> tuple[torch.Tensor, torch.Tensor]:
49 | if isinstance(weight, float) or isinstance(weight, int):
50 | weighted_sum = weight * torch.sum(input)
51 | weights = torch.tensor(float(weight) * torch.numel(input))
52 | return weighted_sum, weights
53 | elif isinstance(weight, torch.Tensor) and input.size() == weight.size():
54 | return torch.sum(weight * input), torch.sum(weight)
55 | else:
56 | raise ValueError(
57 | "Weight must be either a float value or a tensor that matches the input tensor size. "
58 | f"Got {weight} instead."
59 | )
60 |
61 |
62 | def _mean_compute(
63 | input: torch.Tensor, weight: float | int | torch.Tensor
64 | ) -> torch.Tensor:
65 | weighted_sum, weights = _mean_update(input, weight)
66 | return weighted_sum / weights
67 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def reciprocal_rank(
15 | input: torch.Tensor,
16 | target: torch.Tensor,
17 | *,
18 | k: int | None = None,
19 | ) -> torch.Tensor:
20 | """
21 | Compute the reciprocal rank of the correct class among the top predicted classes.
22 | Its class version is ``torcheval.metrics.ReciprocalRank``.
23 |
24 | Args:
25 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or
26 | class probabilities of shape (num_samples, num_classes).
27 | target (Tensor): Ground truth class indices of shape (num_samples,).
28 | k (int, optional): Number of top class probabilities to be considered.
29 |
30 | Examples::
31 |
32 | >>> import torch
33 | >>> from torcheval.metrics.functional import reciprocal_rank
34 | >>> input = torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3], [0.2, 0.1, 0.7], [0.3, 0.3, 0.4]])
35 | >>> target = torch.tensor([2, 1, 1, 0])
36 | >>> reciprocal_rank(input, target)
37 | tensor([1.0000, 0.3333, 0.3333, 0.5000])
38 | >>> reciprocal_rank(input, target, k=2)
39 | tensor([1.0000, 0.0000, 0.0000, 0.5000])
40 | """
41 | _reciprocal_rank_input_check(input, target)
42 |
43 | y_score = torch.gather(input, dim=-1, index=target.unsqueeze(dim=-1))
44 | rank = torch.gt(input, y_score).sum(dim=-1)
45 | score = torch.reciprocal(rank + 1.0)
46 | if k is not None:
47 | score[rank >= k] = 0.0
48 | return score
49 |
50 |
51 | def _reciprocal_rank_input_check(input: torch.Tensor, target: torch.Tensor) -> None:
52 | if target.ndim != 1:
53 | raise ValueError(
54 | f"target should be a one-dimensional tensor, got shape {target.shape}."
55 | )
56 | if input.ndim != 2:
57 | raise ValueError(
58 | f"input should be a two-dimensional tensor, got shape {input.shape}."
59 | )
60 | if input.shape[0] != target.shape[0]:
61 | raise ValueError(
62 | "`input` and `target` should have the same minibatch dimension, ",
63 | f"got shapes {input.shape} and {target.shape}, respectively.",
64 | )
65 |
--------------------------------------------------------------------------------
/.github/workflows/nightly_build_cpu.yaml:
--------------------------------------------------------------------------------
1 | name: Push CPU Binary Nightly
2 |
3 | on:
4 | # run every day at 11:15am
5 | schedule:
6 | - cron: '15 11 * * *'
7 | # or manually trigger it
8 | workflow_dispatch:
9 |
10 |
11 | jobs:
12 | unit_tests:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: [3.8, 3.9, "3.10"]
17 | steps:
18 | - name: Check out repo
19 | uses: actions/checkout@v2
20 | - name: Setup conda env
21 | uses: conda-incubator/setup-miniconda@v2
22 | with:
23 | miniconda-version: "latest"
24 | activate-environment: test
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | shell: bash -l {0}
28 | run: |
29 | set -eux
30 | conda activate test
31 | conda install pytorch torchaudio torchvision cpuonly -c pytorch-nightly
32 | pip install -r requirements.txt
33 | pip install -r dev-requirements.txt
34 | python setup.py sdist bdist_wheel
35 | pip install dist/*.whl
36 | - name: Run unit tests
37 | shell: bash -l {0}
38 | run: |
39 | set -eux
40 | conda activate test
41 | pytest tests -vv
42 | # TODO figure out how to deduplicate steps
43 | upload_to_pypi:
44 | needs: unit_tests
45 | runs-on: ubuntu-latest
46 | steps:
47 | - name: Check out repo
48 | uses: actions/checkout@v2
49 | - name: Setup conda env
50 | uses: conda-incubator/setup-miniconda@v2
51 | with:
52 | miniconda-version: "latest"
53 | activate-environment: test
54 | python-version: "3.10"
55 | - name: Install dependencies
56 | shell: bash -l {0}
57 | run: |
58 | set -eux
59 | conda activate test
60 | conda install pytorch cpuonly -c pytorch-nightly
61 | pip install -r requirements.txt
62 | pip install -r dev-requirements.txt
63 | pip install --no-build-isolation -e ".[dev]"
64 | - name: Upload to PyPI
65 | shell: bash -l {0}
66 | env:
67 | PYPI_USER: ${{ secrets.PYPI_USER }}
68 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
69 | run: |
70 | set -eux
71 | conda activate test
72 | pip install twine
73 | python setup.py --nightly sdist bdist_wheel
74 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose
75 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_hit_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import hit_rate
13 |
14 |
15 | class TestHitRate(unittest.TestCase):
16 | def test_hit_rate_with_valid_input(self) -> None:
17 | input = torch.tensor(
18 | [
19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700],
20 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091],
21 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337],
22 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516],
23 | ]
24 | )
25 | target = torch.tensor([3, 5, 2, 1])
26 |
27 | torch.testing.assert_close(
28 | hit_rate(input, target, k=None),
29 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]),
30 | )
31 | torch.testing.assert_close(
32 | hit_rate(input, target, k=1),
33 | torch.tensor([0.0000, 1.0000, 0.0000, 0.0000]),
34 | )
35 | torch.testing.assert_close(
36 | hit_rate(input, target, k=3),
37 | torch.tensor([0.0000, 1.0000, 1.0000, 1.0000]),
38 | )
39 | torch.testing.assert_close(
40 | hit_rate(input, target, k=5),
41 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]),
42 | )
43 | torch.testing.assert_close(
44 | hit_rate(input, target, k=20),
45 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]),
46 | )
47 |
48 | def test_hit_rate_with_invalid_input(self) -> None:
49 | with self.assertRaisesRegex(
50 | ValueError, "target should be a one-dimensional tensor"
51 | ):
52 | hit_rate(torch.rand(3, 2), torch.rand(3, 2))
53 |
54 | with self.assertRaisesRegex(
55 | ValueError, "input should be a two-dimensional tensor"
56 | ):
57 | hit_rate(torch.rand(3, 2, 2), torch.rand(3))
58 | with self.assertRaisesRegex(
59 | ValueError, "`input` and `target` should have the same minibatch dimension"
60 | ):
61 | hit_rate(torch.rand(4, 2), torch.rand(3))
62 | with self.assertRaisesRegex(ValueError, "k should be None or positive"):
63 | hit_rate(torch.rand(3, 2), torch.rand(3), k=0)
64 |
--------------------------------------------------------------------------------
/tests/metrics/ranking/test_hit_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.ranking import HitRate
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestHitRate(MetricClassTester):
15 | def test_hitrate_with_valid_input(self) -> None:
16 | input = torch.tensor(
17 | [
18 | [
19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700],
20 | ],
21 | [
22 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091],
23 | ],
24 | [
25 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337],
26 | ],
27 | [
28 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516],
29 | ],
30 | ]
31 | )
32 | target = torch.tensor([[3], [5], [2], [1]])
33 |
34 | self.run_class_implementation_tests(
35 | metric=HitRate(),
36 | state_names={"scores"},
37 | update_kwargs={"input": input, "target": target},
38 | compute_result=torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]),
39 | num_total_updates=4,
40 | num_processes=2,
41 | )
42 |
43 | self.run_class_implementation_tests(
44 | metric=HitRate(k=3),
45 | state_names={"scores"},
46 | update_kwargs={"input": input, "target": target},
47 | compute_result=torch.tensor([0.0000, 1.0000, 1.0000, 1.0000]),
48 | num_total_updates=4,
49 | num_processes=2,
50 | )
51 |
52 | def test_hitrate_with_invalid_input(self) -> None:
53 | metric = HitRate()
54 | with self.assertRaisesRegex(
55 | ValueError, "target should be a one-dimensional tensor"
56 | ):
57 | metric.update(torch.rand(3, 2), torch.rand(3, 2))
58 |
59 | with self.assertRaisesRegex(
60 | ValueError, "input should be a two-dimensional tensor"
61 | ):
62 | metric.update(torch.rand(3, 2, 2), torch.rand(3))
63 | with self.assertRaisesRegex(
64 | ValueError, "`input` and `target` should have the same minibatch dimension"
65 | ):
66 | metric.update(torch.rand(4, 2), torch.rand(3))
67 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/hit_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def hit_rate(
15 | input: torch.Tensor,
16 | target: torch.Tensor,
17 | *,
18 | k: int | None = None,
19 | ) -> torch.Tensor:
20 | """
21 | Compute the hit rate of the correct class among the top predicted classes.
22 | Its class version is ``torcheval.metrics.HitRate``.
23 |
24 | Args:
25 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or
26 | class probabilities of shape (num_samples, num_classes).
27 | target (Tensor): Ground truth class indices of shape (num_samples,).
28 | k (int, optional): Number of top predicted classes to be considered.
29 | If k is None, all classes are considered and a hit rate of 1.0 is returned.
30 |
31 | Examples::
32 |
33 | >>> import torch
34 | >>> from torcheval.metrics.functional import hit_rate
35 | >>> input = torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3], [0.2, 0.1, 0.7], [0.3, 0.3, 0.4]])
36 | >>> target = torch.tensor([2, 1, 1, 0])
37 | >>> hit_rate(input, target, k=2)
38 | tensor([1.0000, 0.0000, 0.0000, 1.0000])
39 | """
40 | _hit_rate_input_check(input, target, k)
41 | if k is None or k >= input.size(dim=-1):
42 | return input.new_ones(target.size())
43 |
44 | y_score = torch.gather(input, dim=-1, index=target.unsqueeze(dim=-1))
45 | rank = torch.gt(input, y_score).sum(dim=-1)
46 | return (rank < k).float()
47 |
48 |
49 | def _hit_rate_input_check(
50 | input: torch.Tensor, target: torch.Tensor, k: int | None = None
51 | ) -> None:
52 | if target.ndim != 1:
53 | raise ValueError(
54 | f"target should be a one-dimensional tensor, got shape {target.shape}."
55 | )
56 | if input.ndim != 2:
57 | raise ValueError(
58 | f"input should be a two-dimensional tensor, got shape {input.shape}."
59 | )
60 | if input.shape[0] != target.shape[0]:
61 | raise ValueError(
62 | "`input` and `target` should have the same minibatch dimension, ",
63 | f"got shapes {input.shape} and {target.shape}, respectively.",
64 | )
65 | if k is not None and k <= 0:
66 | raise ValueError(f"k should be None or positive, got {k}.")
67 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/frechet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 | import torch
9 |
10 |
11 | def gaussian_frechet_distance(
12 | mu_x: torch.Tensor, cov_x: torch.Tensor, mu_y: torch.Tensor, cov_y: torch.Tensor
13 | ) -> torch.Tensor:
14 | r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.
15 |
16 | The Fréchet distance is also known as the Wasserstein-2 distance.
17 |
18 | Concretely, for multivariate Gaussians :math:`X(\mu_X, \cov_X)`
19 | and :math:`Y(\mu_Y, \cov_Y)`, the function computes and returns :math:`F` as
20 |
21 | .. math::
22 | F(X, Y) = || \mu_X - \mu_Y ||_2^2
23 | + \text{Tr}\left( \cov_X + \cov_Y - 2 \sqrt{\cov_X \cov_Y} \right)
24 |
25 | Args:
26 | mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
27 | cov_x (torch.Tensor): covariance matrix :math:`\cov_X` of :math:`X`, with shape `(N, N)`.
28 | mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
29 | cov_y (torch.Tensor): covariance matrix :math:`\cov_Y` of :math:`Y`, with shape `(N, N)`.
30 |
31 | Returns:
32 | torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
33 | """
34 | if mu_x.ndim != 1:
35 | msg = f"Input mu_x must be one-dimensional; got dimension {mu_x.ndim}."
36 | raise ValueError(msg)
37 | if mu_y.ndim != 1:
38 | msg = f"Input mu_y must be one-dimensional; got dimension {mu_y.ndim}."
39 | raise ValueError(msg)
40 | if cov_x.ndim != 2:
41 | msg = f"Input cov_x must be two-dimensional; got dimension {cov_x.ndim}."
42 | raise ValueError(msg)
43 | if cov_y.ndim != 2:
44 | msg = f"Input cov_x must be two-dimensional; got dimension {cov_y.ndim}."
45 | raise ValueError(msg)
46 | if mu_x.shape != mu_y.shape:
47 | msg = f"Inputs mu_x and mu_y must have the same shape; got {mu_x.shape} and {mu_y.shape}."
48 | raise ValueError(msg)
49 | if cov_x.shape != cov_y.shape:
50 | msg = f"Inputs cov_x and cov_y must have the same shape; got {cov_x.shape} and {cov_y.shape}."
51 | raise ValueError(msg)
52 |
53 | a = (mu_x - mu_y).square().sum()
54 | b = cov_x.trace() + cov_y.trace()
55 | c = torch.linalg.eigvals(cov_x @ cov_y).sqrt().real.sum()
56 | return a + b - 2 * c
57 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/text/helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | def _edit_distance(
14 | prediction_tokens: list[str],
15 | reference_tokens: list[str],
16 | ) -> int:
17 | """
18 | Dynamic programming algorithm to compute the edit distance between two word sequences.
19 |
20 | Args:
21 | prediction_tokens (List[str]): A tokenized predicted sentence
22 | reference_tokens (List[str]): A tokenized reference sentence
23 | """
24 | dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)]
25 | for i in range(len(prediction_tokens) + 1):
26 | dp[i][0] = i
27 | for j in range(len(reference_tokens) + 1):
28 | dp[0][j] = j
29 | for i in range(1, len(prediction_tokens) + 1):
30 | for j in range(1, len(reference_tokens) + 1):
31 | if prediction_tokens[i - 1] == reference_tokens[j - 1]:
32 | dp[i][j] = dp[i - 1][j - 1]
33 | else:
34 | dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
35 | return dp[-1][-1]
36 |
37 |
38 | def _get_errors_and_totals(
39 | input: str | list[str],
40 | target: str | list[str],
41 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
42 | """
43 | Calculate the edit distance, max length and lengths of predicted and reference word sequences.
44 |
45 | Args:
46 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings.
47 | target (str, List[str]): Reference word sequence(s) as a string or list of strings.
48 | """
49 | if isinstance(input, str):
50 | input = [input]
51 | if isinstance(target, str):
52 | target = [target]
53 | max_total = torch.tensor(0.0, dtype=torch.float64)
54 | errors = torch.tensor(0.0, dtype=torch.float64)
55 | target_total = torch.tensor(0.0, dtype=torch.float64)
56 | input_total = torch.tensor(0.0, dtype=torch.float64)
57 | for ipt, tgt in zip(input, target):
58 | input_tokens = ipt.split()
59 | target_tokens = tgt.split()
60 | errors += _edit_distance(input_tokens, target_tokens)
61 | target_total += len(target_tokens)
62 | input_total += len(input_tokens)
63 | max_total += max(len(target_tokens), len(input_tokens))
64 |
65 | return errors, max_total, target_total, input_total
66 |
--------------------------------------------------------------------------------
/tests/metrics/functional/text/test_word_error_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import word_error_rate
13 |
14 |
15 | class TestWordErrorRate(unittest.TestCase):
16 | def test_word_error_rate_with_valid_input(self) -> None:
17 | torch.testing.assert_close(
18 | word_error_rate("hello meta", "hello metaverse"),
19 | torch.tensor(0.5, dtype=torch.float64),
20 | )
21 | torch.testing.assert_close(
22 | word_error_rate("hello meta", "hello meta"),
23 | torch.tensor(0.0, dtype=torch.float64),
24 | )
25 | torch.testing.assert_close(
26 | word_error_rate("this is the prediction", "this is the reference"),
27 | torch.tensor(0.25, dtype=torch.float64),
28 | )
29 | torch.testing.assert_close(
30 | word_error_rate(
31 | ["hello world", "welcome to the facebook"],
32 | ["hello metaverse", "welcome to meta"],
33 | ),
34 | torch.tensor(0.6, dtype=torch.float64),
35 | )
36 | torch.testing.assert_close(
37 | word_error_rate(
38 | [
39 | "hello metaverse",
40 | "come to the facebook",
41 | "this is reference",
42 | "there is the other one",
43 | ],
44 | [
45 | "hello world",
46 | "welcome to meta",
47 | "this is reference",
48 | "there is another one",
49 | ],
50 | ),
51 | torch.tensor(0.5, dtype=torch.float64),
52 | )
53 |
54 | def test_word_error_rate_with_invalid_input(self) -> None:
55 | with self.assertRaisesRegex(
56 | ValueError, "input and target should have the same type"
57 | ):
58 | word_error_rate(["hello metaverse", "welcome to meta"], "hello world")
59 |
60 | with self.assertRaisesRegex(
61 | ValueError, "input and target lists should have the same length"
62 | ):
63 | word_error_rate(
64 | ["hello metaverse", "welcome to meta"],
65 | [
66 | "welcome to meta",
67 | "this is the prediction",
68 | "there is an other sample",
69 | ],
70 | )
71 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import reciprocal_rank
13 |
14 |
15 | class TestReciprocalRank(unittest.TestCase):
16 | def test_reciprocal_rank_with_valid_input(self) -> None:
17 | input = torch.tensor(
18 | [
19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700],
20 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091],
21 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337],
22 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516],
23 | ]
24 | )
25 | target = torch.tensor([3, 5, 2, 1])
26 |
27 | torch.testing.assert_close(
28 | reciprocal_rank(input, target, k=None),
29 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]),
30 | )
31 | torch.testing.assert_close(
32 | reciprocal_rank(input, target, k=1),
33 | torch.tensor([0.0000, 1.0000, 0.0000, 0.0000]),
34 | )
35 | torch.testing.assert_close(
36 | reciprocal_rank(input, target, k=3),
37 | torch.tensor([0.0000, 1.0000, 1.0000 / 3, 0.5000]),
38 | )
39 | torch.testing.assert_close(
40 | reciprocal_rank(input, target, k=5),
41 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]),
42 | )
43 | torch.testing.assert_close(
44 | reciprocal_rank(input, target, k=20),
45 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]),
46 | )
47 | torch.testing.assert_close(
48 | reciprocal_rank(input, target, k=100),
49 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]),
50 | )
51 |
52 | def test_reciprocal_rank_with_invalid_input(self) -> None:
53 | with self.assertRaisesRegex(
54 | ValueError, "target should be a one-dimensional tensor"
55 | ):
56 | reciprocal_rank(torch.rand(3, 2), torch.rand(3, 2))
57 |
58 | with self.assertRaisesRegex(
59 | ValueError, "input should be a two-dimensional tensor"
60 | ):
61 | reciprocal_rank(torch.rand(3, 2, 2), torch.rand(3))
62 | with self.assertRaisesRegex(
63 | ValueError, "`input` and `target` should have the same minibatch dimension"
64 | ):
65 | reciprocal_rank(torch.rand(4, 2), torch.rand(3))
66 |
--------------------------------------------------------------------------------
/.github/workflows/release_build_docs.yaml:
--------------------------------------------------------------------------------
1 | name: Build Docs for New Release
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | RELEASE_TAG:
7 | description: 'Tag name for this release'
8 | required: true
9 | type: string
10 | DOCS_DIRECTORY:
11 | description: 'Directory which will store the compiled docs'
12 | required: true
13 | type: string
14 |
15 | # Allow one concurrent deployment
16 | concurrency:
17 | group: "pages"
18 | cancel-in-progress: true
19 |
20 | env:
21 | RELEASE_TAG: ${{ inputs.RELEASE_TAG }}
22 | DOCS_DIRECTORY: ${{ inputs.DOCS_DIRECTORY }}
23 |
24 | jobs:
25 | build_docs:
26 | runs-on: ubuntu-latest
27 | steps:
28 | - name: Check out repo
29 | uses: actions/checkout@v2
30 | with:
31 | ref: ${{ env.RELEASE_TAG }}
32 | - name: Setup conda env
33 | uses: conda-incubator/setup-miniconda@v2
34 | with:
35 | miniconda-version: "latest"
36 | activate-environment: test
37 | - name: Install dependencies
38 | shell: bash -l {0}
39 | run: |
40 | set -eux
41 | conda activate test
42 | conda install pytorch cpuonly -c pytorch-nightly
43 | pip install -r requirements.txt
44 | pip install -r dev-requirements.txt
45 | python setup.py sdist bdist_wheel
46 | pip install dist/*.whl
47 | - name: Build docs
48 | shell: bash -l {0}
49 | run: |
50 | set -eux
51 | conda activate test
52 | cd docs
53 | pip install -r requirements.txt
54 | make html
55 | cd ..
56 | - name: Deploy docs to Github pages
57 | uses: JamesIves/github-pages-deploy-action@v4.4.1
58 | with:
59 | branch: gh-pages # The branch the action should deploy to.
60 | folder: docs/build/html # The folder the action should deploy.
61 | target-folder: ${{ env.DOCS_DIRECTORY }}
62 | update_stable_link:
63 | needs: build_docs
64 | runs-on: ubuntu-latest
65 | steps:
66 | - name: Check out repo
67 | uses: actions/checkout@v2
68 | with:
69 | ref: gh-pages
70 | - name: Create symbolic link to latest release
71 | run: |
72 | ln -s ${{ env.DOCS_DIRECTORY }} stable
73 | - name: Add symbolic link to latest release
74 | run: |
75 | git add stable
76 | - name: Commit symbolic link
77 | run: |
78 | git commit -m "Update symbolic link to latest release"
79 | - name: Push changes
80 | uses: ad-m/github-push-action@0fafdd62b84042d49ec0cb92d9cac7f7ce4ec79e
81 | with:
82 | branch: gh-pages
83 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | name: 🐛 Bug Report
2 | description: Create a report to help us reproduce and fix the bug
3 |
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: >
8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the
9 | existing and past issues](https://github.com/pytorch/torcheval/issues?q=is%3Aissue+sort%3Acreated-desc+).
10 | - type: textarea
11 | attributes:
12 | label: 🐛 Describe the bug
13 | description: |
14 | Please provide a clear and concise description of what the bug is.
15 |
16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
17 |
18 | ```python
19 | # All necessary imports at the beginning
20 | import torch
21 | import torcheval
22 |
23 | # A succinct reproducing example trimmed down to the essential parts
24 |
25 | ```
26 |
27 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
28 |
29 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
30 | placeholder: |
31 | A clear and concise description of what the bug is.
32 |
33 | ```python
34 | Sample code to reproduce the problem
35 | ```
36 |
37 | ```
38 | The error message you got, with the full traceback.
39 | ```
40 | validations:
41 | required: true
42 | - type: textarea
43 | attributes:
44 | label: Versions
45 | description: |
46 | Please run the following and paste the output below. Make sure the version numbers of all relevant packages (e.g. torch, torcheval, other domain packages) are included.
47 | ```sh
48 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
49 | # For security purposes, please check the contents of collect_env.py before running it.
50 | python collect_env.py
51 | ```
52 | validations:
53 | required: true
54 |
55 | - type: markdown
56 | attributes:
57 | value: >
58 | Thanks for contributing 🎉!
59 |
--------------------------------------------------------------------------------
/tests/metrics/image/test_ssim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 | from torch import Tensor
12 |
13 | from torcheval.metrics.image.ssim import StructuralSimilarity
14 | from torcheval.utils.test_utils.metric_class_tester import (
15 | BATCH_SIZE,
16 | IMG_CHANNELS,
17 | IMG_HEIGHT,
18 | IMG_WIDTH,
19 | MetricClassTester,
20 | NUM_TOTAL_UPDATES,
21 | )
22 |
23 | # pyre-ignore-all-errors[6]
24 |
25 |
26 | class TestStructuralSimilarity(MetricClassTester):
27 | def setUp(self) -> None:
28 | super().setUp()
29 | torch.manual_seed(0)
30 |
31 | def _get_input_data(
32 | self,
33 | num_updates: int,
34 | batch_size: int,
35 | num_channels: int,
36 | height: int,
37 | width: int,
38 | ) -> dict[str, Tensor]:
39 | images = {
40 | "images_1": torch.rand(
41 | size=(num_updates, batch_size, num_channels, height, width)
42 | ),
43 | "images_2": torch.rand(
44 | size=(num_updates, batch_size, num_channels, height, width)
45 | ),
46 | }
47 |
48 | return images
49 |
50 | def test_ssim(
51 | self,
52 | ) -> None:
53 | images = self._get_input_data(
54 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH
55 | )
56 |
57 | expected_result = torch.tensor(0.022607240825891495)
58 |
59 | state_names = {
60 | "mssim_sum",
61 | "num_images",
62 | }
63 |
64 | self.run_class_implementation_tests(
65 | metric=StructuralSimilarity(),
66 | state_names=state_names,
67 | update_kwargs={
68 | "images_1": images["images_1"],
69 | "images_2": images["images_2"],
70 | },
71 | compute_result=expected_result,
72 | min_updates_before_compute=2,
73 | test_merge_with_one_update=False,
74 | atol=1e-4,
75 | rtol=1e-4,
76 | test_devices=["cpu"],
77 | )
78 |
79 | def test_ssim_invalid_input(self) -> None:
80 | metric = StructuralSimilarity()
81 | images_1 = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
82 | images_2 = torch.rand(BATCH_SIZE, 4, IMG_HEIGHT, IMG_WIDTH)
83 |
84 | with self.assertRaisesRegex(
85 | RuntimeError, "The two sets of images must have the same shape."
86 | ):
87 | metric.update(images_1=images_1, images_2=images_2)
88 |
--------------------------------------------------------------------------------
/tests/metrics/functional/aggregation/test_sum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import sum
13 | from torcheval.utils.test_utils.metric_class_tester import BATCH_SIZE, NUM_TOTAL_UPDATES
14 |
15 |
16 | class TestSum(unittest.TestCase):
17 | def _test_sum_with_input(
18 | self,
19 | val: torch.Tensor,
20 | weight: float | torch.Tensor = 1.0,
21 | ) -> None:
22 | torch.testing.assert_close(
23 | sum(val),
24 | torch.sum(val),
25 | equal_nan=True,
26 | atol=1e-8,
27 | rtol=1e-5,
28 | )
29 |
30 | def test_sum_base(self) -> None:
31 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
32 | self._test_sum_with_input(input_val_tensor)
33 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4)
34 | self._test_sum_with_input(input_val_tensor)
35 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4)
36 | self._test_sum_with_input(input_val_tensor)
37 |
38 | def test_sum_input_valid_weight(self) -> None:
39 | def _compute_result(
40 | val: torch.Tensor, weight: float | torch.Tensor
41 | ) -> torch.Tensor:
42 | weighted_sum = torch.tensor(0.0)
43 | if isinstance(weight, torch.Tensor):
44 | weight = weight.numpy().flatten()
45 | weighted_sum += val.numpy().flatten().dot(weight).sum()
46 |
47 | return weighted_sum
48 |
49 | inputs = [
50 | torch.rand(1),
51 | torch.rand(BATCH_SIZE, 4),
52 | torch.rand(BATCH_SIZE, 3, 4),
53 | torch.rand(5),
54 | torch.rand(10),
55 | ]
56 | weights = [
57 | torch.rand(1),
58 | torch.rand(BATCH_SIZE, 4),
59 | torch.rand(BATCH_SIZE, 3, 4),
60 | 0.8,
61 | 2,
62 | ]
63 |
64 | for input, weight in zip(inputs, weights):
65 | torch.testing.assert_close(
66 | sum(input, weight),
67 | _compute_result(input, weight),
68 | equal_nan=True,
69 | atol=1e-8,
70 | rtol=1e-5,
71 | )
72 |
73 | def test_sum_input_invalid_weight(self) -> None:
74 | with self.assertRaisesRegex(
75 | ValueError,
76 | r"Weight must be either a float value or an int value or a tensor that matches the input tensor size.",
77 | ):
78 | sum(torch.tensor([2.0, 3.0]), torch.tensor([0.5]))
79 |
--------------------------------------------------------------------------------
/tests/metrics/functional/text/test_word_information_preserved.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import word_information_preserved
13 |
14 |
15 | class TestWordInformationPreserved(unittest.TestCase):
16 | def test_word_information_preserved_with_valid_input(self) -> None:
17 | torch.testing.assert_close(
18 | word_information_preserved("hello meta", "hi metaverse"),
19 | torch.tensor(0.0, dtype=torch.float64),
20 | )
21 | torch.testing.assert_close(
22 | word_information_preserved("hello meta", "hello meta"),
23 | torch.tensor(1.0, dtype=torch.float64),
24 | )
25 | torch.testing.assert_close(
26 | word_information_preserved(
27 | "this is the prediction", "this is the reference"
28 | ),
29 | torch.tensor(0.5625, dtype=torch.float64),
30 | )
31 | torch.testing.assert_close(
32 | word_information_preserved(
33 | ["hello world", "welcome to the facebook"],
34 | ["hello metaverse", "welcome to meta"],
35 | ),
36 | torch.tensor(0.3, dtype=torch.float64),
37 | )
38 | torch.testing.assert_close(
39 | word_information_preserved(
40 | [
41 | "hello metaverse",
42 | "come to the facebook",
43 | "this is reference",
44 | "there is the other one",
45 | ],
46 | [
47 | "hello world",
48 | "welcome to meta",
49 | "this is reference",
50 | "there is another one",
51 | ],
52 | ),
53 | torch.tensor(0.38095238, dtype=torch.float64),
54 | )
55 |
56 | def test_word_information_preserved_with_invalid_input(self) -> None:
57 | with self.assertRaisesRegex(
58 | ValueError, "input and target should have the same type"
59 | ):
60 | word_information_preserved(
61 | ["hello metaverse", "welcome to meta"], "hello world"
62 | )
63 |
64 | with self.assertRaisesRegex(
65 | ValueError, "input and target lists should have the same length"
66 | ):
67 | word_information_preserved(
68 | ["hello metaverse", "welcome to meta"],
69 | [
70 | "welcome to meta",
71 | "this is the prediction",
72 | "there is an other sample",
73 | ],
74 | )
75 |
--------------------------------------------------------------------------------
/tests/metrics/functional/ranking/test_weighted_calibration.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import torch
12 | from torcheval.metrics.functional import weighted_calibration
13 |
14 |
15 | class TestWeightedCalibration(unittest.TestCase):
16 | def test_weighted_calibration_with_valid_input(self) -> None:
17 | torch.testing.assert_close(
18 | weighted_calibration(
19 | torch.tensor([0.8, 0.4, 0.3, 0.8, 0.7, 0.6]),
20 | torch.tensor([1, 1, 0, 0, 1, 0]),
21 | ),
22 | torch.tensor(1.2000),
23 | )
24 |
25 | torch.testing.assert_close(
26 | weighted_calibration(
27 | torch.tensor([0.8, 0.4, 0.3, 0.8, 0.7, 0.6]),
28 | torch.tensor([1, 1, 0, 0, 1, 0]),
29 | torch.tensor([0.5, 1.0, 2.0, 0.4, 1.3, 0.9]),
30 | ),
31 | torch.tensor(1.1321428185),
32 | )
33 |
34 | torch.testing.assert_close(
35 | weighted_calibration(
36 | torch.tensor([[0.8, 0.4], [0.8, 0.7]]),
37 | torch.tensor([[1, 1], [0, 1]]),
38 | num_tasks=2,
39 | ),
40 | torch.tensor([0.6000, 1.5000]),
41 | )
42 |
43 | def test_weighted_calibration_with_invalid_input(self) -> None:
44 | with self.assertRaisesRegex(
45 | ValueError,
46 | r"Weight must be either a float value or a tensor that matches the input tensor size.",
47 | ):
48 | weighted_calibration(
49 | torch.tensor([0.8, 0.4, 0.8, 0.7]),
50 | torch.tensor([1, 1, 0, 1]),
51 | torch.tensor([1, 1.5]),
52 | )
53 |
54 | with self.assertRaisesRegex(
55 | ValueError,
56 | r"is different from `target` shape",
57 | ):
58 | weighted_calibration(
59 | torch.tensor([0.8, 0.4, 0.8, 0.7]),
60 | torch.tensor([[1, 1, 0], [0, 1, 1]]),
61 | )
62 |
63 | with self.assertRaisesRegex(
64 | ValueError,
65 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,",
66 | ):
67 | weighted_calibration(
68 | torch.tensor([[0.8, 0.4], [0.8, 0.7]]),
69 | torch.tensor([[1, 1], [0, 1]]),
70 | )
71 | with self.assertRaisesRegex(
72 | ValueError,
73 | r"`num_tasks = 2`, `input`'s shape is expected to be",
74 | ):
75 | weighted_calibration(
76 | torch.tensor([0.8, 0.4, 0.8, 0.7]),
77 | torch.tensor([1, 0, 0, 1]),
78 | num_tasks=2,
79 | )
80 |
--------------------------------------------------------------------------------
/docs/source/torcheval.metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | =============
3 |
4 | .. automodule:: torcheval.metrics
5 |
6 |
7 | Aggregation Metrics
8 | -------------------------------------------------------------------
9 |
10 | .. autosummary::
11 | :toctree: generated
12 | :nosignatures:
13 |
14 | AUC
15 | Cat
16 | Max
17 | Mean
18 | Min
19 | Sum
20 | Throughput
21 |
22 | Audio Metrics
23 | -------------------------------------------------------------------
24 |
25 | .. autosummary::
26 | :toctree: generated
27 | :nosignatures:
28 |
29 | FrechetAudioDistance
30 |
31 | Classification Metrics
32 | -------------------------------------------------------------------
33 |
34 | .. autosummary::
35 | :toctree: generated
36 | :nosignatures:
37 |
38 | BinaryAccuracy
39 | BinaryAUPRC
40 | BinaryAUROC
41 | BinaryBinnedAUROC
42 | BinaryBinnedPrecisionRecallCurve
43 | BinaryConfusionMatrix
44 | BinaryF1Score
45 | BinaryNormalizedEntropy
46 | BinaryPrecision
47 | BinaryPrecisionRecallCurve
48 | BinaryRecall
49 | BinaryRecallAtFixedPrecision
50 | MulticlassAccuracy
51 | MulticlassAUPRC
52 | MulticlassAUROC
53 | MulticlassBinnedAUROC
54 | MulticlassBinnedPrecisionRecallCurve
55 | MulticlassConfusionMatrix
56 | MulticlassF1Score
57 | MulticlassPrecision
58 | MulticlassPrecisionRecallCurve
59 | MulticlassRecall
60 | MultilabelAccuracy
61 | MultilabelAUPRC
62 | MultilabelPrecisionRecallCurve
63 | MultilabelRecallAtFixedPrecision
64 | TopKMultilabelAccuracy
65 |
66 | Image Metrics
67 | -------------------------------------------------------------------
68 |
69 | .. autosummary::
70 | :toctree: generated
71 | :nosignatures:
72 |
73 | FrechetInceptionDistance
74 | PeakSignalNoiseRatio
75 | StructuralSimilarity
76 |
77 | Ranking Metrics
78 | -------------------------------------------------------------------
79 |
80 | .. autosummary::
81 | :toctree: generated
82 | :nosignatures:
83 |
84 | ClickThroughRate
85 | HitRate
86 | ReciprocalRank
87 | WeightedCalibration
88 |
89 | Regression Metrics
90 | -------------------------------------------------------------------
91 |
92 | .. autosummary::
93 | :toctree: generated
94 | :nosignatures:
95 |
96 | MeanSquaredError
97 | R2Score
98 |
99 | Text Metrics
100 | -------------------------------------------------------------------
101 |
102 | .. autosummary::
103 | :toctree: generated
104 | :nosignatures:
105 |
106 | BLEUScore
107 | Perplexity
108 | WordErrorRate
109 | WordInformationLost
110 | WordInformationPreserved
111 |
112 | Windowed Metrics
113 | -------------------------------------------------------------------
114 |
115 | .. autosummary::
116 | :toctree: generated
117 | :nosignatures:
118 |
119 | WindowedBinaryAUROC
120 | WindowedBinaryNormalizedEntropy
121 | WindowedClickThroughRate
122 | WindowedMeanSquaredError
123 | WindowedWeightedCalibration
124 |
--------------------------------------------------------------------------------
/tests/metrics/ranking/test_reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.ranking import ReciprocalRank
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestReciprocalRank(MetricClassTester):
15 | def test_mrr_with_valid_input(self) -> None:
16 | input = torch.tensor(
17 | [
18 | [
19 | [0.9005, 0.0998, 0.2470, 0.6188, 0.9497, 0.6083, 0.7258],
20 | [0.9505, 0.3270, 0.4734, 0.5854, 0.5202, 0.6546, 0.7869],
21 | ],
22 | [
23 | [0.5546, 0.6027, 0.2650, 0.6624, 0.8755, 0.7838, 0.7529],
24 | [0.4121, 0.6082, 0.7813, 0.5947, 0.9582, 0.8736, 0.7389],
25 | ],
26 | [
27 | [0.1306, 0.7939, 0.5192, 0.0494, 0.7987, 0.3898, 0.0108],
28 | [0.2399, 0.2969, 0.6738, 0.8633, 0.7939, 0.1052, 0.7702],
29 | ],
30 | [
31 | [0.9097, 0.7436, 0.0051, 0.6264, 0.6616, 0.7328, 0.7413],
32 | [0.5286, 0.2956, 0.0578, 0.1913, 0.8118, 0.1047, 0.7966],
33 | ],
34 | ]
35 | )
36 | target = torch.tensor([[1, 3], [3, 0], [2, 6], [4, 5]])
37 |
38 | self.run_class_implementation_tests(
39 | metric=ReciprocalRank(),
40 | state_names={"scores"},
41 | update_kwargs={"input": input, "target": target},
42 | compute_result=torch.tensor(
43 | [1.0 / 7, 0.25, 0.25, 1.0 / 7, 1.0 / 3, 1.0 / 3, 0.20, 1.0 / 6]
44 | ),
45 | num_total_updates=4,
46 | num_processes=2,
47 | )
48 |
49 | self.run_class_implementation_tests(
50 | metric=ReciprocalRank(k=5),
51 | state_names={"scores"},
52 | update_kwargs={"input": input, "target": target},
53 | compute_result=torch.tensor(
54 | [0.0, 0.25, 0.25, 0.0, 1.0 / 3, 1.0 / 3, 0.2, 0.0]
55 | ),
56 | num_total_updates=4,
57 | num_processes=2,
58 | )
59 |
60 | def test_mrr_with_invalid_input(self) -> None:
61 | metric = ReciprocalRank()
62 | with self.assertRaisesRegex(
63 | ValueError, "target should be a one-dimensional tensor"
64 | ):
65 | metric.update(torch.rand(3, 2), torch.rand(3, 2))
66 |
67 | with self.assertRaisesRegex(
68 | ValueError, "input should be a two-dimensional tensor"
69 | ):
70 | metric.update(torch.rand(3, 2, 2), torch.rand(3))
71 | with self.assertRaisesRegex(
72 | ValueError, "`input` and `target` should have the same minibatch dimension"
73 | ):
74 | metric.update(torch.rand(4, 2), torch.rand(3))
75 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/cov.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from collections.abc import Iterable
10 | from typing import TypeAlias, TypeVar, Union
11 |
12 | import torch
13 | from torcheval.metrics.metric import Metric
14 | from typing_extensions import Self
15 |
16 | # TODO: use a NamedTuple?
17 | _T = TypeVar("_T", bound=Union[torch.Tensor, int])
18 | _Output: TypeAlias = tuple[torch.Tensor, torch.Tensor] # mean, cov
19 |
20 |
21 | class Covariance(Metric[_Output]):
22 | """Fit sample mean + covariance to empirical distribution"""
23 |
24 | def __init__(self, *, device: torch.device | None = None) -> None:
25 | super().__init__(device=device)
26 | self.sum: torch.Tensor = self._add_state_and_return(
27 | "sum", default=torch.as_tensor(0.0)
28 | )
29 | self.ss_sum: torch.Tensor = self._add_state_and_return(
30 | "ss_sum", default=torch.as_tensor(0.0)
31 | )
32 | self.n: int = self._add_state_and_return("n", default=0)
33 |
34 | def _add_state_and_return(self, name: str, default: _T) -> _T:
35 | # Helper function for pyre
36 | self._add_state(name, default)
37 | return getattr(self, name)
38 |
39 | def _update(self, sum: torch.Tensor, ss_sum: torch.Tensor, n: int) -> None:
40 | if n == 0:
41 | return
42 | elif self.n == 0:
43 | self.n = n
44 | self.ss_sum = ss_sum
45 | self.sum = sum
46 | else:
47 | # Welford's algorithm for numerical stability
48 | delta = (self.sum / self.n) - (sum / n)
49 | outer = torch.outer(delta, delta)
50 |
51 | scale = n * self.n / (self.n + n)
52 | self.ss_sum += ss_sum + outer * scale
53 | self.sum += sum
54 | self.n += n
55 |
56 | # pyre-fixme[14]
57 | def update(self, obs: torch.Tensor) -> Self:
58 | assert obs.ndim == 2
59 | with torch.inference_mode():
60 | demeaned = obs - obs.mean(dim=0, keepdim=True)
61 | ss_sum = torch.einsum("ni,nj->ij", demeaned, demeaned)
62 | self._update(obs.sum(dim=0), ss_sum, len(obs))
63 | return self
64 |
65 | # pyre-fixme[14]
66 | def merge_state(self, metrics: Iterable[Self]) -> Self:
67 | with torch.inference_mode():
68 | for other in metrics:
69 | self._update(other.sum, other.ss_sum, other.n)
70 | return self
71 |
72 | def compute(self) -> _Output:
73 | if self.n < 2:
74 | msg = f"Not enough samples to estimate covariance (found {self.n})"
75 | raise ValueError(msg)
76 | with torch.inference_mode():
77 | mean = self.sum / self.n
78 | # TODO: make degress of freedom configurable?
79 | cov = self.ss_sum / (self.n - 1)
80 | return mean, cov
81 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/image/psnr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def peak_signal_noise_ratio(
15 | input: torch.Tensor,
16 | target: torch.Tensor,
17 | data_range: float | None = None,
18 | ) -> torch.Tensor:
19 | """
20 | Compute the peak signal-to-noise ratio between two images.
21 | It's class version is `torcheval.metrics.PeakSignalNoiseRatio`
22 |
23 | Args:
24 | input (Tensor): Input image ``(N, C, H, W)``.
25 | target (Tensor): Target image ``(N, C, H, W)``.
26 | data_range (float): the range of the input images. Default: None.
27 | If None, the input range computed from the target data ``(target.max() - targert.min())``.
28 | Examples::
29 |
30 | >>> import torch
31 | >>> from torcheval.metrics.functional import peak_signal_noise_ratio
32 | >>> input = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
33 | >>> target = input * 0.9
34 | >>> peak_signal_noise_ratio(input, target)
35 | tensor(19.8767)
36 | """
37 | _psnr_param_check(data_range)
38 |
39 | if data_range is None:
40 | data_range_tensor = torch.max(target) - torch.min(target)
41 | else:
42 | data_range_tensor = torch.tensor(data=data_range, device=target.device)
43 |
44 | sum_square_error, num_observations = _psnr_update(input, target)
45 | psnr = _psnr_compute(sum_square_error, num_observations, data_range_tensor)
46 | return psnr
47 |
48 |
49 | def _psnr_param_check(data_range: float | None) -> None:
50 | # Check matching shapes
51 | if data_range is not None:
52 | if type(data_range) is not float:
53 | raise ValueError("`data_range needs to be either `None` or `float`.")
54 | if data_range <= 0:
55 | raise ValueError("`data_range` needs to be positive.")
56 |
57 |
58 | def _psnr_input_check(input: torch.Tensor, target: torch.Tensor) -> None:
59 | # Check matching shapes
60 | if input.shape != target.shape:
61 | raise ValueError(
62 | "The `input` and `target` must have the same shape, "
63 | f"got shapes {input.shape} and {target.shape}."
64 | )
65 |
66 |
67 | def _psnr_update(
68 | input: torch.Tensor, target: torch.Tensor
69 | ) -> tuple[torch.Tensor, torch.Tensor]:
70 | _psnr_input_check(input, target)
71 | sum_squared_error = torch.sum(torch.pow(input - target, 2))
72 | num_observations = torch.tensor(target.numel(), device=target.device)
73 | return sum_squared_error, num_observations
74 |
75 |
76 | def _psnr_compute(
77 | sum_square_error: torch.Tensor,
78 | num_observations: torch.Tensor,
79 | data_range: torch.Tensor,
80 | ) -> torch.Tensor:
81 | mse = sum_square_error / num_observations
82 | psnr = 10 * torch.log10(torch.pow(data_range, 2) / mse)
83 |
84 | return psnr
85 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os
9 | import sys
10 |
11 | from datetime import date
12 |
13 | from setuptools import find_packages, setup
14 | from torcheval import __version__
15 |
16 |
17 | def current_path(file_name: str) -> str:
18 | return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name))
19 |
20 |
21 | def read_requirements(file_name: str) -> list[str]:
22 | with open(current_path(file_name), encoding="utf8") as f:
23 | return f.read().strip().split()
24 |
25 |
26 | def get_nightly_version() -> str:
27 | return date.today().strftime("%Y.%m.%d")
28 |
29 |
30 | def parse_args() -> argparse.Namespace:
31 | parser = argparse.ArgumentParser(description="torcheval setup")
32 | parser.add_argument(
33 | "--nightly",
34 | dest="nightly",
35 | action="store_true",
36 | help="enable settings for nightly package build",
37 | )
38 | parser.set_defaults(nightly=False)
39 | return parser.parse_known_args()
40 |
41 |
42 | if __name__ == "__main__":
43 | with open(current_path("README.md"), encoding="utf8") as f:
44 | readme = f.read()
45 |
46 | custom_args, setup_args = parse_args()
47 | package_name = "torcheval" if not custom_args.nightly else "torcheval-nightly"
48 | version = __version__ if not custom_args.nightly else get_nightly_version()
49 | print(f"using package_name={package_name}, version={version}")
50 |
51 | sys.argv = [sys.argv[0]] + setup_args
52 |
53 | setup(
54 | name=package_name,
55 | version=version,
56 | author="torcheval team",
57 | author_email="yicongd@fb.com",
58 | description="A library for providing a simple interface to create new metrics and an easy-to-use toolkit for metric computations and checkpointing.",
59 | long_description=readme,
60 | long_description_content_type="text/markdown",
61 | url="https://github.com/pytorch/torcheval",
62 | license="BSD-3",
63 | keywords=["pytorch", "evaluation", "metrics"],
64 | python_requires=">=3.7",
65 | install_requires=read_requirements("requirements.txt"),
66 | packages=find_packages(),
67 | package_data={"torcheval": ["py.typed"]},
68 | zip_safe=True,
69 | classifiers=[
70 | "Development Status :: 2 - Pre-Alpha",
71 | "Intended Audience :: Developers",
72 | "Intended Audience :: Science/Research",
73 | "License :: OSI Approved :: BSD License",
74 | "Programming Language :: Python :: 3",
75 | "Programming Language :: Python :: 3.7",
76 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
77 | ],
78 | extras_require={
79 | "dev": read_requirements("dev-requirements.txt"),
80 | "image": read_requirements("image-requirements.txt"),
81 | },
82 | )
83 |
--------------------------------------------------------------------------------
/tests/metrics/functional/aggregation/test_mean.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import unittest
10 |
11 | import numpy as np
12 | import torch
13 | from torcheval.metrics.functional.aggregation import mean
14 | from torcheval.utils.test_utils.metric_class_tester import BATCH_SIZE, NUM_TOTAL_UPDATES
15 |
16 |
17 | class TestMean(unittest.TestCase):
18 | def _test_mean_with_input(
19 | self,
20 | val: torch.Tensor,
21 | weight: float | torch.Tensor = 1.0,
22 | ) -> None:
23 | torch.testing.assert_close(
24 | mean(val),
25 | torch.mean(val),
26 | equal_nan=True,
27 | atol=1e-8,
28 | rtol=1e-5,
29 | )
30 |
31 | def test_mean_base(self) -> None:
32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
33 | self._test_mean_with_input(input_val_tensor)
34 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4)
35 | self._test_mean_with_input(input_val_tensor)
36 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4)
37 | self._test_mean_with_input(input_val_tensor)
38 |
39 | def test_mean_input_valid_weight(self) -> None:
40 | def _compute_result(
41 | val: torch.Tensor, weights: float | torch.Tensor
42 | ) -> torch.Tensor:
43 | # pyre-fixme[9]: val has type `Tensor`; used as `ndarray[Any, Any]`.
44 | val = val.numpy().flatten()
45 | if isinstance(weights, torch.Tensor):
46 | weights = weights.numpy().flatten()
47 | else:
48 | weights = weights * np.ones_like(val)
49 | weighted_mean = np.average(val, weights=weights)
50 | return torch.tensor(weighted_mean, dtype=torch.float32)
51 |
52 | inputs = [
53 | torch.rand(1),
54 | torch.rand(BATCH_SIZE, 4),
55 | torch.rand(BATCH_SIZE, 3, 4),
56 | torch.rand(5),
57 | torch.rand(10),
58 | ]
59 | weights = [
60 | torch.rand(1),
61 | torch.rand(BATCH_SIZE, 4),
62 | torch.rand(BATCH_SIZE, 3, 4),
63 | 0.8,
64 | 1,
65 | ]
66 |
67 | for input, weight in zip(inputs, weights):
68 | print(input)
69 | print(weight)
70 | torch.testing.assert_close(
71 | mean(input, weight),
72 | _compute_result(input, weight),
73 | equal_nan=True,
74 | atol=1e-8,
75 | rtol=1e-5,
76 | )
77 |
78 | def test_mean_input_invalid_weight(self) -> None:
79 | with self.assertRaisesRegex(
80 | ValueError,
81 | r"Weight must be either a float value or a tensor that matches the input tensor size.",
82 | ):
83 | mean(torch.tensor([2.0, 3.0]), torch.tensor([0.5]))
84 |
--------------------------------------------------------------------------------
/torcheval/metrics/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.classification.accuracy import (
10 | BinaryAccuracy,
11 | MulticlassAccuracy,
12 | MultilabelAccuracy,
13 | TopKMultilabelAccuracy,
14 | )
15 | from torcheval.metrics.classification.auprc import (
16 | BinaryAUPRC,
17 | MulticlassAUPRC,
18 | MultilabelAUPRC,
19 | )
20 |
21 | from torcheval.metrics.classification.auroc import BinaryAUROC, MulticlassAUROC
22 | from torcheval.metrics.classification.binary_normalized_entropy import (
23 | BinaryNormalizedEntropy,
24 | )
25 | from torcheval.metrics.classification.binned_auprc import (
26 | BinaryBinnedAUPRC,
27 | MulticlassBinnedAUPRC,
28 | MultilabelBinnedAUPRC,
29 | )
30 | from torcheval.metrics.classification.binned_auroc import (
31 | BinaryBinnedAUROC,
32 | MulticlassBinnedAUROC,
33 | )
34 | from torcheval.metrics.classification.binned_precision_recall_curve import (
35 | BinaryBinnedPrecisionRecallCurve,
36 | MulticlassBinnedPrecisionRecallCurve,
37 | MultilabelBinnedPrecisionRecallCurve,
38 | )
39 | from torcheval.metrics.classification.confusion_matrix import (
40 | BinaryConfusionMatrix,
41 | MulticlassConfusionMatrix,
42 | )
43 | from torcheval.metrics.classification.f1_score import BinaryF1Score, MulticlassF1Score
44 | from torcheval.metrics.classification.precision import (
45 | BinaryPrecision,
46 | MulticlassPrecision,
47 | )
48 | from torcheval.metrics.classification.precision_recall_curve import (
49 | BinaryPrecisionRecallCurve,
50 | MulticlassPrecisionRecallCurve,
51 | MultilabelPrecisionRecallCurve,
52 | )
53 | from torcheval.metrics.classification.recall import BinaryRecall, MulticlassRecall
54 | from torcheval.metrics.classification.recall_at_fixed_precision import (
55 | BinaryRecallAtFixedPrecision,
56 | MultilabelRecallAtFixedPrecision,
57 | )
58 |
59 | __all__ = [
60 | "BinaryAccuracy",
61 | "BinaryAUPRC",
62 | "BinaryAUROC",
63 | "BinaryBinnedAUROC",
64 | "BinaryBinnedAUPRC",
65 | "BinaryBinnedPrecisionRecallCurve",
66 | "BinaryConfusionMatrix",
67 | "BinaryF1Score",
68 | "BinaryNormalizedEntropy",
69 | "BinaryPrecision",
70 | "BinaryPrecisionRecallCurve",
71 | "BinaryRecall",
72 | "BinaryRecallAtFixedPrecision",
73 | "MulticlassAccuracy",
74 | "MulticlassAUPRC",
75 | "MulticlassAUROC",
76 | "MulticlassBinnedAUPRC",
77 | "MulticlassBinnedAUROC",
78 | "MulticlassBinnedPrecisionRecallCurve",
79 | "MulticlassConfusionMatrix",
80 | "MulticlassF1Score",
81 | "MulticlassPrecision",
82 | "MulticlassPrecisionRecallCurve",
83 | "MulticlassRecall",
84 | "MultilabelAccuracy",
85 | "MultilabelAUPRC",
86 | "MultilabelBinnedAUPRC",
87 | "MultilabelBinnedPrecisionRecallCurve",
88 | "MultilabelPrecisionRecallCurve",
89 | "MultilabelRecallAtFixedPrecision",
90 | "TopKMultilabelAccuracy",
91 | ]
92 |
93 | __doc_name__ = "Classification Metrics"
94 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/sum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.functional.aggregation.sum import _sum_update
17 | from torcheval.metrics.metric import Metric
18 |
19 | TSum = TypeVar("TSum")
20 |
21 |
22 | class Sum(Metric[torch.Tensor]):
23 | """
24 | Calculate the weighted sum value of all elements in all the input tensors.
25 | When weight is not provided, it calculates the unweighted sum.
26 | Its functional version is :func:`torcheval.metrics.functional.sum`.
27 |
28 | Examples::
29 |
30 | >>> import torch
31 | >>> from torcheval.metrics import Sum
32 | >>> metric = Sum()
33 | >>> metric.update(1)
34 | >>> metric.update(torch.tensor([2, 3]))
35 | >>> metric.compute()
36 | tensor(6.)
37 | >>> metric.update(torch.tensor(-1)).compute()
38 | tensor(5.)
39 | >>> metric.reset()
40 | >>> metric.update(torch.tensor(-1)).compute()
41 | tensor(-1.)
42 |
43 | >>> metric = Sum()
44 | >>> metric.update(torch.tensor([2, 3]), torch.tensor([0.1, 0.6])).compute()
45 | tensor(2.)
46 | >>> metric.update(torch.tensor([2, 3]), 0.5).compute()
47 | tensor(4.5)
48 | >>> metric.update(torch.tensor([4, 6]), 1).compute()
49 | tensor(14.5)
50 | """
51 |
52 | def __init__(
53 | self: TSum,
54 | *,
55 | device: torch.device | None = None,
56 | ) -> None:
57 | super().__init__(device=device)
58 | self._add_state(
59 | "weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64)
60 | )
61 |
62 | @torch.inference_mode()
63 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
64 | def update(
65 | self: TSum,
66 | input: torch.Tensor,
67 | *,
68 | weight: float | int | torch.Tensor = 1.0,
69 | ) -> TSum:
70 | """
71 | Update states with the values and weights.
72 |
73 | Args:
74 | input (Tensor): Tensor of input values.
75 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size.
76 | Raises:
77 | ValueError: If value of weight is neither a ``float`` nor ``int`` nor a ``torch.Tensor`` that matches the input tensor size.
78 | """
79 |
80 | self.weighted_sum += _sum_update(input, weight)
81 | return self
82 |
83 | @torch.inference_mode()
84 | def compute(self: TSum) -> torch.Tensor:
85 | return self.weighted_sum
86 |
87 | @torch.inference_mode()
88 | def merge_state(self: TSum, metrics: Iterable[TSum]) -> TSum:
89 | for metric in metrics:
90 | self.weighted_sum += metric.weighted_sum.to(self.device)
91 | return self
92 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/text/word_information_lost.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 | from torcheval.metrics.functional.text.helper import _get_errors_and_totals
13 |
14 |
15 | def _wil_update(
16 | input: str | list[str],
17 | target: str | list[str],
18 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
19 | """Update the wil score with the current set of references and predictions.
20 | Args:
21 | input: Transcription(s) to score as a string or list of strings
22 | target: Reference(s) for each speech input as a string or list of strings
23 | Returns:
24 | Number of correct words
25 | Number of words overall references
26 | Number of words overall predictions
27 | """
28 | if isinstance(input, str):
29 | input = [input]
30 | if isinstance(target, str):
31 | target = [target]
32 | assert (
33 | len(input) == len(target)
34 | ), f"Arguments must contain the same number of strings, but got len(input)={len(input)} and len(target)={len(target)}"
35 | errors, max_total, target_total, input_total = _get_errors_and_totals(input, target)
36 | return errors - max_total, target_total, input_total
37 |
38 |
39 | def _wil_compute(
40 | correct_total: torch.Tensor, target_total: torch.Tensor, preds_total: torch.Tensor
41 | ) -> torch.Tensor:
42 | """Compute the Word Information Lost.
43 | Args:
44 | correct_total: Number of correct words
45 | target_total: Number of words overall references
46 | preds_total: Number of words overall prediction
47 | Returns:
48 | Word Information Lost score
49 | """
50 | return 1 - ((correct_total / target_total) * (correct_total / preds_total))
51 |
52 |
53 | @torch.inference_mode()
54 | def word_information_lost(
55 | input: str | list[str],
56 | target: str | list[str],
57 | ) -> torch.Tensor:
58 | """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This
59 | value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better
60 | the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score.
61 |
62 | Its class version is ``torcheval.metrics.WordInformationLost``.
63 |
64 | Args:
65 | input: Transcription(s) to score as a string or list of strings
66 | target: Reference(s) for each speech input as a string or list of strings
67 | Returns:
68 | Word Information Lost rate
69 | Examples:
70 | >>> from torcheval.metrics.functional import word_information_lost
71 | >>> input = ["this is the prediction", "there is an other sample"]
72 | >>> target = ["this is the reference", "there is another one"]
73 | >>> word_information_lost(input, target)
74 | tensor(0.6528)
75 | """
76 | correct_total, target_total, preds_total = _wil_update(input, target)
77 | return _wil_compute(correct_total, target_total, preds_total)
78 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/test_throughput.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import random
10 |
11 | from torcheval.metrics import Throughput
12 | from torcheval.utils.test_utils.metric_class_tester import (
13 | MetricClassTester,
14 | NUM_PROCESSES,
15 | NUM_TOTAL_UPDATES,
16 | )
17 |
18 |
19 | class TestThroughput(MetricClassTester):
20 | def _test_throughput_class_with_input(
21 | self,
22 | num_processed: list[int],
23 | elapsed_time_sec: list[float],
24 | ) -> None:
25 | num_individual_update = NUM_TOTAL_UPDATES // NUM_PROCESSES
26 | expected_num_total = sum(num_processed)
27 | max_elapsed_time_sec = max(
28 | [
29 | sum(
30 | elapsed_time_sec[
31 | i * num_individual_update : (i + 1) * num_individual_update
32 | ]
33 | )
34 | for i in range(NUM_PROCESSES)
35 | ]
36 | )
37 | total_elapsed_time_sec = sum(elapsed_time_sec)
38 |
39 | expected_compute_result = (1.0 * expected_num_total) / total_elapsed_time_sec
40 | expected_merge_and_compute_result = (
41 | 1.0 * expected_num_total
42 | ) / max_elapsed_time_sec
43 | self.run_class_implementation_tests(
44 | metric=Throughput(),
45 | state_names={"num_total", "elapsed_time_sec"},
46 | update_kwargs={
47 | "num_processed": num_processed,
48 | "elapsed_time_sec": elapsed_time_sec,
49 | },
50 | compute_result=expected_compute_result,
51 | merge_and_compute_result=expected_merge_and_compute_result,
52 | )
53 |
54 | def test_throughput_class_base(self) -> None:
55 | num_processed = [random.randint(0, 40) for _ in range(NUM_TOTAL_UPDATES)]
56 | elapsed_time_sec = [random.uniform(0.1, 5.0) for _ in range(NUM_TOTAL_UPDATES)]
57 | self._test_throughput_class_with_input(num_processed, elapsed_time_sec)
58 |
59 | def test_throughput_class_update_input_invalid_num_processed(self) -> None:
60 | metric = Throughput()
61 | with self.assertRaisesRegex(
62 | ValueError,
63 | r"Expected num_processed to be a non-negative number, but received",
64 | ):
65 | metric.update(-1, 1.0)
66 |
67 | def test_throughput_class_update_input_invalid_elapsed_time_sec(self) -> None:
68 | metric = Throughput()
69 | with self.assertRaisesRegex(
70 | ValueError,
71 | r"Expected elapsed_time_sec to be a positive number, but received",
72 | ):
73 | metric.update(42, -5.1)
74 | with self.assertRaisesRegex(
75 | ValueError,
76 | r"Expected elapsed_time_sec to be a positive number, but received",
77 | ):
78 | metric.update(42, 0.0)
79 |
80 | def test_throughput_class_compute_without_update(self) -> None:
81 | metric = Throughput()
82 | self.assertEqual(metric.compute(), 0.0)
83 |
--------------------------------------------------------------------------------
/examples/simple_example.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[5]: Undefined variable type
10 |
11 | import torch
12 | from torch.utils.data.dataset import TensorDataset
13 |
14 | from torcheval.metrics import MulticlassAccuracy
15 |
16 | NUM_EPOCHS = 4
17 | NUM_BATCHES = 16
18 | BATCH_SIZE = 8
19 |
20 |
21 | class Model(torch.nn.Module):
22 | def __init__(self) -> None:
23 | super().__init__()
24 | self.layers = torch.nn.Sequential(
25 | torch.nn.Linear(128, 64),
26 | torch.nn.ReLU(),
27 | torch.nn.Linear(64, 32),
28 | torch.nn.ReLU(),
29 | torch.nn.Linear(32, 2),
30 | )
31 |
32 | def forward(self, X: torch.Tensor) -> torch.Tensor:
33 | return self.layers(X)
34 |
35 |
36 | def prepare_dataloader() -> torch.utils.data.DataLoader:
37 | num_samples = NUM_BATCHES * BATCH_SIZE
38 | data = torch.randn(num_samples, 128)
39 | labels = torch.randint(low=0, high=2, size=(num_samples,))
40 | return torch.utils.data.DataLoader(
41 | TensorDataset(data, labels), batch_size=BATCH_SIZE
42 | )
43 |
44 |
45 | def main() -> None:
46 | torch.random.manual_seed(42)
47 |
48 | model = Model()
49 | optim = torch.optim.Adagrad(model.parameters(), lr=0.001)
50 |
51 | train_dataloader = prepare_dataloader()
52 |
53 | loss_fn = torch.nn.CrossEntropyLoss()
54 | metric = MulticlassAccuracy()
55 |
56 | compute_frequency = 4
57 | num_epochs_completed = 0
58 |
59 | while num_epochs_completed < NUM_EPOCHS:
60 | data_iter = iter(train_dataloader)
61 | batch_idx = 0
62 | while True:
63 | try:
64 | # get the next batch from data iterator
65 | input, target = next(data_iter)
66 | output = model(input)
67 |
68 | # metric.update() updates the metric state with new data
69 | metric.update(output, target)
70 |
71 | loss = loss_fn(output, target)
72 | optim.zero_grad()
73 | loss.backward()
74 | optim.step()
75 |
76 | if (batch_idx + 1) % compute_frequency == 0:
77 | print(
78 | "Epoch {}/{}, Batch {}/{} --- loss: {:.4f}, acc: {:.4f}".format(
79 | num_epochs_completed + 1,
80 | NUM_EPOCHS,
81 | batch_idx + 1,
82 | NUM_BATCHES,
83 | loss.item(),
84 | # metric.compute() returns metric value from all seen data
85 | metric.compute(),
86 | )
87 | )
88 | batch_idx += 1
89 | except StopIteration:
90 | break
91 |
92 | # metric.reset() cleans up all seen data
93 | metric.reset()
94 |
95 | num_epochs_completed += 1
96 |
97 |
98 | if __name__ == "__main__":
99 | main() # pragma: no cover
100 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/cat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 |
12 | from collections.abc import Iterable
13 | from typing import TypeVar
14 |
15 | import torch
16 |
17 | from torcheval.metrics.metric import Metric
18 |
19 | TCat = TypeVar("TCat")
20 |
21 |
22 | class Cat(Metric[torch.Tensor]):
23 | """
24 | Concatenate all input tensors along dimension dim. Its functional
25 | version is ``torch.cat(input)``.
26 |
27 | All input tensors to ``Cat.update()`` must either have the same shape
28 | (except in the concatenating dimension) or be empty.
29 |
30 | Zero-dimensional tensor is not a valid input of ``Cat.update()``.
31 | ``torch.flatten()`` can be used to flatten zero-dimensional into
32 | an one-dimensional tensor before passing in ``Cat.update()``.
33 |
34 | Examples::
35 |
36 | >>> import torch
37 | >>> from torcheval.metrics import Cat
38 | >>> metric = Cat(dim=1)
39 | >>> metric.update(torch.tensor([[1, 2], [3, 4]]))
40 | >>> metric.compute()
41 | tensor([[1, 2],
42 | [3, 4]]))
43 |
44 | >>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute()
45 | tensor([[1, 2, 5, 6],
46 | [3, 4, 7, 8]]))
47 |
48 | >>> metric.reset()
49 | >>> metric.update(torch.tensor([0])).compute()
50 | tensor([0])
51 | """
52 |
53 | def __init__(
54 | self: "Cat",
55 | *,
56 | dim: int = 0,
57 | device: torch.device | None = None,
58 | ) -> None:
59 | """
60 | Initialize a Cat metric object.
61 |
62 | Args:
63 | dim: The dimension along which to concatenate, as in ``torch.cat()``.
64 | """
65 | super().__init__(device=device)
66 | self._add_state("dim", dim)
67 | self._add_state("inputs", [])
68 |
69 | @torch.inference_mode()
70 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
71 | def update(self: TCat, input: torch.Tensor) -> TCat:
72 | self.inputs.append(input)
73 | return self
74 |
75 | @torch.inference_mode()
76 | def compute(self: TCat) -> torch.Tensor:
77 | """
78 | Return the concatenated inputs.
79 |
80 | If no calls to ``update()`` are made before ``compute()`` is called,
81 | the function returns ``torch.empty(0)``.
82 | """
83 | if not self.inputs:
84 | return torch.empty(0)
85 | return torch.cat(self.inputs, dim=self.dim)
86 |
87 | @torch.inference_mode()
88 | def merge_state(self: TCat, metrics: Iterable[TCat]) -> TCat:
89 | for metric in metrics:
90 | if metric.inputs:
91 | self.inputs.append(
92 | torch.cat(metric.inputs, dim=metric.dim).to(self.device)
93 | )
94 | return self
95 |
96 | @torch.inference_mode()
97 | def _prepare_for_merge_state(self: TCat) -> None:
98 | if self.inputs:
99 | self.inputs = [torch.cat(self.inputs, dim=self.dim)]
100 |
--------------------------------------------------------------------------------
/.github/workflows/unit_test.yaml:
--------------------------------------------------------------------------------
1 | name: unit test
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 |
8 | jobs:
9 | unit_tests:
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: [3.8, 3.9]
14 | steps:
15 | - name: Check out repo
16 | uses: actions/checkout@v2
17 | - name: Setup conda env
18 | uses: conda-incubator/setup-miniconda@v2
19 | with:
20 | miniconda-version: "latest"
21 | activate-environment: test
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install dependencies
24 | shell: bash -l {0}
25 | run: |
26 | set -eux
27 | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
28 | pip install -r requirements.txt
29 | pip install -r dev-requirements.txt
30 | pip install --no-build-isolation -e ".[dev]"
31 | - name: Run unit tests with coverage
32 | shell: bash -l {0}
33 | run: |
34 | set -eux
35 | pytest --cov=. --cov-report xml tests -vv
36 | - name: Upload Coverage to Codecov
37 | uses: codecov/codecov-action@v2
38 |
39 | gpu_unit_tests:
40 | runs-on: ${{ matrix.os }}
41 | strategy:
42 | matrix:
43 | os: [linux.8xlarge.nvidia.gpu]
44 | python-version: [3.8]
45 | cuda-tag: ["cu11"]
46 | steps:
47 | - name: Check out repo
48 | uses: actions/checkout@v2
49 | - name: Setup conda env
50 | uses: conda-incubator/setup-miniconda@v2
51 | with:
52 | miniconda-version: "latest"
53 | activate-environment: test
54 | python-version: ${{ matrix.python-version }}
55 | - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG
56 | uses: pytorch/test-infra/.github/actions/setup-nvidia@main
57 | - name: Display EC2 information
58 | shell: bash
59 | run: |
60 | set -euo pipefail
61 | function get_ec2_metadata() {
62 | # Pulled from instance metadata endpoint for EC2
63 | # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
64 | category=$1
65 | curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
66 | }
67 | echo "ami-id: $(get_ec2_metadata ami-id)"
68 | echo "instance-id: $(get_ec2_metadata instance-id)"
69 | echo "instance-type: $(get_ec2_metadata instance-type)"
70 | - name: Install dependencies
71 | shell: bash -l {0}
72 | run: |
73 | set -eux
74 | conda activate test
75 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
76 | # Use stable fbgemm-gpu
77 | pip uninstall -y fbgemm-gpu-nightly
78 | pip install fbgemm-gpu==0.2.0
79 | pip install -r requirements.txt
80 | pip install -r dev-requirements.txt
81 | pip install --no-build-isolation -e ".[dev]"
82 | - name: Run unit tests with coverage
83 | shell: bash -l {0}
84 | run: |
85 | set -eux
86 | conda activate test
87 | pytest --timeout=60 --cov=. --cov-report xml -vv -rA -m "gpu_only or cpu_and_gpu" tests
88 | - name: Upload coverage to codecov
89 | uses: codecov/codecov-action@v2
90 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.classification.accuracy import (
10 | binary_accuracy,
11 | multiclass_accuracy,
12 | multilabel_accuracy,
13 | topk_multilabel_accuracy,
14 | )
15 | from torcheval.metrics.functional.classification.auprc import (
16 | binary_auprc,
17 | multiclass_auprc,
18 | multilabel_auprc,
19 | )
20 |
21 | from torcheval.metrics.functional.classification.auroc import (
22 | binary_auroc,
23 | multiclass_auroc,
24 | )
25 |
26 | from torcheval.metrics.functional.classification.binary_normalized_entropy import (
27 | binary_normalized_entropy,
28 | )
29 | from torcheval.metrics.functional.classification.binned_auprc import (
30 | binary_binned_auprc,
31 | multiclass_binned_auprc,
32 | multilabel_binned_auprc,
33 | )
34 | from torcheval.metrics.functional.classification.binned_auroc import (
35 | binary_binned_auroc,
36 | multiclass_binned_auroc,
37 | )
38 | from torcheval.metrics.functional.classification.binned_precision_recall_curve import (
39 | binary_binned_precision_recall_curve,
40 | multiclass_binned_precision_recall_curve,
41 | multilabel_binned_precision_recall_curve,
42 | )
43 | from torcheval.metrics.functional.classification.confusion_matrix import (
44 | binary_confusion_matrix,
45 | multiclass_confusion_matrix,
46 | )
47 | from torcheval.metrics.functional.classification.f1_score import (
48 | binary_f1_score,
49 | multiclass_f1_score,
50 | )
51 | from torcheval.metrics.functional.classification.precision import (
52 | binary_precision,
53 | multiclass_precision,
54 | )
55 | from torcheval.metrics.functional.classification.precision_recall_curve import (
56 | binary_precision_recall_curve,
57 | multiclass_precision_recall_curve,
58 | multilabel_precision_recall_curve,
59 | )
60 | from torcheval.metrics.functional.classification.recall import (
61 | binary_recall,
62 | multiclass_recall,
63 | )
64 | from torcheval.metrics.functional.classification.recall_at_fixed_precision import (
65 | binary_recall_at_fixed_precision,
66 | multilabel_recall_at_fixed_precision,
67 | )
68 |
69 | __all__ = [
70 | "binary_accuracy",
71 | "binary_auprc",
72 | "binary_auroc",
73 | "binary_binned_auprc",
74 | "binary_binned_auroc",
75 | "binary_binned_precision_recall_curve",
76 | "binary_confusion_matrix",
77 | "binary_f1_score",
78 | "binary_normalized_entropy",
79 | "binary_precision",
80 | "binary_precision_recall_curve",
81 | "binary_recall",
82 | "binary_recall_at_fixed_precision",
83 | "multiclass_accuracy",
84 | "multiclass_auprc",
85 | "multiclass_auroc",
86 | "multiclass_binned_auprc",
87 | "multiclass_binned_auroc",
88 | "multiclass_binned_precision_recall_curve",
89 | "multiclass_confusion_matrix",
90 | "multiclass_f1_score",
91 | "multiclass_precision",
92 | "multiclass_precision_recall_curve",
93 | "multiclass_recall",
94 | "multilabel_accuracy",
95 | "multilabel_auprc",
96 | "multilabel_binned_auprc",
97 | "multilabel_binned_precision_recall_curve",
98 | "multilabel_precision_recall_curve",
99 | "multilabel_recall_at_fixed_precision",
100 | "topk_multilabel_accuracy",
101 | ]
102 | __doc_name__ = "Classification Metrics"
103 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/text/word_information_preserved.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 | from torcheval.metrics.functional.text.helper import _get_errors_and_totals
13 |
14 |
15 | @torch.inference_mode()
16 | def word_information_preserved(
17 | input: str | list[str],
18 | target: str | list[str],
19 | ) -> torch.Tensor:
20 | """
21 | Compute the word information preserved score of the predicted word sequence(s) against the reference word sequence(s).
22 | Its class version is ``torcheval.metrics.WordInformationPreserved``.
23 |
24 | Args:
25 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings.
26 | target (str, List[str]): Reference word sequence(s) as a string or list of strings.
27 |
28 | Examples:
29 |
30 | >>> import torch
31 | >>> from torcheval.metrics.functional import word_information_preserved
32 | >>> input = ["hello world", "welcome to the facebook"]
33 | >>> target = ["hello metaverse", "welcome to meta"]
34 | >>> word_information_preserved(input, target)
35 | tensor(0.3)
36 | >>> input = ["this is the prediction", "there is an other sample"]
37 | >>> target = ["this is the reference", "there is another one"]
38 | >>> word_information_preserved(input, target)
39 | tensor(0.3472)
40 | """
41 | correct_total, target_total, input_total = _word_information_preserved_update(
42 | input, target
43 | )
44 | return _word_information_preserved_compute(correct_total, target_total, input_total)
45 |
46 |
47 | def _word_information_preserved_update(
48 | input: str | list[str],
49 | target: str | list[str],
50 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51 | """
52 | Update the word information preserved score with current set of predictions and references.
53 |
54 | Args:
55 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings.
56 | target (str, List[str]): Reference word sequence(s) as a string or list of strings.
57 | """
58 | _word_information_preserved_input_check(input, target)
59 | errors, max_total, target_total, input_total = _get_errors_and_totals(input, target)
60 |
61 | return max_total - errors, target_total, input_total
62 |
63 |
64 | def _word_information_preserved_compute(
65 | correct_total: torch.Tensor, target_total: torch.Tensor, input_total: torch.Tensor
66 | ) -> torch.Tensor:
67 | """
68 | Return the word information preserved score
69 |
70 | Args:
71 | correct_total (Tensor): number of words that are correctly predicted, summed over all samples
72 | target_total (Tensor): length of reference sequence, summed over all samples.
73 | input_total (Tensor): length of predicted sequence, summed over all samples.
74 | """
75 | return (correct_total / target_total) * (correct_total / input_total)
76 |
77 |
78 | def _word_information_preserved_input_check(
79 | input: str | list[str],
80 | target: str | list[str],
81 | ) -> None:
82 | if type(input) != type(target):
83 | raise ValueError(
84 | f"input and target should have the same type, got {type(input)} and {type(target)}."
85 | )
86 | if type(input) == list:
87 | if len(input) != len(target):
88 | raise ValueError(
89 | f"input and target lists should have the same length, got {len(input)} and {len(target)}",
90 | )
91 |
--------------------------------------------------------------------------------
/tests/metrics/image/test_psnr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 | from skimage.metrics import peak_signal_noise_ratio as skimage_peak_signal_noise_ratio
13 | from torcheval.metrics import PeakSignalNoiseRatio
14 | from torcheval.utils.test_utils.metric_class_tester import (
15 | BATCH_SIZE,
16 | IMG_CHANNELS,
17 | IMG_HEIGHT,
18 | IMG_WIDTH,
19 | MetricClassTester,
20 | NUM_TOTAL_UPDATES,
21 | )
22 |
23 |
24 | class TestPeakSignalNoiseRatio(MetricClassTester):
25 | def _get_random_data_PeakSignalToNoiseRatio(
26 | self,
27 | num_updates: int,
28 | batch_size: int,
29 | num_channels: int,
30 | height: int,
31 | width: int,
32 | ) -> tuple[torch.Tensor, torch.Tensor]:
33 | inputs = torch.rand(
34 | size=(num_updates, batch_size, num_channels, height, width),
35 | )
36 | targets = torch.rand(
37 | size=(num_updates, batch_size, num_channels, height, width),
38 | )
39 | return inputs, targets
40 |
41 | def _test_psnr_skimage_equivelant(
42 | self,
43 | input: torch.Tensor,
44 | target: torch.Tensor,
45 | data_range: float | None = None,
46 | ) -> None:
47 | input_np = input.numpy().ravel()
48 | target_np = target.numpy().ravel()
49 |
50 | skimage_result = torch.tensor(
51 | skimage_peak_signal_noise_ratio(
52 | image_true=target_np, image_test=input_np, data_range=data_range
53 | )
54 | )
55 |
56 | state_names = {
57 | "num_observations",
58 | "sum_squared_error",
59 | "data_range",
60 | "min_target",
61 | "max_target",
62 | }
63 |
64 | self.run_class_implementation_tests(
65 | metric=PeakSignalNoiseRatio(data_range=data_range),
66 | state_names=state_names,
67 | update_kwargs={"input": input, "target": target},
68 | compute_result=skimage_result.to(torch.float32),
69 | )
70 |
71 | def test_psnr_with_random_data(self) -> None:
72 | input, target = self._get_random_data_PeakSignalToNoiseRatio(
73 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH
74 | )
75 | self._test_psnr_skimage_equivelant(input, target)
76 |
77 | def test_psnr_with_random_data_and_data_range(self) -> None:
78 | input, target = self._get_random_data_PeakSignalToNoiseRatio(
79 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH
80 | )
81 | self._test_psnr_skimage_equivelant(input, target, data_range=0.5)
82 |
83 | def test_psnr_class_invalid_input(self) -> None:
84 | metric = PeakSignalNoiseRatio()
85 | with self.assertRaisesRegex(
86 | ValueError,
87 | "The `input` and `target` must have the same shape, "
88 | r"got shapes torch.Size\(\[4, 3, 4, 4\]\) and torch.Size\(\[4, 3, 4, 6\]\).",
89 | ):
90 | metric.update(torch.rand(4, 3, 4, 4), torch.rand(4, 3, 4, 6))
91 |
92 | def test_psnr_class_invalid_data_range(self) -> None:
93 | with self.assertRaisesRegex(
94 | ValueError, "`data_range needs to be either `None` or `float`."
95 | ):
96 | PeakSignalNoiseRatio(data_range=5)
97 |
98 | with self.assertRaisesRegex(ValueError, "`data_range` needs to be positive."):
99 | PeakSignalNoiseRatio(data_range=-1.0)
100 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/torcheval/metrics/text/word_error_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.functional.text.word_error_rate import (
17 | _word_error_rate_compute,
18 | _word_error_rate_update,
19 | )
20 | from torcheval.metrics.metric import Metric
21 |
22 | TWordErrorRate = TypeVar("TWordErrorRate")
23 |
24 |
25 | class WordErrorRate(Metric[torch.Tensor]):
26 | """
27 | Compute the word error rate of the predicted word sequence(s) with the reference word sequence(s).
28 | Its functional version is :func:`torcheval.metrics.functional.word_error_rate`.
29 |
30 | Examples:
31 |
32 | >>> import torch
33 | >>> from torcheval.metrics import WordErrorRate
34 |
35 | >>> metric = WordErrorRate()
36 | >>> metric.update(["this is the prediction", "there is an other sample"],
37 | ["this is the reference", "there is another one"])
38 | >>> metric.compute()
39 | tensor(0.5)
40 |
41 | >>> metric = WordErrorRate()
42 | >>> metric.update(["this is the prediction", "there is an other sample"],
43 | ["this is the reference", "there is another one"])
44 | >>> metric.update(["hello world", "welcome to the facebook"],
45 | ["hello metaverse", "welcome to meta"])
46 | >>> metric.compute()
47 | tensor(0.53846)
48 | """
49 |
50 | def __init__(
51 | self: TWordErrorRate,
52 | *,
53 | device: torch.device | None = None,
54 | ) -> None:
55 | super().__init__(device=device)
56 | self._add_state(
57 | "errors", torch.tensor(0, dtype=torch.float, device=self.device)
58 | )
59 | self._add_state("total", torch.tensor(0, dtype=torch.float, device=self.device))
60 |
61 | @torch.inference_mode()
62 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently.
63 | def update(
64 | self: TWordErrorRate,
65 | input: str | list[str],
66 | target: str | list[str],
67 | ) -> TWordErrorRate:
68 | """
69 | Update the metric state with edit distance and the length of the reference sequence.
70 |
71 | Args:
72 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings.
73 | target (str, List[str]): Reference word sequence(s) as a string or list of strings.
74 | """
75 | errors, total = _word_error_rate_update(input, target)
76 | self.errors += errors
77 | self.total += total
78 | return self
79 |
80 | @torch.inference_mode()
81 | def compute(self: TWordErrorRate) -> torch.Tensor:
82 | """
83 | Return the word error rate score
84 | """
85 | return _word_error_rate_compute(self.errors, self.total)
86 |
87 | @torch.inference_mode()
88 | def merge_state(
89 | self: TWordErrorRate,
90 | metrics: Iterable[TWordErrorRate],
91 | ) -> TWordErrorRate:
92 | """
93 | Merge the metric state with its counterparts from other metric instances.
94 |
95 | Args:
96 | metrics (Iterable[Metric]): metric instances whose states are to be merged.
97 | """
98 | for metric in metrics:
99 | self.errors += metric.errors.to(self.device)
100 | self.total += metric.total.to(self.device)
101 | return self
102 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/aggregation/auc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 |
11 |
12 | def _auc_compute(
13 | x: torch.Tensor, y: torch.Tensor, reorder: bool = False
14 | ) -> torch.Tensor:
15 | """Computes area under the curve using the trapezoidal rule.
16 | Args:
17 | x: x-coordinates,
18 | y: y-coordinates
19 | reorder: sorts the x input tensor in order, default value is False
20 | Return:
21 | Tensor containing AUC score (float)
22 | """
23 | if x.numel() == 0 or y.numel() == 0:
24 | return torch.tensor([])
25 |
26 | if x.ndim == 1:
27 | x = x.unsqueeze(0)
28 | if y.ndim == 1:
29 | y = y.unsqueeze(0)
30 |
31 | if reorder:
32 | x, x_idx = torch.sort(x, dim=1, stable=True)
33 | y = y.gather(1, x_idx)
34 |
35 | return torch.trapz(y, x)
36 |
37 |
38 | def _auc_update_input_check(x: torch.Tensor, y: torch.Tensor, n_tasks: int = 1) -> None:
39 | """
40 | Checks if the 2 input tensors have the same shape
41 | Checks if the 2 input tensors have atleast 1 elements.
42 | Args:
43 | x: x-coordinates
44 | y: y-coordinates
45 | n_tasks: Number of tasks that need AUC calculation. Default value is 1.
46 | """
47 |
48 | size_x = x.size()
49 | size_y = y.size()
50 |
51 | if x.ndim == 1:
52 | x = x.unsqueeze(0)
53 | if y.ndim == 1:
54 | y = y.unsqueeze(0)
55 |
56 | if x.numel() == 0 or y.numel() == 0:
57 | raise ValueError(
58 | f"The `x` and `y` should have atleast 1 element, got shapes {size_x} and {size_y}."
59 | )
60 | if x.size() != y.size():
61 | raise ValueError(
62 | f"Expected the same shape in `x` and `y` tensor but got shapes {size_x} and {size_y}."
63 | )
64 |
65 | if x.size(0) != n_tasks or y.size(0) != n_tasks:
66 | raise ValueError(
67 | f"Expected `x` dim_1={x.size(0)} and `y` dim_1={y.size(0)} have first dimension equals to n_tasks={n_tasks}."
68 | )
69 |
70 |
71 | def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
72 | """Computes Area Under the Curve (AUC) using the trapezoidal rule.
73 | Args:
74 | x: x-coordinates
75 | y: y-coordinates
76 | reorder: sorts the x input tensor in order, default value is False
77 | Return:
78 | Tensor containing AUC score (float)
79 | Raises:
80 | ValueError:
81 | If both ``x`` and ``y`` don't have the same shape.
82 | If both ``x`` and ``y`` have atleast 1 element.
83 | Example:
84 | >>> from torcheval.metrics.functional.aggregation.auc import auc
85 | >>> x = torch.tensor([0,.1,.2,.3])
86 | >>> y = torch.tensor([1,1,1,1])
87 | >>> auc(x, y)
88 | tensor([0.3000])
89 | >>> y = torch.tensor([[0, 4, 0, 4, 3],
90 | [1, 1, 2, 1, 1],
91 | [4, 3, 1, 4, 4],
92 | [1, 0, 0, 3, 0]])
93 | >>> x = torch.tensor([[0.2535, 0.1138, 0.1324, 0.1887, 0.3117],
94 | [0.1434, 0.4404, 0.1100, 0.1178, 0.1883],
95 | [0.2344, 0.1743, 0.3110, 0.0393, 0.2410],
96 | [0.1381, 0.1564, 0.0320, 0.2220, 0.4515]])
97 | >>> auc(x, y, reorder=True) # Reorders X and calculates AUC.
98 | tensor([0.3667, 0.3343, 0.8843, 0.5048])
99 | """
100 | n_tasks = 1
101 | if x.ndim > 1:
102 | n_tasks = x.size(0)
103 | _auc_update_input_check(x, y, n_tasks)
104 | return _auc_compute(x, y, reorder)
105 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | from torcheval.metrics.functional.aggregation import auc, mean, sum, throughput
10 | from torcheval.metrics.functional.classification import (
11 | binary_accuracy,
12 | binary_auprc,
13 | binary_auroc,
14 | binary_binned_auprc,
15 | binary_binned_auroc,
16 | binary_binned_precision_recall_curve,
17 | binary_confusion_matrix,
18 | binary_f1_score,
19 | binary_normalized_entropy,
20 | binary_precision,
21 | binary_precision_recall_curve,
22 | binary_recall,
23 | binary_recall_at_fixed_precision,
24 | multiclass_accuracy,
25 | multiclass_auprc,
26 | multiclass_auroc,
27 | multiclass_binned_auprc,
28 | multiclass_binned_auroc,
29 | multiclass_binned_precision_recall_curve,
30 | multiclass_confusion_matrix,
31 | multiclass_f1_score,
32 | multiclass_precision,
33 | multiclass_precision_recall_curve,
34 | multiclass_recall,
35 | multilabel_accuracy,
36 | multilabel_auprc,
37 | multilabel_binned_auprc,
38 | multilabel_binned_precision_recall_curve,
39 | multilabel_precision_recall_curve,
40 | multilabel_recall_at_fixed_precision,
41 | topk_multilabel_accuracy,
42 | )
43 | from torcheval.metrics.functional.frechet import gaussian_frechet_distance
44 | from torcheval.metrics.functional.image import peak_signal_noise_ratio
45 | from torcheval.metrics.functional.ranking import (
46 | click_through_rate,
47 | frequency_at_k,
48 | hit_rate,
49 | num_collisions,
50 | reciprocal_rank,
51 | retrieval_precision,
52 | retrieval_recall,
53 | weighted_calibration,
54 | )
55 | from torcheval.metrics.functional.regression import mean_squared_error, r2_score
56 | from torcheval.metrics.functional.text import (
57 | bleu_score,
58 | perplexity,
59 | word_error_rate,
60 | word_information_lost,
61 | word_information_preserved,
62 | )
63 |
64 | __all__ = [
65 | "auc",
66 | "binary_accuracy",
67 | "binary_auprc",
68 | "binary_auroc",
69 | "binary_binned_auprc",
70 | "binary_binned_auroc",
71 | "binary_binned_precision_recall_curve",
72 | "binary_confusion_matrix",
73 | "binary_f1_score",
74 | "binary_normalized_entropy",
75 | "binary_precision",
76 | "binary_precision_recall_curve",
77 | "binary_recall",
78 | "binary_recall_at_fixed_precision",
79 | "bleu_score",
80 | "click_through_rate",
81 | "frequency_at_k",
82 | "gaussian_frechet_distance",
83 | "hit_rate",
84 | "mean",
85 | "mean_squared_error",
86 | "multiclass_accuracy",
87 | "multiclass_auprc",
88 | "multiclass_auroc",
89 | "multiclass_binned_auprc",
90 | "multiclass_binned_auroc",
91 | "multiclass_binned_precision_recall_curve",
92 | "multiclass_confusion_matrix",
93 | "multiclass_f1_score",
94 | "multiclass_precision",
95 | "multiclass_precision_recall_curve",
96 | "multiclass_recall",
97 | "multilabel_accuracy",
98 | "multilabel_auprc",
99 | "multilabel_binned_auprc",
100 | "multilabel_binned_precision_recall_curve",
101 | "multilabel_precision_recall_curve",
102 | "multilabel_recall_at_fixed_precision",
103 | "num_collisions",
104 | "peak_signal_noise_ratio",
105 | "perplexity",
106 | "r2_score",
107 | "reciprocal_rank",
108 | "retrieval_precision",
109 | "retrieval_recall",
110 | "sum",
111 | "throughput",
112 | "topk_multilabel_accuracy",
113 | "weighted_calibration",
114 | "word_error_rate",
115 | "word_information_preserved",
116 | "word_information_lost",
117 | ]
118 |
--------------------------------------------------------------------------------
/torcheval/metrics/ranking/hit_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.functional import hit_rate
17 | from torcheval.metrics.metric import Metric
18 |
19 | THitRate = TypeVar("THitRate")
20 |
21 |
22 | class HitRate(Metric[torch.Tensor]):
23 | """
24 | Compute the hit rate of the correct class among the top predicted classes.
25 | Its functional version is :func:`torcheval.metrics.functional.hit_rate`.
26 |
27 | Args:
28 | k (int, optional): Number of top class probabilities to be considered.
29 | If k is None, all classes are considered and a hit rate of 1.0 is returned.
30 |
31 | Examples::
32 |
33 | >>> import torch
34 | >>> from torcheval.metrics import HitRate
35 |
36 | >>> metric = HitRate()
37 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1]))
38 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0]))
39 | >>> metric.compute()
40 | tensor([1., 1., 1., 1.])
41 |
42 | >>> metric = HitRate(k=2)
43 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1]))
44 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0]))
45 | >>> metric.compute()
46 | tensor([1., 0., 0., 1.])
47 | """
48 |
49 | def __init__(
50 | self: THitRate,
51 | *,
52 | k: int | None = None,
53 | device: torch.device | None = None,
54 | ) -> None:
55 | super().__init__(device=device)
56 | self.k = k
57 | self._add_state("scores", [])
58 |
59 | @torch.inference_mode()
60 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently.
61 | def update(self: THitRate, input: torch.Tensor, target: torch.Tensor) -> THitRate:
62 | """
63 | Update the metric state with the ground truth labels and predictions.
64 |
65 | Args:
66 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or
67 | class probabilities of shape (num_samples, num_classes).
68 | target (Tensor): Ground truth class indices of shape (num_samples,).
69 | """
70 | self.scores.append(hit_rate(input, target, k=self.k))
71 | return self
72 |
73 | @torch.inference_mode()
74 | def compute(self: THitRate) -> torch.Tensor:
75 | """
76 | Return the concatenated hite rate scores. If no ``update()`` calls are made before
77 | ``compute()`` is called, return an empty tensor.
78 | """
79 | if not self.scores:
80 | return torch.empty(0)
81 | return torch.cat(self.scores, dim=0)
82 |
83 | @torch.inference_mode()
84 | def merge_state(self: THitRate, metrics: Iterable[THitRate]) -> THitRate:
85 | """
86 | Merge the metric state with its counterparts from other metric instances.
87 |
88 | Args:
89 | metrics (Iterable[Metric]): metric instances whose states are to be merged.
90 | """
91 | for metric in metrics:
92 | if metric.scores:
93 | self.scores.append(torch.cat(metric.scores).to(self.device))
94 | return self
95 |
96 | @torch.inference_mode()
97 | def _prepare_for_merge_state(self: THitRate) -> None:
98 | if self.scores:
99 | self.scores = [torch.cat(self.scores)]
100 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | TorchEval
2 | ===========================================
3 |
4 | A library with simple and straightforward tooling for model evaluations and a delightful user experience. At a high level TorchEval:
5 |
6 | 1. Contains a rich collection of high performance metric calculations out of the box. We utilize vectorization and GPU acceleration where possible via PyTorch.
7 | 2. Integrates seamlessly with distributed training and tools using `torch.distributed `_
8 | 3. Is designed with extensibility in mind: you have the freedom to easily create your own metrics and leverage our toolkit.
9 | 4. Provides tools for profiling memory and compute requirements for PyTorch based models.
10 |
11 | QuickStart
12 | ===========================================
13 |
14 | Installing
15 | -----------------
16 |
17 | TorchEval can be installed from PyPi via
18 |
19 | .. code-block:: console
20 |
21 | pip install torcheval
22 |
23 | or from github
24 |
25 | .. code-block:: console
26 |
27 | git clone https://github.com/pytorch/torcheval
28 | cd torcheval
29 | pip install -r requirements.txt
30 | python setup.py install
31 |
32 | Usage
33 | -----------------
34 |
35 | TorchEval provides two interfaces to each metric. If you are working in a single process environment, it is simplest to use metrics from the ``functional`` submodule. These can be found in ``torcheval.metrics.functional``.
36 |
37 | .. code-block:: python
38 |
39 | from torcheval.metrics.functional import binary_f1_score
40 | predictions = model(inputs)
41 | f1_score = binary_f1_score(predictions, targets)
42 |
43 | We can use the same metric in the class based route, which provides tools that make computation simple in a multi-process setting. On a single device, you can use the class based metrics as follows:
44 |
45 | .. code-block:: python
46 |
47 | from torcheval.metrics import BinaryF1Score
48 | predictions = model(inputs)
49 | metric = BinaryF1Score()
50 | metric.update(predictions, targets)
51 | f1_score = metric.compute()
52 |
53 | In a multi-process setting, the data from each process must be synchronized to compute the metric across the full dataset. To do this, simply replace ``metric.compute()`` with ``sync_and_compute(metric)``:
54 |
55 | .. code-block:: python
56 |
57 | from torcheval.metrics import BinaryF1Score
58 | from torcheval.metrics.toolkit import sync_and_compute
59 | predictions = model(inputs)
60 | metric = BinaryF1Score()
61 | metric.update(predictions, targets)
62 | f1_score = sync_and_compute(metric)
63 |
64 | Read more about the class based method in the distributed example.
65 |
66 | Further Reading
67 | -----------------
68 | * Check out the guides explaining the compute example
69 | * Check out the distributed example
70 | * Check out how to make your own metric
71 |
72 | Indices and tables
73 | ==================
74 |
75 | * :ref:`genindex`
76 | * :ref:`modindex`
77 | * :ref:`search`
78 |
79 | Getting Started
80 | -------------------
81 | .. fbcode::
82 |
83 | .. toctree::
84 | :maxdepth: 2
85 | :caption: Getting Started (Meta)
86 | :glob:
87 |
88 | meta/getting_started.rst
89 |
90 | .. toctree::
91 | :maxdepth: 2
92 | :caption: Migration (Meta)
93 | :glob:
94 |
95 | meta/migrating_to_torcheval.rst
96 |
97 | TorchEval Tutorials
98 | -------------------
99 | .. toctree::
100 | :maxdepth: 2
101 | :caption: Examples:
102 |
103 | QuickStart Notebook
104 | metric_example.rst
105 |
106 | TorchEval API
107 | -----------------
108 |
109 | .. toctree::
110 | :maxdepth: 2
111 | :caption: Contents:
112 |
113 | torcheval.metrics.rst
114 | torcheval.metrics.functional.rst
115 | torcheval.metrics.toolkit.rst
116 |
--------------------------------------------------------------------------------
/tests/metrics/ranking/test_click_through_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | import torch
10 | from torcheval.metrics.ranking import ClickThroughRate
11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
12 |
13 |
14 | class TestClickThroughRate(MetricClassTester):
15 | def test_ctr_with_valid_input(self) -> None:
16 | input = torch.tensor([[1, 0, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1]])
17 |
18 | self.run_class_implementation_tests(
19 | metric=ClickThroughRate(),
20 | state_names={"click_total", "weight_total"},
21 | update_kwargs={"input": input},
22 | compute_result=torch.tensor([0.5625], dtype=torch.float64),
23 | num_total_updates=4,
24 | num_processes=2,
25 | )
26 |
27 | input = torch.tensor(
28 | [
29 | [[1, 0, 0, 1], [1, 1, 1, 1]],
30 | [[0, 0, 0, 0], [1, 1, 1, 1]],
31 | [[0, 1, 0, 1], [0, 1, 0, 1]],
32 | [[1, 1, 1, 1], [0, 1, 1, 1]],
33 | ]
34 | )
35 | weights = torch.tensor(
36 | [
37 | [[1, 2, 3, 4], [0, 0, 0, 0]],
38 | [[1, 2, 1, 2], [1, 2, 1, 2]],
39 | [[1, 1, 1, 1], [1, 1, 3, 1]],
40 | [[1, 1, 1, 1], [1, 1, 1, 1]],
41 | ]
42 | )
43 |
44 | self.run_class_implementation_tests(
45 | metric=ClickThroughRate(num_tasks=2),
46 | state_names={"click_total", "weight_total"},
47 | update_kwargs={"input": input, "weights": weights},
48 | compute_result=torch.tensor([0.4583333, 0.6875], dtype=torch.float64),
49 | num_total_updates=4,
50 | num_processes=2,
51 | )
52 |
53 | weights = [4.0, 1, torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]), 0.0]
54 |
55 | self.run_class_implementation_tests(
56 | metric=ClickThroughRate(num_tasks=2),
57 | state_names={"click_total", "weight_total"},
58 | update_kwargs={"input": input, "weights": weights},
59 | compute_result=torch.tensor([0.46666667, 0.86666667], dtype=torch.float64),
60 | num_total_updates=4,
61 | num_processes=2,
62 | )
63 |
64 | def test_ctr_with_invalid_input(self) -> None:
65 | metric = ClickThroughRate()
66 | with self.assertRaisesRegex(
67 | ValueError,
68 | "^`input` should be a one or two dimensional tensor",
69 | ):
70 | metric.update(torch.rand(3, 2, 2))
71 |
72 | metric = ClickThroughRate()
73 | with self.assertRaisesRegex(
74 | ValueError,
75 | "^tensor `weights` should have the same shape as tensor `input`",
76 | ):
77 | metric.update(torch.rand(4, 2), torch.rand(3))
78 | with self.assertRaisesRegex(
79 | ValueError,
80 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,",
81 | ):
82 | metric.update(
83 | torch.tensor([[1, 1], [0, 1]]),
84 | )
85 |
86 | metric = ClickThroughRate(num_tasks=2)
87 | with self.assertRaisesRegex(
88 | ValueError,
89 | r"`num_tasks = 2`, `input`'s shape is expected to be",
90 | ):
91 | metric.update(
92 | torch.tensor([1, 0, 0, 1]),
93 | )
94 |
95 | with self.assertRaisesRegex(
96 | ValueError,
97 | r"`num_tasks` value should be greater than and equal to 1,",
98 | ):
99 | metric = ClickThroughRate(num_tasks=0)
100 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 |
156 | # MacOS
157 | .DS_Store
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/torcheval/utils/test_utils/dummy_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections import defaultdict
12 | from collections.abc import Iterable
13 | from typing import TypeVar
14 |
15 | import torch
16 |
17 | from torcheval.metrics import Metric
18 |
19 | TDummySumMetric = TypeVar("TDummySumMetric")
20 |
21 |
22 | class DummySumMetric(Metric[torch.Tensor]):
23 | def __init__(self: TDummySumMetric, *, device: torch.device | None = None) -> None:
24 | super().__init__(device=device)
25 | self._add_state("sum", torch.tensor(0.0, device=self.device))
26 |
27 | @torch.inference_mode()
28 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
29 | def update(self: TDummySumMetric, x: torch.Tensor) -> TDummySumMetric:
30 | self.sum += x
31 | return self
32 |
33 | @torch.inference_mode()
34 | def compute(self: TDummySumMetric) -> torch.Tensor:
35 | return self.sum
36 |
37 | @torch.inference_mode()
38 | def merge_state(
39 | self: TDummySumMetric, metrics: Iterable[TDummySumMetric]
40 | ) -> TDummySumMetric:
41 | for metric in metrics:
42 | self.sum += metric.sum.to(self.device)
43 | return self
44 |
45 |
46 | TDummySumListStateMetric = TypeVar("TDummySumListStateMetric")
47 |
48 |
49 | class DummySumListStateMetric(Metric[torch.Tensor]):
50 | def __init__(
51 | self: TDummySumListStateMetric, *, device: torch.device | None = None
52 | ) -> None:
53 | super().__init__(device=device)
54 | self._add_state("x", [])
55 |
56 | @torch.inference_mode()
57 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
58 | def update(
59 | self: TDummySumListStateMetric, x: torch.Tensor
60 | ) -> TDummySumListStateMetric:
61 | self.x.append(x.to(self.device))
62 | return self
63 |
64 | @torch.inference_mode()
65 | def compute(self: TDummySumListStateMetric) -> torch.Tensor:
66 | # pyre-fixme[7]: Expected `Tensor` but got `int`.
67 | return sum(tensor.sum() for tensor in self.x)
68 |
69 | @torch.inference_mode()
70 | def merge_state(
71 | self: TDummySumListStateMetric, metrics: Iterable[TDummySumListStateMetric]
72 | ) -> TDummySumListStateMetric:
73 | for metric in metrics:
74 | self.x.extend(element.to(self.device) for element in metric.x)
75 | return self
76 |
77 |
78 | TDummySumDictStateMetric = TypeVar("TDummySumDictStateMetric")
79 |
80 |
81 | class DummySumDictStateMetric(Metric[torch.Tensor]):
82 | def __init__(
83 | self: TDummySumDictStateMetric, *, device: torch.device | None = None
84 | ) -> None:
85 | super().__init__(device=device)
86 | self._add_state("x", defaultdict(lambda: torch.tensor(0.0, device=self.device)))
87 |
88 | @torch.inference_mode()
89 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
90 | def update(
91 | self: TDummySumDictStateMetric,
92 | k: str,
93 | v: torch.Tensor,
94 | ) -> TDummySumDictStateMetric:
95 | self.x[k] += v
96 | return self
97 |
98 | @torch.inference_mode()
99 | def compute(self: TDummySumDictStateMetric) -> torch.Tensor:
100 | return self.x
101 |
102 | @torch.inference_mode()
103 | def merge_state(
104 | self: TDummySumDictStateMetric, metrics: Iterable[TDummySumDictStateMetric]
105 | ) -> TDummySumDictStateMetric:
106 | for metric in metrics:
107 | for k in metric.keys():
108 | self.x[k] += metric.x[k].to(self.device)
109 |
110 | return self
111 |
--------------------------------------------------------------------------------
/torcheval/metrics/ranking/reciprocal_rank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | from collections.abc import Iterable
12 | from typing import TypeVar
13 |
14 | import torch
15 |
16 | from torcheval.metrics.functional import reciprocal_rank
17 | from torcheval.metrics.metric import Metric
18 |
19 |
20 | TReciprocalRank = TypeVar("TReciprocalRank")
21 |
22 |
23 | class ReciprocalRank(Metric[torch.Tensor]):
24 | """
25 | Compute the reciprocal rank of the correct class among the top predicted classes.
26 | Its functional version is :func:`torcheval.metrics.functional.reciprocal_rank`.
27 |
28 | Args:
29 | k (int, optional): Number of top class probabilities to be considered.
30 |
31 | Examples::
32 |
33 | >>> import torch
34 | >>> from torcheval.metrics import ReciprocalRank
35 |
36 | >>> metric = ReciprocalRank()
37 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1]))
38 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0]))
39 | >>> metric.compute()
40 | tensor([1.0000, 0.3333, 0.3333, 0.5000])
41 |
42 | >>> metric = ReciprocalRank(k=2)
43 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1]))
44 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0]))
45 | >>> metric.compute()
46 | tensor([1.0000, 0.0000, 0.0000, 0.5000])
47 | """
48 |
49 | def __init__(
50 | self: TReciprocalRank,
51 | *,
52 | k: int | None = None,
53 | device: torch.device | None = None,
54 | ) -> None:
55 | super().__init__(device=device)
56 | self.k = k
57 | self._add_state("scores", [])
58 |
59 | @torch.inference_mode()
60 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently.
61 | def update(
62 | self: TReciprocalRank, input: torch.Tensor, target: torch.Tensor
63 | ) -> TReciprocalRank:
64 | """
65 | Update the metric state with the ground truth labels and predictions.
66 |
67 | Args:
68 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or
69 | class probabilities of shape (num_samples, num_classes).
70 | target (Tensor): Ground truth class indices of shape (num_samples,).
71 | """
72 | self.scores.append(reciprocal_rank(input, target, k=self.k))
73 | return self
74 |
75 | @torch.inference_mode()
76 | def compute(self: TReciprocalRank) -> torch.Tensor:
77 | """
78 | Return the concatenated reciprocal rank scores. If no ``update()`` calls are made before
79 | ``compute()`` is called, return an empty tensor.
80 | """
81 | if not self.scores:
82 | return torch.empty(0)
83 | return torch.cat(self.scores, dim=0)
84 |
85 | @torch.inference_mode()
86 | def merge_state(
87 | self: TReciprocalRank, metrics: Iterable[TReciprocalRank]
88 | ) -> TReciprocalRank:
89 | """
90 | Merge the metric state with its counterparts from other metric instances.
91 |
92 | Args:
93 | metrics (Iterable[Metric]): metric instances whose states are to be merged.
94 | """
95 | for metric in metrics:
96 | if metric.scores:
97 | self.scores.append(torch.cat(metric.scores).to(self.device))
98 | return self
99 |
100 | @torch.inference_mode()
101 | def _prepare_for_merge_state(self: TReciprocalRank) -> None:
102 | if self.scores:
103 | self.scores = [torch.cat(self.scores)]
104 |
--------------------------------------------------------------------------------
/tests/metrics/aggregation/test_sum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 | from torcheval.metrics import Sum
12 | from torcheval.utils.test_utils.metric_class_tester import (
13 | BATCH_SIZE,
14 | MetricClassTester,
15 | NUM_TOTAL_UPDATES,
16 | )
17 |
18 |
19 | class TestSum(MetricClassTester):
20 | def _test_sum_class_with_input(self, input_val_tensor: torch.Tensor) -> None:
21 | self.run_class_implementation_tests(
22 | metric=Sum(),
23 | state_names={"weighted_sum"},
24 | update_kwargs={"input": input_val_tensor},
25 | compute_result=torch.sum(input_val_tensor).to(torch.float64),
26 | )
27 |
28 | def test_sum_class_base(self) -> None:
29 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE)
30 | self._test_sum_class_with_input(input_val_tensor)
31 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4)
32 | self._test_sum_class_with_input(input_val_tensor)
33 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4)
34 | self._test_sum_class_with_input(input_val_tensor)
35 |
36 | def test_sum_class_update_input_dimension_different(self) -> None:
37 | self.run_class_implementation_tests(
38 | metric=Sum(),
39 | state_names={"weighted_sum"},
40 | update_kwargs={
41 | "input": [
42 | torch.tensor(1.0),
43 | torch.tensor([2.0, 3.0, 5.0]),
44 | torch.tensor([-1.0, 2.0]),
45 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]),
46 | ]
47 | },
48 | compute_result=torch.tensor(17.0, dtype=torch.float64),
49 | num_total_updates=4,
50 | num_processes=2,
51 | )
52 |
53 | def test_sum_class_update_input_valid_weight(self) -> None:
54 | update_inputs = [
55 | torch.rand(BATCH_SIZE),
56 | torch.rand(BATCH_SIZE, 4),
57 | torch.rand(BATCH_SIZE, 3, 4),
58 | torch.rand(5),
59 | torch.rand(10),
60 | ]
61 | update_weights = [
62 | torch.rand(BATCH_SIZE),
63 | torch.rand(BATCH_SIZE, 4),
64 | torch.rand(BATCH_SIZE, 3, 4),
65 | 0.8,
66 | 2,
67 | ]
68 |
69 | def _compute_result(
70 | update_inputs: list[torch.Tensor],
71 | update_weights: list[float | torch.Tensor],
72 | ) -> torch.Tensor:
73 | weighted_sum = torch.tensor(0.0, dtype=torch.float64)
74 | for v, w in zip(update_inputs, update_weights):
75 | if isinstance(w, torch.Tensor):
76 | w = w.numpy().flatten()
77 | weighted_sum += v.numpy().flatten().dot(w).sum()
78 | return weighted_sum
79 |
80 | self.run_class_implementation_tests(
81 | metric=Sum(),
82 | state_names={"weighted_sum"},
83 | update_kwargs={
84 | "input": update_inputs,
85 | "weight": update_weights,
86 | },
87 | compute_result=_compute_result(update_inputs, update_weights),
88 | num_total_updates=5,
89 | num_processes=5,
90 | )
91 |
92 | def test_sum_class_update_input_invalid_weight(self) -> None:
93 | metric = Sum()
94 | with self.assertRaisesRegex(
95 | ValueError,
96 | r"Weight must be either a float value or an int value or a tensor that matches the input tensor size.",
97 | ):
98 | metric.update(torch.tensor([2.0, 3.0]), weight=torch.tensor([0.5]))
99 |
100 | def test_sum_class_compute_without_update(self) -> None:
101 | metric = Sum()
102 | self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64))
103 |
--------------------------------------------------------------------------------
/torcheval/metrics/aggregation/throughput.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states.
10 |
11 | import logging
12 | from collections.abc import Iterable
13 | from typing import TypeVar
14 |
15 | import torch
16 |
17 | from torcheval.metrics.metric import Metric
18 |
19 | TThroughput = TypeVar("TThroughput")
20 |
21 | _logger: logging.Logger = logging.getLogger(__name__)
22 |
23 |
24 | class Throughput(Metric[float]):
25 | """
26 | Calculate the throughput value which is the number of elements processed per second.
27 |
28 | Note: In a distributed setting, it's recommended to use `world_size * metric.compute()`
29 | to get an approximation of total throughput. While using `sync_and_compute(metric)` requires
30 | state sync. Additionally, `sync_and_compute(metric)` will give a slightly different value compared
31 | to `world_size * metric.compute()`.
32 |
33 | Examples::
34 |
35 | >>> import time
36 | >>> import torch
37 | >>> from torcheval.metrics import Throughput
38 | >>> metric = Throughput()
39 | >>> items_processed = 64
40 | >>> ts = time.monotonic()
41 | >>> time.sleep(2.0) # simulate executing the program for 2 seconds
42 | >>> elapsed_time_sec = time.monotonic() - ts
43 | >>> metric.update(items_processed, elapsed_time_sec)
44 | >>> metric.compute()
45 | tensor(32.)
46 | """
47 |
48 | def __init__(
49 | self: TThroughput,
50 | *,
51 | device: torch.device | None = None,
52 | ) -> None:
53 | super().__init__(device=device)
54 | self._add_state("num_total", 0.0)
55 | self._add_state("elapsed_time_sec", 0.0)
56 |
57 | @torch.inference_mode()
58 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
59 | def update(
60 | self: TThroughput,
61 | num_processed: int,
62 | elapsed_time_sec: float,
63 | ) -> TThroughput:
64 | """
65 | Update states with the values and weights.
66 |
67 | Args:
68 | num_processed: Number of items processed
69 | elapsed_time_sec: Total elapsed time in seconds to process ``num_processed`` items
70 | Raises:
71 | ValueError:
72 | If ``num_processed`` is a negative number.
73 | If ``elapsed_time_sec`` is a non-positive number.
74 | """
75 | if num_processed < 0:
76 | raise ValueError(
77 | f"Expected num_processed to be a non-negative number, but received {num_processed}."
78 | )
79 | if elapsed_time_sec <= 0:
80 | raise ValueError(
81 | f"Expected elapsed_time_sec to be a positive number, but received {elapsed_time_sec}."
82 | )
83 |
84 | self.elapsed_time_sec += elapsed_time_sec
85 | self.num_total += num_processed
86 | return self
87 |
88 | @torch.inference_mode()
89 | def compute(self: TThroughput) -> float:
90 | if not self.elapsed_time_sec:
91 | _logger.warning("No calls to update() have been made - returning 0.0")
92 | return 0.0
93 |
94 | return self.num_total / self.elapsed_time_sec
95 |
96 | @torch.inference_mode()
97 | def merge_state(self: TThroughput, metrics: Iterable[TThroughput]) -> TThroughput:
98 | for metric in metrics:
99 | self.num_total += metric.num_total
100 | # this assumes the metric is used within a fully-synchronous program.
101 | # In this scenario, the slowest process becomes the bottleneck for the
102 | # program's execution. As a result, we use the max, as the overall throughput
103 | # is gated based on the rank that takes the longest to complete.
104 | # TODO: should this be configurable?
105 | self.elapsed_time_sec = max(self.elapsed_time_sec, metric.elapsed_time_sec)
106 | return self
107 |
--------------------------------------------------------------------------------
/torcheval/metrics/functional/ranking/click_through_rate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # pyre-strict
8 |
9 |
10 | import torch
11 |
12 |
13 | @torch.inference_mode()
14 | def click_through_rate(
15 | input: torch.Tensor,
16 | weights: torch.Tensor | None = None,
17 | *,
18 | num_tasks: int = 1,
19 | ) -> torch.Tensor:
20 | """
21 | Compute the click through rate given a click events.
22 | Its class version is ``torcheval.metrics.ClickThroughRate``.
23 |
24 | Args:
25 | input (Tensor): Series of values representing user click (1) or skip (0)
26 | of shape (num_events) or (num_objectives, num_events).
27 | weights (Tensor, Optional): Weights for each event, tensor with the same shape as input.
28 | num_tasks (int): Number of tasks that need weighted_calibration calculation. Default value
29 | is 1.
30 |
31 | Examples::
32 |
33 | >>> import torch
34 | >>> from torcheval.metrics.functional import click_through_rate
35 | >>> input = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1])
36 | >>> click_through_rate(input)
37 | tensor(0.5)
38 | >>> weights = torch.tensor([1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0])
39 | >>> click_through_rate(input, weights)
40 | tensor(0.58333)
41 | >>> input = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]])
42 | >>> weights = torch.tensor([[1.0, 2.0, 1.0, 2.0],[1.0, 2.0, 1.0, 1.0]])
43 | >>> click_through_rate(input, weights, num_tasks=2)
44 | tensor([0.6667, 0.4])
45 | """
46 | if weights is None:
47 | weights = 1.0
48 | click_total, weight_total = _click_through_rate_update(
49 | input, weights, num_tasks=num_tasks
50 | )
51 | return _click_through_rate_compute(click_total, weight_total)
52 |
53 |
54 | def _click_through_rate_update(
55 | input: torch.Tensor,
56 | weights: torch.Tensor | float | int = 1.0,
57 | *,
58 | num_tasks: int,
59 | ) -> tuple[torch.Tensor, torch.Tensor]:
60 | _click_through_rate_input_check(input, weights, num_tasks=num_tasks)
61 | if isinstance(weights, torch.Tensor):
62 | weights = weights.type(torch.float)
63 | click_total = (input * weights).sum(-1)
64 | weight_total = weights.sum(-1)
65 | else:
66 | click_total = weights * input.sum(-1).type(torch.float)
67 | weight_total = weights * input.size(-1) * torch.ones_like(click_total)
68 |
69 | return click_total, weight_total
70 |
71 |
72 | def _click_through_rate_compute(
73 | click_total: torch.Tensor,
74 | weight_total: torch.Tensor,
75 | ) -> torch.Tensor:
76 | # epsilon is a performant solution to divide by zero errors when weight_total = 0.0
77 | # Since click_total = input*weights, weights = 0.0 implies 0.0/(0.0 + eps) = 0.0
78 | eps = torch.finfo(weight_total.dtype).tiny
79 | return click_total / (weight_total + eps)
80 |
81 |
82 | def _click_through_rate_input_check(
83 | input: torch.Tensor,
84 | weights: torch.Tensor | float | int,
85 | *,
86 | num_tasks: int,
87 | ) -> None:
88 | if input.ndim != 1 and input.ndim != 2:
89 | raise ValueError(
90 | f"`input` should be a one or two dimensional tensor, got shape {input.shape}."
91 | )
92 | if isinstance(weights, torch.Tensor) and weights.shape != input.shape:
93 | raise ValueError(
94 | f"tensor `weights` should have the same shape as tensor `input`, got shapes {weights.shape} and {input.shape}, respectively."
95 | )
96 | if num_tasks == 1:
97 | if len(input.shape) > 1:
98 | raise ValueError(
99 | f"`num_tasks = 1`, `input` is expected to be one-dimensional tensor, but got shape ({input.shape})."
100 | )
101 | elif len(input.shape) == 1 or input.shape[0] != num_tasks:
102 | raise ValueError(
103 | f"`num_tasks = {num_tasks}`, `input`'s shape is expected to be ({num_tasks}, num_samples), but got shape ({input.shape})."
104 | )
105 |
--------------------------------------------------------------------------------