├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── doc ├── bgpq.md ├── core_concepts.md ├── hashtable.md ├── queue.md └── stack.md ├── requirements.txt ├── setup.py ├── tests ├── dataclass_test.py ├── hash_test.py ├── heap_test.py ├── queue_test.py └── stack_test.py └── xtructure ├── __init__.py ├── bgpq ├── __init__.py ├── benchmark_merges.py ├── bgpq.py └── merge_split │ ├── __init__.py │ ├── common.py │ ├── loop.py │ ├── parallel.py │ └── split.py ├── core ├── __init__.py ├── field_descriptors.py ├── protocol.py ├── structuredtype.py └── xtructure_decorators │ ├── __init__.py │ ├── annotate.py │ ├── default.py │ ├── hash.py │ ├── indexing.py │ ├── ops.py │ ├── shape.py │ ├── string_format.py │ └── structure_util.py ├── hashtable ├── __init__.py └── hashtable.py ├── queue ├── __init__.py └── queue.py └── stack ├── __init__.py └── stack.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | 9 | - repo: https://github.com/PyCQA/flake8 10 | rev: 6.0.0 11 | hooks: 12 | - id: flake8 13 | name: Check PEP8 14 | args: [--max-line-length=120, "--ignore=E121,E123,E126,E203,E226,E24,E704,W503,W504", "--per-file-ignores=__init__.py:F401"] 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 22.10.0 18 | hooks: 19 | - id: black 20 | args: [--line-length=100] 21 | name: Format code 22 | exclude: docs/source-app 23 | 24 | - repo: https://github.com/asottile/blacken-docs 25 | rev: 1.13.0 26 | hooks: 27 | - id: blacken-docs 28 | args: [--line-length=120] 29 | additional_dependencies: [black==23.1.0] 30 | exclude: docs/source-app 31 | 32 | - repo: https://github.com/PyCQA/isort 33 | rev: 5.12.0 34 | hooks: 35 | - id: isort 36 | name: Format imports 37 | args: [--profile=black] 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 KyuSeok Jung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Xtructure 2 | 3 | A Python package providing JAX-optimized data structures, including a batched priority queue and a cuckoo hash table. 4 | 5 | ## Features 6 | 7 | - Stack (`Stack`): A LIFO (Last-In, First-Out) data structure. 8 | - Queue (`Queue`): A FIFO (First-In, First-Out) data structure. 9 | - Batched GPU Priority Queue (`BGPQ`): A batched priority queue optimized for GPU operations. 10 | - Cuckoo Hash Table (`HashTable`): A cuckoo hash table optimized for GPU operations. 11 | - Optimized for JAX. 12 | 13 | ## Installation 14 | 15 | ```bash 16 | pip install xtructure 17 | pip install git+https://github.com/tinker495/xtructure.git # recommended 18 | ``` 19 | 20 | Currently under active development, with frequent updates and potential bug fixes. For the most up-to-date version, it is recommended to install directly from the Git repository. 21 | 22 | ## Documentation 23 | 24 | Detailed documentation on how to use Xtructure is available in the `doc/` directory: 25 | 26 | * **[Core Concepts](./doc/core_concepts.md)**: Learn how to define custom data structures using `@xtructure_dataclass` and `FieldDescriptor`. 27 | * **[Stack Usage](./doc/stack.md)**: Guide to using the Stack data structure. 28 | * **[Queue Usage](./doc/queue.md)**: Guide to using the Queue data structure. 29 | * **[BGPQ Usage](./doc/bgpq.md)**: Guide to using the Batched GPU Priority Queue. 30 | * **[HashTable Usage](./doc/hashtable.md)**: Guide to using the Cuckoo hash table. 31 | 32 | Quick examples can still be found below for a brief overview. 33 | 34 | ## Quick Examples 35 | 36 | ```python 37 | import jax 38 | import jax.numpy as jnp 39 | 40 | from xtructure import xtructure_dataclass, FieldDescriptor 41 | from xtructure import HashTable, BGPQ 42 | 43 | 44 | # Define a custom data structure using xtructure_data 45 | @xtructure_dataclass 46 | class MyDataValue: 47 | a: FieldDescriptor[jnp.uint8] 48 | b: FieldDescriptor[jnp.uint32, (1, 2)] 49 | 50 | 51 | # --- HashTable Example --- 52 | print("--- HashTable Example ---") 53 | 54 | # Build a HashTable for a custom data structure 55 | key = jax.random.PRNGKey(0) 56 | key, subkey = jax.random.split(key) 57 | hash_table = HashTable.build(MyDataValue, 1, capacity=1000) 58 | 59 | # Insert random data 60 | items_to_insert = MyDataValue.random((100,), key=subkey) 61 | hash_table, inserted_mask, _, _, _ = HashTable.parallel_insert(hash_table, items_to_insert) 62 | print(f"HashTable: Inserted {jnp.sum(inserted_mask)} items. Current size: {hash_table.size}") 63 | 64 | # Lookup an item 65 | item_to_find = items_to_insert[0] 66 | _, _, found = HashTable.lookup(hash_table, item_to_find) 67 | print(f"HashTable: Item found? {found}") 68 | 69 | 70 | # --- Batched GPU Priority Queue (BGPQ) Example --- 71 | print("\n--- BGPQ Example ---") 72 | 73 | # Build a BGPQ with a specific batch size 74 | key = jax.random.PRNGKey(1) 75 | pq_batch_size = 64 76 | priority_queue = BGPQ.build( 77 | max_size=2000, 78 | batch_size=pq_batch_size, 79 | pytree_def_type_for_values_class=MyDataValue, 80 | ) 81 | print(f"BGPQ: Built with max_size={priority_queue.max_size}, batch_size={priority_queue.batch_size}") 82 | 83 | # Prepare a batch of keys and values to insert 84 | key, subkey1, subkey2 = jax.random.split(key, 3) 85 | keys_to_insert = jax.random.uniform(subkey1, (pq_batch_size,)).astype(jnp.bfloat16) 86 | values_to_insert = MyDataValue.random((pq_batch_size,), key=subkey2) 87 | 88 | # Insert data 89 | priority_queue = BGPQ.insert(priority_queue, keys_to_insert, values_to_insert) 90 | print(f"BGPQ: Inserted a batch. Current size: {priority_queue.size}") 91 | 92 | # Delete a batch of minimums 93 | priority_queue, min_keys, _ = BGPQ.delete_mins(priority_queue) 94 | valid_mask = jnp.isfinite(min_keys) 95 | print(f"BGPQ: Deleted {jnp.sum(valid_mask)} items. Size after deletion: {priority_queue.size}") 96 | ``` 97 | 98 | ## Working Example 99 | 100 | For a fully functional example using `Xtructure`, check out the [JAxtar](https://github.com/tinker495/JAxtar) repository. `JAxtar` demonstrates how to use `Xtructure` to build a JAX-native, parallelizable A* and Q* solver for neural heuristic search research, showcasing the library in a real, high-performance computing workflow. 101 | 102 | ## Citation 103 | 104 | If you use this code in your research, please cite: 105 | 106 | ``` 107 | @software{kyuseokjung2025xtructure, 108 | title={xtructure: JAX-optimized Data Structures}, 109 | author={Kyuseok Jung}, 110 | url = {https://github.com/tinker495/Xtructure}, 111 | year={2025}, 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /doc/bgpq.md: -------------------------------------------------------------------------------- 1 | # `BGPQ` (Batched GPU Priority Queue) Usage 2 | 3 | A priority queue optimized for batched operations on GPUs. It maintains items sorted by a key. 4 | 5 | ```python 6 | import jax 7 | import jax.numpy as jnp 8 | from xtructure import BGPQ, xtructure_dataclass, FieldDescriptor 9 | 10 | 11 | # Define a data structure for BGPQ values (as an example from core_concepts.md) 12 | @xtructure_dataclass 13 | class MyHeapItem: 14 | task_id: FieldDescriptor[jnp.int32] 15 | payload: FieldDescriptor[jnp.float64, (2, 2)] 16 | 17 | 18 | # 1. Build a BGPQ 19 | # BGPQ.build(total_size, batch_size, value_pytree_def_type) 20 | pq_total_size = 2000 # Max number of items 21 | pq_batch_size = 64 # Items to insert/delete per operation 22 | priority_queue = BGPQ.build(pq_total_size, pq_batch_size, MyHeapItem) 23 | # Note: MyHeapItem (the class itself) is passed. 24 | 25 | print(f"BGPQ: Built with max_size={priority_queue.max_size}, batch_size={priority_queue.batch_size}") 26 | 27 | # 2. Prepare keys and values to insert 28 | num_items_to_insert_pq = 150 29 | prng_key = jax.random.PRNGKey(10) 30 | keys_for_pq = jax.random.uniform(prng_key, (num_items_to_insert_pq,)).astype(jnp.float16) 31 | prng_key, subkey = jax.random.split(prng_key) 32 | values_for_pq = MyHeapItem.random(shape=(num_items_to_insert_pq,), key=subkey) 33 | 34 | # 3. Insert data into BGPQ in batches 35 | # BGPQ.insert expects keys and values to be shaped to pq_batch_size. 36 | # Loop through data in chunks and use BGPQ.make_batched for padding. 37 | print(f"BGPQ: Starting to insert {num_items_to_insert_pq} items.") 38 | for i in range(0, num_items_to_insert_pq, pq_batch_size): 39 | start_idx = i 40 | end_idx = min(i + pq_batch_size, num_items_to_insert_pq) 41 | 42 | current_keys_chunk = keys_for_pq[start_idx:end_idx] 43 | # For PyTrees (like our MyHeapItem), slice each leaf array 44 | current_values_chunk = jax.tree_util.tree_map(lambda arr: arr[start_idx:end_idx], values_for_pq) 45 | 46 | # Pad the chunk if it's smaller than pq_batch_size 47 | keys_to_insert, values_to_insert = BGPQ.make_batched(current_keys_chunk, current_values_chunk, pq_batch_size) 48 | 49 | priority_queue = BGPQ.insert(priority_queue, keys_to_insert, values_to_insert) 50 | 51 | print(f"BGPQ: Inserted items. Current size: {priority_queue.size}") 52 | 53 | # 4. Delete minimums (deletes a batch of batch_size items) 54 | # BGPQ.delete_mins(heap) 55 | if priority_queue.size > 0: 56 | priority_queue, min_keys, min_values = BGPQ.delete_mins(priority_queue) 57 | # min_keys and min_values will have shape (pq_batch_size, ...) 58 | 59 | # Filter out padded items (keys will be jnp.inf for padding) 60 | valid_mask = jnp.isfinite(min_keys) 61 | actual_min_keys = min_keys[valid_mask] 62 | actual_min_values = jax.tree_util.tree_map(lambda x: x[valid_mask], min_values) 63 | 64 | print(f"BGPQ: Deleted {jnp.sum(valid_mask)} items.") 65 | if jnp.sum(valid_mask) > 0: 66 | print(f"BGPQ: Smallest key deleted: {actual_min_keys[0]}") 67 | # print(f"BGPQ: Corresponding value: {actual_min_values[0]}") # If you want to see the value 68 | print(f"BGPQ: Size after deletion: {priority_queue.size}") 69 | else: 70 | print("BGPQ: Heap is empty, cannot delete.") 71 | ``` 72 | 73 | ## Key `BGPQ` Details 74 | 75 | * **Batched Operations**: All operations (insert, delete_mins) are designed to work on batches of data of size `batch_size`. 76 | * **`BGPQ.build(total_size, batch_size, value_class)`**: 77 | * `total_size`: Desired maximum capacity. The actual `max_size` of the queue might be slightly larger to be an exact multiple of `batch_size` (calculated as `ceil(total_size / batch_size) * batch_size`). 78 | * `batch_size`: The fixed size for all batch operations. 79 | * `value_class`: The *class* of your custom `@xtructure_dataclass` used for storing values in the queue. This class must have a `.default()` method. 80 | * **`BGPQ.make_batched(keys, values, batch_size)`**: (Static method) 81 | * A crucial helper to prepare data for `BGPQ.insert()`. It takes a chunk of keys and corresponding values and pads them to the required `batch_size`. 82 | * Keys are padded with `jnp.inf`. 83 | * Values are padded using `value_class.default()` for the padding portion. 84 | * Returns `batched_keys, batched_values`. 85 | * **`BGPQ.insert(heap, block_key, block_val, added_size=None)`**: 86 | * Inserts a batch of keys and values. Inputs (`block_key`, `block_val`) *must* be pre-batched, typically using `BGPQ.make_batched()`. 87 | * `added_size` is an optional integer; if not provided, the function counts the number of finite keys in `block_key` to determine how many items are being added. 88 | * **`BGPQ.delete_mins(heap)`**: 89 | * Returns the modified queue, a batch of `batch_size` smallest keys, and their corresponding values. 90 | * **Important**: If the queue contains fewer than `batch_size` items, the returned `min_keys` and `min_values` arrays will be padded (keys with `jnp.inf`, values with their defaults). You **must** use a filter like `valid_mask = jnp.isfinite(min_keys)` to identify and use only the actual (non-padded) items returned. 91 | * **Internal Structure**: The BGPQ maintains a min-heap structure. This heap is composed of multiple sorted blocks, each of size `batch_size`, allowing for efficient batched heap operations. 92 | -------------------------------------------------------------------------------- /doc/core_concepts.md: -------------------------------------------------------------------------------- 1 | # Core Concepts: Defining Custom Data Structures 2 | 3 | Before using `HashTable` or `BGPQ` in xtructure, you often need to define the structure of the data you want to store. This is done using the `@xtructure_dataclass` decorator and `FieldDescriptor`. 4 | 5 | ```python 6 | import jax 7 | import jax.numpy as jnp 8 | from xtructure import xtructure_dataclass, FieldDescriptor 9 | 10 | 11 | # Example: Defining a data structure for HashTable values 12 | @xtructure_dataclass 13 | class MyDataValue: 14 | id: FieldDescriptor[jnp.uint32] 15 | position: FieldDescriptor[jnp.float32, (3,)] # A 3-element vector 16 | flags: FieldDescriptor[jnp.bool_, (4,)] # A 4-element boolean array 17 | 18 | 19 | # Example: Defining a data structure for BGPQ values 20 | @xtructure_dataclass 21 | class MyHeapItem: 22 | task_id: FieldDescriptor[jnp.int32] 23 | payload: FieldDescriptor[jnp.float64, (2, 2)] # A 2x2 matrix 24 | ``` 25 | 26 | ## `@xtructure_dataclass` 27 | 28 | This decorator transforms a Python class into a JAX-compatible structure (specifically, a `chex.dataclass`) and adds several helpful methods and properties: 29 | 30 | * **`shape`** (property): Returns a namedtuple showing the JAX shapes of all fields. 31 | * **`dtype`** (property): Returns a namedtuple showing the JAX dtypes of all fields. 32 | * **`__getitem__(self, index)`**: Allows indexing or slicing an instance (e.g., `my_data_instance[0]`). The operation is applied to each field. 33 | * **`__len__(self)`**: Returns the size of the first dimension of the *first* field, typically used for batch size. 34 | * **`default(cls, shape=())`** (classmethod): Creates an instance with default values for all fields. 35 | * The optional `shape` argument (e.g., `(10,)` or `(5, 2)`) creates a "batched" instance. This means the provided `shape` tuple is prepended to the `intrinsic_shape` of each field defined in the dataclass. 36 | * For example, if a field is `data: FieldDescriptor[jnp.float32, (3,)]` (intrinsic shape `(3,)`): 37 | * Calling `YourClass.default()` or `YourClass.default(shape=())` results in `instance.data.shape` being `(3,)`. 38 | * Calling `YourClass.default(shape=(10,))` results in `instance.data.shape` being `(10, 3)`. 39 | * Calling `YourClass.default(shape=(5, 2))` results in `instance.data.shape` being `(5, 2, 3)`. 40 | * Each field in the instance will be filled with its default value, tiled or broadcasted to this new batched shape. 41 | * This method is auto-generated based on `FieldDescriptor` definitions if not explicitly provided. 42 | * **`random(cls, shape=(), key: jax.random.PRNGKey = ...)`** (classmethod): Creates an instance with random data. 43 | * `shape`: Specifies batch dimensions (e.g., `(10,)` or `(5, 2)`), which are prepended to the `intrinsic_shape` of each field. 44 | * For example, if a field is `data: FieldDescriptor[jnp.float32, (3,)]` (intrinsic shape `(3,)`): 45 | * Calling `YourClass.random(key=k)` or `YourClass.random(shape=(), key=k)` results in `instance.data.shape` being `(3,)`. 46 | * Calling `YourClass.random(shape=(10,), key=k)` results in `instance.data.shape` being `(10, 3)`. 47 | * Calling `YourClass.random(shape=(5, 2), key=k)` results in `instance.data.shape` being `(5, 2, 3)`. 48 | * Each field will be filled with random values according to its JAX dtype, and the field arrays will have these new batched shapes. 49 | * `key`: A JAX PRNG key is required for random number generation. 50 | * `structured_type` (property): An enum (`StructuredType.SINGLE`, `StructuredType.BATCHED`, `StructuredType.UNSTRUCTURED`) indicating instance structure relative to its default. 51 | * `batch_shape` (property): Shape of batch dimensions if `structured_type` is `BATCHED`. 52 | * `reshape(self, new_shape)`: Reshapes batch dimensions. 53 | * `flatten(self)`: Flattens batch dimensions. 54 | * `__str__(self)` / `str(self)`: Provides a string representation. 55 | * Handles instances based on their `structured_type`: 56 | * `SINGLE`: Uses the original `__str__` method of the instance or a custom pretty formatter for a detailed field-by-field view. 57 | * `BATCHED`: For small batches, all items are formatted. For large batches (controlled by `MAX_PRINT_BATCH_SIZE` and `SHOW_BATCH_SIZE`), it provides a summarized view showing the first few and last few elements, along with the batch shape, using `tabulate` for neat formatting. 58 | * `UNSTRUCTURED`: Indicates that the data is unstructured relative to its default shape. 59 | * `default_shape` (property): Returns a namedtuple showing the JAX shapes of all fields as they would be in an instance created by `cls.default()_` (i.e., without any batch dimensions). 60 | * `at[index_or_slice]` (property): Provides access to an updater object for out-of-place modifications of the instance's fields at the given `index_or_slice`. 61 | * `set(values_to_set)`: Returns a new instance with the fields at the specified `index_or_slice` updated with `values_to_set`. If `values_to_set` is an instance of the same dataclass, corresponding fields are used for the update; otherwise, `values_to_set` is applied to all selected field slices. 62 | * `set_as_condition(condition, value_to_conditionally_set)`: Returns a new instance where fields at the specified `index_or_slice` are updated based on a JAX boolean `condition`. If an element in `condition` is true, the corresponding element in the field slice is updated with `value_to_conditionally_set`. 63 | 64 | ## `FieldDescriptor` 65 | 66 | Defines the type and shape of each field within an `@xtructure_dataclass`. 67 | 68 | * **Syntax**: 69 | * `field_name: FieldDescriptor[jax_dtype]` 70 | * `field_name: FieldDescriptor[jax_dtype, intrinsic_shape_tuple]` 71 | * `field_name: FieldDescriptor[jax_dtype, intrinsic_shape_tuple, default_fill_value]` 72 | * Or direct instantiation: `FieldDescriptor(dtype=..., intrinsic_shape=..., fill_value=...)` 73 | * **Parameters**: 74 | * `dtype`: The JAX dtype (e.g., `jnp.int32`, `jnp.float32`, `jnp.bool_`). Can also be another `@xtructure_dataclass` type for nesting. 75 | * `intrinsic_shape` (optional): A tuple defining the field's shape *excluding* batch dimensions (e.g., `(3,)` for a vector, `(2,2)` for a matrix). Defaults to `()` for a scalar. 76 | * `fill_value` (optional): The value used when `cls.default()` is called. 77 | * Defaults: `-1` (max value) for unsigned integers, `jnp.inf` for signed integers and floats. `None` for nested structures (their own default applies). 78 | -------------------------------------------------------------------------------- /doc/hashtable.md: -------------------------------------------------------------------------------- 1 | # `HashTable` Usage 2 | 3 | A Cuckoo hash table optimized for JAX. 4 | 5 | ```python 6 | import jax 7 | import jax.numpy as jnp 8 | from xtructure import HashTable, xtructure_dataclass, FieldDescriptor 9 | 10 | 11 | # Define a data structure (as an example from core_concepts.md) 12 | @xtructure_dataclass 13 | class MyDataValue: 14 | id: FieldDescriptor[jnp.uint32] 15 | position: FieldDescriptor[jnp.float32, (3,)] 16 | flags: FieldDescriptor[jnp.bool_, (4,)] 17 | 18 | 19 | # 1. Build the HashTable 20 | # HashTable.build(pytree_def_type, initial_hash_seed, capacity) 21 | table_capacity = 1000 22 | hash_table = HashTable.build(MyDataValue, 123, table_capacity) 23 | # Note: MyDataValue (the class itself) is passed, not an instance, for build. 24 | 25 | # 3. Prepare data to insert 26 | # Let's create some random data. 27 | num_items_to_insert = 100 28 | key = jax.random.PRNGKey(0) 29 | sample_data = MyDataValue.random(shape=(num_items_to_insert,), key=key) 30 | 31 | # 4. Insert data 32 | # HashTable.parallel_insert(table, samples, filled_mask) 33 | # 'filled_mask' indicates which items in 'sample_data' are valid. 34 | filled_mask = jnp.ones(num_items_to_insert, dtype=jnp.bool_) 35 | hash_table, inserted_mask, unique_mask, idxs, table_idxs = HashTable.parallel_insert( 36 | hash_table, sample_data, filled_mask 37 | ) 38 | 39 | print(f"HashTable: Inserted {jnp.sum(inserted_mask)} items.") 40 | print(f"HashTable: Unique items inserted: {jnp.sum(unique_mask)}") # Number of items that were not already present 41 | print(f"HashTable size: {hash_table.size}") 42 | 43 | # inserted_mask: boolean array, true if the item at the corresponding input index was successfully inserted. 44 | # unique_mask: boolean array, true if the inserted item was unique (not a duplicate). 45 | # idxs: primary indices in the hash table where items were stored. 46 | # table_idxs: cuckoo table indices (0 to CUCKOO_TABLE_N-1) used for each stored item. 47 | 48 | # 5. Lookup data 49 | # HashTable.lookup(table, item_to_lookup) 50 | item_to_check = sample_data[0] # Let's check the first item we inserted 51 | idx, table_idx, found = HashTable.lookup(hash_table, item_to_check) 52 | 53 | if found: 54 | retrieved_item = hash_table.table[idx, table_idx] # Accessing the item from the internal table 55 | print(f"HashTable: Item found at primary index {idx}, cuckoo_index {table_idx}.") 56 | # You can then compare retrieved_item with item_to_check 57 | else: 58 | print("HashTable: Item not found.") 59 | 60 | # (Optional) Batching data for insertion if your data isn't already batched appropriately: 61 | # batch_size_for_insert = 50 # Example internal batch size if HashTable has one 62 | # batched_sample_data, filled_mask_for_batched = HashTable.make_batched( 63 | # MyDataValue, # The class, not an instance 64 | # sample_data, 65 | # batch_size_for_insert 66 | # ) 67 | # Then use batched_sample_data and filled_mask_for_batched in parallel_insert. 68 | # Note: The `parallel_insert` in the README example did not require pre-batching with `HashTable.make_batched`. 69 | # The provided `hash.py` also seems to handle arbitrary input sizes for `parallel_insert` with a `filled` mask. 70 | # `HashTable.make_batched` is available if manual batch control is needed. 71 | ``` 72 | 73 | ## Key `HashTable` Details 74 | 75 | * **Cuckoo Hashing**: Uses `CUCKOO_TABLE_N` (an internal constant, typically small e.g. 2-4) hash functions/slots per primary index to resolve collisions. This means an item can be stored in one of `N` locations. 76 | * **`HashTable.build(dataclass, seed, capacity)`**: 77 | * `dataclass`: The *class* of your custom data structure (e.g., `MyDataValue`). An instance of this class (e.g., `MyDataValue.default()`) is used internally to define the table structure. 78 | * `seed`: Integer seed for hashing. 79 | * `capacity`: Desired user capacity. The internal capacity (`_capacity`) will be larger to accommodate Cuckoo hashing (specifically, `int(HASH_SIZE_MULTIPLIER * capacity / CUCKOO_TABLE_N)`). 80 | * **`HashTable.parallel_insert(table, hash_func, inputs, filled_mask)`**: 81 | * `inputs`: A PyTree (or batch of PyTrees) of items to insert. 82 | * `filled_mask`: A boolean JAX array indicating which entries in `inputs` are valid. 83 | * Returns the updated table, `inserted_mask` (boolean array for successful insertions for each input), `unique_mask` (boolean array, true if the item was new and not a duplicate), `idxs` (main table indices where items were stored), and `table_idxs` (Cuckoo slot indices used). 84 | * **`HashTable.lookup(table, hash_func, item_to_lookup)`**: 85 | * Returns `idx` (main table index), `table_idx` (Cuckoo slot index), and `found` (boolean). 86 | * If `found` is true, the item can be retrieved from `table.table[idx, table_idx]`. 87 | * **`HashTable.make_batched(dataclass, inputs, batch_size)`**: (Static method) 88 | * A helper to reshape and pad input data into fixed-size batches if needed for specific batch processing workflows, though `parallel_insert` itself can handle variable-sized inputs with a `filled_mask`. 89 | * `dataclass`: The class of your custom data structure. 90 | * Returns `batched_pytree, filled_mask_for_batched`. 91 | -------------------------------------------------------------------------------- /doc/queue.md: -------------------------------------------------------------------------------- 1 | # `Queue` Usage 2 | 3 | A JAX-compatible batched Queue data structure, designed for FIFO (First-In, First-Out) operations. It is optimized for parallel execution on hardware like GPUs. 4 | 5 | ```python 6 | import jax 7 | import jax.numpy as jnp 8 | from xtructure import Queue, xtructure_dataclass, FieldDescriptor 9 | 10 | 11 | # Define a data structure to store in the queue 12 | @xtructure_dataclass 13 | class Point: 14 | x: FieldDescriptor[jnp.uint32] 15 | y: FieldDescriptor[jnp.uint32] 16 | 17 | 18 | # 1. Build the Queue 19 | # Queue.build(max_size, value_class) 20 | queue = Queue.build(max_size=1000, value_class=Point) 21 | 22 | # 2. Enqueue a single item 23 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 24 | queue = queue.enqueue(p1) 25 | print(f"Queue size after enqueuing one item: {queue.size}") 26 | print(f"Queue head: {queue.head}, Queue tail: {queue.tail}") 27 | 28 | 29 | # 3. Enqueue a batch of items 30 | batch_points = Point(x=jnp.arange(10, dtype=jnp.uint32), y=jnp.arange(10, 20, dtype=jnp.uint32)) 31 | queue = queue.enqueue(batch_points) 32 | print(f"Queue size after enqueuing a batch: {queue.size}") 33 | print(f"Queue head: {queue.head}, Queue tail: {queue.tail}") 34 | 35 | # 4. Peek at the front item 36 | # Does not modify the queue 37 | peeked_item = queue.peek() 38 | print("Peeked item:", peeked_item) 39 | assert queue.size == 11 # Unchanged 40 | 41 | # 5. Dequeue a batch of items 42 | # Removes the first 5 items from the queue 43 | queue, dequeued_items = queue.dequeue(5) 44 | print(f"Queue size after dequeuing 5 items: {queue.size}") 45 | print(f"Queue head: {queue.head}, Queue tail: {queue.tail}") 46 | print("Dequeued items (x-values):", dequeued_items.x) 47 | 48 | # 6. Dequeue a single item 49 | queue, dequeued_item = queue.dequeue() 50 | print(f"Queue size after dequeuing one item: {queue.size}") 51 | print("Dequeued item:", dequeued_item) 52 | 53 | # 7. Clear the queue 54 | queue = queue.clear() 55 | print(f"Queue size after clearing: {queue.size}") 56 | print(f"Queue head: {queue.head}, Queue tail: {queue.tail}") 57 | ``` 58 | 59 | ## Key `Queue` Details 60 | 61 | * **FIFO Principle**: The first element added to the queue will be the first one to be removed. 62 | * **API Style**: The methods (`enqueue`, `dequeue`, `clear`) modify the queue's state and return the modified instance, allowing for a chained, functional-style usage pattern. 63 | 64 | * **`Queue.build(max_size, value_class)`**: 65 | * `max_size` (int): The maximum number of elements the queue can hold. 66 | * `value_class` (Xtructurable): The class of the data structure to be stored (e.g., `Point`). 67 | 68 | * **`queue.enqueue(items)`**: 69 | * `items` (Xtructurable): An instance or a batch of instances to add to the end of the queue. 70 | * Returns the updated `Queue` instance. 71 | 72 | * **`queue.dequeue(num_items=1)`**: 73 | * `num_items` (int): The number of items to remove from the front of the queue. 74 | * Returns a tuple containing: 75 | 1. The updated `Queue` instance. 76 | 2. The `Xtructurable` containing the dequeued items. 77 | 78 | * **`queue.peek(num_items=1)`**: 79 | * `num_items` (int): The number of items to view from the front of the queue. 80 | * Returns the `Xtructurable` containing the front items without modifying the queue. 81 | 82 | * **`queue.clear()`**: 83 | * Resets the `head` and `tail` of the queue to 0, effectively emptying it. 84 | * Returns the updated `Queue` instance. 85 | ``` 86 | -------------------------------------------------------------------------------- /doc/stack.md: -------------------------------------------------------------------------------- 1 | # `Stack` Usage 2 | 3 | A JAX-compatible batched Stack data structure, designed for LIFO (Last-In, First-Out) operations. It is optimized for parallel execution on hardware like GPUs. 4 | 5 | ```python 6 | import jax 7 | import jax.numpy as jnp 8 | from xtructure import Stack, xtructure_dataclass, FieldDescriptor 9 | 10 | 11 | # Define a data structure to store in the stack 12 | @xtructure_dataclass 13 | class Point: 14 | x: FieldDescriptor[jnp.uint32] 15 | y: FieldDescriptor[jnp.uint32] 16 | 17 | 18 | # 1. Build the Stack 19 | # Stack.build(max_size, value_class) 20 | stack = Stack.build(max_size=1000, value_class=Point) 21 | 22 | # 2. Push a single item 23 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 24 | stack = stack.push(p1) 25 | print(f"Stack size after pushing one item: {stack.size}") 26 | 27 | # 3. Push a batch of items 28 | batch_points = Point(x=jnp.arange(10, dtype=jnp.uint32), y=jnp.arange(10, 20, dtype=jnp.uint32)) 29 | stack = stack.push(batch_points) 30 | print(f"Stack size after pushing a batch: {stack.size}") 31 | 32 | # 4. Peek at the top item 33 | # Does not modify the stack 34 | peeked_item = stack.peek() 35 | print("Peeked item:", peeked_item) 36 | assert stack.size == 11 # Unchanged 37 | 38 | # 5. Pop a batch of items 39 | # Removes the top 5 items from the stack 40 | stack, popped_items = stack.pop(5) 41 | print(f"Stack size after popping 5 items: {stack.size}") 42 | print("Popped items (y-values):", popped_items.y) 43 | 44 | # 6. Pop a single item 45 | stack, popped_item = stack.pop() 46 | print(f"Stack size after popping one item: {stack.size}") 47 | print("Popped item:", popped_item) 48 | ``` 49 | 50 | ## Key `Stack` Details 51 | 52 | * **LIFO Principle**: The last element added to the stack will be the first one to be removed. 53 | * **Pure Functional API**: All methods (`push`, `pop`) are pure functions. They do not modify the stack in-place but instead return a new `Stack` instance with the updated state. This is essential for compatibility with JAX's JIT compilation. 54 | 55 | * **`Stack.build(max_size, value_class)`**: 56 | * `max_size` (int): The maximum number of elements the stack can hold. 57 | * `value_class` (Xtructurable): The class of the data structure to be stored (e.g., `Point`). This defines the structure of the internal value store. 58 | 59 | * **`stack.push(items)`**: 60 | * `items` (Xtructurable): An instance or a batch of instances to push onto the stack. If a batch is provided, its first dimension is treated as the batch dimension. 61 | * Returns a new `Stack` instance with the items added. 62 | 63 | * **`stack.pop(num_items=1)`**: 64 | * `num_items` (int): The number of items to pop from the top of the stack. 65 | * Returns a tuple containing: 66 | 1. A new `Stack` instance with the items removed. 67 | 2. The `Xtructurable` containing the popped items. 68 | 69 | * **`stack.peek(num_items=1)`**: 70 | * `num_items` (int): The number of items to view from the top of the stack. 71 | * Returns the `Xtructurable` containing the top items without modifying the stack. 72 | 73 | ``` 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax[cuda]>=0.4.0 2 | chex>=0.1.0 3 | pytest>=7.0.0 4 | tabulate>=0.9.0 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="xtructure", 8 | version="0.0.17", 9 | author="tinker495", 10 | author_email="wjdrbtjr495@gmail.com", 11 | description="JAX-optimized data structures", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/tinker495/Xtructure", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | "jax[cuda]>=0.4.0", 18 | "chex>=0.1.0", 19 | "tabulate>=0.9.0", 20 | ], 21 | extras_require={ 22 | "dev": [ 23 | "pytest>=7.0.0", 24 | ] 25 | }, 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Operating System :: OS Independent", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ], 32 | python_requires=">=3.8", 33 | ) 34 | -------------------------------------------------------------------------------- /tests/dataclass_test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from xtructure import FieldDescriptor, StructuredType, xtructure_dataclass 5 | 6 | 7 | # Test data structures 8 | @xtructure_dataclass 9 | class SimpleData: 10 | id: FieldDescriptor[jnp.uint32] 11 | value: FieldDescriptor[jnp.float32] 12 | 13 | 14 | @xtructure_dataclass 15 | class VectorData: 16 | position: FieldDescriptor[jnp.float32, (3,)] 17 | velocity: FieldDescriptor[jnp.float32, (3,)] 18 | 19 | 20 | @xtructure_dataclass 21 | class MatrixData: 22 | matrix: FieldDescriptor[jnp.float32, (2, 2)] 23 | flags: FieldDescriptor[jnp.bool_, (4,), False] 24 | 25 | 26 | @xtructure_dataclass 27 | class NestedData: 28 | simple: FieldDescriptor[SimpleData] 29 | vector: FieldDescriptor[VectorData] 30 | 31 | 32 | def test_dataclass_default(): 33 | # Test default creation 34 | simple = SimpleData.default() 35 | assert simple.id.shape == () 36 | assert simple.value.shape == () 37 | assert simple.id.dtype == jnp.uint32 38 | assert simple.value.dtype == jnp.float32 39 | 40 | # Test batched creation 41 | batched = SimpleData.default(shape=(10,)) 42 | assert batched.id.shape == (10,) 43 | assert batched.value.shape == (10,) 44 | 45 | 46 | def test_dataclass_random(): 47 | key = jax.random.PRNGKey(0) 48 | 49 | # Test random creation 50 | simple = SimpleData.random(key=key) 51 | assert simple.id.shape == () 52 | assert simple.value.shape == () 53 | 54 | # Test batched random creation 55 | batched = SimpleData.random(shape=(5,), key=key) 56 | assert batched.id.shape == (5,) 57 | assert batched.value.shape == (5,) 58 | 59 | 60 | def test_vector_data(): 61 | # Test vector data structure 62 | vector = VectorData.default() 63 | assert vector.position.shape == (3,) 64 | assert vector.velocity.shape == (3,) 65 | 66 | # Test batched vector data 67 | batched = VectorData.default(shape=(4,)) 68 | assert batched.position.shape == (4, 3) 69 | assert batched.velocity.shape == (4, 3) 70 | 71 | 72 | def test_matrix_data(): 73 | # Test matrix data structure 74 | matrix = MatrixData.default() 75 | assert matrix.matrix.shape == (2, 2) 76 | assert matrix.flags.shape == (4,) 77 | 78 | # Test batched matrix data 79 | batched = MatrixData.default(shape=(3,)) 80 | assert batched.matrix.shape == (3, 2, 2) 81 | assert batched.flags.shape == (3, 4) 82 | 83 | 84 | def test_nested_data(): 85 | # Test nested data structure 86 | nested = NestedData.default() 87 | assert nested.simple.id.shape == () 88 | assert nested.simple.value.shape == () 89 | assert nested.vector.position.shape == (3,) 90 | 91 | # Test batched nested data 92 | batched = NestedData.default(shape=(2,)) 93 | assert batched.simple.id.shape == (2,) 94 | assert batched.simple.value.shape == (2,) 95 | assert batched.vector.position.shape == (2, 3) 96 | 97 | 98 | def test_structured_type(): 99 | # Test structured type property 100 | simple = SimpleData.default() 101 | assert simple.structured_type == StructuredType.SINGLE 102 | 103 | batched = SimpleData.default(shape=(5,)) 104 | assert batched.structured_type == StructuredType.BATCHED 105 | assert batched.shape.batch == (5,) 106 | 107 | batched2d = SimpleData.default(shape=(5, 10)) 108 | assert batched2d.structured_type == StructuredType.BATCHED 109 | assert batched2d.shape.batch == (5, 10) 110 | 111 | vector = VectorData.default(shape=(5, 10)) 112 | assert vector.structured_type == StructuredType.BATCHED 113 | assert vector.shape.batch == (5, 10) 114 | 115 | matrix = MatrixData.default(shape=(5, 10)) 116 | assert matrix.structured_type == StructuredType.BATCHED 117 | assert matrix.shape.batch == (5, 10) 118 | 119 | nested = NestedData.default(shape=(5, 10)) 120 | assert nested.structured_type == StructuredType.BATCHED 121 | assert nested.shape.batch == (5, 10) 122 | 123 | 124 | def test_reshape(): 125 | # Test reshape functionality 126 | batched = SimpleData.default(shape=(10,)) 127 | reshaped = batched.reshape((2, 5)) 128 | assert reshaped.structured_type == StructuredType.BATCHED 129 | assert reshaped.shape.batch == (2, 5) 130 | assert reshaped.id.shape == (2, 5) 131 | assert reshaped.value.shape == (2, 5) 132 | 133 | batched2d = SimpleData.default(shape=(2, 3)) 134 | reshaped2d = batched2d.reshape((6,)) 135 | assert reshaped2d.structured_type == StructuredType.BATCHED 136 | assert reshaped2d.shape.batch == (6,) 137 | assert reshaped2d.id.shape == (6,) 138 | assert reshaped2d.value.shape == (6,) 139 | 140 | vector = VectorData.default(shape=(10,)) 141 | reshaped_vector = vector.reshape((2, 5)) 142 | assert reshaped_vector.structured_type == StructuredType.BATCHED 143 | assert reshaped_vector.shape.batch == (2, 5) 144 | assert reshaped_vector.position.shape == (2, 5, 3) 145 | assert reshaped_vector.velocity.shape == (2, 5, 3) 146 | 147 | vector2d = VectorData.default(shape=(2, 3)) 148 | reshaped_vector2d = vector2d.reshape((6,)) 149 | assert reshaped_vector2d.structured_type == StructuredType.BATCHED 150 | assert reshaped_vector2d.shape.batch == (6,) 151 | assert reshaped_vector2d.position.shape == (6, 3) 152 | assert reshaped_vector2d.velocity.shape == (6, 3) 153 | 154 | matrix = MatrixData.default(shape=(10,)) 155 | reshaped_matrix = matrix.reshape((2, 5)) 156 | assert reshaped_matrix.structured_type == StructuredType.BATCHED 157 | assert reshaped_matrix.shape.batch == (2, 5) 158 | assert reshaped_matrix.matrix.shape == (2, 5, 2, 2) 159 | 160 | matrix2d = MatrixData.default(shape=(2, 3)) 161 | reshaped_matrix2d = matrix2d.reshape((6,)) 162 | assert reshaped_matrix2d.structured_type == StructuredType.BATCHED 163 | assert reshaped_matrix2d.shape.batch == (6,) 164 | assert reshaped_matrix2d.matrix.shape == (6, 2, 2) 165 | assert reshaped_matrix2d.flags.shape == (6, 4) 166 | 167 | nested = NestedData.default(shape=(10,)) 168 | reshaped_nested = nested.reshape((2, 5)) 169 | assert reshaped_nested.structured_type == StructuredType.BATCHED 170 | assert reshaped_nested.shape.batch == (2, 5) 171 | assert reshaped_nested.simple.id.shape == (2, 5) 172 | assert reshaped_nested.simple.value.shape == (2, 5) 173 | 174 | nested2d = NestedData.default(shape=(2, 3)) 175 | reshaped_nested2d = nested2d.reshape((6,)) 176 | assert reshaped_nested2d.structured_type == StructuredType.BATCHED 177 | assert reshaped_nested2d.shape.batch == (6,) 178 | assert reshaped_nested2d.simple.id.shape == (6,) 179 | assert reshaped_nested2d.simple.value.shape == (6,) 180 | 181 | 182 | def test_flatten(): 183 | # Test flatten functionality 184 | batched = SimpleData.default(shape=(2, 3)) 185 | flattened = batched.flatten() 186 | print(flattened.structured_type) 187 | assert flattened.structured_type == StructuredType.BATCHED 188 | assert flattened.shape.batch == (6,) 189 | assert flattened.id.shape == (6,) 190 | assert flattened.value.shape == (6,) 191 | 192 | batched2d = SimpleData.default(shape=(2, 3)) 193 | flattened2d = batched2d.flatten() 194 | assert flattened2d.structured_type == StructuredType.BATCHED 195 | assert flattened2d.shape.batch == (6,) 196 | assert flattened2d.id.shape == (6,) 197 | assert flattened2d.value.shape == (6,) 198 | 199 | vector = VectorData.default(shape=(2, 3)) 200 | flattened_vector = vector.flatten() 201 | assert flattened_vector.structured_type == StructuredType.BATCHED 202 | assert flattened_vector.shape.batch == (6,) 203 | assert flattened_vector.position.shape == (6, 3) 204 | assert flattened_vector.velocity.shape == (6, 3) 205 | 206 | matrix = MatrixData.default(shape=(2, 3)) 207 | flattened_matrix = matrix.flatten() 208 | assert flattened_matrix.structured_type == StructuredType.BATCHED 209 | assert flattened_matrix.shape.batch == (6,) 210 | assert flattened_matrix.matrix.shape == (6, 2, 2) 211 | assert flattened_matrix.flags.shape == (6, 4) 212 | 213 | nested = NestedData.default(shape=(2, 3)) 214 | flattened_nested = nested.flatten() 215 | assert flattened_nested.structured_type == StructuredType.BATCHED 216 | assert flattened_nested.shape.batch == (6,) 217 | assert flattened_nested.simple.id.shape == (6,) 218 | assert flattened_nested.simple.value.shape == (6,) 219 | 220 | 221 | def test_indexing(): 222 | # Test indexing functionality 223 | batched = SimpleData.default(shape=(5,)) 224 | single = batched[0] 225 | assert single.structured_type == StructuredType.SINGLE 226 | assert single.id.shape == () 227 | assert single.value.shape == () 228 | 229 | # Test slicing 230 | sliced = batched[1:3] 231 | assert sliced.structured_type == StructuredType.BATCHED 232 | assert sliced.shape.batch == (2,) 233 | assert sliced.id.shape == (2,) 234 | assert sliced.value.shape == (2,) 235 | 236 | 237 | def test_unstructured_generation(): 238 | 239 | unstructured = SimpleData(id=jnp.array(1), value=jnp.array([2.0, 3.0, 4.0])) 240 | assert unstructured.structured_type == StructuredType.UNSTRUCTURED 241 | assert unstructured.id.shape == () 242 | assert unstructured.value.shape == (3,) 243 | 244 | batched_unstructured = SimpleData( 245 | id=jnp.array([1, 2]), value=jnp.array([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]) 246 | ) 247 | assert batched_unstructured.structured_type == StructuredType.UNSTRUCTURED 248 | assert batched_unstructured.id.shape == (2,) 249 | assert batched_unstructured.value.shape == (2, 3) 250 | 251 | try: 252 | batched_unstructured.reshape((2, 3)) 253 | assert False, "unstructured data should not be reshaped" 254 | except ValueError: 255 | pass 256 | 257 | try: 258 | batched_unstructured.flatten() 259 | assert False, "flatten operation is only supported for BATCHED structured types" 260 | except ValueError: 261 | pass 262 | 263 | 264 | def test_at_set_simple_data(): 265 | # Test .at[...].set(...) for SimpleData 266 | original_data = SimpleData.default(shape=(3,)) 267 | 268 | # Create new data to set 269 | data_to_set_scalar = SimpleData( 270 | id=jnp.array(100, dtype=jnp.uint32), value=jnp.array(99.9, dtype=jnp.float32) 271 | ) 272 | 273 | # Update a single element with another SimpleData instance 274 | updated_data_single = original_data.at[1].set(data_to_set_scalar) 275 | 276 | assert updated_data_single.id[0] == original_data.id[0] 277 | assert updated_data_single.value[0] == original_data.value[0] 278 | assert updated_data_single.id[1] == data_to_set_scalar.id 279 | assert updated_data_single.value[1] == data_to_set_scalar.value 280 | assert updated_data_single.id[2] == original_data.id[2] 281 | assert updated_data_single.value[2] == original_data.value[2] 282 | 283 | # Ensure original data is unchanged 284 | assert original_data.id[1] != data_to_set_scalar.id 285 | assert original_data.value[1] != data_to_set_scalar.value 286 | 287 | # Update using a scalar value for all fields (if JAX supports it for specific dtype) 288 | # For SimpleData, id is uint32 and value is float32. 289 | # JAX .at[idx].set(scalar) will broadcast if the scalar is compatible. 290 | updated_data_scalar_id = original_data.at[0].set(jnp.uint32(50)) 291 | assert updated_data_scalar_id.id[0] == 50 292 | # The value field should also be updated with 50 if broadcast works, or remain original if not. 293 | # Given current implementation, value_for_this_field = values_to_set (which is 50) 294 | # jnp.array(0.0, dtype=jnp.float32).at[()].set(50) would make it 50.0 295 | assert updated_data_scalar_id.value[0] == 50.0 296 | assert updated_data_scalar_id.id[1] == original_data.id[1] 297 | assert updated_data_scalar_id.value[1] == original_data.value[1] 298 | 299 | # Test setting with a slice 300 | slice_data_to_set = SimpleData.default(shape=(2,)) # id=0, value=0.0 301 | updated_data_slice = original_data.at[0:2].set(slice_data_to_set) 302 | assert updated_data_slice.id[0] == slice_data_to_set.id[0] 303 | assert updated_data_slice.value[0] == slice_data_to_set.value[0] 304 | assert updated_data_slice.id[1] == slice_data_to_set.id[1] 305 | assert updated_data_slice.value[1] == slice_data_to_set.value[1] 306 | assert updated_data_slice.id[2] == original_data.id[2] 307 | assert updated_data_slice.value[2] == original_data.value[2] 308 | 309 | 310 | def test_at_set_vector_data(): 311 | original_data = VectorData.default( 312 | shape=(3,) 313 | ) # position and velocity are (3,3) filled with 0.0 314 | 315 | # Data to set for a single batch element 316 | # position and velocity should be (3,) 317 | vector_to_set = VectorData( 318 | position=jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), 319 | velocity=jnp.array([4.0, 5.0, 6.0], dtype=jnp.float32), 320 | ) 321 | 322 | updated_data = original_data.at[1].set(vector_to_set) 323 | 324 | assert jnp.array_equal(updated_data.position[0], original_data.position[0]) 325 | assert jnp.array_equal(updated_data.velocity[0], original_data.velocity[0]) 326 | 327 | assert jnp.array_equal(updated_data.position[1], vector_to_set.position) 328 | assert jnp.array_equal(updated_data.velocity[1], vector_to_set.velocity) 329 | 330 | assert jnp.array_equal(updated_data.position[2], original_data.position[2]) 331 | assert jnp.array_equal(updated_data.velocity[2], original_data.velocity[2]) 332 | 333 | # Ensure original data is unchanged 334 | assert not jnp.array_equal(original_data.position[1], vector_to_set.position) 335 | 336 | # Test setting with a scalar (will broadcast to the (3,) shape of position/velocity) 337 | updated_data_scalar = original_data.at[0].set(jnp.float32(7.0)) 338 | assert jnp.all(updated_data_scalar.position[0] == 7.0) 339 | assert jnp.all(updated_data_scalar.velocity[0] == 7.0) 340 | assert jnp.array_equal(updated_data_scalar.position[1], original_data.position[1]) 341 | 342 | 343 | def test_at_set_nested_data(): 344 | original_data = NestedData.default(shape=(2,)) 345 | # original_data.simple.id shape (2,), value (2,) 346 | # original_data.vector.position shape (2,3), velocity (2,3) 347 | 348 | # Create a single NestedData instance to set at one index 349 | # This instance itself should NOT be batched. Its internal fields are single. 350 | data_to_set_single_nested = ( 351 | NestedData.default() 352 | ) # scalar id, value, (3,) position, (3,) velocity 353 | data_to_set_single_nested = data_to_set_single_nested.replace( 354 | simple=SimpleData( 355 | id=jnp.array(10, dtype=jnp.uint32), value=jnp.array(1.1, dtype=jnp.float32) 356 | ), 357 | vector=VectorData( 358 | position=jnp.ones(3, dtype=jnp.float32), velocity=jnp.ones(3, dtype=jnp.float32) * 2 359 | ), 360 | ) 361 | 362 | updated_data = original_data.at[0].set(data_to_set_single_nested) 363 | 364 | # Check updated part 365 | assert updated_data.simple.id[0] == data_to_set_single_nested.simple.id 366 | assert updated_data.simple.value[0] == data_to_set_single_nested.simple.value 367 | assert jnp.array_equal( 368 | updated_data.vector.position[0], data_to_set_single_nested.vector.position 369 | ) 370 | assert jnp.array_equal( 371 | updated_data.vector.velocity[0], data_to_set_single_nested.vector.velocity 372 | ) 373 | 374 | # Check unchanged part 375 | assert updated_data.simple.id[1] == original_data.simple.id[1] 376 | assert updated_data.simple.value[1] == original_data.simple.value[1] 377 | assert jnp.array_equal(updated_data.vector.position[1], original_data.vector.position[1]) 378 | assert jnp.array_equal(updated_data.vector.velocity[1], original_data.vector.velocity[1]) 379 | 380 | # Ensure original data is unchanged 381 | assert original_data.simple.id[0] != data_to_set_single_nested.simple.id 382 | -------------------------------------------------------------------------------- /tests/hash_test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from xtructure import FieldDescriptor, HashTable, xtructure_dataclass 5 | 6 | 7 | @xtructure_dataclass 8 | class XtructureValue: 9 | a: FieldDescriptor(jnp.uint8) # type: ignore 10 | b: FieldDescriptor(jnp.uint32, (1, 2)) # type: ignore 11 | 12 | 13 | def test_hash_table_lookup(): 14 | count = 1000 15 | sample = XtructureValue.random((count,)) 16 | table = HashTable.build(XtructureValue, 1, int(1e4)) 17 | 18 | lookup = jax.jit(lambda table, sample: HashTable.lookup(table, sample)) 19 | idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample) 20 | 21 | assert idx.shape == (count,) 22 | assert table_idx.shape == (count,) 23 | assert found.shape == (count,) 24 | assert not jnp.any(found) # Initially all should be not found 25 | 26 | 27 | def test_hash_table_insert(): 28 | count = 1000 29 | batch = 4000 30 | table = HashTable.build(XtructureValue, 1, int(1e4)) 31 | 32 | sample = XtructureValue.random((count,)) 33 | 34 | lookup = jax.jit(lambda table, sample: HashTable.lookup(table, sample)) 35 | parallel_insert = jax.jit( 36 | lambda table, sample, filled: HashTable.parallel_insert(table, sample, filled) 37 | ) 38 | 39 | # Check initial state 40 | _, _, old_found = jax.vmap(lookup, in_axes=(None, 0))(table, sample) 41 | assert not jnp.any(old_found) 42 | 43 | # Insert states 44 | batched_sample, filled = HashTable.make_batched(XtructureValue, sample, batch) 45 | table, inserted, _, _, _ = parallel_insert(table, batched_sample, filled) 46 | 47 | # Verify insertion 48 | _, _, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample) 49 | assert jnp.all(found) # All states should be found after insertion 50 | assert jnp.mean(inserted) > 0 # Some states should have been inserted 51 | 52 | 53 | def test_same_state_insert_at_batch(): 54 | batch = 5000 55 | table = HashTable.build(XtructureValue, 1, int(1e5)) 56 | parallel_insert = jax.jit( 57 | lambda table, sample, filled: HashTable.parallel_insert(table, sample, filled) 58 | ) 59 | lookup = jax.jit(lambda table, sample: HashTable.lookup(table, sample)) 60 | 61 | num = 10 62 | counts = 0 63 | all_samples = [] 64 | for i in range(num): 65 | key = jax.random.PRNGKey(i) 66 | samples = XtructureValue.random((batch,)) 67 | cloned_sample_num = jax.random.randint(key, (), 1, batch // 2) 68 | cloned_sample_idx = jax.random.randint(key, (cloned_sample_num,), 0, batch) 69 | cloned_sample_idx = jnp.sort(cloned_sample_idx) 70 | new_clone_idx = jax.random.randint(key, (cloned_sample_num,), 0, batch) 71 | 72 | # Create deliberate duplicates within the batch 73 | samples = samples.at[new_clone_idx].set(samples[cloned_sample_idx]) 74 | h, bytesed = jax.vmap(lambda x: x.hash(0))(samples) 75 | unique_count = jnp.unique(bytesed, axis=0).shape[0] 76 | # after this, some states are duplicated 77 | all_samples.append(samples) 78 | 79 | batched_sample, filled = HashTable.make_batched(XtructureValue, samples, batch) 80 | table, updatable, unique, idxs, table_idxs = parallel_insert(table, batched_sample, filled) 81 | counts += jnp.sum(updatable) 82 | 83 | # Verify uniqueness tracking 84 | unique_idxs = jnp.unique(jnp.stack([idxs, table_idxs], axis=1), axis=0) 85 | assert ( 86 | unique_idxs.shape[0] == unique_count 87 | ), f"unique_idxs.shape: {unique_idxs.shape}, unique_count: {unique_count}" 88 | assert unique_idxs.shape[0] == jnp.sum(unique), "Unique index mismatch" 89 | assert jnp.all( 90 | jnp.unique(unique_idxs, axis=0) == unique_idxs 91 | ), "Duplicate indices in unique set" 92 | 93 | # Verify inserted states exist in table 94 | _, _, found = jax.vmap(lookup, in_axes=(None, 0))(table, samples) 95 | assert jnp.all(found), ( 96 | "Inserted states not found in table\n", 97 | f"unique_count: {unique_count}\n", 98 | f"unique_idxs.shape: {unique_idxs.shape}, unique: {jnp.sum(unique)}\n", 99 | f"found: {jnp.sum(found)}\n", 100 | f"not_found_idxs: {jnp.where(~found)[0]}\n", 101 | f"cloned_sample_idx: {cloned_sample_idx}\n", 102 | ) 103 | 104 | # Final validation 105 | assert table.size == counts, f"Size mismatch: {table.size} vs {counts}" 106 | 107 | # Verify cross-batch duplicates 108 | for samples in all_samples: 109 | idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, samples) 110 | assert jnp.all(found), "Cross-batch state missing" 111 | contents = table.table[idx, table_idx] 112 | assert jnp.all( 113 | jax.vmap(lambda x, y: x == y)(contents, samples) 114 | ), "Inserted states not found in table" 115 | 116 | 117 | def test_large_hash_table(): 118 | count = int(1e7) 119 | batch = int(1e4) 120 | table = HashTable.build(XtructureValue, 1, count) 121 | 122 | sample = XtructureValue.random((count,)) 123 | hash, bytes = jax.vmap(lambda x: x.hash(0))(sample) 124 | unique_bytes = jnp.unique(bytes, axis=0, return_index=True)[1] 125 | unique_bytes_len = unique_bytes.shape[0] 126 | unique_hash = jnp.unique(hash, axis=0, return_index=True)[1] 127 | unique_hash_len = unique_hash.shape[0] 128 | print(f"unique_bytes_len: {unique_bytes_len}, unique_hash_len: {unique_hash_len}") 129 | 130 | parallel_insert = jax.jit( 131 | lambda table, sample, filled: HashTable.parallel_insert(table, sample, filled) 132 | ) 133 | lookup = jax.jit(lambda table, sample: HashTable.lookup(table, sample)) 134 | 135 | # Insert in batches 136 | inserted_count = 0 137 | for i in range(0, count, batch): 138 | batch_sample = sample[i : i + batch] 139 | table, inserted, _, _, _ = parallel_insert( 140 | table, batch_sample, jnp.ones(len(batch_sample), dtype=jnp.bool_) 141 | ) 142 | inserted_count += jnp.sum(inserted) 143 | 144 | assert ( 145 | inserted_count == unique_bytes_len 146 | ), f"inserted_count: {inserted_count}, unique_bytes_len: {unique_bytes_len}, unique_hash_len: {unique_hash_len}" 147 | 148 | # Verify all states can be found 149 | _, _, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample) 150 | assert jnp.mean(found) == 1.0 # All states should be found 151 | -------------------------------------------------------------------------------- /tests/heap_test.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | import random 5 | 6 | from xtructure import BGPQ, FieldDescriptor, xtructure_dataclass 7 | 8 | @xtructure_dataclass 9 | class XtructureValue: 10 | """ 11 | This class is a dataclass that represents a hash table heap value. 12 | It has two fields: 13 | 1. index: hashtable index 14 | 2. table_index: cuckoo table index 15 | """ 16 | 17 | a: FieldDescriptor(jnp.uint8) # type: ignore 18 | b: FieldDescriptor(jnp.uint32, (1, 2)) # type: ignore 19 | c: FieldDescriptor(jnp.float32, (1, 2, 3)) # type: ignore 20 | 21 | 22 | @jax.jit 23 | def key_gen(x: XtructureValue) -> float: 24 | uint32_hash, _ = x.hash() 25 | key = uint32_hash % (2**12) / (2**8) 26 | return key.astype(jnp.float32) 27 | 28 | 29 | @pytest.fixture 30 | def heap_setup(): 31 | batch_size = int(1e4) 32 | max_size = int(2e7) 33 | heap = BGPQ.build(max_size, batch_size, XtructureValue, jnp.float32) 34 | 35 | _key_gen = jax.jit(jax.vmap(key_gen)) 36 | 37 | return heap, batch_size, max_size, _key_gen 38 | 39 | 40 | def test_heap_initialization(heap_setup): 41 | heap, batch_size, max_size, _key_gen = heap_setup 42 | assert heap is not None 43 | assert heap.size == 0 44 | assert heap.batch_size == batch_size 45 | 46 | @pytest.mark.parametrize("N", [128, 256, 311, 512, 707] + [random.randint(1, 700) for _ in range(5)]) 47 | def test_heap_insert_and_delete_batch_size(heap_setup, N): 48 | heap, batch_size, max_size, _key_gen = heap_setup 49 | rnd_key = jax.random.PRNGKey(random.randint(0, 1000000)) 50 | 51 | # Test batch insertion 52 | total_size = 0 53 | for i in range(0, N, 1): 54 | rnd_key, seed1 = jax.random.split(rnd_key, 2) 55 | value = XtructureValue.random(shape=(batch_size,), key=seed1) 56 | key = _key_gen(value) 57 | heap = heap.insert(key, value) 58 | total_size += batch_size 59 | assert heap.size == total_size, ( 60 | f"Expected size {total_size}, got {heap.size}," 61 | f"heap.heap_size: {heap.heap_size}, heap.buffer_size: {heap.buffer_size}" 62 | ) 63 | 64 | stacked_val = heap.val_store[:N] 65 | stacked_key = heap.key_store[:N] 66 | 67 | stacked_val_key = jax.vmap(_key_gen)(stacked_val) 68 | isclose = jnp.isclose(stacked_key, stacked_val_key) 69 | assert jnp.all(isclose), ( 70 | f"inserted keys and values mismatch, this means that insert is corrupted" 71 | f"Key and value mismatch, \nstacked_key: \n{stacked_key[jnp.where(~isclose)]}," 72 | f"\nstacked_val_key: \n{stacked_val_key[jnp.where(~isclose)]}," 73 | f"\nstacked_val: \n{stacked_val[jnp.where(~isclose)][:3]}," 74 | f"\nidexs: \n{jnp.where(~isclose)}" 75 | ) 76 | 77 | # Test batch deletion 78 | all_keys = [] 79 | last_maximum_key = -jnp.inf 80 | while heap.size > 0: 81 | heap, min_key, min_val = BGPQ.delete_mins(heap) 82 | filled = jnp.isfinite(min_key) 83 | assert jnp.any(filled), ( 84 | f"delete_mins is corrupted" 85 | f"No keys to delete, \nheap: \n{heap}," 86 | f"\nheap.size: \n{heap.size}," 87 | ) 88 | 89 | # check key and value matching 90 | isclose = jnp.isclose(min_key, _key_gen(min_val)) | ~filled 91 | assert jnp.all(isclose), ( 92 | f"delete_mins is corrupted" 93 | f"Key and value mismatch, \nmin_key: \n{min_key}," 94 | f"\nmin_val_key: \n{_key_gen(min_val)}," 95 | f"\nidexs: \n{jnp.where(~isclose)}" 96 | ) 97 | all_keys.append(min_key) 98 | is_larger = min_key >= last_maximum_key 99 | assert jnp.sum(~is_larger) < 1, ( 100 | f"delete_mins is corrupted" 101 | f"Key is not in ascending order, \nmin_key: \n{min_key}," 102 | f"\nlast_maximum_key: \n{last_maximum_key}," 103 | ) 104 | last_maximum_key = jnp.max(min_key) 105 | 106 | all_keys = jnp.concatenate(all_keys) 107 | diff = all_keys[1:] - all_keys[:-1] 108 | decreasing = diff < 0 109 | # Verify that elements are in ascending order 110 | assert jnp.sum(decreasing) < 1, ( 111 | f"Keys are not in ascending order: {decreasing}" 112 | f"\nfailed_idxs: {jnp.where(decreasing)}" 113 | f"\nincorrect_keys: ({all_keys[jnp.where(decreasing)[0]]}," 114 | f"{all_keys[jnp.where(decreasing)[0] + 1]})" 115 | ) 116 | 117 | 118 | @pytest.mark.parametrize("N", [128, 256, 311, 512, 707] + [random.randint(1, 700) for _ in range(5)]) 119 | def test_heap_insert_and_delete_random_size(heap_setup, N): 120 | heap, batch_size, max_size, _key_gen = heap_setup 121 | rnd_key = jax.random.PRNGKey(random.randint(0, 1000000)) 122 | 123 | # Test batch insertion 124 | total_size = 0 125 | for i in range(0, N, 1): 126 | rnd_key, seed1, seed2 = jax.random.split(rnd_key, 3) 127 | size = jax.random.randint( 128 | seed1, minval=1, maxval=8, shape=() 129 | ) * batch_size // 8 130 | value = XtructureValue.random(shape=(size,), key=seed2) 131 | key = _key_gen(value) 132 | key, value = BGPQ.make_batched(key, value, batch_size) 133 | heap = heap.insert(key, value) 134 | total_size += size 135 | assert heap.size == total_size, ( 136 | f"Expected size {total_size}, got {heap.size}," 137 | f"heap.heap_size: {heap.heap_size}, heap.buffer_size: {heap.buffer_size}" 138 | ) 139 | 140 | stacked_val = heap.val_store[: total_size // batch_size] 141 | stacked_key = heap.key_store[: total_size // batch_size] 142 | 143 | stacked_val_key = jax.vmap(_key_gen)(stacked_val) 144 | isclose = jnp.isclose(stacked_key, stacked_val_key) 145 | assert jnp.all(isclose), ( 146 | f"inserted keys and values mismatch, this means that insert is corrupted" 147 | f"Key and value mismatch, \nstacked_key: \n{stacked_key[jnp.where(~isclose)]}," 148 | f"\nstacked_val_key: \n{stacked_val_key[jnp.where(~isclose)]}," 149 | f"\nstacked_val: \n{stacked_val[jnp.where(~isclose)][:3]}," 150 | f"\nidexs: \n{jnp.where(~isclose)}" 151 | ) 152 | 153 | # Test batch deletion 154 | all_keys = [] 155 | last_maximum_key = -jnp.inf 156 | while heap.size > 0: 157 | heap, min_key, min_val = BGPQ.delete_mins(heap) 158 | filled = jnp.isfinite(min_key) 159 | assert jnp.any(filled), ( 160 | f"delete_mins is corrupted" 161 | f"No keys to delete, \nheap: \n{heap}," 162 | f"\nheap.size: \n{heap.size}," 163 | ) 164 | 165 | # check key and value matching 166 | isclose = jnp.isclose(min_key, _key_gen(min_val)) | ~filled 167 | assert jnp.all(isclose), ( 168 | f"delete_mins is corrupted" 169 | f"Key and value mismatch, \nmin_key: \n{min_key}," 170 | f"\nmin_val_key: \n{_key_gen(min_val)}," 171 | f"\nidexs: \n{jnp.where(~isclose)}" 172 | ) 173 | all_keys.append(min_key) 174 | is_larger = min_key >= last_maximum_key 175 | assert jnp.sum(~is_larger) < 1, ( 176 | f"delete_mins is corrupted" 177 | f"Key is not in ascending order, \nmin_key: \n{min_key}," 178 | f"\nlast_maximum_key: \n{last_maximum_key}," 179 | ) 180 | last_maximum_key = jnp.max(min_key) 181 | 182 | all_keys = jnp.concatenate(all_keys) 183 | diff = all_keys[1:] - all_keys[:-1] 184 | decreasing = diff < 0 185 | # Verify that elements are in ascending order 186 | assert jnp.sum(decreasing) < 1, ( 187 | f"Keys are not in ascending order: {decreasing}" 188 | f"\nfailed_idxs: {jnp.where(decreasing)}" 189 | f"\nincorrect_keys: ({all_keys[jnp.where(decreasing)[0]]}," 190 | f"{all_keys[jnp.where(decreasing)[0] + 1]})" 191 | ) 192 | -------------------------------------------------------------------------------- /tests/queue_test.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | import pytest 5 | 6 | from xtructure import FieldDescriptor, Queue, xtructure_dataclass 7 | 8 | 9 | @xtructure_dataclass 10 | class Point: 11 | x: FieldDescriptor[jnp.uint32] 12 | y: FieldDescriptor[jnp.uint32] 13 | 14 | 15 | LARGE_MAX_SIZE = 100_000 16 | 17 | 18 | @pytest.fixture 19 | def queue(): 20 | """Provides a fresh queue for each test.""" 21 | return Queue.build(max_size=LARGE_MAX_SIZE, value_class=Point) 22 | 23 | 24 | def test_build(queue): 25 | """Tests the initial state of a newly built queue.""" 26 | assert queue.size == 0 27 | assert queue.max_size == LARGE_MAX_SIZE 28 | assert queue.head == 0 29 | assert queue.tail == 0 30 | 31 | 32 | def test_enqueue_single_item(queue): 33 | """Tests enqueuing a single item.""" 34 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 35 | queue = queue.enqueue(p1) 36 | assert queue.size == 1 37 | assert queue.tail == 1 38 | peeked = queue.peek() 39 | chex.assert_trees_all_equal(peeked, p1) 40 | 41 | 42 | def test_enqueue_batch(queue): 43 | """Tests enqueuing a batch of items.""" 44 | batch_size = 5000 45 | points = Point( 46 | x=jnp.arange(batch_size, dtype=jnp.uint32), 47 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 48 | ) 49 | queue = queue.enqueue(points) 50 | assert queue.size == batch_size 51 | assert queue.tail == batch_size 52 | peeked = queue.peek(batch_size) 53 | chex.assert_trees_all_equal(peeked, points) 54 | 55 | 56 | def test_dequeue_single(queue): 57 | """Tests dequeuing items one by one.""" 58 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 59 | p2 = Point(x=jnp.array(3, dtype=jnp.uint32), y=jnp.array(4, dtype=jnp.uint32)) 60 | 61 | queue = queue.enqueue(p1) 62 | queue = queue.enqueue(p2) 63 | assert queue.size == 2 64 | 65 | queue, dequeued = queue.dequeue() 66 | assert queue.size == 1 67 | assert queue.head == 1 68 | chex.assert_trees_all_equal(dequeued, p1) 69 | 70 | queue, dequeued = queue.dequeue() 71 | assert queue.size == 0 72 | assert queue.head == 2 73 | chex.assert_trees_all_equal(dequeued, p2) 74 | 75 | 76 | def test_dequeue_batch(queue): 77 | """Tests dequeuing a batch of items.""" 78 | batch_size = 5000 79 | points = Point( 80 | x=jnp.arange(batch_size, dtype=jnp.uint32), 81 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 82 | ) 83 | queue = queue.enqueue(points) 84 | 85 | dequeue_count = 3000 86 | queue, dequeued = queue.dequeue(dequeue_count) 87 | 88 | assert queue.size == batch_size - dequeue_count 89 | assert queue.head == dequeue_count 90 | expected_dequeued = Point(x=points.x[:dequeue_count], y=points.y[:dequeue_count]) 91 | chex.assert_trees_all_equal(dequeued, expected_dequeued) 92 | 93 | 94 | def test_peek(queue): 95 | """Tests peeking without modifying the queue.""" 96 | batch_size = 5000 97 | points = Point( 98 | x=jnp.arange(batch_size, dtype=jnp.uint32), 99 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 100 | ) 101 | queue = queue.enqueue(points) 102 | 103 | original_size = queue.size 104 | original_head = queue.head 105 | peek_count = 3000 106 | peeked = queue.peek(peek_count) 107 | 108 | assert queue.size == original_size 109 | assert queue.head == original_head 110 | 111 | expected_peeked = Point(x=points.x[:peek_count], y=points.y[:peek_count]) 112 | chex.assert_trees_all_equal(peeked, expected_peeked) 113 | 114 | 115 | def test_clear(queue): 116 | """Tests clearing the queue.""" 117 | points = Point(x=jnp.arange(5, dtype=jnp.uint32), y=jnp.arange(5, 10, dtype=jnp.uint32)) 118 | queue = queue.enqueue(points) 119 | assert queue.size == 5 120 | 121 | queue = queue.clear() 122 | assert queue.size == 0 123 | assert queue.head == 0 124 | assert queue.tail == 0 125 | 126 | 127 | def test_jit_compatibility(queue): 128 | @jax.jit 129 | def sequence(q): 130 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 131 | batch_points = Point( 132 | x=jnp.arange(2, dtype=jnp.uint32), y=jnp.arange(2, 4, dtype=jnp.uint32) 133 | ) 134 | 135 | q = q.enqueue(p1) 136 | q = q.enqueue(batch_points) 137 | q, _ = q.dequeue(2) 138 | return q 139 | 140 | final_queue = sequence(queue) 141 | assert final_queue.size == 1 142 | assert final_queue.head == 2 143 | assert final_queue.tail == 3 144 | -------------------------------------------------------------------------------- /tests/stack_test.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | import pytest 5 | 6 | from xtructure import FieldDescriptor, Stack, xtructure_dataclass 7 | 8 | 9 | @xtructure_dataclass 10 | class Point: 11 | x: FieldDescriptor[jnp.uint32] 12 | y: FieldDescriptor[jnp.uint32] 13 | 14 | 15 | # Use a much larger max_size for more robust testing 16 | LARGE_MAX_SIZE = 100_000 17 | 18 | 19 | @pytest.fixture 20 | def stack(): 21 | """Provides a fresh stack for each test.""" 22 | return Stack.build(max_size=LARGE_MAX_SIZE, value_class=Point) 23 | 24 | 25 | def test_build(stack): 26 | """Tests the initial state of a newly built stack.""" 27 | assert stack.size == 0 28 | assert stack.max_size == LARGE_MAX_SIZE 29 | 30 | 31 | def test_push_single_item(stack): 32 | """Tests pushing a single item onto the stack.""" 33 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 34 | 35 | stack = stack.push(p1) 36 | 37 | assert stack.size == 1 38 | peeked = stack.peek() 39 | # peek returns a batch of 1 40 | chex.assert_trees_all_equal(peeked, p1) 41 | 42 | 43 | def test_push_batch(stack): 44 | """Tests pushing a batch of items onto the stack.""" 45 | batch_size = 5000 46 | points = Point( 47 | x=jnp.arange(batch_size, dtype=jnp.uint32), 48 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 49 | ) 50 | 51 | stack = stack.push(points) 52 | 53 | assert stack.size == batch_size 54 | peeked = stack.peek(batch_size) 55 | chex.assert_trees_all_equal(peeked, points) 56 | 57 | 58 | def test_pop_single(stack): 59 | """Tests popping items one by one.""" 60 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 61 | p2 = Point(x=jnp.array(3, dtype=jnp.uint32), y=jnp.array(4, dtype=jnp.uint32)) 62 | 63 | stack = stack.push(p1) 64 | stack = stack.push(p2) 65 | 66 | assert stack.size == 2 67 | 68 | stack, popped = stack.pop() 69 | assert stack.size == 1 70 | # pop returns a batch of 1 71 | chex.assert_trees_all_equal(popped, p2) 72 | 73 | stack, popped = stack.pop() 74 | assert stack.size == 0 75 | chex.assert_trees_all_equal(popped, p1) 76 | 77 | 78 | def test_pop_batch(stack): 79 | """Tests popping a batch of items.""" 80 | batch_size = 5000 81 | points = Point( 82 | x=jnp.arange(batch_size, dtype=jnp.uint32), 83 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 84 | ) 85 | stack = stack.push(points) 86 | 87 | pop_count = 3000 88 | stack, popped = stack.pop(pop_count) 89 | 90 | assert stack.size == batch_size - pop_count 91 | chex.assert_trees_all_equal( 92 | popped, 93 | Point( 94 | x=jnp.arange(batch_size - pop_count, batch_size, dtype=jnp.uint32), 95 | y=jnp.arange(batch_size * 2 - pop_count, batch_size * 2, dtype=jnp.uint32), 96 | ), 97 | ) 98 | 99 | 100 | def test_peek(stack): 101 | """Tests peeking without modifying the stack.""" 102 | batch_size = 5000 103 | points = Point( 104 | x=jnp.arange(batch_size, dtype=jnp.uint32), 105 | y=jnp.arange(batch_size, batch_size * 2, dtype=jnp.uint32), 106 | ) 107 | stack = stack.push(points) 108 | 109 | original_size = stack.size 110 | peek_count = 3000 111 | peeked = stack.peek(peek_count) 112 | 113 | assert stack.size == original_size 114 | chex.assert_trees_all_equal( 115 | peeked, 116 | Point( 117 | x=jnp.arange(batch_size - peek_count, batch_size, dtype=jnp.uint32), 118 | y=jnp.arange(batch_size * 2 - peek_count, batch_size * 2, dtype=jnp.uint32), 119 | ), 120 | ) 121 | 122 | 123 | def test_jit_compatibility(stack): 124 | @jax.jit 125 | def sequence(stack): 126 | p1 = Point(x=jnp.array(1, dtype=jnp.uint32), y=jnp.array(2, dtype=jnp.uint32)) 127 | p2 = Point(x=jnp.arange(2, dtype=jnp.uint32), y=jnp.arange(2, 4, dtype=jnp.uint32)) 128 | 129 | stack = stack.push(p1) 130 | stack = stack.push(p2) 131 | stack, _ = stack.pop(2) 132 | return stack 133 | 134 | final_stack = sequence(stack) 135 | assert final_stack.size == 1 136 | -------------------------------------------------------------------------------- /xtructure/__init__.py: -------------------------------------------------------------------------------- 1 | from .bgpq import BGPQ 2 | from .core import FieldDescriptor, StructuredType, Xtructurable, xtructure_dataclass 3 | from .hashtable import HashTable 4 | from .queue import Queue 5 | from .stack import Stack 6 | 7 | __all__ = [ 8 | # bgpq.py 9 | "bgpq_value_dataclass", 10 | "HeapValue", 11 | "BGPQ", 12 | # hash.py 13 | "HashTable", 14 | # queue.py 15 | "Queue", 16 | # stack.py 17 | "Stack", 18 | # core.dataclass.py 19 | "Xtructurable", 20 | "xtructure_dataclass", 21 | "StructuredType", 22 | # core.field_descriptors.py 23 | "FieldDescriptor", 24 | ] 25 | -------------------------------------------------------------------------------- /xtructure/bgpq/__init__.py: -------------------------------------------------------------------------------- 1 | from .bgpq import BGPQ 2 | 3 | __all__ = ["BGPQ"] 4 | -------------------------------------------------------------------------------- /xtructure/bgpq/benchmark_merges.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | from jax import jit 7 | 8 | from .merge_split import merge_arrays_indices_loop, merge_arrays_parallel 9 | 10 | 11 | def run_correctness_tests(): 12 | """Runs a series of correctness tests with small, fixed inputs.""" 13 | print("\n--- Running correctness tests ---") 14 | # Test case 1: Identical arrays 15 | a1 = jnp.array([1, 2, 3, 4]) 16 | b1 = jnp.array([1, 2, 3, 4]) 17 | merged_keys1, merged_indices1 = merge_arrays_indices_loop(a1, b1) 18 | concatenated1 = jnp.concatenate([a1, b1]) 19 | assert jnp.array_equal(merged_keys1, concatenated1[merged_indices1]) 20 | print("✅ Test 1 (Identical) PASSED") 21 | 22 | # Test case 2: Interleaved arrays 23 | a2 = jnp.array([1, 5, 9]) 24 | b2 = jnp.array([2, 6, 10]) 25 | merged_keys2, merged_indices2 = merge_arrays_indices_loop(a2, b2) 26 | concatenated2 = jnp.concatenate([a2, b2]) 27 | assert jnp.array_equal(merged_keys2, concatenated2[merged_indices2]) 28 | print("✅ Test 2 (Interleaved) PASSED") 29 | 30 | # Test case 3: One array exhausted first 31 | a3 = jnp.array([1, 2]) 32 | b3 = jnp.array([3, 4, 5, 6]) 33 | merged_keys3, merged_indices3 = merge_arrays_indices_loop(a3, b3) 34 | concatenated3 = jnp.concatenate([a3, b3]) 35 | assert jnp.array_equal(merged_keys3, concatenated3[merged_indices3]) 36 | print("✅ Test 3 (Exhaustion) PASSED") 37 | 38 | # Test case 4: Empty array 39 | a4 = jnp.array([], dtype=jnp.int32) 40 | b4 = jnp.array([1, 2, 3]) 41 | merged_keys4a, merged_indices4a = merge_arrays_indices_loop(a4, b4) 42 | concatenated4a = jnp.concatenate([a4, b4]) 43 | assert jnp.array_equal(merged_keys4a, concatenated4a[merged_indices4a]) 44 | print("✅ Test 4a (Empty Left) PASSED") 45 | 46 | merged_keys4b, merged_indices4b = merge_arrays_indices_loop(b4, a4) 47 | concatenated4b = jnp.concatenate([b4, a4]) 48 | assert jnp.array_equal(merged_keys4b, concatenated4b[merged_indices4b]) 49 | print("✅ Test 4b (Empty Right) PASSED") 50 | 51 | # Test case 5: Arrays with duplicate values across them 52 | a5 = jnp.array([10, 20, 30]) 53 | b5 = jnp.array([10, 25, 30]) 54 | merged_keys5, merged_indices5 = merge_arrays_indices_loop(a5, b5) 55 | concatenated5 = jnp.concatenate([a5, b5]) 56 | assert jnp.array_equal(merged_keys5, concatenated5[merged_indices5]) 57 | print("✅ Test 5 (Duplicates) PASSED") 58 | print("--- All correctness tests passed ---") 59 | 60 | 61 | @partial(jit, static_argnums=(1, 2, 3)) 62 | def generate_sorted_test_data(key, size_ak, size_bk, dtype): 63 | """JIT-compiled function to generate and sort random test arrays.""" 64 | key_ak, key_bk = jr.split(key) 65 | 66 | if jnp.issubdtype(dtype, jnp.integer): 67 | ak_rand = jr.randint(key_ak, (size_ak,), minval=0, maxval=max(1, size_ak * 10), dtype=dtype) 68 | bk_rand = jr.randint(key_bk, (size_bk,), minval=0, maxval=max(1, size_bk * 10), dtype=dtype) 69 | elif jnp.issubdtype(dtype, jnp.floating): 70 | ak_rand = jr.uniform( 71 | key_ak, (size_ak,), dtype=dtype, minval=0.0, maxval=float(max(1, size_ak * 10)) 72 | ) 73 | bk_rand = jr.uniform( 74 | key_bk, (size_bk,), dtype=dtype, minval=0.0, maxval=float(max(1, size_bk * 10)) 75 | ) 76 | else: 77 | raise TypeError(f"Unsupported dtype for random generation: {dtype}") 78 | 79 | ak_sorted = jnp.sort(ak_rand) 80 | bk_sorted = jnp.sort(bk_rand) 81 | return ak_sorted, bk_sorted 82 | 83 | 84 | @jit 85 | def jax_baseline_merge(ak, bk): 86 | """A JIT-compiled baseline merge implementation using standard JAX ops.""" 87 | return jnp.sort(jnp.concatenate([ak, bk])) 88 | 89 | 90 | def verify_and_time_merge(key, size_ak, size_bk, dtype=jnp.int32): 91 | """Generates random data and benchmarks all merge implementations.""" 92 | print(f"\nTesting with ak_size={size_ak}, bk_size={size_bk}, dtype={dtype}") 93 | 94 | # Use the JIT-compiled function for faster data generation 95 | ak_sorted, bk_sorted = generate_sorted_test_data(key, size_ak, size_bk, dtype) 96 | ak_sorted.block_until_ready() 97 | bk_sorted.block_until_ready() 98 | 99 | if size_ak < 10 and size_bk < 10: 100 | print(f" Sorted ak: {ak_sorted}") 101 | print(f" Sorted bk: {bk_sorted}") 102 | 103 | implementations_to_test = { 104 | "pallas_loop": merge_arrays_indices_loop, 105 | "pallas_parallel": merge_arrays_parallel, 106 | } 107 | 108 | reference_merged_jax = jnp.sort(jnp.concatenate([ak_sorted, bk_sorted])) 109 | reference_merged_jax.block_until_ready() 110 | concatenated_inputs = jnp.concatenate([ak_sorted, bk_sorted]) 111 | 112 | all_passed = True 113 | for name, merge_fn in implementations_to_test.items(): 114 | print(f" Verifying implementation: {name}") 115 | try: 116 | merged_keys, merged_indices = merge_fn(ak_sorted, bk_sorted) 117 | merged_keys.block_until_ready() 118 | merged_indices.block_until_ready() 119 | reconstructed = concatenated_inputs[merged_indices] 120 | reconstructed.block_until_ready() 121 | assert jnp.array_equal(merged_keys, reference_merged_jax) 122 | assert jnp.array_equal(reconstructed, reference_merged_jax) 123 | print(" ✅ Correctness check PASSED.") 124 | except AssertionError as e: 125 | all_passed = False 126 | print(f" ❌ Correctness check FAILED for {name}.") 127 | if size_ak < 20 and size_bk < 20: 128 | print(f" Pallas merged keys: {merged_keys}") 129 | print(f" Pallas reconstructed: {reconstructed}") 130 | print(f" JAX reference sorted: {reference_merged_jax}") 131 | print(f" Pallas indices: {merged_indices}") 132 | print(f" ❌ {e}") 133 | except Exception as e: 134 | all_passed = False 135 | print(f" ❌ An exception occurred during {name} execution: {e}") 136 | 137 | if not all_passed: 138 | print(" Skipping timing due to correctness failure.") 139 | return 140 | 141 | print("\n --- Timing Comparison ---") 142 | 143 | # JAX baseline 144 | _ = jax_baseline_merge(ak_sorted, bk_sorted).block_until_ready() 145 | start_time_jax = time.perf_counter() 146 | for _ in range(10): 147 | jax_output_timing = jax_baseline_merge(ak_sorted, bk_sorted) 148 | jax_output_timing.block_until_ready() 149 | end_time_jax = time.perf_counter() 150 | jax_time = (end_time_jax - start_time_jax) / 10 151 | print(f" ⏱️ JAX Baseline (JIT) avg time: {jax_time*1000:.4f} ms") 152 | 153 | # Pallas versions 154 | for name, merge_fn in implementations_to_test.items(): 155 | _keys, _indices = merge_fn(ak_sorted, bk_sorted) 156 | _keys.block_until_ready() 157 | _indices.block_until_ready() 158 | start_time_pallas = time.perf_counter() 159 | for _ in range(10): 160 | pallas_keys_timing, pallas_indices_timing = merge_fn(ak_sorted, bk_sorted) 161 | pallas_keys_timing.block_until_ready() 162 | pallas_indices_timing.block_until_ready() 163 | end_time_pallas = time.perf_counter() 164 | pallas_time = (end_time_pallas - start_time_pallas) / 10 165 | print(f" ⏱️ Pallas '{name}' avg time: {pallas_time*1000:.4f} ms") 166 | 167 | 168 | if __name__ == "__main__": 169 | run_correctness_tests() 170 | 171 | print("\n\n--- Running benchmarks with random values and timing ---") 172 | master_key = jr.PRNGKey(42) 173 | 174 | sizes_to_test = [8, 200, 1000, 5000, int(1e5)] 175 | dtypes_to_test = [jnp.int32, jnp.float32] 176 | 177 | for size in sizes_to_test: 178 | for dtype in dtypes_to_test: 179 | key, subkey = jr.split(master_key) 180 | verify_and_time_merge(subkey, size_ak=size, size_bk=size, dtype=dtype) 181 | 182 | print("\n--- Benchmark complete ---") 183 | -------------------------------------------------------------------------------- /xtructure/bgpq/bgpq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Batched GPU Priority Queue (BGPQ) Implementation 3 | This module provides a JAX-compatible priority queue optimized for GPU operations. 4 | Key features: 5 | - Fully batched operations for GPU efficiency 6 | - Supports custom value types through dataclass 7 | - Uses infinity padding for unused slots 8 | - Maintains sorted order for efficient min/max operations 9 | """ 10 | 11 | from functools import partial 12 | 13 | import chex 14 | import jax 15 | import jax.numpy as jnp 16 | 17 | from ..core import Xtructurable 18 | from .merge_split import merge_arrays_parallel 19 | 20 | SORT_STABLE = True # Use stable sorting to maintain insertion order for equal keys 21 | SIZE_DTYPE = jnp.uint32 22 | 23 | 24 | @jax.jit 25 | def merge_sort_split( 26 | ak: chex.Array, av: Xtructurable, bk: chex.Array, bv: Xtructurable 27 | ) -> tuple[chex.Array, Xtructurable, chex.Array, Xtructurable]: 28 | """ 29 | Merge and split two sorted arrays while maintaining their relative order. 30 | This is a key operation for maintaining heap property in batched operations. 31 | 32 | Args: 33 | ak: First array of keys 34 | av: First array of values 35 | bk: Second array of keys 36 | bv: Second array of values 37 | 38 | Returns: 39 | tuple containing: 40 | - First half of merged and sorted keys 41 | - First half of corresponding values 42 | - Second half of merged and sorted keys 43 | - Second half of corresponding values 44 | """ 45 | n = ak.shape[-1] # size of group 46 | val = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b]), av, bv) 47 | sorted_key, sorted_idx = merge_arrays_parallel(ak, bk) 48 | sorted_val = val[sorted_idx] 49 | return sorted_key[:n], sorted_val[:n], sorted_key[n:], sorted_val[n:] 50 | 51 | 52 | def sort_arrays(k: chex.Array, v: Xtructurable): 53 | sorted_k, sorted_idx = jax.lax.sort_key_val(k, jnp.arange(k.shape[0]), is_stable=SORT_STABLE) 54 | sorted_v = v[sorted_idx] 55 | return sorted_k, sorted_v 56 | 57 | 58 | @jax.jit 59 | def _next(current, target): 60 | """ 61 | Calculate the next index in the heap traversal path. 62 | Uses leading zero count (clz) for efficient binary tree navigation. 63 | 64 | Args: 65 | current: Current index in the heap 66 | target: Target index to reach 67 | 68 | Returns: 69 | Next index in the path from current to target 70 | """ 71 | clz_current = jax.lax.clz(current) 72 | clz_target = jax.lax.clz(target) 73 | shift_amount = clz_current - clz_target - 1 74 | next_index = target.astype(SIZE_DTYPE) >> shift_amount 75 | return next_index 76 | 77 | 78 | @chex.dataclass 79 | class BGPQ: 80 | """ 81 | Batched GPU Priority Queue implementation. 82 | Optimized for parallel operations on GPU using JAX. 83 | 84 | Attributes: 85 | max_size: Maximum number of elements the queue can hold 86 | size: Current number of elements in the queue 87 | branch_size: Number of branches in the heap tree 88 | batch_size: Size of batched operations 89 | key_store: Array storing keys in a binary heap structure 90 | val_store: Array storing associated values 91 | key_buffer: Buffer for keys waiting to be inserted 92 | val_buffer: Buffer for values waiting to be inserted 93 | """ 94 | 95 | max_size: int 96 | heap_size: int 97 | buffer_size: int 98 | branch_size: int 99 | batch_size: int 100 | key_store: chex.Array # shape = (total_size, batch_size) 101 | val_store: Xtructurable # shape = (total_size, batch_size, ...) 102 | key_buffer: chex.Array # shape = (batch_size - 1,) 103 | val_buffer: Xtructurable # shape = (batch_size - 1, ...) 104 | 105 | @staticmethod 106 | @partial(jax.jit, static_argnums=(0, 1, 2, 3)) 107 | def build(total_size, batch_size, value_class=Xtructurable, key_dtype=jnp.float16): 108 | """ 109 | Create a new BGPQ instance with specified capacity. 110 | 111 | Args: 112 | total_size: Total number of elements the queue can store 113 | batch_size: Size of batched operations 114 | value_class: Class to use for storing values (must implement default()) 115 | 116 | Returns: 117 | BGPQ: A new priority queue instance initialized with empty storage 118 | """ 119 | total_size = total_size 120 | # Calculate branch size, rounding up if total_size not divisible by batch_size 121 | branch_size = ( 122 | total_size // batch_size 123 | if total_size % batch_size == 0 124 | else total_size // batch_size + 1 125 | ) 126 | max_size = branch_size * batch_size 127 | heap_size = SIZE_DTYPE(0) 128 | buffer_size = SIZE_DTYPE(0) 129 | 130 | # Initialize storage arrays with infinity for unused slots 131 | key_store = jnp.full((branch_size, batch_size), jnp.inf, dtype=key_dtype) 132 | val_store = value_class.default((branch_size, batch_size)) 133 | key_buffer = jnp.full((batch_size - 1,), jnp.inf, dtype=key_dtype) 134 | val_buffer = value_class.default((batch_size - 1,)) 135 | 136 | return BGPQ( 137 | max_size=max_size, 138 | heap_size=heap_size, 139 | buffer_size=buffer_size, 140 | branch_size=branch_size, 141 | batch_size=batch_size, 142 | key_store=key_store, 143 | val_store=val_store, 144 | key_buffer=key_buffer, 145 | val_buffer=val_buffer, 146 | ) 147 | 148 | @property 149 | def size(self): 150 | return jnp.where( 151 | self.heap_size == 0, 152 | jnp.sum(jnp.isfinite(self.key_store[0])) + self.buffer_size, 153 | (self.heap_size + 1) * self.batch_size + self.buffer_size, 154 | ) 155 | 156 | @jax.jit 157 | def merge_buffer(heap: "BGPQ", blockk: chex.Array, blockv: Xtructurable): 158 | """ 159 | Merge buffer contents with block contents, handling overflow conditions. 160 | 161 | This method is crucial for maintaining the heap property when inserting new elements. 162 | It handles the case where the buffer might overflow into the main storage. 163 | 164 | Args: 165 | blockk: Block keys array 166 | blockv: Block values 167 | bufferk: Buffer keys array 168 | bufferv: Buffer values 169 | 170 | Returns: 171 | tuple containing: 172 | - Updated block keys 173 | - Updated block values 174 | - Updated buffer keys 175 | - Updated buffer values 176 | - Boolean indicating if buffer overflow occurred 177 | """ 178 | n = blockk.shape[0] 179 | # Concatenate block and buffer 180 | sorted_key, sorted_idx = merge_arrays_parallel(blockk, heap.key_buffer) 181 | val = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b]), blockv, heap.val_buffer) 182 | val = val[sorted_idx] 183 | 184 | # Check for active elements (non-infinity) 185 | filled = jnp.isfinite(sorted_key) 186 | n_filled = jnp.sum(filled) 187 | buffer_overflow = n_filled >= n 188 | 189 | def overflowed(key, val): 190 | """Handle case where buffer overflows""" 191 | return key[:n], val[:n], key[n:], val[n:] 192 | 193 | def not_overflowed(key, val): 194 | return key[-n:], val[-n:], key[:-n], val[:-n] 195 | 196 | blockk, blockv, heap.key_buffer, heap.val_buffer = jax.lax.cond( 197 | buffer_overflow, 198 | overflowed, 199 | not_overflowed, 200 | sorted_key, 201 | val, 202 | ) 203 | heap.buffer_size = jnp.sum(jnp.isfinite(heap.key_buffer), dtype=SIZE_DTYPE) 204 | return heap, blockk, blockv, buffer_overflow 205 | 206 | @staticmethod 207 | @partial(jax.jit, static_argnums=(2)) 208 | def make_batched(key: chex.Array, val: Xtructurable, batch_size: int): 209 | """ 210 | Convert unbatched arrays into batched format suitable for the queue. 211 | 212 | Args: 213 | key: Array of keys to batch 214 | val: Xtructurable of values to batch 215 | batch_size: Desired batch size 216 | 217 | Returns: 218 | tuple containing: 219 | - Batched key array 220 | - Batched value array 221 | """ 222 | n = key.shape[0] 223 | # Pad arrays to match batch size 224 | key_class = key.dtype 225 | key = jnp.concatenate([key, jnp.full((batch_size - n,), jnp.inf, dtype=key_class)]) 226 | val = jax.tree_util.tree_map( 227 | lambda x, y: jnp.concatenate([x, y]), 228 | val, 229 | val.default((batch_size - n,)), 230 | ) 231 | return key, val 232 | 233 | @staticmethod 234 | def _insert_heapify(heap: "BGPQ", block_key: chex.Array, block_val: Xtructurable): 235 | """ 236 | Internal method to maintain heap property after insertion. 237 | Performs heapification by traversing up the tree and merging nodes. 238 | 239 | Args: 240 | heap: The priority queue instance 241 | block_key: Keys to insert 242 | block_val: Values to insert 243 | 244 | Returns: 245 | tuple containing: 246 | - Updated heap 247 | - Boolean indicating if insertion was successful 248 | """ 249 | last_node = SIZE_DTYPE(heap.heap_size + 1) 250 | 251 | def _cond(var): 252 | """Continue while not reached last node""" 253 | _, _, _, n = var 254 | return n < last_node 255 | 256 | def insert_heapify(var): 257 | """Perform one step of heapification""" 258 | heap, keys, values, n = var 259 | head, hvalues, keys, values = merge_sort_split( 260 | heap.key_store[n], heap.val_store[n], keys, values 261 | ) 262 | heap.key_store = heap.key_store.at[n].set(head) 263 | heap.val_store = heap.val_store.at[n].set(hvalues) 264 | return heap, keys, values, _next(n, last_node) 265 | 266 | heap, keys, values, _ = jax.lax.while_loop( 267 | _cond, 268 | insert_heapify, 269 | ( 270 | heap, 271 | block_key, 272 | block_val, 273 | _next(SIZE_DTYPE(0), last_node), 274 | ), 275 | ) 276 | 277 | def _size_not_full(heap, keys, values): 278 | """Insert remaining elements if heap not full""" 279 | heap.key_store = heap.key_store.at[last_node].set(keys) 280 | heap.val_store = heap.val_store.at[last_node].set(values) 281 | return heap 282 | 283 | added = last_node < heap.branch_size 284 | heap = jax.lax.cond( 285 | added, _size_not_full, lambda heap, keys, values: heap, heap, keys, values 286 | ) 287 | return heap, added 288 | 289 | @jax.jit 290 | def insert(heap: "BGPQ", block_key: chex.Array, block_val: Xtructurable): 291 | """ 292 | Insert new elements into the priority queue. 293 | Maintains heap property through merge operations and heapification. 294 | 295 | Args: 296 | heap: The priority queue instance 297 | block_key: Keys to insert 298 | block_val: Values to insert 299 | added_size: Optional size of insertion (calculated if None) 300 | 301 | Returns: 302 | Updated heap instance 303 | """ 304 | block_key, block_val = sort_arrays(block_key, block_val) 305 | # Merge with root node 306 | root_key, root_val, block_key, block_val = merge_sort_split( 307 | heap.key_store[0], heap.val_store[0], block_key, block_val 308 | ) 309 | heap.key_store = heap.key_store.at[0].set(root_key) 310 | heap.val_store = heap.val_store.at[0].set(root_val) 311 | 312 | # Handle buffer overflow 313 | heap, block_key, block_val, buffer_overflow = heap.merge_buffer(block_key, block_val) 314 | 315 | # Perform heapification if needed 316 | heap, added = jax.lax.cond( 317 | buffer_overflow, 318 | BGPQ._insert_heapify, 319 | lambda heap, block_key, block_val: (heap, False), 320 | heap, 321 | block_key, 322 | block_val, 323 | ) 324 | heap.heap_size = SIZE_DTYPE(heap.heap_size + added) 325 | return heap 326 | 327 | @staticmethod 328 | def delete_heapify(heap: "BGPQ"): 329 | """ 330 | Maintain heap property after deletion of minimum elements. 331 | 332 | Args: 333 | heap: The priority queue instance 334 | 335 | Returns: 336 | Updated heap instance 337 | """ 338 | 339 | last = heap.heap_size 340 | heap.heap_size = SIZE_DTYPE(last - 1) 341 | 342 | # Move last node to root and clear last position 343 | last_key = heap.key_store[last] 344 | last_val = heap.val_store[last] 345 | 346 | heap.key_store = heap.key_store.at[last].set(jnp.inf) 347 | 348 | root_key, root_val, heap.key_buffer, heap.val_buffer = merge_sort_split( 349 | last_key, last_val, heap.key_buffer, heap.val_buffer 350 | ) 351 | 352 | heap.key_store = heap.key_store.at[0].set(root_key) 353 | heap.val_store = heap.val_store.at[0].set(root_val) 354 | 355 | def _lr(n): 356 | """Get left and right child indices""" 357 | left_child = n * 2 + 1 358 | right_child = n * 2 + 2 359 | return left_child, right_child 360 | 361 | def _cond(var): 362 | """Continue while heap property is violated""" 363 | heap, c, l, r = var 364 | max_c = heap.key_store[c][-1] 365 | min_l = heap.key_store[l][0] 366 | min_r = heap.key_store[r][0] 367 | min_lr = jnp.minimum(min_l, min_r) 368 | return max_c > min_lr 369 | 370 | def _f(var): 371 | """Perform one step of heapification""" 372 | heap, current_node, left_child, right_child = var 373 | max_left_child = heap.key_store[left_child][-1] 374 | max_right_child = heap.key_store[right_child][-1] 375 | 376 | # Choose child with smaller key 377 | x, y = jax.lax.cond( 378 | max_left_child > max_right_child, 379 | lambda _: (left_child, right_child), 380 | lambda _: (right_child, left_child), 381 | None, 382 | ) 383 | 384 | # Merge and swap nodes 385 | ky, vy, kx, vx = merge_sort_split( 386 | heap.key_store[left_child], 387 | heap.val_store[left_child], 388 | heap.key_store[right_child], 389 | heap.val_store[right_child], 390 | ) 391 | kc, vc, ky, vy = merge_sort_split( 392 | heap.key_store[current_node], heap.val_store[current_node], ky, vy 393 | ) 394 | heap.key_store = heap.key_store.at[y].set(ky) 395 | heap.key_store = heap.key_store.at[current_node].set(kc) 396 | heap.key_store = heap.key_store.at[x].set(kx) 397 | heap.val_store = heap.val_store.at[y].set(vy) 398 | heap.val_store = heap.val_store.at[current_node].set(vc) 399 | heap.val_store = heap.val_store.at[x].set(vx) 400 | 401 | nc = y 402 | nl, nr = _lr(y) 403 | return heap, nc, nl, nr 404 | 405 | c = SIZE_DTYPE(0) 406 | l, r = _lr(c) 407 | heap, _, _, _ = jax.lax.while_loop(_cond, _f, (heap, c, l, r)) 408 | return heap 409 | 410 | @jax.jit 411 | def delete_mins(heap: "BGPQ"): 412 | """ 413 | Remove and return the minimum elements from the queue. 414 | 415 | Args: 416 | heap: The priority queue instance 417 | 418 | Returns: 419 | tuple containing: 420 | - Updated heap instance 421 | - Array of minimum keys removed 422 | - Xtructurable of corresponding values 423 | """ 424 | min_keys = heap.key_store[0] 425 | min_values = heap.val_store[0] 426 | 427 | def make_empty(heap: "BGPQ"): 428 | """Handle case where heap becomes empty""" 429 | root_key, root_val, heap.key_buffer, heap.val_buffer = merge_sort_split( 430 | jnp.full_like(heap.key_store[0], jnp.inf), 431 | heap.val_store[0], 432 | heap.key_buffer, 433 | heap.val_buffer, 434 | ) 435 | heap.key_store = heap.key_store.at[0].set(root_key) 436 | heap.val_store = heap.val_store.at[0].set(root_val) 437 | heap.buffer_size = SIZE_DTYPE(0) 438 | return heap 439 | 440 | heap = jax.lax.cond(heap.heap_size == 0, make_empty, BGPQ.delete_heapify, heap) 441 | return heap, min_keys, min_values 442 | -------------------------------------------------------------------------------- /xtructure/bgpq/merge_split/__init__.py: -------------------------------------------------------------------------------- 1 | from .loop import merge_arrays_indices_loop 2 | from .parallel import merge_arrays_parallel 3 | from .split import merge_sort_split_idx 4 | 5 | __all__ = [ 6 | "merge_arrays_indices_loop", 7 | "merge_arrays_parallel", 8 | "merge_sort_split_idx", 9 | ] 10 | -------------------------------------------------------------------------------- /xtructure/bgpq/merge_split/common.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.experimental import pallas as pl 4 | 5 | 6 | def _get_sentinels(dtype): 7 | """Returns the min and max sentinel values for a given dtype.""" 8 | if jnp.issubdtype(dtype, jnp.integer): 9 | return jnp.iinfo(dtype).min, jnp.iinfo(dtype).max 10 | if jnp.issubdtype(dtype, jnp.floating): 11 | return dtype.type(-jnp.inf), dtype.type(jnp.inf) 12 | raise TypeError(f"Unsupported dtype for sentinel values: {dtype}") 13 | 14 | 15 | def binary_search_partition(k, a, b): 16 | """ 17 | Finds the partition of k elements between sorted arrays a and b. 18 | 19 | This function implements the core logic of the "Merge Path" algorithm. It 20 | uses binary search to find a split point (i, j) such that i elements from 21 | array `a` and j elements from array `b` constitute the first k elements 22 | of the merged array. Thus, i + j = k. 23 | 24 | The search finds an index `i` in `[0, n]` that satisfies the condition: 25 | `a[i-1] <= b[j]` and `b[j-1] <= a[i]`, where `j = k - i`. These checks 26 | define a valid merge partition. The binary search below finds the 27 | largest `i` that satisfies `a[i-1] <= b[k-i]`. 28 | 29 | Args: 30 | k: The total number of elements in the target partition (the "diagonal" 31 | of the merge path grid). 32 | a: A sorted JAX array or a Pallas Ref to one. 33 | b: A sorted JAX array or a Pallas Ref to one. 34 | 35 | Returns: 36 | A tuple (i, j) where i is the number of elements to take from a and j 37 | is the number of elements from b, satisfying i + j = k. 38 | """ 39 | n = a.shape[0] 40 | m = b.shape[0] 41 | 42 | # The number of elements from `a`, `i`, must be in the range [low, high]. 43 | low = jnp.maximum(0, k - m) 44 | high = jnp.minimum(n, k) 45 | 46 | # Binary search for the correct partition index `i`. We are looking for the 47 | # largest `i` in `[low, high]` such that `a[i-1] <= b[k-i]`. 48 | def cond_fn(state): 49 | low_i, high_i = state 50 | return low_i < high_i 51 | 52 | def body_fn(state): 53 | low_i, high_i = state 54 | # Bias the midpoint to the right to ensure the loop terminates when 55 | # searching for the "last true" condition. 56 | i = low_i + (high_i - low_i + 1) // 2 57 | j = k - i 58 | 59 | min_val, max_val = _get_sentinels(a.dtype) 60 | is_a_safe = i > 0 61 | is_b_safe = j < m 62 | 63 | # A more robust way to handle conditional loading in Pallas to avoid 64 | # the `scf.yield` lowering error. 65 | # 1. Select a safe index to load from (0 if out of bounds). 66 | # 2. Perform the load unconditionally. 67 | # 3. Use `where` to replace the loaded value with a sentinel if the 68 | # original index was out of bounds. 69 | safe_a_idx = jnp.where(is_a_safe, i - 1, 0) 70 | a_val_loaded = pl.load(a, (safe_a_idx,)) 71 | a_val = jnp.where(is_a_safe, a_val_loaded, min_val) 72 | 73 | safe_b_idx = jnp.where(is_b_safe, j, 0) 74 | b_val_loaded = pl.load(b, (safe_b_idx,)) 75 | b_val = jnp.where(is_b_safe, b_val_loaded, max_val) 76 | 77 | # The condition for a valid partition from `a`'s perspective. 78 | # If `a[i-1] <= b[j]`, then `i` is a valid candidate, and we can 79 | # potentially take even more from `a`. So, we search in `[i, high]`. 80 | # Otherwise, `i` is too high, and we must search in `[low, i-1]`. 81 | is_partition_valid = a_val <= b_val 82 | new_low = jnp.where(is_partition_valid, i, low_i) 83 | new_high = jnp.where(is_partition_valid, high_i, i - 1) 84 | return new_low, new_high 85 | 86 | # The loop terminates when low == high, and `final_low` is our desired `i`. 87 | final_low, _ = jax.lax.while_loop(cond_fn, body_fn, (low, high)) 88 | return final_low, k - final_low 89 | -------------------------------------------------------------------------------- /xtructure/bgpq/merge_split/loop.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.experimental import pallas as pl 6 | 7 | 8 | def merge_indices_kernel_loop(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref): 9 | """ 10 | Pallas kernel to merge two sorted arrays (ak, bk) and write the indices 11 | of the merged elements (relative to a conceptual [ak, bk] concatenation) 12 | into merged_indices_ref. Uses explicit loops and Pallas memory operations. 13 | """ 14 | n = ak_ref.shape[0] 15 | m = bk_ref.shape[0] 16 | 17 | def true_branch_body_fn(cond_operands): 18 | ( 19 | current_idx_a, 20 | current_idx_b, 21 | current_out_ptr_val, 22 | val_a_to_store, 23 | _, 24 | _, 25 | merged_keys_ref_from_cond, 26 | merged_indices_ref_from_cond, 27 | ) = cond_operands 28 | val_a_casted = val_a_to_store.astype(merged_keys_ref_from_cond.dtype) 29 | pl.store( 30 | merged_keys_ref_from_cond, 31 | (current_out_ptr_val,), 32 | val_a_casted, 33 | eviction_policy="evict_last", 34 | ) 35 | pl.store( 36 | merged_indices_ref_from_cond, 37 | (current_out_ptr_val,), 38 | current_idx_a, 39 | eviction_policy="evict_last", 40 | ) 41 | return current_idx_a + 1, current_idx_b 42 | 43 | def false_branch_body_fn(cond_operands): 44 | ( 45 | current_idx_a, 46 | current_idx_b, 47 | current_out_ptr_val, 48 | _, 49 | val_b_to_store, 50 | _, 51 | merged_keys_ref_from_cond, 52 | merged_indices_ref_from_cond, 53 | ) = cond_operands 54 | val_b_casted = val_b_to_store.astype(merged_keys_ref_from_cond.dtype) 55 | pl.store( 56 | merged_keys_ref_from_cond, 57 | (current_out_ptr_val,), 58 | val_b_casted, 59 | eviction_policy="evict_last", 60 | ) 61 | pl.store( 62 | merged_indices_ref_from_cond, 63 | (current_out_ptr_val,), 64 | n + current_idx_b, 65 | eviction_policy="evict_last", 66 | ) 67 | return current_idx_a, current_idx_b + 1 68 | 69 | initial_main_loop_state = (0, 0, 0, ak_ref, bk_ref, merged_keys_ref, merged_indices_ref) 70 | 71 | def main_loop_condition(state): 72 | idx_a, idx_b, _, _, _, _, _ = state 73 | return jnp.logical_and(idx_a < n, idx_b < m) 74 | 75 | def main_loop_body(state): 76 | ( 77 | idx_a, 78 | idx_b, 79 | out_ptr, 80 | loop_ak_ref, 81 | loop_bk_ref, 82 | loop_merged_keys_ref, 83 | loop_merged_indices_ref, 84 | ) = state 85 | val_a = pl.load(loop_ak_ref, (idx_a,)) 86 | val_b = pl.load(loop_bk_ref, (idx_b,)) 87 | pred = val_a <= val_b 88 | 89 | updated_idx_a, updated_idx_b = jax.lax.cond( 90 | pred, 91 | true_branch_body_fn, 92 | false_branch_body_fn, 93 | ( 94 | idx_a, 95 | idx_b, 96 | out_ptr, 97 | val_a, 98 | val_b, 99 | loop_ak_ref, 100 | loop_merged_keys_ref, 101 | loop_merged_indices_ref, 102 | ), 103 | ) 104 | return ( 105 | updated_idx_a, 106 | updated_idx_b, 107 | out_ptr + 1, 108 | loop_ak_ref, 109 | loop_bk_ref, 110 | loop_merged_keys_ref, 111 | loop_merged_indices_ref, 112 | ) 113 | 114 | final_state_after_main_loop = jax.lax.while_loop( 115 | main_loop_condition, main_loop_body, initial_main_loop_state 116 | ) 117 | 118 | ( 119 | idx_a, 120 | idx_b, 121 | out_ptr, 122 | _, 123 | _, 124 | final_loop_merged_keys_ref, 125 | final_loop_merged_indices_ref, 126 | ) = final_state_after_main_loop 127 | 128 | initial_ak_loop_state = ( 129 | idx_a, 130 | out_ptr, 131 | ak_ref, 132 | final_loop_merged_keys_ref, 133 | final_loop_merged_indices_ref, 134 | ) 135 | 136 | def ak_loop_condition(state): 137 | current_idx_a, _, _, _, _ = state 138 | return current_idx_a < n 139 | 140 | def ak_loop_body(state): 141 | ( 142 | current_idx_a, 143 | current_out_ptr, 144 | loop_ak_ref, 145 | loop_merged_keys_ref, 146 | loop_merged_indices_ref, 147 | ) = state 148 | val_to_store = pl.load(loop_ak_ref, (current_idx_a,)) 149 | val_casted = val_to_store.astype(loop_merged_keys_ref.dtype) 150 | pl.store( 151 | loop_merged_keys_ref, 152 | (current_out_ptr,), 153 | val_casted, 154 | eviction_policy="evict_last", 155 | ) 156 | pl.store( 157 | loop_merged_indices_ref, (current_out_ptr,), current_idx_a, eviction_policy="evict_last" 158 | ) 159 | return ( 160 | current_idx_a + 1, 161 | current_out_ptr + 1, 162 | loop_ak_ref, 163 | loop_merged_keys_ref, 164 | loop_merged_indices_ref, 165 | ) 166 | 167 | final_state_after_ak_loop = jax.lax.while_loop( 168 | ak_loop_condition, ak_loop_body, initial_ak_loop_state 169 | ) 170 | ( 171 | idx_a, 172 | out_ptr, 173 | _, 174 | final_loop_merged_keys_ref, 175 | final_loop_merged_indices_ref, 176 | ) = final_state_after_ak_loop 177 | 178 | initial_bk_loop_state = ( 179 | idx_b, 180 | out_ptr, 181 | bk_ref, 182 | final_loop_merged_keys_ref, 183 | final_loop_merged_indices_ref, 184 | ) 185 | 186 | def bk_loop_condition(state): 187 | current_idx_b, _, _, _, _ = state 188 | return current_idx_b < m 189 | 190 | def bk_loop_body(state): 191 | ( 192 | current_idx_b, 193 | current_out_ptr, 194 | loop_bk_ref, 195 | loop_merged_keys_ref, 196 | loop_merged_indices_ref, 197 | ) = state 198 | val_to_store = pl.load(loop_bk_ref, (current_idx_b,)) 199 | val_casted = val_to_store.astype(loop_merged_keys_ref.dtype) 200 | pl.store( 201 | loop_merged_keys_ref, 202 | (current_out_ptr,), 203 | val_casted, 204 | eviction_policy="evict_last", 205 | ) 206 | pl.store( 207 | loop_merged_indices_ref, 208 | (current_out_ptr,), 209 | n + current_idx_b, 210 | eviction_policy="evict_last", 211 | ) 212 | return ( 213 | current_idx_b + 1, 214 | current_out_ptr + 1, 215 | loop_bk_ref, 216 | loop_merged_keys_ref, 217 | loop_merged_indices_ref, 218 | ) 219 | 220 | jax.lax.while_loop(bk_loop_condition, bk_loop_body, initial_bk_loop_state) 221 | 222 | 223 | @jax.jit 224 | def merge_arrays_indices_loop(ak: jax.Array, bk: jax.Array) -> Tuple[jax.Array, jax.Array]: 225 | """ 226 | Merges two sorted JAX arrays ak and bk using a loop-based Pallas kernel 227 | and returns a tuple containing: 228 | - merged_keys: The sorted merged array of keys. 229 | - merged_indices: An array of indices representing the merged order. 230 | The indices refer to the positions in a conceptual concatenation [ak, bk]. 231 | """ 232 | if ak.ndim != 1 or bk.ndim != 1: 233 | raise ValueError("Input arrays ak and bk must be 1D.") 234 | 235 | n = ak.shape[0] 236 | m = bk.shape[0] 237 | 238 | key_dtype = jnp.result_type(ak.dtype, bk.dtype) 239 | out_keys_shape_dtype = jax.ShapeDtypeStruct((n + m,), key_dtype) 240 | out_idx_shape_dtype = jax.ShapeDtypeStruct((n + m,), jnp.int32) 241 | 242 | return pl.pallas_call( 243 | merge_indices_kernel_loop, out_shape=(out_keys_shape_dtype, out_idx_shape_dtype) 244 | )(ak, bk) 245 | -------------------------------------------------------------------------------- /xtructure/bgpq/merge_split/parallel.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.experimental import pallas as pl 6 | 7 | from .common import binary_search_partition 8 | 9 | BLOCK_SIZE = 64 10 | 11 | 12 | def merge_parallel_kernel(ak_ref, bk_ref, merged_keys_ref, merged_indices_ref): 13 | """ 14 | Pallas kernel that merges two sorted arrays in parallel using the 15 | Merge Path algorithm for block-level partitioning. 16 | """ 17 | block_idx = pl.program_id(axis=0) 18 | 19 | n, m = ak_ref.shape[0], bk_ref.shape[0] 20 | total_len = n + m 21 | 22 | k_start = block_idx * BLOCK_SIZE 23 | k_end = jnp.minimum(k_start + BLOCK_SIZE, total_len) 24 | 25 | a_start, b_start = binary_search_partition(k_start, ak_ref, bk_ref) 26 | a_end, b_end = binary_search_partition(k_end, ak_ref, bk_ref) 27 | 28 | initial_main_loop_state = (a_start, b_start, k_start) 29 | 30 | def main_loop_cond(state): 31 | idx_a, idx_b, _ = state 32 | return jnp.logical_and(idx_a < a_end, idx_b < b_end) 33 | 34 | def main_loop_body(state): 35 | idx_a, idx_b, out_ptr = state 36 | val_a = pl.load(ak_ref, (idx_a,)) 37 | val_b = pl.load(bk_ref, (idx_b,)) 38 | is_a_le_b = val_a <= val_b 39 | 40 | key_to_store = jnp.where(is_a_le_b, val_a, val_b) 41 | idx_to_store = jnp.where(is_a_le_b, idx_a, n + idx_b) 42 | 43 | key_casted = key_to_store.astype(merged_keys_ref.dtype) 44 | pl.store(merged_keys_ref, (out_ptr,), key_casted) 45 | pl.store(merged_indices_ref, (out_ptr,), idx_to_store) 46 | 47 | next_idx_a = jnp.where(is_a_le_b, idx_a + 1, idx_a) 48 | next_idx_b = jnp.where(is_a_le_b, idx_b, idx_b + 1) 49 | return next_idx_a, next_idx_b, out_ptr + 1 50 | 51 | idx_a, idx_b, out_ptr = jax.lax.while_loop( 52 | main_loop_cond, main_loop_body, initial_main_loop_state 53 | ) 54 | 55 | initial_ak_loop_state = (idx_a, out_ptr) 56 | 57 | def ak_loop_cond(state): 58 | current_idx_a, _ = state 59 | return current_idx_a < a_end 60 | 61 | def ak_loop_body(state): 62 | current_idx_a, current_out_ptr = state 63 | val_to_store = pl.load(ak_ref, (current_idx_a,)) 64 | val_casted = val_to_store.astype(merged_keys_ref.dtype) 65 | pl.store(merged_keys_ref, (current_out_ptr,), val_casted) 66 | pl.store(merged_indices_ref, (current_out_ptr,), current_idx_a) 67 | return current_idx_a + 1, current_out_ptr + 1 68 | 69 | idx_a, out_ptr = jax.lax.while_loop(ak_loop_cond, ak_loop_body, initial_ak_loop_state) 70 | 71 | initial_bk_loop_state = (idx_b, out_ptr) 72 | 73 | def bk_loop_cond(state): 74 | current_idx_b, _ = state 75 | return current_idx_b < b_end 76 | 77 | def bk_loop_body(state): 78 | current_idx_b, current_out_ptr = state 79 | val_to_store = pl.load(bk_ref, (current_idx_b,)) 80 | val_casted = val_to_store.astype(merged_keys_ref.dtype) 81 | pl.store(merged_keys_ref, (current_out_ptr,), val_casted) 82 | pl.store(merged_indices_ref, (current_out_ptr,), n + current_idx_b) 83 | return current_idx_b + 1, current_out_ptr + 1 84 | 85 | jax.lax.while_loop(bk_loop_cond, bk_loop_body, initial_bk_loop_state) 86 | 87 | 88 | @jax.jit 89 | def merge_arrays_parallel(ak: jax.Array, bk: jax.Array) -> Tuple[jax.Array, jax.Array]: 90 | """ 91 | Merges two sorted JAX arrays using the parallel Merge Path Pallas kernel. 92 | """ 93 | if ak.ndim != 1 or bk.ndim != 1: 94 | raise ValueError("Input arrays ak and bk must be 1D.") 95 | 96 | n, m = ak.shape[0], bk.shape[0] 97 | total_len = n + m 98 | if total_len == 0: 99 | key_dtype = jnp.result_type(ak.dtype, bk.dtype) 100 | return jnp.array([], dtype=key_dtype), jnp.array([], dtype=jnp.int32) 101 | 102 | key_dtype = jnp.result_type(ak.dtype, bk.dtype) 103 | out_keys_shape_dtype = jax.ShapeDtypeStruct((total_len,), key_dtype) 104 | out_idx_shape_dtype = jax.ShapeDtypeStruct((total_len,), jnp.int32) 105 | 106 | grid_size = (total_len + BLOCK_SIZE - 1) // BLOCK_SIZE 107 | 108 | return pl.pallas_call( 109 | merge_parallel_kernel, 110 | grid=(grid_size,), 111 | out_shape=(out_keys_shape_dtype, out_idx_shape_dtype), 112 | )(ak, bk) 113 | -------------------------------------------------------------------------------- /xtructure/bgpq/merge_split/split.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax.experimental import pallas as pl 6 | 7 | 8 | def merge_sort_split_kernel( 9 | ak_ref, 10 | bk_ref, 11 | res_key0_ref, 12 | res_idx0_ref, 13 | res_key1_ref, 14 | res_idx1_ref, 15 | ): 16 | """ 17 | Merge and split two sorted arrays while maintaining their relative order. 18 | This version is Pallas-compliant: writes to output refs, no return. 19 | Outputs are: 20 | - res_key0: First N keys of the merged and sorted array. 21 | - res_idx0: Corresponding original indices for res_key0. 22 | - res_key1: Remaining keys of the merged and sorted array. 23 | - res_idx1: Corresponding original indices for res_key1. 24 | """ 25 | ak_val, bk_val = ak_ref[...], bk_ref[...] 26 | n_split = ak_val.shape[-1] 27 | key_concat = jnp.concatenate([ak_val, bk_val]) 28 | indices_payload = jnp.arange(key_concat.shape[0], dtype=jnp.int32) 29 | sorted_key_full, sorted_idx_full = jax.lax.sort_key_val(key_concat, indices_payload) 30 | res_key0_ref[...] = sorted_key_full[:n_split] 31 | res_idx0_ref[...] = sorted_idx_full[:n_split] 32 | res_key1_ref[...] = sorted_key_full[n_split:] 33 | res_idx1_ref[...] = sorted_idx_full[n_split:] 34 | 35 | 36 | @jax.jit 37 | def merge_sort_split_idx( 38 | ak: jax.Array, bk: jax.Array 39 | ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]: 40 | len_ak_part = ak.shape[-1] 41 | len_bk_part = bk.shape[-1] 42 | key_dtype = jnp.result_type(ak.dtype, bk.dtype) 43 | shape_key0 = jax.ShapeDtypeStruct(shape=ak.shape[:-1] + (len_ak_part,), dtype=key_dtype) 44 | shape_idx0 = jax.ShapeDtypeStruct(shape=ak.shape[:-1] + (len_ak_part,), dtype=jnp.int32) 45 | shape_key1 = jax.ShapeDtypeStruct(shape=bk.shape[:-1] + (len_bk_part,), dtype=key_dtype) 46 | shape_idx1 = jax.ShapeDtypeStruct(shape=bk.shape[:-1] + (len_bk_part,), dtype=jnp.int32) 47 | 48 | return pl.pallas_call( 49 | merge_sort_split_kernel, 50 | out_shape=(shape_key0, shape_idx0, shape_key1, shape_idx1), 51 | )(ak, bk) 52 | -------------------------------------------------------------------------------- /xtructure/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .field_descriptors import FieldDescriptor 2 | from .protocol import Xtructurable 3 | from .structuredtype import StructuredType 4 | from .xtructure_decorators import xtructure_dataclass 5 | 6 | __all__ = ["Xtructurable", "StructuredType", "xtructure_dataclass", "FieldDescriptor"] 7 | -------------------------------------------------------------------------------- /xtructure/core/field_descriptors.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import jax.numpy as jnp 4 | 5 | # Represents a JAX dtype, can be a specific type like jnp.int32 or a more generic jnp.dtype 6 | DType = Any 7 | 8 | 9 | class FieldDescriptor: 10 | """ 11 | A descriptor for fields in an xtructure_dataclass. 12 | 13 | This class is used to define the properties of fields in a dataclass decorated with 14 | @xtructure_dataclass. It specifies the JAX dtype, shape, and default fill value 15 | for each field. 16 | 17 | Example usage: 18 | ```python 19 | @xtructure_dataclass 20 | class MyData: 21 | # A scalar uint8 field 22 | a: FieldDescriptor[jnp.uint8] 23 | 24 | # A field with shape (1, 2) of uint32 values 25 | b: FieldDescriptor[jnp.uint32, (1, 2)] 26 | 27 | # A float field with custom fill value 28 | c: FieldDescriptor(dtype=jnp.float32, fill_value=0.0) 29 | 30 | # A nested xtructure_dataclass field 31 | d: FieldDescriptor[AnotherDataClass] 32 | ``` 33 | 34 | The FieldDescriptor can be used with type annotation syntax using square brackets 35 | or instantiated directly with the constructor for more explicit parameter naming. 36 | Describes a field in an xtructure_data class, specifying its JAX dtype, 37 | a default fill value, and its intrinsic (non-batched) shape. 38 | This allows for auto-generation of the .default() classmethod. 39 | """ 40 | 41 | def __init__(self, dtype: DType, intrinsic_shape: Tuple[int, ...] = (), fill_value: Any = None): 42 | """ 43 | Initializes a FieldDescriptor. 44 | 45 | Args: 46 | dtype: The JAX dtype of the field (e.g., jnp.int32, jnp.float32). 47 | fill_value: The default value to fill the field's array with 48 | (e.g., -1, 0.0). 49 | intrinsic_shape: The shape of the field itself, before any batching. 50 | Defaults to () for a scalar field. 51 | """ 52 | self.dtype: DType = dtype 53 | # Set default fill values based on dtype 54 | if fill_value is None: 55 | if hasattr(dtype, "dataclass"): 56 | # Handle xtructure_dataclass types 57 | self.fill_value = fill_value 58 | elif jnp.issubdtype(dtype, jnp.unsignedinteger): 59 | # For unsigned integers, use -1 (which wraps to max value) 60 | self.fill_value = -1 61 | elif jnp.issubdtype(dtype, jnp.integer) or jnp.issubdtype(dtype, jnp.floating): 62 | # For signed integers and floats, use infinity 63 | self.fill_value = jnp.inf 64 | else: 65 | # For other types, keep None 66 | self.fill_value = fill_value 67 | else: 68 | # Use the explicitly provided fill_value 69 | self.fill_value = fill_value 70 | self.intrinsic_shape: Tuple[int, ...] = intrinsic_shape 71 | 72 | def __repr__(self) -> str: 73 | return ( 74 | f"FieldDescriptor(dtype={self.dtype}, " 75 | f"fill_value={self.fill_value}, " 76 | f"intrinsic_shape={self.intrinsic_shape})" 77 | ) 78 | 79 | @classmethod 80 | def __class_getitem__(cls, item: Any) -> "FieldDescriptor": 81 | """ 82 | Allows for syntax like FieldDescriptor[dtype, intrinsic_shape, fill_value]. 83 | """ 84 | if isinstance(item, tuple): 85 | if len(item) == 1: 86 | return cls(item[0]) 87 | elif len(item) == 2: 88 | # Assuming item[1] is intrinsic_shape or fill_value. 89 | # Heuristic: if it's a tuple, it's intrinsic_shape. Otherwise, it could be fill_value. 90 | # This could be ambiguous. For clarity, users might prefer named args with __init__ 91 | # or a more structured approach if this becomes complex. 92 | if isinstance(item[1], tuple): 93 | return cls(item[0], intrinsic_shape=item[1]) 94 | else: # Assuming it's a fill_value, and intrinsic_shape is default 95 | return cls(item[0], fill_value=item[1]) 96 | elif len(item) == 3: 97 | return cls(item[0], intrinsic_shape=item[1], fill_value=item[2]) 98 | else: 99 | raise ValueError( 100 | "FieldDescriptor[...] expects 1 to 3 arguments: " 101 | "dtype, optional intrinsic_shape, optional fill_value" 102 | ) 103 | else: 104 | # Single item is treated as dtype 105 | return cls(item) 106 | 107 | 108 | # Example usage (to be placed in your class definitions later): 109 | # 110 | # from xtructure.field_descriptors import FieldDescriptor 111 | # 112 | # @xtructure_data 113 | # class MyData: 114 | # my_scalar_int: FieldDescriptor[jnp.int32, (), -1] 115 | # my_vector_float: FieldDescriptor[jnp.float32, (10,), 0.0] 116 | # my_default_shape_int: FieldDescriptor[jnp.uint8] 117 | # # ... other fields 118 | # 119 | # # The .default() method would be auto-generated by @xtructure_data 120 | # # using these descriptors. 121 | -------------------------------------------------------------------------------- /xtructure/core/protocol.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Protocol 2 | from typing import Tuple as TypingTuple 3 | from typing import Type, TypeVar 4 | 5 | import chex 6 | 7 | from .structuredtype import StructuredType 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | # Protocol defining the interface added by @xtructure_data 13 | class Xtructurable(Protocol[T]): 14 | # Fields from the original class that chex.dataclass would process 15 | # These are implicitly part of T. For the protocol to be complete, 16 | # it assumes T will have __annotations__. 17 | __annotations__: Dict[str, Any] 18 | # __dict__ is used by the __getitem__ implementation 19 | __dict__: Dict[str, Any] 20 | 21 | # Methods and properties added by add_shape_dtype_len 22 | @property 23 | def shape(self) -> Any: # Actual type is a dynamically generated namedtuple 24 | ... 25 | 26 | @property 27 | def dtype(self) -> Any: # Actual type is a dynamically generated namedtuple 28 | ... 29 | 30 | # Method added by add_indexing_methods (responsible for __getitem__) 31 | def __getitem__(self: T, index: Any) -> T: 32 | ... 33 | 34 | # Method added by add_shape_dtype_len 35 | def __len__(self) -> int: 36 | ... 37 | 38 | # Methods and properties added by add_structure_utilities 39 | # Assumes the class T has a 'default' classmethod as per the decorator's assertion 40 | @classmethod 41 | def default(cls: Type[T], shape: Any = ...) -> T: 42 | ... 43 | 44 | @property 45 | def default_shape(self) -> Any: # Derived from self.default().shape 46 | ... 47 | 48 | @property 49 | def structured_type(self) -> "StructuredType": # Forward reference for StructuredType 50 | ... 51 | 52 | @property 53 | def batch_shape(self) -> TypingTuple[int, ...]: 54 | ... 55 | 56 | def reshape(self: T, new_shape: TypingTuple[int, ...]) -> T: 57 | ... 58 | 59 | def flatten(self: T) -> T: 60 | ... 61 | 62 | @classmethod 63 | def random( 64 | cls: Type[T], shape: TypingTuple[int, ...] = ..., key: Any = ... 65 | ) -> T: # Ellipsis for default value 66 | ... 67 | 68 | # Methods and properties added by add_string_representation_methods 69 | def __str__( 70 | self, 71 | ) -> str: # The actual implementation takes **kwargs, but signature can be simpler for Protocol 72 | ... 73 | 74 | def str(self) -> str: # Alias for __str__ 75 | ... 76 | 77 | # Method added by add_indexing_methods 78 | def at(self: T, index: Any) -> "AtIndexer": 79 | ... 80 | 81 | @property 82 | def bytes(self: T) -> chex.Array: 83 | ... 84 | 85 | def hash(self: T, seed: int = 0) -> TypingTuple[int, chex.Array]: 86 | ... 87 | 88 | 89 | class AtIndexer(Protocol[T]): 90 | def __getitem__(self: T, index: Any) -> "Updater": 91 | ... 92 | 93 | 94 | class Updater(Protocol[T]): 95 | def set(self: T, value: Any) -> T: 96 | ... 97 | 98 | def set_as_condition(self: T, condition: chex.Array, value: Any) -> T: 99 | ... 100 | -------------------------------------------------------------------------------- /xtructure/core/structuredtype.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | # enum for state type 5 | class StructuredType(Enum): 6 | SINGLE = 0 7 | BATCHED = 1 8 | UNSTRUCTURED = 2 9 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type, TypeVar 2 | 3 | import chex 4 | 5 | from xtructure.core.protocol import Xtructurable 6 | 7 | from .default import add_default_method 8 | from .hash import hash_function_decorator 9 | from .indexing import add_indexing_methods 10 | from .ops import add_comparison_operators 11 | from .shape import add_shape_dtype_len 12 | from .string_format import add_string_representation_methods 13 | from .structure_util import add_structure_utilities 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def xtructure_dataclass(cls: Type[T]) -> Type[Xtructurable[T]]: 19 | """ 20 | Decorator that ensures the input class is a `chex.dataclass` (or converts 21 | it to one) and then augments it with additional functionality related to its 22 | structure, type, and operations like indexing, default instance creation, 23 | random instance generation, and string representation. 24 | 25 | It adds properties like `shape`, `dtype`, `default_shape`, `structured_type`, 26 | `batch_shape`, and methods like `__getitem__`, `__len__`, `reshape`, 27 | `flatten`, `random`, and `__str__`. 28 | 29 | Args: 30 | cls: The class to be decorated. It is expected to have a `default` 31 | classmethod for some functionalities. 32 | 33 | Returns: 34 | The decorated class with the aforementioned additional functionalities. 35 | """ 36 | cls = chex.dataclass(cls) 37 | 38 | # Ensure class has a default method for initialization 39 | cls = add_default_method(cls) 40 | 41 | # Ensure class has a default method for initialization 42 | assert hasattr(cls, "default"), "xtructureclass must have a default method." 43 | 44 | # add shape and dtype and len 45 | cls = add_shape_dtype_len(cls) 46 | 47 | # add indexing methods 48 | cls = add_indexing_methods(cls) 49 | 50 | # add structure utilities and random 51 | cls = add_structure_utilities(cls) 52 | 53 | # add string representation methods 54 | cls = add_string_representation_methods(cls) 55 | 56 | # add hash function 57 | cls = hash_function_decorator(cls) 58 | 59 | # add comparison operators 60 | cls = add_comparison_operators(cls) 61 | 62 | setattr(cls, "is_xtructed", True) 63 | 64 | return cls 65 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/annotate.py: -------------------------------------------------------------------------------- 1 | MAX_PRINT_BATCH_SIZE = 4 2 | SHOW_BATCH_SIZE = 2 3 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/default.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, NamedTuple 2 | from typing import Tuple as TypingTuple 3 | from typing import Type, TypeVar, Union 4 | 5 | import jax.numpy as jnp 6 | 7 | from xtructure.core.field_descriptors import FieldDescriptor 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def is_xtructure_class(dtype: Any) -> bool: 13 | if isinstance(dtype, type): 14 | if hasattr(dtype, "is_xtructed"): 15 | return True 16 | return False 17 | else: 18 | return False 19 | 20 | 21 | class FieldInfo(NamedTuple): 22 | """Pre-computed field information for efficient default value generation.""" 23 | 24 | name: str 25 | field_type: str 26 | # 'jax_primitive_descriptor', 'jax_dtype_descriptor', 'nested_class_descriptor', 'nested_class_direct' 27 | descriptor: Union[FieldDescriptor, None] 28 | dtype: Any 29 | fill_value: Any 30 | intrinsic_shape: TypingTuple[int, ...] 31 | nested_class_type: Union[Type, None] 32 | 33 | 34 | def add_default_method(cls: Type[T]) -> Type[T]: 35 | 36 | if any([not isinstance(i, FieldDescriptor) for i in cls.__annotations__.values()]): 37 | invalid_annotations = [ 38 | (name, type(annotation).__name__) 39 | for name, annotation in cls.__annotations__.items() 40 | if not isinstance(annotation, FieldDescriptor) 41 | ] 42 | raise ValueError( 43 | f"xtructure_dataclass can only have FieldDescriptor annotations." 44 | f"Found invalid annotations: {invalid_annotations}" 45 | ) 46 | 47 | # add default method to class 48 | setattr(cls, "default", _create_default_method(cls)) 49 | return cls 50 | 51 | 52 | def _create_default_method(cls_to_modify: Type[T]) -> Callable[..., T]: 53 | annotations = getattr(cls_to_modify, "__annotations__", {}) 54 | 55 | # Pre-compute field information during method creation 56 | field_infos: List[FieldInfo] = [] 57 | 58 | for field_name, annotation_obj in annotations.items(): 59 | descriptor = annotation_obj 60 | dtype_of_field_descriptor = descriptor.dtype 61 | 62 | if is_xtructure_class(dtype_of_field_descriptor): 63 | # It's a user-defined xtructure class. Use its .default() method. 64 | nested_class_type = dtype_of_field_descriptor 65 | if not hasattr(nested_class_type, "default"): 66 | raise TypeError( 67 | f"Error during method creation for '{cls_to_modify.__name__}': " 68 | f"Nested field '{field_name}' (type '{nested_class_type.__name__}' " 69 | f"via FieldDescriptor.dtype) does not have a .default() method. " 70 | f"Ensure it's an @xtructure_data class." 71 | ) 72 | intrinsic_shape = ( 73 | descriptor.intrinsic_shape 74 | if isinstance(descriptor.intrinsic_shape, tuple) 75 | else (descriptor.intrinsic_shape,) 76 | ) 77 | field_infos.append( 78 | FieldInfo( 79 | name=field_name, 80 | field_type="nested_class_descriptor", 81 | descriptor=descriptor, 82 | dtype=None, 83 | fill_value=None, 84 | intrinsic_shape=intrinsic_shape, 85 | nested_class_type=nested_class_type, 86 | ) 87 | ) 88 | elif isinstance(dtype_of_field_descriptor, type): 89 | # Check if it's a JAX primitive type class 90 | is_jax_primitive_type_class = False 91 | try: 92 | if jnp.issubdtype(dtype_of_field_descriptor, jnp.number) or jnp.issubdtype( 93 | dtype_of_field_descriptor, jnp.bool_ 94 | ): 95 | is_jax_primitive_type_class = True 96 | except TypeError: # Not a type that jnp.issubdtype recognizes as a primitive base 97 | is_jax_primitive_type_class = False 98 | 99 | if is_jax_primitive_type_class: 100 | # It's like jnp.int32, jnp.float32. Use jnp.full. 101 | intrinsic_shape = ( 102 | descriptor.intrinsic_shape 103 | if isinstance(descriptor.intrinsic_shape, tuple) 104 | else (descriptor.intrinsic_shape,) 105 | ) 106 | field_infos.append( 107 | FieldInfo( 108 | name=field_name, 109 | field_type="jax_primitive_descriptor", 110 | descriptor=descriptor, 111 | dtype=dtype_of_field_descriptor, 112 | fill_value=descriptor.fill_value, 113 | intrinsic_shape=intrinsic_shape, 114 | nested_class_type=None, 115 | ) 116 | ) 117 | else: 118 | # It's some other type class that we don't support 119 | raise TypeError( 120 | f"Error during method creation for '{cls_to_modify.__name__}': " 121 | f"Field '{field_name}' uses FieldDescriptor with an unsupported " 122 | f"type class: '{dtype_of_field_descriptor}' " 123 | f"(type: {type(dtype_of_field_descriptor).__name__}). " 124 | f"Expected a JAX primitive type/class (like jnp.int32) or an @xtructure_data class type." 125 | ) 126 | elif isinstance(dtype_of_field_descriptor, jnp.dtype): 127 | # dtype_of_field_descriptor is a JAX dtype INSTANCE (e.g., jnp.dtype('int32')). Use jnp.full. 128 | intrinsic_shape = ( 129 | descriptor.intrinsic_shape 130 | if isinstance(descriptor.intrinsic_shape, tuple) 131 | else (descriptor.intrinsic_shape,) 132 | ) 133 | field_infos.append( 134 | FieldInfo( 135 | name=field_name, 136 | field_type="jax_dtype_descriptor", 137 | descriptor=descriptor, 138 | dtype=dtype_of_field_descriptor, 139 | fill_value=descriptor.fill_value, 140 | intrinsic_shape=intrinsic_shape, 141 | nested_class_type=None, 142 | ) 143 | ) 144 | else: 145 | # FieldDescriptor.dtype is neither a recognized class nor a jnp.dtype instance. 146 | raise TypeError( 147 | f"Error during method creation for '{cls_to_modify.__name__}': " 148 | f"Field '{field_name}' uses FieldDescriptor with an unsupported " 149 | f".dtype attribute: '{dtype_of_field_descriptor}' " 150 | f"(type: {type(dtype_of_field_descriptor).__name__}). " 151 | f"Expected a JAX primitive type/class (like jnp.int32 or " 152 | f"jnp.dtype('int32')), or an @xtructure_data class type (like Parent)." 153 | ) 154 | 155 | @classmethod 156 | def default(cls: Type[T], shape: TypingTuple[int, ...] = ()) -> T: 157 | default_values: Dict[str, Any] = {} 158 | 159 | # Use pre-computed field information for efficient value generation 160 | for field_info in field_infos: 161 | if field_info.field_type == "jax_primitive_descriptor": 162 | field_shape = shape + field_info.intrinsic_shape 163 | default_values[field_info.name] = jnp.full( 164 | field_shape, 165 | field_info.fill_value, 166 | dtype=field_info.dtype, 167 | ) 168 | elif field_info.field_type == "jax_dtype_descriptor": 169 | field_shape = shape + field_info.intrinsic_shape 170 | default_values[field_info.name] = jnp.full( 171 | field_shape, field_info.fill_value, dtype=field_info.dtype 172 | ) 173 | elif field_info.field_type == "nested_class_descriptor": 174 | field_shape = shape + field_info.intrinsic_shape 175 | default_values[field_info.name] = field_info.nested_class_type.default( 176 | shape=field_shape 177 | ) 178 | elif field_info.field_type == "nested_class_direct": 179 | default_values[field_info.name] = field_info.nested_class_type.default(shape=shape) 180 | return cls(**default_values) 181 | 182 | return default 183 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/hash.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from .. import Xtructurable 5 | 6 | 7 | def rotl(x, n): 8 | """Rotate left operation for 32-bit integers.""" 9 | return (x << n) | (x >> (32 - n)) 10 | 11 | 12 | @jax.jit 13 | def xxhash(x, seed): 14 | """ 15 | Implementation of xxHash algorithm for 32-bit integers. 16 | Args: 17 | x: Input value to hash 18 | seed: Seed value for hash function 19 | Returns: 20 | 32-bit hash value 21 | """ 22 | prime_1 = jnp.uint32(0x9E3779B1) 23 | prime_2 = jnp.uint32(0x85EBCA77) 24 | prime_3 = jnp.uint32(0xC2B2AE3D) 25 | prime_5 = jnp.uint32(0x165667B1) 26 | acc = jnp.uint32(seed) + prime_5 27 | for _ in range(4): 28 | lane = x & 255 29 | acc = acc + lane * prime_5 30 | acc = rotl(acc, 11) * prime_1 31 | x = x >> 8 32 | acc = acc ^ (acc >> 15) 33 | acc = acc * prime_2 34 | acc = acc ^ (acc >> 13) 35 | acc = acc * prime_3 36 | acc = acc ^ (acc >> 16) 37 | return acc 38 | 39 | 40 | def byterize_hash_func_builder(x: Xtructurable): 41 | """ 42 | Build a hash function for the pytree. 43 | This function creates a JIT-compiled hash function that converts pytree to bytes 44 | and then to uint32 arrays for hashing. 45 | 46 | Args: 47 | x: Example pytree to determine the structure 48 | Returns: 49 | JIT-compiled hash function that takes a pytree and seed 50 | """ 51 | 52 | @jax.jit 53 | def _to_bytes(x): 54 | """Convert input to byte array.""" 55 | # Check if x is a JAX boolean array and cast to uint8 if true 56 | if x.dtype == jnp.bool_: 57 | x = x.astype(jnp.uint8) 58 | return jax.lax.bitcast_convert_type(x, jnp.uint8).reshape(-1) 59 | 60 | @jax.jit 61 | def _byterize(x): 62 | """Convert entire state tree to flattened byte array.""" 63 | x = jax.tree_util.tree_map(_to_bytes, x) 64 | x, _ = jax.tree_util.tree_flatten(x) 65 | if len(x) == 0: 66 | return jnp.array([], dtype=jnp.uint8) 67 | return jnp.concatenate(x) 68 | 69 | default_bytes = _byterize(x.default()) 70 | bytes_len = default_bytes.shape[0] 71 | # Calculate padding needed to make byte length multiple of 4 72 | pad_len = jnp.where(bytes_len % 4 != 0, 4 - (bytes_len % 4), 0) 73 | 74 | if pad_len > 0: 75 | 76 | def _to_uint32(bytes): 77 | """Convert padded bytes to uint32 array.""" 78 | x_padded = jnp.pad(bytes, (pad_len, 0), mode="constant", constant_values=0) 79 | x_reshaped = jnp.reshape(x_padded, (-1, 4)) 80 | return jax.vmap(lambda x: jax.lax.bitcast_convert_type(x, jnp.uint32))( 81 | x_reshaped 82 | ).reshape(-1) 83 | 84 | else: 85 | 86 | def _to_uint32(bytes): 87 | """Convert bytes directly to uint32 array.""" 88 | x_reshaped = jnp.reshape(bytes, (-1, 4)) 89 | return jax.vmap(lambda x: jax.lax.bitcast_convert_type(x, jnp.uint32))( 90 | x_reshaped 91 | ).reshape(-1) 92 | 93 | def _h(x, seed=0): 94 | """ 95 | Main hash function that converts state to bytes and applies xxhash. 96 | Returns both hash value and byte representation. 97 | """ 98 | bytes = x.bytes 99 | uint32ed = _to_uint32(bytes) 100 | 101 | def scan_body(seed, x): 102 | result = xxhash(x, seed) 103 | return result, result 104 | 105 | hash_value, _ = jax.lax.scan(scan_body, seed, uint32ed) 106 | return hash_value, bytes 107 | 108 | return jax.jit(_byterize), jax.jit(_h) 109 | 110 | 111 | def hash_function_decorator(cls): 112 | """ 113 | Decorator to add a hash function to a class. 114 | """ 115 | byterize, hash_func = byterize_hash_func_builder(cls) 116 | 117 | setattr(cls, "bytes", property(byterize)) 118 | setattr(cls, "hash", hash_func) 119 | 120 | return cls 121 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/indexing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, TypeVar 2 | 3 | import jax.numpy as jnp 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | class _Updater: 9 | def __init__(self, obj_instance, index): 10 | self.obj_instance = obj_instance 11 | self.indices = index 12 | self.cls = obj_instance.__class__ 13 | 14 | def set(self, values_to_set): 15 | new_field_data = {} 16 | 17 | if not hasattr(self.cls, "__dataclass_fields__"): 18 | raise TypeError( 19 | f"Class {self.cls.__name__} is not a recognized dataclass or does not have __dataclass_fields__. " 20 | f"The .at[...].set(...) feature expects a dataclass structure." 21 | ) 22 | 23 | for field_name in self.cls.__dataclass_fields__: 24 | current_field_value = getattr(self.obj_instance, field_name) 25 | 26 | try: 27 | updater_ref = current_field_value.at[self.indices] 28 | if hasattr(updater_ref, "set"): 29 | value_for_this_field = None 30 | if isinstance(values_to_set, self.cls): 31 | value_for_this_field = getattr(values_to_set, field_name) 32 | else: 33 | value_for_this_field = values_to_set 34 | 35 | new_field_data[field_name] = updater_ref.set(value_for_this_field) 36 | else: 37 | new_field_data[field_name] = current_field_value 38 | except Exception: 39 | new_field_data[field_name] = current_field_value 40 | 41 | return self.cls(**new_field_data) 42 | 43 | def set_as_condition(self, condition: jnp.ndarray, value_to_conditionally_set: Any): 44 | """ 45 | Sets parts of the fields of the dataclass instance based on a condition. 46 | This is an out-of-place update. 47 | 48 | Args: 49 | condition: A JAX boolean array. Its shape should be compatible with 50 | the slice of the fields selected by `self.indices` through broadcasting. 51 | It determines element-wise whether to use the new value 52 | or the original value. 53 | value_to_conditionally_set: The value(s) to set if the condition is true. 54 | - If it's an instance of the same dataclass type (`self.cls`), 55 | the corresponding fields from this instance are used for updates. 56 | - Otherwise (e.g., a scalar or a JAX array), this value is used 57 | for updating all applicable fields (it must be broadcast-compatible 58 | with the slice of each field). 59 | Returns: 60 | A new instance of the dataclass with updated fields. 61 | """ 62 | new_field_data = {} 63 | 64 | if not hasattr(self.cls, "__dataclass_fields__"): 65 | raise TypeError( 66 | f"Class {self.cls.__name__} is not a recognized dataclass or does not have __dataclass_fields__. " 67 | f"The .at[...].set_as_condition(...) feature expects a dataclass structure." 68 | ) 69 | 70 | for field_name in self.cls.__dataclass_fields__: 71 | original_field_value = getattr(self.obj_instance, field_name) 72 | 73 | update_val_for_this_field_if_true = None 74 | if isinstance(value_to_conditionally_set, self.cls): 75 | update_val_for_this_field_if_true = getattr(value_to_conditionally_set, field_name) 76 | else: 77 | update_val_for_this_field_if_true = value_to_conditionally_set 78 | 79 | try: 80 | if isinstance(getattr(original_field_value, "at", None), AtIndexer): 81 | nested_updater = original_field_value.at[self.indices] 82 | new_field_data[field_name] = nested_updater.set_as_condition( 83 | condition, update_val_for_this_field_if_true 84 | ) 85 | elif hasattr(original_field_value, "at") and hasattr( 86 | original_field_value.at[self.indices], "set" 87 | ): 88 | original_slice_of_field = original_field_value[self.indices] 89 | 90 | # Ensure condition is a JAX array to get its ndim property 91 | cond_array = jnp.asarray(condition) 92 | data_rank = original_slice_of_field.ndim 93 | condition_rank = cond_array.ndim 94 | 95 | reshaped_cond = cond_array 96 | if data_rank > condition_rank: 97 | num_new_axes = data_rank - condition_rank 98 | reshaped_cond = cond_array.reshape(cond_array.shape + (1,) * num_new_axes) 99 | # If condition_rank >= data_rank, jnp.where will handle broadcasting or error appropriately. 100 | 101 | conditionally_updated_slice = jnp.where( 102 | reshaped_cond, update_val_for_this_field_if_true, original_slice_of_field 103 | ) 104 | new_field_data[field_name] = original_field_value.at[self.indices].set( 105 | conditionally_updated_slice 106 | ) 107 | else: 108 | new_field_data[field_name] = original_field_value 109 | except Exception as e: 110 | import sys 111 | 112 | print( 113 | f"Warning: Could not apply conditional set to field '{field_name}' " 114 | f"of class '{self.cls.__name__}'. Error: {e}", 115 | file=sys.stderr, 116 | ) 117 | new_field_data[field_name] = original_field_value 118 | 119 | return self.cls(**new_field_data) 120 | 121 | 122 | class AtIndexer: 123 | def __init__(self, obj_instance): 124 | self.obj_instance = obj_instance 125 | 126 | def __getitem__(self, index): 127 | return _Updater(self.obj_instance, index) 128 | 129 | 130 | def add_indexing_methods(cls: Type[T]) -> Type[T]: 131 | """ 132 | Augments the class with an `__getitem__` method for indexing/slicing 133 | and an `at` property that enables JAX-like out-of-place updates 134 | (e.g., `instance.at[index].set(value)`). 135 | 136 | The `__getitem__` method allows instances to be indexed, applying the 137 | index to each field. 138 | The `at` property provides access to an updater object for specific indices. 139 | """ 140 | 141 | def getitem(self, index): 142 | """Support indexing operations on the dataclass""" 143 | new_values = {} 144 | for field_name, field_value in self.__dict__.items(): 145 | if hasattr(field_value, "__getitem__"): 146 | new_values[field_name] = field_value[index] 147 | else: 148 | new_values[field_name] = field_value 149 | return cls(**new_values) 150 | 151 | setattr(cls, "__getitem__", getitem) 152 | setattr(cls, "at", property(AtIndexer)) 153 | 154 | return cls 155 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/ops.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, TypeVar 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | def add_comparison_operators(cls: Type[T]) -> Type[T]: 10 | """ 11 | Adds custom __eq__ and __ne__ methods to the class. 12 | These methods perform element-wise comparisons on the fields 13 | of the dataclass and return a new instance of the class 14 | containing boolean arrays. 15 | """ 16 | 17 | def _xtructure_eq(self, other: Any) -> T: 18 | if not isinstance(other, self.__class__): 19 | # If comparing with a different type, one might return False 20 | # or NotImplemented. For element-wise comparison resulting in a 21 | # structure, raising an error or returning a structure of False 22 | # might be alternatives. JAX's __eq__ on arrays would raise 23 | # an error or broadcast if shapes are incompatible. 24 | # Here, we'll opt for a structure of False values if types don't match 25 | # or if users expect a single boolean, this override might be surprising. 26 | # A more robust approach for general pytrees might involve checking 27 | # tree structure compatibility. 28 | # For now, returning NotImplemented is safest if 'other' isn't the same type. 29 | return NotImplemented 30 | 31 | # Element-wise comparison for each field 32 | tree_equal = jax.tree_util.tree_map(lambda x, y: jnp.all(x == y), self, other) 33 | return jax.tree_util.tree_reduce(jnp.logical_and, tree_equal) 34 | 35 | def _xtructure_ne(self, other: Any) -> T: 36 | if not isinstance(other, self.__class__): 37 | return NotImplemented 38 | 39 | # Element-wise comparison for each field 40 | tree_equal = jax.tree_util.tree_map(lambda x, y: jnp.any(x != y), self, other) 41 | return jax.tree_util.tree_reduce(jnp.logical_or, tree_equal) 42 | 43 | setattr(cls, "__eq__", _xtructure_eq) 44 | setattr(cls, "__ne__", _xtructure_ne) 45 | 46 | return cls 47 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/shape.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Type, TypeVar 3 | 4 | from xtructure.core.field_descriptors import FieldDescriptor 5 | from xtructure.core.protocol import StructuredType 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | def add_shape_dtype_len(cls: Type[T]) -> Type[T]: 11 | """ 12 | Augments the class with `shape` and `dtype` properties to inspect its 13 | fields, and a `__len__` method. 14 | 15 | The `shape` and `dtype` properties return namedtuples reflecting the 16 | structure of the dataclass fields. 17 | The `__len__` method conventionally returns the size of the first 18 | dimension of the first field of the instance, which is often useful 19 | for determining batch sizes. 20 | """ 21 | shape_tuple = namedtuple("shape", ["batch"] + list(cls.__annotations__.keys())) 22 | field_descriptors: dict[str, FieldDescriptor] = cls.__annotations__ 23 | default_shape = namedtuple("default_shape", cls.__annotations__.keys())( 24 | *[fd.intrinsic_shape for fd in field_descriptors.values()] 25 | ) 26 | default_dtype = namedtuple("default_dtype", cls.__annotations__.keys())( 27 | *[fd.dtype for fd in field_descriptors.values()] 28 | ) 29 | 30 | cls.default_shape = default_shape 31 | cls.default_dtype = default_dtype 32 | 33 | def get_shape(self) -> shape_tuple: 34 | """ 35 | Returns a namedtuple containing the batch shape (if present) and the shapes of all fields. 36 | If a field is itself a xtructure_dataclass, its shape is included as a nested namedtuple. 37 | """ 38 | # Determine batch: if all fields have a leading batch dimension of the same size, use it. 39 | # Otherwise, batch is (). 40 | field_shapes = [] 41 | batch_shapes = [] 42 | for field_name in cls.__annotations__.keys(): 43 | shape = getattr(self, field_name).shape 44 | default_shape_field = getattr(default_shape, field_name) 45 | if ( 46 | isinstance(shape, tuple) 47 | and hasattr(shape, "_fields") 48 | and shape.__class__.__name__ == "shape" 49 | ): 50 | # If the field is itself a xtructure_dataclass (nested shape_tuple) 51 | if default_shape_field == (): 52 | batch_shapes.append(shape.batch) 53 | shape = shape.__class__((), *shape[1:]) 54 | elif shape.batch[-len(default_shape_field) :] == default_shape_field: 55 | batch_shapes.append(shape.batch[: -len(default_shape_field)]) 56 | cuted_batch_shape = shape.batch[-len(default_shape_field) :] 57 | shape = shape.__class__(cuted_batch_shape, *shape[1:]) 58 | else: 59 | batch_shapes.append(-1) 60 | else: 61 | if default_shape_field == (): 62 | batch_shapes.append(shape) 63 | shape = () 64 | elif shape[-len(default_shape_field) :] == default_shape_field: 65 | batch_shapes.append(shape[: -len(default_shape_field)]) 66 | shape = shape[-len(default_shape_field) :] 67 | else: 68 | batch_shapes.append(-1) 69 | field_shapes.append(shape) 70 | 71 | final_batch_shape = batch_shapes[0] 72 | for batch_shape in batch_shapes[1:]: 73 | if batch_shape == -1: 74 | final_batch_shape = -1 75 | break 76 | if final_batch_shape != batch_shape: 77 | final_batch_shape = -1 78 | break 79 | return shape_tuple(final_batch_shape, *field_shapes) 80 | 81 | setattr(cls, "shape", property(get_shape)) 82 | 83 | type_tuple = namedtuple("dtype", cls.__annotations__.keys()) 84 | 85 | def get_type(self) -> type_tuple: 86 | """Get dtypes of all fields in the dataclass""" 87 | return type_tuple( 88 | *[getattr(self, field_name).dtype for field_name in cls.__annotations__.keys()] 89 | ) 90 | 91 | setattr(cls, "dtype", property(get_type)) 92 | 93 | def get_len(self): 94 | """Get length of the first field's first dimension""" 95 | return self.shape[0][0] 96 | 97 | setattr(cls, "__len__", get_len) 98 | 99 | def get_structured_type(self) -> StructuredType: 100 | shape = self.shape 101 | if shape.batch == (): 102 | return StructuredType.SINGLE 103 | elif shape.batch == -1: 104 | return StructuredType.UNSTRUCTURED 105 | else: 106 | return StructuredType.BATCHED 107 | 108 | setattr(cls, "structured_type", property(get_structured_type)) 109 | 110 | return cls 111 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/string_format.py: -------------------------------------------------------------------------------- 1 | from typing import Type, TypeVar 2 | 3 | import jax.numpy as jnp 4 | from tabulate import tabulate 5 | 6 | from xtructure.core.structuredtype import StructuredType 7 | 8 | from .annotate import MAX_PRINT_BATCH_SIZE, SHOW_BATCH_SIZE 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | def add_string_representation_methods(cls: Type[T]) -> Type[T]: 14 | """ 15 | Adds custom `__str__` and `str` methods to the class for generating 16 | a more informative string representation. 17 | 18 | It handles instances categorized by `structured_type` differently: 19 | - `SINGLE`: Uses the original `__str__` (or `repr` if basic) of the instance. 20 | - `BATCHED`: Provides a summarized view if the batch is large, showing 21 | the first few and last few elements, along with the batch shape. 22 | Uses `tabulate` for formatting. 23 | - `UNSTRUCTURED`: Indicates that the data is unstructured relative to its 24 | default shape. 25 | """ 26 | 27 | # Capture the class's __str__ method as it exists *before* this decorator replaces it. 28 | # This will typically be the __str__ provided by chex.dataclass (similar to its __repr__), 29 | # or a user-defined __str__ if the user added one before @xtructure_data. 30 | _original_str_method = getattr(cls, "__str__", None) 31 | 32 | # Determine the function to use for formatting a single item. 33 | # If the original __str__ is just the basic one from `object`, it's not very informative. 34 | # In such cases, or if no __str__ was found, `repr` is a better fallback for dataclasses. 35 | if _original_str_method is None or _original_str_method == object.__str__: 36 | _single_item_formatter = _custom_pretty_formatter 37 | else: 38 | # Use the captured original __str__ method. 39 | def _single_item_formatter(item, **k): 40 | return _original_str_method(item, **k) 41 | 42 | # Note: Original __str__ methods typically don't take **kwargs. 43 | # If kwargs support is needed for the single item formatter, 44 | # the user would need to define a specific method and the decorator would look for that. 45 | # For now, we assume the original __str__ doesn't use kwargs from get_str. 46 | 47 | def get_str(self, use_kwargs: bool = False, **kwargs) -> str: 48 | # This 'self' is an instance of the decorated class 'cls' 49 | # 'kwargs' are passed from the print(instance) or str(instance) call. 50 | 51 | structured_type = self.structured_type # This must be a valid property 52 | 53 | if structured_type == StructuredType.SINGLE: 54 | # For a single item, call the chosen formatter. 55 | if use_kwargs: 56 | return _single_item_formatter(self, **kwargs) 57 | else: 58 | return _single_item_formatter(self) # **kwargs will be an empty dict 59 | 60 | elif structured_type == StructuredType.BATCHED: 61 | batch_shape = self.shape.batch 62 | batch_len_val = ( 63 | jnp.prod(jnp.array(batch_shape)) if len(batch_shape) != 1 else batch_shape[0] 64 | ) 65 | py_batch_len = int(batch_len_val) 66 | 67 | results = [] 68 | if py_batch_len <= MAX_PRINT_BATCH_SIZE: 69 | for i in range(py_batch_len): 70 | index = jnp.unravel_index(i, batch_shape) 71 | current_state_slice = self[index] 72 | # kwargs_idx = {k: v[index] for k, v in kwargs.items()} # Index kwargs if they are batched 73 | # For now, assume single_item_formatter doesn't use these indexed kwargs 74 | if use_kwargs: 75 | results.append(_single_item_formatter(current_state_slice, **kwargs)) 76 | else: 77 | results.append(_single_item_formatter(current_state_slice)) 78 | else: 79 | for i in range(SHOW_BATCH_SIZE): 80 | index = jnp.unravel_index(i, batch_shape) 81 | current_state_slice = self[index] 82 | if use_kwargs: 83 | results.append(_single_item_formatter(current_state_slice, **kwargs)) 84 | else: 85 | results.append(_single_item_formatter(current_state_slice)) 86 | 87 | results.append("...\n(batch : " + f"{batch_shape})") 88 | 89 | for i in range(py_batch_len - SHOW_BATCH_SIZE, py_batch_len): 90 | index = jnp.unravel_index(i, batch_shape) 91 | current_state_slice = self[index] 92 | if use_kwargs: 93 | results.append(_single_item_formatter(current_state_slice, **kwargs)) 94 | else: 95 | results.append(_single_item_formatter(current_state_slice)) 96 | return tabulate([results], tablefmt="plain") 97 | else: # UNSTRUCTURED or any other case 98 | # Fallback for unstructured or unexpected types to avoid errors, 99 | # or re-raise the original error if preferred. 100 | # The original code raised: ValueError(f"State is not structured: {self.shape} != {self.default_shape}") 101 | # Using repr as a safe fallback: 102 | return f"" 103 | 104 | setattr(cls, "__str__", lambda self, **kwargs: get_str(self, use_kwargs=False, **kwargs)) 105 | setattr( 106 | cls, "str", lambda self, **kwargs: get_str(self, use_kwargs=True, **kwargs) 107 | ) # Alias .str to the new __str__ 108 | return cls 109 | 110 | 111 | def _custom_pretty_formatter(item, **_kwargs): # Accepts and ignores _kwargs for now 112 | class_name = item.__class__.__name__ 113 | 114 | field_values = {} 115 | # Prioritize __dataclass_fields__ for declared fields in dataclasses 116 | if hasattr(item, "__dataclass_fields__"): 117 | for field_name_df in getattr(item, "__dataclass_fields__", {}).keys(): 118 | try: 119 | field_values[field_name_df] = getattr(item, field_name_df) 120 | except AttributeError: 121 | # Field declared but not present; should be rare for dataclasses 122 | pass 123 | elif hasattr(item, "__dict__"): 124 | # Fallback for non-dataclasses or if __dataclass_fields__ is not found/empty 125 | field_values = item.__dict__ 126 | else: 127 | # No way to access fields, fallback to simple repr 128 | return repr(item) 129 | 130 | if not field_values: 131 | return f"{class_name}()" 132 | 133 | parts = [] 134 | for name, value in field_values.items(): 135 | try: 136 | value_str = str(value) # Use str() to leverage our enhanced __str__ for nested items 137 | except Exception: 138 | value_str = "" 139 | 140 | if "\n" in value_str: 141 | # Indent all lines of the multi-line value string for better readability 142 | indented_value = "\n".join([" " + line for line in value_str.split("\n")]) 143 | parts.append(f" {name}: \n{indented_value}") 144 | else: 145 | parts.append(f" {name}: {value_str}") 146 | 147 | return f"{class_name}(\n" + ",\n".join(parts) + "\n)" 148 | -------------------------------------------------------------------------------- /xtructure/core/xtructure_decorators/structure_util.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, TypeVar 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from xtructure.core.field_descriptors import FieldDescriptor 7 | from xtructure.core.structuredtype import StructuredType 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | def is_nested_xtructure(dtype: Any) -> bool: 13 | if isinstance(dtype, type): 14 | if hasattr(dtype, "is_xtructed"): 15 | return True 16 | return False 17 | else: 18 | return False 19 | 20 | 21 | def add_structure_utilities(cls: Type[T]) -> Type[T]: 22 | """ 23 | Augments the class with utility methods and properties related to its 24 | structural representation (based on a 'default' instance), batch operations, 25 | and random instance generation. 26 | 27 | Requires the class to have a `default` classmethod, which is used to 28 | determine default shapes, dtypes, and behaviors. 29 | 30 | Adds: 31 | - Properties: 32 | - `default_shape`: Shape of the instance returned by `cls.default()`. 33 | - `structured_type`: An enum (`StructuredType`) indicating if the 34 | instance is SINGLE, BATCHED, or UNSTRUCTURED relative to its 35 | default shape. 36 | - `batch_shape`: The shape of the batch dimensions if `structured_type` 37 | is BATCHED. 38 | - Instance Methods: 39 | - `reshape(new_shape)`: Reshapes the batch dimensions of a BATCHED instance. 40 | - `flatten()`: Flattens the batch dimensions of a BATCHED instance. 41 | - Classmethod: 42 | - `random(shape=(), key=None)`: Generates an instance with random data. 43 | The `shape` argument specifies the desired batch shape, which is 44 | prepended to the default field shapes. 45 | """ 46 | assert hasattr(cls, "default"), "There is no default method." 47 | 48 | field_descriptors: dict[str, FieldDescriptor] = cls.__annotations__ 49 | default_shape = dict([(fn, fd.intrinsic_shape) for fn, fd in field_descriptors.items()]) 50 | default_dtype = dict([(fn, fd.dtype) for fn, fd in field_descriptors.items()]) 51 | 52 | # Pre-calculate generation configurations for the random method 53 | _field_generation_configs = [] 54 | # Ensure consistent order for key splitting, matching __annotations__ 55 | _field_names_for_random = list(cls.__annotations__.keys()) 56 | 57 | for field_name_cfg in _field_names_for_random: 58 | cfg = {} 59 | cfg["name"] = field_name_cfg 60 | # Retrieve the dtype or nested dtype tuple for the current field 61 | actual_dtype_or_nested_dtype_tuple = default_dtype[field_name_cfg] 62 | cfg["default_field_shape"] = default_shape[field_name_cfg] 63 | 64 | if is_nested_xtructure(actual_dtype_or_nested_dtype_tuple): 65 | # This field is a nested xtructure_data instance 66 | cfg["type"] = "xtructure" 67 | # Store the actual nested class type (e.g., Parent, Current) 68 | cfg["nested_class_type"] = cls.__annotations__[field_name_cfg].dtype 69 | # Store the namedtuple of dtypes for the nested structure 70 | cfg["actual_dtype"] = actual_dtype_or_nested_dtype_tuple 71 | else: 72 | # This field is a regular JAX array 73 | actual_dtype = actual_dtype_or_nested_dtype_tuple # It's a single JAX dtype here 74 | cfg["actual_dtype"] = actual_dtype # Store the single JAX dtype 75 | 76 | if jnp.issubdtype(actual_dtype, jnp.integer): 77 | cfg["type"] = "bits_int" # Unified type for all full-range integers via bits 78 | if jnp.issubdtype(actual_dtype, jnp.unsignedinteger): 79 | cfg["bits_gen_dtype"] = actual_dtype # Generate bits of this same unsigned type 80 | cfg["view_as_signed"] = False 81 | else: # It's a signed integer 82 | unsigned_equivalent_str = f"uint{actual_dtype.itemsize * 8}" 83 | cfg["bits_gen_dtype"] = jnp.dtype( 84 | unsigned_equivalent_str 85 | ) # Generate bits of corresponding unsigned type 86 | cfg["view_as_signed"] = True # And then view them as the actual signed type 87 | elif jnp.issubdtype(actual_dtype, jnp.floating): 88 | cfg["type"] = "float" 89 | cfg["gen_dtype"] = actual_dtype 90 | elif actual_dtype == jnp.bool_: 91 | cfg["type"] = "bool" 92 | else: 93 | cfg["type"] = "other" # Fallback 94 | cfg["gen_dtype"] = actual_dtype 95 | _field_generation_configs.append(cfg) 96 | 97 | def reshape(self, new_shape: tuple[int, ...]) -> T: 98 | if self.structured_type == StructuredType.BATCHED: 99 | total_length = jnp.prod(jnp.array(self.shape.batch)) 100 | new_total_length = jnp.prod(jnp.array(new_shape)) 101 | batch_dim = len(self.shape.batch) 102 | if total_length != new_total_length: 103 | raise ValueError( 104 | f"Total length of the state and new shape does not match: {total_length} != {new_total_length}" 105 | ) 106 | return jax.tree_util.tree_map( 107 | lambda x: jnp.reshape(x, new_shape + x.shape[batch_dim:]), self 108 | ) 109 | else: 110 | raise ValueError( 111 | f"Reshape is only supported for BATCHED structured_type. Current type: '{self.structured_type}'." 112 | f"Shape: {self.shape}, Default Shape: {self.default_shape}" 113 | ) 114 | 115 | def flatten(self): 116 | if self.structured_type != StructuredType.BATCHED: 117 | raise ValueError( 118 | f"Flatten operation is only supported for BATCHED structured types. " 119 | f"Current type: {self.structured_type}" 120 | ) 121 | 122 | current_batch_shape = self.shape.batch 123 | # jnp.prod of an empty tuple array is 1, which is correct for total_length 124 | # if current_batch_shape is (). 125 | total_length = jnp.prod(jnp.array(current_batch_shape)) 126 | len_current_batch_shape = len(current_batch_shape) 127 | 128 | return jax.tree_util.tree_map( 129 | # Reshape each leaf: flatten batch dims, keep core dims. 130 | # core_dims are obtained by stripping batch_dims from the start of x.shape. 131 | lambda x: jnp.reshape(x, (total_length,) + x.shape[len_current_batch_shape:]), 132 | self, 133 | ) 134 | 135 | def random(cls, shape=(), key=None): 136 | if key is None: 137 | key = jax.random.PRNGKey(0) 138 | 139 | data = {} 140 | keys = jax.random.split(key, len(_field_generation_configs)) 141 | 142 | for i, cfg in enumerate(_field_generation_configs): 143 | field_key = keys[i] 144 | field_name = cfg["name"] 145 | 146 | if cfg["type"] == "xtructure": 147 | nested_class = cfg["nested_class_type"] 148 | # For nested xtructures, combine batch shape with field shape 149 | current_default_shape = cfg["default_field_shape"] 150 | target_shape = shape + current_default_shape 151 | # Recursively call random for the nested xtructure_data class. 152 | data[field_name] = nested_class.random(shape=target_shape, key=field_key) 153 | else: 154 | # This branch handles primitive JAX array fields. 155 | current_default_shape = cfg["default_field_shape"] 156 | if not isinstance(current_default_shape, tuple): 157 | current_default_shape = ( 158 | current_default_shape, 159 | ) # Ensure it's a tuple for concatenation 160 | 161 | target_shape = shape + current_default_shape 162 | 163 | if cfg["type"] == "bits_int": 164 | generated_bits = jax.random.bits( 165 | field_key, shape=target_shape, dtype=cfg["bits_gen_dtype"] 166 | ) 167 | if cfg["view_as_signed"]: 168 | data[field_name] = generated_bits.view(cfg["actual_dtype"]) 169 | else: 170 | data[field_name] = generated_bits 171 | elif cfg["type"] == "float": 172 | data[field_name] = jax.random.uniform( 173 | field_key, target_shape, dtype=cfg["gen_dtype"] 174 | ) 175 | elif cfg["type"] == "bool": 176 | data[field_name] = jax.random.bernoulli( 177 | field_key, shape=target_shape # p=0.5 by default 178 | ) 179 | else: # Fallback for 'other' dtypes (cfg['type'] == 'other') 180 | try: 181 | data[field_name] = jnp.zeros(target_shape, dtype=cfg["gen_dtype"]) 182 | except TypeError: 183 | raise NotImplementedError( 184 | f"Random generation for dtype {cfg['gen_dtype']} " 185 | f"(field: {field_name}) is not implemented robustly." 186 | ) 187 | return cls(**data) 188 | 189 | # add method based on default state 190 | setattr(cls, "reshape", reshape) 191 | setattr(cls, "flatten", flatten) 192 | setattr(cls, "random", classmethod(random)) 193 | return cls 194 | -------------------------------------------------------------------------------- /xtructure/hashtable/__init__.py: -------------------------------------------------------------------------------- 1 | from .hashtable import HashTable 2 | 3 | __all__ = ["HashTable"] 4 | -------------------------------------------------------------------------------- /xtructure/hashtable/hashtable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hash table implementation using Cuckoo hashing technique for efficient state storage and lookup. 3 | This module provides functionality for hashing Xtructurables and managing collisions. 4 | """ 5 | 6 | from functools import partial 7 | from typing import TypeVar 8 | 9 | import chex 10 | import jax 11 | import jax.numpy as jnp 12 | 13 | from ..core import Xtructurable 14 | 15 | SIZE_DTYPE = jnp.uint32 16 | HASH_TABLE_IDX_DTYPE = jnp.uint8 17 | 18 | T = TypeVar("T") 19 | 20 | 21 | @chex.dataclass 22 | class HashTable: 23 | """ 24 | Cuckoo Hash Table Implementation 25 | 26 | This implementation uses multiple hash functions (specified by n_table) 27 | to resolve collisions. Each item can be stored in one of n_table possible positions. 28 | 29 | Attributes: 30 | seed: Initial seed for hash functions 31 | capacity: User-specified capacity 32 | _capacity: Actual internal capacity (larger than specified to handle collisions) 33 | size: Current number of items in table 34 | table: The actual storage for states 35 | table_idx: Indices tracking which hash function was used for each entry 36 | """ 37 | 38 | seed: int 39 | capacity: int 40 | _capacity: int 41 | cuckoo_table_n: int 42 | size: int 43 | table: Xtructurable # shape = State("args" = (capacity, cuckoo_len, ...), ...) 44 | table_idx: chex.Array # shape = (capacity, ) is the index of the table in the cuckoo table. 45 | 46 | @staticmethod 47 | @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4)) 48 | def build( 49 | dataclass: Xtructurable, 50 | seed: int, 51 | capacity: int, 52 | cuckoo_table_n: int = 2, 53 | hash_size_multiplier: int = 2, 54 | ): 55 | """ 56 | Initialize a new hash table with specified parameters. 57 | 58 | Args: 59 | dataclass: Example Xtructurable to determine the structure 60 | seed: Initial seed for hash functions 61 | capacity: Desired capacity of the table 62 | 63 | Returns: 64 | Initialized HashTable instance 65 | """ 66 | _capacity = int( 67 | hash_size_multiplier * capacity / cuckoo_table_n 68 | ) # Convert to concrete integer 69 | size = SIZE_DTYPE(0) 70 | # Initialize table with default states 71 | table = dataclass.default((_capacity + 1, cuckoo_table_n)) 72 | table_idx = jnp.zeros((_capacity + 1), dtype=HASH_TABLE_IDX_DTYPE) 73 | return HashTable( 74 | seed=seed, 75 | capacity=capacity, 76 | _capacity=_capacity, 77 | cuckoo_table_n=cuckoo_table_n, 78 | size=size, 79 | table=table, 80 | table_idx=table_idx, 81 | ) 82 | 83 | @staticmethod 84 | def get_new_idx( 85 | table: "HashTable", 86 | input: Xtructurable, 87 | seed: int, 88 | ): 89 | """ 90 | Calculate new index for input state using the hash function. 91 | 92 | Args: 93 | table: Hash table instance 94 | input: State to hash 95 | seed: Seed for hash function 96 | 97 | Returns: 98 | Index in the table for the input state 99 | """ 100 | hash_value, _ = input.hash(seed) 101 | idx = hash_value % table._capacity 102 | return idx 103 | 104 | @staticmethod 105 | def get_new_idx_byterized( 106 | table: "HashTable", 107 | input: Xtructurable, 108 | seed: int, 109 | ): 110 | """ 111 | Calculate new index and return byte representation of input state. 112 | Similar to get_new_idx but also returns the byte representation for 113 | equality comparison. 114 | """ 115 | hash_value, bytes = input.hash(seed) 116 | idx = hash_value % table._capacity 117 | return idx, bytes 118 | 119 | @staticmethod 120 | def _lookup( 121 | table: "HashTable", 122 | input: Xtructurable, 123 | idx: int, 124 | table_idx: int, 125 | seed: int, 126 | found: bool, 127 | ): 128 | """ 129 | Internal lookup method that searches for a state in the table. 130 | Uses cuckoo hashing technique to check multiple possible locations. 131 | 132 | Args: 133 | table: Hash table instance 134 | input: State to look up 135 | idx: Initial index to check 136 | table_idx: Which hash function to start with 137 | seed: Initial seed 138 | found: Whether the state has been found 139 | 140 | Returns: 141 | Tuple of (seed, idx, table_idx, found) 142 | """ 143 | 144 | def _cond(val): 145 | seed, idx, table_idx, found = val 146 | filled_idx = table.table_idx[idx] 147 | in_empty = table_idx >= filled_idx 148 | return jnp.logical_and(~found, ~in_empty) 149 | 150 | def _while(val): 151 | seed, idx, table_idx, found = val 152 | 153 | def get_new_idx_and_table_idx(seed, idx, table_idx): 154 | next_table = table_idx >= (table.cuckoo_table_n - 1) 155 | seed, idx, table_idx = jax.lax.cond( 156 | next_table, 157 | lambda _: ( 158 | seed + 1, 159 | HashTable.get_new_idx(table, input, seed + 1), 160 | HASH_TABLE_IDX_DTYPE(0), 161 | ), 162 | lambda _: (seed, idx, HASH_TABLE_IDX_DTYPE(table_idx + 1)), 163 | None, 164 | ) 165 | return seed, idx, table_idx 166 | 167 | state = table.table[idx, table_idx] 168 | found = state == input 169 | seed, idx, table_idx = jax.lax.cond( 170 | found, 171 | lambda _: (seed, idx, table_idx), 172 | lambda _: get_new_idx_and_table_idx(seed, idx, table_idx), 173 | None, 174 | ) 175 | return seed, idx, table_idx, found 176 | 177 | state = table.table[idx, table_idx] 178 | found = jnp.logical_or(found, state == input) 179 | update_seed, idx, table_idx, found = jax.lax.while_loop( 180 | _cond, _while, (seed, idx, table_idx, found) 181 | ) 182 | return update_seed, idx, table_idx, found 183 | 184 | def lookup(table: "HashTable", input: Xtructurable): 185 | """ 186 | Finds the state in the hash table using Cuckoo hashing. 187 | 188 | Args: 189 | table: The HashTable instance. 190 | input: The Xtructurable state to look up. 191 | 192 | Returns: 193 | A tuple (idx, table_idx, found): 194 | - idx (int): The primary hash index in the table. 195 | - table_idx (int): The cuckoo table index (which hash function/slot was used or probed). 196 | - found (bool): True if the state was found, False otherwise. 197 | If not found, idx and table_idx indicate the first empty slot encountered 198 | during the Cuckoo search path where an insertion could occur. 199 | """ 200 | index = HashTable.get_new_idx(table, input, table.seed) 201 | _, idx, table_idx, found = HashTable._lookup( 202 | table, input, index, HASH_TABLE_IDX_DTYPE(0), table.seed, False 203 | ) 204 | return idx, table_idx, found 205 | 206 | def insert(table: "HashTable", input: Xtructurable): 207 | """ 208 | insert the state in the table 209 | """ 210 | 211 | def _update_table(table: "HashTable", input: Xtructurable, idx: int, table_idx: int): 212 | """ 213 | insert the state in the table 214 | """ 215 | table.table = table.table.at[idx, table_idx].set(input) 216 | table.table_idx = table.table_idx.at[idx].add(1) 217 | return table 218 | 219 | idx, table_idx, found = HashTable.lookup(table, input) 220 | return ( 221 | jax.lax.cond( 222 | found, lambda _: table, lambda _: _update_table(table, input, idx, table_idx), None 223 | ), 224 | ~found, 225 | ) 226 | 227 | @staticmethod 228 | @partial( 229 | jax.jit, 230 | static_argnums=( 231 | 0, 232 | 2, 233 | ), 234 | ) 235 | def make_batched(statecls: Xtructurable, inputs: Xtructurable, batch_size: int): 236 | """ 237 | make a batched version of the inputs 238 | """ 239 | count = len(inputs) 240 | batched = jax.tree_util.tree_map( 241 | lambda x, y: jnp.concatenate([x, y]), 242 | inputs, 243 | statecls.default((batch_size - count,)), 244 | ) 245 | filled = jnp.concatenate([jnp.ones(count), jnp.zeros(batch_size - count)], dtype=jnp.bool_) 246 | return batched, filled 247 | 248 | @staticmethod 249 | def _parallel_insert( 250 | table: "HashTable", 251 | inputs: Xtructurable, 252 | seeds: chex.Array, 253 | index: chex.Array, 254 | updatable: chex.Array, 255 | batch_len: int, 256 | ): 257 | def _next_idx(seeds, _idxs, unupdateds): 258 | def get_new_idx_and_table_idx(seed, idx, table_idx, state): 259 | next_table = table_idx >= (table.cuckoo_table_n - 1) 260 | 261 | def next_table_fn(seed, table): 262 | next_idx = HashTable.get_new_idx(table, state, seed) 263 | seed = seed + 1 264 | return seed, next_idx, table.table_idx[next_idx].astype(jnp.uint32) 265 | 266 | seed, idx, table_idx = jax.lax.cond( 267 | next_table, 268 | next_table_fn, 269 | lambda seed, _: (seed, idx, table_idx + 1), 270 | seed, 271 | table, 272 | ) 273 | return seed, idx, table_idx 274 | 275 | idxs = _idxs[:, 0] 276 | table_idxs = _idxs[:, 1] 277 | seeds, idxs, table_idxs = jax.vmap( 278 | lambda unupdated, seed, idx, table_idx, state: jax.lax.cond( 279 | unupdated, 280 | lambda _: get_new_idx_and_table_idx(seed, idx, table_idx, state), 281 | lambda _: (seed, idx, table_idx), 282 | None, 283 | ) 284 | )(unupdateds, seeds, idxs, table_idxs, inputs) 285 | _idxs = jnp.stack((idxs, table_idxs), axis=1) 286 | return seeds, _idxs 287 | 288 | def _cond(val): 289 | _, _, unupdated = val 290 | return jnp.any(unupdated) 291 | 292 | def _while(val): 293 | seeds, _idxs, unupdated = val 294 | seeds, _idxs = _next_idx(seeds, _idxs, unupdated) 295 | 296 | overflowed = jnp.logical_and( 297 | _idxs[:, 1] >= table.cuckoo_table_n, unupdated 298 | ) # Overflowed index must be updated 299 | _idxs = jnp.where(updatable[:, jnp.newaxis], _idxs, jnp.full_like(_idxs, -1)) 300 | unique_idxs = jnp.unique(_idxs, axis=0, size=batch_len, return_index=True)[ 301 | 1 302 | ] # val = (unique_len, 2), unique_idxs = (unique_len,) 303 | not_uniques = ( 304 | jnp.ones((batch_len,), dtype=jnp.bool_).at[unique_idxs].set(False) 305 | ) # set the unique index to True 306 | 307 | unupdated = jnp.logical_and(updatable, not_uniques) 308 | unupdated = jnp.logical_or(unupdated, overflowed) 309 | return seeds, _idxs, unupdated 310 | 311 | _idxs = jnp.where(updatable[:, jnp.newaxis], index, jnp.full_like(index, -1)) 312 | unique_idxs = jnp.unique(_idxs, axis=0, size=batch_len, return_index=True)[ 313 | 1 314 | ] # val = (unique_len, 2), unique_idxs = (unique_len,) 315 | not_uniques = ( 316 | jnp.ones((batch_len,), dtype=jnp.bool_).at[unique_idxs].set(False) 317 | ) # set the unique index to True 318 | unupdated = jnp.logical_and( 319 | updatable, not_uniques 320 | ) # remove the unique index from the unupdated index 321 | 322 | seeds, index, _ = jax.lax.while_loop(_cond, _while, (seeds, _idxs, unupdated)) 323 | 324 | idx, table_idx = index[:, 0], index[:, 1].astype(HASH_TABLE_IDX_DTYPE) 325 | table.table = table.table.at[idx, table_idx].set_as_condition(updatable, inputs) 326 | table.table_idx = table.table_idx.at[idx].add(updatable) 327 | table.size += jnp.sum(updatable, dtype=SIZE_DTYPE) 328 | return table, idx, table_idx 329 | 330 | def parallel_insert(table: "HashTable", inputs: Xtructurable, filled: chex.Array): 331 | """ 332 | Parallel insertion of multiple states into the hash table. 333 | 334 | Args: 335 | table: Hash table instance 336 | inputs: States to insert 337 | filled: Boolean array indicating which inputs are valid 338 | 339 | Returns: 340 | Tuple of (updated_table, updatable, unique_filled, idx, table_idx) 341 | """ 342 | 343 | # Get initial indices and byte representations 344 | initial_idx, bytes = jax.vmap( 345 | partial(HashTable.get_new_idx_byterized), in_axes=(None, 0, None) 346 | )(table, inputs, table.seed) 347 | 348 | batch_len = filled.shape[0] 349 | 350 | # Find unique states to avoid duplicates 351 | unique_bytes_idx = jnp.unique(bytes, axis=0, size=batch_len, return_index=True)[1] 352 | unique = jnp.zeros((batch_len,), dtype=jnp.bool_).at[unique_bytes_idx].set(True) 353 | unique_filled = jnp.logical_and(filled, unique) 354 | 355 | # Look up each state 356 | seeds, idx, table_idx, found = jax.vmap( 357 | partial(HashTable._lookup), in_axes=(None, 0, 0, None, None, 0) 358 | )(table, inputs, initial_idx, HASH_TABLE_IDX_DTYPE(0), table.seed, ~unique_filled) 359 | 360 | idxs = jnp.stack([idx, table_idx], axis=1, dtype=SIZE_DTYPE) 361 | updatable = jnp.logical_and(~found, unique_filled) 362 | 363 | # Perform parallel insertion 364 | table, idx, table_idx = HashTable._parallel_insert( 365 | table, inputs, seeds, idxs, updatable, batch_len 366 | ) 367 | 368 | # Get final indices 369 | _, idx, table_idx, _ = jax.vmap( 370 | partial(HashTable._lookup), in_axes=(None, 0, 0, None, None, 0) 371 | )(table, inputs, initial_idx, HASH_TABLE_IDX_DTYPE(0), table.seed, ~filled) 372 | 373 | return table, updatable, unique_filled, idx, table_idx 374 | -------------------------------------------------------------------------------- /xtructure/queue/__init__.py: -------------------------------------------------------------------------------- 1 | from .queue import Queue 2 | 3 | __all__ = ["Queue"] 4 | -------------------------------------------------------------------------------- /xtructure/queue/queue.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import chex 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from ..core import Xtructurable 8 | 9 | SIZE_DTYPE = jnp.uint32 10 | 11 | 12 | @chex.dataclass 13 | class Queue: 14 | """ 15 | A JAX-compatible batched Queue data structure. 16 | Optimized for parallel operations on GPU using JAX. 17 | 18 | Attributes: 19 | max_size: Maximum number of elements the queue can hold. 20 | val_store: Array storing the values in the queue. 21 | head: Index of the first item in the queue. 22 | tail: Index of the next available slot. 23 | """ 24 | 25 | max_size: int 26 | val_store: Xtructurable 27 | head: SIZE_DTYPE 28 | tail: SIZE_DTYPE 29 | 30 | @property 31 | def size(self): 32 | return self.tail - self.head 33 | 34 | @staticmethod 35 | @partial(jax.jit, static_argnums=(0, 1)) 36 | def build(max_size: int, value_class: Xtructurable): 37 | """ 38 | Creates a new Queue instance. 39 | """ 40 | val_store = value_class.default((max_size,)) 41 | head = SIZE_DTYPE(0) 42 | tail = SIZE_DTYPE(0) 43 | return Queue(max_size=max_size, val_store=val_store, head=head, tail=tail) 44 | 45 | @jax.jit 46 | def enqueue(self, items: Xtructurable): 47 | """ 48 | Enqueues a number of items into the queue. 49 | """ 50 | batch_size = items.shape.batch 51 | if batch_size == (): 52 | num_to_enqueue = 1 53 | indices = self.tail 54 | else: 55 | assert len(batch_size) == 1, "Batch size must be 1" 56 | num_to_enqueue = batch_size[0] 57 | indices = self.tail + jnp.arange(num_to_enqueue) 58 | self.val_store = self.val_store.at[indices].set(items) 59 | self.tail = self.tail + num_to_enqueue 60 | return self 61 | 62 | @partial(jax.jit, static_argnums=(1,)) 63 | def dequeue(self, num_items: int = 1): 64 | """ 65 | Dequeues a number of items from the queue. 66 | """ 67 | if num_items == 1: 68 | indices = self.head 69 | else: 70 | indices = self.head + jnp.arange(num_items) 71 | 72 | dequeued_items = self.val_store[indices] 73 | self.head = self.head + num_items 74 | return self, dequeued_items 75 | 76 | @partial(jax.jit, static_argnums=(1,)) 77 | def peek(self, num_items: int = 1): 78 | """ 79 | Peeks at the front items of the queue without removing them. 80 | """ 81 | if num_items == 1: 82 | indices = self.head 83 | else: 84 | indices = self.head + jnp.arange(num_items) 85 | peeked_items = self.val_store[indices] 86 | return peeked_items 87 | 88 | @jax.jit 89 | def clear(self): 90 | """ 91 | Clears the queue. 92 | """ 93 | self.head = SIZE_DTYPE(0) 94 | self.tail = SIZE_DTYPE(0) 95 | return self 96 | -------------------------------------------------------------------------------- /xtructure/stack/__init__.py: -------------------------------------------------------------------------------- 1 | from .stack import Stack 2 | 3 | __all__ = ["Stack"] 4 | -------------------------------------------------------------------------------- /xtructure/stack/stack.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import chex 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from ..core import Xtructurable 8 | 9 | SIZE_DTYPE = jnp.uint32 10 | 11 | 12 | @chex.dataclass 13 | class Stack: 14 | """ 15 | A JAX-compatible batched Stack data structure. 16 | Optimized for parallel operations on GPU using JAX. 17 | 18 | Attributes: 19 | max_size: Maximum number of elements the stack can hold. 20 | size: Current number of elements in the stack. 21 | val_store: Array storing the values in the stack. 22 | """ 23 | 24 | max_size: int 25 | size: SIZE_DTYPE 26 | val_store: Xtructurable 27 | 28 | @staticmethod 29 | @partial(jax.jit, static_argnums=(0, 1)) 30 | def build(max_size: int, value_class: Xtructurable): 31 | """ 32 | Creates a new Stack instance. 33 | 34 | Args: 35 | max_size: The maximum number of elements the stack can hold. 36 | value_class: The class of values to be stored in the stack. 37 | It must be a subclass of Xtructurable. 38 | 39 | Returns: 40 | A new, empty Stack instance. 41 | """ 42 | size = SIZE_DTYPE(0) 43 | val_store = value_class.default((max_size,)) 44 | return Stack(max_size=max_size, size=size, val_store=val_store) 45 | 46 | @jax.jit 47 | def push(self, items: Xtructurable): 48 | """ 49 | Pushes a batch of items onto the stack. 50 | 51 | Args: 52 | items: An Xtructurable containing the items to push. The first 53 | dimension is the batch dimension. 54 | 55 | Returns: 56 | A new Stack instance with the items pushed onto it. 57 | """ 58 | batch_size = items.shape.batch 59 | if batch_size == (): 60 | new_size = self.size + 1 61 | indices = self.size 62 | else: 63 | assert len(batch_size) == 1, "Batch size must be 1" 64 | new_size = self.size + batch_size[0] 65 | indices = self.size + jnp.arange(batch_size[0]) 66 | self.val_store = self.val_store.at[indices].set(items) 67 | self.size = new_size 68 | return self 69 | 70 | @partial(jax.jit, static_argnums=(1,)) 71 | def pop(self, num_items: int = 1): 72 | """ 73 | Pops a number of items from the stack. 74 | 75 | Args: 76 | num_items: The number of items to pop. 77 | 78 | Returns: 79 | A tuple containing: 80 | - A new Stack instance with items removed. 81 | - The popped items. 82 | """ 83 | new_size = self.size - num_items 84 | if num_items == 1: 85 | indices = self.size - 1 86 | else: 87 | indices = self.size - jnp.arange(num_items, 0, -1) 88 | popped_items = self.val_store[indices] 89 | self.size = new_size 90 | return self, popped_items 91 | 92 | @partial(jax.jit, static_argnums=(1,)) 93 | def peek(self, num_items: int = 1): 94 | """ 95 | Peeks at the top items of the stack without removing them. 96 | 97 | Args: 98 | num_items: The number of items to peek at. Defaults to 1. 99 | 100 | Returns: 101 | The top `num_items` from the stack. 102 | """ 103 | if num_items == 1: 104 | indices = self.size - 1 105 | else: 106 | indices = self.size - jnp.arange(num_items, 0, -1) 107 | peeked_items = self.val_store[indices] 108 | return peeked_items 109 | --------------------------------------------------------------------------------