├── .clang-format ├── .gitattributes ├── .github ├── action │ ├── Dockerfile │ └── entrypoint.sh └── workflows │ └── tests.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── demo.ipynb ├── lib ├── cpu_ops.cc ├── gpu_ops.cc ├── kepler.h ├── kernel_helpers.h ├── kernels.cc.cu ├── kernels.h └── pybind11_kernel_helpers.h ├── pyproject.toml ├── src └── kepler_jax │ ├── __init__.py │ └── kepler_jax.py └── tests └── test_kepler_jax.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | ColumnLimit: 99 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/action/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 2 | 3 | RUN apt-get update && \ 4 | DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip 5 | 6 | RUN pip install --upgrade pip && \ 7 | pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 8 | 9 | COPY entrypoint.sh /entrypoint.sh 10 | 11 | ENTRYPOINT ["/entrypoint.sh"] 12 | -------------------------------------------------------------------------------- /.github/action/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -l 2 | 3 | cd /github/workspace 4 | KEPLER_JAX_CUDA=yes python3 -m pip install -v . 5 | python3 -c 'import kepler_jax;print(kepler_jax.__version__)' 6 | python3 -c 'import kepler_jax.gpu_ops' 7 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | tests: 11 | name: ${{ matrix.os }} 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest, macos-latest] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: '3.10' 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install -U pip 31 | python -m pip install -v .[test] 32 | 33 | - name: Run tests 34 | run: python -m pytest -v tests 35 | 36 | cuda: 37 | name: CUDA 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v4 41 | with: 42 | fetch-depth: 0 43 | - uses: ./.github/action 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/**/*_version.py 2 | *.ipynb 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15...3.26) 2 | project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX) 3 | message(STATUS "Using CMake version: " ${CMAKE_VERSION}) 4 | 5 | # Find pybind11 6 | set(PYBIND11_NEWPYTHON ON) 7 | find_package(pybind11 CONFIG REQUIRED) 8 | 9 | include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) 10 | 11 | # CPU op library 12 | pybind11_add_module(cpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/cpu_ops.cc) 13 | install(TARGETS cpu_ops LIBRARY DESTINATION .) 14 | 15 | # Include the CUDA extensions if possible 16 | include(CheckLanguage) 17 | check_language(CUDA) 18 | 19 | if(CMAKE_CUDA_COMPILER) 20 | enable_language(CUDA) 21 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 22 | pybind11_add_module( 23 | gpu_ops 24 | ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu 25 | ${CMAKE_CURRENT_LIST_DIR}/lib/gpu_ops.cc) 26 | install(TARGETS gpu_ops LIBRARY DESTINATION .) 27 | else() 28 | message(STATUS "Building without CUDA") 29 | endif() 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dan Foreman-Mackey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE *.toml 2 | graft lib 3 | exclude .* 4 | prune .github 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > [!WARNING] 2 | > The JAX documentation now a supported interface for interfacing with C++ and 3 | > CUDA libraries. Check out [the official tutorial](https://jax.readthedocs.io/en/latest/ffi.html), 4 | > which should be preferred to the methods described here. 5 | 6 | # Extending JAX with custom C++ and CUDA code 7 | 8 | [![Tests](https://github.com/dfm/extending-jax/workflows/Tests/badge.svg)](https://github.com/dfm/extending-jax/actions?query=workflow%3ATests) 9 | 10 | This repository is meant as a tutorial demonstrating the infrastructure required 11 | to provide custom ops in JAX when you have an existing implementation in C++ 12 | and, optionally, CUDA. I originally wanted to write this as a blog post, but 13 | there's enough boilerplate code that I ended up deciding that it made more sense 14 | to just share it as a repo with the tutorial in the README, so here we are! 15 | 16 | The motivation for this is that in my work I want to use libraries like JAX to 17 | fit models to data in astrophysics. In these models, there is often at least one 18 | part of the model specification that is physically motivated and while there are 19 | generally existing implementations of these model elements, it is often 20 | inefficient or impractical to re-implement these as a high-level JAX function. 21 | Instead, I want to expose a well-tested and optimized implementation in C 22 | directly to JAX. In my work, this often includes things like iterative 23 | algorithms or special functions that are not well suited to implementation using 24 | JAX directly. 25 | 26 | So, as part of updating my [exoplanet](https://docs.exoplanet.codes) library to 27 | interface with JAX, I had to learn what infrastructure was required to support 28 | this use case, and since I couldn't find a tutorial that covered all the pieces 29 | that I needed in one place, I wanted to put this together. Pretty much 30 | everything that I'll talk about is covered in more detail somewhere else (even 31 | if that somewhere is just a comment in some source code), but hopefully this 32 | summary can point you in the right direction if you have a use case like this. 33 | 34 | **A warning**: I'm writing this in January 2021 (most recent update November 2023; see 35 | github for the full revision history) and much of what I'm talking about is based on 36 | essentially undocumented APIs that are likely to change. 37 | Furthermore, I'm not affiliated with the JAX project and I'm far from an expert 38 | so I'm sure there are wrong things that I say. I'll try to update this if I 39 | notice things changing or if I learn of issues, but no promises! So, MIT license 40 | and all that: use at your own risk. 41 | 42 | ## Related reading 43 | 44 | As I mentioned previously, this tutorial is built on a lot of existing 45 | literature and I won't reproduce all the details of those documents here, so I 46 | wanted to start by listing the key resources that I found useful: 47 | 48 | 1. The [How primitives work][jax-primitives] tutorial in the JAX documentation 49 | includes almost all the details about how to expose a custom op to JAX and 50 | spending some quality time with that tutorial is not wasted time. The only 51 | thing missing from that document is a description of how to use the XLA 52 | CustomCall interface. 53 | 54 | 2. Which brings us to the [XLA custom calls][xla-custom] documentation. This 55 | page is pretty telegraphic, but it includes a description of the interface 56 | that your custom call functions need to support. In particular, this is where 57 | the differences in interface between the CPU and GPU are described, including 58 | things like the "opaque" parameter and how multiple outputs are handled. 59 | 60 | 3. I originally learned how to write the pybind11 interface for an XLA custom 61 | call from the [danieljtait/jax_xla_adventures][xla-adventures] repository by 62 | Dan Tait on GitHub. Again, this doesn't include very many details, but that's 63 | really a benefit here because it really distills the infrastructure to a 64 | place where I could understand what was going on. 65 | 66 | 4. Finally, much of what I know about this topic, I learned from spelunking in 67 | the [jaxlib source code][jaxlib] on GitHub. That code is pretty readable and 68 | includes good comments most of the time so that's a good place to look if you 69 | get stuck since folks there might have already faced the issue. 70 | 71 | ## What is an "op" 72 | 73 | In frameworks like JAX (or Theano, or TensorFlow, or PyTorch, to name a few), 74 | models are defined as a collection of operations or "ops" that can be chained, 75 | fused, or differentiated in clever ways. For our purposes, an op defines a 76 | function that knows: 77 | 78 | 1. how the input and output parameter shapes and types are related, 79 | 2. how to compute the output from a set of inputs, and 80 | 3. how to propagate derivatives using the chain rule. 81 | 82 | There are a lot of choices about where you draw the lines around a single op and 83 | there will be tradeoffs in terms of performance, generality, ease of use, and 84 | other factors when making these decisions. In my experience, it is often best to 85 | define the minimal scope ops and then allow your framework of choice to combine 86 | it efficiently with the rest of your model, but there will always be counter 87 | examples. 88 | 89 | ## Our example application: solving Kepler's equation 90 | 91 | In this section I'll describe the application presented in this project. Feel 92 | free to skip this if you just want to get to the technical details. 93 | 94 | This project exposes a single jit-able and differentiable JAX operation to solve 95 | [Kepler's equation][keplers-equation], a tool that is used for computing 96 | gravitational orbits in astronomy. This is basically the "hello world" example 97 | that I use whenever learning about something like this. For example, I have 98 | previously written [about how to expose such an op when using Stan][stan-cpp]. 99 | The implementation used in that post and the one used here are not meant to be 100 | the most robust or efficient, but it is relatively simple and it exposes some of 101 | the interesting issues that one might face when writing custom JAX ops. If 102 | you're interested in the mathematical details, take a look at [my blog 103 | post][stan-cpp], but the key point for now is that this operation involves 104 | solving a transcendental equation, and in this tutorial we'll use a simple 105 | iterative method that you'll find in the [kepler.h][kepler-h] header file. Then, 106 | the derivatives of this operation can be evaluated using implicit 107 | differentiation. Unlike in the previously mentioned blog post, our operation 108 | will actually return the sine and cosine of the eccentric anomaly, since that's 109 | what most high performance versions of this function would return and because 110 | the way XLA handles ops with multiple outputs is a little funky. 111 | 112 | ## The cost/benefit analysis 113 | 114 | One important question to answer first is: "should I actually write a custom JAX 115 | extension?" If you're here, you've probably already thought about that, but I 116 | wanted to emphasize a few points to consider. 117 | 118 | 1. **Performance**: The main reason why you might want to implement a custom op 119 | for JAX is performance. JAX's JIT compiler can get great performance in a 120 | broad range of applications, but for some of the problems I work on, 121 | finely-tuned C++ can be much faster. In my experience, iterative algorithms, 122 | other special functions, or code with complicated logic are all examples of 123 | places where a custom op might greatly improve performance. I'm not always 124 | good at doing this, but it's probably worth benchmarking performance of a 125 | version of your code implemented directly in high-level JAX against your 126 | custom op. 127 | 128 | 2. **Autodiff**: One thing that is important to realize is that the extension 129 | that we write won't magically know how to propagate derivatives. Instead, 130 | we'll be required to provide a JAX interface for applying the chain rule to 131 | out op. In other words, if you're setting out to wrap that huge Fortran 132 | library that has been passed down through the generations, the payoff might 133 | not be as great as you hoped unless (a) the code already provides operations 134 | for propagating derivatives (in which case you JAX op probably won't support 135 | second and higher order differentiation), or (b) you can easily compute the 136 | differentiation rules using the algorithm that you already have (which is the 137 | case we have for our example application here). In my work, I try (sometimes 138 | unsuccessfully) to identify the minimum number and size of ops that I can get 139 | away with and then implement most of my models directly in JAX. In our demo 140 | application, for example, I could have chosen to make an XLA op generating a 141 | full radial velocity model, instead of just solving Kepler's equation, and 142 | that might (or might not) give better performance. But, the differentiation 143 | rules are _much_ simpler the way it is implemented. 144 | 145 | ## Summary of the relevant files 146 | 147 | The files in this repo come in three categories: 148 | 149 | 1. In the root directory, there are the standard packaging files like a 150 | `pyproject.toml`. Most of this setup is pretty standard, but 151 | I'll highlight some unique elements in the packaging section below. 152 | 153 | 2. Next, the `src/kepler_jax` directory is a Python module with the definition 154 | of our JAX primitive roughly following the JAX [How primitives 155 | work][jax-primitives] tutorial. 156 | 157 | 3. Finally, the C++ and CUDA code implementing our XLA op live in the `lib` 158 | directory. The `pybind11_kernel_helpers.h` and `kernel_helpers.h` headers are 159 | boilerplate necessary for building in the interface. The rest of the files 160 | include the code specific for this implementation, but I'll describe this in 161 | more detail below. 162 | 163 | ## Defining an XLA custom call on the CPU 164 | 165 | The algorithm for our example problem is is implemented in the `lib/kepler.h` 166 | header and I won't go into details about the algorithm here, but the main point 167 | is that this could be an implementation built on any external library that you 168 | can call from C++ and, if you want to support GPU usage, CUDA. That header file 169 | includes a single function `compute_eccentric_anomaly` with the following 170 | signature: 171 | 172 | ```c++ 173 | template 174 | void compute_eccentric_anomaly( 175 | const T& mean_anom, const T& ecc, T* sin_ecc_anom, T* cos_ecc_anom 176 | ); 177 | ``` 178 | 179 | This is the function that we want to expose to JAX. 180 | 181 | As described in the [XLA documentation][xla-custom], the signature for a CPU XLA 182 | custom call in C++ is: 183 | 184 | ```c++ 185 | void custom_call(void* out, const void** in); 186 | ``` 187 | 188 | where, as you might expect, the elements of `in` point to the input values. So, 189 | in our case, the inputs are an integer giving the dimension of the problem 190 | `size`, an array with the mean anomalies `mean_anomaly`, and an array of 191 | eccentricities `ecc`. Therefore, we might parse the input as follows: 192 | 193 | ```c++ 194 | #include // int64_t 195 | 196 | template 197 | void cpu_kepler(void *out, const void **in) { 198 | const std::int64_t size = *reinterpret_cast(in[0]); 199 | const T *mean_anom = reinterpret_cast(in[1]); 200 | const T *ecc = reinterpret_cast(in[2]); 201 | } 202 | ``` 203 | 204 | Here we have used a template so that we can support both single and double 205 | precision version of the op. 206 | 207 | The output parameter is somewhat more complicated. If your op only has one 208 | output, you would access it using 209 | 210 | ```c++ 211 | T *result = reinterpret_cast(out); 212 | ``` 213 | 214 | but when you have multiple outputs, things get a little hairy. In our example, 215 | we have two outputs, the sine `sin_ecc_anom` and cosine `cos_ecc_anom` of the 216 | eccentric anomaly. Therefore, our `out` parameter -- even though it looks like a 217 | `void*` -- is actually a `void**`! Therefore, we will access the output as 218 | follows: 219 | 220 | ```c++ 221 | template 222 | void cpu_kepler(void *out_tuple, const void **in) { 223 | // ... 224 | void **out = reinterpret_cast(out_tuple); 225 | T *sin_ecc_anom = reinterpret_cast(out[0]); 226 | T *cos_ecc_anom = reinterpret_cast(out[1]); 227 | } 228 | ``` 229 | 230 | Then finally, we actually apply the op and the full implementation, which you 231 | can find in `lib/cpu_ops.cc` is: 232 | 233 | ```c++ 234 | // lib/cpu_ops.cc 235 | #include 236 | 237 | template 238 | void cpu_kepler(void *out_tuple, const void **in) { 239 | const std::int64_t size = *reinterpret_cast(in[0]); 240 | const T *mean_anom = reinterpret_cast(in[1]); 241 | const T *ecc = reinterpret_cast(in[2]); 242 | 243 | void **out = reinterpret_cast(out_tuple); 244 | T *sin_ecc_anom = reinterpret_cast(out[0]); 245 | T *cos_ecc_anom = reinterpret_cast(out[1]); 246 | 247 | for (std::int64_t n = 0; n < size; ++n) { 248 | compute_eccentric_anomaly(mean_anom[n], ecc[n], sin_ecc_anom + n, cos_ecc_anom + n); 249 | } 250 | } 251 | ``` 252 | 253 | and that's it! 254 | 255 | ## Building & packaging for the CPU 256 | 257 | Now that we have an implementation of our XLA custom call target, we need to 258 | expose it to JAX. This is done by compiling a CPython module that wraps this 259 | function as a [`PyCapsule`][capsule] type. This can be done using pybind11, 260 | Cython, SWIG, or the Python C API directly, but for this example we'll use 261 | pybind11 since that's what I'm most familiar with. The [LAPACK ops in 262 | jaxlib][jaxlib-lapack] are implemented using Cython if you'd like to see an 263 | example of how to do that. 264 | 265 | Another choice that I've made is to use [scikit-build-core](scikit-build-core) 266 | and [CMake](https://cmake.org) to build the extensions. Another build option 267 | would be to use [bazel](https://bazel.build) to compile the code, like the JAX 268 | project, but I don't have any experience with it, so I decided to stick with 269 | what I know. _The key point is that we're just compiling a regular old Python 270 | module, so you can use whatever infrastructure you're familiar with!_ 271 | 272 | With these choices out of the way, the boilerplate code required to define the 273 | interface is, using the `cpu_kepler` function defined in the previous section as 274 | follows: 275 | 276 | ```c++ 277 | // lib/cpu_ops.cc 278 | #include 279 | 280 | // If you're looking for it, this function is actually implemented in 281 | // lib/pybind11_kernel_helpers.h 282 | template 283 | pybind11::capsule EncapsulateFunction(T* fn) { 284 | return pybind11::capsule((void*)fn, "xla._CUSTOM_CALL_TARGET"); 285 | } 286 | 287 | pybind11::dict Registrations() { 288 | pybind11::dict dict; 289 | dict["cpu_kepler_f32"] = EncapsulateFunction(cpu_kepler); 290 | dict["cpu_kepler_f64"] = EncapsulateFunction(cpu_kepler); 291 | return dict; 292 | } 293 | 294 | PYBIND11_MODULE(cpu_ops, m) { m.def("registrations", &Registrations); } 295 | ``` 296 | 297 | In this case, we're exporting a separate function for both single and double 298 | precision. Another option would be to pass the data type to the function and 299 | perform the dispatch logic directly in C++, but I find it cleaner to do it like 300 | this. 301 | 302 | With that out of the way, the actual build routine is defined in the following 303 | files: 304 | 305 | - In `./pyproject.toml`, we specify that `pybind11` and `scikit-build-core` are 306 | required build dependencies and that we'll use `scikit-build-core` as the 307 | build backend. 308 | 309 | - Then, `CMakeLists.txt` defines the build process for CMake using [pybind11's 310 | support for CMake builds][pybind11-cmake]. This will also, optionally, build 311 | the GPU ops as discussed below. 312 | 313 | With these files in place, we can now compile our XLA custom call ops using 314 | 315 | ```bash 316 | pip install . 317 | ``` 318 | 319 | The final thing that I wanted to reiterate in this section is that 320 | `kepler_jax.cpu_ops` is just a regular old CPython extension module, so anything 321 | that you already know about packaging C extensions or any other resources that 322 | you can find on that topic can be applied. This wasn't obvious when I first 323 | started learning about this so I definitely went down some rabbit holes that 324 | hopefully you can avoid. 325 | 326 | ## Exposing this op as a JAX primitive 327 | 328 | The main components that are required to now call our custom op from JAX are 329 | well covered by the [How primitives work][jax-primitives] tutorial, so I won't 330 | reproduce all of that here. Instead I'll summarize the key points and then 331 | provide the missing part. If you haven't already, you should definitely read 332 | that tutorial before getting started on this part. 333 | 334 | In summary, we will define a `jax.core.Primitive` object with an "abstract 335 | evaluation" rule (see `src/kepler_jax/kepler_jax.py` for all the details) 336 | following the primitives tutorial. Then, we'll add a "translation rule" and a 337 | "JVP rule". We're lucky in this case, and we don't need to add a "transpose 338 | rule". JAX can actually work that out automatically, since our primitive is not 339 | itself used in the calculation of the output tangents. This won't always be 340 | true, and the [How primitives work][jax-primitives] tutorial includes an example 341 | of what to do in that case. 342 | 343 | Before defining these rules, we need to register the custom call target with 344 | JAX. To do that, we import our compiled `cpu_ops` extension module from above 345 | and use the `registrations` dictionary that we defined: 346 | 347 | ```python 348 | from jax.lib import xla_client 349 | from kepler_jax import cpu_ops 350 | 351 | for _name, _value in cpu_ops.registrations().items(): 352 | xla_client.register_custom_call_target(_name, _value, platform="cpu") 353 | ``` 354 | 355 | Then, the **lowering rule** is defined roughly as follows (the one you'll 356 | find in the source code is a little more complicated since it supports both CPU 357 | and GPU translation): 358 | 359 | ```python 360 | # src/kepler_jax/kepler_jax.py 361 | import numpy as np 362 | from jax.interpreters import mlir 363 | from jaxlib.mhlo_helpers import custom_call 364 | 365 | def _kepler_lowering(ctx, mean_anom, ecc): 366 | 367 | # Checking that input types and shape agree 368 | assert mean_anom.type == ecc.type 369 | 370 | # Extract the numpy type of the inputs 371 | mean_anom_aval, ecc_aval = ctx.avals_in 372 | np_dtype = np.dtype(mean_anom_aval.dtype) 373 | 374 | # The inputs and outputs all have the same shape and memory layout 375 | # so let's predefine this specification 376 | dtype = mlir.ir.RankedTensorType(mean_anom.type) 377 | dims = dtype.shape 378 | layout = tuple(range(len(dims) - 1, -1, -1)) 379 | 380 | # The total size of the input is the product across dimensions 381 | size = np.prod(dims).astype(np.int64) 382 | 383 | # We dispatch a different call depending on the dtype 384 | if np_dtype == np.float32: 385 | op_name = "cpu_kepler_f32" 386 | elif np_dtype == np.float64: 387 | op_name = "cpu_kepler_f64" 388 | else: 389 | raise NotImplementedError(f"Unsupported dtype {np_dtype}") 390 | 391 | return custom_call( 392 | op_name, 393 | # Output types 394 | result_types=[dtype, dtype], 395 | # The inputs: 396 | operands=[mlir.ir_constant(size), mean_anom, ecc], 397 | # Layout specification: 398 | operand_layouts=[(), layout, layout], 399 | result_layouts=[layout, layout] 400 | ).results 401 | 402 | mlir.register_lowering( 403 | _kepler_prim, 404 | _kepler_lowering, 405 | platform="cpu") 406 | ``` 407 | 408 | There appears to be a lot going on here, but most of it is just type checking. 409 | The main meat of it is the `custom_call` function which is a thin convenience 410 | wrapper around the `mhlo.CustomCallOp` (documented 411 | [here](https://www.tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop)). 412 | Here's a summary of its arguments: 413 | 414 | - The first argument is the name that you gave your `PyCapsule` 415 | in the `registrations` dictionary in `lib/cpu_ops.cc`. You can check what 416 | names your capsules had by looking at `cpu_ops.registrations().keys()`. 417 | 418 | - Then, the two following arguments give the "type" of the outputs, and 419 | specify the input arguments (operands). In this context, a "type" is 420 | specified by a data type defining the size of each dimension (what I 421 | would normally call the shape), and the type of the array (e.g. float32). 422 | In this case, both our outputs have the same type/shape. 423 | 424 | - Finally, with the last two arguments, we specify the memory layout 425 | of both input and output buffers. 426 | 427 | It's worth remembering that we're expecting the first argument to our function 428 | to be the size of the arrays, and you'll see that that is included as a 429 | `mlir.ir_constant` parameter. 430 | 431 | I'm not going to talk about the **JVP rule** here since it's quite problem 432 | specific, but I've tried to comment the code reasonably thoroughly so check out 433 | the code in `src/kepler_jax/kepler_jax.py` if you're interested, and open an 434 | issue if anything isn't clear. 435 | 436 | ## Defining an XLA custom call on the GPU 437 | 438 | The custom call on the GPU isn't terribly different from the CPU version above, 439 | but the syntax is somewhat different and there's a heck of a lot more 440 | boilerplate required. Since we need to compile and link CUDA code, there are 441 | also a few more packaging steps, but we'll get to that in the next section. The 442 | description in this section is a little all over the place, but the key files to 443 | look at to get more info are (a) `lib/gpu_ops.cc` for the dispatch functions 444 | called from Python, and (b) `lib/kernels.cc.cu` for the CUDA code implementing 445 | the kernel. 446 | 447 | The signature for the GPU custom call is: 448 | 449 | ```c++ 450 | // lib/kernels.cc.cu 451 | template 452 | void gpu_kepler( 453 | cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len 454 | ); 455 | ``` 456 | 457 | The first parameter is a CUDA stream, which I won't talk about at all because I 458 | don't really know very much about GPU programming and we don't really need to 459 | worry about it for now. Then you'll notice that the inputs and outputs are all 460 | provided as a single `void**` buffer. These will be ordered such that our access 461 | code from above is replaced by: 462 | 463 | ```c++ 464 | // lib/kernels.cc.cu 465 | template 466 | void gpu_kepler( 467 | cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len 468 | ) { 469 | const T *mean_anom = reinterpret_cast(buffers[0]); 470 | const T *ecc = reinterpret_cast(buffers[1]); 471 | T *sin_ecc_anom = reinterpret_cast(buffers[2]); 472 | T *cos_ecc_anom = reinterpret_cast(buffers[3]); 473 | } 474 | ``` 475 | 476 | where you might notice that the `size` parameter is no longer one of the inputs. 477 | Instead the array size is passed using the `opaque` parameter since its value is 478 | required on the CPU and within the GPU kernel (see the [XLA custom 479 | calls][xla-custom] documentation for more details). To use this `opaque` 480 | parameter, we will define a type to hold `size`: 481 | 482 | ```c++ 483 | // lib/kernels.h 484 | struct KeplerDescriptor { 485 | std::int64_t size; 486 | }; 487 | ``` 488 | 489 | and then the following boilerplate to serialize it: 490 | 491 | ```c++ 492 | // lib/kernel_helpers.h 493 | #include 494 | 495 | // Note that bit_cast is only available in recent C++ standards so you might need 496 | // to provide a shim like the one in lib/kernel_helpers.h 497 | template 498 | std::string PackDescriptorAsString(const T& descriptor) { 499 | return std::string(bit_cast(&descriptor), sizeof(T)); 500 | } 501 | 502 | // lib/pybind11_kernel_helpers.h 503 | #include 504 | 505 | template 506 | pybind11::bytes PackDescriptor(const T& descriptor) { 507 | return pybind11::bytes(PackDescriptorAsString(descriptor)); 508 | } 509 | ``` 510 | 511 | This serialization procedure should then be exposed in the Python module using: 512 | 513 | ```c++ 514 | // lib/gpu_ops.cc 515 | #include 516 | 517 | PYBIND11_MODULE(gpu_ops, m) { 518 | // ... 519 | m.def("build_kepler_descriptor", 520 | [](std::int64_t size) { 521 | return PackDescriptor(KeplerDescriptor{size}); 522 | }); 523 | } 524 | ``` 525 | 526 | Then, to deserialize this descriptor, we can use the following procedure: 527 | 528 | ```c++ 529 | // lib/kernel_helpers.h 530 | template 531 | const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { 532 | if (opaque_len != sizeof(T)) { 533 | throw std::runtime_error("Invalid opaque object size"); 534 | } 535 | return bit_cast(opaque); 536 | } 537 | 538 | // lib/kernels.cc.cu 539 | template 540 | void gpu_kepler( 541 | cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len 542 | ) { 543 | // ... 544 | const KeplerDescriptor &d = *UnpackDescriptor(opaque, opaque_len); 545 | const std::int64_t size = d.size; 546 | } 547 | ``` 548 | 549 | Once we have these parameters, the full procedure for launching the CUDA kernel 550 | is: 551 | 552 | ```c++ 553 | // lib/kernels.cc.cu 554 | template 555 | void gpu_kepler( 556 | cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len 557 | ) { 558 | const T *mean_anom = reinterpret_cast(buffers[0]); 559 | const T *ecc = reinterpret_cast(buffers[1]); 560 | T *sin_ecc_anom = reinterpret_cast(buffers[2]); 561 | T *cos_ecc_anom = reinterpret_cast(buffers[3]); 562 | const KeplerDescriptor &d = *UnpackDescriptor(opaque, opaque_len); 563 | const std::int64_t size = d.size; 564 | 565 | // Select block sizes, etc., no promises that these numbers are the right choices 566 | const int block_dim = 128; 567 | const int grid_dim = std::min(1024, (size + block_dim - 1) / block_dim); 568 | 569 | // Launch the kernel 570 | kepler_kernel 571 | <<>>(size, mean_anom, ecc, sin_ecc_anom, cos_ecc_anom); 572 | 573 | cudaError_t error = cudaGetLastError(); 574 | if (error != cudaSuccess) { 575 | throw std::runtime_error(cudaGetErrorString(error)); 576 | } 577 | } 578 | ``` 579 | 580 | Finally, the kernel itself is relatively simple: 581 | 582 | ```c++ 583 | // lib/kernels.cc.cu 584 | template 585 | __global__ void kepler_kernel( 586 | std::int64_t size, const T *mean_anom, const T *ecc, T *sin_ecc_anom, T *cos_ecc_anom 587 | ) { 588 | for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 589 | idx += blockDim.x * gridDim.x) { 590 | compute_eccentric_anomaly(mean_anom[idx], ecc[idx], sin_ecc_anom + idx, cos_ecc_anom + idx); 591 | } 592 | } 593 | ``` 594 | 595 | ## Building & packaging for the GPU 596 | 597 | Since we're already using CMake to build our project, it's not too hard to add 598 | support for CUDA. I've chosen to enable GPU builds whenever CMake can detect 599 | CUDA support using `CheckLanguage` in `CMakelists.txt`: 600 | 601 | ```cmake 602 | include(CheckLanguage) 603 | check_language(CUDA) 604 | ``` 605 | 606 | Then, to expose this to JAX, we need to update the translation rule from above as follows: 607 | 608 | ```python 609 | # src/kepler_jax/kepler_jax.py 610 | import numpy as np 611 | from jax.lib import xla_client 612 | from kepler_jax import gpu_ops 613 | 614 | for _name, _value in gpu_ops.registrations().items(): 615 | xla_client.register_custom_call_target(_name, _value, platform="gpu") 616 | 617 | def _kepler_lowering_gpu(ctx, mean_anom, ecc): 618 | # Most of this function is the same as the CPU version above... 619 | 620 | # ... 621 | 622 | # The name of the op is now prefaced with 'gpu' (our choice, see lib/gpu_ops.cc, 623 | # not a requirement) 624 | if np_dtype == np.float32: 625 | op_name = "gpu_kepler_f32" 626 | elif np_dtype == np.float64: 627 | op_name = "gpu_kepler_f64" 628 | else: 629 | raise NotImplementedError(f"Unsupported dtype {dtype}") 630 | 631 | # We need to serialize the array size using a descriptor 632 | opaque = gpu_ops.build_kepler_descriptor(size) 633 | 634 | # The syntax is *almost* the same as the CPU version, but we need to pass the 635 | # size using 'opaque' rather than as an input 636 | return custom_call( 637 | op_name, 638 | # Output types 639 | result_types=[dtype, dtype], 640 | # The inputs: 641 | operands=[mean_anom, ecc], 642 | # Layout specification: 643 | operand_layouts=[layout, layout], 644 | result_layouts=[layout, layout], 645 | # GPU-specific additional data for the kernel 646 | backend_config=opaque 647 | ).results 648 | 649 | mlir.register_lowering( 650 | _kepler_prim, 651 | _kepler_lowering_gpu, 652 | platform="gpu") 653 | ``` 654 | 655 | Otherwise, everything else from our CPU implementation doesn't need to change. 656 | 657 | ## Testing 658 | 659 | As usual, you should always test your code and this repo includes some unit 660 | tests in the `tests` directory for inspiration. You can also see an example of 661 | how to run these tests using the GitHub Actions CI service and the workflow in 662 | `.github/workflows/tests.yml`. I don't know of any public CI servers that 663 | provide GPU support, but I do include a test to confirm that the GPU ops can be 664 | compiled. You can see the infrastructure for that test in the `.github/action` 665 | directory. 666 | 667 | ## See this in action 668 | 669 | To demo the use of this custom op, I put together a notebook, based on [an 670 | example from the exoplanet docs][exoplanet-tutorial]. You can see this notebook 671 | in the `demo.ipynb` file in the root of this repository or open it on Google 672 | Colab: 673 | 674 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dfm/extending-jax/blob/main/demo.ipynb) 675 | 676 | ## References 677 | 678 | [jax-primitives]: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html "How primitives work" 679 | [xla-custom]: https://www.tensorflow.org/xla/custom_call "XLA custom calls" 680 | [xla-adventures]: https://github.com/danieljtait/jax_xla_adventures "JAX XLA adventures" 681 | [jaxlib]: https://github.com/google/jax/tree/master/jaxlib "jaxlib source code" 682 | [keplers-equation]: https://en.wikipedia.org/wiki/Kepler%27s_equation "Kepler's equation" 683 | [stan-cpp]: https://dfm.io/posts/stan-c++/ "Using external C++ functions with PyStan & radial velocity exoplanets" 684 | [kepler-h]: https://github.com/dfm/extending-jax/blob/main/lib/kepler.h 685 | [capsule]: https://docs.python.org/3/c-api/capsule.html "Capsules" 686 | [jaxlib-lapack]: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx "jax/lapack.pyx" 687 | [scikit-build-core]: https://github.com/scikit-build/scikit-build-core "scikit-build-core" 688 | [pybind11-cmake]: https://pybind11.readthedocs.io/en/stable/compiling.html#building-with-cmake "Building with CMake" 689 | [exoplanet-tutorial]: https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets "A more realistic example: radial velocity exoplanets" 690 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "extending-jax-demo.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "Vgw63bl7SQTf" 20 | }, 21 | "source": [ 22 | "# Demo: Extending JAX with custom C++ and CUDA code\n", 23 | "\n", 24 | "This demo is adapted from an example in [the exoplanet project's documentation](https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets) to work with [numpyro](http://num.pyro.ai/) and the custom C++ op defined in the [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax/) tutorial. See those tutorial for all the details." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "metadata": { 30 | "colab": { 31 | "base_uri": "https://localhost:8080/" 32 | }, 33 | "id": "CQ1VJptcACy0", 34 | "outputId": "3f3726e5-8c41-4ffa-b933-04acbc26133b" 35 | }, 36 | "source": [ 37 | "%matplotlib inline\n", 38 | "!python -m pip install -q numpyro\n", 39 | "!python -m pip install -q git+https://github.com/dfm/extending-jax.git" 40 | ], 41 | "execution_count": 1, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "text": [ 46 | "\u001b[K |████████████████████████████████| 184kB 4.5MB/s \n", 47 | "\u001b[K |████████████████████████████████| 481kB 7.6MB/s \n", 48 | "\u001b[K |████████████████████████████████| 32.1MB 150kB/s \n", 49 | "\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 50 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 51 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 52 | " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", 53 | " Building wheel for kepler-jax (PEP 517) ... \u001b[?25l\u001b[?25hdone\n" 54 | ], 55 | "name": "stdout" 56 | } 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "id": "w02puAcQSyG-" 63 | }, 64 | "source": [ 65 | "Download some data (see [here](https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets) for more info):" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "metadata": { 71 | "colab": { 72 | "base_uri": "https://localhost:8080/", 73 | "height": 279 74 | }, 75 | "id": "jCWrLqr3AzMS", 76 | "outputId": "b9ddf8b1-1cf1-4dfa-da9e-aec11674970d" 77 | }, 78 | "source": [ 79 | "import requests\n", 80 | "import numpy as np\n", 81 | "import pandas as pd\n", 82 | "import matplotlib.pyplot as plt\n", 83 | "\n", 84 | "# Download the dataset from the Exoplanet Archive:\n", 85 | "url = \"https://exoplanetarchive.ipac.caltech.edu/data/ExoData/0113/0113357/data/UID_0113357_RVC_001.tbl\"\n", 86 | "r = requests.get(url)\n", 87 | "if r.status_code != requests.codes.ok:\n", 88 | " r.raise_for_status()\n", 89 | "data = np.array(\n", 90 | " [\n", 91 | " l.split()\n", 92 | " for l in r.text.splitlines()\n", 93 | " if not l.startswith(\"\\\\\") and not l.startswith(\"|\")\n", 94 | " ],\n", 95 | " dtype=float,\n", 96 | ")\n", 97 | "t, rv, rv_err = data.T\n", 98 | "t -= np.mean(t)\n", 99 | "\n", 100 | "# Plot the observations \"folded\" on the published period:\n", 101 | "# Butler et al. (2006) https://arxiv.org/abs/astro-ph/0607493\n", 102 | "lit_period = 4.230785\n", 103 | "plt.errorbar(\n", 104 | " (t % lit_period) / lit_period, rv, yerr=rv_err, fmt=\".k\", capsize=0\n", 105 | ")\n", 106 | "plt.xlim(0, 1)\n", 107 | "plt.ylim(-110, 110)\n", 108 | "plt.annotate(\n", 109 | " \"period = {0:.6f} days\".format(lit_period),\n", 110 | " xy=(1, 0),\n", 111 | " xycoords=\"axes fraction\",\n", 112 | " xytext=(-5, 5),\n", 113 | " textcoords=\"offset points\",\n", 114 | " ha=\"right\",\n", 115 | " va=\"bottom\",\n", 116 | " fontsize=12,\n", 117 | ")\n", 118 | "plt.ylabel(\"radial velocity [m/s]\")\n", 119 | "_ = plt.xlabel(\"phase\")" 120 | ], 121 | "execution_count": 2, 122 | "outputs": [ 123 | { 124 | "output_type": "display_data", 125 | "data": { 126 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEGCAYAAACtqQjWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2de3hU9bX3v2sykwsDDMglJAIhIEREQmKjNUp9Q/G8rfVU7bGHaqFQTnuo11aP72vLsRfbnoqtWmmltaX1BtJa3qLVKvXGMVo80TbKJAURCUTCJQkYYCCBTGYn6/1jZm9+e2fPZCeZWybr8zzzZGbfZmXPzF573YmZIQiCIAjxwJVqAQRBEITMQZSKIAiCEDdEqQiCIAhxQ5SKIAiCEDdEqQiCIAhxw51qARLJ+PHjedq0aakWQxAEYUjxzjvvfMTMEwayb0YrlWnTpqG2tjbVYgiCIAwpiGjfQPcV95cgCIIQN0SpCIIgCHFDlIogCIIQN0SpCIIgCHFDlIogCIIQN1KqVIjoUSI6TETblWVnEdErRLQ78ndsZDkR0c+JqIGI6onogtRJLgiCINiRakvlcQCftiz7FoAtzDwTwJbIawC4AsDMyGMFgIeTJKMgCILgkJQqFWZ+A8BRy+KrATwRef4EgGuU5es4zFsAxhBRQXIkFQRBEJyQakvFjnxmbo48bwGQH3l+NoD9ynYHIstMENEKIqolotojR44kVlJBEATBRDoqFQMOTxDr1xQxZl7LzBXMXDFhwoC6DAiCIAgDJB2VSqvu1or8PRxZfhDAFGW7yZFlgiAIQpqQjkrlOQDLIs+XAXhWWb40kgV2MYCA4iYTBEEQ0oCUNpQkot8DqAIwnogOAPgegHsBbCSirwDYB2BRZPPNAD4DoAHAKQDLky6wIAiCEJOUKhVmvj7KqoU22zKAmxMrkSAIgjAY0tH9JQiCIAxRRKkIgiAIcUOUiiAIghA3RKkIgiAIcUOUiiAIghA3RKkIgiAIcUOUiiAIghA3RKkIgiAIcUOUiiAIghA3RKkIgiAIcUOUiiAIghA3RKkIgiAIcUOUiiAIghA3RKkIgiAIcSOjlcquXbtQVVWVajEEQRCGDRmtVFJNVVWVKDVBEIYVKR3SFQ0iKgHwB2XRdADfBTAGwL8DOBJZ/p/MvDnJ4gmCIAhRSEulwsy7AJQBABFlATgI4BmERwg/yMz3p1A8QRAEIQpDwf21EMAeZt6XakEEQRCE2KSlpWLhOgC/V17fQkRLAdQCuIOZj6kbE9EKACsAICcnJ2lC2uH3+1P6/oIgCMkmrS0VIsoGcBWA/xdZ9DCAGQi7xpoBPGDdh5nXMnMFM1d4PJ6kyRoLCdgLgjBcSGulAuAKAO8ycysAMHMrM3czcw+A3wC4KKXS9YGmaejs7EQgEEi1KIIgCEkh3ZXK9VBcX0RUoKz7HIDtSZfIITU1Nejo6EAwGER9ff2gFYtYO4IgDAXSVqkQkRfAPwF4Wln8EyL6BxHVA1gA4PZYx+ju7kZTUxNqamoSKKk91dXVxvOenp64WyuiZARBSEfSNlDPzB0AxlmWfak/xzh9+jQaGxuxcOFCbNmyBZWVlXGVMRbqBd/lcsHn8yX8vVRFJgiCkArS1lKJJ11dXUm/4FZWViI3NxdZWVkoLCxEIBBIicUkCIKQTIaFUsnOzk66q6impgadnZ3o7u7GgQMHDItJFIsgCJlMRiuVvLw8FBcXJ931Bdi7ouwspnjGRiTOIghCqknbmEo8yMrKwtSpU5OuUADYXtx1iykeMZBAIGC41PQiy7KysgEfTxCE4UMi47AZrVRSSWVlJbxeLzRNw9SpU6FpGjZs2BAXBVdTU4P6+nr09PRg4cKFcLlccLvloxQEoX8kQrnIlciC3Unu74nXt3e73XC73SgsLASAuFlM1dXV6OnpARB2qenvIwiCkGoyOqaSTJIZz6iqqoLLFf7osrOzRaEIgpA2yNWoHzixWKqqquD3+1FWVtbvGEes4+sxlPLycvh8PpSWliIQCGDDhg34p3/6J6MdTCLrYQRByAz064nP54v7NUMsFQuBQCBlVfhAuLOx1eLRYyiNjY1Gyxefz4epU6cCQFzbwQiCkNmo15O6ujp88MEHcb3eDXulorqt1JOt1pToisbpBVvTNDQ1NWHVqlVxCYCpMRRryxdrO5h9+/ZJLYwgCFFd8ur1hJnR3Nwc1xo6cX8pWAPgixcvhs/nMzKtXC4X8vLyUFVVZass/H4/NE1DR0eHqT1MtG2rqqpMqcHR0GMougy6uer3+3spumPHjqWkLY0gCEMD9Xqio9fQxeOakdGWSklJSZ+WgurusgbAfT4fAoGAyUroq5W9pmnG877awwQCAZNlFO24lZWVKC0tRXFxMUpLS00+UJ/Ph9zcXBCR4/cVBGH4ol9PCgoKjOtGPLuOZLRS6Yvy8nL4/X7jog7AuHhv2bLFCGLpigYIdz6ur6/vZVkEAgF0dnaaLu59fVCqwurq6kJra2tUpaXGUHTXWllZmfG+zOz4fQVBGN74fD7MmjUL8+bNM3UdiUcW67BWKq2trcZz3d3V2NhoVOHrrqkZM2Zg7NixxrY9PT1Yt26d8VqPxQSDQXR2diI3NzdmexhdEbjdbkNhud1utLS0IBgMoq6uDmvXrrXdz2rZWBVQVlaWuL4EQXCUdKTfrMbzejFslUpNTQ1aWlqM12632+RWUoP2e/bswfjx401WyGOPPWZ8WGosBggrKJ/PZ/qg9DsAVQHt2bMHM2bMQHFxMZYvX25YG8yMW265xdYaUi0buxRil8uFm266SSwVQRiGWK8zqWhkO2yVSnV1tclltHz5cvh8PpSVlaG6urpXxlVTUxMmTZpkbK9pmhG3qKqqMimcnp4e1NXV2X6Q1uO2t7dj6tSpWLp0qekY3d3dveIiqitOj/n4fD54vV54PB4AQCgUktRiQRjmWJOOrNcS/RqXCNJWqRDRh5Epj34iqo0sO4uIXiGi3ZG/Y/s6TjTUoLzL5cLSpUujBu2BsCWTn59vvFbjFpWVlSaFA4StDbsPzaqAWlpaEAgEUFlZiXPOOcdYnpOT08va0Ise1ZiPLpsqayImTQqCMHSwJh1F81wkQrmkrVKJsICZy5i5IvL6WwC2MPNMAFsirweEmlG1detWADCZi8CZoH1ZWRkqKioMq8AuXqIqHAAgol4fpN/vx8qVK00KiJkNBVBYWBj1+DrRfKBqq5ZET5oUBCG9Ua9vTmKsaoBe73o+UNJdqVi5GsATkedPALhmMAdTL9B25qI14yoQCMDtdtte1FU3VEFBAd58882oH6SqgKwKINrxY6ErPa/Xi5ycnF5px4IgDA9Ub8tAgvCBQACnTp0CgKkDlSGdlQoDeJmI3iGiFZFl+czcHHneAiDfuhMRrSCiWiKqPXLkSMw3UE2/aOaimnFVX1+P4uLiqOai2+3GiBEjMGvWrF4fpJ7xpQfXdYtEVQB+vx/t7e29jqv3E7NirfR3u93Izc0FgJS2mhEEIflYg/P9dYEHAgHU1dUhFAoBwISByhG1op6IznKwfw8zHx/om/fBfGY+SEQTAbxCRO+rK5mZiYitOzHzWgBrAaCioqLX+mjo5qLepFG3XlatWmVc0IkIixYt6vc/os4/qa+vR2lpqWGR9BddoanHdLlcKC0tNepW1Fkrkl4sCMMDq7elPw1mA4EA9u3bZ0peGiix2rQcijwoxjZZGISZFAtmPhj5e5iIngFwEYBWIipg5mYiKgBwOJ7vqWdTWVOB9ZYGfQW8olkUixcvdtS7S9M0uN1urFq1qtcx9I7HeosYu35g1g4A8Wy9IAhCemO9VjkdCqjeoMaDWO6vncw8nZmLoz0AtMVFCgtE5CWiUfpzAP8bwHYAzwFYFtlsGYBnE/H+Kv0NeNmh3i1EC6LX1NTg9OnTCAaDjvLKVXddXl4eNmzYgOrqamzYsMHYRirrBWHoMNhq9oFeq6x1dqNGjQKA2LGDGMSyVJxIlKhb4HwAz0RSb90AfsfMLxLR3wFsJKKvANgHoP++qBhEi5XYWTDR9rf7UugxFE3TMHv2bFulolozTiwMO3edvlx/L3F9CcLQpL/TZnWcXqus76U2rJ0xYwb8fn9Tv95YIapSYeZOACCiGQAOMHOQiKoAlAJYx8zH9W3iDTPvBTDPZnkbgIWJeM9Eo6f8WoNn+pdHVTRWC0MdqNPY2Gi4wKJ9gfTxwqJQBCEzUJWMncKJh4UTr6FdTrK/NgHoJqJzEA6ATwHwu0G96xAjHgVCakt8uywyNSOspKQEK1euBNB7QJfaBTnWe3V2dhpFnOICE4TMYyC/7Wj76OnH8ShFcDJPpYeZNSL6HICHmPkhIto26HfOUKIpH1UZRKt4t8sIswbku7q6YqYL19TUoKOjAwCwcOFClJSUSM2KIAxRoikNPSEo2sjyMWPGAACOH09Ucm50nFgqISK6HuHA+PORZZ7EiZR5qBlZwJnAur5OrzXR+46pWNvFhEKhmHno6v56WqEgCEOL/k6b1dG9Kpqm4cSJEygvLx/Q/oPBiaWyHMANAH7EzI1EVAxg/aDeNckMNOgVD9R0PSLCpEmTsGnTJlRWVtqus1ogur+ztbUVzc3NxvJgMGibh25NKxQrRRDSn6qqKtTW1mLixIlYu3Ztrxo0AKYJsZqmQdM0Q+lYJ8jq3gp99pOTJKN4Eav4cS2AvwB4lZm/ri9n5kYAP46bBBmOdR50bm6u8QFb1+mzoq0uK73+RCUrK8tWYViDbrt378bEiRMdfbEEQUgNgUDAiLnefPPNJpe3ekM5f/58ZGdno7MznCNVX19vUkILFy7EsmXLjOPqs5/0376qfBJ1PYjl/noE4QyszUS0hYi+SUS9MrKE2Fi7IauKwOraAqK7rKwTKNesWRPVCtGDbh0dHcYXNdkzFQRBcI76m9c9FzofffSRaV2kjYrxetOmTaZyBCv67KdkzViJqlSY+W1mvpuZP4FwPUgTgDsiregfJaK41ohkKrHmy0ebFW2nLPS290QEr9eLFStW9NpGJRAIoKGhwXgdDAZlbr0gpCnqbz4nJ8d4eL1eZGdnG+tcLheysrJMr6+99lpT30JrHCUUCvXqwmE3YyVeOImp6PUhv488QEQfA/DphEiUgcTK/dbX5efnG4WMejqx3bYul8uoeYn1pQgEAqY+PllZWZJaLAhphvqb9Hg8GD9+PDZt2oQrrrgCAHDy5EnDbc7MmDp1Knbv3m3sM3bsWMydO9dUCG29LjCz8T5OWk4Nlj6VChGNAbAUwDR1ezXOIgwetZDR6vdUOymPHDkyahqhTnV1NWpqajB//nzDlF6zZo3EVAQhDVGbwLa2thrLNU0zLIvOzk54vd5eXczb2tqMOKza5l5XHgBwzjnnGMvtunDEGyeWymYAbwH4B4D4dBwTDFSFAZizxey6DGuaZtSp9NXGZevWrVi8eDF8Pl+f7jJBEFKDWnLQ09ODBQsWGM1lVU6fPm1kdalYOxLHqpAfSBuX/uJEqeQy838kTIJhhBMfpp3fU83cUAPvffX2qqysHFB7fUEQkofu1tZ/926321AqwWDQ2C5aF2G7jsTxaLcyUJwolfVE9O8IFz4a/yEzH02YVBmGU2UChC2VaH7PRYsWGZW00tZeEDIDPQln586dRt++kSNHori4OOpoX73xo6ZpMV1Zuiu9vLw8aUrGiVLpAnAfgLsQnsaIyN/piRIq3iQjNzteROs+DDif7aIiGV+CkJ6oBY8+nw8XX3wxgHALFk3TsG/fvl77EBHcbjfmzJljcnfZocZq9CLKbdsS32HLiVK5A8A5zPxRn1umIX3FKNKRaH7PWApHEIShherOVivnT506hVAoZIqf6B038vPz0djY6MjqsMZqktWyyUnvrwYApxItSKJIVm52stALG0WhCMLQxlrw2Nraivr6elNxo868efMwa9Ysxy4sfWBftMLrROJEqXQA8BPRr4no5/ojUQIR0RQieo2I3iOiHUT0jcjyu4noYKT40k9En3FyPLVqXSYhCoKQKPrbil69yOfl5eHqq6+2Dcar1fX9QS28njFjhqk3WCJx4v76U+SRLDQAdzDzu5GRwu8Q0SuRdQ8y8/39OVimuYzSwdJKZYNOQUh3nP4+9BlKEydONLqW/+pXv+q1HTOjvr7eiIkMRHElMwTQp1Jh5icS9u7279cMoDny/CQR7QRw9mCOmYzc7ExElIcgJI5AIABN00zXJo/HY+v+GkxMRI2tJCNrNKr7K9KlOCZOthkMRDQNQDmAtyOLbiGi+kjvsbFR9llBRLVEVHvkyJFEiicIgjAg9ASiYDAIv99v9OtS+3oRkeH6Umcw6X28nMw+scZWkhECiGWpXENEsWbQE4AFcZbnzMGJRiI8yvg2Zj5BRA8D+CHC6cw/BPAAgH+z7sfMaxEee4yKigq2rhcSi1g3gtA3agIRcGZWit7SHgi3V/F6vYN23Sc7BBBLqfxfB/v/NV6CqBCRB2GFsoGZnwYAZm5V1v8GZ6ZQCjYk8uI+lOp+BCFZ6L8LJ9Xsas2Znpll/a3qrrF4uO6TGQKIqlSSHUvRobC99wiAncz8U2V5QSTeAgCfA7A9FfIlg3S5y7dTHmrdz/z585NWUCUI6Yz6uwCAgoKCmDdddv25VEUDIGVtVgaLk5TiZHMpgC8B+KQlffgnRPQPIqpH2O12e0qlzHCiDfRRzXY9eNjfVEpByDSs7ix9imtNTU3U34dec2ZtBJmTk4OysrK43qzFY/a8UxzNU0kmzLwV4XiNlc3JlmU4E62xpfrjSGZBlSCkkr7cyVYrA3BebK16BHw+H3Jzc4f076pPS4WI5iZDECG+BAIBo0X+QIhWNFpZWQmv14ucnJxekyzj8b6CMBSwWh+6leHxeIxlTjKt9P5cukcgWa1UEokT99cviehvRHQTEQ1d9TmMiMcsarUat6SkxDSN0u12295NJWsGtiCkG1VVVWhsbMSIESPg9XptfzdWqqursWjRIpNHYNGiRTh+/HjaxFUHgpPix08Q0UyE03ffIaK/AXiMmV/pY9e0YSh/QAMh1kyW/hAti0XTNGiaZgQZdfN93bp1SS2yEoTBEI8MSTtLxO12m+YYRcuWrKqqQiAQSMqI32TidEb9biL6NoBaAD8HUB7J0vpPPeVXSB8G0iLfKTU1NUb31Lq6Opx11lk4evQomBmPPvqosd1A31fqXIRMQm0/b9ciRZ+lEq2GZCj+DpzEVEqJ6EEAOwF8EsBnmXl25PmDCZZPGACq6yrefX7ULzkzo62tDczhGtPu7m54PB4QEUpKSsRKETIGp7FCTdPQ2dlpxEbsWqRYybTO404slYcA/BZhq+S0vpCZD0WsFyENiUexk/4DKC8vN8x3uywXnezsbGPdUM5eEQQVu5lMdoWOxcXFxnb19fWYMWMGOjs7QURg5oxxb/WFk0D9M8y8XlUoejt6Zl6fMMmEtMAafAeA0tJSFBQUmLbzeDzYsmULOjs70d3d3euuTmpZhHTDqfVhjVGuW7fO+E3U19fbWiU9PT1oaGhAc3O4XrugoKCX10B//0zI+FJxYqksBbDasuzLAH4Wd2mEtMMu6K/fnZ04cQKaphlztSsrK41tdSU0FCZtCsMPq/VRUlJi2yoF6B2jBGBSHvv27UNRURF8Pp+xndvthqZpAMJu4tzcXNPvQH3/vLw8bNmyJfH/dJKI1aX4eiL6M4BiInpOebwG4GjyRBQGQrwqaPsz5KympsaIrwBAMBgckoFGIfOx3izFshbUSveSkhIsXbrU+E0AwLFjx1BXV4fW1lbMmDEDxcXF+MUvfhFz6mKmTaRViWWp/A/Cc03GI9wRWOckgPpECiWkD3YdTqurq1FTU4NLLrkEQFh5eDwe/OQnPzHtS0SGEgoEAti5cyfKy8ulV5iQcqzWR18xQLXSfeXKlcjLy0N2djaOHTsGIGyNNDc3G7PmV6xYgYcffjhqVlciMzRTTVRLhZn3MXM1M1cy8+vK411m1pIppJBa7LJTrHdWoVAIf/7zn03LPvvZz6KystI0O6K+vl6KIoWUY82Q7G9iidvtRlFRkcliAczDtGJldSUyQzPVxHJ/bY38PUlEJ5THSSI6kTwRhXTE7s5KdX0REe68804AvZtQxjL1pc2LkCwGm8qr15gUFBQYw7T60w8v01KJdWK1vp8f+TsqeeIIQwW9B5heCAkAOTk5mDx5MpqamjB79mzjx2KdHRHN1LdL3cy0H5yQnvQ1I6i6uhrl5eV46623jMQU4Ezqfn5+vuNZKpmOk+LHi4lolPJ6FBF9PLFiCelEtKB/RUUFvF4vPB6PkTJZWFiIiy++2BQ3UQOdpaWlURVFJgcvhfTFSc861YXb0dFhZHbpWNvYA30nyySzHX0ycVKn8jCAduV1R2SZMMwJBALo6OhAKBRCa2trzG3VQKdas2J9nsxZ2oIAOLuZsc5Lyc7OxqpVq5Il4pDCSZ0KseIsZ+YeIkq7OSxC8lm0aBH8fj+AvmMlQLiFRVNTU1QXQbJnaQvDG/37WlNTEzMTy9r4EQinEav1LdZjDmecWCp7iejrROSJPL4BYG+iBYsGEX2aiHYRUQMRfStVcghmyyIvLy+mZbFq1SqcPn26VxWylUwNXgrpi5NMLD0on5WVZSzTW9WLIjHjRKncAOASAAcjj48DWJFIoaJBRFkAfgHgCgDnAbieiM5LhSyC/Y8xmp/YbgwxINleQnrg5GbG5/P1ewjXcMTJPJXDAK5LgixOuAhAAzPvBQAiegrA1QDeS6lUwxinjSutGWA+n88226u/SKt8YTDo3x/djRuLQCCAzs5OAOGU+dWrV4tFbUOfSoWIJiPcqfjSyKK/AvgGMx9IpGBROBvAfuX1AYQtJwMiWoGIJaUOyhFSixovcbvdUYd6xUJVIHKHKCSKqqoq+P1+lJWVobq62kg31hUKcGbsg47c1JzBifvrMQDPASiMPP4cWZaWMPNaZq5g5ooJEyakWhxBQbdq9uzZg8bGRjz66KNG0Vh2djY2btwY9Y5R/6FbnwvCQHDSNdvv96O8vNxIN25paTHWxaq3Gu44USoTmPkxZtYij8cBpOpqfRDAFOX15MgyIUX0N9debQ/e3d2NSZMm9WqV4ff75QcrJJWysjKUlZX1Wq5+X10uFwoKCvqstxruOFEqbUS0hIiyIo8lANr63Csx/B3ATCIqJqJshGM9z6VIFmEA6O3BgbB1kp+fL9leQtqhT3B0u929vq96vZVgjxOl8m8AFgFoQbhr8ecBLE+kUNGINLK8BcBLCI833sjMO1IhizAw9NTMaI38ot0xOkWGgQn9Qc8+PHTokJGFeOjQIXR0dCAYDGLPnj1GO3v9+6rHWgR7nGR/7QNwVRJkcQQzbwawOdVyCP1H/yFWVVUNaNSxWjypowdRy8vL5e5RsMUuQzAQCKC1tRUtLS2mRqhVVVXo6uoyXvf09KC9vR2zZs0Sa9ohUZUKET0EgKOtZ+avJ0QiYdjQn3Rg/e6xsbERRASXy4Xt27fj6NGjYGZjjoUoFqEv1FR2K6FQqNeylpYW5OfnJ0O0jCCWpVKbNCkEAWesDn0Uq94xtqamBg0NDcZ2zIzu7m5TSqdeUClKRegLax8vFY/Hg1AoZLJemNko1hW3V9/Ean3/hPqaiEYw86nEiyQMR+zuHvWCyOrqatOP3I7+zLEQhjdqIS4QViTTpk2DpmnYsGEDbrrpJuzYscOwWvLy8rBhw4ZUijykcNL6vpKI3gPwfuT1PCL6ZcIlE4YVdnePekGk2mPMjoKCAnF9CVGxtgLSxwEXFxfD6/VixIgRKCwsNLIQfT4fLrnkEpSVlWXkZMZE46Tb8GoAn0IkdZeZ64josoRKJWQ0asBex3r3CJzpraRX4+/cudMYjtTV1YXx48cjPz/fVN8CYFDZY0LmoHcXtmsF5Ha7cfToUZw+fRoej8fWdeq0BZFgxlELe2ber1c+R+hOjDjCcMI6bU9t46K7IvQftD6PRcftdmPWrFmmY3V2doKIjLtSuRgIavGibvmq35Wenh4Eg0HU19ejtLQ0xdJmBk6Uyn4iugQAE5EHwDcQrhERhAFj10zSOmfFqVJQ70YBGBP8xG0h6MW2+qyUcePG2WZ+EREWLVoEwN6SFpzjtPX9zQg3czwIoCzyWhAGzEBHB+uVzuo4V/VuVEfGEWce/S1s1S1htXixra3NNvNL2tjHD6eTHxcnXBJhWKHGUPQf9EsvvRRzn+LiYuMuMxgMGn5w9W5Ux3qRkBb56UmiPhfVEtZrmHSr1fpdidbGXr4rA8OJpfImEb1MRF8hojEJl0gYFjiZtqdSXV2NRYsWGReDrKwsw13h8/kwY8YMY1uZdZE59GWdRFsfbSic/r0bO3assa21jb0wOPpUKsw8C8C3AcwB8C4RPR9pKikIg8I6ba+vjsdqarFuiej7qO4wuUgI6nfF5XLh8OHDhgJqbGxEUVGRab24vuKHE0sFzPw3Zv4PhCcvHgXwRB+7CMKAiaZcYlk3aoBfLhKZibXepKqqCrW1tWhqakJ5ebnpM1e/K1u3bkVFRYXpWGpj061bt4pVG0ecFD+OJqJlRPQXAP+DcKfiixIumSDYYDdLXL/79Hq9A5p1IZ2N0xN1GJseI9Ez+2pqahAIBIx+cPX19YaLSyfa3Hm9y7CTufRC/3ESqK8D8CcAP2DmmgTLIwwjVGskHgFbt9sNt9vtqLJeAvdDg/b2dvj9fttsQVWJ9PT0YN++fVHrk/RMMDVtXT77xODE/TWdmW8XhSIIQqpQYyTMjI0bNxrdFXSOHTtmWDHAGaWhjgS2s2iE+OIkUB+7k58gpBhrexYZRzx08Pv9xucXC7Vfl175vmfPnl7bqf3i9O+AWsekZoIJicFRoD5ZENF9RPQ+EdUT0TN6CjMRTSOi00Tkjzx+lWpZhfTh5MmTOHHiBFatWoXjx49L768hiF1cS1cGgUAAtbW1Rr+uxsZG24JXIOwCtR5HHWFNROjs7DSsGSH+pJVSAfAKgPOZuRTABwBWKuv2MHNZ5HFDasQTUo01M6ympgY9PT1gZpProy+smURCelFTU4O6ujpj5EFHR4eRNg0AK6MAACAASURBVK53VbD0IwQALF++vFdMRc/0KigoAAA0Nzf367si9I+0mvzIzC8rL98C8Pl4v4eQWagKJhgMGgFctVGlFbXaev78+cjLy+uVciokB03ToGlary7BdjN0Tp8+jezsbHR2dqKjowNEhHHjxpmmfy5dutS0jxqgz83NNY6pu8kk8yv+pPPkx38D8AfldTERbQNwAsC3mfmvdjsR0QoAKwBg6tSpCRdSiA99KQKgd4O/6upqjBs3znjd09OD48eP92pUWVlZaTq+tdpaLZwUkkdNTQ06OjoAoFeXYLuYWE9PDzo7O43XzIzRo0djypQpCAQCpq7W1hn0LpcLM2bM6NUaSIg/jic/xgsiehXAJJtVdzHzs5Ft7gKgAdDHrTUDmMrMbUT0MQB/IqI5zHzCRu61ANYCQEVFhSQZDAHsOhY7vYNUK+ddLhd+85vf2DaqVC0T9eLicrmMLCJJM3ZOPM6Vuq81JbiyshIzZ87E7t27o+6vT/u0zj2xmyJKRFi+fDk2btzYSwEJ8aXPOhUimgDgmwDOA2AMtGDmTw7kDZn58j7e78sA/hnAQj3zjJmDAIKR5+8Q0R4As5B6a0qIA3Y1CLF+8KrVYbVUvF4vjh07BsDcysVqmeizW3w+H3bv3o2mpqZerfeFxGK1FPSU4JKSEvh8PhQWFqKhocF2lLTH48GcOXNsPy+7KaJq01IZvJVYnBQ/bkDYDXUlwm3wlwE4kghhiOjTAO4E8L+Y+ZSyfAKAo8zcTUTTAcwEsDcRMgjJx65jcTSsk/yWLVtmrHO5XPB4PPB6vZg4caLpblS1TNS7W7UqW+9mK6SOrq4u2ymMVkKhEH75y1/aKgf1+0REmDRpEjZt2iSKJEk4USrjmPkRIvoGM78O4HUi+nuC5FkDIAfAK5HMjrcimV6XAfgBEYUA9AC4gZmPJkgGIcmoUx9juSUCgQD27dtnsmpUcnJy4PP5cPToUaP9hq6gVMtEvWBZq7KlhiF52LnOsrOz4Xa7DcsxJyfHFEdRWbdune13Rf0+yUjg5OMkpTgU+dtMRFcSUTmAsxIhDDOfw8xTrKnDzLyJmedEll3AzH9OxPsLqaOvPky6n1x3bQHhC9DSpUuNnl/6/PHOzs6ofaAAoKmpyVivVmXrVkwspE9YmHikZFvPY0FBAVavXo09e/agsbERu3btMrWo7w/65y3uzOTjxFL5LyLyAbgDwEMARgO4PaFSCYIFq5987NixeOGFF1BZWWlKB9ZdY7t27ep1wVNdZ3l5eVi9ejVuvPFGAOFA7owZM+Dz+SRg3weDSaxQqayshNfrRUdHBzweD/Lz802TGa2WqJXy8vIByS8klj6VCjM/H3kaALAgseIIgj2qn9zlcqGoqMh0IfP7/Vi8eHHMEcVqFXZXVxc2bdpkvGZmtLe3o6mpCYcPHzZVZqszy/WWInoCwHCkv4kVTgiFQqivrzeUPBC2RPPz89Ha2oqsrCxommYK2t92222YO3eu7XtH+2yG62eWTKK6v4jozsjfh4jo59ZH8kQUBPN8jNLSUlu3htqOwy7gb11/7bXXGq+BcKV1Y2OjqXrbKVVVVRgzZsywcI3ZDUsbKOp57unpQVtbm8md6fP5QEQYMWIE5s2bZ9rX7sZBSD2xYio7I39rAbxj8xCEpKL7ybdt22Z7MVEHL9m5ZKzrV6xYgdLSUuTk5MDj8Zi2lYLI6DgdBe0k/mSNadltP3LkSJSVlZluCoDBKzQhMcQqfvxz5K9MeRQSzmCL6PSLizXbxzpHw7peb9+haRpCoZBxzFAohLfffhs5OTm93k/TNCNIPVyziuKVVeV2uw135oYN4VpnvcreGq+pqqoCEZm2H67nP52J1fvrz4jd++uqhEgkCHFCDSjrNSjRsoHcbje8Xi86OzvR3d1ttAQJhUIm5REMBtHZ2YnGxkbMnz9/2I2ijbdlUFZWBr/fb2T+TZ8+3VhnjdeosS0Aw+q8DyViBervj/z9F4TbqjwZeX09gNZECiUIdsSyZqJN9rPO0YimVPTGhtZK7O7ubuOO+dChQ6aaiZ6eHmlKOABiZdepn4+4t4YmsdxfrwMAET3AzGoL1z8TkbRHEdIG1SLJy8sz6lWsGWM+n8/2QqZpmuFysaOrqwvr1q1DQ0ODaTkRmQZBnTp1Ch988EGfbrFMTlnWM+SczrRRXZc6Ho8H48ePN6rgM/l8ZSJOih+9kdYoAAAiKgbgTZxIgtA/7FJcAWcZY9XV1Zg4cWLM42dlZQFArx5U55xzDiorK43ZH6FQCM3NzViwYMGwn9Xh9/tRW1trWyBpLZzUZ+ToNwehUAitreIMGao4USq3A6gmomoieh3AawBuS6xYguCcWCmufWWMAcCGDRtMWUVWiAjl5eW9Mo80TTNa6asKJ9Wpromq+lfrdPp6P936a2xsNA3E0hWHdTlgvjkgIrFMhihOZtS/iHADx28A+DqAEmZ+KdGCCYJTnKa4Otl/5syZICJTirGmaWhrazPSj2fOnAlN04wL47hx40xTCJ3EAvx+/5Coa7EqDD3zza5Hmqp01JRsVclGsyr1/Z3Uv1infwrphdNxwjMBlACYB+ALRLS0j+0FIan01TvM6f6FhYVwuVzIzs421mVnZ2Pjxo1obGw00o/1C+Pp06dx7733Yt68efB4PCgoKMDs2bOxcuXKaG+VcgZaqKl2dPb7/Ub8yI5oSlZVHMxskiHazYGMfh5a9KlUiOh7CPf8egjhNi0/ASDpxELGMnLkSFRUVMDlcoGIjMpuHbUIT22lP2LECMyaNctRE0N9znoiuiIn6iJslbW5uRmXXHIJamvP5O1omob29nYjS46IsHr1akNB6IojJycHpaWltgWq6s1BLHeZkJ44sVQ+D2AhgBZmXo6wtSKtP4WMpaysLKZ7Ra3Mj1X7Eg39jj8YDKK+vr7XhXIwMRH1Ijx//vy4Nl2M9n92dXWhqakJhw4dQkdHB7q7u411zGyazgkA27Ztw+zZs41Ba7GI5S4T0hMnXYpPM3MPEWlENBrAYQBTEiyXIPSLwV5srPvX1NQYF7OFCxdi8uTJ6OzsxOzZs4002IG2VVczm3p6eqLOBXGKmnJrV5tj3U7H7/f32RhTnbLp8/ng9XrR1dXVq/tAY2Oj7f5q2rVOrC7HVlnUtHCpWxkaOLFUaoloDIDfINzz610ACbNBiehuIjpIRP7I4zPKupVE1EBEu4joU4mSQcgcBhrUVfcJBoNoaGhAMBjs1VJfdzWtWrUKx48fR3V1dUz3U01NDVpaWkzLfv3rX6O8vNxU8zJQ95V60bWbD+P3+20zuOyOU15ebnI9HTp0CJqmITs722j6aO2ZZkVPu1bpj/Ux2CQMIfnEtFQoHG1bxczHAfyKiF4EMJqZ6xMs14PMfL+6gIjOA3AdgDkACgG8SkSzmLnb7gCCMBjUO+SsrCzjzly9CFpHG+tFl7FmjVjTj4Gwi0iv9h/srBJ9RommaZg9e7atNaVpmjGrJFZMRx0VoCtWZkYwGITX64Xb7Y4588TlcqGwsLDX8v5aHzK9cWgR01Lh8Ld/s/L6wyQolGhcDeApZg4ycyOABgAXpUgWIcNR75DXrFljm+pqnc9idT/Z3YXbXUBVi8Ju/4HGWDo6OkwWj9/vx8mTJ9HR0YFQKIRQKAS/34+1a9f2ktHv9/dKSFCV4enTp43jAMDo0aNNGV96WvaqVat6ySXWR2bjxP31LhFdmHBJzNxCRPVE9CgR6fNEzwawX9nmQGSZCSJaQUS1RFR75MiRZMgqZCh6JpLeIl+9CFZXV5uKJnVl01ethW5JEBFyc3ONLChdqTit1YhGTU2NkQSwe/duU9aUXW8zALjllltsXW1qQkJxcbFpnfU4J0+exDnnnAMiAhEZFk20jK1t27Zh7969olAyECdK5eMAaohoT+RC/w8iGpS1QkSvEtF2m8fVAB4GMANAGYBmAA/059jMvJaZK5i5YsKECYMRUxAM7Opg7O64nd6FZ2dn49xzz0Vubi58Pp8RRwGAvLw8Y0jVypUr+4yBqDEYu/iE3rssWn+z7u5u2/38fj8aGxuN+h2v1xt1ZjwzQ9M0Q6Go7x2PJArJ+ho6OMn+intAnJkvd7IdEf0GgD7O+CDMWWeTI8sEIWXY+ftjxQB0SwIIx16ys7PxwQcfoKWlBcyMhQsXwuVyITc319g/1vwWawxm9erVvd5TLeS0Iycnx7bNiqZppiFabrcbRUVFhttPVSAulwtut7uXBSMZW8MPJ21a9tk9EiUQERUoLz8HYHvk+XMAriOinEhTy5kA/pYoOQRBvUOO192yegx9Zktzc7Nxce7q6sLEiRONLr9qFbudK8kag1HH8c6cOdOwmJYu7d0Ew+PxoLi4GCUlJaYOAGodTUdHhymYr7rE5s2bB6/Xi+LiYmzdurXXtMyxY8dKzGQY4rRNSzL5ieJiW4BwQ0sw8w4AGwG8B+BFADdL5pcw1FDv2tXAtk52djbcbrdhmagX9GiBf2sMxu12Izc3F4WFhTFb12RnZ2Pq1Km9MsSsHYJ37tyJ2tpaowOA7gr0+Xxwu93Ge6gxJn06oyiU4YcT91dSYeYvxVj3IwA/SqI4ghCT/lovasrvz3/+c9x4442GK2nSpEm4++67jWV6s0qdaIH/0tJS42IfreeYVU6Px2NybenY1dFMmDABBw4cABB22UXrImCVZaDFocLQJu2UiiAMdfpSNG63G263GytWrMDDDz9sugi3tbWZakMOHgyHDa09tFTUC7ieNqyn86rz3XVcLhdef/11WwVkV0ejywCYJ2jq6c7RZBGGJ+no/hKEjMGuxuT48eM4fvw4AJhcSfr2OllZWcYF3q6Hlo5aJa+nDVvTeSsrK1FWVmZq5KjHidTsMdWdpqMqmby8PFEaQkxEqQhCmuD3+7Fy5Uoj0L5mzRpjncvlcpRFFW2OCRBWYHoKs461C/BNN92EvLw8U/sVIoLb7UZBQUGvjs3REhgkDXj4Iu4vQUgCatNHp8ydOxe5ubkIhUL45S9/2WfQOxAImDoEO0nntWaPBQIBwz2naRpGjx6NlpYWaJpmBPBj/Q+iSARRKoKQQPQuv9FiDepFuKqqykjnBYAFCxYgGAwCAG677TbMnTvXVrGUlZUhEAigrq7O5KrSYzCqYrG29bf24fL5fDh69CgAGFlk+jF7enqwePFi7N27d8DnQ8h8xP0lCAlCdS3V19f3OZArEAhg374zJWBqs8ZYlen6ftYAe7QYjIq1A4BV8an9v/TXghALsVQEIUHEmm1iRa2M18nOzjYslWiuLLv9APs5JtGUkl0HAL34EoCkCQv9QiwVQUgQaiYVEaGzszPqjBRVAQHhzK/XXnsNZWVlMfuIWfdzuVzweDyYN2/egAsPdReZ/rBmqAlCLESpCEKC0F1LBQXhzkPNzc1Ru/ZaU3n13l92jSxj7ZeXl4cRI0aYFEB/hn7F2nYww8OE4YMoFUFIIHoar9rby84NpcY2cnNzoWmao4u3vl9OTo4xOEsNxltThmMdM9a2+kAyJ8cRhjeiVAQhwajB7lhpvnrMoquryyhe1IdcxUoN1hWXXdsVJ6N7VVeX3bbV1dVYtGiR4xHAwvBGAvWCkGD0zr6BQAAbNmyIGeuwmybpFGu6MNC/0b2xtu3vCGBh+CKWiiAkEGuwO5ZCiTZN0sl7FBcX28Y7+jO6N9a2MgJYcIpYKoKQRqidfvuyanSsg7qsF/1YQ8OsxNq2P8cRhi9iqQhCmuHEqlFxEjcRhGQhloogJIHBXOj1Vi9244SB5MU7RFkJTkgrpUJEfwBQEnk5BsBxZi4jomkAdgLYFVn3FjPfkHwJBSHxqBfvvlxbQN8uM1EGQjJJK6XCzF/QnxPRAwDUvhZ7mLms916CkLnYubYk3iGkM2mlVHQoPLx7EYBPploWQUglyU7lFatGGCzpGqj/BIBWZt6tLCsmom1E9DoRfSLajkS0gohqiaj2yJEjiZdUEBKIpPIKQ42kWypE9CqASTar7mLmZyPPrwfwe2VdM4CpzNxGRB8D8CcimsPMJ6wHYea1ANYCQEVFBVvXC8JQQ1xbwlAi6UqFmS+PtZ6I3AD+BcDHlH2CAIKR5+8Q0R4AswDUJlBUQRgyiNtKSBfS0f11OYD3mfmAvoCIJhBRVuT5dAAzAcj4OUEQhDQjHQP118Hs+gKAywD8gIhCAHoA3MDMR5MumSAIghCTtFMqzPxlm2WbAGxKvjSCIAhCf0g7pSIIghmJlwhDiXSMqQiCIAhDFFEqgiAIQtwQpSIIgiDEDVEqgiAIQtwQpSIIgiDEDVEqgiAIQtwQpSIIgiDEDVEqgiAIQtwQpSIIgiDEDVEqgiAIQtwQpSIIgiDEDVEqgiCYmDNnzoD7jRERGhoa4iuQgGnTpuHVV19NtRiOEKUiCIKJHTt2oKqqKtVi9GL37t3Izc3FkiVLom5z33334fzzz8eoUaNQXFyM++67z7R+wYIFmDBhAkaPHo158+bh2WefNa3/3e9+h6KiIni9XlxzzTU4evTMhI2RI0eaHllZWbj11luN9Rs3bsTs2bMxatQonHfeefjTn/5krHv88ceRlZVl2j9TG4WKUhEEAQCgaVqqRYjJzTffjAsvvDDmNsyMdevW4dixY3jxxRexZs0aPPXUU8b6n/3sZ2hubsaJEyewdu1aLFmyBM3NzQDCyvRrX/sa1q9fj9bWVowYMQI33XSTsW97e7vxaGlpQV5eHv71X/8VAHDw4EEsWbIEP/3pT3HixAncd999+OIXv4jDhw8b+1dWVpqOkY6KOx6IUhGEIc60adOwatUqnHfeeRg7diyWL1+Ozs5OY/3zzz+PsrIyjBkzBpdccgnq6+tN+/74xz9GaWkpvF4vNE0zuVqCwSBuu+02FBYWorCwELfddhuCwaCx/3333YeCggIUFhbi0UcfTdj/+NRTT2HMmDFYuHBhzO3uvPNOXHDBBXC73SgpKcHVV1+NN99801hfWloKtzs88YOIEAqFsH//fgDAhg0b8NnPfhaXXXYZRo4ciR/+8Id4+umncfLkyV7vs2nTJkycOBGf+MQnAAAHDhzAmDFjcMUVV4CIcOWVV8Lr9WLPnj0D+n/Xr1+PoqIijBs3Dj/60Y9M6/72t7+hsrISY8aMQUFBAW655RZ0dXUBCCveO+64w7T9VVddhQcffBAA8OMf/xhnn302Ro0ahZKSEmzZsmVA8sWEmZP+APCvAHYgPMWxwrJuJYAGALsAfEpZ/unIsgYA33LyPh/72MdYEDKdoqIinjNnDjc1NXFbWxtfcsklfNdddzEz87vvvssTJkzgt956izVN48cff5yLioq4s7PT2HfevHnc1NTEp06dMpa98sorzMz8ne98hz/+8Y9za2srHz58mCsrK/nb3/42MzP/5S9/4YkTJ/I//vEPbm9v5+uvv54B8O7du23lvPHGG9nn89k+5s6dG/X/CwQCPHPmTN6/fz9/73vf48WLFzs6Lz09PVxWVsYPP/ywafmVV17JOTk5DIA/9alPcXd3NzMzX3XVVXzvvfeatvV6vVxbW9vr2AsWLODvfe97xmtN0/iyyy7jZ599ljVN42eeeYbPPvtsbm9vZ2bmxx57jEeMGMHjxo3jmTNn8g9+8AMOhUK2cu/YsYO9Xi+//vrr3NnZybfffjtnZWUZn0ltbS3X1NRwKBTixsZGPvfcc/nBBx9kZua3336bCwoKjP/pyJEjnJeXxy0tLfz+++/z5MmT+eDBg8zM3NjYyA0NDbYyAKjlgV7fB7rjYB4AZgMoAVCtKhUA5wGoA5ADoBjAHgBZkcceANMBZEe2Oa+v9xGlIgwHioqKTBfOF154gadPn87MzDfccIOhBHRmzZrF1dXVxr6PPPJIr+PpF7Dp06fzCy+8YKx78cUXuaioiJmZly9fzt/85jeNdbt27YqpVAbK17/+deNi3x+l8t3vfpdLS0sNBarS1dXFmzdv5gceeMBY9slPfrKXAiosLOTXXnvNtOzDDz9kl8vFe/fuNS3/7W9/y16vl7OysjgvL4+ff/55Y92ePXt479693N3dzfX19Tx79my+5557bOX+/ve/z1/4wheM1+3t7ezxeIzPxMqDDz7I11xzjfH63HPP5ZdffpmZmR966CG+4oormJl59+7dPGHCBH7llVe4q6vL9lg6g1EqKXF/MfNOZt5ls+pqAE8xc5CZGxG2Si6KPBqYeS8zdwF4KrKtIAgApkyZYjwvKirCoUOHAAD79u3DAw88gDFjxhiP/fv3G+ut+1o5dOgQioqKbI996NChXu8bb/x+P1599VXcfvvt/dpvzZo1WLduHV544QXk5OT0Wu/xeHDFFVfg5ZdfxnPPPQcgHIg/ceKEabsTJ05g1KhRpmXr16/H/PnzUVxcbCx79dVXceedd6K6uhpdXV14/fXX8dWvfhV+vx8AMH36dBQXF8PlcmHu3Ln47ne/iz/+8Y+2slvPq9frxbhx44zXH3zwAf75n/8ZkyZNwujRo/Gf//mf+Oijj4z1y5Ytw5NPPgkAePLJJ/GlL30JAHDOOedg9erVuPvuuzFx4kRcd911pu9BvEi3mMrZAPYrrw9ElkVb3gsiWkFEtURUe+TIkYQJKgjphB4XAICmpiYUFhYCCCuMu+66C8ePHzcep06dwvXXX29sT0RRj1tYWIh9+/bZHrugoKDX+8bihhtu6JVBpT/mzJlju091dTU+/PBDTJ06FZMmTcL999+PTZs24YILLoj6Po8++ijuvfdebNmyBZMnT44pk6ZpRtxjzpw5qKurM9bt3bsXwWAQs2bNMu2zbt06LFu2zLTM7/fjsssuQ0VFBVwuFy688EJ8/OMfj5oGTES6d6YX1vN66tQptLW1Ga9vvPFGnHvuudi9ezdOnDiBe+65x3SsJUuW4Nlnn0VdXR127tyJa665xlj3xS9+EVu3bsW+fftARPjmN78Z8/wMiIGaOH09ALwKYLvN42plm2qY3V9rACxRXj8C4PORx2+V5V8CsKYvGcT9JQwHioqK+Pzzz+f9+/dzW1sbX3rppbxy5UpmZv773//OkydP5rfeeot7enq4vb2dn3/+eT5x4oSxr9Wtoi676667uLKykg8fPsxHjhzhSy+91IjXbN68mfPz83nHjh3c0dHBixcvjrv7q6Ojg5ubm43HHXfcwddeey0fPnzYdvsnn3yS8/Pz+b333uu1bufOnbx582Y+deoUd3V18fr169nj8fA777zDzMzbt2/nUaNG8RtvvMHt7e28ePFikxuKmfnNN9/kESNGGOdPp7q6mseNG8fbtm1j5nAs66yzzuKXXnqJmcPnqqWlxZBjzpw5fPfdd9v+D9u3b2ev18t//etfORgM8h133GGKqVx44YX8/e9/n3t6enjnzp08a9YsvvTSS03HuPzyy3nu3Lm8fPlyY9n777/PW7Zs4c7OTg4Gg7x8+XJeunSprQwYajEV4817K5WVAFYqr18CUBl5vBRtu2gPUSrCcKCoqIjvuecenj17Nvt8Pl66dCl3dHQY6//yl79wRUUF+3w+njRpEn/+8593rFROnz7Nt956K0+aNIknTZrEt956K58+fdrYdtWqVZyfn88FBQX8yCOPJCSmomKNqbzxxhvs9XqN19OmTWO3281er9d4fO1rX2Nm5vfee48vuugiHjlyJPt8Pq6oqOCnn37adPwNGzbwlClTeMSIEXzVVVdxW1ubaf2KFSt4yZIltrI99NBDPGPGDB45ciQXFxfz/fffb6y74447eOLEiTxixAguLi7m73znOzHjGo8//jhPmTKFzzrrLP6v//ov02fy+uuvc0lJCXu9Xp4/fz5/5zvf6aVU1q9fzwD4v//7v41ldXV1fOGFF/LIkSN57NixfOWVVxpBeyuDUSrEUUywZEBE1QD+DzPXRl7PAfA7hGMohQC2AJgJgAB8AGAhgIMA/g7gi8y8I9bxKyoquLa2NmHyC0I6MG3aNPz2t7/F5ZdfnmpRhDThjTfewJIlSww3V38honeYuWIg7+0eyE6DhYg+B+AhABMAvEBEfmb+FDPvIKKNAN4DoAG4mZm7I/vcgrDlkgXg0b4UiiAIwnAkFArhZz/7Gb761a8OSKEMlpQoFWZ+BsAzUdb9CMCPbJZvBrA5waIJgiAMWXbu3ImKigrMmzcPjz32WEpkSIlSEQQhfnz44YepFkFIE2bPno2Ojo6UypBuKcWCIAjCEEaUiiAIghA3RKkIgiAIcSOlKcWJhohOItyEUgDGA/ioz62GB3IuziDn4gxyLs5Qwsyj+t6sN5keqN810FzrTIOIauVchJFzcQY5F2eQc3EGIhpwgZ+4vwRBEIS4IUpFEARBiBuZrlTWplqANELOxRnkXJxBzsUZ5FycYcDnIqMD9YIgCEJyyXRLRRAEQUgiolQEQRCEuJERSoWIPk1Eu4iogYi+ZbM+h4j+EFn/NhFNS76UycHBufgPInqPiOqJaAsRxX8GbJrQ17lQtruWiJiIMjad1Mm5IKJFke/GDiL6XbJlTBYOfiNTieg1ItoW+Z18JhVyJhoiepSIDhPR9ijriYh+HjlP9UQUfdymykAHsaTLA+FW+HsATAeQDaAOwHmWbW4C8KvI8+sA/CHVcqfwXCwAMCLy/MbhfC4i240C8AaAt6AMjMukh8PvxUwA2wCMjbyemGq5U3gu1gK4MfL8PAAfplruBJ2LywBcAGB7lPWfAfAXhOdZXQzgbSfHzQRL5SIADcy8l5m7ADwF4GrLNlcDeCLy/I8AFlIqBg0knj7PBTO/xsynIi/fAhB7iPfQxcn3AgB+CODHADqTKVyScXIu/h3AL5j5GAAw8+Eky5gsnJwLBjA68twH4FAS5UsazPwGgKMxNrkawDoO8xaAMURU0NdxM0GpnA1gv/L6QGSZ7TbMrAEIABiXFOmSi5NzofIVhO9ERZOsqAAAA/ZJREFUMpE+z0XEnJ/CzC8kU7AU4OR7MQvALCJ6k4jeIqJPJ0265OLkXNwNYAkRHUB4htOtyREt7ejv9QRA5rdpEaJAREsAVAD4X6mWJRUQkQvATwF8OcWipAtuhF1gVQhbr28Q0VxmPp5SqVLD9QAeZ+YHiKgSwHoiOp+Ze1It2FAgEyyVgwCmKK8nR5bZbkNEboRN2rakSJdcnJwLENHlAO4CcBUzB5MkW7Lp61yMAnA+gGoi+hBhn/FzGRqsd/K9OADgOWYOMXMjgA8QVjKZhpNz8RUAGwGAmWsA5CLcbHK44eh6YiUTlMrfAcwkomIiykY4EP+cZZvnACyLPP88gP/mSCQqw+jzXBBROYBfI6xQMtVvDvRxLpg5wMzjmXkaM09DOL50FTMPuJFeGuPkN/InhK0UENF4hN1he5MpZJJwci6aACwEACKajbBSOZJUKdOD5wAsjWSBXQwgwMzNfe005N1fzKwR0S0AXkI4s+NRZt5BRD8AUMvMzwF4BGETtgHhwNR1qZM4cTg8F/cBGAng/0VyFZqY+aqUCZ0gHJ6LYYHDc/ESgP9NRO8B6Abwf5k546x5h+fiDgC/IaLbEQ7afzkTb0KJ6PcI30iMj8SPvgfAAwDM/CuE40mfAdAA4BSA5Y6Om4HnShAEQUgRmeD+EgRBENIEUSqCIAhC3BClIgiCIMQNUSqCIAhC3BClIgiCIMQNUSqCkCCI6MNIzYcgDBtEqQiCIAhxQ5SKIAwSIppGRO8T0QYi2klEfySiEZHVtxLRu0T0DyI6N7L9RURUE5nX8T9EVBJZPoeI/kZE/sj8ipmR5UuU5b8moqwU/auC0CeiVAQhPpQA+CUzzwZwAuEZPgDwETNfAOBhAP8nsux9AJ9g5nIA3wVwT2T5DQB+xsxlCDf7PBBpE/IFAJdGlncDWJyMf0gQBsKQb9MiCGnCfmZ+M/L8SQBfjzx/OvL3HQD/EnnuA/BExBJhRFpjAKgBcBcRTQbwNDPvJqKFAD4G4O+Rtjp5ADK5Z5swxBGlIgjxwdrvSH+td4Huxpnf2w8BvMbMn6PwaOtqAGDm3xHR2wCuBLCZiL6G8NS9J5h5ZeJEF4T4Ie4vQYgPUyOzNwDgiwC2xtjWhzMtxL+sLySi6QD2MvPPATwLoBTAFgCfJ6KJkW3OIqKiOMsuCHFDlIogxIddAG4mop0AxiIcQ4nGTwCsIqJtMHsLFgHYTkR+hGe9rGPm9wB8G8DLRFQP4BUAfY50FYRUIV2KBWGQRFxYzzPz+SkWRRBSjlgqgiAIQtwQS0UQBEGIG2KpCIIgCHFDlIogCIIQN0SpCIIgCHFDlIogCIIQN0SpCIIgCHHj/wOCjk0CqC82NwAAAABJRU5ErkJggg==\n", 127 | "text/plain": [ 128 | "
" 129 | ] 130 | }, 131 | "metadata": { 132 | "tags": [], 133 | "needs_background": "light" 134 | } 135 | } 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "nasxBsFpS63I" 142 | }, 143 | "source": [ 144 | "Fit these data using numpyro:" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "metadata": { 150 | "colab": { 151 | "base_uri": "https://localhost:8080/", 152 | "height": 334 153 | }, 154 | "id": "4Vn6SYstAUdV", 155 | "outputId": "820bf5e5-dd46-480a-f5b0-26c2cb5bf2c7" 156 | }, 157 | "source": [ 158 | "from jax.config import config\n", 159 | "\n", 160 | "config.update(\"jax_enable_x64\", True)\n", 161 | "\n", 162 | "from jax import random\n", 163 | "import jax.numpy as jnp\n", 164 | "\n", 165 | "import numpyro\n", 166 | "import numpyro.distributions as dist\n", 167 | "from numpyro.infer import MCMC, NUTS, init_to_value\n", 168 | "\n", 169 | "from kepler_jax import kepler\n", 170 | "\n", 171 | "\n", 172 | "def model(t, rv_err, rv=None):\n", 173 | " # Parameters\n", 174 | " K = numpyro.sample(\"K\", dist.Uniform(10.0, 100.0))\n", 175 | " P = numpyro.sample(\"P\", dist.LogNormal(np.log(4.23), 5.0))\n", 176 | " ecc = numpyro.sample(\"ecc\", dist.Uniform(0.0, 1.0))\n", 177 | "\n", 178 | " # Handle wrapping of angles appropriately\n", 179 | " phi_angle = numpyro.sample(\"phi_angle\", dist.Normal(0.0, 1.0), sample_shape=(2,))\n", 180 | " phi = numpyro.deterministic(\"phi\", jnp.arctan2(phi_angle[0], phi_angle[1]))\n", 181 | " w_angle = numpyro.sample(\"w_angle\", dist.Normal(0.0, 1.0), sample_shape=(2,))\n", 182 | " norm = jnp.sqrt(jnp.sum(w_angle ** 2))\n", 183 | " sinw = w_angle[0] / norm\n", 184 | " cosw = w_angle[1] / norm\n", 185 | " omega = numpyro.deterministic(\"omega\", jnp.arctan2(sinw, cosw))\n", 186 | "\n", 187 | " # RV trend parameters\n", 188 | " rv0 = numpyro.sample(\"rv0\", dist.Normal(0.0, 10.0))\n", 189 | " rv_trend = numpyro.sample(\"rv_trend\", dist.Normal(0.0, 10.0))\n", 190 | "\n", 191 | " # Deterministic transformations\n", 192 | " bkg = numpyro.deterministic(\"bkg\", rv0 + rv_trend * t / 365.25)\n", 193 | " mean_anom = 2 * np.pi * t / P - (phi + omega)\n", 194 | "\n", 195 | " # Solve Kepler's equation\n", 196 | " sinE, cosE = kepler(mean_anom, ecc)\n", 197 | "\n", 198 | " # MAGIC: convert to true anomaly\n", 199 | " tanf2 = jnp.sqrt((1 + ecc) / (1 - ecc)) * sinE / (1 + cosE)\n", 200 | " sinf = 2 * tanf2 / (1 + tanf2 ** 2)\n", 201 | " cosf = (1 - tanf2 ** 2) / (1 + tanf2 ** 2)\n", 202 | "\n", 203 | " # Evaluate the RV model\n", 204 | " rv_model = numpyro.deterministic(\n", 205 | " \"rv_model\", bkg + K * (cosw * (cosf + ecc) - sinw * sinf)\n", 206 | " )\n", 207 | "\n", 208 | " # Condition on the observations\n", 209 | " numpyro.sample(\"obs\", dist.Normal(bkg + rv_model, rv_err), obs=rv)\n", 210 | "\n", 211 | "\n", 212 | "# It's often useful to initialize well\n", 213 | "init_values = {\n", 214 | " \"K\": 56.0,\n", 215 | " \"P\": 4.230785,\n", 216 | " \"ecc\": 0.01,\n", 217 | " \"phi_angle\": np.array([0.85, 0.5]),\n", 218 | " \"w_angle\": np.array([0.0, 1.0]),\n", 219 | " \"rv0\": -1.8,\n", 220 | " \"rv_trend\": -1.6,\n", 221 | "}\n", 222 | "\n", 223 | "nuts_kernel = NUTS(\n", 224 | " model,\n", 225 | " dense_mass=True,\n", 226 | " target_accept_prob=0.95,\n", 227 | " init_strategy=init_to_value(values=init_values),\n", 228 | ")\n", 229 | "mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=1)\n", 230 | "rng_key = random.PRNGKey(42307)\n", 231 | "%time mcmc.run(rng_key, t, rv_err, rv=rv)\n", 232 | "\n", 233 | "samples = mcmc.get_samples()\n", 234 | "plt.hist(samples[\"K\"], 20, histtype=\"step\")\n", 235 | "plt.yticks([])\n", 236 | "plt.xlabel(\"RV semiamplitude [m/s]\");" 237 | ], 238 | "execution_count": 3, 239 | "outputs": [ 240 | { 241 | "output_type": "stream", 242 | "text": [ 243 | "sample: 100%|██████████| 2000/2000 [02:16<00:00, 14.65it/s, 1023 steps of size 2.19e-03. acc. prob=0.97]\n" 244 | ], 245 | "name": "stderr" 246 | }, 247 | { 248 | "output_type": "stream", 249 | "text": [ 250 | "CPU times: user 2min 17s, sys: 1.31 s, total: 2min 18s\n", 251 | "Wall time: 2min 18s\n" 252 | ], 253 | "name": "stdout" 254 | }, 255 | { 256 | "output_type": "display_data", 257 | "data": { 258 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAEGCAYAAABbzE8LAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOgElEQVR4nO3dfYxldX3H8c9HtrBTtQ52ocVBGDFitdgCZenDSgurpSursQ+kVWsr2sTYKv6hfZjUhJDSJpOSxrSx2qgxJF2DkPqIi8U1y2K78WGXx+XJAtshOKkKtaMYx1Lh2z9+vwnX2Tt3Zu7DfM+5834lk7lz7vnd8z2/mXzm3N+553ccEQIAbLxnZBcAAJsVAQwASQhgAEhCAANAEgIYAJJsWc/K27Zti+np6RGVAgDj6dZbb30sIk5avnxdATw9Pa3Dhw8PryoA2ARsP9xtOUMQAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkGRdV8IB42jH7H7NLyz23X5qckIHZ3YOsSJsFgQwNr35hUXNze7uu/30zN4hVoPNhCEIAEhCAANAEgIYAJIwBoyxMMiJtKnJiSFXA6wNAYyxMOiJNCADQxAAkIQABoAkBDAAJCGAASAJAQwASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIAkBDABJCGAASEIAA0ASAhgAkhDAAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkIQABoAkBDAAJCGAASAJAQwASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIMmW7AKAJTtm92t+YbGvtlOTE0OuBhg9AhiNMb+wqLnZ3dllABuGIQgASEIAA0AShiCAAU1NTmh6Zm/fbQ/O7BxyRWgLAhgY0CAB2m9wYzwwBAEASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIAkBDABJCGAASEIAA0ASAhgAkhDAAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkIS7IgOJuKX95kYAA4m4pf3mxhAEACQhgAEgCQEMAEkIYABIwkk4DNWO2f2aX1jsq+3U5MSQqwGajQDGUM0vLGpudnd2GUArMAQBAEkIYABIQgADQBICGACSEMAAkIQABoAkBDAAJCGAASAJAQwASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIAkBDABJCGAASEIAA0ASAhgAkhDAAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQJIt2QWgWXbM7tf8wmLf7acmJ4ZYDTDeCGD8iPmFRc3N7s4uA9gUGIIAgCQEMAAkIYABIAkBDABJCGAASEIAA0ASAhgAkhDAAJCEAAaAJAQwACQhgAEgCXNBAC01NTmh6Zm9fbc9OLNzyBVhvQhgoKUGCdB+gxvDxRAEACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkIQABoAkXIgxhga5tTy3lQc2DgE8hri1PNAODEEAQBICGACSEMAAkIQABoAkBDAAJCGAASAJAQwASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIAkBDABJCGAASEIAA0ASAhgAkhDAAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkIQABoAkW7ILALDxpiYnND2zt++2B2d2DrmizYkABjahQQK03+DGsRiCAIAkBDAAJCGAASAJAQwASQhgAEhCAANAEgIYAJIQwACQhAAGgCQEMAAkIYABIAkBDABJCGAASEIAA0ASpqMEsC6DzCW81J75hAsCGMC6DBqezCf8NIYgACAJAQwASRiCaKgds/s1v7DYV9upyYkhVwNgFAjghppfWNTc7O7sMgCMEEMQAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIQgADQBICGACSEMAAkIQABoAkBDAAJGEuiBFiQh0AvRDAI8SEOsCxBrmjxrjdTYMABrChBgnQcbubBmPAAJCEAAaAJAQwACQhgAEgCQEMAEkIYABIwsfQAGwKg14YNYrPH499AA/S6YPiajagOQa5MGpUnz/esADO+u/D1WjA+Bj0Krqm2bAAbuJ/HwDtMk6XIUuchAOANAQwACQhgAEgiSNi7Svbj0p6eHTl9LRN0mNJ2+5X22puW71S+2puW71S+2puYr2nR8RJyxeuK4Az2T4cEedl17Eebau5bfVK7au5bfVK7au5TfUyBAEASQhgAEjSpgD+YHYBfWhbzW2rV2pfzW2rV2pfza2ptzVjwAAwbtp0BAwAY4UABoAkjQhg23O2j9i+w/bhZc+923bY3rZC2zfZfqB+vWljKh645idruztsfyarXttX2p7vqOWSFdrusv012w/antmIeodQ84q/n42sty6/3Pb9tu+x/bcrtG1MH6+j5kb0se3rOv4e5mzfsULblD7uKSLSvyTNSdrWZfnzJd2kcvFHt+efK+lo/X5ifXxik2uu63yvCX0s6UpJf7pKu+MkPSTpDEnHS7pT0kubXHOv309CvRdJ+oKkE+rPJ7egj1etuUl9vOz5v5N0RZP6uNdXI46Ae3ivpD+XtNKZwt+QtC8ivh0R/yNpn6RdG1XcClaruW3Ol/RgRByNiCckfUzSa5NrapM/ljQbEf8rSRHxrS7rNK2P11Jz49i2pN+VdG2Xp5vWx5IaMgShElaft32r7bdKku3XSpqPiDt7tJuS9EjHz1+vyzZCvzVL0lbbh21/2fZvjrzS4ph6q3fYvsv2R2yf2KVdo/q4Wq3mXm1Hqds2z5R0ge2v2L7F9vYu7ZrWx2upeaW2o9ZrmxdI+mZEPNClXWYfr6gpd8R4eUTM2z5Z0j7b90v6S0kXJ9fVyyA1n17bniFpv+0jEfHQSKvtXu8HJF2l8kd9lcrbt7eMuI71GKTmY9pGxBcT6t2iMkT2S5K2S7re9hlR3xc3wCA1N6KPO7b5enU/+m2sRhwBR8R8/f4tSZ+U9GuSXiDpTttzkk6VdJvtn17WdF5lzHXJqXXZyA1Qc2fbo5IOSDonod7zI+KbEfFkRDwl6UMqb9OWa1Ifr7Xmrm0z6lU50vpEFF+V9JTKZDGdGtXHWlvNTepj2d4i6bclXbdC07Q+7iU9gG0/0/azlx6rHEEeioiTI2I6IqZV/iDOjYhvLGt+k6SLbZ9Y34peXJc1tuZa6wn18TZJOyTdm1Dv3bZP6VjttyTd3aX5IUkvsv0C28dLep2kkX9yY5CaV2qbUa+kT6mc1JLtM1VOAC2fqatRfbyWmhvWx5L0Skn3R8TXV2ie0seryj4LqHJW8s76dY+k93RZZ071zKek8yR9uOO5t0h6sH69uek1S/oVSUdq2yOS/iirXkn/XGu4S+WP8ZS6/HmSbuxof4mk/1A5i3zMvjat5rX8fjaw3uMl7VEJitsk7WxBH69ac5P6uD53jaS3LVs/vY9X++JSZABIkj4EAQCbFQEMAEkIYABIQgADQBICGACSEMBjzE/Puna37RtsT7rMHnftsvW22X506fPJI67pr2y/csTbOGD7vPr4xrrfk7b/pI/XutD2Z9e5/nds39jHtk6x/fkez99s+3tL+4b2I4DH22JEnB0RZ0n6tqS3q1w99Ou2f7xjvUsl3RB18pVRiogrIuILo95Ox/YuiYgFSZOS1h3Affq3iOg6TeYqdqnHhUQRcZGkDZn2ERuDAN48viRpKiK+K+kWSa/peO516nINve132r63TnzzsbrsmXUSnK/avt1lAiLZvsz2p2zvc5mT9R2231XX+bLt59b1rrF9aX18he1D9Qj9g7Zdlx+w/V6XCYvus73d9idc5nz+67rOtMt8tR+t6/zLsn8qS/swV684nJX0wvqO4OrlR7a232f7svp4V33t21Qub11ap+u+91K3c4vtT9s+anvW9u/X1zhi+4Udq++S9Ll6JPzFjncvF6y2HbQTAbwJ2D5O0iv09KWX16qErmw/T2X2q/1dms5IOicifk7S2+qy90jaHxHnq1yuenW9LFSSzlIJrO2S/kbS9yPiHJXw/8Mur/++iNhej9AnJL2647knIuI8Sf8k6dMqR+9nSbrM9k/WdV4s6f0R8RJJ31XvI9wZSQ/VdwR/ttJKtreqzDHxGkm/IKlzLo9e+97Lz6v030sk/YGkM+trfFjS5XW7x0l6cUTcK+kNkm6KiLNr264TjKP9CODxNuFyd4BvSPoplfmSJWmvpB22f0Jl/tSPR8STXdrfJemjtt8o6Yd12cWSZurrHpC0VdJp9bmbI+LxiHhU0nck3VCXH5E03eX1L3KZ8vCIpJ2SfrbjuaV/Fkck3RMR/1WHSI7q6UlVHomIg/XxHkkv79kba/Mzkv4zIh6Icpnono7neu17L4c66n9I0tI4b2e//KKkryytL+nNtq+U9LKIeLz/3UGTEcDjbbEeRZ0uySpHkYqIRUn/qjKZTdfhh2q3pH+UdK6kQy4zTlnS79QjybMj4rSIuK+u3zmG/FTHz09p2dSn9Ujz/ZIujYiXqRx1bu1YpbPt8tddeq3l19Gv57r6H+pH//63rrRih1773sta+uVVKr8TRZle8VdVZuu6xna3dw8YAwTwJhAR35f0TknvriEqldB9l8qR8ZeWt7H9DEnPj4ibJf2FpOdIepbKSaLLO8Zr+51KcynwHrP9LJUTget1mu1fro/fIOnfe6z7uKRnd/z8sKSX2j7B9qTKEI0k3S9pumNs9vUdbYa17928QuU2QLJ9usrE4h9SGaY4d4jbQYMQwJtERNyuMqSwFCj7VGaLui66z8h0nKQ9dXjgdkn/UD9NcJWkH5N0l+176s/91LOgctR7t0qwHerjZb4m6e2271O5J+AHemzvvyUdrCe1ro6IRyRdX7d/vco+KiJ+IOmtkvbWk3Cdt+MZyr4vZ/skST/oGGq4UGVe6dsl/Z6kvx/GdtA8zIaGVrI9Lemz9QReY9i+UOWmoa9ebd2ONm+UdGpEzK5h3QP19fk42hhoyi2JgHHxhKSzbN+41s8CR8Se1dcqF2KozIn7fwPUhwbhCBgAkjAGDABJCGAASEIAA0ASAhgAkhDAAJDk/wERYADW8r3OBgAAAABJRU5ErkJggg==\n", 259 | "text/plain": [ 260 | "
" 261 | ] 262 | }, 263 | "metadata": { 264 | "tags": [], 265 | "needs_background": "light" 266 | } 267 | } 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": { 273 | "id": "JYFjEFNbTB5Z" 274 | }, 275 | "source": [ 276 | "This works, but samples a lot slower than the _exoplanet_ example so I'm not totally sure what's up with that!" 277 | ] 278 | } 279 | ] 280 | } -------------------------------------------------------------------------------- /lib/cpu_ops.cc: -------------------------------------------------------------------------------- 1 | // This file defines the Python interface to the XLA custom call implemented on the CPU. 2 | // It is exposed as a standard pybind11 module defining "capsule" objects containing our 3 | // method. For simplicity, we export a separate capsule for each supported dtype. 4 | 5 | #include "kepler.h" 6 | #include "pybind11_kernel_helpers.h" 7 | 8 | using namespace kepler_jax; 9 | 10 | namespace { 11 | 12 | template 13 | void cpu_kepler(void *out_tuple, const void **in) { 14 | // Parse the inputs 15 | const std::int64_t size = *reinterpret_cast(in[0]); 16 | const T *mean_anom = reinterpret_cast(in[1]); 17 | const T *ecc = reinterpret_cast(in[2]); 18 | 19 | // The output is stored as a list of pointers since we have multiple outputs 20 | void **out = reinterpret_cast(out_tuple); 21 | T *sin_ecc_anom = reinterpret_cast(out[0]); 22 | T *cos_ecc_anom = reinterpret_cast(out[1]); 23 | 24 | for (std::int64_t n = 0; n < size; ++n) { 25 | compute_eccentric_anomaly(mean_anom[n], ecc[n], sin_ecc_anom + n, cos_ecc_anom + n); 26 | } 27 | } 28 | 29 | pybind11::dict Registrations() { 30 | pybind11::dict dict; 31 | dict["cpu_kepler_f32"] = EncapsulateFunction(cpu_kepler); 32 | dict["cpu_kepler_f64"] = EncapsulateFunction(cpu_kepler); 33 | return dict; 34 | } 35 | 36 | PYBIND11_MODULE(cpu_ops, m) { m.def("registrations", &Registrations); } 37 | 38 | } // namespace 39 | -------------------------------------------------------------------------------- /lib/gpu_ops.cc: -------------------------------------------------------------------------------- 1 | // This file defines the Python interface to the XLA custom call implemented on the GPU. 2 | // Like in cpu_ops.cc, we export a separate capsule for each supported dtype, but we also 3 | // include one extra method "build_kepler_descriptor" to generate an opaque representation 4 | // of the problem size that will be passed to the op. The actually implementation of the 5 | // custom call can be found in kernels.cc.cu. 6 | 7 | #include "kernels.h" 8 | #include "pybind11_kernel_helpers.h" 9 | 10 | using namespace kepler_jax; 11 | 12 | namespace { 13 | pybind11::dict Registrations() { 14 | pybind11::dict dict; 15 | dict["gpu_kepler_f32"] = EncapsulateFunction(gpu_kepler_f32); 16 | dict["gpu_kepler_f64"] = EncapsulateFunction(gpu_kepler_f64); 17 | return dict; 18 | } 19 | 20 | PYBIND11_MODULE(gpu_ops, m) { 21 | m.def("registrations", &Registrations); 22 | m.def("build_kepler_descriptor", 23 | [](std::int64_t size) { return PackDescriptor(KeplerDescriptor{size}); }); 24 | } 25 | } // namespace 26 | -------------------------------------------------------------------------------- /lib/kepler.h: -------------------------------------------------------------------------------- 1 | // This header defines the actual algorithm for our op. It is reused in cpu_ops.cc and 2 | // kernels.cc.cu to expose this as a XLA custom call. The details aren't too important 3 | // except that directly implementing this algorithm as a higher-level JAX function 4 | // probably wouldn't be very efficient. That being said, this is not meant as a 5 | // particularly efficient or robust implementation. It's just here to demonstrate the 6 | // infrastructure required to extend JAX. 7 | 8 | #ifndef _KEPLER_JAX_KEPLER_H_ 9 | #define _KEPLER_JAX_KEPLER_H_ 10 | 11 | #include 12 | 13 | namespace kepler_jax { 14 | 15 | #ifndef M_PI 16 | #define M_PI 3.14159265358979323846264338327950288 17 | #endif 18 | 19 | #ifdef __CUDACC__ 20 | #define KEPLER_JAX_INLINE_OR_DEVICE __host__ __device__ 21 | #else 22 | #define KEPLER_JAX_INLINE_OR_DEVICE inline 23 | 24 | template 25 | inline void sincos(const T& x, T* sx, T* cx) { 26 | *sx = sin(x); 27 | *cx = cos(x); 28 | } 29 | #endif 30 | 31 | template 32 | KEPLER_JAX_INLINE_OR_DEVICE void compute_eccentric_anomaly(const T& mean_anom, const T& ecc, 33 | T* sin_ecc_anom, T* cos_ecc_anom) { 34 | const T tol = 1e-12; 35 | T g, E = (mean_anom < M_PI) ? mean_anom + 0.85 * ecc : mean_anom - 0.85 * ecc; 36 | for (int i = 0; i < 20; ++i) { 37 | sincos(E, sin_ecc_anom, cos_ecc_anom); 38 | g = E - ecc * (*sin_ecc_anom) - mean_anom; 39 | if (fabs(g) <= tol) return; 40 | E -= g / (1 - ecc * (*cos_ecc_anom)); 41 | } 42 | } 43 | 44 | } // namespace kepler_jax 45 | 46 | #endif -------------------------------------------------------------------------------- /lib/kernel_helpers.h: -------------------------------------------------------------------------------- 1 | // This header is not specific to our application and you'll probably want something like this 2 | // for any extension you're building. This includes the infrastructure needed to serialize 3 | // descriptors that are used with the "opaque" parameter of the GPU custom call. In our example 4 | // we'll use this parameter to pass the size of our problem. 5 | 6 | #ifndef _KEPLER_JAX_KERNEL_HELPERS_H_ 7 | #define _KEPLER_JAX_KERNEL_HELPERS_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace kepler_jax { 15 | 16 | // https://en.cppreference.com/w/cpp/numeric/bit_cast 17 | template 18 | typename std::enable_if::value && 19 | std::is_trivially_copyable::value, 20 | To>::type 21 | bit_cast(const From& src) noexcept { 22 | static_assert( 23 | std::is_trivially_constructible::value, 24 | "This implementation additionally requires destination type to be trivially constructible"); 25 | 26 | To dst; 27 | memcpy(&dst, &src, sizeof(To)); 28 | return dst; 29 | } 30 | 31 | template 32 | std::string PackDescriptorAsString(const T& descriptor) { 33 | return std::string(bit_cast(&descriptor), sizeof(T)); 34 | } 35 | 36 | template 37 | const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { 38 | if (opaque_len != sizeof(T)) { 39 | throw std::runtime_error("Invalid opaque object size"); 40 | } 41 | return bit_cast(opaque); 42 | } 43 | 44 | } // namespace kepler_jax 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /lib/kernels.cc.cu: -------------------------------------------------------------------------------- 1 | // This file contains the GPU implementation of our op. It's a pretty typical CUDA kernel 2 | // and I make no promises about the quality of the code or the choices made therein, but 3 | // it should get the point accross. 4 | 5 | #include "kepler.h" 6 | #include "kernel_helpers.h" 7 | #include "kernels.h" 8 | 9 | namespace kepler_jax { 10 | 11 | namespace { 12 | 13 | template 14 | __global__ void kepler_kernel(std::int64_t size, const T *mean_anom, const T *ecc, T *sin_ecc_anom, 15 | T *cos_ecc_anom) { 16 | for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; 17 | idx += blockDim.x * gridDim.x) { 18 | compute_eccentric_anomaly(mean_anom[idx], ecc[idx], sin_ecc_anom + idx, cos_ecc_anom + idx); 19 | } 20 | } 21 | 22 | void ThrowIfError(cudaError_t error) { 23 | if (error != cudaSuccess) { 24 | throw std::runtime_error(cudaGetErrorString(error)); 25 | } 26 | } 27 | 28 | template 29 | inline void apply_kepler(cudaStream_t stream, void **buffers, const char *opaque, 30 | std::size_t opaque_len) { 31 | const KeplerDescriptor &d = *UnpackDescriptor(opaque, opaque_len); 32 | const std::int64_t size = d.size; 33 | 34 | const T *mean_anom = reinterpret_cast(buffers[0]); 35 | const T *ecc = reinterpret_cast(buffers[1]); 36 | T *sin_ecc_anom = reinterpret_cast(buffers[2]); 37 | T *cos_ecc_anom = reinterpret_cast(buffers[3]); 38 | 39 | const int block_dim = 128; 40 | const int grid_dim = std::min(1024, (size + block_dim - 1) / block_dim); 41 | kepler_kernel 42 | <<>>(size, mean_anom, ecc, sin_ecc_anom, cos_ecc_anom); 43 | 44 | ThrowIfError(cudaGetLastError()); 45 | } 46 | 47 | } // namespace 48 | 49 | void gpu_kepler_f32(cudaStream_t stream, void **buffers, const char *opaque, 50 | std::size_t opaque_len) { 51 | apply_kepler(stream, buffers, opaque, opaque_len); 52 | } 53 | 54 | void gpu_kepler_f64(cudaStream_t stream, void **buffers, const char *opaque, 55 | std::size_t opaque_len) { 56 | apply_kepler(stream, buffers, opaque, opaque_len); 57 | } 58 | 59 | } // namespace kepler_jax 60 | -------------------------------------------------------------------------------- /lib/kernels.h: -------------------------------------------------------------------------------- 1 | #ifndef _KEPLER_JAX_KERNELS_H_ 2 | #define _KEPLER_JAX_KERNELS_H_ 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | namespace kepler_jax { 10 | struct KeplerDescriptor { 11 | std::int64_t size; 12 | }; 13 | 14 | void gpu_kepler_f32(cudaStream_t stream, void** buffers, const char* opaque, 15 | std::size_t opaque_len); 16 | void gpu_kepler_f64(cudaStream_t stream, void** buffers, const char* opaque, 17 | std::size_t opaque_len); 18 | 19 | } // namespace kepler_jax 20 | 21 | #endif -------------------------------------------------------------------------------- /lib/pybind11_kernel_helpers.h: -------------------------------------------------------------------------------- 1 | // This header extends kernel_helpers.h with the pybind11 specific interface to 2 | // serializing descriptors. It also adds a pybind11 function for wrapping our 3 | // custom calls in a Python capsule. This is separate from kernel_helpers so that 4 | // the CUDA code itself doesn't include pybind11. I don't think that this is 5 | // strictly necessary, but they do it in jaxlib, so let's do it here too. 6 | 7 | #ifndef _KEPLER_JAX_PYBIND11_KERNEL_HELPERS_H_ 8 | #define _KEPLER_JAX_PYBIND11_KERNEL_HELPERS_H_ 9 | 10 | #include 11 | 12 | #include "kernel_helpers.h" 13 | 14 | namespace kepler_jax { 15 | 16 | template 17 | pybind11::bytes PackDescriptor(const T& descriptor) { 18 | return pybind11::bytes(PackDescriptorAsString(descriptor)); 19 | } 20 | 21 | template 22 | pybind11::capsule EncapsulateFunction(T* fn) { 23 | return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); 24 | } 25 | 26 | } // namespace kepler_jax 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "kepler_jax" 3 | description = "A simple demonstration of how you can extend JAX with custom C++ and CUDA ops" 4 | readme = "README.md" 5 | authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] 6 | requires-python = ">=3.9" 7 | license = { file = "LICENSE" } 8 | urls = { Homepage = "https://github.com/dfm/extending-jax" } 9 | dependencies = ["jax>=0.4.16", "jaxlib>=0.4.16"] 10 | dynamic = ["version"] 11 | 12 | [project.optional-dependencies] 13 | test = ["pytest"] 14 | 15 | [build-system] 16 | requires = ["pybind11>=2.6", "scikit-build-core>=0.5"] 17 | build-backend = "scikit_build_core.build" 18 | 19 | [tool.scikit-build] 20 | metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" 21 | sdist.include = ["src/kepler_jax/kepler_jax_version.py"] 22 | wheel.install-dir = "kepler_jax" 23 | minimum-version = "0.5" 24 | build-dir = "build/{wheel_tag}" 25 | 26 | [tool.setuptools_scm] 27 | write_to = "src/kepler_jax/kepler_jax_version.py" 28 | -------------------------------------------------------------------------------- /src/kepler_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __all__ = ["__version__", "kepler"] 4 | 5 | from .kepler_jax import kepler 6 | from .kepler_jax_version import version as __version__ 7 | -------------------------------------------------------------------------------- /src/kepler_jax/kepler_jax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __all__ = ["kepler"] 4 | 5 | from functools import partial 6 | 7 | import numpy as np 8 | from jax import core, dtypes, lax 9 | from jax import numpy as jnp 10 | from jax.core import ShapedArray 11 | from jax.interpreters import ad, batching, mlir, xla 12 | from jax.lib import xla_client 13 | from jaxlib.hlo_helpers import custom_call 14 | 15 | # Register the CPU XLA custom calls 16 | from . import cpu_ops 17 | 18 | for _name, _value in cpu_ops.registrations().items(): 19 | xla_client.register_custom_call_target(_name, _value, platform="cpu") 20 | 21 | # If the GPU version exists, also register those 22 | try: 23 | from . import gpu_ops 24 | except ImportError: 25 | gpu_ops = None 26 | else: 27 | for _name, _value in gpu_ops.registrations().items(): 28 | xla_client.register_custom_call_target(_name, _value, platform="gpu") 29 | 30 | # This function exposes the primitive to user code and this is the only 31 | # public-facing function in this module 32 | 33 | 34 | def kepler(mean_anom, ecc): 35 | # We're going to apply array broadcasting here since the logic of our op 36 | # is much simpler if we require the inputs to all have the same shapes 37 | mean_anom_, ecc_ = jnp.broadcast_arrays(mean_anom, ecc) 38 | 39 | # Then we need to wrap into the range [0, 2*pi) 40 | M_mod = jnp.mod(mean_anom_, 2 * np.pi) 41 | 42 | return _kepler_prim.bind(M_mod, ecc_) 43 | 44 | 45 | # ********************************* 46 | # * SUPPORT FOR JIT COMPILATION * 47 | # ********************************* 48 | 49 | # For JIT compilation we need a function to evaluate the shape and dtype of the 50 | # outputs of our op for some given inputs 51 | def _kepler_abstract(mean_anom, ecc): 52 | shape = mean_anom.shape 53 | dtype = dtypes.canonicalize_dtype(mean_anom.dtype) 54 | assert dtypes.canonicalize_dtype(ecc.dtype) == dtype 55 | assert ecc.shape == shape 56 | return (ShapedArray(shape, dtype), ShapedArray(shape, dtype)) 57 | 58 | 59 | # We also need a lowering rule to provide an MLIR "lowering" of out primitive. 60 | # This provides a mechanism for exposing our custom C++ and/or CUDA interfaces 61 | # to the JAX XLA backend. We're wrapping two translation rules into one here: 62 | # one for the CPU and one for the GPU 63 | def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"): 64 | 65 | # Checking that input types and shape agree 66 | assert mean_anom.type == ecc.type 67 | 68 | # Extract the numpy type of the inputs 69 | mean_anom_aval, _ = ctx.avals_in 70 | np_dtype = np.dtype(mean_anom_aval.dtype) 71 | 72 | # The inputs and outputs all have the same shape and memory layout 73 | # so let's predefine this specification 74 | dtype = mlir.ir.RankedTensorType(mean_anom.type) 75 | dims = dtype.shape 76 | layout = tuple(range(len(dims) - 1, -1, -1)) 77 | 78 | # The total size of the input is the product across dimensions 79 | size = np.prod(dims).astype(np.int64) 80 | 81 | # We dispatch a different call depending on the dtype 82 | if np_dtype == np.float32: 83 | op_name = platform + "_kepler_f32" 84 | elif np_dtype == np.float64: 85 | op_name = platform + "_kepler_f64" 86 | else: 87 | raise NotImplementedError(f"Unsupported dtype {np_dtype}") 88 | 89 | # And then the following is what changes between the GPU and CPU 90 | if platform == "cpu": 91 | # On the CPU, we pass the size of the data as a the first input 92 | # argument 93 | return custom_call( 94 | op_name, 95 | # Output types 96 | result_types=[dtype, dtype], 97 | # The inputs: 98 | operands=[mlir.ir_constant(size), mean_anom, ecc], 99 | # Layout specification: 100 | operand_layouts=[(), layout, layout], 101 | result_layouts=[layout, layout] 102 | ).results 103 | 104 | elif platform == "gpu": 105 | if gpu_ops is None: 106 | raise ValueError( 107 | "The 'kepler_jax' module was not compiled with CUDA support" 108 | ) 109 | # On the GPU, we do things a little differently and encapsulate the 110 | # dimension using the 'opaque' parameter 111 | opaque = gpu_ops.build_kepler_descriptor(size) 112 | 113 | return custom_call( 114 | op_name, 115 | # Output types 116 | result_types=[dtype, dtype], 117 | # The inputs: 118 | operands=[mean_anom, ecc], 119 | # Layout specification: 120 | operand_layouts=[layout, layout], 121 | result_layouts=[layout, layout], 122 | # GPU specific additional data 123 | backend_config=opaque 124 | ).results 125 | 126 | raise ValueError( 127 | "Unsupported platform; this must be either 'cpu' or 'gpu'" 128 | ) 129 | 130 | 131 | # ********************************** 132 | # * SUPPORT FOR FORWARD AUTODIFF * 133 | # ********************************** 134 | 135 | # Here we define the differentiation rules using a JVP derived using implicit 136 | # differentiation of Kepler's equation: 137 | # 138 | # M = E - e * sin(E) 139 | # -> dM = dE * (1 - e * cos(E)) - de * sin(E) 140 | # -> dE/dM = 1 / (1 - e * cos(E)) and de/dM = sin(E) / (1 - e * cos(E)) 141 | # 142 | # In this case we don't need to define a transpose rule in order to support 143 | # reverse and higher order differentiation. This might not be true in other 144 | # applications, so check out the "How JAX primitives work" tutorial in the JAX 145 | # documentation for more info as necessary. 146 | def _kepler_jvp(args, tangents): 147 | mean_anom, ecc = args 148 | d_mean_anom, d_ecc = tangents 149 | 150 | # We use "bind" here because we don't want to mod the mean anomaly again 151 | sin_ecc_anom, cos_ecc_anom = _kepler_prim.bind(mean_anom, ecc) 152 | 153 | def zero_tangent(tan, val): 154 | return lax.zeros_like_array(val) if type(tan) is ad.Zero else tan 155 | 156 | # Propagate the derivatives 157 | d_ecc_anom = ( 158 | zero_tangent(d_mean_anom, mean_anom) 159 | + zero_tangent(d_ecc, ecc) * sin_ecc_anom 160 | ) / (1 - ecc * cos_ecc_anom) 161 | 162 | return (sin_ecc_anom, cos_ecc_anom), ( 163 | cos_ecc_anom * d_ecc_anom, 164 | -sin_ecc_anom * d_ecc_anom, 165 | ) 166 | 167 | 168 | # ************************************ 169 | # * SUPPORT FOR BATCHING WITH VMAP * 170 | # ************************************ 171 | 172 | # Our op already supports arbitrary dimensions so the batching rule is quite 173 | # simple. The jax.lax.linalg module includes some example of more complicated 174 | # batching rules if you need such a thing. 175 | def _kepler_batch(args, axes): 176 | assert axes[0] == axes[1] 177 | return kepler(*args), axes 178 | 179 | 180 | # ********************************************* 181 | # * BOILERPLATE TO REGISTER THE OP WITH JAX * 182 | # ********************************************* 183 | _kepler_prim = core.Primitive("kepler") 184 | _kepler_prim.multiple_results = True 185 | _kepler_prim.def_impl(partial(xla.apply_primitive, _kepler_prim)) 186 | _kepler_prim.def_abstract_eval(_kepler_abstract) 187 | 188 | # Connect the XLA translation rules for JIT compilation 189 | for platform in ["cpu", "gpu"]: 190 | mlir.register_lowering( 191 | _kepler_prim, 192 | partial(_kepler_lowering, platform=platform), 193 | platform=platform) 194 | 195 | # Connect the JVP and batching rules 196 | ad.primitive_jvps[_kepler_prim] = _kepler_jvp 197 | batching.primitive_batchers[_kepler_prim] = _kepler_batch 198 | -------------------------------------------------------------------------------- /tests/test_kepler_jax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import jax 7 | from jax.test_util import check_grads 8 | 9 | from kepler_jax import kepler 10 | 11 | 12 | jax.config.update("jax_enable_x64", True) 13 | 14 | 15 | @pytest.fixture(params=[np.float32, np.float64]) 16 | def kepler_data(request): 17 | # Note about precision: the precision of the mod function in float32 is not 18 | # great so we're only going to test values in the range ~0-2*pi. In real 19 | # world applications, the mod should be done in float64 even if the solve 20 | # is done in float32 21 | ecc = np.linspace(0, 0.9, 55) 22 | true_ecc_anom = np.linspace(-np.pi, np.pi, 101) 23 | mean_anom = true_ecc_anom - ecc[:, None] * np.sin(true_ecc_anom) 24 | dtype = request.param 25 | return ( 26 | mean_anom.astype(dtype), 27 | ecc.astype(dtype), 28 | (true_ecc_anom + np.zeros_like(mean_anom)).astype(dtype), 29 | ) 30 | 31 | 32 | def check_kepler(sin_ecc_anom, cos_ecc_anom, true_ecc_anom): 33 | assert np.all(np.isfinite(sin_ecc_anom)) 34 | np.testing.assert_allclose(sin_ecc_anom, np.sin(true_ecc_anom), atol=1e-5) 35 | assert np.all(np.isfinite(cos_ecc_anom)) 36 | np.testing.assert_allclose(cos_ecc_anom, np.cos(true_ecc_anom), atol=1e-5) 37 | 38 | 39 | def test_kepler(kepler_data): 40 | mean_anom, ecc, true_ecc_anom = kepler_data 41 | sin_ecc_anom, cos_ecc_anom = kepler( 42 | mean_anom, ecc[:, None] + np.zeros_like(mean_anom) 43 | ) 44 | check_kepler(sin_ecc_anom, cos_ecc_anom, true_ecc_anom) 45 | 46 | 47 | def test_kepler_broadcast(kepler_data): 48 | mean_anom, ecc, true_ecc_anom = kepler_data 49 | sin_ecc_anom, cos_ecc_anom = kepler(mean_anom, ecc[:, None]) 50 | check_kepler(sin_ecc_anom, cos_ecc_anom, true_ecc_anom) 51 | 52 | 53 | def test_kepler_jit(kepler_data): 54 | mean_anom, ecc, true_ecc_anom = kepler_data 55 | sin_ecc_anom, cos_ecc_anom = jax.jit(kepler)(mean_anom, ecc[:, None]) 56 | check_kepler(sin_ecc_anom, cos_ecc_anom, true_ecc_anom) 57 | 58 | 59 | def test_kepler_vmap(kepler_data): 60 | mean_anom, ecc, true_ecc_anom = kepler_data 61 | sin_ecc_anom, cos_ecc_anom = jax.vmap(kepler)(mean_anom, ecc) 62 | check_kepler(sin_ecc_anom, cos_ecc_anom, true_ecc_anom) 63 | 64 | 65 | def test_kepler_grad(kepler_data): 66 | mean_anom, ecc, true_ecc_anom = kepler_data 67 | if mean_anom.dtype != np.float64: 68 | pytest.skip("Gradients only stable in double precision") 69 | 70 | m = ecc > 0.01 71 | check_grads( 72 | lambda *args: kepler(*args)[0], 73 | [mean_anom[m], ecc[m][:, None]], 74 | 2, 75 | eps=1e-6, 76 | ) 77 | check_grads( 78 | lambda *args: kepler(*args)[1], 79 | [mean_anom[m], ecc[m][:, None]], 80 | 2, 81 | eps=1e-6, 82 | ) 83 | --------------------------------------------------------------------------------