├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── Handbook.pdf ├── LICENSE ├── README.md ├── assets ├── logo-square-bw.png ├── logo-square-bw.xcf ├── logo-square.png ├── logo-square.xcf ├── logo.png └── logo.xcf ├── demo ├── 7-aux │ └── resnetweights55 │ │ ├── _CHECKPOINT_METADATA │ │ ├── _METADATA │ │ ├── _sharding │ │ ├── d │ │ └── c802c67d362c7a34e02297c9e0cbcd85 │ │ ├── manifest.ocdbt │ │ └── ocdbt.process_0 │ │ ├── d │ │ ├── 3f1344dd12d907d616120783e1f73fab │ │ ├── 72bebf4e70b763393a847042cae09d9f │ │ ├── a4a999e8901a8f3285edda24d5a3ea81 │ │ ├── c803c308d312a8392790127807c2d766 │ │ ├── d2fac1f5bd5ce5c5a4adc784fe5d8d33 │ │ ├── d57d0e0fe8ba3a7e2265ead9ba2ff67f │ │ └── faeab70d4852b3d07f3c7d03e60b7f95 │ │ └── manifest.ocdbt ├── Memonomicon.ipynb ├── README.md ├── demo-23.ipynb ├── demo-7segment.ipynb ├── demo-cheryl.ipynb ├── demo-dining-cryptographers.ipynb ├── demo-eig.ipynb ├── demo-empowerment.py ├── demo-fib.py ├── demo-i-pomdp.ipynb ├── demo-mdp.ipynb ├── demo-mdp.wppl ├── demo-monty.ipynb ├── demo-newcomb.ipynb ├── demo-pc.ipynb ├── demo-persuasion.ipynb ├── demo-physics.ipynb ├── demo-politeness.ipynb ├── demo-pomdp.ipynb ├── demo-risk-aversion.ipynb ├── demo-rsa.py ├── demo-sally-anne.ipynb ├── demo-scalar.py ├── demo-schelling.ipynb ├── demo-schelling.wppl ├── demo-takeaway.ipynb ├── demo-ultimatum.ipynb ├── memo └── test.py ├── docs ├── index.html └── tomalot │ ├── index.html │ └── tomalot.pdf ├── memo ├── __init__.py ├── codegen.py ├── comic.py ├── core.py ├── lib.py ├── parse.py ├── utils.py └── version.py └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .DS_Store 3 | __pycache__/ 4 | paper/ 5 | scratch/* 6 | .env/ 7 | .cpu_env 8 | figures.key 9 | out.png 10 | demo/*-aux 11 | demo/*.dot 12 | demo/*.png 13 | Handbook.key 14 | -------------------------------------------------------------------------------- /Handbook.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/Handbook.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 Kartik Chandra and Tony Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![memo's logo](https://github.com/kach/memo/raw/main/assets/logo.png) 2 | 3 | memo is a probabilistic programming language for expressing computational cognitive models involving **recursive reasoning about reasoning**. memo inherits from the tradition of WebPPL-based Bayesian modeling (see [probmods](http://probmods.org/), [agentmodels](https://agentmodels.org/), and [problang](https://www.problang.org/)), but aims to make models **easier to write and run** by taking advantage of modern programming language techniques and hardware capabilities (including GPUs!). As a result, models are often significantly simpler to express (we've seen codebases shrink by a **factor of 3x or more**), and dramatically faster to execute and fit to data (we've seen **speedups of 3,000x or more**). In idiomatic memo, a POMDP solver is 15 lines of code, and is just as fast as a hand-optimized solver written in 200 lines of code. 4 | 5 | memo stands for: mental modeling, memoized matrix operations, model-expressed-model-optimized, and metacognitive memos. 6 | 7 | ## Installing memo 8 | 9 | 1. memo is based on Python. Before installing memo, make sure you have Python 3.12 or higher installed. You can check this by running `python --version`. 10 | 2. Next, install [JAX](https://github.com/google/jax), a Python module that memo uses to produce fast, differentiable, GPU-enabled code. If you don't have a GPU, then running `pip install jax` should be enough. Otherwise, please consult the JAX website for installation instructions. You can check if JAX is installed by running `import jax` in Python. 11 | 3. Finally, install memo by running `pip install memo-lang`. You can check if memo is installed by running `from memo import memo` in Python. (Make sure to install `memo-lang`, not `memo`! The latter is a different package, unrelated to this project.) 12 | 13 | ## Learning memo 14 | 15 | There are many resources available for learning memo. 16 | 1. The [Memonomicon](./demo/Memonomicon.ipynb) gives a brief tour of the language, and an example of how to build a model and fit it to data by parallel grid search and/or gradient descent. 17 | 2. You can watch a [video tutorial](https://www.dropbox.com/scl/fi/c3jjup1lheowfppbz41zr/memo-live-tutorial.mp4?rlkey=ce7reeadff2nh2ktqh3tubbik&st=lai8yx1h&dl=0) that covers similar material. You can also check out a [talk given at LAFI '25](https://www.youtube.com/live/RLEFVgx2UWk?t=12500s) that offers a bigger-picture overview of memo. 18 | 3. The [Handbook](./Handbook.pdf) is a complete reference for memo's syntactic constructs. 19 | 4. This repository includes over a dozen classic examples of recursive reasoning models implemented in memo, which you can find in the [demo directory](./demo/). 20 | 6. I am happy to give a short hands-on tutorial on memo in your lab. Just email me to ask! 21 | 22 | You may also be looking for general resources on the theory behind memo modeling. 23 | 1. For background on the theory of decision making under uncertainty, e.g. MDPs and POMDPs, we recommending consulting _Decision Making Under Uncertainty_ as a reference. You can read the entire book for free online [here](https://algorithmsbook.com/decisionmaking/). 24 | 2. For background on Bayesian models of theory of mind, we recommend consulting chapter 14 of _Bayesian Models of Cognition_ as a reference. You can read the published version [here](https://mitpress.ublish.com/ebook/bayesian-models-of-cognition-reverse-engineering-the-mind-preview/12799/341) and a PDF preprint [here](https://www.tomerullman.org/papers/BBB_chapter14.pdf). 25 | 3. Dae Houlihan (Dartmouth University) is teaching a winter '25 [course](https://comosoco.daeh.info) on computational models of social cognition using memo. 26 | 27 | ## The memo community 28 | 29 | Here are some ways to engage with the memo community. 30 | 31 | 1. For updates on memo's development, we _strongly_ encourage you to subscribe to our low-traffic monthly announcements mailing list [here](https://lists.csail.mit.edu/mailman/listinfo/memo-lang). 32 | 2. To ask questions about memo, and to get help from other memo users, use [Github Discussions](https://github.com/kach/memo/discussions). Note that you will need a Github account to participate. 33 | 3. For live support, we host memOH (**memo office hours**) every Tuesday at 2pm ET. Email Kartik for the zoom link! 34 | 35 | ## The memo on memo 36 | 37 | An early draft of a paper on memo's design and implementation is available [here](https://osf.io/preprints/psyarxiv/pt863). If you use memo in your work, you are invited to cite this paper: 38 | 39 | ```bibtex 40 | @article{chandra2025memo, 41 | title={A Domain-Specific Probabilistic Programming Language for Reasoning About Reasoning (or: a memo on memo)}, 42 | year={2025}, 43 | author={Kartik Chandra and Tony Chen and Joshua B. Tenenbaum and Jonathan Ragan-Kelley}, 44 | journal={psyarxiv preprint}, 45 | url={https://doi.org/10.31234/osf.io/pt863} 46 | } 47 | ``` 48 | 49 | I would love to hear about any research using memo. Please don't hesitate to share your work with me! 50 | 51 | ## FAQ 52 | 53 |
How do I capitalize memo? Is it Memo? MEMO? MeMo? 54 | 55 | "memo," all-lowercase. 56 |
57 | 58 |
When should I use memo rather than Gen or WebPPL? 59 | 60 | memo's core competence is fast tabular/enumerative inference on models with recursive reasoning about reasoning. That covers a wide range of common models: from RSA, to POMDP planning (value iteration = tabular operations), to inverse planning. In general, if you are making nested queries, we recommend using memo. 61 | 62 | There are however two particular cases where you may prefer another PPL: 63 | 1. If you are interested specifically in modeling a sophisticated inference scheme, such as MCMC, particle filters, or variational inference, then we recommend trying Gen. _(But make sure you really need those tools — the fast enumerative inference provided by memo is often sufficient for many common kinds of models!)_ 64 | 2. If you are performing inference over an unbounded domain of hypotheses with varied structure, such as programs generated by a grammar, then we recommend trying Gen or WebPPL because memo's tabular enumerative inference can only handle probability distributions with finite support. _(But if you are okay with inference over a "truncated" domain, e.g. the top 1,000,000 shortest programs, then memo can do that! Similarly, memo can handle continuous domains by discretizing finely.)_ 65 | 66 | The aforementioned cases are explicitly out of scope for memo. By specializing memo to a particular commonly-used class of models and inference strategies, we can produce extremely fast code that is difficult for general-purpose PPLs to produce. 67 |
68 | 69 |
Okay, so how does memo produce such fast code? 70 | 71 | memo compiles enumerative inference to JAX array programs, which can be run extremely fast. The reason for this is that array programs are inherently very easy to execute in parallel (by performing operations on each element of the array independently). Modern hardware is particularly good at parallel processing. 72 |
73 | 74 |
What exactly is JAX? 75 | 76 | [JAX](https://github.com/google/jax) is a library developed by Google that takes Python array programs (similar to NumPy) and compiles them to very fast code that can run on CPUs and GPUs, taking advantage of modern hardware functionality. JAX supports a lot of Google's deep learning, because neural networks involve a lot of array operations. memo compiles your probabilistic models into JAX array programs, and JAX further compiles those array programs into machine code. 77 | 78 | Note that JAX has some unintuitive behaviors. We recommend reading [this guide](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) to get a sense of its "sharp edges." 79 |
80 | 81 |
I'm used to thinking of probabilistic programming as sampling execution traces from generative models. Should I think of memo the same way? 82 | 83 | One way to think about memo is that it simulates _all_ possible traces at the same time. There is no need for sampling, because we always have access to the full posterior distribution. 84 |
85 | 86 |
Is memo a research prototype, or a mature software product? Should I invest in learning memo? 87 | 88 | While memo originated as a research project, it is now stable software that is being used by many labs around the world. memo will be supported for a long time to come, and you should feel confident in using memo for your own projects. 89 |
90 | 91 | --- 92 | 93 |
94 | I installed memo but importing memo gives an error. 95 | 96 | Did you accidentally pip-install the (unrelated) package [memo](https://pypi.org/project/memo/) instead of [memo-lang](https://pypi.org/project/memo-lang/)? 97 |
98 | 99 |
100 | I installed memo on my Mac, but running models gives a weird JAX error about "AVX". 101 | 102 | The common cause of this is that you have a modern Mac (with an ARM processor), but an old version of Python (compiled for x86). We recommend the following installation strategy on ARM-based Macs: 103 | 1. Do not use conda. 104 | 2. Install Homebrew. Make sure you have the ARM version of brew: `brew --prefix` should be `/opt/homebrew`, and `brew config` should say `Rosetta 2: false`. If this is not the case, you have the x86 version of brew, which you should uninstall. 105 | 3. Install Python via `brew install python3`. Ensure that `python3 --version` works as expected, and that `which python3` points to something in `/opt/homebrew/bin/`. 106 | 4. In your project directory, create a virtual environment via `python3 -m venv venv`. 107 | 5. Activate the virtual environment via `. venv/bin/activate`. Your prompt should now begin with `(venv)`. 108 | 6. Install memo via `pip install memo-lang`. 109 |
110 | 111 |
How do I use memo with a GPU? 112 | 113 | Assuming you have [installed JAX with GPU support](https://jax.readthedocs.io/en/latest/installation.html), all you have to do is plug in your GPU! 114 |
115 | 116 |
Can I run memo on Apple's "metal" platform? 117 | 118 | Yes! See this issue for details: https://github.com/kach/memo/issues/66 119 |
120 | 121 | --- 122 | 123 |
VS Code underlines all my memo code in red. It's a bloodbath out there! 124 | 125 | If you write `# type: ignore` at the top of your file (even before the imports), then VS Code will suppress the red lines. If you use Ruff, additionally add `# ruff: noqa`. 126 |
127 | 128 |
Sometimes my model returns 0 in unexpected places, often at the edges/extreme values of distributions. 129 | 130 | This can be caused by numerical stability errors. For example, if a `wpp=` expression gets too big, then it might "overflow" to infinity, and wreak havoc downstream. Similarly, if a `wpp=` expression returns 0 for all possible choices, then normalizing that distribution causes a division-by-zero error that wreaks havoc downstream. This havoc usually comes in the form of calculations being unexpectedly clipped to 0. 131 | 132 | So, if you are seeing unexpected 0s, we recommend inspecting your `wpp=` expressions to see whether they could be returning very large or very small values. Often, you can fix the problem by adding a little epsilon value (e.g. `wpp=f(x) + 1e-5`). 133 |
134 | 135 |
Some of my output array's dimensions are unexpectedly of size 1. 136 | 137 | memo attempts to minimize redundant computation. If the output of your model doesn't depend on an input axis, then instead of repeating the computation along that axis, memo will set that axis to size 1. The idea is that [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) will keep the array compatible with downstream computations. 138 | 139 | As an example, consider the following models: 140 | 141 | ```python 142 | X = np.arange(10) 143 | 144 | @memo 145 | def f[a: X, b: X](): 146 | return a 147 | f().shape # (10, 1) because output is independent of b 148 | 149 | @memo 150 | def f[a: X, b: X](): 151 | return b 152 | f().shape # (1, 10) because output is independent of a 153 | 154 | @memo 155 | def f[a: X, b: X](): 156 | return a + b 157 | f().shape # (10, 10) because output depends on a and b 158 | 159 | @memo 160 | def f[a: X, b: X](): 161 | return 999 162 | f().shape # (1, 1) because output depends on neither a nor b 163 | ``` 164 |
165 | 166 |
How can I visualize what's going on with my model in "comic-book" format? 167 | 168 | Use `@memo(save_comic="filename")` instead of just `@memo`. memo will produce a [Graphviz](https://graphviz.org/) `filename.dot` file that you can [render online](https://dreampuf.github.io/GraphvizOnline/). If you have Graphviz installed, memo will also automatically render a `filename.png` file for you. 169 | 170 |
171 | 172 |
How can I get model outputs in pandas/xarray format? 173 | Pass in the return_pandas=True or return_xarray=True keyword arguments to your model. Your model will then return a tuple: the first argument will be the raw array, and the second argument will have a .pandas or .xarray property, respectively. 174 |
175 | -------------------------------------------------------------------------------- /assets/logo-square-bw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo-square-bw.png -------------------------------------------------------------------------------- /assets/logo-square-bw.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo-square-bw.xcf -------------------------------------------------------------------------------- /assets/logo-square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo-square.png -------------------------------------------------------------------------------- /assets/logo-square.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo-square.xcf -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo.png -------------------------------------------------------------------------------- /assets/logo.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/assets/logo.xcf -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/_CHECKPOINT_METADATA: -------------------------------------------------------------------------------- 1 | {"init_timestamp_nsecs": 1729646697790768321, "commit_timestamp_nsecs": 1729646698137921362} -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/_METADATA: -------------------------------------------------------------------------------- 1 | {"tree_metadata": {"('Dense_0', 'bias')": {"key_metadata": [{"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('Dense_0', 'kernel')": {"key_metadata": [{"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'conv_proj', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "conv_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'norm_proj', 'bias')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_2', 'norm_proj', 'scale')": {"key_metadata": [{"key": "ResNetBlock_2", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'conv_proj', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "conv_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'norm_proj', 'bias')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_4', 'norm_proj', 'scale')": {"key_metadata": [{"key": "ResNetBlock_4", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'conv_proj', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "conv_proj", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'norm_proj', 'bias')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_6', 'norm_proj', 'scale')": {"key_metadata": [{"key": "ResNetBlock_6", "key_type": 2}, {"key": "norm_proj", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('ResNetBlock_7', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "ResNetBlock_7", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('bn_init', 'bias')": {"key_metadata": [{"key": "bn_init", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('bn_init', 'scale')": {"key_metadata": [{"key": "bn_init", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('conv_init', 'kernel')": {"key_metadata": [{"key": "conv_init", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}}, "use_zarr3": false} -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/_sharding: -------------------------------------------------------------------------------- 1 | {"RGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","RGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMC5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMS5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5jb252X3Byb2oua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5ub3JtX3Byb2ouYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMi5ub3JtX3Byb2ouc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfMy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5jb252X3Byb2oua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5ub3JtX3Byb2ouYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNC5ub3JtX3Byb2ouc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNS5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5jb252X3Byb2oua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5ub3JtX3Byb2ouYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNi5ub3JtX3Byb2ouc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5Db252XzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5Db252XzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","UmVzTmV0QmxvY2tfNy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","Y29udl9pbml0Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","Ym5faW5pdC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","Ym5faW5pdC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/d/c802c67d362c7a34e02297c9e0cbcd85: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/d/c802c67d362c7a34e02297c9e0cbcd85 -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/manifest.ocdbt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/manifest.ocdbt -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/3f1344dd12d907d616120783e1f73fab: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/3f1344dd12d907d616120783e1f73fab -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/72bebf4e70b763393a847042cae09d9f: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/72bebf4e70b763393a847042cae09d9f -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/a4a999e8901a8f3285edda24d5a3ea81: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/a4a999e8901a8f3285edda24d5a3ea81 -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/c803c308d312a8392790127807c2d766: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/c803c308d312a8392790127807c2d766 -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/d2fac1f5bd5ce5c5a4adc784fe5d8d33: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/d2fac1f5bd5ce5c5a4adc784fe5d8d33 -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/d57d0e0fe8ba3a7e2265ead9ba2ff67f: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/d57d0e0fe8ba3a7e2265ead9ba2ff67f -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/d/faeab70d4852b3d07f3c7d03e60b7f95: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/d/faeab70d4852b3d07f3c7d03e60b7f95 -------------------------------------------------------------------------------- /demo/7-aux/resnetweights55/ocdbt.process_0/manifest.ocdbt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/demo/7-aux/resnetweights55/ocdbt.process_0/manifest.ocdbt -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # Index of demos 2 | 3 | ## Language (variations on RSA) 4 | 5 | - [Lewis game](./demo-rsa.py) (discussed in paper) 6 | - [Scalar implicature](./demo-scalar.py) (discussed in paper) 7 | - [7-segment displays](./demo-7segment.ipynb) (discussed in paper) 8 | (we show how memo can interface with deep learning and computer vision) 9 | - [Polite speech](./demo-politeness.ipynb) 10 | 11 | ## Economic games 12 | 13 | - [Schelling game](./demo-schelling.ipynb) (discussed in paper) 14 | - [Ultimatum game](./demo-ultimatum.ipynb) 15 | - [Guess 2/3 of average](./demo-23.ipynb) 16 | (we use memo to fit a model to real-world data collected by the New York Times) 17 | - [Bayesian Persuasion](./demo-persuasion.ipynb) 18 | (we use memo to recover a result presented in a classic paper on persuasion) 19 | - [Risk aversion](./demo-risk-aversion.ipynb) 20 | 21 | ## Planning and inverse planning 22 | 23 | - [MDP (gridworld)](./demo-mdp.ipynb) (discussed in paper) 24 | - [POMDP (crying baby)](./demo-pomdp.ipynb) (discussed in paper) 25 | - [Takeaway](./demo-takeaway.ipynb) (discussed in paper) 26 | (we use memo to reflect on cost of computation and make inferences about _whether_ someone is thinking) 27 | - [Perturbation confusion](./demo-pc.ipynb) (discussed in paper) 28 | - [I-POMDP (investment game)](./demo-i-pomdp.ipynb) 29 | (agents reason about their uncertainty about each other) 30 | - [Sally-Anne (false belief test)](./demo-sally-anne.ipynb) 31 | - [Integrating intuitive psychology and intuitive physics](./demo-physics.ipynb) 32 | 33 | ## Information theoretic calculations 34 | 35 | - [Expected information gain](./demo-eig.ipynb) 36 | - [Empowerment](./demo-empowerment.py) 37 | 38 | ## Puzzles 39 | 40 | - [Cheryl's birthday](./demo-cheryl.ipynb) 41 | - [Dining cryptographers](./demo-dining-cryptographers.ipynb) 42 | - [Newcomb's problem](./demo-newcomb.ipynb) 43 | - [Monty Hall](./demo-monty.ipynb) 44 | -------------------------------------------------------------------------------- /demo/demo-cheryl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8d8e4df7-475d-4a2f-81a1-becfe3078512", 6 | "metadata": {}, 7 | "source": [ 8 | "# Cheryl's Birthday\n", 9 | "\n", 10 | "> Everyone knows that Cheryl was born in February, March, or April.\n", 11 | ">\n", 12 | "> Cheryl separately tells Alice the month and Bob the day. Then they have this dialogue:\n", 13 | ">\n", 14 | "> 1. Alice: \"I don't know when Cheryl's birthday is...\"\n", 15 | "> 2. Alice: \"but I know that Bob doesn't know either.\"\n", 16 | "> 3. Bob: \"At first I didn’t know when Cheryl's birthday is...\n", 17 | "> 4. Bob: \"but now I know.\"\n", 18 | "> 5. Alice: \"Now I know when Cheryl's birthday is.\"\n", 19 | ">\n", 20 | "> When is Cheryl's birthday?\n", 21 | "\n", 22 | "_(This is actually our in-house variant of the [original puzzle](https://en.wikipedia.org/wiki/Cheryl%27s_Birthday). It is logically the same, but we find this variant more fun and easier to explain to people, because it doesn't rely on positing an arbitrary subset of dates.)_\n", 23 | "\n", 24 | "We will progressively build up a model of this scenario this in memo by writing a model for each of the 5 utterances. To model whether someone knows for certain when Cheryl's birthday is, we will check whether according to them the variance over the date is zero." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "a1c25755-6319-4bf6-9f39-fe290c8895d1", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "from memo import memo, domain\n", 35 | "import jax\n", 36 | "import jax.numpy as np\n", 37 | "from enum import IntEnum\n", 38 | "\n", 39 | "class Month(IntEnum):\n", 40 | " February = 0\n", 41 | " March = 1\n", 42 | " April = 2\n", 43 | "\n", 44 | "Day = np.arange(1, 31 + 1)\n", 45 | "\n", 46 | "class U(IntEnum):\n", 47 | " DUNNO = 0\n", 48 | " KNOWN = 1\n", 49 | "\n", 50 | "@jax.jit\n", 51 | "def possible(m, d): # 31 days hath...\n", 52 | " return d <= np.array([29, 31, 30])[m]" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "6c544a37-d090-4839-bc39-db2ef0b7fed4", 58 | "metadata": {}, 59 | "source": [ 60 | "Now, let's build up this model step by step.\n", 61 | "\n", 62 | "Alice: (after observing month $m$) \"I don't know when Cheryl's birthday is.\"" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "id": "fd945542-3e43-411c-b4fb-bb7c7acd4d02", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "@memo\n", 73 | "def a_u1[m: Month, u: U]():\n", 74 | " a: thinks[\n", 75 | " c: chooses(m in Month, wpp=1),\n", 76 | " c: chooses(d in Day, wpp=possible(m, d))\n", 77 | " ]\n", 78 | " a: observes [c.m] is m\n", 79 | " return u == a[Var[c.d] == 0] # note: Alice's variance over Cheryl's d\n", 80 | "# print(a_u1())" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "62ab6e22-7ed5-4427-b8aa-b0945d9ab683", 86 | "metadata": {}, 87 | "source": [ 88 | "Alice: \"...but I know that Bob doesn't know either.\"" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "id": "611ecaf5-d6a0-40b4-84c7-2101cf8ecbd3", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "@memo\n", 99 | "def a_u2[m: Month, u: U]():\n", 100 | " a: thinks[\n", 101 | " c: chooses(m in Month, wpp=1),\n", 102 | " c: chooses(d in Day, wpp=possible(m, d)),\n", 103 | " b: thinks[\n", 104 | " c: chooses(m in Month, wpp=1),\n", 105 | " c: chooses(d in Day, wpp=possible(m, d))\n", 106 | " ],\n", 107 | " b: observes [c.d] is c.d\n", 108 | " ]\n", 109 | " a: observes [c.m] is m\n", 110 | " return u == a[Pr[b[Var[c.m] == 0]] > 0]\n", 111 | "# print(a_u2())" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "127b9390-d44f-4e9b-a231-d2483d9e3fec", 117 | "metadata": {}, 118 | "source": [ 119 | "Bob: \"At first I didn't know...\" (similar to `a_u1`)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 4, 125 | "id": "0fe09dd8-b817-426b-82ae-773aa00dfb10", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "@memo\n", 130 | "def b_u3[d: Day, u: U]():\n", 131 | " b: thinks[\n", 132 | " c: chooses(m in Month, wpp=1),\n", 133 | " c: chooses(d in Day, wpp=possible(m, d))\n", 134 | " ]\n", 135 | " b: observes [c.d] is d\n", 136 | " return u == b[Var[c.m] == 0]\n", 137 | "# print(b_u3())" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "52ed3b7a-ff5e-493c-ab69-36d8f6edb886", 143 | "metadata": {}, 144 | "source": [ 145 | "Bob: \"But now...\" (conditions on result of `a_u1` and `a_u2`!)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "id": "d20e2fc1-d38f-4f0c-aa17-66f5247396a4", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "@memo\n", 156 | "def b_u4[d: Day, u1: U, u2: U, u: U]():\n", 157 | " b: thinks[\n", 158 | " c: chooses(m in Month, wpp=1),\n", 159 | " c: chooses(d in Day, wpp=possible(m, d)),\n", 160 | " a: thinks[\n", 161 | " c: chooses(m in Month, wpp=1),\n", 162 | " c: chooses(d in Day, wpp=possible(m, d)),\n", 163 | " ],\n", 164 | " a: observes [c.m] is c.m,\n", 165 | " a: chooses(u1 in U, wpp=a_u1[c.m, u1]()),\n", 166 | " a: chooses(u2 in U, wpp=a_u2[c.m, u2]())\n", 167 | " ]\n", 168 | " b: observes [c.d] is d\n", 169 | " b: observes [a.u1] is u1\n", 170 | " b: observes [a.u2] is u2\n", 171 | " return u == b[Var[c.m] == 0]\n", 172 | "# print(b_u4()[:, U.DUNNO, U.DUNNO])" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "6a5e7d5b-4cce-451d-b0cd-39bc51f5622a", 178 | "metadata": {}, 179 | "source": [ 180 | "Alice: \"Now I know.\" (conditions on `b_u3` and `b_u4`)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "id": "3a34bf80-3219-46b5-8d36-2c4f253999ed", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "@memo\n", 191 | "def a_u5[m: Month, u1: U, u2: U, u3: U, u4: U, u: U]():\n", 192 | " a: knows(u1, u2)\n", 193 | " a: thinks[\n", 194 | " c: chooses(m in Month, wpp=1),\n", 195 | " c: chooses(d in Day, wpp=possible(m, d)),\n", 196 | " b: thinks[\n", 197 | " c: chooses(m in Month, wpp=1),\n", 198 | " c: chooses(d in Day, wpp=possible(m, d))\n", 199 | " ],\n", 200 | " b: knows(u1, u2),\n", 201 | " b: observes [c.d] is c.d,\n", 202 | " b: chooses(u3 in U, wpp=b_u3[c.d, u3]()),\n", 203 | " b: chooses(u4 in U, wpp=b_u4[c.d, u1, u2, u4]()),\n", 204 | " ]\n", 205 | " a: observes [c.m] is m\n", 206 | " a: observes [b.u3] is u3\n", 207 | " a: observes [b.u4] is u4\n", 208 | " return u == a[Var[c.d] == 0]\n", 209 | "# a_u5()[:, U.DUNNO, U.DUNNO, U.DUNNO, U.KNOWN]" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "31123856-0109-4f3d-93ea-099043fcb15b", 215 | "metadata": {}, 216 | "source": [ 217 | "Putting everything together, we condition on all 5 utterances." 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 7, 223 | "id": "ccc80742-2333-49fa-bcfd-8a99574adfa2", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "403 μs ± 5.1 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "@memo\n", 236 | "def puzzle[m: Month, d: Day, u1: U, u2: U, u3: U, u4: U, u5: U]():\n", 237 | " z: thinks[\n", 238 | " c: chooses(m in Month, wpp=1),\n", 239 | " c: chooses(d in Day, wpp=possible(m, d)),\n", 240 | " c: chooses(u1 in U, wpp=a_u1[m, u1]()),\n", 241 | " c: chooses(u2 in U, wpp=a_u2[m, u2]()),\n", 242 | " c: chooses(u3 in U, wpp=b_u3[d, u3]()),\n", 243 | " c: chooses(u4 in U, wpp=b_u4[d, u1, u2, u4]()),\n", 244 | " c: chooses(u5 in U, wpp=a_u5[m, u1, u2, u3, u4, u5]()),\n", 245 | " ]\n", 246 | " z: observes [c.u1] is u1\n", 247 | " z: observes [c.u2] is u2\n", 248 | " z: observes [c.u3] is u3\n", 249 | " z: observes [c.u4] is u4\n", 250 | " z: observes [c.u5] is u5\n", 251 | " z: knows(m, d)\n", 252 | " return z[E[c.m == m and c.d == d]]\n", 253 | "\n", 254 | "answer = puzzle()[:, :, U.DUNNO, U.DUNNO, U.DUNNO, U.KNOWN, U.KNOWN]\n", 255 | "\n", 256 | "%timeit -r 10 -n 100 out = puzzle().block_until_ready()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "d79c2cc8-e317-4806-a4fb-e3d670d4322f", 262 | "metadata": {}, 263 | "source": [ 264 | "Finally, we extract the answer by finding the nonzero entry in the inferred $(m, d)$." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 8, 270 | "id": "757e3c66-2453-4e58-9d9f-043a4fa6df64", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "April 30\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "for m in Month:\n", 283 | " for di, d in enumerate(Day):\n", 284 | " if answer[m, di]:\n", 285 | " print(m.name, d)" 286 | ] 287 | } 288 | ], 289 | "metadata": { 290 | "language_info": { 291 | "name": "python" 292 | } 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 5 296 | } 297 | -------------------------------------------------------------------------------- /demo/demo-dining-cryptographers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b9ed40f2-90c7-4bee-815c-1ddda2b31410", 6 | "metadata": {}, 7 | "source": [ 8 | "# Dining cryptographers\n", 9 | "\n", 10 | "**Inspired by:** Chaum, D. (1988). _The dining cryptographers problem: Unconditional sender and recipient untraceability._ Journal of cryptology, 1, 65-75.\n", 11 | "\n", 12 | "Three cryptographers are out for dinner. After dessert the waiter informs them that their meal was paid for anonymously. It was either paid for by one of the three cryptographers, or by the NSA. Yikes! To find out whether the bill was paid for by the NSA — without revealing which cryptographer paid, in case it wasn't the NSA — they carry out a protocol involving some coin tossing.\n", 13 | "\n", 14 | "Each cryptographer tosses a coin and shows its outcome to their neighbor-to-the-left (hiding the coin behind their menus so the third cryptographer cannot see). Each cryptographer then announces the XOR of (1) their coin, (2) their neighbor's coin, and (3) whether or not they paid. Now, the XOR of all cryptographers' announcements reveals whether the NSA paid, without revealing which cryptographer (if any) paid." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "3637690b-5a2b-4e66-b3c1-20e12aaf8e76", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from memo import memo\n", 25 | "import jax\n", 26 | "import jax.numpy as np\n", 27 | "from enum import IntEnum" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "0dd33c57-e226-41df-b612-7a443f6460f9", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class Bit(IntEnum):\n", 38 | " NOT_PAID = 0\n", 39 | " PAID = 1\n", 40 | "\n", 41 | "class Who(IntEnum):\n", 42 | " A_PAID = 0\n", 43 | " B_PAID = 1\n", 44 | " C_PAID = 2\n", 45 | " NSA_PAID = 3" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "df038c58-09b8-4234-8ce5-7ba3ae114a5d", 51 | "metadata": {}, 52 | "source": [ 53 | "We will model this from the perspective of cryptographer A (who didn't pay). We will show by computation that no matter how the coins come up and no matter what B and C announce, it is impossible to distinguish between B paying and C paying." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "7bdd2769-7009-4da1-a7b6-02eab187e8b4", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "@memo\n", 64 | "def model[a_: Bit, b_: Bit, bx: Bit, cx: Bit, w: Who]():\n", 65 | " a: knows(a_, w)\n", 66 | " a: thinks[\n", 67 | " world: chooses(w in Who, wpp=(w != 0)),\n", 68 | " b: chooses(b_ in Bit, wpp=1),\n", 69 | " c: chooses(c_ in Bit, wpp=1),\n", 70 | "\n", 71 | " b: knows(world.w, c.c_),\n", 72 | " c: knows(world.w, a_),\n", 73 | "\n", 74 | " b: chooses(bx in Bit, wpp=(b_ ^ c.c_ ^ (world.w == 1) == bx)),\n", 75 | " c: chooses(cx in Bit, wpp=(c_ ^ a_ ^ (world.w == 2) == cx)),\n", 76 | " ]\n", 77 | " a: observes [b.b_] is b_\n", 78 | " a: observes [b.bx] is bx\n", 79 | " a: observes [c.cx] is cx\n", 80 | " return a[Pr[world.w == w]]" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "21df0926-ffec-4b40-a3a3-edfc27487241", 86 | "metadata": {}, 87 | "source": [ 88 | "Let's see what happens." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "id": "f6b28232-8dda-4f45-aa90-907da9be61ea", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "23.5 μs ± 1.74 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n", 102 | "A flips 0, B flips 0, B says 0, C says 0 -> [0. 0. 0. 1.]\n", 103 | "A flips 0, B flips 0, B says 0, C says 1 -> [0. 0.5 0.5 0. ]\n", 104 | "A flips 0, B flips 0, B says 1, C says 0 -> [0. 0.5 0.5 0. ]\n", 105 | "A flips 0, B flips 0, B says 1, C says 1 -> [0. 0. 0. 1.]\n", 106 | "A flips 0, B flips 1, B says 0, C says 0 -> [0. 0.5 0.5 0. ]\n", 107 | "A flips 0, B flips 1, B says 0, C says 1 -> [0. 0. 0. 1.]\n", 108 | "A flips 0, B flips 1, B says 1, C says 0 -> [0. 0. 0. 1.]\n", 109 | "A flips 0, B flips 1, B says 1, C says 1 -> [0. 0.5 0.5 0. ]\n", 110 | "A flips 1, B flips 0, B says 0, C says 0 -> [0. 0.5 0.5 0. ]\n", 111 | "A flips 1, B flips 0, B says 0, C says 1 -> [0. 0. 0. 1.]\n", 112 | "A flips 1, B flips 0, B says 1, C says 0 -> [0. 0. 0. 1.]\n", 113 | "A flips 1, B flips 0, B says 1, C says 1 -> [0. 0.5 0.5 0. ]\n", 114 | "A flips 1, B flips 1, B says 0, C says 0 -> [0. 0. 0. 1.]\n", 115 | "A flips 1, B flips 1, B says 0, C says 1 -> [0. 0.5 0.5 0. ]\n", 116 | "A flips 1, B flips 1, B says 1, C says 0 -> [0. 0.5 0.5 0. ]\n", 117 | "A flips 1, B flips 1, B says 1, C says 1 -> [0. 0. 0. 1.]\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "out = model()\n", 123 | "%timeit -r 10 -n 100 out = model().block_until_ready()\n", 124 | "\n", 125 | "import itertools\n", 126 | "for a_, b_, bx, cx in itertools.product(Bit, Bit, Bit, Bit):\n", 127 | " print(f\"A flips {a_}, B flips {b_}, B says {bx}, C says {cx} -> {out[a_, b_, bx, cx]}\")" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "6fa3f797-442c-4220-b657-434d57e36bd0", 133 | "metadata": {}, 134 | "source": [ 135 | "No matter what happens, A's posterior belief is always `[0 0 0 1]` (NSA paid) or `[0 .5 .5 0]` (B or C paid, but unsure which)." 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "language_info": { 141 | "name": "python" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 5 146 | } 147 | -------------------------------------------------------------------------------- /demo/demo-eig.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c8a2f179-3175-442f-96f3-667948ad4113", 6 | "metadata": {}, 7 | "source": [ 8 | "# Question-asking based on Expected Information Gain (EIG)\n", 9 | "\n", 10 | "**Inspired by:** Rothe, A., Lake, B. M., & Gureckis, T. M. (2018). _Do people ask good questions?._ Computational Brain & Behavior, 1, 69-89.\n", 11 | "\n", 12 | "Bob rolls a red die and a blue die. Alice gets to ask one yes-no question about the sum. What is the most informative question she could ask, in order to learn the most about the two die rolls? For example, is it better to ask if the sum is a perfect square, or if the sum is prime?\n", 13 | "\n", 14 | "We'll compute the EIG of various questions..." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "4ee47cfd-0371-489f-b1c2-a277007eea27", 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "48.4 μs ± 2.47 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n", 28 | "EIG Question\n", 29 | "--- ---\n", 30 | "0.69315 n % 2 == 1, # odd?\n", 31 | "0.69315 n % 2 == 0, # even?\n", 32 | "0.67919 n > 6\n", 33 | "0.67919 is_prime[n]\n", 34 | "0.63651 n % 3 == 0\n", 35 | "0.59084 n > 5\n", 36 | "0.59084 n > 8\n", 37 | "0.56233 n % 4 == 0\n", 38 | "0.56233 is_pow_2[n]\n", 39 | "0.49260 n % 5 == 0\n", 40 | "0.49260 is_square[n]\n", 41 | "0.45056 n == 7\n", 42 | "0.28684 n > 10\n", 43 | "0.12693 n == 12\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from memo import memo\n", 49 | "import jax\n", 50 | "import jax.numpy as np\n", 51 | "\n", 52 | "is_prime = np.array([0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0])\n", 53 | "is_square = np.array([1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0])\n", 54 | "is_pow_2 = np.array([0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0])\n", 55 | "Qs = [\n", 56 | " lambda n: n == 7,\n", 57 | " lambda n: n == 12,\n", 58 | " lambda n: n > 10,\n", 59 | " lambda n: n > 8,\n", 60 | " lambda n: n > 6,\n", 61 | " lambda n: n > 5,\n", 62 | " lambda n: n % 2 == 0, # even??\n", 63 | " lambda n: n % 2 == 1, # odd??\n", 64 | " lambda n: n % 3 == 0,\n", 65 | " lambda n: n % 4 == 0,\n", 66 | " lambda n: n % 5 == 0,\n", 67 | " lambda n: is_prime[n],\n", 68 | " lambda n: is_square[n],\n", 69 | " lambda n: is_pow_2[n],\n", 70 | "]\n", 71 | "\n", 72 | "N = np.arange(1, 6 + 1) # single die's outcomes\n", 73 | "Q = np.arange(len(Qs)) # questions\n", 74 | "A = np.array([0, 1]) # answers (yes/no)\n", 75 | "\n", 76 | "@jax.jit\n", 77 | "def respond(q, a, n):\n", 78 | " return np.array([q_(n) for q_ in Qs])[q] == a\n", 79 | "\n", 80 | "@memo\n", 81 | "def eig[q: Q]():\n", 82 | " alice: knows(q)\n", 83 | " alice: thinks[\n", 84 | " # bob rolls dice...\n", 85 | " bob: chooses(n_red in N, wpp=1),\n", 86 | " bob: chooses(n_blu in N, wpp=1),\n", 87 | "\n", 88 | " # bob answers question...\n", 89 | " bob: knows(q),\n", 90 | " bob: chooses(a in A, wpp=respond(q, a, n_red + n_blu))\n", 91 | " ]\n", 92 | " alice: snapshots_self_as(future_self)\n", 93 | "\n", 94 | " return alice[ imagine[\n", 95 | " # if I were to get the answer...\n", 96 | " future_self: observes [bob.a] is bob.a,\n", 97 | " # EIG = entropy minus conditional entropy\n", 98 | " H[bob.n_red, bob.n_blu] - E[future_self[ H[bob.n_red, bob.n_blu] ]]\n", 99 | " ] ]\n", 100 | "\n", 101 | "z = eig()\n", 102 | "%timeit -r 10 -n 100 eig().block_until_ready()\n", 103 | "\n", 104 | "## print questions and EIGs in sorted order\n", 105 | "print('EIG Question')\n", 106 | "print('--- ---')\n", 107 | "import inspect\n", 108 | "q_names = [inspect.getsource(q_).strip()[10:-1] for q_ in Qs]\n", 109 | "for eig_, q_ in reversed(sorted(list(zip(z, q_names)))):\n", 110 | " print(f'{eig_:0.5f}', q_)" 111 | ] 112 | } 113 | ], 114 | "metadata": { 115 | "language_info": { 116 | "name": "python" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 5 121 | } 122 | -------------------------------------------------------------------------------- /demo/demo-empowerment.py: -------------------------------------------------------------------------------- 1 | from memo import memo, domain, make_module 2 | import jax 3 | import jax.numpy as np 4 | from enum import IntEnum 5 | 6 | from matplotlib import pyplot as plt 7 | 8 | """ 9 | **Inspired by:** Klyubin, A. S., Polani, D., & Nehaniv, C. L. (2005, September). 10 | All else being equal be empowered. In European Conference on Artificial Life 11 | (pp. 744-753). Berlin, Heidelberg: Springer Berlin Heidelberg. 12 | 13 | This example shows how to use memo to compute an agent's empowerment in a gridworld. 14 | The particular example is inspired by Figure 3a in Klyubin et al (2005). 15 | """ 16 | 17 | ## This is a little memo module that implements the Blahut-Arimoto algorithm for empowerment 18 | # See: https://www.comm.utoronto.ca/~weiyu/ab_isit04.pdf 19 | def make_blahut_arimoto(X, Y, Z, p_Y_given_X): 20 | m = make_module('blahut_arimoto') 21 | m.X = X 22 | m.Y = Y 23 | m.Z = Z 24 | m.p_Y_given_X = p_Y_given_X 25 | 26 | @memo(install_module=m.install) 27 | def q[x: X, z: Z](t): 28 | alice: knows(z) 29 | alice: chooses(x in X, wpp=imagine[ 30 | bob: knows(x, z), 31 | bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z)), 32 | # exp(E[ log(Q[x, bob.y, z](t - 1) if t > 0 else 1) ]) 33 | bob: thinks[ 34 | charlie: knows(y, z), 35 | charlie: chooses(x in X, wpp=Q[x, y, z](t - 1) if t > 0 else 1) 36 | ], 37 | exp(E[bob[H[charlie.x]]]) 38 | ]) 39 | return Pr[alice.x == x] 40 | 41 | @memo(install_module=m.install) 42 | def Q[x: X, y: Y, z: Z](t): 43 | alice: knows(x, y, z) 44 | alice: thinks[ 45 | bob: knows(z), 46 | bob: chooses(x in X, wpp=q[x, z](t)), 47 | bob: chooses(y in Y, wpp=p_Y_given_X(y, x, z)) 48 | ] 49 | alice: observes [bob.y] is y 50 | return alice[Pr[bob.x == x]] 51 | 52 | @memo(install_module=m.install) 53 | def C[z: Z](t): 54 | alice: knows(z) 55 | alice: chooses(x in X, wpp=q[x, z](t)) 56 | alice: chooses(y in Y, wpp=p_Y_given_X(y, x, z)) 57 | return (H[alice.x] + H[alice.y] - H[alice.x, alice.y]) / log(2) # convert to bits 58 | 59 | return m 60 | 61 | # # Sanity check: a channel that drops messages with probability 0.1 should have capacity 0.9 bits. 62 | # X = [0, 1] 63 | # Y = [0, 1, 2] 64 | # @jax.jit 65 | # def p_Y_given_X(y, x, z): 66 | # return np.array([ 67 | # [0.9, 0.1, 1e-10], 68 | # [1e-10, 0.1, 0.9] 69 | # ])[x, y] 70 | # m = make_blahut_arimoto(X, Y, np.array([0]), p_Y_given_X) 71 | # print(m.q(10)) 72 | # print(m.C(10)) 73 | 74 | 75 | ## Setting up a gridworld... 76 | N = 13 77 | world = np.zeros((N, N)) 78 | world = world.at[N // 2, N // 2].set(1) 79 | 80 | X = np.arange(world.shape[0]) 81 | Y = np.arange(world.shape[1]) 82 | S = domain(x=len(X), y=len(Y)) 83 | 84 | class A(IntEnum): 85 | N = 0 86 | S = 1 87 | W = 2 88 | E = 3 89 | O = 4 90 | Ax = domain( 91 | a1=len(A), 92 | a2=len(A), 93 | a3=len(A), 94 | a4=len(A), 95 | a5=len(A), 96 | ) 97 | 98 | @jax.jit 99 | def Tr1(s, a): 100 | x = S.x(s) 101 | y = S.y(s) 102 | z = np.array([ 103 | [x, y - 1], 104 | [x, y + 1], 105 | [x - 1, y], 106 | [x + 1, y], 107 | [x, y] 108 | ])[a] 109 | x_ = np.clip(z[0], 0, len(X) - 1) 110 | y_ = np.clip(z[1], 0, len(Y) - 1) 111 | return np.where(world[x_, y_], s, S(x_, y_)) 112 | 113 | 114 | @jax.jit 115 | def Tr(s_, ax, s): 116 | for a in Ax._tuple(ax): 117 | s = Tr1(s, a) 118 | return s == s_ 119 | 120 | # ...and computing 5-step empowerment in the gridworld! 121 | m = make_blahut_arimoto(X=Ax, Y=S, Z=S, p_Y_given_X=Tr) 122 | m.Z = S 123 | @memo(install_module=m.install, debug_trace=True) 124 | def empowerment[s: Z](t): 125 | return C[s](t) 126 | 127 | emp = m.empowerment(5).block_until_ready() 128 | emp = emp.reshape(len(X), len(Y)) 129 | emp = emp * (1 - world) 130 | plt.colorbar(plt.imshow(emp.reshape(len(X), len(Y)) * (1 - world), cmap='gray')) 131 | plt.savefig('out.png') 132 | -------------------------------------------------------------------------------- /demo/demo-fib.py: -------------------------------------------------------------------------------- 1 | from memo import memo 2 | import jax 3 | import jax.numpy as np 4 | import functools 5 | 6 | ''' 7 | This file is useful for exploring/tinkering with the memo compiler, especially 8 | understanding the subtleties of statically-known parameters and recursion. 9 | ''' 10 | 11 | Unit = [0] 12 | 13 | @functools.cache 14 | @memo(debug_print_compiled=True, debug_trace=True) 15 | def fib[a: Unit](n): 16 | return 1 if n < 2 else fib[a](n - 1) + fib[a](n - 2) 17 | 18 | print([fib(n) for n in range(10 + 1)]) # this works 19 | # print(jax.vmap(fib)(np.arange(10 + 1))) # this rightfully doesn't work 20 | -------------------------------------------------------------------------------- /demo/demo-i-pomdp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ad16d500-51fd-4b94-ad33-d6f704ca944f", 6 | "metadata": {}, 7 | "source": [ 8 | "# The Investment Game\n", 9 | "\n", 10 | "**Inspired by:** Berg, J., Dickhaut, J., & McCabe, K. (1995). _Trust, reciprocity, and social history._ Games and economic behavior, 10(1), 122-142.\n", 11 | "\n", 12 | "and\n", 13 | "\n", 14 | "Gmytrasiewicz, P. J., & Doshi, P. (2005). _A framework for sequential planning in multi-agent settings._ Journal of Artificial Intelligence Research, 24, 49-79.\n", 15 | "\n", 16 | "The investor is endowed with \\$1 and can choose to send a fraction `fi` to the trustee (and keep the rest). The investment is multiplied by the factor `mult` before reaching the trustee. The trustee can choose to send back to the investor some fraction `ft` of the multiplied investment (and keep the rest). What should they do?\n", 17 | "\n", 18 | "How does their behavior change if they each have a hidden \"guilt\" parameter that modulates a cost they incur for inequitable outcomes?\n", 19 | "\n", 20 | "We can model this as an IPOMDP, a multi-agent extension of POMDPs where agents model their uncertainty about each other…" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "64db0cb2-8882-4a53-bf76-51128a87ab7b", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from memo import memo, domain\n", 31 | "import jax\n", 32 | "import jax.numpy as np" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "5e5dfa78-8a64-48a9-8f5a-6f978c2292d9", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "mult = 3.0\n", 43 | "F = np.arange(2) # actions are fractions :)\n", 44 | "Fractions = np.array([0.2, 0.8])\n", 45 | "\n", 46 | "G = np.array([0.0, 0.7]) # guilts\n", 47 | "\n", 48 | "@jax.jit\n", 49 | "def payout_investor(fi, ft): # fi = fraction chosen by investor, ft = fraction chosen by trustee\n", 50 | " return (1 - Fractions[fi]) + mult * Fractions[fi] * Fractions[ft]\n", 51 | "\n", 52 | "@jax.jit\n", 53 | "def payout_trustee(fi, ft):\n", 54 | " return mult * Fractions[fi] * (1 - Fractions[ft])\n", 55 | "\n", 56 | "H = domain( # histories - for simplicity, just 2 rounds of history (i.e. 3 rounds of game)\n", 57 | " i1=len(F), t1=len(F), # 1st round, investor + trustee (i, t)\n", 58 | " i2=len(F), t2=len(F) # 2nd round, investor + trustee (i, t)\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "id": "a317ebb1-d9f2-4c3c-a4c5-4cbe4d5f8653", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "@jax.jit\n", 70 | "def reludiff(x, y):\n", 71 | " return np.maximum(x - y, 0)\n", 72 | "\n", 73 | "@jax.jit\n", 74 | "def is_init(h):\n", 75 | " return h == 0\n", 76 | "\n", 77 | "@jax.jit\n", 78 | "def Tr(r, h, fi, ft, h_):\n", 79 | " # step game: update h to h_ with moves (fi, ft) at round r\n", 80 | " z = H._tuple(h)\n", 81 | " z = np.array(z)\n", 82 | " z = z.at[r * 2].set(fi)\n", 83 | " z = z.at[r * 2 + 1].set(ft)\n", 84 | " return h_ == H(*z)\n", 85 | "\n", 86 | "@memo(cache=True)\n", 87 | "def hist[ig: G, tg: G, h: H](r, level, beta):\n", 88 | " # prior over histories h at start of round r >= 1 conditioned on guilts ig, tg\n", 89 | " world: knows(ig, tg)\n", 90 | " # start with h at round r - 1\n", 91 | " world: chooses(h in H, wpp=hist[ig, tg, h](r - 1, level, beta) if r > 1 else is_init(h))\n", 92 | " # investor and trustee make moves\n", 93 | " world: chooses(fi in F, wpp=exp(beta * investor[h, ig, fi](r - 1, level, beta)))\n", 94 | " world: chooses(ft in F, wpp=exp(beta * trustee[h, tg, fi, ft](r - 1, level, beta)))\n", 95 | " # h gets updated\n", 96 | " world: chooses(h_ in H, wpp=Tr(r - 1, h, fi, ft, h_))\n", 97 | " return Pr[world.h_ == h]\n", 98 | "\n", 99 | "\n", 100 | "@memo(cache=True)\n", 101 | "def investor[h: H, ig: G, fi: F](r, level, beta):\n", 102 | " # Q-function for investor conditioned on h, own guilt\n", 103 | " investor: knows(ig)\n", 104 | " investor: thinks[\n", 105 | " trustee: knows(ig),\n", 106 | " trustee: chooses(tg in G, wpp=1),\n", 107 | " trustee: chooses(h in H, wpp=(hist[ig, tg, h](r, level - 1, beta) if r > 0 and level > 0 else is_init(h)) + 1e-3)\n", 108 | " ]\n", 109 | " investor: observes [trustee.h] is h # now investor has posterior belief over trustee.tg\n", 110 | " investor: knows(fi)\n", 111 | " return investor[\n", 112 | " imagine[ # Q-value having chosen fi\n", 113 | " trustee: knows(fi),\n", 114 | " trustee: chooses(ft in F, wpp=exp(beta * trustee[h, tg, fi, ft](r, level - 1, beta)) if level > 0 else 1),\n", 115 | " trustee: chooses(h_ in H, wpp=Tr(r, h, fi, ft, h_)),\n", 116 | " trustee: chooses(fi_ in F, wpp=exp(beta * investor[h_, ig, fi_](r + 1, level, beta)) if r < 2 and level > 0 else 1),\n", 117 | " E[\n", 118 | " payout_investor(fi, trustee.ft)\n", 119 | " - ig * reludiff(\n", 120 | " payout_investor(fi, trustee.ft),\n", 121 | " payout_trustee(fi, trustee.ft)\n", 122 | " )\n", 123 | " + (investor[trustee.h_, ig, trustee.fi_](r + 1, level, beta) if r < 2 and level > 0 else 0)\n", 124 | " ]\n", 125 | " ]\n", 126 | " ]\n", 127 | "\n", 128 | "@memo(cache=True)\n", 129 | "def trustee[h: H, tg: G, fi: F, ft: F](r, level, beta):\n", 130 | " # Q-function for trustee conditioned on h, own guilt, investor's most recent move\n", 131 | " trustee: knows(tg)\n", 132 | " trustee: thinks[\n", 133 | " investor: knows(tg),\n", 134 | " investor: chooses(ig in G, wpp=1),\n", 135 | " investor: chooses(h in H, wpp=(hist[ig, tg, h](r, level - 1, beta) if r > 0 and level > 0 else is_init(h)) + 1e-3),\n", 136 | " investor: chooses(fi in F, wpp=exp(beta * investor[h, ig, fi](r, level - 1, beta)) if level > 0 else 1)\n", 137 | " ]\n", 138 | " trustee: observes [investor.h] is h\n", 139 | " trustee: observes [investor.fi] is fi\n", 140 | " trustee: knows(ft)\n", 141 | " return trustee[\n", 142 | " imagine[ # Q-value having chosen ft\n", 143 | " investor: knows(ft),\n", 144 | " investor: chooses(h_ in H, wpp=Tr(r, h, fi, ft, h_)),\n", 145 | " investor: chooses(fi_ in F, wpp=exp(beta * investor[h_, ig, fi_](r + 1, level, beta)) if r < 2 and level > 0 else 1),\n", 146 | " investor: chooses(ft_ in F, wpp=exp(beta * trustee[h_, tg, fi_, ft_](r + 1, level, beta)) if r < 2 and level > 0 else 1),\n", 147 | " E[\n", 148 | " payout_trustee(investor.fi, ft)\n", 149 | " - tg * reludiff(\n", 150 | " payout_trustee(investor.fi, ft),\n", 151 | " payout_investor(investor.fi, ft)\n", 152 | " )\n", 153 | " + (trustee[investor.h_, tg, investor.fi_, investor.ft_](r + 1, level, beta) if r < 2 and level > 0 else 0)\n", 154 | " ]\n", 155 | " ]\n", 156 | " ]" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 4, 162 | "id": "c851dafa-293d-41d8-9591-433b67e67e03", 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "Investor round 1:\n", 170 | " For guilt 0.00 offer:\n", 171 | " 0.20 with probability 0.78\n", 172 | " 0.80 with probability 0.22\n", 173 | " For guilt 0.70 offer:\n", 174 | " 0.20 with probability 0.38\n", 175 | " 0.80 with probability 0.62\n", 176 | "CPU times: user 1.24 s, sys: 42.3 ms, total: 1.28 s\n", 177 | "Wall time: 757 ms\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "%%time\n", 183 | "\n", 184 | "beta = 5.0\n", 185 | "level = 5\n", 186 | "\n", 187 | "inv = investor(r=0, level=level, beta=beta)\n", 188 | "inv = np.exp(beta * inv) / np.exp(beta * inv).sum(axis=-1, keepdims=True)\n", 189 | "\n", 190 | "h = 0\n", 191 | "print('Investor round 1:')\n", 192 | "for gi, g in enumerate(G):\n", 193 | " print(f' For guilt {g:.02f} offer:')\n", 194 | " for f in F:\n", 195 | " print(f' {Fractions[f]:.02f} with probability {inv[h, gi, f]:.02f}')" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 5, 201 | "id": "a379fa27-7f36-4fcc-a7bc-34c98f3095fd", 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "Trustee round 1, after receiving low offer:\n", 209 | " For guilt 0.00 offer:\n", 210 | " 0.20 with probability 0.86\n", 211 | " 0.80 with probability 0.14\n", 212 | " For guilt 0.70 offer:\n", 213 | " 0.20 with probability 0.86\n", 214 | " 0.80 with probability 0.14\n", 215 | "Trustee round 1, after receiving high offer:\n", 216 | " For guilt 0.00 offer:\n", 217 | " 0.20 with probability 1.00\n", 218 | " 0.80 with probability 0.00\n", 219 | " For guilt 0.70 offer:\n", 220 | " 0.20 with probability 0.94\n", 221 | " 0.80 with probability 0.06\n", 222 | "CPU times: user 105 ms, sys: 2.3 ms, total: 107 ms\n", 223 | "Wall time: 106 ms\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "%%time\n", 229 | "tr = trustee(r=0, level=level, beta=beta)\n", 230 | "tr = np.exp(beta * tr) / np.exp(beta * tr).sum(axis=-1, keepdims=True)\n", 231 | "\n", 232 | "h = 0\n", 233 | "print('Trustee round 1, after receiving low offer:')\n", 234 | "for gi, g in enumerate(G):\n", 235 | " print(f' For guilt {g:.02f} offer:')\n", 236 | " for f in F:\n", 237 | " print(f' {Fractions[f]:.02f} with probability {tr[h, gi, 0, f]:.02f}')\n", 238 | "\n", 239 | "print('Trustee round 1, after receiving high offer:')\n", 240 | "for gi, g in enumerate(G):\n", 241 | " print(f' For guilt {g:.02f} offer:')\n", 242 | " for f in F:\n", 243 | " print(f' {Fractions[f]:.02f} with probability {tr[h, gi, 1, f]:.02f}')" 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "language_info": { 249 | "name": "python" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /demo/demo-mdp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e98b3371-c71f-4edd-9bbf-9461dd258e93", 6 | "metadata": {}, 7 | "source": [ 8 | "# MDP planning and inverse planning\n", 9 | "\n", 10 | "In this notebook we will set up a simple grid-world, plan routes to goals, and infer goals given actions." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "dbc21fd7-4da7-4135-a10c-0ae18c50f0ec", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from functools import cache\n", 21 | "import jax\n", 22 | "import jax.numpy as np\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from memo import memo\n", 25 | "\n", 26 | "H = 21\n", 27 | "W = 21\n", 28 | "S = np.arange(H * W) # state space\n", 29 | "G = np.array([0, H * W - 1]) # possible goals: NW and SE corners\n", 30 | "\n", 31 | "A = np.array([0, 1, 2, 3]) # action space: left, right, up, down\n", 32 | "coord_actions = np.array([[-1, 0], [+1, 0], [0, -1], [0, +1]])\n", 33 | "\n", 34 | "maze_raw = np.array(1 - plt.imread('../paper/fig/logo-maze.png'), dtype=int);\n", 35 | "maze = maze_raw.reshape(-1)\n", 36 | "assert maze_raw.size == H * W\n", 37 | "\n", 38 | "# # Alternatively...\n", 39 | "# maze = np.zeros(H * W) # blank maze\n", 40 | "\n", 41 | "# transition function: P(s_ | s, a)\n", 42 | "@jax.jit\n", 43 | "def Tr(s, a, s_):\n", 44 | " x, y = s % W, s // W\n", 45 | " next_coords = np.array([x, y]) + coord_actions[a]\n", 46 | " next_state = (\n", 47 | " + 1 * np.clip(next_coords[0], 0, W - 1)\n", 48 | " + W * np.clip(next_coords[1], 0, H - 1)\n", 49 | " )\n", 50 | " return (\n", 51 | " + 1.0 * ((maze[next_state] == 0) & (next_state == s_)) # next state free, can move there\n", 52 | " + 1.0 * ((maze[next_state] == 1) & (s == s_)) # next state occupied, stay where you are\n", 53 | " )\n", 54 | "\n", 55 | "# reward function\n", 56 | "@jax.jit\n", 57 | "def R(s, a, g):\n", 58 | " return 1.0 * (s == g) - 0.1\n", 59 | "\n", 60 | "@jax.jit\n", 61 | "def is_terminating(s, g):\n", 62 | " return s == g\n", 63 | "\n", 64 | "# discount factor\n", 65 | "@jax.jit\n", 66 | "def gamma():\n", 67 | " return 1.0" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "1e81c7dc-612b-4228-8896-6be6122e8782", 73 | "metadata": {}, 74 | "source": [ 75 | "We can plan via Q-value iteration and inverse-plan by inferring $P(g \\mid s, a)$ where $P(a \\mid s, g)$ is given by a softmax over Q-value with $\\beta=2$." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "id": "734c556e-dcc4-4b39-b773-be1949a3606e", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "6.88 s ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)\n", 89 | "4.74 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "@cache\n", 95 | "@memo\n", 96 | "def Q[s: S, a: A, g: G](t):\n", 97 | " alice: knows(s, a, g)\n", 98 | " alice: given(s_ in S, wpp=Tr(s, a, s_))\n", 99 | " alice: chooses(a_ in A, to_maximize=0.0 if t < 0 else Q[s_, a_, g](t - 1))\n", 100 | " return E[\n", 101 | " R(s, a, g) + (0.0 if t < 0 else\n", 102 | " 0.0 if is_terminating(s, g) else\n", 103 | " gamma() * Q[alice.s_, alice.a_, g](t - 1))\n", 104 | " ]\n", 105 | "\n", 106 | "@memo\n", 107 | "def invplan[s: S, a: A, g: G](t):\n", 108 | " observer: knows(a, s, g)\n", 109 | " observer: thinks[\n", 110 | " alice: chooses(g in G, wpp=1),\n", 111 | " alice: knows(s),\n", 112 | " alice: chooses(a in A, wpp=exp(2 * Q[s, a, g](t))),\n", 113 | " ]\n", 114 | " observer: observes [alice.a] is a\n", 115 | " return observer[E[alice.g == g]]\n", 116 | "\n", 117 | "Q(0) # pre-compile Q\n", 118 | "%timeit -r 1 -n 10 Q.cache_clear(); Q(1000).block_until_ready()\n", 119 | "%timeit -r 1 -n 10 invplan(1000).block_until_ready()\n", 120 | "ip = invplan(1000)\n", 121 | "v = Q(1000).max(axis=1)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "id": "b94b6f0b-dd44-4787-a3f4-49c0726d607e", 127 | "metadata": {}, 128 | "source": [ 129 | "This is already pretty fast, though it is even faster on a GPU.\n", 130 | "\n", 131 | "Finally, let's make the plots shown in the paper." 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 3, 137 | "id": "4dc266a7-e85a-4ee1-8498-099e44d0b641", 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAEhCAYAAABoYoUCAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAID1JREFUeJzt3X2UVOWdJ/Dvrap+k266mwYDLX10BKHboElARN4m6iKBmEB2POoxyjEke7KbiXo0sxPwxLNGIhlAhKMZwTljaDe6zmSOJuAoswY2OxPAIxnQqGSDKDYs0IQs0C/VNN31cp/9o6nqqn5+Rf+q6+1W8f380/L43Hufqnr617d+93lxjDEGRETD8BW6AURUHBgsiEiFwYKIVBgsiEiFwYKIVBgsiEiFwYKIVBgsiEiFwYKIVEo2WKxbtw7Nzc1wXVd9TDgcRlNTEzZt2pTDlpGXJfabI0eOwHEcrF+/ftjjVq5ciVmzZuWhhYVTksGiu7sba9euxYoVK+Dz6V9iWVkZvve972H16tXo6+vLYQvJi0babwDg4Ycfxvvvv4/XX389R60rvJIMFlu2bEEkEsE999yT9rHLly/H6dOn8corr+SgZeRlmfSb8ePHY+nSpaq7kGJVksGitbUVS5YsQWVlZdrH1tXVYeHChXjxxRez3zDytEz6DQDcdddd2L17Nz799NMst8wbSi5YtLW14YMPPsCCBQuSytevX485c+agoaEBVVVVmDFjBl599VXxHLfddht2796Ns2fP5qPJ5AGp+k3Mxo0bceWVV6Kqqgpf/OIXceDAAatO7Nht27bltK2FUnLB4u233wYATJ8+Pan8mWeewRe+8AWsWrUKP/7xjxEIBHDnnXfizTfftM4xY8YMGGPi56LSl6rfAMDPfvYzPPvss/jud7+LRx99FAcOHMCtt96KU6dOJdWrra3FpEmTsGfPnry0Oe9MiXnssccMABMMBpPKe3t7k/4dCoXMtGnTzK233mqdo7293QAwa9euzWlbyTukftPW1mYAmKqqKnP8+PF4+d69ew0A88gjj1jnWbhwoWlpaclLm/Ot5O4szpw5g0AggOrq6qTyqqqq+H93dHSgq6sL8+fPx7vvvmudo76+HgBw+vTp3DaWPCNVvwGAr33ta7jiiivi/77xxhsxa9YsbN++3apbX19fsv2m5IJFKm+88QZuuukmVFZWYsyYMRg3bhw2b96Mrq4uq665sHiY4zj5biZ50DXXXGOVTZkyBUeOHLHKjTEl229KLlg0NDQgEokgGAzGy3bt2hXPcm/atAnbt2/Hjh078PWvfz0eGBJ1dHQAAMaOHZu3dlNhSf1mJDo6Okq23wQK3YBsa25uBjCQ3b7++usBAK+99hoqKyvx1ltvoaKiIl63tbVVPEdbWxsAoKWlJcetJa+Q+k3Mxx9/bNU/dOgQrrrqKqu8ra0Nn/vc53LSxkIruTuL2bNnAwD27dsXL/P7/XAcB9FoNF525MgRbN26VTzH/v374ThO/FxU+qR+E7N161acOHEi/u/f/va32Lt3LxYvXpxUr6urC4cPH8acOXNy29gCKblgcfXVV2PatGnYuXNnvOz2229Hb28vFi1ahOeffx6rVq3CrFmzMHnyZPEcO3bswNy5c9HQ0JCvZlOBSf0mZvLkyZg3bx7WrVuHH/3oR1i8eDEaGhrw/e9/P6nezp07YYzB0qVL89Xs/Crsw5jc2LBhg6murk56XPrTn/7UXHPNNaaiosI0Nzeb1tZW8/jjj5uhb0FnZ6cpLy83L7zwQr6bTQU2tN/EHp0+9dRT5umnnzZNTU2moqLCzJ8/37z//vvW8XfffbeZN29evpudNyUZLDo7O82YMWNG9Au/ceNGM2HCBGtcBpW+TPrNyZMnTWVlpdm6dWsOWuYNJRksjDFmzZo1ZurUqSYajaqPCYVCpqmpyTz33HM5bBl52Uj6jTHGrFixwsycOTNHrfIGxxjuSEZEwyu5BCcR5QaDBRGpMFgQkYpqBKfrumhvb0dNTU3JjnsvdsYYBINBNDY2pr0kXK6w33hfOv1GFSza29vR1NSUlcZRbh07dgwTJ04sdDMAsN8UE02/UQWLmpoaAMDRl8ej9qE/wTnrwozxIfj39YABTL0P7hUDp+p2o9bxXa59mS7XXrqsK1ol1LPLuqOjhGPt83VH7GODkQqrLFXdnrBdtzdcZpX1CWXnhbJQ2G+VRYQyVygzITvqO+HBMrevDyd+sDr+WXlBrC33Tb0PT37yC9RHe9HhvwxPfPbrAIDusir8v4o6AEC0wn59bplQVm7foUTLhXpCz3bLpGN19aTzDbTHLvOFhLKIUBa2H0T6pWOFetL5/CF7JXtfSDg2PFgvEunH3l1rVP1GFSxit5B19/0RNQ7gADAdLmrvOBOv03FiYL6/ce3GRV37w4wIZeGo/UsSitpNlMr6o/YvZ39EKhM+XQBlYbtcKgsIZX4hMPhDQlnYbrcUGBASgoVfCBYBocxDt/uxtjz30cuowUC/qYn2YssHL8Tr/Ic/Xz1QV3gtYrAQfpEdoZ7QHeAIx0LoDlI9RzgfADhSsBAuI93h+xwhWAjXEOsJ1/AbIVgIIyN8Qj1Nv0l71mms3bGfJgD0bqxP9zR0iXGG/Iw4PqybckehmkMjkPEU9eAb4xC9Tv5rTZTKA5//L/i45orhK5JnpB0sjDNwVxH72eO6CCXkKTqFL3edQn6iM3qZqqxLzE8IuQ0h59AdEfIiYbseIOcnzglfOaRcxHnhK0dI+MoRlvIT0leOYfIT8bL+wVtHJ+Sdrx8SFw58MPGf0XJfUp7ClfIOyvyE9JUjo/yEdL4Ufw9T5TJ0pM9MGlCtrSc9zZB25Et43x39k7O0Xqrb4CA6MYD+e0ah4h/OAe1RRMd64zEdedfZwCi0V9bhX8bfgMV/3IdxoS50lttrXZK3pRUsun79GZgGH+A4CN13Gbr6okCFt/+iUeH9pxsegikbBTgO3pgwE75yg7Cv5BZpK3np3RZUOEAsa+o4DBSkEvYFkvoNA0Vx4ncIIlJJK8R3u9GkcRRSMvMLVx5Xnav1o5usMm0y8/Hr3lBd4863v22VSYlMQE5m/mbB06rr5MNVG+0Nd32JSU0PJzh94Sh8xh6sl+jfdjyqOtfc/7hOuoJQZicA97f+V9U1rl2xwSpLOShLSIYeXP091XXyYf7ta60yX8LgLV9ESoDKeGdBRCoMFkSkwmBBRCoMFkSkkt44CzeQNClMGpmptXzqO1bZEx9+xSrTJjMl2lGZgDwyU3LjvXaiMSqONLSPlUck2vV+/7SdIPMJCUxfOOEfYet/e4Zb5ocbGBypKs0w1drzy+9bZVLS852f2/W05BmrKeqW65aw/VLtN+3CCrt/OpV2makQLl5ul731+x9bZT5hJqq/f7DMMMFJRNnGYEFEKgwWRKTCYEFEKmkmOCuTVriSppRLJj9pj4T85LG/ssq0ycyWR+0Rdn/4GzspqJ1iDsjTzCXZTmZKU5+n/jf79UXFJdgSrhH27gjOaIUvaSUsaTq6ZMZyu9/sb7X7jZT0lFz/gP2+fvC3dr+Rp6jLiUz1FPUsJzNdod6C2avsev3C6lnhwdG0vsjFR9YmHaeuSUSXNAYLIlJhsCAiFQYLIlJJL8EZrUparl+b4JQSRtKU6yOP2FOIpeSoK+ytINGulwnIa2ZKpASb1vUP2gk2vYuP4DSeHsHpS1rWX0r+SqSEsDZJKU4zV/YbKZmZeoq6bgTnW3/arLu44EvX/LVVlv5qm4Oc/sGkphPlCE4iyjIGCyJSYbAgIhUGCyJSSXMEZ1XSPqPSmpkSKWFkKuwyKekpJ6V0SSXt5j+AvAGQZPadT1ll8qhOoUzYpFbeQEYncQNeo0zeFYJb7iRt5iNtFCQep9wAKJ01M1XXFaeoy33OKKeo31Z5r1XmE0Z1iiM9++3sdUZbDCWczxfVZ8Z5Z0FEKgwWRKTCYEFEKgwWRKSS3iZD0VFDEpzyjuRDSclMU6YbbybV+nSFbrMY7U7mgLybucTfp0toyUa+a7a0LmexiJb74CSM4JR2PZdodzPXrpmp3fxHSmamSmSagK4/mP5++zpCPd12SfqetENYl3OkeGdBRCoMFkSkwmBBRCoMFkSkkuYU9Ur0J2SnuiLKBKeQzHTK7TI5VTTyeCYlM1MlMk1Id53d/zzyzWsuVW4gOakpjcwUj8sgmand/EciJTNTJjLFRL1tp3nVLuwTKkplHsE7CyJSYbAgIhUGCyJSYbAgIhXHGDNsJqi7uxu1tbX47q6voaJ6MJvUHbF3UX9h5kuqC//ZPzxplUnJRynx6ITtMnUSNUUiUzqn0y+sezncbuYX+JX1fMLUcr9Q5hOmtyfWi4b68Lv/8QN0dXVh9OjR9gkKINZvPn/vavjLB/uKtLbm7/5Ot7bplB8K09GVa2ZqR2aKycwUiUxfub1Jjz9g1w2U2fUqyiKqssqAXXZZmd1JLgvoykYHBkeThnrC+OnN/6TqN7yzICIVBgsiUmGwICIVBgsiUklrBGcwUoH+yGB2qiusG8EpKRMSPtJqgFJaSZyyKyU9lfWA4kpmDpYl1BN2WfeKaDmAhKSmdgSnpFDJTCmRCRRXMjNmVEJZIMA1OIkoyxgsiEiFwYKIVBgsiEiFwYKIVNJbsDdShbLwYFq7J2zvnvTnO+1hu+fD9oID50N21lci5Wrb7n9UdeyVm+zdw6SnHkCRPfkQyoyHn4a4ZQ6chCcgmewqpt+hzv6cPn1Et9DzVS+vtsqkpx5AcT35iKnxDy6aUebn0xAiyjIGCyJSYbAgIhUGCyJSSSvB2ROuSEpwngvbCxP8ZsHTI26Mdo0LLe0QbkBOUn7ymG6Nhc/+lZ2c0yYztes4zL7TTtYm7o7m5QSnL2zgcxLbZ38GBzaOfMc1aY2LVMs/a2TS5wDgozt+qKo3819Wjvgav5z3vKreEx9+xSqr9Z+P/3eFX/egAeCdBREpMVgQkQqDBRGpMFgQkUpaCc7ecBkCCUlNaWSmRBpJKS2w64Z0C+xK5zv6l39tlWlHZQLyyEyt3z9tJ+c+/5/tRK82mSlJTGbGyxLeLxPR7YxVCP4QkJwy1CUfpz0iLc5r13PFz9T+PJt/YJ/v4GohsSqseZKLd/ffF6+xyqSk567b1o/4GonJzJiahLIAE5xElG0MFkSkwmBBRCoMFkSkklaCsy9cBn9CUvN8SJfglKeFj3yB3VTTzK0rKKeYp6orWdT4gFX2P9v/1irTJjMXVt1nlf3q/MtWmV9I/vr7B6dHm4i8oKwXaEZwysdluR3KXJ4TsdtnUvxd1SY+b7rbTlK+83N7yryU9JTMv32tVbbrzRVWWY2U4PQNTlH3O/p+wzsLIlJhsCAiFQYLIlJhsCAilbQSnOfDZfAnJDVDYd3h0khKKTF09C/tBM1VG+3EUKpp5kNp18scOKfqlEClve7ol6611wR96//8jVUmJUdNn71GIuycVFIyM8aXUObzcoIzAviTPgrdCE6fOO1eN71dWtNT+xk7aYzmTZX4HCoqPAuYsdwe5bu/1U6MS8lRabSzJDGZGVPnPxf/7zI/E5xElGUMFkSkwmBBRCoMFkSkklaCMxT2w5+Q1AyHdWsVyklFO4k0+Uk74SNtKqMd2afd/AeQ18yUmEp7jrSpsLNXUtITffZFtKk0n5Tg7BsckuiLejfB6Q+58JvEhJzub5T8mdhJz+sfEKayZ9RvdAn5dLhl9jmjwnR7KekZFRO9uvcwMZkZk5j09Pn0r4x3FkSkwmBBRCoMFkSkwmBBRCppJTgjYT/chKSmdjMWeVq4tKZkJvVs2p3MB+rqRhVKIzMXzF5ln091tjQSnH32/GqnfzBj50SzPJ87i3whA59JfH91STXtCE7tLury+YTrilPZ5U9Kmx6URmZKiVnta9GOgpVGcNY4Cf3GYYKTiLKMwYKIVBgsiEiFwYKIVNJKcLphP5CQ1JQ2AJJ8tGrkO2RnQruTeaq6Wm6FciTryC+RlMyMlyWMCHWkIYse4Qu78I1gBOe/v2yvUZkP8hIIqRKKI9+cyhWXsB15AleSmMyMqUsYtennCE4iyjYGCyJSYbAgIhUGCyJSSSvBaUI+GP9gfJE2ANKumandAEg7zVybzEyVyPQLdWffae/WLu1mHlWuhyjZ+Z49+lPiCNPbkbh+p4cTnG6ZD27ZYF9xy+3PWVpnUlq3UjvVW6onJRQz2ZV9gN0fpvxQWP9TOKer3sDcvvbv/k730KBOSGDW+AZ/7Q0TnESUbQwWRKTCYFHkukJN+O3pB9EVuqLQTaEi8u771+Mrd/4j3j/wWfUxaeUsyHtOnL8RZ0NTMKpvRqGbQkXkH1+9A7venotrJi0BsEN1TFrBwgn74AQSEpzCbuZeT2ZKicxU5VIyc7jdzOPtGWbNzLQlJDPPR+oRcquBUAh/7J0OAPhj7+dGfu4c043glBJtUr3sTlvPnLIfCx+91LfFY5VT6yWJycz/e/wKnDk7BqN8Zfjl60sAAK+/eTuAB1Xn4p1FEfq3PyU+QRnoSGHUFKYxVDSum/VO/L8dZ6DfnD47Rn08cxZF6Pq6/w4HsTsXZ8hPItnf/+RBBAIDty7GpN9vGCyKUONl+3DTWHtcAtHF3PUXW/G/3lgy4uNVX0PMhSXR3L7kJbrEDWSlMmmQi7A3gxHqGWmskVhP+F4nHit//zNCuVgWsb9bG2FTYmmj4kz29ogMGXQVcfsAdGPgu74PQOdAW8zIv99mW6wtkUjy5s+uY/+Nch1hsJVUT3h5rhGOFdoj1XOFVIl4jRQfnVRuhPyEWCb0T7FeBjmL7mDyCzx3LgKgGw7cC5s6dw5cQ9FvHKOodfz4cTQ1NY2krZRnx44dw8SJEwvdDADsN8VE029UwcJ1XbS3t6OmpgaO8BeACs8Yg2AwiMbGRvh83vh2yX7jfen0G1WwICLyxp8gIvI8BgsiUmGwICIVBgsiUmGwICIVBgsiUmGwICIVBgsiUmGwICIVBgsiUmGwICIVBgsiUinpYLFu3To0NzfDlRYtEITDYTQ1NWHTpk05bhl5labPHDlyBI7j4MUXX4yXrVy5ErNmzcpDCwvIlKiuri4zZswYs2XLlrSO27Bhg2lsbDTnz5/PUcvIq7R9pq2tzQAwra2t8bKTJ0+aiooKs23bthy3snBK9s5iy5YtiEQiuOeee9I6bvny5Th9+jReeeWVHLWMvGqkfQYAxo8fj6VLl2L9+tJd7rBkg0VrayuWLFmCysrKtI6rq6vDwoULk24x6dIw0j4Tc9ddd2H37t349NNPs9wybyjJYNHW1oYPPvgACxYsSCp3XRfPPPMMrrvuOlRWVmLcuHFYtGgR9u3bl1Tvtttuw+7du3H27Nl8NpsKKFWf6ezsxDe+8Q3U1tairq4O999/Pzo7O8VzxI7dtm1brptbECUZLN5++20AwPTp05PKv/Wtb+Hhhx9GU1MT1q5di5UrV6KyshLvvPNOUr0ZM2bAGBM/D5U+qc8YY7B06VK89NJLuO+++/Dkk0/i+PHjuP/++8Vz1NbWYtKkSdizZ09e2px3hU6a5MJjjz1mAJhgMBgv+/Wvf20AmIceesiq77pu0r/b29sNALN27dqct5W8QeozW7duNQDMunXr4mWRSMTMnz/fSnDGLFy40LS0tOSjyXlXkncWZ86cQSAQQHV1dbzstddeg+M4ePzxx636QxeTra+vBwCcPn06tw0lz5D6zPbt2xEIBPCd73wnXub3+/Hgg6m3+6uvry/ZflOSwUJy+PBhNDY2YsyY4bdrMxfWMOaK1Je2o0ePYsKECUkBBACmTp2a8hhjTMn2m5IMFg0NDYhEIggGgyM6vqOjAwAwduzYbDaLPCzTPhPT0dFRsv2mJINFc3MzgIEMd8ykSZPQ3t6uesIRO66lpSU3DSTPkfrMlVdeiZMnT6Knpyep7kcffZTyPG1tbSXbb0oyWMyePRsAkh6J3nHHHTDG4IknnrDqmyFbp+zfvx+O48TPQ6VP6jNf/vKXEYlEsHnz5nhZNBrFT37yE/EcXV1dOHz4MObMmZPbxhaIaq/TYnP11Vdj2rRp2LlzJ775zW8CAG655RYsW7YMzz77LD7++GMsWrQIruti165duOWWW/DAAw/Ej9+xYwfmzp2LhoaGQr0EyjOpz3z1q1/F3LlzsXLlShw5cgTXXnstfvGLX6Crq0s8x86dO+OPW0tSQZ/F5NCGDRtMdXW16e3tjZdFIhHz1FNPmebmZlNeXm7GjRtnFi9ebPbv3x+v09nZacrLy80LL7xQiGZTAUl95syZM2bZsmVm9OjRpra21ixbtsy899574qPTu+++28ybNy/Prc6fkg0WnZ2dZsyYMWn/0m/cuNFMmDAhqcPQpWGkfcaYgYlklZWVZuvWrTlomTeUbLAwxpg1a9aYqVOnmmg0qqofCoVMU1OTee6553LcMvKqdPtMzIoVK8zMmTNz1Cpv4MbIRKRSkk9DiCj7GCyISIXBgohUGCyISEU1KMt1XbS3t6OmpqZkJ8kUO2MMgsEgGhsb4fN5428A+433pdNvVMGivb0dTU1NWWkc5daxY8cwceLEQjcDAPtNMdH0G1WwqKmpGTghgNGpKr35ZhpN86BRo+yyyZPz344R6g4G0dTSEv+svCCx39QAcACYCz9jPt4nD50uFkPmmAEAhsxo97Senm7cfHOTqt+ogkXsFnI0hGDh9wMrV8q/bMVE+oRHpwyNnuWl2/2L9RvjD+DkmhdRXV187/FwiilYxGj6TeYTyTZvBqZMyfg0dGk5+k970f/Z6cNXJM9IP1j4fIDrDv4cNw5obBz8/8rdv5BJEk57DekeMRXhLiII79zSDycIbw/ENT4fHNeN/xw3DnAbhz+uEA4d0teVulgx/e3s7tbXTe839vLLgenTB+4mpk8Hxo4FOI2bhmEuvxyYPgNm8/PA9BlwLx8PM+7yQjeL0pTencWHHw4ECMcBvv1t4A9/AMrLc9Q0KhkfHkjqN+fOhoCKikK3itKU3p1FRcXABw4M/GSgII2h/YaBoih5Y/QOEXleel9DTp0CensH/y1kR5wbblCdykiZFSmzJDyHci4srpq1awBi0nT0aO88hjTDJHWN690EZ3s7kLho9qlTdp0bbtC919JnaqrtRLSUeNR+ngcP2u9lerly7/Qbd5h+oX1WAPDOgoiUGCyISIXBgohUGCyISCW9BOeoUckJxwwGwTvCiEkpGePzeSdZBAyfaLzosdC9Fq+95kz19iYP2E0nWTiUtt94KckIZNZvtJwcL03AOwsiUmGwICIVBgsiUmGwICKVvGyMrE1cahN7xZAIzYT0+qRxeI7Hp6VnSkoKSkk87WevPV8haZPgkuFGa2bKW+8UEXkWgwURqTBYEJEKgwURqeQlwSnRJinlkW+FS+xlkhDLdQLqUqBNUuZjxGQ6SqHf8M6CiFQYLIhIhcGCiFQYLIhIpWAJTqcnaJWZffvsip98YpcVcA/SzBJnukRVJqP4Sp4wv136TII99ntYU124RGEp9BveWRCRCoMFEakwWBCRCoMFEakULMGpVgTTirNNmnpe6q85Hy7F9zCbSzdceu8eEY0IgwURqTBYEJEKgwURqeQlwVlKa0Veikkyylwp9JvifwVElBcMFkSkwmBBRCoMFkSkkp8Ep7RGYne37mBpSvLBg3a9xsZ0mzX8df71XzM7Z64l7GLf3dOD2ptvLlxbciCTfiNNRw8G7bITJ+xj08lFJnwEcQcPejuhn9jVe3q6cfPNtarjeGdBRCoMFkSkwmBBRCoMFkSkUrgp6lJmaPRo3bFSkksq054vFamNXuL19uXCoUN2mfJzrhbWbh092p7CLeS605Lp8bmW2L5z5/TH8c6CiFQYLIhIhcGCiFQYLIhIhcGCiFTSexpSXZ2cgZeG40q7ikl+8xv9NYdwbrhBdai4C1Q6T0i0u0hlslaB9hpeT7FfxOTJyW+7NDLfdZVDpIWd7ERSv1F+TtodzoqB9PBopHhnQUQqDBZEpMJgQUQqDBZEpJJegrOnJzmZJ+0Wpkw+SsS1CtrbR3w+MSmYaj0EKSHW3Ky6jNhuKbOU7Wtk8t7kkdMThJO0M1aNVSeTnbOk9yajhKTQb2pSDK2XriMNIZeIn6lwnUyuIa2tMdJcOe8siEiFwYKIVBgsiEiFwYKIVPKynkVGo/OkUaJSYrWAOz45wqhQr7WxGGV7VKd0vkwSq5nS9httMjPX2HuJSIXBgohUGCyISIXBgohUCrdgb5HSJi4zmQ7NRGjpKYV+w15JRCoMFkSkwmBBRCoMFkSk4qkEp3ZEm9dok03F8FqKkdRv1KM/C6jY+g3vLIhIhcGCiFQYLIhIhcGCiFQKluB0hGnF5uBBuwz29FwHxZm8osxJ/aa72/v9QcvL/YZ3FkSkwmBBRCoMFkSkwmBBRCqeGsFZDLw8wu5S5XhjicqLKoV+wzsLIlJhsCAiFQYLIlJhsCAilfQSnOfOJW/6o9yOuZAbuVDhneiuQbcZ3Dn91Cndcew33sI7CyJSYbAgIhUGCyJSYbAgIpX0Epyf+QyQuN6hkOA0+/Zl2qbsOXRIX1cYYSet45jt6fGXwoZCjY3J3aa21q5TqDUzc7HcQT5GZhai35R+TyWirGCwICIVBosit+/DCtx6/0S8+/vyQjeFisjBg8Ajj6T3TZ2zTovcz7bV4n/vHYUpV9l7ZxCl8qtfAe+9BzQ16Y9JL1j09CQnAru70zrc06QNX7y01md1dfw/jx4FTp8GnPMT8fO3qgAAr+0UsoYe4fQE4SSMxqyurrlIbcqmqVMG+3Cs3/h8DnbtGijbs0d/Lt5ZFKGr/iz2izcKjjPQGU6f4dBourjBfjO4BsiZM/rjmbMoQi+/ZBAIDAQJY2IdgMGCLi6536R/PINFEbr3XmDvO4VuBRWbTPuN6muIuRCGuoND9mxQzjotWh7KyXQPacvAW+/AccyFu4suAIOflRek6jfGQ5sWZ5qXCvbYd3Reen2p+03s7mLg/2v6jSpYBC982E0tLem0k7Kotq5OLB/6GQeDQdRKQyQLgP2m8LLZbxyjCCmu66K9vR01NTVwimF11EuQMQbBYBCNjY3weWQIOfuN96XTb1TBgojIG3+CiMjzGCyISIXBgohUGCyISIXBgohUGCyISIXBgohU/j+2jLaMo+6fxgAAAABJRU5ErkJggg==", 143 | "text/plain": [ 144 | "
" 145 | ] 146 | }, 147 | "metadata": {}, 148 | "output_type": "display_data" 149 | } 150 | ], 151 | "source": [ 152 | "plt.figure(figsize=(3, 3))\n", 153 | "\n", 154 | "plt.subplot(2, 2, 1)\n", 155 | "plt.imshow((v[:, 0].reshape(H, W)))\n", 156 | "plt.imshow(1 - maze_raw, cmap='gray', alpha=1. * maze_raw)\n", 157 | "plt.plot([0], [0], 'r*')\n", 158 | "plt.plot([20], [20], 'b*')\n", 159 | "plt.xticks([]); plt.yticks([])\n", 160 | "plt.title('(a)')\n", 161 | "\n", 162 | "plt.subplot(2, 2, 2)\n", 163 | "plt.imshow((v[:, 1].reshape(H, W)))\n", 164 | "plt.imshow(1 - maze_raw, cmap='gray', alpha=1. * maze_raw)\n", 165 | "plt.xticks([]); plt.yticks([])\n", 166 | "plt.plot([0], [0], 'r*')\n", 167 | "plt.plot([20], [20], 'b*')\n", 168 | "plt.title('(b)')\n", 169 | "\n", 170 | "dip = ip[:, :, 0] - ip[:, :, 1]\n", 171 | "plt.subplot(2, 2, 3)\n", 172 | "plt.imshow(dip[:, 0].reshape(H, W), cmap='bwr', vmin=-1, vmax=+1)\n", 173 | "plt.imshow(1 - maze_raw, cmap='gray', alpha=1. * maze_raw)\n", 174 | "plt.xticks([]); plt.yticks([])\n", 175 | "plt.plot([0], [0], 'r*')\n", 176 | "plt.plot([20], [20], 'b*')\n", 177 | "plt.title('(c)')\n", 178 | "\n", 179 | "plt.subplot(2, 2, 4)\n", 180 | "plt.imshow(dip[:, 3].reshape(H, W), cmap='bwr', vmin=-1, vmax=+1)\n", 181 | "plt.imshow(1 - maze_raw, cmap='gray', alpha=1. * maze_raw)\n", 182 | "plt.xticks([]); plt.yticks([])\n", 183 | "plt.plot([0], [0], 'r*')\n", 184 | "plt.plot([20], [20], 'b*')\n", 185 | "plt.title('(d)')\n", 186 | "\n", 187 | "plt.tight_layout()\n", 188 | "\n", 189 | "plt.savefig('../paper/fig/mdp.pdf')" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "language_info": { 195 | "name": "python" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 5 200 | } 201 | -------------------------------------------------------------------------------- /demo/demo-mdp.wppl: -------------------------------------------------------------------------------- 1 | var H = 21 2 | var W = 21 3 | var make_states = function(i) { 4 | return i == 1 ? [0] : make_states(i - 1).concat([i - 1]) 5 | } 6 | var make_blank_maze = function(n) { 7 | return n == 0 ? [] : make_blank_maze(n - 1).concat([0]) 8 | } 9 | var S = make_states(H * W) 10 | var G = [0, H * W - 1] 11 | var A = [0, 1, 2, 3] // left, right, up, down 12 | 13 | var coord_actions = [ 14 | [-1, 0], 15 | [1, 0], 16 | [0, 1], 17 | [0, -1], 18 | ] 19 | 20 | var maze = make_blank_maze(H * W); 21 | 22 | var Tr = function(s, a) { 23 | var x = s % W 24 | var y = Math.floor(s / W) 25 | 26 | var next_x = x + coord_actions[a][0] 27 | var next_y = y + coord_actions[a][1] 28 | 29 | var next_x = next_x < 0 ? 0 : (next_x > W - 1 ? W - 1 : next_x) 30 | var next_y = next_y < 0 ? 0 : (next_y > H - 1 ? H - 1 : next_y) 31 | 32 | var next_state = next_x + W * next_y 33 | return maze[next_state] == 1 ? s : next_state 34 | } 35 | 36 | var R = function(s, a, g){ 37 | return s == g ? 1.0 : 0.0 38 | } 39 | 40 | var is_terminating = function(s, g) { 41 | return s == g 42 | } 43 | 44 | 45 | // Adapted from https://agentmodels.org/chapters/3b-mdp-gridworld.html 46 | var policy = dp.cache(function(s, g, t) { 47 | return Infer({method: 'enumerate', model: function() { 48 | var a = uniformDraw(A) 49 | var value = R(s, a, g) + (t <= 0 ? 0. : ( 50 | is_terminating(s, g) ? 0.0 : (0.9 * expectation(Infer(function() { 51 | var s_ = Tr(s, a) 52 | return V(s_, g, t-1) 53 | }))))) 54 | factor(2 * value) // factor(beta * value) 55 | 56 | return a 57 | }}) 58 | }) 59 | 60 | var V = dp.cache(function(s, g, t) { 61 | return expectation(Infer({method: 'enumerate', model: function() { 62 | var a = sample(policy(s, g, t)) 63 | return R(s, a, g) + (t <= 0 ? 64 | 0.0 : (is_terminating(s, g) ? 0. : 65 | 0.9 * expectation(Infer(function() { 66 | var s_ = Tr(s, a) 67 | return V(s_, g, t-1) 68 | }))))}})) 69 | }) 70 | 71 | // Run with webppl --require webppl-timeit demo-grid.wppl -- 1000 72 | var t = argv._[1]; 73 | console.log('t = ' + t.toString()); // toString() guards against undefined! 74 | 75 | var out = function () { return map( 76 | function(g) { return map( 77 | function(s) { 78 | return V(s, g, t) 79 | }, S) 80 | }, G); 81 | } 82 | console.log(timeit(out)['runtimeInMilliseconds']); 83 | -------------------------------------------------------------------------------- /demo/demo-monty.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9393290f-db1c-418c-a116-503a7c194f3c", 6 | "metadata": {}, 7 | "source": [ 8 | "# The Monty Hall problem\n", 9 | "\n", 10 | "In a game show, contestant Alice faces three doors, one of which hides a prize.\n", 11 | "\n", 12 | "Alice picks a door, and Monty reveals one of the two other doors that does _not_ hide the prize.\n", 13 | "\n", 14 | "Now Alice has the option to keep her current door, or switch to the other not-yet-revealed door. What should she do?" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "a492fbd5-a840-4c16-ab8f-a00e5e434053", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from memo import memo\n", 25 | "import jax\n", 26 | "import jax.numpy as np\n", 27 | "\n", 28 | "Door = np.arange(3)\n", 29 | "\n", 30 | "@memo\n", 31 | "def monty[pick: Door, reveal: Door, d: Door]():\n", 32 | " alice: thinks[ monty: chooses(prize in Door, wpp=1) ]\n", 33 | " alice: knows(pick)\n", 34 | " alice: thinks[\n", 35 | " monty: knows(pick),\n", 36 | " monty: chooses(reveal in Door, wpp=(reveal != prize and reveal != pick))\n", 37 | " ]\n", 38 | " alice: observes [monty.reveal] is reveal\n", 39 | " alice: knows(d)\n", 40 | " return alice[Pr[monty.prize == d]]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "id": "245f73ab-f1f6-4c97-b7ac-ffe208ab2bc0", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "7.04 μs ± 105 ns per loop (mean ± std. dev. of 10 runs, 100 loops each)\n", 54 | "If Alice picked door 0 and Monty revealed door 1, then her belief about the prize's location is:\n", 55 | "p(door 0) = 0.3333333432674408\n", 56 | "p(door 1) = 0.0\n", 57 | "p(door 2) = 0.6666666865348816\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "z = monty()\n", 63 | "%timeit -r 10 -n 100 monty().block_until_ready()\n", 64 | "print(\"If Alice picked door 0 and Monty revealed door 1, then her belief about the prize's location is:\")\n", 65 | "for d in Door:\n", 66 | " print(f'p(door {d}) = {z[0, 1, d]}')" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "language_info": { 72 | "name": "python" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 5 77 | } 78 | -------------------------------------------------------------------------------- /demo/demo-newcomb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e85e33ae-7b76-4db4-85b5-f77fc0035335", 6 | "metadata": {}, 7 | "source": [ 8 | "# Newcomb's Paradox\n", 9 | "\n", 10 | "**Inspired by:** Wolpert, D. H., & Benford, G. (2013). _The lesson of Newcomb’s paradox._ Synthese, 190, 1637-1646.\n", 11 | "\n", 12 | "Suppose there are two boxes, A and B. You can choose to take the contents of either just box B, or both boxes A and B.\n", 13 | "* You know that box A contains \\$1000.\n", 14 | "* Box B's contents were determined ahead of time by an omniscient adversary (\"God\"), who places:\n", 15 | " - \\$1,000,000 if he predicts you will take Box B only\n", 16 | " - \\$0 if he predicts you will take both boxes.\n", 17 | "\n", 18 | "Assuming (because of determinism) that God can indeed make such predictions perfectly, what should you do? People have differing intuitions. Some think you should take both boxes and earn \\$1,001,000. Others think you should only take box B.\n", 19 | "\n", 20 | "Here, we use memo to model how Wolpert and Benford (2013) resolve the paradox. They argue that people's intuitions come from different intuitions of how to translate the English description into a formal game. We will show how both translations can be implemented in memo models.\n", 21 | "\n", 22 | "We'll start with some groundwork…" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "4aa782ac-d04f-4945-91c2-15fc7f7c4ea4", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from memo import memo\n", 33 | "import jax\n", 34 | "import jax.numpy as np\n", 35 | "from enum import IntEnum\n", 36 | "\n", 37 | "class Pick(IntEnum):\n", 38 | " # You could either take both A and B, or just B.\n", 39 | " AB = 0\n", 40 | " B = 1\n", 41 | "\n", 42 | "# y is (y)our choice\n", 43 | "# g is (g)od's choice\n", 44 | "@jax.jit\n", 45 | "def payout(g, y):\n", 46 | " return np.array([\n", 47 | " [1e3 + 000, 000],\n", 48 | " [1e3 + 1e6, 1e6]\n", 49 | " ])[g, y]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "c1e7c9df-2c91-46e5-be1a-fa592a1d0033", 55 | "metadata": {}, 56 | "source": [ 57 | "In Wolpert and Benford's words, the \"fearful\" interpretation is that you have \"free will\" to pick $y$ unconditionally, but God \"knew\" what you would have picked and predicts $g$ correctly. In this case, you should pick B only." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "id": "56a5752f-8bed-4719-9906-f8dc1c4484f3", 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Fearful chooses AB with probability 0.0.\n", 71 | "Fearful chooses B with probability 1.0.\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "@memo\n", 77 | "def fearful[p: Pick]():\n", 78 | " alice: chooses(y in Pick, to_maximize=imagine[\n", 79 | " god: knows(y),\n", 80 | " god: chooses(g in Pick, to_maximize=(g == y)),\n", 81 | " E[payout(god.g, y)]\n", 82 | " ])\n", 83 | " return Pr[alice.y == p]\n", 84 | "\n", 85 | "for p, pr in zip(Pick, fearful()):\n", 86 | " print(f'Fearful chooses {p.name} with probability {pr}.')" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "2e7f89b5-ef91-48db-975d-54cc71e8ebd1", 92 | "metadata": {}, 93 | "source": [ 94 | "The \"realist\" interpretation is that God pre-registers $g$, and you have \"free will\" to pick the conditional distribution of $y \\mid g$. In this case, you should pick both boxes." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "id": "c1af18e1-626f-4d1f-8ace-ad323cc503c0", 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Realist chooses AB with probability 1.0.\n", 108 | "Realist chooses B with probability 0.0.\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "@memo\n", 114 | "def realist[p: Pick]():\n", 115 | " alice: thinks[ god: chooses(g in Pick, wpp=1) ] # this distribution is arbitrary\n", 116 | " alice: chooses(y in Pick, to_maximize=E[payout(god.g, y)])\n", 117 | " return Pr[alice.y == p]\n", 118 | "\n", 119 | "for p, pr in zip(Pick, realist()):\n", 120 | " print(f'Realist chooses {p.name} with probability {pr}.')" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "35df39bd-702a-4955-b719-c6227de874a5", 126 | "metadata": {}, 127 | "source": [ 128 | "## Advanced explorations\n", 129 | "\n", 130 | "What should God do? It is tempting to implement a better version of the realist's model by having God choose $g$ by modeling Alice as choosing $y$ to be equal to $g$. But memo prohibits that: Alice can't choose based on a variable she doesn't yet know." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "id": "1f336674-9853-414e-b819-1a71a54829e0", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | " Error: Unknown choice alice.g\n", 144 | " file: \"810282331.py\", line 5, in @memo should_fail_to_compile\n", 145 | " alice: chooses(y in Pick, to_maximize=(g == y)),\n", 146 | " ^\n", 147 | "\n", 148 | " hint: Did you perhaps misspell g? alice is not yet aware of any\n", 149 | " choice called g. Or, did you forget to call alice.chooses(g\n", 150 | " ...) or alice.knows(g) earlier in the memo?\n", 151 | "\n", 152 | " ctxt: This error was encountered in the frame of alice, as modeled\n", 153 | " by imagined_god, as modeled by alice. In that frame, alice is\n", 154 | " currently modeling the following 1 choices: y.\n", 155 | "\n", 156 | " info: You are using memo 1.1.2, JAX 0.5.0, Python 3.13.2 on Darwin.\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "try:\n", 162 | " @memo\n", 163 | " def should_fail_to_compile[p: Pick]():\n", 164 | " alice: thinks[ god: chooses(g in Pick, wpp=imagine[\n", 165 | " alice: chooses(y in Pick, to_maximize=(g == y)),\n", 166 | " E[g == alice.y]\n", 167 | " ]) ]\n", 168 | " alice: chooses(y in Pick, to_maximize=(g == y))\n", 169 | " return Pr[alice.y == p]\n", 170 | "except Exception as e:\n", 171 | " print(' Error:', e.message)\n", 172 | " print('\\n'.join(e.__notes__))" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "b4c1c680-3e7c-401a-b9c6-73196bc9d6ad", 178 | "metadata": {}, 179 | "source": [ 180 | "Similarly, it is tempting to ask: what if Alice actually _does_ make the choice that God predicted? But memo catches the error here as well." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 5, 186 | "id": "69942d9e-c27d-4b4b-9463-0ba4a6b0c103", 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | " Error: Choice based on uncertain expression\n", 194 | " file: \"3580391324.py\", line 5, in @memo should_also_fail_to_compile\n", 195 | " alice: chooses(y in Pick, to_maximize=(y == god.g))\n", 196 | " ^\n", 197 | "\n", 198 | " hint: alice is uncertain about the value of the expression\n", 199 | " (wpp/to_maximize) that alice is using to choose y. Hence,\n", 200 | " alice cannot compute the probabilities needed to make the\n", 201 | " choice. Perhaps you meant to take an expected value somewhere,\n", 202 | " using E[...]?\n", 203 | "\n", 204 | " ctxt: This error was encountered in the frame of alice. In that\n", 205 | " frame, alice is currently modeling the following 2 choices:\n", 206 | " god.g, y.\n", 207 | "\n", 208 | " info: You are using memo 1.1.2, JAX 0.5.0, Python 3.13.2 on Darwin.\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "try:\n", 214 | " @memo\n", 215 | " def should_also_fail_to_compile[p: Pick]():\n", 216 | " alice: thinks[ god: chooses(g in Pick, wpp=1) ]\n", 217 | " alice: chooses(y in Pick, to_maximize=(y == god.g))\n", 218 | " return Pr[alice.y == p]\n", 219 | "except Exception as e:\n", 220 | " print(' Error:', e.message)\n", 221 | " print('\\n'.join(e.__notes__))" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "language_info": { 227 | "name": "python" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 5 232 | } 233 | -------------------------------------------------------------------------------- /demo/demo-pc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1330a5e8-2186-408a-9326-71ac967784d7", 6 | "metadata": {}, 7 | "source": [ 8 | "# Perturbation Confusion\n", 9 | "\n", 10 | "This notebook demonstrates the \"perturbation confusion\" issue referenced in Section 1.1.1 of the memo paper, and shows how memo addresses it.\n", 11 | "\n", 12 | "The problem set-up is that Alice has to choose between an indoor and outdoor restaurant, and has a slight preference for outdoor over indoor (utility 11 vs 10). However, there is a 50% chance of snow, and if it snows she greatly prefers indoor seating over outdoor (cost of 0 vs 100). We thus expect that Alice will \"play it safe\" and pick indoor seating.\n", 13 | "\n", 14 | "However, in WebPPL, a naïve implementation of this problem leads Alice to pick outdoor seating. This is because this model erroneously gives Alice \"control\" over the weather." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "4ddde059-cec1-4f77-b5b6-4a0148cce075", 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "{\n", 28 | " \u001b[32m'\"indoor\"'\u001b[39m: { val: \u001b[32m'indoor'\u001b[39m, prob: \u001b[33m0.4238831152341708\u001b[39m },\n", 29 | " \u001b[32m'\"" 30 | ] 31 | }, 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "outdoor\"'\u001b[39m: { val: \u001b[32m'outdoor'\u001b[39m, prob: \u001b[33m0.576116884765829\u001b[39m }\n", 37 | "}\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "%%bash\n", 43 | "webppl <(cat <" 126 | ] 127 | }, 128 | "metadata": {}, 129 | "output_type": "display_data" 130 | } 131 | ], 132 | "source": [ 133 | "# Run the solver\n", 134 | "q = Q(10)\n", 135 | "v = np.max(q, axis=1, keepdims=True)\n", 136 | "p = (q == v) * 1.0\n", 137 | "v = v.squeeze(-1)\n", 138 | "\n", 139 | "# Make the figure in the paper\n", 140 | "from matplotlib import pyplot as plt\n", 141 | "plt.figure(figsize=(3, 2))\n", 142 | "plt.plot(B[p[:, 0] == 1], v[p[:, 0] == 1], label='feed')\n", 143 | "plt.plot(B[p[:, 1] == 1], v[p[:, 1] == 1], ':', label='sing')\n", 144 | "plt.plot(B[p[:, 2] == 1], v[p[:, 2] == 1], '--', label='ignore')\n", 145 | "plt.legend()\n", 146 | "plt.xlabel('Belief state, P(hungry)')\n", 147 | "plt.ylabel('Long-term reward')\n", 148 | "plt.title('Crying baby POMDP solution')\n", 149 | "plt.tight_layout()\n", 150 | "plt.savefig('../paper/fig/pomdp.pdf')" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "language_info": { 156 | "name": "python" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 5 161 | } 162 | -------------------------------------------------------------------------------- /demo/demo-risk-aversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6e5b27c6-9918-458d-b1df-c194bebdd650", 6 | "metadata": {}, 7 | "source": [ 8 | "# Risk aversion\n", 9 | "\n", 10 | "Terry Tao recently shared some thoughts about risk aversion [on Mastodon](https://mathstodon.xyz/@tao/113479000564381543).\n", 11 | "\n", 12 | "His example is easy to model in memo. First, some imports and boilerplate…" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "678aef32-1994-4770-a559-2c7fbbb144e8", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from memo import memo\n", 23 | "import jax.numpy as np\n", 24 | "import jax\n", 25 | "from enum import IntEnum\n", 26 | "\n", 27 | "class Action(IntEnum): # two types of actions\n", 28 | " Safe = 0\n", 29 | " Bold = 1\n", 30 | "\n", 31 | "# Outcome space: unit normal, support truncated to -10 to 10, discretized to 101 cells\n", 32 | "Outcome = np.linspace(-10, 10, 101)\n", 33 | "from jax.scipy.stats.norm import pdf as normpdf\n", 34 | "normpdf = jax.jit(normpdf)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "9d9f8e03-da11-4b62-ac94-616da9beffbc", 40 | "metadata": {}, 41 | "source": [ 42 | "An agent's utility depends on the type of action and the outcome." 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "2b76392f-0344-42e3-9330-a43c02dbcb2d", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "@jax.jit\n", 53 | "def utility(a, o):\n", 54 | " means = np.array([5, 9]) # safe, bold\n", 55 | " stdvs = np.array([3, 10]) # safe, bold\n", 56 | " return means[a] + stdvs[a] * o" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "81fc00f2-66cc-40d5-a982-803f896d2724", 62 | "metadata": {}, 63 | "source": [ 64 | "Now in memo, we can model an agent who minimizes their \"value at risk\", which is defined by Terry as $\\sqrt{\\text{Var}[u]} - E[u]$ for utility $u$." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "87e6cd78-f09b-4d36-9284-2e7a8cdbde5b", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "P(Safe) = 1.0\n", 78 | "P(Bold) = 0.0\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "@memo\n", 84 | "def model[a: Action]():\n", 85 | " terry: chooses(a in Action, to_minimize=imagine[ # terry minimizes \"value at risk\"\n", 86 | " world: chooses(o in Outcome, wpp=normpdf(o)), # outcome of action ~ N(0, 1)\n", 87 | " # value risk:\n", 88 | " Var[utility(a, world.o)]**0.5 - E[utility(a, world.o)]\n", 89 | " ])\n", 90 | " return Pr[terry.a == a]\n", 91 | "for a in Action: print(f'P({a.name}) = {model()[a]}')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "2efc65c1-1932-44d2-bc80-d1c1d3568647", 97 | "metadata": {}, 98 | "source": [ 99 | "Terry chooses the safe action.\n", 100 | "\n", 101 | "Now what happens if we introduce an external \"shock\" factor with mean 5 and variance 10?" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "id": "ce8a68b3-db84-4195-84b6-9d9b753551da", 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "P(Safe) = 0.0\n", 115 | "P(Bold) = 1.0\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "@jax.jit\n", 121 | "def utility(a, o, shock):\n", 122 | " mean = np.array([5, 9])\n", 123 | " stdv = np.array([3, 10])\n", 124 | " return mean[a] + stdv[a] * o + (-5 + 10 * shock)\n", 125 | "\n", 126 | "@memo\n", 127 | "def model[a: Action]():\n", 128 | " terry: chooses(a in Action, to_minimize=imagine[ # terry minimizes \"value at risk\"\n", 129 | " world: chooses(o in Outcome, wpp=normpdf(o)), # outcome of action ~ N(0, 1)\n", 130 | " world: chooses(s in Outcome, wpp=normpdf(s)), # external \"shock\" factor ~ N(0, 1)\n", 131 | " # value risk:\n", 132 | " Var[utility(a, world.o, world.s)]**0.5 - E[utility(a, world.o, world.s)]\n", 133 | " ])\n", 134 | " return Pr[terry.a == a]\n", 135 | "\n", 136 | "for a in Action: print(f'P({a.name}) = {model()[a]}')" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "a3880bba-b0cc-4720-ae0b-3c9d348ee673", 142 | "metadata": {}, 143 | "source": [ 144 | "Now Terry prefers the risky action!" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "language_info": { 150 | "name": "python" 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 5 155 | } 156 | -------------------------------------------------------------------------------- /demo/demo-rsa.py: -------------------------------------------------------------------------------- 1 | from memo import memo 2 | import jax 3 | import jax.numpy as np 4 | from enum import IntEnum 5 | 6 | ## Boilerplate - define U, R, and denotes. 7 | class U(IntEnum): # utterance space 8 | GREEN = 0b0001 9 | PINK = 0b0010 10 | SQUARE = 0b0100 11 | ROUND = 0b1000 12 | 13 | class R(IntEnum): # referent space 14 | GREEN_SQUARE = U.GREEN | U.SQUARE 15 | GREEN_CIRCLE = U.GREEN | U.ROUND 16 | PINK_CIRCLE = U.PINK | U.ROUND 17 | 18 | @jax.jit 19 | def denotes(u, r): 20 | return (u & r) != 0 21 | 22 | ## Recursive RSA model 23 | @memo 24 | def L[u: U, r: R](beta, t): 25 | listener: thinks[ 26 | speaker: given(r in R, wpp=1), 27 | speaker: chooses(u in U, wpp= 28 | denotes(u, r) * (1 if t == 0 else exp(beta * L[u, r](beta, t - 1)))) 29 | ] 30 | listener: observes [speaker.u] is u 31 | listener: chooses(r in R, wpp=Pr[speaker.r == r]) 32 | return Pr[listener.r == r] 33 | 34 | beta = 1. 35 | print(L(beta, 0)) 36 | print(L(beta, 1)) 37 | 38 | 39 | ## Fitting the model to data... 40 | Y = np.array([65, 115, 0]) / 180 # data from Qing & Franke 2015 41 | @jax.jit 42 | def loss(beta): 43 | return np.mean((L(beta, 1)[0] - Y) ** 2) 44 | 45 | from matplotlib import pyplot as plt 46 | plt.figure(figsize=(5, 4)) 47 | 48 | ## Best model fit vs. data 49 | beta = 1.74 50 | plt.subplot(2, 1, 2) 51 | X = np.array([0, 1, 2]) 52 | plt.bar(X - 0.25, Y, width=0.25, yerr=2 * np.sqrt(Y * (1 - Y) / 180), capsize=2, label='humans') 53 | plt.bar(X + 0.00, L(beta, 1)[0], width=0.25, label='model, ℓ=1') 54 | plt.bar(X + 0.25, L(beta, 0)[0], width=0.25, label='model, ℓ=0') 55 | plt.xticks([0, 1, 2], ['green\nsquare', 'green\ncircle', 'pink\ncircle']) 56 | plt.xlabel('Inferred referent r') 57 | plt.ylabel("Probability") 58 | plt.ylim(0, 1) 59 | plt.legend() 60 | plt.title('Final model fit') 61 | 62 | ## Fitting by grid search! 63 | plt.subplot(2, 2, 1) 64 | beta = np.linspace(0, 3, 100) 65 | plt.plot(beta, jax.vmap(loss)(beta)) 66 | plt.xlabel('beta') 67 | plt.ylabel('MSE (%)') 68 | plt.yticks([0, 0.02], [0, 2]) 69 | plt.xticks([0, 1, 2, 3]) 70 | plt.title('Grid search') 71 | 72 | ## Fitting by gradient descent! 73 | vg = jax.value_and_grad(loss) 74 | plt.subplot(2, 2, 2) 75 | losses = [] 76 | beta = 0. 77 | for _ in range(26): 78 | l, dbeta = vg(beta) 79 | losses.append(l) 80 | beta = beta - dbeta * 12. 81 | plt.plot(np.arange(len(losses)), losses) 82 | plt.ylabel('MSE (%)') 83 | plt.xlabel('Step #') 84 | plt.yticks([0, 0.02], [0, 2]) 85 | plt.title('Gradient descent') 86 | 87 | plt.tight_layout() 88 | plt.savefig('../paper/fig/rsa-fit.pdf') 89 | -------------------------------------------------------------------------------- /demo/demo-sally-anne.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f9a718b6-f4ee-4aab-aea6-e8339720b630", 6 | "metadata": {}, 7 | "source": [ 8 | "# The Sally-Anne test\n", 9 | "\n", 10 | "**Inspired by:** Wimmer, H., & Perner, J. (1983). _Beliefs about beliefs: Representation and constraining function of wrong beliefs in young children's understanding of deception._ Cognition, 13(1), 103-128.\n", 11 | "\n", 12 | "Sally sees a marble in a box, then leaves the room. While she is gone, anne secretly moves the marble to a basket. When Sally returns to the room, where will she look for the marble—the box or the basket?" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "448d601a-77a4-413a-9b73-95d6b733af30", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from memo import memo\n", 23 | "import jax.numpy as np\n", 24 | "import jax\n", 25 | "from enum import IntEnum" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "id": "e1567f73-5335-41e8-8874-19a793367a26", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "+--------------------+----------+-----------------+-----------------------+\n", 39 | "| marble_pos_t0: Loc | obs: Obs | where_look: Loc | model |\n", 40 | "+--------------------+----------+-----------------+-----------------------+\n", 41 | "| BOX | OBS_NONE | BOX | 0.9900000095367432 |\n", 42 | "| BOX | OBS_NONE | BASKET | 0.009999999776482582 |\n", 43 | "| BOX | OBS_STAY | BOX | 1.0 |\n", 44 | "| BOX | OBS_STAY | BASKET | 0.0 |\n", 45 | "| BOX | OBS_MOVE | BOX | 0.0 |\n", 46 | "| BOX | OBS_MOVE | BASKET | 1.0 |\n", 47 | "| BASKET | OBS_NONE | BOX | 0.009999999776482582 |\n", 48 | "| BASKET | OBS_NONE | BASKET | 0.9900000095367432 |\n", 49 | "| BASKET | OBS_STAY | BOX | 0.0 |\n", 50 | "| BASKET | OBS_STAY | BASKET | 1.0 |\n", 51 | "| BASKET | OBS_MOVE | BOX | 1.0 |\n", 52 | "| BASKET | OBS_MOVE | BASKET | 0.0 |\n", 53 | "+--------------------+----------+-----------------+-----------------------+\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "class Loc(IntEnum): # marble's location\n", 59 | " BOX = 0\n", 60 | " BASKET = 1\n", 61 | "\n", 62 | "class Action(IntEnum): # anne's action on marble\n", 63 | " ACT_STAY = 0\n", 64 | " ACT_MOVE = 1\n", 65 | "\n", 66 | "@jax.jit\n", 67 | "def do(l, a): # apply action to marble to get new location\n", 68 | " return np.array([\n", 69 | " [0, 1],\n", 70 | " [1, 0]\n", 71 | " ])[a, l]\n", 72 | "\n", 73 | "class Obs(IntEnum): # what sally sees\n", 74 | " OBS_NONE = -1 # sees nothing\n", 75 | " OBS_STAY = Action.ACT_STAY\n", 76 | " OBS_MOVE = Action.ACT_MOVE\n", 77 | "\n", 78 | "@memo\n", 79 | "def model[marble_pos_t0: Loc, obs: Obs, where_look: Loc]():\n", 80 | " child: knows(marble_pos_t0, obs, where_look)\n", 81 | " child: thinks[\n", 82 | " sally: knows(marble_pos_t0),\n", 83 | " sally: thinks[\n", 84 | " anne: knows(marble_pos_t0),\n", 85 | " anne: chooses(a in Action, wpp=0.01 if a=={Action.ACT_MOVE} else 0.99),\n", 86 | " anne: chooses(marble_pos_t1 in Loc, wpp=do(marble_pos_t0, a)==marble_pos_t1),\n", 87 | " anne: chooses(o in Obs, wpp=1 if o=={Obs.OBS_NONE} or o==a else 0),\n", 88 | " ],\n", 89 | " sally: observes [anne.o] is obs,\n", 90 | " sally: chooses(where_look in Loc, wpp=Pr[anne.marble_pos_t1 == where_look])\n", 91 | " ]\n", 92 | " return child[ Pr[sally.where_look == where_look] ]\n", 93 | "\n", 94 | "model(print_table=True);" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Python 3 (ipykernel)", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.13.2" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 5 119 | } 120 | -------------------------------------------------------------------------------- /demo/demo-scalar.py: -------------------------------------------------------------------------------- 1 | from memo import memo 2 | import jax 3 | import jax.numpy as np 4 | 5 | ## Scalar implicature 6 | 7 | N = [0, 1, 2, 3] # number of nice people 8 | U = [0, 1, 2] # utterance: {none, some, all} of the people are nice 9 | 10 | @jax.jit 11 | def meaning(n, u): # (none) (some) (all) 12 | return np.array([ n == 0, n > 0, n == 3 ])[u] 13 | 14 | @memo 15 | def scalar[n: N, u: U](): 16 | listener: thinks[ 17 | speaker: chooses(n in N, wpp=1), 18 | speaker: chooses(u in U, wpp=imagine[ 19 | listener: knows(u), 20 | listener: chooses(n in N, wpp=meaning(n, u)), 21 | Pr[listener.n == n] 22 | ]) 23 | ] 24 | listener: observes [speaker.u] is u 25 | listener: chooses(n in N, wpp=E[speaker.n == n]) 26 | return Pr[listener.n == n] 27 | 28 | print(scalar()) -------------------------------------------------------------------------------- /demo/demo-schelling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "489aee96-05b8-426e-bafe-5750cd9a73e6", 6 | "metadata": {}, 7 | "source": [ 8 | "# Schelling game\n", 9 | "\n", 10 | "**Inspired by:** Schelling, T. C. (1980). _The Strategy of Conflict._ Harvard university press.\n", 11 | "\n", 12 | "Alice and Bob agree to meet at a bar on Sunday, but they forget to decide on a bar to meet at. One bar is slightly more popular than the other. Where do they go? We can model this with recursive reasoning — Alice thinks about where Bob thinks Alice might go." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "50196ddd-e76d-4193-ba2e-6cdfff225e5b", 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "alice(1) = [0.59900993 0.40099007]; bob(1) = [0.64611655 0.35388345]\n", 26 | "alice(2) = [0.6905481 0.30945188]; bob(2) = [0.73171747 0.26828253]\n", 27 | "alice(3) = [0.76923996 0.23076005]; bob(3) = [0.8029279 0.1970721]\n", 28 | "alice(4) = [0.832767 0.16723298]; bob(4) = [0.8588822 0.1411178]\n", 29 | "alice(5) = [0.88149947 0.11850046]; bob(5) = [0.90091014 0.09908987]\n", 30 | "alice(6) = [0.91743904 0.08256097]; bob(6) = [0.9314207 0.06857934]\n", 31 | "alice(7) = [0.9431811 0.05681884]; bob(7) = [0.9530266 0.04697341]\n", 32 | "alice(8) = [0.9612362 0.03876386]; bob(8) = [0.96805906 0.031941 ]\n", 33 | "alice(9) = [0.9737139 0.0262862]; bob(9) = [0.97838986 0.02161017]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "from memo import memo\n", 39 | "import jax\n", 40 | "import jax.numpy as np\n", 41 | "\n", 42 | "Bar = np.arange(2)\n", 43 | "@jax.jit\n", 44 | "def prior(b): return np.array([0.55, 0.45])[b]\n", 45 | "\n", 46 | "@memo\n", 47 | "def alice[b: Bar](depth):\n", 48 | " alice: thinks[ bob: chooses(b in Bar, wpp=bob[b](depth - 1)) ]\n", 49 | " alice: chooses(b in Bar, wpp=prior(b) * Pr[b == bob.b])\n", 50 | " return Pr[alice.b == b]\n", 51 | "\n", 52 | "@memo\n", 53 | "def bob[b: Bar](depth):\n", 54 | " bob: thinks[ alice: chooses(b in Bar, wpp=alice[b](depth) if depth > 0 else 1) ]\n", 55 | " bob: chooses(b in Bar, wpp=prior(b) * Pr[b == alice.b])\n", 56 | " return Pr[bob.b == b]\n", 57 | "\n", 58 | "for i in range(1, 10):\n", 59 | " print(f'alice({i}) = {alice(i)}; bob({i}) = {bob(i)}')" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "2533c2d1-9e3b-414f-b7f8-e751f03843e9", 65 | "metadata": {}, 66 | "source": [ 67 | "Notice the rapid convergence to the more popular bar.\n", 68 | "\n", 69 | "How confident is Alice in meeting Bob? How confident is an observer that Alice will meet Bob? Which do you expect to be higher?" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "id": "2b23102d-3a45-4742-8295-ee085913cf87", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Alice is 0.510 confident, observer is 0.529 confident.\n", 83 | "Alice is 0.556 confident, observer is 0.588 confident.\n", 84 | "Alice is 0.625 confident, observer is 0.663 confident.\n", 85 | "Alice is 0.702 confident, observer is 0.739 confident.\n", 86 | "Alice is 0.774 confident, observer is 0.806 confident.\n", 87 | "Alice is 0.835 confident, observer is 0.860 confident.\n", 88 | "Alice is 0.882 confident, observer is 0.902 confident.\n", 89 | "Alice is 0.918 confident, observer is 0.932 confident.\n", 90 | "Alice is 0.943 confident, observer is 0.953 confident.\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "@memo\n", 96 | "def alice_confidence(depth):\n", 97 | " alice: thinks[ bob: chooses(b in Bar, wpp=bob[b](depth - 1)) ]\n", 98 | " alice: chooses(b in Bar, wpp=prior(b) * Pr[b == bob.b])\n", 99 | " return E[alice[Pr[b == bob.b]]]\n", 100 | "\n", 101 | "@memo\n", 102 | "def obs_confidence(depth):\n", 103 | " alice: chooses(b in Bar, wpp=alice[b](depth))\n", 104 | " bob: chooses(b in Bar, wpp=bob[b](depth))\n", 105 | " return Pr[alice.b == bob.b]\n", 106 | "\n", 107 | "for i in range(1, 10):\n", 108 | " print(f'Alice is {alice_confidence(i):.3f} confident, observer is {obs_confidence(i):.3f} confident.')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "b7b09ce2-6df4-410d-9c9a-4ba05fcbd745", 114 | "metadata": {}, 115 | "source": [ 116 | "The observer is always slightly more confident, because they think both Alice and Bob are thinking at level $i$. Alice on the other hand thinks Bob is thinking at level $i-1$.\n", 117 | "\n", 118 | "The code below reproduces the scaled-up experiment shown in the paper (100 bars, 100 levels of recursion)." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 3, 124 | "id": "95af5623-f5c5-4a5d-b225-138f08c424ef", 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "3.06 ms ± 14.4 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "Bar = np.arange(100)\n", 137 | "@jax.jit\n", 138 | "def prior(b): return 1\n", 139 | "\n", 140 | "@memo\n", 141 | "def alice[b: Bar](depth):\n", 142 | " alice: thinks[ bob: chooses(b in Bar, wpp=bob[b](depth - 1)) ]\n", 143 | " alice: chooses(b in Bar, wpp=prior(b) * Pr[b == bob.b])\n", 144 | " return Pr[alice.b == b]\n", 145 | "\n", 146 | "@memo\n", 147 | "def bob[b: Bar](depth):\n", 148 | " bob: thinks[ alice: chooses(b in Bar, wpp=alice[b](depth) if depth > 0 else 1) ]\n", 149 | " bob: chooses(b in Bar, wpp=prior(b) * Pr[b == alice.b])\n", 150 | " return Pr[bob.b == b]\n", 151 | "\n", 152 | "alice(100)\n", 153 | "%timeit -r 10 -n 100 alice(100).block_until_ready()" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "language_info": { 159 | "name": "python" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /demo/demo-schelling.wppl: -------------------------------------------------------------------------------- 1 | // Adapted from https://agentmodels.org/chapters/7-multi-agent.html 2 | 3 | var locationPrior = function() { 4 | if (flip(0.55)) { 5 | return 'popular-bar'; 6 | } else { 7 | return 'unpopular-bar'; 8 | } 9 | } 10 | 11 | /* 12 | var locationPrior = function() { 13 | return randomInteger(100); 14 | } 15 | */ 16 | 17 | var alice = dp.cache(function(depth) { 18 | return Infer({ model() { 19 | var myLocation = locationPrior(); 20 | var bobLocation = sample(bob(depth - 1)); 21 | condition(myLocation === bobLocation); 22 | return myLocation; 23 | }}); 24 | }); 25 | 26 | var bob = dp.cache(function(depth) { 27 | return Infer({ model() { 28 | var myLocation = locationPrior(); 29 | if (depth === 0) { 30 | return myLocation; 31 | } else { 32 | var aliceLocation = sample(alice(depth)); 33 | condition(myLocation === aliceLocation); 34 | return myLocation; 35 | } 36 | }}); 37 | }); 38 | 39 | alice(1).getDist(); 40 | // timeit(function() { alice(100).getDist() }); 41 | // https://github.com/stuhlmueller/webppl-timeit -------------------------------------------------------------------------------- /demo/demo-ultimatum.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "066f6614-c751-425e-a741-834b2669acdf", 6 | "metadata": {}, 7 | "source": [ 8 | "# Ultimatum game\n", 9 | "\n", 10 | "**Inspired by:** Güth, W., Schmittberger, R., & Schwarze, B. (1982). _An experimental analysis of ultimatum bargaining._ Journal of economic behavior & organization, 3(4), 367-388.\n", 11 | "\n", 12 | "This is classic economic game played between an \"offerer\" and a \"receiver.\" The offerer is endowed with $100. They get to make an offer of some (or all) of the endowment to the receiver. The receiver can _accept_ the offer (in which case the deal goes through), or _reject_ it (in which case the deal falls through and nobody gets anything).\n", 13 | "\n", 14 | "We will model a simple softmax-rational offerer-receiver pair and show that the expected rational behavior is for the offerer to offer only slightly more than zero. Interestingly, in practice people tend to offer closer to 50%. This disparity between theoretical prediction and empirical observation has been used to argue that people must be reasoning about some other factor—perhaps perceived fairness." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "e8e9778b-0efd-4944-be96-0028fd0511d7", 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "36.6 μs ± 1.36 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)\n", 28 | "9.71 μs ± 811 ns per loop (mean ± std. dev. of 10 runs, 100 loops each)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "from memo import memo\n", 34 | "import jax\n", 35 | "import jax.numpy as np\n", 36 | "from enum import IntEnum\n", 37 | "\n", 38 | "Proposal = np.linspace(0, 1, 100)\n", 39 | "class Decision(IntEnum):\n", 40 | " Accept = 0\n", 41 | " Reject = 1\n", 42 | "\n", 43 | "@jax.jit\n", 44 | "def payout_offerer(prop, dec):\n", 45 | " return np.array([\n", 46 | " 1 - prop, # if receiver accepts\n", 47 | " 0 # if receiver rejects\n", 48 | " ])[dec]\n", 49 | "\n", 50 | "@jax.jit\n", 51 | "def payout_receiver(prop, dec):\n", 52 | " return np.array([\n", 53 | " prop, # if receiver accepts\n", 54 | " 0 # if receiver rejects\n", 55 | " ])[dec]\n", 56 | "\n", 57 | "@memo # probability that receiver accepts proposal\n", 58 | "def receiver[prop: Proposal, dec: Decision]():\n", 59 | " receiver: knows(prop)\n", 60 | " receiver: chooses(dec in Decision, wpp=exp(50.0 * payout_receiver(prop, dec)))\n", 61 | " return Pr[ receiver.dec == dec ]\n", 62 | "\n", 63 | "@memo # probability that offerer proposes proposal\n", 64 | "def offerer[prop: Proposal]():\n", 65 | " offerer: chooses(prop in Proposal, wpp=exp(50.0 * imagine[\n", 66 | " receiver: knows(prop),\n", 67 | " receiver: chooses(dec in Decision, wpp=receiver[prop, dec]()),\n", 68 | " E[ payout_offerer(prop, receiver.dec) ]\n", 69 | " ]))\n", 70 | " return Pr[ offerer.prop == prop ]\n", 71 | "\n", 72 | "o = offerer()\n", 73 | "r = receiver()\n", 74 | "\n", 75 | "%timeit -r 10 -n 100 offerer().block_until_ready()\n", 76 | "%timeit -r 10 -n 100 receiver().block_until_ready()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 2, 82 | "id": "8f936869-5f47-4ef8-af1e-455c5e4e4bd7", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "" 89 | ] 90 | }, 91 | "execution_count": 2, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | }, 95 | { 96 | "data": { 97 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAASeJJREFUeJzt3Xt8y/f+B/BXkjZJ7y3VG6XuY+5KV/dNj5qts7OLHvOjDBura3djLsXOVrMxM8aZMXYOYxczhx7dVgydbcYYWrZSamip6v2SNvn8/ohmSlWTJvkm6ev5eOSR5Jvv5Z3P6U5ePt/P9/uRCSEEiIiIiByEXOoCiIiIiMyJ4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDcZK6AGvT6XS4fPkyPDw8IJPJpC6HiIiI6kAIgcLCQgQFBUEur71vpsGFm8uXLyM4OFjqMoiIiMgEFy9eRLNmzWpdp8GFGw8PDwD6xvH09JS4GiIiIqqLgoICBAcHG37Ha9Pgwk3VqShPT0+GGyIiIjtTlyElHFBMREREDoXhhoiIiBwKww0RERE5FIYbIiIicigMN0RERORQGG6IiIjIoTDcEBERkUNhuCEiIiKHwnBDREREDoXhhoiIiByKpOFm//79iIqKQlBQEGQyGbZv337Pbfbt24cePXpApVKhTZs22LBhg8XrJCIiIvshabgpLi5G165dsWrVqjqtn5GRgUceeQQPPvggjh07hhkzZmDChAlISkqycKVERERkLySdOPPhhx/Gww8/XOf116xZg5YtW2Lp0qUAgA4dOuDgwYN49913ERkZaakyyVyEAHRaQGgBodM/dDdfQ+g/F+Lma91f21R9pl9wy/I7DlDzMS3CUvslInIAChXg4S/Z4e1qVvBDhw4hIiKi2rLIyEjMmDHjrtuUl5ejvLzc8L6goMBS5TmeSg1QeAUouAwUXgZKb9x85AFleUB5IaApASpKgYpi/XNlOaCtALQaQFsOaCsB3c2H0Er9jYiIyBqa9QYmfCvZ4e0q3GRlZcHfv3oS9Pf3R0FBAUpLS+Hi4nLHNgkJCVi4cKG1SrRPmhLgaiqQdQLIPglknQRuZABFVyF9D4UMMExvf7fXt61/z13WYR0iIjKdQinp4e0q3Jhi9uzZiIuLM7wvKChAcHCwhBXZiJw/gDOJwJn/ARd/+us00O0USsAzCPBsCrj43Hx4659VnoCzK+DsAijdACe1/qFwBpxU+m3lTvr3cif9QyYH5Ar9s6zqWXbzWQ5DaGEAISIiE9lVuAkICEB2dna1ZdnZ2fD09Kyx1wYAVCoVVCqVNcqzfYVZwE//AtJ2ANfTq3/m1gQI6Az4dwICugC+bQDPZoCbL4MGERHZFbsKN+Hh4UhMTKy27Ntvv0V4eLhEFdmJwiwg5T3gl/VAZZl+mdwZaDkAaP8w0C4S8G4ubY1ERERmImm4KSoqQnr6Xz0IGRkZOHbsGBo1aoTmzZtj9uzZuHTpEj755BMAwKRJk7By5Uq88sorePbZZ7Fnzx589tln2LVrl1RfwbYVXQMOvgv8su6vUNOsN/DAZKBNBKD2lLY+IiIiC5A03Pzyyy948MEHDe+rxsbExMRgw4YNuHLlCjIzMw2ft2zZErt27cLMmTPx3nvvoVmzZvjoo494GXhN/vwF2DIKKMrSv2/WGxg0C2j9EE8zERGRQ5MJYbEbgdikgoICeHl5IT8/H56eDtpzcWwz8N/p+suxm9wHRL7JUENERHbNmN9vuxpzQ/egrQS+nQ/8ePOOz/c9Cvx9DaDykLYuIiIiK2K4cRTlRcDW/wPO7dW/H/gqMHAWIOfcqERE1LAw3DgCIYAdU/XBxtkN+PtqoONwqasiIiKSBMONIzj8EXBqm/4meaO3Ac0fkLoiIiIiyfCchb378wiwe7b+9d8WMdgQEVGDx3Bjz0pygc9jAF0F0CEKeOAFqSsiIiKSHMONvdLpgK+eB/IvAo1aAcNX8VJvIiIiMNzYr5TlwB/fAAoV8PRGQO0ldUVEREQ2geHGHhVmA9+/pX897G0gsIu09RAREdkQhht7dGilfq6oZr2AHmOkroaIiMimMNzYm5Jc/ezeADDgZY6zISIiug3Djb35aQ2gKQICOgNth0hdDRERkc1huLEnZfn6cAOw14aIiOguGG7syeGP9AHHtz1wX5TU1RAREdkkhht7oSkGDt2c7bv/i5wQk4iI6C74C2kvjmwESq4DPiFApyelroaIiMhmMdzYg4oy4IcV+tf94gAF5zslIiK6G4Ybe3Dic6DwCuDZFOg6UupqiIiIbBrDjT1I3a5/Dh0HOCklLYWIiMjWMdzYurJ84Nz3+tcdHpO2FiIiIjvAcGPr/vgW0FUAjdsCTdpLXQ0REZHNY7ixdad36p87PCptHURERHaC4caWVZTpe24A3rSPiIiojhhubFnG9/p5pDwCgaDuUldDRERkFxhubFnVKan7HuEdiYmIiOqIv5i2SqcFTifqX9/H8TZERER1xXBjqy7+BJTkAGovIKSf1NUQERHZDYYbW3V6l/653VBA4SxtLURERHaE4cYWCQGk/Vf/mqekiIiIjMJwY4uyTwJ5FwAnNdBmsNTVEBER2RWGG1uUdvMqqdYPAUo3aWshIiKyMww3tqhqvA1PSRERERmN4cbWFF8Hsk/oX7cbKm0tREREdojhxtZc+kX/3Lgt4NZY2lqIiIjsEMONrfnzZrhpFiptHURERHaK4cbWVPXcNO0pbR1ERER2iuHGluh0wKUj+tfsuSEiIjIJw40tyT0LlOXr72/j30nqaoiIiOwSw40tqRpvE9iVUy4QERGZiOHGlhjG2/CUFBERkakYbmyJ4UopDiYmIiIyFcONrago1c8pBbDnhoiIqB4YbmzFld8AXSXg1gTwbi51NURERHaL4cZW3DreRiaTthYiIiI7xnBjKzjehoiIyCwYbmwFr5QiIiIyC4YbW1B0DcjLBCADmvaQuhoiIiK7xnBjC6p6bXzbAWovaWshIiKycww3toAzgRMREZkNw40tuMRwQ0REZC4MN1LT6YBLR/WvOZiYiIio3hhupHb9D6C8AHB2Bfw6Sl0NERGR3WO4kZphJvBugMJJ0lKIiIgcAcON1K4c0z/zEnAiIiKzYLiR2rUz+me/DtLWQURE5CAYbqR2PV3/7NtO2jqIiIgcBMONlMoLgYJL+teN20hbCxERkYOQPNysWrUKISEhUKvVCAsLw88//1zr+suXL0f79u3h4uKC4OBgzJw5E2VlZVaq1syqem3cmgCujaSthYiIyEFIGm62bt2KuLg4xMfH4+jRo+jatSsiIyNx9erVGtffvHkzZs2ahfj4eKSlpWHdunXYunUrXnvtNStXbiY5f+ifeUqKiIjIbCQNN8uWLcPEiRMxbtw4dOzYEWvWrIGrqyvWr19f4/o//PAD+vbti2eeeQYhISEYMmQIRo4cWWtvT3l5OQoKCqo9bEbO7/pn37bS1kFERORAJAs3Go0GR44cQURExF/FyOWIiIjAoUOHatymT58+OHLkiCHMnDt3DomJiRg2bNhdj5OQkAAvLy/DIzg42LxfpD4M4YY9N0REROYi2V3jcnJyoNVq4e/vX225v78/Tp8+XeM2zzzzDHJyctCvXz8IIVBZWYlJkybVelpq9uzZiIuLM7wvKCiwnYDD01JERERmJ/mAYmPs27cPb775Jj744AMcPXoU27Ztw65du/D666/fdRuVSgVPT89qD5ug095yGThPSxEREZmLZD03vr6+UCgUyM7OrrY8OzsbAQEBNW4zb948jB49GhMmTAAAdO7cGcXFxXjuuecwZ84cyOV2lNXyLgBaDeCkBrxspCeJiIjIAUiWBpRKJXr27Ink5GTDMp1Oh+TkZISHh9e4TUlJyR0BRqFQAACEEJYr1hKqTkk1bgPIFdLWQkRE5EAknakxLi4OMTExCA0NRe/evbF8+XIUFxdj3LhxAIAxY8agadOmSEhIAABERUVh2bJl6N69O8LCwpCeno558+YhKirKEHLsBq+UIiIisghJw010dDSuXbuG+fPnIysrC926dcPu3bsNg4wzMzOr9dTMnTsXMpkMc+fOxaVLl9CkSRNERUXhjTfekOormI5XShEREVmETNjd+Zz6KSgogJeXF/Lz86UdXLx+KJB5CHhyHdD5KenqICIisgPG/H7b0QhcB8PTUkRERBbBcCOF4utAyXX9a06YSUREZFYMN1K4fvNKKa9gQOkmbS1EREQOhuFGCjwlRUREZDEMN1LglVJEREQWw3AjBcOcUuy5ISIiMjeGGymw54aIiMhiGG6srbIcuHFe/5rhhoiIyOwYbqwt9xwgdIDKE3D3l7oaIiIih8NwY223Xiklk0lbCxERkQNiuLE2jrchIiKyKIYba+OVUkRERBbFcGNt7LkhIiKyKIYbaxICyEnXv2a4ISIisgiGG2sqygY0hYBMAfi0lLoaIiIih8RwY003LuifvZoCTkppayEiInJQDDfWlJepf/ZqLm0dREREDozhxprybvbceDPcEBERWQrDjTXlX9Q/M9wQERFZDMONNVWdlmK4ISIishiGG2syhJtgaesgIiJyYAw31iIEkP+n/jV7boiIiCyG4cZaiq4ClWWATA54NpW6GiIiIofFcGMtVaekPIIAhbO0tRARETkwhhtr4WXgREREVsFwYy28DJyIiMgqGG6shZeBExERWQXDjbXwMnAiIiKrYLixFvbcEBERWQXDjTUIAeRxzA0REZE1MNxYQ3EOUFkKQAZ4NpO6GiIiIofGcGMNhnvcBAJOSmlrISIicnAMN9aQz/E2RERE1sJwYw0cTExERGQ1DDfWwMvAiYiIrIbhxhrYc0NERGQ1DDfWwMvAiYiIrIbhxtKE+KvnxovhhoiIyNIYbiytJBeoKNa/9uI9boiIiCyN4cbS8i7on90DAGe1tLUQERE1AAw3lpbP8TZERETWxHBjabwMnIiIyKoYbiyNl4ETERFZFcONpfEycCIiIqtiuLE0XgZORERkVQw3lnTrPW7Yc0NERGQVDDeWVHoD0BTqX3NAMRERkVUw3FhS1WXgbn6As4u0tRARETUQDDeWxMvAiYiIrI7hxpI43oaIiMjqGG4sieGGiIjI6hhuLCn/T/2zF09LERERWQvDjSUVXNI/czZwIiIiq2G4saSCy/pnzyBp6yAiImpAGG4spVIDFF3Vv/ZsKm0tREREDQjDjaUUXgEgAIUScG0sdTVEREQNhuThZtWqVQgJCYFarUZYWBh+/vnnWtfPy8tDbGwsAgMDoVKp0K5dOyQmJlqpWiPcekpKJpO2FiIiogbEScqDb926FXFxcVizZg3CwsKwfPlyREZG4syZM/Dz87tjfY1Gg7/97W/w8/PDF198gaZNm+LChQvw9va2fvH3UjWYmKekiIiIrErScLNs2TJMnDgR48aNAwCsWbMGu3btwvr16zFr1qw71l+/fj1yc3Pxww8/wNnZGQAQEhJS6zHKy8tRXl5ueF9QUGC+L1AbDiYmIiKShGSnpTQaDY4cOYKIiIi/ipHLERERgUOHDtW4zY4dOxAeHo7Y2Fj4+/ujU6dOePPNN6HVau96nISEBHh5eRkewcFWuucMww0REZEkTAo3e/furfeBc3JyoNVq4e/vX225v78/srKyatzm3Llz+OKLL6DVapGYmIh58+Zh6dKl+Oc//3nX48yePRv5+fmGx8WLF+tde53wtBQREZEkTDotNXToUDRr1gzjxo1DTEyM1XpDdDod/Pz88OGHH0KhUKBnz564dOkS3n77bcTHx9e4jUqlgkqlskp91bDnhoiISBIm9dxcunQJU6ZMwRdffIFWrVohMjISn332GTQaTZ334evrC4VCgezs7GrLs7OzERAQUOM2gYGBaNeuHRQKhWFZhw4dkJWVZdSxrYLhhoiISBImhRtfX1/MnDkTx44dw08//YR27drhhRdeQFBQEKZNm4bjx4/fcx9KpRI9e/ZEcnKyYZlOp0NycjLCw8Nr3KZv375IT0+HTqczLPv9998RGBgIpVJpylexDG0lUHTz1BpPSxEREVlVvQcU9+jRA7Nnz8aUKVNQVFSE9evXo2fPnujfvz9OnTpV67ZxcXFYu3YtNm7ciLS0NEyePBnFxcWGq6fGjBmD2bNnG9afPHkycnNzMX36dPz+++/YtWsX3nzzTcTGxtb3a5hXUTYgdIDcCXBrInU1REREDYrJ4aaiogJffPEFhg0bhhYtWiApKQkrV65EdnY20tPT0aJFCzz99NO17iM6OhrvvPMO5s+fj27duuHYsWPYvXu3YZBxZmYmrly5Ylg/ODgYSUlJOHz4MLp06YJp06Zh+vTpNV42LqmqU1IegYBcUfu6REREZFYyIYQwdqOpU6fi008/hRACo0ePxoQJE9CpU6dq62RlZSEoKKjaKSRbUFBQAC8vL+Tn58PT09MyBzm1Hfg8Bgh+ABifZJljEBERNSDG/H6bdLVUamoq3n//fTzxxBN3vRLJ19fXLJeM2yUOJiYiIpKMSael4uPj8fTTT98RbCorK7F//34AgJOTEwYOHFj/Cu2R4R43DDdERETWZlLPzYMPPogrV67cMf9Tfn4+HnzwwVrvGNwgGHpueKUUEZmPVqtFRUWF1GUQWYxSqYRcXv/JE0wKN0IIyGqY6fr69etwc3Ord1F2j6eliMiMhBDIyspCXl6e1KUQWZRcLkfLli3rfXsXo8LNE088AQCQyWQYO3ZstdNSWq0Wv/32G/r06VOvghwCp14gIjOqCjZ+fn5wdXWt8R+XRPZOp9Ph8uXLuHLlCpo3b16vv3Ojwo2XlxcA/b8iPDw84OLiYvhMqVTigQcewMSJE00uxiHotEDhzcvX2XNDRPWk1WoNwaZx48ZSl0NkUU2aNMHly5dRWVkJZ2dnk/djVLj5+OOPAQAhISF46aWXeAqqJsXXAF0lIJMD7v73Xp+IqBZVY2xcXV0lroTI8qpOR2m1WuuFmyp3m6SS8NcpKfcAQGFS8xIR3YGnoqghMNffeZ1/fXv06IHk5GT4+Pige/futRZw9OhRsxRnlziYmIiISFJ1DjfDhw83DCB+/PHHLVWP/WO4ISIikpZoYPLz8wUAkZ+fb5kDfDNPiHhPIRJftcz+iahBKS0tFampqaK0tFTqUizi4MGDolOnTsLJyUkMHz78rssc0dy5c8XEiRON2saS7VVeXi5atGghDh8+bPI+6qu2v3djfr/rf6ccqo49N0REdRYXF4du3bohIyMDGzZsuOsyawsJCcG+fftM2nbBggWQyWSQyWRwcnJCSEgIZs6ciaKiIsM6WVlZeO+99zBnzhyj9m3J9lIqlXjppZfw6quvmrwPW1Hn01I+Pj51HuiTm5trckF2j+GGiKjOzp49i0mTJqFZs2a1LjOWRqMx6UZwpm53u/vvvx/fffcdKisrkZKSgmeffRYlJSX417/+BQD46KOP0KdPH7Ro0cKo/Vq6vUaNGoUXX3wRp06dwv3332/y/qRW556b5cuX4913363To0GrulrKy/Q/MiKi2gghUKKplOQhhKhzneXl5Zg2bRr8/PygVqvRr18/HD58GABw/vx5yGQyXL9+Hc8++yxkMhk2bNhQ4zIAOHnyJB5++GG4u7vD398fo0ePRk5OjuFYgwYNwpQpUzBjxgz4+voiMjKyXtvdSqPRYMqUKQgMDIRarUaLFi2QkJBQ63d3cnJCQEAAmjVrhujoaIwaNQo7duwwfL5lyxZERUXZXHv5+Pigb9++2LJlS63fz9bVuecmJibGknU4BiHYc0NEFldaoUXH+UmSHDt1USRclXX76XjllVfw5ZdfYuPGjWjRogWWLFmCyMhIpKenIzg4GFeuXEH79u2xaNEiREdHw8PDA0OHDq22zMvLC3l5eXjooYcwYcIEvPvuuygtLcWrr76KESNGYM+ePYbjbdy4EZMnT0ZKSgoAmLzd7VasWIEdO3bgs88+Q/PmzXHx4kVcvHjRqHZzcXGBRqMBoD+7kZqaitDQUJtqryq9e/fGgQMHjPp+tqbO4aagoACenp6G17WpWq/BKbkOaDUAZPr73BARNVDFxcVYvXo1NmzYgIcffhgAsHbtWnz77bdYt24dXn75ZQQEBEAmk8HLywsBAfr/z3Rzc7tj2dKlS9G9e3e8+eabhv2vX78ewcHB+P3339GuXTsAQNu2bbFkyRLDOv/85z9N2g7Q95RUyczMRNu2bdGvXz/IZDKjTyUdOXIEmzdvxkMPPWTYnxACQUF//SPYFtqrSlBQEC5cuGDUd7Q1Ro25qZoJ3Nvbu8bxN+LmhJoNdlZwww38/ACn+p+zJSKqiYuzAqmL7jx9Yq1j18XZs2dRUVGBvn37GpY5Ozujd+/eSEtLM+qYx48fx969e+Hu7l7jcap+rHv27GmW7W43duxY/O1vf0P79u0xdOhQPProoxgyZEit25w4cQLu7u7QarXQaDR45JFHsHLlSgBAaWkpAECtVlerR+r2quLi4oKSkhKjjmlr6hxu9uzZg0aNGgEA9u7da7GC7BpPSRGRFchksjqfGnIERUVFiIqKwltvvXXHZ4GBgYbXt08JZOp2t+vRowcyMjLwv//9D9999x1GjBiBiIgIfPHFF3fdpn379tixYwecnJwQFBRUbZCyr68vAODGjRto0qRJrcc2RX2/d25urkXqsqY6/9cxcODAGl/TLTgbOBERAKB169ZQKpVISUkxnMapqKjA4cOHMWPGDKP21aNHD3z55ZcICQmBk1PdQ52p29XE09MT0dHRiI6OxlNPPYWhQ4ciNzfX8I/+2ymVSrRp06bGz1q3bg1PT0+kpqYaelFsob2qnDx5Et27dzd6O1ti8n1ubty4gXfeeQfjx4/H+PHjsXTp0oZ9CTjAnhsiopvc3NwwefJkvPzyy9i9ezdSU1MxceJElJSUYPz48UbtKzY2Frm5uRg5ciQOHz6Ms2fPIikpCePGjat1GISp291u2bJl+PTTT3H69Gn8/vvv+PzzzxEQEABvb2+jvkcVuVyOiIgIHDx40LDMFtqryoEDB+552s3WmRRu9u/fj5CQEKxYsQI3btzAjRs3sGLFCrRs2RL79+83d432I7+q54bhhoho8eLFePLJJzF69Gj06NED6enpSEpKgo+Pj1H7CQoKQkpKCrRaLYYMGYLOnTtjxowZ8Pb2hlx+958xU7e7nYeHB5YsWYLQ0FD06tUL58+fR2JiolH7uN2ECROwZcsW6HQ6wzKp2wsADh06hPz8fDz11FMmfS9bIRPG3LTgps6dOyM8PByrV6+GQqEfXKbVavHCCy/ghx9+wIkTJ8xeqLkUFBTAy8sL+fn55r+qa8OjwPkDwBNrgS4jzLtvImqQysrKkJGRgZYtW1YbgEr2TQiBsLAwzJw5EyNHjpS6HIPo6Gh07doVr732miTHr+3v3Zjfb5NiZ3p6Ol588UVDsAEAhUKBuLg4pKenm7JLx8DTUkREVAcymQwffvghKisrpS7FQKPRoHPnzpg5c6bUpdSbSSOsevTogbS0NLRv377a8rS0NHTt2tUshdkd3sCPiIiM0K1bN3Tr1k3qMgyUSiXmzp0rdRlmUedw89tvvxleT5s2DdOnT0d6ejoeeOABAMCPP/6IVatWYfHixeav0h6U3gAq9fcugAfDDRERkVTqHG66desGmUxWbV6RV1555Y71nnnmGURHR5unOntS1Wvj2hhw5nlxIiIiqdQ53GRkZFiyDvtXlg+ovXhKioiISGJ1DjfGzqXR4IT0BWZlApUaqSshIiJq0Op1y8bU1FRkZmYaZjqt8thjj9WrKLvGOaWIiIgkZVK4OXfuHP7+97/jxIkT1cbhVE2m2WAnziQiIiLJmXSfm+nTp6Nly5a4evUqXF1dcerUKezfvx+hoaHYt2+fmUskIiKqu7Fjx+Lxxx+XugyzGTBgADZv3ix1GfWWmpqKZs2aobi42OLHMincHDp0CIsWLYKvry/kcjnkcjn69euHhIQETJs2zdw1EhER1dl7772HDRs2WPw458+fN5yxMMWgQYMgk8kgk8mgVqvRsWNHfPDBB9XW2bFjB7Kzs/GPf/yjvuVa1aBBg+6Y8LNjx4544IEHsGzZMosf36Rwo9Vq4eHhAUA/dfvly/rLoFu0aIEzZ86YrzoiInIYt4/PtBQvLy+TJ7Wsq4qKCrPsZ+LEibhy5QpSU1MxYsQIxMbG4tNPPzV8vmLFCowbN65e81jZknHjxmH16tUWvzOzSa3VqVMnHD9+HAAQFhaGJUuWICUlBYsWLUKrVq3MWiAREdmnQYMGYcqUKZgxYwZ8fX0RGRkJADh58iQefvhhuLu7w9/fH6NHj0ZOTo5hO51OhyVLlqBNmzZQqVRo3rw53njjDcPnFy9exIgRI+Dt7Y1GjRph+PDhOH/+vOHzW09LffjhhwgKCqo2QSUADB8+HM8++6zh/ddff40ePXpArVajVatWWLhwYbUfYJlMhtWrV+Oxxx6Dm5tbtXqqXLhwAVFRUfDx8YGbmxvuv/9+JCYm1tpGrq6uCAgIQKtWrbBgwQK0bdsWO3bsAABcu3YNe/bsQVRUVLVtli1bhs6dO8PNzQ3BwcF44YUXUFRUVG2dlJQUDBo0CK6urvDx8UFkZCRu3Lhh1vZduHAhmjRpAk9PT0yaNMkQXseOHYvvv/8e7733nqFnqmr7v/3tb8jNzcX3339fa7vUl0nhZu7cuYY/lEWLFiEjIwP9+/dHYmIiVqxYYdYCiYjoNkIAmmJpHkbOtbxx40YolUqkpKRgzZo1yMvLw0MPPYTu3bvjl19+we7du5GdnY0RI/6abHj27NlYvHgx5s2bh9TUVGzevBn+/v4A9D0mkZGR8PDwwIEDB5CSkgJ3d3cMHTq0xp6hp59+GtevX8fevXsNy3Jzc7F7926MGjUKAHDgwAGMGTMG06dPR2pqKv71r39hw4YNdwSYBQsWGC6muTUYVYmNjUV5eTn279+PEydO4K233oK7u7tR7eXi4mL4HgcPHoSrqys6dOhQbR25XI4VK1bg1KlT2LhxI/bs2VPtprrHjh3D4MGD0bFjRxw6dAgHDx5EVFSU4WIfc7RvcnIy0tLSsG/fPnz66afYtm0bFi5cCEB/WjA8PNzQK3XlyhUEBwcD0E/x0K1bNxw4cMCodjGaMJPr168LnU5nrt1ZTH5+vgAg8vPzpS6FiOieSktLRWpqqigtLf1rYXmREPGe0jzKi+pc+8CBA0X37t2rLXv99dfFkCFDqi27ePGiACDOnDkjCgoKhEqlEmvXrq1xn//+979F+/btq/3elJeXCxcXF5GUlCSEECImJkYMHz7c8Pnw4cPFs88+a3j/r3/9SwQFBQmtViuEEGLw4MHizTffvOM4gYGBhvcAxIwZM2r9vp07dxYLFiyodZ1bDRw4UEyfPl0IIURlZaX497//LQCIlStXCiGEePfdd0WrVq3uuZ/PP/9cNG7c2PB+5MiRom/fvjWua672bdSokSguLjass3r1auHu7m5o01u/2+3+/ve/i7Fjx9b4WY1/7zcZ8/tdr/vcAPruKwCGVEZERFSlZ8+e1d4fP34ce/furbFH4+zZs8jLy0N5eTkGDx5c4/6OHz+O9PR0w7jPKmVlZTh79myN24waNQoTJ07EBx98AJVKhU2bNuEf//iHYRzL8ePHkZKSUq2nRqvVoqysDCUlJXB1dQUAhIaG1vpdp02bhsmTJ+Obb75BREQEnnzySXTp0qXWbT744AN89NFH0Gg0UCgUmDlzJiZPngwAKC0thVp953Q+3333HRISEnD69GkUFBSgsrKyWq3Hjh3D008/XePx0tLSzNK+Xbt2NbQLAISHh6OoqAgXL168501/XVxcUFJSUus69WVSuKmsrMTChQuxYsUKw3k+d3d3TJ06FfHx8XB2djZrkUREdAtnV+C1y9Id2whubm7V3hcVFSEqKgpvvfXWHesGBgbi3Llzte6vqKgIPXv2xKZNm+74rEmTJjVuExUVBSEEdu3ahV69euHAgQN49913q+1z4cKFeOKJJ+7Y9tZwcft3ud2ECRMQGRmJXbt24ZtvvkFCQgKWLl2KqVOn3nWbUaNGYc6cOXBxcUFgYGC1gcO+vr6GcTJVzp8/j0cffRSTJ0/GG2+8gUaNGuHgwYMYP348NBoNXF1d4eLictfj1fYZYFr7Gis3NxetW7c2y77uxqRwM3XqVGzbtg1LlixBeHg4AP3l4QsWLMD169exevVqsxZJRES3kMkAZe0/tLaqR48e+PLLLxESEgInpzt/gtq2bQsXFxckJydjwoQJNW6/detW+Pn5wdPTs07HVKvVeOKJJ7Bp0yakp6ejffv26NGjR7V9njlzBm3atDH9i90UHByMSZMmYdKkSZg9ezbWrl1ba7jx8vK663G7d++OrKws3LhxAz4+PgCAI0eOQKfTYenSpYYg9Nlnn1XbrkuXLkhOTjaMgbmVudr3+PHjKC0tNYSlH3/8Ee7u7tXG1tzthr4nT57EU089ddd9m4NJA4o3b96MDRs24Pnnn0eXLl3QpUsXPP/881i3bp1D3GiIiIgsIzY2Frm5uRg5ciQOHz6Ms2fPIikpCePGjYNWq4Varcarr76KV155BZ988gnOnj2LH3/8EevWrQOg7+nw9fXF8OHDceDAAWRkZGDfvn2YNm0a/vzzz7sed9SoUdi1axfWr19vGEhcZf78+fjkk0+wcOFCnDp1CmlpadiyZQvmzp1r1HebMWMGkpKSkJGRgaNHj2Lv3r13DAY2Rvfu3eHr64uUlBTDsjZt2qCiogLvv/8+zp07h3//+99Ys2ZNte1mz56Nw4cP44UXXsBvv/2G06dPY/Xq1cjJyTFb+2o0GowfPx6pqalITExEfHw8pkyZYghcISEh+Omnn3D+/Hnk5OQYLkI6f/48Ll26hIiICJPbpS5MCjcqlQohISF3LG/ZsiWUSs6tRERENQsKCkJKSgq0Wi2GDBmCzp07Y8aMGfD29jb8MM6bNw8vvvgi5s+fjw4dOiA6OhpXr14FoL90ev/+/WjevDmeeOIJdOjQAePHj0dZWVmtPQ0PPfQQGjVqhDNnzuCZZ56p9llkZCR27tyJb775Br169cIDDzyAd9991+gJo7VaLWJjY9GhQwcMHToU7dq1u+OmfMZQKBQYN25ctVNEXbt2xbJly/DWW2+hU6dO2LRpExISEqpt165dO3zzzTc4fvw4evfujfDwcHz99deGnjJztO/gwYPRtm1bDBgwANHR0XjsscewYMECw+cvvfQSFAoFOnbsiCZNmiAzMxMA8Omnn2LIkCEWn4xbJoSR1/VBf/n36dOn8fHHH0OlUgEAysvLMX78eLRt2xbx8fFmL9RcCgoK4OXlhfz8/Dp3aRIRSaWsrAwZGRlo2bJljYNLybFlZWXh/vvvx9GjRy0eCOpq7NixyMvLw/bt243aTqPRoG3btti8eTP69u1b4zq1/b0b8/td5zE3tw+0+u6779CsWTN07doVgP78m0ajuesIbCIiIjJOQEAA1q1bh8zMTJsJN6bKzMzEa6+9dtdgY051DjdeXl7V3j/55JPV3vNScCIiIvNzlElA27RpY5ZB23VR53Dz8ccfW7IOIiIisgPWmJS0vup1E79r164ZJsps37692a6BJyIiIjKVSVdLFRcX49lnn0VgYCAGDBiAAQMGICgoCOPHj7f4XQeJiBoiE679ILI75vo7NyncxMXF4fvvv8d///tf5OXlIS8vD19//TW+//57vPjii2YpjIiIYLjjO//hSA1B1eScCoWiXvsx6bTUl19+iS+++AKDBg0yLBs2bBhcXFwwYsQI3qGYiMhMFAoFvL29q92HRCaTSVwVkfnpdDpcu3YNrq6uNd692hgmbV1SUmKYHv1Wfn5+/NcFEZGZBQQEAIAh4BA5KrlcjubNm9c7wJt0E7/BgwejcePG+OSTTww32SktLUVMTAxyc3Px3Xff1asoS+JN/IjIXmm1WlRUVEhdBpHFKJXKapOH3soiN/G71fLlyzF06NA7buKnVquRlJRkyi6JiOgeFApFvcciEDUEJvXcAPpTU5s2bcLp06cBAB06dMCoUaPuOZ261NhzQ0REZH8s2nNTUVGB++67Dzt37sTEiRNNLpKIiIjIEoy+FNzZ2RllZWWWqIWIiIio3ky6z01sbCzeeustVFZWmrseIiIionoxKdwcPnwY27ZtQ/PmzREZGYknnnii2sNYq1atQkhICNRqNcLCwvDzzz/XabstW7ZAJpM5zKRiREREVH8mXS3l7e19x6zgptq6dSvi4uKwZs0ahIWFYfny5YiMjMSZM2fg5+d31+3Onz+Pl156Cf379zdLHUREROQYjLpaSqfT4e2338aOHTug0Wjw0EMPYcGCBfW6QiosLAy9evXCypUrDccIDg7G1KlTMWvWrBq30Wq1GDBgAJ599lkcOHAAeXl52L59e52Ox6uliIiI7I8xv99GnZZ644038Nprr8Hd3R1NmzbFihUrEBsba3KhGo0GR44cQURExF8FyeWIiIjAoUOH7rrdokWL4Ofnh/Hjx9/zGOXl5SgoKKj2ICIiIsdlVLj55JNP8MEHHyApKQnbt2/Hf//7X2zatAk6nc6kg+fk5ECr1d4xlYO/vz+ysrJq3ObgwYNYt24d1q5dW6djJCQkwMvLy/AIDg42qVYiIiKyD0aFm8zMTAwbNszwPiIiAjKZDJcvXzZ7YTUpLCzE6NGjsXbtWvj6+tZpm9mzZyM/P9/wuHjxooWrJCIiIikZNaC4srLSMJdUFWdnZ5PnOvH19YVCoUB2dna15dnZ2YaJ4m519uxZnD9/HlFRUYZlVb1GTk5OOHPmDFq3bl1tG5VKBZVKZVJ9REREZH+MCjdCCIwdO7ZaWCgrK8OkSZPg5uZmWLZt27Y67U+pVKJnz55ITk42XM6t0+mQnJyMKVOm3LH+fffdhxMnTlRbNnfuXBQWFuK9997jKSciIiIyLtzExMTcsez//u//6lVAXFwcYmJiEBoait69e2P58uUoLi7GuHHjAABjxoxB06ZNkZCQALVajU6dOlXb3tvbGwDuWE5EREQNk1Hh5uOPPzZ7AdHR0bh27Rrmz5+PrKwsdOvWDbt37zYMMs7MzLzr9OdEREREtzN5VnB7xfvcEBER2R+L3eeGiIiIyNYx3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUGwi3KxatQohISFQq9UICwvDzz//fNd1165di/79+8PHxwc+Pj6IiIiodX0iIiJqWCQPN1u3bkVcXBzi4+Nx9OhRdO3aFZGRkbh69WqN6+/btw8jR47E3r17cejQIQQHB2PIkCG4dOmSlSsnIiIiWyQTQggpCwgLC0OvXr2wcuVKAIBOp0NwcDCmTp2KWbNm3XN7rVYLHx8frFy5EmPGjLnn+gUFBfDy8kJ+fj48PT3rXT8RERFZnjG/35L23Gg0Ghw5cgQRERGGZXK5HBERETh06FCd9lFSUoKKigo0atSoxs/Ly8tRUFBQ7UFERESOS9Jwk5OTA61WC39//2rL/f39kZWVVad9vPrqqwgKCqoWkG6VkJAALy8vwyM4OLjedRMREZHtknzMTX0sXrwYW7ZswVdffQW1Wl3jOrNnz0Z+fr7hcfHiRStXSURERNbkJOXBfX19oVAokJ2dXW15dnY2AgICat32nXfeweLFi/Hdd9+hS5cud11PpVJBpVKZpV4iIiKyfZL23CiVSvTs2RPJycmGZTqdDsnJyQgPD7/rdkuWLMHrr7+O3bt3IzQ01BqlEhERkZ2QtOcGAOLi4hATE4PQ0FD07t0by5cvR3FxMcaNGwcAGDNmDJo2bYqEhAQAwFtvvYX58+dj8+bNCAkJMYzNcXd3h7u7u2Tfg4iIiGyD5OEmOjoa165dw/z585GVlYVu3bph9+7dhkHGmZmZkMv/6mBavXo1NBoNnnrqqWr7iY+Px4IFC6xZuk0TQiAjpxgarQ73BfCSdyIiajgkv8+NtTnqfW6EEPgpIxeHzl7HsYt5OHYxD/mlFQCA/4wPQ7+2vhJXSEREZDpjfr8l77kh8zjwRw7GrK95Goo1359luCEiogaD4cZBJKfprzjr0swLT/Vshm7B3vBQO2Pw0n04mJ6D01kFPD1FREQNgl3f54b+cjA9BwDwwqA2GBMegi7NvNHS1w1DO+kvqV9/MEPK8oiIiKyG4cYBXMkvxdlrxZDLgPBWjat9Nr5fSwDA9mOXkVNULkV5REREVsVw4wBS0q8DADo384aXq3O1z3o090HXYG9oKnX4z48XpCiPiIjIqhhuHEDKzVNSfVs3vuMzmUxm6L35z48XUFahtWptRERE1sZwY+eEEIbxNv3a1HxF1MOdAhDopUZOkQY7jl+2ZnlERERWx3Bj59KvFuFaYTlUTnL0aOFT4zrOCjli+oQA0A8sbmC3NiIiogaG4cbOVfXa9G7ZCGpnxV3XG9mrOVycFTidVYgfzl63VnlERERWx3Bj5wzjbe5ySqqKl6szng5tBoCXhRMRkWNjuLFjFVodfjyXC+Du421uVXVqat/v15BbrLFkaURERJJhuLFjv/2Zh6LySni7OqNj4L3vPty6iTvuD/KEView+2SWFSokIiKyPoYbO3bwD/3YmT6tG0Mul9Vpm0e7BAEAdv7Gq6aIiMgxMdzYsbqOt7nVo10CAQA/nruOa4W8YzERETkehhs7VVxeiV8v3gBQt/E2VYIbuaJrsDd0AvjfySuWKo+IiEgyDDd26ufzuajQCjTzcUHzRq5GbRt1s/dm53GGGyIicjwMN3Yq5Y+/7kosk9VtvE2VYZ314ebwhVxk5ZeZvTYiIiIpMdzYqZSbN+IzZrxNlSBvF4S28IEQQOIJ9t4QEZFjYbixQ6UaLc5kFQAAeoU0Mmkfj1SdmuJVU0RE5GAYbuzQmexC6ATQ2E0Jf0+VSfsY1jkQMhlwNDMPl/JKzVwhERGRdBhu7NCpy/kAgI5BnkaPt6ni76lG75u9PrvYe0NERA6E4cYOpV7Wn5LqGHTvuxLX5tGuVTf047gbIiJyHAw3dujUzXBzf5BXvfbzcKcAyGXAb3/m48L1YnOURkREJDmGGzuj1QmcvjmYuC7zSdXG112FPq31V1t9fYynpoiIyDEw3NiZjJwilFXo4OKsQEtft3rv78meTQEAn/1yETqdqPf+iIiIpMZwY2eqTkndF+gBRR0ny6zNw50C4aF2wp83SnHo3PV674+IiEhqDDd2JtUw3qZ+p6SqqJ0VGN5NP7B46+GLZtknERGRlBhu7EzqlarxNvUbTHyr6NDmAIDdp7KQV6Ix236JiIikwHBjR4QQt1wpZZ6eGwDo1NQTHQI9oanUYfuvl8y2XyIiIikw3NiR7IJy5BZroJDL0D7Aw2z7lclkiA5tBgDY+sufEIIDi4mIyH4x3NiRqjsTt27iBrWzwqz7frx7Uyid5Ei7UoCTlwrMum8iIiJrYrixI6lmunlfTbxdlYi8PwAAsPWXTLPvn4iIyFoYbuxI1Xib+t68726iQ4MB6G/oV1ahtcgxiIiILI3hxo5UXSllzsHEt+rTujGa+bigsKwS/zvJ+aaIiMg+MdzYiYKyCmTmlgCo/4SZdyOXy/B0T33vzZafec8bIiKyTww3diLt5imppt4u8HZVWuw4T4c2g0Iuw08ZuThyIddixyEiIrIUhhs7UTXepoOFxttUCfJ2wdM99ZeFL9l9hpeFExGR3WG4sROWHm9zq6mD20KpkOOnjFwcTM+x+PGIiIjMieHGThiulLJCuGnq7YJRD+inZHg7ib03RERkXxhu7ICmUof0q4UArNNzAwAvDGoDV6UCv/2Zj29Ss61yTCIiInNguLEDv2cXokIr4OXijKbeLlY5ZhMPFZ7t2xIAsPSbM9Dq2HtDRET2geHGDlRNu9Ax0BMymcxqx504oBU81U74PbsIO45zQk0iIrIPDDd2YN+ZawCAni18rHpcLxdnPD+wNQDg3W//QIVWZ9XjExERmYLhxsaVVWgN4aZq7idrGtc3BL7uSmTmlmBDynmrH5+IiMhYDDc2bv/v11BaoUVTbxd0amqdwcS3clU6Yebf2gEAliSdxtHMG1avgYiIyBgMNzYu6ZT+SqUh9/tbdbzNrZ7p3RyPdA5EhVYgdtNRXC8ql6QOIiKiumC4sWEVWh2+S9OHm6ESnJKqIpPJsPjJzmjl64Yr+WWYsfUYr54iIiKbxXBjw37OyEV+aQUauykRGtJI0lo81M5Y/X894eKswIE/cvBe8h+S1kNERHQ3DDc2LOlUFgAgooM/FHJpTkndqn2AB958ohMA4P09f2DfmasSV0RERHQnhhsbpdMJQ7gZ2km6U1K3+3v3ZhgV1hxCAFM//RU/nOXcU0REZFsYbmzU8T/zkF1QDneVE/q0aSx1OdXMj+qIXiE+KCyrxJh1P+PTnzOlLomIiMiA4cZG7b7Za/PgfX5QOSkkrqY6lZMC/x4fhse6BqFSJzB72wm8vjOVg4yJiMgmMNzYICEEkk7qw03k/f4SV1MztbMC7/2jG+Ju3gNn3cEMTPzkF+SXVkhcGRERNXQMNzbo9+winL9eAqWTHIPa+0ldzl3JZDJMG9wWK5/pDpWTHHtOX8XAt/di7f5zKKvQSl0eERE1UAw3NqhqIHH/Nr5wVzlJXM29PdolCJ89H47WTdyQV1KBNxLTMOjtfdj8UybnoyIiIqtjuLExJZpK/Pf4ZQDSzCVlqq7B3kiaMQBLnuqCIC81sgrK8NpXJzDo7X1Y/L/T+O3PPAjBMTlERGR5NhFuVq1ahZCQEKjVaoSFheHnn3+udf3PP/8c9913H9RqNTp37ozExEQrVWpZp7MKEPX+QfxxtQguzgpEdLTN8TZ346SQY0RoMPa8NAjzHu2IRm5KXMorxZrvz+KxlSno99Ze/HNnKvaczsbVwjKpyyUiIgclExL/c3rr1q0YM2YM1qxZg7CwMCxfvhyff/45zpw5Az+/O8eb/PDDDxgwYAASEhLw6KOPYvPmzXjrrbdw9OhRdOrU6Z7HKygogJeXF/Lz8+Hpaf2JKGsihMCWwxexYMcplFfq4OehwoqR3fFAK9u6BNxYJZpK7Dl9Ff87mYU9aVdRets4HH9PFTo39ULHQE80a+SKZt4uaOrjggAvtc1dIUZERNIy5vdb8nATFhaGXr16YeXKlQAAnU6H4OBgTJ06FbNmzbpj/ejoaBQXF2Pnzp2GZQ888AC6deuGNWvW3PN4lgo35ZVaXCu8c0LJmlq3rEKLwvJKFJZVorCsArtPZmHnb1cAAIPaN8HSp7uisbvKbLXZglKNFt//fg3fpWXj+MU8nL1WhLtdOS6TAV4uzvBxVcLb9eazizNcVQq4Kp3gqlTAVamA2lkBpUIOpdPNh0IOJ4UMCrkcTnIZFDcfclnVMyCX6d/Lbr6WyQDZzWMCt77X3xH6r88AGWR31HkvEs11SkQkKaWTHH4earPu05jfb0lHq2o0Ghw5cgSzZ882LJPL5YiIiMChQ4dq3ObQoUOIi4urtiwyMhLbt2+vcf3y8nKUl/8VOgoKCupfeA1OXS7AEx/8YPL2TnIZXo5sj4n9W0FuA1MtmJuLUoGhnQIMd1su0VQi9XIBTlzKx+/ZRbiUV4pLN0pwKa8UZRU65JVUIK+El5UTEdmjHs29se2FvpIdX9Jwk5OTA61WC3//6mNL/P39cfr06Rq3ycrKqnH9rKysGtdPSEjAwoULzVNwLWQAVE41D2G6/V/+Kmc5PNROcFc5w0PtBF93JSb0b4UezX0sXqetcFU6ITSk0R0TggohkFuswfViDW4Ua3CjpAJ5JRrkl1agRKNFiaby5rMW5ZVaaCp1KK/UQVOpg0arg1YnUKkV+med/r1OAFqdgBACWiEgBG72Guk/E0JAQN/LVvVaXwwMr2/v4Kyp0+n2XjpR41r3xnHXRGTvnBXSDum1/euM62n27NnVenoKCgoQHBxs9uN0b+6DM/982Oz7bWhkMhkau6sc7rQcERFZj6ThxtfXFwqFAtnZ2dWWZ2dnIyCg5sugAwICjFpfpVJBpeIPJRERUUMhab+RUqlEz549kZycbFim0+mQnJyM8PDwGrcJDw+vtj4AfPvtt3ddn4iIiBoWyU9LxcXFISYmBqGhoejduzeWL1+O4uJijBs3DgAwZswYNG3aFAkJCQCA6dOnY+DAgVi6dCkeeeQRbNmyBb/88gs+/PBDKb8GERER2QjJw010dDSuXbuG+fPnIysrC926dcPu3bsNg4YzMzMhl//VwdSnTx9s3rwZc+fOxWuvvYa2bdti+/btdbrHDRERETk+ye9zY222eBM/IiIiqp0xv982Mf0CERERkbkw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyIiInIoDDdERETkUBhuiIiIyKFIPv2CtVXdkLmgoEDiSoiIiKiuqn636zKxQoMLN4WFhQCA4OBgiSshIiIiYxUWFsLLy6vWdRrc3FI6nQ6XL1+Gh4cHZDKZWfddUFCA4OBgXLx4kfNWWRDb2TrYztbBdrYetrV1WKqdhRAoLCxEUFBQtQm1a9Lgem7kcjmaNWtm0WN4enryPxwrYDtbB9vZOtjO1sO2tg5LtPO9emyqcEAxERERORSGGyIiInIoDDdmpFKpEB8fD5VKJXUpDo3tbB1sZ+tgO1sP29o6bKGdG9yAYiIiInJs7LkhIiIih8JwQ0RERA6F4YaIiIgcCsMNERERORSGGyOtWrUKISEhUKvVCAsLw88//1zr+p9//jnuu+8+qNVqdO7cGYmJiVaq1L4Z085r165F//794ePjAx8fH0RERNzzfxfSM/bvucqWLVsgk8nw+OOPW7ZAB2FsO+fl5SE2NhaBgYFQqVRo164d/7+jDoxt5+XLl6N9+/ZwcXFBcHAwZs6cibKyMitVa5/279+PqKgoBAUFQSaTYfv27ffcZt++fejRowdUKhXatGmDDRs2WLxOCKqzLVu2CKVSKdavXy9OnTolJk6cKLy9vUV2dnaN66ekpAiFQiGWLFkiUlNTxdy5c4Wzs7M4ceKElSu3L8a28zPPPCNWrVolfv31V5GWlibGjh0rvLy8xJ9//mnlyu2Lse1cJSMjQzRt2lT0799fDB8+3DrF2jFj27m8vFyEhoaKYcOGiYMHD4qMjAyxb98+cezYMStXbl+MbedNmzYJlUolNm3aJDIyMkRSUpIIDAwUM2fOtHLl9iUxMVHMmTNHbNu2TQAQX331Va3rnzt3Tri6uoq4uDiRmpoq3n//faFQKMTu3bstWifDjRF69+4tYmNjDe+1Wq0ICgoSCQkJNa4/YsQI8cgjj1RbFhYWJp5//nmL1mnvjG3n21VWVgoPDw+xceNGS5XoEExp58rKStGnTx/x0UcfiZiYGIabOjC2nVevXi1atWolNBqNtUp0CMa2c2xsrHjooYeqLYuLixN9+/a1aJ2OpC7h5pVXXhH3339/tWXR0dEiMjLSgpUJwdNSdaTRaHDkyBFEREQYlsnlckRERODQoUM1bnPo0KFq6wNAZGTkXdcn09r5diUlJaioqECjRo0sVabdM7WdFy1aBD8/P4wfP94aZdo9U9p5x44dCA8PR2xsLPz9/dGpUye8+eab0Gq11irb7pjSzn369MGRI0cMp67OnTuHxMREDBs2zCo1NxRS/Q42uIkzTZWTkwOtVgt/f/9qy/39/XH69Okat8nKyqpx/aysLIvVae9MaefbvfrqqwgKCrrjPyj6iyntfPDgQaxbtw7Hjh2zQoWOwZR2PnfuHPbs2YNRo0YhMTER6enpeOGFF1BRUYH4+HhrlG13TGnnZ555Bjk5OejXrx+EEKisrMSkSZPw2muvWaPkBuNuv4MFBQUoLS2Fi4uLRY7LnhtyKIsXL8aWLVvw1VdfQa1WS12OwygsLMTo0aOxdu1a+Pr6Sl2OQ9PpdPDz88OHH36Inj17Ijo6GnPmzMGaNWukLs2h7Nu3D2+++SY++OADHD16FNu2bcOuXbvw+uuvS10amQF7burI19cXCoUC2dnZ1ZZnZ2cjICCgxm0CAgKMWp9Ma+cq77zzDhYvXozvvvsOXbp0sWSZds/Ydj579izOnz+PqKgowzKdTgcAcHJywpkzZ9C6dWvLFm2HTPl7DgwMhLOzMxQKhWFZhw4dkJWVBY1GA6VSadGa7ZEp7Txv3jyMHj0aEyZMAAB07twZxcXFeO655zBnzhzI5fy3vznc7XfQ09PTYr02AHtu6kypVKJnz55ITk42LNPpdEhOTkZ4eHiN24SHh1dbHwC+/fbbu65PprUzACxZsgSvv/46du/ejdDQUGuUateMbef77rsPJ06cwLFjxwyPxx57DA8++CCOHTuG4OBga5ZvN0z5e+7bty/S09MN4REAfv/9dwQGBjLY3IUp7VxSUnJHgKkKlIJTLpqNZL+DFh2u7GC2bNkiVCqV2LBhg0hNTRXPPfec8Pb2FllZWUIIIUaPHi1mzZplWD8lJUU4OTmJd955R6SlpYn4+HheCl4Hxrbz4sWLhVKpFF988YW4cuWK4VFYWCjVV7ALxrbz7Xi1VN0Y286ZmZnCw8NDTJkyRZw5c0bs3LlT+Pn5iX/+859SfQW7YGw7x8fHCw8PD/Hpp5+Kc+fOiW+++Ua0bt1ajBgxQqqvYBcKCwvFr7/+Kn799VcBQCxbtkz8+uuv4sKFC0IIIWbNmiVGjx5tWL/qUvCXX35ZpKWliVWrVvFScFv0/vvvi+bNmwulUil69+4tfvzxR8NnAwcOFDExMdXW/+yzz0S7du2EUqkU999/v9i1a5eVK7ZPxrRzixYtBIA7HvHx8dYv3M4Y+/d8K4abujO2nX/44QcRFhYmVCqVaNWqlXjjjTdEZWWllau2P8a0c0VFhViwYIFo3bq1UKvVIjg4WLzwwgvixo0b1i/cjuzdu7fG/7+tatuYmBgxcODAO7bp1q2bUCqVolWrVuLjjz+2eJ0yIdj/RkRERI6DY26IiIjIoTDcEBERkUNhuCEiIiKHwnBDREREDoXhhoiIiBwKww0RERE5FIYbIiIicigMN0RERORQGG6IiEywYMECdOvWTeoyiKgGDDdEZFFjx46FTCaDTCaDUqlEmzZtsGjRIlRWVkpdGhE5KCepCyAixzd06FB8/PHHKC8vR2JiImJjY+Hs7IzZs2dXW0+j0XDmayKqN/bcEJHFqVQqBAQEoEWLFpg8eTIiIiKwY8cOjB07Fo8//jjeeOMNBAUFoX379gCAEydO4KGHHoKLiwsaN26M5557DkVFRYb9VW23cOFCNGnSBJ6enpg0aRI0Go1hnfLyckybNg1+fn5Qq9Xo168fDh8+bPj8xo0bGDVqFJo0aQIXFxe0bdsWH3/8seHzV199Fe3atYOrqytatWqFefPmoaKiwgqtRUT1xZ4bIrI6FxcXXL9+HQCQnJwMT09PfPvttwCA4uJiREZGIjw8HIcPH8bVq1cxYcIETJkyBRs2bDDsIzk5GWq1Gvv27cP58+cxbtw4NG7cGG+88QYA4JVXXsGXX36JjRs3okWLFliyZAkiIyORnp6ORo0aYd68eUhNTcX//vc/+Pr6Ij09HaWlpYb9e3h4YMOGDQgKCsKJEycwceJEeHh44JVXXrFeQxGRaSw+7zgRNWgxMTFi+PDhQgghdDqd+Pbbb4VKpRIvvfSSiImJEf7+/qK8vNyw/ocffih8fHxEUVGRYdmuXbuEXC4XWVlZhn02atRIFBcXG9ZZvXq1cHd3F1qtVhQVFQlnZ2exadMmw+cajUYEBQWJJUuWCCGEiIqKEuPGjavz93j77bdFz549De/j4+NF165djWoLIrIO9twQkcXt3LkT7u7uqKiogE6nwzPPPIMFCxYgNjYWnTt3rjbOJi0tDV27doWbm5thWd++faHT6XDmzBn4+/sDALp27QpXV1fDOuHh4SgqKsLFixeRn5+PiooK9O3b1/C5s7MzevfujbS0NADA5MmT8eSTT+Lo0aMYMmQIHn/8cfTp08ew/tatW7FixQqcPXsWRUVFqKyshKenp8XaiIjMh2NuiMjiHnzwQRw7dgx//PEHSktLsXHjRkN4uTXEWNPDDz+MCxcuYObMmbh8+TIGDx6Ml156CQBw6NAhjBo1CsOGDcPOnTvx66+/Ys6cOdXG9BCR7WK4ISKLc3NzQ5s2bdC8eXM4OdXeYdyhQwccP34cxcXFhmUpKSmQy+WGAccAcPz48WpjZH788Ue4u7sjODgYrVu3hlKpREpKiuHziooKHD58GB07djQsa9KkCWJiYvCf//wHy5cvx4cffggA+OGHH9CiRQvMmTMHoaGhaNu2LS5cuFDvdiAi62C4ISKbMmrUKKjVasTExODkyZPYu3cvpk6ditGjRxtOSQH6y8bHjx+P1NRUJCYmIj4+HlOmTIFcLoebmxsmT56Ml19+Gbt370ZqaiomTpyIkpISjB8/HgAwf/58fP3110hPT8epU6ewc+dOdOjQAQDQtm1bZGZmYsuWLTh79ixWrFiBr776SpL2ICLjccwNEdkUV1dXJCUlYfr06ejVqxdcXV3x5JNPYtmyZdXWGzx4MNq2bYsBAwagvLwcI0eOxIIFCwyfL168GDqdDqNHj0ZhYSFCQ0ORlJQEHx8fAIBSqcTs2bNx/vx5uLi4oH///tiyZQsA4LHHHsPMmTMxZcoUlJeX45FHHsG8efOq7Z+IbJdMCCGkLoKIyBhjx45FXl4etm/fLnUpRGSDeFqKiIiIHArDDRERETkUnpYiIiIih8KeGyIiInIoDDdERETkUBhuiIiIyKEw3BAREZFDYbghIiIih8JwQ0RERA6F4YaIiIgcCsMNEREROZT/B4DVxEzPS2NxAAAAAElFTkSuQmCC", 98 | "text/plain": [ 99 | "
" 100 | ] 101 | }, 102 | "metadata": {}, 103 | "output_type": "display_data" 104 | } 105 | ], 106 | "source": [ 107 | "from matplotlib import pyplot as plt\n", 108 | "\n", 109 | "plt.plot(Proposal, o, label=\"offerer's P(offer)\")\n", 110 | "plt.plot(Proposal, r[:, 0], label=\"receiver's P(accept)\")\n", 111 | "\n", 112 | "plt.xlabel('Proposal')\n", 113 | "plt.ylabel('Probability')\n", 114 | "\n", 115 | "plt.legend()" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "language_info": { 121 | "name": "python" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /demo/memo: -------------------------------------------------------------------------------- 1 | ../memo -------------------------------------------------------------------------------- /demo/test.py: -------------------------------------------------------------------------------- 1 | from memo import memo, memo_test, make_module 2 | 3 | mod = make_module('test_suite') 4 | mod.install(''' 5 | import jax 6 | import jax.numpy as np 7 | X = np.arange(3) 8 | Y = np.arange(3) 9 | N = 5 10 | 11 | @jax.jit 12 | def f(n): 13 | return n + 1 14 | 15 | Z = np.arange(1000) 16 | R = np.linspace(-10, 10, 1000) 17 | 18 | @jax.jit 19 | def normpdf(x, mu, sigma): 20 | return jax.scipy.stats.norm.pdf(x, mu, sigma) 21 | ''') 22 | 23 | @memo(install_module=mod.install) 24 | def test_[x: X](t): 25 | return x 26 | 27 | @memo_test(mod, expect='ce') 28 | def chooses_multiple(): 29 | bob: chooses(x in X, wpp=1) 30 | bob: chooses(x in X, wpp=1) 31 | return 1 32 | 33 | @memo_test(mod) 34 | def observes_call[x: X](): 35 | a: thinks[ b: chooses(x in X, wpp=1) ] 36 | a: observes [b.x] is x 37 | return a[ b[ test_[x](0) ] ] 38 | 39 | @memo_test(mod) 40 | def observes_other(): 41 | alice: thinks[ bob: chooses(x in X, wpp=1) ] 42 | charlie: chooses(x in X, wpp=1) 43 | alice: observes [bob.x] is charlie.x 44 | return E[alice[bob.x]] 45 | 46 | @memo_test(mod) 47 | def observes_other_imagine(): 48 | alice: thinks[ bob: chooses(x in X, wpp=1) ] 49 | charlie: chooses(x in X, wpp=1) 50 | alice: observes [bob.x] is charlie.x 51 | return E[alice[ 52 | imagine[ 53 | env: knows(bob.x), 54 | bob.x + 1 55 | ] 56 | ]] 57 | 58 | @memo_test(mod) 59 | def imagine_ok(): 60 | return alice[ 61 | imagine[ 62 | bob: chooses(y in X, wpp=1), 63 | E[bob.y] 64 | ] 65 | ] 66 | 67 | @memo_test(mod) 68 | def inline(): 69 | return {N} 70 | 71 | @memo_test(mod) 72 | def inline_call(): 73 | return f({N}) 74 | 75 | @memo_test(mod) 76 | def inline_memo[x: X](): 77 | return test_[x]({N}) 78 | 79 | @memo_test(mod) 80 | def memo_call_ellipsis(t=0): 81 | alice: chooses(x in X, wpp=test_[x](...)) 82 | return E[alice.x] 83 | 84 | @memo_test(mod) 85 | def imagine_ok(): 86 | return alice[ 87 | imagine[ 88 | bob: chooses(y in X, wpp=1), 89 | E[bob.y] 90 | ] 91 | ] 92 | 93 | @memo_test(mod, expect='ce') 94 | def imagine_unknown_err(): 95 | return alice[ 96 | imagine[ 97 | bob: chooses(y in X, wpp=1), 98 | bob.y 99 | ] 100 | ] 101 | 102 | @memo_test(mod, expect='ce') 103 | def imagine_unknown_err_expect(): 104 | return alice[ 105 | E[imagine[ 106 | bob: chooses(y in X, wpp=1), 107 | bob.y 108 | ]] 109 | ] 110 | 111 | @memo_test(mod, expect='ce') 112 | def imagine_unknown_err_expect_future(): 113 | return alice[ 114 | E[imagine[ 115 | future_alice: chooses(y in X, wpp=1), 116 | future_alice.y 117 | ]] 118 | ] 119 | 120 | @memo_test(mod) 121 | def imagine_knows(): 122 | alice: chooses(x in X, wpp=1) 123 | return E[alice[ 124 | imagine[ 125 | world: knows(x), 126 | world.x 127 | ] 128 | ]] 129 | 130 | @memo_test(mod) 131 | def imagine_knows_other[z: X](): 132 | alice: chooses(x in X, wpp=1) 133 | alice: thinks[ bob: chooses(z in X, wpp=1) ] 134 | alice: observes [bob.z] is z 135 | return alice[ 136 | imagine[ 137 | world: knows(bob.z), 138 | world[bob.z] 139 | ] 140 | ] 141 | 142 | @memo_test(mod) 143 | def imagine_future_stress(): 144 | alice: chooses(x in X, wpp=1) 145 | alice: thinks[ bob: chooses(z in X, wpp=1) ] 146 | return E[alice[ 147 | imagine[ 148 | world: knows(x, bob.z), 149 | world: chooses(z in X, wpp=1), 150 | future_alice: chooses(y in X, wpp=x + y), 151 | future_alice: thinks[ world: chooses(z in X, wpp=1) ], 152 | future_alice: observes [world.z] is world.z, 153 | E[future_alice.y + future_alice[world.z + y] + world.z + world[bob.z]] 154 | ] 155 | ]] 156 | 157 | @memo_test(mod) 158 | def imagine_toplevel(): 159 | return imagine[ 160 | alice: chooses(x in X, wpp=1), 161 | E[alice.x] 162 | ] 163 | 164 | mod.install(''' 165 | @jax.jit 166 | def returns_scalar(x): 167 | return np.cos(x) + np.array([0, 1, 2])[x] 168 | @jax.jit 169 | def returns_scalar_no_arg(): 170 | return np.cos(3.14) 171 | @jax.jit 172 | def returns_nonscalar0(): 173 | return np.array([0, 1]) 174 | @jax.jit 175 | def returns_nonscalar1(x): 176 | return np.array([0, 1]) 177 | ''') 178 | 179 | @memo_test(mod) 180 | def ffi_ok(): 181 | alice: chooses(x in X, wpp=1) 182 | return E[returns_scalar(alice.x)] + 12 183 | 184 | @memo_test(mod) 185 | def ffi_ok_no_arg(): 186 | return returns_scalar_no_arg() + 15 187 | 188 | @memo_test(mod, expect='ce') 189 | def ffi_scalar0(): 190 | return returns_nonscalar0() 191 | 192 | @memo_test(mod, expect='ce') 193 | def ffi_scalar1(): 194 | return returns_nonscalar1(1.0) 195 | 196 | @memo_test(mod) 197 | def observes_const(): 198 | alice: thinks[ bob: chooses(x in X, wpp=1) ] 199 | alice: observes_that [bob.x == 0] 200 | return alice[E[bob.x]] 201 | 202 | @memo_test(mod) 203 | def observes_const_float(): 204 | alice: thinks[ bob: chooses(x in X, wpp=1) ] 205 | alice: observes_event(wpp=bob.x / 3) 206 | return alice[E[bob.x]] 207 | 208 | @memo_test(mod) 209 | def observes_const_void_choose(): 210 | alice: chooses(x in X, wpp=1) 211 | alice: observes_event(wpp=x / 3) 212 | return E[alice.x] 213 | 214 | @memo_test(mod) 215 | def observes_const_void(): 216 | alice: observes_event(wpp=3.14) 217 | return alice[2] 218 | 219 | @memo_test(mod) 220 | def pr_joint(): 221 | alice: chooses(x in X, wpp=1) 222 | alice: chooses(y in X, wpp=1) 223 | return Pr[alice.x == 0, alice.y == 0] 224 | 225 | @memo_test(mod) 226 | def choose_many(): 227 | alice: chooses(x in X, y in Y, wpp=1) 228 | return Pr[alice.x == 0, alice.y == 0] 229 | 230 | @memo_test(mod) 231 | def choose_max(): 232 | alice: chooses(x in X, y in Y, to_maximize=x + y) 233 | return Pr[alice.x == 0, alice.y == 0] 234 | 235 | @memo_test(mod) 236 | def choose_min(): 237 | alice: chooses(x in X, y in Y, to_minimize=x + y) 238 | return Pr[alice.x == 0, alice.y == 0] 239 | 240 | @memo_test(mod, expect='ce') 241 | def choose_err(): 242 | alice: chooses(x in X, y in Y, to_eat=x + y) 243 | return Pr[alice.x == 0, alice.y == 0] 244 | 245 | @memo_test(mod) # crashes without post optim 246 | def post_optim[z1: Z, z2: Z](): 247 | alice: chooses(z1 in Z, wpp=1) 248 | alice: chooses(z2 in Z, wpp=1) 249 | return Pr[z1 == alice.z1, alice.z2 == z2] 250 | 251 | @memo_test(mod) # https://stackoverflow.com/a/22348885 252 | def post_optim_distinctness[z: Z](): 253 | alice: chooses(z1 in Z, wpp=1) 254 | alice: chooses(z2 in Z, wpp=1) 255 | return Pr[z == alice.z1, z == alice.z2] 256 | 257 | from math import log 258 | @memo_test(mod, item=log(2/1) + 1/8 - 1/2) 259 | def kl(): 260 | alice: chooses(p in R, wpp=normpdf(p, 0, 1)) 261 | alice: chooses(q in R, wpp=normpdf(q, 0, 2)) 262 | return KL[alice.p | alice.q] 263 | 264 | @memo_test(mod, expect='ce') 265 | def kl_fail_unknown(): 266 | alice: chooses(p in R, wpp=normpdf(p, 0, 1)) 267 | alice: chooses(q in R, wpp=normpdf(q, 0, 2)) 268 | return KL[alice.r | alice.q] 269 | 270 | @memo_test(mod, expect='ce') 271 | def kl_fail_known[r: R](): 272 | alice: knows(r) 273 | alice: chooses(q in R, wpp=normpdf(q, 0, 2)) 274 | return KL[alice.r | alice.q] 275 | 276 | @memo_test(mod, expect='ce') 277 | def kl_fail_dom(): 278 | alice: chooses(p in Z, wpp=normpdf(p, 0, 1)) 279 | alice: chooses(q in R, wpp=normpdf(q, 0, 2)) 280 | return KL[alice.p | alice.q] 281 | 282 | @memo_test(mod) 283 | def kl_victor[x: X](): 284 | bob: thinks[ 285 | alice: given(x in X, wpp=1), 286 | env: chooses(x in X, wpp=1) 287 | ] 288 | return bob[KL[alice.x | env.x]] 289 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | (Under construction.) 4 | -------------------------------------------------------------------------------- /docs/tomalot/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Click here to download the paper (PDF format). 4 | -------------------------------------------------------------------------------- /docs/tomalot/tomalot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/memo/cce50349ae7708e37218244f50c9b1aa64f60fdf/docs/tomalot/tomalot.pdf -------------------------------------------------------------------------------- /memo/__init__.py: -------------------------------------------------------------------------------- 1 | from .codegen import memo, memo_test 2 | from .utils import domain, make_module 3 | from .version import __version__ -------------------------------------------------------------------------------- /memo/codegen.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .core import * 4 | from .parse import * 5 | from .version import __version__ 6 | 7 | import textwrap 8 | import os, sys, platform, inspect 9 | from io import StringIO 10 | from typing import Any, Optional, Literal, Protocol, overload, cast, TYPE_CHECKING 11 | from collections.abc import Callable 12 | import warnings 13 | import linecache 14 | 15 | if TYPE_CHECKING: 16 | import jax 17 | 18 | from . import lib 19 | lib_dir = ', '.join([key for key in dir(lib) if not key.startswith('_')]) 20 | 21 | class MemoCompiled(Protocol): 22 | @overload 23 | def __call__( 24 | self, 25 | *args: jax.typing.ArrayLike, 26 | return_aux: Literal[False] = ..., 27 | return_pandas: bool = ..., 28 | return_xarray: bool = ..., 29 | return_cost: bool = ..., 30 | print_table: bool = ..., 31 | **kwargs: jax.typing.ArrayLike 32 | ) -> jax.Array: 33 | ... 34 | 35 | @overload 36 | def __call__( 37 | self, 38 | *args: jax.typing.ArrayLike, 39 | return_aux: Literal[True] = ..., 40 | return_pandas: bool = ..., 41 | return_xarray: bool = ..., 42 | return_cost: bool = ..., 43 | print_table: bool = ..., 44 | **kwargs: jax.typing.ArrayLike 45 | ) -> memo_result: 46 | ... 47 | 48 | def __call__( 49 | self, 50 | *args: jax.typing.ArrayLike, 51 | return_aux: bool = False, 52 | return_pandas: bool = False, 53 | return_xarray: bool = False, 54 | return_cost: bool = False, 55 | print_table: bool = False, 56 | **kwargs: jax.typing.ArrayLike 57 | ) -> jax.Array | memo_result: 58 | ... 59 | 60 | def make_static_parameter_list(pctxt: ParsingContext) -> str: 61 | out = '' 62 | for sp, sd in zip(pctxt.static_parameters, pctxt.static_defaults): 63 | if sd is None: 64 | out += f'{sp}' 65 | else: 66 | out += f'{sp}={sd}' 67 | out += ', ' 68 | return out 69 | 70 | def codegen( 71 | pctxt: ParsingContext, 72 | stmts: list[Stmt], 73 | retval: Expr, 74 | debug_print_compiled: bool=False, 75 | debug_trace: bool=False, 76 | save_comic: Optional[str]=None, 77 | install_module: Optional[Callable[[str], Any]] = None, 78 | cache: bool = False 79 | ) -> MemoCompiled: 80 | f_name = pctxt.loc_name 81 | ctxt = Context(frame=Frame(name=ROOT_FRAME_NAME)) 82 | ctxt.hoisted_syms.extend(pctxt.static_parameters) 83 | with ctxt.hoist(): 84 | ctxt.emit(f"cost_ = 0") 85 | if debug_trace: 86 | ctxt.emit(f"""_time_ = time.time()""") 87 | ctxt.emit(f"""print(f' --> {pctxt.loc_name}({{ {", ".join(pctxt.static_parameters) if len(pctxt.static_parameters) > 0 else '""'} }})')""") 88 | for stmt_ in stmts: 89 | eval_stmt(stmt_, ctxt) 90 | 91 | val = eval_expr(retval, ctxt) 92 | # for ax_name, ax_dom in pctxt.axes: 93 | # if (Name('self'), ax_name) not in val.deps: 94 | # warnings.warn(f"memo {pctxt.loc_name}'s return value does not depend on axis {ax_name} (of type {ax_dom}). Are you sure this is what you want? Please note that memo will avoid redundant work by returning an array where the dimension along that axis is of length 1.") 95 | 96 | if not val.known: 97 | raise MemoError( 98 | "Returning a value that the observer has uncertainty over", 99 | hint="Did you mean to use E[...] after your return statement?", 100 | ctxt=ctxt, 101 | user=True, 102 | loc=retval.loc 103 | ) 104 | squeeze_axes = [ 105 | -1 - i 106 | for i in range(ctxt.next_idx) 107 | if i not in [z[0] for z in ctxt.forall_idxs] 108 | ] 109 | ctxt.emit(f"# prepare output") 110 | ctxt.emit(f"{val.tag} = jnp.array({val.tag}) # ensure output is an array") 111 | ctxt.emit(f"{val.tag} = pad({val.tag}, {ctxt.next_idx})") 112 | ctxt.emit(f"{val.tag} = {val.tag}.squeeze(axis={tuple(squeeze_axes)}).transpose()") 113 | 114 | with ctxt.hoist(): 115 | ctxt.emit(f"""\ 116 | _out_ = _jit_{f_name}({", ".join(ctxt.hoisted_syms)}) 117 | """) 118 | 119 | ctxt.emit(f"""\ 120 | if return_cost: 121 | # https://jax.readthedocs.io/en/latest/aot.html 122 | _lowered_ = _jit_{f_name}.lower({", ".join(ctxt.hoisted_syms)}) 123 | _cost_ = _lowered_.cost_analysis() 124 | _cost_ = dict( 125 | flops=_cost_.get('flops', 0), 126 | transcendentals=_cost_.get('transcendentals', 0), 127 | bytes=_cost_.get('bytes accessed', 0) 128 | ) 129 | aux.cost += _cost_['flops'] + _cost_['transcendentals'] 130 | """) 131 | if debug_trace: 132 | ctxt.emit(f"""print(f'<-- {pctxt.loc_name}({{ {", ".join(pctxt.static_parameters) if len(pctxt.static_parameters) > 0 else '""'} }}) has shape {{ _out_.shape }}')""") 133 | ctxt.emit(f"""\ 134 | if return_cost: 135 | print(f' cost = {{aux.cost}} operations') 136 | """) 137 | ctxt.emit(f"""print(f' time = {{time.time() - _time_:.6f}} sec')""") 138 | ctxt.emit(f"if print_table: pprint_table(_out_{f_name}, _out_)") 139 | ctxt.emit(f"if return_pandas: aux.pandas = make_pandas_data(_out_{f_name}, _out_)") 140 | ctxt.emit(f"if return_xarray: aux.xarray = make_xarray_data(_out_{f_name}, _out_)") 141 | ctxt.emit(f"""return memo_result(data=_out_, aux=aux) if return_aux else _out_""") 142 | ctxt.emit(f"return {val.tag}") 143 | 144 | out = f"""\ 145 | def _make_{f_name}(): 146 | from memo.lib import {lib_dir} 147 | 148 | @jax.jit 149 | def _jit_{f_name}({", ".join(ctxt.hoisted_syms)}): 150 | {textwrap.indent(ctxt.regular_buf.getvalue(), " " * 2)} 151 | 152 | {" @cache" if cache else ""} 153 | def _out_{f_name}( 154 | {make_static_parameter_list(pctxt)}*, 155 | return_aux=False, 156 | return_pandas=False, 157 | return_xarray=False, 158 | return_cost=False, 159 | print_table=False 160 | ): 161 | aux = AuxInfo() 162 | if return_pandas or return_xarray: 163 | return_aux = True 164 | if return_cost: 165 | return_aux = True 166 | aux.cost = 0. 167 | {textwrap.indent(ctxt.hoisted_buf.getvalue(), " " * 2)} 168 | 169 | _out_{f_name}._shape = tuple([{", ".join(f"len({p[1]})" for p in pctxt.axes)}]) 170 | _out_{f_name}._axes = tuple([{", ".join(f"{repr(p[0])}" for p in pctxt.axes)}]) 171 | _out_{f_name}._doms = tuple([{", ".join(f"{repr(p[1])}" for p in pctxt.axes)}]) 172 | _out_{f_name}._vals = tuple([{", ".join(f"{p[1]}" for p in pctxt.axes)}]) 173 | return _out_{f_name} 174 | 175 | {f_name} = _make_{f_name}() 176 | {f_name}.__name__ = '{f_name}' 177 | {f_name}.__qualname__ = '{pctxt.qualname}' 178 | {f_name}.__doc__ = {repr(pctxt.doc)} 179 | """ 180 | 181 | if debug_print_compiled: 182 | for i, line in enumerate(out.splitlines()): 183 | print(f"{i + 1: 5d} {line}") 184 | 185 | if save_comic is not None: 186 | from .comic import comic 187 | comic(ctxt.frame, fname=save_comic) 188 | 189 | globals_of_caller = inspect.stack()[3].frame.f_globals 190 | locals_of_caller = inspect.stack()[3].frame.f_locals 191 | if globals_of_caller != locals_of_caller and install_module is None: 192 | warnings.warn(f"memo works best in the global (module) scope. Defining memos within function definitions is currently not officially supported, though if you know what you are doing then go ahead and do it!") 193 | if install_module is not None: 194 | ret = install_module(out)[f"{f_name}"] 195 | return cast(MemoCompiled, lambda _: print("Call me from inside the module!")) 196 | 197 | retvals: dict[Any, Any] = {} 198 | 199 | exec(out, globals_of_caller, retvals) 200 | return cast(MemoCompiled, retvals[f"{f_name}"]) 201 | 202 | def memo_(f: Callable[..., Any], **kwargs: Any) -> MemoCompiled: 203 | try: 204 | pctxt, stmts, retval = parse_memo(f) 205 | return codegen(pctxt, stmts, retval, **kwargs) 206 | except MemoError as e: 207 | if e.loc: 208 | e.add_note(f" file: \"{os.path.basename(e.loc.file)}\", line {e.loc.line}, in @memo {e.loc.name}") 209 | e.add_note(f" {linecache.getline(e.loc.file, e.loc.line)[:-1]}") 210 | e.add_note(f" {' ' * e.loc.offset}^") 211 | if e.hint is not None: 212 | e.add_note('') 213 | for line in textwrap.wrap( 214 | e.hint, initial_indent=" hint: ", subsequent_indent=" " 215 | ): 216 | e.add_note(line) 217 | if e.ctxt: # TODO 218 | e.add_note('') 219 | frame_name = f"{e.ctxt.frame.name}" 220 | z = e.ctxt.frame 221 | while z.parent is not None and z.parent.name != ROOT_FRAME_NAME: 222 | z = z.parent 223 | frame_name += f", as modeled by {z.name}" 224 | ctxt_note = f'''\ 225 | This error was encountered in the frame of {frame_name}. 226 | 227 | In that frame, {e.ctxt.frame.name} is currently modeling the following {len(e.ctxt.frame.choices)} choices: {", ".join([v if k == Name("self") else f"{k}.{v}" for k, v in e.ctxt.frame.choices.keys()])}. 228 | ''' 229 | for line in textwrap.wrap( 230 | ctxt_note, initial_indent=" ctxt: ", subsequent_indent=" " 231 | ): 232 | e.add_note(line) 233 | if not e.user: 234 | e.add_note("") 235 | e.add_note( 236 | "[We think this may be a bug in memo: if you don't understand what is going on, please get in touch with us!]" 237 | ) 238 | e.add_note("") 239 | 240 | # Describe environment... 241 | import jax 242 | e.add_note(f" info: You are using memo {__version__}, JAX {jax.__version__}, Python {platform.python_version()} on {platform.system()}.") 243 | 244 | raise e.with_traceback(None) from None 245 | 246 | import sys, traceback 247 | old_excepthook = sys.excepthook 248 | def new_excepthook(typ, val, tb): # type: ignore 249 | if typ is MemoError: 250 | return traceback.print_exception(val, limit=0) 251 | old_excepthook(typ, val, tb) 252 | sys.excepthook = new_excepthook 253 | 254 | warnings.showwarning = lambda msg, *args: print('Warning:', msg, file=sys.stderr) 255 | 256 | try: 257 | ipython = get_ipython() # type: ignore 258 | old_showtraceback = ipython.showtraceback 259 | def new_showtraceback(*args, **kwargs): # type: ignore 260 | info = sys.exc_info() 261 | if info[0] is MemoError: 262 | return traceback.print_exception(info[1], limit=0) 263 | old_showtraceback(sys.exc_info(), **kwargs) 264 | ipython.showtraceback = new_showtraceback 265 | except NameError: 266 | pass 267 | 268 | @overload 269 | def memo(f: None=None, **kwargs: Any) -> Callable[[Callable[..., Any]], MemoCompiled]: 270 | ... 271 | 272 | @overload 273 | def memo(f: Callable[..., Any], **kwargs: Any) -> MemoCompiled: 274 | ... 275 | 276 | def memo(f: None | Callable[..., Any] = None, **kwargs: Any) -> Callable[[Callable[..., Any]], MemoCompiled] | MemoCompiled: 277 | if f is None: 278 | return lambda f: memo_(f, **kwargs) 279 | return memo_(f, **kwargs) 280 | 281 | def memo_test(mod, expect='pass', item=None, *args, **kwargs): # type: ignore 282 | def helper(f): # type: ignore 283 | name = f.__name__ 284 | outcome = None 285 | err: BaseException 286 | try: 287 | memo(f, install_module=mod.install, **kwargs) 288 | f = mod.__getattribute__(name) 289 | out = f(*args) 290 | except MemoError as e: 291 | outcome = 'ce' 292 | err = e 293 | except Exception as e: 294 | outcome = 're' 295 | err = e 296 | else: 297 | outcome = 'pass' 298 | 299 | if outcome == 'pass' and item is not None and abs(out.item() - item) > 1e-6: 300 | print(f'[fail {name}, {out} != {item}]') 301 | elif expect == outcome: 302 | print(f'[ pass {name} ]') 303 | return f 304 | else: 305 | print(f'[!fail {name}, {outcome} != {expect} ]') 306 | raise err 307 | return helper 308 | -------------------------------------------------------------------------------- /memo/comic.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | import shutil, os 3 | 4 | def frame_name(frame): 5 | return f'cluster_{frame.name}_{id(frame)}' 6 | 7 | def node_name(frame, who, id_): 8 | return f'{who}_{id_}_{id(frame)}' 9 | 10 | def comic_frame_edges(frame: Frame, io: StringIO) -> None: 11 | for c in frame.children.keys(): 12 | comic_frame_edges(frame.children[c], io) 13 | 14 | if frame.parent is not None: 15 | for too, frm in frame.conditions.items(): 16 | print(f'{node_name(frame.parent, *frm)} -> {node_name(frame, *too)}[style="dashed"];', file=io) 17 | 18 | def comic_frame_nodes(frame: Frame, io: StringIO) -> None: 19 | for c in frame.children.values(): 20 | comic_frame_nodes(c, io) 21 | 22 | print(f'subgraph {frame_name(frame)} {{', file=io) 23 | print(f'label="{frame.name}\'s frame";', file=io) 24 | print(f'labelloc="b";', file=io) 25 | print(f'''{frame_name(frame)}_dummy[style=invis];''', file=io) 26 | 27 | for c in frame.children.values(): 28 | print(f'''{frame_name(frame)}_dummy -> {frame_name(c)}_dummy[ltail={frame_name(frame)}, lhead={frame_name(c)}, arrowhead="tee"];''', file=io) 29 | 30 | for (who, id), ch in frame.choices.items(): 31 | color = "lightblue" if ch.known else "orange" 32 | label = f'{who}.{id}' if who != 'self' else f'{id}' 33 | print(f'''{node_name(frame, who, id)}[label="{label} : {ch.domain}", color={color}];''', file=io) 34 | 35 | print('}', file=io) 36 | 37 | def comic(frame: Frame, fname: str) -> None: 38 | io = StringIO() 39 | print('digraph G {', file=io) 40 | print('rankdir=LR; compound=true;', file=io) 41 | print('node[shape="cds", style="filled"];', file=io) 42 | comic_frame_nodes(frame, io) 43 | comic_frame_edges(frame, io) 44 | print('}', file=io) 45 | 46 | with open(f'{fname}.dot', 'w') as f: 47 | io.seek(0) 48 | shutil.copyfileobj(io, f) 49 | if shutil.which('dot') is not None: 50 | os.system(f'dot {fname}.dot -Tpng -o {fname}.png') 51 | # os.remove(fname) 52 | else: 53 | print(f"memo couldn't find a graphviz installation, so it only produced the .dot file. If you don't have graphviz installed, you can paste the .dot file into an online editor, such as https://dreampuf.github.io/GraphvizOnline/") 54 | -------------------------------------------------------------------------------- /memo/lib.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import time 4 | from functools import cache 5 | 6 | from .core import MemoError, AuxInfo, memo_result 7 | 8 | def marg(t, dims): 9 | if dims == (): 10 | return t 11 | return t.sum(axis=tuple(-1 - d for d in dims), keepdims=True) 12 | 13 | def maxx(t, dims): 14 | if dims == (): 15 | return t 16 | return jnp.max(t, axis=tuple(-1 - d for d in dims), keepdims=True) 17 | 18 | def pad(t, total): 19 | count = total - len(t.shape) 20 | for _ in range(count): 21 | t = jnp.expand_dims(t, 0) 22 | return t 23 | 24 | def ffi(f, *args): 25 | if jax.eval_shape(f, *[jax.ShapeDtypeStruct((), jnp.int32) for z in args]).shape != (): 26 | raise MemoError( 27 | f"The function {f.__name__}(...) is not scalar-in-scalar-out. memo can only handle external (@jax.jit) functions that take scalars as input and return a single scalar as output.", 28 | hint=None, 29 | user=True, 30 | ctxt=None, 31 | loc=None 32 | ) 33 | if not isinstance(f, jax.lib.xla_extension.PjitFunction): 34 | raise MemoError( 35 | f"Tried to call non-JAX function `{f.__name__}`. Use @jax.jit to mark as JAX.", 36 | hint=None, 37 | user=True, 38 | ctxt=None, 39 | loc=None 40 | ) 41 | if len(args) == 0: 42 | return f() 43 | args = jax.numpy.broadcast_arrays(*args) 44 | target_shape = args[0].shape 45 | args = [arg.reshape(-1) for arg in args] 46 | return jax.vmap(f)(*args).reshape(target_shape) 47 | 48 | def check_domains(tgt, src): 49 | if len(tgt) > len(src): 50 | raise Exception("Not enough arguments to memo call!") 51 | if len(src) > len(tgt): 52 | raise Exception("Too many arguments to memo call!") 53 | for i, (t, s) in enumerate(zip(tgt, src)): 54 | if t != s: 55 | raise Exception(f"Domain mismatch in memo call argument {i + 1}: {t} != {s}.") 56 | 57 | 58 | 59 | def pprint_table(f, z): 60 | z = z.at[jnp.isclose(z, 1., atol=1e-5)].set(1) 61 | z = z.at[jnp.isclose(z, 0., atol=1e-5)].set(0) 62 | 63 | def pprint(val): 64 | if isinstance(val, jnp.ndarray): 65 | return str(val.item()) 66 | from enum import Enum 67 | if isinstance(val, Enum): 68 | return f'{val.name}' 69 | return str(val) 70 | 71 | rows = [] 72 | rows.append(tuple([f'{ax}: {dom}' for ax, dom in zip(f._axes, f._doms)]) + (f"{f.__name__}",)) # header 73 | import itertools 74 | for row in itertools.product(*[enumerate(v) for v in f._vals]): 75 | idx = tuple([r[0] for r in row]) 76 | lead = tuple([pprint(r[1]) for r in row]) 77 | rows.append(lead + (pprint(z[idx]),)) 78 | 79 | widths = [] 80 | for col in range(len(rows[0])): 81 | widths.append(max([len(row[col]) for row in rows])) 82 | 83 | def hr(): 84 | for w, c in zip(widths, rows[0]): 85 | print('+', end='-') 86 | print('-' * w, end='-') 87 | print('-+') 88 | 89 | hr() 90 | for ri, row in enumerate(rows): 91 | for w, c in zip(widths, row): 92 | print('|', end=' ') 93 | print(c + ' ' * (w - len(c)), end=' ') 94 | print(' |') 95 | if ri == 0: 96 | hr() 97 | hr() 98 | 99 | 100 | def make_pandas_data(f, z): 101 | import itertools 102 | def pprint(val): 103 | if isinstance(val, jnp.ndarray): 104 | return val.item() 105 | from enum import Enum 106 | if isinstance(val, Enum): 107 | return val.name 108 | return val 109 | 110 | data = dict() 111 | for ax, dom in zip(f._axes, f._doms): 112 | data[f"{ax}"] = list() 113 | data[f"{f.__name__[5:]}"] = list() 114 | 115 | 116 | for row in itertools.product(*[enumerate(v) for v in f._vals]): 117 | idx = tuple([r[0] for r in row]) 118 | lead = tuple([pprint(r[1]) for r in row]) 119 | row_data = lead + (pprint(z[idx]),) 120 | 121 | for (dom, val) in zip(data.keys(), row_data): 122 | data[dom].append(val) 123 | 124 | import pandas as pd 125 | return pd.DataFrame(data) 126 | 127 | 128 | def make_xarray_data(f, z): 129 | def parse(val): 130 | if isinstance(val, jnp.ndarray): 131 | return val.item() 132 | from enum import Enum 133 | if isinstance(val, Enum): 134 | return val.name 135 | return val 136 | 137 | coords = {} 138 | for (ax, dom, vals) in zip(f._axes, f._doms, f._vals): 139 | coords[f"{ax}"] = [parse(v) for v in vals] 140 | 141 | import xarray as xr 142 | return xr.DataArray(name=f"{f.__name__[5:]}", data=z, coords=coords) -------------------------------------------------------------------------------- /memo/utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | class domain(list): 4 | def __init__(self, **kwargs): 5 | self.n = 1 6 | self.place_values = {} 7 | self.place_moduli = {} 8 | self.keys = list(kwargs.keys()) 9 | 10 | for k, v in reversed(kwargs.items()): 11 | assert isinstance(v, int) 12 | assert v > 1 13 | assert not k.startswith('_') 14 | self.place_values[k] = self.n 15 | self.n *= v 16 | self.place_moduli[k] = v 17 | 18 | for k in kwargs.keys(): 19 | self.__setattr__(k, partial(self.unpack, k=k)) 20 | 21 | def _update(self, z, **kwargs): 22 | for k in kwargs.keys(): 23 | assert k in self.keys 24 | 25 | kwargs_ = {k: kwargs.get(k, self.unpack(z, k)) for k in self.keys} 26 | return self.pack(**kwargs_) 27 | 28 | def _tuple(self, z): 29 | return tuple(self.unpack(z, k) for k in self.keys) 30 | 31 | def __iter__(self): # fool JAX into thinking I'm an array 32 | return iter(range(self.n)) 33 | 34 | def __len__(self): 35 | return self.n 36 | 37 | def __call__(self, *args, **kwargs): 38 | return self.pack(*args, **kwargs) 39 | 40 | def pack(self, *args, **kwargs): 41 | z = 0 42 | if len(args) > 0: 43 | assert len(args) == len(self.keys) 44 | for i, v in enumerate(args): 45 | z = z + v * self.place_values[self.keys[i]] 46 | 47 | else: 48 | assert len(kwargs) == len(self.keys) 49 | for j, (k, v) in enumerate(kwargs.items()): 50 | assert self.keys[j + len(args)] == k 51 | z = z + v * self.place_values[k] 52 | 53 | return z 54 | 55 | def unpack(self, z, k): 56 | assert k in self.keys 57 | return (z // self.place_values[k]) % self.place_moduli[k] 58 | 59 | 60 | import importlib.abc, importlib.util 61 | 62 | class StringLoader(importlib.abc.SourceLoader): 63 | def __init__(self, data): 64 | self.data = data 65 | def get_source(self, fullname): 66 | return self.data 67 | def get_data(self, path): 68 | return self.data.encode("utf-8") 69 | def get_filename(self, fullname): 70 | return "" 71 | 72 | def make_module(name): 73 | loader = StringLoader(''' 74 | def install(x): 75 | exec(x, globals()) 76 | return globals() 77 | ''') 78 | spec = importlib.util.spec_from_loader(name, loader, origin="built-in") 79 | module = importlib.util.module_from_spec(spec) 80 | # import sys 81 | # sys.modules[name] = module 82 | spec.loader.exec_module(module) 83 | return module -------------------------------------------------------------------------------- /memo/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.2.0' 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # python -m build 2 | [build-system] 3 | requires = ["setuptools"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [project] 7 | name = "memo-lang" 8 | dynamic = ["version"] 9 | authors = [ 10 | {name = "Kartik Chandra"}, 11 | {name = "Tony Chen"} 12 | ] 13 | license = {file = "LICENSE"} 14 | description = "A language for mental models" 15 | readme = "README.md" 16 | requires-python = ">=3.12" 17 | dependencies = [ "jax" ] 18 | classifiers = [ 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 20 | "Topic :: Software Development :: Compilers", 21 | "Topic :: Software Development :: Libraries :: Python Modules", 22 | "License :: OSI Approved :: MIT License" 23 | ] 24 | 25 | [project.urls] 26 | Homepage = "http://github.com/kach/memo" 27 | Repository = "http://github.com/kach/memo.git" 28 | Issues = "http://github.com/kach/memo/issues" 29 | 30 | [tool.setuptools] 31 | packages = ["memo"] 32 | 33 | [tool.setuptools.dynamic] 34 | version = {attr = "memo.version.__version__"} 35 | 36 | [tool.mypy] 37 | strict = true 38 | exclude = [ 39 | "demo/.*", 40 | "utils\\.py", 41 | "comic\\.py", 42 | "lib\\.py" 43 | ] 44 | 45 | [[tool.mypy.overrides]] 46 | module = [ "memo.*" ] 47 | follow_imports = "skip" # or "error" 48 | --------------------------------------------------------------------------------