├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── documentation.yml │ ├── feature-request.yml │ └── help-support.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── build_docs.yaml │ ├── nightly_build_cpu.yaml │ ├── pre_commit.yaml │ ├── release_build.yaml │ ├── release_build_docs.yaml │ └── unit_test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dev-requirements.txt ├── docs ├── Makefile ├── build_docs.sh ├── license_header.txt ├── requirements.txt ├── source │ ├── _static │ │ ├── css │ │ │ └── torcheval.css │ │ └── js │ │ │ └── torcheval.js │ ├── conf.py │ ├── ext │ │ └── fbcode.py │ ├── index.rst │ ├── metric_example.rst │ ├── templates │ │ └── layout.html │ ├── torcheval.metrics.functional.rst │ ├── torcheval.metrics.rst │ └── torcheval.metrics.toolkit.rst └── update_docs.py ├── examples ├── Introducing_TorchEval.ipynb ├── distributed_example.py └── simple_example.py ├── image-requirements.txt ├── pyproject.toml ├── requirements.txt ├── setup.py ├── tests ├── metrics │ ├── __init__.py │ ├── aggregation │ │ ├── __init__.py │ │ ├── test_auc.py │ │ ├── test_cat.py │ │ ├── test_cov.py │ │ ├── test_max.py │ │ ├── test_mean.py │ │ ├── test_min.py │ │ ├── test_sum.py │ │ └── test_throughput.py │ ├── audio │ │ └── test_fad.py │ ├── classification │ │ ├── __init__.py │ │ ├── test_accuracy.py │ │ ├── test_auprc.py │ │ ├── test_auroc.py │ │ ├── test_binned_auprc.py │ │ ├── test_binned_auroc.py │ │ ├── test_binned_precision_recall_curve.py │ │ ├── test_confusion_matrix.py │ │ ├── test_f1_score.py │ │ ├── test_normalized_entropy.py │ │ ├── test_precision.py │ │ ├── test_precision_recall_curve.py │ │ ├── test_recall.py │ │ └── test_recall_at_fixed_precision.py │ ├── functional │ │ ├── __init__.py │ │ ├── aggregation │ │ │ ├── __init__.py │ │ │ ├── test_auc.py │ │ │ ├── test_mean.py │ │ │ ├── test_sum.py │ │ │ └── test_throughput.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── test_accuracy.py │ │ │ ├── test_auprc.py │ │ │ ├── test_auroc.py │ │ │ ├── test_binned_auprc.py │ │ │ ├── test_binned_auroc.py │ │ │ ├── test_binned_precision_recall_curve.py │ │ │ ├── test_confusion_matrix.py │ │ │ ├── test_f1_score.py │ │ │ ├── test_normalized_entropy.py │ │ │ ├── test_precision.py │ │ │ ├── test_precision_recall_curve.py │ │ │ ├── test_recall.py │ │ │ └── test_recall_at_fixed_precision.py │ │ ├── image │ │ │ ├── __init__.py │ │ │ └── test_psnr.py │ │ ├── ranking │ │ │ ├── __init__.py │ │ │ ├── test_click_through_rate.py │ │ │ ├── test_frequency.py │ │ │ ├── test_hit_rate.py │ │ │ ├── test_num_collisions.py │ │ │ ├── test_reciprocal_rank.py │ │ │ ├── test_retrieval_precision.py │ │ │ ├── test_retrieval_recall.py │ │ │ └── test_weighted_calibration.py │ │ ├── regression │ │ │ ├── __init__.py │ │ │ ├── test_mean_squared_error.py │ │ │ └── test_r2_score.py │ │ ├── statistical │ │ │ ├── __init__.py │ │ │ └── test_wasserstein.py │ │ └── text │ │ │ ├── __init__.py │ │ │ ├── test_bleu.py │ │ │ ├── test_perplexity.py │ │ │ ├── test_word_error_rate.py │ │ │ ├── test_word_information_lost.py │ │ │ └── test_word_information_preserved.py │ ├── image │ │ ├── test_fid.py │ │ ├── test_psnr.py │ │ └── test_ssim.py │ ├── ranking │ │ ├── __init__.py │ │ ├── test_click_through_rate.py │ │ ├── test_hit_rate.py │ │ ├── test_reciprocal_rank.py │ │ ├── test_retrieval_precision.py │ │ └── test_weighted_calibration.py │ ├── regression │ │ ├── __init__.py │ │ ├── test_mean_squared_error.py │ │ └── test_r2_score.py │ ├── statistical │ │ ├── __init__.py │ │ └── test_wasserstein.py │ ├── test_metric.py │ ├── test_synclib.py │ ├── test_toolkit.py │ ├── text │ │ ├── __init__.py │ │ ├── test_bleu.py │ │ ├── test_perplexity.py │ │ ├── test_word_error_rate.py │ │ ├── test_word_information_lost.py │ │ └── test_word_information_preserved.py │ └── window │ │ ├── test_auroc.py │ │ ├── test_click_through_rate.py │ │ ├── test_mean_squared_error.py │ │ ├── test_normalized_entropy.py │ │ └── test_weighted_calibration.py └── utils │ └── test_random_data.py └── torcheval ├── __init__.py ├── metrics ├── __init__.py ├── aggregation │ ├── __init__.py │ ├── auc.py │ ├── cat.py │ ├── cov.py │ ├── max.py │ ├── mean.py │ ├── min.py │ ├── sum.py │ └── throughput.py ├── audio │ ├── __init__.py │ └── fad.py ├── classification │ ├── __init__.py │ ├── accuracy.py │ ├── auprc.py │ ├── auroc.py │ ├── binary_normalized_entropy.py │ ├── binned_auprc.py │ ├── binned_auroc.py │ ├── binned_precision_recall_curve.py │ ├── confusion_matrix.py │ ├── f1_score.py │ ├── precision.py │ ├── precision_recall_curve.py │ ├── recall.py │ └── recall_at_fixed_precision.py ├── functional │ ├── __init__.py │ ├── aggregation │ │ ├── __init__.py │ │ ├── auc.py │ │ ├── mean.py │ │ ├── sum.py │ │ └── throughput.py │ ├── classification │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── auprc.py │ │ ├── auroc.py │ │ ├── binary_normalized_entropy.py │ │ ├── binned_auprc.py │ │ ├── binned_auroc.py │ │ ├── binned_precision_recall_curve.py │ │ ├── confusion_matrix.py │ │ ├── f1_score.py │ │ ├── precision.py │ │ ├── precision_recall_curve.py │ │ ├── recall.py │ │ └── recall_at_fixed_precision.py │ ├── frechet.py │ ├── image │ │ ├── __init__.py │ │ └── psnr.py │ ├── ranking │ │ ├── __init__.py │ │ ├── click_through_rate.py │ │ ├── frequency.py │ │ ├── hit_rate.py │ │ ├── num_collisions.py │ │ ├── reciprocal_rank.py │ │ ├── retrieval_precision.py │ │ ├── retrieval_recall.py │ │ └── weighted_calibration.py │ ├── regression │ │ ├── __init__.py │ │ ├── mean_squared_error.py │ │ └── r2_score.py │ ├── statistical │ │ ├── __init__.py │ │ └── wasserstein.py │ ├── tensor_utils.py │ └── text │ │ ├── __init__.py │ │ ├── bleu.py │ │ ├── helper.py │ │ ├── perplexity.py │ │ ├── word_error_rate.py │ │ ├── word_information_lost.py │ │ └── word_information_preserved.py ├── image │ ├── __init__.py │ ├── fid.py │ ├── psnr.py │ └── ssim.py ├── metric.py ├── ranking │ ├── __init__.py │ ├── click_through_rate.py │ ├── hit_rate.py │ ├── reciprocal_rank.py │ ├── retrieval_precision.py │ ├── retrieval_recall.py │ └── weighted_calibration.py ├── regression │ ├── __init__.py │ ├── mean_squared_error.py │ └── r2_score.py ├── statistical │ ├── __init__.py │ └── wasserstein.py ├── synclib.py ├── text │ ├── __init__.py │ ├── bleu.py │ ├── perplexity.py │ ├── word_error_rate.py │ ├── word_information_lost.py │ └── word_information_preserved.py ├── toolkit.py └── window │ ├── __init__.py │ ├── auroc.py │ ├── click_through_rate.py │ ├── mean_squared_error.py │ ├── normalized_entropy.py │ └── weighted_calibration.py ├── py.typed ├── utils ├── __init__.py ├── random_data.py └── test_utils │ ├── __init__.py │ ├── dummy_metric.py │ └── metric_class_tester.py └── version.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adopt 3 | select = B,C,E,F,P,T4,W,B9 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | ignore = 8 | E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 9 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 10 | # to line this up with executable bit 11 | EXE001, 12 | optional-ascii-coding = True 13 | exclude = 14 | ./.git, 15 | ./docs 16 | ./build 17 | ./scripts, 18 | ./venv, 19 | *.pyi 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the 9 | existing and past issues](https://github.com/pytorch-labs/torcheval/issues?q=is%3Aissue+sort%3Acreated-desc+). 10 | - type: textarea 11 | attributes: 12 | label: 🐛 Describe the bug 13 | description: | 14 | Please provide a clear and concise description of what the bug is. 15 | 16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: 17 | 18 | ```python 19 | # All necessary imports at the beginning 20 | import torch 21 | import torcheval 22 | 23 | # A succinct reproducing example trimmed down to the essential parts 24 | 25 | ``` 26 | 27 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 28 | 29 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 30 | placeholder: | 31 | A clear and concise description of what the bug is. 32 | 33 | ```python 34 | Sample code to reproduce the problem 35 | ``` 36 | 37 | ``` 38 | The error message you got, with the full traceback. 39 | ``` 40 | validations: 41 | required: true 42 | - type: textarea 43 | attributes: 44 | label: Versions 45 | description: | 46 | Please run the following and paste the output below. Make sure the version numbers of all relevant packages (e.g. torch, torcheval, other domain packages) are included. 47 | ```sh 48 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 49 | # For security purposes, please check the contents of collect_env.py before running it. 50 | python collect_env.py 51 | ``` 52 | validations: 53 | required: true 54 | 55 | - type: markdown 56 | attributes: 57 | value: > 58 | Thanks for contributing 🎉! 59 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to inline documnetation 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 The doc issue 8 | description: > 9 | A clear and concise description of what content is an issue. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Suggest a potential alternative/fix 15 | description: > 16 | Tell us how we could improve the documentation in this regard. 17 | - type: markdown 18 | attributes: 19 | value: > 20 | Thanks for contributing 🎉! 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new feature 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 🚀 The feature 8 | description: > 9 | A clear and concise description of the feature proposal 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Motivation, pitch 15 | description: > 16 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., 17 | *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link 18 | here too. 19 | validations: 20 | required: true 21 | - type: textarea 22 | attributes: 23 | label: Alternatives 24 | description: > 25 | A description of any alternative solutions or features you've considered, if any. 26 | - type: textarea 27 | attributes: 28 | label: Additional context 29 | description: > 30 | Add any other context or screenshots about the feature request. 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉! 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-support.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Help Support 2 | description: Do you need help/support? Send us your questions. 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 Question 8 | description: > 9 | Description of your question or what you need support with. 10 | validations: 11 | required: true 12 | - type: markdown 13 | attributes: 14 | value: > 15 | Thanks for contributing 🎉! 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please read through our [contribution guide](https://github.com/pytorch-labs/torcheval/blob/main/CONTRIBUTING.md) prior to creating your pull request. 2 | 3 | Summary: 4 | 5 | 6 | Test plan: 7 | 8 | 9 | Fixes #{issue number} 10 | 11 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Update Docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | # Allow one concurrent deployment 8 | concurrency: 9 | group: "pages" 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | build_docs: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out repo 17 | uses: actions/checkout@v2 18 | - name: Setup conda env 19 | uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | miniconda-version: "latest" 22 | activate-environment: test 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | conda activate test 28 | conda install pytorch cpuonly -c pytorch-nightly 29 | pip install -r requirements.txt 30 | pip install -r dev-requirements.txt 31 | python setup.py sdist bdist_wheel 32 | pip install dist/*.whl 33 | - name: Build docs 34 | shell: bash -l {0} 35 | run: | 36 | set -eux 37 | conda activate test 38 | cd docs 39 | pip install -r requirements.txt 40 | make html 41 | cd .. 42 | - name: Deploy docs to Github pages 43 | uses: JamesIves/github-pages-deploy-action@v4.4.1 44 | with: 45 | branch: gh-pages # The branch the action should deploy to. 46 | folder: docs/build/html # The folder the action should deploy. 47 | target-folder: main 48 | -------------------------------------------------------------------------------- /.github/workflows/nightly_build_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: Push CPU Binary Nightly 2 | 3 | on: 4 | # run every day at 11:15am 5 | schedule: 6 | - cron: '15 11 * * *' 7 | # or manually trigger it 8 | workflow_dispatch: 9 | 10 | 11 | jobs: 12 | unit_tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.8, 3.9, "3.10"] 17 | steps: 18 | - name: Check out repo 19 | uses: actions/checkout@v2 20 | - name: Setup conda env 21 | uses: conda-incubator/setup-miniconda@v2 22 | with: 23 | miniconda-version: "latest" 24 | activate-environment: test 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | shell: bash -l {0} 28 | run: | 29 | set -eux 30 | conda activate test 31 | conda install pytorch torchaudio torchvision cpuonly -c pytorch-nightly 32 | pip install -r requirements.txt 33 | pip install -r dev-requirements.txt 34 | python setup.py sdist bdist_wheel 35 | pip install dist/*.whl 36 | - name: Run unit tests 37 | shell: bash -l {0} 38 | run: | 39 | set -eux 40 | conda activate test 41 | pytest tests -vv 42 | # TODO figure out how to deduplicate steps 43 | upload_to_pypi: 44 | needs: unit_tests 45 | runs-on: ubuntu-latest 46 | steps: 47 | - name: Check out repo 48 | uses: actions/checkout@v2 49 | - name: Setup conda env 50 | uses: conda-incubator/setup-miniconda@v2 51 | with: 52 | miniconda-version: "latest" 53 | activate-environment: test 54 | python-version: "3.10" 55 | - name: Install dependencies 56 | shell: bash -l {0} 57 | run: | 58 | set -eux 59 | conda activate test 60 | conda install pytorch cpuonly -c pytorch-nightly 61 | pip install -r requirements.txt 62 | pip install -r dev-requirements.txt 63 | pip install --no-build-isolation -e ".[dev]" 64 | - name: Upload to PyPI 65 | shell: bash -l {0} 66 | env: 67 | PYPI_USER: ${{ secrets.PYPI_USER }} 68 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 69 | run: | 70 | set -eux 71 | conda activate test 72 | pip install twine 73 | python setup.py --nightly sdist bdist_wheel 74 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose 75 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /.github/workflows/release_build.yaml: -------------------------------------------------------------------------------- 1 | name: Push Release to PyPi 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | unit_tests: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.8, 3.9, "3.10"] 12 | steps: 13 | - name: Check out repo 14 | uses: actions/checkout@v2 15 | - name: Setup conda env 16 | uses: conda-incubator/setup-miniconda@v2 17 | with: 18 | miniconda-version: "latest" 19 | activate-environment: test 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | shell: bash -l {0} 23 | run: | 24 | set -eux 25 | conda activate test 26 | conda install pytorch torchaudio torchvision cpuonly -c pytorch-nightly 27 | pip install -r requirements.txt 28 | pip install -r dev-requirements.txt 29 | python setup.py sdist bdist_wheel 30 | pip install dist/*.whl 31 | - name: Run unit tests 32 | shell: bash -l {0} 33 | run: | 34 | set -eux 35 | conda activate test 36 | pytest tests -vv 37 | # TODO figure out how to deduplicate steps 38 | upload_to_pypi: 39 | needs: unit_tests 40 | runs-on: ubuntu-latest 41 | steps: 42 | - name: Check out repo 43 | uses: actions/checkout@v2 44 | - name: Setup conda env 45 | uses: conda-incubator/setup-miniconda@v2 46 | with: 47 | miniconda-version: "latest" 48 | activate-environment: test 49 | python-version: "3.10" 50 | - name: Install dependencies 51 | shell: bash -l {0} 52 | run: | 53 | set -eux 54 | conda activate test 55 | conda install pytorch cpuonly -c pytorch-nightly 56 | pip install -r requirements.txt 57 | pip install -r dev-requirements.txt 58 | pip install --no-build-isolation -e ".[dev]" 59 | - name: Upload to PyPI 60 | shell: bash -l {0} 61 | env: 62 | PYPI_USER: ${{ secrets.PYPI_USER_RELEASE }} 63 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_RELEASE }} 64 | run: | 65 | set -eux 66 | conda activate test 67 | pip install twine 68 | python setup.py sdist bdist_wheel 69 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose 70 | -------------------------------------------------------------------------------- /.github/workflows/release_build_docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build Docs for New Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | RELEASE_TAG: 7 | description: 'Tag name for this release' 8 | required: true 9 | type: string 10 | DOCS_DIRECTORY: 11 | description: 'Directory which will store the compiled docs' 12 | required: true 13 | type: string 14 | 15 | # Allow one concurrent deployment 16 | concurrency: 17 | group: "pages" 18 | cancel-in-progress: true 19 | 20 | env: 21 | RELEASE_TAG: ${{ inputs.RELEASE_TAG }} 22 | DOCS_DIRECTORY: ${{ inputs.DOCS_DIRECTORY }} 23 | 24 | jobs: 25 | build_docs: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: Check out repo 29 | uses: actions/checkout@v2 30 | with: 31 | ref: ${{ env.RELEASE_TAG }} 32 | - name: Setup conda env 33 | uses: conda-incubator/setup-miniconda@v2 34 | with: 35 | miniconda-version: "latest" 36 | activate-environment: test 37 | - name: Install dependencies 38 | shell: bash -l {0} 39 | run: | 40 | set -eux 41 | conda activate test 42 | conda install pytorch cpuonly -c pytorch-nightly 43 | pip install -r requirements.txt 44 | pip install -r dev-requirements.txt 45 | python setup.py sdist bdist_wheel 46 | pip install dist/*.whl 47 | - name: Build docs 48 | shell: bash -l {0} 49 | run: | 50 | set -eux 51 | conda activate test 52 | cd docs 53 | pip install -r requirements.txt 54 | make html 55 | cd .. 56 | - name: Deploy docs to Github pages 57 | uses: JamesIves/github-pages-deploy-action@v4.4.1 58 | with: 59 | branch: gh-pages # The branch the action should deploy to. 60 | folder: docs/build/html # The folder the action should deploy. 61 | target-folder: ${{ env.DOCS_DIRECTORY }} 62 | update_stable_link: 63 | needs: build_docs 64 | runs-on: ubuntu-latest 65 | steps: 66 | - name: Check out repo 67 | uses: actions/checkout@v2 68 | with: 69 | ref: gh-pages 70 | - name: Create symbolic link to latest release 71 | run: | 72 | ln -s ${{ env.DOCS_DIRECTORY }} stable 73 | - name: Add symbolic link to latest release 74 | run: | 75 | git add stable 76 | - name: Commit symbolic link 77 | run: | 78 | git commit -m "Update symbolic link to latest release" 79 | - name: Push changes 80 | uses: ad-m/github-push-action@0fafdd62b84042d49ec0cb92d9cac7f7ce4ec79e 81 | with: 82 | branch: gh-pages 83 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yaml: -------------------------------------------------------------------------------- 1 | name: unit test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | jobs: 9 | unit_tests: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.8, 3.9] 14 | steps: 15 | - name: Check out repo 16 | uses: actions/checkout@v2 17 | - name: Setup conda env 18 | uses: conda-incubator/setup-miniconda@v2 19 | with: 20 | miniconda-version: "latest" 21 | activate-environment: test 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu 28 | pip install -r requirements.txt 29 | pip install -r dev-requirements.txt 30 | pip install --no-build-isolation -e ".[dev]" 31 | - name: Run unit tests with coverage 32 | shell: bash -l {0} 33 | run: | 34 | set -eux 35 | pytest --cov=. --cov-report xml tests -vv 36 | - name: Upload Coverage to Codecov 37 | uses: codecov/codecov-action@v2 38 | 39 | gpu_unit_tests: 40 | runs-on: ${{ matrix.os }} 41 | strategy: 42 | matrix: 43 | os: [linux.8xlarge.nvidia.gpu] 44 | python-version: [3.8] 45 | cuda-tag: ["cu11"] 46 | steps: 47 | - name: Check out repo 48 | uses: actions/checkout@v2 49 | - name: Setup conda env 50 | uses: conda-incubator/setup-miniconda@v2 51 | with: 52 | miniconda-version: "latest" 53 | activate-environment: test 54 | python-version: ${{ matrix.python-version }} 55 | - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG 56 | uses: pytorch/test-infra/.github/actions/setup-nvidia@main 57 | - name: Display EC2 information 58 | shell: bash 59 | run: | 60 | set -euo pipefail 61 | function get_ec2_metadata() { 62 | # Pulled from instance metadata endpoint for EC2 63 | # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html 64 | category=$1 65 | curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" 66 | } 67 | echo "ami-id: $(get_ec2_metadata ami-id)" 68 | echo "instance-id: $(get_ec2_metadata instance-id)" 69 | echo "instance-type: $(get_ec2_metadata instance-type)" 70 | - name: Install dependencies 71 | shell: bash -l {0} 72 | run: | 73 | set -eux 74 | conda activate test 75 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 76 | # Use stable fbgemm-gpu 77 | pip uninstall -y fbgemm-gpu-nightly 78 | pip install fbgemm-gpu==0.2.0 79 | pip install -r requirements.txt 80 | pip install -r dev-requirements.txt 81 | pip install --no-build-isolation -e ".[dev]" 82 | - name: Run unit tests with coverage 83 | shell: bash -l {0} 84 | run: | 85 | set -eux 86 | conda activate test 87 | pytest --timeout=60 --cov=. --cov-report xml -vv -rA -m "gpu_only or cpu_and_gpu" tests 88 | - name: Upload coverage to codecov 89 | uses: codecov/codecov-action@v2 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | 156 | # MacOS 157 | .DS_Store 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.1.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-added-large-files 12 | args: ['--maxkb=500'] 13 | - id: end-of-file-fixer 14 | exclude: '.*\.rst' 15 | 16 | - repo: https://github.com/Lucas-C/pre-commit-hooks 17 | rev: v1.1.7 18 | hooks: 19 | - id: insert-license 20 | files: \.py$ 21 | args: 22 | - --license-filepath 23 | - docs/license_header.txt 24 | 25 | - repo: https://github.com/pycqa/flake8 26 | rev: 6.1.0 27 | hooks: 28 | - id: flake8 29 | args: 30 | - --config=.flake8 31 | 32 | - repo: https://github.com/omnilib/ufmt 33 | rev: v2.5.1 34 | hooks: 35 | - id: ufmt 36 | additional_dependencies: 37 | - black == 24.2.0 38 | - usort == 1.0.2 39 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torcheval 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Development Installation 6 | To get the development installation with all the necessary dependencies for 7 | linting, testing, and building the documentation, run the following: 8 | ```bash 9 | git clone https://github.com/pytorch/torcheval 10 | cd torcheval 11 | pip install -r requirements.txt 12 | pip install -r dev-requirements.txt 13 | pip install -r docs/requirements.txt 14 | pip install --no-build-isolation -e ".[dev]" 15 | ``` 16 | 17 | ## Pull Requests 18 | We actively welcome your pull requests. 19 | 20 | 1. Create your branch from `main`. 21 | 2. If you've added code that should be tested, add tests. 22 | 3. If you've changed APIs, update the documentation. 23 | - To build docs 24 | ```bash 25 | cd docs; make html 26 | ``` 27 | - To view docs 28 | ```bash 29 | cd build/html; python -m http.server 30 | ``` 31 | 4. Ensure the test suite passes. 32 | - To run all tests 33 | ```bash 34 | python -m pytest tests/ 35 | ``` 36 | - To run a single test 37 | ```bash 38 | python -m pytest -v tests/metrics/test_metric.py::MetricBaseClassTest::test_add_state_invalid 39 | ``` 40 | 41 | 5. Make sure your code lints. 42 | ```bash 43 | pre-commit run --all-files 44 | ``` 45 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 46 | 47 | ## Contributor License Agreement ("CLA") 48 | In order to accept your pull request, we need you to submit a CLA. You only need 49 | to do this once to work on any of Meta's open source projects. 50 | 51 | Complete your CLA here: 52 | 53 | ## Issues 54 | We use GitHub issues to track public bugs. Please ensure your description is 55 | clear and has sufficient instructions to be able to reproduce the issue. 56 | 57 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 58 | disclosure of security bugs. In those cases, please go through the process 59 | outlined on that page and do not file a public issue. 60 | 61 | ## License 62 | By contributing to torcheval, you agree that your contributions will be licensed 63 | under the LICENSE file in the root directory of this source tree. 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For torcheval software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pre-commit 3 | pytest 4 | pytest-timeout 5 | pytest-cov 6 | Cython>=0.28.5 7 | scikit-learn>=0.22 8 | scikit-image==0.18.3 9 | torchtnt-nightly 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | python update_docs.py 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # This script shows how to build the docs. 9 | # 1. First ensure you have the installed requirements for torcheval in `requirements.txt` 10 | # 2. Then make sure you have installed the requirements inside `docs/requirements.txt` 11 | # 3. Finally cd into docs/ and source this script. Sphinx reads through the installed module 12 | # pull docstrings, so this script just installs the current version of torcheval on your 13 | # system before it builds the docs with `make html` 14 | cd .. || exit 15 | pip install --no-build-isolation . 16 | cd docs || exit 17 | make html 18 | -------------------------------------------------------------------------------- /docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0.1 2 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 3 | -------------------------------------------------------------------------------- /docs/source/_static/css/torcheval.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #redirect-banner { 10 | border-bottom: 1px solid #e2e2e2; 11 | } 12 | #redirect-banner > p { 13 | margin: 0.8rem; 14 | text-align: center; 15 | } 16 | -------------------------------------------------------------------------------- /docs/source/_static/js/torcheval.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | const NETWORK_TEST_URL = 'https://staticdocs.thefacebook.com/ping'; 10 | fetch(NETWORK_TEST_URL).then(() => { 11 | $("#redirect-banner").prependTo("body").show(); 12 | }); 13 | -------------------------------------------------------------------------------- /docs/source/ext/fbcode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | from docutils import nodes 11 | from sphinx.util.docutils import SphinxDirective 12 | from sphinx.util.nodes import nested_parse_with_titles 13 | 14 | 15 | class FbcodeDirective(SphinxDirective): 16 | # this enables content in the directive 17 | has_content = True 18 | 19 | def run(self): 20 | if "fbcode" not in os.getcwd(): 21 | return [] 22 | node = nodes.section() 23 | node.document = self.state.document 24 | nested_parse_with_titles(self.state, self.content, node) 25 | return node.children 26 | 27 | 28 | def setup(app): 29 | app.add_directive("fbcode", FbcodeDirective) 30 | 31 | return { 32 | "version": "0.1", 33 | "parallel_read_safe": True, 34 | "parallel_write_safe": True, 35 | } 36 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | TorchEval 2 | =========================================== 3 | 4 | A library with simple and straightforward tooling for model evaluations and a delightful user experience. At a high level TorchEval: 5 | 6 | 1. Contains a rich collection of high performance metric calculations out of the box. We utilize vectorization and GPU acceleration where possible via PyTorch. 7 | 2. Integrates seamlessly with distributed training and tools using `torch.distributed `_ 8 | 3. Is designed with extensibility in mind: you have the freedom to easily create your own metrics and leverage our toolkit. 9 | 4. Provides tools for profiling memory and compute requirements for PyTorch based models. 10 | 11 | QuickStart 12 | =========================================== 13 | 14 | Installing 15 | ----------------- 16 | 17 | TorchEval can be installed from PyPi via 18 | 19 | .. code-block:: console 20 | 21 | pip install torcheval 22 | 23 | or from github 24 | 25 | .. code-block:: console 26 | 27 | git clone https://github.com/pytorch/torcheval 28 | cd torcheval 29 | pip install -r requirements.txt 30 | python setup.py install 31 | 32 | Usage 33 | ----------------- 34 | 35 | TorchEval provides two interfaces to each metric. If you are working in a single process environment, it is simplest to use metrics from the ``functional`` submodule. These can be found in ``torcheval.metrics.functional``. 36 | 37 | .. code-block:: python 38 | 39 | from torcheval.metrics.functional import binary_f1_score 40 | predictions = model(inputs) 41 | f1_score = binary_f1_score(predictions, targets) 42 | 43 | We can use the same metric in the class based route, which provides tools that make computation simple in a multi-process setting. On a single device, you can use the class based metrics as follows: 44 | 45 | .. code-block:: python 46 | 47 | from torcheval.metrics import BinaryF1Score 48 | predictions = model(inputs) 49 | metric = BinaryF1Score() 50 | metric.update(predictions, targets) 51 | f1_score = metric.compute() 52 | 53 | In a multi-process setting, the data from each process must be synchronized to compute the metric across the full dataset. To do this, simply replace ``metric.compute()`` with ``sync_and_compute(metric)``: 54 | 55 | .. code-block:: python 56 | 57 | from torcheval.metrics import BinaryF1Score 58 | from torcheval.metrics.toolkit import sync_and_compute 59 | predictions = model(inputs) 60 | metric = BinaryF1Score() 61 | metric.update(predictions, targets) 62 | f1_score = sync_and_compute(metric) 63 | 64 | Read more about the class based method in the distributed example. 65 | 66 | Further Reading 67 | ----------------- 68 | * Check out the guides explaining the compute example 69 | * Check out the distributed example 70 | * Check out how to make your own metric 71 | 72 | Indices and tables 73 | ================== 74 | 75 | * :ref:`genindex` 76 | * :ref:`modindex` 77 | * :ref:`search` 78 | 79 | Getting Started 80 | ------------------- 81 | .. fbcode:: 82 | 83 | .. toctree:: 84 | :maxdepth: 2 85 | :caption: Getting Started (Meta) 86 | :glob: 87 | 88 | meta/getting_started.rst 89 | 90 | .. toctree:: 91 | :maxdepth: 2 92 | :caption: Migration (Meta) 93 | :glob: 94 | 95 | meta/migrating_to_torcheval.rst 96 | 97 | TorchEval Tutorials 98 | ------------------- 99 | .. toctree:: 100 | :maxdepth: 2 101 | :caption: Examples: 102 | 103 | QuickStart Notebook 104 | metric_example.rst 105 | 106 | TorchEval API 107 | ----------------- 108 | 109 | .. toctree:: 110 | :maxdepth: 2 111 | :caption: Contents: 112 | 113 | torcheval.metrics.rst 114 | torcheval.metrics.functional.rst 115 | torcheval.metrics.toolkit.rst 116 | -------------------------------------------------------------------------------- /docs/source/templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {%- block extrabody %} 4 | {% if not fbcode %} 5 | 11 | {% endif %} 12 | {%- endblock %} 13 | -------------------------------------------------------------------------------- /docs/source/torcheval.metrics.functional.rst: -------------------------------------------------------------------------------- 1 | Functional Metrics 2 | ================== 3 | 4 | .. automodule:: torcheval.metrics.functional 5 | 6 | Aggregation Metrics 7 | ------------------------------------------------------------------- 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | auc 14 | mean 15 | sum 16 | throughput 17 | 18 | Classification Metrics 19 | ------------------------------------------------------------------- 20 | 21 | .. autosummary:: 22 | :toctree: generated 23 | :nosignatures: 24 | 25 | binary_accuracy 26 | binary_auprc 27 | binary_auroc 28 | binary_binned_auroc 29 | binary_binned_precision_recall_curve 30 | binary_confusion_matrix 31 | binary_f1_score 32 | binary_normalized_entropy 33 | binary_precision 34 | binary_precision_recall_curve 35 | binary_recall 36 | binary_recall_at_fixed_precision 37 | multiclass_accuracy 38 | multiclass_auprc 39 | multiclass_auroc 40 | multiclass_binned_auroc 41 | multiclass_binned_precision_recall_curve 42 | multiclass_confusion_matrix 43 | multiclass_f1_score 44 | multiclass_precision 45 | multiclass_precision_recall_curve 46 | multiclass_recall 47 | multilabel_accuracy 48 | multilabel_auprc 49 | multilabel_precision_recall_curve 50 | multilabel_recall_at_fixed_precision 51 | topk_multilabel_accuracy 52 | 53 | Image Metrics 54 | ------------------------------------------------------------------- 55 | 56 | .. autosummary:: 57 | :toctree: generated 58 | :nosignatures: 59 | 60 | peak_signal_noise_ratio 61 | 62 | Ranking Metrics 63 | ------------------------------------------------------------------- 64 | 65 | .. autosummary:: 66 | :toctree: generated 67 | :nosignatures: 68 | 69 | click_through_rate 70 | frequency_at_k 71 | hit_rate 72 | num_collisions 73 | reciprocal_rank 74 | weighted_calibration 75 | 76 | Regression Metrics 77 | ------------------------------------------------------------------- 78 | 79 | .. autosummary:: 80 | :toctree: generated 81 | :nosignatures: 82 | 83 | mean_squared_error 84 | r2_score 85 | 86 | Text Metrics 87 | ------------------------------------------------------------------- 88 | 89 | .. autosummary:: 90 | :toctree: generated 91 | :nosignatures: 92 | 93 | bleu_score 94 | perplexity 95 | word_error_rate 96 | word_information_preserved 97 | word_information_lost 98 | -------------------------------------------------------------------------------- /docs/source/torcheval.metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ============= 3 | 4 | .. automodule:: torcheval.metrics 5 | 6 | 7 | Aggregation Metrics 8 | ------------------------------------------------------------------- 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | AUC 15 | Cat 16 | Max 17 | Mean 18 | Min 19 | Sum 20 | Throughput 21 | 22 | Audio Metrics 23 | ------------------------------------------------------------------- 24 | 25 | .. autosummary:: 26 | :toctree: generated 27 | :nosignatures: 28 | 29 | FrechetAudioDistance 30 | 31 | Classification Metrics 32 | ------------------------------------------------------------------- 33 | 34 | .. autosummary:: 35 | :toctree: generated 36 | :nosignatures: 37 | 38 | BinaryAccuracy 39 | BinaryAUPRC 40 | BinaryAUROC 41 | BinaryBinnedAUROC 42 | BinaryBinnedPrecisionRecallCurve 43 | BinaryConfusionMatrix 44 | BinaryF1Score 45 | BinaryNormalizedEntropy 46 | BinaryPrecision 47 | BinaryPrecisionRecallCurve 48 | BinaryRecall 49 | BinaryRecallAtFixedPrecision 50 | MulticlassAccuracy 51 | MulticlassAUPRC 52 | MulticlassAUROC 53 | MulticlassBinnedAUROC 54 | MulticlassBinnedPrecisionRecallCurve 55 | MulticlassConfusionMatrix 56 | MulticlassF1Score 57 | MulticlassPrecision 58 | MulticlassPrecisionRecallCurve 59 | MulticlassRecall 60 | MultilabelAccuracy 61 | MultilabelAUPRC 62 | MultilabelPrecisionRecallCurve 63 | MultilabelRecallAtFixedPrecision 64 | TopKMultilabelAccuracy 65 | 66 | Image Metrics 67 | ------------------------------------------------------------------- 68 | 69 | .. autosummary:: 70 | :toctree: generated 71 | :nosignatures: 72 | 73 | FrechetInceptionDistance 74 | PeakSignalNoiseRatio 75 | StructuralSimilarity 76 | 77 | Ranking Metrics 78 | ------------------------------------------------------------------- 79 | 80 | .. autosummary:: 81 | :toctree: generated 82 | :nosignatures: 83 | 84 | ClickThroughRate 85 | HitRate 86 | ReciprocalRank 87 | WeightedCalibration 88 | 89 | Regression Metrics 90 | ------------------------------------------------------------------- 91 | 92 | .. autosummary:: 93 | :toctree: generated 94 | :nosignatures: 95 | 96 | MeanSquaredError 97 | R2Score 98 | 99 | Text Metrics 100 | ------------------------------------------------------------------- 101 | 102 | .. autosummary:: 103 | :toctree: generated 104 | :nosignatures: 105 | 106 | BLEUScore 107 | Perplexity 108 | WordErrorRate 109 | WordInformationLost 110 | WordInformationPreserved 111 | 112 | Windowed Metrics 113 | ------------------------------------------------------------------- 114 | 115 | .. autosummary:: 116 | :toctree: generated 117 | :nosignatures: 118 | 119 | WindowedBinaryAUROC 120 | WindowedBinaryNormalizedEntropy 121 | WindowedClickThroughRate 122 | WindowedMeanSquaredError 123 | WindowedWeightedCalibration 124 | -------------------------------------------------------------------------------- /docs/source/torcheval.metrics.toolkit.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: torcheval.metrics.toolkit 2 | 3 | Metric Toolkit 4 | ================== 5 | 6 | .. automodule:: torcheval.metrics.toolkit 7 | :members: 8 | :undoc-members: 9 | -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[5]: Undefined variable type 10 | 11 | import torch 12 | from torch.utils.data.dataset import TensorDataset 13 | 14 | from torcheval.metrics import MulticlassAccuracy 15 | 16 | NUM_EPOCHS = 4 17 | NUM_BATCHES = 16 18 | BATCH_SIZE = 8 19 | 20 | 21 | class Model(torch.nn.Module): 22 | def __init__(self) -> None: 23 | super().__init__() 24 | self.layers = torch.nn.Sequential( 25 | torch.nn.Linear(128, 64), 26 | torch.nn.ReLU(), 27 | torch.nn.Linear(64, 32), 28 | torch.nn.ReLU(), 29 | torch.nn.Linear(32, 2), 30 | ) 31 | 32 | def forward(self, X: torch.Tensor) -> torch.Tensor: 33 | return self.layers(X) 34 | 35 | 36 | def prepare_dataloader() -> torch.utils.data.DataLoader: 37 | num_samples = NUM_BATCHES * BATCH_SIZE 38 | data = torch.randn(num_samples, 128) 39 | labels = torch.randint(low=0, high=2, size=(num_samples,)) 40 | return torch.utils.data.DataLoader( 41 | TensorDataset(data, labels), batch_size=BATCH_SIZE 42 | ) 43 | 44 | 45 | def main() -> None: 46 | torch.random.manual_seed(42) 47 | 48 | model = Model() 49 | optim = torch.optim.Adagrad(model.parameters(), lr=0.001) 50 | 51 | train_dataloader = prepare_dataloader() 52 | 53 | loss_fn = torch.nn.CrossEntropyLoss() 54 | metric = MulticlassAccuracy() 55 | 56 | compute_frequency = 4 57 | num_epochs_completed = 0 58 | 59 | while num_epochs_completed < NUM_EPOCHS: 60 | data_iter = iter(train_dataloader) 61 | batch_idx = 0 62 | while True: 63 | try: 64 | # get the next batch from data iterator 65 | input, target = next(data_iter) 66 | output = model(input) 67 | 68 | # metric.update() updates the metric state with new data 69 | metric.update(output, target) 70 | 71 | loss = loss_fn(output, target) 72 | optim.zero_grad() 73 | loss.backward() 74 | optim.step() 75 | 76 | if (batch_idx + 1) % compute_frequency == 0: 77 | print( 78 | "Epoch {}/{}, Batch {}/{} --- loss: {:.4f}, acc: {:.4f}".format( 79 | num_epochs_completed + 1, 80 | NUM_EPOCHS, 81 | batch_idx + 1, 82 | NUM_BATCHES, 83 | loss.item(), 84 | # metric.compute() returns metric value from all seen data 85 | metric.compute(), 86 | ) 87 | ) 88 | batch_idx += 1 89 | except StopIteration: 90 | break 91 | 92 | # metric.reset() cleans up all seen data 93 | metric.reset() 94 | 95 | num_epochs_completed += 1 96 | 97 | 98 | if __name__ == "__main__": 99 | main() # pragma: no cover 100 | -------------------------------------------------------------------------------- /image-requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | skimage 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | 5 | [tool.pytest.ini_options] 6 | markers =[ 7 | "cpu_and_gpu", 8 | "gpu_only", 9 | ] 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | typing_extensions 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | from datetime import date 12 | 13 | from setuptools import find_packages, setup 14 | from torcheval import __version__ 15 | 16 | 17 | def current_path(file_name: str) -> str: 18 | return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name)) 19 | 20 | 21 | def read_requirements(file_name: str) -> list[str]: 22 | with open(current_path(file_name), encoding="utf8") as f: 23 | return f.read().strip().split() 24 | 25 | 26 | def get_nightly_version() -> str: 27 | return date.today().strftime("%Y.%m.%d") 28 | 29 | 30 | def parse_args() -> argparse.Namespace: 31 | parser = argparse.ArgumentParser(description="torcheval setup") 32 | parser.add_argument( 33 | "--nightly", 34 | dest="nightly", 35 | action="store_true", 36 | help="enable settings for nightly package build", 37 | ) 38 | parser.set_defaults(nightly=False) 39 | return parser.parse_known_args() 40 | 41 | 42 | if __name__ == "__main__": 43 | with open(current_path("README.md"), encoding="utf8") as f: 44 | readme = f.read() 45 | 46 | custom_args, setup_args = parse_args() 47 | package_name = "torcheval" if not custom_args.nightly else "torcheval-nightly" 48 | version = __version__ if not custom_args.nightly else get_nightly_version() 49 | print(f"using package_name={package_name}, version={version}") 50 | 51 | sys.argv = [sys.argv[0]] + setup_args 52 | 53 | setup( 54 | name=package_name, 55 | version=version, 56 | author="torcheval team", 57 | author_email="yicongd@fb.com", 58 | description="A library for providing a simple interface to create new metrics and an easy-to-use toolkit for metric computations and checkpointing.", 59 | long_description=readme, 60 | long_description_content_type="text/markdown", 61 | url="https://github.com/pytorch/torcheval", 62 | license="BSD-3", 63 | keywords=["pytorch", "evaluation", "metrics"], 64 | python_requires=">=3.7", 65 | install_requires=read_requirements("requirements.txt"), 66 | packages=find_packages(), 67 | package_data={"torcheval": ["py.typed"]}, 68 | zip_safe=True, 69 | classifiers=[ 70 | "Development Status :: 2 - Pre-Alpha", 71 | "Intended Audience :: Developers", 72 | "Intended Audience :: Science/Research", 73 | "License :: OSI Approved :: BSD License", 74 | "Programming Language :: Python :: 3", 75 | "Programming Language :: Python :: 3.7", 76 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 77 | ], 78 | extras_require={ 79 | "dev": read_requirements("dev-requirements.txt"), 80 | "image": read_requirements("image-requirements.txt"), 81 | }, 82 | ) 83 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/test_cov.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | from torcheval.metrics import Covariance 12 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 13 | 14 | 15 | class TestCovariance(MetricClassTester): 16 | def _test_covariance_with_input(self, batching: list[int]) -> None: 17 | gen = torch.Generator() 18 | gen.manual_seed(3) 19 | X = torch.randn(sum(batching), 4, generator=gen) 20 | self.run_class_implementation_tests( 21 | metric=Covariance(), 22 | state_names={"n", "sum", "ss_sum"}, 23 | update_kwargs={"obs": torch.split(X, batching, dim=0)}, 24 | compute_result=(X.mean(dim=0), torch.cov(X.T)), 25 | num_total_updates=len(batching), 26 | min_updates_before_compute=1, 27 | num_processes=4, 28 | ) 29 | 30 | def test_covariance_all_at_once(self) -> None: 31 | self._test_covariance_with_input([100, 100, 100, 100]) 32 | 33 | def test_covariance_one_by_one(self) -> None: 34 | self._test_covariance_with_input(list(range(2, 22))) 35 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/test_max.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics import Max 11 | from torcheval.utils.test_utils.metric_class_tester import ( 12 | BATCH_SIZE, 13 | MetricClassTester, 14 | NUM_TOTAL_UPDATES, 15 | ) 16 | 17 | 18 | class TestMax(MetricClassTester): 19 | def _test_max_class_with_input(self, input_val_tensor: torch.Tensor) -> None: 20 | self.run_class_implementation_tests( 21 | metric=Max(), 22 | state_names={"max"}, 23 | update_kwargs={"input": input_val_tensor}, 24 | compute_result=torch.max(input_val_tensor), 25 | ) 26 | 27 | def test_max_class_base(self) -> None: 28 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) 29 | self._test_max_class_with_input(input_val_tensor) 30 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4) 31 | self._test_max_class_with_input(input_val_tensor) 32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4) 33 | self._test_max_class_with_input(input_val_tensor) 34 | 35 | def test_max_class_update_input_dimension_different(self) -> None: 36 | self.run_class_implementation_tests( 37 | metric=Max(), 38 | state_names={"max"}, 39 | update_kwargs={ 40 | "input": [ 41 | torch.tensor(1.0), 42 | torch.tensor([2.0, 3.0, 5.0]), 43 | torch.tensor([-1.0, 2.0]), 44 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]), 45 | ] 46 | }, 47 | compute_result=torch.tensor(6.0), 48 | num_total_updates=4, 49 | num_processes=2, 50 | ) 51 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/test_min.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics import Min 11 | from torcheval.utils.test_utils.metric_class_tester import ( 12 | BATCH_SIZE, 13 | MetricClassTester, 14 | NUM_TOTAL_UPDATES, 15 | ) 16 | 17 | 18 | class TestMin(MetricClassTester): 19 | def _test_min_class_with_input(self, input_val_tensor: torch.Tensor) -> None: 20 | self.run_class_implementation_tests( 21 | metric=Min(), 22 | state_names={"min"}, 23 | update_kwargs={"input": input_val_tensor}, 24 | compute_result=torch.min(input_val_tensor), 25 | ) 26 | 27 | def test_min_class_base(self) -> None: 28 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) 29 | self._test_min_class_with_input(input_val_tensor) 30 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4) 31 | self._test_min_class_with_input(input_val_tensor) 32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4) 33 | self._test_min_class_with_input(input_val_tensor) 34 | 35 | def test_min_class_update_input_dimension_different(self) -> None: 36 | self.run_class_implementation_tests( 37 | metric=Min(), 38 | state_names={"min"}, 39 | update_kwargs={ 40 | "input": [ 41 | torch.tensor(1.0), 42 | torch.tensor([2.0, 3.0, 5.0]), 43 | torch.tensor([-1.0, 2.0]), 44 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]), 45 | ] 46 | }, 47 | compute_result=torch.tensor(-4.0), 48 | num_total_updates=4, 49 | num_processes=2, 50 | ) 51 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/test_sum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | from torcheval.metrics import Sum 12 | from torcheval.utils.test_utils.metric_class_tester import ( 13 | BATCH_SIZE, 14 | MetricClassTester, 15 | NUM_TOTAL_UPDATES, 16 | ) 17 | 18 | 19 | class TestSum(MetricClassTester): 20 | def _test_sum_class_with_input(self, input_val_tensor: torch.Tensor) -> None: 21 | self.run_class_implementation_tests( 22 | metric=Sum(), 23 | state_names={"weighted_sum"}, 24 | update_kwargs={"input": input_val_tensor}, 25 | compute_result=torch.sum(input_val_tensor).to(torch.float64), 26 | ) 27 | 28 | def test_sum_class_base(self) -> None: 29 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) 30 | self._test_sum_class_with_input(input_val_tensor) 31 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4) 32 | self._test_sum_class_with_input(input_val_tensor) 33 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4) 34 | self._test_sum_class_with_input(input_val_tensor) 35 | 36 | def test_sum_class_update_input_dimension_different(self) -> None: 37 | self.run_class_implementation_tests( 38 | metric=Sum(), 39 | state_names={"weighted_sum"}, 40 | update_kwargs={ 41 | "input": [ 42 | torch.tensor(1.0), 43 | torch.tensor([2.0, 3.0, 5.0]), 44 | torch.tensor([-1.0, 2.0]), 45 | torch.tensor([[1.0, 6.0], [2.0, -4.0]]), 46 | ] 47 | }, 48 | compute_result=torch.tensor(17.0, dtype=torch.float64), 49 | num_total_updates=4, 50 | num_processes=2, 51 | ) 52 | 53 | def test_sum_class_update_input_valid_weight(self) -> None: 54 | update_inputs = [ 55 | torch.rand(BATCH_SIZE), 56 | torch.rand(BATCH_SIZE, 4), 57 | torch.rand(BATCH_SIZE, 3, 4), 58 | torch.rand(5), 59 | torch.rand(10), 60 | ] 61 | update_weights = [ 62 | torch.rand(BATCH_SIZE), 63 | torch.rand(BATCH_SIZE, 4), 64 | torch.rand(BATCH_SIZE, 3, 4), 65 | 0.8, 66 | 2, 67 | ] 68 | 69 | def _compute_result( 70 | update_inputs: list[torch.Tensor], 71 | update_weights: list[float | torch.Tensor], 72 | ) -> torch.Tensor: 73 | weighted_sum = torch.tensor(0.0, dtype=torch.float64) 74 | for v, w in zip(update_inputs, update_weights): 75 | if isinstance(w, torch.Tensor): 76 | w = w.numpy().flatten() 77 | weighted_sum += v.numpy().flatten().dot(w).sum() 78 | return weighted_sum 79 | 80 | self.run_class_implementation_tests( 81 | metric=Sum(), 82 | state_names={"weighted_sum"}, 83 | update_kwargs={ 84 | "input": update_inputs, 85 | "weight": update_weights, 86 | }, 87 | compute_result=_compute_result(update_inputs, update_weights), 88 | num_total_updates=5, 89 | num_processes=5, 90 | ) 91 | 92 | def test_sum_class_update_input_invalid_weight(self) -> None: 93 | metric = Sum() 94 | with self.assertRaisesRegex( 95 | ValueError, 96 | r"Weight must be either a float value or an int value or a tensor that matches the input tensor size.", 97 | ): 98 | metric.update(torch.tensor([2.0, 3.0]), weight=torch.tensor([0.5])) 99 | 100 | def test_sum_class_compute_without_update(self) -> None: 101 | metric = Sum() 102 | self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64)) 103 | -------------------------------------------------------------------------------- /tests/metrics/aggregation/test_throughput.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import random 10 | 11 | from torcheval.metrics import Throughput 12 | from torcheval.utils.test_utils.metric_class_tester import ( 13 | MetricClassTester, 14 | NUM_PROCESSES, 15 | NUM_TOTAL_UPDATES, 16 | ) 17 | 18 | 19 | class TestThroughput(MetricClassTester): 20 | def _test_throughput_class_with_input( 21 | self, 22 | num_processed: list[int], 23 | elapsed_time_sec: list[float], 24 | ) -> None: 25 | num_individual_update = NUM_TOTAL_UPDATES // NUM_PROCESSES 26 | expected_num_total = sum(num_processed) 27 | max_elapsed_time_sec = max( 28 | [ 29 | sum( 30 | elapsed_time_sec[ 31 | i * num_individual_update : (i + 1) * num_individual_update 32 | ] 33 | ) 34 | for i in range(NUM_PROCESSES) 35 | ] 36 | ) 37 | total_elapsed_time_sec = sum(elapsed_time_sec) 38 | 39 | expected_compute_result = (1.0 * expected_num_total) / total_elapsed_time_sec 40 | expected_merge_and_compute_result = ( 41 | 1.0 * expected_num_total 42 | ) / max_elapsed_time_sec 43 | self.run_class_implementation_tests( 44 | metric=Throughput(), 45 | state_names={"num_total", "elapsed_time_sec"}, 46 | update_kwargs={ 47 | "num_processed": num_processed, 48 | "elapsed_time_sec": elapsed_time_sec, 49 | }, 50 | compute_result=expected_compute_result, 51 | merge_and_compute_result=expected_merge_and_compute_result, 52 | ) 53 | 54 | def test_throughput_class_base(self) -> None: 55 | num_processed = [random.randint(0, 40) for _ in range(NUM_TOTAL_UPDATES)] 56 | elapsed_time_sec = [random.uniform(0.1, 5.0) for _ in range(NUM_TOTAL_UPDATES)] 57 | self._test_throughput_class_with_input(num_processed, elapsed_time_sec) 58 | 59 | def test_throughput_class_update_input_invalid_num_processed(self) -> None: 60 | metric = Throughput() 61 | with self.assertRaisesRegex( 62 | ValueError, 63 | r"Expected num_processed to be a non-negative number, but received", 64 | ): 65 | metric.update(-1, 1.0) 66 | 67 | def test_throughput_class_update_input_invalid_elapsed_time_sec(self) -> None: 68 | metric = Throughput() 69 | with self.assertRaisesRegex( 70 | ValueError, 71 | r"Expected elapsed_time_sec to be a positive number, but received", 72 | ): 73 | metric.update(42, -5.1) 74 | with self.assertRaisesRegex( 75 | ValueError, 76 | r"Expected elapsed_time_sec to be a positive number, but received", 77 | ): 78 | metric.update(42, 0.0) 79 | 80 | def test_throughput_class_compute_without_update(self) -> None: 81 | metric = Throughput() 82 | self.assertEqual(metric.compute(), 0.0) 83 | -------------------------------------------------------------------------------- /tests/metrics/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/aggregation/test_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import numpy as np 12 | import torch 13 | from torcheval.metrics.functional.aggregation import mean 14 | from torcheval.utils.test_utils.metric_class_tester import BATCH_SIZE, NUM_TOTAL_UPDATES 15 | 16 | 17 | class TestMean(unittest.TestCase): 18 | def _test_mean_with_input( 19 | self, 20 | val: torch.Tensor, 21 | weight: float | torch.Tensor = 1.0, 22 | ) -> None: 23 | torch.testing.assert_close( 24 | mean(val), 25 | torch.mean(val), 26 | equal_nan=True, 27 | atol=1e-8, 28 | rtol=1e-5, 29 | ) 30 | 31 | def test_mean_base(self) -> None: 32 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) 33 | self._test_mean_with_input(input_val_tensor) 34 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4) 35 | self._test_mean_with_input(input_val_tensor) 36 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4) 37 | self._test_mean_with_input(input_val_tensor) 38 | 39 | def test_mean_input_valid_weight(self) -> None: 40 | def _compute_result( 41 | val: torch.Tensor, weights: float | torch.Tensor 42 | ) -> torch.Tensor: 43 | # pyre-fixme[9]: val has type `Tensor`; used as `ndarray[Any, Any]`. 44 | val = val.numpy().flatten() 45 | if isinstance(weights, torch.Tensor): 46 | weights = weights.numpy().flatten() 47 | else: 48 | weights = weights * np.ones_like(val) 49 | weighted_mean = np.average(val, weights=weights) 50 | return torch.tensor(weighted_mean, dtype=torch.float32) 51 | 52 | inputs = [ 53 | torch.rand(1), 54 | torch.rand(BATCH_SIZE, 4), 55 | torch.rand(BATCH_SIZE, 3, 4), 56 | torch.rand(5), 57 | torch.rand(10), 58 | ] 59 | weights = [ 60 | torch.rand(1), 61 | torch.rand(BATCH_SIZE, 4), 62 | torch.rand(BATCH_SIZE, 3, 4), 63 | 0.8, 64 | 1, 65 | ] 66 | 67 | for input, weight in zip(inputs, weights): 68 | print(input) 69 | print(weight) 70 | torch.testing.assert_close( 71 | mean(input, weight), 72 | _compute_result(input, weight), 73 | equal_nan=True, 74 | atol=1e-8, 75 | rtol=1e-5, 76 | ) 77 | 78 | def test_mean_input_invalid_weight(self) -> None: 79 | with self.assertRaisesRegex( 80 | ValueError, 81 | r"Weight must be either a float value or a tensor that matches the input tensor size.", 82 | ): 83 | mean(torch.tensor([2.0, 3.0]), torch.tensor([0.5])) 84 | -------------------------------------------------------------------------------- /tests/metrics/functional/aggregation/test_sum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import sum 13 | from torcheval.utils.test_utils.metric_class_tester import BATCH_SIZE, NUM_TOTAL_UPDATES 14 | 15 | 16 | class TestSum(unittest.TestCase): 17 | def _test_sum_with_input( 18 | self, 19 | val: torch.Tensor, 20 | weight: float | torch.Tensor = 1.0, 21 | ) -> None: 22 | torch.testing.assert_close( 23 | sum(val), 24 | torch.sum(val), 25 | equal_nan=True, 26 | atol=1e-8, 27 | rtol=1e-5, 28 | ) 29 | 30 | def test_sum_base(self) -> None: 31 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE) 32 | self._test_sum_with_input(input_val_tensor) 33 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 4) 34 | self._test_sum_with_input(input_val_tensor) 35 | input_val_tensor = torch.rand(NUM_TOTAL_UPDATES, BATCH_SIZE, 3, 4) 36 | self._test_sum_with_input(input_val_tensor) 37 | 38 | def test_sum_input_valid_weight(self) -> None: 39 | def _compute_result( 40 | val: torch.Tensor, weight: float | torch.Tensor 41 | ) -> torch.Tensor: 42 | weighted_sum = torch.tensor(0.0) 43 | if isinstance(weight, torch.Tensor): 44 | weight = weight.numpy().flatten() 45 | weighted_sum += val.numpy().flatten().dot(weight).sum() 46 | 47 | return weighted_sum 48 | 49 | inputs = [ 50 | torch.rand(1), 51 | torch.rand(BATCH_SIZE, 4), 52 | torch.rand(BATCH_SIZE, 3, 4), 53 | torch.rand(5), 54 | torch.rand(10), 55 | ] 56 | weights = [ 57 | torch.rand(1), 58 | torch.rand(BATCH_SIZE, 4), 59 | torch.rand(BATCH_SIZE, 3, 4), 60 | 0.8, 61 | 2, 62 | ] 63 | 64 | for input, weight in zip(inputs, weights): 65 | torch.testing.assert_close( 66 | sum(input, weight), 67 | _compute_result(input, weight), 68 | equal_nan=True, 69 | atol=1e-8, 70 | rtol=1e-5, 71 | ) 72 | 73 | def test_sum_input_invalid_weight(self) -> None: 74 | with self.assertRaisesRegex( 75 | ValueError, 76 | r"Weight must be either a float value or an int value or a tensor that matches the input tensor size.", 77 | ): 78 | sum(torch.tensor([2.0, 3.0]), torch.tensor([0.5])) 79 | -------------------------------------------------------------------------------- /tests/metrics/functional/aggregation/test_throughput.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import random 10 | import unittest 11 | 12 | import torch 13 | from torcheval.metrics.functional import throughput 14 | from torcheval.utils.test_utils.metric_class_tester import NUM_PROCESSES 15 | 16 | 17 | class TestThroughput(unittest.TestCase): 18 | def _test_throughput_with_input( 19 | self, 20 | num_processed: int, 21 | elapsed_time_sec: float, 22 | ) -> None: 23 | torch.testing.assert_close( 24 | throughput(num_processed, elapsed_time_sec), 25 | torch.tensor(num_processed / elapsed_time_sec), 26 | equal_nan=True, 27 | atol=1e-8, 28 | rtol=1e-5, 29 | ) 30 | 31 | def test_throughput_base(self) -> None: 32 | num_processed = NUM_PROCESSES 33 | elapsed_time_sec = random.random() * 20 34 | self._test_throughput_with_input(num_processed, elapsed_time_sec) 35 | 36 | def test_throughput_update_input_invalid_num_processed(self) -> None: 37 | with self.assertRaisesRegex( 38 | ValueError, 39 | r"Expected num_processed to be a non-negative number, but received", 40 | ): 41 | throughput(-1, 1.0) 42 | 43 | def test_throughput_update_input_invalid_elapsed_time_sec(self) -> None: 44 | with self.assertRaisesRegex( 45 | ValueError, 46 | r"Expected elapsed_time_sec to be a positive number, but received", 47 | ): 48 | throughput(42, -5.1) 49 | with self.assertRaisesRegex( 50 | ValueError, 51 | r"Expected elapsed_time_sec to be a positive number, but received", 52 | ): 53 | throughput(42, 0.0) 54 | -------------------------------------------------------------------------------- /tests/metrics/functional/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/image/test_psnr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | 13 | from skimage.metrics import peak_signal_noise_ratio as skimage_psnr 14 | from torcheval.metrics.functional import peak_signal_noise_ratio 15 | from torcheval.utils.test_utils.metric_class_tester import ( 16 | BATCH_SIZE, 17 | IMG_CHANNELS, 18 | IMG_HEIGHT, 19 | IMG_WIDTH, 20 | ) 21 | 22 | 23 | class TestPeakSignalNoiseRatio(unittest.TestCase): 24 | def test_psnr_skimage_equivelant(self) -> None: 25 | input, target = self._get_random_data_peak_signal_to_noise_ratio( 26 | BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH 27 | ) 28 | 29 | input_np = input.numpy().ravel() 30 | target_np = target.numpy().ravel() 31 | skimage_result = torch.tensor( 32 | skimage_psnr(target_np, input_np), dtype=torch.float32 33 | ) 34 | 35 | torch.testing.assert_close( 36 | peak_signal_noise_ratio(input, target), 37 | skimage_result, 38 | atol=1e-3, 39 | rtol=1e-3, 40 | ) 41 | 42 | def test_psnr_with_invalid_input(self) -> None: 43 | input = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH) 44 | target = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH + 1) 45 | with self.assertRaisesRegex( 46 | ValueError, 47 | r"^The `input` and `target` must have the same shape, " 48 | + rf"got shapes torch.Size\(\[{BATCH_SIZE}, {IMG_CHANNELS}, {IMG_HEIGHT}, {IMG_WIDTH}\]\) " 49 | + rf"and torch.Size\(\[{BATCH_SIZE}, {IMG_CHANNELS}, {IMG_HEIGHT}, {IMG_WIDTH + 1}\]\).", 50 | ): 51 | peak_signal_noise_ratio(input, target) 52 | 53 | def _get_random_data_peak_signal_to_noise_ratio( 54 | self, batch_size: int, num_channels: int, height: int, width: int 55 | ) -> tuple[torch.Tensor, torch.Tensor]: 56 | input = torch.rand( 57 | size=(batch_size, num_channels, height, width), 58 | ) 59 | target = torch.rand( 60 | size=(batch_size, num_channels, height, width), 61 | ) 62 | return input, target 63 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_click_through_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import click_through_rate 13 | 14 | 15 | class TestClickThroughRate(unittest.TestCase): 16 | def test_click_through_rate_with_valid_input(self) -> None: 17 | input = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1]) 18 | weights = torch.tensor([1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) 19 | torch.testing.assert_close(click_through_rate(input), torch.tensor(0.5)) 20 | torch.testing.assert_close( 21 | click_through_rate(input, weights), torch.tensor(0.58333334) 22 | ) 23 | 24 | input = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]]) 25 | weights = torch.tensor([[1.0, 2.0, 1.0, 2.0], [1.0, 2.0, 1.0, 1.0]]) 26 | torch.testing.assert_close( 27 | click_through_rate(input, num_tasks=2), torch.tensor([0.5, 0.5]) 28 | ) 29 | torch.testing.assert_close( 30 | click_through_rate(input, weights, num_tasks=2), 31 | torch.tensor([0.66666667, 0.4]), 32 | ) 33 | 34 | def test_click_through_rate_with_invalid_input(self) -> None: 35 | with self.assertRaisesRegex( 36 | ValueError, 37 | "^`input` should be a one or two dimensional tensor", 38 | ): 39 | click_through_rate(torch.rand(3, 2, 2)) 40 | with self.assertRaisesRegex( 41 | ValueError, 42 | "^tensor `weights` should have the same shape as tensor `input`", 43 | ): 44 | click_through_rate(torch.rand(4, 2), torch.rand(3)) 45 | with self.assertRaisesRegex( 46 | ValueError, 47 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", 48 | ): 49 | click_through_rate( 50 | torch.tensor([[1, 1], [0, 1]]), 51 | ) 52 | with self.assertRaisesRegex( 53 | ValueError, 54 | r"`num_tasks = 2`, `input`'s shape is expected to be", 55 | ): 56 | click_through_rate( 57 | torch.tensor([1, 0, 0, 1]), 58 | num_tasks=2, 59 | ) 60 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_frequency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import frequency_at_k 13 | 14 | 15 | class TestFrequency(unittest.TestCase): 16 | def test_frequency_with_valid_input(self) -> None: 17 | input = torch.tensor( 18 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700], 19 | ) 20 | 21 | torch.testing.assert_close( 22 | frequency_at_k(input, k=0.5), 23 | torch.tensor([1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000]), 24 | ) 25 | torch.testing.assert_close( 26 | frequency_at_k(input, k=0.9), 27 | torch.tensor([1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.0000, 0.0000]), 28 | ) 29 | torch.testing.assert_close( 30 | frequency_at_k(input, k=0.95), 31 | torch.tensor([1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000]), 32 | ) 33 | torch.testing.assert_close( 34 | frequency_at_k(input, k=1.0), 35 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]), 36 | ) 37 | 38 | def test_frequency_with_invalid_input(self) -> None: 39 | with self.assertRaisesRegex( 40 | ValueError, "input should be a one-dimensional tensor" 41 | ): 42 | frequency_at_k(torch.rand(3, 2, 2), k=1) 43 | with self.assertRaisesRegex(ValueError, "k should not be negative"): 44 | frequency_at_k(torch.rand(3), k=-1) 45 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_hit_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import hit_rate 13 | 14 | 15 | class TestHitRate(unittest.TestCase): 16 | def test_hit_rate_with_valid_input(self) -> None: 17 | input = torch.tensor( 18 | [ 19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700], 20 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091], 21 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337], 22 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516], 23 | ] 24 | ) 25 | target = torch.tensor([3, 5, 2, 1]) 26 | 27 | torch.testing.assert_close( 28 | hit_rate(input, target, k=None), 29 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]), 30 | ) 31 | torch.testing.assert_close( 32 | hit_rate(input, target, k=1), 33 | torch.tensor([0.0000, 1.0000, 0.0000, 0.0000]), 34 | ) 35 | torch.testing.assert_close( 36 | hit_rate(input, target, k=3), 37 | torch.tensor([0.0000, 1.0000, 1.0000, 1.0000]), 38 | ) 39 | torch.testing.assert_close( 40 | hit_rate(input, target, k=5), 41 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]), 42 | ) 43 | torch.testing.assert_close( 44 | hit_rate(input, target, k=20), 45 | torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]), 46 | ) 47 | 48 | def test_hit_rate_with_invalid_input(self) -> None: 49 | with self.assertRaisesRegex( 50 | ValueError, "target should be a one-dimensional tensor" 51 | ): 52 | hit_rate(torch.rand(3, 2), torch.rand(3, 2)) 53 | 54 | with self.assertRaisesRegex( 55 | ValueError, "input should be a two-dimensional tensor" 56 | ): 57 | hit_rate(torch.rand(3, 2, 2), torch.rand(3)) 58 | with self.assertRaisesRegex( 59 | ValueError, "`input` and `target` should have the same minibatch dimension" 60 | ): 61 | hit_rate(torch.rand(4, 2), torch.rand(3)) 62 | with self.assertRaisesRegex(ValueError, "k should be None or positive"): 63 | hit_rate(torch.rand(3, 2), torch.rand(3), k=0) 64 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_num_collisions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import num_collisions 13 | 14 | 15 | class TestNumCollisions(unittest.TestCase): 16 | def test_num_collisions_with_valid_input(self) -> None: 17 | input_test_1 = torch.tensor([3, 4, 2, 3]) 18 | torch.testing.assert_close( 19 | num_collisions(input_test_1), 20 | torch.tensor([1, 0, 0, 1]), 21 | ) 22 | 23 | input_test_2 = torch.tensor([3, 4, 1, 3, 1, 1, 5]) 24 | torch.testing.assert_close( 25 | num_collisions(input_test_2), 26 | torch.tensor([1, 0, 2, 1, 2, 2, 0]), 27 | ) 28 | 29 | def test_num_collisions_with_invalid_input(self) -> None: 30 | with self.assertRaisesRegex( 31 | ValueError, "input should be a one-dimensional tensor" 32 | ): 33 | num_collisions(torch.randint(10, (3, 2))) 34 | 35 | with self.assertRaisesRegex(ValueError, "input should be an integer tensor"): 36 | num_collisions(torch.rand(3)) 37 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import reciprocal_rank 13 | 14 | 15 | class TestReciprocalRank(unittest.TestCase): 16 | def test_reciprocal_rank_with_valid_input(self) -> None: 17 | input = torch.tensor( 18 | [ 19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700], 20 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091], 21 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337], 22 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516], 23 | ] 24 | ) 25 | target = torch.tensor([3, 5, 2, 1]) 26 | 27 | torch.testing.assert_close( 28 | reciprocal_rank(input, target, k=None), 29 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]), 30 | ) 31 | torch.testing.assert_close( 32 | reciprocal_rank(input, target, k=1), 33 | torch.tensor([0.0000, 1.0000, 0.0000, 0.0000]), 34 | ) 35 | torch.testing.assert_close( 36 | reciprocal_rank(input, target, k=3), 37 | torch.tensor([0.0000, 1.0000, 1.0000 / 3, 0.5000]), 38 | ) 39 | torch.testing.assert_close( 40 | reciprocal_rank(input, target, k=5), 41 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]), 42 | ) 43 | torch.testing.assert_close( 44 | reciprocal_rank(input, target, k=20), 45 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]), 46 | ) 47 | torch.testing.assert_close( 48 | reciprocal_rank(input, target, k=100), 49 | torch.tensor([0.2500, 1.0000, 1.0000 / 3, 0.5000]), 50 | ) 51 | 52 | def test_reciprocal_rank_with_invalid_input(self) -> None: 53 | with self.assertRaisesRegex( 54 | ValueError, "target should be a one-dimensional tensor" 55 | ): 56 | reciprocal_rank(torch.rand(3, 2), torch.rand(3, 2)) 57 | 58 | with self.assertRaisesRegex( 59 | ValueError, "input should be a two-dimensional tensor" 60 | ): 61 | reciprocal_rank(torch.rand(3, 2, 2), torch.rand(3)) 62 | with self.assertRaisesRegex( 63 | ValueError, "`input` and `target` should have the same minibatch dimension" 64 | ): 65 | reciprocal_rank(torch.rand(4, 2), torch.rand(3)) 66 | -------------------------------------------------------------------------------- /tests/metrics/functional/ranking/test_weighted_calibration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import weighted_calibration 13 | 14 | 15 | class TestWeightedCalibration(unittest.TestCase): 16 | def test_weighted_calibration_with_valid_input(self) -> None: 17 | torch.testing.assert_close( 18 | weighted_calibration( 19 | torch.tensor([0.8, 0.4, 0.3, 0.8, 0.7, 0.6]), 20 | torch.tensor([1, 1, 0, 0, 1, 0]), 21 | ), 22 | torch.tensor(1.2000), 23 | ) 24 | 25 | torch.testing.assert_close( 26 | weighted_calibration( 27 | torch.tensor([0.8, 0.4, 0.3, 0.8, 0.7, 0.6]), 28 | torch.tensor([1, 1, 0, 0, 1, 0]), 29 | torch.tensor([0.5, 1.0, 2.0, 0.4, 1.3, 0.9]), 30 | ), 31 | torch.tensor(1.1321428185), 32 | ) 33 | 34 | torch.testing.assert_close( 35 | weighted_calibration( 36 | torch.tensor([[0.8, 0.4], [0.8, 0.7]]), 37 | torch.tensor([[1, 1], [0, 1]]), 38 | num_tasks=2, 39 | ), 40 | torch.tensor([0.6000, 1.5000]), 41 | ) 42 | 43 | def test_weighted_calibration_with_invalid_input(self) -> None: 44 | with self.assertRaisesRegex( 45 | ValueError, 46 | r"Weight must be either a float value or a tensor that matches the input tensor size.", 47 | ): 48 | weighted_calibration( 49 | torch.tensor([0.8, 0.4, 0.8, 0.7]), 50 | torch.tensor([1, 1, 0, 1]), 51 | torch.tensor([1, 1.5]), 52 | ) 53 | 54 | with self.assertRaisesRegex( 55 | ValueError, 56 | r"is different from `target` shape", 57 | ): 58 | weighted_calibration( 59 | torch.tensor([0.8, 0.4, 0.8, 0.7]), 60 | torch.tensor([[1, 1, 0], [0, 1, 1]]), 61 | ) 62 | 63 | with self.assertRaisesRegex( 64 | ValueError, 65 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", 66 | ): 67 | weighted_calibration( 68 | torch.tensor([[0.8, 0.4], [0.8, 0.7]]), 69 | torch.tensor([[1, 1], [0, 1]]), 70 | ) 71 | with self.assertRaisesRegex( 72 | ValueError, 73 | r"`num_tasks = 2`, `input`'s shape is expected to be", 74 | ): 75 | weighted_calibration( 76 | torch.tensor([0.8, 0.4, 0.8, 0.7]), 77 | torch.tensor([1, 0, 0, 1]), 78 | num_tasks=2, 79 | ) 80 | -------------------------------------------------------------------------------- /tests/metrics/functional/regression/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/statistical/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/functional/text/test_word_error_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import word_error_rate 13 | 14 | 15 | class TestWordErrorRate(unittest.TestCase): 16 | def test_word_error_rate_with_valid_input(self) -> None: 17 | torch.testing.assert_close( 18 | word_error_rate("hello meta", "hello metaverse"), 19 | torch.tensor(0.5, dtype=torch.float64), 20 | ) 21 | torch.testing.assert_close( 22 | word_error_rate("hello meta", "hello meta"), 23 | torch.tensor(0.0, dtype=torch.float64), 24 | ) 25 | torch.testing.assert_close( 26 | word_error_rate("this is the prediction", "this is the reference"), 27 | torch.tensor(0.25, dtype=torch.float64), 28 | ) 29 | torch.testing.assert_close( 30 | word_error_rate( 31 | ["hello world", "welcome to the facebook"], 32 | ["hello metaverse", "welcome to meta"], 33 | ), 34 | torch.tensor(0.6, dtype=torch.float64), 35 | ) 36 | torch.testing.assert_close( 37 | word_error_rate( 38 | [ 39 | "hello metaverse", 40 | "come to the facebook", 41 | "this is reference", 42 | "there is the other one", 43 | ], 44 | [ 45 | "hello world", 46 | "welcome to meta", 47 | "this is reference", 48 | "there is another one", 49 | ], 50 | ), 51 | torch.tensor(0.5, dtype=torch.float64), 52 | ) 53 | 54 | def test_word_error_rate_with_invalid_input(self) -> None: 55 | with self.assertRaisesRegex( 56 | ValueError, "input and target should have the same type" 57 | ): 58 | word_error_rate(["hello metaverse", "welcome to meta"], "hello world") 59 | 60 | with self.assertRaisesRegex( 61 | ValueError, "input and target lists should have the same length" 62 | ): 63 | word_error_rate( 64 | ["hello metaverse", "welcome to meta"], 65 | [ 66 | "welcome to meta", 67 | "this is the prediction", 68 | "there is an other sample", 69 | ], 70 | ) 71 | -------------------------------------------------------------------------------- /tests/metrics/functional/text/test_word_information_lost.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import word_information_lost 13 | 14 | 15 | class TestWordInformationLost(unittest.TestCase): 16 | def test_word_information_lost(self) -> None: 17 | input = ["hello world", "welcome to the facebook"] 18 | target = ["hello metaverse", "welcome to meta"] 19 | torch.testing.assert_close( 20 | word_information_lost(input, target), 21 | torch.tensor(0.7, dtype=torch.float64), 22 | ) 23 | 24 | input = ["this is the prediction", "there is an other sample"] 25 | target = ["this is the reference", "there is another one"] 26 | torch.testing.assert_close( 27 | word_information_lost(input, target), 28 | torch.tensor(0.6527777, dtype=torch.float64), 29 | ) 30 | 31 | def test_word_information_lost_with_invalid_input(self) -> None: 32 | with self.assertRaisesRegex( 33 | AssertionError, 34 | "Arguments must contain the same number of strings.", 35 | ): 36 | word_information_lost( 37 | ["hello metaverse", "welcome to meta"], 38 | [ 39 | "welcome to meta", 40 | "this is the prediction", 41 | "there is an other sample", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /tests/metrics/functional/text/test_word_information_preserved.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import torch 12 | from torcheval.metrics.functional import word_information_preserved 13 | 14 | 15 | class TestWordInformationPreserved(unittest.TestCase): 16 | def test_word_information_preserved_with_valid_input(self) -> None: 17 | torch.testing.assert_close( 18 | word_information_preserved("hello meta", "hi metaverse"), 19 | torch.tensor(0.0, dtype=torch.float64), 20 | ) 21 | torch.testing.assert_close( 22 | word_information_preserved("hello meta", "hello meta"), 23 | torch.tensor(1.0, dtype=torch.float64), 24 | ) 25 | torch.testing.assert_close( 26 | word_information_preserved( 27 | "this is the prediction", "this is the reference" 28 | ), 29 | torch.tensor(0.5625, dtype=torch.float64), 30 | ) 31 | torch.testing.assert_close( 32 | word_information_preserved( 33 | ["hello world", "welcome to the facebook"], 34 | ["hello metaverse", "welcome to meta"], 35 | ), 36 | torch.tensor(0.3, dtype=torch.float64), 37 | ) 38 | torch.testing.assert_close( 39 | word_information_preserved( 40 | [ 41 | "hello metaverse", 42 | "come to the facebook", 43 | "this is reference", 44 | "there is the other one", 45 | ], 46 | [ 47 | "hello world", 48 | "welcome to meta", 49 | "this is reference", 50 | "there is another one", 51 | ], 52 | ), 53 | torch.tensor(0.38095238, dtype=torch.float64), 54 | ) 55 | 56 | def test_word_information_preserved_with_invalid_input(self) -> None: 57 | with self.assertRaisesRegex( 58 | ValueError, "input and target should have the same type" 59 | ): 60 | word_information_preserved( 61 | ["hello metaverse", "welcome to meta"], "hello world" 62 | ) 63 | 64 | with self.assertRaisesRegex( 65 | ValueError, "input and target lists should have the same length" 66 | ): 67 | word_information_preserved( 68 | ["hello metaverse", "welcome to meta"], 69 | [ 70 | "welcome to meta", 71 | "this is the prediction", 72 | "there is an other sample", 73 | ], 74 | ) 75 | -------------------------------------------------------------------------------- /tests/metrics/image/test_psnr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | from skimage.metrics import peak_signal_noise_ratio as skimage_peak_signal_noise_ratio 13 | from torcheval.metrics import PeakSignalNoiseRatio 14 | from torcheval.utils.test_utils.metric_class_tester import ( 15 | BATCH_SIZE, 16 | IMG_CHANNELS, 17 | IMG_HEIGHT, 18 | IMG_WIDTH, 19 | MetricClassTester, 20 | NUM_TOTAL_UPDATES, 21 | ) 22 | 23 | 24 | class TestPeakSignalNoiseRatio(MetricClassTester): 25 | def _get_random_data_PeakSignalToNoiseRatio( 26 | self, 27 | num_updates: int, 28 | batch_size: int, 29 | num_channels: int, 30 | height: int, 31 | width: int, 32 | ) -> tuple[torch.Tensor, torch.Tensor]: 33 | inputs = torch.rand( 34 | size=(num_updates, batch_size, num_channels, height, width), 35 | ) 36 | targets = torch.rand( 37 | size=(num_updates, batch_size, num_channels, height, width), 38 | ) 39 | return inputs, targets 40 | 41 | def _test_psnr_skimage_equivelant( 42 | self, 43 | input: torch.Tensor, 44 | target: torch.Tensor, 45 | data_range: float | None = None, 46 | ) -> None: 47 | input_np = input.numpy().ravel() 48 | target_np = target.numpy().ravel() 49 | 50 | skimage_result = torch.tensor( 51 | skimage_peak_signal_noise_ratio( 52 | image_true=target_np, image_test=input_np, data_range=data_range 53 | ) 54 | ) 55 | 56 | state_names = { 57 | "num_observations", 58 | "sum_squared_error", 59 | "data_range", 60 | "min_target", 61 | "max_target", 62 | } 63 | 64 | self.run_class_implementation_tests( 65 | metric=PeakSignalNoiseRatio(data_range=data_range), 66 | state_names=state_names, 67 | update_kwargs={"input": input, "target": target}, 68 | compute_result=skimage_result.to(torch.float32), 69 | ) 70 | 71 | def test_psnr_with_random_data(self) -> None: 72 | input, target = self._get_random_data_PeakSignalToNoiseRatio( 73 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH 74 | ) 75 | self._test_psnr_skimage_equivelant(input, target) 76 | 77 | def test_psnr_with_random_data_and_data_range(self) -> None: 78 | input, target = self._get_random_data_PeakSignalToNoiseRatio( 79 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH 80 | ) 81 | self._test_psnr_skimage_equivelant(input, target, data_range=0.5) 82 | 83 | def test_psnr_class_invalid_input(self) -> None: 84 | metric = PeakSignalNoiseRatio() 85 | with self.assertRaisesRegex( 86 | ValueError, 87 | "The `input` and `target` must have the same shape, " 88 | r"got shapes torch.Size\(\[4, 3, 4, 4\]\) and torch.Size\(\[4, 3, 4, 6\]\).", 89 | ): 90 | metric.update(torch.rand(4, 3, 4, 4), torch.rand(4, 3, 4, 6)) 91 | 92 | def test_psnr_class_invalid_data_range(self) -> None: 93 | with self.assertRaisesRegex( 94 | ValueError, "`data_range needs to be either `None` or `float`." 95 | ): 96 | PeakSignalNoiseRatio(data_range=5) 97 | 98 | with self.assertRaisesRegex(ValueError, "`data_range` needs to be positive."): 99 | PeakSignalNoiseRatio(data_range=-1.0) 100 | -------------------------------------------------------------------------------- /tests/metrics/image/test_ssim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | from torch import Tensor 12 | 13 | from torcheval.metrics.image.ssim import StructuralSimilarity 14 | from torcheval.utils.test_utils.metric_class_tester import ( 15 | BATCH_SIZE, 16 | IMG_CHANNELS, 17 | IMG_HEIGHT, 18 | IMG_WIDTH, 19 | MetricClassTester, 20 | NUM_TOTAL_UPDATES, 21 | ) 22 | 23 | # pyre-ignore-all-errors[6] 24 | 25 | 26 | class TestStructuralSimilarity(MetricClassTester): 27 | def setUp(self) -> None: 28 | super().setUp() 29 | torch.manual_seed(0) 30 | 31 | def _get_input_data( 32 | self, 33 | num_updates: int, 34 | batch_size: int, 35 | num_channels: int, 36 | height: int, 37 | width: int, 38 | ) -> dict[str, Tensor]: 39 | images = { 40 | "images_1": torch.rand( 41 | size=(num_updates, batch_size, num_channels, height, width) 42 | ), 43 | "images_2": torch.rand( 44 | size=(num_updates, batch_size, num_channels, height, width) 45 | ), 46 | } 47 | 48 | return images 49 | 50 | def test_ssim( 51 | self, 52 | ) -> None: 53 | images = self._get_input_data( 54 | NUM_TOTAL_UPDATES, BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH 55 | ) 56 | 57 | expected_result = torch.tensor(0.022607240825891495) 58 | 59 | state_names = { 60 | "mssim_sum", 61 | "num_images", 62 | } 63 | 64 | self.run_class_implementation_tests( 65 | metric=StructuralSimilarity(), 66 | state_names=state_names, 67 | update_kwargs={ 68 | "images_1": images["images_1"], 69 | "images_2": images["images_2"], 70 | }, 71 | compute_result=expected_result, 72 | min_updates_before_compute=2, 73 | test_merge_with_one_update=False, 74 | atol=1e-4, 75 | rtol=1e-4, 76 | test_devices=["cpu"], 77 | ) 78 | 79 | def test_ssim_invalid_input(self) -> None: 80 | metric = StructuralSimilarity() 81 | images_1 = torch.rand(BATCH_SIZE, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH) 82 | images_2 = torch.rand(BATCH_SIZE, 4, IMG_HEIGHT, IMG_WIDTH) 83 | 84 | with self.assertRaisesRegex( 85 | RuntimeError, "The two sets of images must have the same shape." 86 | ): 87 | metric.update(images_1=images_1, images_2=images_2) 88 | -------------------------------------------------------------------------------- /tests/metrics/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/ranking/test_click_through_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.ranking import ClickThroughRate 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestClickThroughRate(MetricClassTester): 15 | def test_ctr_with_valid_input(self) -> None: 16 | input = torch.tensor([[1, 0, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1]]) 17 | 18 | self.run_class_implementation_tests( 19 | metric=ClickThroughRate(), 20 | state_names={"click_total", "weight_total"}, 21 | update_kwargs={"input": input}, 22 | compute_result=torch.tensor([0.5625], dtype=torch.float64), 23 | num_total_updates=4, 24 | num_processes=2, 25 | ) 26 | 27 | input = torch.tensor( 28 | [ 29 | [[1, 0, 0, 1], [1, 1, 1, 1]], 30 | [[0, 0, 0, 0], [1, 1, 1, 1]], 31 | [[0, 1, 0, 1], [0, 1, 0, 1]], 32 | [[1, 1, 1, 1], [0, 1, 1, 1]], 33 | ] 34 | ) 35 | weights = torch.tensor( 36 | [ 37 | [[1, 2, 3, 4], [0, 0, 0, 0]], 38 | [[1, 2, 1, 2], [1, 2, 1, 2]], 39 | [[1, 1, 1, 1], [1, 1, 3, 1]], 40 | [[1, 1, 1, 1], [1, 1, 1, 1]], 41 | ] 42 | ) 43 | 44 | self.run_class_implementation_tests( 45 | metric=ClickThroughRate(num_tasks=2), 46 | state_names={"click_total", "weight_total"}, 47 | update_kwargs={"input": input, "weights": weights}, 48 | compute_result=torch.tensor([0.4583333, 0.6875], dtype=torch.float64), 49 | num_total_updates=4, 50 | num_processes=2, 51 | ) 52 | 53 | weights = [4.0, 1, torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]), 0.0] 54 | 55 | self.run_class_implementation_tests( 56 | metric=ClickThroughRate(num_tasks=2), 57 | state_names={"click_total", "weight_total"}, 58 | update_kwargs={"input": input, "weights": weights}, 59 | compute_result=torch.tensor([0.46666667, 0.86666667], dtype=torch.float64), 60 | num_total_updates=4, 61 | num_processes=2, 62 | ) 63 | 64 | def test_ctr_with_invalid_input(self) -> None: 65 | metric = ClickThroughRate() 66 | with self.assertRaisesRegex( 67 | ValueError, 68 | "^`input` should be a one or two dimensional tensor", 69 | ): 70 | metric.update(torch.rand(3, 2, 2)) 71 | 72 | metric = ClickThroughRate() 73 | with self.assertRaisesRegex( 74 | ValueError, 75 | "^tensor `weights` should have the same shape as tensor `input`", 76 | ): 77 | metric.update(torch.rand(4, 2), torch.rand(3)) 78 | with self.assertRaisesRegex( 79 | ValueError, 80 | r"`num_tasks = 1`, `input` is expected to be one-dimensional tensor,", 81 | ): 82 | metric.update( 83 | torch.tensor([[1, 1], [0, 1]]), 84 | ) 85 | 86 | metric = ClickThroughRate(num_tasks=2) 87 | with self.assertRaisesRegex( 88 | ValueError, 89 | r"`num_tasks = 2`, `input`'s shape is expected to be", 90 | ): 91 | metric.update( 92 | torch.tensor([1, 0, 0, 1]), 93 | ) 94 | 95 | with self.assertRaisesRegex( 96 | ValueError, 97 | r"`num_tasks` value should be greater than and equal to 1,", 98 | ): 99 | metric = ClickThroughRate(num_tasks=0) 100 | -------------------------------------------------------------------------------- /tests/metrics/ranking/test_hit_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.ranking import HitRate 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestHitRate(MetricClassTester): 15 | def test_hitrate_with_valid_input(self) -> None: 16 | input = torch.tensor( 17 | [ 18 | [ 19 | [0.4826, 0.9517, 0.8967, 0.8995, 0.1584, 0.9445, 0.9700], 20 | ], 21 | [ 22 | [0.4938, 0.7517, 0.8039, 0.7167, 0.9488, 0.9607, 0.7091], 23 | ], 24 | [ 25 | [0.5127, 0.4732, 0.5461, 0.5617, 0.9198, 0.0847, 0.2337], 26 | ], 27 | [ 28 | [0.4175, 0.9452, 0.9852, 0.2131, 0.5016, 0.7305, 0.0516], 29 | ], 30 | ] 31 | ) 32 | target = torch.tensor([[3], [5], [2], [1]]) 33 | 34 | self.run_class_implementation_tests( 35 | metric=HitRate(), 36 | state_names={"scores"}, 37 | update_kwargs={"input": input, "target": target}, 38 | compute_result=torch.tensor([1.0000, 1.0000, 1.0000, 1.0000]), 39 | num_total_updates=4, 40 | num_processes=2, 41 | ) 42 | 43 | self.run_class_implementation_tests( 44 | metric=HitRate(k=3), 45 | state_names={"scores"}, 46 | update_kwargs={"input": input, "target": target}, 47 | compute_result=torch.tensor([0.0000, 1.0000, 1.0000, 1.0000]), 48 | num_total_updates=4, 49 | num_processes=2, 50 | ) 51 | 52 | def test_hitrate_with_invalid_input(self) -> None: 53 | metric = HitRate() 54 | with self.assertRaisesRegex( 55 | ValueError, "target should be a one-dimensional tensor" 56 | ): 57 | metric.update(torch.rand(3, 2), torch.rand(3, 2)) 58 | 59 | with self.assertRaisesRegex( 60 | ValueError, "input should be a two-dimensional tensor" 61 | ): 62 | metric.update(torch.rand(3, 2, 2), torch.rand(3)) 63 | with self.assertRaisesRegex( 64 | ValueError, "`input` and `target` should have the same minibatch dimension" 65 | ): 66 | metric.update(torch.rand(4, 2), torch.rand(3)) 67 | -------------------------------------------------------------------------------- /tests/metrics/ranking/test_reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.ranking import ReciprocalRank 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestReciprocalRank(MetricClassTester): 15 | def test_mrr_with_valid_input(self) -> None: 16 | input = torch.tensor( 17 | [ 18 | [ 19 | [0.9005, 0.0998, 0.2470, 0.6188, 0.9497, 0.6083, 0.7258], 20 | [0.9505, 0.3270, 0.4734, 0.5854, 0.5202, 0.6546, 0.7869], 21 | ], 22 | [ 23 | [0.5546, 0.6027, 0.2650, 0.6624, 0.8755, 0.7838, 0.7529], 24 | [0.4121, 0.6082, 0.7813, 0.5947, 0.9582, 0.8736, 0.7389], 25 | ], 26 | [ 27 | [0.1306, 0.7939, 0.5192, 0.0494, 0.7987, 0.3898, 0.0108], 28 | [0.2399, 0.2969, 0.6738, 0.8633, 0.7939, 0.1052, 0.7702], 29 | ], 30 | [ 31 | [0.9097, 0.7436, 0.0051, 0.6264, 0.6616, 0.7328, 0.7413], 32 | [0.5286, 0.2956, 0.0578, 0.1913, 0.8118, 0.1047, 0.7966], 33 | ], 34 | ] 35 | ) 36 | target = torch.tensor([[1, 3], [3, 0], [2, 6], [4, 5]]) 37 | 38 | self.run_class_implementation_tests( 39 | metric=ReciprocalRank(), 40 | state_names={"scores"}, 41 | update_kwargs={"input": input, "target": target}, 42 | compute_result=torch.tensor( 43 | [1.0 / 7, 0.25, 0.25, 1.0 / 7, 1.0 / 3, 1.0 / 3, 0.20, 1.0 / 6] 44 | ), 45 | num_total_updates=4, 46 | num_processes=2, 47 | ) 48 | 49 | self.run_class_implementation_tests( 50 | metric=ReciprocalRank(k=5), 51 | state_names={"scores"}, 52 | update_kwargs={"input": input, "target": target}, 53 | compute_result=torch.tensor( 54 | [0.0, 0.25, 0.25, 0.0, 1.0 / 3, 1.0 / 3, 0.2, 0.0] 55 | ), 56 | num_total_updates=4, 57 | num_processes=2, 58 | ) 59 | 60 | def test_mrr_with_invalid_input(self) -> None: 61 | metric = ReciprocalRank() 62 | with self.assertRaisesRegex( 63 | ValueError, "target should be a one-dimensional tensor" 64 | ): 65 | metric.update(torch.rand(3, 2), torch.rand(3, 2)) 66 | 67 | with self.assertRaisesRegex( 68 | ValueError, "input should be a two-dimensional tensor" 69 | ): 70 | metric.update(torch.rand(3, 2, 2), torch.rand(3)) 71 | with self.assertRaisesRegex( 72 | ValueError, "`input` and `target` should have the same minibatch dimension" 73 | ): 74 | metric.update(torch.rand(4, 2), torch.rand(3)) 75 | -------------------------------------------------------------------------------- /tests/metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/statistical/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/metrics/text/test_word_error_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.text import WordErrorRate 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestWordErrorRate(MetricClassTester): 15 | def test_word_error_rate_with_valid_input(self) -> None: 16 | self.run_class_implementation_tests( 17 | metric=WordErrorRate(), 18 | state_names={"errors", "total"}, 19 | update_kwargs={ 20 | "input": [ 21 | ["hello world", "welcome to the facebook"], 22 | ["hello world", "welcome to the facebook"], 23 | ["hello world", "welcome to the facebook"], 24 | ["hello world", "welcome to the facebook"], 25 | ], 26 | "target": [ 27 | ["hello metaverse", "welcome to meta"], 28 | ["hello metaverse", "welcome to meta"], 29 | ["hello metaverse", "welcome to meta"], 30 | ["hello metaverse", "welcome to meta"], 31 | ], 32 | }, 33 | compute_result=torch.tensor(0.6), 34 | num_total_updates=4, 35 | ) 36 | 37 | def test_word_error_rate_with_invalid_input(self) -> None: 38 | metric = WordErrorRate() 39 | with self.assertRaisesRegex( 40 | ValueError, "input and target should have the same type" 41 | ): 42 | metric.update(["hello metaverse", "welcome to meta"], "hello world") 43 | 44 | with self.assertRaisesRegex( 45 | ValueError, "input and target lists should have the same length" 46 | ): 47 | metric.update( 48 | ["hello metaverse", "welcome to meta"], 49 | [ 50 | "welcome to meta", 51 | "this is the prediction", 52 | "there is an other sample", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /tests/metrics/text/test_word_information_lost.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.text import WordInformationLost 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestWordInformationLost(MetricClassTester): 15 | def test_word_information_lost(self) -> None: 16 | self.run_class_implementation_tests( 17 | metric=WordInformationLost(), 18 | state_names={"correct_total", "target_total", "preds_total"}, 19 | update_kwargs={ 20 | "input": [ 21 | ["hello world", "welcome to the facebook"], 22 | ["hello world", "welcome to the facebook"], 23 | ["hello world", "welcome to the facebook"], 24 | ["hello world", "welcome to the facebook"], 25 | ], 26 | "target": [ 27 | ["hello metaverse", "welcome to meta"], 28 | ["hello metaverse", "welcome to meta"], 29 | ["hello metaverse", "welcome to meta"], 30 | ["hello metaverse", "welcome to meta"], 31 | ], 32 | }, 33 | compute_result=torch.tensor(0.7, dtype=torch.float64), 34 | num_total_updates=4, 35 | ) 36 | 37 | def test_word_information_lost_with_invalid_input(self) -> None: 38 | metric = WordInformationLost() 39 | 40 | with self.assertRaisesRegex( 41 | AssertionError, 42 | "Arguments must contain the same number of strings.", 43 | ): 44 | metric.update( 45 | ["hello metaverse", "welcome to meta"], 46 | [ 47 | "welcome to meta", 48 | "this is the prediction", 49 | "there is an other sample", 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /tests/metrics/text/test_word_information_preserved.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | from torcheval.metrics.text import WordInformationPreserved 11 | from torcheval.utils.test_utils.metric_class_tester import MetricClassTester 12 | 13 | 14 | class TestWordInformationPreserved(MetricClassTester): 15 | def test_word_information_preserved_with_valid_input(self) -> None: 16 | self.run_class_implementation_tests( 17 | metric=WordInformationPreserved(), 18 | state_names={"correct_total", "input_total", "target_total"}, 19 | update_kwargs={ 20 | "input": [ 21 | ["hello world", "welcome to the facebook"], 22 | ["hello world", "welcome to the facebook"], 23 | ["hello world", "welcome to the facebook"], 24 | ["hello world", "welcome to the facebook"], 25 | ], 26 | "target": [ 27 | ["hello metaverse", "welcome to meta"], 28 | ["hello metaverse", "welcome to meta"], 29 | ["hello metaverse", "welcome to meta"], 30 | ["hello metaverse", "welcome to meta"], 31 | ], 32 | }, 33 | compute_result=torch.tensor(0.3, dtype=torch.float64), 34 | num_total_updates=4, 35 | ) 36 | 37 | def test_word_information_preserved_with_invalid_input(self) -> None: 38 | metric = WordInformationPreserved() 39 | with self.assertRaisesRegex( 40 | ValueError, "input and target should have the same type" 41 | ): 42 | metric.update(["hello metaverse", "welcome to meta"], "hello world") 43 | 44 | with self.assertRaisesRegex( 45 | ValueError, "input and target lists should have the same length" 46 | ): 47 | metric.update( 48 | ["hello metaverse", "welcome to meta"], 49 | [ 50 | "welcome to meta", 51 | "this is the prediction", 52 | "there is an other sample", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /torcheval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | "A library that contains a collection of performant PyTorch model metrics" 10 | 11 | from .version import __version__ 12 | 13 | __all__ = [ 14 | "__version__", 15 | ] 16 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.aggregation.auc import AUC 10 | from torcheval.metrics.aggregation.cat import Cat 11 | from torcheval.metrics.aggregation.cov import Covariance 12 | from torcheval.metrics.aggregation.max import Max 13 | from torcheval.metrics.aggregation.mean import Mean 14 | from torcheval.metrics.aggregation.min import Min 15 | from torcheval.metrics.aggregation.sum import Sum 16 | from torcheval.metrics.aggregation.throughput import Throughput 17 | 18 | __all__ = ["AUC", "Cat", "Covariance", "Max", "Mean", "Min", "Sum", "Throughput"] 19 | __doc_name__ = "Aggregation Metrics" 20 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/cat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | 12 | from collections.abc import Iterable 13 | from typing import TypeVar 14 | 15 | import torch 16 | 17 | from torcheval.metrics.metric import Metric 18 | 19 | TCat = TypeVar("TCat") 20 | 21 | 22 | class Cat(Metric[torch.Tensor]): 23 | """ 24 | Concatenate all input tensors along dimension dim. Its functional 25 | version is ``torch.cat(input)``. 26 | 27 | All input tensors to ``Cat.update()`` must either have the same shape 28 | (except in the concatenating dimension) or be empty. 29 | 30 | Zero-dimensional tensor is not a valid input of ``Cat.update()``. 31 | ``torch.flatten()`` can be used to flatten zero-dimensional into 32 | an one-dimensional tensor before passing in ``Cat.update()``. 33 | 34 | Examples:: 35 | 36 | >>> import torch 37 | >>> from torcheval.metrics import Cat 38 | >>> metric = Cat(dim=1) 39 | >>> metric.update(torch.tensor([[1, 2], [3, 4]])) 40 | >>> metric.compute() 41 | tensor([[1, 2], 42 | [3, 4]])) 43 | 44 | >>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute() 45 | tensor([[1, 2, 5, 6], 46 | [3, 4, 7, 8]])) 47 | 48 | >>> metric.reset() 49 | >>> metric.update(torch.tensor([0])).compute() 50 | tensor([0]) 51 | """ 52 | 53 | def __init__( 54 | self: "Cat", 55 | *, 56 | dim: int = 0, 57 | device: torch.device | None = None, 58 | ) -> None: 59 | """ 60 | Initialize a Cat metric object. 61 | 62 | Args: 63 | dim: The dimension along which to concatenate, as in ``torch.cat()``. 64 | """ 65 | super().__init__(device=device) 66 | self._add_state("dim", dim) 67 | self._add_state("inputs", []) 68 | 69 | @torch.inference_mode() 70 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 71 | def update(self: TCat, input: torch.Tensor) -> TCat: 72 | self.inputs.append(input) 73 | return self 74 | 75 | @torch.inference_mode() 76 | def compute(self: TCat) -> torch.Tensor: 77 | """ 78 | Return the concatenated inputs. 79 | 80 | If no calls to ``update()`` are made before ``compute()`` is called, 81 | the function returns ``torch.empty(0)``. 82 | """ 83 | if not self.inputs: 84 | return torch.empty(0) 85 | return torch.cat(self.inputs, dim=self.dim) 86 | 87 | @torch.inference_mode() 88 | def merge_state(self: TCat, metrics: Iterable[TCat]) -> TCat: 89 | for metric in metrics: 90 | if metric.inputs: 91 | self.inputs.append( 92 | torch.cat(metric.inputs, dim=metric.dim).to(self.device) 93 | ) 94 | return self 95 | 96 | @torch.inference_mode() 97 | def _prepare_for_merge_state(self: TCat) -> None: 98 | if self.inputs: 99 | self.inputs = [torch.cat(self.inputs, dim=self.dim)] 100 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/cov.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from collections.abc import Iterable 10 | from typing import TypeAlias, TypeVar, Union 11 | 12 | import torch 13 | from torcheval.metrics.metric import Metric 14 | from typing_extensions import Self 15 | 16 | # TODO: use a NamedTuple? 17 | _T = TypeVar("_T", bound=Union[torch.Tensor, int]) 18 | _Output: TypeAlias = tuple[torch.Tensor, torch.Tensor] # mean, cov 19 | 20 | 21 | class Covariance(Metric[_Output]): 22 | """Fit sample mean + covariance to empirical distribution""" 23 | 24 | def __init__(self, *, device: torch.device | None = None) -> None: 25 | super().__init__(device=device) 26 | self.sum: torch.Tensor = self._add_state_and_return( 27 | "sum", default=torch.as_tensor(0.0) 28 | ) 29 | self.ss_sum: torch.Tensor = self._add_state_and_return( 30 | "ss_sum", default=torch.as_tensor(0.0) 31 | ) 32 | self.n: int = self._add_state_and_return("n", default=0) 33 | 34 | def _add_state_and_return(self, name: str, default: _T) -> _T: 35 | # Helper function for pyre 36 | self._add_state(name, default) 37 | return getattr(self, name) 38 | 39 | def _update(self, sum: torch.Tensor, ss_sum: torch.Tensor, n: int) -> None: 40 | if n == 0: 41 | return 42 | elif self.n == 0: 43 | self.n = n 44 | self.ss_sum = ss_sum 45 | self.sum = sum 46 | else: 47 | # Welford's algorithm for numerical stability 48 | delta = (self.sum / self.n) - (sum / n) 49 | outer = torch.outer(delta, delta) 50 | self.ss_sum += ss_sum + outer * (n * self.n) / (self.n + n) 51 | self.sum += sum 52 | self.n += n 53 | 54 | # pyre-fixme[14] 55 | def update(self, obs: torch.Tensor) -> Self: 56 | assert obs.ndim == 2 57 | with torch.inference_mode(): 58 | demeaned = obs - obs.mean(dim=0, keepdim=True) 59 | ss_sum = torch.einsum("ni,nj->ij", demeaned, demeaned) 60 | self._update(obs.sum(dim=0), ss_sum, len(obs)) 61 | return self 62 | 63 | # pyre-fixme[14] 64 | def merge_state(self, metrics: Iterable[Self]) -> Self: 65 | with torch.inference_mode(): 66 | for other in metrics: 67 | self._update(other.sum, other.ss_sum, other.n) 68 | return self 69 | 70 | def compute(self) -> _Output: 71 | if self.n < 2: 72 | msg = f"Not enough samples to estimate covariance (found {self.n})" 73 | raise ValueError(msg) 74 | with torch.inference_mode(): 75 | mean = self.sum / self.n 76 | # TODO: make degress of freedom configurable? 77 | cov = self.ss_sum / (self.n - 1) 78 | return mean, cov 79 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/max.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.metric import Metric 17 | 18 | 19 | TMax = TypeVar("TMax") 20 | 21 | 22 | class Max(Metric[torch.Tensor]): 23 | """ 24 | Calculate the maximum value of all elements in all the input tensors. 25 | Its functional version is ``torch.max(input)``. 26 | 27 | Examples:: 28 | 29 | >>> import torch 30 | >>> from torcheval.metrics import Max 31 | >>> metric = Max() 32 | >>> metric.update(torch.tensor([[1, 2], [3, 4]])) 33 | >>> metric.compute() 34 | tensor(4.) 35 | 36 | >>> metric.update(torch.tensor(-1)).compute() 37 | tensor(4.) 38 | 39 | >>> metric.reset() 40 | >>> metric.update(torch.tensor(-1)).compute() 41 | tensor(-1.) 42 | """ 43 | 44 | def __init__( 45 | self: TMax, 46 | *, 47 | device: torch.device | None = None, 48 | ) -> None: 49 | super().__init__(device=device) 50 | self._add_state("max", torch.tensor(float("-inf"), device=self.device)) 51 | 52 | @torch.inference_mode() 53 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 54 | def update(self: TMax, input: torch.Tensor) -> TMax: 55 | self.max = torch.max(self.max, torch.max(input)) 56 | return self 57 | 58 | @torch.inference_mode() 59 | def compute(self: TMax) -> torch.Tensor: 60 | return self.max 61 | 62 | @torch.inference_mode() 63 | def merge_state(self: TMax, metrics: Iterable[TMax]) -> TMax: 64 | for metric in metrics: 65 | self.max = torch.max(self.max, metric.max.to(self.device)) 66 | return self 67 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/min.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.metric import Metric 17 | 18 | 19 | TMin = TypeVar("TMin") 20 | 21 | 22 | class Min(Metric[torch.Tensor]): 23 | """ 24 | Calculate the minimum value of all elements in all the input tensors. 25 | Its functional version is ``torch.min(input)``. 26 | 27 | Examples:: 28 | 29 | >>> import torch 30 | >>> from torcheval.metrics import Min 31 | >>> metric = Min() 32 | >>> metric.update(torch.tensor([[1, 2], [3, 4]])) 33 | >>> metric.compute() 34 | tensor(1.) 35 | 36 | >>> metric.update(torch.tensor(-1)).compute() 37 | tensor(-1.) 38 | 39 | >>> metric.reset() 40 | >>> metric.update(torch.tensor(5)).compute() 41 | tensor(5.) 42 | """ 43 | 44 | def __init__( 45 | self: TMin, 46 | *, 47 | device: torch.device | None = None, 48 | ) -> None: 49 | super().__init__(device=device) 50 | self._add_state("min", torch.tensor(float("inf"), device=self.device)) 51 | 52 | @torch.inference_mode() 53 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 54 | def update(self: TMin, input: torch.Tensor) -> TMin: 55 | self.min = torch.min(self.min, torch.min(input)) 56 | return self 57 | 58 | @torch.inference_mode() 59 | def compute(self: TMin) -> torch.Tensor: 60 | return self.min 61 | 62 | @torch.inference_mode() 63 | def merge_state(self: TMin, metrics: Iterable[TMin]) -> TMin: 64 | for metric in metrics: 65 | self.min = torch.min(self.min, metric.min.to(self.device)) 66 | return self 67 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/sum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.functional.aggregation.sum import _sum_update 17 | from torcheval.metrics.metric import Metric 18 | 19 | TSum = TypeVar("TSum") 20 | 21 | 22 | class Sum(Metric[torch.Tensor]): 23 | """ 24 | Calculate the weighted sum value of all elements in all the input tensors. 25 | When weight is not provided, it calculates the unweighted sum. 26 | Its functional version is :func:`torcheval.metrics.functional.sum`. 27 | 28 | Examples:: 29 | 30 | >>> import torch 31 | >>> from torcheval.metrics import Sum 32 | >>> metric = Sum() 33 | >>> metric.update(1) 34 | >>> metric.update(torch.tensor([2, 3])) 35 | >>> metric.compute() 36 | tensor(6.) 37 | >>> metric.update(torch.tensor(-1)).compute() 38 | tensor(5.) 39 | >>> metric.reset() 40 | >>> metric.update(torch.tensor(-1)).compute() 41 | tensor(-1.) 42 | 43 | >>> metric = Sum() 44 | >>> metric.update(torch.tensor([2, 3]), torch.tensor([0.1, 0.6])).compute() 45 | tensor(2.) 46 | >>> metric.update(torch.tensor([2, 3]), 0.5).compute() 47 | tensor(4.5) 48 | >>> metric.update(torch.tensor([4, 6]), 1).compute() 49 | tensor(14.5) 50 | """ 51 | 52 | def __init__( 53 | self: TSum, 54 | *, 55 | device: torch.device | None = None, 56 | ) -> None: 57 | super().__init__(device=device) 58 | self._add_state( 59 | "weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64) 60 | ) 61 | 62 | @torch.inference_mode() 63 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 64 | def update( 65 | self: TSum, 66 | input: torch.Tensor, 67 | *, 68 | weight: float | int | torch.Tensor = 1.0, 69 | ) -> TSum: 70 | """ 71 | Update states with the values and weights. 72 | 73 | Args: 74 | input (Tensor): Tensor of input values. 75 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size. 76 | Raises: 77 | ValueError: If value of weight is neither a ``float`` nor ``int`` nor a ``torch.Tensor`` that matches the input tensor size. 78 | """ 79 | 80 | self.weighted_sum += _sum_update(input, weight) 81 | return self 82 | 83 | @torch.inference_mode() 84 | def compute(self: TSum) -> torch.Tensor: 85 | return self.weighted_sum 86 | 87 | @torch.inference_mode() 88 | def merge_state(self: TSum, metrics: Iterable[TSum]) -> TSum: 89 | for metric in metrics: 90 | self.weighted_sum += metric.weighted_sum.to(self.device) 91 | return self 92 | -------------------------------------------------------------------------------- /torcheval/metrics/aggregation/throughput.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | import logging 12 | from collections.abc import Iterable 13 | from typing import TypeVar 14 | 15 | import torch 16 | 17 | from torcheval.metrics.metric import Metric 18 | 19 | TThroughput = TypeVar("TThroughput") 20 | 21 | _logger: logging.Logger = logging.getLogger(__name__) 22 | 23 | 24 | class Throughput(Metric[float]): 25 | """ 26 | Calculate the throughput value which is the number of elements processed per second. 27 | 28 | Note: In a distributed setting, it's recommended to use `world_size * metric.compute()` 29 | to get an approximation of total throughput. While using `sync_and_compute(metric)` requires 30 | state sync. Additionally, `sync_and_compute(metric)` will give a slightly different value compared 31 | to `world_size * metric.compute()`. 32 | 33 | Examples:: 34 | 35 | >>> import time 36 | >>> import torch 37 | >>> from torcheval.metrics import Throughput 38 | >>> metric = Throughput() 39 | >>> items_processed = 64 40 | >>> ts = time.monotonic() 41 | >>> time.sleep(2.0) # simulate executing the program for 2 seconds 42 | >>> elapsed_time_sec = time.monotonic() - ts 43 | >>> metric.update(items_processed, elapsed_time_sec) 44 | >>> metric.compute() 45 | tensor(32.) 46 | """ 47 | 48 | def __init__( 49 | self: TThroughput, 50 | *, 51 | device: torch.device | None = None, 52 | ) -> None: 53 | super().__init__(device=device) 54 | self._add_state("num_total", 0.0) 55 | self._add_state("elapsed_time_sec", 0.0) 56 | 57 | @torch.inference_mode() 58 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 59 | def update( 60 | self: TThroughput, 61 | num_processed: int, 62 | elapsed_time_sec: float, 63 | ) -> TThroughput: 64 | """ 65 | Update states with the values and weights. 66 | 67 | Args: 68 | num_processed: Number of items processed 69 | elapsed_time_sec: Total elapsed time in seconds to process ``num_processed`` items 70 | Raises: 71 | ValueError: 72 | If ``num_processed`` is a negative number. 73 | If ``elapsed_time_sec`` is a non-positive number. 74 | """ 75 | if num_processed < 0: 76 | raise ValueError( 77 | f"Expected num_processed to be a non-negative number, but received {num_processed}." 78 | ) 79 | if elapsed_time_sec <= 0: 80 | raise ValueError( 81 | f"Expected elapsed_time_sec to be a positive number, but received {elapsed_time_sec}." 82 | ) 83 | 84 | self.elapsed_time_sec += elapsed_time_sec 85 | self.num_total += num_processed 86 | return self 87 | 88 | @torch.inference_mode() 89 | def compute(self: TThroughput) -> float: 90 | if not self.elapsed_time_sec: 91 | _logger.warning("No calls to update() have been made - returning 0.0") 92 | return 0.0 93 | 94 | return self.num_total / self.elapsed_time_sec 95 | 96 | @torch.inference_mode() 97 | def merge_state(self: TThroughput, metrics: Iterable[TThroughput]) -> TThroughput: 98 | for metric in metrics: 99 | self.num_total += metric.num_total 100 | # this assumes the metric is used within a fully-synchronous program. 101 | # In this scenario, the slowest process becomes the bottleneck for the 102 | # program's execution. As a result, we use the max, as the overall throughput 103 | # is gated based on the rank that takes the longest to complete. 104 | # TODO: should this be configurable? 105 | self.elapsed_time_sec = max(self.elapsed_time_sec, metric.elapsed_time_sec) 106 | return self 107 | -------------------------------------------------------------------------------- /torcheval/metrics/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.audio.fad import FrechetAudioDistance 10 | 11 | 12 | __all__ = ["FrechetAudioDistance"] 13 | -------------------------------------------------------------------------------- /torcheval/metrics/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.classification.accuracy import ( 10 | BinaryAccuracy, 11 | MulticlassAccuracy, 12 | MultilabelAccuracy, 13 | TopKMultilabelAccuracy, 14 | ) 15 | from torcheval.metrics.classification.auprc import ( 16 | BinaryAUPRC, 17 | MulticlassAUPRC, 18 | MultilabelAUPRC, 19 | ) 20 | 21 | from torcheval.metrics.classification.auroc import BinaryAUROC, MulticlassAUROC 22 | from torcheval.metrics.classification.binary_normalized_entropy import ( 23 | BinaryNormalizedEntropy, 24 | ) 25 | from torcheval.metrics.classification.binned_auprc import ( 26 | BinaryBinnedAUPRC, 27 | MulticlassBinnedAUPRC, 28 | MultilabelBinnedAUPRC, 29 | ) 30 | from torcheval.metrics.classification.binned_auroc import ( 31 | BinaryBinnedAUROC, 32 | MulticlassBinnedAUROC, 33 | ) 34 | from torcheval.metrics.classification.binned_precision_recall_curve import ( 35 | BinaryBinnedPrecisionRecallCurve, 36 | MulticlassBinnedPrecisionRecallCurve, 37 | MultilabelBinnedPrecisionRecallCurve, 38 | ) 39 | from torcheval.metrics.classification.confusion_matrix import ( 40 | BinaryConfusionMatrix, 41 | MulticlassConfusionMatrix, 42 | ) 43 | from torcheval.metrics.classification.f1_score import BinaryF1Score, MulticlassF1Score 44 | from torcheval.metrics.classification.precision import ( 45 | BinaryPrecision, 46 | MulticlassPrecision, 47 | ) 48 | from torcheval.metrics.classification.precision_recall_curve import ( 49 | BinaryPrecisionRecallCurve, 50 | MulticlassPrecisionRecallCurve, 51 | MultilabelPrecisionRecallCurve, 52 | ) 53 | from torcheval.metrics.classification.recall import BinaryRecall, MulticlassRecall 54 | from torcheval.metrics.classification.recall_at_fixed_precision import ( 55 | BinaryRecallAtFixedPrecision, 56 | MultilabelRecallAtFixedPrecision, 57 | ) 58 | 59 | __all__ = [ 60 | "BinaryAccuracy", 61 | "BinaryAUPRC", 62 | "BinaryAUROC", 63 | "BinaryBinnedAUROC", 64 | "BinaryBinnedAUPRC", 65 | "BinaryBinnedPrecisionRecallCurve", 66 | "BinaryConfusionMatrix", 67 | "BinaryF1Score", 68 | "BinaryNormalizedEntropy", 69 | "BinaryPrecision", 70 | "BinaryPrecisionRecallCurve", 71 | "BinaryRecall", 72 | "BinaryRecallAtFixedPrecision", 73 | "MulticlassAccuracy", 74 | "MulticlassAUPRC", 75 | "MulticlassAUROC", 76 | "MulticlassBinnedAUPRC", 77 | "MulticlassBinnedAUROC", 78 | "MulticlassBinnedPrecisionRecallCurve", 79 | "MulticlassConfusionMatrix", 80 | "MulticlassF1Score", 81 | "MulticlassPrecision", 82 | "MulticlassPrecisionRecallCurve", 83 | "MulticlassRecall", 84 | "MultilabelAccuracy", 85 | "MultilabelAUPRC", 86 | "MultilabelBinnedAUPRC", 87 | "MultilabelBinnedPrecisionRecallCurve", 88 | "MultilabelPrecisionRecallCurve", 89 | "MultilabelRecallAtFixedPrecision", 90 | "TopKMultilabelAccuracy", 91 | ] 92 | 93 | __doc_name__ = "Classification Metrics" 94 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.aggregation import auc, mean, sum, throughput 10 | from torcheval.metrics.functional.classification import ( 11 | binary_accuracy, 12 | binary_auprc, 13 | binary_auroc, 14 | binary_binned_auprc, 15 | binary_binned_auroc, 16 | binary_binned_precision_recall_curve, 17 | binary_confusion_matrix, 18 | binary_f1_score, 19 | binary_normalized_entropy, 20 | binary_precision, 21 | binary_precision_recall_curve, 22 | binary_recall, 23 | binary_recall_at_fixed_precision, 24 | multiclass_accuracy, 25 | multiclass_auprc, 26 | multiclass_auroc, 27 | multiclass_binned_auprc, 28 | multiclass_binned_auroc, 29 | multiclass_binned_precision_recall_curve, 30 | multiclass_confusion_matrix, 31 | multiclass_f1_score, 32 | multiclass_precision, 33 | multiclass_precision_recall_curve, 34 | multiclass_recall, 35 | multilabel_accuracy, 36 | multilabel_auprc, 37 | multilabel_binned_auprc, 38 | multilabel_binned_precision_recall_curve, 39 | multilabel_precision_recall_curve, 40 | multilabel_recall_at_fixed_precision, 41 | topk_multilabel_accuracy, 42 | ) 43 | from torcheval.metrics.functional.frechet import gaussian_frechet_distance 44 | from torcheval.metrics.functional.image import peak_signal_noise_ratio 45 | from torcheval.metrics.functional.ranking import ( 46 | click_through_rate, 47 | frequency_at_k, 48 | hit_rate, 49 | num_collisions, 50 | reciprocal_rank, 51 | retrieval_precision, 52 | retrieval_recall, 53 | weighted_calibration, 54 | ) 55 | from torcheval.metrics.functional.regression import mean_squared_error, r2_score 56 | from torcheval.metrics.functional.text import ( 57 | bleu_score, 58 | perplexity, 59 | word_error_rate, 60 | word_information_lost, 61 | word_information_preserved, 62 | ) 63 | 64 | __all__ = [ 65 | "auc", 66 | "binary_accuracy", 67 | "binary_auprc", 68 | "binary_auroc", 69 | "binary_binned_auprc", 70 | "binary_binned_auroc", 71 | "binary_binned_precision_recall_curve", 72 | "binary_confusion_matrix", 73 | "binary_f1_score", 74 | "binary_normalized_entropy", 75 | "binary_precision", 76 | "binary_precision_recall_curve", 77 | "binary_recall", 78 | "binary_recall_at_fixed_precision", 79 | "bleu_score", 80 | "click_through_rate", 81 | "frequency_at_k", 82 | "gaussian_frechet_distance", 83 | "hit_rate", 84 | "mean", 85 | "mean_squared_error", 86 | "multiclass_accuracy", 87 | "multiclass_auprc", 88 | "multiclass_auroc", 89 | "multiclass_binned_auprc", 90 | "multiclass_binned_auroc", 91 | "multiclass_binned_precision_recall_curve", 92 | "multiclass_confusion_matrix", 93 | "multiclass_f1_score", 94 | "multiclass_precision", 95 | "multiclass_precision_recall_curve", 96 | "multiclass_recall", 97 | "multilabel_accuracy", 98 | "multilabel_auprc", 99 | "multilabel_binned_auprc", 100 | "multilabel_binned_precision_recall_curve", 101 | "multilabel_precision_recall_curve", 102 | "multilabel_recall_at_fixed_precision", 103 | "num_collisions", 104 | "peak_signal_noise_ratio", 105 | "perplexity", 106 | "r2_score", 107 | "reciprocal_rank", 108 | "retrieval_precision", 109 | "retrieval_recall", 110 | "sum", 111 | "throughput", 112 | "topk_multilabel_accuracy", 113 | "weighted_calibration", 114 | "word_error_rate", 115 | "word_information_preserved", 116 | "word_information_lost", 117 | ] 118 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.aggregation.auc import auc 10 | 11 | from torcheval.metrics.functional.aggregation.mean import mean 12 | 13 | from torcheval.metrics.functional.aggregation.sum import sum 14 | 15 | from torcheval.metrics.functional.aggregation.throughput import throughput 16 | 17 | 18 | __all__ = ["auc", "mean", "sum", "throughput"] 19 | __doc_name__ = "Aggregation Metrics" 20 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/aggregation/auc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | 11 | 12 | def _auc_compute( 13 | x: torch.Tensor, y: torch.Tensor, reorder: bool = False 14 | ) -> torch.Tensor: 15 | """Computes area under the curve using the trapezoidal rule. 16 | Args: 17 | x: x-coordinates, 18 | y: y-coordinates 19 | reorder: sorts the x input tensor in order, default value is False 20 | Return: 21 | Tensor containing AUC score (float) 22 | """ 23 | if x.numel() == 0 or y.numel() == 0: 24 | return torch.tensor([]) 25 | 26 | if x.ndim == 1: 27 | x = x.unsqueeze(0) 28 | if y.ndim == 1: 29 | y = y.unsqueeze(0) 30 | 31 | if reorder: 32 | x, x_idx = torch.sort(x, dim=1, stable=True) 33 | y = y.gather(1, x_idx) 34 | 35 | return torch.trapz(y, x) 36 | 37 | 38 | def _auc_update_input_check(x: torch.Tensor, y: torch.Tensor, n_tasks: int = 1) -> None: 39 | """ 40 | Checks if the 2 input tensors have the same shape 41 | Checks if the 2 input tensors have atleast 1 elements. 42 | Args: 43 | x: x-coordinates 44 | y: y-coordinates 45 | n_tasks: Number of tasks that need AUC calculation. Default value is 1. 46 | """ 47 | 48 | size_x = x.size() 49 | size_y = y.size() 50 | 51 | if x.ndim == 1: 52 | x = x.unsqueeze(0) 53 | if y.ndim == 1: 54 | y = y.unsqueeze(0) 55 | 56 | if x.numel() == 0 or y.numel() == 0: 57 | raise ValueError( 58 | f"The `x` and `y` should have atleast 1 element, got shapes {size_x} and {size_y}." 59 | ) 60 | if x.size() != y.size(): 61 | raise ValueError( 62 | f"Expected the same shape in `x` and `y` tensor but got shapes {size_x} and {size_y}." 63 | ) 64 | 65 | if x.size(0) != n_tasks or y.size(0) != n_tasks: 66 | raise ValueError( 67 | f"Expected `x` dim_1={x.size(0)} and `y` dim_1={y.size(0)} have first dimension equals to n_tasks={n_tasks}." 68 | ) 69 | 70 | 71 | def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: 72 | """Computes Area Under the Curve (AUC) using the trapezoidal rule. 73 | Args: 74 | x: x-coordinates 75 | y: y-coordinates 76 | reorder: sorts the x input tensor in order, default value is False 77 | Return: 78 | Tensor containing AUC score (float) 79 | Raises: 80 | ValueError: 81 | If both ``x`` and ``y`` don't have the same shape. 82 | If both ``x`` and ``y`` have atleast 1 element. 83 | Example: 84 | >>> from torcheval.metrics.functional.aggregation.auc import auc 85 | >>> x = torch.tensor([0,.1,.2,.3]) 86 | >>> y = torch.tensor([1,1,1,1]) 87 | >>> auc(x, y) 88 | tensor([0.3000]) 89 | >>> y = torch.tensor([[0, 4, 0, 4, 3], 90 | [1, 1, 2, 1, 1], 91 | [4, 3, 1, 4, 4], 92 | [1, 0, 0, 3, 0]]) 93 | >>> x = torch.tensor([[0.2535, 0.1138, 0.1324, 0.1887, 0.3117], 94 | [0.1434, 0.4404, 0.1100, 0.1178, 0.1883], 95 | [0.2344, 0.1743, 0.3110, 0.0393, 0.2410], 96 | [0.1381, 0.1564, 0.0320, 0.2220, 0.4515]]) 97 | >>> auc(x, y, reorder=True) # Reorders X and calculates AUC. 98 | tensor([0.3667, 0.3343, 0.8843, 0.5048]) 99 | """ 100 | n_tasks = 1 101 | if x.ndim > 1: 102 | n_tasks = x.size(0) 103 | _auc_update_input_check(x, y, n_tasks) 104 | return _auc_compute(x, y, reorder) 105 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/aggregation/mean.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def mean( 15 | input: torch.Tensor, 16 | weight: float | int | torch.Tensor = 1.0, 17 | ) -> torch.Tensor: 18 | """ 19 | Compute weighted mean. When weight is not provided, it calculates the unweighted mean. 20 | Its class version is ``torcheval.metrics.Mean``. 21 | 22 | weighted_mean = sum(weight * input) / sum(weight) 23 | 24 | Args: 25 | input (Tensor): Tensor of input values. 26 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size. 27 | Raises: 28 | ValueError: If value of weight is neither a ``float`` nor a ``int`` nor a ``torch.Tensor`` that matches the input tensor size. 29 | 30 | Examples:: 31 | 32 | >>> import torch 33 | >>> from torcheval.metrics.functional import mean 34 | >>> mean(torch.tensor([2, 3])) 35 | tensor(2.5) 36 | >>> mean(torch.tensor([2, 3]), torch.tensor([0.2, 0.8])) 37 | tensor(2.8) 38 | >>> mean(torch.tensor([2, 3]), 0.5) 39 | tensor(2.5) 40 | >>> mean(torch.tensor([2, 3]), 1) 41 | tensor(2.5) 42 | """ 43 | return _mean_compute(input, weight) 44 | 45 | 46 | def _mean_update( 47 | input: torch.Tensor, weight: float | int | torch.Tensor 48 | ) -> tuple[torch.Tensor, torch.Tensor]: 49 | if isinstance(weight, float) or isinstance(weight, int): 50 | weighted_sum = weight * torch.sum(input) 51 | weights = torch.tensor(float(weight) * torch.numel(input)) 52 | return weighted_sum, weights 53 | elif isinstance(weight, torch.Tensor) and input.size() == weight.size(): 54 | return torch.sum(weight * input), torch.sum(weight) 55 | else: 56 | raise ValueError( 57 | "Weight must be either a float value or a tensor that matches the input tensor size. " 58 | f"Got {weight} instead." 59 | ) 60 | 61 | 62 | def _mean_compute( 63 | input: torch.Tensor, weight: float | int | torch.Tensor 64 | ) -> torch.Tensor: 65 | weighted_sum, weights = _mean_update(input, weight) 66 | return weighted_sum / weights 67 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/aggregation/sum.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def sum( 15 | input: torch.Tensor, 16 | weight: float | torch.Tensor = 1.0, 17 | ) -> torch.Tensor: 18 | """ 19 | Compute weighted sum. When weight is not provided, it calculates the unweighted sum. 20 | Its class version is ``torcheval.metrics.Sum``. 21 | 22 | Args: 23 | input (Tensor): Tensor of input values. 24 | weight(optional): Float or Int or Tensor of input weights. It is default to 1.0. If weight is a Tensor, its size should match the input tensor size. 25 | Raises: 26 | ValueError: If value of weight is neither a ``float`` nor an ``int`` nor a ``torch.Tensor`` that matches the input tensor size. 27 | 28 | Examples:: 29 | 30 | >>> import torch 31 | >>> from torcheval.metrics.functional import sum 32 | >>> sum(torch.tensor([2, 3])) 33 | tensor(5.) 34 | >>> sum(torch.tensor([2, 3]), torch.tensor([0.1, 0.6])) 35 | tensor(2.) 36 | >>> sum(torch.tensor([2, 3]), 0.5) 37 | tensor(2.5) 38 | >>> sum(torch.tensor([2, 3]), 2) 39 | tensor(10.) 40 | """ 41 | return _sum_update(input, weight) 42 | 43 | 44 | def _sum_update( 45 | input: torch.Tensor, weight: float | int | torch.Tensor 46 | ) -> torch.Tensor: 47 | if ( 48 | isinstance(weight, float) 49 | or isinstance(weight, int) 50 | or (isinstance(weight, torch.Tensor) and input.size() == weight.size()) 51 | ): 52 | return (input * weight).sum() 53 | else: 54 | raise ValueError( 55 | "Weight must be either a float value or an int value or a tensor that matches the input tensor size. " 56 | f"Got {weight} instead." 57 | ) 58 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/aggregation/throughput.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def throughput( 15 | num_processed: int = 0, 16 | elapsed_time_sec: float = 0.0, 17 | ) -> torch.Tensor: 18 | """ 19 | Calculate the throughput value which is the number of elements processed per second. 20 | Its class version is ``torcheval.metrics.Throughput``. 21 | 22 | Args: 23 | num_processed (int): Number of items processed. 24 | elapsed_time_sec (float): Total elapsed time in seconds to process ``num_processed`` items. 25 | Raises: 26 | ValueError: 27 | If ``num_processed`` is a negative number. 28 | If ``elapsed_time_sec`` is a non-positive number. 29 | 30 | Examples:: 31 | 32 | >>> import torch 33 | >>> from torcheval.metrics.functional import throughput 34 | >>> throughput(64, 2.0) 35 | tensor(32.) 36 | """ 37 | return _throughput_compute(num_processed, elapsed_time_sec) 38 | 39 | 40 | def _throughput_compute(num_processed: int, elapsed_time_sec: float) -> torch.Tensor: 41 | if num_processed < 0: 42 | raise ValueError( 43 | f"Expected num_processed to be a non-negative number, but received {num_processed}." 44 | ) 45 | if elapsed_time_sec <= 0: 46 | raise ValueError( 47 | f"Expected elapsed_time_sec to be a positive number, but received {elapsed_time_sec}." 48 | ) 49 | return torch.tensor(num_processed / elapsed_time_sec) 50 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.classification.accuracy import ( 10 | binary_accuracy, 11 | multiclass_accuracy, 12 | multilabel_accuracy, 13 | topk_multilabel_accuracy, 14 | ) 15 | from torcheval.metrics.functional.classification.auprc import ( 16 | binary_auprc, 17 | multiclass_auprc, 18 | multilabel_auprc, 19 | ) 20 | 21 | from torcheval.metrics.functional.classification.auroc import ( 22 | binary_auroc, 23 | multiclass_auroc, 24 | ) 25 | 26 | from torcheval.metrics.functional.classification.binary_normalized_entropy import ( 27 | binary_normalized_entropy, 28 | ) 29 | from torcheval.metrics.functional.classification.binned_auprc import ( 30 | binary_binned_auprc, 31 | multiclass_binned_auprc, 32 | multilabel_binned_auprc, 33 | ) 34 | from torcheval.metrics.functional.classification.binned_auroc import ( 35 | binary_binned_auroc, 36 | multiclass_binned_auroc, 37 | ) 38 | from torcheval.metrics.functional.classification.binned_precision_recall_curve import ( 39 | binary_binned_precision_recall_curve, 40 | multiclass_binned_precision_recall_curve, 41 | multilabel_binned_precision_recall_curve, 42 | ) 43 | from torcheval.metrics.functional.classification.confusion_matrix import ( 44 | binary_confusion_matrix, 45 | multiclass_confusion_matrix, 46 | ) 47 | from torcheval.metrics.functional.classification.f1_score import ( 48 | binary_f1_score, 49 | multiclass_f1_score, 50 | ) 51 | from torcheval.metrics.functional.classification.precision import ( 52 | binary_precision, 53 | multiclass_precision, 54 | ) 55 | from torcheval.metrics.functional.classification.precision_recall_curve import ( 56 | binary_precision_recall_curve, 57 | multiclass_precision_recall_curve, 58 | multilabel_precision_recall_curve, 59 | ) 60 | from torcheval.metrics.functional.classification.recall import ( 61 | binary_recall, 62 | multiclass_recall, 63 | ) 64 | from torcheval.metrics.functional.classification.recall_at_fixed_precision import ( 65 | binary_recall_at_fixed_precision, 66 | multilabel_recall_at_fixed_precision, 67 | ) 68 | 69 | __all__ = [ 70 | "binary_accuracy", 71 | "binary_auprc", 72 | "binary_auroc", 73 | "binary_binned_auprc", 74 | "binary_binned_auroc", 75 | "binary_binned_precision_recall_curve", 76 | "binary_confusion_matrix", 77 | "binary_f1_score", 78 | "binary_normalized_entropy", 79 | "binary_precision", 80 | "binary_precision_recall_curve", 81 | "binary_recall", 82 | "binary_recall_at_fixed_precision", 83 | "multiclass_accuracy", 84 | "multiclass_auprc", 85 | "multiclass_auroc", 86 | "multiclass_binned_auprc", 87 | "multiclass_binned_auroc", 88 | "multiclass_binned_precision_recall_curve", 89 | "multiclass_confusion_matrix", 90 | "multiclass_f1_score", 91 | "multiclass_precision", 92 | "multiclass_precision_recall_curve", 93 | "multiclass_recall", 94 | "multilabel_accuracy", 95 | "multilabel_auprc", 96 | "multilabel_binned_auprc", 97 | "multilabel_binned_precision_recall_curve", 98 | "multilabel_precision_recall_curve", 99 | "multilabel_recall_at_fixed_precision", 100 | "topk_multilabel_accuracy", 101 | ] 102 | __doc_name__ = "Classification Metrics" 103 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/frechet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | import torch 9 | 10 | 11 | def gaussian_frechet_distance( 12 | mu_x: torch.Tensor, cov_x: torch.Tensor, mu_y: torch.Tensor, cov_y: torch.Tensor 13 | ) -> torch.Tensor: 14 | r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`. 15 | 16 | The Fréchet distance is also known as the Wasserstein-2 distance. 17 | 18 | Concretely, for multivariate Gaussians :math:`X(\mu_X, \cov_X)` 19 | and :math:`Y(\mu_Y, \cov_Y)`, the function computes and returns :math:`F` as 20 | 21 | .. math:: 22 | F(X, Y) = || \mu_X - \mu_Y ||_2^2 23 | + \text{Tr}\left( \cov_X + \cov_Y - 2 \sqrt{\cov_X \cov_Y} \right) 24 | 25 | Args: 26 | mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`. 27 | cov_x (torch.Tensor): covariance matrix :math:`\cov_X` of :math:`X`, with shape `(N, N)`. 28 | mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`. 29 | cov_y (torch.Tensor): covariance matrix :math:`\cov_Y` of :math:`Y`, with shape `(N, N)`. 30 | 31 | Returns: 32 | torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`. 33 | """ 34 | if mu_x.ndim != 1: 35 | msg = f"Input mu_x must be one-dimensional; got dimension {mu_x.ndim}." 36 | raise ValueError(msg) 37 | if mu_y.ndim != 1: 38 | msg = f"Input mu_y must be one-dimensional; got dimension {mu_y.ndim}." 39 | raise ValueError(msg) 40 | if cov_x.ndim != 2: 41 | msg = f"Input cov_x must be two-dimensional; got dimension {cov_x.ndim}." 42 | raise ValueError(msg) 43 | if cov_y.ndim != 2: 44 | msg = f"Input cov_x must be two-dimensional; got dimension {cov_y.ndim}." 45 | raise ValueError(msg) 46 | if mu_x.shape != mu_y.shape: 47 | msg = f"Inputs mu_x and mu_y must have the same shape; got {mu_x.shape} and {mu_y.shape}." 48 | raise ValueError(msg) 49 | if cov_x.shape != cov_y.shape: 50 | msg = f"Inputs cov_x and cov_y must have the same shape; got {cov_x.shape} and {cov_y.shape}." 51 | raise ValueError(msg) 52 | 53 | a = (mu_x - mu_y).square().sum() 54 | b = cov_x.trace() + cov_y.trace() 55 | c = torch.linalg.eigvals(cov_x @ cov_y).sqrt().real.sum() 56 | return a + b - 2 * c 57 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.image.psnr import peak_signal_noise_ratio 10 | 11 | __all__ = ["peak_signal_noise_ratio"] 12 | __doc_name__ = "Image Metrics" 13 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/image/psnr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def peak_signal_noise_ratio( 15 | input: torch.Tensor, 16 | target: torch.Tensor, 17 | data_range: float | None = None, 18 | ) -> torch.Tensor: 19 | """ 20 | Compute the peak signal-to-noise ratio between two images. 21 | It's class version is `torcheval.metrics.PeakSignalNoiseRatio` 22 | 23 | Args: 24 | input (Tensor): Input image ``(N, C, H, W)``. 25 | target (Tensor): Target image ``(N, C, H, W)``. 26 | data_range (float): the range of the input images. Default: None. 27 | If None, the input range computed from the target data ``(target.max() - targert.min())``. 28 | Examples:: 29 | 30 | >>> import torch 31 | >>> from torcheval.metrics.functional import peak_signal_noise_ratio 32 | >>> input = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) 33 | >>> target = input * 0.9 34 | >>> peak_signal_noise_ratio(input, target) 35 | tensor(19.8767) 36 | """ 37 | _psnr_param_check(data_range) 38 | 39 | if data_range is None: 40 | data_range_tensor = torch.max(target) - torch.min(target) 41 | else: 42 | data_range_tensor = torch.tensor(data=data_range, device=target.device) 43 | 44 | sum_square_error, num_observations = _psnr_update(input, target) 45 | psnr = _psnr_compute(sum_square_error, num_observations, data_range_tensor) 46 | return psnr 47 | 48 | 49 | def _psnr_param_check(data_range: float | None) -> None: 50 | # Check matching shapes 51 | if data_range is not None: 52 | if type(data_range) is not float: 53 | raise ValueError("`data_range needs to be either `None` or `float`.") 54 | if data_range <= 0: 55 | raise ValueError("`data_range` needs to be positive.") 56 | 57 | 58 | def _psnr_input_check(input: torch.Tensor, target: torch.Tensor) -> None: 59 | # Check matching shapes 60 | if input.shape != target.shape: 61 | raise ValueError( 62 | "The `input` and `target` must have the same shape, " 63 | f"got shapes {input.shape} and {target.shape}." 64 | ) 65 | 66 | 67 | def _psnr_update( 68 | input: torch.Tensor, target: torch.Tensor 69 | ) -> tuple[torch.Tensor, torch.Tensor]: 70 | _psnr_input_check(input, target) 71 | sum_squared_error = torch.sum(torch.pow(input - target, 2)) 72 | num_observations = torch.tensor(target.numel(), device=target.device) 73 | return sum_squared_error, num_observations 74 | 75 | 76 | def _psnr_compute( 77 | sum_square_error: torch.Tensor, 78 | num_observations: torch.Tensor, 79 | data_range: torch.Tensor, 80 | ) -> torch.Tensor: 81 | mse = sum_square_error / num_observations 82 | psnr = 10 * torch.log10(torch.pow(data_range, 2) / mse) 83 | 84 | return psnr 85 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.ranking.click_through_rate import click_through_rate 10 | from torcheval.metrics.functional.ranking.frequency import frequency_at_k 11 | 12 | from torcheval.metrics.functional.ranking.hit_rate import hit_rate 13 | 14 | from torcheval.metrics.functional.ranking.num_collisions import num_collisions 15 | 16 | from torcheval.metrics.functional.ranking.reciprocal_rank import reciprocal_rank 17 | 18 | from torcheval.metrics.functional.ranking.retrieval_precision import retrieval_precision 19 | 20 | from torcheval.metrics.functional.ranking.retrieval_recall import retrieval_recall 21 | 22 | from torcheval.metrics.functional.ranking.weighted_calibration import ( 23 | weighted_calibration, 24 | ) 25 | 26 | __all__ = [ 27 | "click_through_rate", 28 | "frequency_at_k", 29 | "hit_rate", 30 | "num_collisions", 31 | "reciprocal_rank", 32 | "weighted_calibration", 33 | "retrieval_precision", 34 | "retrieval_recall", 35 | ] 36 | __doc_name__ = "Ranking Metrics" 37 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/click_through_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def click_through_rate( 15 | input: torch.Tensor, 16 | weights: torch.Tensor | None = None, 17 | *, 18 | num_tasks: int = 1, 19 | ) -> torch.Tensor: 20 | """ 21 | Compute the click through rate given a click events. 22 | Its class version is ``torcheval.metrics.ClickThroughRate``. 23 | 24 | Args: 25 | input (Tensor): Series of values representing user click (1) or skip (0) 26 | of shape (num_events) or (num_objectives, num_events). 27 | weights (Tensor, Optional): Weights for each event, tensor with the same shape as input. 28 | num_tasks (int): Number of tasks that need weighted_calibration calculation. Default value 29 | is 1. 30 | 31 | Examples:: 32 | 33 | >>> import torch 34 | >>> from torcheval.metrics.functional import click_through_rate 35 | >>> input = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1]) 36 | >>> click_through_rate(input) 37 | tensor(0.5) 38 | >>> weights = torch.tensor([1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]) 39 | >>> click_through_rate(input, weights) 40 | tensor(0.58333) 41 | >>> input = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]]) 42 | >>> weights = torch.tensor([[1.0, 2.0, 1.0, 2.0],[1.0, 2.0, 1.0, 1.0]]) 43 | >>> click_through_rate(input, weights, num_tasks=2) 44 | tensor([0.6667, 0.4]) 45 | """ 46 | if weights is None: 47 | weights = 1.0 48 | click_total, weight_total = _click_through_rate_update( 49 | input, weights, num_tasks=num_tasks 50 | ) 51 | return _click_through_rate_compute(click_total, weight_total) 52 | 53 | 54 | def _click_through_rate_update( 55 | input: torch.Tensor, 56 | weights: torch.Tensor | float | int = 1.0, 57 | *, 58 | num_tasks: int, 59 | ) -> tuple[torch.Tensor, torch.Tensor]: 60 | _click_through_rate_input_check(input, weights, num_tasks=num_tasks) 61 | if isinstance(weights, torch.Tensor): 62 | weights = weights.type(torch.float) 63 | click_total = (input * weights).sum(-1) 64 | weight_total = weights.sum(-1) 65 | else: 66 | click_total = weights * input.sum(-1).type(torch.float) 67 | weight_total = weights * input.size(-1) * torch.ones_like(click_total) 68 | 69 | return click_total, weight_total 70 | 71 | 72 | def _click_through_rate_compute( 73 | click_total: torch.Tensor, 74 | weight_total: torch.Tensor, 75 | ) -> torch.Tensor: 76 | # epsilon is a performant solution to divide by zero errors when weight_total = 0.0 77 | # Since click_total = input*weights, weights = 0.0 implies 0.0/(0.0 + eps) = 0.0 78 | eps = torch.finfo(weight_total.dtype).tiny 79 | return click_total / (weight_total + eps) 80 | 81 | 82 | def _click_through_rate_input_check( 83 | input: torch.Tensor, 84 | weights: torch.Tensor | float | int, 85 | *, 86 | num_tasks: int, 87 | ) -> None: 88 | if input.ndim != 1 and input.ndim != 2: 89 | raise ValueError( 90 | f"`input` should be a one or two dimensional tensor, got shape {input.shape}." 91 | ) 92 | if isinstance(weights, torch.Tensor) and weights.shape != input.shape: 93 | raise ValueError( 94 | f"tensor `weights` should have the same shape as tensor `input`, got shapes {weights.shape} and {input.shape}, respectively." 95 | ) 96 | if num_tasks == 1: 97 | if len(input.shape) > 1: 98 | raise ValueError( 99 | f"`num_tasks = 1`, `input` is expected to be one-dimensional tensor, but got shape ({input.shape})." 100 | ) 101 | elif len(input.shape) == 1 or input.shape[0] != num_tasks: 102 | raise ValueError( 103 | f"`num_tasks = {num_tasks}`, `input`'s shape is expected to be ({num_tasks}, num_samples), but got shape ({input.shape})." 104 | ) 105 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/frequency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def frequency_at_k( 15 | input: torch.Tensor, 16 | k: float, 17 | ) -> torch.Tensor: 18 | """ 19 | Calculate the frequency given a list of frequencies and threshold k. 20 | Generate a binary list to indicate if frequencies is less than k. 21 | 22 | Args: 23 | input (Tensor): Predicted unnormalized scores (often referred to as logits). 24 | k (float): Threshold of the frequency. k should not negative value. 25 | 26 | Example: 27 | >>> import torch 28 | >>> from torcheval.metrics.functional import frequency 29 | >>> input = torch.tensor([0.3, 0.1, 0.6]) 30 | >>> frequency(input, k=0.5) 31 | tensor([1.0000, 1.0000, 0.0000]) 32 | """ 33 | _frequency_input_check(input, k) 34 | 35 | return (input < k).float() 36 | 37 | 38 | def _frequency_input_check(input: torch.Tensor, k: float) -> None: 39 | if input.ndim != 1: 40 | raise ValueError( 41 | f"input should be a one-dimensional tensor, got shape {input.shape}." 42 | ) 43 | if k < 0: 44 | raise ValueError(f"k should not be negative, got {k}.") 45 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/hit_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def hit_rate( 15 | input: torch.Tensor, 16 | target: torch.Tensor, 17 | *, 18 | k: int | None = None, 19 | ) -> torch.Tensor: 20 | """ 21 | Compute the hit rate of the correct class among the top predicted classes. 22 | Its class version is ``torcheval.metrics.HitRate``. 23 | 24 | Args: 25 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or 26 | class probabilities of shape (num_samples, num_classes). 27 | target (Tensor): Ground truth class indices of shape (num_samples,). 28 | k (int, optional): Number of top predicted classes to be considered. 29 | If k is None, all classes are considered and a hit rate of 1.0 is returned. 30 | 31 | Examples:: 32 | 33 | >>> import torch 34 | >>> from torcheval.metrics.functional import hit_rate 35 | >>> input = torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3], [0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]) 36 | >>> target = torch.tensor([2, 1, 1, 0]) 37 | >>> hit_rate(input, target, k=2) 38 | tensor([1.0000, 0.0000, 0.0000, 1.0000]) 39 | """ 40 | _hit_rate_input_check(input, target, k) 41 | if k is None or k >= input.size(dim=-1): 42 | return input.new_ones(target.size()) 43 | 44 | y_score = torch.gather(input, dim=-1, index=target.unsqueeze(dim=-1)) 45 | rank = torch.gt(input, y_score).sum(dim=-1) 46 | return (rank < k).float() 47 | 48 | 49 | def _hit_rate_input_check( 50 | input: torch.Tensor, target: torch.Tensor, k: int | None = None 51 | ) -> None: 52 | if target.ndim != 1: 53 | raise ValueError( 54 | f"target should be a one-dimensional tensor, got shape {target.shape}." 55 | ) 56 | if input.ndim != 2: 57 | raise ValueError( 58 | f"input should be a two-dimensional tensor, got shape {input.shape}." 59 | ) 60 | if input.shape[0] != target.shape[0]: 61 | raise ValueError( 62 | "`input` and `target` should have the same minibatch dimension, ", 63 | f"got shapes {input.shape} and {target.shape}, respectively.", 64 | ) 65 | if k is not None and k <= 0: 66 | raise ValueError(f"k should be None or positive, got {k}.") 67 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/num_collisions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | 11 | 12 | @torch.inference_mode() 13 | def num_collisions(input: torch.Tensor) -> torch.Tensor: 14 | """ 15 | Compute the number of collisions given a list of input(ids). 16 | 17 | Args: 18 | input (Tensor): a tensor of input ids (num_samples, ). 19 | class probabilities of shape (num_samples, num_classes). 20 | 21 | Examples:: 22 | 23 | >>> import torch 24 | >>> from torcheval.metrics.functional import num_collisions 25 | >>> input = torch.tensor([3, 4, 2, 3]) 26 | >>> num_collisions(input) 27 | tensor([1, 0, 0, 1]) 28 | >>> input = torch.tensor([3, 4, 1, 3, 1, 1, 5]) 29 | >>> num_collisions(input) 30 | tensor([1, 0, 2, 1, 2, 2, 0]) 31 | """ 32 | _num_collisions_input_check(input) 33 | 34 | input_for_logits = input.view(1, -1).repeat_interleave(torch.numel(input), dim=0) 35 | num_collisions = (input_for_logits == input.view(-1, 1)).sum( 36 | dim=1, keepdim=True 37 | ) - 1 38 | return num_collisions.view(-1) 39 | 40 | 41 | def _num_collisions_input_check(input: torch.Tensor) -> None: 42 | if input.ndim != 1: 43 | raise ValueError( 44 | f"input should be a one-dimensional tensor, got shape {input.shape}." 45 | ) 46 | 47 | if input.dtype not in ( 48 | torch.int, 49 | torch.int8, 50 | torch.int16, 51 | torch.int32, 52 | torch.int64, 53 | ): 54 | raise ValueError(f"input should be an integer tensor, got {input.dtype}.") 55 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/ranking/reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | @torch.inference_mode() 14 | def reciprocal_rank( 15 | input: torch.Tensor, 16 | target: torch.Tensor, 17 | *, 18 | k: int | None = None, 19 | ) -> torch.Tensor: 20 | """ 21 | Compute the reciprocal rank of the correct class among the top predicted classes. 22 | Its class version is ``torcheval.metrics.ReciprocalRank``. 23 | 24 | Args: 25 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or 26 | class probabilities of shape (num_samples, num_classes). 27 | target (Tensor): Ground truth class indices of shape (num_samples,). 28 | k (int, optional): Number of top class probabilities to be considered. 29 | 30 | Examples:: 31 | 32 | >>> import torch 33 | >>> from torcheval.metrics.functional import reciprocal_rank 34 | >>> input = torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3], [0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]) 35 | >>> target = torch.tensor([2, 1, 1, 0]) 36 | >>> reciprocal_rank(input, target) 37 | tensor([1.0000, 0.3333, 0.3333, 0.5000]) 38 | >>> reciprocal_rank(input, target, k=2) 39 | tensor([1.0000, 0.0000, 0.0000, 0.5000]) 40 | """ 41 | _reciprocal_rank_input_check(input, target) 42 | 43 | y_score = torch.gather(input, dim=-1, index=target.unsqueeze(dim=-1)) 44 | rank = torch.gt(input, y_score).sum(dim=-1) 45 | score = torch.reciprocal(rank + 1.0) 46 | if k is not None: 47 | score[rank >= k] = 0.0 48 | return score 49 | 50 | 51 | def _reciprocal_rank_input_check(input: torch.Tensor, target: torch.Tensor) -> None: 52 | if target.ndim != 1: 53 | raise ValueError( 54 | f"target should be a one-dimensional tensor, got shape {target.shape}." 55 | ) 56 | if input.ndim != 2: 57 | raise ValueError( 58 | f"input should be a two-dimensional tensor, got shape {input.shape}." 59 | ) 60 | if input.shape[0] != target.shape[0]: 61 | raise ValueError( 62 | "`input` and `target` should have the same minibatch dimension, ", 63 | f"got shapes {input.shape} and {target.shape}, respectively.", 64 | ) 65 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/regression/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from torcheval.metrics.functional.regression.mean_squared_error import ( 12 | mean_squared_error, 13 | ) 14 | 15 | from torcheval.metrics.functional.regression.r2_score import r2_score 16 | 17 | __all__ = ["mean_squared_error", "r2_score"] 18 | __doc_name__ = "Regression Metrics" 19 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/statistical/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 8 | 9 | from torcheval.metrics.functional.statistical.wasserstein import wasserstein_1d 10 | 11 | __all__ = ["wasserstein_1d"] 12 | __doc_name__ = "Statistical Metrics" 13 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | def _riemann_integral(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 14 | """Riemann integral approximates the area of each cell with a rectangle positioned at the egde. 15 | It is conventionally used rather than trapezoid approximation, which uses a rectangle positioned in the 16 | center""" 17 | return -torch.sum((x[1:] - x[:-1]) * y[:-1]) 18 | 19 | 20 | def _create_threshold_tensor( 21 | threshold: int | list[float] | torch.Tensor, 22 | device: torch.device, 23 | ) -> torch.Tensor: 24 | """ 25 | Creates a threshold tensor from an integer, a list or a tensor. 26 | If `threshold` is an integer n, returns a Tensor with values [0, 1/(n-1), 2/(n-1), ..., (n-2)/(n-1), 1]. 27 | If `threshold` is a list, returns the list converted to a Tensor. 28 | Otherwise, returns the tensor itself. 29 | """ 30 | if isinstance(threshold, int): 31 | threshold = torch.linspace(0, 1.0, threshold, device=device) 32 | elif isinstance(threshold, list): 33 | threshold = torch.tensor(threshold, device=device) 34 | return threshold 35 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.functional.text.bleu import bleu_score 10 | 11 | from torcheval.metrics.functional.text.perplexity import perplexity 12 | 13 | from torcheval.metrics.functional.text.word_error_rate import word_error_rate 14 | 15 | from torcheval.metrics.functional.text.word_information_lost import ( 16 | word_information_lost, 17 | ) 18 | 19 | from torcheval.metrics.functional.text.word_information_preserved import ( 20 | word_information_preserved, 21 | ) 22 | 23 | __all__ = [ 24 | "bleu_score", 25 | "perplexity", 26 | "word_error_rate", 27 | "word_information_preserved", 28 | "word_information_lost", 29 | ] 30 | __doc_name__ = "Text Metrics" 31 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/text/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | 13 | def _edit_distance( 14 | prediction_tokens: list[str], 15 | reference_tokens: list[str], 16 | ) -> int: 17 | """ 18 | Dynamic programming algorithm to compute the edit distance between two word sequences. 19 | 20 | Args: 21 | prediction_tokens (List[str]): A tokenized predicted sentence 22 | reference_tokens (List[str]): A tokenized reference sentence 23 | """ 24 | dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)] 25 | for i in range(len(prediction_tokens) + 1): 26 | dp[i][0] = i 27 | for j in range(len(reference_tokens) + 1): 28 | dp[0][j] = j 29 | for i in range(1, len(prediction_tokens) + 1): 30 | for j in range(1, len(reference_tokens) + 1): 31 | if prediction_tokens[i - 1] == reference_tokens[j - 1]: 32 | dp[i][j] = dp[i - 1][j - 1] 33 | else: 34 | dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1 35 | return dp[-1][-1] 36 | 37 | 38 | def _get_errors_and_totals( 39 | input: str | list[str], 40 | target: str | list[str], 41 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 42 | """ 43 | Calculate the edit distance, max length and lengths of predicted and reference word sequences. 44 | 45 | Args: 46 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings. 47 | target (str, List[str]): Reference word sequence(s) as a string or list of strings. 48 | """ 49 | if isinstance(input, str): 50 | input = [input] 51 | if isinstance(target, str): 52 | target = [target] 53 | max_total = torch.tensor(0.0, dtype=torch.float64) 54 | errors = torch.tensor(0.0, dtype=torch.float64) 55 | target_total = torch.tensor(0.0, dtype=torch.float64) 56 | input_total = torch.tensor(0.0, dtype=torch.float64) 57 | for ipt, tgt in zip(input, target): 58 | input_tokens = ipt.split() 59 | target_tokens = tgt.split() 60 | errors += _edit_distance(input_tokens, target_tokens) 61 | target_total += len(target_tokens) 62 | input_total += len(input_tokens) 63 | max_total += max(len(target_tokens), len(input_tokens)) 64 | 65 | return errors, max_total, target_total, input_total 66 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/text/word_information_lost.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | from torcheval.metrics.functional.text.helper import _get_errors_and_totals 13 | 14 | 15 | def _wil_update( 16 | input: str | list[str], 17 | target: str | list[str], 18 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 19 | """Update the wil score with the current set of references and predictions. 20 | Args: 21 | input: Transcription(s) to score as a string or list of strings 22 | target: Reference(s) for each speech input as a string or list of strings 23 | Returns: 24 | Number of correct words 25 | Number of words overall references 26 | Number of words overall predictions 27 | """ 28 | if isinstance(input, str): 29 | input = [input] 30 | if isinstance(target, str): 31 | target = [target] 32 | assert ( 33 | len(input) == len(target) 34 | ), f"Arguments must contain the same number of strings, but got len(input)={len(input)} and len(target)={len(target)}" 35 | errors, max_total, target_total, input_total = _get_errors_and_totals(input, target) 36 | return errors - max_total, target_total, input_total 37 | 38 | 39 | def _wil_compute( 40 | correct_total: torch.Tensor, target_total: torch.Tensor, preds_total: torch.Tensor 41 | ) -> torch.Tensor: 42 | """Compute the Word Information Lost. 43 | Args: 44 | correct_total: Number of correct words 45 | target_total: Number of words overall references 46 | preds_total: Number of words overall prediction 47 | Returns: 48 | Word Information Lost score 49 | """ 50 | return 1 - ((correct_total / target_total) * (correct_total / preds_total)) 51 | 52 | 53 | @torch.inference_mode() 54 | def word_information_lost( 55 | input: str | list[str], 56 | target: str | list[str], 57 | ) -> torch.Tensor: 58 | """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This 59 | value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better 60 | the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. 61 | 62 | Its class version is ``torcheval.metrics.WordInformationLost``. 63 | 64 | Args: 65 | input: Transcription(s) to score as a string or list of strings 66 | target: Reference(s) for each speech input as a string or list of strings 67 | Returns: 68 | Word Information Lost rate 69 | Examples: 70 | >>> from torcheval.metrics.functional import word_information_lost 71 | >>> input = ["this is the prediction", "there is an other sample"] 72 | >>> target = ["this is the reference", "there is another one"] 73 | >>> word_information_lost(input, target) 74 | tensor(0.6528) 75 | """ 76 | correct_total, target_total, preds_total = _wil_update(input, target) 77 | return _wil_compute(correct_total, target_total, preds_total) 78 | -------------------------------------------------------------------------------- /torcheval/metrics/functional/text/word_information_preserved.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | 10 | import torch 11 | 12 | from torcheval.metrics.functional.text.helper import _get_errors_and_totals 13 | 14 | 15 | @torch.inference_mode() 16 | def word_information_preserved( 17 | input: str | list[str], 18 | target: str | list[str], 19 | ) -> torch.Tensor: 20 | """ 21 | Compute the word information preserved score of the predicted word sequence(s) against the reference word sequence(s). 22 | Its class version is ``torcheval.metrics.WordInformationPreserved``. 23 | 24 | Args: 25 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings. 26 | target (str, List[str]): Reference word sequence(s) as a string or list of strings. 27 | 28 | Examples: 29 | 30 | >>> import torch 31 | >>> from torcheval.metrics.functional import word_information_preserved 32 | >>> input = ["hello world", "welcome to the facebook"] 33 | >>> target = ["hello metaverse", "welcome to meta"] 34 | >>> word_information_preserved(input, target) 35 | tensor(0.3) 36 | >>> input = ["this is the prediction", "there is an other sample"] 37 | >>> target = ["this is the reference", "there is another one"] 38 | >>> word_information_preserved(input, target) 39 | tensor(0.3472) 40 | """ 41 | correct_total, target_total, input_total = _word_information_preserved_update( 42 | input, target 43 | ) 44 | return _word_information_preserved_compute(correct_total, target_total, input_total) 45 | 46 | 47 | def _word_information_preserved_update( 48 | input: str | list[str], 49 | target: str | list[str], 50 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 51 | """ 52 | Update the word information preserved score with current set of predictions and references. 53 | 54 | Args: 55 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings. 56 | target (str, List[str]): Reference word sequence(s) as a string or list of strings. 57 | """ 58 | _word_information_preserved_input_check(input, target) 59 | errors, max_total, target_total, input_total = _get_errors_and_totals(input, target) 60 | 61 | return max_total - errors, target_total, input_total 62 | 63 | 64 | def _word_information_preserved_compute( 65 | correct_total: torch.Tensor, target_total: torch.Tensor, input_total: torch.Tensor 66 | ) -> torch.Tensor: 67 | """ 68 | Return the word information preserved score 69 | 70 | Args: 71 | correct_total (Tensor): number of words that are correctly predicted, summed over all samples 72 | target_total (Tensor): length of reference sequence, summed over all samples. 73 | input_total (Tensor): length of predicted sequence, summed over all samples. 74 | """ 75 | return (correct_total / target_total) * (correct_total / input_total) 76 | 77 | 78 | def _word_information_preserved_input_check( 79 | input: str | list[str], 80 | target: str | list[str], 81 | ) -> None: 82 | if type(input) != type(target): 83 | raise ValueError( 84 | f"input and target should have the same type, got {type(input)} and {type(target)}." 85 | ) 86 | if type(input) == list: 87 | if len(input) != len(target): 88 | raise ValueError( 89 | f"input and target lists should have the same length, got {len(input)} and {len(target)}", 90 | ) 91 | -------------------------------------------------------------------------------- /torcheval/metrics/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.image.fid import FrechetInceptionDistance 10 | from torcheval.metrics.image.psnr import PeakSignalNoiseRatio 11 | from torcheval.metrics.image.ssim import StructuralSimilarity 12 | 13 | __all__ = [ 14 | "FrechetInceptionDistance", 15 | "PeakSignalNoiseRatio", 16 | "StructuralSimilarity", 17 | ] 18 | __doc_name__ = "Image Metrics" 19 | -------------------------------------------------------------------------------- /torcheval/metrics/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.ranking.click_through_rate import ClickThroughRate 10 | from torcheval.metrics.ranking.hit_rate import HitRate 11 | from torcheval.metrics.ranking.reciprocal_rank import ReciprocalRank 12 | from torcheval.metrics.ranking.retrieval_precision import RetrievalPrecision 13 | from torcheval.metrics.ranking.retrieval_recall import RetrievalRecall 14 | from torcheval.metrics.ranking.weighted_calibration import WeightedCalibration 15 | 16 | __all__ = [ 17 | "ClickThroughRate", 18 | "HitRate", 19 | "ReciprocalRank", 20 | "RetrievalPrecision", 21 | "RetrievalRecall", 22 | "WeightedCalibration", 23 | ] 24 | __doc_name__ = "Ranking Metrics" 25 | -------------------------------------------------------------------------------- /torcheval/metrics/ranking/hit_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.functional import hit_rate 17 | from torcheval.metrics.metric import Metric 18 | 19 | THitRate = TypeVar("THitRate") 20 | 21 | 22 | class HitRate(Metric[torch.Tensor]): 23 | """ 24 | Compute the hit rate of the correct class among the top predicted classes. 25 | Its functional version is :func:`torcheval.metrics.functional.hit_rate`. 26 | 27 | Args: 28 | k (int, optional): Number of top class probabilities to be considered. 29 | If k is None, all classes are considered and a hit rate of 1.0 is returned. 30 | 31 | Examples:: 32 | 33 | >>> import torch 34 | >>> from torcheval.metrics import HitRate 35 | 36 | >>> metric = HitRate() 37 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) 38 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) 39 | >>> metric.compute() 40 | tensor([1., 1., 1., 1.]) 41 | 42 | >>> metric = HitRate(k=2) 43 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) 44 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) 45 | >>> metric.compute() 46 | tensor([1., 0., 0., 1.]) 47 | """ 48 | 49 | def __init__( 50 | self: THitRate, 51 | *, 52 | k: int | None = None, 53 | device: torch.device | None = None, 54 | ) -> None: 55 | super().__init__(device=device) 56 | self.k = k 57 | self._add_state("scores", []) 58 | 59 | @torch.inference_mode() 60 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. 61 | def update(self: THitRate, input: torch.Tensor, target: torch.Tensor) -> THitRate: 62 | """ 63 | Update the metric state with the ground truth labels and predictions. 64 | 65 | Args: 66 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or 67 | class probabilities of shape (num_samples, num_classes). 68 | target (Tensor): Ground truth class indices of shape (num_samples,). 69 | """ 70 | self.scores.append(hit_rate(input, target, k=self.k)) 71 | return self 72 | 73 | @torch.inference_mode() 74 | def compute(self: THitRate) -> torch.Tensor: 75 | """ 76 | Return the concatenated hite rate scores. If no ``update()`` calls are made before 77 | ``compute()`` is called, return an empty tensor. 78 | """ 79 | if not self.scores: 80 | return torch.empty(0) 81 | return torch.cat(self.scores, dim=0) 82 | 83 | @torch.inference_mode() 84 | def merge_state(self: THitRate, metrics: Iterable[THitRate]) -> THitRate: 85 | """ 86 | Merge the metric state with its counterparts from other metric instances. 87 | 88 | Args: 89 | metrics (Iterable[Metric]): metric instances whose states are to be merged. 90 | """ 91 | for metric in metrics: 92 | if metric.scores: 93 | self.scores.append(torch.cat(metric.scores).to(self.device)) 94 | return self 95 | 96 | @torch.inference_mode() 97 | def _prepare_for_merge_state(self: THitRate) -> None: 98 | if self.scores: 99 | self.scores = [torch.cat(self.scores)] 100 | -------------------------------------------------------------------------------- /torcheval/metrics/ranking/reciprocal_rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.functional import reciprocal_rank 17 | from torcheval.metrics.metric import Metric 18 | 19 | 20 | TReciprocalRank = TypeVar("TReciprocalRank") 21 | 22 | 23 | class ReciprocalRank(Metric[torch.Tensor]): 24 | """ 25 | Compute the reciprocal rank of the correct class among the top predicted classes. 26 | Its functional version is :func:`torcheval.metrics.functional.reciprocal_rank`. 27 | 28 | Args: 29 | k (int, optional): Number of top class probabilities to be considered. 30 | 31 | Examples:: 32 | 33 | >>> import torch 34 | >>> from torcheval.metrics import ReciprocalRank 35 | 36 | >>> metric = ReciprocalRank() 37 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) 38 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) 39 | >>> metric.compute() 40 | tensor([1.0000, 0.3333, 0.3333, 0.5000]) 41 | 42 | >>> metric = ReciprocalRank(k=2) 43 | >>> metric.update(torch.tensor([[0.3, 0.1, 0.6], [0.5, 0.2, 0.3]]), torch.tensor([2, 1])) 44 | >>> metric.update(torch.tensor([[0.2, 0.1, 0.7], [0.3, 0.3, 0.4]]), torch.tensor([1, 0])) 45 | >>> metric.compute() 46 | tensor([1.0000, 0.0000, 0.0000, 0.5000]) 47 | """ 48 | 49 | def __init__( 50 | self: TReciprocalRank, 51 | *, 52 | k: int | None = None, 53 | device: torch.device | None = None, 54 | ) -> None: 55 | super().__init__(device=device) 56 | self.k = k 57 | self._add_state("scores", []) 58 | 59 | @torch.inference_mode() 60 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. 61 | def update( 62 | self: TReciprocalRank, input: torch.Tensor, target: torch.Tensor 63 | ) -> TReciprocalRank: 64 | """ 65 | Update the metric state with the ground truth labels and predictions. 66 | 67 | Args: 68 | input (Tensor): Predicted unnormalized scores (often referred to as logits) or 69 | class probabilities of shape (num_samples, num_classes). 70 | target (Tensor): Ground truth class indices of shape (num_samples,). 71 | """ 72 | self.scores.append(reciprocal_rank(input, target, k=self.k)) 73 | return self 74 | 75 | @torch.inference_mode() 76 | def compute(self: TReciprocalRank) -> torch.Tensor: 77 | """ 78 | Return the concatenated reciprocal rank scores. If no ``update()`` calls are made before 79 | ``compute()`` is called, return an empty tensor. 80 | """ 81 | if not self.scores: 82 | return torch.empty(0) 83 | return torch.cat(self.scores, dim=0) 84 | 85 | @torch.inference_mode() 86 | def merge_state( 87 | self: TReciprocalRank, metrics: Iterable[TReciprocalRank] 88 | ) -> TReciprocalRank: 89 | """ 90 | Merge the metric state with its counterparts from other metric instances. 91 | 92 | Args: 93 | metrics (Iterable[Metric]): metric instances whose states are to be merged. 94 | """ 95 | for metric in metrics: 96 | if metric.scores: 97 | self.scores.append(torch.cat(metric.scores).to(self.device)) 98 | return self 99 | 100 | @torch.inference_mode() 101 | def _prepare_for_merge_state(self: TReciprocalRank) -> None: 102 | if self.scores: 103 | self.scores = [torch.cat(self.scores)] 104 | -------------------------------------------------------------------------------- /torcheval/metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.regression.mean_squared_error import MeanSquaredError 10 | from torcheval.metrics.regression.r2_score import R2Score 11 | 12 | __all__ = ["MeanSquaredError", "R2Score"] 13 | __doc_name__ = "Regression Metrics" 14 | -------------------------------------------------------------------------------- /torcheval/metrics/statistical/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 8 | 9 | from torcheval.metrics.statistical.wasserstein import Wasserstein1D 10 | 11 | __all__ = ["Wasserstein1D"] 12 | __doc_name__ = "Statistical Metrics" 13 | -------------------------------------------------------------------------------- /torcheval/metrics/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.text.bleu import BLEUScore 10 | from torcheval.metrics.text.perplexity import Perplexity 11 | from torcheval.metrics.text.word_error_rate import WordErrorRate 12 | from torcheval.metrics.text.word_information_lost import WordInformationLost 13 | from torcheval.metrics.text.word_information_preserved import WordInformationPreserved 14 | 15 | __all__ = [ 16 | "BLEUScore", 17 | "Perplexity", 18 | "WordErrorRate", 19 | "WordInformationLost", 20 | "WordInformationPreserved", 21 | ] 22 | __doc_name__ = "Text Metrics" 23 | -------------------------------------------------------------------------------- /torcheval/metrics/text/word_error_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections.abc import Iterable 12 | from typing import TypeVar 13 | 14 | import torch 15 | 16 | from torcheval.metrics.functional.text.word_error_rate import ( 17 | _word_error_rate_compute, 18 | _word_error_rate_update, 19 | ) 20 | from torcheval.metrics.metric import Metric 21 | 22 | TWordErrorRate = TypeVar("TWordErrorRate") 23 | 24 | 25 | class WordErrorRate(Metric[torch.Tensor]): 26 | """ 27 | Compute the word error rate of the predicted word sequence(s) with the reference word sequence(s). 28 | Its functional version is :func:`torcheval.metrics.functional.word_error_rate`. 29 | 30 | Examples: 31 | 32 | >>> import torch 33 | >>> from torcheval.metrics import WordErrorRate 34 | 35 | >>> metric = WordErrorRate() 36 | >>> metric.update(["this is the prediction", "there is an other sample"], 37 | ["this is the reference", "there is another one"]) 38 | >>> metric.compute() 39 | tensor(0.5) 40 | 41 | >>> metric = WordErrorRate() 42 | >>> metric.update(["this is the prediction", "there is an other sample"], 43 | ["this is the reference", "there is another one"]) 44 | >>> metric.update(["hello world", "welcome to the facebook"], 45 | ["hello metaverse", "welcome to meta"]) 46 | >>> metric.compute() 47 | tensor(0.53846) 48 | """ 49 | 50 | def __init__( 51 | self: TWordErrorRate, 52 | *, 53 | device: torch.device | None = None, 54 | ) -> None: 55 | super().__init__(device=device) 56 | self._add_state( 57 | "errors", torch.tensor(0, dtype=torch.float, device=self.device) 58 | ) 59 | self._add_state("total", torch.tensor(0, dtype=torch.float, device=self.device)) 60 | 61 | @torch.inference_mode() 62 | # pyre-ignore[14]: `update` overrides method defined in `Metric` inconsistently. 63 | def update( 64 | self: TWordErrorRate, 65 | input: str | list[str], 66 | target: str | list[str], 67 | ) -> TWordErrorRate: 68 | """ 69 | Update the metric state with edit distance and the length of the reference sequence. 70 | 71 | Args: 72 | input (str, List[str]): Predicted word sequence(s) to score as a string or list of strings. 73 | target (str, List[str]): Reference word sequence(s) as a string or list of strings. 74 | """ 75 | errors, total = _word_error_rate_update(input, target) 76 | self.errors += errors 77 | self.total += total 78 | return self 79 | 80 | @torch.inference_mode() 81 | def compute(self: TWordErrorRate) -> torch.Tensor: 82 | """ 83 | Return the word error rate score 84 | """ 85 | return _word_error_rate_compute(self.errors, self.total) 86 | 87 | @torch.inference_mode() 88 | def merge_state( 89 | self: TWordErrorRate, 90 | metrics: Iterable[TWordErrorRate], 91 | ) -> TWordErrorRate: 92 | """ 93 | Merge the metric state with its counterparts from other metric instances. 94 | 95 | Args: 96 | metrics (Iterable[Metric]): metric instances whose states are to be merged. 97 | """ 98 | for metric in metrics: 99 | self.errors += metric.errors.to(self.device) 100 | self.total += metric.total.to(self.device) 101 | return self 102 | -------------------------------------------------------------------------------- /torcheval/metrics/window/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.metrics.window.auroc import WindowedBinaryAUROC 10 | from torcheval.metrics.window.click_through_rate import WindowedClickThroughRate 11 | from torcheval.metrics.window.mean_squared_error import WindowedMeanSquaredError 12 | from torcheval.metrics.window.normalized_entropy import WindowedBinaryNormalizedEntropy 13 | from torcheval.metrics.window.weighted_calibration import WindowedWeightedCalibration 14 | 15 | __all__ = [ 16 | "WindowedBinaryAUROC", 17 | "WindowedBinaryNormalizedEntropy", 18 | "WindowedClickThroughRate", 19 | "WindowedMeanSquaredError", 20 | "WindowedWeightedCalibration", 21 | ] 22 | __doc_name__ = "Windowed Metrics" 23 | -------------------------------------------------------------------------------- /torcheval/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torcheval/5736efbbd6642ab8f42689f4783d271780a26432/torcheval/py.typed -------------------------------------------------------------------------------- /torcheval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from torcheval.utils.random_data import ( 10 | get_rand_data_binary, 11 | get_rand_data_binned_binary, 12 | get_rand_data_multiclass, 13 | ) 14 | 15 | __all__ = [ 16 | "get_rand_data_binary", 17 | "get_rand_data_binned_binary", 18 | "get_rand_data_multiclass", 19 | ] 20 | -------------------------------------------------------------------------------- /torcheval/utils/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torcheval/utils/test_utils/dummy_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | # pyre-ignore-all-errors[16]: Undefined attribute of metric states. 10 | 11 | from collections import defaultdict 12 | from collections.abc import Iterable 13 | from typing import TypeVar 14 | 15 | import torch 16 | 17 | from torcheval.metrics import Metric 18 | 19 | TDummySumMetric = TypeVar("TDummySumMetric") 20 | 21 | 22 | class DummySumMetric(Metric[torch.Tensor]): 23 | def __init__(self: TDummySumMetric, *, device: torch.device | None = None) -> None: 24 | super().__init__(device=device) 25 | self._add_state("sum", torch.tensor(0.0, device=self.device)) 26 | 27 | @torch.inference_mode() 28 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 29 | def update(self: TDummySumMetric, x: torch.Tensor) -> TDummySumMetric: 30 | self.sum += x 31 | return self 32 | 33 | @torch.inference_mode() 34 | def compute(self: TDummySumMetric) -> torch.Tensor: 35 | return self.sum 36 | 37 | @torch.inference_mode() 38 | def merge_state( 39 | self: TDummySumMetric, metrics: Iterable[TDummySumMetric] 40 | ) -> TDummySumMetric: 41 | for metric in metrics: 42 | self.sum += metric.sum.to(self.device) 43 | return self 44 | 45 | 46 | TDummySumListStateMetric = TypeVar("TDummySumListStateMetric") 47 | 48 | 49 | class DummySumListStateMetric(Metric[torch.Tensor]): 50 | def __init__( 51 | self: TDummySumListStateMetric, *, device: torch.device | None = None 52 | ) -> None: 53 | super().__init__(device=device) 54 | self._add_state("x", []) 55 | 56 | @torch.inference_mode() 57 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 58 | def update( 59 | self: TDummySumListStateMetric, x: torch.Tensor 60 | ) -> TDummySumListStateMetric: 61 | self.x.append(x.to(self.device)) 62 | return self 63 | 64 | @torch.inference_mode() 65 | def compute(self: TDummySumListStateMetric) -> torch.Tensor: 66 | # pyre-fixme[7]: Expected `Tensor` but got `int`. 67 | return sum(tensor.sum() for tensor in self.x) 68 | 69 | @torch.inference_mode() 70 | def merge_state( 71 | self: TDummySumListStateMetric, metrics: Iterable[TDummySumListStateMetric] 72 | ) -> TDummySumListStateMetric: 73 | for metric in metrics: 74 | self.x.extend(element.to(self.device) for element in metric.x) 75 | return self 76 | 77 | 78 | TDummySumDictStateMetric = TypeVar("TDummySumDictStateMetric") 79 | 80 | 81 | class DummySumDictStateMetric(Metric[torch.Tensor]): 82 | def __init__( 83 | self: TDummySumDictStateMetric, *, device: torch.device | None = None 84 | ) -> None: 85 | super().__init__(device=device) 86 | self._add_state("x", defaultdict(lambda: torch.tensor(0.0, device=self.device))) 87 | 88 | @torch.inference_mode() 89 | # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any 90 | def update( 91 | self: TDummySumDictStateMetric, 92 | k: str, 93 | v: torch.Tensor, 94 | ) -> TDummySumDictStateMetric: 95 | self.x[k] += v 96 | return self 97 | 98 | @torch.inference_mode() 99 | def compute(self: TDummySumDictStateMetric) -> torch.Tensor: 100 | return self.x 101 | 102 | @torch.inference_mode() 103 | def merge_state( 104 | self: TDummySumDictStateMetric, metrics: Iterable[TDummySumDictStateMetric] 105 | ) -> TDummySumDictStateMetric: 106 | for metric in metrics: 107 | for k in metric.keys(): 108 | self.x[k] += metric.x[k].to(self.device) 109 | 110 | return self 111 | -------------------------------------------------------------------------------- /torcheval/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # Follows PEP-0440 version scheme guidelines 11 | # https://www.python.org/dev/peps/pep-0440/#version-scheme 12 | # 13 | # Examples: 14 | # 0.1.0.devN # Developmental release 15 | # 0.1.0aN # Alpha release 16 | # 0.1.0bN # Beta release 17 | # 0.1.0rcN # Release Candidate 18 | # 0.1.0 # Final release 19 | __version__: str = "0.0.7" 20 | --------------------------------------------------------------------------------