├── .flake8 ├── .github ├── CONTRIBUTING.md ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── animation.gif ├── applying_neos.md ├── binder └── postBuild ├── demo.ipynb ├── docs ├── Makefile ├── conf.py ├── index.rst └── make.bat ├── examples ├── ap000.gif ├── binning.ipynb ├── cuts.ipynb ├── diffable_histograms.ipynb ├── float.png ├── requirements.txt ├── simple-analysis-optimisation.ipynb ├── withbinfloat.png └── withnobinfloat.png ├── nbs ├── assets │ ├── 2_model_demo.gif │ ├── Screenshot 2020-07-27 at 11.50.53.png │ ├── anaflow.png │ ├── anaflowgrad.png │ ├── cc.png │ ├── cern.jpg │ ├── cut.gif │ ├── eu.png │ ├── fixed.png │ ├── free.png │ ├── goodkde.gif │ ├── gradhep.png │ ├── insights.jpg │ ├── jax.png │ ├── kde_bins.gif │ ├── kde_interesting.gif │ ├── kde_pyhf_animation.gif │ ├── kde_sigmoif.gif │ ├── kdesig.gif │ ├── kdestud.png │ ├── lu.png │ ├── multi_bin.gif │ ├── neoflow.png │ ├── neos-slide.png │ ├── neos.png │ ├── neos_banner.png │ ├── neos_logo.png │ ├── noaxis.gif │ ├── pyhf-logo.png │ ├── pyhf.png │ ├── pyhf_3.gif │ ├── soft2.gif │ ├── softmax_animation.gif │ ├── softmax_animation_2.gif │ ├── softmax_pyhf_animation.gif │ └── training.gif └── talk_slides.ipynb ├── noxfile.py ├── pyproject.toml ├── random.pdf ├── setup.cfg ├── setup.py ├── src └── neos │ ├── __init__.py │ ├── losses.py │ ├── py.typed │ └── top_level.py └── tests └── test_package.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = E203, E501, E722, B950 3 | select = C,E,F,W,T,B,B9,I 4 | per-file-ignores = 5 | tests/*: T 6 | noxfile.py: T 7 | src/neos/pipeline.py: T 8 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | See the [Scikit-HEP Developer introduction][skhep-dev-intro] for a 2 | detailed description of best practices for developing Scikit-HEP packages. 3 | 4 | [skhep-dev-intro]: https://scikit-hep.org/developer/intro 5 | 6 | # Quick development 7 | 8 | The fastest way to start with development is to use nox. If you don't have nox, 9 | you can use `pipx run nox` to run it without installing, or `pipx install nox`. 10 | If you don't have pipx (pip for applications), then you can install with with 11 | `pip install pipx` (the only case were installing an application with regular 12 | pip is reasonable). If you use macOS, then pipx and nox are both in brew, use 13 | `brew install pipx nox`. 14 | 15 | To use, run `nox`. This will lint and test using every installed version of 16 | Python on your system, skipping ones that are not installed. You can also run 17 | specific jobs: 18 | 19 | ```console 20 | $ nox -s lint # Lint only 21 | $ nox -s tests-3.9 # Python 3.9 tests only 22 | $ nox -s docs -- serve # Build and serve the docs 23 | $ nox -s build # Make an SDist and wheel 24 | ``` 25 | 26 | Nox handles everything for you, including setting up an temporary virtual 27 | environment for each run. 28 | 29 | 30 | # Setting up a development environment manually 31 | 32 | You can set up a development environment by running: 33 | 34 | ```bash 35 | python3 -m venv .venv 36 | source ./.venv/bin/activate 37 | pip install -v -e .[dev] 38 | ``` 39 | 40 | If you have the [Python Launcher for Unix](https://github.com/brettcannon/python-launcher), 41 | you can instead do: 42 | 43 | ```bash 44 | py -m venv .venv 45 | py -m install -v -e .[dev] 46 | ``` 47 | 48 | # Post setup 49 | 50 | You should prepare pre-commit, which will help you by checking that commits 51 | pass required checks: 52 | 53 | ```bash 54 | pip install pre-commit # or brew install pre-commit on macOS 55 | pre-commit install # Will install a pre-commit hook into the git repo 56 | ``` 57 | 58 | You can also/alternatively run `pre-commit run` (changes only) or `pre-commit 59 | run --all-files` to check even without installing the hook. 60 | 61 | # Testing 62 | 63 | Use pytest to run the unit checks: 64 | 65 | ```bash 66 | pytest 67 | ``` 68 | 69 | # Building docs 70 | 71 | You can build the docs using: 72 | 73 | ```bash 74 | nox -s docs 75 | ``` 76 | 77 | You can see a preview with: 78 | 79 | ```bash 80 | nox -s docs -- serve 81 | ``` 82 | 83 | # Pre-commit 84 | 85 | This project uses pre-commit for all style checking. While you can run it with 86 | nox, this is such an important tool that it deserves to be installed on its 87 | own. Install pre-commit and run: 88 | 89 | ```bash 90 | pre-commit run -a 91 | ``` 92 | 93 | to check all files. 94 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | ignore: 9 | # Official actions have moving tags like v1 10 | # that are used, so they don't need updates here 11 | - dependency-name: "actions/*" 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - master 9 | - main 10 | - develop 11 | release: 12 | types: 13 | - published 14 | 15 | jobs: 16 | dist: 17 | name: Distribution build 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v1 22 | 23 | - name: Build sdist and wheel 24 | run: pipx run build 25 | 26 | - uses: actions/upload-artifact@v2 27 | with: 28 | path: dist 29 | 30 | - name: Check products 31 | run: pipx run twine check dist/* 32 | 33 | - uses: pypa/gh-action-pypi-publish@v1.4.2 34 | if: github.event_name == 'release' && github.event.action == 'published' 35 | with: 36 | user: __token__ 37 | # Remember to generate this and set it in "GitHub Secrets" 38 | password: ${{ secrets.pypi_password }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # setuptools_scm 141 | src/*/_version.py 142 | 143 | # jupyter 144 | t.ipynb 145 | Untitled.ipynb 146 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black-jupyter 6 | 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.3.0 9 | hooks: 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-symlinks 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | - id: requirements-txt-fixer 18 | - id: trailing-whitespace 19 | 20 | - repo: https://github.com/pre-commit/pygrep-hooks 21 | rev: v1.9.0 22 | hooks: 23 | - id: python-check-blanket-noqa 24 | - id: python-check-blanket-type-ignore 25 | - id: python-no-log-warn 26 | - id: python-no-eval 27 | - id: python-use-type-annotations 28 | - id: rst-backticks 29 | - id: rst-directive-colons 30 | - id: rst-inline-touching-normal 31 | 32 | - repo: https://github.com/PyCQA/isort 33 | rev: 5.10.1 34 | hooks: 35 | - id: isort 36 | 37 | - repo: https://github.com/asottile/pyupgrade 38 | rev: v2.34.0 39 | hooks: 40 | - id: pyupgrade 41 | args: ["--py36-plus"] 42 | 43 | - repo: https://github.com/asottile/setup-cfg-fmt 44 | rev: v1.20.1 45 | hooks: 46 | - id: setup-cfg-fmt 47 | 48 | - repo: https://github.com/hadialqattan/pycln 49 | rev: v1.3.5 50 | hooks: 51 | - id: pycln 52 | args: [--config=pyproject.toml] 53 | 54 | - repo: https://github.com/asottile/yesqa 55 | rev: v1.3.0 56 | hooks: 57 | - id: yesqa 58 | exclude: docs/conf.py 59 | additional_dependencies: &flake8_dependencies 60 | - flake8-bugbear 61 | - flake8-print 62 | 63 | - repo: https://github.com/pycqa/flake8 64 | rev: 4.0.1 65 | hooks: 66 | - id: flake8 67 | exclude: docs/conf.py 68 | additional_dependencies: *flake8_dependencies 69 | 70 | # - repo: https://github.com/pre-commit/mirrors-mypy 71 | # rev: v0.910-1 72 | # hooks: 73 | # - id: mypy 74 | # files: src 75 | 76 | 77 | - repo: https://github.com/shellcheck-py/shellcheck-py 78 | rev: v0.8.0.4 79 | hooks: 80 | - id: shellcheck 81 | 82 | 83 | - repo: https://github.com/mgedmin/check-manifest 84 | rev: "0.48" 85 | hooks: 86 | - id: check-manifest 87 | stages: [manual] 88 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Include PDF and ePub 13 | formats: all 14 | 15 | python: 16 | version: 3.8 17 | install: 18 | - method: pip 19 | path: . 20 | extra_requirements: 21 | - docs 22 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "Thanks for being interested in neos! If you use this software in a project, please cite it as below." 3 | authors: 4 | - family-names: Simpson 5 | given-names: Nathan 6 | orcid: https://orcid.org/0000-0003-4188-829 7 | - family-names: Heinrich 8 | given-names: Lukas 9 | orcid: https://orcid.org/0000-0002-4048-7584 10 | title: "neos: version 0.2.0" 11 | version: v0.2.0 12 | date-released: 2021-01-12 13 | url: "https://github.com/gradhep/neos" 14 | doi: 10.5281/zenodo.6351423 15 | references: 16 | - type: article 17 | authors: 18 | - family-names: Simpson 19 | given-names: Nathan 20 | orcid: https://orcid.org/0000-0003-4188-829 21 | - family-names: "Heinrich" 22 | given-names: "Lukas" 23 | orcid: "https://orcid.org/0000-0002-4048-7584" 24 | affiliation: "TU Munich" 25 | title: "neos: End-to-End-Optimised Summary Statistics for High Energy Physics" 26 | doi: 10.48550/arXiv.2203.05570 27 | url: "https://doi.org/10.48550/arXiv.2203.05570" 28 | year: 2022 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | See the [Scikit-HEP Developer introduction][skhep-dev-intro] for a 2 | detailed description of best practices for developing packages. 3 | 4 | [skhep-dev-intro]: https://scikit-hep.org/developer/intro 5 | 6 | # Setting up a development environment 7 | 8 | You can set up a development environment by running: 9 | 10 | ```bash 11 | python3 -m venv .env 12 | source ./.env/bin/activate 13 | pip install -v -e .[all] 14 | ``` 15 | 16 | # Post setup 17 | 18 | You should prepare pre-commit, which will help you by checking that commits 19 | pass required checks: 20 | 21 | ```bash 22 | pip install pre-commit # or brew install pre-commit on macOS 23 | pre-commit install # Will install a pre-commit hook into the git repo 24 | ``` 25 | 26 | You can also/alternatively run `pre-commit run` (changes only) or `pre-commit 27 | run --all-files` to check even without installing the hook. 28 | 29 | # Testing 30 | 31 | Use PyTest to run the unit checks: 32 | 33 | ```bash 34 | pytest 35 | ``` 36 | 37 | # Building docs 38 | 39 | You can build the docs using: 40 | 41 | 42 | From inside your environmentwith the docs extra installed, run: 43 | 44 | ```bash 45 | sphinx-build -M html docs docs/_build 46 | ``` 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Nathan Simpson. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the vector package developers nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune * 2 | graft src 3 | graft tests 4 | 5 | include LICENSE README.md pyproject.toml setup.py setup.cfg 6 | include *.gif 7 | include *.ipynb 8 | include *.md 9 | include *.pdf 10 | global-exclude __pycache__ *.py[cod] .* 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | neos logo
3 | neural end-to-end-optimised summary statistics 4 |
5 | arxiv.org/abs/2203.05570 6 |
7 | 8 | GitHub Workflow Status 9 | 10 | 11 | Zenodo DOI 12 | 13 | 14 | Binder 15 | 16 |

