├── jax ├── py.typed ├── _src │ ├── third_party │ │ ├── __init__.py │ │ ├── numpy │ │ │ ├── __init__.py │ │ │ └── LICENSE │ │ ├── scipy │ │ │ ├── __init__.py │ │ │ └── LICENSE.txt │ │ └── README.md │ ├── __init__.py │ ├── image │ │ └── __init__.py │ ├── nn │ │ └── __init__.py │ ├── numpy │ │ └── __init__.py │ ├── ops │ │ └── __init__.py │ ├── scipy │ │ ├── __init__.py │ │ ├── stats │ │ │ ├── __init__.py │ │ │ ├── expon.py │ │ │ ├── gennorm.py │ │ │ ├── uniform.py │ │ │ ├── geom.py │ │ │ ├── cauchy.py │ │ │ ├── bernoulli.py │ │ │ ├── pareto.py │ │ │ ├── gamma.py │ │ │ ├── chi2.py │ │ │ ├── poisson.py │ │ │ ├── beta.py │ │ │ ├── nbinom.py │ │ │ ├── logistic.py │ │ │ ├── laplace.py │ │ │ └── t.py │ │ ├── cluster │ │ │ └── __init__.py │ │ ├── optimize │ │ │ └── __init__.py │ │ ├── sparse │ │ │ └── __init__.py │ │ └── interpolate │ │ │ └── __init__.py │ ├── lib │ │ └── mlir │ │ │ ├── __init__.py │ │ │ └── dialects │ │ │ └── __init__.py │ ├── lax │ │ └── __init__.py │ ├── debugger │ │ └── __init__.py │ ├── state │ │ └── __init__.py │ └── custom_api_util.py ├── experimental │ ├── jax2tf │ │ ├── OWNERS │ │ ├── examples │ │ │ ├── requirements.txt │ │ │ ├── tf_js │ │ │ │ ├── README.md │ │ │ │ └── quickdraw │ │ │ │ │ └── third_party │ │ │ │ │ └── zaidalyafeai.github.io │ │ │ │ │ ├── LICENSE │ │ │ │ │ └── class_names.txt │ │ │ ├── tflite │ │ │ │ └── README.md │ │ │ ├── __init__.py │ │ │ ├── serving │ │ │ │ └── __init__.py │ │ │ └── keras_reuse_main_test.py │ │ ├── converters_eval │ │ │ ├── converters_results.md.template │ │ │ └── test_models │ │ │ │ └── flax │ │ │ │ └── cnn.py │ │ ├── tests │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── JAX2TF_getting_started.ipynb │ │ └── BUILD │ ├── gda_serialization │ │ ├── README │ │ └── __init__.py │ ├── compilation_cache │ │ ├── __init__.py │ │ └── cache_interface.py │ ├── __init__.py │ └── checkify.py ├── tools │ ├── __init__.py │ ├── BUILD │ └── colab_tpu.py ├── example_libraries │ └── __init__.py ├── flatten_util.py ├── ipu │ ├── random │ │ ├── __init__.py │ │ └── prng.py │ ├── debug │ │ └── __init__.py │ ├── __init__.py │ └── primitive │ │ └── __init__.py ├── cloud_tpu_init.py ├── distributed.py ├── scipy │ ├── sparse │ │ ├── __init__.py │ │ └── linalg.py │ ├── cluster │ │ ├── __init__.py │ │ └── vq.py │ ├── fft.py │ ├── stats │ │ ├── nbinom.py │ │ ├── t.py │ │ ├── beta.py │ │ ├── chi2.py │ │ ├── expon.py │ │ ├── gamma.py │ │ ├── geom.py │ │ ├── cauchy.py │ │ ├── pareto.py │ │ ├── uniform.py │ │ ├── bernoulli.py │ │ ├── betabinom.py │ │ ├── dirichlet.py │ │ ├── gennorm.py │ │ ├── laplace.py │ │ ├── poisson.py │ │ ├── multivariate_normal.py │ │ ├── norm.py │ │ ├── logistic.py │ │ └── __init__.py │ ├── ndimage.py │ ├── optimize │ │ └── __init__.py │ ├── interpolate │ │ └── __init__.py │ ├── signal.py │ ├── __init__.py │ ├── linalg.py │ └── special.py ├── dlpack.py ├── custom_transpose.py ├── custom_batching.py ├── config.py ├── stages.py ├── interpreters │ └── __init__.py ├── ad_checkpoint.py ├── ops │ └── __init__.py ├── debug.py ├── abstract_arrays.py ├── lib │ ├── __init__.py │ └── xla_bridge.py ├── api_util.py ├── prng.py ├── image │ └── __init__.py ├── errors.py ├── util.py ├── lax │ └── linalg.py ├── profiler.py ├── version.py ├── dtypes.py ├── custom_derivatives.py ├── nn │ ├── initializers.py │ └── __init__.py └── numpy │ ├── fft.py │ └── linalg.py ├── .bazelversion ├── docs ├── changelog.md ├── _static │ ├── mesh.jpg │ ├── debugger.gif │ ├── favicon.png │ ├── perfetto.png │ ├── style.css │ ├── xla_spmd.jpg │ ├── multi_host.jpg │ ├── jax_logo_250px.png │ ├── partition_spec_x_y.png │ ├── partition_spec_xy.png │ ├── partition_spec_none_y.png │ ├── partition_spec_x_none.png │ ├── partition_spec_y_none.png │ └── tensorboard_profiler.png ├── _templates │ └── layout.html ├── README.md ├── notebooks │ └── README.md ├── jax.experimental.jet.rst ├── jax.experimental.pjit.rst ├── jax.example_libraries.stax.rst ├── jax.dlpack.rst ├── jax.example_libraries.optimizers.rst ├── jax_internal_api.rst ├── jax.experimental.maps.rst ├── jax.distributed.rst ├── jax.flatten_util.rst ├── jax.experimental.global_device_array.rst ├── installation.rst ├── jax.experimental.host_callback.rst ├── jax.debug.rst ├── requirements.txt ├── jax.image.rst ├── jax-101 │ └── index.rst ├── errors.rst ├── jax.experimental.checkify.rst ├── jax.config.rst ├── deprecation.md ├── jax.ops.rst ├── jax.experimental.sparse.rst ├── jax.tree_util.rst ├── concurrency.rst ├── jax.profiler.rst ├── jax.nn.rst ├── jax.lib.rst ├── jax.experimental.rst ├── jax.nn.initializers.rst ├── jax.example_libraries.rst ├── jax.random.rst └── rank_promotion_warning.rst ├── jaxlib ├── setup.cfg ├── README.md ├── init.py ├── pocketfft_kernels.h ├── pocketfft.fbs ├── rocm │ ├── hip_prng_kernels.h │ ├── hip_prng.cc │ ├── hip_lu_pivot_kernels.h │ └── hip_prng_kernels.cc ├── cuda │ ├── cuda_prng_kernels.h │ ├── cuda_prng.cc │ ├── cuda_lu_pivot_kernels.h │ ├── cuda_prng_kernels.cc │ └── cuda_linalg.cc ├── kernel_pybind11_helpers.h └── kernel_helpers.h ├── images ├── jax_logo.png ├── lifecycle.png ├── jax_logo_250px.png └── jax_logo_500px.png ├── cloud_tpu_colabs └── images │ ├── lorentz.png │ ├── wave_movie.gif │ └── nested_pmap.png ├── .clang-format ├── .style.yapf ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── Feature_request.md │ └── bug-report.yml └── workflows │ └── jax-ipu-release-public.yaml ├── third_party └── pocketfft │ ├── BUILD.bazel │ └── workspace.bzl ├── .gradient └── available_ipus.py ├── CITATION.bib ├── tests ├── filecheck │ ├── README.md │ ├── names.filecheck.py │ └── jax_filecheck_helpers.py ├── benchmarks │ └── xla.py ├── xla_interpreter_test.py ├── clear_backends_test.py ├── heap_profiler_test.py ├── third_party │ └── scipy │ │ └── LICENSE ├── stack_test.py ├── remote_transfer_test.py └── ipu │ └── primitive │ └── custom_primitive_test.py ├── .gitignore ├── setup.cfg ├── examples ├── __init__.py └── jax_cpp │ ├── prog.py │ └── BUILD ├── .readthedocs.yml ├── mypy.ini ├── CONTRIBUTING.md ├── conftest.py ├── .pre-commit-config.yaml ├── pytest.ini ├── setup.sh └── pylintrc /jax/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 5.1.1 2 | -------------------------------------------------------------------------------- /jax/_src/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jax/_src/third_party/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jax/_src/third_party/scipy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/OWNERS: -------------------------------------------------------------------------------- 1 | marcvanzee 2 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | ``` 3 | -------------------------------------------------------------------------------- /jaxlib/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE.txt -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow_datasets 2 | tensorflow_hub 3 | flax 4 | -------------------------------------------------------------------------------- /docs/_static/mesh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/mesh.jpg -------------------------------------------------------------------------------- /images/jax_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/images/jax_logo.png -------------------------------------------------------------------------------- /images/lifecycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/images/lifecycle.png -------------------------------------------------------------------------------- /docs/_static/debugger.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/debugger.gif -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/_static/perfetto.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/perfetto.png -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | .wy-side-nav-search { 4 | background-color: #fff; 5 | } 6 | -------------------------------------------------------------------------------- /docs/_static/xla_spmd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/xla_spmd.jpg -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set css_files = css_files + ["_static/style.css"] %} 3 | -------------------------------------------------------------------------------- /images/jax_logo_250px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/images/jax_logo_250px.png -------------------------------------------------------------------------------- /images/jax_logo_500px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/images/jax_logo_500px.png -------------------------------------------------------------------------------- /docs/_static/multi_host.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/multi_host.jpg -------------------------------------------------------------------------------- /docs/_static/jax_logo_250px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/jax_logo_250px.png -------------------------------------------------------------------------------- /cloud_tpu_colabs/images/lorentz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/cloud_tpu_colabs/images/lorentz.png -------------------------------------------------------------------------------- /docs/_static/partition_spec_x_y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/partition_spec_x_y.png -------------------------------------------------------------------------------- /docs/_static/partition_spec_xy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/partition_spec_xy.png -------------------------------------------------------------------------------- /cloud_tpu_colabs/images/wave_movie.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/cloud_tpu_colabs/images/wave_movie.gif -------------------------------------------------------------------------------- /docs/_static/partition_spec_none_y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/partition_spec_none_y.png -------------------------------------------------------------------------------- /docs/_static/partition_spec_x_none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/partition_spec_x_none.png -------------------------------------------------------------------------------- /docs/_static/partition_spec_y_none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/partition_spec_y_none.png -------------------------------------------------------------------------------- /docs/_static/tensorboard_profiler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/docs/_static/tensorboard_profiler.png -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | -------------------------------------------------------------------------------- /cloud_tpu_colabs/images/nested_pmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-experimental/HEAD/cloud_tpu_colabs/images/nested_pmap.png -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | To rebuild the documentation, 2 | see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). 3 | -------------------------------------------------------------------------------- /docs/notebooks/README.md: -------------------------------------------------------------------------------- 1 | For instructions on how to change and test notebooks, see 2 | [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). 3 | -------------------------------------------------------------------------------- /docs/jax.experimental.jet.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.jet module 2 | =========================== 3 | 4 | .. automodule:: jax.experimental.jet 5 | 6 | API 7 | --- 8 | 9 | .. autofunction:: jet 10 | -------------------------------------------------------------------------------- /docs/jax.experimental.pjit.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.pjit module 2 | ============================ 3 | 4 | .. automodule:: jax.experimental.pjit 5 | 6 | API 7 | --- 8 | 9 | .. autofunction:: pjit 10 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/tf_js/README.md: -------------------------------------------------------------------------------- 1 | This directory contains examples of using the jax2tf converter to produce 2 | models that can be used with TensorFlow.js. Note that this is still highly 3 | experimental. 4 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style: yapf 3 | column_limit: 88 4 | indent_width: 2 5 | split_before_named_assigns: True 6 | spaces_around_power_operator: True 7 | dedent_closing_brackets: True 8 | coalesce_brackets: True 9 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/tflite/README.md: -------------------------------------------------------------------------------- 1 | This directory contains examples of using the jax2tf converter to produce 2 | models that can be used with TensorFlow Lite. 3 | Note that this is still highly experimental. 4 | -------------------------------------------------------------------------------- /docs/jax.example_libraries.stax.rst: -------------------------------------------------------------------------------- 1 | jax.example_libraries.stax module 2 | ================================= 3 | 4 | .. automodule:: jax.example_libraries.stax 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Have questions or need support? 4 | url: https://github.com/google/jax/discussions 5 | about: Please ask questions on the Discussions tab 6 | -------------------------------------------------------------------------------- /docs/jax.dlpack.rst: -------------------------------------------------------------------------------- 1 | jax.dlpack module 2 | ================= 3 | 4 | .. currentmodule:: jax.dlpack 5 | 6 | .. automodule:: jax.dlpack 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | from_dlpack 12 | to_dlpack -------------------------------------------------------------------------------- /docs/jax.example_libraries.optimizers.rst: -------------------------------------------------------------------------------- 1 | jax.example_libraries.optimizers module 2 | ======================================= 3 | 4 | .. automodule:: jax.example_libraries.optimizers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/jax_internal_api.rst: -------------------------------------------------------------------------------- 1 | Internal APIs 2 | ============= 3 | 4 | core 5 | ---- 6 | 7 | .. currentmodule:: jax.core 8 | .. automodule:: jax.core 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | Jaxpr 14 | ClosedJaxpr 15 | -------------------------------------------------------------------------------- /docs/jax.experimental.maps.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.maps module 2 | ============================ 3 | 4 | .. automodule:: jax.experimental.maps 5 | 6 | API 7 | --- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | Mesh 13 | xmap 14 | -------------------------------------------------------------------------------- /docs/jax.distributed.rst: -------------------------------------------------------------------------------- 1 | jax.distributed module 2 | ====================== 3 | 4 | .. currentmodule:: jax.distributed 5 | 6 | .. automodule:: jax.distributed 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | initialize 12 | shutdown -------------------------------------------------------------------------------- /docs/jax.flatten_util.rst: -------------------------------------------------------------------------------- 1 | jax.flatten_util package 2 | ======================== 3 | 4 | .. currentmodule:: jax.flatten_util 5 | 6 | .. automodule:: jax.flatten_util 7 | 8 | List of Functions 9 | ----------------- 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | ravel_pytree -------------------------------------------------------------------------------- /docs/jax.experimental.global_device_array.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.global_device_array module 2 | =========================================== 3 | 4 | .. automodule:: jax.experimental.global_device_array 5 | 6 | API 7 | --- 8 | 9 | .. autoclass:: GlobalDeviceArray 10 | :members: 11 | .. autoclass:: Shard 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 'Feature Request' 3 | about: 'Suggest a new idea or improvement for JAX' 4 | labels: 'enhancement' 5 | --- 6 | 7 | Please: 8 | 9 | - [ ] Check for duplicate requests. 10 | - [ ] Describe your goal, and if possible provide a code snippet with a motivating example. 11 | -------------------------------------------------------------------------------- /jax/_src/third_party/README.md: -------------------------------------------------------------------------------- 1 | This sub-directory contains third-party code for which Google does not have 2 | copyright. Each sub-directory should correspond to a third-party library and 3 | must contain the appropriate LICENSE file. 4 | See [instructions](https://opensource.google/docs/releasing/preparing/#third-party-components). 5 | -------------------------------------------------------------------------------- /third_party/pocketfft/BUILD.bazel: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "pocketfft", 7 | hdrs = ["pocketfft_hdronly.h"], 8 | copts = ["-fexceptions"], 9 | features = ["-use_header_modules"], 10 | include_prefix = "pocketfft", 11 | ) 12 | -------------------------------------------------------------------------------- /jax/experimental/gda_serialization/README: -------------------------------------------------------------------------------- 1 | # Serialization and De-serialization of GlobalDeviceArray via tensorstore 2 | 3 | Warning: This directory is going to move in the near future. Please use at your 4 | own risk. 5 | 6 | To use this library, please install tensorstore and JAX. 7 | 8 | ```bash 9 | pip install -U tensorstore 10 | ``` -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installing JAX 2 | ============== 3 | 4 | JAX is available to install via the `Python Package Index`_. 5 | For full installation instructions, please refer to the `Install Guide`_ in the project README. 6 | 7 | .. _Python Package Index: https://pypi.org/project/jax/ 8 | .. _Install Guide: https://github.com/google/jax#installation 9 | -------------------------------------------------------------------------------- /docs/jax.experimental.host_callback.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.host_callback module 2 | ===================================== 3 | 4 | 5 | .. automodule:: jax.experimental.host_callback 6 | 7 | API 8 | --- 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | id_tap 14 | id_print 15 | call 16 | barrier_wait 17 | CallbackException 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /.gradient/available_ipus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import subprocess 3 | import json 4 | 5 | j = subprocess.check_output(['gc-monitor', '-j']) 6 | data = json.loads(j) 7 | num_ipuMs = len(data["cards"]) 8 | num_ipus = 4 * num_ipuMs 9 | 10 | # to be captured as a variable in the bash script that calls this python script 11 | print(num_ipus) 12 | -------------------------------------------------------------------------------- /docs/jax.debug.rst: -------------------------------------------------------------------------------- 1 | 2 | jax.debug package 3 | ================= 4 | 5 | .. currentmodule:: jax.debug 6 | 7 | .. automodule:: jax.debug 8 | 9 | Debugging utilities 10 | -------------------------- 11 | 12 | :doc:`debugging/print_breakpoint` describes how to make use of JAX's debugging features. 13 | 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | callback 18 | print 19 | breakpoint 20 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx >=3 2 | sphinx-autodoc-typehints 3 | sphinx-book-theme>=0.3.3 4 | sphinx-copybutton>=0.5.0 5 | sphinx-remove-toctrees 6 | jupyter-sphinx>=0.3.2 7 | myst-nb 8 | 9 | # Packages used for CI tests. 10 | pytest 11 | pytest-xdist 12 | 13 | # Packages used for notebook execution 14 | matplotlib 15 | scikit-learn 16 | numpy 17 | .[ci] # Install jax from the current directory; jaxlib from pypi. 18 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @software{jax2018github, 2 | author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, 3 | title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, 4 | url = {http://github.com/google/jax}, 5 | version = {0.3.13}, 6 | year = {2018}, 7 | } 8 | -------------------------------------------------------------------------------- /jaxlib/README.md: -------------------------------------------------------------------------------- 1 | # jaxlib: support library for JAX 2 | 3 | jaxlib is the support library for JAX. While JAX itself is a pure Python package, 4 | jaxlib contains the binary (C/C++) parts of the library, including Python bindings, 5 | the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. 6 | For more information, including installation and build instructions, refer to main 7 | JAX README: https://github.com/google/jax/. 8 | -------------------------------------------------------------------------------- /tests/filecheck/README.md: -------------------------------------------------------------------------------- 1 | This directory contains LLVM 2 | [FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify 3 | that JAX primitives can be lowered to MHLO. 4 | 5 | These tests are intended to be a quick and easy-to-understand way to catch 6 | regressions from changes due the MLIR Python bindings and from changes to the 7 | various MLIR dialects used by JAX, without needing to run the full JAX test 8 | suite. 9 | -------------------------------------------------------------------------------- /docs/jax.image.rst: -------------------------------------------------------------------------------- 1 | jax.image package 2 | ================= 3 | 4 | .. currentmodule:: jax.image 5 | 6 | .. automodule:: jax.image 7 | 8 | 9 | Image manipulation functions 10 | ---------------------------- 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | resize 16 | scale_and_translate 17 | 18 | Argument classes 19 | ---------------- 20 | 21 | .. currentmodule:: jax.image 22 | 23 | .. autoclass:: ResizeMethod 24 | -------------------------------------------------------------------------------- /docs/jax-101/index.rst: -------------------------------------------------------------------------------- 1 | Tutorial: JAX 101 2 | ================= 3 | 4 | This is a tutorial developed by engineers and researchers at DeepMind_. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: Tutorials 9 | 10 | 01-jax-basics 11 | 02-jitting 12 | 03-vectorization 13 | 04-advanced-autodiff 14 | 05-random-numbers 15 | 05.1-pytrees 16 | 06-parallelism 17 | 07-state 18 | 08-pjit 19 | 20 | 21 | .. _Deepmind: http://deepmind.com 22 | -------------------------------------------------------------------------------- /docs/errors.rst: -------------------------------------------------------------------------------- 1 | .. _jax-errors: 2 | 3 | JAX Errors 4 | ========== 5 | This page lists a few of the errors you might encounter when using JAX, 6 | along with representative examples of how one might fix them. 7 | 8 | .. currentmodule:: jax.errors 9 | .. autoclass:: ConcretizationTypeError 10 | .. autoclass:: NonConcreteBooleanIndexError 11 | .. autoclass:: TracerArrayConversionError 12 | .. autoclass:: TracerIntegerConversionError 13 | .. autoclass:: UnexpectedTracerError 14 | -------------------------------------------------------------------------------- /docs/jax.experimental.checkify.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.checkify module 2 | ===================================== 3 | 4 | 5 | .. automodule:: jax.experimental.checkify 6 | 7 | API 8 | --- 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | checkify 14 | check 15 | check_error 16 | Error 17 | ErrorCategory 18 | user_checks 19 | nan_checks 20 | index_checks 21 | div_checks 22 | float_checks 23 | automatic_checks 24 | all_checks 25 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/converters_eval/converters_results.md.template: -------------------------------------------------------------------------------- 1 | # JAX Converters Evaluation Results 2 | 3 | *Last generated on: {{generation_date}}* (YYYY-MM-DD) 4 | 5 | This file contains the evaluation results for all converters in table format. 6 | Please see [README.md](README.md) for more details. 7 | 8 | ## Summary Table 9 | 10 | {{table}} 11 | 12 | ## Errors 13 | 14 | {{errors}} 15 | 16 | See `models_test.py` for instructions on how to regenerate this table. 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | *.egg-info 4 | *.whl 5 | build/bazel* 6 | dist/ 7 | .ipynb_checkpoints 8 | /bazel-* 9 | .jax_configure.bazelrc 10 | /tensorflow 11 | .DS_Store 12 | .mypy_cache/ 13 | .pytype/ 14 | docs/build 15 | *_pb2.py 16 | docs/notebooks/.ipynb_checkpoints/ 17 | docs/_autosummary 18 | .idea 19 | .vscode 20 | .devcontainer 21 | jax.iml 22 | 23 | # virtualenv/venv directories 24 | venv/ 25 | bin/ 26 | include/ 27 | lib/ 28 | share/ 29 | 30 | # IPU specific. 31 | *.rendered.*.cpp 32 | -------------------------------------------------------------------------------- /docs/jax.config.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: jax 2 | 3 | JAX configuration 4 | ================= 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | config 10 | check_tracer_leaks 11 | checking_leaks 12 | debug_nans 13 | debug_infs 14 | default_device 15 | default_matmul_precision 16 | default_prng_impl 17 | enable_checks 18 | enable_custom_prng 19 | enable_custom_vjp_by_custom_transpose 20 | log_compiles 21 | numpy_rank_promotion 22 | transfer_guard 23 | -------------------------------------------------------------------------------- /docs/deprecation.md: -------------------------------------------------------------------------------- 1 | # Python and NumPy version support policy 2 | 3 | 4 | JAX follows NumPy's [NEP-29 deprecation policy](https://numpy.org/neps/nep-0029-deprecation_policy.html). JAX supports at least: 5 | 6 | * All minor versions of Python released 42 months prior to the project, and at minimum the two latest minor versions. 7 | 8 | * All minor versions of numpy released in the 24 months prior to the project, and at minimum the last three minor versions. 9 | 10 | JAX may support older versions of Python and NumPy, but support for older versions may be dropped at any time. -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = 4 | C901 # object names too complex 5 | E111, E114 # four-space indents 6 | E121 # line continuations 7 | W503, W504 # line breaks around binary operators 8 | max-complexity = 18 9 | select = B,C,F,W,T4,B9,E225,E227,E228 10 | exclude = 11 | .git, 12 | build, 13 | __pycache__ 14 | per-file-ignores = 15 | docs/autodidax.py:F811 16 | jax/*.py:F401 17 | jax/**/__init__.py:F401 18 | jax/experimental/*.py:F401 19 | jax/lax/*.py:F401 20 | jax/nn/*.py:F401 21 | jax/numpy/*.py:F401 22 | jax/scipy/**/*.py:F401 23 | -------------------------------------------------------------------------------- /docs/jax.ops.rst: -------------------------------------------------------------------------------- 1 | 2 | jax.ops package 3 | =============== 4 | 5 | .. currentmodule:: jax.ops 6 | 7 | .. automodule:: jax.ops 8 | 9 | .. _syntactic-sugar-for-ops: 10 | 11 | The functions ``jax.ops.index_update``, ``jax.ops.index_add``, etc., which were 12 | deprecated in JAX 0.2.22, have been removed. Please use the 13 | :attr:`jax.numpy.ndarray.at` property on JAX arrays instead. 14 | 15 | Segment reduction operators 16 | --------------------------- 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | segment_max 22 | segment_min 23 | segment_prod 24 | segment_sum 25 | -------------------------------------------------------------------------------- /docs/jax.experimental.sparse.rst: -------------------------------------------------------------------------------- 1 | jax.experimental.sparse module 2 | ============================== 3 | 4 | .. automodule:: jax.experimental.sparse 5 | 6 | API 7 | --- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | BCOO 13 | sparsify 14 | bcoo_broadcast_in_dim 15 | bcoo_concatenate 16 | bcoo_dot_general 17 | bcoo_dot_general_sampled 18 | bcoo_extract 19 | bcoo_fromdense 20 | bcoo_multiply_dense 21 | bcoo_multiply_sparse 22 | bcoo_reduce_sum 23 | bcoo_reshape 24 | bcoo_sort_indices 25 | bcoo_sum_duplicates 26 | bcoo_todense 27 | bcoo_transpose 28 | -------------------------------------------------------------------------------- /docs/jax.tree_util.rst: -------------------------------------------------------------------------------- 1 | jax.tree_util package 2 | ===================== 3 | 4 | .. currentmodule:: jax.tree_util 5 | 6 | .. automodule:: jax.tree_util 7 | 8 | List of Functions 9 | ----------------- 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | Partial 15 | all_leaves 16 | build_tree 17 | register_pytree_node 18 | register_pytree_node_class 19 | tree_all 20 | tree_flatten 21 | tree_leaves 22 | tree_map 23 | tree_reduce 24 | tree_structure 25 | tree_transpose 26 | tree_unflatten 27 | treedef_children 28 | treedef_is_leaf 29 | treedef_tuple 30 | 31 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/optimize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/example_libraries/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/_src/scipy/interpolate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/experimental/compilation_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/experimental/gda_serialization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/serving/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/concurrency.rst: -------------------------------------------------------------------------------- 1 | Concurrency 2 | =========== 3 | 4 | JAX has limited support for Python concurrency. 5 | 6 | Clients may call JAX APIs (e.g., :func:`~jax.jit` or :func:`~jax.grad`) 7 | concurrently from separate Python threads. 8 | 9 | It is not permitted to manipulate JAX trace values concurrently from multiple 10 | threads. In other words, while it is permissible to call functions that use JAX 11 | tracing (e.g., :func:`~jax.jit`) from multiple threads, you must not use 12 | threading to manipulate JAX values inside the implementation of the function 13 | `f` that is passed to :func:`~jax.jit`. The most likely outcome if you do this 14 | is a mysterious error from JAX. 15 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: "ubuntu-20.04" 10 | tools: 11 | python: "3.9" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | fail_on_warning: true 17 | 18 | # Optionally build your docs in additional formats such as PDF and ePub 19 | formats: 20 | - htmlzip 21 | 22 | # Optionally set the version of Python and requirements required to build your docs 23 | python: 24 | install: 25 | - requirements: docs/requirements.txt 26 | -------------------------------------------------------------------------------- /jaxlib/init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .version import __version__ # noqa: F401 16 | -------------------------------------------------------------------------------- /jax/flatten_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.flatten_util import ravel_pytree 16 | -------------------------------------------------------------------------------- /jax/ipu/random/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import prng 15 | -------------------------------------------------------------------------------- /jax/cloud_tpu_init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.cloud_tpu_init import cloud_tpu_init 16 | -------------------------------------------------------------------------------- /jax/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.distributed import (initialize, shutdown) 16 | -------------------------------------------------------------------------------- /jax/scipy/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.scipy.sparse import linalg as linalg 16 | -------------------------------------------------------------------------------- /jax/_src/lib/mlir/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # flake8: noqa: F401 16 | import jaxlib.mlir.ir as ir 17 | -------------------------------------------------------------------------------- /jax/dlpack.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.dlpack import (to_dlpack, from_dlpack, SUPPORTED_DTYPES) 16 | -------------------------------------------------------------------------------- /jax/scipy/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.scipy.cluster import vq as vq 16 | -------------------------------------------------------------------------------- /jax/scipy/cluster/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.cluster.vq import vq as vq 16 | -------------------------------------------------------------------------------- /jax/custom_transpose.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.custom_transpose import ( 16 | custom_transpose, 17 | ) 18 | -------------------------------------------------------------------------------- /jax/scipy/fft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.fft import ( 16 | dct as dct, 17 | dctn as dctn, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/nbinom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.nbinom import ( 16 | logpmf, 17 | pmf, 18 | ) 19 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | show_error_codes = True 3 | disable_error_code = attr-defined 4 | 5 | [mypy-absl.*] 6 | ignore_missing_imports = True 7 | [mypy-colorama.*] 8 | ignore_missing_imports = True 9 | [mypy-numpy.*] 10 | ignore_missing_imports = True 11 | [mypy-opt_einsum.*] 12 | ignore_missing_imports = True 13 | [mypy-scipy.*] 14 | ignore_missing_imports = True 15 | [mypy-jax.interpreters.autospmd] 16 | ignore_errors = True 17 | [mypy-jax.lax.lax_parallel] 18 | ignore_errors = True 19 | [mypy-jax.experimental.jax2tf.tests.primitive_harness] 20 | ignore_errors = True 21 | [mypy-libtpu.*] 22 | ignore_missing_imports = True 23 | [mypy-jaxlib.mlir.*] 24 | ignore_missing_imports = True 25 | [mypy-iree.*] 26 | ignore_missing_imports = True 27 | -------------------------------------------------------------------------------- /jax/scipy/ndimage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.ndimage import ( 16 | map_coordinates as map_coordinates, 17 | ) 18 | -------------------------------------------------------------------------------- /jax/scipy/stats/t.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.t import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/custom_batching.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.custom_batching import ( 16 | custom_vmap, 17 | sequential_vmap, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/ipu/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .ipu_python_callback import ipu_debug_callback_custom_call 15 | -------------------------------------------------------------------------------- /jax/scipy/stats/beta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.beta import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/chi2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.chi2 import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/expon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.expon import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/gamma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.gamma import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/geom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.geom import ( 16 | logpmf as logpmf, 17 | pmf as pmf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/ipu/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from . import primitive 15 | from . import debug 16 | from . import random 17 | -------------------------------------------------------------------------------- /jax/scipy/stats/cauchy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.cauchy import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/pareto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.pareto import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/uniform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.uniform import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # TODO(phawkins): fix users of this alias and delete this file. 16 | 17 | from jax._src.config import config 18 | -------------------------------------------------------------------------------- /jax/scipy/stats/bernoulli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.bernoulli import ( 16 | logpmf as logpmf, 17 | pmf as pmf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/betabinom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.betabinom import ( 16 | logpmf as logpmf, 17 | pmf as pmf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/scipy/stats/dirichlet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.dirichlet import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/stages.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.stages import ( 16 | Compiled as Compiled, 17 | Lowered as Lowered, 18 | Wrapped as Wrapped, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/stats/gennorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.gennorm import ( 16 | cdf as cdf, 17 | logpdf as logpdf, 18 | pdf as pdf, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/stats/laplace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.laplace import ( 16 | cdf as cdf, 17 | logpdf as logpdf, 18 | pdf as pdf, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/stats/poisson.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.poisson import ( 16 | logpmf as logpmf, 17 | pmf as pmf, 18 | cdf as cdf, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/sparse/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.sparse.linalg import ( 16 | cg as cg, 17 | gmres as gmres, 18 | bicgstab as bicgstab, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/stats/multivariate_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.multivariate_normal import ( 16 | logpdf as logpdf, 17 | pdf as pdf, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/_src/lax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from jax._src import traceback_util 18 | traceback_util.register_exclusion(os.path.dirname(__file__)) 19 | -------------------------------------------------------------------------------- /jax/interpreters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from jax._src import traceback_util 18 | traceback_util.register_exclusion(os.path.dirname(__file__)) 19 | -------------------------------------------------------------------------------- /examples/jax_cpp/prog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example function to be jitted.""" 16 | import jax.numpy as jnp 17 | 18 | def fn(x, y, z): 19 | return jnp.dot(x, y) / z 20 | -------------------------------------------------------------------------------- /jax/scipy/optimize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.optimize.minimize import ( 16 | minimize as minimize, 17 | OptimizeResults as OptimizeResults, 18 | ) 19 | -------------------------------------------------------------------------------- /jax/ad_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.ad_checkpoint import ( 16 | checkpoint, 17 | checkpoint_policies, 18 | checkpoint_name, 19 | print_saved_residuals, 20 | remat, 21 | ) 22 | -------------------------------------------------------------------------------- /jax/scipy/stats/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.norm import ( 16 | cdf as cdf, 17 | logcdf as logcdf, 18 | logpdf as logpdf, 19 | pdf as pdf, 20 | ppf as ppf, 21 | ) 22 | -------------------------------------------------------------------------------- /jax/scipy/stats/logistic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.stats.logistic import ( 16 | cdf as cdf, 17 | isf as isf, 18 | logpdf as logpdf, 19 | pdf as pdf, 20 | ppf as ppf, 21 | sf as sf, 22 | ) 23 | -------------------------------------------------------------------------------- /docs/jax.profiler.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: jax.profiler 2 | 3 | jax.profiler module 4 | =================== 5 | 6 | .. automodule:: jax.profiler 7 | 8 | Tracing and time profiling 9 | -------------------------- 10 | 11 | :doc:`profiling` describes how to make use of JAX's tracing and time profiling 12 | features. 13 | 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | start_server 18 | start_trace 19 | stop_trace 20 | trace 21 | annotate_function 22 | TraceAnnotation 23 | StepTraceAnnotation 24 | 25 | 26 | Device memory profiling 27 | ----------------------- 28 | 29 | See :doc:`device_memory_profiling` for an introduction to JAX's device memory 30 | profiling features. 31 | 32 | .. autosummary:: 33 | :toctree: _autosummary 34 | 35 | device_memory_profile 36 | save_device_memory_profile 37 | -------------------------------------------------------------------------------- /docs/jax.nn.rst: -------------------------------------------------------------------------------- 1 | 2 | jax.nn package 3 | ================= 4 | 5 | .. currentmodule:: jax.nn 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | jax.nn.initializers 11 | 12 | .. automodule:: jax.nn 13 | 14 | 15 | Activation functions 16 | ------------------------ 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | relu 22 | relu6 23 | sigmoid 24 | softplus 25 | soft_sign 26 | silu 27 | swish 28 | log_sigmoid 29 | leaky_relu 30 | hard_sigmoid 31 | hard_silu 32 | hard_swish 33 | hard_tanh 34 | elu 35 | celu 36 | selu 37 | gelu 38 | glu 39 | 40 | Other functions 41 | --------------- 42 | 43 | .. autosummary:: 44 | :toctree: _autosummary 45 | 46 | softmax 47 | log_softmax 48 | logsumexp 49 | normalize 50 | one_hot 51 | -------------------------------------------------------------------------------- /jax/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.ops.scatter import ( 16 | segment_sum as segment_sum, 17 | segment_prod as segment_prod, 18 | segment_min as segment_min, 19 | segment_max as segment_max, 20 | ) 21 | -------------------------------------------------------------------------------- /docs/jax.lib.rst: -------------------------------------------------------------------------------- 1 | jax.lib package 2 | =============== 3 | The `jax.lib` package is a set of internal tools and types for bridging between 4 | JAX's Python frontend and its XLA backend. 5 | 6 | jax.lib.xla_bridge 7 | ------------------ 8 | 9 | .. currentmodule:: jax.lib.xla_bridge 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | default_backend 15 | device_count 16 | get_backend 17 | get_compile_options 18 | local_device_count 19 | process_index 20 | 21 | jax.lib.xla_client 22 | ------------------ 23 | 24 | .. currentmodule:: jaxlib.xla_client 25 | 26 | .. autosummary:: 27 | :toctree: _autosummary 28 | 29 | jax.lib.xla_extension 30 | --------------------- 31 | 32 | .. currentmodule:: jaxlib.xla_extension 33 | 34 | .. autosummary:: 35 | :toctree: _autosummary 36 | 37 | Device 38 | TpuDevice 39 | -------------------------------------------------------------------------------- /jax/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from jax._src.debugging import debug_callback as callback 15 | from jax._src.debugging import debug_print as print 16 | from jax._src.debugging import DebugEffect 17 | from jax._src.debugger import breakpoint 18 | -------------------------------------------------------------------------------- /jax/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.interpreters.pxla import PartitionSpec as PartitionSpec 16 | from jax.experimental.x64_context import ( 17 | enable_x64 as enable_x64, 18 | disable_x64 as disable_x64, 19 | ) 20 | -------------------------------------------------------------------------------- /jax/scipy/interpolate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Already deprecate namespaces that will be removed in SciPy v2.0.0 16 | 17 | from jax._src.third_party.scipy.interpolate import ( 18 | RegularGridInterpolator as RegularGridInterpolator) 19 | -------------------------------------------------------------------------------- /jax/abstract_arrays.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # TODO(phawkins): fix users of these aliases and delete this file. 16 | 17 | from jax._src.abstract_arrays import array_types 18 | from jax.core import ( 19 | ShapedArray, 20 | raise_to_shaped, 21 | ) 22 | -------------------------------------------------------------------------------- /jax/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # flake8: noqa: F401 16 | from jax._src.lib import ( 17 | version_str as __version__, 18 | xla_client as xla_client, 19 | xla_extension as xla_extension, 20 | ) 21 | from jax.lib import xla_bridge as xla_bridge 22 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val, 16 | split_to_logical_devices, PolyShape) 17 | from jax.experimental.jax2tf.call_tf import call_tf 18 | -------------------------------------------------------------------------------- /jax/api_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from jax._src.api_util import ( 17 | argnums_partial, 18 | donation_vector, 19 | flatten_axes, 20 | flatten_fun, 21 | flatten_fun_nokwargs, 22 | rebase_donate_argnums, 23 | safe_map, 24 | shaped_abstractify, 25 | ) 26 | -------------------------------------------------------------------------------- /jaxlib/pocketfft_kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 17 | 18 | namespace jax { 19 | 20 | void PocketFft(void* out, void** in, XlaCustomCallStatus*); 21 | 22 | } // namespace jax 23 | -------------------------------------------------------------------------------- /jax/experimental/compilation_cache/cache_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | class CacheInterface(ABC): 18 | @abstractmethod 19 | def get(self, key: str): 20 | pass 21 | 22 | @abstractmethod 23 | def put(self, key: str, value: bytes): 24 | pass 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to experimental JAX on IPU 2 | 3 | ## Issues 4 | 5 | To help respond effectively to issues: 6 | 7 | * Please tag your issue with `bug`, `feature request`, or `question`. 8 | * Please include the version of JAX you are running, and the IPU hardware type. 9 | * Please provide a minimal reproducing example. 10 | 11 | ## Pull Requests 12 | 13 | We'd love to accept your patches and contributions to this project. There are 14 | just a few small guidelines you need to follow. 15 | 16 | **Contributor License Agreement** 17 | 18 | You need to agree to Graphcore contributor license agreement: 19 | 20 | **Code reviews** 21 | 22 | All submissions require review. We 23 | use GitHub pull requests for this purpose. Consult 24 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 25 | information on using pull requests. 26 | -------------------------------------------------------------------------------- /jax/scipy/signal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.signal import ( 16 | convolve as convolve, 17 | convolve2d as convolve2d, 18 | correlate as correlate, 19 | correlate2d as correlate2d, 20 | detrend as detrend, 21 | csd as csd, 22 | istft as istft, 23 | stft as stft, 24 | welch as welch, 25 | ) 26 | -------------------------------------------------------------------------------- /docs/jax.experimental.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: jax.experimental 2 | 3 | jax.experimental package 4 | ======================== 5 | 6 | ``jax.experimental.optix`` has been moved into its own Python package 7 | (https://github.com/deepmind/optax). 8 | 9 | ``jax.experimental.ann`` has been moved into ``jax.lax``. 10 | 11 | Experimental Modules 12 | -------------------- 13 | 14 | .. toctree:: 15 | :maxdepth: 1 16 | 17 | jax.experimental.checkify 18 | jax.experimental.global_device_array 19 | jax.experimental.host_callback 20 | jax.experimental.maps 21 | jax.experimental.pjit 22 | jax.experimental.sparse 23 | jax.experimental.jet 24 | 25 | Experimental APIs 26 | ----------------- 27 | 28 | .. autosummary:: 29 | :toctree: _autosummary 30 | 31 | enable_x64 32 | disable_x64 33 | 34 | jax.experimental.checkify.checkify 35 | jax.experimental.checkify.check 36 | jax.experimental.checkify.check_error 37 | -------------------------------------------------------------------------------- /jax/prng.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.prng import ( 16 | PRNGImpl as PRNGImpl, 17 | seed_with_impl as seed_with_impl, 18 | threefry2x32_p as threefry2x32_p, 19 | threefry_2x32 as threefry_2x32, 20 | threefry_prng_impl as threefry_prng_impl, 21 | rbg_prng_impl as rbg_prng_impl, 22 | unsafe_rbg_prng_impl as unsafe_rbg_prng_impl, 23 | ) 24 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """pytest configuration""" 15 | 16 | import pytest 17 | 18 | 19 | @pytest.fixture(autouse=True) 20 | def add_imports(doctest_namespace): 21 | import jax 22 | import numpy 23 | doctest_namespace["jax"] = jax 24 | doctest_namespace["lax"] = jax.lax 25 | doctest_namespace["jnp"] = jax.numpy 26 | doctest_namespace["np"] = numpy 27 | -------------------------------------------------------------------------------- /jax/_src/debugger/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from jax._src.debugger.core import breakpoint 15 | from jax._src.debugger import cli_debugger 16 | from jax._src.debugger import colab_debugger 17 | from jax._src.debugger import web_debugger 18 | 19 | del cli_debugger # For registration only 20 | del colab_debugger # For registration only 21 | del web_debugger # For registration only 22 | -------------------------------------------------------------------------------- /jax/_src/state/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module for state.""" 15 | from jax._src.state.types import ShapedArrayRef, StateEffect 16 | from jax._src.state.primitives import (ref_get, ref_set, ref_swap, 17 | ref_addupdate, get_p, swap_p, 18 | addupdate_p) 19 | from jax._src.state.discharge import discharge_state 20 | -------------------------------------------------------------------------------- /docs/jax.nn.initializers.rst: -------------------------------------------------------------------------------- 1 | 2 | jax.nn.initializers package 3 | =========================== 4 | 5 | .. currentmodule:: jax.nn.initializers 6 | 7 | .. automodule:: jax.nn.initializers 8 | 9 | 10 | Initializers 11 | ------------ 12 | 13 | This module provides common neural network layer initializers, 14 | consistent with definitions used in Keras and Sonnet. 15 | 16 | An initializer is a function that takes three arguments: 17 | ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and 18 | data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random 19 | key used when generating random numbers to initialize the array. 20 | 21 | .. autosummary:: 22 | :toctree: _autosummary 23 | 24 | constant 25 | delta_orthogonal 26 | glorot_normal 27 | glorot_uniform 28 | he_normal 29 | he_uniform 30 | lecun_normal 31 | lecun_uniform 32 | normal 33 | ones 34 | orthogonal 35 | uniform 36 | variance_scaling 37 | zeros 38 | -------------------------------------------------------------------------------- /jax/_src/custom_api_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | _custom_wrapper_types = set() 17 | 18 | def register_custom_decorator_type(cls): 19 | _custom_wrapper_types.add(cls) 20 | return cls 21 | 22 | def forward_attr(self_, name): 23 | if name.startswith('def') and type(self_.fun) in _custom_wrapper_types: 24 | return getattr(self_.fun, name) 25 | else: 26 | raise AttributeError 27 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/JAX2TF_getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "JAX2TF_getting_started.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "mB5eSZXZIO9W" 21 | }, 22 | "source": [ 23 | "JAX-TensorFlow interoperation with JAX2TF\n", 24 | "===========================================\n", 25 | "\n", 26 | "Link: go/jax2tf-colab\n", 27 | "\n", 28 | "The JAX2TF colab has been deprecated, and the example code has\n", 29 | "been moved to [jax2tf/examples](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples). \n" 30 | ] 31 | } 32 | ] 33 | } 34 | -------------------------------------------------------------------------------- /jax/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Image manipulation functions. 16 | 17 | More image manipulation functions can be found in libraries built on top of 18 | JAX, such as `PIX`_. 19 | 20 | .. _PIX: https://github.com/deepmind/dm_pix 21 | """ 22 | 23 | from jax._src.image.scale import ( 24 | resize as resize, 25 | ResizeMethod as ResizeMethod, 26 | scale_and_translate as scale_and_translate, 27 | ) 28 | -------------------------------------------------------------------------------- /jax/lib/xla_bridge.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # flake8: noqa: F401 16 | from jax._src.lib.xla_bridge import ( 17 | default_backend as default_backend, 18 | device_count as device_count, 19 | get_backend as get_backend, 20 | get_compile_options as get_compile_options, 21 | local_device_count as local_device_count, 22 | process_index as process_index, 23 | xla_client as xla_client, 24 | _backends as _backends, 25 | ) 26 | -------------------------------------------------------------------------------- /jax/scipy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.scipy import interpolate as interpolate 16 | from jax.scipy import linalg as linalg 17 | from jax.scipy import ndimage as ndimage 18 | from jax.scipy import signal as signal 19 | from jax.scipy import sparse as sparse 20 | from jax.scipy import special as special 21 | from jax.scipy import stats as stats 22 | from jax.scipy import fft as fft 23 | from jax.scipy import cluster as cluster 24 | -------------------------------------------------------------------------------- /jax/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.errors import ( 16 | JAXTypeError as JAXTypeError, 17 | JAXIndexError as JAXIndexError, 18 | ConcretizationTypeError as ConcretizationTypeError, 19 | NonConcreteBooleanIndexError as NonConcreteBooleanIndexError, 20 | TracerArrayConversionError as TracerArrayConversionError, 21 | TracerIntegerConversionError as TracerIntegerConversionError, 22 | UnexpectedTracerError as UnexpectedTracerError, 23 | ) 24 | -------------------------------------------------------------------------------- /docs/jax.example_libraries.rst: -------------------------------------------------------------------------------- 1 | jax.example_libraries package 2 | ============================= 3 | 4 | JAX provides some small, experimental libraries for machine learning. These 5 | libraries are in part about providing tools and in part about serving as 6 | examples for how to build such libraries using JAX. Each one is only <300 source 7 | lines of code, so take a look inside and adapt them as you need! 8 | 9 | .. note:: 10 | Each mini-library is meant to be an *inspiration*, but not a prescription. 11 | 12 | To serve that purpose, it is best to keep their code samples minimal; so we 13 | generally **will not merge PRs** adding new features. Instead, please send your 14 | lovely pull requests and design ideas to more fully-featured libraries like 15 | `Haiku`_ or `Flax`_. 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | 21 | jax.example_libraries.optimizers 22 | jax.example_libraries.stax 23 | 24 | .. automodule:: jax.example_libraries 25 | 26 | 27 | .. _Haiku: https://github.com/deepmind/dm-haiku 28 | .. _Flax: https://github.com/google/flax 29 | -------------------------------------------------------------------------------- /jax/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.util import ( 16 | HashableFunction as HashableFunction, 17 | as_hashable_function as as_hashable_function, 18 | cache as cache, 19 | safe_map as safe_map, 20 | safe_zip as safe_zip, 21 | split_dict as split_dict, 22 | split_list as split_list, 23 | split_merge as split_merge, 24 | subvals as subvals, 25 | toposort as toposort, 26 | unzip2 as unzip2, 27 | wrap_name as wrap_name, 28 | wraps as wraps, 29 | ) 30 | -------------------------------------------------------------------------------- /jax/lax/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.lax.linalg import ( 16 | cholesky, 17 | cholesky_p, 18 | eig, 19 | eig_p, 20 | eigh, 21 | eigh_p, 22 | lu, 23 | lu_p, 24 | lu_pivots_to_permutation, 25 | qr, 26 | qr_p, 27 | svd, 28 | svd_p, 29 | triangular_solve, 30 | triangular_solve_p, 31 | tridiagonal_solve, 32 | tridiagonal_solve_p, 33 | schur, 34 | schur_p 35 | ) 36 | 37 | 38 | from jax._src.lax.qdwh import ( 39 | qdwh as qdwh 40 | ) 41 | -------------------------------------------------------------------------------- /jax/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.profiler import ( 16 | StepTraceAnnotation as StepTraceAnnotation, 17 | TraceAnnotation as TraceAnnotation, 18 | device_memory_profile as device_memory_profile, 19 | save_device_memory_profile as save_device_memory_profile, 20 | start_server as start_server, 21 | stop_server as stop_server, 22 | start_trace as start_trace, 23 | stop_trace as stop_trace, 24 | trace as trace, 25 | annotate_function as annotate_function, 26 | ) 27 | -------------------------------------------------------------------------------- /jax/ipu/primitive/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os.path 15 | 16 | from .cppimport_utils import cppimport_append_include_dirs 17 | from .ipu_custom_primitive_utils import ipu_mlir_lowering_custom_primitive, PrimitiveMetadata 18 | from .xla_utils import dtype_to_primitive_type, dtype_to_tf_datatype_enum, xla_shape_to_aval 19 | 20 | # Update default `cppimport` library dirs to include headers in this directory. 21 | cppimport_append_include_dirs([os.path.dirname(__file__)]) 22 | -------------------------------------------------------------------------------- /jax/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file is included as part of both jax and jaxlib. It is also 16 | # eval()-ed by setup.py, so it should not have any dependencies. 17 | 18 | __version__ = "0.3.16" 19 | _minimum_jaxlib_version = "0.3.14" 20 | 21 | def _version_as_tuple(version_str): 22 | return tuple(int(i) for i in version_str.split(".") if i.isdigit()) 23 | 24 | __version_info__ = _version_as_tuple(__version__) 25 | _minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version) 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Install the pre-commit hooks below with 2 | # 'pre-commit install' 3 | 4 | # Auto-update the version of the hooks with 5 | # 'pre-commit autoupdate' 6 | 7 | # Run the hooks on all files with 8 | # 'pre-commit run --all' 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/mirrors-clang-format 12 | rev: 'v15.0.7' 13 | hooks: 14 | - id: clang-format 15 | # Automatic formatting only on IPU files. 16 | files: .*/ipu 17 | 18 | - repo: https://github.com/google/yapf 19 | rev: v0.32.0 20 | hooks: 21 | - id: yapf 22 | additional_dependencies: [toml] 23 | # Automatic formatting only on IPU files. 24 | files: .*/ipu 25 | 26 | - repo: https://github.com/pycqa/flake8 27 | rev: '4.0.1' 28 | hooks: 29 | - id: flake8 30 | 31 | - repo: https://github.com/pre-commit/mirrors-mypy 32 | rev: 'v0.942' 33 | hooks: 34 | - id: mypy 35 | files: jax/ 36 | additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5] 37 | 38 | - repo: https://github.com/mwouts/jupytext 39 | rev: v1.13.8 40 | hooks: 41 | - id: jupytext 42 | args: [--sync] 43 | -------------------------------------------------------------------------------- /docs/jax.random.rst: -------------------------------------------------------------------------------- 1 | jax.random package 2 | ================== 3 | 4 | .. automodule:: jax.random 5 | 6 | List of Available Functions 7 | --------------------------- 8 | 9 | .. Generate the list below as follows: 10 | >>> from jax import random 11 | >>> fns = (x for x in sorted(dir(random)) if x != 'threefry_2x32') 12 | >>> fns = (x for x in fns if callable(getattr(random, x))) 13 | >>> print('\n'.join(' ' + x for x in fns)) # doctest: +SKIP 14 | 15 | .. autosummary:: 16 | :toctree: _autosummary 17 | 18 | PRNGKey 19 | ball 20 | bernoulli 21 | beta 22 | categorical 23 | cauchy 24 | choice 25 | dirichlet 26 | double_sided_maxwell 27 | exponential 28 | fold_in 29 | gamma 30 | generalized_normal 31 | gumbel 32 | laplace 33 | loggamma 34 | logistic 35 | maxwell 36 | multivariate_normal 37 | normal 38 | orthogonal 39 | pareto 40 | permutation 41 | poisson 42 | rademacher 43 | randint 44 | shuffle 45 | split 46 | t 47 | truncated_normal 48 | uniform 49 | weibull_min 50 | 51 | -------------------------------------------------------------------------------- /jax/dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.dtypes import ( 16 | _jax_types, # TODO(phawkins): fix users and remove? 17 | bfloat16 as bfloat16, 18 | canonicalize_dtype as canonicalize_dtype, 19 | finfo, # TODO(phawkins): switch callers to jnp.finfo? 20 | float0 as float0, 21 | iinfo, # TODO(phawkins): switch callers to jnp.iinfo? 22 | issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? 23 | result_type as result_type, 24 | scalar_type_of as scalar_type_of, 25 | ) 26 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | error 4 | ignore:No GPU/TPU found, falling back to CPU.:UserWarning 5 | ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning 6 | # xmap 7 | ignore:xmap is an experimental feature and probably has bugs! 8 | # The rest are for experimental/jax_to_tf 9 | ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning 10 | ignore:can't resolve package from __spec__ or __package__:ImportWarning 11 | ignore:Using or importing the ABCs.*:DeprecationWarning 12 | # jax2tf tests due to mix of JAX and TF 13 | ignore:numpy.ufunc size changed 14 | ignore:.*experimental feature 15 | ignore:index.*is deprecated.*:DeprecationWarning 16 | ignore:jax.experimental.* is deprecated, import jax.example_libraries.* instead:FutureWarning 17 | # numpy uses distutils which is deprecated 18 | ignore:The distutils.* is deprecated.*:DeprecationWarning 19 | ignore:`sharded_jit` is deprecated. Please use `pjit` instead.*:DeprecationWarning 20 | doctest_optionflags = NUMBER NORMALIZE_WHITESPACE 21 | addopts = --doctest-glob="*.rst" 22 | -------------------------------------------------------------------------------- /jax/experimental/checkify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.checkify import ( 16 | Error as Error, 17 | ErrorCategory as ErrorCategory, 18 | all_checks as all_checks, 19 | automatic_checks as automatic_checks, 20 | check as check, 21 | check_error as check_error, 22 | checkify as checkify, 23 | div_checks as div_checks, 24 | float_checks as float_checks, 25 | index_checks as index_checks, 26 | init_error as init_error, 27 | nan_checks as nan_checks, 28 | user_checks as user_checks, 29 | ) 30 | -------------------------------------------------------------------------------- /jax/custom_derivatives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.custom_derivatives import ( 16 | _initial_style_jaxpr, 17 | _sum_tangents, 18 | _zeros_like_pytree, 19 | closure_convert as closure_convert, 20 | custom_gradient as custom_gradient, 21 | custom_jvp as custom_jvp, 22 | custom_jvp_call_p as custom_jvp_call_p, 23 | custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, 24 | custom_vjp as custom_vjp, 25 | custom_vjp_call_p as custom_vjp_call_p, 26 | custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, 27 | linear_call as linear_call, 28 | ) 29 | -------------------------------------------------------------------------------- /jaxlib/pocketfft.fbs: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | namespace jax; 17 | 18 | enum PocketFftDtype : byte { 19 | COMPLEX64 = 0, 20 | COMPLEX128 = 1, 21 | } 22 | 23 | enum PocketFftType : byte { 24 | C2C = 0, 25 | C2R = 1, 26 | R2C = 2, 27 | } 28 | 29 | table PocketFftDescriptor { 30 | dtype:PocketFftDtype; 31 | fft_type:PocketFftType; 32 | shape:[uint64]; 33 | strides_in:[uint64]; 34 | strides_out:[uint64]; 35 | axes:[uint32]; 36 | forward:bool; 37 | scale:double; 38 | } 39 | 40 | root_type PocketFftDescriptor; 41 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zaid Alyafeai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 3 | # Script to be sourced on launch of the Gradient Notebook 4 | 5 | DETECTED_NUMBER_OF_IPUS=$(python .gradient/available_ipus.py) 6 | if [[ "$1" == "test" ]]; then 7 | IPU_ARG="${DETECTED_NUMBER_OF_IPUS}" 8 | else 9 | IPU_ARG=${1:-"${DETECTED_NUMBER_OF_IPUS}"} 10 | fi 11 | 12 | export NUM_AVAILABLE_IPU=${IPU_ARG} 13 | export GRAPHCORE_POD_TYPE="pod${IPU_ARG}" 14 | 15 | export POPLAR_EXECUTABLE_CACHE_DIR="/tmp/exe_cache" 16 | export DATASET_DIR="/tmp/dataset_cache" 17 | export CHECKPOINT_DIR="/tmp/checkpoints" 18 | 19 | # mounted public dataset directory (path in the container) 20 | # in the Paperspace environment this would be ="/datasets" 21 | export PUBLIC_DATASET_DIR="/datasets" 22 | 23 | export POPTORCH_CACHE_DIR="${POPLAR_EXECUTABLE_CACHE_DIR}" 24 | export POPTORCH_LOG_LEVEL=ERR 25 | export RDMAV_FORK_SAFE=1 26 | 27 | export PIP_DISABLE_PIP_VERSION_CHECK=1 CACHE_DIR=/tmp 28 | jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True \ 29 | --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True \ 30 | --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True 31 | -------------------------------------------------------------------------------- /tests/benchmarks/xla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import enum 17 | import pytest 18 | import numpy as np 19 | 20 | from jax import numpy as jnp 21 | from jax.interpreters import xla 22 | 23 | 24 | class AnEnum(enum.IntEnum): 25 | A = 123 26 | B = 456 27 | 28 | 29 | _abstractify_args = [ 30 | 3, 31 | 3.5, 32 | np.int32(3), 33 | np.uint32(7), 34 | np.random.randn(3, 4, 5, 6), 35 | np.arange(100, dtype=np.float32), 36 | jnp.int64(-3), 37 | jnp.array([1, 2, 3]), 38 | AnEnum.B, 39 | ] 40 | 41 | @pytest.mark.parametrize("arg", _abstractify_args) 42 | def test_abstractify(benchmark, arg): 43 | benchmark(xla.abstractify, arg) 44 | -------------------------------------------------------------------------------- /.github/workflows/jax-ipu-release-public.yaml: -------------------------------------------------------------------------------- 1 | name: CI_jax_ipu_release_public 2 | 3 | env: 4 | GIT_MAIN_BRANCH: "jax-v0.3.16-ipu" 5 | 6 | # Controls when the workflow will run. 7 | on: 8 | # Only main IPU branch 9 | push: 10 | branches: [ "jax-v0.3.16-ipu" ] 11 | pull_request: 12 | branches: [ "jax-v0.3.16-ipu" ] 13 | types: 14 | - closed 15 | release: 16 | types: [edited, deleted, published] 17 | 18 | # Allows you to run this workflow manually from the Actions tab. 19 | workflow_dispatch: 20 | 21 | jobs: 22 | public_pages: 23 | if: github.repository == 'graphcore-research/jax-experimental' 24 | runs-on: ubuntu-latest 25 | timeout-minutes: 10 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Set up Python 3.8 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: 3.8 32 | - name: Install dependencies 33 | run: | 34 | pip install github3.py 35 | - name: Build HTML wheels page 36 | run: | 37 | python ./build/ipu/generate_wheels_html.py 38 | ls _site/ 39 | - name: Publish pages 40 | uses: Cecilapp/GitHub-Pages-deploy@v3 41 | env: { GITHUB_TOKEN: "${{ github.token }}" } 42 | with: 43 | build_dir: _site 44 | -------------------------------------------------------------------------------- /jax/tools/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load( 16 | "//jaxlib:jax.bzl", 17 | "py_deps", 18 | ) 19 | 20 | licenses(["notice"]) 21 | 22 | package(default_visibility = ["//visibility:public"]) 23 | 24 | py_library( 25 | name = "jax_to_ir", 26 | srcs = ["jax_to_ir.py"], 27 | tags = [ 28 | "ignore_for_dep=third_party.py.jax.experimental.jax2tf", 29 | "ignore_for_dep=third_party.py.tensorflow", 30 | ], 31 | deps = [ 32 | "//jax", 33 | ], 34 | ) 35 | 36 | py_library( 37 | name = "jax_to_ir_with_tensorflow", 38 | srcs = ["jax_to_ir.py"], 39 | deps = [ 40 | "//jax", 41 | "//jax/experimental/jax2tf", 42 | ] + py_deps("tensorflow"), 43 | ) 44 | -------------------------------------------------------------------------------- /tests/xla_interpreter_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | 17 | import jax 18 | from jax._src import test_util as jtu 19 | from jax._src import dispatch 20 | 21 | 22 | class XlaInterpreterTest(jtu.JaxTestCase): 23 | 24 | def test_prune_jit_args(self): 25 | def f(*args): 26 | return args[0] 27 | 28 | closed_jaxpr = jax.make_jaxpr(f)(*range(10)) 29 | pruned_jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs( 30 | closed_jaxpr.jaxpr) 31 | assert len(pruned_jaxpr.invars) == 1 32 | assert kept_const_idx == set() 33 | assert kept_var_idx == {0} 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main(testLoader=jtu.JaxTestLoader()) 38 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/expon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | 17 | from jax import lax 18 | from jax._src.numpy.util import _wraps 19 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf 20 | 21 | 22 | @_wraps(osp_stats.expon.logpdf, update_doc=False) 23 | def logpdf(x, loc=0, scale=1): 24 | x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale) 25 | log_scale = lax.log(scale) 26 | linear_term = lax.div(lax.sub(x, loc), scale) 27 | log_probs = lax.neg(lax.add(linear_term, log_scale)) 28 | return where(lax.lt(x, loc), -inf, log_probs) 29 | 30 | @_wraps(osp_stats.expon.pdf, update_doc=False) 31 | def pdf(x, loc=0, scale=1): 32 | return lax.exp(logpdf(x, loc, scale)) 33 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/gennorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | from jax import lax 17 | from jax._src.numpy.util import _wraps 18 | from jax._src.numpy.lax_numpy import _promote_args_inexact 19 | 20 | @_wraps(osp_stats.gennorm.logpdf, update_doc=False) 21 | def logpdf(x, p): 22 | x, p = _promote_args_inexact("gennorm.logpdf", x, p) 23 | return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p 24 | 25 | @_wraps(osp_stats.gennorm.cdf, update_doc=False) 26 | def cdf(x, p): 27 | x, p = _promote_args_inexact("gennorm.cdf", x, p) 28 | return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p)) 29 | 30 | @_wraps(osp_stats.gennorm.pdf, update_doc=False) 31 | def pdf(x, p): 32 | return lax.exp(logpdf(x, p)) 33 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/uniform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.numpy.util import _wraps 20 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or 21 | 22 | 23 | @_wraps(osp_stats.uniform.logpdf, update_doc=False) 24 | def logpdf(x, loc=0, scale=1): 25 | x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale) 26 | log_probs = lax.neg(lax.log(scale)) 27 | return where(logical_or(lax.gt(x, lax.add(loc, scale)), 28 | lax.lt(x, loc)), 29 | -inf, log_probs) 30 | 31 | @_wraps(osp_stats.uniform.pdf, update_doc=False) 32 | def pdf(x, loc=0, scale=1): 33 | return lax.exp(logpdf(x, loc, scale)) 34 | -------------------------------------------------------------------------------- /third_party/pocketfft/workspace.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Bazel workspace for PocketFFT.""" 16 | 17 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 18 | 19 | def repo(): 20 | http_archive( 21 | name = "pocketfft", 22 | sha256 = "66eda977b195965d27aeb9d74f46e0029a6a02e75fbbc47bb554aad68615a260", 23 | strip_prefix = "pocketfft-f800d91ba695b6e19ae2687dd60366900b928002", 24 | urls = [ 25 | "https://github.com/mreineck/pocketfft/archive/f800d91ba695b6e19ae2687dd60366900b928002.tar.gz", 26 | "https://storage.googleapis.com/jax-releases/mirror/pocketfft/pocketfft-f800d91ba695b6e19ae2687dd60366900b928002.tar.gz", 27 | ], 28 | build_file = "@//third_party/pocketfft:BUILD.bazel", 29 | ) 30 | -------------------------------------------------------------------------------- /jax/_src/lib/mlir/dialects/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # flake8: noqa: F401 16 | import jaxlib.mlir.dialects.builtin as builtin 17 | import jaxlib.mlir.dialects.chlo as chlo 18 | import jaxlib.mlir.dialects.mhlo as mhlo 19 | import jaxlib.mlir.dialects.func as func 20 | 21 | try: 22 | import jaxlib.mlir.dialects.ml_program as ml_program 23 | except (ModuleNotFoundError, ImportError): 24 | # TODO(phawkins): make this unconditional when jaxlib > 0.3.14 25 | # is the minimum version. 26 | pass 27 | 28 | try: 29 | import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor 30 | except (ModuleNotFoundError, ImportError): 31 | # TODO(ajcbik,phawkins): make this unconditional when jaxlib > 0.3.7 32 | # is the minimum version. 33 | pass 34 | 35 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/geom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | 17 | from jax import lax 18 | from jax._src.lax.lax import _const as _lax_const 19 | from jax._src.numpy import lax_numpy as jnp 20 | from jax._src.numpy.util import _wraps 21 | from jax.scipy.special import xlog1py 22 | 23 | @_wraps(osp_stats.geom.logpmf, update_doc=False) 24 | def logpmf(k, p, loc=0): 25 | k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) 26 | zero = _lax_const(k, 0) 27 | one = _lax_const(k, 1) 28 | x = lax.sub(k, loc) 29 | log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) 30 | return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) 31 | 32 | @_wraps(osp_stats.geom.pmf, update_doc=False) 33 | def pmf(k, p, loc=0): 34 | return jnp.exp(logpmf(k, p, loc)) 35 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load( 16 | "//jaxlib:jax.bzl", 17 | "jax2tf_deps", 18 | "py_deps", 19 | ) 20 | 21 | licenses(["notice"]) # Apache 2 22 | 23 | package( 24 | default_visibility = ["//visibility:private"], 25 | ) 26 | 27 | py_library( 28 | name = "jax2tf", 29 | srcs = ["__init__.py"], 30 | srcs_version = "PY3", 31 | visibility = ["//visibility:public"], 32 | deps = [":jax2tf_internal"], 33 | ) 34 | 35 | py_library( 36 | name = "jax2tf_internal", 37 | srcs = [ 38 | "call_tf.py", 39 | "impl_no_xla.py", 40 | "jax2tf.py", 41 | "shape_poly.py", 42 | ], 43 | srcs_version = "PY3", 44 | deps = [ 45 | "//jax", 46 | ] + py_deps("numpy") + py_deps("tensorflow") + jax2tf_deps, 47 | ) 48 | -------------------------------------------------------------------------------- /jax/scipy/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.linalg import ( 16 | block_diag as block_diag, 17 | cholesky as cholesky, 18 | cho_factor as cho_factor, 19 | cho_solve as cho_solve, 20 | det as det, 21 | eigh as eigh, 22 | eigh_tridiagonal as eigh_tridiagonal, 23 | expm as expm, 24 | expm_frechet as expm_frechet, 25 | inv as inv, 26 | lu as lu, 27 | lu_factor as lu_factor, 28 | lu_solve as lu_solve, 29 | polar as polar, 30 | polar_unitary as polar_unitary, 31 | qr as qr, 32 | rsf2csf as rsf2csf, 33 | schur as schur, 34 | sqrtm as sqrtm, 35 | solve as solve, 36 | solve_triangular as solve_triangular, 37 | svd as svd, 38 | tril as tril, 39 | triu as triu, 40 | ) 41 | 42 | from jax._src.third_party.scipy.linalg import ( 43 | funm as funm, 44 | ) 45 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/tf_js/quickdraw/third_party/zaidalyafeai.github.io/class_names.txt: -------------------------------------------------------------------------------- 1 | screwdriver 2 | wristwatch 3 | butterfly 4 | sword 5 | cat 6 | shorts 7 | eyeglasses 8 | lollipop 9 | baseball 10 | traffic_light 11 | sun 12 | helmet 13 | bridge 14 | alarm_clock 15 | drums 16 | book 17 | broom 18 | fan 19 | scissors 20 | cloud 21 | tent 22 | clock 23 | headphones 24 | bicycle 25 | stop_sign 26 | table 27 | donut 28 | umbrella 29 | smiley_face 30 | pillow 31 | bed 32 | saw 33 | light_bulb 34 | shovel 35 | bird 36 | syringe 37 | coffee_cup 38 | moon 39 | ice_cream 40 | moustache 41 | cell_phone 42 | pants 43 | anvil 44 | radio 45 | chair 46 | star 47 | door 48 | face 49 | mushroom 50 | tree 51 | rifle 52 | camera 53 | lightning 54 | flower 55 | basketball 56 | wheel 57 | hammer 58 | hat 59 | knife 60 | diving_board 61 | square 62 | cup 63 | mountain 64 | apple 65 | spoon 66 | key 67 | pencil 68 | line 69 | ladder 70 | triangle 71 | t-shirt 72 | dumbbell 73 | microphone 74 | snake 75 | sock 76 | suitcase 77 | laptop 78 | paper_clip 79 | rainbow 80 | candle 81 | bread 82 | spider 83 | envelope 84 | circle 85 | power_outlet 86 | tooth 87 | hot_dog 88 | frying_pan 89 | bench 90 | ceiling_fan 91 | tennis_racquet 92 | car 93 | beard 94 | axe 95 | baseball_bat 96 | pizza 97 | grapes 98 | eye 99 | cookie 100 | airplane 101 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/cauchy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import scipy.stats as osp_stats 18 | 19 | from jax import lax 20 | from jax._src.lax.lax import _const as _lax_const 21 | from jax._src.numpy.util import _wraps 22 | from jax._src.numpy.lax_numpy import _promote_args_inexact 23 | 24 | 25 | @_wraps(osp_stats.cauchy.logpdf, update_doc=False) 26 | def logpdf(x, loc=0, scale=1): 27 | x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) 28 | pi = _lax_const(x, np.pi) 29 | scaled_x = lax.div(lax.sub(x, loc), scale) 30 | normalize_term = lax.log(lax.mul(pi, scale)) 31 | return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) 32 | 33 | @_wraps(osp_stats.cauchy.pdf, update_doc=False) 34 | def pdf(x, loc=0, scale=1): 35 | return lax.exp(logpdf(x, loc, scale)) 36 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/converters_eval/test_models/flax/cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Convolutional Neural Network from Flax MNIST example, see: 15 | 16 | https://github.com/google/flax/tree/main/examples/mnist 17 | """ 18 | 19 | from flax import linen as nn 20 | 21 | 22 | class CNN(nn.Module): 23 | """A simple CNN model.""" 24 | 25 | @nn.compact 26 | def __call__(self, x): 27 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 28 | x = nn.relu(x) 29 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 30 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 31 | x = nn.relu(x) 32 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 33 | x = x.reshape((x.shape[0], -1)) # flatten 34 | x = nn.Dense(features=256)(x) 35 | x = nn.relu(x) 36 | x = nn.Dense(features=10)(x) 37 | return x 38 | -------------------------------------------------------------------------------- /jax/scipy/special.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.scipy.special import ( 16 | betainc as betainc, 17 | betaln as betaln, 18 | digamma as digamma, 19 | entr as entr, 20 | erf as erf, 21 | erfc as erfc, 22 | erfinv as erfinv, 23 | exp1 as exp1, 24 | expi as expi, 25 | expit as expit, 26 | expn as expn, 27 | gammainc as gammainc, 28 | gammaincc as gammaincc, 29 | gammaln as gammaln, 30 | i0 as i0, 31 | i0e as i0e, 32 | i1 as i1, 33 | i1e as i1e, 34 | logit as logit, 35 | logsumexp as logsumexp, 36 | lpmn as lpmn, 37 | lpmn_values as lpmn_values, 38 | multigammaln as multigammaln, 39 | log_ndtr as log_ndtr, 40 | ndtr as ndtr, 41 | ndtri as ndtri, 42 | polygamma as polygamma, 43 | sph_harm as sph_harm, 44 | xlogy as xlogy, 45 | xlog1py as xlog1py, 46 | zeta as zeta, 47 | ) 48 | -------------------------------------------------------------------------------- /tests/filecheck/names.filecheck.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Tests for naming of modules when lowering JAX into MLIR. 16 | 17 | # RUN: %PYTHON %s | FileCheck %s 18 | 19 | from absl import app 20 | 21 | import jax 22 | from jax import lax 23 | import numpy as np 24 | 25 | from jax.tests.filecheck.jax_filecheck_helpers import print_ir 26 | 27 | jax.config.update("jax_enable_x64", True) 28 | 29 | 30 | def main(_): 31 | # CHECK-LABEL: TEST: neg int32[7] 32 | # CHECK: module @jit_neg 33 | # CHECK: func public @main 34 | print_ir(np.empty([7], np.int32))(lax.neg) 35 | 36 | # CHECK-LABEL: TEST: foo int32[7] 37 | # CHECK: module @jit_foo 38 | # CHECK: func public @main 39 | @print_ir(np.empty([7], np.int32)) 40 | @jax.jit 41 | def foo(x): return x + 2 42 | 43 | 44 | if __name__ == "__main__": 45 | app.run(main) 46 | -------------------------------------------------------------------------------- /jaxlib/rocm/hip_prng_kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_HIP_PRNG_KERNELS_H_ 17 | #define JAXLIB_HIP_PRNG_KERNELS_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "rocm/include/hip/hip_runtime_api.h" 23 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 24 | 25 | namespace jax { 26 | 27 | struct ThreeFry2x32Descriptor { 28 | std::int64_t n; 29 | }; 30 | 31 | void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers, 32 | ThreeFry2x32Descriptor descriptor); 33 | 34 | void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque, 35 | size_t opaque_len, XlaCustomCallStatus* status); 36 | 37 | } // namespace jax 38 | 39 | #endif // JAXLIB_HIP_PRNG_KERNELS_H_ -------------------------------------------------------------------------------- /tests/filecheck/jax_filecheck_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Helpers for writing JAX filecheck tests. 16 | 17 | import jax 18 | import jax.tree_util as tree_util 19 | import numpy as np 20 | 21 | def print_ir(*prototypes): 22 | def lower(f): 23 | """Prints the MHLO IR that results from lowering `f`. 24 | 25 | The arguments to `f` are taken to be arrays shaped like `prototypes`.""" 26 | inputs = tree_util.tree_map(np.array, prototypes) 27 | flat_inputs, _ = tree_util.tree_flatten(inputs) 28 | shape_strs = " ".join([f"{x.dtype.name}[{','.join(map(str, x.shape))}]" 29 | for x in flat_inputs]) 30 | name = f.func.__name__ if hasattr(f, "func") else f.__name__ 31 | print(f"\nTEST: {name} {shape_strs}") 32 | print(jax.jit(f).lower(*inputs).compiler_ir(dialect="mhlo")) 33 | return lower 34 | -------------------------------------------------------------------------------- /tests/clear_backends_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for release_backend_clients.""" 15 | 16 | from absl.testing import absltest 17 | 18 | import jax 19 | from jax.config import config 20 | from jax._src import test_util as jtu 21 | from jax._src.lib import xla_bridge as xb 22 | from jax._src.lib import xla_client as xc 23 | 24 | config.parse_flags_with_absl() 25 | 26 | 27 | class ClearBackendsTest(jtu.JaxTestCase): 28 | 29 | def test_clear_backends(self): 30 | g = jax.jit(lambda x, y: x * y) 31 | self.assertEqual(g(1, 2), 2) 32 | if xc._version >= 79: 33 | self.assertNotEmpty(xb.get_backend().live_executables()) 34 | jax.clear_backends() 35 | self.assertEmpty(xb.get_backend().live_executables()) 36 | self.assertEqual(g(1, 2), 2) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main(testLoader=jtu.JaxTestLoader()) 41 | -------------------------------------------------------------------------------- /tests/heap_profiler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | from absl.testing import absltest 17 | 18 | import jax 19 | import jax._src.lib.xla_bridge 20 | from jax.config import config 21 | import jax._src.test_util as jtu 22 | 23 | 24 | config.parse_flags_with_absl() 25 | 26 | 27 | class HeapProfilerTest(unittest.TestCase): 28 | # These tests simply test that the heap profiler API does not crash; they do 29 | # not check functional correctness. 30 | 31 | def testBasics(self): 32 | client = jax._src.lib.xla_bridge.get_backend() 33 | _ = client.heap_profile() 34 | 35 | a = jax.device_put(1) 36 | _ = client.heap_profile() 37 | 38 | # Heap profiler doesn't crash with deleted buffer 39 | a.delete() 40 | _ = client.heap_profile() 41 | 42 | if __name__ == "__main__": 43 | absltest.main(testLoader=jtu.JaxTestLoader()) 44 | -------------------------------------------------------------------------------- /jax/nn/initializers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Common neural network layer initializers, consistent with definitions 17 | used in Keras and Sonnet. 18 | """ 19 | 20 | from jax._src.nn.initializers import ( 21 | constant as constant, 22 | Initializer as Initializer, 23 | delta_orthogonal as delta_orthogonal, 24 | glorot_normal as glorot_normal, 25 | glorot_uniform as glorot_uniform, 26 | he_normal as he_normal, 27 | he_uniform as he_uniform, 28 | kaiming_normal as kaiming_normal, 29 | kaiming_uniform as kaiming_uniform, 30 | lecun_normal as lecun_normal, 31 | lecun_uniform as lecun_uniform, 32 | normal as normal, 33 | ones as ones, 34 | orthogonal as orthogonal, 35 | uniform as uniform, 36 | variance_scaling as variance_scaling, 37 | xavier_normal as xavier_normal, 38 | xavier_uniform as xavier_uniform, 39 | zeros as zeros, 40 | ) 41 | -------------------------------------------------------------------------------- /jaxlib/cuda/cuda_prng_kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_CUDA_PRNG_KERNELS_H_ 17 | #define JAXLIB_CUDA_PRNG_KERNELS_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "third_party/gpus/cuda/include/cuda_runtime_api.h" 23 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 24 | 25 | namespace jax { 26 | 27 | struct ThreeFry2x32Descriptor { 28 | std::int64_t n; 29 | }; 30 | 31 | void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers, 32 | ThreeFry2x32Descriptor descriptor); 33 | 34 | void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, 35 | size_t opaque_len, XlaCustomCallStatus* status); 36 | 37 | } // namespace jax 38 | 39 | #endif // JAXLIB_CUDA_PRNG_KERNELS_H_ 40 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/bernoulli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy import lax_numpy as jnp 21 | from jax._src.numpy.util import _wraps 22 | from jax.scipy.special import xlogy, xlog1py 23 | 24 | 25 | @_wraps(osp_stats.bernoulli.logpmf, update_doc=False) 26 | def logpmf(k, p, loc=0): 27 | k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) 28 | zero = _lax_const(k, 0) 29 | one = _lax_const(k, 1) 30 | x = lax.sub(k, loc) 31 | log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) 32 | return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), 33 | -jnp.inf, log_probs) 34 | 35 | @_wraps(osp_stats.bernoulli.pmf, update_doc=False) 36 | def pmf(k, p, loc=0): 37 | return jnp.exp(logpmf(k, p, loc)) 38 | -------------------------------------------------------------------------------- /jax/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common functions for neural network libraries.""" 16 | 17 | from jax.numpy import tanh as tanh 18 | from jax.nn import initializers as initializers 19 | from jax._src.nn.functions import ( 20 | celu as celu, 21 | elu as elu, 22 | gelu as gelu, 23 | glu as glu, 24 | hard_sigmoid as hard_sigmoid, 25 | hard_silu as hard_silu, 26 | hard_swish as hard_swish, 27 | hard_tanh as hard_tanh, 28 | leaky_relu as leaky_relu, 29 | log_sigmoid as log_sigmoid, 30 | log_softmax as log_softmax, 31 | logsumexp as logsumexp, 32 | normalize as normalize, 33 | standardize as standardize, 34 | one_hot as one_hot, 35 | relu as relu, 36 | relu6 as relu6, 37 | selu as selu, 38 | sigmoid as sigmoid, 39 | soft_sign as soft_sign, 40 | softmax as softmax, 41 | softplus as softplus, 42 | silu as silu, 43 | swish as swish, 44 | ) 45 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/pareto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy.util import _wraps 21 | from jax._src.numpy.lax_numpy import _promote_args_inexact, inf, where 22 | 23 | 24 | @_wraps(osp_stats.pareto.logpdf, update_doc=False) 25 | def logpdf(x, b, loc=0, scale=1): 26 | x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) 27 | one = _lax_const(x, 1) 28 | scaled_x = lax.div(lax.sub(x, loc), scale) 29 | normalize_term = lax.log(lax.div(scale, b)) 30 | log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) 31 | return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs) 32 | 33 | @_wraps(osp_stats.pareto.pdf, update_doc=False) 34 | def pdf(x, b, loc=0, scale=1): 35 | return lax.exp(logpdf(x, b, loc, scale)) 36 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/gamma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | 17 | from jax import lax 18 | from jax._src.lax.lax import _const as _lax_const 19 | from jax._src.numpy.util import _wraps 20 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf 21 | from jax.scipy.special import gammaln, xlogy 22 | 23 | 24 | @_wraps(osp_stats.gamma.logpdf, update_doc=False) 25 | def logpdf(x, a, loc=0, scale=1): 26 | x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale) 27 | one = _lax_const(x, 1) 28 | y = lax.div(lax.sub(x, loc), scale) 29 | log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) 30 | shape_terms = lax.add(gammaln(a), lax.log(scale)) 31 | log_probs = lax.sub(log_linear_term, shape_terms) 32 | return where(lax.lt(x, loc), -inf, log_probs) 33 | 34 | @_wraps(osp_stats.gamma.pdf, update_doc=False) 35 | def pdf(x, a, loc=0, scale=1): 36 | return lax.exp(logpdf(x, a, loc, scale)) 37 | -------------------------------------------------------------------------------- /jaxlib/rocm/hip_prng.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "jaxlib/rocm/hip_prng_kernels.h" 17 | 18 | #include "jaxlib/rocm/hip_gpu_kernel_helpers.h" 19 | #include "jaxlib/kernel_pybind11_helpers.h" 20 | #include "include/pybind11/pybind11.h" 21 | 22 | namespace jax { 23 | namespace { 24 | 25 | std::string BuildHipThreeFry2x32Descriptor(std::int64_t n) { 26 | return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); 27 | } 28 | pybind11::dict Registrations() { 29 | pybind11::dict dict; 30 | dict["hip_threefry2x32"] = EncapsulateFunction(HipThreeFry2x32); 31 | return dict; 32 | } 33 | 34 | PYBIND11_MODULE(_hip_prng, m) { 35 | m.def("registrations", &Registrations); 36 | m.def("threefry2x32_descriptor", [](std::int64_t n) { 37 | std::string result = BuildHipThreeFry2x32Descriptor(n); 38 | return pybind11::bytes(result); 39 | }); 40 | } 41 | 42 | } // namespace 43 | } // namespace jax 44 | -------------------------------------------------------------------------------- /jaxlib/cuda/cuda_prng.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "jaxlib/cuda/cuda_prng_kernels.h" 17 | 18 | #include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" 19 | #include "jaxlib/kernel_pybind11_helpers.h" 20 | #include "include/pybind11/pybind11.h" 21 | 22 | namespace jax { 23 | namespace { 24 | 25 | std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) { 26 | return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); 27 | } 28 | pybind11::dict Registrations() { 29 | pybind11::dict dict; 30 | dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32); 31 | return dict; 32 | } 33 | 34 | PYBIND11_MODULE(_cuda_prng, m) { 35 | m.def("registrations", &Registrations); 36 | m.def("threefry2x32_descriptor", [](std::int64_t n) { 37 | std::string result = BuildCudaThreeFry2x32Descriptor(n); 38 | return pybind11::bytes(result); 39 | }); 40 | } 41 | 42 | } // namespace 43 | } // namespace jax 44 | -------------------------------------------------------------------------------- /jaxlib/rocm/hip_lu_pivot_kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_HIP_LU_PIVOT_KERNELS_H_ 17 | #define JAXLIB_HIP_LU_PIVOT_KERNELS_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "rocm/include/hip/hip_runtime_api.h" 23 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 24 | 25 | namespace jax { 26 | 27 | struct LuPivotsToPermutationDescriptor { 28 | std::int64_t batch_size; 29 | std::int32_t pivot_size; 30 | std::int32_t permutation_size; 31 | }; 32 | 33 | void LaunchLuPivotsToPermutationKernel( 34 | hipStream_t stream, void** buffers, 35 | LuPivotsToPermutationDescriptor descriptor); 36 | 37 | void HipLuPivotsToPermutation(hipStream_t stream, void** buffers, 38 | const char* opaque, size_t opaque_len, 39 | XlaCustomCallStatus* status); 40 | 41 | } // namespace jax 42 | 43 | #endif // JAXLIB_HIP_LU_PIVOT_KERNELS_H_ -------------------------------------------------------------------------------- /jaxlib/cuda/cuda_lu_pivot_kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ 17 | #define JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "third_party/gpus/cuda/include/cuda_runtime_api.h" 23 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 24 | 25 | namespace jax { 26 | 27 | struct LuPivotsToPermutationDescriptor { 28 | std::int64_t batch_size; 29 | std::int32_t pivot_size; 30 | std::int32_t permutation_size; 31 | }; 32 | 33 | void LaunchLuPivotsToPermutationKernel( 34 | cudaStream_t stream, void** buffers, 35 | LuPivotsToPermutationDescriptor descriptor); 36 | 37 | void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, 38 | const char* opaque, size_t opaque_len, 39 | XlaCustomCallStatus* status); 40 | 41 | } // namespace jax 42 | 43 | #endif // JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ 44 | -------------------------------------------------------------------------------- /jax/numpy/fft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.numpy.fft import ( 16 | ifft as ifft, 17 | ifft2 as ifft2, 18 | ifftn as ifftn, 19 | ifftshift as ifftshift, 20 | ihfft as ihfft, 21 | irfft as irfft, 22 | irfft2 as irfft2, 23 | irfftn as irfftn, 24 | fft as fft, 25 | fft2 as fft2, 26 | fftfreq as fftfreq, 27 | fftn as fftn, 28 | fftshift as fftshift, 29 | hfft as hfft, 30 | rfft as rfft, 31 | rfft2 as rfft2, 32 | rfftfreq as rfftfreq, 33 | rfftn as rfftn, 34 | ) 35 | 36 | # Module initialization is encapsulated in a function to avoid accidental 37 | # namespace pollution. 38 | _NOT_IMPLEMENTED = [] 39 | def _init(): 40 | import numpy as np 41 | from jax._src.numpy import lax_numpy 42 | from jax._src import util 43 | # Builds a set of all unimplemented NumPy functions. 44 | for name, func in util.get_module_functions(np.fft).items(): 45 | if name not in globals(): 46 | _NOT_IMPLEMENTED.append(name) 47 | globals()[name] = lax_numpy._not_implemented(func) 48 | 49 | _init() 50 | del _init 51 | -------------------------------------------------------------------------------- /tests/third_party/scipy/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /examples/jax_cpp/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | load( 16 | "@org_tensorflow//tensorflow:tensorflow.bzl", 17 | "tf_cc_binary", 18 | ) 19 | 20 | licenses(["notice"]) 21 | 22 | tf_cc_binary( 23 | name = "main", 24 | srcs = ["main.cc"], 25 | tags = ["manual"], 26 | deps = [ 27 | "@org_tensorflow//tensorflow/compiler/xla:literal", 28 | "@org_tensorflow//tensorflow/compiler/xla:literal_util", 29 | "@org_tensorflow//tensorflow/compiler/xla:shape_util", 30 | "@org_tensorflow//tensorflow/compiler/xla:status", 31 | "@org_tensorflow//tensorflow/compiler/xla:statusor", 32 | "@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device", 33 | "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client", 34 | "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc", 35 | "@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader", 36 | "@org_tensorflow//tensorflow/core/platform:logging", 37 | "@org_tensorflow//tensorflow/core/platform:platform_port", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /jax/_src/third_party/numpy/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2005-2019, NumPy Developers. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of the NumPy Developers nor the names of any 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /jax/scipy/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax.scipy.stats import bernoulli as bernoulli 16 | from jax.scipy.stats import beta as beta 17 | from jax.scipy.stats import cauchy as cauchy 18 | from jax.scipy.stats import dirichlet as dirichlet 19 | from jax.scipy.stats import expon as expon 20 | from jax.scipy.stats import gamma as gamma 21 | from jax.scipy.stats import geom as geom 22 | from jax.scipy.stats import laplace as laplace 23 | from jax.scipy.stats import logistic as logistic 24 | from jax.scipy.stats import multivariate_normal as multivariate_normal 25 | from jax.scipy.stats import nbinom as nbinom 26 | from jax.scipy.stats import norm as norm 27 | from jax.scipy.stats import pareto as pareto 28 | from jax.scipy.stats import poisson as poisson 29 | from jax.scipy.stats import t as t 30 | from jax.scipy.stats import uniform as uniform 31 | from jax.scipy.stats import chi2 as chi2 32 | from jax.scipy.stats import betabinom as betabinom 33 | from jax.scipy.stats import gennorm as gennorm 34 | from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde 35 | -------------------------------------------------------------------------------- /jax/_src/third_party/scipy/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/chi2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy.util import _wraps 21 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf 22 | 23 | 24 | @_wraps(osp_stats.chi2.logpdf, update_doc=False) 25 | def logpdf(x, df, loc=0, scale=1): 26 | x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) 27 | one = _lax_const(x, 1) 28 | two = _lax_const(x, 2) 29 | y = lax.div(lax.sub(x, loc), scale) 30 | df_on_two = lax.div(df, two) 31 | 32 | kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) 33 | 34 | nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) 35 | 36 | log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) 37 | return where(lax.lt(x, loc), -inf, log_probs) 38 | 39 | @_wraps(osp_stats.chi2.pdf, update_doc=False) 40 | def pdf(x, df, loc=0, scale=1): 41 | return lax.exp(logpdf(x, df, loc, scale)) 42 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/poisson.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy.util import _wraps 21 | from jax._src.numpy import lax_numpy as jnp 22 | from jax.scipy.special import xlogy, gammaln, gammaincc 23 | 24 | 25 | @_wraps(osp_stats.poisson.logpmf, update_doc=False) 26 | def logpmf(k, mu, loc=0): 27 | k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) 28 | zero = _lax_const(k, 0) 29 | x = lax.sub(k, loc) 30 | log_probs = xlogy(x, mu) - gammaln(x + 1) - mu 31 | return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs) 32 | 33 | @_wraps(osp_stats.poisson.pmf, update_doc=False) 34 | def pmf(k, mu, loc=0): 35 | return jnp.exp(logpmf(k, mu, loc)) 36 | 37 | @_wraps(osp_stats.poisson.cdf, update_doc=False) 38 | def cdf(k, mu, loc=0): 39 | k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) 40 | zero = _lax_const(k, 0) 41 | x = lax.sub(k, loc) 42 | p = gammaincc(jnp.floor(1 + x), mu) 43 | return jnp.where(lax.lt(x, zero), zero, p) 44 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | description: >- 4 | Report a bug or unexpected behavior to help us improve the package 5 | title: 'BUG: ' 6 | labels: 7 | - bug 8 | 9 | body: 10 | - type: markdown 11 | attributes: 12 | value: > 13 | ## Thank you for helping us improve JAX! 14 | 15 | * Please first verify that your issue is not already reported using the 16 | [Issue search][issue search]. 17 | 18 | * If you are asking a question or seeking support, please 19 | consider [starting a discussion][Discussions]. 20 | 21 | * If you prefer a non-templated issue report, click [here][Raw report]. 22 | 23 | 24 | [Discussions]: https://github.com/google/jax/discussions 25 | 26 | [issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues 27 | 28 | [Raw report]: http://github.com/google/jax/issues/new 29 | - type: textarea 30 | attributes: 31 | label: Description 32 | description: >- 33 | A concise description of the bug, preferably including self-contained 34 | code to reproduce the issue. 35 | placeholder: | 36 | Text may use markdown formatting. 37 | ```python 38 | # for codeblocks, use triple backticks 39 | ``` 40 | validations: 41 | required: true 42 | - type: input 43 | attributes: 44 | label: What jax/jaxlib version are you using? 45 | placeholder: For example jax v0.3.0, jaxlib v0.3.0 46 | - type: input 47 | attributes: 48 | label: Which accelerator(s) are you using? 49 | placeholder: CPU/GPU/TPU 50 | - type: input 51 | attributes: 52 | label: Additional System Info 53 | placeholder: Python version, OS (Linux/Mac/Windows/WSL), etc. 54 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/beta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | 17 | from jax import lax 18 | from jax._src.lax.lax import _const as _lax_const 19 | from jax._src.numpy.util import _wraps 20 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or 21 | from jax.scipy.special import betaln, xlogy, xlog1py 22 | 23 | 24 | @_wraps(osp_stats.beta.logpdf, update_doc=False) 25 | def logpdf(x, a, b, loc=0, scale=1): 26 | x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) 27 | one = _lax_const(x, 1) 28 | shape_term = lax.neg(betaln(a, b)) 29 | y = lax.div(lax.sub(x, loc), scale) 30 | log_linear_term = lax.add(xlogy(lax.sub(a, one), y), 31 | xlog1py(lax.sub(b, one), lax.neg(y))) 32 | log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) 33 | return where(logical_or(lax.gt(x, lax.add(loc, scale)), 34 | lax.lt(x, loc)), -inf, log_probs) 35 | 36 | @_wraps(osp_stats.beta.pdf, update_doc=False) 37 | def pdf(x, a, b, loc=0, scale=1): 38 | return lax.exp(logpdf(x, a, b, loc, scale)) 39 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/nbinom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | 16 | import scipy.stats as osp_stats 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf 21 | from jax._src.numpy.util import _wraps 22 | from jax._src.scipy.special import gammaln, xlogy 23 | 24 | 25 | @_wraps(osp_stats.nbinom.logpmf, update_doc=False) 26 | def logpmf(k, n, p, loc=0): 27 | """JAX implementation of scipy.stats.nbinom.logpmf.""" 28 | k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc) 29 | one = _lax_const(k, 1) 30 | y = lax.sub(k, loc) 31 | comb_term = lax.sub( 32 | lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one)) 33 | ) 34 | log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p))) 35 | log_probs = lax.add(comb_term, log_linear_term) 36 | return where(lax.lt(k, loc), -inf, log_probs) 37 | 38 | 39 | @_wraps(osp_stats.nbinom.pmf, update_doc=False) 40 | def pmf(k, n, p, loc=0): 41 | """JAX implementation of scipy.stats.nbinom.pmf.""" 42 | return lax.exp(logpmf(k, n, p, loc)) 43 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/logistic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | from jax.scipy.special import expit, logit 17 | 18 | from jax import lax 19 | from jax._src.lax.lax import _const as _lax_const 20 | from jax._src.numpy.util import _wraps 21 | from jax._src.numpy.lax_numpy import _promote_args_inexact 22 | from jax._src.numpy import lax_numpy as jnp 23 | 24 | 25 | @_wraps(osp_stats.logistic.logpdf, update_doc=False) 26 | def logpdf(x): 27 | x, = _promote_args_inexact("logistic.logpdf", x) 28 | two = _lax_const(x, 2) 29 | half_x = lax.div(x, two) 30 | return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))) 31 | 32 | @_wraps(osp_stats.logistic.pdf, update_doc=False) 33 | def pdf(x): 34 | return lax.exp(logpdf(x)) 35 | 36 | @_wraps(osp_stats.logistic.ppf, update_doc=False) 37 | def ppf(x): 38 | return logit(x) 39 | 40 | @_wraps(osp_stats.logistic.sf, update_doc=False) 41 | def sf(x): 42 | return expit(lax.neg(x)) 43 | 44 | @_wraps(osp_stats.logistic.isf, update_doc=False) 45 | def isf(x): 46 | return -logit(x) 47 | 48 | @_wraps(osp_stats.logistic.cdf, update_doc=False) 49 | def cdf(x): 50 | return expit(x) 51 | -------------------------------------------------------------------------------- /jax/ipu/random/prng.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .ipu_random_primitive import ipu_threefry2x32_lowering 15 | 16 | from jax.interpreters import mlir 17 | from jax._src.lib.mlir.dialects import mhlo 18 | from jax._src.prng import threefry2x32_p 19 | 20 | 21 | def _ipu_threefry2x32_broadcast_lowering(ctx, k1, k2, x1, x2): 22 | aval_out, _ = ctx.avals_out 23 | k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in 24 | rank = len(aval_out.shape) 25 | if 0 in aval_out.shape: 26 | zeros = mlir.full_like_aval(0, aval_out) 27 | return [zeros, zeros] 28 | 29 | def _broadcast(x, aval): 30 | return mhlo.BroadcastInDimOp( 31 | mlir.aval_to_ir_type(aval_out), x, 32 | mlir.dense_int_elements(range(rank - len(aval.shape), rank)) 33 | ).result 34 | 35 | ctx = ctx.replace(avals_in=[aval_out] * 4) 36 | # TODO: optimize in the case of scalar keys. 37 | return ipu_threefry2x32_lowering( 38 | ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), _broadcast(x1, x1_aval), 39 | _broadcast(x2, x2_aval) 40 | ) 41 | 42 | 43 | mlir.register_lowering( 44 | threefry2x32_p, _ipu_threefry2x32_broadcast_lowering, platform="ipu" 45 | ) 46 | -------------------------------------------------------------------------------- /jax/numpy/linalg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax._src.numpy.linalg import ( 16 | cholesky as cholesky, 17 | det as det, 18 | eig as eig, 19 | eigh as eigh, 20 | eigvals as eigvals, 21 | eigvalsh as eigvalsh, 22 | inv as inv, 23 | lstsq as lstsq, 24 | matrix_power as matrix_power, 25 | matrix_rank as matrix_rank, 26 | norm as norm, 27 | pinv as pinv, 28 | qr as qr, 29 | slogdet as slogdet, 30 | solve as solve, 31 | svd as svd, 32 | ) 33 | from jax._src.third_party.numpy.linalg import ( 34 | cond as cond, 35 | multi_dot as multi_dot, 36 | tensorinv as tensorinv, 37 | tensorsolve as tensorsolve, 38 | ) 39 | 40 | # Module initialization is encapsulated in a function to avoid accidental 41 | # namespace pollution. 42 | _NOT_IMPLEMENTED = [] 43 | def _init(): 44 | import numpy as np 45 | from jax._src.numpy import lax_numpy 46 | from jax._src import util 47 | # Builds a set of all unimplemented NumPy functions. 48 | for name, func in util.get_module_functions(np.linalg).items(): 49 | if name not in globals(): 50 | _NOT_IMPLEMENTED.append(name) 51 | globals()[name] = lax_numpy._not_implemented(func) 52 | 53 | _init() 54 | del _init 55 | -------------------------------------------------------------------------------- /tests/stack_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Tests for stack.""" 17 | 18 | from absl.testing import absltest 19 | 20 | import jax.numpy as jnp 21 | from jax._src.lax.stack import Stack 22 | from jax._src import test_util as jtu 23 | 24 | 25 | from jax.config import config 26 | config.parse_flags_with_absl() 27 | 28 | 29 | class StackTest(jtu.JaxTestCase): 30 | 31 | def test_empty(self): 32 | stack = Stack.create(7, jnp.zeros((), jnp.int32)) 33 | self.assertTrue(stack.empty()) 34 | 35 | def test_pushes_and_pops(self): 36 | stack = Stack.create(7, jnp.zeros((), jnp.int32)) 37 | stack = stack.push(jnp.int32(7)) 38 | self.assertFalse(stack.empty()) 39 | stack = stack.push(jnp.int32(8)) 40 | self.assertFalse(stack.empty()) 41 | x, stack = stack.pop() 42 | self.assertFalse(stack.empty()) 43 | self.assertEqual(8, x) 44 | stack = stack.push(jnp.int32(9)) 45 | x, stack = stack.pop() 46 | self.assertFalse(stack.empty()) 47 | self.assertEqual(9, x) 48 | x, stack = stack.pop() 49 | self.assertTrue(stack.empty()) 50 | self.assertEqual(7, x) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main(testLoader=jtu.JaxTestLoader()) 55 | -------------------------------------------------------------------------------- /jax/tools/colab_tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for running JAX on Cloud TPUs via Colab.""" 16 | 17 | import requests 18 | import os 19 | 20 | from jax.config import config 21 | 22 | TPU_DRIVER_MODE = 0 23 | 24 | 25 | def setup_tpu(tpu_driver_version='tpu_driver_nightly'): 26 | """Sets up Colab to run on TPU. 27 | 28 | Note: make sure the Colab Runtime is set to Accelerator: TPU. 29 | 30 | Args 31 | ---- 32 | tpu_driver_version : (str) specify the version identifier for the tpu driver. 33 | Defaults to "tpu_driver_nightly". Occasionally the nightly release contains bugs, 34 | in which case a workaround is to use a known working version from a previous date, 35 | for example "tpu_driver-0.1dev20211031". 36 | """ 37 | global TPU_DRIVER_MODE 38 | 39 | if not TPU_DRIVER_MODE: 40 | colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] 41 | url = f'http://{colab_tpu_addr}:8475/requestversion/{tpu_driver_version}' 42 | requests.post(url) 43 | TPU_DRIVER_MODE = 1 44 | 45 | # The following is required to use TPU Driver as JAX's backend. 46 | config.FLAGS.jax_xla_backend = "tpu_driver" 47 | config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] 48 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/laplace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import scipy.stats as osp_stats 16 | 17 | from jax import lax 18 | from jax._src.lax.lax import _const as _lax_const 19 | from jax._src.numpy.util import _wraps 20 | from jax._src.numpy.lax_numpy import _promote_args_inexact 21 | 22 | 23 | @_wraps(osp_stats.laplace.logpdf, update_doc=False) 24 | def logpdf(x, loc=0, scale=1): 25 | x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale) 26 | two = _lax_const(x, 2) 27 | linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale) 28 | return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale)))) 29 | 30 | @_wraps(osp_stats.laplace.pdf, update_doc=False) 31 | def pdf(x, loc=0, scale=1): 32 | return lax.exp(logpdf(x, loc, scale)) 33 | 34 | @_wraps(osp_stats.laplace.cdf, update_doc=False) 35 | def cdf(x, loc=0, scale=1): 36 | x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale) 37 | half = _lax_const(x, 0.5) 38 | one = _lax_const(x, 1) 39 | zero = _lax_const(x, 0) 40 | diff = lax.div(lax.sub(x, loc), scale) 41 | return lax.select(lax.le(diff, zero), 42 | lax.mul(half, lax.exp(diff)), 43 | lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff))))) 44 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code 6 | extension-pkg-whitelist=numpy 7 | 8 | 9 | [MESSAGES CONTROL] 10 | 11 | # Disable the message, report, category or checker with the given id(s). You 12 | # can either give multiple identifiers separated by comma (,) or put this 13 | # option multiple times (only on the command line, not in the configuration 14 | # file where it should appear only once).You can also use "--disable=all" to 15 | # disable everything first and then reenable specific checks. For example, if 16 | # you want to run only the similarities checker, you can use "--disable=all 17 | # --enable=similarities". If you want to run only the classes checker, but have 18 | # no Warning level messages displayed, use"--disable=all --enable=classes 19 | # --disable=W" 20 | disable=missing-docstring, 21 | too-many-locals, 22 | invalid-name, 23 | redefined-outer-name, 24 | redefined-builtin, 25 | protected-name, 26 | no-else-return, 27 | fixme, 28 | protected-access, 29 | too-many-arguments, 30 | blacklisted-name, 31 | too-few-public-methods, 32 | unnecessary-lambda, 33 | 34 | 35 | # Enable the message, report, category or checker with the given id(s). You can 36 | # either give multiple identifier separated by comma (,) or put this option 37 | # multiple time (only on the command line, not in the configuration file where 38 | # it should appear only once). See also the "--disable" option for examples. 39 | enable=c-extension-no-member 40 | 41 | 42 | [FORMAT] 43 | 44 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 45 | # tab). 46 | indent-string=" " -------------------------------------------------------------------------------- /docs/rank_promotion_warning.rst: -------------------------------------------------------------------------------- 1 | Rank promotion warning 2 | ====================== 3 | 4 | `NumPy broadcasting rules 5 | `_ 6 | allow the automatic promotion of arguments from one rank (number of array axes) 7 | to another. This behavior can be convenient when intended but can also lead to 8 | surprising bugs where a silent rank promotion masks an underlying shape error. 9 | 10 | Here's an example of rank promotion: 11 | 12 | >>> import numpy as np 13 | >>> x = np.arange(12).reshape(4, 3) 14 | >>> y = np.array([0, 1, 0]) 15 | >>> x + y 16 | array([[ 0, 2, 2], 17 | [ 3, 5, 5], 18 | [ 6, 8, 8], 19 | [ 9, 11, 11]]) 20 | 21 | To avoid potential surprises, :code:`jax.numpy` is configurable so that 22 | expressions requiring rank promotion can lead to a warning, error, or can be 23 | allowed just like regular NumPy. The configuration option is named 24 | :code:`jax_numpy_rank_promotion` and it can take on string values 25 | :code:`allow`, :code:`warn`, and :code:`raise`. The default setting is 26 | :code:`warn`, which raises a warning on the first occurrence of rank promotion. 27 | The :code:`raise` setting raises an error on rank promotion, and :code:`allow` 28 | allows rank promotion without warning or error. 29 | 30 | As with most other JAX configuration options, you can set this option in 31 | several ways. One is by using :code:`jax.config` in your code: 32 | 33 | .. code-block:: python 34 | 35 | from jax.config import config 36 | config.update("jax_numpy_rank_promotion", "allow") 37 | 38 | You can also set the option using the environment variable 39 | :code:`JAX_NUMPY_RANK_PROMOTION`, for example as 40 | :code:`JAX_NUMPY_RANK_PROMOTION='raise'`. Finally, when using :code:`absl-py` 41 | the option can be set with a command-line flag. 42 | -------------------------------------------------------------------------------- /jax/_src/scipy/stats/t.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import scipy.stats as osp_stats 18 | 19 | from jax import lax 20 | from jax._src.lax.lax import _const as _lax_const 21 | from jax._src.numpy.util import _wraps 22 | from jax._src.numpy.lax_numpy import _promote_args_inexact 23 | 24 | 25 | @_wraps(osp_stats.t.logpdf, update_doc=False) 26 | def logpdf(x, df, loc=0, scale=1): 27 | x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) 28 | two = _lax_const(x, 2) 29 | scaled_x = lax.div(lax.sub(x, loc), scale) 30 | df_over_two = lax.div(df, two) 31 | df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) 32 | normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) 33 | normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) 34 | normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), 35 | lax.lgamma(df_plus_one_over_two)) 36 | quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) 37 | return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) 38 | 39 | @_wraps(osp_stats.t.pdf, update_doc=False) 40 | def pdf(x, df, loc=0, scale=1): 41 | return lax.exp(logpdf(x, df, loc, scale)) 42 | -------------------------------------------------------------------------------- /jax/experimental/jax2tf/examples/keras_reuse_main_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from absl import flags 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from jax._src import test_util as jtu 20 | from jax.config import config 21 | 22 | from jax.experimental.jax2tf.examples import keras_reuse_main 23 | from jax.experimental.jax2tf.tests import tf_test_util 24 | 25 | config.parse_flags_with_absl() 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class KerasReuseMainTest(tf_test_util.JaxToTfTestCase): 30 | 31 | def setUp(self): 32 | super().setUp() 33 | FLAGS.model_path = os.path.join(absltest.get_default_test_tmpdir(), 34 | "saved_models") 35 | FLAGS.num_epochs = 1 36 | FLAGS.test_savedmodel = True 37 | FLAGS.mock_data = True 38 | FLAGS.show_images = False 39 | FLAGS.serving_batch_size = 1 40 | 41 | @parameterized.named_parameters( 42 | dict(testcase_name=f"_{model}", model=model) 43 | for model in ["mnist_pure_jax", "mnist_flax"]) 44 | def test_keras_reuse(self, model="mnist_pure_jax"): 45 | FLAGS.model = model 46 | keras_reuse_main.main(None) 47 | 48 | 49 | if __name__ == "__main__": 50 | absltest.main(testLoader=jtu.JaxTestLoader()) 51 | -------------------------------------------------------------------------------- /jaxlib/rocm/hip_prng_kernels.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "jaxlib/rocm/hip_prng_kernels.h" 17 | 18 | #include 19 | 20 | #include "jaxlib/rocm/hip_gpu_kernel_helpers.h" 21 | #include "jaxlib/kernel_helpers.h" 22 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 23 | 24 | namespace jax { 25 | namespace { 26 | 27 | absl::Status HipThreeFry2x32_(hipStream_t stream, void** buffers, 28 | const char* opaque, std::size_t opaque_len) { 29 | auto s = UnpackDescriptor(opaque, opaque_len); 30 | JAX_RETURN_IF_ERROR(s.status()); 31 | LaunchThreeFry2x32Kernel(stream, buffers, **s); 32 | JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError())); 33 | return absl::OkStatus(); 34 | } 35 | 36 | } // namespace 37 | 38 | void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque, 39 | size_t opaque_len, XlaCustomCallStatus* status) { 40 | auto s = HipThreeFry2x32_(stream, buffers, opaque, opaque_len); 41 | if (!s.ok()) { 42 | std::string_view message = s.message(); 43 | XlaCustomCallStatusSetFailure(status, message.data(), message.length()); 44 | } 45 | } 46 | 47 | } // namespace jax 48 | -------------------------------------------------------------------------------- /jaxlib/cuda/cuda_prng_kernels.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "jaxlib/cuda/cuda_prng_kernels.h" 17 | 18 | #include 19 | 20 | #include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" 21 | #include "jaxlib/kernel_helpers.h" 22 | #include "tensorflow/compiler/xla/service/custom_call_status.h" 23 | 24 | namespace jax { 25 | namespace { 26 | 27 | absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers, 28 | const char* opaque, std::size_t opaque_len) { 29 | auto s = UnpackDescriptor(opaque, opaque_len); 30 | JAX_RETURN_IF_ERROR(s.status()); 31 | LaunchThreeFry2x32Kernel(stream, buffers, **s); 32 | JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError())); 33 | return absl::OkStatus(); 34 | } 35 | 36 | } // namespace 37 | 38 | void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, 39 | size_t opaque_len, XlaCustomCallStatus* status) { 40 | auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len); 41 | if (!s.ok()) { 42 | std::string_view message = s.message(); 43 | XlaCustomCallStatusSetFailure(status, message.data(), message.length()); 44 | } 45 | } 46 | 47 | } // namespace jax 48 | -------------------------------------------------------------------------------- /tests/remote_transfer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for cross host device transfer.""" 15 | 16 | from absl.testing import absltest 17 | import unittest 18 | import numpy as np 19 | 20 | import jax 21 | from jax._src import test_util as jtu 22 | 23 | from jax.config import config 24 | 25 | config.parse_flags_with_absl() 26 | 27 | 28 | class RemoteTransferTest(jtu.JaxTestCase): 29 | 30 | # TODO(jheek): this test crashes on multi-GPU. 31 | @jtu.skip_on_devices("gpu") 32 | def test_remote_transfer(self): 33 | if jax.device_count() < 2: 34 | raise unittest.SkipTest("Remote transfer requires at lest 2 devices") 35 | dev_a, dev_b = jax.local_devices()[:2] 36 | if "libtpu" in jax.local_devices()[0].client.platform_version: 37 | raise unittest.SkipTest("Test does not yet work on cloud TPU") 38 | send_buf = jax.device_put(np.ones((32,)), dev_a) 39 | shapes = [send_buf.xla_shape()] 40 | (tag, recv_buf), = dev_b.client.make_cross_host_receive_buffers( 41 | shapes, dev_b) 42 | status, dispatched = send_buf.copy_to_remote_device(tag) 43 | self.assertIsNone(status) 44 | self.assertTrue(dispatched) 45 | self.assertArraysEqual(send_buf, recv_buf) 46 | 47 | 48 | if __name__ == '__main__': 49 | absltest.main(testLoader=jtu.JaxTestLoader()) 50 | -------------------------------------------------------------------------------- /jaxlib/kernel_pybind11_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_KERNEL_PYBIND11_HELPERS_H_ 17 | #define JAXLIB_KERNEL_PYBIND11_HELPERS_H_ 18 | 19 | #include "absl/base/casts.h" 20 | #include "jaxlib/kernel_helpers.h" 21 | #include "include/pybind11/pybind11.h" 22 | 23 | namespace jax { 24 | 25 | // Descriptor objects are opaque host-side objects used to pass data from JAX 26 | // to the custom kernel launched by XLA. Currently simply treat host-side 27 | // structures as byte-strings; this is not portable across architectures. If 28 | // portability is needed, we could switch to using a representation such as 29 | // protocol buffers or flatbuffers. 30 | 31 | // Packs a descriptor object into a pybind11::bytes structure. 32 | // UnpackDescriptor() is available in kernel_helpers.h. 33 | template 34 | pybind11::bytes PackDescriptor(const T& descriptor) { 35 | return pybind11::bytes(PackDescriptorAsString(descriptor)); 36 | } 37 | 38 | template 39 | pybind11::capsule EncapsulateFunction(T* fn) { 40 | return pybind11::capsule(absl::bit_cast(fn), 41 | "xla._CUSTOM_CALL_TARGET"); 42 | } 43 | 44 | } // namespace jax 45 | 46 | #endif // JAXLIB_KERNEL_PYBIND11_HELPERS_H_ 47 | -------------------------------------------------------------------------------- /tests/ipu/primitive/custom_primitive_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from absl.testing import absltest, parameterized 3 | from jax._src import test_util as jtu 4 | 5 | import numpy as np 6 | import numpy.testing as npt 7 | 8 | import jax 9 | from ipu_custom_activation import custom_activation, CustomActivationPrimitive 10 | 11 | 12 | class IPUCustomPrimitiveTest(jtu.JaxTestCase): 13 | 14 | def setUp(self): 15 | super().setUp() 16 | # Check we have at least one IPU device to run these tests. 17 | assert len(jax.devices("ipu")) >= 1 18 | 19 | def test__custom_activation__metadata(self): 20 | # Proper metadata for the custom JAX IPU primitive. 21 | metadata = CustomActivationPrimitive.metadata(num_inputs=2) 22 | assert metadata.num_inputs == 2 23 | assert metadata.is_elementwise 24 | assert metadata.is_stateless 25 | assert metadata.is_hashable 26 | assert metadata.input_to_output_tensor_aliasing == {0: 0} 27 | assert len(metadata.allocating_indices) == 0 28 | assert len(metadata.replica_identical_output_indices) == 0 29 | 30 | def test__custom_activation__numpy_array_implementation(self): 31 | in0 = np.random.rand(2, 3, 4).astype(np.float32) 32 | in1 = np.random.rand(2, 3, 4).astype(np.float32) 33 | output = custom_activation(in0, in1) 34 | npt.assert_array_equal(output, np.abs(in0) * in1) 35 | 36 | @parameterized.parameters(["cpu", "ipu"]) 37 | def test__custom_activation__multi_backends_jitting(self, backend): 38 | in0 = np.random.rand(2, 3, 4).astype(np.float32) 39 | in1 = np.random.rand(2, 3, 4).astype(np.float32) 40 | 41 | custom_activation_jit = jax.jit(custom_activation, backend=backend) 42 | output = custom_activation_jit(in0, in1) 43 | npt.assert_array_equal(output, np.abs(in0) * in1) 44 | 45 | 46 | if __name__ == "__main__": 47 | absltest.main(testLoader=jtu.JaxTestLoader()) 48 | -------------------------------------------------------------------------------- /jaxlib/kernel_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef JAXLIB_KERNEL_HELPERS_H_ 17 | #define JAXLIB_KERNEL_HELPERS_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "absl/base/casts.h" 24 | #include "absl/status/statusor.h" 25 | 26 | namespace jax { 27 | 28 | // See kernel_pybind11_helpers.h for info on descriptor objects. We separate out 29 | // the functionality that doesn't require pybind11 for building CUDA libraries, 30 | // since older versions nvcc don't seem to be able to compile pybind11. 31 | 32 | // Packs a descriptor object into a byte string. 33 | template 34 | std::string PackDescriptorAsString(const T& descriptor) { 35 | return std::string(absl::bit_cast(&descriptor), sizeof(T)); 36 | } 37 | 38 | // Unpacks a descriptor object from a byte string. 39 | template 40 | absl::StatusOr UnpackDescriptor(const char* opaque, 41 | std::size_t opaque_len) { 42 | if (opaque_len != sizeof(T)) { 43 | return absl::InternalError("Invalid size for operation descriptor."); 44 | } 45 | return absl::bit_cast(opaque); 46 | } 47 | 48 | } // namespace jax 49 | 50 | #endif // JAXLIB_KERNEL_HELPERS_H_ 51 | -------------------------------------------------------------------------------- /jaxlib/cuda/cuda_linalg.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 Google LLC 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | 17 | #include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" 18 | #include "jaxlib/cuda/cuda_lu_pivot_kernels.h" 19 | #include "jaxlib/kernel_pybind11_helpers.h" 20 | #include "include/pybind11/pybind11.h" 21 | 22 | namespace jax { 23 | namespace { 24 | 25 | std::string BuildCudaLuPivotsToPermutationDescriptor( 26 | std::int64_t batch_size, std::int32_t pivot_size, 27 | std::int32_t permutation_size) { 28 | return PackDescriptorAsString(LuPivotsToPermutationDescriptor{ 29 | batch_size, pivot_size, permutation_size}); 30 | } 31 | 32 | pybind11::dict Registrations() { 33 | pybind11::dict dict; 34 | dict["cuda_lu_pivots_to_permutation"] = 35 | EncapsulateFunction(CudaLuPivotsToPermutation); 36 | return dict; 37 | } 38 | 39 | PYBIND11_MODULE(_cuda_linalg, m) { 40 | m.def("registrations", &Registrations); 41 | m.def("lu_pivots_to_permutation_descriptor", 42 | [](std::int64_t batch_size, std::int32_t pivot_size, 43 | std::int32_t permutation_size) { 44 | std::string result = BuildCudaLuPivotsToPermutationDescriptor( 45 | batch_size, pivot_size, permutation_size); 46 | return pybind11::bytes(result); 47 | }); 48 | } 49 | 50 | } // namespace 51 | } // namespace jax 52 | --------------------------------------------------------------------------------