├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── Makefile ├── README.md ├── make.bat ├── run.sh └── source │ ├── api │ └── python │ │ ├── index.rst │ │ ├── tglite.batch.rst │ │ ├── tglite.block.rst │ │ ├── tglite.context.rst │ │ ├── tglite.graph.rst │ │ ├── tglite.mailbox.rst │ │ ├── tglite.memory.rst │ │ ├── tglite.nn.rst │ │ ├── tglite.op.rst │ │ ├── tglite.rst │ │ └── tglite.sampler.rst │ ├── conf.py │ ├── generated │ ├── tglite.Mailbox.rst │ ├── tglite.Memory.rst │ ├── tglite.TBatch.block.rst │ ├── tglite.TBatch.block_adj.rst │ ├── tglite.TBatch.edges.rst │ ├── tglite.TBatch.eids.rst │ ├── tglite.TBatch.g.rst │ ├── tglite.TBatch.neg_nodes.rst │ ├── tglite.TBatch.nodes.rst │ ├── tglite.TBatch.split_data.rst │ ├── tglite.TBatch.times.rst │ ├── tglite.TBlock.allnodes.rst │ ├── tglite.TBlock.apply.rst │ ├── tglite.TBlock.clear_hooks.rst │ ├── tglite.TBlock.clear_nbrs.rst │ ├── tglite.TBlock.dstdata.rst │ ├── tglite.TBlock.dstfeat.rst │ ├── tglite.TBlock.dstindex.rst │ ├── tglite.TBlock.dstnodes.rst │ ├── tglite.TBlock.dsttimes.rst │ ├── tglite.TBlock.edata.rst │ ├── tglite.TBlock.efeat.rst │ ├── tglite.TBlock.eid.rst │ ├── tglite.TBlock.ets.rst │ ├── tglite.TBlock.g.rst │ ├── tglite.TBlock.has_nbrs.rst │ ├── tglite.TBlock.layer.rst │ ├── tglite.TBlock.mail.rst │ ├── tglite.TBlock.mem_data.rst │ ├── tglite.TBlock.next.rst │ ├── tglite.TBlock.next_block.rst │ ├── tglite.TBlock.nfeat.rst │ ├── tglite.TBlock.num_dst.rst │ ├── tglite.TBlock.num_edges.rst │ ├── tglite.TBlock.num_src.rst │ ├── tglite.TBlock.prev.rst │ ├── tglite.TBlock.register_hook.rst │ ├── tglite.TBlock.run_hooks.rst │ ├── tglite.TBlock.set_nbrs.rst │ ├── tglite.TBlock.srcdata.rst │ ├── tglite.TBlock.srcfeat.rst │ ├── tglite.TBlock.srcnodes.rst │ ├── tglite.TBlock.time_deltas.rst │ ├── tglite.TBlock.uniq_src.rst │ ├── tglite.TContext.enable_embed_caching.rst │ ├── tglite.TContext.enable_time_precompute.rst │ ├── tglite.TContext.eval.rst │ ├── tglite.TContext.graph.rst │ ├── tglite.TContext.need_sampling.rst │ ├── tglite.TContext.set_cache_limit.rst │ ├── tglite.TContext.set_time_window.rst │ ├── tglite.TContext.train.rst │ ├── tglite.TGraph.compute_device.rst │ ├── tglite.TGraph.edata.rst │ ├── tglite.TGraph.efeat.rst │ ├── tglite.TGraph.mailbox.rst │ ├── tglite.TGraph.mem.rst │ ├── tglite.TGraph.move_data.rst │ ├── tglite.TGraph.ndata.rst │ ├── tglite.TGraph.nfeat.rst │ ├── tglite.TGraph.num_edges.rst │ ├── tglite.TGraph.num_nodes.rst │ ├── tglite.TGraph.set_compute.rst │ ├── tglite.TGraph.storage_device.rst │ ├── tglite.from_csv.rst │ ├── tglite.iter_edges.rst │ ├── tglite.op.aggregate.rst │ ├── tglite.op.cache.rst │ ├── tglite.op.coalesce.rst │ ├── tglite.op.dedup.rst │ ├── tglite.op.edge_reduce.rst │ ├── tglite.op.edge_softmax.rst │ ├── tglite.op.edge_view.rst │ ├── tglite.op.precomputed_times.rst │ ├── tglite.op.precomputed_zeros.rst │ ├── tglite.op.preload.rst │ ├── tglite.op.propagate.rst │ └── tglite.op.src_scatter.rst │ ├── img │ ├── blank.png │ ├── colab.svg │ ├── github.svg │ ├── tblock-structure.png │ ├── tblock-workflow.png │ └── train.png │ ├── index.rst │ ├── install │ └── index.rst │ └── tutorial │ ├── quickstart.ipynb │ └── tblock.rst ├── environment.yml ├── examples ├── apan │ ├── apan.py │ └── train.py ├── download-data.sh ├── exp │ ├── apan-gdelt.sh │ ├── apan.sh │ ├── jodie-gdelt.sh │ ├── jodie.sh │ ├── tgat-gdelt.sh │ ├── tgat.sh │ ├── tgn-gdelt.sh │ └── tgn.sh ├── gen-data-files.py ├── jodie │ ├── jodie.py │ └── train.py ├── requirements.txt ├── support.py ├── tgat │ ├── TGAT.ipynb │ ├── tgat.py │ └── train.py └── tgn │ ├── tgn.py │ └── train.py ├── include └── tglite │ ├── cache.h │ ├── core.h │ ├── dedup.h │ ├── sampler.h │ ├── tcsr.h │ └── utils.h ├── lib ├── bind.cpp ├── cache.cpp ├── dedup.cpp ├── sampler.cpp ├── tcsr.cpp └── utils.cpp ├── pyproject.toml ├── python └── tglite │ ├── __init__.py │ ├── _batch.py │ ├── _block.py │ ├── _context.py │ ├── _core.py │ ├── _frame.py │ ├── _graph.py │ ├── _mailbox.py │ ├── _memory.py │ ├── _sampler.py │ ├── _stats.py │ ├── _utils.py │ ├── nn.py │ └── op.py ├── scripts ├── run-exp-large.sh ├── run-exp-slurm.sh ├── run-exp.sh ├── setup-aws.sh └── setup-repo.sh ├── setup.py └── tests └── python ├── data └── edges.csv ├── test_block.py ├── test_frame.py ├── test_graph.py ├── test_tglite.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | *.DS_Store 3 | 4 | __pycache__/ 5 | .pytest_cache/ 6 | .coverage 7 | 8 | dist/ 9 | build/ 10 | *.egg-info/ 11 | *.so 12 | 13 | docs/_build/ 14 | 15 | examples/data 16 | examples/models/ 17 | examples/out-*.csv 18 | examples/out-*.txt 19 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-20.04 10 | tools: 11 | # python: "mambaforge-22.9" 12 | python: "3.7" 13 | 14 | commands: 15 | # - pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 16 | # - pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html 17 | # - pip install .[docs] 18 | # - cd docs/ && make html 19 | # - mkdir _readthedocs 20 | # - cp -r docs/build/html _readthedocs/ 21 | # - python setup.py install 22 | # Install dependencies 23 | # - cd docs/ && pip install -r requirements.txt 24 | # Build the site 25 | # - cd docs/ && make html 26 | # Copy generated files into Read the Docs directory 27 | # - cd docs/ && ls 28 | # - cd docs/build && ls 29 | # - mkdir _readthedocs 30 | # - cp --recursive docs/build/html _readthedocs/ 31 | - ls _readthedocs 32 | 33 | # conda: 34 | # environment: environment.yml 35 | 36 | # Build documentation in the "docs/" directory with Sphinx 37 | # sphinx: 38 | # configuration: null 39 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 40 | # builder: "dirhtml" 41 | # Fail on all warnings to avoid broken references 42 | # fail_on_warning: true 43 | 44 | # Optionally build your docs in additional formats such as PDF and ePub 45 | # formats: 46 | # - pdf 47 | # - epub 48 | 49 | # Optional but recommended, declare the Python requirements required 50 | # to build your documentation 51 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 52 | # python: 53 | # install: 54 | # - method: pip 55 | # path: . 56 | # extra_requirements: 57 | # - docs 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 TGLite Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Add header files to sdist 2 | graft include/tglite 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build clean develop install test uninstall 2 | 3 | build: 4 | python setup.py build 5 | 6 | develop: 7 | python setup.py develop 8 | 9 | install: 10 | python setup.py install 11 | 12 | uninstall: 13 | pip uninstall --yes tglite 14 | 15 | test: 16 | pytest --cov=tglite 17 | 18 | clean: 19 | rm -rf **/*/__pycache__ .pytest_cache 20 | rm -rf build dist **/*.egg-info 21 | rm -rf .coverage **/*/*.so 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TGLite - A Framework for Temporal GNNs 2 | 3 | TGLite is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. TGNNs, or Temporal Graph Neural Networks, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes. TGLite employs an abstraction called a _TBlock_ to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the [TGL][tgl] framework by [up to 3x](#publication) in terms of training time. 4 | 5 |
6 | 7 | End-to-end training epoch time comparison on an Nvidia A100 GPU. 8 |
9 | 10 | [tgl]: https://github.com/amazon-science/tgl 11 | 12 | ## Installation 13 | 14 | See our [documentation][docs] for instructions on how to install the TGLite package, as well as examples and references for supported functionality. To install from source or for local development, go to the [Building from source][build-src] section, it also explains how to run [examples][exp]. 15 | 16 | [docs]: https://tglite.readthedocs.io 17 | [build-src]: https://tglite.readthedocs.io/en/latest/install/index.html#building-from-source 18 | [exp]: examples 19 | 20 | ## Getting Started 21 | 22 | TGLite is currently designed to be used with PyTorch as a training backend, typically with GPU devices. A TGNN model can be defined and trained in the usual way using PyTorch, with the computations constructed using a mix of PyTorch functions and operators/optimizations from TGLite. Below is a simple example (not a real network architecture, just for demonstration purposes): 23 | 24 | ```python 25 | import torch 26 | import tglite as tg 27 | 28 | class TGNN(torch.nn.Module): 29 | def __init__(self, ctx: tg.TContext, dim_node=100, dim_time=100): 30 | super().__init__() 31 | self.ctx = ctx 32 | self.linear = torch.nn.Linear(dim_node + dim_time, dim_node) 33 | self.sampler = tg.TSampler(num_nbrs=10, strategy='recent') 34 | self.encoder = tg.nn.TimeEncode(dim_time) 35 | 36 | def forward(self, batch: tg.TBatch): 37 | blk = batch.block(self.ctx) 38 | blk = tg.op.dedup(blk) 39 | blk = self.sampler.sample(blk) 40 | blk.srcdata['h'] = blk.srcfeat() 41 | return tg.op.aggregate(blk, self.compute, key='h') 42 | 43 | def compute(self, blk: tg.TBlock): 44 | feats = self.encoder(blk.time_deltas()) 45 | feats = torch.cat([blk.srcdata['h'], feats], dim=1) 46 | embeds = self.linear(feats) 47 | embeds = tg.op.edge_reduce(blk, embeds, op='sum') 48 | return torch.relu(embeds) 49 | 50 | graph = tg.from_csv(...) 51 | ctx = tg.TContext(graph) 52 | model = TGNN(ctx) 53 | train(model) 54 | ``` 55 | 56 | The example model is defined to first construct the graph dependencies for nodes in the current batch of edges. The `dedup()` optimization is applied before sampling for 10 recent neighbors. Node embeddings are computed by simply combining node and time features, applying a linear layer and summing across neighbors. More complex computations and aggregations, such as temporal self-attention often used with TGNNs, can be defined using the provided building blocks. 57 | 58 | ## Publication 59 | 60 | * Yufeng Wang and Charith Mendis. 2024. [TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks][tglite-paper]. In 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS '24), April 2024, La Jolla, CA, USA. (To Appear) 61 | 62 | * Yufeng Wang and Charith Mendis. 2023. [TGOpt: Redundancy-Aware Optimizations for Temporal Graph Attention Networks][tgopt-paper]. In Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming (PPoPP '23), February 2023, Montreal, QC, Canada. 63 | 64 | If you find TGLite useful, please consider attributing to the following citation: 65 | 66 | ```bibtex 67 | @inproceedings{wang2024tglite, 68 | author = {Wang, Yufeng and Mendis, Charith}, 69 | title = {TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks}, 70 | year = {2024}, 71 | booktitle = {Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, 72 | doi = {10.1145/3620665.3640414} 73 | } 74 | ``` 75 | 76 | [tglite-paper]: https://charithmendis.com/assets/pdf/asplos24-tglite.pdf 77 | [tgopt-paper]: https://charithmendis.com/assets/pdf/ppopp23-tgopt.pdf 78 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Build TGLite Documentation 2 | See [Build the documentation locally](https://tglite.readthedocs.io/en/latest/install/index.html#building-the-document-locally) in the doc. 3 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Script to open TGLite documentation using detected browser. 4 | # 5 | 6 | # Get the directory where the script is located. 7 | SCRIPT_DIR="$(dirname "$0")" 8 | 9 | # Path to the HTML file you want to open. 10 | HTML_PATH="$SCRIPT_DIR/build/html/index.html" 11 | 12 | # Function to open the HTML file using a browser. 13 | open_html() { 14 | local path=$1 15 | if command -v xdg-open &> /dev/null; then 16 | # Preferred way to open files on Desktop Linux 17 | xdg-open "$path" 18 | elif command -v gnome-open &> /dev/null; then 19 | # For systems with Gnome. 20 | gnome-open "$path" 21 | elif command -v x-www-browser &> /dev/null; then 22 | # A generic way to open files, might work when xdg-open and gnome-open are unavailable. 23 | x-www-browser "$path" 24 | else 25 | echo "Could not detect the web browser to open the documentation." 26 | return 1 27 | fi 28 | } 29 | 30 | # Detect the operating system and open the HTML file. 31 | case "$(uname -s)" in 32 | Linux*) 33 | open_html "$HTML_PATH" 34 | ;; 35 | Darwin*) 36 | open "$HTML_PATH" 37 | ;; 38 | CYGWIN*|MINGW32*|MSYS*|MINGW*) 39 | start "$HTML_PATH" 40 | ;; 41 | *) 42 | echo "Unknown operating system. Cannot open the documentation automatically." 43 | ;; 44 | esac 45 | -------------------------------------------------------------------------------- /docs/source/api/python/index.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | tglite 8 | tglite.batch 9 | tglite.block 10 | tglite.context 11 | tglite.graph 12 | tglite.mailbox 13 | tglite.memory 14 | tglite.sampler 15 | tglite.nn 16 | tglite.op -------------------------------------------------------------------------------- /docs/source/api/python/tglite.batch.rst: -------------------------------------------------------------------------------- 1 | .. _api-batch: 2 | 3 | tglite.TBatch 4 | ============= 5 | 6 | .. currentmodule:: tglite 7 | .. autoclass:: TBatch 8 | 9 | .. automethod:: __init__ 10 | 11 | .. currentmodule:: tglite.TBatch 12 | 13 | Get graph data 14 | -------------- 15 | .. autosummary:: 16 | :toctree: ../../generated/ 17 | 18 | g 19 | neg_nodes 20 | eids 21 | edges 22 | nodes 23 | times 24 | 25 | 26 | Get TBlock 27 | ----------- 28 | .. autosummary:: 29 | :toctree: ../../generated/ 30 | 31 | block 32 | block_adj 33 | 34 | Split data 35 | ---------- 36 | .. autosummary:: 37 | :toctree: ../../generated/ 38 | 39 | split_data 40 | 41 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.block.rst: -------------------------------------------------------------------------------- 1 | .. _api-block: 2 | 3 | tglite.TBlock 4 | ============= 5 | 6 | TBlock captures 1-hop relationships between node/time pairs and their neighbors for doing computations. :ref:`Figure 7 | ` shows the internal structure of a TBlock and the doubly-linked list design with next and prev pointing 8 | to sampled neighbors' TBlock. To have a general idea of how TBlock works, please refer to the :ref:`tutorial `. 9 | 10 | .. _tblock figure: 11 | .. figure:: ../../img/tblock-structure.png 12 | :alt: tblock-structure 13 | :align: center 14 | :figwidth: 60 % 15 | 16 | Diagram of the doubly-linked list design and internal structure of a TBlock (destination node-time is denoted as 17 | ). 18 | 19 | 20 | .. currentmodule:: tglite 21 | .. autoclass:: TBlock 22 | 23 | .. automethod:: __init__ 24 | 25 | 26 | .. currentmodule:: tglite.TBlock 27 | 28 | 29 | Query TBlock attributes 30 | ----------------------- 31 | .. autosummary:: 32 | :toctree: ../../generated/ 33 | 34 | g 35 | layer 36 | dstnodes 37 | dsttimes 38 | num_dst 39 | 40 | Query neighbor attributes 41 | ------------------------- 42 | .. autosummary:: 43 | :toctree: ../../generated/ 44 | 45 | dstindex 46 | srcnodes 47 | num_src 48 | eid 49 | ets 50 | num_edges 51 | has_nbrs 52 | 53 | Query data 54 | ---------- 55 | .. autosummary:: 56 | :toctree: ../../generated/ 57 | 58 | dstdata 59 | srcdata 60 | edata 61 | 62 | Query Cache 63 | ----------- 64 | .. autosummary:: 65 | :toctree: ../../generated/ 66 | 67 | allnodes 68 | uniq_src 69 | efeat 70 | nfeat 71 | srcfeat 72 | dstfeat 73 | mem_data 74 | mail 75 | time_deltas 76 | 77 | Neighbor ops 78 | ------------ 79 | .. autosummary:: 80 | :toctree: ../../generated/ 81 | 82 | set_nbrs 83 | clear_nbrs 84 | 85 | Linked list ops 86 | --------------- 87 | .. autosummary:: 88 | :toctree: ../../generated/ 89 | 90 | prev 91 | next 92 | next_block 93 | 94 | Execution ops 95 | ------------- 96 | .. autosummary:: 97 | :toctree: ../../generated/ 98 | 99 | apply 100 | register_hook 101 | run_hooks 102 | clear_hooks 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.context.rst: -------------------------------------------------------------------------------- 1 | .. _api-context: 2 | 3 | tglite.TContext 4 | =============== 5 | .. currentmodule:: tglite 6 | .. autoclass:: TContext 7 | 8 | .. automethod:: __init__ 9 | 10 | 11 | .. currentmodule:: tglite.TContext 12 | 13 | Basic settings 14 | -------------- 15 | .. autosummary:: 16 | :toctree: ../../generated/ 17 | 18 | train 19 | eval 20 | need_sampling 21 | 22 | Set node embedding cache 23 | ------------------------ 24 | .. autosummary:: 25 | :toctree: ../../generated/ 26 | 27 | enable_embed_caching 28 | set_cache_limit 29 | 30 | Set time precomputation 31 | ----------------------- 32 | .. autosummary:: 33 | :toctree: ../../generated/ 34 | 35 | enable_time_precompute 36 | set_time_window 37 | 38 | Query graph 39 | ----------- 40 | .. autosummary:: 41 | :toctree: ../../generated/ 42 | 43 | graph 44 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.graph.rst: -------------------------------------------------------------------------------- 1 | .. _api-graph: 2 | 3 | tglite.TGraph 4 | ============= 5 | .. currentmodule:: tglite 6 | .. autoclass:: TGraph 7 | 8 | .. automethod:: __init__ 9 | 10 | 11 | .. currentmodule:: tglite.TGraph 12 | 13 | Query graph structure 14 | --------------------- 15 | 16 | .. autosummary:: 17 | :toctree: ../../generated/ 18 | 19 | num_nodes 20 | num_edges 21 | 22 | Query graph data 23 | ---------------- 24 | 25 | .. autosummary:: 26 | :toctree: ../../generated/ 27 | 28 | efeat 29 | nfeat 30 | edata 31 | ndata 32 | 33 | Query memory-based TGN data 34 | --------------------------- 35 | 36 | .. autosummary:: 37 | :toctree: ../../generated/ 38 | 39 | mem 40 | mailbox 41 | 42 | Set device 43 | ---------- 44 | 45 | .. autosummary:: 46 | :toctree: ../../generated/ 47 | 48 | storage_device 49 | compute_device 50 | set_compute 51 | move_data 52 | 53 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.mailbox.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: tglite 2 | 3 | .. autosummary:: 4 | :toctree: ../../generated/ 5 | :nosignatures: 6 | 7 | Mailbox 8 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.memory.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: tglite 2 | 3 | .. autosummary:: 4 | :toctree: ../../generated/ 5 | :nosignatures: 6 | 7 | Memory 8 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.nn.rst: -------------------------------------------------------------------------------- 1 | .. _api-nn: 2 | 3 | tglite.nn 4 | ==================== 5 | .. currentmodule:: tglite.nn 6 | .. autoclass:: TimeEncode 7 | 8 | .. automethod:: __init__ 9 | .. automethod:: zeros 10 | .. automethod:: forward 11 | 12 | 13 | .. autoclass:: TemporalAttnLayer 14 | 15 | .. automethod:: __init__ 16 | .. automethod:: forward -------------------------------------------------------------------------------- /docs/source/api/python/tglite.op.rst: -------------------------------------------------------------------------------- 1 | .. _api-op: 2 | 3 | tglite.op 4 | ========= 5 | .. currentmodule:: tglite.op 6 | 7 | .. autosummary:: 8 | :toctree: ../../generated/ 9 | 10 | edge_view 11 | edge_softmax 12 | edge_reduce 13 | src_scatter 14 | coalesce 15 | preload 16 | aggregate 17 | propagate 18 | dedup 19 | cache 20 | precomputed_zeros 21 | precomputed_times 22 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.rst: -------------------------------------------------------------------------------- 1 | .. _api-tglite: 2 | 3 | tglite 4 | ====== 5 | .. currentmodule:: tglite 6 | .. automodule:: tglite 7 | 8 | .. _api-edge-iteration-ops: 9 | 10 | Edge Iterate Ops 11 | ---------------- 12 | .. autosummary:: 13 | :toctree: ../../generated/ 14 | 15 | iter_edges 16 | 17 | .. _api-graph-creation-ops: 18 | 19 | Create Graph Ops 20 | ---------------- 21 | .. autosummary:: 22 | :toctree: ../../generated/ 23 | 24 | from_csv 25 | -------------------------------------------------------------------------------- /docs/source/api/python/tglite.sampler.rst: -------------------------------------------------------------------------------- 1 | .. _api-sampler: 2 | 3 | tglite.TSampler 4 | =============== 5 | .. currentmodule:: tglite 6 | .. autoclass:: TSampler 7 | 8 | .. automethod:: __init__ 9 | .. automethod:: sample -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | 7 | # If extensions (or modules to document with autodoc) are in another directory, 8 | # add these directories to sys.path here. 9 | 10 | import pathlib 11 | import sys 12 | import os 13 | import site 14 | sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix()) 15 | sys.path.insert(0, os.path.abspath("../../python")) 16 | sys.path.insert(0, site.getsitepackages()[0]) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 20 | 21 | project = 'TGLite' 22 | copyright = '2024, ADAPT Group' 23 | author = 'ADAPT Group' 24 | release = '0.1.0' 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 29 | 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.autosummary', 33 | 'sphinx.ext.doctest', 34 | 'sphinx.ext.duration', 35 | "sphinx.ext.mathjax", 36 | 'sphinx_rtd_theme', 37 | 'nbsphinx' 38 | ] 39 | 40 | templates_path = ['_templates'] 41 | exclude_patterns = [] 42 | 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 47 | 48 | html_theme = 'sphinx_rtd_theme' 49 | html_static_path = ['_static'] 50 | html_context = { 51 | "display_github": True, # Integrate GitHub 52 | "github_user": "ADAPT-uiuc", # Username 53 | "github_repo": "tglite", # Repo name 54 | "github_version": "main", # Version 55 | "conf_py_path": "/docs/source/", # Path in the checkout to the docs root 56 | } 57 | -------------------------------------------------------------------------------- /docs/source/generated/tglite.Mailbox.rst: -------------------------------------------------------------------------------- 1 | tglite.Mailbox 2 | ============== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoclass:: Mailbox 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Mailbox.__init__ 17 | ~Mailbox.dims 18 | ~Mailbox.move_to 19 | ~Mailbox.reset 20 | ~Mailbox.store 21 | 22 | 23 | 24 | 25 | 26 | .. rubric:: Attributes 27 | 28 | .. autosummary:: 29 | 30 | ~Mailbox.device 31 | ~Mailbox.mail 32 | ~Mailbox.time 33 | 34 | -------------------------------------------------------------------------------- /docs/source/generated/tglite.Memory.rst: -------------------------------------------------------------------------------- 1 | tglite.Memory 2 | ============= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoclass:: Memory 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Memory.__init__ 17 | ~Memory.backup 18 | ~Memory.dim 19 | ~Memory.move_to 20 | ~Memory.reset 21 | ~Memory.restore 22 | ~Memory.update 23 | 24 | 25 | 26 | 27 | 28 | .. rubric:: Attributes 29 | 30 | .. autosummary:: 31 | 32 | ~Memory.data 33 | ~Memory.device 34 | ~Memory.time 35 | 36 | -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.block.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.block 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.block -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.block_adj.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.block\_adj 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.block_adj -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.edges.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.edges 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.edges -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.eids.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.eids 2 | ================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.eids -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.g.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.g 2 | =============== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBatch.g -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.neg_nodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.neg\_nodes 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBatch.neg_nodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.nodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.nodes 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.nodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.split_data.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.split\_data 2 | ========================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.split_data -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBatch.times.rst: -------------------------------------------------------------------------------- 1 | tglite.TBatch.times 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBatch.times -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.allnodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.allnodes 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.allnodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.apply.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.apply 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.apply -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.clear_hooks.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.clear\_hooks 2 | ========================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.clear_hooks -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.clear_nbrs.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.clear\_nbrs 2 | ========================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.clear_nbrs -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.dstdata.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.dstdata 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.dstdata -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.dstfeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.dstfeat 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.dstfeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.dstindex.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.dstindex 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.dstindex -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.dstnodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.dstnodes 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.dstnodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.dsttimes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.dsttimes 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.dsttimes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.edata.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.edata 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.edata -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.efeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.efeat 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.efeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.eid.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.eid 2 | ================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.eid -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.ets.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.ets 2 | ================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.ets -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.g.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.g 2 | =============== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.g -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.has_nbrs.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.has\_nbrs 2 | ======================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.has_nbrs -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.layer.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.layer 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.layer -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.mail.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.mail 2 | ================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.mail -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.mem_data.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.mem\_data 2 | ======================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.mem_data -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.next.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.next 2 | ================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.next -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.next_block.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.next\_block 2 | ========================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.next_block -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.nfeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.nfeat 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.nfeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.num_dst.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.num\_dst 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.num_dst -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.num_edges.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.num\_edges 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.num_edges -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.num_src.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.num\_src 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.num_src -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.prev.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.prev 2 | ================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.prev -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.register_hook.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.register\_hook 2 | ============================ 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.register_hook -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.run_hooks.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.run\_hooks 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.run_hooks -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.set_nbrs.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.set\_nbrs 2 | ======================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.set_nbrs -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.srcdata.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.srcdata 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.srcdata -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.srcfeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.srcfeat 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.srcfeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.srcnodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.srcnodes 2 | ====================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TBlock.srcnodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.time_deltas.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.time\_deltas 2 | ========================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.time_deltas -------------------------------------------------------------------------------- /docs/source/generated/tglite.TBlock.uniq_src.rst: -------------------------------------------------------------------------------- 1 | tglite.TBlock.uniq\_src 2 | ======================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TBlock.uniq_src -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.enable_embed_caching.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.enable\_embed\_caching 2 | ====================================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.enable_embed_caching -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.enable_time_precompute.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.enable\_time\_precompute 2 | ======================================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.enable_time_precompute -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.eval.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.eval 2 | ==================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.eval -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.graph.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.graph 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TContext.graph -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.need_sampling.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.need\_sampling 2 | ============================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.need_sampling -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.set_cache_limit.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.set\_cache\_limit 2 | ================================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.set_cache_limit -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.set_time_window.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.set\_time\_window 2 | ================================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.set_time_window -------------------------------------------------------------------------------- /docs/source/generated/tglite.TContext.train.rst: -------------------------------------------------------------------------------- 1 | tglite.TContext.train 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TContext.train -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.compute_device.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.compute\_device 2 | ============================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.compute_device -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.edata.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.edata 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.edata -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.efeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.efeat 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.efeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.mailbox.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.mailbox 2 | ===================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.mailbox -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.mem.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.mem 2 | ================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.mem -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.move_data.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.move\_data 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.move_data -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.ndata.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.ndata 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.ndata -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.nfeat.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.nfeat 2 | =================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autoproperty:: TGraph.nfeat -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.num_edges.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.num\_edges 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.num_edges -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.num_nodes.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.num\_nodes 2 | ======================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.num_nodes -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.set_compute.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.set\_compute 2 | ========================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.set_compute -------------------------------------------------------------------------------- /docs/source/generated/tglite.TGraph.storage_device.rst: -------------------------------------------------------------------------------- 1 | tglite.TGraph.storage\_device 2 | ============================= 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. automethod:: TGraph.storage_device -------------------------------------------------------------------------------- /docs/source/generated/tglite.from_csv.rst: -------------------------------------------------------------------------------- 1 | tglite.from\_csv 2 | ================ 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autofunction:: from_csv -------------------------------------------------------------------------------- /docs/source/generated/tglite.iter_edges.rst: -------------------------------------------------------------------------------- 1 | tglite.iter\_edges 2 | ================== 3 | 4 | .. currentmodule:: tglite 5 | 6 | .. autofunction:: iter_edges -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.aggregate.rst: -------------------------------------------------------------------------------- 1 | tglite.op.aggregate 2 | =================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: aggregate -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.cache.rst: -------------------------------------------------------------------------------- 1 | tglite.op.cache 2 | =============== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: cache -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.coalesce.rst: -------------------------------------------------------------------------------- 1 | tglite.op.coalesce 2 | ================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: coalesce -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.dedup.rst: -------------------------------------------------------------------------------- 1 | tglite.op.dedup 2 | =============== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: dedup -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.edge_reduce.rst: -------------------------------------------------------------------------------- 1 | tglite.op.edge\_reduce 2 | ====================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: edge_reduce -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.edge_softmax.rst: -------------------------------------------------------------------------------- 1 | tglite.op.edge\_softmax 2 | ======================= 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: edge_softmax -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.edge_view.rst: -------------------------------------------------------------------------------- 1 | tglite.op.edge\_view 2 | ==================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: edge_view -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.precomputed_times.rst: -------------------------------------------------------------------------------- 1 | tglite.op.precomputed\_times 2 | ============================ 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: precomputed_times -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.precomputed_zeros.rst: -------------------------------------------------------------------------------- 1 | tglite.op.precomputed\_zeros 2 | ============================ 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: precomputed_zeros -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.preload.rst: -------------------------------------------------------------------------------- 1 | tglite.op.preload 2 | ================= 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: preload -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.propagate.rst: -------------------------------------------------------------------------------- 1 | tglite.op.propagate 2 | =================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: propagate -------------------------------------------------------------------------------- /docs/source/generated/tglite.op.src_scatter.rst: -------------------------------------------------------------------------------- 1 | tglite.op.src\_scatter 2 | ====================== 3 | 4 | .. currentmodule:: tglite.op 5 | 6 | .. autofunction:: src_scatter -------------------------------------------------------------------------------- /docs/source/img/blank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/blank.png -------------------------------------------------------------------------------- /docs/source/img/colab.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 9 | 10 | 13 | 15 | 17 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/source/img/github.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 15 | 16 | -------------------------------------------------------------------------------- /docs/source/img/tblock-structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/tblock-structure.png -------------------------------------------------------------------------------- /docs/source/img/tblock-workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/tblock-workflow.png -------------------------------------------------------------------------------- /docs/source/img/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/train.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TGLite documentation master file, created by 2 | sphinx-quickstart on Wed Nov 1 15:04:59 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TGLite's documentation! 7 | ================================== 8 | 9 | **TGLite** is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. *TGNNs*, or *Temporal Graph Neural Networks*, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes. 10 | 11 | TGLite employs an abstraction called a :ref:`TBlock ` to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the `TGL `_ framework by up to *3x* in terms of training time. 12 | 13 | .. _train figure: 14 | .. figure:: img/train.png 15 | :alt: End-to-end training epoch time comparison on an Nvidia A100 GPU 16 | :align: center 17 | :figwidth: 85 % 18 | 19 | End-to-end training epoch time comparison on an Nvidia A100 GPU 20 | 21 | Install TGLite 22 | -------------- 23 | See :ref:`Getting started ` for instructions on how to install the TGLite binaries. To install from source or for local development, refer to :ref:`Building from source ` and :ref:`Development mode `. 24 | 25 | Tutorials 26 | --------- 27 | We provide a set of tutorials to help you get started with TGLite. These tutorials cover the basics of using TGLite, as well as more advanced topics. 28 | 29 | 0. Quickstart_: A step-by-step guide to train a TGNN model using TGLite. 30 | 1. :ref:`How does TBlock work? `: A tutorial on how to use the TBlock abstraction to implement TGNN models. 31 | 32 | .. _Quickstart: tutorial/quickstart.ipynb 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :caption: TGLite 37 | :hidden: 38 | :glob: 39 | 40 | install/index 41 | tutorial/quickstart 42 | tutorial/tblock 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | :caption: API 47 | :glob: 48 | 49 | api/python/tglite 50 | api/python/tglite.batch 51 | api/python/tglite.block 52 | api/python/tglite.context 53 | api/python/tglite.graph 54 | api/python/tglite.sampler 55 | api/python/tglite.mailbox 56 | api/python/tglite.memory 57 | api/python/tglite.nn 58 | api/python/tglite.op 59 | 60 | 61 | .. note:: 62 | This project is under active development. 63 | 64 | 65 | Indices and tables 66 | ================== 67 | 68 | * :ref:`genindex` 69 | * :ref:`modindex` 70 | * :ref:`search` 71 | -------------------------------------------------------------------------------- /docs/source/install/index.rst: -------------------------------------------------------------------------------- 1 | .. _getting-started: 2 | 3 | Getting Started 4 | --------------- 5 | 6 | TGLite currently only support PyTorch as backend. 7 | 8 | Installation with pip 9 | `````````````````````` 10 | 11 | Prerequisites 12 | ^^^^^^^^^^^^^ 13 | 14 | .. list-table:: 15 | :widths: 25 25 25 25 16 | 17 | * - **python** 18 | - **gcc** 19 | - **torch** 20 | - **torch-scatter** 21 | * - ≥ 3.7 22 | - ≥ 6.1 23 | - ≥ 1.12.1 24 | - ≥ 2.1.0 25 | 26 | .. * python 3.7 or later 27 | .. * gcc 6.1 or later 28 | .. * torch 1.12.1 or later 29 | .. * torch-scatter 2.1.0 or later 30 | 31 | Installation 32 | ^^^^^^^^^^^^ 33 | Ensure at least PyTorch 1.12.1 and torch-scatter 2.1.0 are installed (refer to `PyTorch `_ and `torch-scatter `_ for installation instructions), simply run 34 | 35 | .. code-block:: console 36 | 37 | $ pip install tglite 38 | 39 | Verification 40 | ^^^^^^^^^^^^ 41 | To verify the installation, run the following in Python: 42 | 43 | .. code-block:: python 44 | 45 | import tglite 46 | print(tglite.__version__) 47 | 48 | .. note:: 49 | Currently, the library is only tested on Linux. MacOS and Windows support is not guaranteed. 50 | 51 | .. _build-from-source: 52 | 53 | Building from source 54 | ````````````````````` 55 | To install the latest TGLite code for testing or development on the core, you will need to build TGlite from source. Here, we show how to build TGLite with Python 3.7, PyTorch 1.12.1 and torch-scatter 2.1.0. 56 | 57 | Create and activate a python environment: 58 | 59 | .. code-block:: console 60 | 61 | $ conda create -n tglite python=3.7 62 | $ conda activate tglite 63 | 64 | Install dependencies that have CUDA versions: 65 | 66 | .. code-block:: console 67 | 68 | $ pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 69 | $ pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html 70 | 71 | Get the TGLite source: 72 | 73 | .. code-block:: console 74 | 75 | $ git clone https://github.com/ADAPT-uiuc/tglite.git 76 | $ cd tglite 77 | 78 | Then install the package locally: 79 | 80 | .. code-block:: console 81 | 82 | $ python setup.py install 83 | 84 | This will build the C++ extension (which requires C++14 and OpenMP), install 85 | the rest of the dependencies (as listed in `pyproject.toml`), and then install 86 | the `tglite` package. 87 | 88 | .. _development-mode: 89 | 90 | Development Mode 91 | ```````````````` 92 | 93 | Development mode allows easily editing the code without having to re-install 94 | the package. However, this only applies to the python code. When editing the 95 | C++ extension code, it needs to be re-compiled again. Use `develop` instead of `install` to use dev mode: 96 | 97 | .. code-block:: console 98 | 99 | $ python setup.py develop 100 | 101 | 102 | Running Tests 103 | ^^^^^^^^^^^^^ 104 | 105 | Unit tests are located in `tests` directory. First, install the testing 106 | dependencies specified in `pyproject.toml`. Doing so might overwrite the dev 107 | mode install, so you might need to re-enable dev mode. Then, exercise the tests 108 | using the `pytest` utility. 109 | 110 | .. code-block:: console 111 | 112 | # install test dependencies 113 | $ pip install '.[test]' 114 | 115 | # re-enable dev mode install 116 | $ pip uninstall -y tglite 117 | $ python setup.py develop 118 | 119 | # run with test coverage report 120 | $ pytest --cov=tglite 121 | 122 | 123 | Running Examples 124 | ```````````````` 125 | Inside the `examples `_ directory of the repository, several CTDG models have been implemented using `tglite`. 126 | To run these example models, install the additional dependencies and download the datasets: 127 | 128 | .. code-block:: console 129 | 130 | $ cd examples 131 | $ pip install -r requirements.txt # or "conda install -c conda-forge pandas scikit-learn" using conda 132 | $ ./download-data.sh 133 | $ python gen-data-files.py --data wiki-talk 134 | 135 | This will download the datasets inside `examples/data/`, one can also download data to other places. 136 | 137 | Use the scripts in `examples/exp` as a starting point, e.g.: 138 | 139 | .. code-block:: console 140 | 141 | $ ./exp/tgat.sh --data-path . -d wiki --epochs 3 142 | 143 | 144 | Building this document locally 145 | ``````````````````````````````` 146 | .. code-block:: console 147 | 148 | # install doc dependencies 149 | $ pip install '.[docs]' 150 | 151 | # build docs 152 | $ cd docs 153 | $ make html 154 | 155 | # launch in browser 156 | $ sh run.sh 157 | -------------------------------------------------------------------------------- /docs/source/tutorial/tblock.rst: -------------------------------------------------------------------------------- 1 | .. _tutorial-tblock: 2 | 3 | How does TBlock work? 4 | ===================== 5 | 6 | Introduction 7 | ------------ 8 | 9 | This tutorial provides an overview of `TBlock`, a key component of the TGLite framework. TBlocks, or temporal blocks, capture the message-flow dependencies between target node-time pairs and their temporally sampled neighbors. This tutorial will explain the design choices and features of TBlock to help you understand its usage within the TGLite framework. 10 | 11 | Overview 12 | -------- 13 | 14 | TBlocks are motivated by the MFG (Message Flow Graph) objects available in the DGL (Deep Graph Library) but provide additional capabilities for CTDG (Continuous-Time Dynamic Graph) models. The following sections will explain the three key design choices that distinguish TBlocks from MFGs. 15 | 16 | Doubly-Linked List Structure 17 | ---------------------------- 18 | 19 | One key distinction of TBlocks is the use of a doubly-linked list structure. This structure explicitly captures the multi-hop neighbor sampling/aggregation relationship that TBlocks are used for. Unlike standalone MFG objects in DGL/TGL, TBlocks maintain links to related blocks, enabling efficient multi-hop aggregation operations. 20 | 21 | Target and Neighbor Information 22 | ------------------------------- 23 | 24 | TBlocks primarily focus on target destination nodes and optionally include information about neighbor source nodes. By separating neighbor information as optional, TBlocks allow for easier manipulation of target node information. This flexibility enables optimizations such as deduplication and caching to be applied effectively, as they can be performed on destination nodes before sampling for neighbors. 25 | 26 | Hooks Mechanism for Post-Processing 27 | ----------------------------------- 28 | 29 | TBlocks provide a hooks mechanism for running post-processing procedures. These hooks are callable functions that are invoked after computations are performed on the block. The hooks mechanism enables scheduling of transformations on computed output, such as deduplication and preserving output semantics. TGLite runtime automatically handles the execution of registered hooks, simplifying the post-processing step. 30 | 31 | .. _tblock-structure figure: 32 | .. figure:: ../img/tblock-structure.png 33 | :alt: tblock-structure 34 | :align: center 35 | :figwidth: 60 % 36 | 37 | Diagram of the doubly-linked list design and internal structure of a TBlock (destination node-time is denoted as ). 38 | 39 | Block Lifecycle and Usage 40 | ------------------------- 41 | 42 | The block lifecycle involves creating a TBlock, applying optimizations, sampling neighbors, performing computations, and accessing cached data. The following steps outline the typical usage of a TBlock: 43 | 44 | 1. Create a TBlock using various methods or construct it directly. 45 | 2. Apply optimizations to the block to minimize subgraph size and potential computations. 46 | 3. Sample neighbors of the block to capture message-flow dependencies. 47 | 4. Manipulate the block in-place, register hooks, or cache data. 48 | 5. Use the block for computations and access cached data for specific nodes and edges. 49 | 50 | .. _tblock-workflow figure: 51 | .. figure:: ../img/tblock-workflow.png 52 | :alt: tblock-workflow 53 | :align: center 54 | :figwidth: 60 % 55 | 56 | Typical flow of constructing and using a TBlock object. 57 | 58 | Example Usage 59 | ------------- 60 | 61 | Here's an example code snippet demonstrating the usage of TBlock: 62 | 63 | .. code-block:: python 64 | 65 | # Create the head TBlock from a batch data 66 | head = batch.block(self.ctx) 67 | 68 | # Create the next TBlock iteratively 69 | for i in range(self.num_layers): 70 | tail = head if i == 0 else tail.next_block(...) 71 | # Apply optimizations 72 | tail = tg.op.dedup(tail) 73 | tail = tg.op.cache(self.ctx, tail, ...) 74 | # Sample neighbors 75 | tail = self.sampler.sample(tail) 76 | 77 | # Load data 78 | tg.op.preload(head, use_pin=True) 79 | tail.dstdata['h'] = tail.dstfeat() 80 | tail.srcdata['h'] = tail.srcfeat() 81 | # Perform computations 82 | emb = tg.op.aggregate(head, self.compute, key='h') 83 | 84 | 85 | 86 | In this tutorial, you learned about TBlock, a key component of the TGLite framework. TBlocks provide a powerful mechanism for capturing and analyzing message-flow dependencies in a continuous-time dynamic graph. By understanding the design choices and features of TBlock, you can effectively leverage its capabilities within your applications. 87 | 88 | For more details and advanced usage, refer to the :ref:`TBlock `. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - defaults 4 | dependencies: 5 | - pip=22.3.1=py37h06a4308_0 6 | - python=3.7.16=h7a1cb2a_0 7 | - pandoc 8 | - pip: 9 | - torch==1.12.1 -------------------------------------------------------------------------------- /examples/apan/apan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tglite as tg 3 | 4 | from torch import nn, Tensor 5 | from tglite._stats import tt 6 | 7 | import sys, os 8 | sys.path.append(os.path.join(os.getcwd(), '..')) 9 | import support 10 | 11 | 12 | class APAN(nn.Module): 13 | def __init__(self, ctx: tg.TContext, 14 | dim_mem: int, dim_edge: int, dim_time: int, 15 | sampler: tg.TSampler, num_heads=2, dropout=0.1): 16 | super().__init__() 17 | self.ctx = ctx 18 | self.dim_edge = dim_edge 19 | self.mem_updater = AttnMemoryUpdater(ctx, 20 | dim_mem=dim_mem, 21 | dim_msg=2 * dim_mem + dim_edge, 22 | dim_time=dim_time, 23 | num_heads=num_heads, 24 | dropout=dropout) 25 | self.sampler = sampler 26 | self.edge_predictor = support.EdgePredictor(dim_mem) 27 | 28 | def forward(self, batch: tg.TBatch): 29 | size = len(batch) 30 | t_start = tt.start() 31 | mem = self.mem_updater(batch) 32 | tt.t_mem_update += tt.elapsed(t_start) 33 | 34 | nodes = batch.nodes(include_negs=False) 35 | times = batch.times(include_negs=False) 36 | batch.g.mem.update(nodes, mem[:2 * size], torch.from_numpy(times)) 37 | 38 | src, dst, neg = batch.split_data(mem) 39 | scores = self.edge_predictor(src, dst) 40 | if batch.neg_nodes is not None: 41 | scores = (scores, self.edge_predictor(src, neg)) 42 | del src 43 | del dst 44 | del neg 45 | 46 | blk = tg.TBlock(self.ctx, 0, nodes, times) 47 | blk = self.sampler.sample(blk) 48 | mem = mem[:2 * size].detach().to(batch.g.storage_device()) 49 | self.create_mails(batch, blk, mem) 50 | del mem 51 | 52 | tg.op.propagate(blk, self.send_mails) 53 | return scores 54 | 55 | def create_mails(self, batch: tg.TBatch, blk: tg.TBlock, mem: Tensor): 56 | size = len(batch) 57 | mem_src = mem[:size] 58 | mem_dst = mem[size:] 59 | 60 | if self.dim_edge > 0: 61 | efeat = batch.g.efeat[batch.eids()] 62 | src_mail = torch.cat([mem_src, mem_dst, efeat], dim=1) 63 | dst_mail = torch.cat([mem_dst, mem_src, efeat], dim=1) 64 | else: 65 | src_mail = torch.cat([mem_src, mem_dst], dim=1) 66 | dst_mail = torch.cat([mem_dst, mem_src], dim=1) 67 | 68 | blk.dstdata['mail'] = torch.cat([src_mail, dst_mail], dim=0) 69 | 70 | def send_mails(self, blk: tg.TBlock): 71 | sdev = blk.g.storage_device() 72 | if blk.num_edges() == 0: 73 | return 74 | 75 | mail = blk.dstdata['mail'][blk.dstindex] 76 | mail = tg.op.src_scatter(blk, mail, op='mean') 77 | 78 | mail_ts = torch.from_numpy(blk.dsttimes) 79 | mail_ts = mail_ts.to(sdev)[blk.dstindex] 80 | mail_ts = tg.op.src_scatter(blk, mail_ts, op='mean') 81 | 82 | blk.g.mailbox.store(blk.uniq_src()[0], mail, mail_ts) 83 | 84 | 85 | class AttnMemoryUpdater(nn.Module): 86 | def __init__(self, ctx: tg.TContext, 87 | dim_mem: int, dim_msg: int, dim_time: int, 88 | num_heads=2, dropout=0.1): 89 | super().__init__() 90 | assert (dim_mem % num_heads == 0) 91 | self.ctx = ctx 92 | self.num_heads = num_heads 93 | self.time_encode = tg.nn.TimeEncode(dim_time) 94 | self.w_q = nn.Linear(dim_mem, dim_mem) 95 | self.w_k = nn.Linear(dim_msg + dim_time, dim_mem) 96 | self.w_v = nn.Linear(dim_msg + dim_time, dim_mem) 97 | self.mlp = nn.Linear(dim_mem, dim_mem) 98 | self.attn_act = nn.LeakyReLU(0.2) 99 | self.dropout = nn.Dropout(dropout) 100 | self.layer_norm = nn.LayerNorm(dim_mem) 101 | 102 | def forward(self, batch: tg.TBatch) -> Tensor: 103 | sdev = batch.g.storage_device() 104 | cdev = batch.g.compute_device() 105 | nodes = batch.nodes() 106 | times = batch.times() 107 | 108 | size = len(nodes) 109 | mem = batch.g.mem 110 | mailbox = batch.g.mailbox 111 | mail_size = mailbox.dims()[0] 112 | 113 | mem_data = mem.data[nodes].to(cdev) 114 | Q = self.w_q(mem_data).reshape(size, self.num_heads, -1) 115 | 116 | mail = mailbox.mail[nodes].to(cdev) 117 | mail = mail.reshape(size, mail_size, -1) 118 | time_feat = torch.from_numpy(times).to(sdev).reshape(-1, 1) 119 | time_feat = (time_feat - mailbox.time[nodes]).to(cdev) 120 | time_feat = tg.op.precomputed_times(self.ctx, 0, self.time_encode, time_feat) 121 | time_feat = time_feat.reshape(size, mail_size, -1) 122 | mail = torch.cat([mail, time_feat], dim=2) 123 | del time_feat 124 | 125 | K = self.w_k(mail).reshape(size, mail_size, self.num_heads, -1) 126 | V = self.w_v(mail).reshape(size, mail_size, self.num_heads, -1) 127 | del mail 128 | 129 | attn = torch.sum(Q[:, None, :, :] * K, dim=3) 130 | del Q 131 | 132 | attn = self.attn_act(attn) 133 | attn = nn.functional.softmax(attn, dim=1) 134 | attn = self.dropout(attn) 135 | out = torch.sum(attn[:, :, :, None] * V, dim=1) 136 | del attn 137 | 138 | out = out.reshape(size, -1) 139 | out = out + mem_data 140 | del mem_data 141 | 142 | out = self.layer_norm(out) 143 | out = self.mlp(out) 144 | out = self.dropout(out) 145 | out = nn.functional.relu(out) 146 | 147 | return out 148 | -------------------------------------------------------------------------------- /examples/apan/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import tglite as tg 6 | 7 | from apan import APAN 8 | import support 9 | 10 | ### arguments 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name') 14 | parser.add_argument('--data-path', type=str, default='', help='path to data folder') 15 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model') 16 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)') 17 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)') 18 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)') 19 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)') 20 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)') 21 | # parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)') 22 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)') 23 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)') 24 | parser.add_argument('--n-mail', type=int, default=10, help='max number of mails (default: 10)') 25 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)') 26 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)') 27 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use') 28 | parser.add_argument('--move', action='store_true', help='move data to device') 29 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)') 30 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)') 31 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings') 32 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)') 33 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations') 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | device = support.make_device(args.gpu) 38 | model_path = support.make_model_path('apan', args.prefix, args.data) 39 | model_mem_path = support.make_model_mem_path('apan', args.prefix, args.data) 40 | if args.seed >= 0: 41 | support.set_seed(args.seed) 42 | 43 | DATA: str = args.data 44 | DATA_PATH: str = args.data_path 45 | EPOCHS: int = args.epochs 46 | BATCH_SIZE: int = args.bsize 47 | LEARN_RATE: float = float(args.lr) 48 | DROPOUT: float = float(args.dropout) 49 | # N_LAYERS: int = args.n_layers 50 | N_HEADS: int = args.n_heads 51 | N_NBRS: int = args.n_nbrs 52 | N_MAIL: int = args.n_mail 53 | DIM_TIME: int = args.dim_time 54 | DIM_EMBED: int = args.dim_embed 55 | N_THREADS: int = args.n_threads 56 | SAMPLING: str = args.sampling 57 | OPT_TIME: bool = args.opt_time or args.opt_all 58 | TIME_WINDOW: int = int(args.time_window) 59 | 60 | 61 | ### load data 62 | 63 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv')) 64 | support.load_feats(g, DATA, DATA_PATH) 65 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1] 66 | g.nfeat = None 67 | dim_mem = DIM_EMBED 68 | 69 | g.mailbox = tg.Mailbox(g.num_nodes(), N_MAIL, 2 * dim_mem + dim_efeat) 70 | g.mem = tg.Memory(g.num_nodes(), dim_mem) 71 | 72 | g.set_compute(device) 73 | if args.move: 74 | g.move_data(device) 75 | 76 | ctx = tg.TContext(g) 77 | ctx.need_sampling(True) 78 | ctx.enable_time_precompute(OPT_TIME) 79 | ctx.set_time_window(TIME_WINDOW) 80 | 81 | 82 | ### model 83 | 84 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS) 85 | model = APAN(ctx, 86 | dim_mem=dim_mem, 87 | dim_edge=dim_efeat, 88 | dim_time=DIM_TIME, 89 | sampler = sampler, 90 | num_heads=N_HEADS, 91 | dropout=DROPOUT) 92 | model = model.to(device) 93 | criterion = torch.nn.BCEWithLogitsLoss() 94 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE) 95 | 96 | 97 | ### training 98 | 99 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15) 100 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size) 101 | 102 | trainer = support.LinkPredTrainer( 103 | ctx, model, criterion, optimizer, neg_sampler, 104 | EPOCHS, BATCH_SIZE, train_end, val_end, 105 | model_path, model_mem_path) 106 | 107 | trainer.train() 108 | trainer.test() -------------------------------------------------------------------------------- /examples/download-data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz 4 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz 5 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt 6 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv 7 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz 8 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv 9 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt 10 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/edges.csv 11 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/ext_full.npz 12 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_full.npz 13 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_train.npz 14 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_train.npz 15 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/labels.csv 16 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_full.npz 17 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/ext_full.npz 18 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/edges.csv 19 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/node_features.pt 20 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/edges.csv 21 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/ext_full.npz 22 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_full.npz 23 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_train.npz 24 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edge_features.pt 25 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edges.csv 26 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/ext_full.npz 27 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_full.npz 28 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_train.npz 29 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/labels.csv 30 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt 31 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edges.csv 32 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/ext_full.npz 33 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_full.npz 34 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_train.npz 35 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/labels.csv 36 | 37 | wget -P ./data/wiki-talk http://snap.stanford.edu/data/wiki-talk-temporal.txt.gz 38 | cd ./data/wiki-talk && gzip -d wiki-talk-temporal.txt.gz 39 | -------------------------------------------------------------------------------- /examples/exp/apan-gdelt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python apan/train.py --seed 0 --prefix exp \ 8 | --epochs 3 --bsize 4000 --n-threads 64 \ 9 | --n-nbrs 10 --n-mail 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/exp/apan.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python apan/train.py --seed 0 --prefix exp \ 8 | --epochs 10 --bsize 600 --n-threads 64 \ 9 | --n-nbrs 10 --n-mail 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/exp/jodie-gdelt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python jodie/train.py --seed 0 --prefix exp \ 8 | --epochs 3 --bsize 4000 "$@" 9 | -------------------------------------------------------------------------------- /examples/exp/jodie.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python jodie/train.py --seed 0 --prefix exp \ 8 | --epochs 10 --bsize 600 "$@" 9 | -------------------------------------------------------------------------------- /examples/exp/tgat-gdelt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python tgat/train.py --seed 0 --prefix exp \ 8 | --epochs 3 --bsize 4000 --n-threads 64 \ 9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/exp/tgat.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python tgat/train.py --seed 0 --prefix exp \ 8 | --epochs 10 --bsize 600 --n-threads 64 \ 9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/exp/tgn-gdelt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python tgn/train.py --seed 0 --prefix exp \ 8 | --epochs 3 --bsize 4000 --n-threads 64 \ 9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/exp/tgn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)" 4 | cd "$examples_dir" 5 | export PYTHONPATH="$examples_dir" 6 | 7 | python tgn/train.py --seed 0 --prefix exp \ 8 | --epochs 10 --bsize 600 --n-threads 64 \ 9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \ 10 | --sampling recent "$@" 11 | -------------------------------------------------------------------------------- /examples/gen-data-files.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name') 8 | args = parser.parse_args() 9 | 10 | 11 | def gen_edges_wiki_talk(): 12 | df = pd.read_csv('data/wiki-talk/wiki-talk-temporal.txt', 13 | sep=' ', header=None, names=['src', 'dst', 'time'], 14 | dtype={'src': np.int32, 'dst': np.int32, 'time': np.float32}) 15 | 16 | num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1 17 | num_edges = df.shape[0] 18 | train_end = int(np.ceil(num_edges * 0.70)) 19 | valid_end = int(np.ceil(num_edges * 0.85)) 20 | print('num_nodes:', num_nodes) 21 | print('num_edges:', num_edges) 22 | print('train_end:', train_end) 23 | print('valid_end:', valid_end) 24 | 25 | df['int_roll'] = np.zeros(num_edges, dtype=np.int32) 26 | ext_roll = np.zeros(num_edges, dtype=np.int32) 27 | ext_roll[train_end:] = 1 28 | ext_roll[valid_end:] = 2 29 | df['ext_roll'] = ext_roll 30 | 31 | df.to_csv('data/wiki-talk/edges.csv') 32 | 33 | 34 | if args.data == 'wiki-talk': 35 | gen_edges_wiki_talk() 36 | else: 37 | print('not handling dataset:', args.data) 38 | -------------------------------------------------------------------------------- /examples/jodie/jodie.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import tglite as tg 5 | 6 | from typing import Tuple 7 | from torch import nn, Tensor 8 | from tglite._stats import tt 9 | 10 | import sys, os 11 | sys.path.append(os.path.join(os.getcwd(), '..')) 12 | import support 13 | 14 | 15 | class NormalLinear(nn.Linear): 16 | def reset_parameters(self): 17 | stdv = 1. / math.sqrt(self.weight.size(1)) 18 | self.weight.data.normal_(0, stdv) 19 | if self.bias is not None: 20 | self.bias.data.normal_(0, stdv) 21 | 22 | 23 | class JODIE(nn.Module): 24 | def __init__(self, ctx: tg.TContext, dim_embed: int, dim_node: int, dim_edge: int, dim_time: int): 25 | super().__init__() 26 | self.ctx = ctx 27 | self.dim_embed = dim_embed 28 | self.dim_node = dim_node 29 | self.dim_edge = dim_edge 30 | self.dim_time = dim_time 31 | 32 | dim_input = dim_embed + dim_edge + dim_time 33 | self.updater = nn.RNNCell(dim_input, dim_embed) 34 | self.time_encode = tg.nn.TimeEncode(dim_time) 35 | self.time_linear = NormalLinear(1, dim_embed) 36 | 37 | if dim_node != dim_embed: 38 | self.node_linear = nn.Linear(dim_node, dim_embed) 39 | self.norm = nn.LayerNorm(dim_embed) 40 | self.edge_predictor = support.EdgePredictor(dim_embed) 41 | 42 | def forward(self, batch: tg.TBatch): 43 | size = len(batch) 44 | nodes = batch.nodes() 45 | 46 | t_start = tt.start() 47 | embed, embed_ts = self.update_embed(batch, nodes) 48 | tt.t_mem_update += tt.elapsed(t_start) 49 | embed = self.normalize_embed(batch, nodes, embed) 50 | batch.g.mem.update(batch.nodes(include_negs=False), embed[:2 * size], embed_ts[:2 * size]) 51 | 52 | embed = self.project_embed(batch, embed, embed_ts) 53 | scores = self.edge_predictor(embed[:size], embed[size:2 * size]) 54 | if batch.neg_nodes is not None: 55 | scores = (scores, self.edge_predictor(embed[:size], embed[2 * size: ])) 56 | del embed_ts 57 | del embed 58 | 59 | self.save_raw_msgs(batch) 60 | return scores 61 | 62 | def update_embed(self, batch: tg.TBatch, nodes: np.ndarray) -> Tuple[Tensor, Tensor]: 63 | cdev = batch.g.compute_device() 64 | 65 | embed_time = batch.g.mem.time[nodes] 66 | mail_ts = batch.g.mailbox.time[nodes] 67 | time_feat = (mail_ts - embed_time).to(cdev) 68 | time_feat = self.time_encode(time_feat.squeeze()) 69 | input = batch.g.mailbox.mail[nodes].to(cdev) 70 | input = torch.cat([input, time_feat], dim=1) 71 | 72 | embed = batch.g.mem.data[nodes].to(cdev) 73 | embed = self.updater(input, embed) 74 | 75 | return embed, mail_ts.to(cdev) 76 | 77 | def normalize_embed(self, batch: tg.TBatch, nodes: np.ndarray, embed: Tensor) -> Tensor: 78 | nfeat = batch.g.nfeat[nodes].to(embed.device) 79 | if self.dim_node != self.dim_embed: 80 | embed = embed + self.node_linear(nfeat) 81 | else: 82 | embed = embed + nfeat 83 | return self.norm(embed) 84 | 85 | def project_embed(self, batch: tg.TBatch, embed: Tensor, embed_ts: Tensor) -> Tensor: 86 | times = torch.from_numpy(batch.times()).to(embed_ts.device) 87 | delta = times - embed_ts 88 | time_diff = (delta / (times + 1)).to(embed.device) 89 | return embed * (1 + self.time_linear(time_diff.reshape(-1, 1))) 90 | 91 | def save_raw_msgs(self, batch: tg.TBatch): 92 | sdev = batch.g.storage_device() 93 | blk = batch.block_adj(self.ctx) 94 | 95 | adj_nodes = torch.from_numpy(blk.srcnodes).long().to(sdev) 96 | mail = batch.g.mem.data[adj_nodes] 97 | if self.dim_edge > 0: 98 | eids = torch.from_numpy(blk.eid).long().to(sdev) 99 | mail = torch.cat([mail, batch.g.efeat[eids]], dim=1) 100 | mail_ts = torch.from_numpy(blk.ets).to(sdev) 101 | batch.g.mailbox.store(blk.dstnodes, mail, mail_ts) -------------------------------------------------------------------------------- /examples/jodie/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import tglite as tg 6 | 7 | from jodie import JODIE 8 | import support 9 | 10 | 11 | ### arguments 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name') 15 | parser.add_argument('--data-path', type=str, default='', help='path to data folder') 16 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model') 17 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)') 18 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)') 19 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)') 20 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)') 21 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)') 22 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)') 23 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use') 24 | parser.add_argument('--move', action='store_true', help='move data to device') 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | device = support.make_device(args.gpu) 29 | model_path = support.make_model_path('jodie', args.prefix, args.data) 30 | model_mem_path = support.make_model_mem_path('jodie', args.prefix, args.data) 31 | if args.seed >= 0: 32 | support.set_seed(args.seed) 33 | 34 | DATA: str = args.data 35 | DATA_PATH: str = args.data_path 36 | EPOCHS: int = args.epochs 37 | BATCH_SIZE: int = args.bsize 38 | LEARN_RATE: float = float(args.lr) 39 | DIM_TIME: int = args.dim_time 40 | DIM_EMBED: int = args.dim_embed 41 | 42 | 43 | ### load data 44 | 45 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv')) 46 | support.load_feats(g, DATA, DATA_PATH) 47 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1] 48 | dim_nfeat = g.nfeat.shape[1] 49 | 50 | g.mem = tg.Memory(g.num_nodes(), DIM_EMBED) 51 | g.mailbox = tg.Mailbox(g.num_nodes(), 1, DIM_EMBED + dim_efeat) 52 | 53 | g.set_compute(device) 54 | if args.move: 55 | g.move_data(device) 56 | 57 | ctx = tg.TContext(g) 58 | 59 | 60 | ### model 61 | 62 | model = JODIE(ctx, 63 | dim_embed=DIM_EMBED, 64 | dim_node=dim_nfeat, 65 | dim_edge=dim_efeat, 66 | dim_time=DIM_TIME) 67 | model = model.to(device) 68 | criterion = torch.nn.BCEWithLogitsLoss() 69 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE) 70 | 71 | 72 | ### training 73 | 74 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15) 75 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size) 76 | 77 | ctx = tg.TContext(g) 78 | trainer = support.LinkPredTrainer( 79 | ctx, model, criterion, optimizer, neg_sampler, 80 | EPOCHS, BATCH_SIZE, train_end, val_end, 81 | model_path, model_mem_path) 82 | 83 | trainer.train() 84 | trainer.test() 85 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.3.5 2 | scikit-learn==1.0.2 3 | -------------------------------------------------------------------------------- /examples/support.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import os 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from torch import nn, Tensor 8 | from pathlib import Path 9 | from sklearn.metrics import average_precision_score, roc_auc_score 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | import tglite as tg 13 | from tglite._stats import tt 14 | 15 | 16 | def set_seed(seed: int): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | 23 | def make_device(gpu: int) -> torch.device: 24 | return torch.device(f'cuda:{gpu}' if gpu >= 0 else 'cpu') 25 | 26 | 27 | def make_model_path(model: str, prefix: str, data: str) -> str: 28 | """If prefix is not empty, return 'models/{model}/{prefix}-{data}.pt', else return 29 | 'models/{model}/{data}-{time.time()}.pt'.""" 30 | Path(f'models/{model}').mkdir(parents=True, exist_ok=True) 31 | if prefix: 32 | return f'models/{model}/{prefix}-{data}.pt' 33 | else: 34 | return f'models/{model}/{data}-{time.time()}.pt' 35 | 36 | 37 | def make_model_mem_path(model: str, prefix: str, data: str) -> str: 38 | Path(f'models/{model}').mkdir(parents=True, exist_ok=True) 39 | if prefix: 40 | return f'models/{model}/{prefix}-{data}-mem.pt' 41 | else: 42 | return f'models/{model}/{data}-mem-{time.time()}.pt' 43 | 44 | 45 | def load_graph(path: Union[str, Path]) -> tg.TGraph: 46 | """Create a TGraph with edges and timestamps loaded from path. Provided data should include 47 | 'src' 'dst' and 'time' columns.""" 48 | df = pd.read_csv(str(path)) 49 | 50 | src = df['src'].to_numpy().astype(np.int32).reshape(-1, 1) 51 | dst = df['dst'].to_numpy().astype(np.int32).reshape(-1, 1) 52 | etime = df['time'].to_numpy().astype(np.float32) 53 | del df 54 | 55 | edges = np.concatenate([src, dst], axis=1) 56 | del src 57 | del dst 58 | 59 | g = tg.TGraph(edges, etime) 60 | print('num edges:', g.num_edges()) 61 | print('num nodes:', g.num_nodes()) 62 | return g 63 | 64 | 65 | def load_feats(g: tg.TGraph, d: str, data_path: str=''): 66 | """ 67 | Load edge features and node features to g from data/{d}/edge_features.pt and 68 | data/{d}/edge_features.pt. If no file, create random edge and node features for data 'mooc', 69 | 'lastfm' and 'wiki-talk', create random edge features for data 'wiki' and 'reddit', None for 70 | other data. 71 | """ 72 | edge_feats = None 73 | node_feats = None 74 | 75 | if Path(os.path.join(data_path, f'data/{d}/edge_features.pt')).exists(): 76 | edge_feats = torch.load(os.path.join(data_path, f'data/{d}/edge_features.pt')) 77 | edge_feats = edge_feats.type(torch.float32) 78 | elif d in ['mooc', 'lastfm', 'wiki-talk']: 79 | edge_feats = torch.randn(g.num_edges(), 128, dtype=torch.float32) 80 | 81 | if Path(os.path.join(data_path, f'data/{d}/node_features.pt')).exists(): 82 | node_feats = torch.load(os.path.join(data_path, f'data/{d}/node_features.pt')) 83 | node_feats = node_feats.type(torch.float32) 84 | elif d in ['wiki', 'mooc', 'reddit', 'lastfm', 'wiki-talk']: 85 | node_feats = torch.randn(g.num_nodes(), edge_feats.shape[1], dtype=torch.float32) 86 | 87 | print('edge feat:', None if edge_feats is None else edge_feats.shape) 88 | print('node feat:', None if node_feats is None else node_feats.shape) 89 | g.efeat = edge_feats 90 | g.nfeat = node_feats 91 | 92 | 93 | def data_split(num_samples: int, train_percent: float, val_percent: float) -> Tuple[int, int]: 94 | train_end = int(np.ceil(num_samples * train_percent)) 95 | val_end = int(np.ceil(num_samples * (train_percent + val_percent))) 96 | return train_end, val_end 97 | 98 | 99 | class EdgePredictor(nn.Module): 100 | def __init__(self, dim: int): 101 | super().__init__() 102 | self.dim = dim 103 | self.src_fc = nn.Linear(dim, dim) 104 | self.dst_fc = nn.Linear(dim, dim) 105 | self.out_fc = nn.Linear(dim, 1) 106 | self.act = nn.ReLU() 107 | 108 | def forward(self, src: Tensor, dst: Tensor) -> Tensor: 109 | h_src = self.src_fc(src) 110 | h_dst = self.dst_fc(dst) 111 | h_out = self.act(h_src + h_dst) 112 | return self.out_fc(h_out) 113 | 114 | 115 | class LinkPredTrainer(object): 116 | def __init__(self, ctx: tg.TContext, model: nn.Module, 117 | criterion: nn.Module, optimizer: torch.optim.Optimizer, 118 | neg_sampler: Callable, epochs: int, bsize: int, 119 | train_end: int, val_end: int, 120 | model_path: str, model_mem_path: Optional[str]): 121 | self.ctx = ctx 122 | self.g = ctx.graph 123 | self.model = model 124 | self.criterion = criterion 125 | self.optimizer = optimizer 126 | self.neg_sampler = neg_sampler 127 | self.epochs = epochs 128 | self.bsize = bsize 129 | self.train_end = train_end 130 | self.val_end = val_end 131 | self.model_path = model_path 132 | self.model_mem_path = model_mem_path 133 | 134 | def train(self): 135 | tt.csv_open('out-stats.csv') 136 | tt.csv_write_header() 137 | best_epoch = 0 138 | best_ap = 0 139 | for e in range(self.epochs): 140 | print(f'epoch {e}:') 141 | torch.cuda.synchronize() 142 | t_epoch = tt.start() 143 | 144 | self.ctx.train() 145 | self.model.train() 146 | if self.g.mem is not None: 147 | self.g.mem.reset() 148 | if self.g.mailbox is not None: 149 | self.g.mailbox.reset() 150 | 151 | epoch_loss = 0.0 152 | t_loop = tt.start() 153 | for batch in tg.iter_edges(self.g, size=self.bsize, end=self.train_end): 154 | t_start = tt.start() 155 | batch.neg_nodes = self.neg_sampler(len(batch)) 156 | tt.t_prep_batch += tt.elapsed(t_start) 157 | 158 | t_start = tt.start() 159 | self.optimizer.zero_grad() 160 | pred_pos, pred_neg = self.model(batch) 161 | tt.t_forward += tt.elapsed(t_start) 162 | 163 | t_start = tt.start() 164 | loss = self.criterion(pred_pos, torch.ones_like(pred_pos)) 165 | loss += self.criterion(pred_neg, torch.zeros_like(pred_neg)) 166 | epoch_loss += float(loss) 167 | loss.backward() 168 | self.optimizer.step() 169 | tt.t_backward += tt.elapsed(t_start) 170 | tt.t_loop = tt.elapsed(t_loop) 171 | 172 | t_eval = tt.start() 173 | ap, auc = self.eval(start_idx=self.train_end, end_idx=self.val_end) 174 | tt.t_eval = tt.elapsed(t_eval) 175 | 176 | torch.cuda.synchronize() 177 | tt.t_epoch = tt.elapsed(t_epoch) 178 | if e == 0 or ap > best_ap: 179 | best_epoch = e 180 | best_ap = ap 181 | torch.save(self.model.state_dict(), self.model_path) 182 | if self.g.mem is not None: 183 | torch.save(self.g.mem.backup(), self.model_mem_path) 184 | print(' loss:{:.4f} val ap:{:.4f} val auc:{:.4f}'.format(epoch_loss, ap, auc)) 185 | tt.csv_write_line(epoch=e) 186 | tt.print_epoch() 187 | tt.reset_epoch() 188 | tt.csv_close() 189 | print('best model at epoch {}'.format(best_epoch)) 190 | 191 | @torch.no_grad() 192 | def eval(self, start_idx: int, end_idx: int = None): 193 | self.ctx.eval() 194 | self.model.eval() 195 | val_aps = [] 196 | val_auc = [] 197 | for batch in tg.iter_edges(self.g, size=self.bsize, start=start_idx, end=end_idx): 198 | size = len(batch) 199 | batch.neg_nodes = self.neg_sampler(size) 200 | prob_pos, prob_neg = self.model(batch) 201 | prob_pos = prob_pos.cpu() 202 | prob_neg = prob_neg.cpu() 203 | pred_score = torch.cat([prob_pos, prob_neg], dim=0).sigmoid() 204 | true_label = torch.cat([torch.ones(size), torch.zeros(size)]) 205 | val_aps.append(average_precision_score(true_label, pred_score)) 206 | val_auc.append(roc_auc_score(true_label, pred_score)) 207 | return np.mean(val_aps), np.mean(val_auc) 208 | 209 | def test(self): 210 | print('loading saved checkpoint and testing model...') 211 | self.model.load_state_dict(torch.load(self.model_path)) 212 | if self.g.mem is not None: 213 | self.g.mem.restore(torch.load(self.model_mem_path)) 214 | t_test = tt.start() 215 | ap, auc = self.eval(start_idx=self.val_end) 216 | t_test = tt.elapsed(t_test) 217 | print(' test time:{:.2f}s AP:{:.4f} AUC:{:.4f}'.format(t_test, ap, auc)) 218 | -------------------------------------------------------------------------------- /examples/tgat/TGAT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "id": "F4O-2MjGfOMr" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import os\n", 13 | "import numpy as np\n", 14 | "import tglite as tg\n", 15 | "\n", 16 | "from tgat import TGAT\n", 17 | "import support" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "id": "tH_6K5u3iQhy" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "DATA: str = 'wiki' # 'wiki', 'reddit', 'mooc', 'mag', 'lastfm', 'gdelt', 'wiki-talk'\n", 29 | "DATA_PATH: str = '/shared'\n", 30 | "EPOCHS: int = 10\n", 31 | "BATCH_SIZE: int = 200\n", 32 | "LEARN_RATE: float = 0.0001\n", 33 | "DROPOUT: float = 0.1\n", 34 | "N_LAYERS: int = 2\n", 35 | "N_HEADS: int = 2\n", 36 | "N_NBRS: int = 20\n", 37 | "DIM_TIME: int = 100\n", 38 | "DIM_EMBED: int = 100\n", 39 | "N_THREADS: int = 32\n", 40 | "SAMPLING: str = 'recent' # 'recent'or 'uniform'\n", 41 | "OPT_DEDUP = True\n", 42 | "OPT_CACHE = True\n", 43 | "OPT_TIME = True\n", 44 | "OPT_ALL = True\n", 45 | "OPT_DEDUP: bool = OPT_DEDUP or OPT_ALL\n", 46 | "OPT_CACHE: bool = OPT_CACHE or OPT_ALL\n", 47 | "OPT_TIME: bool = OPT_TIME or OPT_ALL\n", 48 | "CACHE_LIMIT: int = int(2e6)\n", 49 | "TIME_WINDOW: int = int(1e4)\n", 50 | "\n", 51 | "MOVE = True\n", 52 | "GPU = 0\n", 53 | "SEED = 1\n", 54 | "PREFIX = ''" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "device = support.make_device(GPU)\n", 64 | "model_path = support.make_model_path('tgat', PREFIX, DATA)\n", 65 | "if SEED >= 0:\n", 66 | " support.set_seed(SEED)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": { 73 | "id": "COgTQXmcgCEr" 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "num edges: 157474\n", 81 | "num nodes: 9228\n", 82 | "edge feat: torch.Size([157474, 172])\n", 83 | "node feat: torch.Size([9228, 172])\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "### load data\n", 89 | "\n", 90 | "g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))\n", 91 | "support.load_feats(g, DATA, DATA_PATH)\n", 92 | "dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]\n", 93 | "dim_nfeat = g.nfeat.shape[1]\n", 94 | "\n", 95 | "g.set_compute(device)\n", 96 | "if MOVE:\n", 97 | " g.move_data(device)\n", 98 | "\n", 99 | "ctx = tg.TContext(g)\n", 100 | "ctx.need_sampling(True)\n", 101 | "ctx.enable_embed_caching(OPT_CACHE, DIM_EMBED)\n", 102 | "ctx.enable_time_precompute(OPT_TIME)\n", 103 | "ctx.set_cache_limit(CACHE_LIMIT)\n", 104 | "ctx.set_time_window(TIME_WINDOW)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": { 111 | "id": "5-YKifBwLuMS" 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "### model\n", 116 | "\n", 117 | "sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)\n", 118 | "model = TGAT(ctx,\n", 119 | " dim_node=dim_nfeat,\n", 120 | " dim_edge=dim_efeat,\n", 121 | " dim_time=DIM_TIME,\n", 122 | " dim_embed=DIM_EMBED,\n", 123 | " sampler=sampler,\n", 124 | " num_layers=N_LAYERS,\n", 125 | " num_heads=N_HEADS,\n", 126 | " dropout=DROPOUT,\n", 127 | " dedup=OPT_DEDUP,)\n", 128 | "model = model.to(device)\n", 129 | "criterion = torch.nn.BCEWithLogitsLoss()\n", 130 | "optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": { 137 | "id": "DLZ8cipJLx-u" 138 | }, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "epoch 0:\n", 145 | " loss:295.7572 val ap:0.9739 val auc:0.9782\n", 146 | " epoch | total:14.08s loop:12.34s eval:1.74s\n", 147 | " loop | forward:7.82s backward:4.46s sample:0.49s prep_batch:0.05s prep_input:1.03s post_update:0.00s\n", 148 | " comp | mem_update:0.00s time_zero:1.15s time_nbrs:0.93s self_attn:3.64s\n", 149 | "epoch 1:\n", 150 | " loss:170.5700 val ap:0.9828 val auc:0.9853\n", 151 | " epoch | total:13.69s loop:11.75s eval:1.94s\n", 152 | " loop | forward:6.86s backward:4.83s sample:0.50s prep_batch:0.05s prep_input:0.93s post_update:0.00s\n", 153 | " comp | mem_update:0.00s time_zero:0.21s time_nbrs:1.18s self_attn:3.71s\n", 154 | "epoch 2:\n", 155 | " loss:142.4620 val ap:0.9828 val auc:0.9855\n", 156 | " epoch | total:14.43s loop:12.56s eval:1.86s\n", 157 | " loop | forward:7.24s backward:5.26s sample:0.48s prep_batch:0.05s prep_input:0.82s post_update:0.00s\n", 158 | " comp | mem_update:0.00s time_zero:0.25s time_nbrs:1.30s self_attn:3.99s\n", 159 | "epoch 3:\n", 160 | " loss:128.9173 val ap:0.9850 val auc:0.9872\n", 161 | " epoch | total:13.71s loop:12.58s eval:1.12s\n", 162 | " loop | forward:6.61s backward:5.91s sample:0.52s prep_batch:0.06s prep_input:0.55s post_update:0.00s\n", 163 | " comp | mem_update:0.00s time_zero:0.27s time_nbrs:0.96s self_attn:3.47s\n", 164 | "epoch 4:\n", 165 | " loss:123.7265 val ap:0.9852 val auc:0.9874\n", 166 | " epoch | total:16.85s loop:15.26s eval:1.58s\n", 167 | " loop | forward:7.32s backward:7.84s sample:0.82s prep_batch:0.09s prep_input:0.52s post_update:0.00s\n", 168 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.28s self_attn:3.50s\n", 169 | "epoch 5:\n", 170 | " loss:116.0836 val ap:0.9843 val auc:0.9871\n", 171 | " epoch | total:16.34s loop:14.75s eval:1.58s\n", 172 | " loop | forward:7.10s backward:7.56s sample:0.79s prep_batch:0.08s prep_input:0.51s post_update:0.00s\n", 173 | " comp | mem_update:0.00s time_zero:0.40s time_nbrs:1.27s self_attn:3.39s\n", 174 | "epoch 6:\n", 175 | " loss:113.4501 val ap:0.9877 val auc:0.9894\n", 176 | " epoch | total:18.37s loop:17.14s eval:1.21s\n", 177 | " loop | forward:7.92s backward:9.13s sample:0.81s prep_batch:0.08s prep_input:0.51s post_update:0.00s\n", 178 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.21s self_attn:3.82s\n", 179 | "epoch 7:\n", 180 | " loss:108.1715 val ap:0.9869 val auc:0.9889\n", 181 | " epoch | total:18.27s loop:17.03s eval:1.23s\n", 182 | " loop | forward:7.81s backward:9.12s sample:0.65s prep_batch:0.09s prep_input:0.52s post_update:0.00s\n", 183 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.20s self_attn:3.81s\n", 184 | "epoch 8:\n", 185 | " loss:103.1291 val ap:0.9882 val auc:0.9897\n", 186 | " epoch | total:17.87s loop:16.68s eval:1.18s\n", 187 | " loop | forward:7.59s backward:8.99s sample:0.64s prep_batch:0.09s prep_input:0.50s post_update:0.00s\n", 188 | " comp | mem_update:0.00s time_zero:0.39s time_nbrs:1.17s self_attn:3.71s\n", 189 | "epoch 9:\n", 190 | " loss:101.0907 val ap:0.9899 val auc:0.9912\n", 191 | " epoch | total:18.85s loop:17.23s eval:1.61s\n", 192 | " loop | forward:7.75s backward:9.36s sample:0.85s prep_batch:0.11s prep_input:0.56s post_update:0.00s\n", 193 | " comp | mem_update:0.00s time_zero:0.43s time_nbrs:1.32s self_attn:3.64s\n", 194 | "best model at epoch 9\n", 195 | "loading saved checkpoint and testing model...\n", 196 | " test time:1.54s AP:0.9854 AUC:0.9876\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "### training\n", 202 | "\n", 203 | "train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)\n", 204 | "neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)\n", 205 | "\n", 206 | "trainer = support.LinkPredTrainer(\n", 207 | " ctx, model, criterion, optimizer, neg_sampler,\n", 208 | " EPOCHS, BATCH_SIZE, train_end, val_end,\n", 209 | " model_path, None)\n", 210 | "\n", 211 | "trainer.train()\n", 212 | "trainer.test()" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "colab": { 218 | "provenance": [] 219 | }, 220 | "kernelspec": { 221 | "display_name": "Python 3", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.7.16" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 0 239 | } 240 | -------------------------------------------------------------------------------- /examples/tgat/tgat.py: -------------------------------------------------------------------------------- 1 | import tglite as tg 2 | 3 | from torch import nn, Tensor 4 | from tglite.nn import TemporalAttnLayer 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.join(os.getcwd(), '..')) 9 | import support 10 | 11 | 12 | class TGAT(nn.Module): 13 | def __init__(self, ctx: tg.TContext, 14 | dim_node: int, dim_edge: int, dim_time: int, dim_embed: int, 15 | sampler: tg.TSampler, num_layers=2, num_heads=2, dropout=0.1, 16 | dedup: bool = True): 17 | super().__init__() 18 | self.ctx = ctx 19 | self.num_layers = num_layers 20 | self.attn = nn.ModuleList([ 21 | TemporalAttnLayer(ctx, 22 | num_heads=num_heads, 23 | dim_node=dim_node if i == 0 else dim_embed, 24 | dim_edge=dim_edge, 25 | dim_time=dim_time, 26 | dim_out=dim_embed, 27 | dropout=dropout) 28 | for i in range(num_layers)]) 29 | self.sampler = sampler 30 | self.edge_predictor = support.EdgePredictor(dim=dim_embed) 31 | self.dedup = dedup 32 | 33 | def forward(self, batch: tg.TBatch) -> Tensor: 34 | head = batch.block(self.ctx) 35 | for i in range(self.num_layers): 36 | tail = head if i == 0 \ 37 | else tail.next_block(include_dst=True) 38 | tail = tg.op.dedup(tail) if self.dedup else tail 39 | tail = tg.op.cache(self.ctx, tail.layer, tail) 40 | tail = self.sampler.sample(tail) 41 | 42 | tg.op.preload(head, use_pin=True) 43 | if tail.num_dst() > 0: 44 | tail.dstdata['h'] = tail.dstfeat() 45 | tail.srcdata['h'] = tail.srcfeat() 46 | embeds = tg.op.aggregate(head, list(reversed(self.attn)), key='h') 47 | del head 48 | del tail 49 | 50 | src, dst, neg = batch.split_data(embeds) 51 | scores = self.edge_predictor(src, dst) 52 | if batch.neg_nodes is not None: 53 | scores = (scores, self.edge_predictor(src, neg)) 54 | 55 | return scores -------------------------------------------------------------------------------- /examples/tgat/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import tglite as tg 6 | 7 | from tgat import TGAT 8 | import support 9 | 10 | ### arguments 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name') 14 | parser.add_argument('--data-path', type=str, default='', help='path to data folder') 15 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model') 16 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)') 17 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)') 18 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)') 19 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)') 20 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)') 21 | parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)') 22 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)') 23 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)') 24 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)') 25 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)') 26 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use') 27 | parser.add_argument('--move', action='store_true', help='move data to device') 28 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)') 29 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)') 30 | parser.add_argument('--opt-dedup', action='store_true', help='enable dedup optimization') 31 | parser.add_argument('--opt-cache', action='store_true', help='enable caching optimization') 32 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings') 33 | parser.add_argument('--cache-limit', type=str, default=2e6, help='max number of embeds to cache (default: 2e6)') 34 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)') 35 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations') 36 | args = parser.parse_args() 37 | print(args) 38 | 39 | device = support.make_device(args.gpu) 40 | model_path = support.make_model_path('tgat', args.prefix, args.data) 41 | if args.seed >= 0: 42 | support.set_seed(args.seed) 43 | 44 | DATA: str = args.data 45 | DATA_PATH: str = args.data_path 46 | EPOCHS: int = args.epochs 47 | BATCH_SIZE: int = args.bsize 48 | LEARN_RATE: float = float(args.lr) 49 | DROPOUT: float = float(args.dropout) 50 | N_LAYERS: int = args.n_layers 51 | N_HEADS: int = args.n_heads 52 | N_NBRS: int = args.n_nbrs 53 | DIM_TIME: int = args.dim_time 54 | DIM_EMBED: int = args.dim_embed 55 | N_THREADS: int = args.n_threads 56 | SAMPLING: str = args.sampling 57 | OPT_DEDUP: bool = args.opt_dedup or args.opt_all 58 | OPT_CACHE: bool = args.opt_cache or args.opt_all 59 | OPT_TIME: bool = args.opt_time or args.opt_all 60 | CACHE_LIMIT: int = int(args.cache_limit) 61 | TIME_WINDOW: int = int(args.time_window) 62 | 63 | 64 | ### load data 65 | 66 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv')) 67 | support.load_feats(g, DATA, DATA_PATH) 68 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1] 69 | dim_nfeat = g.nfeat.shape[1] 70 | 71 | g.set_compute(device) 72 | if args.move: 73 | g.move_data(device) 74 | 75 | ctx = tg.TContext(g) 76 | ctx.need_sampling(True) 77 | ctx.enable_embed_caching(OPT_CACHE, DIM_EMBED) 78 | ctx.enable_time_precompute(OPT_TIME) 79 | ctx.set_cache_limit(CACHE_LIMIT) 80 | ctx.set_time_window(TIME_WINDOW) 81 | 82 | 83 | ### model 84 | 85 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS) 86 | model = TGAT(ctx, 87 | dim_node=dim_nfeat, 88 | dim_edge=dim_efeat, 89 | dim_time=DIM_TIME, 90 | dim_embed=DIM_EMBED, 91 | sampler=sampler, 92 | num_layers=N_LAYERS, 93 | num_heads=N_HEADS, 94 | dropout=DROPOUT) 95 | model = model.to(device) 96 | criterion = torch.nn.BCEWithLogitsLoss() 97 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE) 98 | 99 | 100 | ### training 101 | 102 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15) 103 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size) 104 | 105 | trainer = support.LinkPredTrainer( 106 | ctx, model, criterion, optimizer, neg_sampler, 107 | EPOCHS, BATCH_SIZE, train_end, val_end, 108 | model_path, None) 109 | 110 | trainer.train() 111 | trainer.test() -------------------------------------------------------------------------------- /examples/tgn/tgn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tglite as tg 3 | 4 | from torch import nn, Tensor 5 | from tglite.nn import TemporalAttnLayer 6 | from tglite._stats import tt 7 | 8 | import sys, os 9 | sys.path.append(os.path.join(os.getcwd(), '..')) 10 | import support 11 | 12 | 13 | class TGN(nn.Module): 14 | def __init__(self, ctx: tg.TContext, 15 | dim_node: int, dim_edge: int, dim_time: int, dim_embed: int, 16 | sampler: tg.TSampler, num_layers=2, num_heads=2, dropout=0.1, 17 | dedup: bool = True): 18 | super().__init__() 19 | self.ctx = ctx 20 | self.dim_edge = dim_edge 21 | self.num_layers = num_layers 22 | self.nfeat_map = None if dim_node == dim_embed else nn.Linear(dim_node, dim_embed) 23 | self.mem_cell = nn.GRUCell(2 * dim_embed + dim_edge + dim_time, dim_embed) 24 | self.mem_time_encode = tg.nn.TimeEncode(dim_time) 25 | self.attn = nn.ModuleList([ 26 | TemporalAttnLayer(ctx, 27 | num_heads=num_heads, 28 | dim_node=dim_embed, 29 | dim_edge=dim_edge, 30 | dim_time=dim_time, 31 | dim_out=dim_embed, 32 | dropout=dropout) 33 | for i in range(num_layers)]) 34 | self.sampler = sampler 35 | self.edge_predictor = support.EdgePredictor(dim=dim_embed) 36 | self.dedup = dedup 37 | 38 | def forward(self, batch: tg.TBatch) -> Tensor: 39 | # setup message passing 40 | head = batch.block(self.ctx) 41 | for i in range(self.num_layers): 42 | tail = head if i == 0 \ 43 | else tail.next_block(include_dst=True, use_dst_times=False) 44 | tail = tg.op.dedup(tail) if self.dedup else tail 45 | tail = self.sampler.sample(tail) 46 | 47 | # load data / feats 48 | tg.op.preload(head, use_pin=True) 49 | if tail.num_dst() > 0: 50 | t_start = tt.start() 51 | mem = self.update_memory(tail) 52 | nfeat = tail.nfeat() if self.nfeat_map is None else self.nfeat_map(tail.nfeat()) 53 | tail.dstdata['h'] = nfeat[:tail.num_dst()] + mem[:tail.num_dst()] 54 | tail.srcdata['h'] = nfeat[tail.num_dst():] + mem[tail.num_dst():] 55 | tt.t_mem_update += tt.elapsed(t_start) 56 | del nfeat 57 | del mem 58 | 59 | # compute embeddings 60 | embeds = tg.op.aggregate(head, list(reversed(self.attn)), key='h') 61 | del head 62 | del tail 63 | 64 | # compute scores 65 | src, dst, neg = batch.split_data(embeds) 66 | scores = self.edge_predictor(src, dst) 67 | if neg is not None: 68 | scores = (scores, self.edge_predictor(src, neg)) 69 | del embeds 70 | del src 71 | del dst 72 | del neg 73 | 74 | # memory messages 75 | t_start = tt.start() 76 | self.save_raw_msgs(batch) 77 | tt.t_post_update += tt.elapsed(t_start) 78 | 79 | return scores 80 | 81 | def update_memory(self, blk: tg.TBlock) -> Tensor: 82 | cdev = blk.g.compute_device() 83 | nodes = blk.allnodes() 84 | 85 | mail_ts = blk.g.mailbox.time[nodes] 86 | delta = mail_ts - blk.g.mem.time[nodes] 87 | delta = delta.squeeze().to(cdev) 88 | mail = tg.op.precomputed_times(self.ctx, 0, self.mem_time_encode, delta) 89 | mail = torch.cat([blk.mail(), mail], dim=1) 90 | 91 | mem = blk.mem_data() 92 | mem = self.mem_cell(mail, mem) 93 | blk.g.mem.update(nodes, mem, mail_ts) 94 | return mem 95 | 96 | def save_raw_msgs(self, batch: tg.TBatch): 97 | sdev = batch.g.storage_device() 98 | mem = batch.g.mem.data 99 | 100 | blk = batch.block_adj(self.ctx) 101 | blk = tg.op.coalesce(blk, by='latest') 102 | 103 | uniq = torch.from_numpy(blk.dstnodes).long().to(sdev) 104 | nbrs = torch.from_numpy(blk.srcnodes).long().to(sdev) 105 | if self.dim_edge > 0: 106 | eids = torch.from_numpy(blk.eid).long().to(sdev) 107 | mail = torch.cat([mem[uniq], mem[nbrs], batch.g.efeat[eids]], dim=1) 108 | else: 109 | mail = torch.cat([mem[uniq], mem[nbrs]], dim=1) 110 | mail_ts = torch.from_numpy(blk.ets).to(sdev) 111 | batch.g.mailbox.store(uniq, mail, mail_ts) -------------------------------------------------------------------------------- /examples/tgn/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import tglite as tg 6 | 7 | import support 8 | from tgn import TGN 9 | 10 | 11 | ### arguments 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name') 15 | parser.add_argument('--data-path', type=str, default='', help='path to data folder') 16 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model') 17 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)') 18 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)') 19 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)') 20 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)') 21 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)') 22 | parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)') 23 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)') 24 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)') 25 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)') 26 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)') 27 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use') 28 | parser.add_argument('--move', action='store_true', help='move data to device') 29 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)') 30 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)') 31 | parser.add_argument('--opt-dedup', action='store_true', help='enable dedup optimization') 32 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings') 33 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)') 34 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations') 35 | args = parser.parse_args() 36 | print(args) 37 | 38 | device = support.make_device(args.gpu) 39 | model_path = support.make_model_path('tgn', args.prefix, args.data) 40 | model_mem_path = support.make_model_mem_path('tgn', args.prefix, args.data) 41 | if args.seed >= 0: 42 | support.set_seed(args.seed) 43 | 44 | DATA: str = args.data 45 | DATA_PATH: str = args.data_path 46 | EPOCHS: int = args.epochs 47 | BATCH_SIZE: int = args.bsize 48 | LEARN_RATE: float = float(args.lr) 49 | DROPOUT: float = float(args.dropout) 50 | N_LAYERS: int = args.n_layers 51 | N_HEADS: int = args.n_heads 52 | N_NBRS: int = args.n_nbrs 53 | DIM_TIME: int = args.dim_time 54 | DIM_EMBED: int = args.dim_embed 55 | N_THREADS: int = args.n_threads 56 | SAMPLING: str = args.sampling 57 | OPT_DEDUP: bool = args.opt_dedup or args.opt_all 58 | OPT_TIME: bool = args.opt_time or args.opt_all 59 | TIME_WINDOW: int = int(args.time_window) 60 | 61 | 62 | ### load data 63 | 64 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv')) 65 | support.load_feats(g, DATA, DATA_PATH) 66 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1] 67 | dim_nfeat = g.nfeat.shape[1] 68 | 69 | g.mailbox = tg.Mailbox(g.num_nodes(), 1, 2 * DIM_EMBED + dim_efeat) 70 | g.mem = tg.Memory(g.num_nodes(), DIM_EMBED) 71 | 72 | g.set_compute(device) 73 | if args.move: 74 | g.move_data(device) 75 | 76 | ctx = tg.TContext(g) 77 | ctx.need_sampling(True) 78 | ctx.enable_time_precompute(OPT_TIME) 79 | ctx.set_time_window(TIME_WINDOW) 80 | 81 | 82 | ### model 83 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS) 84 | model = TGN(ctx, 85 | dim_node=dim_nfeat, 86 | dim_edge=dim_efeat, 87 | dim_time=DIM_TIME, 88 | dim_embed=DIM_EMBED, 89 | sampler = sampler, 90 | num_layers=N_LAYERS, 91 | num_heads=N_HEADS, 92 | dropout=DROPOUT) 93 | model = model.to(device) 94 | criterion = torch.nn.BCEWithLogitsLoss() 95 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE) 96 | 97 | 98 | ### training 99 | 100 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15) 101 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size) 102 | 103 | trainer = support.LinkPredTrainer( 104 | ctx, model, criterion, optimizer, neg_sampler, 105 | EPOCHS, BATCH_SIZE, train_end, val_end, 106 | model_path, model_mem_path) 107 | 108 | trainer.train() 109 | trainer.test() 110 | -------------------------------------------------------------------------------- /include/tglite/cache.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tglite/core.h" 4 | 5 | namespace tglite { 6 | 7 | py::tuple find_dedup_time_hits( 8 | torch::Tensor ×, 9 | torch::Tensor &time_table, 10 | int time_window); 11 | 12 | py::array_t compute_cache_keys(py::array_t &nodes, py::array_t ×); 13 | 14 | /// Table for caching computed embeddings. 15 | class EmbedTable { 16 | public: 17 | EmbedTable(ssize_t dim_emb, ssize_t limit); 18 | 19 | py::tuple lookup(py::array_t &keys, torch::Device &device); 20 | 21 | void store(py::array_t &keys, torch::Tensor &values); 22 | 23 | private: 24 | ssize_t _dim_emb; 25 | ssize_t _limit; 26 | 27 | ssize_t _start = 0; 28 | torch::Tensor _table; 29 | std::vector _keys; 30 | std::unordered_map _key2idx; 31 | // torch::Tensor _pin; 32 | }; 33 | 34 | } // namespace tglite 35 | -------------------------------------------------------------------------------- /include/tglite/core.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tglite { 7 | 8 | typedef int32_t IdI32; 9 | typedef float TsF32; 10 | 11 | } // namespace tglite 12 | -------------------------------------------------------------------------------- /include/tglite/dedup.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tglite/core.h" 4 | 5 | namespace tglite { 6 | 7 | py::tuple dedup_targets(py::array_t &nodes, py::array_t ×); 8 | 9 | // py::tuple dedup_indices(torch::Tensor &nodes, torch::Tensor ×); 10 | 11 | // bool dedup_targets( 12 | // const IdI32 *nodes_ptr, 13 | // const TsF32 *times_ptr, 14 | // size_t len, 15 | // const IdI32 *pre_nodes, 16 | // const TsF32 *pre_times, 17 | // size_t pre_len, 18 | // std::vector &uniq_nodes, 19 | // std::vector &uniq_times, 20 | // std::vector &inv_idx 21 | // ); 22 | 23 | } // namespace tglite 24 | -------------------------------------------------------------------------------- /include/tglite/sampler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tglite/core.h" 4 | #include "tglite/tcsr.h" 5 | 6 | namespace tglite { 7 | 8 | class TemporalBlock { 9 | public: 10 | size_t num_edges = 0; 11 | std::vector dstindex; 12 | std::vector srcnodes; 13 | std::vector eid; 14 | std::vector ets; 15 | 16 | TemporalBlock() {} 17 | 18 | py::array_t dstindex_copy() const { return to_pyarray_copy(dstindex); } 19 | py::array_t srcnodes_copy() const { return to_pyarray_copy(srcnodes); } 20 | py::array_t eid_copy() const { return to_pyarray_copy(eid); } 21 | py::array_t ets_copy() const { return to_pyarray_copy(ets); } 22 | 23 | // py::array_t dstindex_owned() { 24 | // auto *ptr = dstindex; 25 | // dstindex = nullptr; 26 | // return to_pyarray_owned(ptr); 27 | // } 28 | 29 | // py::array_t srcnodes_owned() { 30 | // auto *ptr = srcnodes; 31 | // srcnodes = nullptr; 32 | // return to_pyarray_owned(ptr); 33 | // } 34 | 35 | // py::array_t eid_owned() { 36 | // auto *ptr = eid; 37 | // eid = nullptr; 38 | // return to_pyarray_owned(ptr); 39 | // } 40 | 41 | // py::array_t ets_owned() { 42 | // auto *ptr = ets; 43 | // ets = nullptr; 44 | // return to_pyarray_owned(ptr); 45 | // } 46 | }; 47 | 48 | class TemporalSampler { 49 | public: 50 | TemporalSampler(int num_threads, int num_nbrs, bool recent); 51 | 52 | TemporalBlock sample(TCSR &tcsr, 53 | py::array_t &nodes, 54 | py::array_t ×); 55 | 56 | private: 57 | int _num_threads; 58 | int _num_nbrs; 59 | bool _recent; 60 | 61 | void sample_layer(TCSR &tcsr, TemporalBlock &block, 62 | const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t size); 63 | 64 | void add_neighbor(TCSR &tcsr, 65 | std::vector *eid, 66 | std::vector *ets, 67 | std::vector *srcnodes, 68 | std::vector *dstindex, 69 | IdI32 &k, IdI32 &dst_idx); 70 | 71 | void combine_coo( 72 | TemporalBlock &block, 73 | std::vector **eid, 74 | std::vector **ets, 75 | std::vector **srcnodes, 76 | std::vector **dstindex, 77 | std::vector &out_nodes); 78 | }; 79 | 80 | } // namespace tglite 81 | -------------------------------------------------------------------------------- /include/tglite/tcsr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tglite/core.h" 4 | #include "tglite/utils.h" 5 | 6 | namespace tglite { 7 | 8 | class TCSR { 9 | public: 10 | std::vector ind; 11 | std::vector nbr; 12 | std::vector eid; 13 | std::vector ets; 14 | 15 | TCSR() { 16 | ind.push_back(0); 17 | } 18 | 19 | py::array_t ind_view() const { return to_pyarray_view(ind); } 20 | py::array_t nbr_view() const { return to_pyarray_view(nbr); } 21 | py::array_t eid_view() const { return to_pyarray_view(eid); } 22 | py::array_t ets_view() const { return to_pyarray_view(ets); } 23 | }; 24 | 25 | TCSR create_tcsr(py::array_t &edges, py::array_t ×, size_t num_nodes); 26 | 27 | } // namespace tglite 28 | -------------------------------------------------------------------------------- /include/tglite/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tglite/core.h" 4 | 5 | namespace tglite { 6 | 7 | // torch::Tensor index_pinned(torch::Tensor &input, torch::Tensor &index); 8 | 9 | // py::tuple find_last_message(py::array_t &uniq_nodes, py::array_t &edges); 10 | py::array_t find_latest_uniq(py::array_t &uniq, py::array_t &nodes, py::array_t ×); 11 | 12 | /// Custom hash function for collision-free keys. 13 | inline int64_t opt_hash(int32_t &s, float &t) { 14 | return (static_cast(s) << 32) | static_cast(t); 15 | } 16 | 17 | template 18 | inline py::array_t to_pyarray_view(const T &seq) { 19 | if (seq.size() > 0) { 20 | auto capsule = py::capsule(&seq, [](void* p) { /* borrowed */ }); 21 | return py::array(seq.size(), seq.data(), capsule); 22 | } else { 23 | return py::array(); 24 | } 25 | } 26 | 27 | template 28 | inline py::array_t to_pyarray_copy(const T &seq) { 29 | if (seq.size() > 0) { 30 | T* copy_ptr = new T(seq); 31 | auto capsule = py::capsule(copy_ptr, [](void* p) { delete reinterpret_cast(p); }); 32 | return py::array(copy_ptr->size(), copy_ptr->data(), capsule); 33 | } else { 34 | return py::array(); 35 | } 36 | } 37 | 38 | template 39 | inline py::array_t to_pyarray_owned(T *seq_ptr) { 40 | if (seq_ptr && seq_ptr->size() > 0) { 41 | auto capsule = py::capsule(seq_ptr, [](void* p) { delete reinterpret_cast(p); }); 42 | return py::array(seq_ptr->size(), seq_ptr->data(), capsule); 43 | } else { 44 | return py::array(); 45 | } 46 | } 47 | 48 | } // namespace tglite 49 | -------------------------------------------------------------------------------- /lib/bind.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/core.h" 2 | #include "tglite/cache.h" 3 | #include "tglite/dedup.h" 4 | #include "tglite/sampler.h" 5 | #include "tglite/tcsr.h" 6 | #include "tglite/utils.h" 7 | 8 | using namespace tglite; 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | py::class_(m, "TCSR") 12 | .def(py::init<>()) 13 | .def_property_readonly("ind", &TCSR::ind_view) 14 | .def_property_readonly("nbr", &TCSR::nbr_view) 15 | .def_property_readonly("eid", &TCSR::eid_view) 16 | .def_property_readonly("ets", &TCSR::ets_view); 17 | 18 | py::class_(m, "TemporalBlock") 19 | .def(py::init<>()) 20 | .def("copy_eid", &TemporalBlock::eid_copy) 21 | .def("copy_ets", &TemporalBlock::ets_copy) 22 | .def("copy_srcnodes", &TemporalBlock::srcnodes_copy) 23 | .def("copy_dstindex", &TemporalBlock::dstindex_copy) 24 | .def("num_edges", [](const TemporalBlock &b) { return b.num_edges; }); 25 | 26 | py::class_(m, "TemporalSampler") 27 | .def(py::init()) 28 | .def("sample", &TemporalSampler::sample); 29 | 30 | py::class_(m, "EmbedTable") 31 | .def(py::init()) 32 | // .def("_table", [](const EmbedTable &et) { return et._table; }) 33 | // .def("_keys", [](const EmbedTable &et) { return et._keys; }) 34 | // .def("_map", [](const EmbedTable &et) { return et._key2idx; }) 35 | .def("lookup", &EmbedTable::lookup, "tglite::EmbedTable::lookup") 36 | .def("store", &EmbedTable::store, "tglite::EmbedTable::store"); 37 | 38 | m.def("create_tcsr", &create_tcsr, "tglite::create_tcsr"); 39 | m.def("dedup_targets", &dedup_targets, "tglite::dedup_targets"); 40 | m.def("find_latest_uniq", &find_latest_uniq, "tglite::find_latest_uniq"); 41 | // m.def("find_last_message", &find_last_message, "tglite::find_last_message"); 42 | m.def("find_dedup_time_hits", &find_dedup_time_hits, "tglite::find_dedup_time_hits"); 43 | m.def("compute_cache_keys", &compute_cache_keys, "tglite::compute_cache_keys"); 44 | 45 | // m.def("dedup_indices", &dedup_indices, "tglite::dedup_indices"); 46 | // m.def("index_pinned", &index_pinned, "tglite::index_pinned"); 47 | } 48 | -------------------------------------------------------------------------------- /lib/cache.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/cache.h" 2 | #include "tglite/utils.h" 3 | 4 | namespace tglite { 5 | 6 | py::tuple find_dedup_time_hits( 7 | torch::Tensor ×, 8 | torch::Tensor &time_table, 9 | int time_window 10 | ) { 11 | auto tup = torch::_unique(times.flatten(), /*sorted=*/true, /*return_inverse=*/true); 12 | 13 | times = std::get<0>(tup); 14 | auto inv_idx = std::get<1>(tup); 15 | 16 | auto delta_int = times.to(torch::kInt64); 17 | auto hit_idx = (delta_int == times) & (0 <= delta_int) & (delta_int <= time_window); 18 | auto hit_delta = delta_int.index({hit_idx}); 19 | 20 | int64_t hit_count = torch::sum(hit_idx).item().toLong(); 21 | int64_t uniq_size = times.size(0); 22 | 23 | torch::Tensor out_embeds; 24 | if (hit_count == uniq_size) { 25 | out_embeds = time_table.index({hit_delta}); 26 | } else { 27 | int64_t time_dim = time_table.size(1); 28 | auto opts = torch::TensorOptions().device(time_table.device()); 29 | out_embeds = torch::zeros({uniq_size, time_dim}, opts); 30 | out_embeds.index_put_({hit_idx}, time_table.index({hit_delta})); 31 | } 32 | 33 | return py::make_tuple(hit_count, hit_idx, out_embeds, times, inv_idx); 34 | } 35 | 36 | py::array_t compute_cache_keys(py::array_t &nodes, py::array_t ×) { 37 | ssize_t size = nodes.size(); 38 | auto keys = py::array_t(size); 39 | auto keys_ptr = static_cast(keys.request().ptr); 40 | auto node_ptr = static_cast(nodes.request().ptr); 41 | auto time_ptr = static_cast(times.request().ptr); 42 | for (ssize_t i = 0; i < size; i++) { 43 | keys_ptr[i] = opt_hash(node_ptr[i], time_ptr[i]); 44 | } 45 | return keys; 46 | } 47 | 48 | EmbedTable:: 49 | EmbedTable(ssize_t dim_emb, ssize_t limit) 50 | : _dim_emb(dim_emb), _limit(limit) { 51 | ssize_t start_capacity = std::min((ssize_t)1024, limit); 52 | _table = torch::zeros({start_capacity, dim_emb}, torch::TensorOptions().dtype(torch::kFloat32)); 53 | _key2idx.reserve(start_capacity); 54 | _keys.resize(start_capacity); 55 | } 56 | 57 | py::tuple EmbedTable:: 58 | lookup(py::array_t &keys, torch::Device &device) { 59 | ssize_t size = keys.size(); 60 | auto output = torch::zeros({size, _dim_emb}); 61 | auto hit_idx = torch::zeros(size, torch::TensorOptions().dtype(torch::kBool)); 62 | 63 | auto *keys_ptr = static_cast(keys.request().ptr); 64 | auto *hit_ptr = hit_idx.accessor().data(); 65 | std::vector indices; 66 | indices.reserve(size); 67 | 68 | for (ssize_t i = 0; i < size; i++) { 69 | auto it = _key2idx.find(keys_ptr[i]); 70 | if (it != _key2idx.end()) { 71 | indices.push_back(it->second); 72 | hit_ptr[i] = true; 73 | } 74 | } 75 | 76 | output.index_put_({hit_idx}, _table.index( 77 | {torch::from_blob(indices.data(), indices.size(), 78 | torch::TensorOptions().dtype(torch::kLong))})); 79 | 80 | output = output.to(device); 81 | hit_idx = hit_idx.to(device); 82 | return py::make_tuple(hit_idx, output); 83 | } 84 | 85 | void EmbedTable:: 86 | store(py::array_t &keys, torch::Tensor &values) { 87 | ssize_t size = keys.size(); 88 | auto *keys_ptr = static_cast(keys.request().ptr); 89 | auto embeds = values.detach().cpu(); 90 | 91 | ssize_t nrows = _table.size(0); 92 | if (_start + size > nrows && nrows < _limit) { 93 | ssize_t grow_size = std::max(_start + size, nrows * 2); 94 | grow_size = std::min(grow_size, _limit); 95 | _table.resize_({grow_size, _dim_emb}); 96 | _keys.resize(grow_size); 97 | nrows = grow_size; 98 | } 99 | 100 | ssize_t inp_idx = 0; 101 | while (size > 0) { 102 | ssize_t nslots = std::min(nrows - _start, size); 103 | ssize_t inp_end = inp_idx + nslots; 104 | ssize_t out_end = _start + nslots; 105 | 106 | _table.index_put_({torch::indexing::Slice(_start, out_end)}, 107 | embeds.index({torch::indexing::Slice(inp_idx, inp_end)})); 108 | 109 | for (ssize_t i = 0; i < nslots; i++) { 110 | ssize_t slot = _start + i; 111 | int64_t old_key = _keys[slot]; 112 | int64_t new_key = keys_ptr[inp_idx + i]; 113 | _key2idx.erase(old_key); 114 | _key2idx.emplace(new_key, slot); 115 | _keys[slot] = new_key; 116 | } 117 | 118 | _start = out_end % nrows; 119 | inp_idx = inp_end; 120 | size -= nslots; 121 | } 122 | } 123 | 124 | } // namespace tglite 125 | -------------------------------------------------------------------------------- /lib/dedup.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/dedup.h" 2 | #include "tglite/utils.h" 3 | 4 | namespace tglite { 5 | 6 | py::tuple dedup_targets(py::array_t &nodes, py::array_t ×) { 7 | ssize_t size = nodes.size(); 8 | auto inv_idx = py::array_t(size); 9 | 10 | auto *node_ptr = static_cast(nodes.request().ptr); 11 | auto *time_ptr = static_cast(times.request().ptr); 12 | auto *inv_ptr = static_cast(inv_idx.request().ptr); 13 | 14 | std::unordered_map key2idx; 15 | auto *uniq_node = new std::vector; 16 | auto *uniq_time = new std::vector; 17 | uniq_node->reserve(size); 18 | uniq_time->reserve(size); 19 | key2idx.reserve(size); 20 | 21 | bool has_dups = false; 22 | for (ssize_t i = 0; i < size; i++) { 23 | IdI32 nid = node_ptr[i]; 24 | TsF32 nts = time_ptr[i]; 25 | int64_t key = opt_hash(nid, nts); 26 | auto iter = key2idx.find(key); 27 | if (iter != key2idx.end()) { 28 | auto uniq_idx = iter->second; 29 | inv_ptr[i] = uniq_idx; 30 | has_dups = true; 31 | } else { 32 | auto idx = uniq_node->size(); 33 | uniq_node->push_back(nid); 34 | uniq_time->push_back(nts); 35 | key2idx.emplace(key, idx); 36 | inv_ptr[i] = idx; 37 | } 38 | } 39 | 40 | py::array_t res_nodes = to_pyarray_owned(uniq_node); 41 | py::array_t res_times = to_pyarray_owned(uniq_time); 42 | return py::make_tuple(has_dups, res_nodes, res_times, inv_idx); 43 | } 44 | 45 | // py::tuple dedup_indices(torch::Tensor &nodes, torch::Tensor ×) { 46 | // ssize_t size = nodes.size(0); 47 | // 48 | // auto nodes_ptr = nodes.accessor().data(); 49 | // auto times_ptr = times.accessor().data(); 50 | // 51 | // auto opt_i64 = torch::TensorOptions().dtype(torch::kInt64); 52 | // auto inv_idx = torch::zeros(size, opt_i64); 53 | // auto inv_ptr = inv_idx.accessor().data(); 54 | // 55 | // std::unordered_map key2idx; 56 | // std::vector indices; 57 | // key2idx.reserve(size); 58 | // 59 | // for (ssize_t i = 0; i < size; i++) { 60 | // IdI32 nid = nodes_ptr[i]; 61 | // TsF32 nts = times_ptr[i]; 62 | // 63 | // int64_t key = opt_hash(nid, nts); 64 | // auto it = key2idx.find(key); 65 | // 66 | // if (it != key2idx.end()) { 67 | // auto uniq_idx = it->second; 68 | // inv_ptr[i] = uniq_idx; 69 | // } else { 70 | // auto idx = indices.size(); 71 | // indices.push_back(i); 72 | // key2idx.emplace(key, idx); 73 | // inv_ptr[i] = idx; 74 | // } 75 | // } 76 | // 77 | // py::array_t filter_idx = py::cast(indices); 78 | // return py::make_tuple(filter_idx, inv_idx); 79 | // } 80 | 81 | // bool dedup_targets( 82 | // const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t len, 83 | // const IdI32 *pre_nodes, const TsF32 *pre_times, size_t pre_len, 84 | // std::vector &uniq_nodes, 85 | // std::vector &uniq_times, 86 | // std::vector &inv_idx) { 87 | // 88 | // std::unordered_map key2idx; 89 | // key2idx.reserve(len + pre_len); 90 | // bool has_dups = false; 91 | // 92 | // for (size_t i = 0; i < len + pre_len; i++) { 93 | // IdI32 nid = i < pre_len ? pre_nodes[i] : nodes_ptr[i - pre_len]; 94 | // TsF32 nts = i < pre_len ? pre_times[i] : times_ptr[i - pre_len]; 95 | // 96 | // int64_t key = opt_hash(nid, nts); 97 | // auto it = key2idx.find(key); 98 | // 99 | // if (it != key2idx.end()) { 100 | // auto uniq_idx = it->second; 101 | // inv_idx.push_back(uniq_idx); 102 | // has_dups = true; 103 | // } else { 104 | // auto idx = uniq_nodes.size(); 105 | // uniq_nodes.push_back(nid); 106 | // uniq_times.push_back(nts); 107 | // key2idx.emplace(key, idx); 108 | // inv_idx.push_back(idx); 109 | // } 110 | // } 111 | // 112 | // return has_dups; 113 | // } 114 | 115 | } // namespace tglite 116 | -------------------------------------------------------------------------------- /lib/sampler.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/sampler.h" 2 | #include 3 | 4 | namespace tglite { 5 | 6 | TemporalSampler:: 7 | TemporalSampler(int num_threads, int num_nbrs, bool recent) 8 | : _num_threads(num_threads), _num_nbrs(num_nbrs), _recent(recent) { } 9 | 10 | TemporalBlock TemporalSampler:: 11 | sample(TCSR &tcsr, py::array_t &nodes, py::array_t ×) { 12 | omp_set_num_threads(_num_threads); 13 | 14 | const IdI32 *nodes_ptr = static_cast(nodes.request().ptr); 15 | const TsF32 *times_ptr = static_cast(times.request().ptr); 16 | size_t size = nodes.size(); 17 | 18 | TemporalBlock block; 19 | sample_layer(tcsr, block, nodes_ptr, times_ptr, size); 20 | 21 | return block; 22 | } 23 | 24 | void TemporalSampler:: 25 | sample_layer(TCSR &tcsr, TemporalBlock &block, 26 | const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t size) { 27 | std::vector *eid[_num_threads]; 28 | std::vector *ets[_num_threads]; 29 | std::vector *srcnodes[_num_threads]; 30 | std::vector *dstindex[_num_threads]; 31 | std::vector out_nodes(_num_threads, 0); 32 | 33 | int nodes_per_thread = int(ceil(static_cast(size) / _num_threads)); 34 | int reserve_capacity = nodes_per_thread * _num_nbrs; 35 | 36 | #pragma omp parallel 37 | { 38 | int tid = omp_get_thread_num(); 39 | unsigned int loc_seed = tid; 40 | 41 | eid[tid] = new std::vector; 42 | ets[tid] = new std::vector; 43 | srcnodes[tid] = new std::vector; 44 | dstindex[tid] = new std::vector; 45 | 46 | eid[tid]->reserve(reserve_capacity); 47 | ets[tid]->reserve(reserve_capacity); 48 | srcnodes[tid]->reserve(reserve_capacity); 49 | dstindex[tid]->reserve(reserve_capacity); 50 | 51 | #pragma omp for schedule(static, nodes_per_thread) 52 | for (size_t j = 0; j < size; j++) { 53 | IdI32 nid = nodes_ptr[j]; 54 | TsF32 nts = times_ptr[j]; 55 | 56 | IdI32 s_search = tcsr.ind[nid]; 57 | auto e_it = std::lower_bound(tcsr.ets.begin() + s_search, 58 | tcsr.ets.begin() + tcsr.ind[nid + 1], nts); 59 | IdI32 e_search = std::max(int(e_it - tcsr.ets.begin()) - 1, s_search); 60 | 61 | if (_recent || (e_search - s_search + 1 < _num_nbrs)) { 62 | for (IdI32 k = e_search; k >= std::max(s_search, e_search - _num_nbrs + 1); k--) { 63 | if (tcsr.ets[k] < nts - 1e-7f) { 64 | add_neighbor(tcsr, eid[tid], ets[tid], srcnodes[tid], dstindex[tid], k, out_nodes[tid]); 65 | } 66 | } 67 | } else { 68 | for (int k = 0; k < _num_nbrs; k++) { 69 | IdI32 picked = s_search + rand_r(&loc_seed) % (e_search - s_search + 1); 70 | if (tcsr.ets[picked] < nts - 1e-7f) { 71 | add_neighbor(tcsr, eid[tid], ets[tid], srcnodes[tid], dstindex[tid], picked, out_nodes[tid]); 72 | } 73 | } 74 | } 75 | 76 | out_nodes[tid] += 1; 77 | } 78 | } 79 | 80 | combine_coo(block, eid, ets, srcnodes, dstindex, out_nodes); 81 | } 82 | 83 | inline void TemporalSampler:: 84 | add_neighbor(TCSR &tcsr, 85 | std::vector *eid, std::vector *ets, 86 | std::vector *srcnodes, std::vector *dstindex, 87 | IdI32 &k, IdI32 &dst_idx) { 88 | eid->push_back(tcsr.eid[k]); 89 | ets->push_back(tcsr.ets[k]); 90 | srcnodes->push_back(tcsr.nbr[k]); 91 | dstindex->push_back(dst_idx); 92 | } 93 | 94 | inline void TemporalSampler:: 95 | combine_coo(TemporalBlock &block, 96 | std::vector **eid, 97 | std::vector **ets, 98 | std::vector **srcnodes, 99 | std::vector **dstindex, 100 | std::vector &out_nodes) { 101 | 102 | std::vector scan_nodes; 103 | std::vector scan_edges; 104 | scan_nodes.push_back(0); 105 | scan_edges.push_back(0); 106 | for (int tid = 0; tid < _num_threads; tid++) { 107 | scan_nodes.push_back(scan_nodes.back() + out_nodes[tid]); 108 | scan_edges.push_back(scan_edges.back() + eid[tid]->size()); 109 | } 110 | 111 | IdI32 num_edges = scan_edges.back(); 112 | block.dstindex.resize(num_edges); 113 | block.srcnodes.resize(num_edges); 114 | block.eid.resize(num_edges); 115 | block.ets.resize(num_edges); 116 | block.num_edges = num_edges; 117 | 118 | #pragma omp parallel for schedule(static, 1) 119 | for (int tid = 0; tid < _num_threads; tid++) { 120 | std::transform(dstindex[tid]->begin(), dstindex[tid]->end(), 121 | dstindex[tid]->begin(), [&](auto &v) { return v + scan_nodes[tid]; }); 122 | std::copy(eid[tid]->begin(), eid[tid]->end(), block.eid.begin() + scan_edges[tid]); 123 | std::copy(ets[tid]->begin(), ets[tid]->end(), block.ets.begin() + scan_edges[tid]); 124 | std::copy(srcnodes[tid]->begin(), srcnodes[tid]->end(), block.srcnodes.begin() + scan_edges[tid]); 125 | std::copy(dstindex[tid]->begin(), dstindex[tid]->end(), block.dstindex.begin() + scan_edges[tid]); 126 | delete eid[tid]; 127 | delete ets[tid]; 128 | delete srcnodes[tid]; 129 | delete dstindex[tid]; 130 | } 131 | } 132 | 133 | } // namespace tglite 134 | -------------------------------------------------------------------------------- /lib/tcsr.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/tcsr.h" 2 | 3 | namespace tglite { 4 | 5 | struct ETuple { 6 | IdI32 nbr; 7 | IdI32 eid; 8 | TsF32 ets; 9 | static bool cmp_ts(const ETuple &a, const ETuple &b) { 10 | return a.ets < b.ets; 11 | } 12 | }; 13 | 14 | TCSR create_tcsr(py::array_t &edges, py::array_t ×, size_t num_nodes) { 15 | auto *edges_ptr = static_cast(edges.request().ptr); 16 | auto *times_ptr = static_cast(times.request().ptr); 17 | 18 | std::vector> adj_list(num_nodes); 19 | for (IdI32 eid = 0; eid < edges.shape(0); eid++) { 20 | IdI32 src = edges_ptr[eid * 2]; 21 | IdI32 dst = edges_ptr[eid * 2 + 1]; 22 | TsF32 ets = times_ptr[eid]; 23 | adj_list[src].push_back({dst, eid, ets}); 24 | adj_list[dst].push_back({src, eid, ets}); 25 | } 26 | 27 | TCSR tcsr; 28 | for (auto &adj : adj_list) { 29 | std::sort(adj.begin(), adj.end(), ETuple::cmp_ts); 30 | tcsr.ind.push_back(tcsr.ind.back() + adj.size()); 31 | for (auto &tuple : adj) { 32 | tcsr.nbr.push_back(tuple.nbr); 33 | tcsr.eid.push_back(tuple.eid); 34 | tcsr.ets.push_back(tuple.ets); 35 | } 36 | adj.clear(); 37 | adj.shrink_to_fit(); 38 | } 39 | 40 | return tcsr; 41 | } 42 | 43 | } // namespace tglite 44 | -------------------------------------------------------------------------------- /lib/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "tglite/utils.h" 2 | #include 3 | 4 | namespace tglite { 5 | 6 | // torch::Tensor index_pinned(torch::Tensor &input, torch::Tensor &index) { 7 | // int64_t nrows = index.size(0); 8 | // int64_t nfeats = input.size(1); 9 | // 10 | // auto opts = torch::TensorOptions() 11 | // .device(input.device()) 12 | // .dtype(input.dtype()) 13 | // .pinned_memory(true); 14 | // auto out = torch::zeros({nrows, nfeats}, opts); 15 | // 16 | // auto *out_ptr = out.accessor().data(); 17 | // auto *inp_ptr = input.accessor().data(); 18 | // auto *idx_ptr = index.accessor().data(); 19 | // 20 | // #pragma omp parallel for 21 | // for (int64_t i = 0; i < nrows; i++) { 22 | // auto idx = idx_ptr[i]; 23 | // auto inp_start = inp_ptr + idx * nfeats; 24 | // std::copy(inp_start, inp_start + nfeats, out_ptr + i * nfeats); 25 | // } 26 | // 27 | // return out; 28 | // } 29 | 30 | // py::tuple find_last_message(py::array_t &uniq_nodes, py::array_t &edges) { 31 | // ssize_t num_nodes = uniq_nodes.size(); 32 | // ssize_t num_edges = edges.shape(0); 33 | // 34 | // auto *msg_order = new std::vector; 35 | // auto *msg_index = new std::vector; 36 | // msg_order->resize(num_nodes * 2); 37 | // msg_index->resize(num_nodes); 38 | // 39 | // auto *nodes_ptr = static_cast(uniq_nodes.request().ptr); 40 | // auto *edges_ptr = static_cast(edges.request().ptr); 41 | // 42 | // #pragma omp parallel for schedule(static) 43 | // for (ssize_t i = 0; i < num_nodes; i++) { 44 | // IdI32 nid = nodes_ptr[i]; 45 | // for (ssize_t e = num_edges - 1; e >= 0; e--) { 46 | // if (edges_ptr[e * 2 + 0] == nid) { 47 | // // is src node, order is same 48 | // (*msg_order)[i * 2 + 0] = edges_ptr[e * 2 + 0]; 49 | // (*msg_order)[i * 2 + 1] = edges_ptr[e * 2 + 1]; 50 | // (*msg_index)[i] = e; 51 | // break; 52 | // } else if (edges_ptr[e * 2 + 1] == nid) { 53 | // // is dst node, order is flipped 54 | // (*msg_order)[i * 2 + 0] = edges_ptr[e * 2 + 1]; 55 | // (*msg_order)[i * 2 + 1] = edges_ptr[e * 2 + 0]; 56 | // (*msg_index)[i] = e; 57 | // break; 58 | // } 59 | // } 60 | // } 61 | // 62 | // py::array_t res_order = to_pyarray_owned(msg_order); 63 | // py::array_t res_index = to_pyarray_owned(msg_index); 64 | // return py::make_tuple(res_order, res_index); 65 | // } 66 | 67 | py::array_t find_latest_uniq(py::array_t &uniq, py::array_t &nodes, py::array_t ×) { 68 | ssize_t num_uniq = uniq.size(); 69 | ssize_t num_nodes = nodes.size(); 70 | 71 | auto *index = new std::vector; 72 | index->resize(num_uniq); 73 | 74 | auto *uniq_ptr = static_cast(uniq.request().ptr); 75 | auto *node_ptr = static_cast(nodes.request().ptr); 76 | auto *time_ptr = static_cast(times.request().ptr); 77 | 78 | #pragma omp parallel for schedule(static) 79 | for (ssize_t i = 0; i < num_uniq; i++) { 80 | IdI32 nid = uniq_ptr[i]; 81 | TsF32 max = -1.0f; 82 | for (ssize_t j = num_nodes - 1; j >= 0; j--) { 83 | if (node_ptr[j] == nid && time_ptr[j] > max) { 84 | max = time_ptr[j]; 85 | (*index)[i] = j; 86 | } 87 | } 88 | } 89 | 90 | py::array_t res = to_pyarray_owned(index); 91 | return res; 92 | } 93 | 94 | } // namespace tglite 95 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tglite" 3 | # dynamic = ["version"] 4 | version = "0.0.4" 5 | description = "Temporal GNN Lightweight Framework" 6 | readme = "README.md" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | {name = "Yufeng Wang", email = "yufengwang05@gmail.com"}, 10 | {name = "Charith Mendis", email = "charithm@illinois.edu"} 11 | ] 12 | maintainers = [ 13 | {name = "Wanyu Zhao", email = "wanyu2@illinois.edu"} 14 | ] 15 | keywords = [ 16 | "machine learning", "TGNN", 17 | ] 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Operating System :: POSIX :: Linux", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | ] 30 | requires-python = ">=3.7, <3.11" 31 | dependencies = [ 32 | 'numpy==1.21.6; python_version == "3.7"', 33 | 'numpy>=1.21.6, <1.25.0; python_version == "3.8"', 34 | 'numpy>=1.21.6, <1.26.0; python_version == "3.9"', 35 | 'numpy>=1.21.6; python_version == "3.10"', 36 | 'torch>=1.12.1, <2.0.0; python_version == "3.7"', 37 | 'torch>=1.12.1; python_version >= "3.8"', 38 | 'torch-scatter>=2.1.0, <2.1.2; python_version == "3.7"', 39 | 'torch-scatter>=2.1.0; python_version == "3.8"', 40 | ] 41 | 42 | [project.optional-dependencies] 43 | dev = ["sphinx"] 44 | test = ["pytest", "pytest-cov"] 45 | docs = ["sphinx_rtd_theme", "nbsphinx", "ipykernel"] 46 | 47 | [project.urls] 48 | homepage = "https://github.com/ADAPT-uiuc/tglite" 49 | documentation = "https://readthedocs.org" 50 | repository = "https://github.com/ADAPT-uiuc/tglite" 51 | issues = "https://github.com/me/ADAPT-uiuc/tglite/issues" 52 | 53 | # [tool.setuptools.dynamic] 54 | # version = {attr = "tglite.__version__"} 55 | 56 | [tool.setuptools.packages.find] 57 | where = ["python"] 58 | include = ["tglite*"] 59 | namespaces = false 60 | 61 | [build-system] 62 | requires = ["setuptools>=61.0", "wheel", "torch>=1.12.1"] 63 | build-backend = "setuptools.build_meta" 64 | -------------------------------------------------------------------------------- /python/tglite/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | tglite: Temporal GNN Lightweight Framework 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | __version__ = "0.0.4" 8 | 9 | ### Re-exports 10 | 11 | import torch 12 | from ._graph import TGraph, from_csv 13 | from ._batch import TBatch 14 | from ._block import TBlock 15 | from ._frame import TFrame 16 | from ._memory import Memory 17 | from ._mailbox import Mailbox 18 | from ._sampler import TSampler 19 | from ._context import TContext 20 | from ._core import TError 21 | from . import _utils as utils 22 | from . import nn 23 | from . import op 24 | 25 | 26 | def iter_edges(g: TGraph, size=1, start=None, end=None) -> EdgesIter: 27 | """ 28 | Create and return an iterator to generate TBatch 29 | 30 | :param TGraph g: The graph to iterate on. 31 | :param int size: Number of edges in each mini-batch. 32 | :rtype: EdgesIter 33 | :param start: The starting edge index. 34 | :type start: int or None 35 | :param end: The ending edge index. 36 | :type end: int or None 37 | """ 38 | return EdgesIter(g, size=size, start=start, end=end) 39 | 40 | 41 | class EdgesIter(object): 42 | """ An edge iterator of a TGraph.""" 43 | def __init__(self, g: TGraph, size=1, start=None, end=None): 44 | """ 45 | Create an edge iterator. 46 | 47 | :param TGraph g: The graph it iterates on. 48 | :param int size: Number of edges in each mini-batch. 49 | :param start: The starting edge index. 50 | :type start: int or None 51 | :param end: The ending edge index. 52 | :type end: int or None 53 | """ 54 | self._g = g 55 | self._size = size 56 | self._curr = 0 if start is None else start 57 | self._last = g.num_edges() if end is None else end 58 | 59 | def __iter__(self) -> EdgesIter: 60 | return self 61 | 62 | def __next__(self) -> TBatch: 63 | if self._curr < self._last: 64 | idx = self._curr 65 | self._curr += self._size 66 | end = min(self._curr, self._last) 67 | return TBatch(self._g, range=(idx, end)) 68 | raise StopIteration 69 | -------------------------------------------------------------------------------- /python/tglite/_batch.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple 2 | if TYPE_CHECKING: 3 | from ._graph import TGraph 4 | from ._context import TContext 5 | 6 | import numpy as np 7 | from torch import Tensor 8 | 9 | from ._core import TError 10 | from ._block import TBlock 11 | from ._stats import tt 12 | 13 | 14 | class TBatch(object): 15 | """ 16 | Represents a batch of temporal edges to process. A thin wrapper with a TGraph reference and without actually 17 | materializing any arrays until they are needed. 18 | """ 19 | 20 | def __init__(self, g: 'TGraph', range: Tuple[int, int]): 21 | """ 22 | Internal constructor for creating a TBatch. 23 | 24 | :param TGraph g: The TGraph. 25 | :param Tuple[int, int] range: The range of edge indices: beginning and ending edge index. 26 | """ 27 | self._g = g 28 | self._beg_idx = range[0] 29 | self._end_idx = range[1] 30 | self._neg_nodes = None 31 | 32 | def __len__(self) -> int: 33 | """Returns the total number of edges in the batch.""" 34 | return self._end_idx - self._beg_idx 35 | 36 | @property 37 | def g(self) -> 'TGraph': 38 | """Returns the TGraph.""" 39 | return self._g 40 | 41 | @property 42 | def neg_nodes(self) -> Optional[np.ndarray]: 43 | """Get the negative nodes.""" 44 | return self._neg_nodes 45 | 46 | @neg_nodes.setter 47 | def neg_nodes(self, value: np.ndarray): 48 | """ 49 | Set the negative nodes. 50 | 51 | :param np.ndarray value: An array of negative node samples. 52 | :raises TError: if value is not a 1-dimensional ndarray. 53 | """ 54 | if not isinstance(value, np.ndarray): 55 | raise TError('negative samples must be an ndarray') 56 | if len(value.shape) != 1: 57 | raise TError('negative samples must be 1-dimensional') 58 | self._neg_nodes = value 59 | 60 | def block(self, ctx: 'TContext') -> TBlock: 61 | """Creates the head TBlock of the batch, including negative nodes if set.""" 62 | t_start = tt.start() 63 | blk = TBlock(ctx, 0, self.nodes(), self.times()) 64 | tt.t_prep_input += tt.elapsed(t_start) 65 | return blk 66 | 67 | def block_adj(self, ctx: 'TContext') -> TBlock: 68 | """Creates the head TBlock with batch edges as neighbors (excluding negative nodes).""" 69 | dstnodes = self.nodes(include_negs=False) 70 | srcnodes = self.nodes(include_negs=False, reverse=True) 71 | dstindex = np.arange(len(dstnodes)) 72 | eids = np.tile(self.eids(), 2) 73 | ets = self.times(include_negs=False) 74 | return TBlock(ctx, 0, dstnodes, ets, dstindex, srcnodes, eids, ets) 75 | 76 | def eids(self) -> np.ndarray: 77 | """ 78 | Returns edge ids of the batch. 79 | 80 | rtype: np.ndarray 81 | """ 82 | return np.arange(self._beg_idx, self._end_idx, dtype=np.int32) 83 | 84 | def edges(self) -> np.ndarray: 85 | """ 86 | Returns the edges in the batch as a two-column ndarray, where the first column represents the source 87 | node index and the second column represents the destination node index. 88 | 89 | rtype: np.ndarray 90 | """ 91 | return self.g._edges[self._beg_idx:self._end_idx] 92 | 93 | def nodes(self, include_negs=True, reverse=False) -> np.ndarray: 94 | """ 95 | Returns a node index array: [src, des(, neg)] if reverse is False or [des, src(, src)] if reverse is True. 96 | 97 | :param bool include_negs: Whether to include negative nodes. 98 | :param bool reverse: Whether to reverse the edges. 99 | :rtype: np.ndarray 100 | """ 101 | nids = self.g._edges[self._beg_idx:self._end_idx] 102 | nids = np.flip(nids, axis=1) if reverse else nids 103 | nids = nids.T.reshape(-1) 104 | if self._neg_nodes is not None and include_negs: 105 | negs = nids[len(self):] if reverse else self._neg_nodes 106 | nids = np.concatenate([nids, negs]) 107 | return nids.astype(np.int32) 108 | 109 | def times(self, include_dsts=True, include_negs=True) -> np.ndarray: 110 | """ 111 | Returns timestamps corresponding to the nodes. It retrieves timestamps of the batch edges (as the timestamps for source nodes), 112 | repeating the timestamps to include destination nodes and negative nodes. 113 | 114 | :param bool include_dst: Whether to include destination nodes of the edges as positive nodes. 115 | :param bool include_negs: Whether to include negative nodes. 116 | """ 117 | n_repeats = 2 if include_dsts else 1 118 | if include_negs and self._neg_nodes is not None: 119 | n_repeats += 1 120 | times = self.g._times[self._beg_idx:self._end_idx] 121 | return np.tile(times, n_repeats).astype(np.float32) 122 | 123 | def split_data(self, data: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]: 124 | """ 125 | Splits the data into multiple arrays, with each array containing a number of rows equal to the batch size. 126 | 127 | :param Tensor data: The source data to be split. 128 | :raises TError: If the length of data is not three times the batch size when negative nodes are included or two times otherwise. 129 | :return: A tuple (src, dst, neg), where neg is None if no negative nodes are specified. 130 | """ 131 | size = len(self) 132 | if self._neg_nodes is not None: 133 | if data.shape[0] != 3 * size: 134 | raise TError('expected data to have 3 times batch size') 135 | dst = data[size:2 * size] 136 | neg = data[2 * size:] 137 | else: 138 | if data.shape[0] != 2 * size: 139 | raise TError('expected data to have 2 times batch size') 140 | dst = data[size:] 141 | neg = None 142 | src = data[:size] 143 | return (src, dst, neg) -------------------------------------------------------------------------------- /python/tglite/_context.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, List 2 | if TYPE_CHECKING: 3 | from ._graph import TGraph 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from ._core import TError 9 | 10 | 11 | class TContext(object): 12 | """Graph-level context and scratch space used by the tglite runtime.""" 13 | 14 | def __init__(self, g: 'TGraph'): 15 | """ 16 | Internal constructor for creating a TContext. 17 | 18 | :param TGraph g: The TGraph to operate on. 19 | """ 20 | self._g = g 21 | self._training = True 22 | 23 | # pin buffers 24 | self._efeat_pins = {} 25 | self._nfeat_pins = {} 26 | self._mem_data_pins = {} 27 | self._mail_pins = {} 28 | 29 | # embed caching 30 | self._cache_enabled = False 31 | self._cache_dim_emb = None 32 | self._cache_limit = int(2e6) 33 | self._cache_tables = {} 34 | 35 | # time precomputation 36 | self._time_enabled = False 37 | self._time_window = int(1e4) 38 | self._time_tables = {} 39 | 40 | @property 41 | def graph(self) -> 'TGraph': 42 | """Returns the TGraph it associated with.""" 43 | return self._g 44 | 45 | def train(self): 46 | """Enables training mode and clear time tables and embedding cache tables.""" 47 | self._training = True 48 | self._time_tables.clear() 49 | self._cache_tables.clear() 50 | 51 | def eval(self): 52 | """Disables training mode.""" 53 | self._training = False 54 | 55 | def need_sampling(self, need: bool): 56 | """ 57 | Creates tcsr within the TGraph if sampling is needed. 58 | 59 | :param bool need: Whether sampling is required. 60 | """ 61 | if need: 62 | self._g._init_tcsr() 63 | else: 64 | self._g._tcsr = None 65 | 66 | def enable_embed_caching(self, enabled: bool, dim_embed: int = None): 67 | """ 68 | Performs embedding cache settings and clear cache tables. 69 | 70 | :param bool enabled: Whether to enable embedding caching. 71 | :param dim_embed: Dimension of node embeddings. 72 | :raises TError: If enable is True and dim_embed is None. 73 | """ 74 | self._cache_enabled = enabled 75 | if enabled and dim_embed is None: 76 | raise TError('need dimension of embeddings') 77 | elif enabled: 78 | self._cache_dim_emb = dim_embed 79 | self._cache_tables.clear() 80 | 81 | def set_cache_limit(self, limit: int): 82 | """ 83 | Sets embedding cache limit and clear cache tables. 84 | 85 | :param int limit: Number of embeddings to cache. 86 | """ 87 | self._cache_limit = limit 88 | self._cache_tables.clear() 89 | 90 | def enable_time_precompute(self, enabled: bool): 91 | """ 92 | Performs time precomputation settings and clear time tables. 93 | 94 | :param bool enabled: Whether to enable embedding caching. 95 | """ 96 | self._time_enabled = enabled 97 | self._time_tables.clear() 98 | 99 | def set_time_window(self, window: int): 100 | """ 101 | Sets length of time window and clear time tables. 102 | 103 | :param int window: Length of time window. 104 | :raises TError: If int is negative. 105 | """ 106 | 107 | if window < 0: 108 | raise TError('time window must be non-negative') 109 | self._time_window = window 110 | self._time_tables.clear() 111 | 112 | def _get_efeat_pin(self, layer: int, rows: int, dim: int) -> Tensor: 113 | return self._get_pin(self._efeat_pins, layer, rows, [dim]) 114 | 115 | def _get_nfeat_pin(self, layer: int, rows: int, dim: int) -> Tensor: 116 | return self._get_pin(self._nfeat_pins, layer, rows, [dim]) 117 | 118 | def _get_mem_data_pin(self, layer: int, rows: int) -> Tensor: 119 | return self._get_pin(self._mem_data_pins, layer, rows, [self._g.mem.dim()]) 120 | 121 | def _get_mail_pin(self, layer: int, rows: int) -> Tensor: 122 | return self._get_pin(self._mail_pins, layer, rows, self._g.mailbox.dims()) 123 | 124 | def _get_pin(self, cache: dict, layer: int, rows: int, dims: List[int]) -> Tensor: 125 | """ 126 | Creates/reshapes pinned buffer and returns it. 127 | 128 | :param dict cache: The dictionary of the buffer with layer numbers as keys. 129 | :param int layer: Which layer pinned data belongs to. 130 | :param int rows: Number of rows of pinned data. 131 | :param List[int] dims: Size of each pinned data. 132 | :return: Pinned buffer. 133 | :rtype: Tensor 134 | """ 135 | if layer not in cache: 136 | shape = tuple([rows] + dims) 137 | pin = torch.zeros(shape, pin_memory=True) 138 | cache[layer] = pin 139 | return pin 140 | pin = cache[layer] 141 | if pin.shape[0] < rows or list(pin.shape[1:]) != dims: 142 | shape = tuple([rows] + dims) 143 | pin.resize_(shape) 144 | return pin[:rows] 145 | -------------------------------------------------------------------------------- /python/tglite/_core.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class TError(Exception): 4 | """Base error thrown by tglite.""" 5 | -------------------------------------------------------------------------------- /python/tglite/_frame.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import Tensor 4 | 5 | from ._core import TError 6 | 7 | 8 | class TFrame(dict): 9 | """A container for storing tensor features.""" 10 | 11 | def __init__(self, dim=None): 12 | """Initialize a TFrame object""" 13 | super().__init__() 14 | self._dim = 0 if dim is None else dim 15 | 16 | def dim(self) -> int: 17 | """Get the leading dimension of stored tensors""" 18 | return self._dim 19 | 20 | def to(self, device, **kwargs) -> TFrame: 21 | """Move tensor data in this frame to the specified device. 22 | 23 | :param device: the device to which the tensor data should be moved 24 | :param **kwargs: additional keyword arguments for the pytorch Tensor.to() method 25 | :returns: a new TFrame object with the tensor data copied to the specified device 26 | """ 27 | copy = TFrame(dim=self._dim) 28 | for key, val in self.items(): 29 | copy[key] = val.to(device, **kwargs) 30 | return copy 31 | 32 | def __setitem__(self, key, value): 33 | """ 34 | Set an item in the TFrame 35 | 36 | :raises TError: if the value is not a tensor or if it does not have the expected leading dimension 37 | """ 38 | if not isinstance(value, Tensor): 39 | raise TError("expected value to be a tensor") 40 | if len(value) != self._dim: 41 | raise TError(f"expected value to have leading dimension of {self._dim}, got {len(value)}") 42 | super().__setitem__(key, value) 43 | -------------------------------------------------------------------------------- /python/tglite/_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | from ._core import TError 8 | from ._frame import TFrame 9 | from ._memory import Memory 10 | from ._mailbox import Mailbox 11 | from ._utils import create_tcsr, check_edges_times, check_num_nodes 12 | 13 | 14 | class TGraph(object): 15 | """A container for temporal graph and related tensor data. It initially stores temporal edges in COO format, 16 | sorted based on timestamp. While performing neighborhood sampling, it uses CSR format for faster lookups. TGLite 17 | automatically handles the construction and management of these graph formats without intervention from the user.""" 18 | 19 | def __init__(self, edges: np.ndarray, times: np.ndarray, num_nodes: int = None): 20 | """ 21 | Internal constructor for creating a TGraph 22 | 23 | :param np.ndarray edges: 24 | :param np.ndarray times: 25 | :param int num_nodes: 26 | """ 27 | check_edges_times(edges, times) 28 | self._num_nodes = check_num_nodes(edges, num_nodes) 29 | self._efeat_frame = TFrame(dim=edges.shape[0]) 30 | self._nfeat_frame = TFrame(dim=self._num_nodes) 31 | self._edata = TFrame(dim=edges.shape[0]) 32 | self._ndata = TFrame(dim=self._num_nodes) 33 | self._edges = edges 34 | self._times = times 35 | self._tcsr = None 36 | self._mem = None 37 | self._mailbox = None 38 | self._storage_dev = torch.device('cpu') 39 | self._compute_dev = torch.device('cpu') 40 | 41 | @property 42 | def efeat(self) -> Optional[Tensor]: 43 | """Returns edge feature""" 44 | return self._efeat_frame.get('f') 45 | 46 | @efeat.setter 47 | def efeat(self, value): 48 | """ 49 | Sets edge feature 50 | 51 | :param value: edge feature 52 | """ 53 | if value is None: 54 | self._efeat_frame.clear() 55 | else: 56 | self._efeat_frame['f'] = value 57 | 58 | @property 59 | def nfeat(self) -> Optional[Tensor]: 60 | """Returns node feature""" 61 | return self._nfeat_frame.get('f') 62 | 63 | @nfeat.setter 64 | def nfeat(self, value): 65 | """ 66 | Sets node feature 67 | 68 | :param value: edge feature 69 | """ 70 | if value is None: 71 | self._nfeat_frame.clear() 72 | else: 73 | self._nfeat_frame['f'] = value 74 | 75 | @property 76 | def edata(self) -> TFrame: 77 | """Returns edge data""" 78 | return self._edata 79 | 80 | @property 81 | def ndata(self) -> TFrame: 82 | """Returns node data""" 83 | return self._ndata 84 | 85 | @property 86 | def mem(self) -> Optional[Memory]: 87 | """Returns node memory""" 88 | return self._mem 89 | 90 | @mem.setter 91 | def mem(self, value: Memory): 92 | """ 93 | Sets node memory 94 | 95 | :param Memory value: node memory to set 96 | :raises TError: if value is not a Memory instance or its length doesn't equal to number of nodes, 97 | or value is not on this TGraph's storage device. 98 | """ 99 | if not isinstance(value, Memory): 100 | raise TError('invalid memory object') 101 | if len(value) != self._num_nodes: 102 | raise TError('memory number of nodes mismatch') 103 | if value.device != self._storage_dev: 104 | raise TError('memory storage device mismatch') 105 | self._mem = value 106 | 107 | @property 108 | def mailbox(self) -> Optional[Mailbox]: 109 | """Returns node mailbox""" 110 | return self._mailbox 111 | 112 | @mailbox.setter 113 | def mailbox(self, value: Mailbox): 114 | """ 115 | Sets mailbox 116 | 117 | :param Mailbox value: mailbox to set 118 | :raises TError: if value is not a Mailbox instance or its length doesn't equal to number of nodes, 119 | or value is not on this TGraph's storage device. 120 | """ 121 | if not isinstance(value, Mailbox): 122 | raise TError('invalid mailbox object') 123 | if value.device != self._storage_dev: 124 | raise TError('mailbox storage device mismatch') 125 | # ... more checks here ... 126 | self._mailbox = value 127 | 128 | def storage_device(self) -> torch.device: 129 | """Returns TGraph's storage device""" 130 | return self._storage_dev 131 | 132 | def compute_device(self) -> torch.device: 133 | """Returns TGraph's computing device""" 134 | return self._compute_dev 135 | 136 | def num_nodes(self) -> int: 137 | """ 138 | Total number of nodes 139 | 140 | :rtype: int 141 | """ 142 | return self._num_nodes 143 | 144 | def num_edges(self) -> int: 145 | """ 146 | Total number of edges 147 | 148 | :rtype: int 149 | """ 150 | return self._edges.shape[0] 151 | 152 | def set_compute(self, device): 153 | """Sets computing device""" 154 | self._compute_dev = torch.device(device) 155 | 156 | def move_data(self, device, **kwargs): 157 | """Moves tensor data to device while keeping graph on CPU""" 158 | if self._storage_dev == device: 159 | return 160 | self._efeat_frame = self._efeat_frame.to(device, **kwargs) 161 | self._nfeat_frame = self._nfeat_frame.to(device, **kwargs) 162 | self._edata = self._edata.to(device, **kwargs), 163 | self._ndata = self._ndata.to(device, **kwargs), 164 | if self._mem is not None: 165 | self._mem.move_to(device, **kwargs) 166 | if self._mailbox is not None: 167 | self._mailbox.move_to(device, **kwargs) 168 | self._storage_dev = device 169 | 170 | def _init_tcsr(self): 171 | """Creates tcsr of the graph if it doesn't exist""" 172 | if self._tcsr is None: 173 | self._tcsr = create_tcsr(self._edges, self._times, num_nodes=self._num_nodes) 174 | 175 | def _get_tcsr(self): 176 | """Returns the tcsr of the graph""" 177 | self._init_tcsr() 178 | return self._tcsr 179 | 180 | 181 | def from_csv(path: Union[str, Path], skip_first=True) -> TGraph: 182 | """ 183 | Creates a TGraph from a csv file 184 | 185 | :param path: csv file path 186 | :type path: str or Path 187 | :param bool skip_first: whether to skip the first line 188 | :rtype: TGraph 189 | :raises TError: if path doesn't exist 190 | """ 191 | src, dst, ts = [], [], [] 192 | 193 | path = Path(path) 194 | if not path.exists(): 195 | raise TError(f'file does not exist: {path}') 196 | 197 | with path.open() as file: 198 | if skip_first: 199 | next(file) 200 | for line in file: 201 | line = line.strip().split(',') 202 | src.append(int(line[0])) 203 | dst.append(int(line[1])) 204 | ts.append(float(line[2])) 205 | 206 | src = np.array(src, dtype=np.int32).reshape(-1, 1) 207 | dst = np.array(dst, dtype=np.int32).reshape(-1, 1) 208 | edges = np.concatenate([src, dst], axis=1) 209 | del src 210 | del dst 211 | 212 | etime = np.array(ts, dtype=np.float32) 213 | del ts 214 | 215 | return TGraph(edges, etime) 216 | -------------------------------------------------------------------------------- /python/tglite/_mailbox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from typing import List, Union 5 | 6 | 7 | class Mailbox(object): 8 | """A container for node mailbox messages.""" 9 | 10 | def __init__(self, num_nodes: int, size: int, dim: int, device=None): 11 | self._size = size 12 | self._device = torch.device('cpu' if device is None else device) 13 | 14 | self._mail = torch.zeros((num_nodes, size, dim), device=self._device).squeeze(dim=1) 15 | self._time = torch.zeros((num_nodes, size), device=self._device).squeeze(dim=1) 16 | if size > 1: 17 | self._next = torch.zeros(num_nodes, dtype=torch.long, device=self._device) 18 | 19 | @property 20 | def mail(self) -> Tensor: 21 | return self._mail 22 | 23 | @property 24 | def time(self) -> Tensor: 25 | return self._time 26 | 27 | @property 28 | def device(self) -> torch.device: 29 | return self._device 30 | 31 | def dims(self) -> List[int]: 32 | return list(self.mail.shape[1:]) 33 | 34 | def reset(self): 35 | self._mail.zero_() 36 | self._time.zero_() 37 | 38 | def store(self, nids: Union[np.ndarray, Tensor], mail: Tensor, mail_ts: Tensor): 39 | if not isinstance(nids, Tensor): 40 | nids = torch.from_numpy(nids).long() 41 | nids = nids.to(self._device) 42 | mail = mail.detach().to(self._device) 43 | mail_ts = mail_ts.detach().to(self._device) 44 | if self._size == 1: 45 | self._mail[nids] = mail 46 | self._time[nids] = mail_ts 47 | else: 48 | pos = self._next[nids] 49 | self._mail[nids, pos] = mail 50 | self._time[nids, pos] = mail_ts 51 | self._next[nids] = torch.remainder(pos + 1, self._size) 52 | 53 | def move_to(self, device, **kwargs): 54 | if device is None or self._device == device: 55 | return 56 | self._mail = self._mail.to(device, **kwargs) 57 | self._time = self._time.to(device, **kwargs) 58 | self._device = device 59 | -------------------------------------------------------------------------------- /python/tglite/_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from typing import Tuple, Union 5 | 6 | from ._core import TError 7 | 8 | 9 | class Memory(object): 10 | """A container for node memory.""" 11 | 12 | def __init__(self, num_nodes: int, dim: int, device=None): 13 | """ 14 | Internal constructor for creating a Memory. Initialize node memory and timestamps to zero. 15 | 16 | :param int num_nodes: Number of nodes 17 | :param int dim: Length of memory vector for a single node 18 | :param device: Which device to put node memory 19 | :type device: None or str or torch.device 20 | """ 21 | self._device = torch.device('cpu' if device is None else device) 22 | 23 | self._data = torch.zeros((num_nodes, dim), device=self._device) 24 | self._time = torch.zeros(num_nodes, device=self._device) 25 | 26 | if list(self._data.shape) != [num_nodes, dim]: 27 | raise TError('memory data dimension mismatch') 28 | if self._time.shape[0] != num_nodes: 29 | raise TError('memory timestamp dimension mismatch') 30 | 31 | def __len__(self) -> int: 32 | """ 33 | Return number of nodes in Memory 34 | 35 | :rtype: int 36 | """ 37 | return self._data.shape[0] 38 | 39 | @property 40 | def data(self) -> Tensor: 41 | """Return node memory 42 | 43 | :rtype: Tensor 44 | """ 45 | return self._data 46 | 47 | @property 48 | def time(self) -> Tensor: 49 | """Return timestamps of current node memory 50 | 51 | :rtype: Tensor 52 | """ 53 | return self._time 54 | 55 | @property 56 | def device(self) -> torch.device: 57 | """Return the device where Memory is located 58 | 59 | :rtype: torch.device 60 | """ 61 | return self._device 62 | 63 | def dim(self) -> int: 64 | """Return length of memory vector for a single node""" 65 | return self._data.shape[1] 66 | 67 | def reset(self): 68 | """Reset node memory and timestamps to zero""" 69 | self._data.zero_() 70 | self._time.zero_() 71 | 72 | def update(self, nids: Union[np.ndarray, Tensor], newdata: Tensor, newtime: Tensor): 73 | if not isinstance(nids, Tensor): 74 | nids = torch.from_numpy(nids).long() 75 | nids = nids.to(self._device) 76 | self._data[nids] = newdata.detach().to(self._device) 77 | self._time[nids] = newtime.detach().to(self._device) 78 | 79 | def move_to(self, device, **kwargs): 80 | if device is None or self._device == device: 81 | return 82 | self._data = self._data.to(device, **kwargs) 83 | self._time = self._time.to(device, **kwargs) 84 | self._device = device 85 | 86 | def backup(self) -> Tuple[Tensor, Tensor]: 87 | return (self._data.cpu().clone(), self._time.cpu().clone()) 88 | 89 | def restore(self, state: Tuple[Tensor, Tensor]): 90 | data, time = state 91 | if self._data.shape != data.shape: 92 | raise TError('memory data dimension mismatch') 93 | if self._time.shape != time.shape: 94 | raise TError('memory timestamp dimension mismatch') 95 | self._data = data.clone().to(self._device) 96 | self._time = time.clone().to(self._device) 97 | -------------------------------------------------------------------------------- /python/tglite/_sampler.py: -------------------------------------------------------------------------------- 1 | from . import _c 2 | from ._core import TError 3 | from ._block import TBlock 4 | from ._utils import get_num_cpus 5 | from ._stats import tt 6 | 7 | 8 | class TSampler(object): 9 | 10 | def __init__(self, num_nbrs: int, strategy='recent', num_threads: int = None): 11 | """ 12 | Internal constructor for creating a TSampler 13 | 14 | :param int num_nbrs: number of neighbors 15 | :param str strategy: sampling strategy, 'recent' or 'uniform' 16 | :param int num_threads: number of threads for parallel sampling, set to number of cpus if not provided 17 | :raises TError: if strategy is not in ['recent', 'uniform'] 18 | """ 19 | 20 | if strategy not in ['recent', 'uniform']: 21 | raise TError(f'sampling strategy not supported: {strategy}') 22 | 23 | self._n_nbrs = num_nbrs 24 | self._strategy = strategy 25 | self._n_threads = get_num_cpus() \ 26 | if num_threads is None else num_threads 27 | 28 | self._sampler = _c.TemporalSampler( 29 | self._n_threads, 30 | self._n_nbrs, 31 | self._strategy == 'recent') 32 | 33 | def sample(self, blk: TBlock) -> TBlock: 34 | """Updates block with sampled 1-hop source neighbors 35 | 36 | :returns: updated block 37 | """ 38 | t_start = tt.start() 39 | if blk.num_dst() > 0: 40 | block = self._sampler.sample(blk._g._get_tcsr(), blk._dstnodes, blk._dsttimes) 41 | blk.set_nbrs( 42 | block.copy_dstindex(), 43 | block.copy_srcnodes(), 44 | block.copy_eid(), 45 | block.copy_ets()) 46 | tt.t_sample += tt.elapsed(t_start) 47 | return blk 48 | -------------------------------------------------------------------------------- /python/tglite/_stats.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | 5 | class TimeTable(object): 6 | def __init__(self): 7 | self.csv = None 8 | self.reset_epoch() 9 | 10 | def reset_epoch(self): 11 | """Set all time records to zeros""" 12 | self.t_epoch = 0.0 13 | self.t_loop = 0.0 14 | self.t_eval = 0.0 15 | self.t_forward = 0.0 16 | self.t_backward = 0.0 17 | self.t_sample = 0.0 18 | self.t_prep_batch = 0.0 19 | self.t_prep_input = 0.0 20 | self.t_post_update = 0.0 21 | self.t_mem_update = 0.0 22 | self.t_time_zero = 0.0 23 | self.t_time_nbrs = 0.0 24 | self.t_self_attn = 0.0 25 | 26 | def start(self): 27 | # Uncomment for better breakdown timings 28 | #torch.cuda.synchronize() 29 | return time.perf_counter() 30 | 31 | def elapsed(self, start): 32 | # Uncomment for better breakdown timings 33 | #torch.cuda.synchronize() 34 | return time.perf_counter() - start 35 | 36 | def print_epoch(self, prefix=' '): 37 | """Print the timing breakdown of different components in an epoch""" 38 | lines = f'' \ 39 | f'{prefix}epoch | total:{self.t_epoch:.2f}s loop:{self.t_loop:.2f}s eval:{self.t_eval:.2f}s\n' \ 40 | f'{prefix} loop | forward:{self.t_forward:.2f}s backward:{self.t_backward:.2f}s sample:{self.t_sample:.2f}s prep_batch:{self.t_prep_batch:.2f}s prep_input:{self.t_prep_input:.2f}s post_update:{self.t_post_update:.2f}s\n' \ 41 | f'{prefix} comp | mem_update:{self.t_mem_update:.2f}s time_zero:{self.t_time_zero:.2f}s time_nbrs:{self.t_time_nbrs:.2f}s self_attn:{self.t_self_attn:.2f}s\n' 42 | print(lines, end='') 43 | 44 | def csv_open(self, path): 45 | """Close the opened file (if any) and open a new file in write mode""" 46 | self.csv_close() 47 | self.csv = Path(path).open('w') 48 | 49 | def csv_close(self): 50 | """Close the opened file (if any)""" 51 | if self.csv is not None: 52 | self.csv.close() 53 | self.csv = None 54 | 55 | def csv_write_header(self): 56 | """Write the header line to the CSV file""" 57 | header = 'epoch,total,loop,eval,' \ 58 | 'forward,backward,sample,prep_batch,prep_input,post_update,' \ 59 | 'mem_update,time_zero,time_nbrs,self_attn' 60 | self.csv.write(header + '\n') 61 | 62 | def csv_write_line(self, epoch): 63 | """Write a line of timing information to the CSV file""" 64 | line = f'{epoch},{self.t_epoch},{self.t_loop},{self.t_eval},' \ 65 | f'{self.t_forward},{self.t_backward},{self.t_sample},{self.t_prep_batch},{self.t_prep_input},{self.t_post_update},' \ 66 | f'{self.t_mem_update},{self.t_time_zero},{self.t_time_nbrs},{self.t_self_attn}' 67 | self.csv.write(line + '\n') 68 | 69 | 70 | # Global for accumulating timings. 71 | tt = TimeTable() 72 | -------------------------------------------------------------------------------- /python/tglite/_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from . import _c 5 | from ._core import TError 6 | 7 | 8 | def get_num_cpus(default=16) -> int: 9 | cpus = os.cpu_count() 10 | return default if cpus is None else cpus 11 | 12 | 13 | def check_edges_times(edges: np.ndarray, times: np.ndarray): 14 | if edges.shape[0] != times.shape[0]: 15 | raise TError("edge list and timestamps must have same leading dimension") 16 | if edges.shape[1] != 2: 17 | raise TError("edge list must have only 2 columns") 18 | if edges.dtype != np.int32: 19 | raise TError("currently only supports int32 node/edge ids") 20 | if times.dtype != np.float32: 21 | raise TError("currently only supports float32 timestamps") 22 | 23 | 24 | def check_num_nodes(edges: np.ndarray, num_nodes: int = None) -> int: 25 | """Returns the number of nodes in the graph represented by the given edges 26 | 27 | :raises TErrror: if the specified number of nodes is less than or equal to the number of distinct nodes present in the edges 28 | """ 29 | max_nid = int(edges.max()) 30 | num_nodes = max_nid + 1 \ 31 | if num_nodes is None else num_nodes 32 | if num_nodes <= max_nid: 33 | raise TError("number of nodes must be greater than max node id") 34 | return num_nodes 35 | 36 | 37 | def create_tcsr(edges: np.ndarray, times: np.ndarray, num_nodes: int = None): 38 | check_edges_times(edges, times) 39 | num_nodes = check_num_nodes(edges, num_nodes) 40 | return _c.create_tcsr(edges, times, num_nodes) 41 | -------------------------------------------------------------------------------- /python/tglite/nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | if TYPE_CHECKING: 4 | from ._block import TBlock 5 | from ._context import TContext 6 | 7 | import torch 8 | import numpy as np 9 | from torch import Tensor 10 | 11 | from ._stats import tt 12 | from .op import precomputed_zeros, precomputed_times, edge_reduce, edge_view, edge_softmax 13 | 14 | 15 | class TimeEncode(torch.nn.Module): 16 | 17 | # A hidden tag used to detect if we have an instance of this builtin encoder 18 | # so that we can call the more optimized method of generating zeros, without 19 | # running into situation with circular dependencies. 20 | __tg_builtin_encoder__ = True 21 | 22 | def __init__(self, dim_time: int): 23 | ''' 24 | Initializes the TimeEncode module, which encodes time information into a higher-dimensional space. 25 | 26 | :param dim_time: dimensionality of the encoded time 27 | ''' 28 | super().__init__() 29 | self.w = torch.nn.Linear(1, dim_time) 30 | self.w.weight = torch.nn.Parameter(torch 31 | .from_numpy(1 / 10 ** np.linspace(0, 9, dim_time)) 32 | .float().reshape(dim_time, 1)) 33 | self.w.bias = torch.nn.Parameter(torch.zeros(dim_time).float()) 34 | self._z = torch.zeros(1).float() 35 | 36 | def zeros(self, size: int, device): 37 | ''' 38 | Generates a tensor of zeros with the encoded time dimensionality. 39 | 40 | :param size: 41 | :param device: 42 | ''' 43 | if self._z.device != torch.device(device): 44 | self._z = self._z.to(device) 45 | # expand does not allocate memory 46 | view = self._z.expand(size) 47 | return self(view) 48 | 49 | def forward(self, ts: Tensor) -> Tensor: 50 | ''' 51 | Forward pass of the TimeEncode module. Encodes the input time stamps into a high-dimensional space. 52 | 53 | :param ts: input time stamps 54 | ''' 55 | return torch.cos(self.w(ts.unsqueeze(-1))) 56 | 57 | 58 | class TemporalAttnLayer(torch.nn.Module): 59 | def __init__(self, ctx: TContext, num_heads: int, 60 | dim_node: int, dim_edge: int, dim_time: int, dim_out: int, 61 | dropout=0.1): 62 | """ 63 | Initializes the Temporal Attention Layer for processing dynamic graphs with temporal features. 64 | This layer uses multi-head attention mechanism to incorporate node, edge, and time features. 65 | 66 | :param ctx: context object 67 | :param num_heads: number of heads 68 | :param dim_node: dimension of node features 69 | :param dim_edge: dimension of edge features 70 | :param dim_time: dimension of time features 71 | :param dim_out: dimension of output features 72 | :param dropout: dropout rate 73 | """ 74 | super().__init__() 75 | assert (dim_out % num_heads == 0) 76 | self.ctx = ctx 77 | self.num_heads = num_heads 78 | self.dim_edge = dim_edge 79 | self.dim_out = dim_out 80 | self.time_encode = TimeEncode(dim_time) 81 | self.w_q = torch.nn.Linear(dim_node + dim_time, dim_out) 82 | self.w_kv = torch.nn.Linear(dim_node + dim_edge + dim_time, dim_out * 2) 83 | self.w_out = torch.nn.Linear(dim_node + dim_out, dim_out) 84 | self.attn_act = torch.nn.LeakyReLU(0.2) 85 | self.dropout = torch.nn.Dropout(dropout) 86 | self.layer_norm = torch.nn.LayerNorm(dim_out) 87 | 88 | def forward(self, blk: TBlock) -> Tensor: 89 | ''' 90 | Forward pass of the Temporal Attention Layer. Applies a time-sensitive attention mechanism over 91 | the input graph block (blk) to produce node embeddings. The method handles both cases of graph blocks 92 | with and without edges. 93 | 94 | If the block has no edges, a zero-initialized tensor is concatenated with the destination node features. 95 | For blocks with edges, the method computes attention scores and aggregates neighbor features using 96 | the computed attention. 97 | 98 | :param blk: input graph block 99 | ''' 100 | if blk.num_edges() == 0: 101 | dev = blk.dstdata['h'].device 102 | out = torch.zeros(blk.num_dst(), self.dim_out, dtype=torch.float32, device=dev) 103 | out = torch.cat([out, blk.dstdata['h']], dim=1) 104 | else: 105 | t_start = tt.start() 106 | zero_time_feat = precomputed_zeros(self.ctx, blk.layer, self.time_encode, blk.num_dst()) 107 | tt.t_time_zero += tt.elapsed(t_start) 108 | t_start = tt.start() 109 | nbrs_time_feat = precomputed_times(self.ctx, blk.layer, self.time_encode, blk.time_deltas()) 110 | tt.t_time_nbrs += tt.elapsed(t_start) 111 | t_start = tt.start() 112 | Q = torch.cat([blk.dstdata['h'], zero_time_feat], dim=1) 113 | if self.dim_edge > 0: 114 | Z = torch.cat([blk.srcdata['h'], blk.efeat(), nbrs_time_feat], dim=1) 115 | else: 116 | Z = torch.cat([blk.srcdata['h'], nbrs_time_feat], dim=1) 117 | del zero_time_feat 118 | del nbrs_time_feat 119 | 120 | Q = self.w_q(Q) 121 | Z = self.w_kv(Z) 122 | K = Z[:, :self.dim_out] 123 | V = Z[:, self.dim_out:] 124 | del Z 125 | 126 | Q = edge_view(blk, Q) 127 | Q = torch.reshape(Q, (Q.shape[0], self.num_heads, -1)) 128 | K = torch.reshape(K, (K.shape[0], self.num_heads, -1)) 129 | V = torch.reshape(V, (V.shape[0], self.num_heads, -1)) 130 | 131 | attn = torch.sum(Q * K, dim=2) 132 | del Q 133 | del K 134 | 135 | attn = self.attn_act(attn) 136 | attn = edge_softmax(blk, attn) 137 | attn = self.dropout(attn) 138 | 139 | out = torch.reshape(V * attn[:, :, None], (V.shape[0], -1)) 140 | del attn 141 | 142 | out = edge_reduce(blk, out, op='sum') 143 | out = torch.cat([out, blk.dstdata['h']], dim=1) 144 | tt.t_self_attn += tt.elapsed(t_start) 145 | 146 | out = self.w_out(out) 147 | out = torch.nn.functional.relu(self.dropout(out)) 148 | out = self.layer_norm(out) 149 | return out 150 | 151 | -------------------------------------------------------------------------------- /scripts/run-exp-large.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run experiments for larger benchmarks. 4 | # 5 | 6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)" 7 | cd "$tglite/examples" 8 | 9 | echo "start: $(date)" 10 | 11 | for data in wiki-talk gdelt; do 12 | for model in apan jodie tgat tgn; do 13 | echo "tglite $data $model"; 14 | if [[ "$model" != "jodie" ]]; then 15 | "./exp/$model-gdelt.sh" --data "$data" --n-threads "$(nproc)" --opt-all 16 | else 17 | "./exp/$model-gdelt.sh" --data "$data" 18 | fi 19 | mv out-stats.csv "out-tglite-$data-$model.csv"; 20 | echo; 21 | echo "time: $(date)" 22 | echo; 23 | done 24 | done 25 | 26 | echo "end: $(date)" 27 | -------------------------------------------------------------------------------- /scripts/run-exp-slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem=64g 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --cpus-per-task=64 # <- match to OMP_NUM_THREADS 6 | #SBATCH --partition=gpuA100x4 7 | #SBATCH --time=00:10:00 8 | #SBATCH --account=bbzw-delta-gpu 9 | #SBATCH --job-name=jodie-wiki 10 | #SBATCH --output=test.out 11 | #SBATCH --error=test.err 12 | ### GPU options ### 13 | #SBATCH --gpus-per-node=1 14 | #SBATCH --gpus-per-task=1 15 | #SBATCH --gpu-bind=verbose,per_task:1 16 | ###SBATCH --gpu-bind=none # <- or closest 17 | 18 | source ~/.bashrc 19 | conda deactivate 20 | module purge 21 | module load anaconda3_gpu 22 | module list 23 | 24 | conda activate tglite 25 | conda info -e 26 | 27 | echo "job is starting on `hostname`" 28 | 29 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)" 30 | cd "$tglite/examples" 31 | 32 | echo "start: $(date)" 33 | 34 | srun python3 \ 35 | jodie.py \ 36 | --seed 0 \ 37 | --prefix exp \ 38 | --epochs 50 \ 39 | --bsize 2000 \ 40 | -d wiki 41 | 42 | echo "end $(date)" 43 | 44 | exit 45 | -------------------------------------------------------------------------------- /scripts/run-exp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run experiments for standard benchmarks. 4 | # 5 | 6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)" 7 | cd "$tglite/examples" 8 | 9 | echo "start: $(date)" 10 | 11 | for data in wiki mooc reddit lastfm; do 12 | for model in apan jodie tgat tgn; do 13 | common_flags="--data $data" 14 | if [[ "$model" != "jodie" ]]; then 15 | common_flags+=" --n-threads $(nproc)" 16 | fi 17 | 18 | echo "tglite $data $model"; 19 | "./exp/$model.sh" $common_flags 20 | mv out-stats.csv "out-tglite-$data-$model.csv"; 21 | echo; 22 | echo "time: $(date)" 23 | echo; 24 | 25 | "./exp/$model.sh" $common_flags --move 26 | mv out-stats.csv "out-tglite-allgpu-$data-$model.csv"; 27 | echo; 28 | echo "time: $(date)" 29 | echo; 30 | 31 | if [[ "$model" != "jodie" ]]; then 32 | "./exp/$model.sh" $common_flags --opt-all 33 | mv out-stats.csv "out-tglite-opt-$data-$model.csv"; 34 | echo; 35 | echo "time: $(date)" 36 | echo; 37 | fi 38 | done 39 | done 40 | 41 | echo "end: $(date)" 42 | -------------------------------------------------------------------------------- /scripts/setup-aws.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Script to setup an AWS EC2 instance with the following expected provisioning: 4 | # 5 | # - AMI: Amazon Linux 2 AMI with NVIDIA TESLA GPU Driver 6 | # - Instance Type: p3.8xlarge (32 vCPU, 244 GiB, Tesla V100 GPU) 7 | # - Storage: at least 80 GB 8 | # 9 | 10 | conda_dir="$HOME/.conda" 11 | conda_bin="$conda_dir/bin/conda" 12 | 13 | echo 14 | echo ">> installing conda" 15 | echo 16 | 17 | curl -sL -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.12.0-Linux-x86_64.sh 18 | bash ~/miniconda.sh -b -p "$conda_dir" 19 | rm ~/miniconda.sh 20 | 21 | "$conda_bin" init 22 | "$conda_bin" config --set auto_activate_base false 23 | source ~/.conda/etc/profile.d/conda.sh 24 | 25 | echo 26 | echo ">> setting up cuda11" 27 | echo 28 | 29 | conda create -n cuda11 30 | conda activate cuda11 31 | conda install -c "nvidia/label/cuda-11.8.0" cuda-libraries 32 | echo 'export LD_LIBRARY_PATH="$HOME/.conda/envs/cuda11/lib:$LD_LIBRARY_PATH"' >> ~/.bashrc 33 | 34 | echo 35 | echo ">> done! please restart your shell session" 36 | -------------------------------------------------------------------------------- /scripts/setup-repo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Script to setup environment for this repo. 4 | # 5 | 6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)" 7 | cd "$tglite" 8 | 9 | echo 10 | echo ">> setting up environment" 11 | echo 12 | 13 | source ~/.conda/etc/profile.d/conda.sh 14 | conda create -n tglite python=3.7 15 | conda activate tglite 16 | 17 | echo 18 | echo ">> installing python packages" 19 | echo 20 | 21 | pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 22 | pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html 23 | 24 | echo 25 | echo ">> installing tglite package" 26 | echo 27 | 28 | python setup.py develop 29 | 30 | echo 31 | echo ">> setting up example applications" 32 | echo 33 | 34 | cd "$tglite/examples" 35 | pip install -r requirements.txt 36 | ./download-data.sh 37 | python gen-data-files.py --data wiki-talk 38 | 39 | echo 40 | echo ">> done!" 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CppExtension 4 | 5 | 6 | def main(): 7 | curr_dir = Path(__file__).absolute().parent 8 | setup( 9 | ext_modules=[ 10 | CppExtension( 11 | name="tglite._c", 12 | sources=[ 13 | "lib/bind.cpp", 14 | "lib/cache.cpp", 15 | "lib/dedup.cpp", 16 | "lib/sampler.cpp", 17 | "lib/tcsr.cpp", 18 | "lib/utils.cpp" 19 | ], 20 | include_dirs=[curr_dir/"include/"], 21 | extra_compile_args=["-std=c++14", "-fopenmp"], 22 | extra_link_args=["-fopenmp"] 23 | )], 24 | cmdclass={ 25 | "build_ext": BuildExtension 26 | }) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /tests/python/data/edges.csv: -------------------------------------------------------------------------------- 1 | src,dst,time 2 | 1,8228,0.0 3 | 2,8229,36.0 4 | 2,8229,77.0 5 | 3,8230,131.0 6 | 2,8229,150.0 7 | 3,8230,153.0 8 | 4,8231,169.0 9 | 2,8229,217.0 10 | 5,8232,218.0 11 | 5,8232,242.0 12 | 6,8233,295.0 13 | 2,8229,300.0 14 | 2,8229,376.0 15 | 3,8230,380.0 16 | 7,8234,384.0 17 | 8,8228,384.0 18 | 4,8231,387.0 19 | 9,8235,395.0 20 | 2,8229,402.0 21 | 3,8230,402.0 22 | 3,8230,432.0 23 | 10,8236,454.0 24 | 2,8229,465.0 25 | 7,8234,563.0 26 | 5,8232,578.0 27 | 2,8229,623.0 28 | 4,8231,646.0 29 | 3,8230,674.0 30 | 2,8229,692.0 31 | 3,8230,715.0 32 | 11,8237,729.0 33 | 12,8238,742.0 34 | 4,8231,742.0 35 | 4,8231,809.0 36 | 13,8239,854.0 37 | 2,8229,854.0 38 | 5,8232,872.0 39 | 14,8240,904.0 40 | 15,8241,939.0 41 | 13,8239,959.0 42 | 2,8229,989.0 43 | 15,8241,1026.0 44 | 15,8241,1072.0 45 | 16,8242,1142.0 46 | 13,8239,1151.0 47 | 17,8243,1153.0 48 | 18,8244,1177.0 49 | 5,8232,1199.0 50 | 5,8232,1214.0 51 | 4,8231,1226.0 52 | 19,8245,1255.0 53 | 12,8238,1301.0 54 | 20,8246,1307.0 55 | 21,8247,1308.0 56 | 22,8248,1314.0 57 | 23,8249,1340.0 58 | 4,8231,1401.0 59 | 18,8250,1451.0 60 | 12,8238,1496.0 61 | 5,8232,1508.0 62 | 24,8251,1529.0 63 | 5,8232,1531.0 64 | 25,8252,1543.0 65 | 13,8239,1555.0 66 | 26,8253,1600.0 67 | 4,8231,1607.0 68 | 5,8232,1630.0 69 | 3,8230,1643.0 70 | 27,8254,1651.0 71 | 5,8232,1664.0 72 | 5,8232,1677.0 73 | 28,8255,1682.0 74 | 12,8238,1706.0 75 | 13,8239,1713.0 76 | 29,8256,1724.0 77 | 19,8245,1729.0 78 | 30,8257,1771.0 79 | 12,8238,1789.0 80 | 3,8230,1797.0 81 | 9,8235,1805.0 82 | 5,8232,1814.0 83 | 5,8232,1827.0 84 | 31,8258,1908.0 85 | 32,8259,1914.0 86 | 29,8256,1925.0 87 | 29,8256,1996.0 88 | 17,8250,2022.0 89 | 33,8260,2053.0 90 | 5,8232,2063.0 91 | 4,8231,2075.0 92 | 34,8261,2079.0 93 | 9,8235,2098.0 94 | 5,8232,2127.0 95 | 17,8250,2131.0 96 | 25,8252,2233.0 97 | 5,8232,2285.0 98 | 4,8231,2290.0 99 | 2,8229,2313.0 100 | 4,8231,2333.0 101 | 35,8262,2343.0 102 | -------------------------------------------------------------------------------- /tests/python/test_block.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pathlib import Path 4 | 5 | import tglite as tg 6 | 7 | 8 | def test_tblock(): 9 | g = tg.from_csv(Path(__file__).parent / 'data/edges.csv') 10 | ctx = tg.TContext(g) 11 | batch = next(tg.iter_edges(g, size=10)) 12 | assert len(batch) == 10 13 | 14 | blk = batch.block(ctx) 15 | assert blk.layer == 0 16 | assert blk.num_dst() == len(batch) * 2 # src, dst nodes of batch edges 17 | assert blk.num_dst() == len(blk.dsttimes) 18 | assert blk.has_nbrs() == False # no neighbor info before sampling 19 | -------------------------------------------------------------------------------- /tests/python/test_frame.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import tglite as tg 5 | 6 | 7 | def test_tframe(): 8 | frame = tg.TFrame(dim=3) 9 | frame['f'] = torch.ones((3, 2)) 10 | assert frame['f'].shape == (3, 2) 11 | assert frame['f'].sum() == 6 12 | 13 | 14 | def test_tframe_dim(): 15 | frame = tg.TFrame() 16 | assert frame.dim() == 0 17 | frame = tg.TFrame(dim=16) 18 | assert frame.dim() == 16 19 | 20 | 21 | def test_tframe_checks_dim(): 22 | with pytest.raises(tg.TError) as exinfo: 23 | frame = tg.TFrame(dim=3) 24 | frame['f'] = torch.ones((1, 2)) 25 | assert "dimension of 3, got 1" in str(exinfo.value) 26 | 27 | 28 | def test_tframe_only_tensors(): 29 | frame = tg.TFrame(dim=3) 30 | frame['tensor'] = torch.randn(3, 2) 31 | with pytest.raises(tg.TError) as exinfo: 32 | frame['list'] = [1, 2, 3] 33 | assert "expected value to be a tensor" in str(exinfo.value) 34 | 35 | 36 | def test_tframe_dict_behavior(): 37 | frame = tg.TFrame(dim=3) 38 | frame['a'] = torch.randn((3, 1)) 39 | frame['b'] = torch.randn((3, 2)) 40 | frame['c'] = torch.randn((3, 3)) 41 | frame['d'] = torch.randn((3, 4)) 42 | 43 | assert 'a' in frame 44 | assert len(frame) == 4 45 | for key, val in frame.items(): 46 | if key == 'a': assert val.shape[1] == 1 47 | if key == 'b': assert val.shape[1] == 2 48 | if key == 'c': assert val.shape[1] == 3 49 | if key == 'd': assert val.shape[1] == 4 50 | 51 | frame.clear() 52 | assert frame.dim() == 3 53 | assert len(frame) == 0 54 | -------------------------------------------------------------------------------- /tests/python/test_graph.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | import tglite as tg 7 | 8 | 9 | def test_from_csv_nofile(): 10 | with pytest.raises(tg.TError) as exinfo: 11 | tg.from_csv("foobar") 12 | assert "file does not exist" in str(exinfo.value) 13 | 14 | 15 | def test_from_csv(): 16 | path = Path(__file__).parent / 'data/edges.csv' 17 | g = tg.from_csv(path) 18 | assert g.num_edges() == 100 19 | assert g.num_nodes() == 8263 20 | 21 | 22 | def test_tgraph(): 23 | edges = np.array([[0,1], [0,2], [1,2]], dtype=np.int32) 24 | etime = np.array([10, 11, 12], dtype=np.float32) 25 | 26 | g = tg.TGraph(edges, etime) 27 | g.edata['f'] = torch.randn((3, 2)) 28 | g.ndata['f'] = torch.randn((3, 2)) 29 | 30 | assert g.num_edges() == 3 31 | assert g.num_nodes() == 3 32 | assert g.edata['f'].shape[1] == 2 33 | assert g.ndata['f'].shape[1] == 2 34 | assert str(g.storage_device()) == 'cpu' 35 | assert str(g.compute_device()) == 'cpu' 36 | -------------------------------------------------------------------------------- /tests/python/test_tglite.py: -------------------------------------------------------------------------------- 1 | import tglite as tg 2 | 3 | 4 | def test_terror(): 5 | err = tg.TError("message") 6 | assert str(err) == "message" 7 | -------------------------------------------------------------------------------- /tests/python/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | import tglite as tg 5 | 6 | 7 | def test_tcsr(): 8 | edges = np.array([[0,1], [0,2], [1,2]], dtype=np.int32) 9 | etime = np.array([10, 11, 12], dtype=np.float32) 10 | tcsr = tg.utils.create_tcsr(edges, etime) 11 | 12 | assert len(tcsr.ind) == len(edges) + 1 13 | assert list(tcsr.ind) == [0, 2, 4, 6] 14 | assert list(tcsr.nbr) == [1, 2, 0, 2, 0, 1] 15 | assert list(tcsr.eid) == [0, 1, 0, 2, 1, 2] 16 | assert list(tcsr.ets) == [10, 11, 10, 12, 11, 12] 17 | --------------------------------------------------------------------------------