├── README.md ├── .gitignore ├── initial_custom_call ├── setup.py ├── custom_call_for_test.pyx ├── test.py ├── README.md └── multiply_add.py └── pybind11_register_custom_call ├── custom_call_for_test.cpp ├── test.py └── README.md /README.md: -------------------------------------------------------------------------------- 1 | # jax_xla_adventures 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | 4 | *.cpp 5 | *.so 6 | build 7 | -------------------------------------------------------------------------------- /initial_custom_call/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | 4 | setup( 5 | ext_modules=cythonize("custom_call_for_test.pyx") 6 | ) -------------------------------------------------------------------------------- /pybind11_register_custom_call/custom_call_for_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace py = pybind11; 4 | 5 | const void multiply_add_f32(void* out_ptr, void** data_ptr) { 6 | float x = ((float*) data_ptr[0])[0]; 7 | float y = ((float*) data_ptr[1])[0]; 8 | float z = ((float*) data_ptr[2])[0]; 9 | float* out = (float*) out_ptr; 10 | out[0] = x*y + z; 11 | } 12 | 13 | PYBIND11_MODULE(custom_call_for_test, m) { 14 | m.doc() = "pybind11 capsuling for registering XLA custom calls"; 15 | m.def("return_multiply_add_f32_capsule", 16 | []() { 17 | const char* name = "xla._CUSTOM_CALL_TARGET"; 18 | return py::capsule((void *) &multiply_add_f32, name);}, "Returns a capsule."); 19 | } 20 | -------------------------------------------------------------------------------- /initial_custom_call/custom_call_for_test.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=2 2 | # distutils: language = c++ 3 | 4 | # Test case for defining a XLA custom call target in Cython, and registering 5 | # it via the xla_client SWIG API. 6 | from cpython.pycapsule cimport PyCapsule_New 7 | 8 | 9 | cdef void multiply_add_f32(void* out_ptr, void** data_ptr) nogil: 10 | cdef float x = ((data_ptr[0]))[0] 11 | cdef float y = ((data_ptr[1]))[0] 12 | cdef float z = ((data_ptr[2]))[0] 13 | cdef float* out = (out_ptr) 14 | out[0] = x*y + z 15 | 16 | 17 | cpu_custom_call_targets = {} 18 | 19 | cdef register_custom_call_target(fn_name, void* fn): 20 | cdef const char* name = "xla._CUSTOM_CALL_TARGET" 21 | cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) 22 | 23 | 24 | register_custom_call_target(b"multiply_add_f32", (multiply_add_f32)) 25 | -------------------------------------------------------------------------------- /initial_custom_call/test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jaxlib import xla_client 3 | import custom_call_for_test 4 | 5 | # register the function 6 | for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 7 | xla_client.register_cpu_custom_call_target(name, fn) 8 | 9 | c = xla_client.ComputationBuilder('comp_builder') 10 | 11 | c.CustomCall(b'multiply_add_f32', 12 | operands=(c.ConstantF32Scalar(2.), c.ConstantF32Scalar(0.5), c.ConstantF32Scalar(2.5)), 13 | shape_with_layout=xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 14 | operand_shapes_with_layout=( 15 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 16 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 17 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()) 18 | )) 19 | 20 | compiled_c = c.Build().Compile() 21 | result = xla_client.execute_with_python_values(compiled_c, ()) 22 | print("Result: {} Expected: {}".format(result, 3.5)) -------------------------------------------------------------------------------- /pybind11_register_custom_call/test.py: -------------------------------------------------------------------------------- 1 | import custom_call_for_test 2 | import jax.numpy as jnp 3 | from jaxlib import xla_client 4 | 5 | xla_client.register_cpu_custom_call_target( 6 | b'multiply_add_f32', 7 | custom_call_for_test.return_multiply_add_f32_capsule()) 8 | 9 | c = xla_client.ComputationBuilder('comp_builder') 10 | 11 | x, y, z = (0.6, 5., 0.14) 12 | 13 | c.CustomCallWithLayout(b'multiply_add_f32', 14 | operands=(c.ConstantF32Scalar(x), c.ConstantF32Scalar(y), c.ConstantF32Scalar(z)), 15 | shape_with_layout=xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 16 | operand_shapes_with_layout=( 17 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 18 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 19 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()) 20 | )) 21 | 22 | compiled_c = c.Build().Compile() 23 | result = xla_client.execute_with_python_values(compiled_c, ()) 24 | print("Result: {} Expected: {}".format(result, x*y + z)) 25 | -------------------------------------------------------------------------------- /initial_custom_call/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This provides a small working example of registering a 4 | `CustomCall` for JIT compilation of a `JAX` primitive. 5 | 6 | First run 7 | 8 | ```shell script 9 | $ python setup.py build_ext --inplace 10 | ``` 11 | 12 | to generate the `Cython` code. Then 13 | 14 | ``` shell script 15 | $ python test.py 16 | ``` 17 | 18 | should run without error. However, 19 | 20 | ```shell script 21 | $ python multiply_add.py 22 | ``` 23 | 24 | will fail. Uncomment the lines 25 | ```python 26 | # for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 27 | # xla_client.register_cpu_custom_call_target(name, fn) 28 | ``` 29 | to register the `multipl_add_f32` C++ function, and the code should 30 | now successfully run! 31 | 32 | ## References 33 | 34 | This code is drawn together from several sources 35 | * For explanation of JAX primitives see [the JAX docs](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) 36 | * The `.pyx` and implementation of a CustomCall was found in the 37 | test suite of [Tensorflow xla compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla), 38 | in particular `tensorflow/compiler/xla/python/custom_call_for_test.pyx` 39 | * Much less helpful was the [XLA CustomCall documentation](https://www.tensorflow.org/xla/custom_call) :( -------------------------------------------------------------------------------- /pybind11_register_custom_call/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Similar to [the example using SWIG](https://github.com/danieljtait/jax_xla_adventures/tree/master/initial_custom_call) this demonstrates 4 | a simple example of registering a `CustomCall` using 5 | `pybind11`. 6 | 7 | The basic idea is the same, we 8 | 9 | 1. write our C++ function definition in `custom_call_for_test.cpp` 10 | 2. bundle it up in a `PyCapsule` 11 | 3. finally create a `pybind11` module with a function that can return the capsule. 12 | 13 | We then compile the C++ code, for example using 14 | 15 | ```shell script 16 | % c++ -O3 -Wall -shared -std=c++11 -undefined dynamic_lookup `python3 -m pybind11 --includes` custom_call_for_test.cpp -o custom_call_for_test`python3-config --extension-suffix` 17 | ``` 18 | though see [the pybind11 docs](https://pybind11.readthedocs.io/en/stable/compiling.html#building-manually) 19 | for any additional guidance needed here. 20 | 21 | We can now get access to this function by importing the module 22 | ```python 23 | >>> import custom_call_for_test 24 | >>> custom_call_for_test.return_multiply_add_f32_capsule() 25 | 26 | ``` 27 | and register it in the usual way 28 | ```python 29 | xla_client.register_cpu_custom_call_target( 30 | b'multiply_add_f32', 31 | custom_call_for_test.return_multiply_add_f32_capsule()) 32 | ``` -------------------------------------------------------------------------------- /initial_custom_call/multiply_add.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | from jax import abstract_arrays, core, xla, api 3 | import numpy as onp 4 | import jax.numpy as jnp 5 | import custom_call_for_test 6 | from jaxlib import xla_client 7 | """ 8 | See https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html 9 | for an explanation on most of the primatives 10 | """ 11 | multiply_add_p = core.Primitive("multiply_add") # Create the primitive 12 | 13 | 14 | # register the function -- uncomment these lines for this to run successfully 15 | # for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 16 | # xla_client.register_cpu_custom_call_target(name, fn) 17 | 18 | 19 | def multiply_add_prim(x, y, z): 20 | """The JAX-traceable way to use the JAX primitive. 21 | 22 | Note that the traced arguments must be passed as positional arguments 23 | to `bind`. 24 | """ 25 | return multiply_add_p.bind(x, y, z) 26 | 27 | 28 | def multiply_add_impl(x, y, z): 29 | """Concrete implementation of the primitive. 30 | 31 | This function does not need to be JAX traceable. 32 | Args: 33 | x, y, z: the concrete arguments of the primitive. Will only be caled with 34 | concrete values. 35 | Returns: 36 | the concrete result of the primitive. 37 | """ 38 | # Note that we can use the original numpy, which is not JAX traceable 39 | return onp.add(onp.multiply(x, y), z) 40 | 41 | 42 | def multiply_add_abstract_eval(xs, ys, zs): 43 | """Abstract evaluation of the primitive. 44 | 45 | This function does not need to be JAX traceable. It will be invoked with 46 | abstractions of the actual arguments. 47 | Args: 48 | xs, ys, zs: abstractions of the arguments. 49 | Result: 50 | a ShapedArray for the result of the primitive. 51 | """ 52 | assert xs.shape == ys.shape 53 | assert xs.shape == zs.shape 54 | return abstract_arrays.ShapedArray(xs.shape, xs.dtype) 55 | 56 | 57 | def multiply_add_xla_translation(c, xc, yc, zc): 58 | """The compilation to XLA of the primitive. 59 | 60 | Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the 61 | result of the function. 62 | 63 | Does not need to be a JAX-traceable function. 64 | """ 65 | return c.CustomCall(b'multiply_add_f32', 66 | operands=(xc, yc, zc), 67 | shape_with_layout=xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 68 | operand_shapes_with_layout=( 69 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 70 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()), 71 | xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()) 72 | )) 73 | 74 | 75 | # Define the concrete implementation 76 | multiply_add_p.def_impl(multiply_add_impl) 77 | # Define the abstract evaluation 78 | multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) 79 | # Register XLA compilation rule 80 | xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation 81 | 82 | x, y, z = (1., 2., 3.) 83 | 84 | jit_res = api.jit(multiply_add_prim)(x, y, z) 85 | res = multiply_add_prim(x, y, z) 86 | 87 | print("Result {} Expected {}".format(jit_res, res)) 88 | --------------------------------------------------------------------------------