17 | 18 | 19 | 20 | [actions-badge]: https://github.com/gradhep/neos/workflows/CI/badge.svg 21 | [actions-link]: https://github.com/gradhep/neos/actions 22 | [black-badge]: https://img.shields.io/badge/code%20style-black-000000.svg 23 | [black-link]: https://github.com/psf/black 24 | [conda-badge]: https://img.shields.io/conda/vn/conda-forge/neos 25 | [conda-link]: https://github.com/conda-forge/neos-feedstock 26 | [codecov-badge]: https://app.codecov.io/gh/gradhep/neos/branch/main/graph/badge.svg 27 | [codecov-link]: https://app.codecov.io/gh/gradhep/neos 28 | [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github 29 | [github-discussions-link]: https://github.com/gradhep/neos/discussions 30 | [gitter-badge]: https://badges.gitter.im/https://github.com/gradhep/neos/community.svg 31 | [gitter-link]: https://gitter.im/https://github.com/gradhep/neos/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge 32 | [pypi-link]: https://pypi.org/project/neos/ 33 | [pypi-platforms]: https://img.shields.io/pypi/pyversions/neos 34 | [pypi-version]: https://badge.fury.io/py/neos.svg 35 | [rtd-badge]: https://readthedocs.org/projects/neos/badge/?version=latest 36 | [rtd-link]: https://neos.readthedocs.io/en/latest/?badge=latest 37 | [sk-badge]: https://scikit-hep.org/assets/images/Scikit--HEP-Project-blue.svg 38 | 39 | ![](animation.gif) 40 | 41 | ## About 42 | 43 | Leverages the shoulders of giants ([`jax`](https://github.com/google/jax/) and [`pyhf`](https://github.com/scikit-hep/pyhf)) to differentiate through a high-energy physics analysis workflow, including the construction of the frequentist profile likelihood. 44 | 45 | If you're more of a video person, see [this talk](https://www.youtube.com/watch?v=3P4ZDkbleKs) given by [Nathan](https://github.com/phinate) on the broader topic of differentiable programming in high-energy physics, which also covers `neos`. 46 | 47 | ## You want to apply this to your analysis? 48 | 49 | Some things need to happen first. [Click here for more info -- I wrote them up!](applying_neos.md) 50 | 51 | ## Have questions? 52 | 53 | Do you want to chat about `neos`? Join us in Mattermost: [![Mattermost](https://img.shields.io/badge/chat-mattermost-blue.svg)](https://mattermost.web.cern.ch/signup_user_complete/?id=zf7w5rb1miy85xsfjqm68q9hwr&md=link&sbr=su) 54 | 55 | ## Cite 56 | 57 | Please cite our newly released paper: 58 | 59 | ``` 60 | @article{neos, 61 | Author = {Nathan Simpson and Lukas Heinrich}, 62 | Title = {neos: End-to-End-Optimised Summary Statistics for High Energy Physics}, 63 | Year = {2022}, 64 | Eprint = {arXiv:2203.05570}, 65 | doi = {10.48550/arXiv.2203.05570}, 66 | url = {https://doi.org/10.48550/arXiv.2203.05570} 67 | } 68 | ``` 69 | 70 | 71 | ## Example usage -- train a neural network to optimize an expected p-value 72 | 73 | ### setup 74 | In a python 3 environment, run the following: 75 | ``` 76 | pip install --upgrade pip setuptools wheel 77 | pip install neos 78 | pip install git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor 79 | ``` 80 | 81 | With this, you should be able to run the demo notebook [demo.ipynb](demo.ipynb) on your pc :) 82 | 83 | This workflow is as follows: 84 | - From a set of normal distributions with different means, we'll generate four blobs of `(x,y)` points, corresponding to a signal process, a nominal background process, and two variations of the background from varying the background distribution's mean up and down. 85 | - We'll then feed these points into the previously defined neural network for each blob, and construct a histogram of the output using kernel density estimation. The difference between the two background variations is used as a systematic uncertainty on the nominal background. 86 | - We can then leverage the magic of `pyhf` to construct an [event-counting statistical model](https://scikit-hep.org/pyhf/intro.html#histfactory) from the histogram yields. 87 | - Finally, we calculate the p-value of a test between the nominal signal and background-only hypotheses. This uses the familiar [profile likelihood-based test statistic](https://arxiv.org/abs/1007.1727). 88 | 89 | This counts as one forward pass of the workflow -- we then optimize the neural network by gradient descent, backpropagating through the whole analysis! 90 | 91 | 92 | 93 | ## Thanks 94 | 95 | A big thanks to the teams behind [`jax`](https://github.com/google/jax/), [`fax`](https://github.com/gehring/fax), [`jaxopt`](http://github.com/google/jaxopt) and [`pyhf`](https://github.com/scikit-hep/pyhf) for their software and support. 96 | -------------------------------------------------------------------------------- /animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/animation.gif -------------------------------------------------------------------------------- /applying_neos.md: -------------------------------------------------------------------------------- 1 | # The current state of `neos` (written by Nathan) 2 | 3 | At the moment, **there is no working version of the full `neos` pipeline with `pyhf`**. The main reason is that due to a number of errors and bugs when trying to make `pyhf` model building properly use `jax`, it was rabbit hole after rabbit hole of trying to fix things, and it's become clear that there's too much to change to hack it together in a way that makes sense. 4 | 5 | 6 | ## i want to use `neos` with `pyhf` now 7 | 8 | If you want to do something *right now*, there's a pretty simple but time-consuming solution to this: write a new likelihood function in `jax` that can be used in the `neos` pipeline. This is a non-trivial task, and will require a significant amount of time to complete. 9 | 10 | If you're doing unbinned fits, it's likely that this was always the case (there's no HEP-driven unbinned likelihoods in JAX yet, unless you can somehow make [`zfit`](https://github.com/zfit/zfit) work, but this would be another `pyhf`-scale operation). I'd recommend you make your model class a JAX PyTree by inheriting from `equinox.Module` -- this will make it work out-the-box with `relaxed` (see below for more on this). 11 | 12 | If you need some inspiration for a HistFactory-based solution, there are a couple places I can point you to: 13 | 14 | - the [`dummy_pyhf` file in `relaxed`](https://github.com/gradhep/relaxed/blob/main/tests/dummy_pyhf.py) has a working example of a HistFactory-based likelihood function that can be used roughly interchangeably with the `pyhf` one for simple likelihoods with one (bin-uncorrelated) background systematic. It's not perfect, but it could serve as a starting point to try testing your pipelines. 15 | - [`dilax`](https://github.com/pfackeldey/dilax) is a slighly more mature version of this, but does not use the same naming conventions as `pyhf`. It's a nice first attempt at what could be the right way to go about this in future. 16 | 17 | ## long-term plans 18 | I'm not working very actively in the field right now, but I've tried my best to indicate the direction I think things should go in [this discussion on the `pyhf` repo](https://github.com/scikit-hep/pyhf/discussions/2196) -- if this is important to you, maybe leave a reaction or a comment there! The key ingredient is PyTrees (read the issue for more details). If you're interested in working on this, I'd be happy to help out, but I don't have the time (or the HistFactory expertise) to do it fully myself. I think it's a really important thing to do, though -- probably essential if this is going to be truly used in HEP! 19 | 20 | I've just released [`relaxed` v0.3.0](https://github.com/gradhep/relaxed), which has been tested with dummy PyTree models to work. It's designed for a `pyhf` that doesn't exist yet, and may never exist at all. But it will work with any PyTree model, so if you can write a PyTree model, you can use `relaxed` to do your fits, then backpropagate through them. 21 | 22 | ## reaching out 23 | 24 | If you're interested in working on this, please reach out to me through [Mattermost](https://mattermost.web.cern.ch/signup_user_complete/?id=zf7w5rb1miy85xsfjqm68q9hwr&md=link&sbr=su), or by email. 25 | 26 | -------------------------------------------------------------------------------- /binder/postBuild: -------------------------------------------------------------------------------- 1 | python -m pip install --upgrade . 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # Warning: do not change the path here. To use autodoc, you need to install the 8 | # package first. 9 | 10 | from typing import List 11 | 12 | # -- Project information ----------------------------------------------------- 13 | 14 | project = "neos" 15 | copyright = "2021, Nathan Simpson" 16 | author = "Nathan Simpson" 17 | 18 | 19 | # -- General configuration --------------------------------------------------- 20 | 21 | # Add any Sphinx extension module names here, as strings. They can be 22 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 23 | # ones. 24 | extensions = [ 25 | "myst_parser", 26 | "sphinx.ext.autodoc", 27 | "sphinx.ext.mathjax", 28 | "sphinx.ext.napoleon", 29 | "sphinx_copybutton", 30 | ] 31 | 32 | # Add any paths that contain templates here, relative to this directory. 33 | templates_path = [] 34 | 35 | # List of patterns, relative to source directory, that match files and 36 | # directories to ignore when looking for source files. 37 | # This pattern also affects html_static_path and html_extra_path. 38 | exclude_patterns = ["_build", "**.ipynb_checkpoints", "Thumbs.db", ".DS_Store", ".env"] 39 | 40 | 41 | # -- Options for HTML output ------------------------------------------------- 42 | 43 | # The theme to use for HTML and HTML Help pages. See the documentation for 44 | # a list of builtin themes. 45 | # 46 | html_theme = "sphinx_book_theme" 47 | 48 | html_title = f"{project}" 49 | 50 | html_baseurl = "https://neos.readthedocs.io/en/latest/" 51 | 52 | html_theme_options = { 53 | "home_page_in_toc": True, 54 | "repository_url": "https://github.com/gradhep/neos", 55 | "use_repository_button": True, 56 | "use_issues_button": True, 57 | "use_edit_page_button": True, 58 | } 59 | 60 | # Add any paths that contain custom static files (such as style sheets) here, 61 | # relative to this directory. They are copied after the builtin static files, 62 | # so a file named "default.css" will overwrite the builtin "default.css". 63 | html_static_path: List[str] = [] 64 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Welcome to documentation! 3 | ========================= 4 | 5 | 6 | Introduction 7 | ------------ 8 | 9 | This should be updated! 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :titlesonly: 14 | :caption: Contents 15 | :glob: 16 | 17 | 18 | 19 | Indices and tables 20 | ================== 21 | 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | * :ref:`search` 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/ap000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/examples/ap000.gif -------------------------------------------------------------------------------- /examples/float.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/examples/float.png -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | celluloid 2 | git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor 3 | plothelp 4 | -------------------------------------------------------------------------------- /examples/simple-analysis-optimisation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import optax\n", 13 | "from jaxopt import OptaxSolver\n", 14 | "import relaxed\n", 15 | "from celluloid import Camera\n", 16 | "from functools import partial\n", 17 | "import matplotlib.lines as mlines\n", 18 | "\n", 19 | "# matplotlib settings\n", 20 | "plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n", 21 | "plt.rc(\"legend\", fontsize=6)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Optimising a simple one-bin analysis with `relaxed`\n", 29 | "\n", 30 | "Let's define an analysis with a predicted number of signal and background events, with some uncertainty on the background estimate. We'll abstract the analysis configuration into a single parameter $\\phi$ like so:\n", 31 | "\n", 32 | "$$s = 15 + \\phi $$\n", 33 | "$$b = 45 - 2 \\phi $$\n", 34 | "$$\\sigma_b = 0.5 + 0.1*\\phi^2 $$\n", 35 | "\n", 36 | "Note that $s \\propto \\phi$ and $\\propto -2\\phi$, so increasing $\\phi$ corresponds to increasing the signal/backround ratio. However, our uncertainty scales like $\\phi^2$, so we're also going to compromise in our certainty of the background count as we do that. This kind of tradeoff between $s/b$ ratio and uncertainty is important for the discovery of a new signal, so we can't get away with optimising $s/b$ alone.\n", 37 | "\n", 38 | "To illustrate this, we'll plot the discovery significance for this model with and without uncertainty." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# model definition\n", 48 | "def yields(phi, uncertainty=True):\n", 49 | " s = 15 + phi\n", 50 | " b = 45 - 2 * phi\n", 51 | " db = (\n", 52 | " 0.5 + 0.1 * phi**2 if uncertainty else jnp.zeros_like(phi) + 0.001\n", 53 | " ) # small enough to be negligible\n", 54 | " return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])\n", 55 | "\n", 56 | "\n", 57 | "# our analysis pipeline, from phi to p-value\n", 58 | "def pipeline(phi, return_yields=False, uncertainty=True):\n", 59 | " y = yields(phi, uncertainty=uncertainty)\n", 60 | " # use a dummy version of pyhf for simplicity + compatibility with jax\n", 61 | " model = relaxed.dummy_pyhf.uncorrelated_background(*y)\n", 62 | " nominal_pars = jnp.array([1.0, 1.0])\n", 63 | " data = model.expected_data(nominal_pars) # we expect the nominal model\n", 64 | " # do the hypothesis test (and fit model pars with gradient descent)\n", 65 | " pvalue = relaxed.infer.hypotest(\n", 66 | " 0.0, # value of mu for the alternative hypothesis\n", 67 | " data,\n", 68 | " model,\n", 69 | " test_stat=\"q0\", # discovery significance test\n", 70 | " lr=1e-3,\n", 71 | " expected_pars=nominal_pars, # optionally providing MLE pars in advance\n", 72 | " )\n", 73 | " if return_yields:\n", 74 | " return pvalue, y\n", 75 | " else:\n", 76 | " return pvalue\n", 77 | "\n", 78 | "\n", 79 | "# calculate p-values for a range of phi values\n", 80 | "phis = jnp.linspace(0, 10, 100)\n", 81 | "\n", 82 | "# with uncertainty\n", 83 | "pipe = partial(pipeline, return_yields=True, uncertainty=True)\n", 84 | "pvals, ys = jax.vmap(pipe)(phis) # map over phi grid\n", 85 | "# without uncertainty\n", 86 | "pipe_no_uncertainty = partial(pipeline, uncertainty=False)\n", 87 | "pvals_no_uncertainty = jax.vmap(pipe_no_uncertainty)(phis)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "fig, axs = plt.subplots(2, 1, sharex=True)\n", 97 | "axs[0].plot(phis, pvals, label=\"with uncertainty\", color=\"C2\")\n", 98 | "axs[0].plot(phis, pvals_no_uncertainty, label=\"no uncertainty\", color=\"C4\")\n", 99 | "axs[0].set_ylabel(\"$p$-value\")\n", 100 | "# plot vertical dotted line at minimum of p-values + s/b\n", 101 | "best_phi = phis[jnp.argmin(pvals)]\n", 102 | "axs[0].axvline(x=best_phi, linestyle=\"dotted\", color=\"C2\", label=\"optimal p-value\")\n", 103 | "axs[0].axvline(\n", 104 | " x=phis[jnp.argmin(pvals_no_uncertainty)],\n", 105 | " linestyle=\"dotted\",\n", 106 | " color=\"C4\",\n", 107 | " label=r\"optimal $s/b$\",\n", 108 | ")\n", 109 | "axs[0].legend(loc=\"upper left\", ncol=2)\n", 110 | "s, b, db = ys\n", 111 | "s, b, db = s.ravel(), b.ravel(), db.ravel() # everything is [[x]] for pyhf\n", 112 | "axs[1].fill_between(phis, s + b, b, color=\"C9\", label=\"signal\")\n", 113 | "axs[1].fill_between(phis, b, color=\"C1\", label=\"background\")\n", 114 | "axs[1].fill_between(phis, b - db, b + db, facecolor=\"k\", alpha=0.2, label=r\"$\\sigma_b$\")\n", 115 | "axs[1].set_xlabel(\"$\\phi$\")\n", 116 | "axs[1].set_ylabel(\"yield\")\n", 117 | "axs[1].legend(loc=\"lower left\")\n", 118 | "plt.suptitle(\"Discovery p-values, with and without uncertainty\")\n", 119 | "plt.tight_layout()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "Using gradient descent, we can optimise this analysis in an uncertainty-aware way by directly optimising $\\phi$ for the lowest discovery p-value. Here's how you do that:" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# The fast way!\n", 136 | "# use the OptaxSolver wrapper from jaxopt to perform the minimisation\n", 137 | "# set a couple of tolerance kwargs to make sure we don't get stuck\n", 138 | "solver = OptaxSolver(pipeline, opt=optax.adam(1e-3), tol=1e-8, maxiter=10000)\n", 139 | "pars = 9.0 # random init\n", 140 | "result = solver.run(pars).params\n", 141 | "print(\n", 142 | " f\"our solution: phi={result:.5f}\\ntrue optimum: phi={phis[jnp.argmin(pvals)]:.5f}\\nbest s/b: phi=10\"\n", 143 | ")" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "# The longer way (but with plots)!\n", 153 | "pipe = partial(pipeline, return_yields=True, uncertainty=True)\n", 154 | "solver = OptaxSolver(pipe, opt=optax.adam(1e-1), has_aux=True)\n", 155 | "pars = 9.0\n", 156 | "state = solver.init_state(pars) # we're doing init, update steps instead of .run()\n", 157 | "\n", 158 | "plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n", 159 | "plt.rc(\"legend\", fontsize=8)\n", 160 | "fig, axs = plt.subplots(1, 2)\n", 161 | "cam = Camera(fig)\n", 162 | "steps = 5 # increase me for better results! (100ish works well)\n", 163 | "for i in range(steps):\n", 164 | " pars, state = solver.update(pars, state)\n", 165 | " s, b, db = state.aux\n", 166 | " val = state.value\n", 167 | " ax = axs[0]\n", 168 | " cv = ax.plot(phis, pvals, c=\"C0\")\n", 169 | " cvs = ax.plot(phis, pvals_no_uncertainty, c=\"green\")\n", 170 | " current = ax.scatter(pars, val, c=\"C0\")\n", 171 | " ax.set_xlabel(r\"analysis config $\\phi$\")\n", 172 | " ax.set_ylabel(\"p-value\")\n", 173 | " ax.legend(\n", 174 | " [\n", 175 | " mlines.Line2D([], [], color=\"C0\"),\n", 176 | " mlines.Line2D([], [], color=\"green\"),\n", 177 | " current,\n", 178 | " ],\n", 179 | " [\"p-value (with uncert)\", \"p-value (without uncert)\", \"current value\"],\n", 180 | " frameon=False,\n", 181 | " )\n", 182 | " ax.text(0.3, 0.61, f\"step {i}\", transform=ax.transAxes)\n", 183 | " ax = axs[1]\n", 184 | " ax.set_ylim((0, 80))\n", 185 | " b1 = ax.bar(0.5, b, facecolor=\"C1\", label=\"b\")\n", 186 | " b2 = ax.bar(0.5, s, bottom=b, facecolor=\"C9\", label=\"s\")\n", 187 | " b3 = ax.bar(\n", 188 | " 0.5, db, bottom=b - db / 2, facecolor=\"k\", alpha=0.5, label=r\"$\\sigma_b$\"\n", 189 | " )\n", 190 | " ax.set_ylabel(\"yield\")\n", 191 | " ax.set_xticks([])\n", 192 | " ax.legend([b1, b2, b3], [\"b\", \"s\", r\"$\\sigma_b$\"], frameon=False)\n", 193 | " plt.tight_layout()\n", 194 | " cam.snap()\n", 195 | "\n", 196 | "ani = cam.animate()\n", 197 | "# uncomment this to save and view the animation!\n", 198 | "# ani.save(\"ap00.gif\", fps=9)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "interpreter": { 211 | "hash": "22d6333b89854cd01c2018f3ca2f5a59a2cde2765fbca789ff36cfad48ca629b" 212 | }, 213 | "kernelspec": { 214 | "display_name": "Python 3.9.12 ('venv': venv)", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.9.12" 229 | }, 230 | "orig_nbformat": 4 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 2 234 | } 235 | -------------------------------------------------------------------------------- /examples/withbinfloat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/examples/withbinfloat.png -------------------------------------------------------------------------------- /examples/withnobinfloat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/examples/withnobinfloat.png -------------------------------------------------------------------------------- /nbs/assets/2_model_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/2_model_demo.gif -------------------------------------------------------------------------------- /nbs/assets/Screenshot 2020-07-27 at 11.50.53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/Screenshot 2020-07-27 at 11.50.53.png -------------------------------------------------------------------------------- /nbs/assets/anaflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/anaflow.png -------------------------------------------------------------------------------- /nbs/assets/anaflowgrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/anaflowgrad.png -------------------------------------------------------------------------------- /nbs/assets/cc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/cc.png -------------------------------------------------------------------------------- /nbs/assets/cern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/cern.jpg -------------------------------------------------------------------------------- /nbs/assets/cut.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/cut.gif -------------------------------------------------------------------------------- /nbs/assets/eu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/eu.png -------------------------------------------------------------------------------- /nbs/assets/fixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/fixed.png -------------------------------------------------------------------------------- /nbs/assets/free.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/free.png -------------------------------------------------------------------------------- /nbs/assets/goodkde.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/goodkde.gif -------------------------------------------------------------------------------- /nbs/assets/gradhep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/gradhep.png -------------------------------------------------------------------------------- /nbs/assets/insights.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/insights.jpg -------------------------------------------------------------------------------- /nbs/assets/jax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/jax.png -------------------------------------------------------------------------------- /nbs/assets/kde_bins.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kde_bins.gif -------------------------------------------------------------------------------- /nbs/assets/kde_interesting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kde_interesting.gif -------------------------------------------------------------------------------- /nbs/assets/kde_pyhf_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kde_pyhf_animation.gif -------------------------------------------------------------------------------- /nbs/assets/kde_sigmoif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kde_sigmoif.gif -------------------------------------------------------------------------------- /nbs/assets/kdesig.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kdesig.gif -------------------------------------------------------------------------------- /nbs/assets/kdestud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/kdestud.png -------------------------------------------------------------------------------- /nbs/assets/lu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/lu.png -------------------------------------------------------------------------------- /nbs/assets/multi_bin.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/multi_bin.gif -------------------------------------------------------------------------------- /nbs/assets/neoflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/neoflow.png -------------------------------------------------------------------------------- /nbs/assets/neos-slide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/neos-slide.png -------------------------------------------------------------------------------- /nbs/assets/neos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/neos.png -------------------------------------------------------------------------------- /nbs/assets/neos_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/neos_banner.png -------------------------------------------------------------------------------- /nbs/assets/neos_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/neos_logo.png -------------------------------------------------------------------------------- /nbs/assets/noaxis.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/noaxis.gif -------------------------------------------------------------------------------- /nbs/assets/pyhf-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/pyhf-logo.png -------------------------------------------------------------------------------- /nbs/assets/pyhf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/pyhf.png -------------------------------------------------------------------------------- /nbs/assets/pyhf_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/pyhf_3.gif -------------------------------------------------------------------------------- /nbs/assets/soft2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/soft2.gif -------------------------------------------------------------------------------- /nbs/assets/softmax_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/softmax_animation.gif -------------------------------------------------------------------------------- /nbs/assets/softmax_animation_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/softmax_animation_2.gif -------------------------------------------------------------------------------- /nbs/assets/softmax_pyhf_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/softmax_pyhf_animation.gif -------------------------------------------------------------------------------- /nbs/assets/training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/nbs/assets/training.gif -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import nox 5 | 6 | DIR = Path(__file__).parent.resolve() 7 | 8 | nox.options.sessions = ["lint", "tests"] 9 | 10 | 11 | @nox.session 12 | def lint(session: nox.Session) -> None: 13 | """ 14 | Run the linter. 15 | """ 16 | session.install("pre-commit") 17 | session.run("pre-commit", "run", "--all-files", *session.posargs) 18 | 19 | 20 | @nox.session 21 | def tests(session: nox.Session) -> None: 22 | """ 23 | Run the unit and regular tests. 24 | """ 25 | session.install(".[test]") 26 | session.run("pytest", *session.posargs) 27 | 28 | 29 | @nox.session 30 | def docs(session: nox.Session) -> None: 31 | """ 32 | Build the docs. Pass "serve" to serve. 33 | """ 34 | 35 | session.install(".[docs]") 36 | session.chdir("docs") 37 | session.run("sphinx-build", "-M", "html", ".", "_build") 38 | 39 | if session.posargs: 40 | if "serve" in session.posargs: 41 | print("Launching docs at http://localhost:8000/ - use Ctrl-C to quit") 42 | session.run("python", "-m", "http.server", "8000", "-d", "_build/html") 43 | else: 44 | print("Unsupported argument to docs") 45 | 46 | 47 | @nox.session 48 | def build(session: nox.Session) -> None: 49 | """ 50 | Build an SDist and wheel. 51 | """ 52 | 53 | build_p = DIR.joinpath("build") 54 | if build_p.exists(): 55 | shutil.rmtree(build_p) 56 | 57 | session.install("build") 58 | session.run("python", "-m", "build") 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["wheel", "setuptools>=42", "setuptools_scm[toml]>=3.4"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.setuptools_scm] 7 | write_to = "src/neos/_version.py" 8 | 9 | 10 | [tool.pytest.ini_options] 11 | minversion = "6.0" 12 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] 13 | xfail_strict = true 14 | filterwarnings = [ 15 | "error", 16 | "ignore:the imp module is deprecated:DeprecationWarning", 17 | ] 18 | testpaths = [ 19 | "tests", 20 | ] 21 | 22 | 23 | [tool.pycln] 24 | all = true 25 | 26 | 27 | [tool.mypy] 28 | files = "src" 29 | python_version = "3.7" 30 | warn_unused_configs = true 31 | 32 | disallow_any_generics = true 33 | disallow_subclassing_any = true 34 | disallow_untyped_calls = true 35 | disallow_untyped_defs = true 36 | disallow_incomplete_defs = true 37 | check_untyped_defs = true 38 | disallow_untyped_decorators = true 39 | no_implicit_optional = true 40 | warn_redundant_casts = true 41 | warn_unused_ignores = true 42 | warn_return_any = true 43 | no_implicit_reexport = true 44 | strict_equality = true 45 | 46 | 47 | [tool.check-manifest] 48 | ignore = [ 49 | ".github/**", 50 | "docs/**", 51 | ".pre-commit-config.yaml", 52 | ".readthedocs.yml", 53 | "src/*/_version.py", 54 | "noxfile.py", 55 | ] 56 | 57 | [tool.isort] 58 | profile = "black" 59 | -------------------------------------------------------------------------------- /random.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/random.pdf -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = neos 3 | description = UpUpstream optimization of a neural net summary statistic with respect to downstream inference goals. 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | url = https://github.com/gradhep/neos 7 | author = Nathan Simpson 8 | author_email = n.s@cern.ch 9 | maintainer = Nathan Simpson 10 | maintainer_email = n.s@cern.ch 11 | license = BSD-3-Clause 12 | license_file = LICENSE 13 | platforms = 14 | Any 15 | classifiers = 16 | Development Status :: 1 - Planning 17 | Intended Audience :: Developers 18 | Intended Audience :: Science/Research 19 | License :: OSI Approved :: BSD License 20 | Operating System :: OS Independent 21 | Programming Language :: Python 22 | Programming Language :: Python :: 3 23 | Programming Language :: Python :: 3 :: Only 24 | Programming Language :: Python :: 3.6 25 | Programming Language :: Python :: 3.7 26 | Programming Language :: Python :: 3.8 27 | Programming Language :: Python :: 3.9 28 | Programming Language :: Python :: 3.10 29 | Topic :: Scientific/Engineering 30 | project_urls = 31 | Documentation = https://neos.readthedocs.io/ 32 | Bug Tracker = https://github.com/gradhep/neos/issues 33 | Discussions = https://github.com/gradhep/neos/discussions 34 | Changelog = https://github.com/gradhep/neos/releases 35 | 36 | [options] 37 | packages = find: 38 | install_requires = 39 | celluloid 40 | matplotlib 41 | relaxed>=0.2.0 42 | sklearn 43 | typing-extensions>=3.7;python_version<'3.8' 44 | python_requires = >=3.6 45 | include_package_data = True 46 | package_dir = 47 | =src 48 | 49 | [options.packages.find] 50 | where = src 51 | 52 | [options.extras_require] 53 | dev = 54 | pytest>=6 55 | docs = 56 | Sphinx~=3.0 57 | myst-parser>=0.13 58 | sphinx-book-theme>=0.1.0 59 | sphinx-copybutton 60 | test = 61 | pytest>=6 62 | 63 | [flake8] 64 | ignore = E203, E231, E501, E722, W503, B950 65 | select = C,E,F,W,T,B,B9,I 66 | per-file-ignores = 67 | tests/*: T 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2021, Nathan Simpson 3 | # 4 | # Distributed under the 3-clause BSD license, see accompanying file LICENSE 5 | # or https://github.com/gradhep/neos for details. 6 | 7 | from setuptools import setup 8 | 9 | setup() 10 | 11 | # This file is optional, on recent versions of pip you can remove it and even 12 | # still get editable installs. 13 | -------------------------------------------------------------------------------- /src/neos/__init__.py: -------------------------------------------------------------------------------- 1 | from neos._version import version as __version__ 2 | 3 | __all__ = ( 4 | "__version__", 5 | "hists_from_nn", 6 | "loss_from_model", 7 | "losses", 8 | ) 9 | 10 | from neos import losses 11 | from neos.top_level import hists_from_nn, loss_from_model 12 | -------------------------------------------------------------------------------- /src/neos/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ( 4 | "poi_uncert", 5 | "discovery_significance", 6 | "cls_value", 7 | "generalised_variance", 8 | "bce", 9 | ) 10 | 11 | import jax.numpy as jnp 12 | import pyhf 13 | 14 | import relaxed 15 | 16 | Array = jnp.ndarray 17 | 18 | 19 | def poi_uncert(model: pyhf.Model) -> float: 20 | hypothesis_pars = ( 21 | jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(1.0) 22 | ) 23 | observed_hist = jnp.asarray(model.expected_data(hypothesis_pars)) 24 | return relaxed.cramer_rao_uncert(model, hypothesis_pars, observed_hist)[ 25 | model.config.poi_index 26 | ] 27 | 28 | 29 | def discovery_significance(model: pyhf.Model, fit_lr: float) -> float: 30 | test_stat = "q0" 31 | test_poi = 0.0 # background-only as the alternative 32 | # nominal s+b as the null 33 | hypothesis_pars = ( 34 | jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(1.0) 35 | ) 36 | observed_hist = jnp.asarray(model.expected_data(hypothesis_pars)) 37 | return relaxed.infer.hypotest( 38 | test_poi=test_poi, 39 | data=observed_hist, 40 | model=model, 41 | test_stat=test_stat, 42 | expected_pars=hypothesis_pars, 43 | lr=fit_lr, 44 | ) 45 | 46 | 47 | def cls_value(model: pyhf.Model, fit_lr: float) -> float: 48 | test_stat = "q" 49 | test_poi = 1.0 # nominal s+b as the null 50 | # background-only as the alternative 51 | hypothesis_pars = ( 52 | jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(0.0) 53 | ) 54 | observed_hist = jnp.asarray(model.expected_data(hypothesis_pars)) 55 | return relaxed.infer.hypotest( 56 | test_poi=test_poi, 57 | data=observed_hist, 58 | model=model, 59 | test_stat=test_stat, 60 | expected_pars=hypothesis_pars, 61 | lr=fit_lr, 62 | ) 63 | 64 | 65 | def generalised_variance(model: pyhf.Model) -> float: 66 | hypothesis_pars = ( 67 | jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(0.0) 68 | ) 69 | observed_hist = jnp.asarray(model.expected_data(hypothesis_pars)) 70 | return 1 / jnp.linalg.det( 71 | relaxed.fisher_info(model, hypothesis_pars, observed_hist) 72 | ) 73 | 74 | 75 | def sigmoid_cross_entropy_with_logits(preds, labels): 76 | return jnp.mean( 77 | jnp.maximum(preds, 0) - preds * labels + jnp.log1p(jnp.exp(-jnp.abs(preds))) 78 | ) 79 | 80 | 81 | def bce(data, nn, pars, with_aug=False, signal_label="sig", background_label="bkg"): 82 | preds = {k: nn(pars, data[k]).ravel() for k in data} 83 | if with_aug: 84 | bkg = jnp.concatenate([preds[k] for k in preds if signal_label not in k]) 85 | all_vals = jnp.concatenate(list(preds.values())).ravel() 86 | else: 87 | bkg = preds[background_label] 88 | all_vals = jnp.concatenate( 89 | [preds[signal_label], preds[background_label]] 90 | ).ravel() 91 | sig = preds[signal_label] 92 | labels = jnp.concatenate([jnp.ones_like(sig), jnp.zeros_like(bkg)]) 93 | return sigmoid_cross_entropy_with_logits(all_vals, labels).mean() 94 | -------------------------------------------------------------------------------- /src/neos/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gradhep/neos/fd52e3fc91a24805880cdc0a8d0669a06235955e/src/neos/py.typed -------------------------------------------------------------------------------- /src/neos/top_level.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ( 4 | "loss_from_model", 5 | "hists_from_nn", 6 | ) 7 | 8 | from functools import partial 9 | from typing import Any, Callable 10 | 11 | import jax.numpy as jnp 12 | import pyhf 13 | 14 | import relaxed 15 | from neos.losses import ( 16 | cls_value, 17 | discovery_significance, 18 | generalised_variance, 19 | poi_uncert, 20 | ) 21 | 22 | Array = jnp.ndarray 23 | 24 | 25 | def hists_from_nn( 26 | pars: Array, 27 | data: dict[str, Array], 28 | nn: Callable, 29 | bandwidth: float, 30 | bins: Array, 31 | scale_factors: dict[str, float] | None = None, 32 | overall_scale: float = 1.0, 33 | ) -> dict[str, Array]: 34 | """Function that takes in data + analysis config parameters, and constructs yields.""" 35 | # apply the neural network to each data sample, and keep track of the sample names in a dict 36 | nn_output = {k: nn(pars, data[k]).ravel() for k in data} 37 | 38 | # The next two lines allow you to also optimise your binning: 39 | bins_new = jnp.concatenate( 40 | ( 41 | jnp.array([bins[0]]), 42 | jnp.where(bins[1:] > bins[:-1], bins[1:], bins[:-1] + 1e-4), 43 | ), 44 | axis=0, 45 | ) 46 | # define our histogram-maker with some hyperparameters (bandwidth, binning) 47 | make_hist = partial(relaxed.hist, bandwidth=bandwidth, bins=bins_new) 48 | 49 | # every histogram is scaled to the number of points from that data source in the batch 50 | # so we have more control over the scaling of sig/bkg for realism 51 | scale_factors = scale_factors or {k: 1.0 for k in nn_output} 52 | hists = { 53 | k: make_hist(nn_output[k]) * scale_factors[k] * overall_scale / len(v) 54 | + 1e-3 # add a floor so no zeros in any bin! 55 | for k, v in nn_output.items() 56 | } 57 | return hists 58 | 59 | 60 | def loss_from_model( 61 | model: pyhf.Model, 62 | loss: str | Callable[[dict[str, Any]], float] = "neos", 63 | fit_lr: float = 1e-3, 64 | ) -> float: 65 | if isinstance(loss, Callable): 66 | # everything 67 | return 0 68 | # loss specific 69 | if loss.lower() == "discovery": 70 | return discovery_significance(model, fit_lr) 71 | elif loss.lower() in ["neos", "cls"]: 72 | return cls_value(model, fit_lr) 73 | elif loss.lower() in ["inferno", "poi_uncert", "mu_uncert"]: 74 | return poi_uncert(model) 75 | elif loss.lower() in [ 76 | "general_variance", 77 | "generalised_variance", 78 | "generalized_variance", 79 | ]: 80 | return generalised_variance(model) 81 | else: 82 | raise ValueError(f"loss function {loss} not recognised") 83 | -------------------------------------------------------------------------------- /tests/test_package.py: -------------------------------------------------------------------------------- 1 | import neos 2 | 3 | 4 | def test_version(): 5 | assert neos.__version__ 6 | --------------------------------------------------------------------------------