├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docs ├── Makefile └── source │ ├── api.rst │ ├── conf.py │ ├── constraints.rst │ ├── criticism.rst │ ├── index.rst │ ├── internals.rst │ └── testing.rst ├── examples ├── README.md ├── immune_sequence.py ├── oed_vs_rand.png ├── plot_tf8_results.ipynb ├── rollout_tf8.py ├── tf8_criticism.ipynb └── tf8_demo.ipynb ├── pyroed ├── __init__.py ├── api.py ├── constraints.py ├── criticism.py ├── datasets │ ├── __init__.py │ ├── data.py │ └── tf_bind_8-PBX4_REF_R2.npy.gz ├── inference.py ├── models.py ├── oed.py ├── optimizers.py ├── py.typed ├── testing.py └── typing.py ├── setup.cfg ├── setup.py └── test ├── conftest.py ├── test_constraints.py └── test_e2e.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | env: 13 | CXX: g++-8 14 | CC: gcc-8 15 | 16 | jobs: 17 | lint: 18 | runs-on: ubuntu-20.04 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.7 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip wheel setuptools 28 | pip install flake8 black isort>=5.0 mypy 29 | - name: Lint 30 | run: | 31 | make lint 32 | 33 | docs: 34 | runs-on: ubuntu-20.04 35 | needs: lint 36 | steps: 37 | - name: Set up Python 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: 3.7 41 | - uses: actions/checkout@master 42 | with: 43 | fetch-depth: 0 # otherwise, you will failed to push refs to dest repo 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip wheel setuptools 47 | pip install sphinx_rtd_theme 48 | pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 49 | pip install .[test] 50 | pip freeze 51 | - name: Build and Commit 52 | uses: sphinx-notes/pages@v2 53 | with: 54 | documentation_path: docs/source 55 | - name: Push changes 56 | uses: ad-m/github-push-action@master 57 | with: 58 | github_token: ${{ secrets.GITHUB_TOKEN }} 59 | branch: gh-pages 60 | 61 | test: 62 | runs-on: ubuntu-20.04 63 | needs: lint 64 | steps: 65 | - uses: actions/checkout@v2 66 | - name: Set up Python 67 | uses: actions/setup-python@v2 68 | with: 69 | python-version: 3.7 70 | - name: Install dependencies 71 | run: | 72 | sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test 73 | sudo apt-get update 74 | sudo apt-get install gcc-8 g++-8 ninja-build 75 | python -m pip install --upgrade pip wheel setuptools 76 | pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 77 | pip install .[test] 78 | pip freeze 79 | - name: Run tests 80 | run: | 81 | make test 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.pdf 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: FORCE 2 | python -m pip install -e .[test] 3 | 4 | docs: FORCE 5 | $(MAKE) -C docs html 6 | 7 | lint: FORCE 8 | python -m flake8 9 | python -m black --check *.py pyroed test examples/*.py 10 | python -m isort --check . 11 | python -m mypy --install-types --non-interactive pyroed test 12 | 13 | format: FORCE 14 | python -m black *.py pyroed test examples/*.py 15 | python -m isort . 16 | 17 | test: lint FORCE 18 | pytest -vx test 19 | python examples/immune_sequence.py --simulate-experiments=1 20 | @echo PASSED 21 | 22 | FORCE: 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://github.com/pyro-ppl/pyroed/workflows/CI/badge.svg)](https://github.com/pyro-ppl/pyroed/actions) 2 | [![Docs](https://img.shields.io/badge/api-docs-blue)](https://pyro-ppl.github.io/pyroed/) 3 | 4 | # Bayesian optimization of discrete sequences 5 | 6 | Pyroed is a framework for model-based optimization of sequences of discrete 7 | choices with constraints among choices. 8 | Pyroed aims to address the regime where there is very little data (100-10000 9 | observations), small batch size (say 10-100), short sequences (length 2-100) of 10 | heterogeneous choice sets, and possibly with [constraints](https://pyro-ppl.github.io/pyroed/constraints.html) among choices at 11 | different positions in the sequence. 12 | 13 | Under the hood, Pyroed performs 14 | [Thompson sampling](https://pyro-ppl.github.io/pyroed/internals.html#pyroed.oed.thompson_sample) 15 | against a hierarchical Bayesian 16 | [linear regression model](https://pyro-ppl.github.io/pyroed/internals.html#pyroed.models.model) 17 | that is automatically generated from a Pyroed problem specification, deferring 18 | to [Pyro](https://pyro.ai) for Bayesian inference (either 19 | [variational](https://pyro-ppl.github.io/pyroed/internals.html#pyroed.inference.fit_svi) 20 | or 21 | [MCMC](https://pyro-ppl.github.io/pyroed/internals.html#pyroed.inference.fit_mcmc)) 22 | and to 23 | [annealed Gibbs sampling](https://pyro-ppl.github.io/pyroed/internals.html#pyroed.optimizers.optimize_simulated_annealing) 24 | for discrete optimization. All numerics is performed by 25 | [PyTorch](https://pytorch.org). 26 | 27 | ## Installing 28 | 29 | You can install directly from github via 30 | ```sh 31 | pip install https://github.com/pyro-ppl/pyroed/archive/main.zip 32 | ``` 33 | For developing Pyroed you can install from source 34 | ```sh 35 | git clone git@github.com:pyro-ppl/pyroed 36 | cd pyroed 37 | pip install -e . 38 | ``` 39 | 40 | ## Quick Start 41 | 42 | ### 1. Specify your problem in the Pyroed language 43 | 44 | First specify your sequence space by declaring a `SCHEMA`, `CONSTRAINTS`, `FEATURE_BLOCKS`, and `GIBBS_BLOCKS`. These are all simple Python data structures. 45 | For example to optimize a nucleotide sequence of length 6: 46 | ```python 47 | # Declare the set of choices and the values each choice can take. 48 | SCHEMA = OrderedDict() 49 | SCHEMA["nuc0"] = ["A", "C", "G", "T"] # these are the same, but 50 | SCHEMA["nuc1"] = ["A", "C", "G", "T"] # you can make each list different 51 | SCHEMA["nuc2"] = ["A", "C", "G", "T"] 52 | SCHEMA["nuc3"] = ["A", "C", "G", "T"] 53 | SCHEMA["nuc4"] = ["A", "C", "G", "T"] 54 | SCHEMA["nuc5"] = ["A", "C", "G", "T"] 55 | 56 | # Declare some constraints. See pyroed.constraints for options. 57 | CONSTRAINTS = [] 58 | CONSTRAINTS.append(AllDifferent("nuc0", "nuc1", "nuc2")) 59 | CONSTRAINTS.append(Iff(TakesValue("nuc4", "T"), TakesValue("nuc5", "T"))) 60 | 61 | # Specify groups of cross features for the Bayesian linear regression model. 62 | FEATURE_BLOCKS = [] 63 | FEATURE_BLOCKS.append(["nuc0"]) # single features 64 | FEATURE_BLOCKS.append(["nuc1"]) 65 | FEATURE_BLOCKS.append(["nuc2"]) 66 | FEATURE_BLOCKS.append(["nuc3"]) 67 | FEATURE_BLOCKS.append(["nuc4"]) 68 | FEATURE_BLOCKS.append(["nuc5"]) 69 | FEATURE_BLOCKS.append(["nuc0", "nuc1"]) # consecutive pairs 70 | FEATURE_BLOCKS.append(["nuc1", "nuc2"]) 71 | FEATURE_BLOCKS.append(["nuc2", "nuc3"]) 72 | FEATURE_BLOCKS.append(["nuc3", "nuc4"]) 73 | FEATURE_BLOCKS.append(["nuc4", "nuc5"]) 74 | 75 | # Finally define Gibbs sampling blocks for the discrete optimization. 76 | GIBBS_BLOCKS = [] 77 | GIBBS_BLOCKS.append(["nuc0", "nuc1"]) # consecutive pairs 78 | GIBBS_BLOCKS.append(["nuc1", "nuc2"]) 79 | GIBBS_BLOCKS.append(["nuc2", "nuc3"]) 80 | GIBBS_BLOCKS.append(["nuc3", "nuc4"]) 81 | GIBBS_BLOCKS.append(["nuc4", "nuc5"]) 82 | ``` 83 | 84 | ### 2. Declare your initial experiment 85 | 86 | An experiment consists of a set of `sequences` and the experimentally measured 87 | `responses` of those sequences. 88 | ```python 89 | # Enter your existing data. 90 | sequences = ["ACGAAA", "ACGATT", "AGTTTT"] 91 | responses = torch.tensor([0.1, 0.2, 0.6]) 92 | 93 | # Collect these into a dictionary that we'll maintain throughout our workflow. 94 | design = pyroed.encode_design(SCHEMA, sequences) 95 | experiment = pyroed.start_experiment(SCHEMA, design, responses) 96 | ``` 97 | 98 | ### 3. Iteratively create new designs 99 | 100 | At each step of our optimization loop, we'll query Pyroed for a new design. 101 | Pyroed choose the design to balance exploitation (finding sequences with high 102 | response) and exploration. 103 | ```python 104 | design = pyroed.get_next_design( 105 | SCHEMA, CONSTRAINTS, FEATURE_BLOCKS, GIBBS_BLOCKS, experiment, design_size=3 106 | ) 107 | new_seqences = ["".join(s) for s in pyroed.decode_design(SCHEMA, design)] 108 | print(new_sequences) 109 | # ["CAGTGC", "GCAGTT", "TAGGTT"] 110 | ``` 111 | Then we'll go to the lab, measure the responses of these new sequences, and 112 | append the new results to our experiment: 113 | ```python 114 | new_responses = torch.tensor([0.04, 0.3, 0.25]) 115 | experiment = pyroed.update_experiment(SCHEMA, experiment, design, new_responses) 116 | ``` 117 | We repeat step 3 as long as we like. 118 | 119 | ## Demo: Semi-Synthetic Experiment 120 | 121 | For a more in-depth demonstration of Pyroed usage in practice on some transcription factor data 122 | see [`rollout_tf8.py`](https://github.com/pyro-ppl/pyroed/blob/main/examples/rollout_tf8.py) 123 | and [`tf8_demo.ipynb`](https://github.com/pyro-ppl/pyroed/blob/main/examples/tf8_demo.ipynb). 124 | 125 | ![plot](./examples/oed_vs_rand.png) 126 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS ?= -E -W 6 | SPHINXBUILD = python -msphinx 7 | APIDOC = sphinx-apidoc 8 | SPHINXPROJ = Pyroed 9 | SOURCEDIR = source 10 | PROJECTDIR = ../pyroed 11 | BUILDDIR = build 12 | 13 | # Put it first so that "make" without argument is like "make help". 14 | help: 15 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 16 | 17 | .PHONY: help Makefile 18 | 19 | apidoc: 20 | $(APIDOC) -o "$(SOURCEDIR)" "$(PROJECTDIR)" 21 | # Catch-all target: route all unknown targets to Sphinx using the new 22 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 23 | %: Makefile 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | Interface 2 | ========= 3 | .. automodule:: pyroed.api 4 | :members: 5 | :undoc-members: 6 | :show-inheritance: 7 | :member-order: bysource 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import sphinx_rtd_theme 5 | 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # This file only contains a selection of the most common options. For a full 9 | # list see the documentation: 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 11 | 12 | # -- Path setup -------------------------------------------------------------- 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | sys.path.insert(0, os.path.abspath('.')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'Pyroed' 22 | copyright = '2022, Fritz Obermeyer, Martin Jankowiak' 23 | author = 'Fritz Obermeyer, Martin Jankowiak' 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx.ext.intersphinx", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.ifconfig", 35 | "sphinx.ext.viewcode", 36 | "sphinx.ext.githubpages", 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.doctest", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = [] 48 | 49 | autodoc_member_order = "bysource" 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The theme to use for HTML and HTML Help pages. See the documentation for 55 | # a list of builtin themes. 56 | 57 | html_theme = "sphinx_rtd_theme" 58 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 59 | 60 | # Add any paths that contain custom static files (such as style sheets) here, 61 | # relative to this directory. They are copied after the builtin static files, 62 | # so a file named "default.css" will overwrite the builtin "default.css". 63 | # html_static_path = ['_static'] 64 | 65 | # Example configuration for intersphinx: refer to the Python standard library. 66 | intersphinx_mapping = { 67 | "python": ("https://docs.python.org/3/", None), 68 | "torch": ("https://pytorch.org/docs/master/", None), 69 | "funsor": ("http://funsor.pyro.ai/en/stable/", None), 70 | "pyro": ("http://docs.pyro.ai/en/stable/", None), 71 | "opt_einsum": ("https://optimized-einsum.readthedocs.io/en/stable/", None), 72 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 73 | "Bio": ("https://biopython.org/docs/latest/api/", None), 74 | "horovod": ("https://horovod.readthedocs.io/en/stable/", None), 75 | "graphviz": ("https://graphviz.readthedocs.io/en/stable/", None), 76 | } 77 | -------------------------------------------------------------------------------- /docs/source/constraints.rst: -------------------------------------------------------------------------------- 1 | Constraints 2 | =========== 3 | .. automodule:: pyroed.constraints 4 | :members: 5 | :undoc-members: 6 | :show-inheritance: 7 | :member-order: bysource 8 | -------------------------------------------------------------------------------- /docs/source/criticism.rst: -------------------------------------------------------------------------------- 1 | Criticism 2 | ========= 3 | .. automodule:: pyroed.criticism 4 | :members: 5 | :undoc-members: 6 | :show-inheritance: 7 | :member-order: bysource 8 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Pyroed documentation master file, created by 2 | sphinx-quickstart on Tue Mar 15 13:33:23 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/pyro-ppl/pyroed 7 | 8 | Pyroed Documentation 9 | ==================== 10 | 11 | This is the low-level documentation for 12 | `Pyroed `_, 13 | a library for combinatorial Bayesian optimization of constrained sequences. 14 | See also the 15 | `examples/ `_ 16 | directory on github. 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Contents: 21 | 22 | api.rst 23 | constraints.rst 24 | criticism.rst 25 | internals.rst 26 | testing.rst 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/source/internals.rst: -------------------------------------------------------------------------------- 1 | Internals 2 | ========= 3 | 4 | Typing & Validation 5 | ------------------- 6 | .. automodule:: pyroed.typing 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | :member-order: bysource 11 | 12 | Models 13 | ------ 14 | .. automodule:: pyroed.models 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | :member-order: bysource 19 | 20 | Inference 21 | --------- 22 | .. automodule:: pyroed.inference 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | :member-order: bysource 27 | 28 | Optimization 29 | ------------ 30 | .. automodule:: pyroed.optimizers 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | :member-order: bysource 35 | 36 | Experiment Design 37 | ----------------- 38 | .. automodule:: pyroed.oed 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | :member-order: bysource 43 | -------------------------------------------------------------------------------- /docs/source/testing.rst: -------------------------------------------------------------------------------- 1 | Testing Utilities 2 | ================= 3 | .. automodule:: pyroed.testing 4 | :members: 5 | :undoc-members: 6 | :show-inheritance: 7 | :member-order: bysource 8 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples of Pyroed Usage 2 | 3 | The files in this directory are as follows: 4 | 5 | - `immune_sequence.py`: end-to-end example of Pyroed usage using simulated data 6 | - `oed_vs_rand.png`: figure generated by `plot_tf8_results.ipynb` 7 | - `plot_tf8_results.ipynb`: notebook for generating plots; companion to `rollout_tf8.py` 8 | - `rollout_tf8.py`: script for doing semi-synthetic experiment involving transcription factor data 9 | - `tf8_demo.ipynb`: notebook that demonstrates basic Pyroed usage in the context of transcription factor data 10 | -------------------------------------------------------------------------------- /examples/immune_sequence.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | import argparse 4 | import warnings 5 | from collections import OrderedDict 6 | 7 | import pandas as pd 8 | import pyro 9 | import torch 10 | import torch.multiprocessing as mp 11 | 12 | from pyroed.api import get_next_design, start_experiment, update_experiment 13 | from pyroed.constraints import AllDifferent, Iff, IfThen, TakesValue 14 | from pyroed.testing import generate_fake_data 15 | 16 | # Specify the design space via SCHEMA, CONSTRAINTS, FEATURE_BLOCKS, and GIBBS_BLOCKS. 17 | SCHEMA = OrderedDict() 18 | SCHEMA["Protein 1"] = ["Prot1", "Prot2", None] 19 | SCHEMA["Protein 2"] = ["Prot3", "HLA1", "HLA2", "HLA3", "HLA4", None] 20 | SCHEMA["Signalling Pep"] = ["Sig1", "Sig2", None] 21 | SCHEMA["EP"] = [f"Ep{i}" for i in range(1, 10 + 1)] + [None] 22 | SCHEMA["Linker"] = ["Link1", None] 23 | SCHEMA["Internal"] = ["Int1", "Int2", "Int3", "Int3", None] 24 | SCHEMA["2A-1"] = ["twoa1", "twoa2", None] 25 | SCHEMA["2A-2"] = ["twoa3", "twoa4", None] 26 | SCHEMA["2A-3"] = [f"twoa{i}" for i in range(1, 7 + 1)] 27 | 28 | CONSTRAINTS = [ 29 | AllDifferent("2A-1", "2A-2", "2A-3"), 30 | Iff(TakesValue("Protein 1", None), TakesValue("2A-1", None)), 31 | Iff(TakesValue("Signalling Pep", None), TakesValue("EP", None)), 32 | Iff(TakesValue("EP", None), TakesValue("Linker", None)), 33 | IfThen(TakesValue("Protein 2", None), TakesValue("Internal", None)), 34 | Iff(TakesValue("Protein 2", "Prot3"), TakesValue("2A-2", None)), 35 | ] 36 | 37 | FEATURE_BLOCKS = [[name] for name in SCHEMA] 38 | FEATURE_BLOCKS.append(["Protein 1", "Protein 2"]) # TODO(liz) add a real interaction 39 | 40 | GIBBS_BLOCKS = [ 41 | ["Protein 1", "2A-1"], 42 | ["Signalling Pep", "EP", "Linker"], 43 | ["2A-1", "2A-2", "2A-3"], 44 | ["Protein 2", "Internal", "2A-2"], 45 | ] 46 | 47 | 48 | def load_experiment(filename, schema): 49 | df = pd.read_csv(filename, sep="\t") 50 | 51 | # Load response. 52 | col = "Amount Expression Output 1" 53 | df = df[~df[col].isna()] # Filter to rows where response was observed. 54 | N = len(df[col]) 55 | response = torch.zeros(N) 56 | response[:] = torch.tensor([float(cell.strip("%")) / 100 for cell in df[col]]) 57 | 58 | # Load sequences. 59 | sequences = torch.zeros(N, len(SCHEMA), dtype=torch.long) 60 | for i, (name, values) in enumerate(SCHEMA.items()): 61 | sequences[:, i] = torch.tensor( 62 | [values.index(v if isinstance(v, str) else None) for v in df[name]] 63 | ) 64 | 65 | # Optionally load batch id. 66 | col = "Batch ID" 67 | batch_id = torch.zeros(N, dtype=torch.long) 68 | if col in df: 69 | batch_id[:] = df[col].to_numpy() 70 | else: 71 | warnings.warn(f"Found no '{col}' column, assuming a single batch") 72 | 73 | experiment = start_experiment(SCHEMA, sequences, response, batch_id) 74 | return experiment 75 | 76 | 77 | def main(args): 78 | pyro.set_rng_seed(args.seed) 79 | 80 | if args.tsv_file_in: 81 | print(f"Loading data from {args.tsv_file_in}") 82 | experiment = load_experiment(args.tsv_file_in, SCHEMA) 83 | else: 84 | print("Generating fake data") 85 | truth, experiment = generate_fake_data( 86 | SCHEMA, FEATURE_BLOCKS, args.sequences_per_batch, args.simulate_batches 87 | ) 88 | 89 | config = { 90 | "inference": "mcmc" if args.mcmc else "svi", 91 | "mcmc_num_samples": args.mcmc_num_samples, 92 | "mcmc_warmup_steps": args.mcmc_warmup_steps, 93 | "mcmc_num_chains": args.mcmc_num_chains, 94 | "svi_num_steps": args.svi_num_steps, 95 | "sa_num_steps": args.sa_num_steps, 96 | "max_tries": args.max_tries, 97 | "thompson_temperature": args.thompson_temperature, 98 | "log_every": args.log_every, 99 | "jit_compile": args.jit, 100 | } 101 | design = get_next_design( 102 | SCHEMA, 103 | CONSTRAINTS, 104 | FEATURE_BLOCKS, 105 | GIBBS_BLOCKS, 106 | experiment, 107 | design_size=args.sequences_per_batch, 108 | config=config, 109 | ) 110 | print("Design:") 111 | for row in design.tolist(): 112 | cells = [values[i] for values, i in zip(SCHEMA.values(), row)] 113 | print("\t".join("-" if c is None else c for c in cells)) 114 | 115 | for step in range(args.simulate_experiments): 116 | print("Simulating fake responses") 117 | response = torch.rand(len(design)) 118 | experiment = update_experiment(SCHEMA, experiment, design, response) 119 | design = get_next_design( 120 | SCHEMA, 121 | CONSTRAINTS, 122 | FEATURE_BLOCKS, 123 | GIBBS_BLOCKS, 124 | experiment, 125 | design_size=args.sequences_per_batch, 126 | config=config, 127 | ) 128 | print("Design:") 129 | for row in design.tolist(): 130 | cells = [values[i] for values, i in zip(SCHEMA.values(), row)] 131 | print("\t".join("-" if c is None else c for c in cells)) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description="Design sequences") 136 | 137 | # Data files. 138 | parser.add_argument("--tsv-file-in") 139 | 140 | # Simulation parameters. 141 | parser.add_argument("--sequences-per-batch", default=10, type=int) 142 | parser.add_argument("--simulate-batches", default=20) 143 | parser.add_argument("--simulate-experiments", default=0, type=int) 144 | 145 | # Algorithm parameters. 146 | parser.add_argument("--max-tries", default=1000, type=int) 147 | parser.add_argument("--thompson-temperature", default=4.0, type=float) 148 | parser.add_argument("--mcmc", default=False, action="store_true") 149 | parser.add_argument("--svi", dest="mcmc", action="store_false") 150 | parser.add_argument("--mcmc-num-samples", default=500, type=int) 151 | parser.add_argument("--mcmc-warmup-steps", default=500, type=int) 152 | parser.add_argument("--mcmc-num-chains", default=min(4, mp.cpu_count()), type=int) 153 | parser.add_argument("--svi-num-steps", default=201, type=int) 154 | parser.add_argument("--sa-num-steps", default=201, type=int) 155 | parser.add_argument("--jit", action="store_true") 156 | parser.add_argument("--nojit", dest="jit", action="store_false") 157 | parser.add_argument("--seed", default=20210929) 158 | parser.add_argument("--log-every", default=100, type=int) 159 | args = parser.parse_args() 160 | main(args) 161 | -------------------------------------------------------------------------------- /examples/oed_vs_rand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyro-ppl/pyroed/c549bd9dc9511e2199ff55fb2c86f84226b9b1c2/examples/oed_vs_rand.png -------------------------------------------------------------------------------- /examples/rollout_tf8.py: -------------------------------------------------------------------------------- 1 | """ 2 | We consider data from 3 | 4 | Barrera, Luis A., et al. "Survey of variation in human transcription factors reveals 5 | prevalent DNA binding changes." Science 351.6280 (2016): 1450-1454. 6 | 7 | for the PBX4 transcription factor. The dataset consists of measurements of the binding 8 | affinities of PBX4 to all possible DNA sequences of length 8, i.e. for a total of 9 | 4^8 = 65536 sequences. Since this dataset is exhaustive we can use it to do a 10 | semi-synthetic experiment in which we first "measure" a small number of binding 11 | affinities and then do additional "experiments" in multiple rounds. 12 | 13 | In the script below we build a pipeline to run multiple trials of such roll-out 14 | experiments under different parameter settings so we can assess whether optimal 15 | experimental design (OED) is making our adaptive experiments more efficient. In 16 | particular we ask whether adaptive experiments are more efficient at identifying 17 | high-affinity DNA sequences than random experimentation in which designs 18 | (DNA sequences) are chosen at random. 19 | 20 | The results of this script are visualized here: 21 | https://github.com/pyro-ppl/pyroed/blob/main/examples/oed_vs_rand.png 22 | """ 23 | 24 | # type: ignore 25 | 26 | import argparse 27 | import pickle 28 | import time 29 | from collections import OrderedDict 30 | 31 | import pyro 32 | import torch 33 | 34 | from pyroed.datasets import load_tf_data 35 | from pyroed.oed import thompson_sample 36 | 37 | SCHEMA = OrderedDict() 38 | for n in range(8): 39 | SCHEMA[f"Nucleotide{n}"] = ["A", "C", "G", "T"] 40 | 41 | CONSTRAINTS = [] # No constraints. 42 | 43 | singletons = [[name] for name in SCHEMA] 44 | pairs = [list(ns) for ns in zip(SCHEMA, list(SCHEMA)[1:])] 45 | triples = [list(ns) for ns in zip(SCHEMA, list(SCHEMA)[1:], list(SCHEMA)[2:])] 46 | 47 | SINGLETON_BLOCKS = singletons 48 | PAIRWISE_BLOCKS = singletons + pairs 49 | GIBBS_BLOCKS = triples 50 | 51 | 52 | def update_experiment(experiment: dict, design: set, data: dict) -> dict: 53 | ids = list(map(data["seq_to_id"].__getitem__, sorted(design))) 54 | new_data = { 55 | "sequences": data["sequences"][ids], 56 | "responses": data["responses"][ids], 57 | "batch_ids": torch.zeros(len(ids)).long(), 58 | } 59 | experiment = {k: torch.cat([v, new_data[k]]) for k, v in experiment.items()} 60 | return experiment 61 | 62 | 63 | def make_design( 64 | experiment: dict, 65 | design_size: int, 66 | thompson_temperature: float, 67 | feature_blocks: list, 68 | ) -> set: 69 | return thompson_sample( 70 | SCHEMA, 71 | CONSTRAINTS, 72 | feature_blocks, 73 | GIBBS_BLOCKS, 74 | experiment, 75 | design_size=design_size, 76 | thompson_temperature=thompson_temperature, 77 | inference="svi", 78 | svi_num_steps=1000, 79 | sa_num_steps=400, 80 | log_every=0, 81 | jit_compile=False, 82 | ) 83 | 84 | 85 | def main(args): 86 | pyro.set_rng_seed(args.seed) 87 | 88 | data = load_tf_data() 89 | ids = torch.randperm(len(data["responses"]))[: args.num_initial_sequences] 90 | experiment = {k: v[ids] for k, v in data.items()} 91 | data["seq_to_id"] = { 92 | tuple(row): i for i, row in enumerate(data["sequences"].tolist()) 93 | } 94 | 95 | experiments = [experiment] 96 | best_response = experiment["responses"].max().item() 97 | print("[0th batch] Best response thus far: {:0.6g}".format(best_response)) 98 | t0 = time.time() 99 | 100 | for batch in range(args.num_batches): 101 | design = make_design( 102 | experiments[-1], 103 | args.num_sequences_per_batch, 104 | args.thompson_temperature, 105 | SINGLETON_BLOCKS if args.features == "singleton" else PAIRWISE_BLOCKS, 106 | ) 107 | experiments.append(update_experiment(experiments[-1], design, data)) 108 | print( 109 | "[Batch #{}] Best response thus far: {:0.6g}".format( 110 | batch + 1, experiments[-1]["responses"].max().item() 111 | ) 112 | ) 113 | 114 | print( 115 | "Best response from all batches: {:0.6g}".format( 116 | experiments[-1]["responses"].max().item() 117 | ) 118 | ) 119 | print("Elapsed time: {:.4f}".format(time.time() - t0)) 120 | 121 | response_curve = [e["responses"].max().item() for e in experiments] 122 | 123 | f = "results.{}.s{}.temp{}.nb{}.nspb{}.nis{}.pkl" 124 | f = f.format( 125 | args.features, 126 | args.seed, 127 | int(args.thompson_temperature), 128 | args.num_batches, 129 | args.num_sequences_per_batch, 130 | args.num_initial_sequences, 131 | ) 132 | pickle.dump(response_curve, open(f, "wb")) 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description="Design sequences") 137 | 138 | parser.add_argument("--num-initial-sequences", default=30, type=int) 139 | parser.add_argument("--num-sequences-per-batch", default=10, type=int) 140 | parser.add_argument("--num-batches", default=7) 141 | parser.add_argument("--seed", default=0, type=int) 142 | parser.add_argument("--thompson-temperature", default=1.0, type=float) 143 | parser.add_argument( 144 | "--features", type=str, default="singleton", choices=["singleton", "pairwise"] 145 | ) 146 | 147 | args = parser.parse_args() 148 | 149 | main(args) 150 | -------------------------------------------------------------------------------- /examples/tf8_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "cfa88313", 6 | "metadata": {}, 7 | "source": [ 8 | "# Predicting transcription factor binding affinity" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bbce78c2", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from collections import OrderedDict\n", 19 | "from pprint import pprint\n", 20 | "import pyro\n", 21 | "import torch\n", 22 | "import numpy as np\n", 23 | "import matplotlib\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "from pyroed.datasets.data import load_tf_data\n", 27 | "from pyroed.constraints import AllDifferent, Iff, IfThen, TakesValue\n", 28 | "from pyroed.oed import thompson_sample\n", 29 | "from pyroed.testing import generate_fake_data\n", 30 | "\n", 31 | "matplotlib.rcParams[\"figure.facecolor\"] = \"white\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "3168d890", 37 | "metadata": {}, 38 | "source": [ 39 | "### Transcription factor data\n", 40 | "\n", 41 | "We consider data from [Survey of variation in human transcription factors reveals prevalent DNA binding changes](https://www.science.org/doi/abs/10.1126/science.aad2257),\n", 42 | "in particular for the PBX4 transcription factor.\n", 43 | "The dataset consists of measurements of the binding affinities of PBX4 to all\n", 44 | "possible DNA sequences of length 8, i.e. for a total of $4^8 = 65536$ sequences.\n", 45 | "Since this dataset is exhaustive we can use it to do a semi-synthetic experiment in which\n", 46 | "we first \"measure\" a small number of binding affinities and then do additional \"experiments\" in multiple rounds." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "id": "b9c5f1dc", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "data = load_tf_data(data_dir=\"../pyroed/datasets\")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "dbcf570c", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "sequences torch.int64 (65792, 8)\n", 70 | "responses torch.float32 (65792,)\n", 71 | "batch_ids torch.int64 (65792,)\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "for k, v in data.items():\n", 77 | " print(f\"{k} {v.dtype} {tuple(v.shape)}\")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "add4b8cb", 83 | "metadata": {}, 84 | "source": [ 85 | "Note that there are actually more than $65536$ data points due to some repeats." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "63917f48", 91 | "metadata": {}, 92 | "source": [ 93 | "Sequences take values 0,1,2,3 corresponding to nucleotides." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "id": "bf570325", 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/plain": [ 105 | "{0, 1, 2, 3}" 106 | ] 107 | }, 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "set(data[\"sequences\"].reshape(-1).tolist())" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "fb9d3635", 120 | "metadata": {}, 121 | "source": [ 122 | "The response variable appears to be approximately Gaussian distributed." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 5, 128 | "id": "e1ebbbac", 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZj0lEQVR4nO3df0yd5f3/8ecpTJJ1GO0EdtgNFj0M4Uit9hSLyVyVna3Hr4Np3QmdhsO0IcIWZvbHwj+Lc8nG6ZIlNo6wnYSYw36dhCUDo/Uwp3bLGpEd1G6CP06UGs7ZGdCBQx1ii/f3D7/fs48fYJwWOJRer0fSpFzcF+d6p82LK9d93dftsG3bRkREjLBtswcgIiLZo9AXETGIQl9ExCAKfRERgyj0RUQMotAXETHIqqH/2muvsXv37vSfSy+9lIcffpiZmRm8Xi/l5eV4vV5mZ2fTfTo7O3G5XFRUVDA4OJhuHxkZobq6GpfLRXt7O9otKiKSXY5z2ae/uLjIZz/7WZ5//nm6urrYsWMHHR0dBINBZmdnOXLkCGNjYxw6dIjh4WH+/ve/88UvfpHXX3+dnJwcampqOHr0KPv27eO2226jvb0dn8+3kfWJiMj/kHsuFz/99NNcffXVXHnllQwMDHD8+HEAAoEA+/fv58iRIwwMDNDY2EheXh5lZWW4XC6Gh4fZuXMnc3Nz1NbWAtDU1ER/f/+qoX/FFVewc+fO8ypORMRUp06d4vTp00vazyn0I5EIhw4dAmBychKn0wmA0+lkamoKgGQyyb59+9J9LMsimUzyiU98AsuylrQvJxQKEQqFANi+fTuxWOxchikiYjyPx7Nse8Y3cj/44AMee+wxvva1r/3X65ZbLXI4HCu2L6elpYVYLEYsFqOgoCDTIYqIyCoyDv0nn3ySG264gaKiIgCKiopIpVIApFIpCgsLgY9m8BMTE+l+iUSC4uJiLMsikUgsaRcRkezJOPR/85vfpJd2AOrr6wmHwwCEw2EaGhrS7ZFIhIWFBcbHx4nH49TU1OB0OsnPz2doaAjbtunt7U33ERGR7MhoTf/f//43Tz31FD//+c/TbR0dHfj9fnp6eigtLaWvrw8At9uN3++nqqqK3Nxcurq6yMnJAaC7u5vm5mbm5+fx+XzauSMikmXntGVzM3g8Ht3IFRE5Rytlp57IFRExiEJfRMQgCn0REYMo9EVEDHJOT+SKXIh2djxx3n1PBf/POo5E5MKnmb6IiEEU+iIiBtHyjhhNS0NiGs30RUQMotAXETGIQl9ExCAKfRERgyj0RUQMot07sunWsoNGRM6NQl/kPGm7p2xFWt4RETGIQl9ExCAKfRERgyj0RUQMotAXETGIQl9ExCAKfRERg2QU+m+//TZ33XUX11xzDZWVlTz33HPMzMzg9XopLy/H6/UyOzubvr6zsxOXy0VFRQWDg4Pp9pGREaqrq3G5XLS3t2Pb9vpXJCIiK8oo9L/97W9z4MABXn31VU6ePEllZSXBYJC6ujri8Th1dXUEg0EAxsbGiEQijI6OEo1GaWtrY3FxEYDW1lZCoRDxeJx4PE40Gt24ykREZIlVQ39ubo4//elP3HfffQBccsklXHbZZQwMDBAIBAAIBAL09/cDMDAwQGNjI3l5eZSVleFyuRgeHiaVSjE3N0dtbS0Oh4OmpqZ0HxERyY5VQ//NN9+koKCAb3zjG1x//fUcPnyY9957j8nJSZxOJwBOp5OpqSkAkskkJSUl6f6WZZFMJkkmk1iWtaR9OaFQCI/Hg8fjYXp6ek0FiojIf6wa+mfPnuWFF16gtbWVF198ke3bt6eXcpaz3Dq9w+FYsX05LS0txGIxYrEYBQUFqw1RREQytOqBa5ZlYVkWN954IwB33XUXwWCQoqIiUqkUTqeTVCpFYWFh+vqJiYl0/0QiQXFxMZZlkUgklrSLmGitJ4vqwDY5X6vO9D/zmc9QUlLCa6+9BsDTTz9NVVUV9fX1hMNhAMLhMA0NDQDU19cTiURYWFhgfHyceDxOTU0NTqeT/Px8hoaGsG2b3t7edB8REcmOjI5WfuSRR7j77rv54IMPuOqqq3j00Uf58MMP8fv99PT0UFpaSl9fHwButxu/309VVRW5ubl0dXWRk5MDQHd3N83NzczPz+Pz+fD5fBtXmYiILOGwL/DN8h6Ph1gsttnDkA2kl6icOy3vyGpWyk49kSsiYhCFvoiIQRT6IiIGUeiLiBhEL0aXdaGbsSJbg2b6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCFvoiIQRT6IiIGUeiLiBhEoS8iYhCdpy+yBa3l/QV6qbrZNNMXETFIRqG/c+dOqqur2b17Nx6PB4CZmRm8Xi/l5eV4vV5mZ2fT13d2duJyuaioqGBwcDDdPjIyQnV1NS6Xi/b2dmzbXudyRETkv8l4pv/ss8/y0ksvEYvFAAgGg9TV1RGPx6mrqyMYDAIwNjZGJBJhdHSUaDRKW1sbi4uLALS2thIKhYjH48TjcaLR6AaUJCIiKznv5Z2BgQECgQAAgUCA/v7+dHtjYyN5eXmUlZXhcrkYHh4mlUoxNzdHbW0tDoeDpqamdB8REcmOjELf4XDwpS99iT179hAKhQCYnJzE6XQC4HQ6mZqaAiCZTFJSUpLua1kWyWSSZDKJZVlL2pcTCoXweDx4PB6mp6fPrzIREVkio907J06coLi4mKmpKbxeL9dcc82K1y63Tu9wOFZsX05LSwstLS0A6XsIIiKydhnN9IuLiwEoLCzkjjvuYHh4mKKiIlKpFACpVIrCwkLgoxn8xMREum8ikaC4uBjLskgkEkvaRUQke1YN/ffee4933nkn/fff//73XHvttdTX1xMOhwEIh8M0NDQAUF9fTyQSYWFhgfHxceLxODU1NTidTvLz8xkaGsK2bXp7e9N9REQkO1Zd3pmcnOSOO+4A4OzZs3z961/nwIED7N27F7/fT09PD6WlpfT19QHgdrvx+/1UVVWRm5tLV1cXOTk5AHR3d9Pc3Mz8/Dw+nw+fz7eBpYmIyP/msC/wzfIejye9TVQuXGt5QlSyS0/kmmGl7NQTuSIiBlHoi4gYRKEvImIQhb6IiEEU+iIiBlHoi4gYRKEvImIQhb6IiEH0ukRJ0wNWIhc/zfRFRAyi0BcRMYhCX0TEIFrTFzHMWu7d6LC2rU8zfRERgyj0RUQMotAXETGIQl9ExCAKfRERgyj0RUQMotAXETGIQl9ExCAKfRERg2Qc+ouLi1x//fXcfvvtAMzMzOD1eikvL8fr9TI7O5u+trOzE5fLRUVFBYODg+n2kZERqqurcblctLe3Y9v2OpYiIiKryTj0jx49SmVlZfrrYDBIXV0d8Xicuro6gsEgAGNjY0QiEUZHR4lGo7S1tbG4uAhAa2sroVCIeDxOPB4nGo2uczkiIvLfZBT6iUSCJ554gsOHD6fbBgYGCAQCAAQCAfr7+9PtjY2N5OXlUVZWhsvlYnh4mFQqxdzcHLW1tTgcDpqamtJ9REQkOzIK/QceeIAf//jHbNv2n8snJydxOp0AOJ1OpqamAEgmk5SUlKSvsyyLZDJJMpnEsqwl7csJhUJ4PB48Hg/T09PnXpWIiCxr1dB//PHHKSwsZM+ePRn9wOXW6R0Ox4rty2lpaSEWixGLxSgoKMjoc0VEZHWrHq184sQJHnvsMY4dO8b777/P3Nwc99xzD0VFRaRSKZxOJ6lUisLCQuCjGfzExES6fyKRoLi4GMuySCQSS9pFRCR7Vp3pd3Z2kkgkOHXqFJFIhFtvvZVf/vKX1NfXEw6HAQiHwzQ0NABQX19PJBJhYWGB8fFx4vE4NTU1OJ1O8vPzGRoawrZtent7031ERCQ7zvslKh0dHfj9fnp6eigtLaWvrw8At9uN3++nqqqK3Nxcurq6yMnJAaC7u5vm5mbm5+fx+Xz4fL71qUJERDLisC/wzfIej4dYLLbZwzDCWt6oJGbQm7O2jpWyU0/kiogYRKEvImIQhb6IiEEU+iIiBlHoi4gYRKEvImIQhb6IiEEU+iIiBlHoi4gYRKEvImIQhb6IiEEU+iIiBlHoi4gYRKEvImKQ8z5PX0TMs5bjt3Us84VBM30REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExyKqh//7771NTU8N1112H2+3mwQcfBGBmZgav10t5eTler5fZ2dl0n87OTlwuFxUVFQwODqbbR0ZGqK6uxuVy0d7ejm3bG1CSiIisZNXQz8vL45lnnuHkyZO89NJLRKNRhoaGCAaD1NXVEY/HqaurIxgMAjA2NkYkEmF0dJRoNEpbWxuLi4sAtLa2EgqFiMfjxONxotHoxlYnIiIfs+oTuQ6Hg0996lMAnDlzhjNnzuBwOBgYGOD48eMABAIB9u/fz5EjRxgYGKCxsZG8vDzKyspwuVwMDw+zc+dO5ubmqK2tBaCpqYn+/n58Pt/GVWegtTwxKSIXv4zW9BcXF9m9ezeFhYV4vV5uvPFGJicncTqdADidTqampgBIJpOUlJSk+1qWRTKZJJlMYlnWkvblhEIhPB4PHo+H6enp8y5OREQ+LqPQz8nJ4aWXXiKRSDA8PMzLL7+84rXLrdM7HI4V25fT0tJCLBYjFotRUFCQyRBFRCQD57R757LLLmP//v1Eo1GKiopIpVIApFIpCgsLgY9m8BMTE+k+iUSC4uJiLMsikUgsaRcRkexZNfSnp6d5++23AZifn+cPf/gD11xzDfX19YTDYQDC4TANDQ0A1NfXE4lEWFhYYHx8nHg8Tk1NDU6nk/z8fIaGhrBtm97e3nQfERHJjlVv5KZSKQKBAIuLi3z44Yf4/X5uv/12amtr8fv99PT0UFpaSl9fHwButxu/309VVRW5ubl0dXWRk5MDQHd3N83NzczPz+Pz+XQTV0Qkyxz2Bb5Z3uPxEIvFNnsYW4Z278iFSufpZ9dK2aknckVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgq74YXURkPazl/c16v+760UxfRMQgCn0REYOsGvoTExPccsstVFZW4na7OXr0KAAzMzN4vV7Ky8vxer3Mzs6m+3R2duJyuaioqGBwcDDdPjIyQnV1NS6Xi/b2dmzb3oCSRERkJauGfm5uLj/5yU945ZVXGBoaoquri7GxMYLBIHV1dcTjcerq6ggGgwCMjY0RiUQYHR0lGo3S1tbG4uIiAK2trYRCIeLxOPF4nGg0urHViYjIx6wa+k6nkxtuuAGA/Px8KisrSSaTDAwMEAgEAAgEAvT39wMwMDBAY2MjeXl5lJWV4XK5GB4eJpVKMTc3R21tLQ6Hg6ampnQfERHJjnNa0z916hQvvvgiN954I5OTkzidTuCjXwxTU1MAJJNJSkpK0n0syyKZTJJMJrEsa0n7ckKhEB6PB4/Hw/T09DkXJSIiy8s49N99910OHjzIww8/zKWXXrridcut0zscjhXbl9PS0kIsFiMWi1FQUJDpEEVEZBUZhf6ZM2c4ePAgd999N3feeScARUVFpFIpAFKpFIWFhcBHM/iJiYl030QiQXFxMZZlkUgklrSLiEj2rBr6tm1z3333UVlZyXe+8510e319PeFwGIBwOExDQ0O6PRKJsLCwwPj4OPF4nJqaGpxOJ/n5+QwNDWHbNr29vek+IiKSHas+kXvixAl+8YtfUF1dze7duwH40Y9+REdHB36/n56eHkpLS+nr6wPA7Xbj9/upqqoiNzeXrq4ucnJyAOju7qa5uZn5+Xl8Ph8+n28DSxMRkf/NYV/gm+U9Hg+xWGyzh7FlrOVRd5ELlY5hOHcrZafO3rkAKbhFZKPoGAYREYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYPozVkicsFb69vk9LrF/9BMX0TEIAp9ERGDKPRFRAyyaujfe++9FBYWcu2116bbZmZm8Hq9lJeX4/V6mZ2dTX+vs7MTl8tFRUUFg4OD6faRkRGqq6txuVy0t7dj2/Y6lyIiIqtZNfSbm5uJRqMfawsGg9TV1RGPx6mrqyMYDAIwNjZGJBJhdHSUaDRKW1sbi4uLALS2thIKhYjH48Tj8SU/U0RENt6qoX/zzTezY8eOj7UNDAwQCAQACAQC9Pf3p9sbGxvJy8ujrKwMl8vF8PAwqVSKubk5amtrcTgcNDU1pfuIiEj2nNea/uTkJE6nEwCn08nU1BQAyWSSkpKS9HWWZZFMJkkmk1iWtaR9JaFQCI/Hg8fjYXp6+nyGKCIiy1jXG7nLrdM7HI4V21fS0tJCLBYjFotRUFCwnkMUETHaeYV+UVERqVQKgFQqRWFhIfDRDH5iYiJ9XSKRoLi4GMuySCQSS9pFRCS7ziv06+vrCYfDAITDYRoaGtLtkUiEhYUFxsfHicfj1NTU4HQ6yc/PZ2hoCNu26e3tTfcREZHsWfUYhkOHDnH8+HFOnz6NZVk89NBDdHR04Pf76enpobS0lL6+PgDcbjd+v5+qqipyc3Pp6uoiJycHgO7ubpqbm5mfn8fn8+Hz+Ta2sk201kfGRUQ2isO+wDfMezweYrHYZg/jnCj0RS4sJp69s1J26olcERGDKPRFRAyi0BcRMYhCX0TEIAp9ERGD6M1ZInLRW8uOuott549m+iIiBlHoi4gYRKEvImIQhb6IiEEU+iIiBlHoi4gYRKEvImIQ7dMXEfkvLrY9/prpi4gYRKEvImIQLe+sQC9CEZGLkWb6IiIGUeiLiBhEoS8iYhCFvoiIQXQjV0Rkg1yIe/yzPtOPRqNUVFTgcrkIBoPZ/ngREaNldaa/uLjIN7/5TZ566iksy2Lv3r3U19dTVVW1IZ+nbZciIh+X1Zn+8PAwLpeLq666iksuuYTGxkYGBgayOQQREaNldaafTCYpKSlJf21ZFs8///yS60KhEKFQCIBXX30Vj8dzXp9nT09TUFBwfoPdoqZVsxFMq9m0egGuvLJtTTWfOnVq2fashr5t20vaHA7HkraWlhZaWlrW/Hkej4dYLLbmn7OVqGYzmFazafXCxtWc1eUdy7KYmJhIf51IJCguLs7mEEREjJbV0N+7dy/xeJzx8XE++OADIpEI9fX12RyCiIjRcr7//e9/P1sftm3bNsrLy7nnnnt45JFHuOeeezh48OCGfuaePXs29OdfiFSzGUyr2bR6YWNqdtjLLbSLiMhFSccwiIgYRKEvImKQiyL0VzvawbZt2tvbcblc7Nq1ixdeeGETRrl+Vqv3V7/6Fbt27WLXrl3cdNNNnDx5chNGub4yPb7jL3/5Czk5Ofz2t7/N4ug2RiY1Hz9+nN27d+N2u/nCF76Q5RGuv9Vq/te//sVXvvIVrrvuOtxuN48++ugmjHL93HvvvRQWFnLttdcu+/0NyS57izt79qx91VVX2W+88Ya9sLBg79q1yx4dHf3YNU888YR94MAB+8MPP7Sfe+45u6amZpNGu3aZ1HvixAl7ZmbGtm3bPnbs2Jau17Yzq/n/X3fLLbfYPp/P7uvr24SRrp9Map6dnbUrKyvtt956y7Zt256cnNyMoa6bTGr+4Q9/aH/3u9+1bdu2p6am7Msvv9xeWFjYjOGuiz/+8Y/2yMiI7Xa7l/3+RmTXlp/pZ3K0w8DAAE1NTTgcDvbt28fbb79NKpXapBGvTSb13nTTTVx++eUA7Nu3j0QisRlDXTeZHt/xyCOPcPDgQQoLCzdhlOsrk5p//etfc+edd1JaWgqw5evOpGaHw8E777yDbdu8++677Nixg9zcrXtY8M0338yOHTtW/P5GZNeWD/3ljnZIJpPnfM1Wca619PT04PP5sjG0DZPpv/Hvfvc77r///mwPb0NkUvPrr7/O7Ows+/fvZ8+ePfT29mZ7mOsqk5q/9a1v8corr1BcXEx1dTVHjx5l27YtH2Mr2ojs2rq/Iv8fO4OjHTK5Zqs4l1qeffZZenp6+POf/7zRw9pQmdT8wAMPcOTIEXJycrI1rA2VSc1nz55lZGSEp59+mvn5eWpra9m3bx+f+9znsjXMdZVJzYODg+zevZtnnnmGN954A6/Xy+c//3kuvfTSbA0zqzYiu7Z86GdytMPFdPxDprX89a9/5fDhwzz55JN8+tOfzuYQ110mNcdiMRobGwE4ffo0x44dIzc3l69+9atZHet6yfT/9RVXXMH27dvZvn07N998MydPntyyoZ9JzY8++igdHR04HA5cLhdlZWW8+uqr1NTUZHu4WbEh2bXmuwKb7MyZM3ZZWZn95ptvpm/+vPzyyx+75vHHH//YzZC9e/du0mjXLpN633rrLfvqq6+2T5w4sUmjXF+Z1Pw/BQKBLX8jN5Oax8bG7FtvvdU+c+aM/d5779lut9v+29/+tkkjXrtMar7//vvtBx980LZt2/7HP/5hFxcX29PT05sw2vUzPj6+4o3cjciuLT/Tz83N5ac//Slf/vKXWVxc5N5778XtdvOzn/0MgPvvv5/bbruNY8eO4XK5+OQnP7mlt3llUu8PfvAD/vnPf9LW1pbus5VPKMyk5otNJjVXVlZy4MABdu3axbZt2zh8+PCKW/+2gkxq/t73vkdzczPV1dXYts2RI0e44oorNnnk5+/QoUMcP36c06dPY1kWDz30EGfOnAE2Lrt0DIOIiEEu3tveIiKyhEJfRMQgCn0REYMo9EVEDKLQFxExiEJfRMQgCn0REYP8X+DB2OLe9oJNAAAAAElFTkSuQmCC\n", 134 | "text/plain": [ 135 | "
" 136 | ] 137 | }, 138 | "metadata": {}, 139 | "output_type": "display_data" 140 | } 141 | ], 142 | "source": [ 143 | "plt.hist(data[\"responses\"].numpy(), bins=20)\n", 144 | "plt.show()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "f0369fac", 150 | "metadata": {}, 151 | "source": [ 152 | "## Modeling in pyroed\n", 153 | "\n", 154 | "Specify the design space via SCHEMA, CONSTRAINTS, FEATURE_BLOCKS, and GIBBS_BLOCKS." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "id": "a5568f60", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "SCHEMA = OrderedDict()\n", 165 | "for n in range(8):\n", 166 | " SCHEMA[f\"Nucleotide{n}\"] = [\"A\", \"C\", \"G\", \"T\"]\n", 167 | "\n", 168 | "CONSTRAINTS = [] # No constraints.\n", 169 | "\n", 170 | "singletons = [[name] for name in SCHEMA]\n", 171 | "pairs = [list(ns) for ns in zip(SCHEMA, list(SCHEMA)[1:])]\n", 172 | "triples = [list(ns) for ns in zip(SCHEMA, list(SCHEMA)[1:], list(SCHEMA)[2:])]\n", 173 | "\n", 174 | "FEATURE_BLOCKS = singletons + pairs\n", 175 | "GIBBS_BLOCKS = triples" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "53014fc6", 181 | "metadata": {}, 182 | "source": [ 183 | "Let's start with a random subsample of data." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "id": "921550d3", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "Best response = 0.574323\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "pyro.set_rng_seed(0)\n", 202 | "full_size = len(data[\"responses\"])\n", 203 | "batch_size = 10\n", 204 | "ids = torch.randperm(full_size)[:batch_size]\n", 205 | "experiment = {k: v[ids] for k, v in data.items()}\n", 206 | "print(f\"Best response = {experiment['responses'].max():0.6g}\")" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "a92d114c", 212 | "metadata": {}, 213 | "source": [ 214 | "Each step of the OED process we'll test on new data. Let's make a helper to simulate lab work." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 8, 220 | "id": "b56ba9d7", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "seq_to_id = {tuple(row): i for i, row in enumerate(data[\"sequences\"].tolist())}\n", 225 | "\n", 226 | "def update_experiment(experiment: dict, design: set) -> dict:\n", 227 | " batch_ids = experiment[\"batch_ids\"].max().item() + 1\n", 228 | " ids = list(map(seq_to_id.__getitem__, sorted(design)))\n", 229 | " new_data = {\n", 230 | " \"sequences\": data[\"sequences\"][ids],\n", 231 | " \"batch_ids\": torch.full((len(ids),), batch_ids),\n", 232 | " \"responses\": data[\"responses\"][ids],\n", 233 | " }\n", 234 | " experiment = {k: torch.cat([v, new_data[k]]) for k, v in experiment.items()}\n", 235 | " print(f\"Best response = {experiment['responses'].max():0.6g}\")\n", 236 | " return experiment\n", 237 | "\n", 238 | "def make_design(experiment: dict) -> set:\n", 239 | " return thompson_sample(\n", 240 | " SCHEMA,\n", 241 | " CONSTRAINTS,\n", 242 | " FEATURE_BLOCKS,\n", 243 | " GIBBS_BLOCKS,\n", 244 | " experiment,\n", 245 | " inference=\"svi\",\n", 246 | " svi_num_steps=201,\n", 247 | " sa_num_steps=201,\n", 248 | " log_every=0,\n", 249 | " jit_compile=False,\n", 250 | " )" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "id": "2d74b628", 256 | "metadata": {}, 257 | "source": [ 258 | "Initialize our sequence of experiments:" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 9, 264 | "id": "4b2c2ed7", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "experiments = [experiment]" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "id": "6bcabacb", 274 | "metadata": {}, 275 | "source": [ 276 | "Let's start with a single loop of active learning:" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 10, 282 | "id": "7b1ffbf8", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Design:\n", 290 | "{(0, 2, 3, 3, 1, 2, 1, 1),\n", 291 | " (0, 3, 0, 2, 0, 1, 3, 1),\n", 292 | " (1, 0, 3, 1, 0, 1, 0, 0),\n", 293 | " (1, 0, 3, 3, 3, 3, 1, 0),\n", 294 | " (1, 3, 0, 2, 0, 0, 3, 0),\n", 295 | " (2, 1, 1, 0, 2, 3, 2, 1),\n", 296 | " (2, 1, 3, 1, 3, 3, 3, 1),\n", 297 | " (2, 2, 0, 0, 2, 3, 1, 1),\n", 298 | " (3, 1, 3, 0, 0, 0, 3, 0),\n", 299 | " (3, 3, 3, 0, 0, 2, 3, 1)}\n", 300 | "Best response = 0.905897\n", 301 | "CPU times: user 4.91 s, sys: 17.3 ms, total: 4.93 s\n", 302 | "Wall time: 4.93 s\n" 303 | ] 304 | } 305 | ], 306 | "source": [ 307 | "%%time\n", 308 | "design = make_design(experiments[-1])\n", 309 | "print(\"Design:\")\n", 310 | "pprint(design)\n", 311 | "experiments.append(update_experiment(experiments[-1], design))" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "id": "8047d410", 317 | "metadata": {}, 318 | "source": [ 319 | "Let's run multiple loops, say 10 more loops." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 11, 325 | "id": "cf15268e", 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "Best response = 0.905897\n", 333 | "Best response = 0.905897\n", 334 | "Best response = 0.905897\n", 335 | "Best response = 0.94027\n", 336 | "Best response = 0.94027\n", 337 | "Best response = 0.959295\n", 338 | "Best response = 0.959295\n", 339 | "Best response = 0.959295\n", 340 | "Best response = 0.959295\n", 341 | "Best response = 0.959295\n", 342 | "CPU times: user 53.3 s, sys: 175 ms, total: 53.4 s\n", 343 | "Wall time: 53.5 s\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "%%time\n", 349 | "for step in range(10):\n", 350 | " design = make_design(experiments[-1])\n", 351 | " experiments.append(update_experiment(experiments[-1], design))" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "id": "0a479832", 357 | "metadata": {}, 358 | "source": [ 359 | "How did the response improve over time?" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 13, 365 | "id": "5b1d0c01", 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dfVhUZf4/8PfA8CAIKAIqDog4iDyjDJhrQepXwR6GVcsltVJyiUKzb2XRfjftu3WVrd/1oahl+fmQ5iqurUraRpqZ5UOyqOiWDwxPBmPLjAjIgAgznN8frFMjg4PGMMzM+3VdXNfOmfuc+Ryvdt5zzn2f+xYJgiCAiIjoFg6WLoCIiPonBgQRERnFgCAiIqMYEEREZBQDgoiIjBJbuoDe5OPjg6CgIEuXQURkNaqqqnDlyhWj79lUQAQFBaG4uNjSZRARWQ2ZTNbte7zFRERERjEgiIjIKAYEEREZxYAgIiKjGBBERGQUA4KIiIxiQBARkVE29RwEEd05QRBwpqYR35Sq0a7rsHQ5dBfcXMTITBrd68dlQBDZqXK1BgUll/FJiRJVdS0AAJHIwkXRXfEZ6MKAIKJf5t+Nrdh39jIKSi7jX8pGiETAr0YPwbP3S5EcOQxeA5wsXSL1I2YNiMLCQixduhQ6nQ6LFi1Cdna2wfv19fVIT09HeXk5XF1dsXHjRkRGRgLonDbDw8MDjo6OEIvFnEKD6C41Xm9H4Xc/oqDkMo5X1EEQgGiJF37/YBgejvHHUE9XS5dI/ZTZAkKn0yErKwsHDhyARCJBfHw85HI5wsPD9W3eeustxMbGYvfu3bhw4QKysrJw8OBB/fuHDh2Cj4+PuUokslmt7Tp8eUGFghIlDl1Qo03XgaAhbnhuSgjksf4Y7TvQ0iWSFTBbQBQVFUEqlSI4OBgAkJaWhoKCAoOAOHfuHF599VUAwNixY1FVVYXa2loMHTrUXGUR2Sxdh4Dj5XXYU6LE59/9G003tPD1cMH8e0YiNdYf0RIviNjJQHfAbAGhVCoREBCgfy2RSHDixAmDNjExMdi1axfuvfdeFBUV4dKlS6ipqcHQoUMhEokwffp0iEQiPP3008jIyDD6OXl5ecjLywMAqNVqc50OUb8kCALO1jRiT4kS+87+CHXTDXi4iJESOQypsSMwcfQQODowFOjumC0gBEHosu3WXy/Z2dlYunQpYmNjERUVhXHjxkEs7izp6NGj8Pf3h0qlwrRp0zB27FgkJiZ2OWZGRoY+PG43bS2RLalQa7DnZyOQnB0dMGWsH1Jj/TF5rB9cnRwtXSLZALMFhEQiQXV1tf51TU0N/P39Ddp4enpi06ZNADoDZdSoURg1ahQA6Nv6+flh5syZKCoqMhoQRPai9lor9p4xHIE0MZgjkMh8zBYQ8fHxUCgUqKysxIgRI5Cfn49t27YZtGloaICbmxucnZ2xfv16JCYmwtPTE83Nzejo6ICHhweam5uxf/9+LF++3FylEvVbxkYgRY3gCCTqG2YLCLFYjJycHCQnJ0On0yE9PR0RERHIzc0FAGRmZuL8+fN44okn4OjoiPDwcGzYsAEAUFtbi5kzZwIAtFot5s6di5SUFHOVSmRSu64DG45U4uvSvuvn0uoElFQ36EcgLZkSglSOQKI+JBKMdRZYKZlMxuclqNeVVDcg++9nceHfTYjw94Sbc9/d348c4YVfx47gCCQym9t9b/JJaqJuNN/Q4k/7S/HhsUr4ergg7/E4TI8YZumyiPoMA4LIiEMXVfj97u+gbLiO+fcE4uWUsfB0ZScw2RcGBNHPXNHcwBv7zqGg5DKkfgPxceZEyIK8LV0WkUUwIIjQOcz676eUePPTc2i+ocXz/xWCZ+4fDRcxnycg+8WAILv3Q10Lfrf7XzhSdgVxIwdj5awohAz1sHRZRBbHgCC7pdV1YOPRSqw+UAqxgwPe+HUk5iUEwoFTUxABYECQnfpO2YjsXWfxnfIapoUPxR9SIzDca4ClyyLqVxgQZFeut+mw5otSbDhSCW93Z/x53nikRA7jMwZERjAgyG4cUVzB73b/Cz9cbcFjCQHInhHG+YuIboMBQTavvrkNb356Hn8/VYNgH3fkZ9yDe4KHWLoson6PAUE2SxAEfHLmMv6w9xwar7djyRQpsiZLORU2UQ8xIMgm1dS34H92f4fDpWrEBgzCX2dHYewwT0uXRWRVGBBkU3QdAj48VoU/7b8IAFjxcDiemBjEVdWI7gIDgmzG+R+vIfvvZ3GmphFTxvrhjV9HYsQgDl0lulsMCLJ6re06vHtQgbyvKzDIzQnvPTYOD0UP59BVol+IAWEB11rb0abtsHQZNuHCj034/Z5/oaquBY/GSfA/D4ZhkJuzpcsisgkMiD52+od6zPrzMdjOMk2WN3KIG7YtmoBfSX0sXQqRTTFrQBQWFmLp0qXQ6XRYtGgRsrOzDd6vr69Heno6ysvL4erqio0bNyIyMrJH+1qrUz80QBCA3z8YBhexg6XLsXoDnMV4KHo4h64SmYHZAkKn0yErKwsHDhyARCJBfHw85HI5wsPD9W3eeustxMbGYvfu3bhw4QKysrJw8ODBHu1rrcpUGgxyc8JT947iPXIi6tfM9hO2qKgIUqkUwcHBcHZ2RlpaGgoKCgzanDt3DlOnTgUAjB07FlVVVaitre3RvtaqTNWEEL+BDAci6vfMFhBKpRIBAQH61xKJBEql0qBNTEwMdu3aBaAzUC5duoSampoe7XtTXl4eZDIZZDIZ1Gq1Gc6k9wiCAIVKA6kf1xogov7PbAEhGOmFvfVXc3Z2Nurr6xEbG4v33nsP48aNg1gs7tG+N2VkZKC4uBjFxcXw9fXtneLNpK65DQ0t7ZD6DbR0KUREJpmtD0IikaC6ulr/uqamBv7+/gZtPD09sWnTJgCdgTJq1CiMGjUKLS0tJve1RopaDQAghAFBRFbAbFcQ8fHxUCgUqKysRFtbG/Lz8yGXyw3aNDQ0oK2tDQCwfv16JCYmwtPTs0f7WqMy9X8CYigDgoj6P7NdQYjFYuTk5CA5ORk6nQ7p6emIiIhAbm4uACAzMxPnz5/HE088AUdHR4SHh2PDhg233dfaldU2YaCLGMM8XS1dChGRSSLB2A1/KyWTyVBcXGzpMro1b/230NzQoSBrkqVLISICcPvvTT6p1YcUtRpIfXl7iYisAwOijzReb4eq6Qb7H4jIajAg+kiZiiOYiMi6MCD6SJmqCQD4DAQRWQ0GRB8pU2ngInaAZLCbpUshIuoRBkQfUag0CPYdyKUvichqMCD6iKJWw/4HIrIqDIg+0NKmhbLhOgOCiKwKA6IPlKuaAbCDmoisCwOiD5SpO0cw8RkIIrImDIg+oKjVQOwgwsgh7pYuhYioxxgQfUCh0iDIxx1OjvznJiLrwW+sPlCu4ggmIrI+JgPi6tWrfVGHzbqh1aGqrpkd1ERkdUwGxIQJE/Doo4/iH//4h9GlQOn2qq60oEPgCCYisj4mA6K0tBQZGRn46KOPIJVK8bvf/Q6lpaV9UZtNUHAOJiKyUiYDQiQSYdq0adi+fTvWr1+PzZs3IyEhAUlJSTh+/Pht9y0sLERoaCikUilWrlzZ5f3GxkY8/PDDiImJQUREhH59agAICgpCVFQUYmNjIZPJ7uLU+gdFrQYiETCa60AQkZUxueRoXV0dtm7dio8++ghDhw7Fe++9B7lcjpKSEjz66KOorKw0up9Op0NWVhYOHDgAiUSC+Ph4yOVyhIeH69u8//77CA8Px969e6FWqxEaGop58+bB2dkZAHDo0CH4+Pj00qlaRplag0BvN7g6OVq6FCKiO2IyICZOnIjHH38ce/bsgUQi0W+XyWTIzMzsdr+ioiJIpVIEBwcDANLS0lBQUGAQECKRCE1NTRAEARqNBt7e3hCLzbZMtkWUcRU5IrJSJm8xvfnmm3jttdcMwmHnzp0AgFdeeaXb/ZRKJQICAvSvJRIJlEqlQZvFixfj/Pnz8Pf3R1RUFNatWwcHh86SRCIRpk+fjri4OOTl5XX7OXl5eZDJZJDJZFCr1aZOp09pdR2ovNIMKZ+gJiIrZDIgjPUdvP322yYPbGzEk0hkONX1559/jtjYWFy+fBklJSVYvHgxrl27BgA4evQoTp06hc8++wzvv/8+vv76a6Ofk5GRgeLiYhQXF8PX19dkXX3ph6staNN18AqCiKxSt/dzPvvsM/zjH/+AUqnEc889p99+7dq1Ht0GkkgkqK6u1r+uqamBv7+/QZtNmzYhOzsbIpEIUqkUo0aNwoULF5CQkKBv6+fnh5kzZ6KoqAiJiYl3fIKWpLi5zOhQDwtXQkR057q9gvD394dMJoOrqyvi4uL0f3K5HJ9//rnJA8fHx0OhUKCyshJtbW3Iz8+HXC43aBMYGIiDBw8CAGpra3Hx4kUEBwejubkZTU2dw0Obm5uxf/9+REZG/pLztIib61BziCsRWaNuLwViYmIQExODefPm3VXHsVgsRk5ODpKTk6HT6ZCeno6IiAjk5uYCADIzM/Haa69hwYIFiIqKgiAIeOedd+Dj44OKigrMnDkTAKDVajF37lykpKTc5SlaTplKg+FerhjoYlsd70RkH0RCN49Hz5kzB3/7298QFRXVpe8AAM6ePWv24u6UTCZDcXGxpcvQe/i9Ixjk5oSPnppg6VKIiIy63fdmtz9t161bBwDYt2+feaqycR0dAspUGqQlBJhuTETUD3UbEMOHDwcAjBw5ss+KsSXKhuu43q5DiB87qInIOpkc5rpr1y6EhITAy8sLnp6e8PDwgKenZ1/UZtXK1DdHMLGDmoisk8ne05dffhl79+5FWFhYX9RjM8pq/zOCic9AEJGVMnkFMXToUIbDXShTaeAz0BmD3Z0tXQoR0V0xeQUhk8nwm9/8Br/+9a/h4uKi3z5r1iyzFmbtFKomzuBKRFbNZEBcu3YNbm5u2L9/v36bSCRiQNyGIAhQqDRIjfU33ZiIqJ8yGRA/X6OBekbddANNrVqOYCIiq9ZtQPzxj3/Eyy+/jCVLlhh9UO7dd981a2HWTMEpNojIBnQbEDfXbbDm1dws5eYcTCEMCCKyYt0GxI4dO/DQQw+hoaEBS5cu7cuarJ5C1QRPVzF8PVxMNyYi6qe6HeZ68uRJXLp0CRs3bkR9fT2uXr1q8EfdU9RqIPUbaPTWHBGRtej2CiIzMxMpKSmoqKhAXFycwQJAIpEIFRUVfVKgNSpXazB17FBLl0FE9It0ewXx8MMP4/z580hPT0dFRQUqKyv1fwyH7l1tbsMVTRs7qInI6nUbEI888ggAoLS0tM+KsQX6RYI4BxMRWblubzF1dHTgf//3f1FaWorVq1d3ef+FF14wa2HWiiOYiMhWdHsFkZ+fD1dXV2i1WjQ1NXX5I+MUqiYMcHKEv9cAS5dCRPSLdHsFERoaildeeQXR0dGYMWPGXR28sLAQS5cuhU6nw6JFi5CdnW3wfmNjI+bPn48ffvgBWq0WL730EhYuXNijffurMlXnCCYHB45gIiLrZnKqjRkzZuDTTz/F999/j9bWVv325cuX33Y/nU6HrKwsHDhwABKJBPHx8ZDL5foH8ADg/fffR3h4OPbu3Qu1Wo3Q0FDMmzcPjo6OJvftr8pUGtwTPMTSZRAR/WImp/vOzMzEjh078N5770EQBOzcuROXLl0yeeCioiJIpVIEBwfD2dkZaWlpKCgoMGgjEonQ1NQEQRCg0Wjg7e0NsVjco337o6bWdvzY2MoRTERkE0wGxLFjx7BlyxYMHjwYK1aswPHjx1FdXW3ywEqlEgEBP63HLJFIoFQqDdosXrwY58+fh7+/P6KiorBu3To4ODj0aN+b8vLyIJPJIJPJoFarTdZlTuXqZgDsoCYi22AyIAYM6OxsdXNzw+XLl+Hk5ITKykqTB/75g3U33fpk8eeff47Y2FhcvnwZJSUlWLx4Ma5du9ajfW/KyMhAcXExiouL4evra7Iuc1LUdnbe8wqCiGyByYC4OR/TsmXLMH78eAQFBeGxxx4zeWCJRGJwpVFTUwN/f8P1ETZt2oRZs2ZBJBJBKpVi1KhRuHDhQo/27Y/K1Bo4Ozog0NvN0qUQEf1iJjupX3vtNQDA7Nmz8dBDD6G1tRVeXl4mDxwfHw+FQoHKykqMGDEC+fn52LZtm0GbwMBAHDx4EPfddx9qa2tx8eJFBAcHY9CgQSb37Y/KajUY5eMOsaPJ3CUi6vdMBsTPubi4GCw7etsDi8XIyclBcnIydDod0tPTERERgdzcXACdnd+vvfYaFixYgKioKAiCgHfeeQc+Pj4AYHTf/k6h0iBKYjo8iYisgUgwdsPfSslkMhQXF1vks1vbdQhbXoilU0Pw/H+NsUgNRER36nbfm7wX0kvK1RoIAjuoich2mLzFdOrUqS7bvLy8MHLkSIjFd3SHyqb9NAcT16EmIttg8hv+2WefxalTpxAdHQ1BEPDdd98hOjoadXV1yM3NxfTp0/uizn6vTKWBgwgI8uEIJiKyDSZvMQUFBeH06dMoLi7GyZMncfr0aURGRuKLL77Ayy+/3Bc1WgVFrQZBQ9zhIna0dClERL3CZEBcuHDBYARReHg4Tp8+jeDgYLMWZm3K1Br2PxCRTTF5iyk0NBTPPPMM0tLSAAA7duzAmDFjcOPGDTg5OZm9QGvQrutA1ZVmTA/nMqNEZDtMXkF8+OGHkEqlWLt2LdasWYPg4GB8+OGHcHJywqFDh/qixn7vUl0ztB0CQriKHBHZEJNXEAMGDMCLL76IF198sct7AwfyCxHo7H8AAKkvRzARke0wGRBHjx7F66+/jkuXLkGr1eq3V1RUmLUwa6L4zxDX0X7uFq6EiKj3mAyIp556CmvWrEFcXBwcHTlCx5gylQaSwQPg5sznQojIdpj8RvPy8rrrJUfthULFEUxEZHtMBsTkyZOxbNkyzJo1y2CivvHjx5u1MGuh6xBQodbgXimXGSUi22IyIE6cOAEABpM5iUQifPnll+aryorU1LfghraDVxBEZHNMBgSHst6efgQT52AiIhvTbUBs3boV8+fPx+rVq42+/8ILL5itKGtSpr4ZELyCICLb0m1ANDc3AwCampr6rBhrpKjVwM/DBV4D+FQ5EdmWbgPi6aefBgCsWLGiz4qxRmVqDZ+gJiKb1G1APPfcc7fd8d133zV58MLCQixduhQ6nQ6LFi1Cdna2wfurVq3CX//6VwCAVqvF+fPnoVar4e3tjaCgIHh4eMDR0RFisdhiK8XdjiAIKKttwiNxEkuXQkTU67qdiykuLg5xcXFobW3FqVOnEBISgpCQEJSUlPTogTmdToesrCx89tlnOHfuHLZv345z584ZtFm2bBlKSkpQUlKCt99+G0lJSfD29ta/f+jQIZSUlPTLcACAHxtb0dymg3QoO6iJyPZ0ewXx5JNPAuicrO/QoUP6mVszMzN7tEhQUVERpFKpflrwtLQ0FBQUIDw83Gj77du347HHHrvjE7Ckn1aR4y0mIrI9JmdzvXz5skFHtUajweXLl00eWKlUIiAgQP9aIpFAqVQabdvS0oLCwkLMnj1bv00kEmH69OmIi4tDXl5et5+Tl5cHmUwGmUwGtVptsq7edHMOJo5gIiJbZPI5iOzsbIwbNw6TJ08GABw+fBivv/66yQMLgtBlm0gkMtp27969mDRpksHtpaNHj8Lf3x8qlQrTpk3D2LFjkZiY2GXfjIwMZGRkAABkMpnJunpTmUqDwW5OGOLu3KefS0TUF0wGxMKFCzFjxgz9E9UrV67EsGHDTB5YIpGgurpa/7qmpgb+/v5G2+bn53e5vXSzrZ+fH2bOnImioiKjAWFJZaomhPh5dBt8RETWzOQtJqCzw9nX1xeDBw9GaWkpvv76a5P7xMfHQ6FQoLKyEm1tbcjPz4dcLu/SrrGxEYcPH0Zqaqp+W3Nzs/62VnNzM/bv34/IyMienlOfEAQBCpUGo3l7iYhslMkriFdeeQU7duxAREQEHBw680QkEpn8NS8Wi5GTk4Pk5GTodDqkp6cjIiICubm5ADo7uwFg9+7dmD59Otzdf1pLoba2FjNnzgTQOfx17ty5SElJubszNJO65jY0tLSzg5qIbJZIMNZZ8DOhoaE4e/aswUyu/ZVMJuuzIbHHy+vw2P/7FlvSE5A4xrdPPpOIqLfd7nvT5C2m4OBgtLe393pR1u7mHEx8ipqIbJXJW0xubm6IjY3F1KlTDa4ievIktS0rq23CQBcxhnm6WroUIiKzMBkQcrncaOeyvbvZQc0RTERkq0wGxM0nqslQmUrDvgcismndBsScOXPwt7/9DVFRUUZ/JZ89e9ashfVnjS3tUDXd4BPURGTTug2IdevWAQD27dvXZ8VYizJ15zMaHOJKRLas21FMw4cPBwCMHDkSLi4uOHPmjH6468iRI/uswP7op0n6OIsrEdkuk8Nc169fj4SEBOzatQsff/wx7rnnHmzcuLEvauu3FLUauIgdMGLwAEuXQkRkNiY7qVetWoXTp09jyJAhAIC6ujr86le/Qnp6utmL66/K1BqM9h0IRweOYCIi22XyCkIikcDD46dbKR4eHgbTeNsjRa2GHdREZPO6vYJYvXo1AGDEiBGYMGECUlNTIRKJUFBQgISEhD4rsL9pvqGFsuE60uLtOySJyPZ1GxA3Z1MdPXo0Ro8erd/+81lX7VGFuhkAp9ggItvXbUCsWLGiL+uwGgpVZ3DyFhMR2boerQdBPylTaSB2EGHkEHfTjYmIrBgD4g4pVBoE+bjDyZH/dERk20x+yx09erRH2+xFmUrDJ6iJyC6YDIglS5b0aJsxhYWFCA0NhVQqxcqVK7u8v2rVKsTGxiI2NhaRkZFwdHTE1atXe7SvJdzQ6nCprpkBQUR2odtO6uPHj+PYsWNQq9X6Ia8AcO3aNeh0OpMH1ul0yMrKwoEDByCRSBAfHw+5XI7w8HB9m2XLlmHZsmUAgL1792LNmjXw9vbu0b6WUHmlGR0CuA41EdmFbq8g2traoNFooNVq0dTUpP/z9PTExx9/bPLARUVFkEqlCA4OhrOzM9LS0lBQUNBt++3bt+Oxxx67q337CudgIiJ70u0VRFJSEpKSkrBgwQL95HwdHR3QaDTw9PQ0eWClUmnwxLVEIsGJEyeMtm1paUFhYSFycnLueN+8vDzk5eUBANRqtcm6fglFrQYiERDsyxFMRGT7TPZBvPrqq7h27Rqam5sRHh6O0NBQrFq1yuSBBUHosq271df27t2LSZMmwdvb+473zcjIQHFxMYqLi+Hra94FfMpUGgR6u8HVydGsn0NE1B+YDIhz587B09MTe/bswQMPPIAffvgBH330kckDSyQSVFdX61/X1NTA39/faNv8/Hz97aU73bcvcQQTEdkTkwHR3t6O9vZ27NmzB6mpqXBycurROszx8fFQKBSorKxEW1sb8vPzja5t3djYiMOHDxtM4dHTffuSVteBiisadlATkd0wOd33008/jaCgIMTExCAxMRGXLl3qUR+EWCxGTk4OkpOTodPpkJ6ejoiICOTm5gIAMjMzAQC7d+/G9OnT4e7ubnJfS/rhagvadQI7qInIbogEYzf8TdBqtRCLTWZLn5PJZCguLjbLsT///t94+qOT2JM1CbEBg8zyGUREfe1235smbzHV1tbiqaeewowZMwB09kls3ry5dyu0AjeHuHKSPiKyFyYDYsGCBUhOTsbly5cBAGPGjMHatWvNXlh/U6bSwN/LFQNd+t+VExGROZgMiCtXrmDOnDlwcOhsKhaL4ehof8M8FaomdlATkV0xGRDu7u6oq6vTj1z69ttv4eXlZfbC+pOODgHlqmZ2UBORXTF5v2T16tWQy+UoLy/HpEmToFarezTVhi1RNlzH9XYd+x+IyK6YDIjx48fj8OHDuHjxIgRBQGhoKJycnPqitn5DPwcTlxklIjtiMiBaW1vxwQcf4MiRIxCJRLjvvvuQmZkJV1fXvqivX9CPYPJlQBCR/TAZEE888QQ8PDz0a0Bs374djz/+OHbu3Gn24voLhaoJPgOdMdjd2dKlEBH1GZMBcfHiRZw5c0b/evLkyYiJiTFrUf1NmUrD/gcisjsmRzGNGzcO3377rf71iRMnMGnSJLMW1Z8IggCFSsMRTERkd7q9goiKioJIJEJ7ezu2bNmCwMBAiEQiXLp0yeIru/UlVdMNNLVqeQVBRHan24DYt29fX9bRb/20ihwDgojsS7cBcXMVOXunqG0CwDmYiMj+mOyDsHdlag08XcXw9XCxdClERH2KAWGColaDkKEePVokiYjIljAgTChTafiAHBHZJQbEbVxtbkNdcxun2CAiu2TWgCgsLERoaCikUilWrlxptM1XX32F2NhYREREICkpSb89KCgIUVFRiI2NhUwmM2eZ3bo5gonTfBORPTLb6jc6nQ5ZWVk4cOAAJBIJ4uPjIZfLDZ6haGhowLPPPovCwkIEBgZCpVIZHOPQoUPw8fExV4kmcYgrEdkzs11BFBUVQSqVIjg4GM7OzkhLS0NBQYFBm23btmHWrFkIDAwEAPj5+ZmrnLuiUDXBzdkR/l4DLF0KEVGfM1tAKJVKBAQE6F9LJBIolUqDNqWlpaivr8f999+PuLg4bNmyRf+eSCTC9OnTERcXh7y8vG4/Jy8vDzKZDDKZDGq1ulfPoUylwWjfgXBw4AgmIrI/ZrvFJAhCl223DhXVarU4efIkDh48iOvXr2PixIm45557MGbMGBw9ehT+/v5QqVSYNm0axo4di8TExC7HzMjIQEZGBgD0el9FmUqDicFDevWYRETWwmxXEBKJBNXV1frXNTU18Pf379ImJSUF7u7u8PHxQWJion7m2Jtt/fz8MHPmTBQVFZmrVKOaWtvxY2MrO6iJyG6ZLSDi4+OhUChQWVmJtrY25OfnQy6XG7RJTU3FN998A61Wi5aWFpw4cQJhYWFobm5GU1PnFBfNzc3Yv38/IiMjzVWqUeXqZgDsoCYi+2W2W0xisRg5OTlITk6GTqdDeno6IiIikJubCwDIzMxEWFgYUlJSEB0dDQcHByxatAiRkZGoqKjAzJkzAXTehpo7dy5SUlLMVapRN+dgChnKab6JyCyiVKkAAA78SURBVD6JBGOdBVZKJpOhuLi4V4719j/OY9PRKpz7QzLEjnyekIhs0+2+N/nN140ylQbBvu4MByKyW/z264ZCpWEHNRHZNQaEEa3tOlTXt7CDmojsGgPCiHK1BoIArkNNRHaNAWHEzTmYuIocEdkzBoQRZSoNHB1ECPJxs3QpREQWw4AwQlGrwUhvN7iIHS1dChGRxTAgjFComnh7iYjsHgPiFm3aDlyqa+EqckRk9xgQt7hU1wxth8ArCCKyewyIW/y0ihyHuBKRfWNA3ELxn4AI9nW3cCVERJbFgLiFQqWBZPAAuDmbbaJbIiKrwIC4RZlKwyk2iIjAgDCg6xBQrtawg5qICAwIAzX1LWjTdrCDmogIDAgDitrODmpO801EZOaAKCwsRGhoKKRSKVauXGm0zVdffYXY2FhEREQgKSnpjvbtbQpO0kdEpGe2oTo6nQ5ZWVk4cOAAJBIJ4uPjIZfLER4erm/T0NCAZ599FoWFhQgMDIRKperxvuZQptJgqKcLvAY4mfVziIisgdmuIIqKiiCVShEcHAxnZ2ekpaWhoKDAoM22bdswa9YsBAYGAgD8/Px6vK85lHEOJiIiPbMFhFKpREBAgP61RCKBUqk0aFNaWor6+nrcf//9iIuLw5YtW3q87015eXmQyWSQyWRQq9V3Xa8gCP8Z4soOaiIiwIy3mARB6LJNJBIZvNZqtTh58iQOHjyI69evY+LEibjnnnt6tO9NGRkZyMjIAADIZLK7rvfHxlY0t+l4BUFE9B9mCwiJRILq6mr965qaGvj7+3dp4+PjA3d3d7i7uyMxMRFnzpzp0b69jR3URESGzHaLKT4+HgqFApWVlWhra0N+fj7kcrlBm9TUVHzzzTfQarVoaWnBiRMnEBYW1qN9e9tPk/QxIIiIADNeQYjFYuTk5CA5ORk6nQ7p6emIiIhAbm4uACAzMxNhYWFISUlBdHQ0HBwcsGjRIkRGRgKA0X3NqUzVhMFuThgy0MWsn0NEZC1EgrEb/lZKJpOhuLj4rvZ9NPcYRBDhb5kTe7kqIqL+63bfm3ySGp0d6qW1Gki5ihwRkR4DAsAVTRsar7dD6suAICK6iQGBn3VQ8wqCiEiPAYHODmqAQ1yJiH6OAYHOK4iBLmIM83S1dClERP0GAwKdD8lJ/QZ2+7Q2EZE9YkDgp4AgIqKf2H1AaHUdSAzxxSTpEEuXQkTUr5jtSWprIXZ0wJ/mxFi6DCKifsfuryCIiMg4BgQRERnFgCAiIqMYEEREZBQDgoiIjGJAEBGRUQwIIiIyigFBRERG2dSKcj4+PggKCrqrfdVqNXx9fXu3oH6C52a9bPn8eG79Q1VVFa5cuWL0PZsKiF/ilyxX2t/x3KyXLZ8fz63/4y0mIiIyigFBRERGOb7++uuvW7qI/iIuLs7SJZgNz8162fL58dz6N/ZBEBGRUbzFRERERjEgiIjIKLsPiMLCQoSGhkIqlWLlypWWLqdXVVdXY/LkyQgLC0NERATWrVtn6ZJ6nU6nw7hx4/DQQw9ZupRe1dDQgEceeQRjx45FWFgYjh8/bumSetWaNWsQERGByMhIPPbYY2htbbV0SXctPT0dfn5+iIyM1G+7evUqpk2bhpCQEEybNg319fUWrPDu2XVA6HQ6ZGVl4bPPPsO5c+ewfft2nDt3ztJl9RqxWIw//elPOH/+PL799lu8//77NnV+ALBu3TqEhYVZuoxet3TpUqSkpODChQs4c+aMTZ2jUqnEu+++i+LiYnz33XfQ6XTIz8+3dFl3bcGCBSgsLDTYtnLlSkydOhUKhQJTp0612h+fdh0QRUVFkEqlCA4OhrOzM9LS0lBQUGDpsnrN8OHDMX78eACAh4cHwsLCoFQqLVxV76mpqcGnn36KRYsWWbqUXnXt2jV8/fXXeOqppwAAzs7OGDRokIWr6l1arRbXr1+HVqtFS0sL/P39LV3SXUtMTIS3t7fBtoKCAjz55JMAgCeffBJ79uyxRGm/mF0HhFKpREBAgP61RCKxqS/Qn6uqqsLp06cxYcIES5fSa55//nn88Y9/hIODbf1nXFFRAV9fXyxcuBDjxo3DokWL0NzcbOmyes2IESPw0ksvITAwEMOHD4eXlxemT59u6bJ6VW1tLYYPHw6g84eaSqWycEV3x7b+n3WHjI3wFYlEFqjEvDQaDWbPno21a9fC09PT0uX0in379sHPz88mxprfSqvV4tSpU3jmmWdw+vRpuLu7W+0tCmPq6+tRUFCAyspKXL58Gc3Nzdi6daulyyIj7DogJBIJqqur9a9ramqs+lLXmPb2dsyePRvz5s3DrFmzLF1Orzl69Cg++eQTBAUFIS0tDV9++SXmz59v6bJ6hUQigUQi0V/tPfLIIzh16pSFq+o9X3zxBUaNGgVfX184OTlh1qxZOHbsmKXL6lVDhw7Fjz/+CAD48ccf4efnZ+GK7o5dB0R8fDwUCgUqKyvR1taG/Px8yOVyS5fVawRBwFNPPYWwsDC88MILli6nV7399tuoqalBVVUV8vPzMWXKFJv5FTps2DAEBATg4sWLAICDBw8iPDzcwlX1nsDAQHz77bdoaWmBIAg4ePCgTXXCA4BcLsfmzZsBAJs3b0ZqaqqFK7pLgp379NNPhZCQECE4OFh48803LV1Or/rmm28EAEJUVJQQExMjxMTECJ9++qmly+p1hw4dEh588EFLl9GrTp8+LcTFxQlRUVFCamqqcPXqVUuX1KuWL18uhIaGChEREcL8+fOF1tZWS5d019LS0oRhw4YJYrFYGDFihLB+/XrhypUrwpQpUwSpVCpMmTJFqKurs3SZd4VTbRARkVF2fYuJiIi6x4AgIiKjGBBERGQUA4KIiIxiQBARkVEMCOp3RCIRXnzxRf3r//u//0NvLXy4YMECfPzxx71yrNvZuXMnwsLCMHny5C7vff/995gyZQrGjBmDkJAQvPHGG/qn+j/88EP4+voiNjZW/3fu3DlUVVVhwIABGDduHMLCwpCQkKAfZ98TDQ0N+OCDD3rt/Mg+MCCo33FxccGuXbtw5coVS5diQKfT9bjthg0b8MEHH+DQoUMG269fvw65XI7s7GyUlpbizJkzOHbsmMGX929+8xuUlJTo/24+JDd69GicPn0a58+fR35+PtasWYNNmzb1qB4GBN0NBgT1O2KxGBkZGVizZk2X9269Ahg4cCAA4KuvvkJSUhLmzJmDMWPGIDs7G3/961+RkJCAqKgolJeX6/f54osvcN9992HMmDHYt28fgM4v/2XLliE+Ph7R0dH4y1/+oj/u5MmTMXfuXERFRXWpZ/v27YiKikJkZCReeeUVAMAf/vAHHDlyBJmZmVi2bJlB+23btmHSpEn6yenc3NyQk5Nzx3MtBQcHY/Xq1Xj33Xe7vPf9998jISEBsbGxiI6OhkKhQHZ2NsrLyxEbG6uvadWqVfrzXbFiBYDOSR3Hjh2LJ598EtHR0XjkkUfQ0tICAMjOzkZ4eDiio6Px0ksv3VG9ZKUs/KAeURfu7u5CY2OjMHLkSKGhoUFYtWqVsGLFCkEQBOHJJ58Udu7cadBWEDqfpvby8hIuX74stLa2Cv7+/sLy5csFQRCEtWvXCkuXLtXvn5ycLOh0OqG0tFQYMWKEcP36deEvf/mL8MYbbwiCIAitra1CXFycUFFRIRw6dEhwc3MTKioqutSpVCqFgIAAQaVSCe3t7cLkyZOF3bt3C4IgCElJScI///nPLvv893//t7B27dou2wcNGiQ0NjYKmzZtEnx8fPRPvsfExAgtLS1CZWWlEBERYbBPfX294Orq2uVYixcvFrZu3SoIgiDcuHHD6P6ff/658Nvf/lbo6OgQdDqd8OCDDwqHDx8WKisrBQDCkSNHBEEQhIULFwqrVq0S6urqhDFjxggdHR36zybbxysI6pc8PT3xxBNPGP2F3J34+HgMHz4cLi4uGD16tP5XelRUFKqqqvTt5syZAwcHB4SEhCA4OBgXLlzA/v37sWXLFsTGxmLChAmoq6uDQqEAACQkJGDUqFFdPu+f//wn7r//fvj6+kIsFmPevHn4+uuvb1ujIAjdzhh8c/utt5gGDBjQ7bGMmThxIt566y288847uHTpktH99+/fj/3792PcuHEYP348Lly4oD/fgIAATJo0CQAwf/58HDlyBJ6ennB1dcWiRYuwa9cuuLm53fY8yTYwIKjfev7557FhwwaDtRDEYjE6OjoAdH5BtrW16d9zcXHR/28HBwf9awcHB2i1Wv17t35Bi0QiCIKA9957T/+lXFlZqQ8Yd3d3o/V19wV9OxERESguLjbYVlFRgYEDB8LDw+OOjnX69Gmjk9zNnTsXn3zyCQYMGIDk5GR8+eWXXdoIgoBXX31Vf75lZWX6BYqM/fuIxWIUFRVh9uzZ2LNnD1JSUu6oVrJODAjqt7y9vTFnzhxs2LBBvy0oKAgnT54E0LlqV3t7+x0fd+fOnejo6EB5eTkqKioQGhqK5ORk/PnPf9Yfr7S01OQiPRMmTMDhw4dx5coV6HQ6bN++HUlJSbfdZ968eThy5Ai++OILAJ2d1s899xxefvnlOzqHqqoqvPTSS1iyZEmX9yoqKhAcHIznnnsOcrkcZ8+ehYeHB5qamvRtkpOTsXHjRmg0GgCdi2fdXNTmhx9+0K+BvX37dtx7773QaDRobGzEAw88gLVr16KkpOSO6iXrJLZ0AUS38+KLLyInJ0f/+re//S1SU1ORkJCAqVOndvvr/nZCQ0ORlJSE2tpa5Obm6m+dVFVVYfz48RAEAb6+viaXiRw+fDjefvttTJ48GYIg4IEHHjA5rfOAAQNQUFCAJUuWICsrCzqdDo8//jgWL16sb7Njxw4cOXJE//qDDz6Av78/ysvLMW7cOLS2tsLDwwNLlizBwoULu3zGjh07sHXrVjg5OWHYsGFYvnw5vL29MWnSJERGRmLGjBlYtWoVzp8/j4kTJwLo7OzfunUrHB0dERYWhs2bN+Ppp59GSEgInnnmGTQ2NiI1NRWtra0QBMHoAAKyPZzNlYj0qqqq8NBDD+G7776zdCnUD/AWExERGcUrCCIiMopXEEREZBQDgoiIjGJAEBGRUQwIIiIyigFBRERG/X9YkVsgemePxwAAAABJRU5ErkJggg==\n", 371 | "text/plain": [ 372 | "
" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "responses = [e[\"responses\"].max().item() for e in experiments]\n", 381 | "plt.plot(responses)\n", 382 | "plt.xlabel(\"Number of OED steps\")\n", 383 | "plt.ylabel(\"best binding affinity\");" 384 | ] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "Python 3", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.9.1" 404 | } 405 | }, 406 | "nbformat": 4, 407 | "nbformat_minor": 5 408 | } 409 | -------------------------------------------------------------------------------- /pyroed/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.0" 2 | 3 | from .api import ( 4 | decode_design, 5 | encode_design, 6 | get_next_design, 7 | start_experiment, 8 | update_experiment, 9 | ) 10 | 11 | __all__ = [ 12 | "decode_design", 13 | "encode_design", 14 | "get_next_design", 15 | "update_experiment", 16 | "start_experiment", 17 | ] 18 | -------------------------------------------------------------------------------- /pyroed/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pyroed's high-level interface includes a design language and a set of 3 | functions to operate on Python data structures. 4 | 5 | The **design language** allows you to specify a problem by defining a 6 | ``SCHEMA``, a list ``CONSTRAINTS`` of :class:`~pyroed.constraints.Constraint` 7 | objects, a list ``FEATURE_BLOCKS`` defining cross features, and a list 8 | ``GIBBS_BLOCKS`` defining groups of features that are related to each other. 9 | The examples in this module will use the following model specification:: 10 | 11 | SCHEMA = OrderedDict() 12 | SCHEMA["aa1"] = ["P", "L", None] 13 | SCHEMA["aa2"] = ["N", "Y", "T", None] 14 | SCHEMA["aa3"] = ["R", "S"] 15 | 16 | CONSTRAINTS = [Not(And(TakesValue("aa1", None), TakesValue("aa2", None)))] 17 | 18 | FEATURE_BLOCKS = [["aa1"], ["aa2"], ["aa3"], ["aa1", "aa2"], ["aa2", "aa3"]] 19 | 20 | GIBBS_BLOCKS = [["aa1", "aa2"], ["aa2", "aa3"]] 21 | 22 | After declaring the design space, we can progressively gather data into an 23 | ``experiment`` dict by using the functions in this module and by experimentally 24 | measuring sequences. 25 | 26 | - :func:`encode_design` and :func:`decode_design` convert between 27 | text-representations of designs like ``[["P", "N", "R"], ["P", "N", "S"]]`` 28 | and PyTorch representations of designs like 29 | ``torch.tensor([[0, 0, 0], [0, 0, 1]])``. 30 | - :func:`start_experiment` initializes an experiment dict, 31 | - :func:`get_next_design` suggests a next set of sequences to test, and 32 | - :func:`update_experiment` updates an experiment dict with measured responses. 33 | 34 | Note that :func:`get_next_design` merely retuns suggested sequences; you can 35 | ignore these suggestions or measure a different set of sequences if you want. 36 | For example if some of your measurements are lost due to technical reasons, you 37 | can simply pass a subset of the suggested design back to 38 | :func:`update_experiment`. 39 | """ 40 | 41 | from collections import OrderedDict 42 | from typing import Callable, Dict, Iterable, List, Optional 43 | 44 | import torch 45 | 46 | from .oed import thompson_sample 47 | from .typing import Blocks, Constraints, Schema, validate 48 | 49 | 50 | def encode_design( 51 | schema: Schema, design: Iterable[List[Optional[str]]] 52 | ) -> torch.Tensor: 53 | """ 54 | Converts a human readable list of sequences into a tensor. 55 | 56 | Example:: 57 | 58 | SCHEMA = OrderedDict() 59 | SCHEMA["aa1"] = ["P", "L", None] 60 | SCHEMA["aa2"] = ["N", "Y", "T", None] 61 | SCHEMA["aa3"] = ["R", "S"] 62 | 63 | design = [ 64 | ["P", "N", "R"], 65 | ["P", "N", "S"], 66 | [None, "N", "R"], 67 | ["P", None, "R"], 68 | ] 69 | sequences = encode_design(SCHEMA, design) 70 | print(sequences) 71 | # torch.tensor([[0, 0, 0], [0, 0, 1], [2, 0, 0], [0, 3, 0]]) 72 | 73 | :param OrderedDict schema: A schema dict. 74 | :param list design: A list of list of choices (strings or None). 75 | :returns: A tensor of encoded sequences. 76 | :rtype: torch.tensor 77 | """ 78 | # Validate inputs. 79 | design = list(design) 80 | assert len(design) > 0 81 | assert isinstance(schema, OrderedDict) 82 | if __debug__: 83 | for seq in design: 84 | assert len(seq) == len(schema) 85 | for value, (name, values) in zip(seq, schema.items()): 86 | if value not in values: 87 | raise ValueError( 88 | f"Value {repr(value)} not found in schema[{repr(name)}]" 89 | ) 90 | 91 | # Convert python list -> tensor. 92 | rows = [ 93 | [values.index(value) for value, values in zip(seq, schema.values())] 94 | for seq in design 95 | ] 96 | return torch.tensor(rows, dtype=torch.long) 97 | 98 | 99 | def decode_design(schema: Schema, sequences: torch.Tensor) -> List[List[Optional[str]]]: 100 | """ 101 | Converts an tensor representation of a design into a readable list of designs. 102 | 103 | Example:: 104 | 105 | SCHEMA = OrderedDict() 106 | SCHEMA["aa1"] = ["P", "L", None] 107 | SCHEMA["aa2"] = ["N", "Y", "T", None] 108 | SCHEMA["aa3"] = ["R", "S"] 109 | 110 | sequences = torch.tensor([[0, 0, 0], [0, 0, 1], [2, 0, 0]]) 111 | design = decode_design(SCHEMA, sequences) 112 | print(design) 113 | # [["P", "N", "R"], ["P", "N", "S"], [None, "N", "R"]] 114 | 115 | :param OrderedDict schema: A schema dict. 116 | :param torch.Tensor sequences: A tensor of encoded sequences. 117 | :returns: A list of list of choices (strings or None). 118 | :rtype: list 119 | """ 120 | # Validate. 121 | assert isinstance(schema, OrderedDict) 122 | assert isinstance(sequences, torch.Tensor) 123 | assert sequences.dtype == torch.long 124 | assert sequences.dim() == 2 125 | 126 | # Convert tensor -> python list. 127 | rows = [ 128 | [values[i] for i, values in zip(seq, schema.values())] 129 | for seq in sequences.tolist() 130 | ] 131 | return rows 132 | 133 | 134 | def start_experiment( 135 | schema: Schema, 136 | sequences: torch.Tensor, 137 | responses: torch.Tensor, 138 | batch_ids: Optional[torch.Tensor] = None, 139 | ) -> Dict[str, torch.Tensor]: 140 | """ 141 | Creates a cumulative experiment with initial data. 142 | 143 | Example:: 144 | 145 | SCHEMA = OrderedDict() 146 | SCHEMA["aa1"] = ["P", "L", None] 147 | SCHEMA["aa2"] = ["N", "Y", "T", None] 148 | SCHEMA["aa3"] = ["R", "S"] 149 | 150 | sequences = torch.tensor([[0, 0, 0], [0, 0, 1], [2, 0, 0]]) 151 | responses = torch.tensor([0.1, 0.4, 0.5]) 152 | 153 | experiment = start_experiment(SCHEMA, sequences, responses) 154 | 155 | :param OrderedDict schema: A schema dict. 156 | :param torch.Tensor sequences: A tensor of encoded sequences that have been 157 | measured. 158 | :param torch.Tensor responses: A tensor of the measured responses of sequences. 159 | :param torch.Tensor batch_ids: An optional tensor of batch ids. 160 | :returns: A cumulative experiment dict. 161 | :rtype: dict 162 | """ 163 | # If unspecified, simply create a single batch id. 164 | if batch_ids is None: 165 | batch_ids = sequences.new_zeros(responses.shape) 166 | 167 | # This function is a thin wrapper around dict(). 168 | experiment = { 169 | "sequences": sequences, 170 | "responses": responses, 171 | "batch_ids": batch_ids, 172 | } 173 | 174 | # Validate. 175 | if __debug__: 176 | validate(schema, experiment=experiment) 177 | return experiment 178 | 179 | 180 | def get_next_design( 181 | schema: Schema, 182 | constraints: Constraints, 183 | feature_blocks: Blocks, 184 | gibbs_blocks: Blocks, 185 | experiment: Dict[str, torch.Tensor], 186 | *, 187 | design_size: int = 10, 188 | feature_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 189 | config: Optional[dict] = None, 190 | ) -> torch.Tensor: 191 | """ 192 | Generate a new design given cumulative experimental data. 193 | 194 | Under the hood this runs :func:`~pyroed.oed.thompson_sample`, which 195 | performs Bayesian inference via either variational inference 196 | :func:`~pyroed.inference.fit_svi` or MCMC 197 | :func:`~pyroed.inference.fit_mcmc` and performs optimization via 198 | :func:`~pyroed.optimizers.optimize_simulated_annealing`. These algorithms 199 | can be tuned through the ``config`` dict. 200 | 201 | Example:: 202 | 203 | # Initialize experiment. 204 | sequences = encode_design(SCHEMA, [ 205 | ["P", "N", "R"], 206 | ["P", "N", "S"], 207 | [None, "N", "R"], 208 | ["P", None, "R"], 209 | ]) 210 | print(sequences) 211 | # torch.tensor([[0, 0, 0], [0, 0, 1], [2, 0, 0], [0, 3, 0]]) 212 | experiment = { 213 | "sequences": sequences, 214 | "responses": torch.tensor([0.1, 0.4, 0.5, 0.2]), 215 | "batch_ids": torch.tensor([0, 0, 1, 1]), 216 | } 217 | 218 | # Run Bayesian optimization to get the next sequences to measure. 219 | new_sequences = get_next_design( 220 | SCHEMA, CONSTRAINTS, FEATURE_BLOCKS, GIBBS_BLOCKS, 221 | experiment, design_size=2, 222 | ) 223 | print(new_sequences) 224 | # torch.tensor([[1, 1, 1], [1, 2, 0]]) 225 | print(decode_design(SCHEMA, new_sequences)) 226 | # [["L", "Y", "S"], ["L", T", "R"]] 227 | 228 | :param OrderedDict schema: A schema dict. 229 | :param list constraints: A list of zero or more 230 | :class:`~pyroed.constraints.Constraint` objects. 231 | :param list feature_blocks: A list of choice blocks for linear regression. 232 | :param list gibbs_blocks: A list of choice blocks for Gibbs sampling. 233 | :param dict experiment: A dict containing all old experiment data. 234 | :param int design_size: Number of designs to try to return (sometimes 235 | fewer designs are found). 236 | :param callable feature_fn: An optional callback to generate additional 237 | features. If provided, this function should input a batch of sequences 238 | (say of shape ``batch_shape``) and return a floating point tensor of of 239 | shape ``batch_shape + (F,)`` for some number of features ``F``. This 240 | will be called internally during inference. 241 | :param dict config: Optional config dict. See keyword arguments to 242 | :func:`~pyroed.oed.thompson_sample` for details. 243 | :returns: A tensor of encoded new sequences to measure, i.e. a ``design``. 244 | :rtype: torch.Tensor 245 | """ 246 | if config is None: 247 | config = {} 248 | 249 | # Validate inputs. 250 | assert isinstance(design_size, int) 251 | assert design_size > 0 252 | assert isinstance(config, dict) 253 | if __debug__: 254 | validate( 255 | schema, 256 | constraints=constraints, 257 | feature_blocks=feature_blocks, 258 | gibbs_blocks=gibbs_blocks, 259 | experiment=experiment, 260 | config=config, 261 | ) 262 | 263 | # Perform OED via Thompson sampling. 264 | design_set = thompson_sample( 265 | schema, 266 | constraints, 267 | feature_blocks, 268 | gibbs_blocks, 269 | experiment, 270 | design_size=design_size, 271 | feature_fn=feature_fn, 272 | **config, 273 | ) 274 | design = torch.tensor(sorted(design_set)) 275 | return design 276 | 277 | 278 | def update_experiment( 279 | schema: Schema, 280 | experiment: Dict[str, torch.Tensor], 281 | new_sequences: torch.Tensor, 282 | new_responses: torch.Tensor, 283 | new_batch_ids: Optional[torch.Tensor] = None, 284 | ) -> Dict[str, torch.Tensor]: 285 | """ 286 | Updates a cumulative experiment by appending new data. 287 | 288 | Note this does not modify its arguments; you must capture the result:: 289 | 290 | experiment = update_experiment( 291 | SCHEMA, experiment, new_sequences, new_responses, new_batch_ids 292 | ) 293 | 294 | :param OrderedDict schema: A schema dict. 295 | :param dict experiment: A dict containing all old experiment data. 296 | :param torch.Tensor new_sequences: A set of new sequences that have been 297 | measured. These may simply be the ``design`` returned by 298 | :func:`get_next_design`, or may be arbitrary new sequences you have 299 | decided to measure, or old sequences you have measured again, or a 300 | combination of all three. 301 | :param torch.Tensor new_responses: A tensor of the measured responses of sequences. 302 | :param torch.Tensor new_batch_ids: An optional tensor of batch ids. 303 | :returns: A concatenated experiment. 304 | :rtype: dict 305 | """ 306 | # If unspecified, simply create a new single batch id. 307 | if new_batch_ids is None: 308 | new_batch_ids = experiment["batch_ids"].new_full( 309 | new_responses.shape, experiment["batch_ids"].max().item() + 1 310 | ) 311 | 312 | # Validate. 313 | if __debug__: 314 | validate(schema, experiment=experiment) 315 | assert len(new_responses) == len(new_sequences) 316 | assert len(new_batch_ids) == len(new_sequences) 317 | 318 | # Concatenate the dictionaries. 319 | new_experiment = { 320 | "sequences": new_sequences, 321 | "responses": new_responses, 322 | "batch_ids": new_batch_ids, 323 | } 324 | experiment = {k: torch.cat([v, new_experiment[k]]) for k, v in experiment.items()} 325 | 326 | # Validate again. 327 | if __debug__: 328 | validate(schema, experiment=experiment) 329 | return experiment 330 | 331 | 332 | __all__ = [ 333 | "decode_design", 334 | "encode_design", 335 | "get_next_design", 336 | "start_experiment", 337 | "update_experiment", 338 | ] 339 | -------------------------------------------------------------------------------- /pyroed/constraints.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements a declarative langauge of constraints on sequence 3 | values. The language consists of: 4 | 5 | - Atomic constraints :class:`TakesValue`, :class:`TakesValues`, 6 | :class:`AllDifferent`, and :class:`Function`. 7 | - Logical operations :class:`Not`, :class:`And`, :class:`Or`, :class:`Xor`, 8 | :class:`IfThen`, and :class:`Iff`. 9 | 10 | Examples in this file will reference the following ``SCHEMA``:: 11 | 12 | SCHEMA = OrderedDict() 13 | SCHEMA["aa1"] = ["P", "L", None] 14 | SCHEMA["aa2"] = ["N", "Y", "T", None] 15 | SCHEMA["aa3"] = ["R", "S"] 16 | """ 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Callable, Dict, Optional 20 | 21 | import torch 22 | 23 | from .typing import Schema 24 | 25 | 26 | class Constraint(ABC): 27 | """ 28 | Abstract base class for constraints. 29 | 30 | Derived classes must implement the :meth:`__call__` method. 31 | """ 32 | 33 | @abstractmethod 34 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 35 | raise NotImplementedError 36 | 37 | 38 | class Function(Constraint): 39 | """ 40 | Constrains the choices by a black box user-provided function. 41 | 42 | Example:: 43 | 44 | def constraint(choices: torch.Tensor): 45 | # Put an upper bound on the sequences. 46 | return choices.float().sum(-1) < 8 47 | 48 | CONSTRAINTS.append(Function(constraint)) 49 | 50 | :param callable fn: A function inputting a batch of encoded sequences and 51 | returning a batched feasability tensor. 52 | """ 53 | 54 | def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): 55 | super().__init__() 56 | self.fn = fn 57 | 58 | def __str__(self): 59 | return f"Function({self.fn})" 60 | 61 | def __call__(self, schema: Schema, choices: torch.Tensor): 62 | ok = self.fn(choices) 63 | assert isinstance(ok, torch.Tensor) 64 | assert ok.dtype == torch.bool 65 | if not torch._C._get_tracing_state(): 66 | assert ok.shape == choices.shape[:-1] 67 | return ok 68 | 69 | 70 | class TakesValue(Constraint): 71 | """ 72 | Constrains a site to take a fixed value. 73 | 74 | Example:: 75 | 76 | CONSTRAINTS.append(TakesValue("aa1", "P")) 77 | 78 | :param str name: Name of a sequence variable in the schema. 79 | :param value: The value that the variable can take. 80 | """ 81 | 82 | def __init__(self, name: str, value: Optional[str]): 83 | super().__init__() 84 | self.name = name 85 | self.value = value 86 | 87 | def __str__(self): 88 | return f"TakesValue({repr(self.name)}, {repr(self.value)})" 89 | 90 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 91 | for k, (name, values) in enumerate(schema.items()): 92 | if name == self.name: 93 | if self.value not in values: 94 | raise ValueError( 95 | f"In constraint {self}: " 96 | f"{repr(self.value)} not found in schema[repr(name)]" 97 | ) 98 | v = values.index(self.value) 99 | return choices[..., k] == v 100 | raise ValueError( 101 | f"In constraint {self}: {repr(self.value)} not found in schema" 102 | ) 103 | 104 | 105 | class TakesValues(Constraint): 106 | r""" 107 | Constrains a site to take one of a set of values. 108 | 109 | Example:: 110 | 111 | CONSTRAINTS.append(TakesValue("aa1", "P", "L")) 112 | 113 | :param str name: Name of a sequence variable in the schema. 114 | :param \*values: Values that the variable can take. 115 | """ 116 | 117 | def __init__(self, name: str, *values: Optional[str]): 118 | super().__init__() 119 | self.name = name 120 | self.values = values 121 | self.schema: Optional[Schema] = None 122 | self.index: Optional[int] = None 123 | self.mask: Optional[torch.Tensor] = None 124 | 125 | def __str__(self): 126 | args = ", ".join(map(repr, (self.name,) + self.values)) 127 | return f"TakesValues({args})" 128 | 129 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 130 | # Lazily build tensor indices. 131 | if schema is not self.schema: 132 | assert self.name in schema 133 | for index, (name, values) in enumerate(schema.items()): 134 | if name == self.name: 135 | self.schema = schema 136 | self.index = index 137 | self.mask = torch.zeros(len(values), dtype=torch.bool) 138 | assert set(self.values).issubset(values) 139 | for i, value in enumerate(values): 140 | if value in self.values: 141 | self.mask[i] = True 142 | break 143 | assert self.mask is not None 144 | assert self.index is not None 145 | 146 | # Compute the constraint. 147 | return self.mask[choices[..., self.index]] 148 | 149 | 150 | class AllDifferent(Constraint): 151 | r""" 152 | Constrains a set of sites to all have distinct values. 153 | 154 | Example:: 155 | 156 | CONSTRAINTS.append(AllDifferent("aa1", "aa2", "aa3")) 157 | 158 | :param str \*names: Names of sequence variables that should be distinct. 159 | """ 160 | 161 | def __init__(self, *names: str): 162 | super().__init__() 163 | self.names = names 164 | self.schema: Optional[Schema] = None 165 | self.name_to_int: Optional[Dict[str, int]] = None 166 | self.standardize: Optional[torch.Tensor] = None 167 | 168 | def __str__(self): 169 | return "AllDifferent({})".format(", ".join(map(repr, self.names))) 170 | 171 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 172 | # Lazily build tensor indices. 173 | if schema is not self.schema: 174 | self.schema = schema 175 | self.name_to_int = {name: i for i, name in enumerate(schema)} 176 | standard = list(set().union(*(schema[n] for n in self.names))) 177 | standard.sort(key=lambda x: (0, "") if x is None else (1, x)) 178 | self.standardize = choices.new_empty( 179 | len(self.names), max(len(schema[n]) for n in self.names) 180 | ) 181 | for i, name in enumerate(self.names): 182 | for j, value in enumerate(schema[name]): 183 | self.standardize[i, j] = standard.index(value) 184 | assert self.name_to_int is not None 185 | assert self.standardize is not None 186 | 187 | # Compute the constraint. 188 | ps = [self.name_to_int[name] for name in self.names] 189 | ok = torch.tensor(True) 190 | for i, p1 in enumerate(ps): 191 | c1 = self.standardize[i][choices[..., p1]] 192 | for j, p2 in enumerate(ps[:i]): 193 | c2 = self.standardize[j][choices[..., p2]] 194 | ok = ok & (c1 != c2) 195 | return ok 196 | 197 | 198 | class Not(Constraint): 199 | """ 200 | Negates a constraints. 201 | 202 | Example:: 203 | 204 | CONSTRAINTS.append(Not(TakesValue("aa1", "P"))) 205 | 206 | :param Constraint arg: A constraint. 207 | """ 208 | 209 | def __init__(self, arg: Constraint): 210 | super().__init__() 211 | self.arg = arg 212 | 213 | def __str__(self): 214 | return f"Not({self.arg})" 215 | 216 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 217 | arg = self.arg(schema, choices) 218 | return ~arg 219 | 220 | 221 | class And(Constraint): 222 | """ 223 | Conjoins two constraints. 224 | 225 | Example:: 226 | 227 | CONSTRAINTS.append(And(TakesValue("aa1", None), TakesValue("aa2", None))) 228 | 229 | :param Constraint lhs: A constraint. 230 | :param Constraint rhs: A constraint. 231 | """ 232 | 233 | def __init__(self, lhs: Constraint, rhs: Constraint): 234 | super().__init__() 235 | self.lhs = lhs 236 | self.rhs = rhs 237 | 238 | def __str__(self): 239 | return f"And({self.lhs}, {self.rhs})" 240 | 241 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 242 | lhs = self.lhs(schema, choices) 243 | rhs = self.rhs(schema, choices) 244 | return rhs & lhs 245 | 246 | 247 | class Or(Constraint): 248 | """ 249 | Disjoins two constraints. 250 | 251 | Example:: 252 | 253 | CONSTRAINTS.append(Or(TakesValue("aa1", None), TakesValue("aa2", None))) 254 | 255 | :param Constraint lhs: A constraint. 256 | :param Constraint rhs: A constraint. 257 | """ 258 | 259 | def __init__(self, lhs: Constraint, rhs: Constraint): 260 | super().__init__() 261 | self.lhs = lhs 262 | self.rhs = rhs 263 | 264 | def __str__(self): 265 | return f"Or({self.lhs}, {self.rhs})" 266 | 267 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 268 | lhs = self.lhs(schema, choices) 269 | rhs = self.rhs(schema, choices) 270 | return rhs | lhs 271 | 272 | 273 | class Xor(Constraint): 274 | """ 275 | Exclusive or among constraints. Equivalent to ``Not(Iff(lhs, rhs))``. 276 | 277 | Example:: 278 | 279 | CONSTRAINTS.append(Xor(TakesValue("aa1", None), TakesValue("aa2", None))) 280 | 281 | :param Constraint lhs: A constraint. 282 | :param Constraint rhs: A constraint. 283 | """ 284 | 285 | def __init__(self, lhs: Constraint, rhs: Constraint): 286 | super().__init__() 287 | self.lhs = lhs 288 | self.rhs = rhs 289 | 290 | def __str__(self): 291 | return f"Xor({self.lhs}, {self.rhs})" 292 | 293 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 294 | lhs = self.lhs(schema, choices) 295 | rhs = self.rhs(schema, choices) 296 | return rhs ^ lhs 297 | 298 | 299 | class IfThen(Constraint): 300 | """ 301 | Conditional between constraints. 302 | 303 | Example:: 304 | 305 | CONSTRAINTS.append(IfThen(TakesValue("aa1", None), TakesValue("aa2", None))) 306 | 307 | :param Constraint lhs: A constraint. 308 | :param Constraint rhs: A constraint. 309 | """ 310 | 311 | def __init__(self, lhs: Constraint, rhs: Constraint): 312 | super().__init__() 313 | self.lhs = lhs 314 | self.rhs = rhs 315 | 316 | def __str__(self): 317 | return f"IfThen({self.lhs}, {self.rhs})" 318 | 319 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 320 | lhs = self.lhs(schema, choices) 321 | rhs = self.rhs(schema, choices) 322 | return rhs | ~lhs 323 | 324 | 325 | class Iff(Constraint): 326 | """ 327 | Equality among constraints. 328 | 329 | Exmaple:: 330 | 331 | CONSTRAINTS.append(Iff(TakesValue("aa1", None), TakesValue("aa2", None))) 332 | 333 | :param Constraint lhs: A constraint. 334 | :param Constraint rhs: A constraint. 335 | """ 336 | 337 | def __init__(self, lhs: Constraint, rhs: Constraint): 338 | super().__init__() 339 | self.lhs = lhs 340 | self.rhs = rhs 341 | 342 | def __str__(self): 343 | return f"IfThen({self.lhs}, {self.rhs})" 344 | 345 | def __call__(self, schema: Schema, choices: torch.Tensor) -> torch.Tensor: 346 | lhs = self.lhs(schema, choices) 347 | rhs = self.rhs(schema, choices) 348 | return lhs == rhs 349 | -------------------------------------------------------------------------------- /pyroed/criticism.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import warnings 3 | from typing import Callable, Dict, Optional 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from pyro import poutine 9 | from pyro.infer.reparam import AutoReparam 10 | from scipy.stats import pearsonr 11 | 12 | from .inference import fit_mcmc, fit_svi 13 | from .models import linear_response, model 14 | from .typing import Blocks, Schema 15 | 16 | 17 | def criticize( 18 | schema: Schema, 19 | feature_blocks: Blocks, 20 | experiment: Dict[str, torch.Tensor], 21 | test_data: Dict[str, torch.Tensor], 22 | *, 23 | feature_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 24 | response_type: str = "unit_interval", 25 | inference: str = "svi", 26 | num_posterior_samples: int = 11, 27 | mcmc_num_samples: int = 500, 28 | mcmc_warmup_steps: int = 500, 29 | mcmc_num_chains: int = 1, 30 | svi_num_steps: int = 501, 31 | svi_reparam: bool = True, 32 | thompson_temperature: float = 1.0, 33 | jit_compile: bool = False, 34 | log_every: int = 100, 35 | filename: Optional[str] = None, 36 | ): 37 | """ 38 | Plots observed versus predicted responses on a held out test set. 39 | 40 | :param OrderedDict schema: A schema dict. 41 | :param list feature_blocks: A list of choice blocks for linear regression. 42 | :param dict experiment: A dict containing training data. 43 | :param dict test_data: A dict containing held out test data. 44 | """ 45 | # Compute extra features. 46 | extra_features = None 47 | if feature_fn is not None: 48 | extra_features = feature_fn(experiment["sequences"]) 49 | 50 | bound_model = functools.partial( 51 | model, 52 | schema, 53 | feature_blocks, 54 | extra_features, 55 | experiment, 56 | response_type=response_type, 57 | likelihood_temperature=thompson_temperature, 58 | ) 59 | if inference == "svi" and svi_reparam: 60 | bound_model = AutoReparam()(bound_model) 61 | poutine.block(bound_model)() # initialize reparam 62 | 63 | # Fit a posterior distribution over parameters given experiment data. 64 | with warnings.catch_warnings(): 65 | warnings.filterwarnings( 66 | "ignore", 67 | "torch.tensor results are registered as constants", 68 | torch.jit.TracerWarning, 69 | ) 70 | if inference == "svi": 71 | sampler = fit_svi( 72 | bound_model, 73 | num_steps=svi_num_steps, 74 | jit_compile=jit_compile, 75 | plot=True, 76 | ) 77 | elif inference == "mcmc": 78 | sampler = fit_mcmc( 79 | bound_model, 80 | num_samples=mcmc_num_samples, 81 | warmup_steps=mcmc_warmup_steps, 82 | num_chains=mcmc_num_chains, 83 | jit_compile=jit_compile, 84 | ) 85 | else: 86 | raise ValueError(f"Unknown inference type: {inference}") 87 | 88 | test_responses = test_data["responses"] 89 | test_sequences = test_data["sequences"] 90 | sort_idx = test_responses.sort(0).indices 91 | test_responses = test_responses[sort_idx] 92 | test_sequences = test_sequences[sort_idx] 93 | if feature_fn is not None: 94 | extra_features = feature_fn(test_sequences) 95 | 96 | predictions = [] 97 | for _ in range(num_posterior_samples): 98 | coefs = poutine.condition(bound_model, sampler())() 99 | test_prediction = linear_response( 100 | schema, coefs, test_sequences, extra_features 101 | ).sigmoid() 102 | predictions.append(test_prediction) 103 | 104 | test_predictions = torch.stack(predictions).detach().cpu().numpy() 105 | mean_predictions = test_predictions.mean(0) 106 | std_predictions = test_predictions.std(0) 107 | corr = pearsonr(test_responses, mean_predictions)[0] 108 | 109 | fig = plt.figure(figsize=(6, 6)) 110 | plt.plot( 111 | np.linspace(0, 1, 100), 112 | np.linspace(0, 1, 100), 113 | color="k", 114 | linestyle="dotted", 115 | ) 116 | plt.errorbar( 117 | test_responses, 118 | mean_predictions, 119 | yerr=std_predictions, 120 | marker="o", 121 | linestyle="None", 122 | ) 123 | plt.text(0.2, 0.9, f"Pearson $\\rho$ = {corr:0.3g}", ha="center", va="center") 124 | plt.xlim(0, 1) 125 | plt.ylim(0, 1) 126 | plt.xlabel("Observed response") 127 | plt.ylabel("Predictedresponse") 128 | 129 | plt.tight_layout() 130 | if filename is not None: 131 | plt.savefig(filename) 132 | print(f"Saved {filename}") 133 | 134 | return fig 135 | -------------------------------------------------------------------------------- /pyroed/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_tf_data 2 | 3 | __all__ = ["load_tf_data"] 4 | -------------------------------------------------------------------------------- /pyroed/datasets/data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def load_tf_data(data_dir="./pyroed/datasets"): 8 | """ 9 | Return tuple (x, y) of numpy arrays for PBX4 transcription factor data. 10 | 11 | Reference: Barrera, Luis A., et al. "Survey of variation in human transcription 12 | factors reveals prevalent DNA binding changes." Science 351.6280 (2016): 1450-1454. 13 | """ 14 | xy = np.load(gzip.GzipFile(data_dir + "/tf_bind_8-PBX4_REF_R2.npy.gz", "rb")) 15 | x, y = xy[:, :-1], xy[:, -1] 16 | assert x.shape[0] == y.shape[0] 17 | assert x.ndim == 2 18 | return { 19 | "sequences": torch.tensor(x, dtype=torch.long), 20 | "responses": torch.tensor(y, dtype=torch.float), 21 | "batch_ids": torch.zeros(len(x), dtype=torch.long), 22 | } 23 | -------------------------------------------------------------------------------- /pyroed/datasets/tf_bind_8-PBX4_REF_R2.npy.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyro-ppl/pyroed/c549bd9dc9511e2199ff55fb2c86f84226b9b1c2/pyroed/datasets/tf_bind_8-PBX4_REF_R2.npy.gz -------------------------------------------------------------------------------- /pyroed/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict 2 | 3 | import pyro 4 | import torch 5 | from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO 6 | from pyro.infer.autoguide import AutoLowRankMultivariateNormal 7 | from pyro.infer.mcmc import MCMC, NUTS 8 | from pyro.optim import ClippedAdam 9 | 10 | 11 | def fit_svi( 12 | model: Callable, 13 | *, 14 | lr: float = 0.02, 15 | num_steps: int = 501, 16 | jit_compile: bool = False, 17 | log_every: int = 100, 18 | plot: bool = False, 19 | ) -> Callable[[], Dict[str, torch.Tensor]]: 20 | """ 21 | Fits a model via stochastic variational inference. 22 | 23 | :param callable model: A Bayesian regression model from :mod:`pyroed.models`. 24 | :returns: A variational distribution that can generate samples. 25 | :rtype: callable 26 | """ 27 | pyro.clear_param_store() 28 | guide: Callable[[], Dict[str, torch.Tensor]] = AutoLowRankMultivariateNormal(model) 29 | optim = ClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / num_steps)}) 30 | elbo = (JitTrace_ELBO if jit_compile else Trace_ELBO)() 31 | svi = SVI(model, guide, optim, elbo) 32 | losses = [] 33 | for step in range(num_steps): 34 | loss = svi.step() 35 | losses.append(loss) 36 | if log_every and step % log_every == 0: 37 | print(f"svi step {step} loss = {loss:0.6g}") 38 | 39 | if plot: 40 | import matplotlib.pyplot as plt 41 | 42 | plt.plot(losses) 43 | plt.xlabel("SVI step") 44 | plt.ylabel("loss") 45 | 46 | return guide 47 | 48 | 49 | def fit_mcmc( 50 | model: Callable, 51 | *, 52 | num_samples: int = 500, 53 | warmup_steps: int = 500, 54 | num_chains: int = 1, 55 | jit_compile: bool = True, 56 | ) -> Callable[[], Dict[str, torch.Tensor]]: 57 | """ 58 | Fits a model via Hamiltonian Monte Carlo. 59 | 60 | :param callable model: A Bayesian regression model from :mod:`pyroed.models`. 61 | :returns: A sampler that draws from the empirical distribution. 62 | :rtype: Sampler 63 | """ 64 | kernel = NUTS(model, jit_compile=jit_compile) 65 | mcmc = MCMC( 66 | kernel, 67 | num_samples=num_samples, 68 | warmup_steps=warmup_steps, 69 | num_chains=num_chains, 70 | ) 71 | mcmc.run() 72 | samples = mcmc.get_samples() 73 | return Sampler(samples) 74 | 75 | 76 | class Sampler: 77 | """ 78 | Helper to sample from an empirical distribution. 79 | 80 | :param dict samples: A dictionary of batches of samples. 81 | """ 82 | 83 | def __init__(self, samples: Dict[str, torch.Tensor]): 84 | self.samples = samples 85 | self.num_samples = len(next(iter(samples.values()))) 86 | 87 | def __call__(self) -> Dict[str, torch.Tensor]: 88 | i = torch.randint(0, self.num_samples, ()) 89 | return {k: v[i] for k, v in self.samples.items()} 90 | -------------------------------------------------------------------------------- /pyroed/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, Optional 3 | 4 | import pyro 5 | import pyro.distributions as dist 6 | import pyro.poutine as poutine 7 | import torch 8 | 9 | from .typing import Blocks, Coefs, Schema, validate 10 | 11 | 12 | def linear_response( 13 | schema: Schema, 14 | coefs: Coefs, 15 | sequence: torch.Tensor, 16 | extra_features: Optional[torch.Tensor], 17 | ) -> torch.Tensor: 18 | """ 19 | Linear response function. 20 | 21 | :param OrderedDict schema: A schema dict. 22 | :param dict coefs: A dictionary mapping feature tuples to coefficient 23 | tensors. 24 | :param torch.Tensor sequence: A tensor representing a sequence. 25 | :param torch.Tensor extra_features: An optional tensor of extra features, 26 | i.e. those computed by a custom ``features_fn`` rather than standard 27 | cross features from ``FEATURE_BLOCKS``. 28 | :returns: The response. 29 | :rtype: torch.Tensor 30 | """ 31 | if not torch._C._get_tracing_state(): 32 | assert isinstance(schema, OrderedDict) 33 | assert isinstance(coefs, dict) 34 | assert sequence.dtype == torch.long 35 | assert sequence.size(-1) == len(schema) 36 | if extra_features is None: 37 | assert None not in coefs 38 | else: 39 | assert None in coefs 40 | assert coefs[None].dim() == 1 41 | assert extra_features.shape == sequence.shape[:-1] + coefs[None].shape 42 | assert (extra_features is not None) == (None in coefs) 43 | choices = dict(zip(schema, sequence.unbind(-1))) 44 | 45 | result = torch.tensor(0.0) 46 | for key, coef in coefs.items(): 47 | if key is None: 48 | assert extra_features is not None 49 | result = result + extra_features @ coefs[None] 50 | else: 51 | assert isinstance(key, tuple) 52 | if not torch._C._get_tracing_state(): 53 | assert coef.dim() == len(key) 54 | index = tuple(choices[name] for name in key) 55 | result = result + coef[index] 56 | 57 | return result 58 | 59 | 60 | def model( 61 | schema: Schema, 62 | feature_blocks: Blocks, 63 | extra_features: Optional[torch.Tensor], 64 | experiment: Dict[str, torch.Tensor], # sequences, batch_id, optional(response) 65 | *, 66 | max_batch_id: Optional[int] = None, 67 | response_type: str = "unit_interval", 68 | likelihood_temperature: float = 1.0, 69 | quantization_bins: int = 100, 70 | ): 71 | """ 72 | A `Pyro `_ model for Bayesian linear regression. 73 | 74 | :param OrderedDict schema: A schema dict. 75 | :param list feature_blocks: A list of choice blocks for linear regression. 76 | :param dict experiment: A dict containing all old experiment data. 77 | :param str response_type: Type of response, one of: "real", "unit_interval". 78 | :param int quantization_bins: Number of bins in which to quantize the 79 | "unit_interval" response response_type. 80 | :returns: A dictionary mapping feature tuples to coefficient tensors. 81 | :rtype: dict 82 | """ 83 | if max_batch_id is None: 84 | max_batch_id = int(experiment["batch_ids"].max()) 85 | N = experiment["sequences"].size(0) 86 | B = 1 + max_batch_id 87 | if __debug__ and not torch._C._get_tracing_state(): 88 | validate(schema, experiment=experiment) 89 | if extra_features is not None: 90 | assert extra_features.dim() == 2 91 | assert extra_features.size(0) == N 92 | name_to_int = {name: i for i, name in enumerate(schema)} 93 | 94 | # Hierarchically sample linear coefficients. 95 | coef_scale_loc = pyro.sample("coef_scale_loc", dist.Normal(-2, 1)) 96 | coef_scale_scale = pyro.sample("coef_scale_scale", dist.LogNormal(-2, 1)) 97 | coefs: Coefs = {} 98 | trivial_blocks: Blocks = [[]] # For the constant term. 99 | for block in trivial_blocks + feature_blocks: 100 | shape = tuple(len(schema[name]) for name in block) 101 | ps = tuple(name_to_int[name] for name in block) 102 | suffix = "_".join(map(str, ps)) 103 | # Within-component variance of coefficients. 104 | coef_scale = pyro.sample( 105 | f"coef_scale_{suffix}", 106 | dist.LogNormal(coef_scale_loc, coef_scale_scale), 107 | ) 108 | # Linear coefficients. Note this overparametrizes; there are only 109 | # len(choices) - 1 degrees of freedom and 1 nuisance dim. 110 | coefs[tuple(block)] = pyro.sample( 111 | f"coef_{suffix}", 112 | dist.Normal(torch.zeros(shape), coef_scale).to_event(len(shape)), 113 | ) 114 | if extra_features is not None: 115 | # Sample coefficients for all extra user-provided features. 116 | shape = extra_features.shape[-1:] 117 | coef_scale = pyro.sample( 118 | "coef_scale", dist.LogNormal(coef_scale_loc, coef_scale_scale) 119 | ) 120 | coefs[None] = pyro.sample( 121 | "coef", dist.Normal(torch.zeros(shape), coef_scale).to_event(1) 122 | ) 123 | 124 | # Compute the linear response function. 125 | response_loc = linear_response( 126 | schema, 127 | coefs, 128 | experiment["sequences"], 129 | extra_features, 130 | ) 131 | 132 | # Observe a noisy response. 133 | within_batch_scale = pyro.sample("within_batch_scale", dist.LogNormal(0, 1)) 134 | if B == 1: 135 | within_batch_loc = response_loc 136 | else: 137 | # Model batch effects. 138 | across_batch_scale = pyro.sample("across_batch_scale", dist.LogNormal(0, 1)) 139 | with pyro.plate("batch", B): 140 | batch_response = pyro.sample( 141 | "batch_response", dist.Normal(0, across_batch_scale) 142 | ) 143 | if not torch._C._get_tracing_state(): 144 | assert batch_response.shape == (B,) 145 | within_batch_loc = response_loc + batch_response[experiment["batch_ids"]] 146 | 147 | # This likelihood can be generalized to counts or other datatype. 148 | with pyro.plate("data", N): 149 | if response_type == "real": 150 | with poutine.scale(scale=1 / likelihood_temperature): 151 | pyro.sample( 152 | "responses", 153 | dist.Normal(within_batch_loc, within_batch_scale), 154 | obs=experiment.get("responses"), 155 | ) 156 | 157 | elif response_type == "unit_interval": 158 | logits = pyro.sample( 159 | "logits", 160 | dist.Normal(within_batch_loc, within_batch_scale), 161 | ) 162 | 163 | # Quantize the observation to avoid numerical artifacts near 0 and 1. 164 | quantized_obs = None 165 | response = experiment.get("responses") 166 | if response is not None: # during inference 167 | quantized_obs = (response * quantization_bins).round() 168 | with poutine.scale(scale=1 / likelihood_temperature): 169 | quantized_obs = pyro.sample( 170 | "quantized_response", 171 | dist.Binomial(quantization_bins, logits=logits), 172 | obs=quantized_obs, 173 | ) 174 | assert quantized_obs is not None 175 | if response is None: # during simulation 176 | pyro.deterministic("responses", quantized_obs / quantization_bins) 177 | 178 | else: 179 | raise ValueError(f"Unknown response_type type {repr(response_type)}") 180 | 181 | return coefs 182 | -------------------------------------------------------------------------------- /pyroed/oed.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import warnings 3 | from typing import Callable, Dict, Optional, Set, Tuple 4 | 5 | import pyro.poutine as poutine 6 | import torch 7 | from pyro.infer.reparam import AutoReparam 8 | 9 | from .inference import fit_mcmc, fit_svi 10 | from .models import model 11 | from .optimizers import optimize_simulated_annealing 12 | from .typing import Blocks, Constraints, Schema 13 | 14 | 15 | def thompson_sample( 16 | schema: Schema, 17 | constraints: Constraints, 18 | feature_blocks: Blocks, 19 | gibbs_blocks: Blocks, 20 | experiment: Dict[str, torch.Tensor], 21 | *, 22 | design_size: int = 10, 23 | feature_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 24 | response_type: str = "unit_interval", 25 | inference: str = "svi", 26 | mcmc_num_samples: int = 500, 27 | mcmc_warmup_steps: int = 500, 28 | mcmc_num_chains: int = 1, 29 | svi_num_steps: int = 501, 30 | svi_reparam: bool = True, 31 | svi_plot: bool = False, 32 | sa_num_steps: int = 1000, 33 | max_tries: int = 1000, 34 | thompson_temperature: float = 1.0, 35 | jit_compile: Optional[bool] = None, 36 | log_every: int = 100, 37 | ) -> Set[Tuple[int, ...]]: 38 | """ 39 | Performs Bayesian optimization via Thompson sampling. 40 | 41 | This fits a Bayesian model to existing experimental data, and draws 42 | Thompson samples wrt that model. To draw each Thompson sample, this first 43 | samples parameters from the fitted posterior (with likelihood annealed by 44 | ``thompson_temperature``), then finds an optimal sequenc wrt those 45 | parameters via simulated annealing. 46 | 47 | The Bayesian model can be fit either via stochastic variational inference 48 | (SVI, faster but less accurate) or Markov chain Monte Carlo (MCMC, slower 49 | but more accurate). 50 | 51 | :param OrderedDict schema: A schema dict. 52 | :param list constraints: A list of constraints. 53 | :param list feature_blocks: A list of choice blocks for linear regression. 54 | :param list gibbs_blocks: A list of choice blocks for Gibbs sampling. 55 | :param dict experiment: A dict containing all old experiment data. 56 | :param int design_size: Number of designs to try to return (sometimes 57 | fewer designs are found). 58 | :param callable feature_fn: An optional callback to generate additional 59 | features. 60 | :param str response_type: Type of response, one of: "real", "unit_interval". 61 | :param str inference: Inference algorithm, one of: "svi", "mcmc". 62 | :param int mcmc_num_samples: If ``inference == "mcmc"``, this sets the 63 | number of posterior samples to draw from MCMC. Should be larger than 64 | ``design_size``. 65 | :param int mcmc_warmup_steps: If ``inference == "mcmc", this sets the 66 | number of warmup steps for MCMC. Should be the same order of magnitude 67 | as ``mcmc_num_samples``. 68 | :param int svi_num_steps: If ``inference == "svi"`` this sets the number of 69 | steps to run stochastic variational inference. 70 | :param bool svi_reparam: Whether to reparametrize SVI inference. 71 | This only works when ``thompson_temperature == 1``. 72 | :param int sa_num_steps: Number of steps to run simulated annealing, at 73 | each Thompson sample. 74 | :param bool svi_plot: If ``inference == "svi"`` whether to plot loss curve. 75 | :param int max_tries: Number of extra Thompson samples to draw in search 76 | of novel sequences to add to the design. 77 | :param float thompson_temperature: Likelihood annealing temperature at 78 | which Thompson samples are drawn. Defaults to 1. You may want to 79 | increase this if you are have trouble finding novel designs, i.e. if 80 | this function returns fewer designs than you request. 81 | :param bool jit_compile: Optional flag to force jit compilation during 82 | inference. Defaults to safe values for both SVI and MCMC inference. 83 | :param int log_every: Logging interval for internal algorithms. To disable 84 | logging, set this to zero. 85 | 86 | :returns: A design as a set of tuples of choices. 87 | :rtype: set 88 | """ 89 | if jit_compile is None: 90 | if inference == "svi": 91 | jit_compile = False # default to False to avoid jit errors 92 | elif inference == "mcmc": 93 | jit_compile = True # default to True for speed 94 | else: 95 | raise ValueError(f"Unknown inference type: {inference}") 96 | 97 | # Compute extra features. 98 | extra_features = None 99 | if feature_fn is not None: 100 | with torch.no_grad(): 101 | extra_features = feature_fn(experiment["sequences"]) 102 | assert isinstance(extra_features, torch.Tensor) 103 | 104 | # Pass max_batch_id separately as a python int to allow jitting. 105 | max_batch_id = int(experiment["batch_ids"].max()) 106 | assert thompson_temperature > 0 107 | bound_model = functools.partial( 108 | model, 109 | schema, 110 | feature_blocks, 111 | extra_features, 112 | experiment, 113 | max_batch_id=max_batch_id, 114 | response_type=response_type, 115 | likelihood_temperature=thompson_temperature, 116 | ) 117 | # Reparametrization can improve variational inference, 118 | # but isn't compatible with with jit compilation or mcmc. 119 | if inference == "svi" and svi_reparam: 120 | bound_model = AutoReparam()(bound_model) 121 | poutine.block(bound_model)() # initialize reparam 122 | 123 | # Fit a posterior distribution over parameters given experiment data. 124 | with warnings.catch_warnings(): 125 | warnings.filterwarnings( 126 | "ignore", 127 | "torch.tensor results are registered as constants", 128 | torch.jit.TracerWarning, 129 | ) 130 | if inference == "svi": 131 | sampler = fit_svi( 132 | bound_model, 133 | num_steps=svi_num_steps, 134 | plot=svi_plot, 135 | jit_compile=jit_compile, 136 | log_every=log_every, 137 | ) 138 | elif inference == "mcmc": 139 | assert mcmc_num_samples >= design_size 140 | sampler = fit_mcmc( 141 | bound_model, 142 | num_samples=mcmc_num_samples, 143 | warmup_steps=mcmc_warmup_steps, 144 | num_chains=mcmc_num_chains, 145 | jit_compile=jit_compile, 146 | ) 147 | else: 148 | raise ValueError(f"Unknown inference type: {inference}") 149 | 150 | # Repeatedly sample coefficients from the posterior, 151 | # and for each sample find an optimal sequence. 152 | with torch.no_grad(), poutine.mask(mask=False): 153 | logits = experiment["responses"].clamp(min=0.001, max=0.999).logit() 154 | extent = logits.max() - logits.min() 155 | temperature_schedule = extent * torch.logspace(0.0, -2.0, sa_num_steps) 156 | 157 | old_design = set(map(tuple, experiment["sequences"].tolist())) 158 | design: Set[Tuple[int, ...]] = set() 159 | 160 | for i in range(design_size + max_tries): 161 | if log_every: 162 | print(".", end="", flush=True) 163 | coefs = poutine.condition(bound_model, sampler())() 164 | 165 | seq = optimize_simulated_annealing( 166 | schema, 167 | constraints, 168 | gibbs_blocks, 169 | coefs, 170 | feature_fn=feature_fn, 171 | temperature_schedule=temperature_schedule, 172 | log_every=log_every, 173 | ) 174 | 175 | new_seq = tuple(seq.tolist()) 176 | if new_seq not in old_design: 177 | design.add(new_seq) 178 | if len(design) >= design_size: 179 | break 180 | 181 | if len(design) < design_size: 182 | warnings.warn(f"Found design of only {len(design)}/{design_size} sequences") 183 | 184 | return design 185 | -------------------------------------------------------------------------------- /pyroed/optimizers.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import reduce 3 | from typing import Callable, Optional 4 | 5 | import pyro.distributions as dist 6 | import torch 7 | 8 | from .models import linear_response 9 | from .typing import Blocks, Constraints, Schema 10 | 11 | 12 | @torch.no_grad() 13 | def optimize_simulated_annealing( 14 | schema: Schema, 15 | constraints: Constraints, 16 | gibbs_blocks: Blocks, 17 | coefs: dict, 18 | *, 19 | feature_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 20 | temperature_schedule: torch.Tensor, 21 | max_tries=10000, 22 | log_every=100, 23 | ) -> torch.Tensor: 24 | """ 25 | Finds an optimal sequence via annealed Gibbs sampling. 26 | 27 | :param OrderedDict schema: A schema dict. 28 | :param list constraints: A list of constraints. 29 | :param list gibbs_blocks: A list of choice blocks for Gibbs sampling. 30 | :param dict coefs: A dictionary mapping feature tuples to coefficient 31 | tensors. 32 | :returns: The single best found sequence. 33 | :rtype: torch.Tensor 34 | """ 35 | # Set up problem shape. 36 | P = len(schema) 37 | num_categories = torch.tensor([len(v) for v in schema.values()]) 38 | bounds = dist.constraints.integer_interval(0, num_categories) 39 | assert set(sum(gibbs_blocks, [])) == set(schema), "invalid gibbs blocks" 40 | name_to_int = {name: i for i, name in enumerate(schema)} 41 | int_blocks = [[name_to_int[name] for name in block] for block in gibbs_blocks] 42 | 43 | def constraint_fn(seq): 44 | if not constraints: 45 | return True 46 | return reduce(operator.and_, (c(schema, seq) for c in constraints)) 47 | 48 | # Initialize to a single random uniform feasible state. 49 | for i in range(max_tries): 50 | state = torch.stack( 51 | [torch.randint(0, Cp, ()) for Cp in num_categories.tolist()] 52 | ) 53 | assert bounds.check(state).all() 54 | if constraint_fn(state): 55 | break 56 | if not constraint_fn(state): 57 | raise ValueError("Failed to find a feasible initial state") 58 | best_state = state 59 | extra_features = None 60 | if feature_fn is not None: 61 | extra_features = feature_fn(state) 62 | best_logits = float(linear_response(schema, coefs, state, extra_features)) 63 | 64 | # Anneal, recording the best state. 65 | for step, temperature in enumerate(temperature_schedule): 66 | # Choose a random Gibbs block. 67 | b = int(torch.randint(0, len(gibbs_blocks), ())) 68 | block = int_blocks[b] 69 | Cs = [int(num_categories[p]) for p in block] 70 | 71 | # Create a cartesian product over choices within the block. 72 | nbhd = state.expand(tuple(reversed(Cs)) + (P,)).clone() 73 | for i, (p, C) in enumerate(zip(block, Cs)): 74 | nbhd[..., p] = torch.arange(C).reshape((-1,) + (1,) * i) 75 | nbhd = nbhd.reshape(-1, P) 76 | 77 | # Restrict to feasible states. 78 | ok = constraint_fn(nbhd) 79 | if ok is not True: 80 | nbhd = nbhd[ok] 81 | assert bounds.check(nbhd).all() 82 | 83 | # Randomly sample variables in the block wrt an annealed logits. 84 | if feature_fn is not None: 85 | extra_features = feature_fn(nbhd) 86 | logits = linear_response(schema, coefs, nbhd, extra_features) 87 | assert logits.dim() == 1 88 | choice = dist.Categorical(logits=logits / temperature).sample() 89 | state[:] = nbhd[choice] 90 | assert bounds.check(state).all() 91 | assert constraint_fn(state) 92 | 93 | # Save the best response. 94 | current_logits = float(logits[choice]) 95 | if current_logits > best_logits: 96 | best_state = state.clone() 97 | best_logits = current_logits 98 | if log_every and step % log_every == 0: 99 | print( 100 | f"sa step {step} temp={temperature:0.3g} " 101 | f"logits={current_logits:0.6g}" 102 | ) 103 | 104 | return best_state 105 | -------------------------------------------------------------------------------- /pyroed/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyro-ppl/pyroed/c549bd9dc9511e2199ff55fb2c86f84226b9b1c2/pyroed/py.typed -------------------------------------------------------------------------------- /pyroed/testing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Tuple 3 | 4 | import pyro.poutine as poutine 5 | import torch 6 | 7 | from .models import model 8 | from .typing import Blocks, Schema, validate 9 | 10 | 11 | @torch.no_grad() 12 | def generate_fake_data( 13 | schema: Schema, 14 | feature_blocks: Blocks, 15 | sequences_per_batch: int, 16 | num_batches: int = 1, 17 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 18 | """ 19 | Generates a fake dataset for testing. 20 | 21 | :param OrderedDict schema: A schema dict. 22 | :param list constraints: A list of constraints. 23 | :param list feature_blocks: A list of choice blocks for linear regression. 24 | :param int sequences_per_batch: The number of sequences per experiment 25 | batch. 26 | :param int num_batches: The number of experiment batches. 27 | :returns: A pair ``(truth, experiment)``, where ``truth`` is a dict of 28 | true values of latent variables (regression coefficients, etc.), and 29 | ``experiment`` is a standard experiment dict. 30 | :rtype: tuple 31 | """ 32 | B = num_batches 33 | N = sequences_per_batch * B 34 | experiment: Dict[str, torch.Tensor] = {} 35 | feature_fn = None 36 | 37 | # Work around irrelevant PyTorch interface change warnings. 38 | with warnings.catch_warnings(): 39 | warnings.filterwarnings("ignore", "floor_divide is deprecated", UserWarning) 40 | experiment["batch_ids"] = torch.arange(N) // sequences_per_batch 41 | 42 | experiment["sequences"] = torch.stack( 43 | [torch.randint(0, len(choices), (N,)) for choices in schema.values()], dim=-1 44 | ) 45 | trace = poutine.trace(model).get_trace( 46 | schema, feature_blocks, feature_fn, experiment 47 | ) 48 | truth: Dict[str, torch.Tensor] = { 49 | name: site["value"].detach() 50 | for name, site in trace.nodes.items() 51 | if site["type"] == "sample" and not site["is_observed"] 52 | if type(site["fn"]).__name__ != "_Subsample" 53 | if name != "batch_response" # shape varies in time 54 | } 55 | experiment["responses"] = trace.nodes["responses"]["value"].detach() 56 | if __debug__: 57 | validate(schema, feature_blocks=feature_blocks, experiment=experiment) 58 | 59 | return truth, experiment 60 | -------------------------------------------------------------------------------- /pyroed/typing.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from typing import Any, Callable, Dict, List, Optional, Tuple 4 | 5 | import torch 6 | 7 | # Cannot use OrderedDict yet https://stackoverflow.com/questions/41207128 8 | Schema = Dict[str, List[Optional[str]]] 9 | Coefs = Dict[Optional[Tuple[str, ...]], torch.Tensor] 10 | Blocks = List[List[str]] 11 | Constraints = List[Callable] 12 | 13 | 14 | def validate( 15 | schema: Schema, 16 | *, 17 | constraints: Optional[Constraints] = None, 18 | feature_blocks: Optional[Blocks] = None, 19 | gibbs_blocks: Optional[Blocks] = None, 20 | experiment: Optional[Dict[str, torch.Tensor]] = None, 21 | config: Optional[Dict[str, Any]] = None, 22 | ) -> None: 23 | """ 24 | Validates a Pyroed problem specification. 25 | 26 | :param OrderedDict schema: A schema dict. 27 | :param list constraints: An optional list of constraints. 28 | :param list feature_blocks: An optional list of choice blocks for linear 29 | regression. 30 | :param list gibbs_blocks: An optional list of choice blocks for Gibbs 31 | sampling. 32 | :param dict experiment: An optional dict containing all old experiment data. 33 | """ 34 | from .constraints import Constraint # avoid import cycle 35 | 36 | # Validate schema. 37 | assert isinstance(schema, OrderedDict) 38 | for name, values in schema.items(): 39 | assert isinstance(name, str) 40 | assert isinstance(values, list) 41 | assert name 42 | assert values 43 | for value in values: 44 | if value is not None: 45 | assert isinstance(value, str) 46 | assert value is not None 47 | 48 | # Validate constraints. 49 | if constraints is not None: 50 | assert isinstance(constraints, list) 51 | for constraint in constraints: 52 | assert isinstance(constraint, Constraint) 53 | 54 | # Validate feature_blocks. 55 | if feature_blocks is not None: 56 | assert isinstance(feature_blocks, list) 57 | for block in feature_blocks: 58 | assert block, "empty blocks are not allowed" 59 | assert isinstance(block, list) 60 | for col in block: 61 | assert col in schema 62 | assert len({tuple(f) for f in feature_blocks}) == len( 63 | feature_blocks 64 | ), "duplicate feature_blocks" 65 | 66 | # Validate gibbs_blocks. 67 | if gibbs_blocks is not None: 68 | assert isinstance(gibbs_blocks, list) 69 | for block in gibbs_blocks: 70 | assert block, "empty blocks are not allowed" 71 | assert isinstance(block, list) 72 | for col in block: 73 | assert col in schema 74 | assert len({tuple(f) for f in gibbs_blocks}) == len( 75 | gibbs_blocks 76 | ), "duplicate gibbs_blocks" 77 | 78 | # Validate experiment. 79 | if experiment is not None: 80 | assert isinstance(experiment, dict) 81 | allowed_keys = {"sequences", "batch_ids", "responses"} 82 | required_keys = {"sequences", "batch_ids"} 83 | assert allowed_keys.issuperset(experiment) 84 | assert required_keys.issubset(experiment) 85 | 86 | sequences = experiment["sequences"] 87 | assert isinstance(sequences, torch.Tensor) 88 | assert sequences.dtype == torch.long 89 | assert sequences.dim() == 2 90 | assert sequences.shape[-1] == len(schema) 91 | 92 | batch_id = experiment["batch_ids"] 93 | assert isinstance(batch_id, torch.Tensor) 94 | assert batch_id.dtype == torch.long 95 | assert batch_id.shape == sequences.shape[:1] 96 | 97 | response = experiment.get("responses") 98 | if response is not None: 99 | assert isinstance(response, torch.Tensor) 100 | assert torch.is_floating_point(response) 101 | assert response.shape == sequences.shape[:1] 102 | assert -math.inf < response.min() 103 | assert response.max() < math.inf 104 | if config is not None: 105 | response_type = config.get("response_type", "unit_interval") 106 | if response_type == "unit_interval": 107 | message = ( 108 | "response outside of unit interval, " 109 | 'consider configuring response_type="real"' 110 | ) 111 | assert 0 <= response.min(), message 112 | assert response.max() <= 1, message 113 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 89 3 | ignore = E741,E203,W503 4 | exclude = build, dist 5 | 6 | [isort] 7 | profile = black 8 | skip_glob = .ipynb_checkpoints 9 | known_first_party = pyroed 10 | known_third_party = opt_einsum, pyro, torch, torchvision 11 | 12 | [tool:pytest] 13 | filterwarnings = error 14 | ignore::PendingDeprecationWarning 15 | ignore::DeprecationWarning 16 | once::DeprecationWarning 17 | 18 | [mypy] 19 | python_version = 3.7 20 | check_untyped_defs = True 21 | ignore_missing_imports = True 22 | warn_incomplete_stub = True 23 | warn_return_any = True 24 | warn_unreachable = True 25 | warn_unused_configs = True 26 | warn_unused_ignores = True 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from setuptools import find_packages, setup 5 | 6 | with open("pyroed/__init__.py") as f: 7 | for line in f: 8 | match = re.match('^__version__ = "(.*)"$', line) 9 | if match: 10 | __version__ = match.group(1) 11 | break 12 | 13 | try: 14 | long_description = open("README.md", encoding="utf-8").read() 15 | except Exception as e: 16 | sys.stderr.write("Failed to read README.md: {}\n".format(e)) 17 | sys.stderr.flush() 18 | long_description = "" 19 | 20 | setup( 21 | name="pyroed", 22 | version=__version__, 23 | description="Sequence design using Pyro", 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | packages=find_packages(include=["pyroed"]), 27 | package_data={"pyroed": ["py.typed"]}, 28 | url="https://github.com/pyro-ppl/pyroed", 29 | author="Pyro team at the Broad Institute of MIT and Harvard", 30 | author_email="fritz.obermeyer@gmail.com", 31 | install_requires=[ 32 | "matplotlib", 33 | "pandas", 34 | "pyro-ppl>=1.7", 35 | "scipy", 36 | ], 37 | extras_require={ 38 | "test": [ 39 | "black", 40 | "isort>=5.0", 41 | "flake8", 42 | "pytest>=5.0", 43 | "mypy>=0.812", 44 | ], 45 | }, 46 | python_requires=">=3.7", 47 | keywords="optimal experimental design pyro", 48 | license="Apache 2.0", 49 | classifiers=[ 50 | "Intended Audience :: Developers", 51 | "Intended Audience :: Science/Research", 52 | "License :: OSI Approved :: Apache Software License", 53 | "Operating System :: POSIX :: Linux", 54 | "Operating System :: MacOS :: MacOS X", 55 | "Programming Language :: Python :: 3.7", 56 | "Programming Language :: Python :: 3.8", 57 | "Programming Language :: Python :: 3.9", 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import pyro 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def pytest_runtest_setup(item): 9 | np.random.seed(20220324) 10 | pyro.set_rng_seed(20220324) 11 | pyro.enable_validation(True) 12 | pyro.clear_param_store() 13 | -------------------------------------------------------------------------------- /test/test_constraints.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List, Optional 3 | 4 | from pyroed.api import encode_design 5 | from pyroed.constraints import ( 6 | AllDifferent, 7 | And, 8 | Function, 9 | Iff, 10 | IfThen, 11 | Not, 12 | Or, 13 | TakesValue, 14 | TakesValues, 15 | Xor, 16 | ) 17 | from pyroed.typing import Schema 18 | 19 | 20 | def stringify(bools: List[bool]) -> str: 21 | return "".join("1" if x else "0" for x in bools) 22 | 23 | 24 | def test_immune_sequence(): 25 | SCHEMA: Schema = OrderedDict() 26 | SCHEMA["Protein 1"] = ["Prot1", "Prot2", None] 27 | SCHEMA["Protein 2"] = ["Prot3", "HLA1", "HLA2", "HLA3", "HLA4", None] 28 | SCHEMA["Signalling Pep"] = ["Sig1", "Sig2", None] 29 | SCHEMA["EP"] = [f"Ep{i}" for i in range(1, 10 + 1)] 30 | SCHEMA["EP"].append(None) 31 | SCHEMA["Linker"] = ["Link1", None] 32 | SCHEMA["Internal"] = ["Int1", "Int2", "Int3", "Int3", None] 33 | SCHEMA["2A-1"] = ["twoa1", "twoa2", None] 34 | SCHEMA["2A-2"] = ["twoa3", "twoa4", None] 35 | SCHEMA["2A-3"] = [f"twoa{i}" for i in range(1, 7 + 1)] 36 | 37 | CONSTRAINTS = [ 38 | AllDifferent("2A-1", "2A-2", "2A-3"), 39 | Iff(TakesValue("Protein 1", None), TakesValue("2A-1", None)), 40 | Iff(TakesValue("Signalling Pep", None), TakesValue("EP", None)), 41 | Iff(TakesValue("EP", None), TakesValue("Linker", None)), 42 | IfThen(TakesValue("Protein 2", None), TakesValue("Internal", None)), 43 | Iff(TakesValue("Protein 2", "Prot3"), TakesValue("2A-2", None)), 44 | ] 45 | 46 | design: List[List[Optional[str]]] = [ 47 | ["Prot1", "Prot3", "Sig1", "Ep1", "Link1", "Int1", "twoa1", None, "twoa2"], 48 | ["Prot1", "Prot3", "Sig1", "Ep1", "Link1", "Int1", "twoa1", None, "twoa1"], 49 | [None, "Prot3", "Sig1", "Ep1", "Link1", "Int1", "twoa1", None, "twoa2"], 50 | ["Prot1", "Prot3", "Sig1", None, None, None, "twoa1", None, "twoa2"], 51 | ["Prot1", "Prot3", "Sig1", "Ep1", None, "Int1", "twoa1", None, "twoa2"], 52 | ["Prot1", None, "Sig1", "Ep1", "Link1", "Int1", "twoa1", "twoa4", "twoa2"], 53 | ["Prot1", "Prot3", "Sig1", "Ep1", "Link1", "Int1", "twoa1", "twoa4", "twoa2"], 54 | ] 55 | 56 | sequences = encode_design(SCHEMA, design) 57 | actual = [c(SCHEMA, sequences).tolist() for c in CONSTRAINTS] 58 | assert actual[0] == [True, False, True, True, True, True, True] 59 | assert actual[1] == [True, True, False, True, True, True, True] 60 | assert actual[2] == [True, True, True, False, True, True, True] 61 | assert actual[3] == [True, True, True, True, False, True, True] 62 | assert actual[4] == [True, True, True, True, True, False, True] 63 | assert actual[5] == [True, True, True, True, True, True, False] 64 | 65 | 66 | def test_function(): 67 | SCHEMA = OrderedDict() 68 | SCHEMA["foo"] = ["a", "b", "c", None] 69 | SCHEMA["bar"] = ["a", "b", None] 70 | 71 | CONSTRAINTS = [ 72 | Function(lambda x: x.sum(-1) <= 0), 73 | Function(lambda x: x.sum(-1) <= 1), 74 | Function(lambda x: x.sum(-1) <= 2), 75 | Function(lambda x: x.sum(-1) <= 3), 76 | Function(lambda x: x.sum(-1) <= 4), 77 | Function(lambda x: x.sum(-1) <= 5), 78 | ] 79 | 80 | design: List[List[Optional[str]]] = [ 81 | ["a", "a"], 82 | ["a", "b"], 83 | ["a", None], 84 | ["b", "a"], 85 | ["b", "b"], 86 | ["b", None], 87 | ["c", "a"], 88 | ["c", "b"], 89 | ["c", None], 90 | [None, "a"], 91 | [None, "b"], 92 | [None, None], 93 | ] 94 | 95 | sequences = encode_design(SCHEMA, design) 96 | actual = [c(SCHEMA, sequences).tolist() for c in CONSTRAINTS] 97 | assert stringify(actual[0]) == "100000000000" 98 | assert stringify(actual[1]) == "110100000000" 99 | assert stringify(actual[2]) == "111110100000" 100 | assert stringify(actual[3]) == "111111110100" 101 | assert stringify(actual[4]) == "111111111110" 102 | assert stringify(actual[5]) == "111111111111" 103 | 104 | 105 | def test_takes_value(): 106 | SCHEMA = OrderedDict() 107 | SCHEMA["foo"] = ["a", "b", "c", None] 108 | SCHEMA["bar"] = ["a", "b", None] 109 | 110 | CONSTRAINTS = [ 111 | TakesValue("foo", "a"), 112 | TakesValue("foo", "b"), 113 | TakesValue("foo", "c"), 114 | TakesValue("foo", None), 115 | TakesValue("bar", "a"), 116 | TakesValue("bar", "b"), 117 | TakesValue("bar", None), 118 | ] 119 | 120 | design: List[List[Optional[str]]] = [ 121 | ["a", "a"], 122 | ["a", "b"], 123 | ["a", None], 124 | ["b", "a"], 125 | ["b", "b"], 126 | ["b", None], 127 | ["c", "a"], 128 | ["c", "b"], 129 | ["c", None], 130 | [None, "a"], 131 | [None, "b"], 132 | [None, None], 133 | ] 134 | 135 | sequences = encode_design(SCHEMA, design) 136 | actual = [c(SCHEMA, sequences).tolist() for c in CONSTRAINTS] 137 | assert stringify(actual[0]) == "111000000000" 138 | assert stringify(actual[1]) == "000111000000" 139 | assert stringify(actual[2]) == "000000111000" 140 | assert stringify(actual[3]) == "000000000111" 141 | assert stringify(actual[4]) == "100100100100" 142 | assert stringify(actual[5]) == "010010010010" 143 | assert stringify(actual[6]) == "001001001001" 144 | 145 | 146 | def test_takes_values(): 147 | SCHEMA = OrderedDict() 148 | SCHEMA["foo"] = ["a", "b", "c", None] 149 | SCHEMA["bar"] = ["a", "b", None] 150 | 151 | CONSTRAINTS = [ 152 | TakesValues("foo", "a"), 153 | TakesValues("foo", "b", "c"), 154 | TakesValues("foo", "a", None), 155 | TakesValues("bar", "a", "b", None), 156 | TakesValues("bar", "b"), 157 | TakesValues("bar"), 158 | ] 159 | 160 | design: List[List[Optional[str]]] = [ 161 | ["a", "a"], 162 | ["a", "b"], 163 | ["a", None], 164 | ["b", "a"], 165 | ["b", "b"], 166 | ["b", None], 167 | ["c", "a"], 168 | ["c", "b"], 169 | ["c", None], 170 | [None, "a"], 171 | [None, "b"], 172 | [None, None], 173 | ] 174 | 175 | sequences = encode_design(SCHEMA, design) 176 | actual = [c(SCHEMA, sequences).tolist() for c in CONSTRAINTS] 177 | assert stringify(actual[0]) == "111000000000" 178 | assert stringify(actual[1]) == "000111111000" 179 | assert stringify(actual[2]) == "111000000111" 180 | assert stringify(actual[3]) == "111111111111" 181 | assert stringify(actual[4]) == "010010010010" 182 | assert stringify(actual[5]) == "000000000000" 183 | 184 | 185 | def test_logic(): 186 | SCHEMA = OrderedDict() 187 | SCHEMA["foo"] = ["a", None] 188 | SCHEMA["bar"] = ["a", None] 189 | 190 | foo = TakesValue("foo", "a") 191 | bar = TakesValue("bar", "a") 192 | CONSTRAINTS = [ 193 | foo, 194 | bar, 195 | Not(foo), 196 | Not(bar), 197 | And(foo, bar), 198 | Or(foo, bar), 199 | Xor(foo, bar), 200 | Iff(foo, bar), 201 | IfThen(foo, bar), 202 | ] 203 | 204 | design: List[List[Optional[str]]] = [ 205 | ["a", "a"], 206 | ["a", None], 207 | [None, "a"], 208 | [None, None], 209 | ] 210 | 211 | sequences = encode_design(SCHEMA, design) 212 | actual = [c(SCHEMA, sequences).tolist() for c in CONSTRAINTS] 213 | assert stringify(actual[0]) == "1100" 214 | assert stringify(actual[1]) == "1010" 215 | assert stringify(actual[2]) == "0011" 216 | assert stringify(actual[3]) == "0101" 217 | assert stringify(actual[4]) == "1000" 218 | assert stringify(actual[5]) == "1110" 219 | assert stringify(actual[6]) == "0110" 220 | assert stringify(actual[7]) == "1001" 221 | assert stringify(actual[8]) == "1011" 222 | -------------------------------------------------------------------------------- /test/test_e2e.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from collections import OrderedDict 4 | from typing import List, Optional 5 | 6 | import pytest 7 | import torch 8 | 9 | from pyroed import ( 10 | decode_design, 11 | encode_design, 12 | get_next_design, 13 | start_experiment, 14 | update_experiment, 15 | ) 16 | from pyroed.constraints import AllDifferent 17 | from pyroed.criticism import criticize 18 | from pyroed.typing import Constraints, Schema 19 | 20 | 21 | def example_feature_fn(sequence): 22 | sequence = sequence.to(torch.get_default_dtype()) 23 | return torch.stack( 24 | [ 25 | sequence.sum(-1), 26 | sequence.max(-1).values, 27 | sequence.min(-1).values, 28 | sequence.mean(-1), 29 | sequence.std(-1), 30 | ], 31 | dim=-1, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "response_type, feature_fn", 37 | [ 38 | ("unit_interval", None), 39 | ("real", None), 40 | ("real", example_feature_fn), 41 | ], 42 | ) 43 | @pytest.mark.parametrize( 44 | "inference, jit_compile", 45 | [ 46 | ("svi", False), 47 | pytest.param("svi", True, marks=[pytest.mark.xfail(reason="jit error")]), 48 | pytest.param("mcmc", False, marks=[pytest.mark.skip(reason="slow")]), 49 | ("mcmc", True), 50 | ], 51 | ) 52 | def test_end_to_end(inference, jit_compile, response_type, feature_fn): 53 | # Declare a problem. 54 | SCHEMA: Schema = OrderedDict() 55 | SCHEMA["foo"] = ["a", "b", None] 56 | SCHEMA["bar"] = ["a", "b", "c", None] 57 | SCHEMA["baz"] = ["a", "b", "c", "d", None] 58 | 59 | CONSTRAINTS: Constraints = [AllDifferent("bar", "baz")] 60 | 61 | FEATURE_BLOCKS = [["foo"], ["bar", "baz"]] 62 | 63 | GIBBS_BLOCKS = [["foo", "bar"], ["bar", "baz"]] 64 | 65 | design_size = 4 66 | design: List[List[Optional[str]]] = [ 67 | ["a", "b", "c"], 68 | ["b", "c", "a"], 69 | ["a", "b", None], 70 | ["b", "c", None], 71 | ] 72 | 73 | # Initialize an experiment. 74 | sequences = encode_design(SCHEMA, design) 75 | if response_type == "unit_interval": 76 | responses = torch.rand(design_size) 77 | elif response_type == "real": 78 | responses = torch.rand(design_size) 79 | batch_ids = torch.zeros(design_size, dtype=torch.long) 80 | experiment = start_experiment(SCHEMA, sequences, responses, batch_ids) 81 | 82 | # Draw new batches. 83 | config = { 84 | "response_type": response_type, 85 | "inference": inference, 86 | "jit_compile": jit_compile, 87 | "mcmc_num_samples": 100, 88 | "mcmc_warmup_steps": 100, 89 | "svi_num_steps": 100, 90 | "sa_num_steps": 100, 91 | "log_every": 10, 92 | } 93 | for step in range(2): 94 | sequences = get_next_design( 95 | SCHEMA, 96 | CONSTRAINTS, 97 | FEATURE_BLOCKS, 98 | GIBBS_BLOCKS, 99 | experiment, 100 | design_size=design_size, 101 | feature_fn=feature_fn, 102 | config=config, 103 | ) 104 | 105 | design = decode_design(SCHEMA, sequences) 106 | actual_sequences = encode_design(SCHEMA, design) 107 | assert torch.allclose(actual_sequences, sequences) 108 | responses = torch.rand(design_size) 109 | 110 | if step == 0: 111 | experiment = update_experiment(SCHEMA, experiment, sequences, responses) 112 | assert len(experiment["sequences"]) == design_size * (2 + step) 113 | else: 114 | test_data = { 115 | "sequences": sequences, 116 | "responses": responses, 117 | } 118 | 119 | # Criticize. 120 | with tempfile.TemporaryDirectory() as dirname: 121 | criticize( 122 | SCHEMA, 123 | FEATURE_BLOCKS, 124 | experiment, 125 | test_data, 126 | feature_fn=feature_fn, 127 | filename=os.path.join(dirname, "criticize.pdf"), 128 | ) 129 | --------------------------------------------------------------------------------