├── .github └── workflows │ └── pytest_and_autopublish.yml ├── .gitignore ├── .pylintrc ├── .readthedocs.yaml ├── .vscode ├── extensions.json └── settings.json ├── AUTHORS ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── coix ├── __init__.py ├── algo.py ├── algo_test.py ├── api.py ├── api_test.py ├── core.py ├── core_test.py ├── loss.py ├── loss_test.py ├── numpyro.py ├── oryx.py ├── oryx_test.py ├── util.py └── util_test.py ├── docs ├── Makefile ├── _static │ ├── anneal.png │ ├── anneal_oryx.png │ ├── bmnist.gif │ ├── dmm.png │ ├── dmm_oryx.png │ ├── gmm.png │ ├── gmm_oryx.png │ ├── tutorial_part1_vae.png │ ├── tutorial_part2_api.png │ └── tutorial_part3_smcs.png ├── _templates │ └── breadcrumbs.html ├── algo.rst ├── api.rst ├── conf.py ├── core.rst ├── index.rst ├── loss.rst ├── requirements.txt └── util.rst ├── examples ├── README.rst ├── anneal.py ├── anneal_oryx.py ├── bmnist.py ├── dmm.py ├── dmm_oryx.py ├── gmm.py └── gmm_oryx.py ├── notebooks ├── figures │ └── smcs_nvi.png ├── tutorial_part1_vae.ipynb ├── tutorial_part2_api.ipynb └── tutorial_part3_smcs.ipynb └── pyproject.toml /.github/workflows/pytest_and_autopublish.yml: -------------------------------------------------------------------------------- 1 | name: Unittests & Auto-publish 2 | 3 | # Allow to trigger the workflow manually (e.g. when deps changes) 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | workflow_dispatch: 12 | 13 | jobs: 14 | lint: 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 30 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | # Install deps 22 | - uses: actions/setup-python@v4 23 | with: 24 | python-version: '3.11' 25 | 26 | - run: sudo apt install -y pandoc gsfonts 27 | - run: pip --version 28 | - run: pip install -e .[dev,doc] 29 | - run: pip freeze 30 | 31 | - name: Run lint 32 | run: make lint 33 | - name: Build documentation 34 | run: make docs 35 | 36 | pytest-job: 37 | needs: lint 38 | runs-on: ubuntu-latest 39 | timeout-minutes: 30 40 | 41 | concurrency: 42 | group: ${{ github.workflow }}-${{ github.ref }} 43 | cancel-in-progress: true 44 | 45 | steps: 46 | - uses: actions/checkout@v3 47 | 48 | # Install deps 49 | - uses: actions/setup-python@v4 50 | with: 51 | python-version: '3.11' 52 | # Uncomment to cache of pip dependencies (if tests too slow) 53 | # cache: pip 54 | # cache-dependency-path: '**/pyproject.toml' 55 | 56 | - run: pip install -e .[dev] 57 | - run: pip install "git+https://github.com/jax-ml/oryx.git" 58 | 59 | # Run tests (in parallel) 60 | - name: Run core tests 61 | run: pytest -vv -n auto 62 | 63 | # Run custom prng tests 64 | - name: Run custom prng tests 65 | run: JAX_ENABLE_CUSTOM_PRNG=1 pytest -vv -n auto 66 | 67 | # Auto-publish when version is increased 68 | publish-job: 69 | # Only try to publish if: 70 | # * Repo is self (prevents running from forks) 71 | # * Branch is `main` 72 | if: | 73 | github.repository == 'jax-ml/coix' 74 | && github.ref == 'refs/heads/main' 75 | needs: pytest-job # Only publish after tests are successful 76 | runs-on: ubuntu-latest 77 | permissions: 78 | contents: write 79 | timeout-minutes: 30 80 | 81 | steps: 82 | # Publish the package (if local `__version__` > pip version) 83 | - uses: etils-actions/pypi-auto-publish@v1 84 | with: 85 | pypi-token: ${{ secrets.PYPI_API_TOKEN }} 86 | gh-token: ${{ secrets.GITHUB_TOKEN }} 87 | parse-changelog: true 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | docs/notebooks 2 | docs/examples 3 | docs/getting_started.rst 4 | docs/sg_execution_times.rst 5 | venv/ 6 | 7 | # Compiled python modules. 8 | *.pyc 9 | 10 | # Byte-compiled 11 | _pycache__/ 12 | .cache/ 13 | 14 | # Poetry, setuptools, PyPI distribution artifacts. 15 | /*.egg-info 16 | .eggs/ 17 | build/ 18 | dist/ 19 | poetry.lock 20 | 21 | # Tests 22 | .pytest_cache/ 23 | 24 | # Type checking 25 | .pytype/ 26 | 27 | # Other 28 | *.DS_Store 29 | *~ 30 | .ipynb_checkpoints/ 31 | 32 | # PyCharm 33 | .idea 34 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Add files or directories to the ignore list. They should be base names, not 11 | # paths. 12 | ignore=third_party 13 | 14 | # Add files or directories matching the regex patterns to the ignore list. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=no 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code. 35 | extension-pkg-allow-list= 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Enable the message, report, category or checker with the given id(s). You can 45 | # either give multiple identifier separated by comma (,) or put this option 46 | # multiple time (only on the command line, not in the configuration file where 47 | # it should appear only once). See also the "--disable" option for examples. 48 | #enable= 49 | 50 | # Disable the message, report, category or checker with the given id(s). You 51 | # can either give multiple identifiers separated by comma (,) or put this 52 | # option multiple times (only on the command line, not in the configuration 53 | # file where it should appear only once).You can also use "--disable=all" to 54 | # disable everything first and then reenable specific checks. For example, if 55 | # you want to run only the similarities checker, you can use "--disable=all 56 | # --enable=similarities". If you want to run only the classes checker, but have 57 | # no Warning level messages displayed, use"--disable=all --enable=classes 58 | # --disable=W" 59 | disable=abstract-method, 60 | apply-builtin, 61 | arguments-differ, 62 | attribute-defined-outside-init, 63 | backtick, 64 | bad-option-value, 65 | basestring-builtin, 66 | buffer-builtin, 67 | c-extension-no-member, 68 | consider-using-enumerate, 69 | cmp-builtin, 70 | cmp-method, 71 | coerce-builtin, 72 | coerce-method, 73 | delslice-method, 74 | div-method, 75 | duplicate-code, 76 | eq-without-hash, 77 | execfile-builtin, 78 | file-builtin, 79 | filter-builtin-not-iterating, 80 | fixme, 81 | getslice-method, 82 | global-statement, 83 | hex-method, 84 | idiv-method, 85 | implicit-str-concat-in-sequence, 86 | import-error, 87 | import-self, 88 | import-star-module-level, 89 | inconsistent-return-statements, 90 | input-builtin, 91 | intern-builtin, 92 | invalid-str-codec, 93 | locally-disabled, 94 | long-builtin, 95 | long-suffix, 96 | map-builtin-not-iterating, 97 | misplaced-comparison-constant, 98 | # missing-function-docstring, 99 | metaclass-assignment, 100 | next-method-called, 101 | next-method-defined, 102 | no-absolute-import, 103 | no-else-break, 104 | no-else-continue, 105 | no-else-raise, 106 | no-else-return, 107 | no-init, # added 108 | no-member, 109 | no-name-in-module, 110 | no-self-use, 111 | nonzero-method, 112 | oct-method, 113 | old-division, 114 | old-ne-operator, 115 | old-octal-literal, 116 | old-raise-syntax, 117 | parameter-unpacking, 118 | print-statement, 119 | raising-string, 120 | range-builtin-not-iterating, 121 | raw_input-builtin, 122 | rdiv-method, 123 | reduce-builtin, 124 | relative-import, 125 | reload-builtin, 126 | round-builtin, 127 | setslice-method, 128 | signature-differs, 129 | standarderror-builtin, 130 | suppressed-message, 131 | sys-max-int, 132 | too-few-public-methods, 133 | too-many-ancestors, 134 | too-many-arguments, 135 | too-many-boolean-expressions, 136 | too-many-branches, 137 | too-many-instance-attributes, 138 | too-many-locals, 139 | too-many-nested-blocks, 140 | too-many-positional-arguments, 141 | too-many-public-methods, 142 | too-many-return-statements, 143 | too-many-statements, 144 | trailing-newlines, 145 | unichr-builtin, 146 | unicode-builtin, 147 | unnecessary-pass, 148 | unpacking-in-except, 149 | useless-else-on-loop, 150 | useless-object-inheritance, 151 | useless-suppression, 152 | using-cmp-argument, 153 | wrong-import-order, 154 | xrange-builtin, 155 | zip-builtin-not-iterating, 156 | 157 | 158 | [REPORTS] 159 | 160 | # Set the output format. Available formats are text, parseable, colorized, msvs 161 | # (visual studio) and html. You can also give a reporter class, eg 162 | # mypackage.mymodule.MyReporterClass. 163 | output-format=text 164 | 165 | # Put messages in a separate file for each module / package specified on the 166 | # command line instead of printing them on stdout. Reports (if any) will be 167 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 168 | # and it will be removed in Pylint 2.0. 169 | # files-output=no 170 | 171 | # Tells whether to display a full report or only the messages 172 | reports=no 173 | 174 | # Python expression which should return a note less than 10 (10 is the highest 175 | # note). You have access to the variables errors warning, statement which 176 | # respectively contain the number of errors / warnings messages and the total 177 | # number of statements analyzed. This is used by the global evaluation report 178 | # (RP0004). 179 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 180 | 181 | # Template used to display messages. This is a python new-style format string 182 | # used to format the message information. See doc for all details 183 | #msg-template= 184 | 185 | 186 | [BASIC] 187 | 188 | # Good variable names which should always be accepted, separated by a comma 189 | good-names=main,_ 190 | 191 | # Bad variable names which should always be refused, separated by a comma 192 | bad-names= 193 | 194 | # Colon-delimited sets of names that determine each other's naming style when 195 | # the name regexes allow several styles. 196 | name-group= 197 | 198 | # Include a hint for the correct naming format with invalid-name 199 | include-naming-hint=no 200 | 201 | # List of decorators that produce properties, such as abc.abstractproperty. Add 202 | # to this list to register other decorators that produce valid properties. 203 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 204 | 205 | # Regular expression matching correct function names 206 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 207 | 208 | # Regular expression matching correct variable names 209 | variable-rgx=^[a-z][a-z0-9_]*$ 210 | 211 | # Regular expression matching correct constant names 212 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 213 | 214 | # Regular expression matching correct attribute names 215 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 216 | 217 | # Regular expression matching correct argument names 218 | argument-rgx=^[a-z][a-z0-9_]*$ 219 | 220 | # Regular expression matching correct class attribute names 221 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 222 | 223 | # Regular expression matching correct inline iteration names 224 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 225 | 226 | # Regular expression matching correct class names 227 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 228 | 229 | # Regular expression matching correct module names 230 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 231 | 232 | # Regular expression matching correct method names 233 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 234 | 235 | # Regular expression which should only match function or class names that do 236 | # not require a docstring. 237 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 238 | 239 | # Minimum line length for functions/classes that require docstrings, shorter 240 | # ones are exempt. 241 | docstring-min-length=10 242 | 243 | 244 | [TYPECHECK] 245 | 246 | # List of decorators that produce context managers, such as 247 | # contextlib.contextmanager. Add to this list to register other decorators that 248 | # produce valid context managers. 249 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 250 | 251 | # Tells whether missing members accessed in mixin class should be ignored. A 252 | # mixin class is detected if its name ends with "mixin" (case insensitive). 253 | ignore-mixin-members=yes 254 | 255 | # List of module names for which member attributes should not be checked 256 | # (useful for modules/projects where namespaces are manipulated during runtime 257 | # and thus existing member attributes cannot be deduced by static analysis. It 258 | # supports qualified module names, as well as Unix pattern matching. 259 | ignored-modules= 260 | 261 | # List of class names for which member attributes should not be checked (useful 262 | # for classes with dynamically set attributes). This supports the use of 263 | # qualified names. 264 | ignored-classes=optparse.Values,thread._local,_thread._local 265 | 266 | # List of members which are set dynamically and missed by pylint inference 267 | # system, and so shouldn't trigger E1101 when accessed. Python regular 268 | # expressions are accepted. 269 | generated-members= 270 | 271 | 272 | [FORMAT] 273 | 274 | # Maximum number of characters on a single line. 275 | max-line-length=80 276 | 277 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 278 | # lines made too long by directives to pytype. 279 | 280 | # Regexp for a line that is allowed to be longer than the limit. 281 | ignore-long-lines=(?x)( 282 | ^\s*(\#\ )??$| 283 | ^\s*(from\s+\S+\s+)?import\s+.+$) 284 | 285 | # Allow the body of an if to be on the same line as the test if there is no 286 | # else. 287 | single-line-if-stmt=yes 288 | 289 | # List of optional constructs for which whitespace checking is disabled. `dict- 290 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 291 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 292 | # `empty-line` allows space-only lines. 293 | # no-space-check= 294 | 295 | # Maximum number of lines in a module 296 | max-module-lines=99999 297 | 298 | # String used as indentation unit. The internal Google style guide mandates 2 299 | # spaces. Google's externaly-published style guide says 4, consistent with 300 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 301 | # projects (like TensorFlow). 302 | indent-string=' ' 303 | 304 | # Number of spaces of indent required inside a hanging or continued line. 305 | indent-after-paren=4 306 | 307 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 308 | expected-line-ending-format= 309 | 310 | 311 | [MISCELLANEOUS] 312 | 313 | # List of note tags to take in consideration, separated by a comma. 314 | notes=TODO 315 | 316 | 317 | [STRING] 318 | 319 | # This flag controls whether inconsistent-quotes generates a warning when the 320 | # character used as a quote delimiter is used inconsistently within a module. 321 | check-quote-consistency=yes 322 | 323 | 324 | [VARIABLES] 325 | 326 | # Tells whether we should check for unused import in __init__ files. 327 | init-import=no 328 | 329 | # A regular expression matching the name of dummy variables (i.e. expectedly 330 | # not used). 331 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 332 | 333 | # List of additional names supposed to be defined in builtins. Remember that 334 | # you should avoid to define new builtins when possible. 335 | additional-builtins= 336 | 337 | # List of strings which can identify a callback function by name. A callback 338 | # name must start or end with one of those strings. 339 | callbacks=cb_,_cb 340 | 341 | # List of qualified module names which can have objects that can redefine 342 | # builtins. 343 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 344 | 345 | 346 | [LOGGING] 347 | 348 | # Logging modules to check that the string format arguments are in logging 349 | # function parameter format 350 | logging-modules=logging,absl.logging,tensorflow.io.logging 351 | 352 | 353 | [SIMILARITIES] 354 | 355 | # Minimum lines number of a similarity. 356 | min-similarity-lines=4 357 | 358 | # Ignore comments when computing similarities. 359 | ignore-comments=yes 360 | 361 | # Ignore docstrings when computing similarities. 362 | ignore-docstrings=yes 363 | 364 | # Ignore imports when computing similarities. 365 | ignore-imports=no 366 | 367 | 368 | [SPELLING] 369 | 370 | # Spelling dictionary name. Available dictionaries: none. To make it working 371 | # install python-enchant package. 372 | spelling-dict= 373 | 374 | # List of comma separated words that should not be checked. 375 | spelling-ignore-words= 376 | 377 | # A path to a file that contains private dictionary; one word per line. 378 | spelling-private-dict-file= 379 | 380 | # Tells whether to store unknown words to indicated private dictionary in 381 | # --spelling-private-dict-file option instead of raising a message. 382 | spelling-store-unknown-words=no 383 | 384 | 385 | [IMPORTS] 386 | 387 | # Deprecated modules which should not be used, separated by a comma 388 | deprecated-modules=regsub, 389 | TERMIOS, 390 | Bastion, 391 | rexec, 392 | sets 393 | 394 | # Create a graph of every (i.e. internal and external) dependencies in the 395 | # given file (report RP0402 must not be disabled) 396 | import-graph= 397 | 398 | # Create a graph of external dependencies in the given file (report RP0402 must 399 | # not be disabled) 400 | ext-import-graph= 401 | 402 | # Create a graph of internal dependencies in the given file (report RP0402 must 403 | # not be disabled) 404 | int-import-graph= 405 | 406 | # Force import order to recognize a module as part of the standard 407 | # compatibility libraries. 408 | known-standard-library= 409 | 410 | # Force import order to recognize a module as part of a third party library. 411 | known-third-party=enchant, absl 412 | 413 | # Analyse import fallback blocks. This can be used to support both Python 2 and 414 | # 3 compatible code, which means that the block might have code that exists 415 | # only in one or another interpreter, leading to false positives when analysed. 416 | analyse-fallback-blocks=no 417 | 418 | 419 | [CLASSES] 420 | 421 | # List of method names used to declare (i.e. assign) instance attributes. 422 | defining-attr-methods=__init__, 423 | __new__, 424 | setUp 425 | 426 | # List of member names, which should be excluded from the protected access 427 | # warning. 428 | exclude-protected=_asdict, 429 | _fields, 430 | _replace, 431 | _source, 432 | _make 433 | 434 | # List of valid names for the first argument in a class method. 435 | valid-classmethod-first-arg=cls, 436 | class_ 437 | 438 | # List of valid names for the first argument in a metaclass class method. 439 | valid-metaclass-classmethod-first-arg=mcs 440 | 441 | 442 | [EXCEPTIONS] 443 | 444 | # Exceptions that will emit a warning when being caught. Defaults to 445 | # "Exception" 446 | overgeneral-exceptions=builtins.StandardError, 447 | builtins.Exception, 448 | builtins.BaseException 449 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.9" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | formats: 12 | - pdf 13 | - epub 14 | 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.black-formatter" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.insertFinalNewline": true, 3 | "files.trimFinalNewlines": true, 4 | "files.trimTrailingWhitespace": true, 5 | "files.associations": { 6 | ".pylintrc": "ini" 7 | }, 8 | "python.testing.unittestEnabled": false, 9 | "python.testing.nosetestsEnabled": false, 10 | "python.testing.pytestEnabled": true, 11 | "python.linting.pylintUseMinimalCheckers": false, 12 | "[python]": { 13 | "editor.rulers": [80], 14 | "editor.tabSize": 2, 15 | "editor.defaultFormatter": "ms-python.black-formatter", 16 | "editor.formatOnSave": true, 17 | "editor.detectIndentation": false 18 | }, 19 | "python.formatting.provider": "none", 20 | "black-formatter.path": ["pyink"], 21 | "files.watcherExclude": { 22 | "**/.git/**": true 23 | }, 24 | "files.exclude": { 25 | "**/__pycache__": true, 26 | "**/.pytest_cache": true, 27 | "**/*.egg-info": true 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of coix's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | Du Phan 9 | Heiko Zimmermann 10 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 23 | 24 | ## [Unreleased] 25 | 26 | ## [0.1.0] - 2024-04-17 27 | 28 | * First stable release, including documentations, tutorials, and examples. 29 | 30 | ## [0.0.1] - 2023-04-25 31 | 32 | * Initial release for testing (please don't use) 33 | 34 | [Unreleased]: https://github.com/jax-ml/coix/compare/v0.1.0...HEAD 35 | [0.1.0]: https://github.com/jax-ml/coix/releases/tag/v0.1.0 36 | [0.0.1]: https://github.com/jax-ml/coix/releases/tag/0.0.1 37 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: test 2 | 3 | format: FORCE 4 | pyink . --exclude=docs 5 | isort . 6 | 7 | install: FORCE 8 | pip install -e .[dev] 9 | 10 | lint: FORCE 11 | pylint coix 12 | pyink . --check --exclude=docs 13 | isort --check . 14 | 15 | test: lint FORCE 16 | pytest -vv -n auto 17 | 18 | docs: FORCE 19 | $(MAKE) -C docs html 20 | 21 | FORCE: 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # coix 2 | 3 | [![Unittests](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml) 4 | [![Documentation Status](https://readthedocs.org/projects/coix/badge/?version=latest)](https://coix.readthedocs.io/en/latest/?badge=latest) 5 | [![PyPI version](https://badge.fury.io/py/coix.svg)](https://badge.fury.io/py/coix) 6 | 7 | Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box. 8 | 9 | Coix is a lightweight framework which includes the following main components: 10 | 11 | - **coix.api:** Implementation of the program combinators. 12 | - **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program. 13 | - **coix.loss:** Common objectives for variational inference. 14 | - **coix.algo:** Example inference algorithms. 15 | 16 | Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility. 17 | 18 | *This is not an officially supported Google product.* 19 | 20 | ## Installation 21 | 22 | To install Coix, you can use pip: 23 | 24 | ``` 25 | pip install coix 26 | ``` 27 | 28 | or you can clone the repository: 29 | 30 | ``` 31 | git clone https://github.com/jax-ml/coix.git 32 | cd coix 33 | pip install -e .[dev,doc] 34 | ``` 35 | 36 | Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support. 37 | 38 | -------------------------------------------------------------------------------- /coix/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """coix API.""" 16 | 17 | from coix import algo 18 | from coix import loss 19 | from coix import util 20 | from coix.api import compose 21 | from coix.api import extend 22 | from coix.api import fori_loop 23 | from coix.api import memoize 24 | from coix.api import propose 25 | from coix.api import resample 26 | from coix.core import detach 27 | from coix.core import empirical 28 | from coix.core import prng_key 29 | from coix.core import register_backend 30 | from coix.core import set_backend 31 | from coix.core import stick_the_landing 32 | from coix.core import suffix 33 | from coix.core import traced_evaluate 34 | 35 | __all__ = [ 36 | "__version__", 37 | "empirical", 38 | "algo", 39 | "compose", 40 | "detach", 41 | "extend", 42 | "fori_loop", 43 | "loss", 44 | "memoize", 45 | "prng_key", 46 | "propose", 47 | "register_backend", 48 | "resample", 49 | "set_backend", 50 | "stick_the_landing", 51 | "suffix", 52 | "traced_evaluate", 53 | "util", 54 | ] 55 | 56 | # A new PyPI release will be pushed everytime `__version__` is increased 57 | # When changing this, also update the CHANGELOG.md 58 | __version__ = "0.1.0" 59 | -------------------------------------------------------------------------------- /coix/algo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inference algorithms.""" 16 | 17 | import functools 18 | 19 | from coix.api import compose 20 | from coix.api import extend 21 | from coix.api import fori_loop 22 | from coix.api import propose 23 | from coix.api import resample 24 | from coix.core import detach 25 | from coix.core import stick_the_landing 26 | from coix.loss import apg_loss 27 | from coix.loss import avo_loss 28 | from coix.loss import elbo_loss 29 | from coix.loss import fkl_loss 30 | from coix.loss import iwae_loss 31 | from coix.loss import rkl_loss 32 | from coix.loss import rws_loss 33 | 34 | __all__ = [ 35 | "aft", 36 | "apgs", 37 | "dais", 38 | "nasmc", 39 | "nvi_avo", 40 | "nvi_fkl", 41 | "nvi_rkl", 42 | "rws", 43 | "svi", 44 | "svi_iwae", 45 | "svi_stl", 46 | "vsmc", 47 | ] 48 | 49 | 50 | def _use_fori_loop(targets, num_targets, *fns): 51 | """Whether or not using rolling fori_loop.""" 52 | if callable(targets): 53 | if num_targets is None: 54 | raise ValueError("To use fori_loop, num_targets needs to be specified.") 55 | for fn in fns: 56 | if not callable(fn): 57 | raise ValueError( 58 | "To use fori_loop, input programs need to be callable," 59 | f" but got {type(fn)}." 60 | ) 61 | return True 62 | return False 63 | 64 | 65 | def aft(targets, flows, *, num_targets=None): 66 | """Annealed Flow Transport. 67 | 68 | 1. *Annealed Flow Transport Monte Carlo*, 69 | Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet 70 | https://arxiv.org/abs/2102.07501 71 | 72 | Args: 73 | targets: a list of target programs 74 | flows: a list of flows 75 | num_targets: the number of targets 76 | 77 | Returns: 78 | q: the inference program 79 | """ 80 | if _use_fori_loop(targets, num_targets, flows): 81 | 82 | def body_fun(i, q): 83 | p, q = targets(i + 1), compose(flows(i), resample(q)) 84 | return propose(p, q, loss_fn=elbo_loss, detach=True) 85 | 86 | return fori_loop(0, num_targets - 1, body_fun, targets(0)) 87 | 88 | q = targets[0] 89 | for p, flow in zip(targets[1:], flows): 90 | q = propose(p, compose(flow, resample(q)), loss_fn=elbo_loss, detach=True) 91 | return q 92 | 93 | 94 | def apgs(target, kernels, *, num_sweeps=1): 95 | """Amortized Population Gibbs Sampler. 96 | 97 | 1. *Amortized Population Gibbs Samplers with Neural Sufficient Statistics*, 98 | Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de 99 | Meent 100 | https://arxiv.org/abs/1911.01382 101 | 102 | Args: 103 | target: the target program 104 | kernels: the Gibbs kernels 105 | num_sweeps: the number of sweeps 106 | 107 | Returns: 108 | q: the inference program 109 | """ 110 | kernels = [detach(k) for k in kernels] 111 | q = functools.reduce(lambda a, b: compose(b, a), kernels[1:], kernels[0]) 112 | q = propose(target, q, loss_fn=rws_loss) 113 | 114 | def body_fn(_, q): 115 | for k in kernels: 116 | q = compose(k, resample(q)) 117 | q = propose(extend(target, k), q, loss_fn=apg_loss) 118 | return q 119 | 120 | return fori_loop(0, num_sweeps, body_fn, q) 121 | 122 | 123 | def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None): 124 | """Differentiable Annealed Importance Sampling. 125 | 126 | 1. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, 127 | Tomas Geffner, Justin Domke 128 | https://arxiv.org/abs/2107.04150 129 | 2. *Differentiable Annealed Importance Sampling and the Perils of Gradient 130 | Noise*, 131 | Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse 132 | https://arxiv.org/abs/2107.10211 133 | 134 | Args: 135 | targets: a list of target programs 136 | momentum: the momentum program which calculates kinetic energy 137 | leapfrog: the program which performs leapfrog update 138 | refreshment: the momentum refreshment program 139 | num_targets: the number of targets 140 | 141 | Returns: 142 | q: the inference program 143 | """ 144 | if _use_fori_loop(targets, num_targets): 145 | 146 | def body_fun(i, q): 147 | assert callable(targets) 148 | p = extend(compose(momentum, targets(i), suffix=False), refreshment) 149 | return propose(p, compose(refreshment, compose(leapfrog, q))) 150 | 151 | q = compose(momentum, targets(0), suffix=False) 152 | q = fori_loop(1, num_targets - 1, body_fun, q) 153 | p = compose(momentum, targets(num_targets - 1), suffix=False) 154 | q = compose(refreshment, compose(leapfrog, q)) 155 | return propose(extend(p, refreshment), q, loss_fn=iwae_loss) 156 | 157 | targets = [compose(momentum, p, suffix=False) for p in targets] 158 | q = targets[0] 159 | loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,) 160 | for p, loss_fn in zip(targets[1:], loss_fns): 161 | q = compose(refreshment, compose(leapfrog, q)) 162 | q = propose(extend(p, refreshment), q, loss_fn=loss_fn) 163 | return q 164 | 165 | 166 | def nasmc(targets, proposals, *, num_targets=None): 167 | """Neural Adaptive Sequential Monte Carlo. 168 | 169 | 1. *Neural Adaptive Sequential Monte Carlo*, 170 | Shixiang Gu, Zoubin Ghahramani, Richard E. Turner 171 | https://arxiv.org/abs/1506.03338 172 | 173 | Args: 174 | targets: a list of target programs 175 | proposals: a list of proposal programs 176 | num_targets: the number of targets 177 | 178 | Returns: 179 | q: the inference program 180 | """ 181 | if _use_fori_loop(targets, num_targets, proposals): 182 | 183 | def body_fun(i, q): 184 | p, q = targets(i), compose(detach(proposals(i)), resample(q)) 185 | return propose(p, q, loss_fn=rws_loss) 186 | 187 | q = propose(targets(0), detach(proposals(0)), loss_fn=rws_loss) 188 | return fori_loop(1, num_targets, body_fun, q) 189 | 190 | q = propose(targets[0], detach(proposals[0]), loss_fn=rws_loss) 191 | for p, fwd in zip(targets[1:], proposals[1:]): 192 | q = propose(p, compose(detach(fwd), resample(q)), loss_fn=rws_loss) 193 | return q 194 | 195 | 196 | def nvi_avo(targets, forwards, reverses, *, num_targets=None): 197 | """AIS with Annealed Variational Objective. 198 | 199 | 1. *Improving Explorability in Variational Inference with Annealed Variational 200 | Objectives*, 201 | Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville 202 | https://arxiv.org/abs/1809.01818 203 | 204 | Args: 205 | targets: a list of target programs 206 | forwards: a list of forward kernels 207 | reverses: a list of reverse kernels 208 | num_targets: the number of targets 209 | 210 | Returns: 211 | q: the inference program 212 | """ 213 | if _use_fori_loop(targets, num_targets, forwards, reverses): 214 | 215 | def body_fun(i, q): 216 | p, q = extend(targets(i + 1), reverses(i)), compose(forwards(i), q) 217 | return propose(p, q, loss_fn=avo_loss, detach=True) 218 | 219 | return fori_loop(0, num_targets - 1, body_fun, targets(0)) 220 | 221 | q = targets[0] 222 | for p, fwd, rev in zip(targets[1:], forwards, reverses): 223 | q = propose(extend(p, rev), compose(fwd, q), loss_fn=avo_loss, detach=True) 224 | return q 225 | 226 | 227 | def nvi_fkl(targets, proposals, *, num_targets=None): 228 | """Nested Variational Inference with forward KL objective. 229 | 230 | Note: The implementation assumes that targets are smoothing distributions. 231 | This is different from `nasmc`, where we assume that the targets are filtering 232 | distributions. We also assume that the final target does not have parameters. 233 | 234 | 1. *Nested Variational Inference*, 235 | Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent 236 | https://arxiv.org/abs/2106.11302 237 | 238 | Args: 239 | targets: a list of target programs 240 | proposals: the proposal for the initial target 241 | num_targets: the number of targets 242 | 243 | Returns: 244 | q: the inference program 245 | """ 246 | if _use_fori_loop(targets, num_targets, proposals): 247 | 248 | def body_fun(i, q): 249 | p, q = targets(i), compose(detach(proposals(i)), resample(q)) 250 | return propose(p, q, loss_fn=fkl_loss) 251 | 252 | q = propose(targets(0), detach(proposals(0)), loss_fn=fkl_loss) 253 | return fori_loop(1, num_targets, body_fun, q) 254 | 255 | q = propose(targets[0], detach(proposals[0]), loss_fn=fkl_loss) 256 | for p, fwd in zip(targets[1:], proposals[1:]): 257 | q = propose(p, compose(detach(fwd), resample(q)), loss_fn=fkl_loss) 258 | return q 259 | 260 | 261 | def nvi_rkl(targets, forwards, reverses, *, num_targets=None): 262 | """Nested Variational Inference with reverse KL objective. 263 | 264 | If `targets` is a callable which takes an integer index and returns the i-th 265 | taget, we will use `fori_loop` combinator to improve the compiling time. This 266 | requires `num_targets` to be a concrete value. 267 | 268 | Note: In nested VI, we typically assume that the final target does not have 269 | parameters. This allows us to optimize intermediate KLs to bridge from the 270 | initial target to the final target. Here we use ELBO loss in the last step 271 | to also maximize likelihood in case there are parameters in the final target. 272 | 273 | 1. *Nested Variational Inference*, 274 | Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent 275 | https://arxiv.org/abs/2106.11302 276 | 277 | Args: 278 | targets: a list of target programs 279 | forwards: a list of forward kernels 280 | reverses: a list of reverse kernels 281 | num_targets: the number of targets 282 | 283 | Returns: 284 | q: the inference program 285 | """ 286 | if _use_fori_loop(targets, num_targets, forwards, reverses): 287 | 288 | def body_fun(i, q): 289 | p, fwd, rev = targets(i + 1), forwards(i), reverses(i) 290 | p, q = extend(p, rev), compose(stick_the_landing(fwd), resample(q)) 291 | return propose(p, q, loss_fn=rkl_loss, detach=True) 292 | 293 | return fori_loop(0, num_targets - 1, body_fun, targets(0)) 294 | 295 | q = targets[0] 296 | for p, fwd, rev in zip(targets[1:], forwards, reverses): 297 | p, q = extend(p, rev), compose(stick_the_landing(fwd), resample(q)) 298 | q = propose(p, q, loss_fn=rkl_loss, detach=True) 299 | return q 300 | 301 | 302 | def rws(target, proposal): 303 | """Reweighted Wake-Sleep. 304 | 305 | 1. *Reweighted Wake-Sleep*, 306 | Jörg Bornschein, Yoshua Bengio 307 | https://arxiv.org/abs/1406.2751 308 | 2. *Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow*, 309 | Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood 310 | https://arxiv.org/abs/1805.10469 311 | 312 | Args: 313 | target: the target program 314 | proposal: the proposal program 315 | 316 | Returns: 317 | q: the inference program 318 | """ 319 | return propose(target, detach(proposal), loss_fn=rws_loss) 320 | 321 | 322 | def svi(target, proposal): 323 | """Stochastic Variational Inference. 324 | 325 | 1. *Auto-Encoding Variational Bayes*, 326 | Diederik P Kingma, Max Welling 327 | https://arxiv.org/abs/1312.6114 328 | 2. *Stochastic Backpropagation and Approximate Inference in Deep Generative 329 | Models*, 330 | Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra 331 | https://arxiv.org/abs/1401.4082 332 | 333 | Args: 334 | target: the target program 335 | proposal: the proposal program 336 | 337 | Returns: 338 | q: the inference program 339 | """ 340 | return propose(target, proposal, loss_fn=elbo_loss) 341 | 342 | 343 | def svi_iwae(target, proposal): 344 | """SVI with Important Weighted Autoencoder objective. 345 | 346 | 1. *Importance Weighted Autoencoders*, 347 | Yuri Burda, Roger Grosse, Ruslan Salakhutdinov 348 | https://arxiv.org/abs/1509.00519 349 | 350 | Args: 351 | target: the target program 352 | proposal: the proposal program 353 | 354 | Returns: 355 | q: the inference program 356 | """ 357 | return propose(target, proposal, loss_fn=iwae_loss) 358 | 359 | 360 | def svi_stl(target, proposal): 361 | """SVI with Sticking-the-Landing objective. 362 | 363 | 1. *Sticking the Landing: Simple, Lower-Variance Gradient Estimators for 364 | Variational Inference*, 365 | Geoffrey Roeder, Yuhuai Wu, David Duvenaud 366 | https://arxiv.org/abs/1703.09194 367 | 368 | Args: 369 | target: the target program 370 | proposal: the proposal program 371 | 372 | Returns: 373 | q: the inference program 374 | """ 375 | return propose(target, stick_the_landing(proposal), loss_fn=elbo_loss) 376 | 377 | 378 | def vsmc(targets, proposals, *, num_targets=None): 379 | """Variational Sequential Monte Carlo. 380 | 381 | Note: Here, we assume that the dimension of variables is constant (modulo 382 | masking) during SMC steps. The targets can be filtering distributions or 383 | smoothing distributions (as in [4]). 384 | 385 | 1. *Filtering Variational Objectives*, 386 | Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, 387 | Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh 388 | https://arxiv.org/abs/1705.09279 389 | 2. *Auto-Encoding Sequential Monte Carlo*, 390 | Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood 391 | https://arxiv.org/abs/1705.10306 392 | 3. *Variational Sequential Monte Carlo*, 393 | Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei 394 | https://arxiv.org/abs/1705.11140 395 | 4. *Twisted Variational Sequential Monte Carlo*, 396 | Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison, 397 | Ryan P Adams, Yee Whye Teh 398 | http://bayesiandeeplearning.org/2018/papers/111.pdf 399 | 400 | Args: 401 | targets: a list of target programs 402 | proposals: a list of proposal programs 403 | num_targets: the number of targets 404 | 405 | Returns: 406 | q: the inference program 407 | """ 408 | if _use_fori_loop(targets, num_targets, proposals): 409 | 410 | def body_fun(i, q): 411 | return propose(targets(i), compose(proposals(i), resample(q))) 412 | 413 | q = propose(targets(0), proposals(0)) 414 | q = fori_loop(1, num_targets - 1, body_fun, q) 415 | q = compose(proposals(num_targets - 1), resample(q)) 416 | return propose(targets(num_targets - 1), q, loss_fn=iwae_loss) 417 | 418 | q = propose(targets[0], proposals[0]) 419 | loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,) 420 | for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns): 421 | q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn) 422 | return q 423 | -------------------------------------------------------------------------------- /coix/algo_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for algo.py.""" 16 | 17 | import functools 18 | 19 | import coix 20 | from jax import random 21 | import jax.numpy as jnp 22 | import numpy as np 23 | import numpyro 24 | import numpyro.distributions as dist 25 | import optax 26 | 27 | coix.set_backend("coix.numpyro") 28 | 29 | np.random.seed(0) 30 | num_data, dim = 4, 2 31 | data = np.random.randn(num_data, dim).astype(np.float32) 32 | loc_p = np.random.randn(dim).astype(np.float32) 33 | precision_p = np.random.rand(dim).astype(np.float32) 34 | scale_p = np.sqrt(1 / precision_p) 35 | precision_x = np.random.rand(dim).astype(np.float32) 36 | scale_x = np.sqrt(1 / precision_x) 37 | precision_q = precision_p + num_data * precision_x 38 | loc_q = (data.sum(0) * precision_x + loc_p * precision_p) / precision_q 39 | log_scale_q = -0.5 * np.log(precision_q) 40 | 41 | 42 | def vmap(p): 43 | return numpyro.handlers.plate("N", 5)(p) 44 | 45 | 46 | def model(params, key): 47 | del params 48 | key_z, key_next = random.split(key) 49 | z = numpyro.sample("z", dist.Normal(loc_p, scale_p).to_event(), rng_key=key_z) 50 | z = jnp.repeat(z[..., None, :], num_data, axis=-2) 51 | x = numpyro.sample("x", dist.Normal(z, scale_x).to_event(2), obs=data) 52 | return key_next, z, x 53 | 54 | 55 | def guide(params, key, *args): 56 | del args 57 | key, _ = random.split(key) # split here to test tie_in 58 | scale_q = jnp.exp(params["log_scale_q"]) 59 | z = numpyro.sample( 60 | "z", dist.Normal(params["loc_q"], scale_q).to_event(), rng_key=key 61 | ) 62 | return z 63 | 64 | 65 | def check_ess(make_program): 66 | params = {"loc_q": loc_q, "log_scale_q": log_scale_q} 67 | p = vmap(functools.partial(model, params)) 68 | q = vmap(functools.partial(guide, params)) 69 | program = make_program(p, q) 70 | 71 | key = random.PRNGKey(0) 72 | ess = coix.traced_evaluate(program)(key)[2]["ess"] 73 | np.testing.assert_allclose(ess, 5.0) 74 | 75 | 76 | def run_inference(make_program, num_steps=1000): 77 | """Performs inference given an algorithm `make_program`.""" 78 | 79 | def loss_fn(params, key): 80 | p = vmap(functools.partial(model, params)) 81 | q = vmap(functools.partial(guide, params)) 82 | program = make_program(p, q) 83 | 84 | metrics = coix.traced_evaluate(program)(key)[2] 85 | return metrics["loss"], metrics 86 | 87 | init_params = { 88 | "loc_q": jnp.zeros_like(loc_q), 89 | "log_scale_q": jnp.zeros_like(log_scale_q), 90 | } 91 | params, _ = coix.util.train( 92 | loss_fn, init_params, optax.adam(0.01), num_steps=num_steps 93 | ) 94 | 95 | np.testing.assert_allclose(params["loc_q"], loc_q, atol=0.2) 96 | np.testing.assert_allclose(params["log_scale_q"], log_scale_q, atol=0.2) 97 | 98 | 99 | def test_apgs(): 100 | check_ess(lambda p, q: coix.algo.apgs(p, [q])) 101 | run_inference(lambda p, q: coix.algo.apgs(p, [q])) 102 | 103 | 104 | def test_rws(): 105 | check_ess(coix.algo.rws) 106 | run_inference(coix.algo.rws) 107 | 108 | 109 | def test_svi_elbo(): 110 | check_ess(coix.algo.svi) 111 | run_inference(coix.algo.svi) 112 | 113 | 114 | def test_svi_iwae(): 115 | check_ess(coix.algo.svi_iwae) 116 | run_inference(coix.algo.svi_iwae) 117 | 118 | 119 | def test_svi_stl(): 120 | check_ess(coix.algo.svi_stl) 121 | run_inference(coix.algo.svi_stl) 122 | -------------------------------------------------------------------------------- /coix/api_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for api.py.""" 16 | 17 | import coix 18 | import jax 19 | from jax import random 20 | import jax.numpy as jnp 21 | import numpy as np 22 | import numpyro 23 | import numpyro.distributions as dist 24 | import pytest 25 | 26 | coix.set_backend("coix.numpyro") 27 | 28 | 29 | def test_compose(): 30 | def p(key): 31 | key, subkey = random.split(key) 32 | x = numpyro.sample("x", dist.Normal(0, 1), rng_key=subkey) 33 | return key, x 34 | 35 | def f(key, x): 36 | return numpyro.sample("z", dist.Normal(x, 1), rng_key=key) 37 | 38 | _, p_trace, _ = coix.traced_evaluate(coix.compose(f, p))(random.PRNGKey(0)) 39 | assert set(p_trace.keys()) == {"x", "z"} 40 | 41 | 42 | def test_extend(): 43 | def p(key): 44 | key, subkey = random.split(key) 45 | x = numpyro.sample("x", dist.Normal(0, 1), rng_key=subkey) 46 | return key, x 47 | 48 | def f(key, x): 49 | return (numpyro.sample("z", dist.Normal(x, 1), rng_key=key),) 50 | 51 | def g(z): 52 | return z + 1 53 | 54 | key = random.PRNGKey(0) 55 | out, trace, _ = coix.traced_evaluate(coix.extend(p, f))(key) 56 | assert set(trace.keys()) == {"x", "z"} 57 | 58 | expected_key, expected_x = p(key) 59 | expected_key = random.key_data(expected_key) 60 | actual_key = random.key_data(out[0]) 61 | np.testing.assert_allclose(actual_key, expected_key) 62 | np.testing.assert_allclose(out[1], expected_x) 63 | 64 | marginal_pfg = coix.traced_evaluate(coix.extend(p, coix.compose(g, f)))(key)[ 65 | 0 66 | ] 67 | actual_key2, actual_x2 = marginal_pfg 68 | actual_key2 = random.key_data(actual_key2) 69 | np.testing.assert_allclose(actual_key2, expected_key) 70 | np.testing.assert_allclose(actual_x2, expected_x) 71 | 72 | 73 | def test_propose(): 74 | def p(key): 75 | key, subkey = random.split(key) 76 | x = numpyro.sample("x", dist.Normal(0, 1), rng_key=subkey) 77 | return key, x 78 | 79 | def f(key, x): 80 | return numpyro.sample("z", dist.Normal(x, 1), rng_key=key) 81 | 82 | def q(key): 83 | return numpyro.sample("x", dist.Normal(1, 2), rng_key=key) 84 | 85 | program = coix.propose(coix.extend(p, f), q) 86 | key = random.PRNGKey(0) 87 | out, trace, metrics = coix.traced_evaluate(program)(key) 88 | assert set(trace.keys()) == {"x", "z"} 89 | assert isinstance(out, tuple) and len(out) == 2 90 | assert out[0].shape == key.shape 91 | with np.testing.assert_raises(AssertionError): 92 | np.testing.assert_allclose(metrics["log_density"], 0.0) 93 | 94 | def vmap(p): 95 | return numpyro.handlers.plate("N", 3)(p) 96 | 97 | particle_program = coix.propose(vmap(coix.extend(p, f)), vmap(q)) 98 | particle_out = particle_program(key) 99 | assert isinstance(particle_out, tuple) and len(particle_out) == 2 100 | assert particle_out[1].shape == (3,) 101 | 102 | 103 | def test_resample(): 104 | def q(key): 105 | return numpyro.sample("x", dist.Normal(1, 2), rng_key=key) 106 | 107 | particle_program = numpyro.handlers.plate("N", 3)(q) 108 | key = random.PRNGKey(0) 109 | particle_out = coix.resample(particle_program)(key) 110 | assert particle_out.shape == (3,) 111 | 112 | 113 | def test_resample_one(): 114 | def q(key): 115 | x = numpyro.sample("x", dist.Normal(1, 2), rng_key=key) 116 | return numpyro.sample("z", dist.Normal(x, 1), obs=0.0) 117 | 118 | particle_program = numpyro.handlers.plate("N", 3)(q) 119 | key = random.PRNGKey(0) 120 | particle_out = coix.resample(particle_program, num_samples=())(key) 121 | assert not jnp.shape(particle_out) 122 | 123 | 124 | def test_fori_loop(): 125 | def drift(key, x): 126 | key_out, key = random.split(key) 127 | x_new = numpyro.sample("x", dist.Normal(x, 1.0), rng_key=key) 128 | return key_out, x_new 129 | 130 | compile_time = {"value": 0} 131 | 132 | def body_fun(_, q): 133 | compile_time["value"] += 1 134 | return coix.propose(drift, coix.compose(drift, q)) 135 | 136 | q = drift 137 | for i in range(5): 138 | q = body_fun(i, q) 139 | x_init = np.zeros(3, np.float32) 140 | q(random.PRNGKey(0), x_init) 141 | assert compile_time["value"] == 5 142 | 143 | random_walk = coix.fori_loop(0, 5, body_fun, drift) 144 | random_walk(random.PRNGKey(0), x_init) 145 | assert compile_time["value"] == 6 146 | 147 | 148 | # TODO(phandu): Support memoised arrays. 149 | @pytest.mark.skip(reason="Currently, we only support memoised lists.") 150 | def test_memoize(): 151 | def model(key): 152 | x = numpyro.sample("x", dist.Normal(0, 1), rng_key=key) 153 | y = numpyro.sample("y", dist.Normal(x, 1), obs=0.0) 154 | return x, y 155 | 156 | def guide(key): 157 | return numpyro.sample("x", dist.Normal(1, 2), rng_key=key) 158 | 159 | def vmodel(key): 160 | return jax.vmap(model)(random.split(key, 5)) 161 | 162 | def vguide(key): 163 | return jax.vmap(guide)(random.split(key, 3)) 164 | 165 | memory = {"x": np.array([2, 4])} 166 | program = coix.memoize(vmodel, vguide, memory) 167 | out, trace, metrics = coix.traced_evaluate(program)(random.PRNGKey(0)) 168 | assert set(trace.keys()) == {"x"} 169 | assert "memory" in metrics 170 | assert metrics["memory"]["x"].shape == (2,) 171 | assert out[0].shape == (2,) 172 | -------------------------------------------------------------------------------- /coix/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Program transforms.""" 16 | 17 | import importlib 18 | 19 | __all__ = [ 20 | "detach", 21 | "empirical", 22 | "prng_key", 23 | "register_backend", 24 | "set_backend", 25 | "stick_the_landing", 26 | "suffix", 27 | "traced_evaluate", 28 | ] 29 | 30 | _BACKENDS = {} 31 | _COIX_BACKEND = None 32 | 33 | 34 | # pylint:disable=redefined-outer-name 35 | def register_backend( 36 | backend, 37 | traced_evaluate=None, 38 | empirical=None, 39 | suffix=None, 40 | prng_key=None, 41 | detach=None, 42 | stick_the_landing=None, 43 | ): 44 | """Register backend.""" 45 | fn_map = { 46 | "traced_evaluate": traced_evaluate, 47 | "empirical": empirical, 48 | "suffix": suffix, 49 | "prng_key": prng_key, 50 | "detach": detach, 51 | "stick_the_landing": stick_the_landing, 52 | } 53 | _BACKENDS[backend] = fn_map 54 | 55 | 56 | # pylint:enable=redefined-outer-name 57 | 58 | 59 | def set_backend(backend): 60 | """Set backend.""" 61 | global _COIX_BACKEND 62 | 63 | if backend not in _BACKENDS: 64 | module = importlib.import_module(backend) 65 | fn_map = {} 66 | for fn in [ 67 | "traced_evaluate", 68 | "empirical", 69 | "suffix", 70 | "prng_key", 71 | "detach", 72 | "stick_the_landing", 73 | ]: 74 | fn_map[fn] = getattr(module, fn, None) 75 | register_backend(backend, **fn_map) 76 | 77 | _COIX_BACKEND = backend 78 | 79 | 80 | def get_backend_name(): 81 | return _COIX_BACKEND 82 | 83 | 84 | def get_backend(): 85 | backend = _COIX_BACKEND 86 | if backend is None: 87 | set_backend("coix.numpyro") 88 | return _BACKENDS["coix.numpyro"] 89 | else: 90 | return _BACKENDS[backend] 91 | 92 | 93 | ######################################## 94 | # Program transforms 95 | ######################################## 96 | 97 | 98 | def _remove_suffix(name): 99 | i = 0 100 | while name.endswith("_PREV_"): 101 | i += len("_PREV_") 102 | name = name[: -len("_PREV_")] 103 | return name, i 104 | 105 | 106 | def desuffix(trace): 107 | """Remove unnecessary suffix terms added to the trace.""" 108 | names_to_raw_names = {} 109 | num_suffix_min = {} 110 | for name in trace: 111 | raw_name, num_suffix = _remove_suffix(name) 112 | names_to_raw_names[name] = raw_name 113 | if raw_name in num_suffix_min: 114 | num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) 115 | else: 116 | num_suffix_min[raw_name] = num_suffix 117 | new_trace = {} 118 | for name in trace: 119 | raw_name = names_to_raw_names[name] 120 | if raw_name != name and isinstance(trace[name], dict): 121 | trace[name]["suffix"] = True 122 | new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] 123 | return new_trace 124 | 125 | 126 | def traced_evaluate(p, latents=None, seed=None, **kwargs): 127 | """Performs traced evaluation for a program `p`.""" 128 | # Work around some backends not having `seed` keyword. 129 | kwargs = kwargs.copy() 130 | if seed is not None: 131 | kwargs["seed"] = seed 132 | fn = get_backend()["traced_evaluate"](p, latents=latents, **kwargs) 133 | 134 | def wrapped(*args, **kwargs): 135 | out, trace, metrics = fn(*args, **kwargs) 136 | return out, desuffix(trace), metrics 137 | 138 | return wrapped 139 | 140 | 141 | def empirical(out, trace, metrics): 142 | """Creates an empirical program given a trace.""" 143 | return get_backend()["empirical"](out, trace, metrics) 144 | 145 | 146 | def suffix(p): 147 | """Adds suffix `_PREV_` to variable names of `p`.""" 148 | fn = get_backend()["suffix"] 149 | if fn is not None: 150 | return fn(p) 151 | else: 152 | return p 153 | 154 | 155 | def detach(p): 156 | """Makes random variables in `p` become non-reparameterized.""" 157 | fn = get_backend()["detach"] 158 | if fn is not None: 159 | return fn(p) 160 | else: 161 | return p 162 | 163 | 164 | def stick_the_landing(p): 165 | """Stops gradient of distributions' parameters before computing log prob.""" 166 | fn = get_backend()["stick_the_landing"] 167 | if fn is not None: 168 | return fn(p) 169 | else: 170 | return p 171 | 172 | 173 | def prng_key(): 174 | """Generates a random JAX PRNGKey.""" 175 | fn = get_backend()["prng_key"] 176 | if fn is not None: 177 | return fn() 178 | else: 179 | return None 180 | -------------------------------------------------------------------------------- /coix/core_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for core.py.""" 16 | 17 | import coix.core 18 | 19 | 20 | def test_desuffix(): 21 | trace = { 22 | "z_PREV__PREV_": 0, 23 | "v_PREV__PREV_": 1, 24 | "z_PREV_": 2, 25 | "v_PREV_": 3, 26 | "v": 4, 27 | } 28 | desuffix_trace = { 29 | "z_PREV_": 0, 30 | "v_PREV__PREV_": 1, 31 | "z": 2, 32 | "v_PREV_": 3, 33 | "v": 4, 34 | } 35 | assert coix.core.desuffix(trace) == desuffix_trace 36 | -------------------------------------------------------------------------------- /coix/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inference objectives.""" 16 | 17 | from coix import util 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | __all__ = [ 22 | "apg_loss", 23 | "avo_loss", 24 | "elbo_loss", 25 | "fkl_loss", 26 | "iwae_loss", 27 | "rkl_loss", 28 | "rws_loss", 29 | ] 30 | 31 | 32 | def apg_loss( 33 | q_trace, 34 | p_trace, 35 | incoming_log_weight, 36 | incremental_log_weight, 37 | aggregate=True, 38 | ): 39 | """RWS objective that exploits conditional dependency.""" 40 | del incoming_log_weight, incremental_log_weight 41 | p_log_probs = { 42 | name: util.get_site_log_prob(site) for name, site in p_trace.items() 43 | } 44 | q_log_probs = { 45 | name: util.get_site_log_prob(site) for name, site in q_trace.items() 46 | } 47 | forward_sites = [name[:-6] for name in p_trace if name.endswith("_PREV_")] 48 | observed = [ 49 | name for name, site in p_trace.items() if util.is_observed_site(site) 50 | ] 51 | 52 | log_probs = [ 53 | lp 54 | for name, lp in p_log_probs.items() 55 | if (name in forward_sites) or name in observed 56 | ] 57 | min_ndim = min(jnp.ndim(lp) for lp in log_probs) 58 | batch_shape = () 59 | for i in range(min_ndim): 60 | dims = set(jnp.shape(lp)[i] for lp in log_probs) 61 | if len(dims) > 1: 62 | break 63 | batch_shape = batch_shape + tuple(dims) 64 | batch_ndims = len(batch_shape) 65 | 66 | global_sites = [ 67 | name 68 | for name, lp in p_log_probs.items() 69 | if (not name.endswith("_PREV_")) 70 | and (lp.shape[:batch_ndims] != batch_shape) 71 | ] 72 | target_lp = sum( 73 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 74 | for name, lp in p_log_probs.items() 75 | if not (name.endswith("_PREV_") or name in global_sites) 76 | ) 77 | reverse_lp = sum( 78 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 79 | for name, lp in p_log_probs.items() 80 | if name.endswith("_") 81 | ) 82 | forward_lp = sum( 83 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 84 | for name, lp in q_log_probs.items() 85 | if name in forward_sites 86 | ) 87 | proposal_lp = sum( 88 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 89 | for name, lp in q_log_probs.items() 90 | if (name not in forward_sites) and name not in global_sites 91 | ) 92 | 93 | surrogate_loss = target_lp + forward_lp 94 | log_weight = target_lp + reverse_lp - (forward_lp + proposal_lp) 95 | w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) 96 | loss = -(w * surrogate_loss) 97 | if aggregate: 98 | loss = loss.sum() 99 | return loss 100 | 101 | 102 | def avo_loss( 103 | q_trace, 104 | p_trace, 105 | incoming_log_weight, 106 | incremental_log_weight, 107 | aggregate=True, 108 | ): 109 | """Annealed Variational Objective.""" 110 | del q_trace, p_trace 111 | surrogate_loss = incremental_log_weight 112 | if jnp.ndim(incoming_log_weight) > 0: 113 | w1 = 1.0 / incoming_log_weight.shape[0] 114 | else: 115 | w1 = 1.0 116 | loss = -(w1 * surrogate_loss) 117 | if aggregate: 118 | loss = loss.sum() 119 | return loss 120 | 121 | 122 | def elbo_loss( 123 | q_trace, 124 | p_trace, 125 | incoming_log_weight, 126 | incremental_log_weight, 127 | aggregate=True, 128 | ): 129 | """Evidence Lower Bound objective.""" 130 | del q_trace, p_trace 131 | surrogate_loss = incremental_log_weight 132 | if jnp.ndim(incoming_log_weight) > 0: 133 | w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) 134 | else: 135 | w1 = 1.0 136 | loss = -(w1 * surrogate_loss) 137 | if aggregate: 138 | loss = loss.sum() 139 | return loss 140 | 141 | 142 | def fkl_loss( 143 | q_trace, 144 | p_trace, 145 | incoming_log_weight, 146 | incremental_log_weight, 147 | aggregate=True, 148 | ): 149 | """Forward KL objective. Here we do not optimize p.""" 150 | del p_trace 151 | batch_ndims = incoming_log_weight.ndim 152 | q_log_probs = { 153 | name: util.get_site_log_prob(site) for name, site in q_trace.items() 154 | } 155 | proposal_sites = [ 156 | name 157 | for name, site in q_trace.items() 158 | if name.endswith("_PREV_") 159 | or (isinstance(site, dict) and "suffix" in site) 160 | ] 161 | 162 | proposal_lp = sum( 163 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 164 | for name, lp in q_log_probs.items() 165 | if name in proposal_sites 166 | ) 167 | forward_lp = sum( 168 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 169 | for name, lp in q_log_probs.items() 170 | if name not in proposal_sites 171 | ) 172 | 173 | surrogate_loss = forward_lp + proposal_lp 174 | w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) 175 | log_weight = incoming_log_weight + incremental_log_weight 176 | w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) 177 | loss = -(w * surrogate_loss - w1 * proposal_lp) 178 | if aggregate: 179 | loss = loss.sum() 180 | return loss 181 | 182 | 183 | def iwae_loss( 184 | q_trace, 185 | p_trace, 186 | incoming_log_weight, 187 | incremental_log_weight, 188 | aggregate=True, 189 | ): 190 | """Importance Weighted Autoencoder objective.""" 191 | del q_trace, p_trace 192 | log_weight = incoming_log_weight + incremental_log_weight 193 | surrogate_loss = log_weight 194 | if jnp.ndim(incoming_log_weight) > 0: 195 | w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) 196 | else: 197 | w = 1.0 198 | loss = -(w * surrogate_loss) 199 | if aggregate: 200 | loss = loss.sum() 201 | return loss 202 | 203 | 204 | def rkl_loss( 205 | q_trace, 206 | p_trace, 207 | incoming_log_weight, 208 | incremental_log_weight, 209 | aggregate=True, 210 | ): 211 | """Reverse KL objective.""" 212 | batch_ndims = incoming_log_weight.ndim 213 | p_log_probs = { 214 | name: util.get_site_log_prob(site) for name, site in p_trace.items() 215 | } 216 | q_log_probs = { 217 | name: util.get_site_log_prob(site) for name, site in q_trace.items() 218 | } 219 | proposal_sites = [ 220 | name 221 | for name, site in q_trace.items() 222 | if name.endswith("_PREV_") 223 | or (isinstance(site, dict) and "suffix" in site) 224 | ] 225 | 226 | proposal_lp = sum( 227 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 228 | for name, lp in q_log_probs.items() 229 | if name in proposal_sites 230 | ) 231 | target_lp = sum( 232 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 233 | for name, lp in p_log_probs.items() 234 | if not name.endswith("_PREV_") 235 | ) 236 | 237 | w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) 238 | v = jax.lax.stop_gradient(incremental_log_weight) 239 | surrogate_loss = ( 240 | incremental_log_weight + (1 + v - (w1 * v).sum(0)) * proposal_lp 241 | ) 242 | log_weight = incoming_log_weight + incremental_log_weight 243 | w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) 244 | loss = -(w1 * surrogate_loss - w * target_lp) 245 | if aggregate: 246 | loss = loss.sum() 247 | return loss 248 | 249 | 250 | def rws_loss( 251 | q_trace, 252 | p_trace, 253 | incoming_log_weight, 254 | incremental_log_weight, 255 | aggregate=True, 256 | ): 257 | """Reweighted Wake-Sleep objective.""" 258 | batch_ndims = incoming_log_weight.ndim 259 | p_log_probs = { 260 | name: util.get_site_log_prob(site) for name, site in p_trace.items() 261 | } 262 | q_log_probs = { 263 | name: util.get_site_log_prob(site) for name, site in q_trace.items() 264 | } 265 | proposal_sites = [ 266 | name 267 | for name, site in q_trace.items() 268 | if name.endswith("_PREV_") 269 | or (isinstance(site, dict) and "suffix" in site) 270 | ] 271 | 272 | proposal_lp = sum( 273 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 274 | for name, lp in q_log_probs.items() 275 | if name in proposal_sites 276 | ) 277 | forward_lp = sum( 278 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 279 | for name, lp in q_log_probs.items() 280 | if name not in proposal_sites 281 | ) 282 | target_lp = sum( 283 | lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) 284 | for name, lp in p_log_probs.items() 285 | if not name.endswith("_PREV_") 286 | ) 287 | 288 | surrogate_loss = (target_lp - proposal_lp) + forward_lp 289 | log_weight = incoming_log_weight + incremental_log_weight 290 | w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) 291 | loss = -(w * surrogate_loss) 292 | if aggregate: 293 | loss = loss.sum() 294 | return loss 295 | -------------------------------------------------------------------------------- /coix/loss_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for loss.py.""" 16 | 17 | import coix 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | p_trace = { 22 | "x": {"log_prob": np.full((3, 2), 2.0)}, 23 | "y": {"log_prob": np.array([3.0, 0.0, -2.0])}, 24 | "x_PREV_": {"log_prob": np.ones((3, 2))}, 25 | } 26 | q_trace = { 27 | "x": {"log_prob": np.ones((3, 2))}, 28 | "y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True}, 29 | "x_PREV_": {"log_prob": np.full((3, 2), 3.0)}, 30 | } 31 | incoming_weight = np.zeros(3) 32 | incremental_weight = np.log(np.array([1 / 6, 1 / 3, 1 / 2])) 33 | 34 | 35 | def test_apg(): 36 | result = coix.loss.apg_loss( 37 | q_trace, p_trace, incoming_weight, incremental_weight 38 | ) 39 | np.testing.assert_allclose(result, -6.0) 40 | 41 | 42 | def test_elbo(): 43 | result = coix.loss.elbo_loss( 44 | q_trace, p_trace, incoming_weight, incremental_weight 45 | ) 46 | expected = -incremental_weight.sum() / 3 47 | np.testing.assert_allclose(result, expected) 48 | 49 | 50 | def test_iwae(): 51 | result = coix.loss.iwae_loss( 52 | q_trace, p_trace, incoming_weight, incremental_weight 53 | ) 54 | w = incoming_weight + incremental_weight 55 | expected = -(jnp.exp(w) * w).sum() 56 | np.testing.assert_allclose(result, expected, rtol=1e-6) 57 | 58 | 59 | def test_rws(): 60 | result = coix.loss.rws_loss( 61 | q_trace, p_trace, incoming_weight, incremental_weight 62 | ) 63 | np.testing.assert_allclose(result, 1.0, rtol=1e-6) 64 | -------------------------------------------------------------------------------- /coix/numpyro.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Backend implementation for NumPyro.""" 16 | 17 | from coix.util import get_batch_ndims 18 | from coix.util import get_log_weight 19 | from coix.util import get_site_log_prob 20 | import jax 21 | import jax.numpy as jnp 22 | import numpyro 23 | from numpyro import handlers 24 | import numpyro.distributions as dist 25 | 26 | __all__ = [ 27 | "detach", 28 | "empirical", 29 | "prng_key", 30 | "stick_the_landing", 31 | "suffix", 32 | "traced_evaluate", 33 | ] 34 | 35 | prng_key = numpyro.prng_key 36 | 37 | 38 | def traced_evaluate(p, latents=None, seed=None): 39 | """Performs traced evaluation for a program `p`.""" 40 | 41 | def wrapped(*args, **kwargs): 42 | data = {} if latents is None else latents 43 | rng_seed = numpyro.prng_key() if seed is None else seed 44 | subs_model = handlers.seed( 45 | handlers.substitute(p, data=data), rng_seed=rng_seed 46 | ) 47 | with handlers.block(), handlers.trace() as tr: 48 | out = subs_model(*args, **kwargs) 49 | trace = {} 50 | for name, site in tr.items(): 51 | if site["type"] == "sample": 52 | value = site["value"] 53 | log_prob = site["fn"].log_prob(value) 54 | event_dim_holder = jnp.empty([1] * site["fn"].event_dim) 55 | trace[name] = { 56 | "value": value, 57 | "log_prob": log_prob, 58 | "_event_dim_holder": event_dim_holder, 59 | } 60 | if site.get("is_observed", False): 61 | trace[name]["is_observed"] = True 62 | metrics = { 63 | name: site["value"] 64 | for name, site in tr.items() 65 | if site["type"] == "metric" 66 | } 67 | # add log_weight to metrics 68 | if "log_weight" not in metrics: 69 | log_probs = [get_site_log_prob(site) for site in trace.values()] 70 | weight = get_log_weight(trace, get_batch_ndims(log_probs)) 71 | metrics = {**metrics, "log_weight": weight} 72 | return out, trace, metrics 73 | 74 | return wrapped 75 | 76 | 77 | def add_metric(name, value): 78 | """A NumPyro primitive to add `metric` type to a program.""" 79 | if numpyro.primitives._PYRO_STACK: # pylint:disable=protected-access 80 | msg = {"type": "metric", "value": value, "name": name} 81 | numpyro.primitives.apply_stack(msg) 82 | 83 | 84 | def empirical(out, trace, metrics): 85 | """A program that produces `out`, `trace`, and `metrics` under evaluation.""" 86 | 87 | def wrapped(*args, **kwargs): 88 | del args, kwargs 89 | for name, site in trace.items(): 90 | value, lp = site["value"], site["log_prob"] 91 | event_dim = jnp.ndim(site["_event_dim_holder"]) 92 | obs = value if "is_observed" in site else None 93 | numpyro.sample(name, dist.Delta(value, lp, event_dim=event_dim), obs=obs) 94 | for name, value in metrics.items(): 95 | add_metric(name, value) 96 | return out 97 | 98 | return wrapped 99 | 100 | 101 | class suffix(numpyro.primitives.Messenger): # pylint:disable=invalid-name 102 | 103 | def process_message(self, msg): 104 | if msg["type"] == "sample": 105 | msg["name"] = msg["name"] + "_PREV_" 106 | 107 | 108 | class StopGradient(dist.Distribution): 109 | """Nonreparameterized or stick-the-landing distribution.""" 110 | 111 | def __init__(self, base_dist, detach_sample=False, detach_args=False): 112 | self.base_dist = base_dist 113 | self.detach_sample, self.detach_args = detach_sample, detach_args 114 | super().__init__(base_dist.batch_shape, base_dist.event_shape) 115 | 116 | def sample(self, key, sample_shape=()): 117 | samples = self.base_dist.sample(key, sample_shape=sample_shape) 118 | return jax.lax.stop_gradient(samples) if self.detach_sample else samples 119 | 120 | def log_prob(self, value): 121 | d = ( 122 | jax.lax.stop_gradient(self.base_dist) 123 | if self.detach_args 124 | else self.base_dist 125 | ) 126 | return d.log_prob(value) 127 | 128 | def tree_flatten(self): 129 | params, treedef = jax.tree.flatten(self.base_dist) 130 | return params, (treedef, self.detach_sample, self.detach_args) 131 | 132 | @classmethod 133 | def tree_unflatten(cls, aux_data, params): 134 | treedef, detach_sample, detach_args = aux_data 135 | base_dist = jax.tree.unflatten(treedef, params) 136 | return cls(base_dist, detach_sample=detach_sample, detach_args=detach_args) 137 | 138 | 139 | class detach(numpyro.primitives.Messenger): # pylint:disable=invalid-name 140 | 141 | def process_message(self, msg): 142 | if msg["type"] == "sample" and not msg.get("is_observed", False): 143 | msg["fn"] = StopGradient(msg["fn"], detach_sample=True) 144 | 145 | 146 | class stick_the_landing(numpyro.primitives.Messenger): # pylint:disable=invalid-name 147 | 148 | def process_message(self, msg): 149 | if msg["type"] == "sample" and not msg.get("is_observed", False): 150 | msg["fn"] = StopGradient(msg["fn"], detach_args=True) 151 | -------------------------------------------------------------------------------- /coix/oryx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Program primitives and transforms.""" 16 | 17 | import functools 18 | import inspect 19 | import itertools 20 | from typing import Any, Callable, Iterable, TypeVar 21 | 22 | from coix.util import get_batch_ndims 23 | from coix.util import get_log_weight 24 | from coix.util import get_site_log_prob 25 | import jax 26 | import jax.numpy as jnp 27 | 28 | from oryx.core import ppl 29 | from oryx.core import primitive 30 | from oryx.core import trace_util 31 | from oryx.core.interpreters import harvest 32 | from oryx.core.ppl import effect_handler 33 | from oryx.distributions import distribution_extensions 34 | 35 | random_variable_p = distribution_extensions.random_variable_p 36 | 37 | __all__ = [ 38 | "detach", 39 | "empirical", 40 | "factor", 41 | "prng_key", 42 | "rv", 43 | "stick_the_landing", 44 | "suffix", 45 | "traced_evaluate", 46 | ] 47 | 48 | DISTRIBUTION = "distribution" 49 | METRIC = "metric" 50 | OBSERVED = "observed" 51 | RANDOM_VARIABLE = ppl.RANDOM_VARIABLE 52 | 53 | ALL_TAGS = (RANDOM_VARIABLE, OBSERVED, DISTRIBUTION, METRIC) 54 | 55 | T = TypeVar("T") 56 | 57 | 58 | def safe_map(f: Callable[..., T], *args: Iterable[Any]) -> list[T]: 59 | """Like map(), but checks argument lengths and returns a list.""" 60 | return [f(*a) for a in zip(*args, strict=True)] 61 | 62 | 63 | ######################################## 64 | # Override Oryx behaviors 65 | ######################################## 66 | 67 | 68 | # Patch Oryx behavior to handle custom_jvp properly. 69 | def _process_custom_jvp_call( 70 | self, trace, prim, fun, jvp, tracers, *, symbolic_zeros 71 | ): 72 | """Patch harvest.ReapContext.process_custom_jvp_call.""" 73 | del self 74 | vals_in = [t.val for t in tracers] 75 | out_flat = prim.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) 76 | out_tracer = safe_map(trace.pure, out_flat) 77 | return out_tracer 78 | 79 | 80 | harvest.ReapContext.process_custom_jvp_call = _process_custom_jvp_call 81 | 82 | 83 | def _eval_jaxpr_with_state(jaxpr, rules, consts, state, *args): 84 | """Patch effect_handler.eval_jaxpr_with_state.""" 85 | env = effect_handler.Environment() 86 | 87 | safe_map(env.write, jaxpr.constvars, consts) 88 | safe_map(env.write, jaxpr.invars, args) 89 | 90 | for eqn in jaxpr.eqns: 91 | invals = safe_map(env.read, eqn.invars) 92 | call_jaxpr, params = trace_util.extract_call_jaxpr( 93 | eqn.primitive, eqn.params 94 | ) 95 | if eqn.primitive.name == "custom_jvp_call": 96 | subfuns, bind_params = eqn.primitive.get_bind_params(params) 97 | ans = eqn.primitive.bind(*subfuns, *invals, **bind_params) 98 | elif call_jaxpr: 99 | call_rule = effect_handler._effect_handler_call_rules.get( # pylint: disable=protected-access 100 | eqn.primitive, 101 | functools.partial( 102 | effect_handler.default_call_interpreter_rule, eqn.primitive 103 | ), 104 | ) 105 | ans, state = call_rule(rules, state, invals, call_jaxpr, **params) 106 | elif eqn.primitive in rules: 107 | ans, state = rules[eqn.primitive](state, *invals, **params) 108 | else: 109 | ans = eqn.primitive.bind(*invals, **params) 110 | if eqn.primitive.multiple_results: 111 | safe_map(env.write, eqn.outvars, ans) 112 | else: 113 | env.write(eqn.outvars[0], ans) 114 | return safe_map(env.read, jaxpr.outvars), state 115 | 116 | 117 | effect_handler.eval_jaxpr_with_state = _eval_jaxpr_with_state 118 | 119 | 120 | def identity(value, dist): 121 | del dist 122 | return value 123 | 124 | 125 | def _dist_sample(key, dist): 126 | if "seed" in inspect.getfullargspec(dist.sample).args: 127 | return dist.sample(seed=key) 128 | else: 129 | return dist.sample(key) 130 | 131 | 132 | def rv(dist, *, obs=None, name=None): 133 | """Declares a random variable.""" 134 | 135 | # This behaves like oryx.core.ppl.rv but allows observed declaration 136 | # and batched distribution. 137 | 138 | def sample(key): 139 | if obs is None: 140 | sample_fn = _dist_sample 141 | sample_args = (key, dist) 142 | else: 143 | sample_fn = identity 144 | sample_args = (obs, dist) 145 | 146 | result = primitive.initial_style_bind( 147 | random_variable_p, 148 | batch_ndims=0, 149 | distribution_name=dist.__class__.__name__, 150 | name=name, 151 | mode="strict", 152 | )(sample_fn)(*sample_args) 153 | return harvest.sow(result, tag=RANDOM_VARIABLE, name=name) 154 | 155 | if obs is None: 156 | return sample 157 | else: 158 | return harvest.sow(sample(0), tag=OBSERVED, name=name) 159 | 160 | 161 | @jax.tree_util.register_pytree_node_class 162 | class Delta: 163 | """Dirac Delta distribution.""" 164 | 165 | def __init__(self, value, log_density): 166 | self.value = value 167 | self.log_density = log_density 168 | 169 | def sample(self, key): 170 | del key 171 | return self.value 172 | 173 | def log_prob(self, value): 174 | del value 175 | return self.log_density 176 | 177 | def tree_flatten(self): 178 | return ((self.value, self.log_density), None) 179 | 180 | @classmethod 181 | def tree_unflatten(cls, aux_data, children): 182 | del aux_data 183 | return cls(*children) 184 | 185 | 186 | # We follow Pyro approach to use Unit distributions for factors. 187 | @jax.tree_util.register_pytree_node_class 188 | class Unit: 189 | """Unit Factor distribution.""" 190 | 191 | def __init__(self, log_factor): 192 | self.log_factor = log_factor 193 | 194 | def sample(self, key): 195 | del key 196 | return jnp.empty((0,)) 197 | 198 | def log_prob(self, value): 199 | del value 200 | return self.log_factor 201 | 202 | def tree_flatten(self): 203 | return ((self.log_factor,), None) 204 | 205 | @classmethod 206 | def tree_unflatten(cls, aux_data, children): 207 | del aux_data 208 | return cls(*children) 209 | 210 | 211 | def factor(log_factor, *, name=None): 212 | """Declare a factor to be added to a program.""" 213 | return rv(Unit(log_factor), name=name)(0) 214 | 215 | 216 | ######################################## 217 | # Effect Handlers 218 | ######################################## 219 | 220 | 221 | def _split_list(args, num_consts): 222 | return args[num_consts:] 223 | 224 | 225 | def substitute_rule(state, *args, **kwargs): 226 | """Rule for substitute handler.""" 227 | name = kwargs.get("name") 228 | if name in state: 229 | flat_args = _split_list(args, kwargs["num_consts"]) 230 | _, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) 231 | value = state[name] 232 | value = primitive.tie_in(flat_args, value) 233 | jaxpr, _ = trace_util.stage(identity, dynamic=True)(value, dist) 234 | kwargs["jaxpr"] = jaxpr.jaxpr 235 | kwargs["num_consts"] = len(jaxpr.literals) 236 | args = itertools.chain(jaxpr.literals, (value,), flat_args[1:]) 237 | return random_variable_p.bind(*args, **kwargs), state 238 | 239 | 240 | substitute_handler = ppl.make_effect_handler( 241 | {random_variable_p: substitute_rule} 242 | ) 243 | 244 | 245 | def substitute(f, latents): 246 | """Runs `f` with latent values are obtained from `latents`.""" 247 | 248 | def wrapped(*args, **kwargs): 249 | return substitute_handler(f)(latents, *args, **kwargs)[0] 250 | 251 | return wrapped 252 | 253 | 254 | def distribution_rule(state, *args, **kwargs): 255 | """Rule for distribution handler.""" 256 | name = kwargs.get("name") 257 | if name is not None: 258 | flat_args = _split_list(args, kwargs["num_consts"]) 259 | _, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) 260 | dist_flat, dist_tree = jax.tree.flatten(dist) 261 | state[name] = {dist_tree: dist_flat} 262 | return random_variable_p.bind(*args, **kwargs), state 263 | 264 | 265 | distribution_handler = ppl.make_effect_handler( 266 | {random_variable_p: distribution_rule} 267 | ) 268 | 269 | 270 | def tag_distribution(f): 271 | """Executes f with distributions tagged.""" 272 | 273 | def wrapped(*args, **kwargs): 274 | out, fns = distribution_handler(f)({}, *args, **kwargs) 275 | for name, fn in fns.items(): 276 | harvest.sow(fn, tag=DISTRIBUTION, name=name) 277 | return out 278 | 279 | return wrapped 280 | 281 | 282 | def suffix_rule(state, *args, **kwargs): 283 | """Suffix rule for `sow_p` primitive.""" 284 | if kwargs["tag"] in [OBSERVED, RANDOM_VARIABLE]: 285 | if kwargs["name"]: 286 | kwargs["name"] = kwargs["name"] + "_PREV_" 287 | return harvest.sow_p.bind(*args, **kwargs), state 288 | 289 | 290 | def suffix_rv_rule(state, *args, **kwargs): 291 | """Suffix rule for `random_variable_p` primitive.""" 292 | if kwargs.get("name"): 293 | kwargs["name"] = kwargs["name"] + "_PREV_" 294 | return random_variable_p.bind(*args, **kwargs), state 295 | 296 | 297 | suffix_handler = ppl.make_effect_handler( 298 | {harvest.sow_p: suffix_rule, random_variable_p: suffix_rv_rule} 299 | ) 300 | 301 | 302 | def suffix(f): 303 | """Adds suffix to random variables appeared in `names`.""" 304 | 305 | def wrapped(*args, **kwargs): 306 | return suffix_handler(f)(None, *args, **kwargs)[0] 307 | 308 | return wrapped 309 | 310 | 311 | def detach_rule(state, *args, **kwargs): 312 | """Rule for detach handler.""" 313 | consts = args[: kwargs["num_consts"]] 314 | run_args = args[kwargs["num_consts"] :] 315 | 316 | def _run(*args): 317 | return jax.lax.stop_gradient( 318 | random_variable_p.bind(*itertools.chain(consts, args), **kwargs) 319 | ) 320 | 321 | detach_jaxpr, _ = trace_util.stage(_run, dynamic=True)(*run_args) 322 | kwargs["jaxpr"] = detach_jaxpr.jaxpr 323 | return random_variable_p.bind(*args, **kwargs), state 324 | 325 | 326 | detach_handler = ppl.make_effect_handler({random_variable_p: detach_rule}) 327 | 328 | 329 | def detach(f): 330 | """Detach handler.""" 331 | 332 | def wrapped(*args, **kwargs): 333 | return detach_handler(f)(None, *args, **kwargs)[0] 334 | 335 | return wrapped 336 | 337 | 338 | @jax.tree_util.register_pytree_node_class 339 | class STLDistribution: 340 | """Sticking-the-landing log density.""" 341 | 342 | def __init__(self, base_dist): 343 | self.base_dist = base_dist 344 | 345 | def sample(self, key): 346 | return _dist_sample(key, self.base_dist) 347 | 348 | def log_prob(self, value): 349 | return jax.lax.stop_gradient(self.base_dist).log_prob(value) 350 | 351 | def tree_flatten(self): 352 | params, treedef = jax.tree.flatten(self.base_dist) 353 | return (params, treedef) 354 | 355 | @classmethod 356 | def tree_unflatten(cls, aux_data, children): 357 | base_dist = jax.tree.unflatten(aux_data, children) 358 | return cls(base_dist) 359 | 360 | 361 | def stl_rule(state, *args, **kwargs): 362 | flat_args = _split_list(args, kwargs["num_consts"]) 363 | key, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) 364 | stl_dist = STLDistribution(dist) 365 | _, in_tree = jax.tree.flatten((key, stl_dist)) 366 | kwargs["in_tree"] = in_tree 367 | out = random_variable_p.bind(*args, **kwargs) 368 | return out, state 369 | 370 | 371 | stl_handler = ppl.make_effect_handler({random_variable_p: stl_rule}) 372 | 373 | 374 | def stick_the_landing(f): 375 | def wrapped(*args, **kwargs): 376 | return stl_handler(f)(None, *args, **kwargs)[0] 377 | 378 | return wrapped 379 | 380 | 381 | ######################################## 382 | # Reap helpers 383 | ######################################## 384 | 385 | 386 | def call_and_reap_tags(f, tags): 387 | """A helper to collect values from a sequence of tags.""" 388 | tags = [tags] if isinstance(tags, str) else tags 389 | 390 | def wrapped(*args, **kwargs): 391 | f_with_tag = f 392 | for tag in tags: 393 | f_with_tag = harvest.call_and_reap(f_with_tag, tag=tag) 394 | out_with_tag = f_with_tag(*args, **kwargs) 395 | tags_dict = {} 396 | for tag in tags[::-1]: 397 | out_with_tag, values = out_with_tag 398 | tags_dict[tag] = values 399 | return out_with_tag, tags_dict 400 | 401 | return wrapped 402 | 403 | 404 | def traced_evaluate(p, latents=None): 405 | """Perform traced evaluation. 406 | 407 | Args: 408 | p: a program 409 | latents: optional values to be substituted into `p` 410 | 411 | Returns: 412 | (out, p_trace, p_metrics): a tuple of marginal output, trace, and metrics 413 | """ 414 | 415 | def wrapped(*args, **kwargs): 416 | p_subs = substitute(p, latents=latents) if latents is not None else p 417 | p_tagged = tag_distribution(p_subs) 418 | out, tags = call_and_reap_tags(p_tagged, ALL_TAGS)(*args, **kwargs) 419 | trace = {} 420 | for name, value in tags[RANDOM_VARIABLE].items(): 421 | dist_tree, dist_flat = list(tags[DISTRIBUTION][name].items())[0] 422 | dist = jax.tree.unflatten(dist_tree, dist_flat) 423 | trace[name] = {"value": value, "log_prob": dist.log_prob(value)} 424 | if name in tags[OBSERVED]: 425 | trace[name]["is_observed"] = True 426 | metrics = tags[METRIC] 427 | if "loss" not in metrics: 428 | metrics["loss"] = jnp.array(0.0) 429 | if "log_density" not in metrics: 430 | log_density = sum(jnp.sum(site["log_prob"]) for site in trace.values()) 431 | metrics["log_density"] = jnp.array(0.0) + log_density 432 | if "log_weight" not in metrics: 433 | log_probs = [get_site_log_prob(site) for site in trace.values()] 434 | weight = get_log_weight(trace, get_batch_ndims(log_probs)) 435 | metrics = {**metrics, "log_weight": weight} 436 | return out, trace, metrics 437 | 438 | return wrapped 439 | 440 | 441 | def sow_metric(value, name): 442 | return harvest.sow(value, tag=METRIC, name=name) 443 | 444 | 445 | def empirical(out, trace, metrics): 446 | """Creates a deterministic program with Delta variables.""" 447 | 448 | def wrapped(*args, **kwargs): 449 | tie_trace, tie_metrics = primitive.tie_in((args, kwargs), (trace, metrics)) 450 | for name, site in tie_trace.items(): 451 | value, lp = site["value"], site["log_prob"] 452 | if "is_observed" in site: 453 | rv(Delta(value, lp), obs=value, name=name) 454 | else: 455 | rv(Delta(value, lp), name=name)(0) 456 | for name, value in tie_metrics.items(): 457 | sow_metric(value, name) 458 | return out 459 | 460 | return wrapped 461 | 462 | 463 | def prng_key(): 464 | raise ValueError("Cannot genenerate random key under the oryx backend.") 465 | -------------------------------------------------------------------------------- /coix/oryx_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for oryx.py.""" 16 | 17 | import coix 18 | import coix.core 19 | 20 | # pylint: disable=g-import-not-at-top 21 | try: 22 | import coix.oryx as coryx 23 | except (ModuleNotFoundError, ImportError): 24 | coryx = None 25 | import jax 26 | from jax import random 27 | import jax.numpy as jnp 28 | import numpy as np 29 | import numpyro.distributions as dist 30 | import pytest 31 | 32 | pytest.skip("oryx backend is broken", allow_module_level=True) 33 | 34 | 35 | def test_call_and_reap_tags(): 36 | coix.set_backend("coix.oryx") 37 | 38 | def model(key): 39 | return coryx.rv(dist.Normal(0, 1), name="x")(key) 40 | 41 | _, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0)) 42 | assert set(trace.keys()) == {"x"} 43 | assert set(trace["x"].keys()) == {"value", "log_prob"} 44 | 45 | 46 | def test_delta_distribution(): 47 | coix.set_backend("coix.oryx") 48 | 49 | def model(key): 50 | x = random.normal(key) 51 | return coryx.rv(dist.Delta(x, 5.0), name="x")(key) 52 | 53 | _, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0)) 54 | assert set(trace.keys()) == {"x"} 55 | 56 | 57 | def test_detach(): 58 | coix.set_backend("coix.oryx") 59 | 60 | def model(x): 61 | return coryx.rv(dist.Delta(x, 0.0), name="x")(None) * x 62 | 63 | x = 2.0 64 | np.testing.assert_allclose(jax.grad(coix.detach(model))(x), x) 65 | 66 | 67 | def test_detach_vmap(): 68 | coix.set_backend("coix.oryx") 69 | 70 | def model(x): 71 | return coryx.rv(dist.Normal(x, 1.0), name="x")(random.PRNGKey(0)) 72 | 73 | outs = coix.detach(jax.vmap(model))(jnp.ones(2)) 74 | np.testing.assert_allclose(outs[0], outs[1]) 75 | 76 | 77 | def test_distribution(): 78 | coix.set_backend("coix.oryx") 79 | 80 | def model(key): 81 | x = random.normal(key) 82 | return coryx.rv(dist.Delta(x, 5.0), name="x")(key) 83 | 84 | f = coix.oryx.call_and_reap_tags( 85 | coix.oryx.tag_distribution(model), coix.oryx.DISTRIBUTION 86 | ) 87 | assert set(f(random.PRNGKey(0))[1][coix.oryx.DISTRIBUTION].keys()) == {"x"} 88 | 89 | 90 | def test_empirical_program(): 91 | coix.set_backend("coix.oryx") 92 | 93 | def model(x): 94 | trace = { 95 | "x": {"value": x, "log_prob": 11.0}, 96 | "y": {"value": x + 1, "log_prob": 9.0, "is_observed": True}, 97 | } 98 | return coix.empirical(0.0, trace, {})() 99 | 100 | _, trace, _ = coix.traced_evaluate(model)(1.0) 101 | samples = {name: site["value"] for name, site in trace.items()} 102 | jax.tree.map(np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0}) 103 | assert "is_observed" not in trace["x"] 104 | assert trace["y"]["is_observed"] 105 | 106 | 107 | def test_factor(): 108 | coix.set_backend("coix.oryx") 109 | 110 | def model(x): 111 | return coryx.factor(x, name="x") 112 | 113 | _, trace, _ = coix.traced_evaluate(model)(10.0) 114 | assert "x" in trace 115 | np.testing.assert_allclose(trace["x"]["log_prob"], 10.0) 116 | 117 | 118 | def test_log_prob_detach(): 119 | coix.set_backend("coix.oryx") 120 | 121 | def model(loc): 122 | x = coryx.rv(dist.Normal(loc, 1), name="x")(random.PRNGKey(0)) 123 | return x 124 | 125 | def actual_fn(x): 126 | return coix.traced_evaluate(coix.detach(model))(x)[1]["x"]["log_prob"] 127 | 128 | def expected_fn(x): 129 | return dist.Normal(x, 1).log_prob(model(1.0)) 130 | 131 | actual = jax.grad(actual_fn)(1.0) 132 | expect = jax.grad(expected_fn)(1.0) 133 | np.testing.assert_allclose(actual, expect) 134 | 135 | 136 | def test_observed(): 137 | coix.set_backend("coix.oryx") 138 | 139 | def model(a): 140 | return coryx.rv(dist.Delta(a, 3.0), obs=1.0, name="x") + a 141 | 142 | _, trace, _ = coix.traced_evaluate(model)(2.0) 143 | assert "x" in trace 144 | np.testing.assert_allclose(trace["x"]["value"], 1.0) 145 | assert trace["x"]["is_observed"] 146 | 147 | 148 | def test_stick_the_landing(): 149 | coix.set_backend("coix.oryx") 150 | 151 | def model(lp): 152 | return coryx.rv(dist.Delta(0.0, lp), name="x")(None) 153 | 154 | def p(x): 155 | return coix.traced_evaluate(coix.detach(model))(x)[1]["x"]["log_prob"] 156 | 157 | def q(x): 158 | model_stl = coix.detach(coix.stick_the_landing(model)) 159 | return coix.traced_evaluate(model_stl)(x)[1]["x"]["log_prob"] 160 | 161 | np.testing.assert_allclose(jax.grad(p)(5.0), 1.0) 162 | np.testing.assert_allclose(jax.grad(q)(5.0), 0.0) 163 | 164 | 165 | def test_substitute(): 166 | coix.set_backend("coix.oryx") 167 | 168 | def model(key): 169 | return coryx.rv(dist.Delta(1.0, 5.0), name="x")(key) 170 | 171 | expected = {"x": 9.0} 172 | _, trace, _ = coix.traced_evaluate(model, expected)(random.PRNGKey(0)) 173 | actual = {"x": trace["x"]["value"]} 174 | jax.tree.map(np.testing.assert_allclose, actual, expected) 175 | 176 | 177 | def test_suffix(): 178 | coix.set_backend("coix.oryx") 179 | 180 | def model(x): 181 | return coryx.rv(dist.Delta(x, 5.0), name="x")(None) 182 | 183 | f = coix.oryx.call_and_reap_tags( 184 | coix.core.suffix(model), coix.oryx.RANDOM_VARIABLE 185 | ) 186 | jax.tree.map( 187 | np.testing.assert_allclose, 188 | f(1.0)[1][coix.oryx.RANDOM_VARIABLE], 189 | {"x_PREV_": 1.0}, 190 | ) 191 | -------------------------------------------------------------------------------- /coix/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities.""" 16 | 17 | import functools 18 | import time 19 | 20 | import jax 21 | from jax import random 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | 26 | def get_systematic_resampling_indices(log_weights, rng_key, num_samples): 27 | """Gets resampling indices based on systematic resampling.""" 28 | n = log_weights.shape[0] 29 | # TODO(phandu): It might be more numerical stable if we work in log space. 30 | weight = jax.nn.softmax(log_weights, axis=0) 31 | cummulative_weight = weight.cumsum(axis=0) 32 | cummulative_weight = cummulative_weight / cummulative_weight[-1] 33 | cummulative_weight = cummulative_weight.reshape((n, -1)).swapaxes(0, 1) 34 | m = cummulative_weight.shape[0] 35 | if rng_key is not None: 36 | uniform = jax.random.uniform(rng_key, (m,)) 37 | else: 38 | uniform = np.random.rand(m) 39 | positions = (uniform[:, None] + np.arange(num_samples)) / num_samples 40 | shift = np.arange(m)[:, None] 41 | cummulative_weight = (cummulative_weight + 2 * shift).reshape(-1) 42 | positions = (positions + 2 * shift).reshape(-1) 43 | index = cummulative_weight.searchsorted(positions) 44 | index = (index.reshape(m, num_samples) - n * shift).swapaxes(0, 1) 45 | return index.reshape((num_samples,) + log_weights.shape[1:]) 46 | 47 | 48 | def get_site_log_prob(site): 49 | if hasattr(site, "log_density"): 50 | return site.log_density 51 | else: 52 | return site["log_prob"] 53 | 54 | 55 | def get_site_value(site, detach=False): 56 | if hasattr(site, "value"): 57 | value = site.value 58 | else: 59 | value = site["value"] 60 | if detach and isinstance(value, jnp.ndarray): 61 | return jax.lax.stop_gradient(value) 62 | else: 63 | return value 64 | 65 | 66 | def is_observed_site(site): 67 | if hasattr(site, "tag"): 68 | return site.tag == "observed" 69 | else: 70 | return "is_observed" in site 71 | 72 | 73 | def can_extract_key(args): 74 | return ( 75 | args 76 | and isinstance(args[0], jnp.ndarray) 77 | and ( 78 | jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key) 79 | or ( 80 | (args[0].dtype == jnp.uint32) 81 | and (jnp.ndim(args[0]) >= 1) 82 | and (args[0].shape[-1] == 2) 83 | ) 84 | ) 85 | ) 86 | 87 | 88 | class _ChildModule: 89 | """A child of a bind module.""" 90 | 91 | def __init__(self, module, params, name): 92 | self.module = module 93 | self.params = params 94 | self.name = name 95 | 96 | def __getitem__(self, i): 97 | return functools.partial( 98 | self.module.apply, 99 | self.params, 100 | method=lambda n, *a, **kw: getattr(n, self.name)[i](*a, **kw), 101 | ) 102 | 103 | def __call__(self, *args, **kwargs): 104 | return self.module.apply( 105 | self.params, 106 | *args, 107 | method=lambda n, *a, **kw: getattr(n, self.name)(*a, **kw), 108 | **kwargs, 109 | ) 110 | 111 | 112 | class BindModule: 113 | """Like Flax's `module.bind(params)` but composed with JAX transforms.""" 114 | 115 | def __init__(self, module, params): 116 | self.module = module 117 | self.params = params 118 | for submodule in params["params"]: 119 | setattr( 120 | self, submodule, _ChildModule(self.module, self.params, submodule) 121 | ) 122 | for submodule in params["params"]: 123 | if "_" in submodule and submodule.split("_")[-1].isnumeric(): 124 | maybe_submodule_list = "_".join(submodule.split("_")[:-1]) 125 | if not hasattr(self, maybe_submodule_list): 126 | setattr( 127 | self, 128 | maybe_submodule_list, 129 | _ChildModule(self.module, self.params, maybe_submodule_list), 130 | ) 131 | for field in module.__annotations__: 132 | if field not in ("parent", "name"): 133 | setattr(self, field, getattr(module, field)) 134 | 135 | def __call__(self, *args, **kwargs): 136 | return self.module.apply(self.params, *args, **kwargs) 137 | 138 | 139 | def _skip_update(grad, opt_state, params): 140 | del params 141 | return jax.tree.map(jnp.zeros_like, grad), opt_state 142 | 143 | 144 | @functools.partial(jax.jit, donate_argnums=(0, 1, 2), static_argnums=(3,)) 145 | def _optimizer_update(params, opt_state, grads, optimizer): 146 | """Updates the parameters and the optimizer state.""" 147 | # Helpful metric to print out during training. 148 | squared_grad_norm = sum(jnp.square(p).sum() for p in jax.tree.leaves(grads)) 149 | do_update = jnp.isfinite(squared_grad_norm) 150 | grads = jax.tree.map(lambda x, y: x.astype(y.dtype), grads, params) 151 | updates, new_opt_state = optimizer.update(grads, opt_state, params) 152 | opt_state = jax.tree.map( 153 | lambda x, y: jnp.where(do_update, x, y), new_opt_state, opt_state 154 | ) 155 | params = jax.tree.map( 156 | lambda p, u: jnp.where(do_update, p + u, u), params, updates 157 | ) 158 | return params, opt_state, squared_grad_norm 159 | 160 | 161 | def train( 162 | loss_fn, 163 | init_params, 164 | optimizer, 165 | num_steps, 166 | dataloader=None, 167 | seed=0, 168 | jit_compile=True, 169 | eval_fn=None, 170 | log_every=None, 171 | init_step=0, 172 | opt_state=None, 173 | **kwargs, 174 | ): 175 | """Optimize the parameters.""" 176 | 177 | def step_fn(params, opt_state, *args, **kwargs): 178 | (_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)( 179 | params, *args, **kwargs 180 | ) 181 | params, opt_state, squared_grad_norm = _optimizer_update( 182 | params, opt_state, grads, optimizer 183 | ) 184 | metrics["squared_grad_norm"] = squared_grad_norm 185 | return params, opt_state, metrics 186 | 187 | if callable(jit_compile): 188 | maybe_jitted_step_fn = jit_compile(step_fn) 189 | else: 190 | maybe_jitted_step_fn = jax.jit(step_fn) if jit_compile else step_fn 191 | opt_state = optimizer.init(init_params) if opt_state is None else opt_state 192 | params = init_params 193 | run_key = random.PRNGKey(seed) if isinstance(seed, int) else seed 194 | log_every = max(num_steps // 20, 1) if log_every is None else log_every 195 | space = str(len(str(num_steps - 1))) 196 | kwargs = kwargs.copy() 197 | if eval_fn is not None: 198 | print("Evaluating with the initial params...", flush=True) 199 | tic = time.time() 200 | eval_fn(init_step, params, opt_state, metrics=None) 201 | print("Time to compile an eval step:", time.time() - tic, flush=True) 202 | print("Compiling the first train step...", flush=True) 203 | tic = time.time() 204 | metrics = None 205 | for step in range(init_step + 1, num_steps + 1): 206 | key = random.fold_in(run_key, step) 207 | args = (key, next(dataloader)) if dataloader is not None else (key,) 208 | params, opt_state, metrics = maybe_jitted_step_fn( 209 | params, opt_state, *args, **kwargs 210 | ) 211 | for name in kwargs: 212 | if name in metrics: 213 | kwargs[name] = metrics[name] 214 | if step == 1: 215 | print("Time to compile a train step:", time.time() - tic, flush=True) 216 | print("=====", flush=True) 217 | if (step == num_steps) or (step % log_every == 0): 218 | log = ("Step {:<" + space + "d}").format(step) 219 | for name, value in sorted(metrics.items()): 220 | if np.isscalar(value) or ( 221 | isinstance(value, (np.ndarray, jnp.ndarray)) and (value.ndim == 0) 222 | ): 223 | log += f" | {name} {float(value):10.4f}" 224 | print(log, flush=True) 225 | if eval_fn is not None: 226 | eval_fn(step, params, opt_state, metrics) 227 | return params, metrics 228 | 229 | 230 | def get_batch_ndims(xs): 231 | """Gets the number of same-size leading dimensions of the elements in xs.""" 232 | if not xs: 233 | return 0 234 | min_ndim = min(jnp.ndim(lp) for lp in xs) 235 | batch_ndims = 0 236 | for i in range(min_ndim): 237 | if len(set(jnp.shape(lp)[i] for lp in xs)) > 1: 238 | break 239 | batch_ndims = batch_ndims + 1 240 | return batch_ndims 241 | 242 | 243 | def get_log_weight(trace, batch_ndims): 244 | """Computes log weight of the trace and keeps its batch dimensions.""" 245 | log_weight = jnp.zeros((1,) * batch_ndims) 246 | for site in trace.values(): 247 | lp = get_site_log_prob(site) 248 | if is_observed_site(site): 249 | log_weight = log_weight + jnp.sum( 250 | lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) 251 | ) 252 | else: 253 | log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) 254 | return log_weight 255 | -------------------------------------------------------------------------------- /coix/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for util.py.""" 16 | 17 | import coix 18 | import jax 19 | import numpy as np 20 | import pytest 21 | 22 | 23 | @pytest.mark.parametrize("seed", [0, None]) 24 | def test_systematic_resampling_uniform(seed): 25 | log_weights = np.zeros(5) 26 | rng_key = jax.random.PRNGKey(seed) if seed is not None else None 27 | num_samples = 5 28 | resample_indices = coix.util.get_systematic_resampling_indices( 29 | log_weights, rng_key, num_samples 30 | ) 31 | np.testing.assert_allclose(resample_indices, np.arange(5)) 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = coix 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | git clean -dfx build 21 | git clean -dfx examples 22 | git clean -dfx notebooks 23 | git clean -f getting_started.rst 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | git clean -dfx examples 26 | git clean -dfx notebooks 27 | git clean -f getting_started.rst 28 | -------------------------------------------------------------------------------- /docs/_static/anneal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/anneal.png -------------------------------------------------------------------------------- /docs/_static/anneal_oryx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/anneal_oryx.png -------------------------------------------------------------------------------- /docs/_static/bmnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/bmnist.gif -------------------------------------------------------------------------------- /docs/_static/dmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/dmm.png -------------------------------------------------------------------------------- /docs/_static/dmm_oryx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/dmm_oryx.png -------------------------------------------------------------------------------- /docs/_static/gmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/gmm.png -------------------------------------------------------------------------------- /docs/_static/gmm_oryx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/gmm_oryx.png -------------------------------------------------------------------------------- /docs/_static/tutorial_part1_vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/tutorial_part1_vae.png -------------------------------------------------------------------------------- /docs/_static/tutorial_part2_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/tutorial_part2_api.png -------------------------------------------------------------------------------- /docs/_static/tutorial_part3_smcs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/docs/_static/tutorial_part3_smcs.png -------------------------------------------------------------------------------- /docs/_templates/breadcrumbs.html: -------------------------------------------------------------------------------- 1 | {%- extends "sphinx_rtd_theme/breadcrumbs.html" %} 2 | 3 | {% set display_vcs_links = display_vcs_links if display_vcs_links is defined else True %} 4 | 5 | {% block breadcrumbs_aside %} 6 |
  • 7 | {% if hasdoc(pagename) and display_vcs_links %} 8 | {% if display_github %} 9 | {% if check_meta and 'github_url' in meta %} 10 | 11 | {{ _('Edit on GitHub') }} 12 | {% else %} 13 | {% if 'examples/index' in pagename %} 14 | {{ _('Edit on GitHub') }} 15 | {% elif 'examples/' in pagename %} 16 | {{ _('Edit on GitHub') }} 17 | {% else %} 18 | {{ _('Edit on GitHub') }} 19 | {% endif %} 20 | {% endif %} 21 | {% elif show_source and source_url_prefix %} 22 | {{ _('View page source') }} 23 | {% elif show_source and has_source and sourcename %} 24 | {{ _('View page source') }} 25 | {% endif %} 26 | {% endif %} 27 |
  • 28 | {% endblock %} 29 | -------------------------------------------------------------------------------- /docs/algo.rst: -------------------------------------------------------------------------------- 1 | Inference algorithms 2 | ==================== 3 | 4 | .. automodule:: coix.algo 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Program combinators 2 | =================== 3 | 4 | .. automodule:: coix.api 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import sys 5 | 6 | import nbsphinx 7 | import sphinx_rtd_theme 8 | 9 | # import pkg_resources 10 | 11 | # -*- coding: utf-8 -*- 12 | # 13 | # Configuration file for the Sphinx documentation builder. 14 | # 15 | # This file does only contain a selection of the most common options. For a 16 | # full list see the documentation: 17 | # http://www.sphinx-doc.org/en/master/config 18 | 19 | # -- Path setup -------------------------------------------------------------- 20 | 21 | # If extensions (or modules to document with autodoc) are in another directory, 22 | # add these directories to sys.path here. If the directory is relative to the 23 | # documentation root, use os.path.abspath to make it absolute, like shown here. 24 | # 25 | sys.path.insert(0, os.path.abspath("..")) 26 | 27 | 28 | os.environ["SPHINX_BUILD"] = "1" 29 | 30 | # -- Project information ----------------------------------------------------- 31 | 32 | project = "coix" 33 | copyright = "The coix Authors" 34 | author = "The coix Authors" 35 | 36 | version = "" 37 | 38 | if "READTHEDOCS" not in os.environ: 39 | # if developing locally, use coix.__version__ as version 40 | from coix import __version__ # noqaE402 41 | 42 | version = __version__ 43 | 44 | # Add "Edit on GitHub" button on the upper right corner of local docs. 45 | html_context = {"github_version": "main", "display_github": True} 46 | 47 | # release version 48 | release = version 49 | 50 | 51 | # -- General configuration --------------------------------------------------- 52 | 53 | # If your documentation needs a minimal Sphinx version, state it here. 54 | # 55 | # needs_sphinx = '1.0' 56 | 57 | # Add any Sphinx extension module names here, as strings. They can be 58 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 59 | # ones. 60 | extensions = [ 61 | "nbsphinx", 62 | "sphinxcontrib.jquery", 63 | "sphinx.ext.autodoc", 64 | "sphinx.ext.doctest", 65 | "sphinx.ext.imgconverter", 66 | "sphinx.ext.intersphinx", 67 | "sphinx.ext.mathjax", 68 | "sphinx.ext.napoleon", 69 | "sphinx.ext.viewcode", 70 | "sphinx_gallery.gen_gallery", 71 | "sphinx_search.extension", 72 | ] 73 | 74 | # Enable documentation inheritance 75 | 76 | autodoc_inherit_docstrings = True 77 | 78 | # autodoc_default_options = { 79 | # 'member-order': 'bysource', 80 | # 'show-inheritance': True, 81 | # 'special-members': True, 82 | # 'undoc-members': True, 83 | # 'exclude-members': '__dict__,__module__,__weakref__', 84 | # } 85 | 86 | # Add any paths that contain templates here, relative to this directory. 87 | templates_path = ["_templates"] 88 | 89 | # The suffix(es) of source filenames. 90 | # You can specify multiple suffix as a list of string: 91 | # 92 | # source_suffix = ['.rst', '.md'] 93 | # NOTE: `.rst` is the default suffix of sphinx, and nbsphinx will 94 | # automatically add support for `.ipynb` suffix. 95 | 96 | # do not execute cells 97 | nbsphinx_execute = "never" 98 | 99 | # Don't add .txt suffix to source files: 100 | html_sourcelink_suffix = "" 101 | 102 | # The master toctree document. 103 | master_doc = "index" 104 | 105 | # The language for content autogenerated by Sphinx. Refer to documentation 106 | # for a list of supported languages. 107 | # 108 | # This is also used if you do content translation via gettext catalogs. 109 | # Usually you set "language" from the command line for these cases. 110 | language = "en" 111 | 112 | # List of patterns, relative to source directory, that match files and 113 | # directories to ignore when looking for source files. 114 | # This pattern also affects html_static_path and html_extra_path . 115 | exclude_patterns = [ 116 | ".ipynb_checkpoints", 117 | "examples/*ipynb", 118 | "examples/*py", 119 | ] 120 | 121 | # The name of the Pygments (syntax highlighting) style to use. 122 | pygments_style = "sphinx" 123 | 124 | 125 | # do not prepend module name to functions 126 | add_module_names = False 127 | 128 | 129 | # This is processed by Jinja2 and inserted before each notebook 130 | nbsphinx_prolog = r""" 131 | {% set docname = 'notebooks/' + env.doc2path(env.docname, base=None).split('/')[-1] %} 132 | :github_url: https://github.com/jax-ml/coix/blob/main/{{ docname }} 133 | 134 | .. raw:: html 135 | 136 |
    137 | Interactive online version: 138 | 139 | 140 | Open In Colab 142 | 143 | 144 |
    145 | """ # noqa: E501 146 | 147 | 148 | # -- Copy README files 149 | 150 | # replace "# coix" by "# Getting Started with Coix" 151 | with open("../README.md", "rt") as f: 152 | lines = f.readlines() 153 | start_line = 0 154 | for i, line in enumerate(lines): 155 | start_line = i 156 | if "# coix" == line.rstrip(): 157 | break 158 | lines = lines[start_line:] 159 | lines[0] = "# Getting Started with Coix\n" 160 | text = "\n".join(lines) 161 | 162 | with open("getting_started.rst", "wt") as f: 163 | f.write(nbsphinx.markdown2rst(text)) 164 | 165 | 166 | # -- Copy notebook files 167 | 168 | if not os.path.exists("notebooks"): 169 | os.makedirs("notebooks") 170 | 171 | if not os.path.exists("notebooks/figures"): 172 | os.makedirs("notebooks/figures") 173 | 174 | for src_file in glob.glob("../notebooks/*.ipynb"): 175 | shutil.copy(src_file, "notebooks/") 176 | 177 | for src_file in glob.glob("../notebooks/figures/*"): 178 | shutil.copy(src_file, "notebooks/figures/") 179 | 180 | 181 | # add index file to `notebooks` path, `:orphan:` is used to 182 | # tell sphinx that this rst file needs not to be appeared in toctree 183 | # with open("../notebooks/index.rst", "rt") as f1: 184 | # with open("tutorials/index.rst", "wt") as f2: 185 | # f2.write(":orphan:\n\n") 186 | # f2.write(f1.read()) 187 | 188 | 189 | # -- Convert scripts to notebooks 190 | 191 | sphinx_gallery_conf = { 192 | "examples_dirs": ["../examples"], 193 | "gallery_dirs": ["examples"], 194 | # only execute files beginning with plot_ 195 | "filename_pattern": "/plot_", 196 | "ignore_pattern": "__init__", 197 | # not display Total running time of the script because we do not execute it 198 | "min_reported_time": 1, 199 | } 200 | 201 | 202 | # -- Resolve sphinx 7.3.5 warnings 203 | 204 | suppress_warnings = ["config.cache"] 205 | 206 | 207 | # -- Add thumbnails images 208 | 209 | nbsphinx_thumbnails = {} 210 | 211 | for src_file in glob.glob("../notebooks/*.ipynb") + glob.glob( 212 | "../examples/*.py" 213 | ): 214 | toctree_path = "notebooks/" if src_file.endswith("ipynb") else "examples/" 215 | filename = os.path.splitext(src_file.split("/")[-1])[0] 216 | img_path = "_static/" + filename + ".png" 217 | # use Coix logo if not exist png file 218 | if not os.path.exists(img_path): 219 | img_path = "_static/" + filename + ".gif" 220 | if not os.path.exists(img_path): 221 | img_path = "_static/coix_logo.png" 222 | nbsphinx_thumbnails[toctree_path + filename] = img_path 223 | 224 | 225 | # -- Options for HTML output ------------------------------------------------- 226 | 227 | # logo 228 | # html_logo = "_static/img/coix_logo_wide.png" 229 | 230 | # logo 231 | # html_favicon = "_static/img/favicon/favicon.ico" 232 | 233 | # The theme to use for HTML and HTML Help pages. See the documentation for 234 | # a list of builtin themes. 235 | # 236 | html_theme = "sphinx_rtd_theme" 237 | 238 | # Theme options are theme-specific and customize the look and feel of a theme 239 | # further. For a list of options available for each theme, see the 240 | # documentation. 241 | # 242 | # html_theme_options = {} 243 | 244 | # Add any paths that contain custom static files (such as style sheets) here, 245 | # relative to this directory. They are copied after the builtin static files, 246 | # so a file named "default.css" will overwrite the builtin "default.css". 247 | html_static_path = ["_static"] 248 | # html_style = "css/coix.css" 249 | 250 | # Custom sidebar templates, must be a dictionary that maps document names 251 | # to template names. 252 | # 253 | # The default sidebars (for documents that don't match any pattern) are 254 | # defined by theme itself. Builtin themes are using these templates by 255 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 256 | # 'searchbox.html']``. 257 | # 258 | # html_sidebars = {} 259 | 260 | 261 | # -- Options for HTMLHelp output --------------------------------------------- 262 | 263 | # Output file base name for HTML help builder. 264 | htmlhelp_basename = "coixdoc" 265 | 266 | 267 | # -- Options for LaTeX output ------------------------------------------------ 268 | 269 | latex_elements = { 270 | # The paper size ('letterpaper' or 'a4paper'). 271 | # 272 | # 'papersize': 'letterpaper', 273 | # The font size ('10pt', '11pt' or '12pt'). 274 | # 275 | # 'pointsize': '10pt', 276 | # Additional stuff for the LaTeX preamble. 277 | # 278 | "preamble": r""" 279 | \usepackage{pmboxdraw} 280 | \usepackage{alphabeta} 281 | """, 282 | # Latex figure (float) alignment 283 | # 284 | # 'figure_align': 'htbp', 285 | } 286 | 287 | # Grouping the document tree into LaTeX files. List of tuples 288 | # (source start file, target name, title, 289 | # author, documentclass [howto, manual, or own class]). 290 | latex_documents = [( 291 | master_doc, 292 | "coix.tex", 293 | "Coix Documentation", 294 | "The coix Authors", 295 | "manual", 296 | )] 297 | 298 | # -- Options for manual page output ------------------------------------------ 299 | 300 | # One entry per manual page. List of tuples 301 | # (source start file, name, description, authors, manual section). 302 | man_pages = [(master_doc, "Coix", "Coix Documentation", [author], 1)] 303 | 304 | # -- Options for Texinfo output ---------------------------------------------- 305 | 306 | # Grouping the document tree into Texinfo files. List of tuples 307 | # (source start file, target name, title, author, 308 | # dir menu entry, description, category) 309 | texinfo_documents = [( 310 | master_doc, 311 | "Coix", 312 | "Coix Documentation", 313 | author, 314 | "Coix", 315 | "Inference Combinators in JAX", 316 | "Miscellaneous", 317 | )] 318 | 319 | 320 | # -- Extension configuration ------------------------------------------------- 321 | 322 | # -- Options for intersphinx extension --------------------------------------- 323 | 324 | # Example configuration for intersphinx: refer to the Python standard library. 325 | intersphinx_mapping = { 326 | "python": ("https://docs.python.org/3/", None), 327 | "numpy": ("http://docs.scipy.org/doc/numpy/", None), 328 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 329 | "numpyro": ("http://num.pyro.ai/en/stable/", None), 330 | } 331 | -------------------------------------------------------------------------------- /docs/core.rst: -------------------------------------------------------------------------------- 1 | Program transforms 2 | ================== 3 | 4 | .. automodule:: coix.core 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/jax-ml/coix 2 | 3 | 4 | Coix Documentation 5 | ================== 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | getting_started 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: API and Developer Reference 15 | 16 | algo 17 | api 18 | core 19 | loss 20 | util 21 | Change Log 22 | 23 | .. nbgallery:: 24 | :maxdepth: 1 25 | :caption: Tutorials and Examples 26 | :name: tutorials 27 | 28 | notebooks/tutorial_part1_vae 29 | notebooks/tutorial_part2_api 30 | notebooks/tutorial_part3_smcs 31 | examples/anneal 32 | examples/gmm 33 | examples/dmm 34 | examples/bmnist 35 | examples/anneal_oryx 36 | examples/gmm_oryx 37 | examples/dmm_oryx 38 | 39 | Indices and tables 40 | ================== 41 | 42 | * :ref:`genindex` 43 | * :ref:`modindex` 44 | * :ref:`search` 45 | -------------------------------------------------------------------------------- /docs/loss.rst: -------------------------------------------------------------------------------- 1 | Inference objectives 2 | ==================== 3 | 4 | .. automodule:: coix.loss 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | ipython 2 | jax 3 | jaxlib 4 | Jinja2 5 | matplotlib 6 | nbsphinx 7 | numpy 8 | numpyro 9 | pillow 10 | pylab-sdk 11 | pyyaml 12 | readthedocs-sphinx-search 13 | sphinx>=5 14 | sphinx-gallery 15 | sphinx_rtd_theme 16 | -------------------------------------------------------------------------------- /docs/util.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. automodule:: coix.util 5 | :members: 6 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Code Examples 2 | ============= 3 | 4 | Examples for Coix. 5 | 6 | `View source files on github`__ 7 | 8 | .. _github: https://github.com/jax-ml/coix/tree/main/examples 9 | 10 | __ github_ 11 | -------------------------------------------------------------------------------- /examples/anneal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Example: Annealed Variational Inference in NumPyro 17 | ================================================== 18 | 19 | This example illustrates how to construct an inference program based on the NVI 20 | algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of 21 | the reference. We will use the NumPyro (default) backend for this example. 22 | 23 | **References** 24 | 25 | 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. 26 | 27 | .. image:: ../_static/anneal.png 28 | :align: center 29 | 30 | """ 31 | 32 | import argparse 33 | from functools import partial 34 | 35 | import coix 36 | import flax 37 | import flax.linen as nn 38 | import jax 39 | from jax import random 40 | import jax.numpy as jnp 41 | import matplotlib.pyplot as plt 42 | import numpy as np 43 | import numpyro 44 | import numpyro.distributions as dist 45 | import optax 46 | 47 | # %% 48 | # First, we define the neural networks for the targets and kernels. 49 | 50 | 51 | class AnnealKernel(nn.Module): 52 | 53 | @nn.compact 54 | def __call__(self, x): 55 | h = nn.Dense(50)(x) 56 | h = nn.relu(h) 57 | loc = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) + x 58 | scale_raw = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) 59 | return loc, nn.softplus(scale_raw) 60 | 61 | 62 | class AnnealDensity(nn.Module): 63 | M = 8 64 | 65 | @nn.compact 66 | def __call__(self, x, index=0): 67 | beta_raw = self.param("beta_raw", lambda _: -jnp.ones(self.M - 2)) 68 | beta = nn.sigmoid( 69 | beta_raw[0] + jnp.pad(jnp.cumsum(nn.softplus(beta_raw[1:])), (1, 0)) 70 | ) 71 | beta = jnp.pad(beta, (1, 1), constant_values=(0, 1)) 72 | beta_k = beta[index] 73 | 74 | angles = 2 * jnp.arange(1, self.M + 1) * jnp.pi / self.M 75 | mu = 10 * jnp.stack([jnp.sin(angles), jnp.cos(angles)], -1) 76 | sigma = jnp.sqrt(0.5) 77 | target_density = nn.logsumexp( 78 | dist.Normal(mu, sigma).log_prob(x[..., None, :]).sum(-1), -1 79 | ) 80 | init_proposal = dist.Normal(0, 5).log_prob(x).sum(-1) 81 | return beta_k * target_density + (1 - beta_k) * init_proposal 82 | 83 | 84 | class AnnealKernelList(nn.Module): 85 | M = 8 86 | 87 | @nn.compact 88 | def __call__(self, x, index=0): 89 | if self.is_mutable_collection("params"): 90 | vmap_net = nn.vmap( 91 | AnnealKernel, variable_axes={"params": 0}, split_rngs={"params": True} 92 | ) 93 | out = vmap_net(name="kernel")( 94 | jnp.broadcast_to(x, (self.M - 1,) + x.shape) 95 | ) 96 | return jax.tree.map(lambda x: x[index], out) 97 | params = self.scope.get_variable("params", "kernel") 98 | params_i = jax.tree.map(lambda x: x[index], params) 99 | return AnnealKernel(name="kernel").apply( 100 | flax.core.freeze({"params": params_i}), x 101 | ) 102 | 103 | 104 | class AnnealNetwork(nn.Module): 105 | 106 | def setup(self): 107 | self.forward_kernels = AnnealKernelList() 108 | self.reverse_kernels = AnnealKernelList() 109 | self.anneal_density = AnnealDensity() 110 | 111 | def __call__(self, x): 112 | self.reverse_kernels(x) 113 | self.anneal_density(x) 114 | return self.forward_kernels(x) 115 | 116 | 117 | # %% 118 | # Then, we define the targets and kernels as in Section E.1. 119 | 120 | 121 | def anneal_target(network, k=0): 122 | x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event()) 123 | anneal_density = network.anneal_density(x, index=k) 124 | # We make "anneal_density" a latent site so that it does not contribute 125 | # to the likelihood weighting of the first proposal. 126 | numpyro.sample("anneal_density", dist.Unit(anneal_density)) 127 | return ({"x": x},) 128 | 129 | 130 | def anneal_forward(network, inputs, k=0): 131 | mu, sigma = network.forward_kernels(inputs["x"], index=k) 132 | return numpyro.sample("x", dist.Normal(mu, sigma).to_event(1)) 133 | 134 | 135 | def anneal_reverse(network, inputs, k=0): 136 | mu, sigma = network.reverse_kernels(inputs["x"], index=k) 137 | return numpyro.sample("x", dist.Normal(mu, sigma).to_event(1)) 138 | 139 | 140 | # %% 141 | # Finally, we create the anneal inference program, define the loss function, 142 | # run the training loop, and plot the results. 143 | 144 | 145 | def make_anneal(params, unroll=False, num_particles=10): 146 | network = coix.util.BindModule(AnnealNetwork(), params) 147 | # Add particle dimension and construct a program. 148 | vmap = lambda p: numpyro.plate("particle", num_particles, dim=-1)(p) 149 | targets = lambda k: vmap(partial(anneal_target, network, k=k)) 150 | forwards = lambda k: vmap(partial(anneal_forward, network, k=k)) 151 | reverses = lambda k: vmap(partial(anneal_reverse, network, k=k)) 152 | if unroll: # to unroll the algorithm, we provide a list of programs 153 | targets = [targets(k) for k in range(8)] 154 | forwards = [forwards(k) for k in range(7)] 155 | reverses = [reverses(k) for k in range(7)] 156 | program = coix.algo.nvi_rkl(targets, forwards, reverses, num_targets=8) 157 | return program 158 | 159 | 160 | def loss_fn(params, key, num_particles, unroll=False): 161 | # Run the program and get metrics. 162 | program = make_anneal(params, num_particles=num_particles, unroll=unroll) 163 | _, _, metrics = coix.traced_evaluate(program, seed=key)() 164 | return metrics["loss"], metrics 165 | 166 | 167 | def main(args): 168 | lr = args.learning_rate 169 | num_steps = args.num_steps 170 | num_particles = args.num_particles 171 | unroll = args.unroll_loop 172 | 173 | anneal_net = AnnealNetwork() 174 | init_params = anneal_net.init(random.PRNGKey(0), jnp.zeros(2)) 175 | 176 | anneal_params, _ = coix.util.train( 177 | partial(loss_fn, num_particles=num_particles, unroll=unroll), 178 | init_params, 179 | optax.adam(lr), 180 | num_steps, 181 | jit_compile=True, 182 | ) 183 | 184 | rng_keys = random.split(random.PRNGKey(1), 100) 185 | 186 | def eval_program(seed): 187 | p = make_anneal(anneal_params, unroll=unroll, num_particles=1000) 188 | out, trace, metrics = coix.traced_evaluate(p, seed=seed)() 189 | return out, trace, metrics 190 | 191 | _, trace, metrics = jax.vmap(eval_program)(rng_keys) 192 | 193 | metrics.pop("log_weight") 194 | anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics) 195 | print(anneal_metrics) 196 | 197 | plt.figure(figsize=(8, 8)) 198 | x = trace["x"]["value"].reshape((-1, 2)) 199 | H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100) 200 | plt.imshow(H.T) 201 | plt.show() 202 | 203 | 204 | if __name__ == "__main__": 205 | parser = argparse.ArgumentParser(description="Annealing example") 206 | parser.add_argument("--num_particles", nargs="?", default=36, type=int) 207 | parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) 208 | parser.add_argument("--num-steps", nargs="?", default=20000, type=int) 209 | parser.add_argument("--unroll-loop", action="store_true") 210 | parser.add_argument( 211 | "--device", default="cpu", type=str, help='use "cpu" or "gpu".' 212 | ) 213 | args = parser.parse_args() 214 | 215 | numpyro.set_platform(args.device) 216 | 217 | main(args) 218 | -------------------------------------------------------------------------------- /examples/anneal_oryx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Example: Annealed Variational Inference in Oryx 17 | =============================================== 18 | 19 | This example illustrates how to construct an inference program based on the NVI 20 | algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of 21 | the reference. We will use the Oryx backend for this example. 22 | 23 | **References** 24 | 25 | 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. 26 | 27 | .. image:: ../_static/anneal_oryx.png 28 | :align: center 29 | 30 | """ 31 | 32 | import argparse 33 | from functools import partial 34 | 35 | import coix 36 | import coix.oryx as coryx 37 | import flax 38 | import flax.linen as nn 39 | import jax 40 | from jax import random 41 | import jax.numpy as jnp 42 | import matplotlib.pyplot as plt 43 | import numpy as np 44 | import numpyro 45 | import numpyro.distributions as dist 46 | import optax 47 | 48 | # %% 49 | # First, we define the neural networks for the targets and kernels. 50 | 51 | 52 | class AnnealKernel(nn.Module): 53 | 54 | @nn.compact 55 | def __call__(self, x): 56 | h = nn.Dense(50)(x) 57 | h = nn.relu(h) 58 | loc = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) + x 59 | scale_raw = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) 60 | return loc, nn.softplus(scale_raw) 61 | 62 | 63 | class AnnealDensity(nn.Module): 64 | M = 8 65 | 66 | @nn.compact 67 | def __call__(self, x, index=0): 68 | beta_raw = self.param("beta_raw", lambda _: -jnp.ones(self.M - 2)) 69 | beta = nn.sigmoid( 70 | beta_raw[0] + jnp.pad(jnp.cumsum(nn.softplus(beta_raw[1:])), (1, 0)) 71 | ) 72 | beta = jnp.pad(beta, (1, 1), constant_values=(0, 1)) 73 | beta_k = beta[index] 74 | 75 | angles = 2 * jnp.arange(1, self.M + 1) * jnp.pi / self.M 76 | mu = 10 * jnp.stack([jnp.sin(angles), jnp.cos(angles)], -1) 77 | sigma = jnp.sqrt(0.5) 78 | target_density = nn.logsumexp( 79 | dist.Normal(mu, sigma).log_prob(x[..., None, :]).sum(-1), -1 80 | ) 81 | init_proposal = dist.Normal(0, 5).log_prob(x).sum(-1) 82 | return beta_k * target_density + (1 - beta_k) * init_proposal 83 | 84 | 85 | class AnnealKernelList(nn.Module): 86 | M = 8 87 | 88 | @nn.compact 89 | def __call__(self, x, index=0): 90 | if self.is_mutable_collection("params"): 91 | vmap_net = nn.vmap( 92 | AnnealKernel, variable_axes={"params": 0}, split_rngs={"params": True} 93 | ) 94 | out = vmap_net(name="kernel")( 95 | jnp.broadcast_to(x, (self.M - 1,) + x.shape) 96 | ) 97 | return jax.tree.map(lambda x: x[index], out) 98 | params = self.scope.get_variable("params", "kernel") 99 | params_i = jax.tree.map(lambda x: x[index], params) 100 | return AnnealKernel(name="kernel").apply( 101 | flax.core.freeze({"params": params_i}), x 102 | ) 103 | 104 | 105 | class AnnealNetwork(nn.Module): 106 | 107 | def setup(self): 108 | self.forward_kernels = AnnealKernelList() 109 | self.reverse_kernels = AnnealKernelList() 110 | self.anneal_density = AnnealDensity() 111 | 112 | def __call__(self, x): 113 | self.reverse_kernels(x) 114 | self.anneal_density(x) 115 | return self.forward_kernels(x) 116 | 117 | 118 | # %% 119 | # Then, we define the targets and kernels as in Section E.1. 120 | 121 | 122 | def anneal_target(network, key, k=0): 123 | key_out, key = random.split(key) 124 | x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key) 125 | coryx.factor(network.anneal_density(x, index=k), name="anneal_density") 126 | return key_out, {"x": x} 127 | 128 | 129 | def anneal_forward(network, key, inputs, k=0): 130 | mu, sigma = network.forward_kernels(inputs["x"], index=k) 131 | return coryx.rv(dist.Normal(mu, sigma), name="x")(key) 132 | 133 | 134 | def anneal_reverse(network, key, inputs, k=0): 135 | mu, sigma = network.reverse_kernels(inputs["x"], index=k) 136 | return coryx.rv(dist.Normal(mu, sigma), name="x")(key) 137 | 138 | 139 | # %% 140 | # Finally, we create the anneal inference program, define the loss function, 141 | # run the training loop, and plot the results. 142 | 143 | 144 | def make_anneal(params, unroll=False): 145 | network = coix.util.BindModule(AnnealNetwork(), params) 146 | # Add particle dimension and construct a program. 147 | targets = lambda k: jax.vmap(partial(anneal_target, network, k=k)) 148 | forwards = lambda k: jax.vmap(partial(anneal_forward, network, k=k)) 149 | reverses = lambda k: jax.vmap(partial(anneal_reverse, network, k=k)) 150 | if unroll: # to unroll the algorithm, we provide a list of programs 151 | targets = [targets(k) for k in range(8)] 152 | forwards = [forwards(k) for k in range(7)] 153 | reverses = [reverses(k) for k in range(7)] 154 | program = coix.algo.nvi_rkl(targets, forwards, reverses, num_targets=8) 155 | return program 156 | 157 | 158 | def loss_fn(params, key, num_particles, unroll=False): 159 | # Prepare data for the program. 160 | rng_keys = random.split(key, num_particles) 161 | 162 | # Run the program and get metrics. 163 | program = make_anneal(params, unroll=unroll) 164 | _, _, metrics = coix.traced_evaluate(program)(rng_keys) 165 | return metrics["loss"], metrics 166 | 167 | 168 | def main(args): 169 | lr = args.learning_rate 170 | num_steps = args.num_steps 171 | num_particles = args.num_particles 172 | unroll = args.unroll_loop 173 | 174 | anneal_net = AnnealNetwork() 175 | init_params = anneal_net.init(random.PRNGKey(0), jnp.zeros(2)) 176 | 177 | anneal_params, _ = coix.util.train( 178 | partial(loss_fn, num_particles=num_particles, unroll=unroll), 179 | init_params, 180 | optax.adam(lr), 181 | num_steps, 182 | jit_compile=True, 183 | ) 184 | 185 | rng_keys = random.split(random.PRNGKey(1), 100000).reshape((100, 1000, 2)) 186 | _, trace, metrics = coix.traced_evaluate( 187 | jax.vmap(make_anneal(anneal_params, unroll=unroll)) 188 | )(rng_keys) 189 | 190 | metrics.pop("log_weight") 191 | anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics) 192 | print(anneal_metrics) 193 | 194 | plt.figure(figsize=(8, 8)) 195 | x = trace["x"]["value"].reshape((-1, 2)) 196 | H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100) 197 | plt.imshow(H.T) 198 | plt.show() 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = argparse.ArgumentParser(description="Annealing example") 203 | parser.add_argument("--num_particles", nargs="?", default=36, type=int) 204 | parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) 205 | parser.add_argument("--num-steps", nargs="?", default=20000, type=int) 206 | parser.add_argument("--unroll-loop", action="store_true") 207 | parser.add_argument( 208 | "--device", default="cpu", type=str, help='use "cpu" or "gpu".' 209 | ) 210 | args = parser.parse_args() 211 | 212 | numpyro.set_platform(args.device) 213 | coix.set_backend("coix.oryx") 214 | 215 | main(args) 216 | -------------------------------------------------------------------------------- /examples/bmnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Example: Time Series Model - Bouncing MNIST in NumPyro 17 | ====================================================== 18 | 19 | This example illustrates how to construct an inference program based on the APGS 20 | sampler [1] for BMNIST. The details of BMNIST can be found in the sections 21 | 6.4 and F.3 of the reference. We will use the NumPyro (default) backend for this 22 | example. 23 | 24 | **References** 25 | 26 | 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural 27 | sufficient statistics. ICML 2020. 28 | 29 | .. image:: ../_static/bmnist.gif 30 | :align: center 31 | 32 | """ 33 | 34 | import argparse 35 | from functools import partial 36 | 37 | import coix 38 | import flax.linen as nn 39 | import jax 40 | from jax import random 41 | import jax.numpy as jnp 42 | import matplotlib.animation as animation 43 | from matplotlib.patches import Rectangle 44 | import matplotlib.pyplot as plt 45 | import numpyro 46 | import numpyro.distributions as dist 47 | import optax 48 | import tensorflow as tf 49 | import tensorflow_datasets as tfds 50 | 51 | # %% 52 | # First, let's load the moving mnist dataset. 53 | 54 | 55 | def load_dataset(*, is_training, batch_size): 56 | ds = tfds.load("moving_mnist:1.0.0", split="test") 57 | ds = ds.repeat() 58 | if is_training: 59 | ds = ds.shuffle(10 * batch_size, seed=0) 60 | map_fn = lambda x: x["image_sequence"][..., :10, :, :, 0] / 255 61 | else: 62 | map_fn = lambda x: x["image_sequence"][..., 0] / 255 63 | ds = ds.batch(batch_size) 64 | ds = ds.map(map_fn) 65 | return iter(tfds.as_numpy(ds)) 66 | 67 | 68 | def get_digit_mean(): 69 | ds, ds_info = tfds.load("mnist:3.0.1", split="train", with_info=True) 70 | ds = tfds.as_numpy(ds.batch(ds_info.splits["train"].num_examples)) 71 | digit_mean = next(iter(ds))["image"].squeeze(-1).mean(axis=0) 72 | return digit_mean / 255 73 | 74 | 75 | # %% 76 | # Next, we define the neural proposals for the Gibbs kernels and the neural 77 | # decoder for the generative model. 78 | 79 | 80 | def scale_and_translate(image, where, out_size): 81 | translate = abs(image.shape[-1] - out_size) * (where[..., ::-1] + 1) / 2 82 | return jax.image.scale_and_translate( 83 | image, 84 | (out_size, out_size), 85 | (0, 1), 86 | jnp.ones(2), 87 | translate, 88 | method="cubic", 89 | antialias=False, 90 | ) 91 | 92 | 93 | def crop_frames(frames, z_where, digit_size=28): 94 | # frames: time.frame_size.frame_size 95 | # z_where: (digits).time.2 96 | # out: (digits).time.digit_size.digit_size 97 | if frames.ndim == 2 and z_where.ndim == 1: 98 | return scale_and_translate(frames, z_where, out_size=digit_size) 99 | elif frames.ndim == 3 and z_where.ndim == 2: 100 | in_axes = (0, 0) 101 | elif frames.ndim == 3 and z_where.ndim == 3: 102 | in_axes = (None, 0) 103 | elif frames.ndim == z_where.ndim: 104 | in_axes = (0, 0) 105 | elif frames.ndim > z_where.ndim: 106 | in_axes = (0, None) 107 | else: 108 | in_axes = (None, 0) 109 | return jax.vmap(partial(crop_frames, digit_size=digit_size), in_axes)( 110 | frames, z_where 111 | ) 112 | 113 | 114 | def embed_digits(digits, z_where, frame_size=64): 115 | # digits: (digits). .digit_size.digit_size 116 | # z_where: (digits).(time).2 117 | # out: (digits).(time).frame_size.frame_size 118 | if digits.ndim == 2 and z_where.ndim == 1: 119 | return scale_and_translate(digits, z_where, out_size=frame_size) 120 | elif digits.ndim == 2 and z_where.ndim == 2: 121 | in_axes = (None, 0) 122 | elif digits.ndim >= z_where.ndim: 123 | in_axes = (0, 0) 124 | else: 125 | in_axes = (None, 0) 126 | return jax.vmap(partial(embed_digits, frame_size=frame_size), in_axes)( 127 | digits, z_where 128 | ) 129 | 130 | 131 | def conv2d(frames, digits): 132 | # frames: (time).frame_size.frame_size 133 | # digits: (digits). .digit_size.digit_size 134 | # out: (digits).(time).conv_size .conv_size 135 | if frames.ndim == 2 and digits.ndim == 2: 136 | return jax.scipy.signal.convolve2d(frames, digits, mode="valid") 137 | elif frames.ndim == digits.ndim: 138 | in_axes = (0, 0) 139 | elif frames.ndim > digits.ndim: 140 | in_axes = (0, None) 141 | else: 142 | in_axes = (None, 0) 143 | return jax.vmap(conv2d, in_axes=in_axes)(frames, digits) 144 | 145 | 146 | class EncoderWhat(nn.Module): 147 | 148 | @nn.compact 149 | def __call__(self, digits): 150 | x = digits.reshape(digits.shape[:-2] + (-1,)) 151 | x = nn.Dense(400)(x) 152 | x = nn.relu(x) 153 | x = nn.Dense(200)(x) 154 | x = nn.relu(x) 155 | 156 | x = x.sum(-2) # sum/mean across time 157 | loc_raw = nn.Dense(10)(x) 158 | scale_raw = 0.5 * nn.Dense(10)(x) 159 | return loc_raw, jnp.exp(scale_raw) 160 | 161 | 162 | class EncoderWhere(nn.Module): 163 | 164 | @nn.compact 165 | def __call__(self, frame_conv): 166 | x = frame_conv.reshape(frame_conv.shape[:-2] + (-1,)) 167 | x = nn.softmax(x, -1) 168 | x = nn.Dense(200)(x) 169 | x = nn.relu(x) 170 | x = nn.Dense(200)(x) 171 | x = x.reshape(x.shape[:-1] + (2, 100)) 172 | x = nn.relu(x) 173 | loc_raw = nn.Dense(2)(x[..., 0, :]) 174 | scale_raw = 0.5 * nn.Dense(2)(x[..., 1, :]) 175 | return nn.tanh(loc_raw), jnp.exp(scale_raw) 176 | 177 | 178 | class DecoderWhat(nn.Module): 179 | 180 | @nn.compact 181 | def __call__(self, z_what): 182 | x = nn.Dense(200)(z_what) 183 | x = nn.relu(x) 184 | x = nn.Dense(400)(x) 185 | x = nn.relu(x) 186 | x = nn.Dense(784)(x) 187 | logits = x.reshape(x.shape[:-1] + (28, 28)) 188 | return nn.sigmoid(logits) 189 | 190 | 191 | class BMNISTAutoEncoder(nn.Module): 192 | digit_mean: jnp.ndarray 193 | frame_size: int 194 | 195 | def setup(self): 196 | self.encode_what = EncoderWhat() 197 | self.encode_where = EncoderWhere() 198 | self.decode_what = DecoderWhat() 199 | 200 | def __call__(self, frames): 201 | # Heuristic procedure to setup initial parameters. 202 | frames_conv = conv2d(frames, self.digit_mean) 203 | z_where, _ = self.encode_where(frames_conv) 204 | 205 | digits = crop_frames(frames, z_where, 28) 206 | z_what, _ = self.encode_what(digits) 207 | 208 | digit_recon = self.decode_what(z_what) 209 | frames_recon = embed_digits(digit_recon, z_where, self.frame_size) 210 | return frames_recon 211 | 212 | 213 | # %% 214 | # Then, we define the target and kernels as in Section 6.4. 215 | 216 | 217 | def bmnist_target(network, inputs, D=2, T=10): 218 | z_what = numpyro.sample( 219 | "z_what", dist.Normal(0, 1).expand([D, 10]).to_event() 220 | ) 221 | digits = network.decode_what(z_what) # can cache this 222 | 223 | z_where = [] 224 | # p = [] 225 | for d in range(D): 226 | z_where_d = [] 227 | z_where_d_t = jnp.zeros(2) 228 | for t in range(T): 229 | scale = 1 if t == 0 else 0.1 230 | z_where_d_t = numpyro.sample( 231 | f"z_where_{d}_{t}", dist.Normal(z_where_d_t, scale).to_event(1) 232 | ) 233 | z_where_d.append(z_where_d_t) 234 | z_where_d = jnp.stack(z_where_d, -2) 235 | z_where.append(z_where_d) 236 | z_where = jnp.stack(z_where, -3) 237 | 238 | p = embed_digits(digits, z_where, network.frame_size) 239 | p = dist.util.clamp_probs(p.sum(-4)) # sum across digits 240 | frames = numpyro.sample("frames", dist.Bernoulli(p).to_event(3), obs=inputs) 241 | 242 | out = { 243 | "frames": frames, 244 | "frames_recon": p, 245 | "z_what": z_what, 246 | "digits": jax.lax.stop_gradient(digits), 247 | **{f"z_where_{t}": z_where[..., t, :] for t in range(T)}, 248 | } 249 | return (out,) 250 | 251 | 252 | def kernel_where(network, inputs, D=2, t=0): 253 | if not isinstance(inputs, dict): 254 | inputs = { 255 | "frames": inputs, 256 | "digits": jnp.repeat(jnp.expand_dims(network.digit_mean, -3), D, -3), 257 | } 258 | 259 | frame = inputs["frames"][..., t, :, :] 260 | z_where_t = [] 261 | for d in range(D): 262 | digit = inputs["digits"][..., d, :, :] 263 | x_conv = conv2d(frame, digit) 264 | loc, scale = network.encode_where(x_conv) 265 | z_where_d_t = numpyro.sample( 266 | f"z_where_{d}_{t}", dist.Normal(loc, scale).to_event(1) 267 | ) 268 | z_where_t.append(z_where_d_t) 269 | frame_recon = embed_digits(digit, z_where_d_t, network.frame_size) 270 | frame = frame - frame_recon 271 | z_where_t = jnp.stack(z_where_t, -2) 272 | 273 | out = {**inputs, **{f"z_where_{t}": z_where_t}} 274 | return (out,) 275 | 276 | 277 | def kernel_what(network, inputs, T=10): 278 | z_where = jnp.stack([inputs[f"z_where_{t}"] for t in range(T)], -2) 279 | digits = crop_frames(inputs["frames"], z_where, 28) 280 | loc, scale = network.encode_what(digits) 281 | z_what = numpyro.sample("z_what", dist.Normal(loc, scale).to_event(2)) 282 | 283 | out = {**inputs, **{"z_what": z_what}} 284 | return (out,) 285 | 286 | 287 | # %% 288 | # Finally, we create the bmnist inference program, define the loss function, 289 | # run the training loop, and plot the results. 290 | 291 | 292 | def make_bmnist(params, bmnist_net, T=10, num_sweeps=5, num_particles=10): 293 | network = coix.util.BindModule(bmnist_net, params) 294 | # Add particle dimension and construct a program. 295 | vmap = lambda p: numpyro.plate("particle", num_particles, dim=-2)(p) 296 | target = vmap(partial(bmnist_target, network, D=2, T=T)) 297 | kernels = [] 298 | for t in range(T): 299 | kernels.append(vmap(partial(kernel_where, network, D=2, t=t))) 300 | kernels.append(vmap(partial(kernel_what, network, T=T))) 301 | program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) 302 | return program 303 | 304 | 305 | def loss_fn(params, key, batch, bmnist_net, num_sweeps, num_particles): 306 | # Prepare data for the program. 307 | shuffle_rng, rng_key = random.split(key) 308 | batch = random.permutation(shuffle_rng, batch, axis=1) 309 | T = batch.shape[-3] 310 | 311 | # Run the program and get metrics. 312 | program = make_bmnist(params, bmnist_net, T, num_sweeps, num_particles) 313 | _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) 314 | for metric_name in ["log_Z", "log_density", "loss"]: 315 | metrics[metric_name] = metrics[metric_name] / batch.shape[0] 316 | return metrics["loss"], metrics 317 | 318 | 319 | def main(args): 320 | lr = args.learning_rate 321 | num_steps = args.num_steps 322 | batch_size = args.batch_size 323 | num_sweeps = args.num_sweeps 324 | num_particles = args.num_particles 325 | 326 | train_ds = load_dataset(is_training=True, batch_size=batch_size) 327 | test_ds = load_dataset(is_training=False, batch_size=1) 328 | digit_mean = get_digit_mean() 329 | 330 | test_data = next(test_ds) 331 | frame_size = test_data.shape[-1] 332 | bmnist_net = BMNISTAutoEncoder(digit_mean=digit_mean, frame_size=frame_size) 333 | init_params = bmnist_net.init(jax.random.PRNGKey(0), test_data[0]) 334 | bmnist_params, _ = coix.util.train( 335 | partial( 336 | loss_fn, 337 | bmnist_net=bmnist_net, 338 | num_sweeps=num_sweeps, 339 | num_particles=num_particles, 340 | ), 341 | init_params, 342 | optax.adam(lr), 343 | num_steps, 344 | train_ds, 345 | ) 346 | 347 | T_test = test_data.shape[-3] 348 | program = make_bmnist( 349 | bmnist_params, bmnist_net, T_test, num_sweeps, num_particles 350 | ) 351 | out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))( 352 | test_data 353 | ) 354 | out = out[0] 355 | 356 | prop_cycle = plt.rcParams["axes.prop_cycle"] 357 | colors = prop_cycle.by_key()["color"] 358 | fig, axes = plt.subplots(1, 2, figsize=(12, 6)) 359 | 360 | def animate(i): 361 | axes[0].cla() 362 | axes[0].imshow(test_data[0, i]) 363 | axes[1].cla() 364 | axes[1].imshow(out["frames_recon"][0, 0, i]) 365 | for d in range(2): 366 | where = 0.5 * (out[f"z_where_{i}"][0, 0, d] + 1) * (frame_size - 28) - 0.5 367 | color = colors[d] 368 | axes[0].add_patch( 369 | Rectangle(where, 28, 28, edgecolor=color, lw=3, fill=False) 370 | ) 371 | 372 | plt.rc("animation", html="jshtml") 373 | plt.tight_layout() 374 | ani = animation.FuncAnimation(fig, animate, frames=range(20), interval=300) 375 | writer = animation.PillowWriter(fps=15) 376 | ani.save("bmnist.gif", writer=writer) 377 | plt.show() 378 | 379 | 380 | if __name__ == "__main__": 381 | parser = argparse.ArgumentParser(description="Annealing example") 382 | parser.add_argument("--batch-size", nargs="?", default=5, type=int) 383 | parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) 384 | parser.add_argument("--num_particles", nargs="?", default=10, type=int) 385 | parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float) 386 | parser.add_argument("--num-steps", nargs="?", default=20000, type=int) 387 | parser.add_argument( 388 | "--device", default="gpu", type=str, help='use "cpu" or "gpu".' 389 | ) 390 | args = parser.parse_args() 391 | 392 | tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. 393 | numpyro.set_platform(args.device) 394 | 395 | main(args) 396 | -------------------------------------------------------------------------------- /examples/dmm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Example: Deep Generative Mixture Model in NumPyro 16 | ================================================= 17 | 18 | This example illustrates how to construct an inference program based on the APGS 19 | sampler [1] for DMM. The details of DMM can be found in the sections 6.3 and 20 | F.2 of the reference. We will use the NumPyro (default) backend for this 21 | example. 22 | 23 | **References** 24 | 25 | 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural 26 | sufficient statistics. ICML 2020. 27 | 28 | .. image:: ../_static/dmm.png 29 | :align: center 30 | 31 | """ 32 | 33 | import argparse 34 | from functools import partial 35 | 36 | import coix 37 | import flax.linen as nn 38 | import jax 39 | from jax import random 40 | import jax.numpy as jnp 41 | import matplotlib.pyplot as plt 42 | import numpy as np 43 | import numpyro 44 | import numpyro.distributions as dist 45 | from numpyro.ops.indexing import Vindex 46 | import optax 47 | import tensorflow as tf 48 | 49 | # %% 50 | # First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. 51 | 52 | 53 | def simulate_rings(num_instances=1, N=200, seed=0): 54 | np.random.seed(seed) 55 | mu = np.random.normal(0, 3, (num_instances, 1, 4, 2)) 56 | angle = np.linspace(0, 2 * np.pi, N // 8, endpoint=False) 57 | shift = np.random.uniform( 58 | 0, (2 * np.pi) // (N // 8), size=(num_instances, 1, 2, 4) 59 | ) 60 | angle = angle[:, None, None] + shift 61 | angle = angle.reshape((num_instances, N // 4, 4)) 62 | loc = np.stack([np.cos(angle), np.sin(angle)], -1) 63 | noise = np.random.normal(0, 0.1, loc.shape) 64 | x = (mu + loc + noise).reshape((num_instances, N, 2)) 65 | shuffle_idx = np.random.uniform(size=x.shape[:2] + (1,)).argsort(axis=1) 66 | return np.take_along_axis(x, shuffle_idx, axis=1) 67 | 68 | 69 | def load_dataset(split, *, batch_size): 70 | if split == "train": 71 | num_data = 20000 72 | num_points = 200 73 | seed = 0 74 | else: 75 | num_data = batch_size 76 | num_points = 600 77 | seed = 1 78 | data = simulate_rings(num_data, num_points, seed=seed) 79 | ds = tf.data.Dataset.from_tensor_slices(data) 80 | ds = ds.repeat() 81 | if split == "train": 82 | ds = ds.shuffle(10 * batch_size, seed=0) 83 | ds = ds.batch(batch_size) 84 | return ds.as_numpy_iterator() 85 | 86 | 87 | # %% 88 | # Next, we define the neural proposals for the Gibbs kernels and the neural 89 | # decoder for the generative model. 90 | 91 | 92 | class EncoderMu(nn.Module): 93 | 94 | @nn.compact 95 | def __call__(self, x): 96 | s = nn.Dense(32)(x) 97 | s = nn.tanh(s) 98 | s = nn.Dense(8)(s) 99 | 100 | t = nn.Dense(32)(x) 101 | t = nn.tanh(t) 102 | t = nn.Dense(4)(t) 103 | t = nn.softmax(t, -1) 104 | 105 | s, t = jnp.expand_dims(s, -2), jnp.expand_dims(t, -1) 106 | st = (s * t).sum(-3) / t.sum(-3) 107 | 108 | shape = st.shape[:-1] + (2,) 109 | x = jnp.concatenate([st, jnp.zeros(shape), jnp.full(shape, 10.0)], -1) 110 | x = nn.Dense(64)(x) 111 | x = x.reshape(x.shape[:-1] + (2, 32)) 112 | x = nn.tanh(x) 113 | loc = nn.Dense(2)(x[..., 0, :]) 114 | scale_raw = 0.5 * nn.Dense(2)(x[..., 1, :]) 115 | return loc, jnp.exp(scale_raw) 116 | 117 | 118 | class EncoderC(nn.Module): 119 | 120 | @nn.compact 121 | def __call__(self, x): 122 | x = nn.Dense(32)(x) 123 | x = nn.relu(x) # nn.tanh(x) 124 | logits = nn.Dense(1)(x).squeeze(-1) 125 | return logits + jnp.log(jnp.ones(4) / 4) 126 | 127 | 128 | class EncoderH(nn.Module): 129 | 130 | @nn.compact 131 | def __call__(self, x): 132 | x = nn.Dense(64)(x) 133 | x = x.reshape(x.shape[:-1] + (2, 32)) 134 | x = nn.tanh(x) 135 | alpha_raw = nn.Dense(1)(x[..., 0, :]).squeeze(-1) 136 | beta_raw = nn.Dense(1)(x[..., 1, :]).squeeze(-1) 137 | return jnp.exp(alpha_raw), jnp.exp(beta_raw) 138 | 139 | 140 | class DecoderH(nn.Module): 141 | 142 | @nn.compact 143 | def __call__(self, x): 144 | x = nn.Dense(32)(jnp.expand_dims(x, -1)) 145 | x = nn.tanh(x) 146 | x = nn.Dense(2)(x) 147 | angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) 148 | radius = 1.0 # self.param("radius", nn.initializers.ones, (1,)) 149 | return radius * angle 150 | 151 | 152 | class DMMAutoEncoder(nn.Module): 153 | 154 | def setup(self): 155 | self.encode_initial_mu = EncoderMu() 156 | self.encode_mu = EncoderMu() 157 | self.encode_c = EncoderC() 158 | self.encode_h = EncoderH() 159 | self.decode_h = DecoderH() 160 | 161 | def __call__(self, x): # N x D 162 | # Heuristic procedure to setup initial parameters. 163 | mu, _ = self.encode_initial_mu(x) # M x D 164 | 165 | xmu = jnp.expand_dims(x, -2) - mu 166 | logits = self.encode_c(xmu) # N x M 167 | c = jnp.argmax(logits, -1) # N 168 | 169 | loc = mu[c] # N x D 170 | alpha, beta = self.encode_h(x - loc) # N 171 | h = alpha / (alpha + beta) # N 172 | 173 | xch = jnp.concatenate([x, jax.nn.one_hot(c, 4), jnp.expand_dims(h, -1)], -1) 174 | mu, _ = self.encode_mu(xch) # M x D 175 | 176 | angle = self.decode_h(h) # N x D 177 | x_recon = mu[c] + angle # N x D 178 | return x_recon 179 | 180 | 181 | # %% 182 | # Then, we define the target and kernels as in Section 6.3. 183 | 184 | 185 | def dmm_target(network, inputs): 186 | mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event()) 187 | with numpyro.plate("N", inputs.shape[-2], dim=-1): 188 | c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) 189 | h = numpyro.sample("h", dist.Beta(1, 1)) 190 | x_recon = network.decode_h(h) + Vindex(mu)[..., c, :] 191 | x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs) 192 | 193 | out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} 194 | return (out,) 195 | 196 | 197 | def dmm_kernel_mu(network, inputs): 198 | if not isinstance(inputs, dict): 199 | inputs = {"x": inputs} 200 | 201 | if "c" in inputs: 202 | x = jnp.broadcast_to(inputs["x"], inputs["h"].shape + (2,)) 203 | c = jax.nn.one_hot(inputs["c"], 4) 204 | h = jnp.expand_dims(inputs["h"], -1) 205 | xch = jnp.concatenate([x, c, h], -1) 206 | loc, scale = network.encode_mu(xch) 207 | else: 208 | loc, scale = network.encode_initial_mu(inputs["x"]) 209 | loc, scale = jnp.expand_dims(loc, -3), jnp.expand_dims(scale, -3) 210 | mu = numpyro.sample("mu", dist.Normal(loc, scale).to_event(2)) 211 | 212 | out = {**inputs, **{"mu": mu}} 213 | return (out,) 214 | 215 | 216 | def dmm_kernel_c_h(network, inputs): 217 | x, mu = inputs["x"], inputs["mu"] 218 | xmu = jnp.expand_dims(x, -2) - mu 219 | logits = network.encode_c(xmu) 220 | with numpyro.plate("N", logits.shape[-2], dim=-1): 221 | c = numpyro.sample("c", dist.Categorical(logits=logits)) 222 | alpha, beta = network.encode_h(inputs["x"] - Vindex(mu)[..., c, :]) 223 | h = numpyro.sample("h", dist.Beta(alpha, beta)) 224 | 225 | out = {**inputs, **{"c": c, "h": h}} 226 | return (out,) 227 | 228 | 229 | # %% 230 | # Finally, we create the dmm inference program, define the loss function, 231 | # run the training loop, and plot the results. 232 | 233 | 234 | def make_dmm(params, num_sweeps=5, num_particles=10): 235 | network = coix.util.BindModule(DMMAutoEncoder(), params) 236 | # Add particle dimension and construct a program. 237 | vmap = lambda p: numpyro.plate("particle", num_particles, dim=-3)(p) 238 | target = vmap(partial(dmm_target, network)) 239 | kernel_mu = vmap(partial(dmm_kernel_mu, network)) 240 | kernel_c_h = vmap(partial(dmm_kernel_c_h, network)) 241 | kernels = [kernel_mu, kernel_c_h] 242 | program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) 243 | return program 244 | 245 | 246 | def loss_fn(params, key, batch, num_sweeps, num_particles): 247 | # Prepare data for the program. 248 | shuffle_rng, rng_key = random.split(key) 249 | batch = random.permutation(shuffle_rng, batch, axis=1) 250 | 251 | # Run the program and get metrics. 252 | program = make_dmm(params, num_sweeps, num_particles) 253 | _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) 254 | for metric_name in ["log_Z", "log_density", "loss"]: 255 | metrics[metric_name] = metrics[metric_name] / batch.shape[0] 256 | return metrics["loss"], metrics 257 | 258 | 259 | def main(args): 260 | lr = args.learning_rate 261 | num_steps = args.num_steps 262 | batch_size = args.batch_size 263 | num_sweeps = args.num_sweeps 264 | num_particles = args.num_particles 265 | 266 | train_ds = load_dataset("train", batch_size=batch_size) 267 | test_ds = load_dataset("test", batch_size=batch_size) 268 | 269 | init_params = DMMAutoEncoder().init( 270 | jax.random.PRNGKey(0), jnp.zeros((200, 2)) 271 | ) 272 | dmm_params, _ = coix.util.train( 273 | partial(loss_fn, num_sweeps=num_sweeps, num_particles=num_particles), 274 | init_params, 275 | optax.adam(lr), 276 | num_steps, 277 | train_ds, 278 | ) 279 | 280 | program = make_dmm(dmm_params, num_sweeps, num_particles) 281 | batch = next(test_ds) 282 | out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) 283 | out = out[0] 284 | 285 | _, axes = plt.subplots(2, 3, figsize=(15, 10)) 286 | for i in range(3): 287 | axes[0][i].scatter(out["x"][i, :, 0], out["x"][i, :, 1], marker=".") 288 | axes[1][i].scatter( 289 | out["x_recon"][0, i, :, 0], 290 | out["x_recon"][0, i, :, 1], 291 | c=out["c"][0, i], 292 | cmap="Accent", 293 | marker=".", 294 | ) 295 | axes[1][i].scatter( 296 | out["mu"][0, i, 0, :, 0], 297 | out["mu"][0, i, 0, :, 1], 298 | c=range(4), 299 | marker="x", 300 | cmap="Accent", 301 | ) 302 | plt.show() 303 | 304 | 305 | if __name__ == "__main__": 306 | parser = argparse.ArgumentParser(description="Annealing example") 307 | parser.add_argument("--batch-size", nargs="?", default=20, type=int) 308 | parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) 309 | parser.add_argument("--num_particles", nargs="?", default=10, type=int) 310 | parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) 311 | parser.add_argument("--num-steps", nargs="?", default=30000, type=int) 312 | parser.add_argument( 313 | "--device", default="gpu", type=str, help='use "cpu" or "gpu".' 314 | ) 315 | args = parser.parse_args() 316 | 317 | tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. 318 | numpyro.set_platform(args.device) 319 | 320 | main(args) 321 | -------------------------------------------------------------------------------- /examples/dmm_oryx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Example: Deep Generative Mixture Model in Oryx 17 | ============================================== 18 | 19 | This example illustrates how to construct an inference program based on the APGS 20 | sampler [1] for DMM. The details of DMM can be found in the sections 6.3 and 21 | F.2 of the reference. We will use the Oryx backend for this example. 22 | 23 | **References** 24 | 25 | 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural 26 | sufficient statistics. ICML 2020. 27 | 28 | .. image:: ../_static/dmm_oryx.png 29 | :align: center 30 | 31 | """ 32 | 33 | # %% 34 | # **Note:** The metrics seem to be incorrect in this example. 35 | 36 | import argparse 37 | from functools import partial 38 | 39 | import coix 40 | import coix.oryx as coryx 41 | import flax.linen as nn 42 | import jax 43 | from jax import random 44 | import jax.numpy as jnp 45 | import matplotlib.pyplot as plt 46 | import numpy as np 47 | import numpyro 48 | import numpyro.distributions as dist 49 | import optax 50 | import tensorflow as tf 51 | 52 | # %% 53 | # First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. 54 | 55 | 56 | def simulate_rings(num_instances=1, N=200, seed=0): 57 | np.random.seed(seed) 58 | mu = np.random.normal(0, 3, (num_instances, 1, 4, 2)) 59 | angle = np.linspace(0, 2 * np.pi, N // 8, endpoint=False) 60 | shift = np.random.uniform( 61 | 0, (2 * np.pi) // (N // 8), size=(num_instances, 1, 2, 4) 62 | ) 63 | angle = angle[:, None, None] + shift 64 | angle = angle.reshape((num_instances, N // 4, 4)) 65 | loc = np.stack([np.cos(angle), np.sin(angle)], -1) 66 | noise = np.random.normal(0, 0.1, loc.shape) 67 | x = (mu + loc + noise).reshape((num_instances, N, 2)) 68 | shuffle_idx = np.random.uniform(size=x.shape[:2] + (1,)).argsort(axis=1) 69 | return np.take_along_axis(x, shuffle_idx, axis=1) 70 | 71 | 72 | def load_dataset(split, *, is_training, batch_size): 73 | if split == "train": 74 | num_data = 20000 75 | num_points = 200 76 | seed = 0 77 | else: 78 | num_data = batch_size 79 | num_points = 600 80 | seed = 1 81 | data = simulate_rings(num_data, num_points, seed=seed) 82 | ds = tf.data.Dataset.from_tensor_slices(data) 83 | ds = ds.repeat() 84 | if split == "train": 85 | ds = ds.shuffle(10 * batch_size, seed=0) 86 | ds = ds.batch(batch_size) 87 | return ds.as_numpy_iterator() 88 | 89 | 90 | # %% 91 | # Next, we define the neural proposals for the Gibbs kernels and the neural 92 | # decoder for the generative model. 93 | 94 | 95 | class EncoderMu(nn.Module): 96 | 97 | @nn.compact 98 | def __call__(self, x): 99 | s = nn.Dense(32)(x) 100 | s = nn.tanh(s) 101 | s = nn.Dense(8)(s) 102 | 103 | t = nn.Dense(32)(x) 104 | t = nn.tanh(t) 105 | t = nn.Dense(4)(t) 106 | t = nn.softmax(t, -1) 107 | 108 | s, t = jnp.expand_dims(s, -2), jnp.expand_dims(t, -1) 109 | st = (s * t).sum(-3) / t.sum(-3) 110 | 111 | shape = st.shape[:-1] + (2,) 112 | x = jnp.concatenate([st, jnp.zeros(shape), jnp.full(shape, 10.0)], -1) 113 | x = nn.Dense(64)(x) 114 | x = x.reshape(x.shape[:-1] + (2, 32)) 115 | x = nn.tanh(x) 116 | loc = nn.Dense(2)(x[..., 0, :]) 117 | scale_raw = 0.5 * nn.Dense(2)(x[..., 1, :]) 118 | return loc, jnp.exp(scale_raw) 119 | 120 | 121 | class EncoderC(nn.Module): 122 | 123 | @nn.compact 124 | def __call__(self, x): 125 | x = nn.Dense(32)(x) 126 | x = nn.relu(x) # nn.tanh(x) 127 | logits = nn.Dense(1)(x).squeeze(-1) 128 | return logits + jnp.log(jnp.ones(4) / 4) 129 | 130 | 131 | class EncoderH(nn.Module): 132 | 133 | @nn.compact 134 | def __call__(self, x): 135 | x = nn.Dense(64)(x) 136 | x = x.reshape(x.shape[:-1] + (2, 32)) 137 | x = nn.tanh(x) 138 | alpha_raw = nn.Dense(1)(x[..., 0, :]).squeeze(-1) 139 | beta_raw = nn.Dense(1)(x[..., 1, :]).squeeze(-1) 140 | return jnp.exp(alpha_raw), jnp.exp(beta_raw) 141 | 142 | 143 | class DecoderH(nn.Module): 144 | 145 | @nn.compact 146 | def __call__(self, x): 147 | x = nn.Dense(32)(jnp.expand_dims(x, -1)) 148 | x = nn.tanh(x) 149 | x = nn.Dense(2)(x) 150 | angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) 151 | radius = 1.0 # self.param("radius", nn.initializers.ones, (1,)) 152 | return radius * angle 153 | 154 | 155 | class DMMAutoEncoder(nn.Module): 156 | 157 | def setup(self): 158 | self.encode_initial_mu = EncoderMu() 159 | self.encode_mu = EncoderMu() 160 | self.encode_c = EncoderC() 161 | self.encode_h = EncoderH() 162 | self.decode_h = DecoderH() 163 | 164 | def __call__(self, x): # N x D 165 | # Heuristic procedure to setup initial parameters. 166 | mu, _ = self.encode_initial_mu(x) # M x D 167 | 168 | xmu = jnp.expand_dims(x, -2) - mu 169 | logits = self.encode_c(xmu) # N x M 170 | c = jnp.argmax(logits, -1) # N 171 | 172 | loc = mu[c] # N x D 173 | alpha, beta = self.encode_h(x - loc) # N 174 | h = alpha / (alpha + beta) # N 175 | 176 | xch = jnp.concatenate([x, jax.nn.one_hot(c, 4), jnp.expand_dims(h, -1)], -1) 177 | mu, _ = self.encode_mu(xch) # M x D 178 | 179 | angle = self.decode_h(h) # N x D 180 | x_recon = mu[c] + angle # N x D 181 | return x_recon 182 | 183 | 184 | # %% 185 | # Then, we define the target and kernels as in Section 6.3. 186 | 187 | 188 | def dmm_target(network, key, inputs): 189 | key_out, key_mu, key_c, key_h = random.split(key, 4) 190 | N = inputs.shape[-2] 191 | 192 | mu = coryx.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu) 193 | c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) 194 | h = coryx.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h) 195 | x_recon = mu[c] + network.decode_h(h) 196 | x = coryx.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x") 197 | 198 | out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} 199 | return key_out, out 200 | 201 | 202 | def dmm_kernel_mu(network, key, inputs): 203 | if not isinstance(inputs, dict): 204 | inputs = {"x": inputs} 205 | key_out, key_mu = random.split(key) 206 | 207 | if "c" in inputs: 208 | x = inputs["x"] 209 | c = jax.nn.one_hot(inputs["c"], 4) 210 | h = jnp.expand_dims(inputs["h"], -1) 211 | xch = jnp.concatenate([x, c, h], -1) 212 | loc, scale = network.encode_mu(xch) 213 | else: 214 | loc, scale = network.encode_initial_mu(inputs["x"]) 215 | mu = coryx.rv(dist.Normal(loc, scale), name="mu")(key_mu) 216 | 217 | out = {**inputs, **{"mu": mu}} 218 | return key_out, out 219 | 220 | 221 | def dmm_kernel_c_h(network, key, inputs): 222 | key_out, key_c, key_h = random.split(key, 3) 223 | 224 | x, mu = inputs["x"], inputs["mu"] 225 | xmu = jnp.expand_dims(x, -2) - mu 226 | logits = network.encode_c(xmu) 227 | c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) 228 | alpha, beta = network.encode_h(x - mu[c]) 229 | h = coryx.rv(dist.Beta(alpha, beta), name="h")(key_h) 230 | 231 | out = {**inputs, **{"c": c, "h": h}} 232 | return key_out, out 233 | 234 | 235 | # %% 236 | # Finally, we create the dmm inference program, define the loss function, 237 | # run the training loop, and plot the results. Note that we are using 238 | # 10x less steps than the paper. 239 | 240 | 241 | def make_dmm(params, num_sweeps): 242 | network = coix.util.BindModule(DMMAutoEncoder(), params) 243 | # Add particle dimension and construct a program. 244 | target = jax.vmap(partial(dmm_target, network)) 245 | kernels = [ 246 | jax.vmap(partial(dmm_kernel_mu, network)), 247 | jax.vmap(partial(dmm_kernel_c_h, network)), 248 | ] 249 | program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) 250 | return program 251 | 252 | 253 | def loss_fn(params, key, batch, num_sweeps, num_particles): 254 | # Prepare data for the program. 255 | shuffle_rng, rng_key = random.split(key) 256 | batch = random.permutation(shuffle_rng, batch, axis=1) 257 | batch_rng = random.split(rng_key, batch.shape[0]) 258 | batch = jnp.repeat(batch[:, None], num_particles, axis=1) 259 | rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) 260 | 261 | # Run the program and get metrics. 262 | program = make_dmm(params, num_sweeps) 263 | _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) 264 | metrics = jax.tree.map( 265 | partial(jnp.mean, axis=0), metrics 266 | ) # mean across batch 267 | return metrics["loss"], metrics 268 | 269 | 270 | def main(args): 271 | lr = args.learning_rate 272 | num_steps = args.num_steps 273 | batch_size = args.batch_size 274 | num_sweeps = args.num_sweeps 275 | num_particles = args.num_particles 276 | 277 | train_ds = load_dataset("train", is_training=True, batch_size=batch_size) 278 | test_ds = load_dataset("test", is_training=False, batch_size=batch_size) 279 | 280 | init_params = DMMAutoEncoder().init( 281 | jax.random.PRNGKey(0), jnp.zeros((200, 2)) 282 | ) 283 | dmm_params, _ = coix.util.train( 284 | partial(loss_fn, num_sweeps=num_sweeps, num_particles=num_particles), 285 | init_params, 286 | optax.adam(lr), 287 | num_steps, 288 | train_ds, 289 | ) 290 | 291 | program = make_dmm(dmm_params, num_sweeps) 292 | batch = jnp.repeat(next(test_ds)[:, None], num_particles, axis=1) 293 | rng_keys = jax.vmap(partial(random.split, num=num_particles))( 294 | random.split(jax.random.PRNGKey(1), batch.shape[0]) 295 | ) 296 | _, out = jax.vmap(program)(rng_keys, batch) 297 | 298 | fig, axes = plt.subplots(2, 3, figsize=(15, 10)) 299 | for i in range(3): 300 | n = i 301 | axes[0][i].scatter(out["x"][n, 0, :, 0], out["x"][n, 0, :, 1], marker=".") 302 | axes[1][i].scatter( 303 | out["x_recon"][n, 0, :, 0], 304 | out["x_recon"][n, 0, :, 1], 305 | c=out["c"][n, 0], 306 | cmap="Accent", 307 | marker=".", 308 | ) 309 | axes[1][i].scatter( 310 | out["mu"][n, 0, :, 0], 311 | out["mu"][n, 0, :, 1], 312 | c=range(4), 313 | marker="x", 314 | cmap="Accent", 315 | ) 316 | plt.show() 317 | 318 | 319 | if __name__ == "__main__": 320 | parser = argparse.ArgumentParser(description="Annealing example") 321 | parser.add_argument("--batch-size", nargs="?", default=20, type=int) 322 | parser.add_argument("--num-sweeps", nargs="?", default=8, type=int) 323 | parser.add_argument("--num_particles", nargs="?", default=10, type=int) 324 | parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) 325 | parser.add_argument("--num-steps", nargs="?", default=30000, type=int) 326 | parser.add_argument( 327 | "--device", default="gpu", type=str, help='use "cpu" or "gpu".' 328 | ) 329 | args = parser.parse_args() 330 | 331 | tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. 332 | numpyro.set_platform(args.device) 333 | coix.set_backend("coix.oryx") 334 | 335 | main(args) 336 | -------------------------------------------------------------------------------- /examples/gmm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Example: Gaussian Mixture Model in NumPyro 16 | ========================================== 17 | 18 | This example illustrates how to construct an inference program based on the APGS 19 | sampler [1] for GMM. The details of GMM can be found in the sections 6.2 and 20 | F.1 of the reference. We will use the NumPyro (default) backend for this 21 | example. 22 | 23 | **References** 24 | 25 | 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural 26 | sufficient statistics. ICML 2020. 27 | 28 | .. image:: ../_static/gmm_oryx.png 29 | :align: center 30 | 31 | """ 32 | 33 | import argparse 34 | from functools import partial 35 | 36 | import coix 37 | import flax.linen as nn 38 | import jax 39 | from jax import random 40 | import jax.numpy as jnp 41 | from matplotlib.patches import Ellipse 42 | import matplotlib.pyplot as plt 43 | import numpy as np 44 | import numpyro 45 | import numpyro.distributions as dist 46 | from numpyro.ops.indexing import Vindex 47 | import optax 48 | import tensorflow as tf 49 | 50 | # %% 51 | # First, let's simulate a synthetic dataset of 2D Gaussian mixtures. 52 | 53 | 54 | def simulate_clusters(num_instances=1, N=60, seed=0): 55 | np.random.seed(seed) 56 | tau = np.random.gamma(2, 0.5, (num_instances, 4, 2)) 57 | mu_base = np.random.normal(0, 1, (num_instances, 4, 2)) 58 | mu = mu_base / np.sqrt(0.1 * tau) 59 | c = np.random.choice(np.arange(3), (num_instances, N)) 60 | mu_ = np.take_along_axis(mu, c[..., None], axis=1) 61 | tau_ = np.take_along_axis(tau, c[..., None], axis=1) 62 | eps = np.random.normal(0, 1, (num_instances, N, 2)) 63 | x = mu_ + eps / np.sqrt(tau_) 64 | return x, c 65 | 66 | 67 | def load_dataset(split, *, batch_size): 68 | if split == "train": 69 | num_data = 20000 70 | num_points = 60 71 | seed = 0 72 | else: 73 | num_data = batch_size 74 | num_points = 100 75 | seed = 1 76 | data, label = simulate_clusters(num_data, num_points, seed=seed) 77 | if split == "train": 78 | ds = tf.data.Dataset.from_tensor_slices(data) 79 | ds = ds.repeat() 80 | ds = ds.shuffle(10 * batch_size, seed=0) 81 | else: 82 | ds = tf.data.Dataset.from_tensor_slices((data, label)) 83 | ds = ds.repeat() 84 | ds = ds.batch(batch_size) 85 | return ds.as_numpy_iterator() 86 | 87 | 88 | # %% 89 | # Next, we define the neural proposals for the Gibbs kernels. 90 | 91 | 92 | class GMMEncoderMeanTau(nn.Module): 93 | 94 | @nn.compact 95 | def __call__(self, x): 96 | s = nn.Dense(2)(x) 97 | 98 | t = nn.Dense(3)(x) 99 | t = nn.softmax(t, -1) 100 | 101 | s, t = jnp.expand_dims(s, -2), jnp.expand_dims(t, -1) 102 | N = t.sum(-3) 103 | x = (t * s).sum(-3) 104 | x2 = (t * s**2).sum(-3) 105 | mu0, nu0, alpha0, beta0 = (0, 0.1, 2, 2) 106 | alpha = alpha0 + 0.5 * N 107 | beta = ( 108 | beta0 109 | + 0.5 * (x2 - x**2 / N) 110 | + 0.5 * N * nu0 / (N + nu0) * (x / N - mu0) ** 2 111 | ) 112 | mu = (mu0 * nu0 + x) / (nu0 + N) 113 | nu = nu0 + N 114 | return alpha, beta, mu, nu 115 | 116 | 117 | class GMMEncoderC(nn.Module): 118 | 119 | @nn.compact 120 | def __call__(self, x): 121 | x = nn.Dense(32)(x) 122 | x = nn.tanh(x) 123 | logits = nn.Dense(1)(x).squeeze(-1) 124 | return logits + jnp.log(jnp.ones(3) / 3) 125 | 126 | 127 | def broadcast_concatenate(*xs): 128 | shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) 129 | xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] 130 | return jnp.concatenate(xs, -1) 131 | 132 | 133 | class GMMEncoder(nn.Module): 134 | 135 | def setup(self): 136 | self.encode_initial_mean_tau = GMMEncoderMeanTau() 137 | self.encode_mean_tau = GMMEncoderMeanTau() 138 | self.encode_c = GMMEncoderC() 139 | 140 | def __call__(self, x): # N x D 141 | # Heuristic procedure to setup initial parameters. 142 | alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D 143 | tau = alpha / beta # M x D 144 | 145 | xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) 146 | logits = self.encode_c(xmt) # N x D 147 | c = jnp.argmax(logits, -1) # N 148 | 149 | xc = jnp.concatenate([x, jax.nn.one_hot(c, 3)], axis=-1) 150 | return self.encode_mean_tau(xc) 151 | 152 | 153 | # %% 154 | # Then, we define the target and kernels as in Section 6.2. 155 | 156 | 157 | def gmm_target(inputs): 158 | tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event()) 159 | mean = numpyro.sample( 160 | "mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)).to_event() 161 | ) 162 | with numpyro.plate("N", inputs.shape[-2], dim=-1): 163 | c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) 164 | loc = Vindex(mean)[..., c, :] 165 | scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :]) 166 | x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs) 167 | 168 | out = {"mean": mean, "tau": tau, "c": c, "x": x} 169 | return (out,) 170 | 171 | 172 | def gmm_kernel_mean_tau(network, inputs): 173 | if not isinstance(inputs, dict): 174 | inputs = {"x": inputs} 175 | 176 | if "c" in inputs: 177 | x = inputs["x"] 178 | c = jax.nn.one_hot(inputs["c"], 3) 179 | xc = broadcast_concatenate(x, c) 180 | alpha, beta, mu, nu = network.encode_mean_tau(xc) 181 | else: 182 | alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) 183 | alpha, beta, mu, nu = jax.tree.map( 184 | lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu) 185 | ) 186 | tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2)) 187 | mean = numpyro.sample( 188 | "mean", dist.Normal(mu, 1 / jnp.sqrt(tau * nu)).to_event(2) 189 | ) 190 | 191 | out = {**inputs, **{"mean": mean, "tau": tau}} 192 | return (out,) 193 | 194 | 195 | def gmm_kernel_c(network, inputs): 196 | x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] 197 | xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) 198 | logits = network.encode_c(xmt) 199 | with numpyro.plate("N", logits.shape[-2], dim=-1): 200 | c = numpyro.sample("c", dist.Categorical(logits=logits)) 201 | 202 | out = {**inputs, **{"c": c}} 203 | return (out,) 204 | 205 | 206 | # %% 207 | # Finally, we create the gmm inference program, define the loss function, 208 | # run the training loop, and plot the results. 209 | 210 | 211 | def make_gmm(params, num_sweeps, num_particles): 212 | network = coix.util.BindModule(GMMEncoder(), params) 213 | # Add particle dimension and construct a program. 214 | vmap = lambda p: numpyro.plate("particle", num_particles, dim=-3)(p) 215 | target = vmap(gmm_target) 216 | kernel_mean_tau = vmap(partial(gmm_kernel_mean_tau, network)) 217 | kernel_c = vmap(partial(gmm_kernel_c, network)) 218 | kernels = [kernel_mean_tau, kernel_c] 219 | program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) 220 | return program 221 | 222 | 223 | def loss_fn(params, key, batch, num_sweeps, num_particles): 224 | # Prepare data for the program. 225 | shuffle_rng, rng_key = random.split(key) 226 | batch = random.permutation(shuffle_rng, batch, axis=1) 227 | 228 | # Run the program and get metrics. 229 | program = make_gmm(params, num_sweeps, num_particles) 230 | _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) 231 | for metric_name in ["log_Z", "log_density", "loss"]: 232 | metrics[metric_name] = metrics[metric_name] / batch.shape[0] 233 | return metrics["loss"], metrics 234 | 235 | 236 | def main(args): 237 | lr = args.learning_rate 238 | num_steps = args.num_steps 239 | batch_size = args.batch_size 240 | num_sweeps = args.num_sweeps 241 | num_particles = args.num_particles 242 | 243 | train_ds = load_dataset("train", batch_size=batch_size) 244 | test_ds = load_dataset("test", batch_size=batch_size) 245 | 246 | init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2))) 247 | gmm_params, _ = coix.util.train( 248 | partial(loss_fn, num_sweeps=num_sweeps, num_particles=num_particles), 249 | init_params, 250 | optax.adam(lr), 251 | num_steps, 252 | train_ds, 253 | ) 254 | 255 | program = make_gmm(gmm_params, num_sweeps, num_particles) 256 | batch, label = next(test_ds) 257 | out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) 258 | out = out[0] 259 | 260 | _, axes = plt.subplots(2, 3, figsize=(15, 10)) 261 | for i in range(6): 262 | axes[i // 3][i % 3].scatter( 263 | batch[i, :, 0], 264 | batch[i, :, 1], 265 | marker=".", 266 | color=np.array(["c", "m", "y"])[label[i]], 267 | ) 268 | for j, c in enumerate(["r", "g", "b"]): 269 | ellipse = Ellipse( 270 | xy=(out["mean"][0, i, 0, j, 0], out["mean"][0, i, 0, j, 1]), 271 | width=4 / jnp.sqrt(out["tau"][0, i, 0, j, 0]), 272 | height=4 / jnp.sqrt(out["tau"][0, i, 0, j, 1]), 273 | fc=c, 274 | alpha=0.3, 275 | ) 276 | axes[i // 3][i % 3].add_patch(ellipse) 277 | plt.show() 278 | 279 | 280 | if __name__ == "__main__": 281 | parser = argparse.ArgumentParser(description="Annealing example") 282 | parser.add_argument("--batch-size", nargs="?", default=20, type=int) 283 | parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) 284 | parser.add_argument("--num_particles", nargs="?", default=10, type=int) 285 | parser.add_argument("--learning-rate", nargs="?", default=2.5e-4, type=float) 286 | parser.add_argument("--num-steps", nargs="?", default=200000, type=int) 287 | parser.add_argument( 288 | "--device", default="gpu", type=str, help='use "cpu" or "gpu".' 289 | ) 290 | args = parser.parse_args() 291 | 292 | tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. 293 | numpyro.set_platform(args.device) 294 | 295 | main(args) 296 | -------------------------------------------------------------------------------- /examples/gmm_oryx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The coix Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Example: Gaussian Mixture Model in Oryx 17 | ======================================= 18 | 19 | This example illustrates how to construct an inference program for GMM, based on 20 | the APGS sampler [1]. The details of GMM can be found in the sections 6.2 and 21 | F.1 of the reference. We will use the Oryx backend for this example. 22 | 23 | **References** 24 | 25 | 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural 26 | sufficient statistics. ICML 2020. 27 | 28 | .. image:: ../_static/gmm_oryx.png 29 | :align: center 30 | 31 | """ 32 | 33 | # %% 34 | # **Note:** The metrics seem to be incorrect in this example. 35 | 36 | import argparse 37 | from functools import partial 38 | 39 | import coix 40 | import coix.oryx as coryx 41 | import flax.linen as nn 42 | import jax 43 | from jax import random 44 | import jax.numpy as jnp 45 | from matplotlib.patches import Ellipse 46 | import matplotlib.pyplot as plt 47 | import numpy as np 48 | import numpyro 49 | import numpyro.distributions as dist 50 | import optax 51 | import tensorflow as tf 52 | 53 | # %% 54 | # First, let's simulate a synthetic dataset of Gaussian clusters. 55 | 56 | 57 | def simulate_clusters(num_instances=1, N=60, seed=0): 58 | np.random.seed(seed) 59 | tau = np.random.gamma(2, 0.5, (num_instances, 4, 2)) 60 | mu_base = np.random.normal(0, 1, (num_instances, 4, 2)) 61 | mu = mu_base / np.sqrt(0.1 * tau) 62 | c = np.random.choice(np.arange(3), (num_instances, N)) 63 | mu_ = np.take_along_axis(mu, c[..., None], axis=1) 64 | tau_ = np.take_along_axis(tau, c[..., None], axis=1) 65 | eps = np.random.normal(0, 1, (num_instances, N, 2)) 66 | x = mu_ + eps / np.sqrt(tau_) 67 | return x, c 68 | 69 | 70 | def load_dataset(split, *, batch_size): 71 | if split == "train": 72 | num_data = 20000 73 | num_points = 60 74 | seed = 0 75 | else: 76 | num_data = batch_size 77 | num_points = 100 78 | seed = 1 79 | data, label = simulate_clusters(num_data, num_points, seed=seed) 80 | if split == "train": 81 | ds = tf.data.Dataset.from_tensor_slices(data) 82 | ds = ds.repeat() 83 | ds = ds.shuffle(10 * batch_size, seed=0) 84 | else: 85 | ds = tf.data.Dataset.from_tensor_slices((data, label)) 86 | ds = ds.repeat() 87 | ds = ds.batch(batch_size) 88 | return ds.as_numpy_iterator() 89 | 90 | 91 | # %% 92 | # Next, we define the neural proposals for the Gibbs kernels. 93 | 94 | 95 | class GMMEncoderMeanTau(nn.Module): 96 | 97 | @nn.compact 98 | def __call__(self, x): 99 | s = nn.Dense(2)(x) 100 | 101 | t = nn.Dense(3)(x) 102 | t = nn.softmax(t, -1) 103 | 104 | s, t = jnp.expand_dims(s, -2), jnp.expand_dims(t, -1) 105 | N = t.sum(-3) 106 | x = (t * s).sum(-3) 107 | x2 = (t * s**2).sum(-3) 108 | mu0, nu0, alpha0, beta0 = (0, 0.1, 2, 2) 109 | alpha = alpha0 + 0.5 * N 110 | beta = ( 111 | beta0 112 | + 0.5 * (x2 - x**2 / N) 113 | + 0.5 * N * nu0 / (N + nu0) * (x / N - mu0) ** 2 114 | ) 115 | mu = (mu0 * nu0 + x) / (nu0 + N) 116 | nu = nu0 + N 117 | return alpha, beta, mu, nu 118 | 119 | 120 | class GMMEncoderC(nn.Module): 121 | 122 | @nn.compact 123 | def __call__(self, x): 124 | x = nn.Dense(32)(x) 125 | x = nn.tanh(x) 126 | logits = nn.Dense(1)(x).squeeze(-1) 127 | return logits + jnp.log(jnp.ones(3) / 3) 128 | 129 | 130 | def broadcast_concatenate(*xs): 131 | shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) 132 | xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] 133 | return jnp.concatenate(xs, -1) 134 | 135 | 136 | class GMMEncoder(nn.Module): 137 | 138 | def setup(self): 139 | self.encode_initial_mean_tau = GMMEncoderMeanTau() 140 | self.encode_mean_tau = GMMEncoderMeanTau() 141 | self.encode_c = GMMEncoderC() 142 | 143 | def __call__(self, x): # N x D 144 | # Heuristic procedure to setup initial parameters. 145 | alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D 146 | tau = alpha / beta # M x D 147 | 148 | xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) 149 | logits = self.encode_c(xmt) # N x D 150 | c = jnp.argmax(logits, -1) # N 151 | 152 | xc = jnp.concatenate([x, jax.nn.one_hot(c, 3)], axis=-1) 153 | return self.encode_mean_tau(xc) 154 | 155 | 156 | # %% 157 | # Then, we define the target and kernels as in Section 6.2. 158 | 159 | 160 | def gmm_target(network, key, inputs): 161 | key_out, key_mean, key_tau, key_c = random.split(key, 4) 162 | N = inputs.shape[-2] 163 | 164 | tau = coryx.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) 165 | mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")( 166 | key_mean 167 | ) 168 | c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) 169 | x = coryx.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") 170 | 171 | out = {"mean": mean, "tau": tau, "c": c, "x": x} 172 | return key_out, out 173 | 174 | 175 | def gmm_kernel_mean_tau(network, key, inputs): 176 | if not isinstance(inputs, dict): 177 | inputs = {"x": inputs} 178 | key_out, key_mean, key_tau = random.split(key, 3) 179 | 180 | if "c" in inputs: 181 | x = inputs["x"] 182 | c = jax.nn.one_hot(inputs["c"], 3) 183 | xc = jnp.concatenate([x, c], -1) 184 | alpha, beta, mu, nu = network.encode_mean_tau(xc) 185 | else: 186 | alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) 187 | tau = coryx.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) 188 | mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")( 189 | key_mean 190 | ) 191 | 192 | out = {**inputs, **{"mean": mean, "tau": tau}} 193 | return key_out, out 194 | 195 | 196 | def gmm_kernel_c(network, key, inputs): 197 | key_out, key_c = random.split(key, 2) 198 | 199 | x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] 200 | xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) 201 | logits = network.encode_c(xmt) 202 | c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) 203 | 204 | out = {**inputs, **{"c": c}} 205 | return key_out, out 206 | 207 | 208 | # %% 209 | # Finally, we create the gmm inference program, define the loss function, 210 | # run the training loop, and plot the results. 211 | 212 | 213 | def make_gmm(params, num_sweeps): 214 | network = coix.util.BindModule(GMMEncoder(), params) 215 | # Add particle dimension and construct a program. 216 | target = jax.vmap(partial(gmm_target, network)) 217 | kernels = [ 218 | jax.vmap(partial(gmm_kernel_mean_tau, network)), 219 | jax.vmap(partial(gmm_kernel_c, network)), 220 | ] 221 | program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) 222 | return program 223 | 224 | 225 | def loss_fn(params, key, batch, num_sweeps, num_particles): 226 | # Prepare data for the program. 227 | shuffle_rng, rng_key = random.split(key) 228 | batch = random.permutation(shuffle_rng, batch, axis=1) 229 | batch_rng = random.split(rng_key, batch.shape[0]) 230 | batch = jnp.repeat(batch[:, None], num_particles, axis=1) 231 | rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) 232 | 233 | # Run the program and get metrics. 234 | program = make_gmm(params, num_sweeps) 235 | _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) 236 | metrics = jax.tree.map( 237 | partial(jnp.mean, axis=0), metrics 238 | ) # mean across batch 239 | return metrics["loss"], metrics 240 | 241 | 242 | def main(args): 243 | lr = args.learning_rate 244 | num_steps = args.num_steps 245 | batch_size = args.batch_size 246 | num_sweeps = args.num_sweeps 247 | num_particles = args.num_particles 248 | 249 | train_ds = load_dataset("train", batch_size=batch_size) 250 | test_ds = load_dataset("test", batch_size=batch_size) 251 | 252 | init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2))) 253 | gmm_params, _ = coix.util.train( 254 | partial(loss_fn, num_sweeps=num_sweeps, num_particles=num_particles), 255 | init_params, 256 | optax.adam(lr), 257 | num_steps, 258 | train_ds, 259 | ) 260 | 261 | program = make_gmm(gmm_params, num_sweeps) 262 | batch, label = next(test_ds) 263 | batch = jnp.repeat(batch[:, None], num_particles, axis=1) 264 | rng_keys = jax.vmap(partial(random.split, num=num_particles))( 265 | random.split(jax.random.PRNGKey(1), batch.shape[0]) 266 | ) 267 | _, out = jax.vmap(program)(rng_keys, batch) 268 | 269 | _, axes = plt.subplots(2, 3, figsize=(15, 10)) 270 | for i in range(6): 271 | axes[i // 3][i % 3].scatter( 272 | batch[i, 0, :, 0], 273 | batch[i, 0, :, 1], 274 | marker=".", 275 | color=np.array(["c", "m", "y"])[label[i]], 276 | ) 277 | for j, c in enumerate(["r", "g", "b"]): 278 | ellipse = Ellipse( 279 | xy=(out["mean"][i, 0, j, 0], out["mean"][i, 0, j, 1]), 280 | width=4 / jnp.sqrt(out["tau"][i, 0, j, 0]), 281 | height=4 / jnp.sqrt(out["tau"][i, 0, j, 1]), 282 | fc=c, 283 | alpha=0.3, 284 | ) 285 | axes[i // 3][i % 3].add_patch(ellipse) 286 | plt.show() 287 | 288 | 289 | if __name__ == "__main__": 290 | parser = argparse.ArgumentParser(description="Annealing example") 291 | parser.add_argument("--batch-size", nargs="?", default=20, type=int) 292 | parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) 293 | parser.add_argument("--num_particles", nargs="?", default=10, type=int) 294 | parser.add_argument("--learning-rate", nargs="?", default=2.5e-4, type=float) 295 | parser.add_argument("--num-steps", nargs="?", default=200000, type=int) 296 | parser.add_argument( 297 | "--device", default="gpu", type=str, help='use "cpu" or "gpu".' 298 | ) 299 | args = parser.parse_args() 300 | 301 | tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. 302 | numpyro.set_platform(args.device) 303 | coix.set_backend("coix.oryx") 304 | 305 | main(args) 306 | -------------------------------------------------------------------------------- /notebooks/figures/smcs_nvi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/coix/84daae70474d47fcd221f60a1ea97867e143968f/notebooks/figures/smcs_nvi.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | # Project metadata. Available keys are documented at: 3 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata 4 | name = "coix" 5 | description = "Inference Combinators in JAX" 6 | readme = "README.md" 7 | requires-python = ">=3.9" 8 | license = {file = "LICENSE"} 9 | authors = [{name = "coix authors", email="coix-dev@google.com"}] 10 | classifiers = [ # List of https://pypi.org/classifiers/ 11 | "License :: OSI Approved :: Apache Software License", 12 | "Intended Audience :: Science/Research", 13 | ] 14 | keywords = [ 15 | "probabilistic machine learning", 16 | "bayesian statistics", 17 | ] 18 | 19 | # pip dependencies of the project 20 | # Installed locally with `pip install -e .` 21 | dependencies = [ 22 | "jax", 23 | "jaxlib", 24 | "numpy", 25 | "numpyro", 26 | ] 27 | 28 | # `version` is automatically set by flit to use `coix.__version__` 29 | dynamic = ["version"] 30 | 31 | [project.urls] 32 | homepage = "https://github.com/jax-ml/coix" 33 | repository = "https://github.com/jax-ml/coix" 34 | changelog = "https://github.com/jax-ml/coix/blob/main/CHANGELOG.md" 35 | documentation = "https://coix.readthedocs.io" 36 | 37 | [project.optional-dependencies] 38 | # Development deps (unittest, linting, formating,...) 39 | # Installed through `pip install -e .[dev]` 40 | dev = [ 41 | "flax", 42 | "isort", 43 | "matplotlib", 44 | "numpyro", 45 | "optax", 46 | "pytest", 47 | "pytest-xdist", 48 | "pylint>=2.6.0", 49 | "pyink", 50 | ] 51 | doc = [ 52 | "ipython", 53 | "nbsphinx", 54 | "readthedocs-sphinx-search", 55 | "sphinx>=5", 56 | "sphinx_rtd_theme", 57 | "sphinx-gallery", 58 | ] 59 | oryx = [ 60 | "oryx", 61 | ] 62 | 63 | [tool.pyink] 64 | # Formatting configuration to follow Google style-guide 65 | line-length = 80 66 | preview = true 67 | pyink-indentation = 2 68 | pyink-use-majority-quotes = true 69 | 70 | [tool.isort] 71 | profile = "google" 72 | known_third_party = ["coix", "numpyro"] 73 | src_paths = ["examples", "coix"] 74 | 75 | [build-system] 76 | # Build system specify which backend is used to build/install the project (flit, 77 | # poetry, setuptools,...). All backends are supported by `pip install` 78 | requires = ["flit_core>=3.8,<4"] 79 | build-backend = "flit_core.buildapi" 80 | --------------------------------------------------------------------------------