├── .github └── workflows │ ├── JOSS-pdf.yml │ ├── ci.yml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── JOSS │ ├── paper.bib │ └── paper.md ├── Makefile ├── README.md ├── _static │ └── images │ │ ├── jetnetlogo.key │ │ ├── jetnetlogo.pdf │ │ ├── jetnetlogo.png │ │ ├── jetnetlogo_white.key │ │ └── jetnetlogo_white.png ├── conf.py ├── index.rst ├── make.bat ├── pages │ ├── contents.rst │ ├── datasets.rst │ ├── losses.rst │ ├── metrics.rst │ ├── tutorials │ └── utils.rst └── requirements.txt ├── jetnet ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── jetnet.py │ ├── normalisations.py │ ├── qgjets.py │ ├── toptagging.py │ └── utils.py ├── evaluation │ ├── __init__.py │ ├── fpnd_resources │ │ ├── __init__.py │ │ └── jetnet │ │ │ ├── 30_particles │ │ │ ├── __init__.py │ │ │ ├── g_mu.txt │ │ │ ├── g_sigma.txt │ │ │ ├── pnet_state_dict.pt │ │ │ ├── q_mu.txt │ │ │ ├── q_sigma.txt │ │ │ ├── t_mu.txt │ │ │ └── t_sigma.txt │ │ │ └── __init__.py │ ├── gen_metrics.py │ └── particlenet.py ├── losses │ ├── __init__.py │ └── losses.py └── utils │ ├── __init__.py │ ├── coord_transform.py │ └── utils.py ├── pyproject.toml ├── setup.py ├── tests ├── datasets │ ├── test_jetnet.py │ ├── test_normalisations.py │ ├── test_qgjets.py │ ├── test_toptagging.py │ └── test_utils.py ├── evaluation │ └── test_gen_metrics.py └── utils │ └── test_image.py └── tutorials └── pyhep-data-access.ipynb /.github/workflows/JOSS-pdf.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | paths: 4 | - docs/JOSS/** 5 | 6 | jobs: 7 | paper: 8 | runs-on: ubuntu-latest 9 | name: JOSS Draft 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v3 13 | - name: Build draft JOSS PDF 14 | uses: openjournals/openjournals-draft-action@master 15 | with: 16 | journal: joss 17 | # This should be the path to the paper within your repo. 18 | paper-path: docs/JOSS/paper.md 19 | - name: Upload 20 | uses: actions/upload-artifact@v1 21 | with: 22 | name: paper 23 | # This is the output path where Pandoc will write the compiled 24 | # PDF. Note, this should be the same directory as the input 25 | # paper.md 26 | path: docs/JOSS/paper.pdf 27 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | run-linters: 7 | name: Run linters 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Check out Git repository 12 | uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: 3.9 18 | 19 | - name: Install Python dependencies 20 | run: pip install black 21 | 22 | - name: Run linters 23 | uses: wearerequired/lint-action@v2 24 | with: 25 | auto_fix: true 26 | black: true 27 | black_auto_fix: true 28 | 29 | # TODO: might want to consider running tests only for changed files at some point https://github.com/marketplace/actions/changed-files 30 | 31 | pytest: 32 | runs-on: ubuntu-latest 33 | 34 | strategy: 35 | matrix: 36 | python-version: ["3.9", "3.12"] 37 | pytest-file: 38 | [ 39 | "tests/datasets/test_jetnet.py", 40 | "tests/datasets/test_normalisations.py", 41 | "tests/datasets/test_utils.py", 42 | "tests/utils/test_image.py", 43 | "tests/evaluation/test_gen_metrics.py", 44 | ] 45 | 46 | steps: 47 | - uses: actions/checkout@v3 48 | - name: Set up Python ${{ matrix.python-version }} 49 | uses: actions/setup-python@v4 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | - name: Install dependencies 53 | run: | 54 | python -m pip install --upgrade pip 55 | pip install pytest pytest-xdist 56 | pip install -e . 57 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 58 | - name: Test with pytest 59 | run: | 60 | pytest -n 0 --durations 0 -v -m "not slow" ${{ matrix.pytest-file }} 61 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | if: github.repository_owner == 'jet-net' 18 | 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | with: 24 | ref: main 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: "3.x" 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install build 35 | 36 | - name: Build package 37 | run: python -m build 38 | 39 | - name: Publish package 40 | uses: pypa/gh-action-pypi-publish@release/v1 41 | with: 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # setuptools_scm 141 | src/*/_version.py 142 | 143 | 144 | # ruff 145 | .ruff_cache/ 146 | 147 | # OS specific stuff 148 | .DS_Store 149 | .DS_Store? 150 | ._* 151 | .Spotlight-V100 152 | .Trashes 153 | ehthumbs.db 154 | Thumbs.db 155 | 156 | # Common editor files 157 | *~ 158 | *.swp 159 | 160 | 161 | .DS_Store 162 | **/.DS_Store 163 | docs/_build 164 | /build 165 | **/__pycache__ 166 | /dist 167 | *egg-info 168 | *test*.py 169 | *test*.ipynb 170 | /datasets 171 | tutorials/datasets 172 | .vscode 173 | 174 | !tests 175 | !tests/**/*.py 176 | 177 | /test.ipynb 178 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | autoupdate_schedule: "monthly" 5 | 6 | exclude: ".*key" 7 | 8 | repos: 9 | - repo: https://github.com/psf/black-pre-commit-mirror 10 | rev: 25.1.0 11 | hooks: 12 | - id: black-jupyter 13 | language_version: python3 14 | args: [--line-length=100] 15 | 16 | - repo: https://github.com/adamchainz/blacken-docs 17 | rev: "1.19.1" 18 | hooks: 19 | - id: blacken-docs 20 | additional_dependencies: [black==23.*] 21 | 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: "v5.0.0" 24 | hooks: 25 | - id: check-added-large-files 26 | - id: check-case-conflict 27 | - id: check-merge-conflict 28 | - id: check-symlinks 29 | - id: check-yaml 30 | - id: debug-statements 31 | - id: end-of-file-fixer 32 | - id: mixed-line-ending 33 | - id: name-tests-test 34 | args: ["--pytest-test-first"] 35 | - id: requirements-txt-fixer 36 | - id: trailing-whitespace 37 | 38 | - repo: https://github.com/pre-commit/pygrep-hooks 39 | rev: "v1.10.0" 40 | hooks: 41 | - id: rst-backticks 42 | - id: rst-directive-colons 43 | - id: rst-inline-touching-normal 44 | 45 | - repo: https://github.com/pre-commit/mirrors-prettier 46 | rev: "v4.0.0-alpha.8" 47 | hooks: 48 | - id: prettier 49 | types_or: [yaml, markdown, html, css, scss, javascript, json] 50 | args: [--prose-wrap=preserve] 51 | 52 | - repo: https://github.com/astral-sh/ruff-pre-commit 53 | rev: "v0.11.12" 54 | hooks: 55 | - id: ruff 56 | args: ["--fix", "--show-fixes"] 57 | 58 | # - repo: https://github.com/pre-commit/mirrors-mypy 59 | # rev: "v1.6.1" 60 | # hooks: 61 | # - id: mypy 62 | # files: src|tests 63 | # args: [] 64 | # additional_dependencies: 65 | # - pytest 66 | 67 | - repo: https://github.com/codespell-project/codespell 68 | rev: "v2.4.1" 69 | hooks: 70 | - id: codespell 71 | types_or: [python, rst, markdown] 72 | 73 | - repo: https://github.com/shellcheck-py/shellcheck-py 74 | rev: "v0.10.0.1" 75 | hooks: 76 | - id: shellcheck 77 | 78 | - repo: https://github.com/abravalheri/validate-pyproject 79 | rev: v0.24.1 80 | hooks: 81 | - id: validate-pyproject 82 | 83 | - repo: https://github.com/python-jsonschema/check-jsonschema 84 | rev: 0.33.0 85 | hooks: 86 | - id: check-dependabot 87 | - id: check-github-workflows 88 | - id: check-readthedocs 89 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Include PDF and ePub 19 | formats: all 20 | 21 | python: 22 | install: 23 | - requirements: docs/requirements.txt 24 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: "1.2.0" 2 | authors: 3 | - family-names: Kansal 4 | given-names: Raghav 5 | orcid: "https://orcid.org/0000-0003-2445-1060" 6 | - family-names: Pareja 7 | given-names: Carlos 8 | orcid: "https://orcid.org/0000-0002-9022-2349" 9 | - family-names: Hao 10 | given-names: Zichun 11 | orcid: "https://orcid.org/0000-0002-5624-4907" 12 | - family-names: Duarte 13 | given-names: Javier 14 | orcid: "https://orcid.org/0000-0002-5076-7096" 15 | contact: 16 | - family-names: Kansal 17 | given-names: Raghav 18 | orcid: "https://orcid.org/0000-0003-2445-1060" 19 | doi: 10.5281/zenodo.10044601 20 | message: If you use this library for your research, please cite our article in the Journal of Open Source Software. 21 | preferred-citation: 22 | authors: 23 | - family-names: Kansal 24 | given-names: Raghav 25 | orcid: "https://orcid.org/0000-0003-2445-1060" 26 | - family-names: Pareja 27 | given-names: Carlos 28 | orcid: "https://orcid.org/0000-0002-9022-2349" 29 | - family-names: Hao 30 | given-names: Zichun 31 | orcid: "https://orcid.org/0000-0002-5624-4907" 32 | - family-names: Duarte 33 | given-names: Javier 34 | orcid: "https://orcid.org/0000-0002-5076-7096" 35 | date-published: 2023-10-30 36 | doi: 10.21105/joss.05789 37 | issn: 2475-9066 38 | issue: 90 39 | journal: Journal of Open Source Software 40 | publisher: 41 | name: Open Journals 42 | start: 5789 43 | title: "JetNet: A Python package for accessing open datasets and 44 | benchmarking machine learning methods in high energy physics" 45 | type: article 46 | url: "https://joss.theoj.org/papers/10.21105/joss.05789" 47 | volume: 8 48 | title: "JetNet: A Python package for accessing open datasets and 49 | benchmarking machine learning methods in high energy physics" 50 | version: "v0.2.4" 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Raghav Kansal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | recursive-include jetnet/evaluation/fpnd_resources *.txt *.pt 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
6 | For developing and reproducing ML + HEP projects. 7 |
8 | 9 | --- 10 | 11 |12 | JetNet • 13 | Installation • 14 | Quickstart • 15 | Documentation • 16 | Contributing • 17 | Citation • 18 | References 19 |
20 | 21 | --- 22 | 23 | [](https://github.com/jet-net/jetnet/actions) 24 | [](https://jetnet.readthedocs.io/en/latest/) 25 | [](https://github.com/psf/black) 26 | [](https://results.pre-commit.ci/latest/github/jet-net/JetNet/main) 27 | 28 | [](https://pypi.org/project/jetnet/) 29 | [](https://pepy.tech/project/jetnet) 30 | [](https://doi.org/10.5281/zenodo.10044601) 31 | [](https://doi.org/10.21105/joss.05789) 32 | 33 | --- 34 | 35 | ## JetNet 36 | 37 | JetNet is an effort to increase accessibility and reproducibility in jet-based machine learning. 38 | 39 | Currently we provide: 40 | 41 | - Easy-to-access and standardised interfaces for the following datasets: 42 | - [JetNet](https://zenodo.org/record/6975118) 43 | - [TopTagging](https://zenodo.org/record/2603256) 44 | - [QuarkGluon](https://zenodo.org/record/3164691) 45 | - Standard implementations of generative evaluation metrics (Ref. [[1, 2](#references)]), including: 46 | - Fréchet physics distance (FPD) 47 | - Kernel physics distance (KPD) 48 | - Wasserstein-1 (W1) 49 | - Fréchet ParticleNet Distance (FPND) 50 | - coverage and minimum matching distance (MMD) 51 | - Loss functions: 52 | - Differentiable implementation of the energy mover's distance [[3](#references)] 53 | - And more general jet utilities. 54 | 55 | Additional functionality is under development, and please reach out if you're interested in contributing! 56 | 57 | ## Installation 58 | 59 | JetNet can be installed with pip: 60 | 61 | ```bash 62 | pip install jetnet 63 | ``` 64 | 65 | To use the differentiable EMD loss `jetnet.losses.EMDLoss`, additional libraries must be installed via 66 | 67 | ```bash 68 | pip install "jetnet[emdloss]" 69 | ``` 70 | 71 | Finally, [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) must be installed independently for the Fréchet ParticleNet Distance metric `jetnet.evaluation.fpnd` ([Installation instructions](https://github.com/pyg-team/pytorch_geometric#installation)). 72 | 73 | ## Quickstart 74 | 75 | Datasets can be downloaded and accessed quickly, for example: 76 | 77 | ```python 78 | from jetnet.datasets import JetNet, TopTagging 79 | 80 | # as numpy arrays: 81 | particle_data, jet_data = JetNet.getData( 82 | jet_type=["g", "q"], data_dir="./datasets/jetnet/", download=True 83 | ) 84 | # or as a PyTorch dataset: 85 | dataset = TopTagging( 86 | jet_type="all", data_dir="./datasets/toptagging/", split="train", download=True 87 | ) 88 | ``` 89 | 90 | Evaluation metrics can be used as such: 91 | 92 | ```python 93 | generated_jets = np.random.rand(50000, 30, 3) 94 | fpnd_score = jetnet.evaluation.fpnd(generated_jets, jet_type="g") 95 | ``` 96 | 97 | Loss functions can be initialized and used similarly to standard PyTorch in-built losses such as MSE: 98 | 99 | ```python 100 | emd_loss = jetnet.losses.EMDLoss(num_particles=30) 101 | loss = emd_loss(real_jets, generated_jets) 102 | loss.backward() 103 | ``` 104 | 105 | ## Documentation 106 | 107 | The full API reference and tutorials are available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/). 108 | Tutorial notebooks are in the [tutorials](https://github.com/jet-net/JetNet/tree/main/tutorials) folder, with more to come. 109 | 110 | 111 | 112 | ## Contributing 113 | 114 | We welcome feedback and contributions! Please feel free to [create an issue](https://github.com/jet-net/JetNet/issues/new) for bugs or functionality requests, or open [pull requests](https://github.com/jet-net/JetNet/pulls) from your [forked repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo) to solve them. 115 | 116 | ### Building and testing locally 117 | 118 | Perform an editable installation of the package from inside your forked repo and install the `pytest` package for unit testing: 119 | 120 | ```bash 121 | pip install -e . 122 | pip install pytest 123 | ``` 124 | 125 | Run the test suite to ensure everything is working as expected: 126 | 127 | ```bash 128 | pytest tests # tests all datasets 129 | pytest tests -m "not slow" # tests only on the JetNet dataset for convenience 130 | ``` 131 | 132 | ## Citation 133 | 134 | If you use this library for your research, please cite our article in the Journal of Open Source Software: 135 | 136 | ``` 137 | @article{Kansal_JetNet_2023, 138 | author = {Kansal, Raghav and Pareja, Carlos and Hao, Zichun and Duarte, Javier}, 139 | doi = {10.21105/joss.05789}, 140 | journal = {Journal of Open Source Software}, 141 | number = {90}, 142 | pages = {5789}, 143 | title = {{JetNet: A Python package for accessing open datasets and benchmarking machine learning methods in high energy physics}}, 144 | url = {https://joss.theoj.org/papers/10.21105/joss.05789}, 145 | volume = {8}, 146 | year = {2023} 147 | } 148 | ``` 149 | 150 | Please further cite the following if you use these components of the library. 151 | 152 | ### JetNet dataset or FPND 153 | 154 | ``` 155 | @inproceedings{Kansal_MPGAN_2021, 156 | author = {Kansal, Raghav and Duarte, Javier and Su, Hao and Orzari, Breno and Tomei, Thiago and Pierini, Maurizio and Touranakou, Mary and Vlimant, Jean-Roch and Gunopulos, Dimitrios}, 157 | booktitle = "{Advances in Neural Information Processing Systems}", 158 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 159 | pages = {23858--23871}, 160 | publisher = {Curran Associates, Inc.}, 161 | title = {Particle Cloud Generation with Message Passing Generative Adversarial Networks}, 162 | url = {https://proceedings.neurips.cc/paper_files/paper/2021/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf}, 163 | volume = {34}, 164 | year = {2021}, 165 | eprint = {2106.11535}, 166 | archivePrefix = {arXiv}, 167 | } 168 | ``` 169 | 170 | ### FPD or KPD 171 | 172 | ``` 173 | @article{Kansal_Evaluating_2023, 174 | author = {Kansal, Raghav and Li, Anni and Duarte, Javier and Chernyavskaya, Nadezda and Pierini, Maurizio and Orzari, Breno and Tomei, Thiago}, 175 | title = {Evaluating generative models in high energy physics}, 176 | reportNumber = "FERMILAB-PUB-22-872-CMS-PPD", 177 | doi = "10.1103/PhysRevD.107.076017", 178 | journal = "{Phys. Rev. D}", 179 | volume = "107", 180 | number = "7", 181 | pages = "076017", 182 | year = "2023", 183 | eprint = "2211.10295", 184 | archivePrefix = "arXiv", 185 | } 186 | ``` 187 | 188 | ### EMD Loss 189 | 190 | Please cite the respective [qpth](https://locuslab.github.io/qpth/) or [cvxpy](https://github.com/cvxpy/cvxpy) libraries, depending on the method used (`qpth` by default), as well as the original EMD paper [[3]](#references). 191 | 192 | ## References 193 | 194 | [1] R. Kansal et al., _Particle Cloud Generation with Message Passing Generative Adversarial Networks_, [NeurIPS 2021](https://proceedings.neurips.cc/paper/2021/hash/c8512d142a2d849725f31a9a7a361ab9-Abstract.html) [[2106.11535](https://arxiv.org/abs/2106.11535)]. 195 | 196 | [2] R. Kansal et al., _Evaluating Generative Models in High Energy Physics_, [Phys. Rev. D **107** (2023) 076017](https://doi.org/10.1103/PhysRevD.107.076017) [[2211.10295](https://arxiv.org/abs/2211.10295)]. 197 | 198 | [3] P. T. Komiske, E. M. Metodiev, and J. Thaler, _The Metric Space of Collider Events_, [Phys. Rev. Lett. **123** (2019) 041801](https://doi.org/10.1103/PhysRevLett.123.041801) [[1902.02346](https://arxiv.org/abs/1902.02346)]. 199 | -------------------------------------------------------------------------------- /docs/JOSS/paper.bib: -------------------------------------------------------------------------------- 1 | @article{Buhmann:2023pmh, 2 | archiveprefix = {arXiv}, 3 | author = {Buhmann, Erik and Kasieczka, Gregor and Thaler, Jesse}, 4 | doi = {10.21468/SciPostPhys.15.4.130}, 5 | eprint = {2301.08128}, 6 | journal = {SciPost Phys.}, 7 | pages = {130}, 8 | primaryclass = {hep-ph}, 9 | reportnumber = {MIT-CTP 5519}, 10 | title = {{EPiC-GAN: Equivariant Point Cloud Generation for Particle Jets}}, 11 | volume = {15}, 12 | year = {2023} 13 | } 14 | 15 | 16 | @article{Chen:2021euv, 17 | archiveprefix = {arXiv}, 18 | author = {Chen, Yifan and others}, 19 | doi = {10.1038/s41597-021-01109-0}, 20 | eprint = {2108.02214}, 21 | journal = {Sci. Data}, 22 | pages = {31}, 23 | primaryclass = {hep-ex}, 24 | title = {{A FAIR and AI-ready Higgs boson decay dataset}}, 25 | volume = {9}, 26 | year = {2022} 27 | } 28 | 29 | 30 | @article{Hao:2022zns, 31 | archiveprefix = {arXiv}, 32 | author = {Hao, Zichun and Kansal, Raghav and Duarte, Javier and Chernyavskaya, Nadezda}, 33 | doi = {10.1140/epjc/s10052-023-11633-5}, 34 | eprint = {2212.07347}, 35 | journal = {Eur. Phys. J. C}, 36 | number = {6}, 37 | pages = {485}, 38 | primaryclass = {hep-ex}, 39 | reportnumber = {FERMILAB-PUB-22-963-V}, 40 | title = {{Lorentz group equivariant autoencoders}}, 41 | volume = {83}, 42 | year = {2023} 43 | } 44 | 45 | @software{jax2018github, 46 | author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, 47 | title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, 48 | url = {http://github.com/google/jax}, 49 | version = {0.3.13}, 50 | year = {2018} 51 | } 52 | 53 | @article{Kach:2022uzq, 54 | archiveprefix = {arXiv}, 55 | author = {Kach, Benno and Kr{\"u}cker, Dirk and Melzer-Pellmann, Isabell}, 56 | eprint = {2211.13623}, 57 | primaryclass = {hep-ex}, 58 | title = {{Point Cloud Generation using Transformer Encoders and Normalising Flows}}, 59 | year = {2022} 60 | } 61 | 62 | @article{Kach:2023rqw, 63 | archiveprefix = {arXiv}, 64 | author = {Kach, Benno and Melzer-Pellmann, Isabell}, 65 | eprint = {2305.15254}, 66 | primaryclass = {hep-ex}, 67 | title = {{Attention to Mean-Fields for Particle Cloud Generation}}, 68 | year = {2023} 69 | } 70 | 71 | @dataset{kansal_raghav_2022_6975118, 72 | author = {Kansal, Raghav and 73 | Duarte, Javier and 74 | Su, Hao and 75 | Orzari, Breno and 76 | Tomei, Thiago and 77 | Pierini, Maurizio and 78 | Touranakou, Mary and 79 | Vlimant, Jean-Roch and 80 | Gunopulos, Dimitrios}, 81 | doi = {10.5281/zenodo.6975118}, 82 | publisher = {Zenodo}, 83 | title = {JetNet}, 84 | url = {https://doi.org/10.5281/zenodo.6975118}, 85 | version = {2}, 86 | year = {2022} 87 | } 88 | 89 | @inproceedings{Kansal:2021cqp, 90 | archiveprefix = {arXiv}, 91 | author = {Kansal, Raghav and Duarte, Javier and Su, Hao and Orzari, Breno and Tomei, Thiago and Pierini, Maurizio and Touranakou, Mary and Vlimant, Jean-Roch and Gunopulos, Dimitrios}, 92 | booktitle = {{Advances in Neural Information Processing Systems}}, 93 | eprint = {2106.11535}, 94 | primaryclass = {cs.LG}, 95 | publisher = {Curran Associates, Inc.}, 96 | title = {Particle Cloud Generation with Message Passing Generative Adversarial Networks}, 97 | url = {https://papers.neurips.cc/paper_files/paper/2021/file/c8512d142a2d849725f31a9a7a361ab9-Paper.pdf}, 98 | volume = {34}, 99 | year = {2021} 100 | } 101 | 102 | @article{Kansal:2022spb, 103 | archiveprefix = {arXiv}, 104 | author = {Kansal, Raghav and Li, Anni and Duarte, Javier and Chernyavskaya, Nadezda and Pierini, Maurizio and Orzari, Breno and Tomei, Thiago}, 105 | doi = {10.1103/PhysRevD.107.076017}, 106 | eprint = {2211.10295}, 107 | journal = {Phys. Rev. D}, 108 | number = {7}, 109 | pages = {076017}, 110 | primaryclass = {hep-ex}, 111 | reportnumber = {FERMILAB-PUB-22-872-CMS-PPD}, 112 | title = {{Evaluating generative models in high energy physics}}, 113 | volume = {107}, 114 | year = {2023} 115 | } 116 | 117 | @article{Kasieczka_2021, 118 | archiveprefix = {arXiv}, 119 | author = {Gregor Kasieczka and Benjamin Nachman and David Shih and Oz Amram and Anders Andreassen and Kees Benkendorfer and Blaz Bortolato and Gustaaf Brooijmans and Florencia Canelli and Jack H Collins and Biwei Dai and Felipe F De Freitas and Barry M Dillon and Ioan-Mihail Dinu and Zhongtian Dong and Julien Donini and Javier Duarte and D A Faroughy and Julia Gonski and Philip Harris and Alan Kahn and Jernej F Kamenik and Charanjit K Khosa and Patrick Komiske and Luc Le Pottier and Pablo Martín-Ramiro and Andrej Matevc and Eric Metodiev and Vinicius Mikuni and Christopher W Murphy and Inês Ochoa and Sang Eon Park and Maurizio Pierini and Dylan Rankin and Veronica Sanz and Nilai Sarda and Urŏ Seljak and Aleks Smolkovic and George Stein and Cristina Mantilla Suarez and Manuel Szewc and Jesse Thaler and Steven Tsan and Silviu-Marian Udrescu and Louis Vaslin and Jean-Roch Vlimant and Daniel Williams and Mikaeel Yunus}, 120 | doi = {10.1088/1361-6633/ac36b9}, 121 | eprint = {2101.08320}, 122 | journal = {Rept. Prog. Phys.}, 123 | number = {12}, 124 | pages = {124201}, 125 | primaryclass = {hep-ph}, 126 | title = {{The LHC Olympics 2020 a community challenge for anomaly detection in high energy physics}}, 127 | url = {https://doi.org/10.1088/1361-6633/ac36b9}, 128 | volume = {84}, 129 | year = {2021} 130 | } 131 | 132 | @dataset{kasieczka_gregor_2019_2603256, 133 | author = {Kasieczka, Gregor and 134 | Plehn, Tilman and 135 | Thompson, Jennifer and 136 | Russel, Michael}, 137 | doi = {10.5281/zenodo.2603256}, 138 | publisher = {Zenodo}, 139 | title = {Top Quark Tagging Reference Dataset}, 140 | url = {https://doi.org/10.5281/zenodo.2603256}, 141 | version = {v0 (2018\_03\_27)}, 142 | year = {2019} 143 | } 144 | 145 | @article{Kasieczka:2019dbj, 146 | archiveprefix = {arXiv}, 147 | author = {Butter, Anja and others}, 148 | doi = {10.21468/SciPostPhys.7.1.014}, 149 | editor = {Kasieczka, Gregor and Plehn, Tilman}, 150 | eprint = {1902.09914}, 151 | journal = {SciPost Phys.}, 152 | pages = {014}, 153 | primaryclass = {hep-ph}, 154 | title = {{The Machine Learning landscape of top taggers}}, 155 | volume = {7}, 156 | year = {2019} 157 | } 158 | 159 | @dataset{komiske_patrick_2019_3164691, 160 | author = {Komiske, Patrick and 161 | Metodiev, Eric and 162 | Thaler, Jesse}, 163 | doi = {10.5281/zenodo.3164691}, 164 | publisher = {Zenodo}, 165 | title = {Pythia8 Quark and Gluon Jets for Energy Flow}, 166 | url = {https://doi.org/10.5281/zenodo.3164691}, 167 | version = {v1}, 168 | year = 2019 169 | } 170 | 171 | @article{Komiske:2019jim, 172 | archiveprefix = {arXiv}, 173 | author = {Komiske, Patrick T. and Mastandrea, Radha and Metodiev, Eric M. and Naik, Preksha and Thaler, Jesse}, 174 | doi = {10.1103/PhysRevD.101.034009}, 175 | eprint = {1908.08542}, 176 | journal = {Phys. Rev. D}, 177 | number = {3}, 178 | pages = {034009}, 179 | primaryclass = {hep-ph}, 180 | reportnumber = {MIT-CTP 5129}, 181 | title = {{Exploring the Space of Jets with CMS Open Data}}, 182 | volume = {101}, 183 | year = {2020} 184 | } 185 | 186 | @article{Leigh:2023toe, 187 | archiveprefix = {arXiv}, 188 | author = {Leigh, Matthew and Sengupta, Debajyoti and Qu\'etant, Guillaume and Raine, John Andrew and Zoch, Knut and Golling, Tobias}, 189 | eprint = {2303.05376}, 190 | primaryclass = {hep-ph}, 191 | title = {{PC-JeDi: Diffusion for Particle Cloud Generation in High Energy Physics}}, 192 | year = {2023} 193 | } 194 | 195 | @article{Mikuni:2023dvk, 196 | archiveprefix = {arXiv}, 197 | author = {Mikuni, Vinicius and Nachman, Benjamin and Pettee, Mariel}, 198 | doi = {10.1103/PhysRevD.108.036025}, 199 | eprint = {2304.01266}, 200 | journal = {Phys. Rev. D}, 201 | number = {3}, 202 | pages = {036025}, 203 | primaryclass = {hep-ph}, 204 | title = {{Fast point cloud generation with diffusion models in high energy physics}}, 205 | volume = {108}, 206 | year = {2023} 207 | } 208 | 209 | @inproceedings{NEURIPS2019_9015, 210 | archiveprefix = {arXiv}, 211 | author = {Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and Kopf, Andreas and Yang, Edward and DeVito, Zachary and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith}, 212 | booktitle = {Advances in Neural Information Processing Systems}, 213 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 214 | eprint = {1912.01703}, 215 | pages = {8024}, 216 | publisher = {Curran Associates, Inc.}, 217 | title = {PyTorch: An Imperative Style, High-Performance Deep Learning Library}, 218 | url = {http://papers.neurips.cc/paper/9015-pytorch-an-imperative-style-high-performance-deep-learning-library.pdf}, 219 | volume = {32}, 220 | year = {2019} 221 | } 222 | 223 | @article{PhysRevLett.123.041801, 224 | archiveprefix = {arXiv}, 225 | author = {Komiske, Patrick T. and Metodiev, Eric M. and Thaler, Jesse}, 226 | doi = {10.1103/PhysRevLett.123.041801}, 227 | eprint = {1902.02346}, 228 | journal = {Phys. Rev. Lett.}, 229 | number = {4}, 230 | pages = {041801}, 231 | primaryclass = {hep-ph}, 232 | reportnumber = {MIT-CTP 5102}, 233 | title = {{Metric Space of Collider Events}}, 234 | volume = {123}, 235 | year = {2019} 236 | } 237 | 238 | @misc{Zenodo, 239 | author = {{European Organization For Nuclear Research} and {OpenAIRE}}, 240 | doi = {10.25495/7GXK-RD71}, 241 | keywords = {FOS: Physical sciences, Publication, Dataset}, 242 | publisher = {CERN}, 243 | title = {Zenodo}, 244 | url = {https://www.zenodo.org/}, 245 | year = {2013} 246 | } 247 | -------------------------------------------------------------------------------- /docs/JOSS/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "JetNet: A Python package for accessing open datasets and benchmarking machine learning methods in high energy physics" 3 | tags: 4 | - Python 5 | - PyTorch 6 | - high energy physics 7 | - machine learning 8 | - jets 9 | authors: 10 | - name: Raghav Kansal 11 | orcid: 0000-0003-2445-1060 12 | affiliation: "1, 2" # (Multiple affiliations must be quoted) 13 | corresponding: true 14 | - name: Carlos Pareja 15 | orcid: 0000-0002-9022-2349 16 | affiliation: 1 17 | - name: Zichun Hao 18 | orcid: 0000-0002-5624-4907 19 | affiliation: 3 20 | - name: Javier Duarte 21 | orcid: 0000-0002-5076-7096 22 | affiliation: 1 23 | affiliations: 24 | - name: UC San Diego, USA 25 | index: 1 26 | - name: Fermilab, USA 27 | index: 2 28 | - name: California Institute of Technology, USA 29 | index: 3 30 | date: 2023 31 | bibliography: paper.bib 32 | --- 33 | 34 | # Summary 35 | 36 | `JetNet` is a Python package that aims to increase accessibility and reproducibility for machine learning (ML) research in high energy physics (HEP), primarily related to particle jets. Based on the popular PyTorch ML framework, it provides easy-to-access and standardized interfaces for multiple heterogeneous HEP datasets and implementations of evaluation metrics, loss functions, and more general utilities relevant to HEP. 37 | 38 | # Statement of need 39 | 40 | It is essential in scientific research to maintain standardized benchmark datasets following the findable, accessible, interoperable, and reproducible (FAIR) data principles [@Chen:2021euv], practices for using the data, and methods for evaluating and comparing different algorithms. This can often be difficult in high energy physics (HEP) because of the broad set of formats in which data is released and the expert knowledge required to parse the relevant information. The `JetNet` Python package aims to facilitate this by providing a standard interface and format for HEP datasets, integrated with PyTorch [@NEURIPS2019_9015], to improve accessibility for both HEP experts and new or interdisciplinary researchers looking to do ML. Furthermore, by providing standard formats and implementations for evaluation metrics, results are more easily reproducible, and models are more easily assessed and benchmarked. `JetNet` is complementary to existing efforts for improving HEP dataset accessibility, notably the `EnergyFlow` library [@Komiske:2019jim], with a unique focus to ML applications and integration with PyTorch. 41 | 42 | ## Content 43 | 44 | `JetNet` currently provides easy-to-access and standardized interfaces for the JetNet [@kansal_raghav_2022_6975118], top quark tagging [@kasieczka_gregor_2019_2603256; @Kasieczka:2019dbj], and quark-gluon tagging [@komiske_patrick_2019_3164691] reference datasets, all hosted on Zenodo [@Zenodo]. It also provides standard implementations of generative evaluation metrics [@Kansal:2021cqp; @Kansal:2022spb], including Fréchet physics distance (FPD), kernel physics distance (KPD), 1-Wasserstein distance (W1), Fréchet ParticleNet distance (FPND), coverage, and minimum matching distance (MMD). Finally, `JetNet` implements custom loss functions like a differentiable version of the energy mover's distance [@PhysRevLett.123.041801] and more general jet utilities. 45 | 46 | ## Impact 47 | 48 | The impact of `JetNet` is demonstrated by the surge in ML and HEP research facilitated by the package, including in the areas of generative adversarial networks [@Kansal:2021cqp], transformers [@Kach:2022uzq; @Kansal:2022spb; @Kach:2023rqw], diffusion models [@Leigh:2023toe; @Mikuni:2023dvk], and equivariant networks [@Hao:2022zns; @Buhmann:2023pmh], all accessing datasets, metrics, and more through `JetNet`. 49 | 50 | ## Future Work 51 | 52 | Future work will expand the package to additional dataset loaders, including detector-level data, and different machine learning backends such as JAX [@jax2018github]. Improvements to the performance, such as optional lazy loading of large datasets, are also planned, as well as community challenges to benchmark algorithms facilitated by `JetNet`. 53 | 54 | # Acknowledgements 55 | 56 | We thank the `JetNet` community for their support and feedback. J.D. and R.K. received support for work related to `JetNet` provided by the U.S. Department of Energy (DOE), Office of Science, Office of High Energy Physics Early Career Research Program under Award No. DE-SC0021187, the DOE, Office of Advanced Scientific Computing Research under Award No. DE-SC0021396 (FAIR4HEP). R.K. was partially supported by the LHC Physics Center at Fermi National Accelerator Laboratory, managed and operated by Fermi Research Alliance, LLC under Contract No. DE-AC02-07CH11359 with the DOE. C.P. was supported by the Experiential Projects for Accelerated Networking and Development (EXPAND) mentorship program at UC San Diego. 57 | 58 | # References 59 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # JetNet Docs 2 | 3 | We're using Sphinx with the Napoleon extension, and the Google Python docstrings style. 4 | 5 | ## Build 6 | 7 | Install requirements with `pip install -r requirements.txt`. 8 | 9 | And then you can run the following command to build locally 10 | 11 | ```bash 12 | make html 13 | ``` 14 | 15 | After which you can open the `docs/_build/html/index.html` file in your browser. 16 | -------------------------------------------------------------------------------- /docs/_static/images/jetnetlogo.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/docs/_static/images/jetnetlogo.key -------------------------------------------------------------------------------- /docs/_static/images/jetnetlogo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/docs/_static/images/jetnetlogo.pdf -------------------------------------------------------------------------------- /docs/_static/images/jetnetlogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/docs/_static/images/jetnetlogo.png -------------------------------------------------------------------------------- /docs/_static/images/jetnetlogo_white.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/docs/_static/images/jetnetlogo_white.key -------------------------------------------------------------------------------- /docs/_static/images/jetnetlogo_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/docs/_static/images/jetnetlogo_white.png -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | from __future__ import annotations 14 | 15 | import os 16 | import sys 17 | 18 | import sphinx_rtd_theme 19 | 20 | sys.path.insert(0, os.path.abspath("../")) # noqa: PTH100 21 | # sys.path.insert(0, os.path.abspath("../tutorials/")) 22 | autodoc_mock_imports = [ 23 | "energyflow", 24 | "awkward", 25 | "coffea", 26 | "tqdm", 27 | "scipy", 28 | "torch_geometric", 29 | "torch", 30 | "cvxpy", 31 | "qpth", 32 | "numba", 33 | ] 34 | 35 | # -- Project information ----------------------------------------------------- 36 | 37 | project = "JetNet" 38 | copyright = "2021, Raghav Kansal" 39 | author = "Raghav Kansal" 40 | 41 | # The full version, including alpha/beta/rc tags 42 | release = "0.2.0a" 43 | 44 | 45 | # -- General configuration --------------------------------------------------- 46 | 47 | # Add any Sphinx extension module names here, as strings. They can be 48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 49 | # ones. 50 | extensions = [ 51 | "sphinx.ext.autodoc", 52 | "sphinx.ext.autosummary", 53 | "sphinx.ext.napoleon", 54 | "autodocsumm", 55 | "m2r2", 56 | "nbsphinx", 57 | "sphinx_rtd_theme", 58 | ] 59 | autosummary_generate = True # Turn on sphinx.ext.autosummary 60 | 61 | autodoc_type_aliases = {"ArrayLike": "ArrayLike"} 62 | 63 | source_suffix = [".rst", ".md"] 64 | 65 | # Add any paths that contain templates here, relative to this directory. 66 | templates_path = ["_templates"] 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path. 71 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md", "JOSS"] 72 | 73 | master_doc = "pages/contents" 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = "sphinx_rtd_theme" 81 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 82 | 83 | # Add any paths that contain custom static files (such as style sheets) here, 84 | # relative to this directory. They are copied after the builtin static files, 85 | # so a file named "default.css" will overwrite the builtin "default.css". 86 | html_static_path = ["_static"] 87 | 88 | html_sidebars = {"**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"]} 89 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to JetNet! 2 | ================== 3 | 4 | .. mdinclude:: ../README.md 5 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/pages/contents.rst: -------------------------------------------------------------------------------- 1 | .. image:: ../_static/images/jetnetlogo.png 2 | :width: 75% 3 | :alt: JetNet logo 4 | 5 | Contents 6 | ============== 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | ../index 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | :name: api 16 | :caption: API 17 | 18 | datasets 19 | metrics 20 | losses 21 | utils 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Tutorials 26 | 27 | tutorials/pyhep-data-access 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /docs/pages/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | =========== 3 | 4 | JetNet 5 | ************** 6 | 7 | .. autoclass:: jetnet.datasets.JetNet 8 | :members: 9 | :autosummary: 10 | :noindex: 11 | 12 | TopTagging 13 | ************** 14 | 15 | .. autoclass:: jetnet.datasets.TopTagging 16 | :members: 17 | :autosummary: 18 | :noindex: 19 | 20 | QuarkGluon 21 | ************** 22 | 23 | .. autoclass:: jetnet.datasets.QuarkGluon 24 | :members: 25 | :autosummary: 26 | :noindex: 27 | 28 | Normalisations 29 | ************** 30 | 31 | .. automodule:: jetnet.datasets.normalisations 32 | :members: 33 | :noindex: 34 | :autosummary: 35 | 36 | Utility Functions 37 | ****************** 38 | 39 | .. automodule:: jetnet.datasets.utils 40 | :members: 41 | :noindex: 42 | :autosummary: 43 | -------------------------------------------------------------------------------- /docs/pages/losses.rst: -------------------------------------------------------------------------------- 1 | Loss Functions 2 | ********************************** 3 | .. automodule:: jetnet.losses 4 | :members: 5 | :imported-members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/pages/metrics.rst: -------------------------------------------------------------------------------- 1 | 2 | Metrics 3 | ********************************** 4 | .. automodule:: jetnet.evaluation 5 | :members: 6 | :imported-members: 7 | :autosummary: 8 | :exclude-members: JetNet 9 | -------------------------------------------------------------------------------- /docs/pages/tutorials: -------------------------------------------------------------------------------- 1 | ../../tutorials -------------------------------------------------------------------------------- /docs/pages/utils.rst: -------------------------------------------------------------------------------- 1 | Utility Functions 2 | ********************************** 3 | .. automodule:: jetnet.utils 4 | :imported-members: 5 | :members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # File: docs/requirements.txt 2 | 3 | autodocsumm 4 | ipykernel 5 | m2r2 6 | nbsphinx 7 | numpy 8 | readthedocs-sphinx-search 9 | scipy 10 | sphinx<7 11 | sphinx_rtd_theme==0.5.2 12 | torch 13 | tqdm 14 | -------------------------------------------------------------------------------- /jetnet/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__ 2 | from __future__ import annotations 3 | 4 | import jetnet.datasets 5 | import jetnet.datasets.normalisations 6 | import jetnet.datasets.utils 7 | import jetnet.evaluation 8 | import jetnet.losses 9 | import jetnet.utils # noqa: F401 10 | 11 | __version__ = "0.2.5" 12 | -------------------------------------------------------------------------------- /jetnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .jetnet import JetNet # noqa: F401 4 | from .qgjets import QuarkGluon # noqa: F401 5 | from .toptagging import TopTagging # noqa: F401 6 | -------------------------------------------------------------------------------- /jetnet/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base classes for JetNet datasets. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Any, Callable 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from .normalisations import NormaliseABC 13 | from .utils import checkListNotEmpty, checkStrToList, firstNotNoneElement 14 | 15 | 16 | class JetDataset(torch.utils.data.Dataset): 17 | """ 18 | Base class for jet datasets. 19 | Inspired by https://pytorch.org/vision/main/generated/torchvision.datasets.VisionDataset.html 20 | 21 | Args: 22 | data_dir (str): directory where dataset is or will be stored. 23 | particle_features (list[str], optional): list of particle features to retrieve. If empty 24 | or None, gets no particle features. Should default to all. 25 | jet_features (list[str], optional): list of jet features to retrieve. If empty or None, 26 | gets no particle features. Should default to all. 27 | particle_normalisation (NormaliseABC, optional): optional normalisation for 28 | particle-level features. Defaults to None. 29 | jet_normalisation (NormaliseABC, optional): optional normalisation for jet-level 30 | features. Defaults to None. 31 | particle_transform (callable, optional): A function/transform that takes in the particle 32 | data tensor and transforms it. Defaults to None. 33 | jet_transform (callable, optional): A function/transform that takes in the jet 34 | data tensor and transforms it. Defaults to None. 35 | num_particles (int, optional): max number of particles to retain per jet. Defaults to None. 36 | """ 37 | 38 | _repr_indent = 4 39 | 40 | particle_data = None 41 | jet_data = None 42 | MAX_NUM_PARTICLES = None 43 | 44 | def __init__( 45 | self, 46 | data_dir: str = "./", 47 | particle_features: list[str] | None = "all", 48 | jet_features: list[str] | None = "all", 49 | particle_normalisation: NormaliseABC | None = None, 50 | jet_normalisation: NormaliseABC | None = None, 51 | particle_transform: Callable | None = None, 52 | jet_transform: Callable | None = None, 53 | num_particles: int | None = None, 54 | ): 55 | self.data_dir = data_dir 56 | 57 | self.particle_features, self.jet_features = checkStrToList(particle_features, jet_features) 58 | self.use_particle_features, self.use_jet_features = checkListNotEmpty( 59 | particle_features, jet_features 60 | ) 61 | 62 | self.particle_normalisation = particle_normalisation 63 | self.jet_normalisation = jet_normalisation 64 | 65 | if self.use_particle_features and self.particle_normalisation is not None: 66 | if self.particle_normalisation.features_need_deriving(): 67 | self.particle_normalisation.derive_dataset_features(self.particle_data) 68 | self.particle_data = self.particle_normalisation(self.particle_data) 69 | 70 | if self.use_jet_features and self.jet_normalisation is not None: 71 | if self.jet_normalisation.features_need_deriving(): 72 | self.jet_normalisation.derive_dataset_features(self.jet_data) 73 | self.jet_data = self.jet_normalisation(self.jet_data) 74 | 75 | self.particle_transform = particle_transform 76 | self.jet_transform = jet_transform 77 | 78 | self.num_particles = num_particles 79 | 80 | @classmethod 81 | def getData(**opts) -> Any: 82 | """Class method to download and return numpy arrays of the data""" 83 | raise NotImplementedError 84 | 85 | def __getitem__(self, index) -> tuple[Tensor | None, Tensor | None]: 86 | """ 87 | Gets data and if needed transforms it. 88 | 89 | Args: 90 | index (int): Index 91 | 92 | Returns: 93 | (Tuple[Tensor | None, Tensor | None]): particle, jet data 94 | """ 95 | 96 | if self.use_particle_features: 97 | particle_data = self.particle_data[index] 98 | 99 | if self.particle_transform is not None: 100 | particle_data = self.particle_transform(particle_data) 101 | 102 | particle_data = Tensor(particle_data) 103 | else: 104 | particle_data = [] 105 | 106 | if self.use_jet_features: 107 | jet_data = self.jet_data[index] 108 | 109 | if self.jet_transform is not None: 110 | jet_data = self.jet_transform(jet_data) 111 | 112 | jet_data = Tensor(jet_data) 113 | else: 114 | jet_data = [] 115 | 116 | return particle_data, jet_data 117 | 118 | def __len__(self) -> int: 119 | return len(firstNotNoneElement(self.particle_data, self.jet_data)) 120 | 121 | def __repr__(self) -> str: 122 | head = "Dataset " + self.__class__.__name__ 123 | body = [f"Number of datapoints: {self.__len__()}"] 124 | 125 | if self.data_dir is not None: 126 | body.append(f"Data location: {self.data_dir}") 127 | 128 | body += self.extra_repr().splitlines() 129 | 130 | if self.particle_features is not None: 131 | bstr = f"Particle features: {self.particle_features}" 132 | if self.num_particles is not None: 133 | bstr += f", max {self.num_particles} particles per jet" 134 | 135 | body += [bstr] 136 | 137 | if self.jet_features is not None: 138 | body += [f"Jet features: {self.jet_features}"] 139 | 140 | if self.particle_normalisation is not None: 141 | body += [f"Particle normalisation: {self.particle_normalisation}"] 142 | 143 | if self.jet_normalisation is not None: 144 | body += [f"Jet normalisation: {self.jet_normalisation}"] 145 | 146 | if self.particle_transform is not None: 147 | body += [f"Particle transform: {self.particle_transform}"] 148 | 149 | if self.jet_transform is not None: 150 | body += [f"Jet transform: {self.jet_transform}"] 151 | 152 | lines = [head] + [" " * self._repr_indent + line for line in body] 153 | 154 | return "\n".join(lines) 155 | 156 | def extra_repr(self) -> str: 157 | return "" 158 | -------------------------------------------------------------------------------- /jetnet/datasets/jetnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import copy 4 | from typing import Callable 5 | 6 | import numpy as np 7 | 8 | from .dataset import JetDataset 9 | from .normalisations import FeaturewiseLinearBounded, NormaliseABC 10 | from .utils import ( 11 | checkConvertElements, 12 | checkDownloadZenodoDataset, 13 | checkListNotEmpty, 14 | checkStrToList, 15 | firstNotNoneElement, 16 | getOrderedFeatures, 17 | getSplitting, 18 | ) 19 | 20 | 21 | class JetNet(JetDataset): 22 | """ 23 | PyTorch ``torch.unit.data.Dataset`` class for the JetNet dataset. 24 | 25 | If hdf5 files are not found in the ``data_dir`` directory then dataset will be downloaded 26 | from Zenodo (https://zenodo.org/record/6975118 or https://zenodo.org/record/6975117). 27 | 28 | Args: 29 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 30 | 'g' (gluon), 'q' (light quarks), 't' (top quarks), 'w' (W bosons), or 'z' (Z bosons). 31 | "all" will get all types. Defaults to "all". 32 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 33 | particle_features (List[str], optional): list of particle features to retrieve. If empty 34 | or None, gets no particle features. Defaults to 35 | ``["etarel", "phirel", "ptrel", "mask"]``. 36 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 37 | gets no jet features. Defaults to 38 | ``["type", "pt", "eta", "mass", "num_particles"]``. 39 | particle_normalisation (NormaliseABC, optional): optional normalisation to apply to 40 | particle data. Defaults to None. 41 | jet_normalisation (NormaliseABC, optional): optional normalisation to apply to jet data. 42 | Defaults to None. 43 | particle_transform (callable, optional): A function/transform that takes in the particle 44 | data tensor and transforms it. Defaults to None. 45 | jet_transform (callable, optional): A function/transform that takes in the jet 46 | data tensor and transforms it. Defaults to None. 47 | num_particles (int, optional): number of particles to retain per jet, max of 150. 48 | Defaults to 30. 49 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 50 | to "train". 51 | split_fraction (List[float], optional): splitting fraction of training, validation, 52 | testing data respectively. Defaults to [0.7, 0.15, 0.15]. 53 | seed (int, optional): PyTorch manual seed - important to use the same seed for all 54 | dataset splittings. Defaults to 42. 55 | download (bool, optional): If True, downloads the dataset from the internet and 56 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 57 | downloaded again. Defaults to False. 58 | """ 59 | 60 | _ZENODO_RECORD_IDS = {"30": 6975118, "150": 6975117} 61 | 62 | MAX_NUM_PARTICLES = 150 63 | 64 | JET_TYPES = ["g", "q", "t", "w", "z"] 65 | ALL_PARTICLE_FEATURES = ["etarel", "phirel", "ptrel", "mask"] 66 | ALL_JET_FEATURES = ["type", "pt", "eta", "mass", "num_particles"] 67 | SPLITS = ["train", "valid", "test", "all"] 68 | 69 | # normalisation used for ParticleNet training for FPND, as defined in arXiv:2106.11535 70 | fpnd_norm = FeaturewiseLinearBounded( 71 | feature_norms=1.0, 72 | feature_shifts=[0.0, 0.0, -0.5], 73 | feature_maxes=[1.6211985349655151, 0.520724892616272, 0.8934717178344727], 74 | ) 75 | 76 | def __init__( 77 | self, 78 | jet_type: str | set[str] = "all", 79 | data_dir: str = "./", 80 | particle_features: list[str] | None = "all", 81 | jet_features: list[str] | None = "all", 82 | particle_normalisation: NormaliseABC | None = None, 83 | jet_normalisation: NormaliseABC | None = None, 84 | particle_transform: Callable | None = None, 85 | jet_transform: Callable | None = None, 86 | num_particles: int = 30, 87 | split: str = "train", 88 | split_fraction: list[float] | None = None, 89 | seed: int = 42, 90 | download: bool = False, 91 | ): 92 | if particle_features == "all": 93 | particle_features = copy(self.ALL_PARTICLE_FEATURES) 94 | 95 | if jet_features == "all": 96 | jet_features = copy(self.ALL_JET_FEATURES) 97 | 98 | if split_fraction is None: 99 | split_fraction = [0.7, 0.15, 0.15] 100 | 101 | self.particle_data, self.jet_data = self.getData( 102 | jet_type, 103 | data_dir, 104 | particle_features, 105 | jet_features, 106 | num_particles, 107 | split, 108 | split_fraction, 109 | seed, 110 | download, 111 | ) 112 | 113 | super().__init__( 114 | data_dir=data_dir, 115 | particle_features=particle_features, 116 | jet_features=jet_features, 117 | particle_normalisation=particle_normalisation, 118 | jet_normalisation=jet_normalisation, 119 | particle_transform=particle_transform, 120 | jet_transform=jet_transform, 121 | num_particles=num_particles, 122 | ) 123 | 124 | self.jet_type = jet_type 125 | self.split = split 126 | self.split_fraction = split_fraction 127 | 128 | @classmethod 129 | def getData( 130 | cls: JetDataset, 131 | jet_type: str | set[str] = "all", 132 | data_dir: str = "./", 133 | particle_features: list[str] | None = "all", 134 | jet_features: list[str] | None = "all", 135 | num_particles: int = 30, 136 | split: str = "all", 137 | split_fraction: list[float] | None = None, 138 | seed: int = 42, 139 | download: bool = False, 140 | ) -> tuple[np.ndarray | None, np.ndarray | None]: 141 | """ 142 | Downloads, if needed, and loads and returns JetNet data. 143 | 144 | Args: 145 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 146 | 'g' (gluon), 't' (top quarks), 'q' (light quarks), 'w' (W bosons), 147 | or 'z' (Z bosons). "all" will get all types. Defaults to "all". 148 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 149 | particle_features (List[str], optional): list of particle features to retrieve. If empty 150 | or None, gets no particle features. Defaults to 151 | ``["etarel", "phirel", "ptrel", "mask"]``. 152 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 153 | gets no jet features. Defaults to 154 | ``["type", "pt", "eta", "mass", "num_particles"]``. 155 | num_particles (int, optional): number of particles to retain per jet, max of 150. 156 | Defaults to 30. 157 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 158 | to "train". 159 | split_fraction (List[float], optional): splitting fraction of training, validation, 160 | testing data respectively. Defaults to [0.7, 0.15, 0.15]. 161 | seed (int, optional): PyTorch manual seed - important to use the same seed for all 162 | dataset splittings. Defaults to 42. 163 | download (bool, optional): If True, downloads the dataset from the internet and 164 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 165 | downloaded again. Defaults to False. 166 | 167 | Returns: 168 | tuple[np.ndarray | None, np.ndarray | None]: particle data, jet data 169 | """ 170 | if particle_features == "all": 171 | particle_features = copy(cls.ALL_PARTICLE_FEATURES) 172 | 173 | if jet_features == "all": 174 | jet_features = copy(cls.ALL_JET_FEATURES) 175 | 176 | if split_fraction is None: 177 | split_fraction = [0.7, 0.15, 0.15] 178 | 179 | assert num_particles <= cls.MAX_NUM_PARTICLES, ( 180 | f"num_particles {num_particles} exceeds max number of " 181 | + f"particles in the dataset {cls.MAX_NUM_PARTICLES}" 182 | ) 183 | jet_type = checkConvertElements(jet_type, cls.JET_TYPES, ntype="jet type") 184 | particle_features, jet_features = checkStrToList(particle_features, jet_features) 185 | use_particle_features, use_jet_features = checkListNotEmpty(particle_features, jet_features) 186 | 187 | import h5py 188 | 189 | # Use JetNet150 if ``num_particles`` > 30 190 | use_150 = num_particles > 30 191 | 192 | particle_data = [] 193 | jet_data = [] 194 | 195 | for j in jet_type: 196 | dname = f"{j}{'150' if use_150 else ''}" 197 | 198 | hdf5_file = checkDownloadZenodoDataset( 199 | data_dir, 200 | dataset_name=dname, 201 | record_id=cls._ZENODO_RECORD_IDS["150" if use_150 else "30"], 202 | key=f"{dname}.hdf5", 203 | download=download, 204 | ) 205 | 206 | with h5py.File(hdf5_file, "r") as f: 207 | pf = ( 208 | np.array(f["particle_features"])[:, :num_particles] 209 | if use_particle_features 210 | else None 211 | ) 212 | jf = np.array(f["jet_features"]) if use_jet_features else None 213 | 214 | if use_particle_features: 215 | # reorder if needed 216 | pf = getOrderedFeatures(pf, particle_features, cls.ALL_PARTICLE_FEATURES) 217 | 218 | if use_jet_features: 219 | # add class index as first jet feature 220 | class_index = cls.JET_TYPES.index(j) 221 | jf = np.concatenate( 222 | ( 223 | np.full([len(jf), 1], class_index), 224 | jf[:, :3], 225 | # max particles should be num particles 226 | np.minimum(jf[:, 3:], num_particles), 227 | ), 228 | axis=1, 229 | ) 230 | # reorder if needed 231 | jf = getOrderedFeatures(jf, jet_features, cls.ALL_JET_FEATURES) 232 | 233 | particle_data.append(pf) 234 | jet_data.append(jf) 235 | 236 | particle_data = np.concatenate(particle_data, axis=0) if use_particle_features else None 237 | jet_data = np.concatenate(jet_data, axis=0) if use_jet_features else None 238 | 239 | length = len(firstNotNoneElement(particle_data, jet_data)) 240 | 241 | # shuffling and splitting into training and test 242 | lcut, rcut = getSplitting(length, split, cls.SPLITS, split_fraction) 243 | 244 | rng = np.random.default_rng(seed) 245 | randperm = rng.permutation(length) 246 | 247 | if use_particle_features: 248 | particle_data = particle_data[randperm][lcut:rcut] 249 | 250 | if use_jet_features: 251 | jet_data = jet_data[randperm][lcut:rcut] 252 | 253 | return particle_data, jet_data 254 | 255 | def extra_repr(self) -> str: 256 | ret = f"Including {self.jet_type} jets" 257 | 258 | if self.split == "all": 259 | ret += "\nUsing all data (no split)" 260 | else: 261 | ret += ( 262 | f"\nSplit into {self.split} data out of {self.SPLITS} possible splits, " 263 | f"with splitting fractions {self.split_fraction}" 264 | ) 265 | 266 | return ret 267 | -------------------------------------------------------------------------------- /jetnet/datasets/normalisations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Suite of common ways to normalise data. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | import numpy as np 10 | import torch 11 | from numpy.typing import ArrayLike 12 | 13 | 14 | class NormaliseABC(ABC): 15 | """ 16 | ABC for generalised normalisation class. 17 | """ 18 | 19 | def features_need_deriving(self) -> bool: 20 | """Checks if any dataset values or features need to be derived""" 21 | return False 22 | 23 | def derive_dataset_features(self, x: ArrayLike): # noqa: ARG002 24 | """Derive features from dataset needed for normalisation if needed""" 25 | return 26 | 27 | @abstractmethod 28 | def __call__(self, x: ArrayLike, inverse: bool = False, inplace: bool = False) -> ArrayLike: 29 | """ 30 | Normalises (``inverse`` = False) or inverses normalisation of (``inverse`` = True) ``x`` 31 | Performed inplace if ``inplace`` is True. 32 | """ 33 | 34 | 35 | class FeaturewiseLinear(NormaliseABC): 36 | """ 37 | Shifts features by ``feature_shifts`` then multiplies by ``feature_scales``. 38 | 39 | If using the ``normal`` option, ``feature_shifts`` and ``feature_scales`` can be derived from 40 | the dataset (by calling ``derive_dataset_features``) to normalise the data to have 0 mean and 41 | unit standard deviation per feature. 42 | 43 | Args: 44 | feature_shifts (Union[float, List[float]], optional): value to shift features by. 45 | Can either be a single float for all features, or a list of length ``num_features``. 46 | Defaults to 0.0. 47 | feature_scales (Union[float, List[float]], optional): after shifting, value to multiply 48 | features by. Can either be a single float for all features, or a list of length 49 | ``num_features``. Defaults to 1.0. 50 | normalise_features (Optional[List[bool]], optional): if only some features need to be 51 | normalised, can input here a list of booleans of length ``num_features`` with ``True`` 52 | meaning normalise and ``False`` meaning to ignore. Defaults to None i.e. normalise all. 53 | normal (bool, optional): derive ``feature_shifts`` and ``feature_scales`` to have 0 mean and 54 | unit standard deviation per feature after normalisation (``derive_dataset_features`` 55 | method must be called before normalising). 56 | 57 | """ 58 | 59 | def __init__( 60 | self, 61 | feature_shifts: float | list[float] = 0.0, 62 | feature_scales: float | list[float] = 1.0, 63 | normalise_features: list[bool] | None = None, 64 | normal: bool = False, 65 | ): 66 | super().__init__() 67 | 68 | self.feature_shifts = feature_shifts 69 | self.feature_scales = feature_scales 70 | self.normalise_features = normalise_features 71 | self.normal = normal 72 | 73 | def derive_dataset_features(self, x: ArrayLike) -> tuple[np.ndarray, np.ndarray] | None: 74 | """ 75 | If using the ``normal`` option, this will derive the means and standard deviations per 76 | feature, and save and return them. If not, will do nothing. 77 | 78 | Args: 79 | x (ArrayLike): dataset of shape [..., ``num_features``]. 80 | 81 | Returns: 82 | (Optional[Tuple[np.ndarray, np.ndarray]]): if ``normal`` option, means and stds of each 83 | feature. 84 | 85 | """ 86 | if self.normal: 87 | num_features = x.shape[-1] 88 | self.feature_shifts = -np.mean(x.reshape(-1, num_features), axis=0) 89 | self.feature_scales = 1.0 / np.std(x.reshape(-1, num_features), axis=0) 90 | return self.feature_shifts, self.feature_scales 91 | 92 | return None 93 | 94 | def features_need_deriving(self) -> bool: 95 | """Checks if any dataset values or features need to be derived""" 96 | return (self.feature_shifts is None) or (self.feature_scales is None) 97 | 98 | def __call__(self, x: ArrayLike, inverse: bool = False, inplace: bool = False) -> ArrayLike: 99 | assert not self.features_need_deriving(), ( 100 | "Feature means and stds have not been specified, " 101 | + "you need to either set or derive them first" 102 | ) 103 | 104 | num_features = x.shape[-1] 105 | 106 | if isinstance(self.feature_shifts, float): 107 | feature_shifts = np.full(num_features, self.feature_shifts) 108 | else: 109 | feature_shifts = self.feature_shifts 110 | 111 | if isinstance(self.feature_scales, float): 112 | feature_scales = np.full(num_features, self.feature_scales) 113 | else: 114 | feature_scales = self.feature_scales 115 | 116 | if self.normalise_features is None: 117 | normalise_features = np.full(num_features, True) 118 | elif isinstance(self.normalise_features, bool): 119 | normalise_features = np.full(num_features, self.normalise_features) 120 | else: 121 | normalise_features = self.normalise_features 122 | 123 | assert ( 124 | len(feature_shifts) == num_features 125 | ), "Number of features in input does not equal number of specified feature shifts" 126 | 127 | assert ( 128 | len(feature_scales) == num_features 129 | ), "Number of features in input does not equal number of specified feature scales" 130 | 131 | assert ( 132 | len(normalise_features) == num_features 133 | ), "Number of features in input does not equal length of ``normalise_features``" 134 | 135 | if not inplace: 136 | x = torch.clone(x) if isinstance(x, torch.Tensor) else np.copy(x) 137 | 138 | if not inverse: 139 | for i in range(num_features): 140 | if normalise_features[i]: 141 | x[..., i] += feature_shifts[i] 142 | x[..., i] *= feature_scales[i] 143 | 144 | else: 145 | for i in range(num_features): 146 | if normalise_features[i]: 147 | x[..., i] /= feature_scales[i] 148 | x[..., i] -= feature_shifts[i] 149 | 150 | return x 151 | 152 | def __repr__(self) -> str: 153 | if self.normal: 154 | ret = "Normalising features to zero mean and unit standard deviation" 155 | else: 156 | ret = ( 157 | f"Shift features by {self.feature_shifts} " 158 | f"and then multiplying by {self.feature_scales}" 159 | ) 160 | 161 | if self.normalise_features is not None and self.normalise_features is not True: 162 | ret += f", normalising features: {self.normalise_features}" 163 | 164 | return ret 165 | 166 | 167 | class FeaturewiseLinearBounded(NormaliseABC): 168 | """ 169 | Normalizes dataset features by scaling each to an (absolute) max of ``feature_norms`` 170 | and shifting by ``feature_shifts``. 171 | 172 | If the value in the list for a feature is None, it won't be scaled or shifted. 173 | 174 | Args: 175 | feature_norms (Union[float, List[float]], optional): max value to scale each feature to. 176 | Can either be a single float for all features, or a list of length ``num_features``. 177 | Defaults to 1.0. 178 | feature_shifts (Union[float, List[float]], optional): after scaling, 179 | value to shift feature by. 180 | Can either be a single float for all features, or a list of length ``num_features``. 181 | Defaults to 0.0. 182 | feature_maxes (List[float], optional): max pre-scaling absolute value of each feature, used 183 | for scaling to the norm and inverting. 184 | normalise_features (Optional[List[bool]], optional): if only some features need to be 185 | normalised, can input here a list of booleans of length ``num_features`` with ``True`` 186 | meaning normalise and ``False`` meaning to ignore. Defaults to None i.e. normalise all. 187 | 188 | """ 189 | 190 | def __init__( 191 | self, 192 | feature_norms: float | list[float] = 1.0, 193 | feature_shifts: float | list[float] = 0.0, 194 | feature_maxes: list[float] | None = None, 195 | normalise_features: list[bool] | None = None, 196 | ): 197 | super().__init__() 198 | 199 | self.feature_norms = feature_norms 200 | self.feature_shifts = feature_shifts 201 | self.feature_maxes = feature_maxes 202 | self.normalise_features = normalise_features 203 | 204 | def derive_dataset_features(self, x: ArrayLike) -> np.ndarray: 205 | """ 206 | Derives, saves, and returns absolute feature maxes of dataset ``x``. 207 | 208 | Args: 209 | x (ArrayLike): dataset of shape [..., ``num_features``]. 210 | 211 | Returns: 212 | np.ndarray: feature maxes 213 | 214 | """ 215 | num_features = x.shape[-1] 216 | self.feature_maxes = np.max(np.abs(x.reshape(-1, num_features)), axis=0) 217 | return self.feature_maxes 218 | 219 | def features_need_deriving(self) -> bool: 220 | """Checks if any dataset values or features need to be derived""" 221 | return self.feature_maxes is None 222 | 223 | def __call__(self, x: ArrayLike, inverse: bool = False, inplace: bool = False) -> ArrayLike: 224 | assert ( 225 | not self.features_need_deriving() 226 | ), "Feature maxes have not been specified, you need to either set or derive them first" 227 | 228 | num_features = x.shape[-1] 229 | 230 | assert num_features == len( 231 | self.feature_maxes 232 | ), "Number of features in ``x`` does not equal length of saved feature maxes" 233 | 234 | if isinstance(self.feature_norms, float): 235 | feature_norms = np.full(num_features, self.feature_norms) 236 | else: 237 | feature_norms = self.feature_norms 238 | 239 | if isinstance(self.feature_shifts, float): 240 | feature_shifts = np.full(num_features, self.feature_shifts) 241 | else: 242 | feature_shifts = self.feature_shifts 243 | 244 | if self.normalise_features is None: 245 | normalise_features = np.full(num_features, True) 246 | elif isinstance(self.normalise_features, bool): 247 | normalise_features = np.full(num_features, self.normalise_features) 248 | else: 249 | normalise_features = self.normalise_features 250 | 251 | assert ( 252 | len(feature_shifts) == num_features 253 | ), "Number of features in input does not equal number of specified feature shifts" 254 | 255 | assert ( 256 | len(feature_norms) == num_features 257 | ), "Number of features in input does not equal number of specified feature norms" 258 | 259 | assert ( 260 | len(normalise_features) == num_features 261 | ), "Number of features in input does not equal length of ``normalise_features``" 262 | 263 | if not inplace: 264 | x = torch.clone(x) if isinstance(x, torch.Tensor) else np.copy(x) 265 | 266 | if not inverse: 267 | for i in range(num_features): 268 | if normalise_features[i]: 269 | if feature_norms[i] is not None: 270 | x[..., i] /= self.feature_maxes[i] 271 | x[..., i] *= feature_norms[i] 272 | 273 | if feature_shifts[i] is not None: 274 | x[..., i] += feature_shifts[i] 275 | 276 | else: 277 | for i in range(num_features): 278 | if normalise_features[i]: 279 | if feature_shifts[i] is not None: 280 | x[..., i] -= feature_shifts[i] 281 | 282 | if feature_norms[i] is not None: 283 | x[..., i] /= feature_norms[i] 284 | x[..., i] *= self.feature_maxes[i] 285 | 286 | return x 287 | 288 | def __repr__(self) -> str: 289 | ret = ( 290 | f"Linear scaling features to feature norms {self.feature_norms} " 291 | f" and (post-scaling) feature shifts {self.feature_shifts}" 292 | ) 293 | 294 | if self.feature_maxes is not None: 295 | ret += f", with pre-scaling feature maxes {self.feature_maxes}" 296 | 297 | if self.normalise_features is not None and self.normalise_features is not True: 298 | ret += f", normalising features: {self.normalise_features}" 299 | 300 | return ret 301 | -------------------------------------------------------------------------------- /jetnet/datasets/qgjets.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import copy 4 | from typing import Callable 5 | 6 | import numpy as np 7 | 8 | from .dataset import JetDataset 9 | from .normalisations import NormaliseABC 10 | from .utils import ( 11 | checkConvertElements, 12 | checkDownloadZenodoDataset, 13 | checkListNotEmpty, 14 | checkStrToList, 15 | getOrderedFeatures, 16 | getSplitting, 17 | ) 18 | 19 | 20 | class QuarkGluon(JetDataset): 21 | """ 22 | PyTorch ``torch.unit.data.Dataset`` class for the Quark Gluon Jets dataset. Either jets with 23 | or without bottom and charm quark jets can be selected (``with_bc`` flag). 24 | 25 | If npz files are not found in the ``data_dir`` directory then dataset will be automatically 26 | downloaded from Zenodo (https://zenodo.org/record/3164691). 27 | 28 | Args: 29 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 30 | 'g' (gluon) and 'q' (light quarks). Defaults to "all". 31 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 32 | with_bc (bool, optional): with or without bottom and charm quark jets. Defaults to True. 33 | particle_features (List[str], optional): list of particle features to retrieve. If empty 34 | or None, gets no particle features. Defaults to 35 | ``["pt", "eta", "phi", "pdgid"]``. 36 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 37 | gets no jet features. Defaults to 38 | ``["type"]``. 39 | particle_normalisation (NormaliseABC, optional): optional normalisation to apply to 40 | particle data. Defaults to None. 41 | jet_normalisation (NormaliseABC, optional): optional normalisation to apply to jet data. 42 | Defaults to None. 43 | particle_transform (callable, optional): A function/transform that takes in the particle 44 | data tensor and transforms it. Defaults to None. 45 | jet_transform (callable, optional): A function/transform that takes in the jet 46 | data tensor and transforms it. Defaults to None. 47 | num_particles (int, optional): number of particles to retain per jet, max of 153. 48 | Defaults to 153. 49 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 50 | to "train". 51 | split_fraction (List[float], optional): splitting fraction of training, validation, 52 | testing data respectively. Defaults to [0.7, 0.15, 0.15]. 53 | seed (int, optional): PyTorch manual seed - important to use the same seed for all 54 | dataset splittings. Defaults to 42. 55 | file_list (List[str], optional): list of files to load, if full dataset is not required. 56 | Defaults to None (will load all files). 57 | download (bool, optional): If True, downloads the dataset from the internet and 58 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 59 | downloaded again. Defaults to False. 60 | """ 61 | 62 | _ZENODO_RECORD_ID = 3164691 63 | 64 | # False - without bc, True - with bc 65 | _FILE_LIST = { 66 | False: [ 67 | "QG_jets.npz", 68 | "QG_jets_1.npz", 69 | "QG_jets_2.npz", 70 | "QG_jets_3.npz", 71 | "QG_jets_4.npz", 72 | "QG_jets_5.npz", 73 | "QG_jets_6.npz", 74 | "QG_jets_7.npz", 75 | "QG_jets_8.npz", 76 | "QG_jets_9.npz", 77 | "QG_jets_10.npz", 78 | "QG_jets_11.npz", 79 | "QG_jets_12.npz", 80 | "QG_jets_13.npz", 81 | "QG_jets_14.npz", 82 | "QG_jets_15.npz", 83 | "QG_jets_16.npz", 84 | "QG_jets_17.npz", 85 | "QG_jets_18.npz", 86 | "QG_jets_19.npz", 87 | ], 88 | True: [ 89 | "QG_jets_withbc_0.npz", 90 | "QG_jets_withbc_1.npz", 91 | "QG_jets_withbc_2.npz", 92 | "QG_jets_withbc_3.npz", 93 | "QG_jets_withbc_3.npz", 94 | "QG_jets_withbc_4.npz", 95 | "QG_jets_withbc_5.npz", 96 | "QG_jets_withbc_6.npz", 97 | "QG_jets_withbc_7.npz", 98 | "QG_jets_withbc_8.npz", 99 | "QG_jets_withbc_9.npz", 100 | "QG_jets_withbc_10.npz", 101 | "QG_jets_withbc_11.npz", 102 | "QG_jets_withbc_12.npz", 103 | "QG_jets_withbc_13.npz", 104 | "QG_jets_withbc_14.npz", 105 | "QG_jets_withbc_15.npz", 106 | "QG_jets_withbc_16.npz", 107 | "QG_jets_withbc_17.npz", 108 | "QG_jets_withbc_18.npz", 109 | "QG_jets_withbc_19.npz", 110 | ], 111 | } 112 | 113 | MAX_NUM_PARTICLES = 153 114 | 115 | JET_TYPES = ["g", "q"] 116 | ALL_PARTICLE_FEATURES = ["pt", "eta", "phi", "pdgid"] 117 | ALL_JET_FEATURES = ["type"] 118 | SPLITS = ["train", "valid", "test", "all"] 119 | 120 | def __init__( 121 | self, 122 | jet_type: str | set[str] = "all", 123 | data_dir: str = "./", 124 | with_bc: bool = True, 125 | particle_features: list[str] | None = "all", 126 | jet_features: list[str] | None = "all", 127 | particle_normalisation: NormaliseABC | None = None, 128 | jet_normalisation: NormaliseABC | None = None, 129 | particle_transform: Callable | None = None, 130 | jet_transform: Callable | None = None, 131 | num_particles: int = MAX_NUM_PARTICLES, 132 | split: str = "train", 133 | split_fraction: list[float] | None = None, 134 | seed: int = 42, 135 | file_list: list[str] | None = None, 136 | download: bool = False, 137 | ): 138 | if particle_features == "all": 139 | particle_features = copy(self.ALL_PARTICLE_FEATURES) 140 | 141 | if jet_features == "all": 142 | jet_features = copy(self.ALL_JET_FEATURES) 143 | 144 | if split_fraction is None: 145 | split_fraction = [0.7, 0.15, 0.15] 146 | 147 | self.particle_data, self.jet_data = self.getData( 148 | jet_type, 149 | data_dir, 150 | with_bc, 151 | particle_features, 152 | jet_features, 153 | num_particles, 154 | split, 155 | split_fraction, 156 | seed, 157 | file_list, 158 | download, 159 | ) 160 | 161 | super().__init__( 162 | data_dir=data_dir, 163 | particle_features=particle_features, 164 | jet_features=jet_features, 165 | particle_normalisation=particle_normalisation, 166 | jet_normalisation=jet_normalisation, 167 | particle_transform=particle_transform, 168 | jet_transform=jet_transform, 169 | num_particles=num_particles, 170 | ) 171 | 172 | self.jet_type = jet_type 173 | self.split = split 174 | self.split_fraction = split_fraction 175 | 176 | @classmethod 177 | def getData( 178 | cls: JetDataset, 179 | jet_type: str | set[str] = "all", 180 | data_dir: str = "./", 181 | with_bc: bool = True, 182 | particle_features: list[str] | None = "all", 183 | jet_features: list[str] | None = "all", 184 | num_particles: int = MAX_NUM_PARTICLES, 185 | split: str = "all", 186 | split_fraction: list[float] | None = None, 187 | seed: int = 42, 188 | file_list: list[str] | None = None, 189 | download: bool = False, 190 | ) -> tuple[np.ndarray | None, np.ndarray | None]: 191 | """ 192 | Downloads, if needed, and loads and returns Quark Gluon data. 193 | 194 | Args: 195 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 196 | 'g' (gluon) and 'q' (light quarks). Defaults to "all". 197 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 198 | with_bc (bool, optional): with or without bottom and charm quark jets. Defaults to True. 199 | particle_features (List[str], optional): list of particle features to retrieve. If empty 200 | or None, gets no particle features. Defaults to 201 | ``["pt", "eta", "phi", "pdgid"]``. 202 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 203 | gets no jet features. Defaults to 204 | ``["type"]``. 205 | num_particles (int, optional): number of particles to retain per jet, max of 153. 206 | Defaults to 153. 207 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 208 | to "train". 209 | split_fraction (List[float], optional): splitting fraction of training, validation, 210 | testing data respectively. Defaults to [0.7, 0.15, 0.15]. 211 | seed (int, optional): PyTorch manual seed - important to use the same seed for all 212 | dataset splittings. Defaults to 42. 213 | file_list (List[str], optional): list of files to load, if full dataset is not required. 214 | Defaults to None (will load all files). 215 | download (bool, optional): If True, downloads the dataset from the internet and 216 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 217 | downloaded again. Defaults to False. 218 | 219 | Returns: 220 | tuple[np.ndarray | None, np.ndarray | None]: particle data, jet data 221 | """ 222 | if particle_features == "all": 223 | particle_features = copy(cls.ALL_PARTICLE_FEATURES) 224 | 225 | if jet_features == "all": 226 | jet_features = copy(cls.ALL_JET_FEATURES) 227 | 228 | if split_fraction is None: 229 | split_fraction = [0.7, 0.15, 0.15] 230 | 231 | assert num_particles <= cls.MAX_NUM_PARTICLES, ( 232 | f"num_particles {num_particles} exceeds max number of " 233 | + f"particles in the dataset {cls.MAX_NUM_PARTICLES}" 234 | ) 235 | 236 | jet_type = checkConvertElements(jet_type, cls.JET_TYPES, ntype="jet type") 237 | type_indices = [cls.JET_TYPES.index(t) for t in jet_type] 238 | 239 | particle_features, jet_features = checkStrToList(particle_features, jet_features) 240 | use_particle_features, use_jet_features = checkListNotEmpty(particle_features, jet_features) 241 | 242 | particle_data = [] 243 | jet_data = [] 244 | 245 | file_list = cls._FILE_LIST[with_bc] if file_list is None else file_list 246 | 247 | for file_name in file_list: 248 | npz_file = checkDownloadZenodoDataset( 249 | data_dir, 250 | dataset_name=file_name, 251 | record_id=cls._ZENODO_RECORD_ID, 252 | key=file_name, 253 | download=download, 254 | ) 255 | 256 | print(f"Loading {file_name}") 257 | data = np.load(npz_file) 258 | 259 | # select only specified types of jets (qcd or top or both) 260 | jet_selector = np.sum([data["y"] == i for i in type_indices], axis=0).astype(bool) 261 | 262 | if use_particle_features: 263 | pf = data["X"][jet_selector][:, :num_particles] 264 | 265 | # zero-pad if needed (datasets have different numbers of max particles) 266 | pf_np = pf.shape[1] 267 | if pf_np < num_particles: 268 | pf = np.pad(pf, ((0, 0), (0, num_particles - pf_np), (0, 0)), constant_values=0) 269 | 270 | # reorder if needed 271 | pf = getOrderedFeatures(pf, particle_features, cls.ALL_PARTICLE_FEATURES) 272 | 273 | if use_jet_features: 274 | jf = data["y"][jet_selector].reshape(-1, 1) 275 | jf = getOrderedFeatures(jf, jet_features, cls.ALL_JET_FEATURES) 276 | 277 | length = np.sum(jet_selector) 278 | 279 | # shuffling and splitting into training and test 280 | lcut, rcut = getSplitting(length, split, cls.SPLITS, split_fraction) 281 | 282 | rng = np.random.default_rng(seed) 283 | randperm = rng.permutation(length) 284 | 285 | if use_particle_features: 286 | pf = pf[randperm][lcut:rcut] 287 | particle_data.append(pf) 288 | 289 | if use_jet_features: 290 | jf = jf[randperm][lcut:rcut] 291 | jet_data.append(jf) 292 | 293 | particle_data = np.concatenate(particle_data, axis=0) if use_particle_features else None 294 | jet_data = np.concatenate(jet_data, axis=0) if use_jet_features else None 295 | 296 | return particle_data, jet_data 297 | 298 | def extra_repr(self) -> str: 299 | ret = f"Including {self.jet_type} jets" 300 | 301 | if self.split == "all": 302 | ret += "\nUsing all data (no split)" 303 | else: 304 | ret += ( 305 | f"\nSplit into {self.split} data out of {self.SPLITS} possible splits, " 306 | f"with splitting fractions {self.split_fraction}" 307 | ) 308 | 309 | return ret 310 | -------------------------------------------------------------------------------- /jetnet/datasets/toptagging.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import copy 4 | from typing import Callable 5 | 6 | import numpy as np 7 | 8 | from .dataset import JetDataset 9 | from .normalisations import NormaliseABC 10 | from .utils import ( 11 | checkConvertElements, 12 | checkDownloadZenodoDataset, 13 | checkListNotEmpty, 14 | checkStrToList, 15 | getOrderedFeatures, 16 | ) 17 | 18 | 19 | class TopTagging(JetDataset): 20 | """ 21 | PyTorch ``torch.unit.data.Dataset`` class for the Top Quark Tagging Reference dataset. 22 | 23 | If hdf5 files are not found in the ``data_dir`` directory then dataset will be downloaded 24 | from Zenodo (https://zenodo.org/record/2603256). 25 | 26 | Args: 27 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 'qcd' and 28 | 'top'. Defaults to "all". 29 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 30 | particle_features (List[str], optional): list of particle features to retrieve. If empty 31 | or None, gets no particle features. Defaults to ``["E", "px", "py", "pz"]``. 32 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 33 | gets no jet features. Defaults to ``["type", "E", "px", "py", "pz"]``. 34 | particle_normalisation (NormaliseABC, optional): optional normalisation to apply to 35 | particle data. Defaults to None. 36 | jet_normalisation (NormaliseABC, optional): optional normalisation to apply to jet data. 37 | Defaults to None. 38 | particle_transform (callable, optional): A function/transform that takes in the particle 39 | data tensor and transforms it. Defaults to None. 40 | jet_transform (callable, optional): A function/transform that takes in the jet 41 | data tensor and transforms it. Defaults to None. 42 | num_particles (int, optional): number of particles to retain per jet, max of 200. Defaults 43 | to 200. 44 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 45 | to "train". 46 | download (bool, optional): If True, downloads the dataset from the internet and 47 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 48 | downloaded again. Defaults to False. 49 | """ 50 | 51 | _ZENODO_RECORD_ID = 2603256 52 | MAX_NUM_PARTICLES = 200 53 | 54 | JET_TYPES = ["qcd", "top"] 55 | ALL_PARTICLE_FEATURES = ["E", "px", "py", "pz"] 56 | ALL_JET_FEATURES = ["type", "E", "px", "py", "pz"] 57 | SPLITS = ["train", "valid", "test"] 58 | _SPLIT_KEY_MAPPING = {"train": "train", "valid": "val", "test": "test"} # map to file name 59 | 60 | def __init__( 61 | self, 62 | jet_type: str | set[str] = "all", 63 | data_dir: str = "./", 64 | particle_features: list[str] | None = "all", 65 | jet_features: list[str] | None = "all", 66 | particle_normalisation: NormaliseABC | None = None, 67 | jet_normalisation: NormaliseABC | None = None, 68 | particle_transform: Callable | None = None, 69 | jet_transform: Callable | None = None, 70 | num_particles: int = MAX_NUM_PARTICLES, 71 | split: str = "train", 72 | download: bool = False, 73 | ): 74 | if particle_features == "all": 75 | particle_features = copy(self.ALL_PARTICLE_FEATURES) 76 | 77 | if jet_features == "all": 78 | jet_features = copy(self.ALL_JET_FEATURES) 79 | 80 | self.particle_data, self.jet_data = self.getData( 81 | jet_type, data_dir, particle_features, jet_features, num_particles, split, download 82 | ) 83 | 84 | super().__init__( 85 | data_dir=data_dir, 86 | particle_features=particle_features, 87 | jet_features=jet_features, 88 | particle_normalisation=particle_normalisation, 89 | jet_normalisation=jet_normalisation, 90 | particle_transform=particle_transform, 91 | jet_transform=jet_transform, 92 | num_particles=num_particles, 93 | ) 94 | 95 | self.jet_type = jet_type 96 | self.split = split 97 | 98 | @classmethod 99 | def getData( 100 | cls, 101 | jet_type: str | set[str] = "all", 102 | data_dir: str = "./", 103 | particle_features: list[str] | None = "all", 104 | jet_features: list[str] | None = "all", 105 | num_particles: int = MAX_NUM_PARTICLES, 106 | split: str = "all", 107 | download: bool = False, 108 | ) -> tuple[np.ndarray | None, np.ndarray | None]: 109 | """ 110 | Downloads, if needed, and loads and returns Top Quark Tagging data. 111 | 112 | Args: 113 | jet_type (Union[str, Set[str]], optional): individual type or set of types out of 'qcd' 114 | and 'top'. Defaults to "all". 115 | data_dir (str, optional): directory in which data is (to be) stored. Defaults to "./". 116 | particle_features (List[str], optional): list of particle features to retrieve. If empty 117 | or None, gets no particle features. Defaults to ``["E", "px", "py", "pz"]``. 118 | jet_features (List[str], optional): list of jet features to retrieve. If empty or None, 119 | gets no jet features. Defaults to ``["type", "E", "px", "py", "pz"]``. 120 | num_particles (int, optional): number of particles to retain per jet, max of 200. 121 | Defaults to 200. 122 | split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults 123 | to "all". 124 | download (bool, optional): If True, downloads the dataset from the internet and 125 | puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not 126 | downloaded again. Defaults to False. 127 | 128 | Returns: 129 | (tuple[np.ndarray | None, np.ndarray | None]): particle data, jet data 130 | """ 131 | if particle_features == "all": 132 | particle_features = copy(cls.ALL_PARTICLE_FEATURES) 133 | 134 | if jet_features == "all": 135 | jet_features = copy(cls.ALL_JET_FEATURES) 136 | 137 | assert num_particles <= cls.MAX_NUM_PARTICLES, ( 138 | f"num_particles {num_particles} exceeds max number of " 139 | + f"particles in the dataset {cls.MAX_NUM_PARTICLES}" 140 | ) 141 | 142 | jet_type = checkConvertElements(jet_type, cls.JET_TYPES, ntype="jet type") 143 | type_indices = [cls.JET_TYPES.index(t) for t in jet_type] 144 | 145 | particle_features, jet_features = checkStrToList(particle_features, jet_features) 146 | use_particle_features, use_jet_features = checkListNotEmpty(particle_features, jet_features) 147 | split = checkConvertElements(split, cls.SPLITS, ntype="splitting") 148 | 149 | import pandas as pd 150 | 151 | particle_data = [] 152 | jet_data = [] 153 | 154 | for s in split: 155 | hdf5_file = checkDownloadZenodoDataset( 156 | data_dir, 157 | dataset_name=cls._SPLIT_KEY_MAPPING[s], 158 | record_id=cls._ZENODO_RECORD_ID, 159 | key=f"{cls._SPLIT_KEY_MAPPING[s]}.h5", 160 | download=download, 161 | ) 162 | 163 | data = np.array(pd.read_hdf(hdf5_file, key="table")) 164 | 165 | # select only specified types of jets (qcd or top or both) 166 | jet_selector = np.sum([data[:, -1] == i for i in type_indices], axis=0).astype(bool) 167 | data = data[jet_selector] 168 | 169 | # extract particle and jet features in the order specified by the class 170 | # ``feature_order`` variables 171 | total_particle_features = cls.MAX_NUM_PARTICLES * len(cls.ALL_PARTICLE_FEATURES) 172 | 173 | if use_particle_features: 174 | pf = data[:, :total_particle_features].reshape( 175 | -1, cls.MAX_NUM_PARTICLES, len(cls.ALL_PARTICLE_FEATURES) 176 | )[:, :num_particles] 177 | 178 | # reorder if needed 179 | pf = getOrderedFeatures(pf, particle_features, cls.ALL_PARTICLE_FEATURES) 180 | particle_data.append(pf) 181 | 182 | if use_jet_features: 183 | jf = np.concatenate( 184 | (data[:, -1:], data[:, total_particle_features : total_particle_features + 4]), 185 | axis=-1, 186 | ) 187 | 188 | # reorder if needed 189 | jf = getOrderedFeatures(jf, jet_features, cls.ALL_JET_FEATURES) 190 | jet_data.append(jf) 191 | 192 | particle_data = np.concatenate(particle_data, axis=0) if use_particle_features else None 193 | jet_data = np.concatenate(jet_data, axis=0) if use_jet_features else None 194 | 195 | return particle_data, jet_data 196 | 197 | def extra_repr(self) -> str: 198 | ret = f"Including {self.jet_type} jets" 199 | 200 | if self.split == "all": 201 | ret += "\nUsing all data (no split)" 202 | else: 203 | ret += f"\nSplit into {self.split} data out of {self.SPLITS} possible splits" 204 | 205 | return ret 206 | -------------------------------------------------------------------------------- /jetnet/datasets/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility methods for datasets. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import hashlib 8 | import sys 9 | from pathlib import Path 10 | from typing import Any 11 | 12 | import numpy as np 13 | import requests 14 | from numpy.typing import ArrayLike 15 | 16 | 17 | def download_progress_bar(file_url: str, file_dest: str): 18 | """ 19 | Download while outputting a progress bar. 20 | Modified from https://sumit-ghosh.com/articles/python-download-progress-bar/ 21 | 22 | Args: 23 | file_url (str): url to download from 24 | file_dest (str): path at which to save downloaded file 25 | """ 26 | 27 | with Path(file_dest).open("wb") as f: 28 | response = requests.get(file_url, stream=True) 29 | total = response.headers.get("content-length") 30 | 31 | if total is None: 32 | f.write(response.content) 33 | else: 34 | downloaded = 0 35 | total = int(total) 36 | 37 | print("Downloading dataset") 38 | for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)): 39 | downloaded += len(data) 40 | f.write(data) 41 | done = int(50 * downloaded / total) 42 | sys.stdout.write( 43 | "\r[{}{}] {:.0f}%".format( 44 | "█" * done, "." * (50 - done), float(downloaded / total) * 100 45 | ) 46 | ) 47 | sys.stdout.flush() 48 | 49 | sys.stdout.write("\n") 50 | 51 | 52 | # from TorchVision 53 | # https://github.com/pytorch/vision/blob/48f8473e21b0f3e425aabc60db201b68fedf59b3/torchvision/datasets/utils.py#L51-L66 54 | def _calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: 55 | # Setting the `usedforsecurity` flag does not change anything about the functionality, but 56 | # indicates that we are not using the MD5 checksum for cryptography. This enables its usage 57 | # in restricted environments like FIPS. 58 | if sys.version_info >= (3, 9): # noqa: UP036 59 | md5 = hashlib.md5(usedforsecurity=False) 60 | else: 61 | md5 = hashlib.md5() 62 | with Path(fpath).open("rb") as f: 63 | # switch to simpler assignment operator once we support only Python >=3.8 64 | # while chunk := f.read(chunk_size): 65 | for chunk in iter(lambda: f.read(chunk_size), b""): 66 | md5.update(chunk) 67 | return md5.hexdigest() 68 | 69 | 70 | def _check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: 71 | fmd5 = _calculate_md5(fpath, **kwargs) 72 | return (md5 == fmd5), fmd5 73 | 74 | 75 | def _getZenodoFileURL(record_id: int, file_name: str) -> str: 76 | """Finds URL and md5 hash for downloading the file ``file_name`` from a Zenodo record.""" 77 | 78 | import requests 79 | 80 | records_url = f"https://zenodo.org/api/records/{record_id}" 81 | r = requests.get(records_url).json() 82 | 83 | # Zenodo API seems to be switching back and forth between these at the moment... so trying both 84 | try: 85 | file = next(item for item in r["files"] if item["filename"] == file_name) 86 | file_url = file["links"]["download"] 87 | md5 = file["checksum"] 88 | except KeyError: 89 | file = next(item for item in r["files"] if item["key"] == file_name) 90 | file_url = file["links"]["self"] 91 | md5 = file["checksum"].split("md5:")[1] 92 | 93 | return file_url, md5 94 | 95 | 96 | def checkDownloadZenodoDataset( 97 | data_dir: str, dataset_name: str, record_id: int, key: str, download: bool 98 | ) -> str: 99 | """ 100 | Checks if dataset exists and md5 hash matches; 101 | if not and download = True, downloads it from Zenodo, and returns the file path. 102 | or if not and download = False, raises an error. 103 | """ 104 | data_dir = Path(data_dir) 105 | file_path = data_dir / key 106 | file_url, md5 = _getZenodoFileURL(record_id, key) 107 | 108 | if download: 109 | if file_path.is_file(): 110 | match_md5, fmd5 = _check_md5(file_path, md5) 111 | if not match_md5: 112 | print( 113 | f"File corrupted - MD5 hash of {file_path} does not match: " 114 | f"(expected md5:{md5}, got md5:{fmd5}), " 115 | "removing existing file and re-downloading." 116 | "\nPlease open an issue at https://github.com/jet-net/JetNet/issues/new " 117 | "if you believe this is an error." 118 | ) 119 | file_path.unlink() 120 | 121 | if not file_path.is_file(): 122 | data_dir.mkdir(parents=True, exist_ok=True) 123 | 124 | print(f"Downloading {dataset_name} dataset to {file_path}") 125 | download_progress_bar(file_url, file_path) 126 | 127 | if not file_path.is_file(): 128 | raise RuntimeError( 129 | f"Dataset {dataset_name} not found at {file_path}, " 130 | "you can use download=True to download it." 131 | ) 132 | 133 | match_md5, fmd5 = _check_md5(file_path, md5) 134 | if not match_md5: 135 | raise RuntimeError( 136 | f"File corrupted - MD5 hash of {file_path} does not match: " 137 | f"(expected md5:{md5}, got md5:{fmd5}), " 138 | "you can use download=True to re-download it." 139 | "\nPlease open an issue at https://github.com/jet-net/JetNet/issues/new " 140 | "if you believe this is an error." 141 | ) 142 | 143 | return file_path 144 | 145 | 146 | def getOrderedFeatures( 147 | data: ArrayLike, features: list[str], features_order: list[str] 148 | ) -> np.ndarray: 149 | """Returns data with features in the order specified by ``features``. 150 | 151 | Args: 152 | data (ArrayLike): input data 153 | features (List[str]): desired features in order 154 | features_order (List[str]): name and ordering of features in input data 155 | 156 | Returns: 157 | (np.ndarray): data with features in specified order 158 | """ 159 | 160 | if np.all(features == features_order): # check if already in order 161 | return data 162 | 163 | ret_data = [] 164 | for feat in features: 165 | assert ( 166 | feat in features_order 167 | ), f"`{feat}` feature does not exist in this dataset (available features: {features_order})" 168 | index = features_order.index(feat) 169 | ret_data.append(data[..., index, np.newaxis]) 170 | 171 | return np.concatenate(ret_data, axis=-1) 172 | 173 | 174 | def checkStrToList( 175 | *inputs: list[str | list[str] | set[str]], to_set: bool = False 176 | ) -> list[list[str]] | list[set[str]] | list: 177 | """Converts str inputs to a list or set""" 178 | ret = [] 179 | for inp in inputs: 180 | if isinstance(inp, str): 181 | inp = [inp] if not to_set else {inp} # noqa: PLW2901 182 | ret.append(inp) 183 | 184 | return ret if len(inputs) > 1 else ret[0] 185 | 186 | 187 | def checkListNotEmpty(*inputs: list[list]) -> list[bool]: 188 | """Checks that list inputs are not None or empty""" 189 | ret = [] 190 | for inp in inputs: 191 | ret.append(inp is not None and len(inp)) 192 | 193 | return ret if len(inputs) > 1 else ret[0] 194 | 195 | 196 | def firstNotNoneElement(*inputs: list[Any]) -> Any: 197 | """Returns the first element out of all inputs which isn't None""" 198 | for inp in inputs: 199 | if inp is not None: 200 | return inp 201 | 202 | return None 203 | 204 | 205 | def checkConvertElements(elem: str | list[str], valid_types: list[str], ntype: str = "element"): 206 | """Checks if elem(s) are valid and if needed converts into a list""" 207 | if elem != "all": 208 | elem = checkStrToList(elem, to_set=True) 209 | 210 | for j in elem: 211 | assert j in valid_types, f"{j} is not a valid {ntype}, must be one of {valid_types}" 212 | 213 | else: 214 | elem = valid_types 215 | 216 | return elem 217 | 218 | 219 | def getSplitting( 220 | length: int, split: str, splits: list[str], split_fraction: list[float] 221 | ) -> tuple[int, int]: 222 | """ 223 | Returns starting and ending index for splitting a dataset of length ``length`` according to 224 | the input ``split`` out of the total possible ``splits`` and a given ``split_fraction``. 225 | 226 | "all" is considered a special keyword to mean the entire dataset - it cannot be used to define a 227 | normal splitting, and if it is a possible splitting it must be the last entry in ``splits``. 228 | 229 | e.g. for ``length = 100``, ``split = "valid"``, ``splits = ["train", "valid", "test"]``, 230 | ``split_fraction = [0.7, 0.15, 0.15]`` 231 | 232 | This will return ``(70, 85)``. 233 | """ 234 | 235 | assert split in splits, f"{split} not a valid splitting, must be one of {splits}" 236 | 237 | if "all" in splits: 238 | if split == "all": 239 | return 0, length 240 | else: 241 | assert splits[-1] == "all", "'all' must be last entry in ``splits`` array" 242 | splits = splits[:-1] 243 | 244 | assert np.sum(split_fraction) <= 1.0, "sum of split fractions must be ≤ 1" 245 | 246 | split_index = splits.index(split) 247 | cuts = (np.cumsum(np.insert(split_fraction, 0, 0)) * length).astype(int) 248 | return cuts[split_index], cuts[split_index + 1] 249 | -------------------------------------------------------------------------------- /jetnet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .gen_metrics import * # noqa: F403 4 | -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/jetnet/evaluation/fpnd_resources/__init__.py -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/30_particles/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/jetnet/evaluation/fpnd_resources/jetnet/30_particles/__init__.py -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/30_particles/g_mu.txt: -------------------------------------------------------------------------------- 1 | -4.371720850467681885e-01 2 | -1.805645376443862915e-01 3 | -3.922226838767528534e-03 4 | -1.144851818680763245e-01 5 | 4.805534780025482178e-01 6 | -5.648179054260253906e-01 7 | -7.620820403099060059e-02 8 | 1.456556618213653564e-01 9 | -1.398528963327407837e-01 10 | -4.191589727997779846e-02 11 | 1.344792842864990234e-01 12 | 3.165954053401947021e-01 13 | -6.360217928886413574e-02 14 | -7.133946567773818970e-02 15 | 2.720776386559009552e-02 16 | 5.035558715462684631e-02 17 | -1.499548792839050293e+00 18 | -1.038656115531921387e+00 19 | 3.316102921962738037e-02 20 | -2.380084991455078125e+00 21 | 6.171974539756774902e-01 22 | -3.522773385047912598e-01 23 | 2.404114902019500732e-01 24 | -8.551157265901565552e-02 25 | -8.465415835380554199e-01 26 | 4.785576462745666504e-01 27 | -5.424700379371643066e-01 28 | -1.162979125976562500e+00 29 | -1.495897173881530762e+00 30 | -1.475711166858673096e-01 31 | 4.989392161369323730e-01 32 | 1.061629951000213623e-01 33 | -1.887385398149490356e-01 34 | -1.746551275253295898e+00 35 | -4.468872249126434326e-01 36 | -8.725170791149139404e-02 37 | -7.502891868352890015e-02 38 | 2.594395875930786133e-01 39 | 1.840624660253524780e-01 40 | -7.945888042449951172e-01 41 | 3.160526156425476074e-01 42 | -3.048188984394073486e-01 43 | 3.557741343975067139e-01 44 | 8.169592916965484619e-02 45 | -6.985864639282226562e-01 46 | 4.886779189109802246e-01 47 | -6.139116287231445312e-01 48 | 3.562860190868377686e-01 49 | -4.766258299350738525e-01 50 | -4.784351587295532227e-01 51 | -2.442610217258334160e-03 52 | 2.727622985839843750e-01 53 | -7.171751856803894043e-01 54 | -2.463864833116531372e-01 55 | 4.092823565006256104e-01 56 | -1.571005582809448242e-01 57 | 7.710742354393005371e-01 58 | 3.773435950279235840e-01 59 | -5.394867062568664551e-01 60 | -7.830733060836791992e-01 61 | -1.510766297578811646e-01 62 | -4.123195409774780273e-01 63 | 2.969397008419036865e-01 64 | 3.770505264401435852e-02 65 | -5.519575476646423340e-01 66 | 4.071931540966033936e-01 67 | -6.472921967506408691e-01 68 | -2.840783260762691498e-02 69 | -2.346236258745193481e-01 70 | 8.663079738616943359e-01 71 | -4.527220427989959717e-01 72 | 3.369300439953804016e-02 73 | 4.723206162452697754e-01 74 | -1.098836421966552734e+00 75 | 7.643222808837890625e-02 76 | -4.397888779640197754e-01 77 | -1.009121298789978027e+00 78 | -3.323112428188323975e-01 79 | -2.327213138341903687e-01 80 | 5.621105805039405823e-02 81 | -2.819559276103973389e-01 82 | 9.416612237691879272e-02 83 | 3.509472608566284180e-01 84 | -4.308800697326660156e-01 85 | -1.064589470624923706e-01 86 | -1.205602049827575684e+00 87 | -3.481802046298980713e-01 88 | 3.726439476013183594e-01 89 | 2.048564255237579346e-01 90 | -1.626667217351496220e-03 91 | -4.506298303604125977e-01 92 | 9.711784869432449341e-02 93 | 6.800799816846847534e-02 94 | 2.331841140985488892e-01 95 | -8.724997639656066895e-01 96 | -8.136501312255859375e-01 97 | -1.316208124160766602e+00 98 | -2.303700000047683716e-01 99 | -1.516915440559387207e+00 100 | -6.604931950569152832e-01 101 | -7.996947765350341797e-01 102 | -5.047752335667610168e-02 103 | -1.314321517944335938e+00 104 | 3.070122003555297852e-01 105 | 6.434792303480207920e-04 106 | -4.783317446708679199e-01 107 | -1.012912318110466003e-01 108 | 5.209989547729492188e-01 109 | 7.858695387840270996e-01 110 | -2.878283262252807617e-01 111 | 1.279395520687103271e-01 112 | -1.370989680290222168e-01 113 | -3.968503475189208984e-01 114 | 3.680468499660491943e-01 115 | 2.024173550307750702e-02 116 | 3.463025093078613281e-01 117 | -3.479301929473876953e-01 118 | 4.415942430496215820e-01 119 | -4.106250405311584473e-01 120 | -7.002981901168823242e-01 121 | -9.529718160629272461e-01 122 | -7.855919599533081055e-01 123 | 7.998730242252349854e-02 124 | -9.796912670135498047e-01 125 | 2.727634012699127197e-01 126 | 3.319811224937438965e-01 127 | 1.116193085908889771e-01 128 | -1.460960060358047485e-01 129 | -8.218395113945007324e-01 130 | -7.186439633369445801e-02 131 | -2.346745878458023071e-01 132 | 2.542576380074024200e-02 133 | 1.816838085651397705e-01 134 | 5.876022577285766602e-01 135 | 3.162940144538879395e-01 136 | -7.620092481374740601e-02 137 | -4.904680550098419189e-01 138 | -5.839932560920715332e-01 139 | -5.164038538932800293e-01 140 | -1.462005257606506348e+00 141 | -4.774998426437377930e-01 142 | 4.595757275819778442e-02 143 | -5.097575187683105469e-01 144 | -9.771474003791809082e-01 145 | -1.881053149700164795e-01 146 | 5.001438856124877930e-01 147 | -6.051782369613647461e-01 148 | 3.425657004117965698e-02 149 | -1.128476783633232117e-01 150 | -1.384837031364440918e-01 151 | -1.237563192844390869e-01 152 | -6.253792047500610352e-01 153 | 3.204698860645294189e-02 154 | -9.198061823844909668e-01 155 | -1.271724343299865723e+00 156 | 2.197372317314147949e-01 157 | -6.598924994468688965e-01 158 | 1.089294701814651489e-01 159 | -5.932117700576782227e-01 160 | -2.714620828628540039e-01 161 | 2.297144532203674316e-01 162 | -3.174534142017364502e-01 163 | -1.230260252952575684e+00 164 | -6.159781217575073242e-01 165 | 4.722552597522735596e-01 166 | -1.569443941116333008e-01 167 | 4.642578959465026855e-02 168 | -2.780922129750251770e-02 169 | -2.199016809463500977e-01 170 | -1.163347885012626648e-01 171 | 3.830692498013377190e-03 172 | -6.645449995994567871e-01 173 | -3.374440968036651611e-02 174 | -1.814230531454086304e-02 175 | -2.982152104377746582e-01 176 | 1.432804018259048462e-01 177 | -4.904610812664031982e-01 178 | 3.358787894248962402e-01 179 | -1.696789711713790894e-01 180 | -3.305447846651077271e-02 181 | -2.128667831420898438e-01 182 | 4.716456234455108643e-01 183 | -3.524016588926315308e-02 184 | -5.457623004913330078e-01 185 | -1.423557519912719727e+00 186 | 5.526475980877876282e-02 187 | -4.537776112556457520e-01 188 | -1.297913551330566406e+00 189 | 1.807704418897628784e-01 190 | 1.331230700016021729e-01 191 | 1.931185424327850342e-01 192 | -4.061575829982757568e-01 193 | 2.985567450523376465e-01 194 | -4.399412572383880615e-01 195 | -1.477907299995422363e-01 196 | -5.870941281318664551e-02 197 | -9.247033596038818359e-01 198 | 2.165328897535800934e-02 199 | 3.458697348833084106e-02 200 | -2.492623478174209595e-01 201 | 3.014889955520629883e-01 202 | 3.760405778884887695e-01 203 | 4.254507124423980713e-01 204 | 4.157692492008209229e-01 205 | -3.962422907352447510e-01 206 | -3.981198370456695557e-01 207 | -4.856810271739959717e-01 208 | 8.768606185913085938e-02 209 | -5.529445409774780273e-01 210 | 3.885187506675720215e-01 211 | -3.097257912158966064e-01 212 | -4.833705723285675049e-01 213 | -1.072345733642578125e+00 214 | 6.858038902282714844e-01 215 | -4.481486082077026367e-01 216 | 3.248192667961120605e-01 217 | 4.484864473342895508e-01 218 | -2.977046072483062744e-01 219 | -9.272873997688293457e-01 220 | 2.930595576763153076e-01 221 | -8.383497595787048340e-01 222 | 4.459198117256164551e-01 223 | -6.182915568351745605e-01 224 | -5.606850981712341309e-01 225 | -1.579932272434234619e-01 226 | 4.240046143531799316e-01 227 | 2.476030737161636353e-01 228 | -3.763284683227539062e-01 229 | -6.023084521293640137e-01 230 | 5.583781003952026367e-01 231 | -7.434175610542297363e-01 232 | -2.303280234336853027e-01 233 | -2.957838773727416992e-01 234 | -1.138446688652038574e+00 235 | 2.429798394441604614e-01 236 | -2.334023416042327881e-01 237 | -6.806474924087524414e-02 238 | 4.592995047569274902e-01 239 | -5.372208356857299805e-01 240 | 2.400450408458709717e-01 241 | -3.004089891910552979e-01 242 | -4.917985498905181885e-01 243 | -4.610557481646537781e-02 244 | 1.275183558464050293e-01 245 | -4.134217202663421631e-01 246 | -4.745686054229736328e-01 247 | -1.494868278503417969e+00 248 | -1.214967370033264160e+00 249 | 2.700903117656707764e-01 250 | 3.530358895659446716e-02 251 | -1.260660439729690552e-01 252 | -1.801397442817687988e+00 253 | 1.643899232149124146e-01 254 | -1.548262219876050949e-02 255 | -3.474367558956146240e-01 256 | 3.958128392696380615e-01 257 | -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/30_particles/pnet_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/jetnet/evaluation/fpnd_resources/jetnet/30_particles/pnet_state_dict.pt -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/30_particles/q_mu.txt: -------------------------------------------------------------------------------- 1 | -3.448236286640167236e-01 2 | -2.860240936279296875e-01 3 | 4.316971600055694580e-01 4 | 2.114026099443435669e-01 5 | 4.048407077789306641e-01 6 | -7.472324371337890625e-01 7 | -3.415537476539611816e-01 8 | 6.920142769813537598e-01 9 | -4.207466840744018555e-01 10 | 2.682481706142425537e-01 11 | 4.023147523403167725e-01 12 | 5.226641893386840820e-01 13 | 2.767414152622222900e-01 14 | -1.369227468967437744e-01 15 | -6.378980278968811035e-01 16 | 1.335402876138687134e-01 17 | -2.277965068817138672e+00 18 | -2.011763334274291992e+00 19 | 3.264744281768798828e-01 20 | -2.850445270538330078e+00 21 | 8.144413828849792480e-01 22 | -4.054320752620697021e-01 23 | -1.408622264862060547e-01 24 | 4.969568252563476562e-01 25 | -1.043303251266479492e+00 26 | 4.936068356037139893e-01 27 | -1.391484260559082031e+00 28 | -1.709470987319946289e+00 29 | -2.646713495254516602e+00 30 | 4.055457413196563721e-01 31 | 8.257797360420227051e-01 32 | -4.636558890342712402e-01 33 | -5.166595578193664551e-01 34 | -2.647644042968750000e+00 35 | 2.921877615153789520e-02 36 | -3.723706603050231934e-01 37 | 3.424092233180999756e-01 38 | 2.364711835980415344e-02 39 | -1.614864319562911987e-01 40 | -9.467721581459045410e-01 41 | -1.004892736673355103e-01 42 | -1.122161149978637695e+00 43 | 5.080600976943969727e-01 44 | 2.791655659675598145e-01 45 | -1.217724561691284180e+00 46 | -1.302572488784790039e-01 47 | -1.149832367897033691e+00 48 | 5.986269935965538025e-02 49 | -5.742069482803344727e-01 50 | -2.711922526359558105e-01 51 | 5.547340512275695801e-01 52 | 4.175775051116943359e-01 53 | -8.171770572662353516e-01 54 | -6.083912216126918793e-03 55 | 1.255086157470941544e-02 56 | 4.354733973741531372e-02 57 | 7.781461477279663086e-01 58 | -1.505583226680755615e-01 59 | -7.688400149345397949e-01 60 | -1.047348856925964355e+00 61 | 2.618552744388580322e-01 62 | 3.275084123015403748e-02 63 | -2.195692509412765503e-01 64 | 4.734301567077636719e-01 65 | -5.923132300376892090e-01 66 | -1.442272216081619263e-01 67 | -6.929616332054138184e-01 68 | -1.877778954803943634e-02 69 | -8.928155153989791870e-02 70 | 1.043061614036560059e+00 71 | -4.462654590606689453e-01 72 | 4.177887439727783203e-01 73 | 5.731233954429626465e-01 74 | -1.660756826400756836e+00 75 | 6.736028194427490234e-01 76 | -4.943993687629699707e-01 77 | -1.107497215270996094e+00 78 | -8.977988362312316895e-01 79 | -4.050123393535614014e-01 80 | -3.268640935420989990e-01 81 | 1.501356214284896851e-01 82 | 5.005853176116943359e-01 83 | -1.887347102165222168e-01 84 | -5.438641905784606934e-01 85 | 2.663469910621643066e-01 86 | -1.119373679161071777e+00 87 | -4.855801463127136230e-01 88 | 5.534688830375671387e-01 89 | 3.466225862503051758e-01 90 | 5.360110402107238770e-01 91 | -5.281301736831665039e-01 92 | 3.138095140457153320e-01 93 | 4.714883565902709961e-01 94 | -2.012610286474227905e-01 95 | -1.000772237777709961e+00 96 | -1.397985339164733887e+00 97 | -2.284084796905517578e+00 98 | -1.281093358993530273e+00 99 | -1.885194659233093262e+00 100 | -1.144259095191955566e+00 101 | -1.795475363731384277e+00 102 | 3.764282166957855225e-01 103 | -1.919138312339782715e+00 104 | 4.662797152996063232e-01 105 | 5.144885778427124023e-01 106 | -6.312644481658935547e-01 107 | 4.367765188217163086e-01 108 | 9.790994971990585327e-02 109 | 9.397537708282470703e-01 110 | -1.129408121109008789e+00 111 | 3.973918855190277100e-01 112 | -8.304964900016784668e-01 113 | -4.584658741950988770e-01 114 | 3.711727261543273926e-02 115 | -2.562644481658935547e-01 116 | -1.301011145114898682e-01 117 | -3.918544948101043701e-01 118 | 5.834288001060485840e-01 119 | -5.969913005828857422e-01 120 | -8.957766294479370117e-01 121 | -1.943243026733398438e+00 122 | -8.552753925323486328e-01 123 | 4.663594961166381836e-01 124 | -2.040998935699462891e+00 125 | -4.033117592334747314e-01 126 | 7.086278796195983887e-01 127 | -3.093762397766113281e-01 128 | -2.304511517286300659e-01 129 | -1.804449796676635742e+00 130 | -1.001575961709022522e-01 131 | -1.018240571022033691e+00 132 | -4.541819989681243896e-01 133 | 6.651262938976287842e-02 134 | 7.936475276947021484e-01 135 | 7.414385080337524414e-01 136 | 4.400554299354553223e-01 137 | -4.518392980098724365e-01 138 | -6.510295271873474121e-01 139 | -9.125501513481140137e-01 140 | -1.339298129081726074e+00 141 | -6.232970356941223145e-01 142 | 3.483478426933288574e-01 143 | -1.253634095191955566e+00 144 | -1.518315911293029785e+00 145 | -3.541562855243682861e-01 146 | -1.449910998344421387e-01 147 | -1.325427174568176270e+00 148 | 4.116210937500000000e-01 149 | 5.194989442825317383e-01 150 | -6.833115126937627792e-03 151 | -7.372644543647766113e-01 152 | -1.497585415840148926e+00 153 | 4.709804356098175049e-01 154 | -3.448056578636169434e-01 155 | -2.126025199890136719e+00 156 | 2.406137287616729736e-01 157 | -8.993841409683227539e-01 158 | 4.108309447765350342e-01 159 | -5.958597064018249512e-01 160 | -1.249573707580566406e+00 161 | 3.303063213825225830e-01 162 | -1.110633492469787598e+00 163 | -1.622221231460571289e+00 164 | -1.052463293075561523e+00 165 | 4.987423121929168701e-01 166 | -5.464392900466918945e-01 167 | -2.490926831960678101e-01 168 | 5.335001349449157715e-01 169 | -9.762493893504142761e-03 170 | 3.653861284255981445e-01 171 | 4.738501310348510742e-01 172 | -7.092235684394836426e-01 173 | 2.252282649278640747e-01 174 | 5.252654552459716797e-01 175 | -1.184747219085693359e+00 176 | 3.891493678092956543e-01 177 | -5.309808850288391113e-01 178 | 1.428944170475006104e-01 179 | 6.151017546653747559e-02 180 | -2.753069996833801270e-01 181 | -1.177823066711425781e+00 182 | 6.377012729644775391e-01 183 | 2.930057942867279053e-01 184 | -5.294016599655151367e-01 185 | -1.699206113815307617e+00 186 | -1.184429600834846497e-01 187 | -6.647282838821411133e-01 188 | -1.108105897903442383e+00 189 | 5.042155385017395020e-01 190 | 6.475363969802856445e-01 191 | -1.050402671098709106e-01 192 | -4.637343883514404297e-01 193 | -8.956113457679748535e-02 194 | -6.456416249275207520e-01 195 | -5.638803243637084961e-01 196 | -1.008124828338623047e+00 197 | -1.837265491485595703e+00 198 | 6.154600903391838074e-02 199 | 3.397713601589202881e-01 200 | -5.247316956520080566e-01 201 | -8.771530538797378540e-02 202 | 6.683788299560546875e-01 203 | 4.114940166473388672e-01 204 | -9.826095402240753174e-02 205 | -4.504830837249755859e-01 206 | 3.903784602880477905e-02 207 | -3.800178766250610352e-01 208 | 5.255084633827209473e-01 209 | -1.174760103225708008e+00 210 | 2.492522448301315308e-02 211 | -1.459077447652816772e-01 212 | -7.468535900115966797e-01 213 | -9.900581836700439453e-01 214 | 1.085715532302856445e+00 215 | -4.173258543014526367e-01 216 | 5.078753232955932617e-01 217 | 2.541098184883594513e-02 218 | -3.582891821861267090e-01 219 | -1.078568816184997559e+00 220 | -2.991097569465637207e-01 221 | -1.669080018997192383e+00 222 | -2.048849463462829590e-01 223 | -7.773264646530151367e-01 224 | -1.238261461257934570e+00 225 | 5.451040863990783691e-01 226 | 5.741064548492431641e-01 227 | -2.636810950934886932e-02 228 | -1.284635365009307861e-01 229 | -1.372915148735046387e+00 230 | -1.130890399217605591e-01 231 | -1.659625411033630371e+00 232 | -1.276493966579437256e-01 233 | -4.425036609172821045e-01 234 | -1.535377025604248047e+00 235 | 6.157888174057006836e-01 236 | -1.518803387880325317e-01 237 | -1.347669214010238647e-01 238 | 3.831249848008155823e-02 239 | -5.537612438201904297e-01 240 | 4.683212041854858398e-01 241 | -1.151367664337158203e+00 242 | -6.045078039169311523e-01 243 | 5.217700600624084473e-01 244 | 1.745016314089298248e-02 245 | -1.395949959754943848e+00 246 | -5.257391333580017090e-01 247 | -2.420942068099975586e+00 248 | -1.966244816780090332e+00 249 | -7.832395285367965698e-02 250 | 5.593221187591552734e-01 251 | -3.563382029533386230e-01 252 | -2.411356210708618164e+00 253 | -3.674933910369873047e-01 254 | 4.713683724403381348e-01 255 | -5.152627825736999512e-01 256 | 5.340193510055541992e-01 257 | -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/30_particles/t_mu.txt: -------------------------------------------------------------------------------- 1 | -1.919607162475585938e+00 2 | 4.656288623809814453e-01 3 | 3.143086433410644531e-01 4 | -6.669927835464477539e-01 5 | -1.724123507738113403e-01 6 | -4.384452700614929199e-01 7 | 7.996671199798583984e-01 8 | -4.601585566997528076e-01 9 | 7.665176391601562500e-01 10 | 2.281616777181625366e-01 11 | -3.430414199829101562e-01 12 | 3.169580996036529541e-01 13 | -2.049736529588699341e-01 14 | 6.503489613533020020e-01 15 | 1.485682129859924316e+00 16 | -3.553236424922943115e-01 17 | 5.169294774532318115e-02 18 | 4.028459787368774414e-01 19 | -6.149333715438842773e-01 20 | -1.170877099037170410e+00 21 | -3.112658560276031494e-01 22 | -4.972251951694488525e-01 23 | -1.252682805061340332e-01 24 | -2.356266081333160400e-01 25 | -1.190475344657897949e+00 26 | 9.617172479629516602e-01 27 | 9.143186807632446289e-01 28 | -9.670587778091430664e-01 29 | 1.091158017516136169e-01 30 | -8.095042109489440918e-01 31 | -4.164728522300720215e-01 32 | -1.218203783035278320e+00 33 | 5.765347480773925781e-01 34 | -8.953229188919067383e-01 35 | -1.472953081130981445e+00 36 | 5.623524188995361328e-01 37 | 4.510621130466461182e-01 38 | 1.328066736459732056e-01 39 | 8.810572624206542969e-01 40 | -1.149511575698852539e+00 41 | -3.066723942756652832e-01 42 | 9.098187685012817383e-01 43 | 9.248798489570617676e-01 44 | 2.989077568054199219e-01 45 | 2.631780803203582764e-01 46 | 1.020086407661437988e+00 47 | 2.466192543506622314e-01 48 | 6.414339691400527954e-02 49 | -7.393011450767517090e-01 50 | -1.946285486221313477e+00 51 | -1.087762191891670227e-01 52 | 3.782342970371246338e-01 53 | -1.212731480598449707e+00 54 | -1.611511856317520142e-01 55 | -7.597271353006362915e-02 56 | -6.237940192222595215e-01 57 | -1.939526796340942383e-01 58 | 9.001361131668090820e-01 59 | -6.185238957405090332e-01 60 | 1.288689821958541870e-01 61 | -1.179372072219848633e+00 62 | -1.679799199104309082e+00 63 | -2.989626228809356689e-01 64 | -1.893358826637268066e-01 65 | -6.635152101516723633e-01 66 | 5.397852063179016113e-01 67 | -9.953084588050842285e-01 68 | -1.155411243438720703e+00 69 | -6.375781893730163574e-01 70 | 7.825260609388351440e-02 71 | -8.238940834999084473e-01 72 | -5.442811846733093262e-01 73 | -8.276981115341186523e-01 74 | -9.539427161216735840e-01 75 | -4.123525023460388184e-01 76 | -7.059239149093627930e-01 77 | -1.105544090270996094e+00 78 | 1.027651309967041016e+00 79 | -2.130596041679382324e-01 80 | -5.067097544670104980e-01 81 | -1.042125821113586426e+00 82 | -8.792169690132141113e-01 83 | -2.352587878704071045e-01 84 | -7.521837353706359863e-01 85 | -4.220880568027496338e-01 86 | -1.804182648658752441e+00 87 | 6.798790693283081055e-01 88 | -2.831193506717681885e-01 89 | 5.158485770225524902e-01 90 | -2.018180340528488159e-01 91 | -6.817542314529418945e-01 92 | -9.894375205039978027e-01 93 | -4.651043713092803955e-01 94 | -8.494793623685836792e-02 95 | -9.449350237846374512e-01 96 | -3.278186321258544922e-01 97 | 8.577052503824234009e-02 98 | 1.577664852142333984e+00 99 | -1.397526144981384277e+00 100 | -1.868851333856582642e-01 101 | 8.533503413200378418e-01 102 | 4.248158633708953857e-01 103 | -9.258887767791748047e-01 104 | -2.368453741073608398e-01 105 | -9.010270237922668457e-02 106 | -6.119343638420104980e-01 107 | -5.891281366348266602e-01 108 | 7.191983461380004883e-01 109 | 2.955370582640171051e-02 110 | 1.206668257713317871e+00 111 | 6.946456432342529297e-01 112 | 1.118457555770874023e+00 113 | -5.398175120353698730e-01 114 | 2.553229033946990967e-01 115 | 7.950652837753295898e-01 116 | -2.915655374526977539e-01 117 | -3.902739882469177246e-01 118 | 4.938508272171020508e-01 119 | -4.755576252937316895e-01 120 | -8.193768262863159180e-01 121 | 7.926660180091857910e-01 122 | -8.476743102073669434e-01 123 | -8.025608062744140625e-01 124 | 6.463884711265563965e-01 125 | -2.626391351222991943e-01 126 | -2.791174948215484619e-01 127 | 3.023090362548828125e-01 128 | 2.112307399511337280e-01 129 | 6.889151334762573242e-01 130 | 7.472928762435913086e-01 131 | 1.219888925552368164e+00 132 | -4.926699697971343994e-01 133 | 2.493357807397842407e-01 134 | -2.325252741575241089e-01 135 | -4.283604025840759277e-01 136 | 1.798242777585983276e-01 137 | -6.087763905525207520e-01 138 | -7.567882537841796875e-01 139 | 4.664313793182373047e-01 140 | -2.319591522216796875e+00 141 | -5.358254909515380859e-01 142 | -6.826263070106506348e-01 143 | 5.264142155647277832e-01 144 | 2.555173933506011963e-01 145 | 3.007619082927703857e-01 146 | 5.843247175216674805e-01 147 | 3.486472368240356445e-01 148 | 4.419296979904174805e-01 149 | -2.031341344118118286e-01 150 | -9.328261613845825195e-01 151 | 1.324127912521362305e+00 152 | 7.321987301111221313e-02 153 | -5.341126322746276855e-01 154 | -4.421860277652740479e-01 155 | -8.315241336822509766e-01 156 | 1.418796777725219727e-01 157 | -9.182305335998535156e-01 158 | 7.004160434007644653e-02 159 | -7.847358584403991699e-01 160 | 1.134046673774719238e+00 161 | -3.076154589653015137e-01 162 | 1.223272085189819336e+00 163 | -1.078833222389221191e+00 164 | -1.190787330269813538e-01 165 | -4.840435087680816650e-02 166 | -1.128192067146301270e+00 167 | 8.865109086036682129e-01 168 | -2.088767588138580322e-01 169 | -7.078149914741516113e-01 170 | -1.131381154060363770e+00 171 | -5.183210968971252441e-01 172 | -1.094166755676269531e+00 173 | -8.271268010139465332e-01 174 | -8.721852302551269531e-01 175 | 1.117804765701293945e+00 176 | -8.646777868270874023e-01 177 | -4.620472192764282227e-01 178 | -5.398727655410766602e-01 179 | 9.732392430305480957e-02 180 | 3.556648492813110352e-01 181 | 1.331714510917663574e+00 182 | -9.940794855356216431e-02 183 | 3.306509256362915039e-01 184 | -6.415281891822814941e-01 185 | -1.285492181777954102e+00 186 | -1.814385354518890381e-01 187 | 6.028547883033752441e-01 188 | -1.858239889144897461e+00 189 | -5.734735727310180664e-01 190 | -3.941416144371032715e-01 191 | 8.563172817230224609e-01 192 | -5.027634501457214355e-01 193 | -1.346674514934420586e-03 194 | -5.586805343627929688e-01 195 | 7.432445287704467773e-01 196 | 1.261430501937866211e+00 197 | -1.897506862878799438e-01 198 | 6.696163415908813477e-01 199 | -7.139615416526794434e-01 200 | -6.934331059455871582e-01 201 | -2.526594996452331543e-01 202 | -2.403129041194915771e-01 203 | -1.007690057158470154e-01 204 | 3.904160857200622559e-01 205 | -4.890286028385162354e-01 206 | -1.927569985389709473e+00 207 | -5.368205308914184570e-01 208 | 4.976361989974975586e-01 209 | 2.046933919191360474e-01 210 | 1.820291280746459961e-01 211 | -3.095429837703704834e-01 212 | 6.006532311439514160e-01 213 | -1.681893229484558105e+00 214 | -1.675780415534973145e-01 215 | -4.785014986991882324e-01 216 | -3.858889341354370117e-01 217 | 4.428193867206573486e-01 218 | -4.082077741622924805e-01 219 | -8.311929702758789062e-01 220 | -2.440002560615539551e-01 221 | 7.815997600555419922e-01 222 | 7.201396822929382324e-01 223 | -6.865431070327758789e-01 224 | 9.569305181503295898e-01 225 | -1.409661889076232910e+00 226 | -5.221148729324340820e-01 227 | 1.182062178850173950e-01 228 | -9.755721092224121094e-01 229 | 9.030334651470184326e-02 230 | 1.038209319114685059e+00 231 | 9.354519248008728027e-01 232 | 9.493125230073928833e-02 233 | -1.515232026576995850e-01 234 | -1.118843078613281250e+00 235 | -4.374151527881622314e-01 236 | -8.177415728569030762e-01 237 | 8.400709629058837891e-01 238 | -1.703475974500179291e-02 239 | -5.031045675277709961e-01 240 | -3.185825049877166748e-01 241 | 5.038411021232604980e-01 242 | -5.699746012687683105e-01 243 | -2.212360948324203491e-01 244 | 4.871013164520263672e-01 245 | 9.771687984466552734e-01 246 | -6.658237576484680176e-01 247 | -7.770020365715026855e-01 248 | 1.578502506017684937e-01 249 | -7.723373174667358398e-02 250 | -5.432156920433044434e-01 251 | -4.745414480566978455e-02 252 | -1.289033174514770508e+00 253 | -5.989030003547668457e-01 254 | 4.385402798652648926e-01 255 | 3.653763532638549805e-01 256 | -1.277545690536499023e-01 257 | -------------------------------------------------------------------------------- /jetnet/evaluation/fpnd_resources/jetnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jet-net/JetNet/6ca12b7e63062ddc99bcdf48cb304303c64b5be3/jetnet/evaluation/fpnd_resources/jetnet/__init__.py -------------------------------------------------------------------------------- /jetnet/evaluation/particlenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch_cluster import knn_graph 8 | from torch_geometric.nn import EdgeConv, global_mean_pool 9 | 10 | 11 | class _ParticleNetEdgeNet(nn.Module): 12 | def __init__(self, in_size, layer_size): 13 | super().__init__() 14 | 15 | layers = [] 16 | 17 | layers.append(nn.Linear(in_size * 2, layer_size)) 18 | layers.append(nn.BatchNorm1d(layer_size)) 19 | layers.append(nn.ReLU()) 20 | 21 | for _ in range(2): 22 | layers.append(nn.Linear(layer_size, layer_size)) 23 | layers.append(nn.BatchNorm1d(layer_size)) 24 | layers.append(nn.ReLU()) 25 | 26 | self.model = nn.Sequential(*layers) 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | def __repr__(self): 32 | return f"{self.__class__.__name__}(nn={self.model})" 33 | 34 | 35 | class _ParticleNet(nn.Module): 36 | def __init__(self, num_hits, node_feat_size, num_classes=5): 37 | super().__init__() 38 | self.num_hits = num_hits 39 | self.node_feat_size = node_feat_size 40 | self.num_classes = num_classes 41 | 42 | self.k = 16 43 | self.num_edge_convs = 3 44 | self.kernel_sizes = [64, 128, 256] 45 | self.fc_size = 256 46 | self.dropout = 0.1 47 | 48 | self.edge_nets = nn.ModuleList() 49 | self.edge_convs = nn.ModuleList() 50 | 51 | self.kernel_sizes.insert(0, self.node_feat_size) 52 | self.output_sizes = np.cumsum(self.kernel_sizes) 53 | 54 | self.edge_nets.append(_ParticleNetEdgeNet(self.node_feat_size, self.kernel_sizes[1])) 55 | self.edge_convs.append(EdgeConv(self.edge_nets[-1], aggr="mean")) 56 | 57 | for i in range(1, self.num_edge_convs): 58 | # adding kernel sizes because of skip connections 59 | self.edge_nets.append( 60 | _ParticleNetEdgeNet(self.output_sizes[i], self.kernel_sizes[i + 1]) 61 | ) 62 | self.edge_convs.append(EdgeConv(self.edge_nets[-1], aggr="mean")) 63 | 64 | self.fc1 = nn.Sequential(nn.Linear(self.output_sizes[-1], self.fc_size)) 65 | 66 | self.dropout_layer = nn.Dropout(p=self.dropout) 67 | 68 | self.fc2 = nn.Linear(self.fc_size, self.num_classes) 69 | 70 | def forward(self, x, ret_activations=False, relu_activations=False): 71 | batch_size = x.size(0) 72 | x = x.reshape(batch_size * self.num_hits, self.node_feat_size) 73 | zeros = torch.zeros(batch_size * self.num_hits, dtype=int).to(x.device) 74 | zeros[torch.arange(batch_size) * self.num_hits] = 1 75 | batch = torch.cumsum(zeros, 0) - 1 76 | 77 | for i in range(self.num_edge_convs): 78 | # using only angular coords for knn in first edgeconv block 79 | edge_index = ( 80 | knn_graph(x[:, :2], self.k, batch) if i == 0 else knn_graph(x, self.k, batch) 81 | ) 82 | x = torch.cat( 83 | (self.edge_convs[i](x, edge_index), x), dim=1 84 | ) # concatenating with original features i.e. skip connection 85 | 86 | x = global_mean_pool(x, batch) 87 | x = self.fc1(x) 88 | 89 | if ret_activations: 90 | if relu_activations: 91 | return F.relu(x) 92 | else: 93 | return x # for Frechet ParticleNet Distance 94 | else: 95 | x = self.dropout_layer(F.relu(x)) 96 | 97 | return self.fc2(x) # no softmax because pytorch cross entropy loss includes softmax 98 | 99 | # TODO: !! 100 | # def __repr__(self): 101 | # return "" 102 | -------------------------------------------------------------------------------- /jetnet/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .losses import * # noqa: F403 4 | -------------------------------------------------------------------------------- /jetnet/losses/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | 8 | class EMDLoss(nn.Module): 9 | """ 10 | Calculates the energy mover's distance between two batches of jets differentiably 11 | as a convex optimization problem either through the linear programming library ``cvxpy`` 12 | or by converting it to a quadratic programming problem and using the ``qpth`` library. 13 | ``cvxpy`` is marginally more accurate but ``qpth`` is significantly faster so defaults 14 | to ``qpth``. 15 | 16 | **JetNet must be installed with the extra option** ``pip install jetnet[emdloss]`` 17 | **to use this.** 18 | 19 | *Note: PyTorch <= 1.9 has a bug which will cause this to fail for >= 32 particles.* 20 | *This PR should fix this from 1.10 onwards* https://github.com/pytorch/pytorch/pull/61815. 21 | 22 | Args: 23 | method (str): 'cvxpy' or 'qpth'. Defaults to 'qpth'. 24 | num_particles (int): number of particles per jet 25 | - only needs to be specified if method is 'cvxpy'. 26 | qpth_form (str): 'L2' or 'QP'. Defaults to 'L2'. 27 | qpth_l2_strength (float): regularization parameter for 'L2' qp form. 28 | Defaults to 0.0001. 29 | device (str): 'cpu' or 'cuda'. Defaults to 'cpu'. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | method: str = "qpth", 36 | num_particles: int | None = None, 37 | qpth_form: str = "L2", 38 | qpth_l2_strength: float = 0.0001, 39 | device: str = "cpu", 40 | ): 41 | super().__init__() 42 | 43 | if method == "qpth": 44 | try: 45 | global qpth # noqa: PLW0603 46 | qpth = __import__("qpth", globals(), locals()) 47 | except ImportError: 48 | print( 49 | "QPTH needs to be installed separately to use this method " 50 | + "- try pip install jetnet[emdloss]" 51 | ) 52 | raise 53 | else: 54 | try: 55 | global cp, cvxpylayers # noqa: PLW0603 56 | cp = __import__("cvxpy", globals(), locals()) 57 | cvxpylayers = __import__("cvxpylayers", globals(), locals()) 58 | except ImportError: 59 | print( 60 | "cvxpy needs to be installed separately to use this method " 61 | + "- try pip install jetnet[emdloss]" 62 | ) 63 | raise 64 | 65 | assert method == "qpth" or method == "cvxpy", "invalid method type" 66 | assert method != "cvxpy" or ( 67 | num_particles is not None and num_particles > 0 68 | ), "num_particles must be specified to use 'cvxpy' method" 69 | assert qpth_form == "L2" or qpth_form == "QP", "invalid qpth form" 70 | assert device == "cpu" or device == "cuda", "invalid device type" 71 | 72 | self.num_particles = num_particles 73 | self.method = method 74 | if method == "qpth": 75 | self.form = qpth_form 76 | self.l2_strength = qpth_l2_strength 77 | self.device = device 78 | 79 | if method == "cvxpy": 80 | x = cp.Variable(num_particles * num_particles) # flows 81 | c = cp.Parameter(num_particles * num_particles) # costs 82 | w = cp.Parameter(num_particles + num_particles) # weights 83 | Emin = cp.Parameter(1) # min energy out of the two jets 84 | 85 | g1 = np.zeros((num_particles, num_particles * num_particles)) 86 | for i in range(num_particles): 87 | g1[i, i * num_particles : (i + 1) * num_particles] = 1 88 | g2 = np.concatenate([np.eye(num_particles) for i in range(num_particles)], axis=1) 89 | g = np.concatenate((g1, g2), axis=0) 90 | 91 | constraints = [x >= 0, g @ x <= w, cp.sum(x) == Emin] 92 | objective = cp.Minimize(c.T @ x) 93 | problem = cp.Problem(objective, constraints) 94 | 95 | self.cvxpylayer = cvxpylayers.torch.CvxpyLayer( 96 | problem, parameters=[c, w, Emin], variables=[x] 97 | ).to(device) 98 | 99 | def _emd_inference_qpth( 100 | self, distance_matrix: Tensor, weight1: Tensor, weight2: Tensor 101 | ) -> tuple[Tensor, Tensor]: 102 | """ 103 | Using the QP solver QPTH to get EMDs (LP problem), adapted from 104 | https://github.com/icoz69/DeepEMD/blob/master/Models/models/emd_utils.py. 105 | One can transform the LP problem to QP, or omit the QP term by multiplying 106 | it with a small value, i.e. l2_strngth. 107 | 108 | Args: 109 | distance_matrix (Tensor): nbatch * element_number * element_number. 110 | weight1 (Tensor): nbatch * weight_number. 111 | weight2 (Tensor): nbatch * weight_number. 112 | 113 | Returns: 114 | emd distance: nbatch*1 115 | flow : nbatch * weight_number *weight_number 116 | """ 117 | nbatch = distance_matrix.shape[0] 118 | nelement_distmatrix = distance_matrix.shape[1] * distance_matrix.shape[2] 119 | nelement_weight1 = weight1.shape[1] 120 | nelement_weight2 = weight2.shape[1] 121 | 122 | # reshape dist matrix too (nbatch, 1, n1 * n2) 123 | Q_1 = distance_matrix.view(-1, 1, nelement_distmatrix).double() 124 | 125 | if ( 126 | self.form == "QP" 127 | ): # converting to QP - after testing L2 reg performs marginally better than QP 128 | # version: QTQ 129 | Q = torch.bmm(Q_1.transpose(2, 1), Q_1).double() + 1e-4 * torch.eye( 130 | nelement_distmatrix 131 | ).double().unsqueeze(0).repeat( 132 | nbatch, 1, 1 133 | ) # 0.00001 * 134 | p = torch.zeros(nbatch, nelement_distmatrix).double().to(self.device) 135 | elif self.form == "L2": # regularizing a trivial Q term with l2_strength 136 | # version: regularizer 137 | Q = ( 138 | (self.l2_strength * torch.eye(nelement_distmatrix).double()) 139 | .unsqueeze(0) 140 | .repeat(nbatch, 1, 1) 141 | .to(self.device) 142 | ) 143 | p = distance_matrix.view(nbatch, nelement_distmatrix).double() 144 | else: 145 | raise ValueError("Unknown form") 146 | 147 | # h = [0 ... 0 w1 w2] 148 | h_1 = torch.zeros(nbatch, nelement_distmatrix).double().to(self.device) 149 | h_2 = torch.cat([weight1, weight2], 1).double() 150 | h = torch.cat((h_1, h_2), 1) 151 | 152 | G_1 = ( 153 | -torch.eye(nelement_distmatrix) 154 | .double() 155 | .unsqueeze(0) 156 | .repeat(nbatch, 1, 1) 157 | .to(self.device) 158 | ) 159 | G_2 = ( 160 | torch.zeros([nbatch, nelement_weight1 + nelement_weight2, nelement_distmatrix]) 161 | .double() 162 | .to(self.device) 163 | ) 164 | # sum_j(xij) = si 165 | for i in range(nelement_weight1): 166 | G_2[:, i, nelement_weight2 * i : nelement_weight2 * (i + 1)] = 1 167 | # sum_i(xij) = dj 168 | for j in range(nelement_weight2): 169 | G_2[:, nelement_weight1 + j, j::nelement_weight2] = 1 170 | 171 | # xij>=0, sum_j(xij) <= si,sum_i(xij) <= dj, sum_ij(x_ij) = min(sum(si), sum(dj)) 172 | G = torch.cat((G_1, G_2), 1) 173 | A = torch.ones(nbatch, 1, nelement_distmatrix).double().to(self.device) 174 | b = torch.min(torch.sum(weight1, 1), torch.sum(weight2, 1)).unsqueeze(1).double() 175 | flow = qpth.qp.QPFunction(verbose=-1)(Q, p, G, h, A, b) 176 | 177 | energy_diff = torch.abs(torch.sum(weight1, dim=1) - torch.sum(weight2, dim=1)) 178 | 179 | emd_score = torch.sum((Q_1).squeeze() * flow, 1) 180 | emd_score += energy_diff 181 | 182 | return emd_score, flow.view(-1, nelement_weight1, nelement_weight2) 183 | 184 | def forward( 185 | self, jets1: Tensor, jets2: Tensor, return_flows: bool = False 186 | ) -> Tensor | tuple[Tensor, Tensor]: 187 | """ 188 | Calculate EMD between ``jets1`` and ``jets2``. 189 | 190 | Args: 191 | jets1 (Tensor): tensor of shape ``[num_jets, num_particles, num_features]``, 192 | with features in order ``[eta, phi, pt]``. 193 | jets2 (Tensor): tensor of same format as ``jets1``. 194 | return_flows (bool): return energy flows between particles in each jet. 195 | Defaults to False. 196 | 197 | Returns: 198 | Union[Tensor, Tuple[Tensor, Tensor]]: 199 | - **Tensor**: EMD scores tensor of shape [num_jets]. 200 | - **Tensor** *Optional*, if ``return_flows`` is True: tensor of flows between 201 | particles of shape ``[num_jets, num_particles, num_particles]``. 202 | 203 | """ 204 | assert (len(jets1.shape) == 3) and (len(jets2.shape) == 3), "Jets shape incorrect" 205 | assert jets1.shape[0] == jets2.shape[0], "jets1 and jets2 have different numbers of jets" 206 | assert (jets1.shape[1] == self.num_particles) and ( 207 | jets2.shape[1] == self.num_particles 208 | ), "jets don't have num_particles particles" 209 | 210 | if self.method == "cvxpy": 211 | diffs = -(jets1[:, :, :2].unsqueeze(2) - jets2[:, :, :2].unsqueeze(1)) + 1e-12 212 | dists = torch.norm(diffs, dim=3).view(-1, self.num_particles * self.num_particles) 213 | 214 | weights = torch.cat((jets1[:, :, 2], jets2[:, :, 2]), dim=1) 215 | 216 | E1 = torch.sum(jets1[:, :, 2], dim=1) 217 | E2 = torch.sum(jets2[:, :, 2], dim=1) 218 | 219 | Emin = torch.minimum(E1, E2).unsqueeze(1) 220 | EabsDiff = torch.abs(E2 - E1).unsqueeze(1) 221 | 222 | (flows,) = self.cvxpylayer(dists, weights, Emin) 223 | 224 | emds = torch.sum(dists * flows, dim=1) + EabsDiff 225 | elif self.method == "qpth": 226 | diffs = -(jets1[:, :, :2].unsqueeze(2) - jets2[:, :, :2].unsqueeze(1)) + 1e-12 227 | dists = torch.norm(diffs, dim=3) 228 | 229 | emds, flows = self._emd_inference_qpth(dists, jets1[:, :, 2], jets2[:, :, 2]) 230 | 231 | return (emds, flows) if return_flows else emds 232 | -------------------------------------------------------------------------------- /jetnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .coord_transform import * # noqa: F403 4 | from .utils import * # noqa: F403 5 | -------------------------------------------------------------------------------- /jetnet/utils/coord_transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterable 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def cartesian_to_EtaPhiPtE(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 10 | r""" 11 | Transform 4-momenta from Cartesian coordinates to polar coordinates for massless particles. 12 | 13 | Args: 14 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in Cartesian coordinates, 15 | of shape ``[..., 4]``. The last axis should be in order 16 | :math:`(E/c, p_x, p_y, p_z)`. 17 | 18 | Returns: 19 | np.ndarray or torch.Tensor: array of 4-momenta in polar coordinates, arranged in order 20 | :math:`(\eta, \phi, p_\mathrm{T}, E/c)`, where :math:`\eta` is the pseudorapidity. 21 | """ 22 | 23 | eps = __get_default_eps(p4) # default epsilon for the dtype 24 | 25 | # (E/c, px, py, pz) -> (eta, phi, pT, E/c) 26 | p0, px, py, pz = __unbind(p4, axis=-1) 27 | pt = __sqrt(px**2 + py**2) 28 | eta = __arcsinh(pz / (pt + eps)) 29 | phi = __arctan2(py, px) 30 | 31 | return __stack([eta, phi, pt, p0], axis=-1) 32 | 33 | 34 | def EtaPhiPtE_to_cartesian(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 35 | r""" 36 | Transform 4-momenta from polar coordinates to Cartesian coordinates for massless particles. 37 | 38 | Args: 39 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in polar coordinates, 40 | of shape ``[..., 4]``. The last axis should be in order 41 | :math:`(\eta, \phi, p_\mathrm{T}, E/c)`, where :math:`\eta` is the pseudorapidity. 42 | 43 | Returns: 44 | np.ndarray or torch.Tensor: array of 4-momenta in polar coordinates, arranged in order 45 | :math:`(E/c, p_x, p_y, p_z)`. 46 | """ 47 | 48 | # (eta, phi, pT, E/c) -> (E/c, px, py, pz) 49 | eta, phi, pt, p0 = __unbind(p4, axis=-1) 50 | px = pt * __cos(phi) 51 | py = pt * __sin(phi) 52 | pz = pt * __sinh(eta) 53 | 54 | return __stack([p0, px, py, pz], axis=-1) 55 | 56 | 57 | def cartesian_to_YPhiPtE(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 58 | r""" 59 | Transform 4-momenta from Cartesian coordinates to polar coordinates. 60 | 61 | Args: 62 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in Cartesian coordinates, 63 | of shape ``[..., 4]``. The last axis should be in order 64 | :math:`(E/c, p_x, p_y, p_z)`. 65 | 66 | Returns: 67 | np.ndarray or torch.Tensor: array of 4-momenta in polar coordinates, arranged in order 68 | :math:`(y, \phi, E/c, p_\mathrm{T})`, where :math:`y` is the rapidity. 69 | """ 70 | 71 | eps = __get_default_eps(p4) # default epsilon for the dtype 72 | 73 | # (E/c, p_x, p_y, p_z) -> (y, phi, pT, E/c) 74 | p0, px, py, pz = __unbind(p4, axis=-1) 75 | pt = __sqrt(px**2 + py**2) 76 | y = 0.5 * __log((p0 + pz + eps) / (p0 - pz + eps)) 77 | phi = __arctan2(py, px) 78 | 79 | return __stack([y, phi, pt, p0], axis=-1) 80 | 81 | 82 | def YPhiPtE_to_cartesian(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 83 | r""" 84 | Transform 4-momenta from polar coordinates to Cartesian coordinates. 85 | 86 | Args: 87 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in Cartesian coordinates, 88 | of shape ``[..., 4]``. The last axis should be in order 89 | :math:`(y, \phi, E/c, p_\mathrm{T})`, where :math:`y` is the rapidity. 90 | 91 | Returns: 92 | np.ndarray or torch.Tensor: array of 4-momenta in polar coordinates, arranged in order 93 | :math:`(E/c, p_x, p_y, p_z)`. 94 | """ 95 | 96 | eps = __get_default_eps(p4) # default epsilon for the dtype 97 | 98 | # (y, phi, pt, E/c) -> (E/c, px, py, pz) 99 | y, phi, pt, p0 = __unbind(p4, axis=-1) 100 | px = pt * __cos(phi) 101 | py = pt * __sin(phi) 102 | # get pz 103 | mt = p0 / (__cosh(y) + eps) # get transverse mass 104 | pz = mt * __sinh(y) 105 | return __stack([p0, px, py, pz], axis=-1) 106 | 107 | 108 | def cartesian_to_relEtaPhiPt(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 109 | r""" 110 | Get particle features in relative polar coordinates from 4-momenta in Cartesian coordinates. 111 | 112 | Args: 113 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in Cartesian coordinates, 114 | of shape ``[..., 4]``. The last axis should be in order 115 | :math:`(E/c, p_x, p_y, p_z)`. 116 | 117 | Returns: 118 | np.ndarray or torch.Tensor: array of features in relative polar coordinates, arranged 119 | in order :math:`(\eta^\mathrm{rel}, \phi^\mathrm{rel}, p_\mathrm{T}^\mathrm{rel})`. 120 | """ 121 | 122 | eps = __get_default_eps(p4) # default epsilon for the dtype 123 | 124 | # particle (eta, phi, pT) 125 | p4_polar = cartesian_to_EtaPhiPtE(p4) 126 | eta, phi, pt, _ = __unbind(p4_polar, axis=-1) 127 | 128 | # jet (Eta, Phi, PT) 129 | jet_cartesian = __sum(p4, axis=-2, keepdims=True) 130 | jet_polar = cartesian_to_EtaPhiPtE(jet_cartesian) 131 | Eta, Phi, Pt, _ = __unbind(jet_polar, axis=-1) 132 | 133 | # get relative features 134 | pt_rel = pt / (Pt + eps) 135 | eta_rel = eta - Eta 136 | phi_rel = phi - Phi 137 | phi_rel = (phi_rel + np.pi) % (2 * np.pi) - np.pi # map to [-pi, pi] 138 | 139 | return __stack([eta_rel, phi_rel, pt_rel], axis=-1) 140 | 141 | 142 | def EtaPhiPtE_to_relEtaPhiPt(p4: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 143 | r""" 144 | Get particle features in relative polar coordinates from 4-momenta in polar coordinates. 145 | 146 | Args: 147 | p4 (np.ndarray or torch.Tensor): array of 4-momenta in polar coordinates, 148 | of shape ``[..., 4]``. The last axis should be in order 149 | :math:`(\eta, \phi, p_\mathrm{T}, E/c)`, where :math:`\eta` is the pseudorapidity. 150 | 151 | Returns: 152 | np.ndarray or torch.Tensor: array of features in relative polar coordinates, arranged 153 | in order :math:`(\eta^\mathrm{rel}, \phi^\mathrm{rel}, p_\mathrm{T}^\mathrm{rel})`. 154 | """ 155 | 156 | eps = __get_default_eps(p4) # default epsilon for the dtype 157 | 158 | # particle (eta, phi, pT) 159 | p4_polar = p4 160 | eta, phi, pt, _ = __unbind(p4, axis=-1) 161 | 162 | # jet (Eta, Phi, PT) 163 | p4_cartesian = EtaPhiPtE_to_cartesian(p4_polar) 164 | # expand dimension to (..., 1, 4) to match p4 shape 165 | jet_cartesian = __sum(p4_cartesian, axis=-2, keepdims=True) 166 | jet_polar = cartesian_to_EtaPhiPtE(jet_cartesian) 167 | Eta, Phi, Pt, _ = __unbind(jet_polar, axis=-1) 168 | 169 | # get relative features 170 | pt_rel = pt / (Pt + eps) 171 | eta_rel = eta - Eta 172 | phi_rel = phi - Phi 173 | phi_rel = (phi_rel + np.pi) % (2 * np.pi) - np.pi # map to [-pi, pi] 174 | 175 | return __stack([eta_rel, phi_rel, pt_rel], axis=-1) 176 | 177 | 178 | def relEtaPhiPt_to_EtaPhiPt( 179 | p_polarrel: np.ndarray | torch.Tensor, 180 | jet_features: np.ndarray | torch.Tensor, 181 | jet_coord: str = "cartesian", 182 | ) -> np.ndarray | torch.Tensor: 183 | r""" 184 | Get particle features in absolute polar coordinates from relative polar coordinates 185 | and jet features. 186 | 187 | Args: 188 | p_polarrel (np.ndarray or torch.Tensor): array of particle features in 189 | relative polar coordinates of shape ``[..., 3]``. The last axis should be in 190 | order :math:`(\eta^\mathrm{rel}, \phi^\mathrm{rel}, p_\mathrm{T}^\mathrm{rel})`, 191 | where :math:`\eta` is the pseudorapidity. 192 | jet_features (np.ndarray or torch.Tensor): array of jet features in polar coordinates, 193 | of shape ``[..., 4]``. The coordinates are specified by ``jet_coord``. 194 | jet_coord (str): coordinate system of jet features. Can be either "cartesian" or "polar". 195 | Defaults to "cartesian". 196 | If "cartesian", the last axis of ``jet_features`` should be in order 197 | :math:`(E/c, p_x, p_y, p_z)`. 198 | If "polar", the last axis of ``jet_features`` should be in order 199 | :math:`(\eta, \phi, p_\mathrm{T}, E/c)`. 200 | 201 | Returns: 202 | np.ndarray or torch.Tensor: array of particle features in absolute polar coordinates, 203 | arranged in order :math:`(\eta, \phi, p_\mathrm{T}, E/c)`. 204 | """ 205 | 206 | # particle features in relative polar coordinates 207 | eta_rel, phi_rel, pt_rel = __unbind(p_polarrel, axis=-1) 208 | 209 | # jet features in polar coordinates 210 | if jet_coord.lower() in ("cartesian", "epxpypz", "e_px_py_pz"): 211 | jet_features = cartesian_to_EtaPhiPtE(jet_features) 212 | elif jet_coord.lower() in ("polar", "etaphipte", "eta_phi_pt_e"): 213 | pass 214 | else: 215 | raise ValueError("jet_coord can only be 'cartesian' or 'polar'") 216 | # eta is used even though jet is massive 217 | Eta, Phi, Pt, _ = __unbind(jet_features, axis=-1) 218 | 219 | # transform back to absolute coordinates 220 | pt = pt_rel * Pt 221 | eta = eta_rel + Eta 222 | phi = phi_rel + Phi 223 | p0 = pt * __cosh(eta) 224 | 225 | return __stack([eta, phi, pt, p0], axis=-1) 226 | 227 | 228 | def relEtaPhiPt_to_cartesian( 229 | p_polarrel: np.ndarray | torch.Tensor, 230 | jet_features: np.ndarray | torch.Tensor, 231 | jet_coord: str = "cartesian", 232 | ) -> np.ndarray | torch.Tensor: 233 | r""" 234 | Get particle features in absolute Cartesian coordinates from relative polar coordinates 235 | and jet features. 236 | 237 | Args: 238 | p_polarrel (np.ndarray or torch.Tensor): array of particle features in relative 239 | polar coordinates of shape ``[..., 3]``. The last axis should be in order 240 | :math:`(\eta^\mathrm{rel}, \phi^\mathrm{rel}, p_\mathrm{T}^\mathrm{rel})`, 241 | where :math:`\eta` is the pseudorapidity. 242 | jet_features (np.ndarray or torch.Tensor): array of jet features in polar coordinates, 243 | of shape ``[..., 4]``. The coordinates are specified by ``jet_coord``. 244 | jet_coord (str): coordinate system of jet features. Can be either "cartesian" or "polar". 245 | Defaults to "cartesian". 246 | If "cartesian", the last axis of ``jet_features`` should be in order 247 | :math:`(E/c, p_x, p_y, p_z)`. 248 | If "polar", the last axis of ``jet_features`` should be in order 249 | :math:`(\eta, \phi, p_\mathrm{T}, E/c)`. 250 | 251 | Returns: 252 | np.ndarray or torch.Tensor: array of particle features in absolute polar coordinates, 253 | arranged in order :math:`(E/c, p_x, p_y, p_z)`. 254 | """ 255 | p4_polar = relEtaPhiPt_to_EtaPhiPt(p_polarrel, jet_features, jet_coord) 256 | # eta is used even though jet is massive 257 | return EtaPhiPtE_to_cartesian(p4_polar) 258 | 259 | 260 | def __unbind(x: np.ndarray | torch.Tensor, axis: int) -> np.ndarray | torch.Tensor: 261 | """Unbind an np.ndarray or torch.Tensor along a given axis.""" 262 | if isinstance(x, torch.Tensor): 263 | return torch.unbind(x, dim=axis) 264 | elif isinstance(x, np.ndarray): 265 | return np.rollaxis(x, axis=axis) 266 | else: 267 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 268 | 269 | 270 | def __stack(x: Iterable[np.ndarray | torch.Tensor], axis: int) -> np.ndarray | torch.Tensor: 271 | """Stack an iterable of np.ndarray or torch.Tensor along a given axis.""" 272 | if not isinstance(x, Iterable): 273 | raise TypeError("x must be an iterable.") 274 | 275 | if isinstance(x[0], torch.Tensor): 276 | return torch.stack(x, dim=axis) 277 | elif isinstance(x[0], np.ndarray): 278 | return np.stack(x, axis=axis) 279 | else: 280 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 281 | 282 | 283 | def __cos(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 284 | """Cosine function that works with np.ndarray and torch.Tensor.""" 285 | if isinstance(x, torch.Tensor): 286 | return torch.cos(x) 287 | elif isinstance(x, np.ndarray): 288 | return np.cos(x) 289 | else: 290 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 291 | 292 | 293 | def __sin(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 294 | """Sine function that works with np.ndarray and torch.Tensor.""" 295 | if isinstance(x, torch.Tensor): 296 | return torch.sin(x) 297 | elif isinstance(x, np.ndarray): 298 | return np.sin(x) 299 | else: 300 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 301 | 302 | 303 | def __sinh(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 304 | """Hyperbolic sine function that works with np.ndarray and torch.Tensor.""" 305 | if isinstance(x, torch.Tensor): 306 | return torch.sinh(x) 307 | elif isinstance(x, np.ndarray): 308 | return np.sinh(x) 309 | else: 310 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 311 | 312 | 313 | def __arcsinh(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 314 | """Inverse hyperbolic sine function that works with np.ndarray and torch.Tensor.""" 315 | if isinstance(x, torch.Tensor): 316 | return torch.asinh(x) 317 | elif isinstance(x, np.ndarray): 318 | return np.arcsinh(x) 319 | else: 320 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 321 | 322 | 323 | def __cosh(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 324 | """Hyperbolic cosine function that works with np.ndarray and torch.Tensor.""" 325 | if isinstance(x, torch.Tensor): 326 | return torch.cosh(x) 327 | elif isinstance(x, np.ndarray): 328 | return np.cosh(x) 329 | else: 330 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 331 | 332 | 333 | def __log(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 334 | """Logarithm function that works with np.ndarray and torch.Tensor.""" 335 | if isinstance(x, torch.Tensor): 336 | return torch.log(x) 337 | elif isinstance(x, np.ndarray): 338 | return np.log(x) 339 | else: 340 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 341 | 342 | 343 | def __arctan2( 344 | y: np.ndarray | torch.Tensor, x: np.ndarray | torch.Tensor 345 | ) -> np.ndarray | torch.Tensor: 346 | """Arctangent function that works with np.ndarray and torch.Tensor.""" 347 | if isinstance(y, torch.Tensor): 348 | return torch.atan2(y, x) 349 | elif isinstance(y, np.ndarray): 350 | return np.arctan2(y, x) 351 | else: 352 | raise TypeError(f"y must be either a numpy array or a torch tensor, not {type(y)}") 353 | 354 | 355 | def __sqrt(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 356 | """Square root function that works with np.ndarray and torch.Tensor.""" 357 | if isinstance(x, torch.Tensor): 358 | return torch.sqrt(x) 359 | elif isinstance(x, np.ndarray): 360 | return np.sqrt(x) 361 | else: 362 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 363 | 364 | 365 | def __sum( 366 | x: np.ndarray | torch.Tensor, axis: int, keepdims: bool = False 367 | ) -> np.ndarray | torch.Tensor: 368 | """Sum function that works with np.ndarray and torch.Tensor.""" 369 | if isinstance(x, torch.Tensor): 370 | return x.sum(axis, keepdim=keepdims) 371 | elif isinstance(x, np.ndarray): 372 | return np.sum(x, axis=axis, keepdims=keepdims) 373 | else: 374 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 375 | 376 | 377 | def __get_default_eps(x: np.ndarray | torch.Tensor) -> float: 378 | if isinstance(x, torch.Tensor): 379 | return torch.finfo(x.dtype).eps 380 | elif isinstance(x, np.ndarray): 381 | return np.finfo(x.dtype).eps 382 | else: 383 | raise TypeError(f"x must be either a numpy array or a torch tensor, not {type(x)}") 384 | 385 | 386 | __ALL__ = [ 387 | # cartesian <-> polar 388 | cartesian_to_EtaPhiPtE, 389 | EtaPhiPtE_to_cartesian, 390 | cartesian_to_YPhiPtE, 391 | YPhiPtE_to_cartesian, 392 | # cartesian <-> polarrel 393 | cartesian_to_relEtaPhiPt, 394 | relEtaPhiPt_to_cartesian, 395 | # polar <-> polarrel 396 | EtaPhiPtE_to_relEtaPhiPt, 397 | relEtaPhiPt_to_EtaPhiPt, 398 | ] 399 | -------------------------------------------------------------------------------- /jetnet/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # for ArrayLike type in docs 2 | 3 | # for calculating jet features quickly, 4 | # TODO: replace with vector library when summing over axis feature is implemented 5 | import awkward as ak 6 | import numpy as np 7 | from coffea.nanoevents.methods import vector 8 | from energyflow import EFPSet 9 | from numpy.typing import ArrayLike 10 | 11 | ak.behavior.update(vector.behavior) 12 | 13 | 14 | def jet_features(jets: np.ndarray) -> dict[str, float | np.ndarray]: 15 | """ 16 | Calculates jet features by summing over particle Lorentz 4-vectors. 17 | 18 | Args: 19 | jets (np.ndarray): array of either a single or multiple jets, of shape either 20 | ``[num_particles, num_features]`` or ``[num_jets, num_particles, num_features]``, 21 | with features in order ``[eta, phi, pt, (optional) mass]``. If no particle masses given, 22 | they are assumed to be 0. 23 | 24 | Returns: 25 | Dict[str, Union[float, np.ndarray]]: 26 | dict of float (if inputted single jet) or 27 | 1D arrays of length ``num_jets`` (if inputted multiple jets) 28 | with 'mass', 'pt', and 'eta' keys. 29 | 30 | """ 31 | 32 | assert len(jets.shape) == 2 or len(jets.shape) == 3, "jets dimensions are incorrect" 33 | assert jets.shape[-1] >= 3, "missing particle features" 34 | 35 | if len(jets.shape) == 2: 36 | vecs = ak.zip( 37 | { 38 | "pt": jets[:, 2:3], 39 | "eta": jets[:, 0:1], 40 | "phi": jets[:, 1:2], 41 | # 0s for mass if no mass given 42 | "mass": ak.full_like(jets[:, 2:3], 0) if jets.shape[1] == 3 else jets[:, 3:4], 43 | }, 44 | with_name="PtEtaPhiMLorentzVector", 45 | ) 46 | 47 | sum_vecs = vecs.sum(axis=0) 48 | else: 49 | vecs = ak.zip( 50 | { 51 | "pt": jets[:, :, 2:3], 52 | "eta": jets[:, :, 0:1], 53 | "phi": jets[:, :, 1:2], 54 | # 0s for mass if no mass given 55 | "mass": ak.full_like(jets[:, :, 2:3], 0) if jets.shape[2] == 3 else jets[:, :, 3:4], 56 | }, 57 | with_name="PtEtaPhiMLorentzVector", 58 | ) 59 | 60 | sum_vecs = vecs.sum(axis=1) 61 | 62 | jf = { 63 | "mass": np.nan_to_num(np.array(sum_vecs.mass)).squeeze(), 64 | "pt": np.nan_to_num(np.array(sum_vecs.pt)).squeeze(), 65 | "eta": np.nan_to_num(np.array(sum_vecs.eta)).squeeze(), 66 | } 67 | 68 | return jf 69 | 70 | 71 | def efps( 72 | jets: np.ndarray, 73 | use_particle_masses: bool = False, 74 | efpset_args: list | None = None, 75 | efp_jobs: int | None = None, 76 | ) -> np.ndarray: 77 | """ 78 | Utility for calculating EFPs for jets in JetNet format using the energyflow library. 79 | 80 | Args: 81 | jets (np.ndarray): array of either a single or multiple jets, of shape either 82 | ``[num_particles, num_features]`` or ``[num_jets, num_particles, num_features]``, 83 | with features in order ``[eta, phi, pt, (optional) mass]``. If no particle masses given, 84 | they are assumed to be 0. 85 | efpset_args (List): Args for the energyflow.efpset function to specify which EFPs to use, 86 | as defined here https://energyflow.network/docs/efp/#efpset. 87 | Defaults to the n=4, d=5, prime EFPs. 88 | efp_jobs (int): number of jobs to use for energyflow's EFP batch computation. 89 | None means as many processes as there are CPUs. 90 | 91 | Returns: 92 | np.ndarray: 93 | 1D (if inputted single jet) or 2D array of shape ``[num_jets, num_efps]`` of EFPs per jet 94 | 95 | """ 96 | 97 | if efpset_args is None: 98 | efpset_args = [("n==", 4), ("d==", 4), ("p==", 1)] 99 | assert len(jets.shape) == 2 or len(jets.shape) == 3, "jets dimensions are incorrect" 100 | assert jets.shape[-1] - int(use_particle_masses) >= 3, "particle feature format is incorrect" 101 | 102 | efpset = EFPSet(*efpset_args, measure="hadr", beta=1, normed=None, coords="ptyphim") 103 | 104 | if len(jets.shape) == 2: 105 | # convert to energyflow format 106 | jets = jets[:, [2, 0, 1]] if not use_particle_masses else jets[:, [2, 0, 1, 3]] 107 | efps = efpset.compute(jets) 108 | else: 109 | # convert to energyflow format 110 | jets = jets[:, :, [2, 0, 1]] if not use_particle_masses else jets[:, :, [2, 0, 1, 3]] 111 | efps = efpset.batch_compute(jets, efp_jobs) 112 | 113 | return efps 114 | 115 | 116 | def to_image( 117 | jets: np.ndarray, im_size: int, mask: np.ndarray = None, maxR: float = 1.0 118 | ) -> np.ndarray: 119 | """ 120 | Convert jet(s) into 2D ``im_size`` x ``im_size`` or 3D ``num_jets`` x ``im_size`` x ``im_size`` 121 | image arrays. 122 | 123 | Args: 124 | jets (np.ndarray): array of jet(s) of shape ``[num_particles, num_features]`` or 125 | ``[num_jets, num_particles, num_features]`` with features in order ``[eta, phi, pt]``. 126 | im_size (int): number of pixels per row and column. 127 | mask (np.ndarray): optional binary array of masks of shape ``[num_particles]`` or 128 | ``[num_jets, num_particles]``. 129 | maxR (float): max radius of the jet. Defaults to 1.0. 130 | 131 | Returns: 132 | np.ndarray: 2D or 3D array of shape ``[im_size, im_size]`` or 133 | ``[num_jets, im_size, im_size]``. 134 | 135 | """ 136 | assert len(jets.shape) == 2 or len(jets.shape) == 3, "jets dimensions are incorrect" 137 | assert jets.shape[-1] >= 3, "particle feature format is incorrect" 138 | 139 | eta = jets[..., 0] 140 | phi = jets[..., 1] 141 | pt = jets[..., 2] 142 | num_jets = 1 if len(jets.shape) == 2 else jets.shape[0] 143 | 144 | if mask is not None: 145 | assert len(mask.shape) == 1 or len(mask.shape) == 2, "mask shape incorrect" 146 | assert mask.shape == jets.shape[:-1], "mask shape and jets shape do not agree" 147 | mask = mask.astype(int) 148 | pt *= mask 149 | 150 | jet_images = np.zeros((num_jets, im_size, im_size)) 151 | 152 | for i_jet in range(num_jets): 153 | hist_2d, _, _ = np.histogram2d( 154 | eta[i_jet] if num_jets > 1 else eta, 155 | phi[i_jet] if num_jets > 1 else phi, 156 | bins=[im_size, im_size], 157 | range=[[-maxR, maxR], [-maxR, maxR]], 158 | weights=pt[i_jet] if num_jets > 1 else pt, 159 | ) 160 | jet_images[i_jet] = hist_2d 161 | 162 | if num_jets == 1: 163 | jet_images = jet_images[0] 164 | 165 | return jet_images 166 | 167 | 168 | def gen_jet_corrections( 169 | jets: ArrayLike, 170 | ret_mask_separate: bool = True, 171 | zero_mask_particles: bool = True, 172 | zero_neg_pt: bool = True, 173 | pt_index: int = 2, 174 | ) -> ArrayLike | tuple[ArrayLike, ArrayLike]: 175 | """ 176 | Zero's masked particles and negative pTs. 177 | 178 | Args: 179 | jets (ArrayLike): jets to recorrect. 180 | ret_mask_separate (bool, optional): return the jet and mask separately. Defaults to True. 181 | zero_mask_particles (bool, optional): set features of zero-masked particles to 0. Defaults 182 | to True. 183 | zero_neg_pt (bool, optional): set pT to 0 for particles with negative pt. Defaults to True. 184 | pt_index (int, optional): index of the pT feature. Defaults to 2. 185 | 186 | Returns: 187 | Jets of same type as input, of shape 188 | ``[num_jets, num_particles, num_features (including mask)]`` if ``ret_mask_separate`` 189 | is False, else a tuple with a tensor/array of shape 190 | ``[num_jets, num_particles, num_features (excluding mask)]`` and another binary mask 191 | tensor/array of shape ``[num_jets, num_particles, 1]``. 192 | """ 193 | 194 | use_mask = ret_mask_separate or zero_mask_particles 195 | 196 | mask = jets[:, :, -1] >= 0.5 if use_mask else None 197 | 198 | if zero_mask_particles and use_mask: 199 | jets[~mask] = 0 200 | 201 | if zero_neg_pt: 202 | jets[:, :, pt_index][jets[:, :, pt_index] < 0] = 0 203 | 204 | return (jets[:, :, :-1], mask) if ret_mask_separate else jets 205 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | [tool.pytest.ini_options] 5 | markers = ["slow: test is too slow for Github Actions"] 6 | 7 | [tool.coverage] 8 | run.source = ["simulation_challenge"] 9 | port.exclude_lines = [ 10 | 'pragma: no cover', 11 | '\.\.\.', 12 | 'if typing.TYPE_CHECKING:', 13 | ] 14 | 15 | [tool.mypy] 16 | files = ["src", "tests"] 17 | python_version = "3.8" 18 | warn_unused_configs = true 19 | strict = true 20 | show_error_codes = true 21 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] 22 | warn_unreachable = true 23 | disallow_untyped_defs = false 24 | disallow_incomplete_defs = false 25 | 26 | [[tool.mypy.overrides]] 27 | module = "simulation_challenge.*" 28 | disallow_untyped_defs = true 29 | disallow_incomplete_defs = true 30 | 31 | 32 | [tool.ruff] 33 | src = ["src"] 34 | 35 | [tool.ruff.lint] 36 | extend-select = [ 37 | "B", # flake8-bugbear 38 | "I", # isort 39 | "ARG", # flake8-unused-arguments 40 | "C4", # flake8-comprehensions 41 | "EM", # flake8-errmsg 42 | "ICN", # flake8-import-conventions 43 | "G", # flake8-logging-format 44 | "PGH", # pygrep-hooks 45 | "PIE", # flake8-pie 46 | "PL", # pylint 47 | "PT", # flake8-pytest-style 48 | "PTH", # flake8-use-pathlib 49 | "RET", # flake8-return 50 | "RUF", # Ruff-specific 51 | "SIM", # flake8-simplify 52 | "T20", # flake8-print 53 | "UP", # pyupgrade 54 | "YTT", # flake8-2020 55 | "EXE", # flake8-executable 56 | "NPY", # NumPy specific rules 57 | "PD", # pandas-vet 58 | ] 59 | ignore = [ 60 | "PLR", # Design related pylint codes 61 | "PT013", # incorrect import of pytest 62 | "PT018", # assertion should be broken down 63 | "T201", 64 | "EM101", 65 | "EM102", 66 | "G004", # logging format string 67 | "RUF012", # mutable class defaults 68 | "RET504", # unnecessary assignment before return statement 69 | "RET505", # unnecessary else after return statement 70 | ] 71 | isort.required-imports = ["from __future__ import annotations"] 72 | # Uncomment if using a _compat.typing backport 73 | # typing-modules = ["simulation_challenge._compat.typing"] 74 | 75 | [tool.ruff.lint.per-file-ignores] 76 | "tests/**" = ["T20"] 77 | "noxfile.py" = ["T20"] 78 | 79 | 80 | [tool.pylint] 81 | py-version = "3.8" 82 | ignore-paths = [".*/_version.py"] 83 | reports.output-format = "colorized" 84 | similarities.ignore-imports = "yes" 85 | messages_control.disable = [ 86 | "design", 87 | "fixme", 88 | "line-too-long", 89 | "missing-module-docstring", 90 | "wrong-import-position", 91 | ] 92 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from pathlib import Path 5 | 6 | from setuptools import find_packages, setup 7 | 8 | install_requires = [ 9 | "numpy >= 1.21.0", 10 | "torch >= 1.8.0", 11 | "energyflow >= 1.3.0", 12 | "scipy >= 1.6.2", 13 | "awkward >= 1.4.0", 14 | "coffea >= 0.7.0", 15 | "h5py >= 3.0.0", 16 | "pandas", 17 | "tables", 18 | "requests", 19 | "tqdm", 20 | ] 21 | 22 | extras_require = {"emdloss": ["qpth", "cvxpy"]} 23 | 24 | 25 | classifiers = [ 26 | # How mature is this project? Common values are 27 | # 3 - Alpha 28 | # 4 - Beta 29 | # 5 - Production/Stable 30 | "Development Status :: 3 - Alpha", 31 | # Pick your license as you wish (should match "license" above) 32 | "License :: OSI Approved :: MIT License", 33 | # Specify the Python versions you support here. In particular, ensure 34 | # that you indicate whether you support Python 2, Python 3 or both. 35 | "Programming Language :: Python :: 3", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | "Programming Language :: Python :: 3.9", 39 | "Programming Language :: Python :: 3.10", 40 | "Programming Language :: Python :: 3.11", 41 | "Programming Language :: Python :: 3.12", 42 | ] 43 | 44 | 45 | def readme(): 46 | with Path("README.md").open() as f: 47 | return f.read() 48 | 49 | 50 | with Path("jetnet/__init__.py").open() as f: 51 | __version__ = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read()).group(1) 52 | 53 | 54 | setup( 55 | name="jetnet", 56 | version=__version__, 57 | description="Jets + ML integration", 58 | long_description=readme(), 59 | long_description_content_type="text/markdown", 60 | url="http://github.com/jet-net/JetNet", 61 | author="Raghav Kansal", 62 | author_email="rkansal@cern.ch", 63 | license="MIT", 64 | packages=find_packages(), 65 | install_requires=install_requires, 66 | python_requires=">=3.7", 67 | extras_require=extras_require, 68 | classifiers=classifiers, 69 | zip_safe=False, 70 | include_package_data=True, 71 | ) 72 | -------------------------------------------------------------------------------- /tests/datasets/test_jetnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | from jetnet.datasets import JetNet, normalisations 8 | from pytest import approx 9 | from torch.utils.data import DataLoader 10 | 11 | # TODO: use checksum for downloaded files 12 | 13 | 14 | data_dir = Path("./datasets/jetnet") 15 | DataClass = JetNet 16 | jet_types = ["g", "q"] # subset of jet types 17 | gq_length = 177252 + 170679 18 | 19 | 20 | @pytest.mark.parametrize( 21 | ("jet_types", "expected_length", "class_id"), 22 | [ 23 | ("g", 177252, 0), 24 | ("q", 170679, 1), 25 | (jet_types, gq_length, None), 26 | ], 27 | ) 28 | @pytest.mark.parametrize("num_particles", [30, 75]) 29 | def test_getData(jet_types, num_particles, expected_length, class_id): 30 | # test getData errors and md5 checksum for one of the datasets 31 | if jet_types == "q": 32 | file_path = data_dir / f"q{'150' if num_particles > 30 else ''}.hdf5" 33 | 34 | if file_path.is_file(): 35 | file_path.unlink() 36 | 37 | # should raise a RunetimeError since file doesn't exist 38 | with pytest.raises(RuntimeError): 39 | DataClass.getData(jet_types, data_dir, num_particles=num_particles) 40 | 41 | # write random data to file 42 | with file_path.open("wb") as f: 43 | f.write(np.random.bytes(100)) # noqa: NPY002 44 | 45 | # should raise a RunetimeError since file exists but is incorrect 46 | with pytest.raises(RuntimeError): 47 | DataClass.getData(jet_types, data_dir, num_particles=num_particles) 48 | 49 | pf, jf = DataClass.getData(jet_types, data_dir, num_particles=num_particles, download=True) 50 | assert pf.shape == (expected_length, num_particles, 4) 51 | assert jf.shape == (expected_length, 5) 52 | if class_id is not None: 53 | assert np.all(jf[:, 0] == class_id) 54 | 55 | 56 | @pytest.mark.parametrize("num_particles", [30, 75]) 57 | def test_getDataFeatures(num_particles): 58 | pf, jf = DataClass.getData( 59 | jet_types, 60 | data_dir=data_dir, 61 | num_particles=num_particles, 62 | jet_features=["pt", "num_particles"], 63 | ) 64 | assert pf.shape == (gq_length, num_particles, 4) 65 | assert jf.shape == (gq_length, 2) 66 | assert np.max(jf[:, 0], axis=0) == approx(3000, rel=0.1) 67 | assert np.max(jf[:, 1], axis=0) == num_particles 68 | 69 | pf, jf = DataClass.getData( 70 | jet_types, data_dir=data_dir, num_particles=num_particles, jet_features=None 71 | ) 72 | assert pf.shape == (gq_length, num_particles, 4) 73 | assert jf is None 74 | 75 | pf, jf = DataClass.getData( 76 | jet_types, 77 | data_dir=data_dir, 78 | num_particles=num_particles, 79 | particle_features=["etarel", "mask"], 80 | ) 81 | assert pf.shape == (gq_length, num_particles, 2) 82 | assert jf.shape == (gq_length, 5) 83 | assert np.max(pf.reshape(-1, 2), axis=0) == approx([1, 1], rel=1e-2) 84 | 85 | 86 | @pytest.mark.parametrize("num_particles", [30, 75]) 87 | def test_getDataSplitting(num_particles): 88 | pf, jf = DataClass.getData( 89 | jet_type=jet_types, 90 | data_dir=data_dir, 91 | num_particles=num_particles, 92 | split_fraction=[0.6, 0.2, 0.2], 93 | split="train", 94 | ) 95 | assert len(pf) == int(gq_length * 0.6) 96 | assert len(jf) == int(gq_length * 0.6) 97 | 98 | pf, jf = DataClass.getData( 99 | jet_type=jet_types, data_dir=data_dir, num_particles=num_particles, split="all" 100 | ) 101 | assert len(pf) == int(gq_length) 102 | assert len(jf) == int(gq_length) 103 | 104 | pf, jf = DataClass.getData( 105 | jet_type=jet_types, 106 | data_dir=data_dir, 107 | num_particles=num_particles, 108 | split_fraction=[0.6, 0.2, 0.2], 109 | split="valid", 110 | ) 111 | assert len(pf) == int(gq_length * 0.8) - int(gq_length * 0.6) 112 | assert len(jf) == int(gq_length * 0.8) - int(gq_length * 0.6) 113 | 114 | pf, jf = DataClass.getData( 115 | jet_type=jet_types, 116 | data_dir=data_dir, 117 | num_particles=num_particles, 118 | split_fraction=[0.5, 0.2, 0.3], 119 | split="test", 120 | ) 121 | assert len(pf) == gq_length - int(gq_length * 0.7) 122 | assert len(jf) == gq_length - int(gq_length * 0.7) 123 | 124 | 125 | def test_getDataErrors(): 126 | with pytest.raises(AssertionError): 127 | DataClass.getData(jet_type="f") 128 | 129 | with pytest.raises(AssertionError): 130 | DataClass.getData(jet_type={"g", "f"}) 131 | 132 | with pytest.raises(AssertionError): 133 | DataClass.getData(data_dir=data_dir, particle_features="foo") 134 | 135 | with pytest.raises(AssertionError): 136 | DataClass.getData(data_dir=data_dir, jet_features=["eta", "mask"]) 137 | 138 | 139 | @pytest.mark.parametrize("num_particles", [30, 75]) 140 | def test_DataClass(num_particles): 141 | X = DataClass(jet_type=jet_types, data_dir=data_dir, num_particles=num_particles) 142 | assert len(X) == int(gq_length * 0.7) 143 | 144 | X_loaded = DataLoader(X) 145 | pf, jf = next(iter(X_loaded)) 146 | assert pf.shape == (1, num_particles, 4) 147 | assert jf.shape == (1, 5) 148 | 149 | X = DataClass( 150 | jet_type=jet_types, 151 | data_dir=data_dir, 152 | num_particles=num_particles, 153 | particle_features=["mask", "ptrel"], 154 | jet_features=None, 155 | ) 156 | X_loaded = DataLoader(X) 157 | pf, jf = next(iter(X_loaded)) 158 | assert pf.shape == (1, num_particles, 2) 159 | assert jf == [] 160 | 161 | X = DataClass( 162 | jet_type=jet_types, data_dir=data_dir, num_particles=num_particles, particle_features=None 163 | ) 164 | X_loaded = DataLoader(X) 165 | pf, jf = next(iter(X_loaded)) 166 | assert pf == [] 167 | assert jf.shape == (1, 5) 168 | 169 | 170 | @pytest.mark.parametrize("num_particles", [30, 75]) 171 | def test_DataClassNormalisation(num_particles): 172 | X = DataClass( 173 | jet_type=jet_types, 174 | data_dir=data_dir, 175 | num_particles=num_particles, 176 | particle_normalisation=normalisations.FeaturewiseLinearBounded(), 177 | jet_normalisation=normalisations.FeaturewiseLinearBounded( 178 | normalise_features=[False, True, True, True, True] 179 | ), 180 | split="all", 181 | ) 182 | 183 | assert np.all(np.max(np.abs(X.particle_data.reshape(-1, 4)), axis=0) == approx(1)) 184 | assert np.all(np.max(np.abs(X.jet_data[:, 1:].reshape(-1, 4)), axis=0) == approx(1)) 185 | assert np.all(np.sum([X.jet_data[:, 0] == 0, X.jet_data[:, 0] == 1], axis=-1)) 186 | -------------------------------------------------------------------------------- /tests/datasets/test_normalisations.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | from jetnet.datasets.normalisations import FeaturewiseLinear, FeaturewiseLinearBounded 6 | from pytest import approx 7 | 8 | rng = np.random.default_rng(42) 9 | test_data_1d = rng.random(3) * 100 10 | test_data_2d = rng.random((4, 3)) * 100 11 | test_data_3d = rng.random((5, 4, 3)) * 100 12 | 13 | test_data_1d_posneg = rng.random(3) * 100 - 50 14 | test_data_2d_posneg = rng.random((4, 3)) * 100 - 50 15 | test_data_3d_posneg = rng.random((5, 4, 3)) * 100 - 50 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "data", 20 | [ 21 | test_data_1d, 22 | test_data_2d, 23 | test_data_3d, 24 | test_data_1d_posneg, 25 | test_data_2d_posneg, 26 | test_data_3d_posneg, 27 | ], 28 | ) 29 | def test_FeaturewiseLinearBounded(data): 30 | norm = FeaturewiseLinearBounded() 31 | 32 | norm.derive_dataset_features(data) 33 | assert norm.feature_maxes.shape == (data.shape[-1],) 34 | assert np.all(norm.feature_maxes == np.max(np.abs(data.reshape(-1, 3)), axis=0)) 35 | 36 | normed = norm(data) 37 | assert normed.shape == data.shape 38 | assert np.all(np.max(np.abs(normed.reshape(-1, 3)), axis=0) == approx(1)) 39 | 40 | unnormed = norm(normed, inverse=True) 41 | assert np.all(unnormed == approx(data)) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "data", 46 | [ 47 | test_data_1d, 48 | test_data_2d, 49 | test_data_3d, 50 | test_data_1d_posneg, 51 | test_data_2d_posneg, 52 | test_data_3d_posneg, 53 | ], 54 | ) 55 | def test_FeaturewiseLinearBoundedErrors(data): 56 | norm = FeaturewiseLinearBounded() 57 | with pytest.raises(AssertionError): 58 | norm(data) 59 | 60 | norm = FeaturewiseLinearBounded(feature_norms=[3, 5]) 61 | norm.derive_dataset_features(data) 62 | with pytest.raises(AssertionError): 63 | norm(data) 64 | 65 | norm = FeaturewiseLinearBounded(feature_shifts=[3, 5]) 66 | norm.derive_dataset_features(data) 67 | with pytest.raises(AssertionError): 68 | norm(data) 69 | 70 | norm = FeaturewiseLinearBounded(feature_maxes=[3, 5]) 71 | with pytest.raises(AssertionError): 72 | norm(data) 73 | 74 | norm = FeaturewiseLinearBounded(normalise_features=[True, False]) 75 | norm.derive_dataset_features(data) 76 | with pytest.raises(AssertionError): 77 | norm(data) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "data", 82 | [ 83 | test_data_1d, 84 | test_data_2d, 85 | test_data_3d, 86 | test_data_1d_posneg, 87 | test_data_2d_posneg, 88 | test_data_3d_posneg, 89 | ], 90 | ) 91 | def test_FeaturewiseLinearBoundedNones(data): 92 | norm = FeaturewiseLinearBounded(feature_norms=[None, 8, None], feature_shifts=[3, None, None]) 93 | norm.derive_dataset_features(data) 94 | 95 | normed = norm(data) 96 | assert normed.shape == data.shape 97 | assert np.all(normed[..., 0] == approx(data[..., 0] + 3)) 98 | assert np.max(np.abs(normed[..., 1])) == approx(8) 99 | assert np.all(normed[..., 2] == approx(data[..., 2])) 100 | 101 | unnormed = norm(normed, inverse=True) 102 | assert np.all(unnormed == approx(data)) 103 | 104 | 105 | @pytest.mark.parametrize( 106 | "data", 107 | [ 108 | test_data_1d, 109 | test_data_2d, 110 | test_data_3d, 111 | test_data_1d_posneg, 112 | test_data_2d_posneg, 113 | test_data_3d_posneg, 114 | ], 115 | ) 116 | def test_FeaturewiseLinearBoundedCustom(data): 117 | norm = FeaturewiseLinearBounded( 118 | feature_norms=[3, 8, -1], feature_shifts=[2, 0, 3], normalise_features=[True, False, True] 119 | ) 120 | norm.derive_dataset_features(data) 121 | 122 | normed = norm(data) 123 | assert normed.shape == data.shape 124 | assert np.all(normed[..., 0] == approx(data[..., 0] / np.max(np.abs(data[..., 0])) * 3 + 2)) 125 | assert np.all(normed[..., 1] == normed[..., 1]) 126 | assert np.all(normed[..., 2] == approx(data[..., 2] / np.max(np.abs(data[..., 2])) * (-1) + 3)) 127 | 128 | unnormed = norm(normed, inverse=True) 129 | assert np.all(unnormed == approx(data)) 130 | 131 | 132 | @pytest.mark.parametrize( 133 | "data", 134 | [ 135 | test_data_1d, 136 | test_data_2d, 137 | test_data_3d, 138 | test_data_1d_posneg, 139 | test_data_2d_posneg, 140 | test_data_3d_posneg, 141 | ], 142 | ) 143 | def test_FeaturewiseLinear(data): 144 | norm = FeaturewiseLinear( 145 | feature_scales=[5, 0.25, -4], 146 | feature_shifts=[2, -1, 3], 147 | normalise_features=[False, True, True], 148 | ) 149 | 150 | norm.derive_dataset_features(data) # should do nothing 151 | 152 | normed = norm(data) 153 | assert normed.shape == data.shape 154 | assert np.all(normed[..., 0] == approx(data[..., 0])) 155 | assert np.all(normed[..., 1] == approx((data[..., 1] - 1) * 0.25)) 156 | assert np.all(normed[..., 2] == approx((data[..., 2] + 3) * (-4))) 157 | 158 | unnormed = norm(normed, inverse=True) 159 | assert np.all(unnormed == approx(data)) 160 | 161 | 162 | @pytest.mark.parametrize( 163 | "data", 164 | [ 165 | test_data_2d, 166 | test_data_3d, 167 | test_data_2d_posneg, 168 | test_data_3d_posneg, 169 | ], 170 | ) 171 | def test_FeaturewiseLinearNormal(data): 172 | norm = FeaturewiseLinear(normal=True, normalise_features=[True, False, True]) 173 | norm.derive_dataset_features(data) 174 | 175 | normed = norm(data) 176 | assert normed.shape == data.shape 177 | assert np.mean(normed[..., 0]) == approx(0) 178 | assert np.std(normed[..., 0]) == approx(1) 179 | assert np.mean(normed[..., 2]) == approx(0) 180 | assert np.std(normed[..., 2]) == approx(1) 181 | assert np.all(normed[..., 1] == approx(data[..., 1])) 182 | 183 | unnormed = norm(normed, inverse=True) 184 | assert np.all(unnormed == approx(data)) 185 | 186 | 187 | @pytest.mark.parametrize( 188 | "data", 189 | [ 190 | test_data_1d, 191 | test_data_2d, 192 | test_data_3d, 193 | test_data_1d_posneg, 194 | test_data_2d_posneg, 195 | test_data_3d_posneg, 196 | ], 197 | ) 198 | def test_FeaturewiseLinearErrors(data): 199 | norm = FeaturewiseLinear(feature_scales=[3, 5]) 200 | norm.derive_dataset_features(data) 201 | with pytest.raises(AssertionError): 202 | norm(data) 203 | 204 | norm = FeaturewiseLinear(feature_shifts=[3, 5]) 205 | norm.derive_dataset_features(data) 206 | with pytest.raises(AssertionError): 207 | norm(data) 208 | 209 | norm = FeaturewiseLinear(normalise_features=[False, False]) 210 | with pytest.raises(AssertionError): 211 | norm(data) 212 | -------------------------------------------------------------------------------- /tests/datasets/test_qgjets.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | from jetnet.datasets import QuarkGluon 8 | from pytest import approx 9 | 10 | # TODO: use checksum for downloaded files 11 | 12 | 13 | test_file_list_withbc = [ 14 | "QG_jets_withbc_0.npz", 15 | "QG_jets_withbc_1.npz", 16 | ] 17 | 18 | test_file_list_withoutbc = [ 19 | "QG_jets.npz", 20 | "QG_jets_1.npz", 21 | ] 22 | 23 | data_dir = Path("./datasets/qgjets") 24 | total_length = 200_000 25 | DataClass = QuarkGluon 26 | num_particles = 153 27 | 28 | 29 | @pytest.mark.slow 30 | @pytest.mark.parametrize( 31 | ("jet_types", "split", "expected_length", "class_id"), 32 | [ 33 | ("g", "all", total_length / 2, 0), 34 | ("q", "train", total_length * 0.7 / 2, 1), 35 | ("all", "valid", total_length * 0.15, None), 36 | ], 37 | ) 38 | @pytest.mark.parametrize("file_list", [test_file_list_withbc, test_file_list_withoutbc]) 39 | def test_getData(jet_types, split, expected_length, class_id, file_list): 40 | # test md5 checksum is working for one of the datasets 41 | if jet_types == "q" and file_list == test_file_list_withoutbc: 42 | file_path = data_dir / file_list[-1] 43 | 44 | if file_path.is_file(): 45 | file_path.unlink() 46 | 47 | # should raise a RunetimeError since file doesn't exist 48 | with pytest.raises(RuntimeError): 49 | DataClass.getData(jet_types, data_dir, file_list=file_list, split=split) 50 | 51 | # write random data to file 52 | with file_path.open("wb") as f: 53 | f.write(np.random.bytes(100)) # noqa: NPY002 54 | 55 | # should raise a RunetimeError since file exists but is incorrect 56 | with pytest.raises(RuntimeError): 57 | DataClass.getData(jet_types, data_dir, file_list=file_list, split=split) 58 | 59 | pf, jf = DataClass.getData(jet_types, data_dir, file_list=file_list, split=split, download=True) 60 | assert pf.shape == (expected_length, num_particles, 4) 61 | assert jf.shape == (expected_length, 1) 62 | if class_id is not None: 63 | assert np.all(jf[:, 0] == class_id) 64 | 65 | 66 | @pytest.mark.slow 67 | @pytest.mark.parametrize("file_list", [test_file_list_withbc, test_file_list_withoutbc]) 68 | def test_getDataFeatures(file_list): 69 | pf, jf = DataClass.getData(data_dir=data_dir, jet_features=None, file_list=file_list) 70 | assert pf.shape == (total_length, num_particles, 4) 71 | assert jf is None 72 | 73 | pf, jf = DataClass.getData( 74 | data_dir=data_dir, 75 | particle_features=["pdgid", "pt"], 76 | num_particles=30, 77 | file_list=file_list, 78 | ) 79 | assert pf.shape == (total_length, 30, 2) 80 | assert jf.shape == (total_length, 1) 81 | assert np.max(pf[:, :, 0]) == approx(2212) 82 | assert np.max(pf[:, :, 1]) == approx(550, rel=0.2) 83 | 84 | 85 | @pytest.mark.slow 86 | def test_getDataErrors(): 87 | with pytest.raises(AssertionError): 88 | DataClass.getData(jet_type="f") 89 | 90 | with pytest.raises(AssertionError): 91 | DataClass.getData(jet_type={"qcd", "f"}) 92 | 93 | with pytest.raises(AssertionError): 94 | DataClass.getData(data_dir=data_dir, particle_features="foo") 95 | 96 | with pytest.raises(AssertionError): 97 | DataClass.getData(data_dir=data_dir, jet_features=["eta", "mask"]) 98 | -------------------------------------------------------------------------------- /tests/datasets/test_toptagging.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | from jetnet.datasets import TopTagging, normalisations 8 | from pytest import approx 9 | 10 | # TODO: use checksum for downloaded files 11 | 12 | 13 | data_dir = Path("./datasets/toptagging") 14 | DataClass = TopTagging 15 | 16 | valid_length = 403000 17 | num_particles = 200 18 | split = "valid" # for faster testing 19 | 20 | 21 | @pytest.mark.slow 22 | @pytest.mark.parametrize( 23 | ("jet_types", "split", "expected_length", "class_id"), 24 | [ 25 | # ("qcd", "all", 1008940, 0), 26 | # ("top", "all", 1009060, 1), 27 | ("qcd", "valid", 201503, 0), 28 | ("top", "valid", 201497, 1), 29 | # ("top", "test", 202086, 1), 30 | # ("top", "train", 605477, 1), 31 | # ("all", "train", 1211000, None), 32 | ("all", "valid", valid_length, None), 33 | # ("all", "test", 404000, None), 34 | # ("all", "all", total_length, None), 35 | ], 36 | ) 37 | def test_getData(jet_types, split, expected_length, class_id): 38 | # test md5 checksum is working for one of the datasets 39 | if jet_types == "top" and split == "valid": 40 | file_path = data_dir / "val.h5" 41 | 42 | if file_path.is_file(): 43 | file_path.unlink() 44 | 45 | # should raise a RunetimeError since file doesn't exist 46 | with pytest.raises(RuntimeError): 47 | DataClass.getData(jet_types, data_dir, split=split) 48 | 49 | # write random data to file 50 | with file_path.open("wb") as f: 51 | f.write(np.random.bytes(100)) # noqa: NPY002 52 | 53 | # should raise a RunetimeError since file exists but is incorrect 54 | with pytest.raises(RuntimeError): 55 | DataClass.getData(jet_types, data_dir, split=split) 56 | 57 | pf, jf = DataClass.getData(jet_types, data_dir, split=split, download=True) 58 | assert pf.shape == (expected_length, num_particles, 4) 59 | assert jf.shape == (expected_length, 5) 60 | if class_id is not None: 61 | assert np.all(jf[:, 0] == class_id) 62 | 63 | 64 | @pytest.mark.slow 65 | def test_getDataFeatures(): 66 | pf, jf = DataClass.getData(data_dir=data_dir, jet_features=["E", "type"], split=split) 67 | assert pf.shape == (valid_length, num_particles, 4) 68 | assert jf.shape == (valid_length, 2) 69 | assert np.max(jf[:, 0]) == approx(4000, rel=0.2) 70 | assert np.max(jf[:, 1]) == 1 71 | 72 | pf, jf = DataClass.getData(data_dir=data_dir, jet_features=None, split=split) 73 | assert pf.shape == (valid_length, num_particles, 4) 74 | assert jf is None 75 | 76 | pf, jf = DataClass.getData( 77 | data_dir=data_dir, particle_features=["px", "E"], num_particles=30, split=split 78 | ) 79 | assert pf.shape == (valid_length, 30, 2) 80 | assert jf.shape == (valid_length, 5) 81 | assert np.max(pf[:, :, 0]) == approx(700, rel=0.2) 82 | assert np.max(pf[:, :, 1]) == approx(2000, rel=0.2) 83 | 84 | 85 | @pytest.mark.slow 86 | def test_DataClassNormalisation(): 87 | X = DataClass( 88 | data_dir=data_dir, 89 | num_particles=num_particles, 90 | particle_normalisation=normalisations.FeaturewiseLinearBounded(), 91 | jet_normalisation=normalisations.FeaturewiseLinearBounded( 92 | normalise_features=[False, True, True, True, True] 93 | ), 94 | split=split, 95 | ) 96 | 97 | assert np.all(np.max(np.abs(X.particle_data.reshape(-1, 4)), axis=0) == approx(1)) 98 | assert np.all(np.max(np.abs(X.jet_data[:, 1:].reshape(-1, 4)), axis=0) == approx(1)) 99 | assert np.max(X.jet_data[:, 0]) == 1 100 | 101 | 102 | @pytest.mark.slow 103 | def test_getDataErrors(): 104 | with pytest.raises(AssertionError): 105 | DataClass.getData(jet_type="f", split=split) 106 | 107 | with pytest.raises(AssertionError): 108 | DataClass.getData(jet_type={"qcd", "f"}, split=split) 109 | 110 | with pytest.raises(AssertionError): 111 | DataClass.getData(data_dir=data_dir, particle_features="foo", split=split) 112 | 113 | with pytest.raises(AssertionError): 114 | DataClass.getData(data_dir=data_dir, jet_features=["eta", "mask"], split=split) 115 | -------------------------------------------------------------------------------- /tests/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | from jetnet.datasets import utils 6 | 7 | rng = np.random.default_rng(42) 8 | test_data_2d = rng.random((4, 3)) 9 | test_data_3d = rng.random((5, 4, 3)) 10 | 11 | 12 | @pytest.fixture 13 | def features_order(): 14 | return ["eta", "phi", "pt"] 15 | 16 | 17 | @pytest.mark.parametrize( 18 | ("data", "features", "expected"), 19 | [ 20 | (test_data_2d, ["phi", "pt"], test_data_2d[:, 1:]), 21 | (test_data_2d, ["phi"], test_data_2d[:, 1:2]), 22 | (test_data_2d, ["pt", "phi"], np.stack((test_data_2d[:, 2], test_data_2d[:, 1]), axis=-1)), 23 | (test_data_3d, ["phi", "pt"], test_data_3d[:, :, 1:]), 24 | (test_data_3d, ["phi"], test_data_3d[:, :, 1:2]), 25 | ( 26 | test_data_3d, 27 | ["pt", "phi"], 28 | np.stack((test_data_3d[:, :, 2], test_data_3d[:, :, 1]), axis=-1), 29 | ), 30 | ], 31 | ) 32 | def test_getOrderedFeatures(data, features, features_order, expected): 33 | assert np.all(utils.getOrderedFeatures(data, features, features_order) == expected) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | ("data", "features"), 38 | [ 39 | (test_data_2d, ["phi", "pt", "foo"]), 40 | (test_data_2d, "foo"), 41 | ], 42 | ) 43 | def test_getOrderedFeaturesException(data, features, features_order): 44 | with pytest.raises(AssertionError): 45 | utils.getOrderedFeatures(data, features, features_order) 46 | 47 | 48 | @pytest.mark.parametrize( 49 | ("inputs", "to_set", "expected"), 50 | [ 51 | (["foo"], False, ["foo"]), 52 | (["foo"], True, {"foo"}), 53 | ([["foo"]], False, ["foo"]), 54 | ([{"foo"}], True, {"foo"}), 55 | (["foo", ["bar"]], False, [["foo"], ["bar"]]), 56 | (["foo", ["bar", "boom"]], False, [["foo"], ["bar", "boom"]]), 57 | ], 58 | ) 59 | def test_checkStrToList(inputs, to_set, expected): 60 | assert utils.checkStrToList(*inputs, to_set=to_set) == expected 61 | 62 | 63 | @pytest.mark.parametrize( 64 | ("inputs", "expected"), 65 | [([[]], False), ([None], False), ([[3]], True), ([[], None, [3]], [False, False, True])], 66 | ) 67 | def test_checkListNotEmpty(inputs, expected): 68 | assert utils.checkListNotEmpty(*inputs) == expected 69 | 70 | 71 | @pytest.mark.parametrize( 72 | ("inputs", "expected"), 73 | [([None, 3], 3), ([None], None), ([3, 5, None], 3)], 74 | ) 75 | def test_firstNotNoneElement(inputs, expected): 76 | assert utils.firstNotNoneElement(*inputs) == expected 77 | 78 | 79 | tvt_splits = ["train", "valid", "test"] 80 | tvt_splits_all = ["train", "valid", "test", "all"] 81 | 82 | 83 | @pytest.mark.parametrize( 84 | ("length", "split", "splits", "split_fraction", "expected"), 85 | [ 86 | (100, "train", tvt_splits, [0.7, 0.15, 0.15], (0, 70)), 87 | (100, "valid", tvt_splits, [0.7, 0.15, 0.15], (70, 85)), 88 | (100, "test", tvt_splits, [0.7, 0.15, 0.15], (85, 100)), 89 | (100, "train", tvt_splits, [0.5, 0.2, 0.3], (0, 50)), 90 | (100, "valid", tvt_splits, [0.5, 0.2, 0.3], (50, 70)), 91 | (100, "test", tvt_splits, [0.5, 0.2, 0.3], (70, 100)), 92 | (10, "valid", tvt_splits, [0.7, 0.15, 0.15], (7, 8)), 93 | (10, "test", tvt_splits, [0.7, 0.15, 0.15], (8, 10)), 94 | (100, "valid", tvt_splits_all, [0.7, 0.15, 0.15], (70, 85)), 95 | (100, "all", tvt_splits_all, [0.7, 0.15, 0.15], (0, 100)), 96 | (100, "all", tvt_splits_all, [0.7, 0.15, 0.2], (0, 100)), 97 | ], 98 | ) 99 | def test_getSplitting(length, split, splits, split_fraction, expected): 100 | assert utils.getSplitting(length, split, splits, split_fraction) == expected 101 | -------------------------------------------------------------------------------- /tests/evaluation/test_gen_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | from jetnet import evaluation 6 | from pytest import approx 7 | 8 | test_zeros = np.zeros((50_000, 2)) 9 | test_ones = np.ones((50_000, 2)) 10 | test_twos = np.ones((50_000, 2)) * 2 11 | 12 | 13 | def test_fpd(): 14 | val, err = evaluation.fpd(test_zeros, test_zeros) 15 | assert val == approx(0, abs=0.01) 16 | assert err < 1e-3 17 | 18 | val, err = evaluation.fpd(test_twos, test_zeros) 19 | assert val == approx(2, rel=0.01) # 1^2 + 1^2 20 | assert err < 1e-3 21 | 22 | # test normalization 23 | val, err = evaluation.fpd(test_zeros, test_zeros, normalise=False) # should have no effect 24 | assert val == approx(0, abs=0.01) 25 | assert err < 1e-3 26 | 27 | val, err = evaluation.fpd(test_twos, test_zeros, normalise=False) 28 | assert val == approx(8, rel=0.01) # 2^2 + 2^2 29 | assert err < 1e-3 30 | 31 | 32 | @pytest.mark.parametrize("num_threads", [None, 2]) # test numba parallelization 33 | def test_kpd(num_threads): 34 | assert evaluation.kpd(test_zeros, test_zeros, num_threads=num_threads) == approx([0, 0]) 35 | assert evaluation.kpd(test_twos, test_zeros, num_threads=num_threads) == approx([15, 0]) 36 | 37 | # test normalization 38 | assert evaluation.kpd( 39 | test_zeros, test_zeros, normalise=False, num_threads=num_threads 40 | ) == approx([0, 0]) 41 | assert evaluation.kpd( 42 | test_twos, test_zeros, normalise=False, num_threads=num_threads 43 | ) == approx([624, 0]) 44 | -------------------------------------------------------------------------------- /tests/utils/test_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | from jetnet.utils import to_image 6 | 7 | # 1 jet with 3 test particles 8 | test_data_2d = np.zeros((3, 3)) 9 | test_data_2d[:, 0] = np.array([-0.5, 0, 0.5]) # eta 10 | test_data_2d[:, 1] = np.array([-0.5, 0, 0.5]) # phi 11 | test_data_2d[:, 2] = np.array([1, 1, 1]) # pt 12 | expected_2d = np.identity(3) 13 | 14 | # 2 jets 15 | test_data_3d = np.stack([test_data_2d] * 2) 16 | expected_3d = np.stack([expected_2d] * 2) 17 | 18 | 19 | @pytest.mark.parametrize( 20 | ("data", "expected"), [(test_data_2d, expected_2d), (test_data_3d, expected_3d)] 21 | ) 22 | def test_to_image(data, expected): 23 | jet_image = to_image(data, im_size=3, maxR=1.0) 24 | assert len(jet_image.shape) == len(data.shape), "wrong jet image shape" 25 | assert jet_image.shape[-2:] == (3, 3), "wrong jet image size" 26 | np.testing.assert_allclose(jet_image, expected) 27 | --------------------------------------------------------------------------------