├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
├── Makefile
├── README.md
├── make.bat
├── run.sh
└── source
│ ├── api
│ └── python
│ │ ├── index.rst
│ │ ├── tglite.batch.rst
│ │ ├── tglite.block.rst
│ │ ├── tglite.context.rst
│ │ ├── tglite.graph.rst
│ │ ├── tglite.mailbox.rst
│ │ ├── tglite.memory.rst
│ │ ├── tglite.nn.rst
│ │ ├── tglite.op.rst
│ │ ├── tglite.rst
│ │ └── tglite.sampler.rst
│ ├── conf.py
│ ├── generated
│ ├── tglite.Mailbox.rst
│ ├── tglite.Memory.rst
│ ├── tglite.TBatch.block.rst
│ ├── tglite.TBatch.block_adj.rst
│ ├── tglite.TBatch.edges.rst
│ ├── tglite.TBatch.eids.rst
│ ├── tglite.TBatch.g.rst
│ ├── tglite.TBatch.neg_nodes.rst
│ ├── tglite.TBatch.nodes.rst
│ ├── tglite.TBatch.split_data.rst
│ ├── tglite.TBatch.times.rst
│ ├── tglite.TBlock.allnodes.rst
│ ├── tglite.TBlock.apply.rst
│ ├── tglite.TBlock.clear_hooks.rst
│ ├── tglite.TBlock.clear_nbrs.rst
│ ├── tglite.TBlock.dstdata.rst
│ ├── tglite.TBlock.dstfeat.rst
│ ├── tglite.TBlock.dstindex.rst
│ ├── tglite.TBlock.dstnodes.rst
│ ├── tglite.TBlock.dsttimes.rst
│ ├── tglite.TBlock.edata.rst
│ ├── tglite.TBlock.efeat.rst
│ ├── tglite.TBlock.eid.rst
│ ├── tglite.TBlock.ets.rst
│ ├── tglite.TBlock.g.rst
│ ├── tglite.TBlock.has_nbrs.rst
│ ├── tglite.TBlock.layer.rst
│ ├── tglite.TBlock.mail.rst
│ ├── tglite.TBlock.mem_data.rst
│ ├── tglite.TBlock.next.rst
│ ├── tglite.TBlock.next_block.rst
│ ├── tglite.TBlock.nfeat.rst
│ ├── tglite.TBlock.num_dst.rst
│ ├── tglite.TBlock.num_edges.rst
│ ├── tglite.TBlock.num_src.rst
│ ├── tglite.TBlock.prev.rst
│ ├── tglite.TBlock.register_hook.rst
│ ├── tglite.TBlock.run_hooks.rst
│ ├── tglite.TBlock.set_nbrs.rst
│ ├── tglite.TBlock.srcdata.rst
│ ├── tglite.TBlock.srcfeat.rst
│ ├── tglite.TBlock.srcnodes.rst
│ ├── tglite.TBlock.time_deltas.rst
│ ├── tglite.TBlock.uniq_src.rst
│ ├── tglite.TContext.enable_embed_caching.rst
│ ├── tglite.TContext.enable_time_precompute.rst
│ ├── tglite.TContext.eval.rst
│ ├── tglite.TContext.graph.rst
│ ├── tglite.TContext.need_sampling.rst
│ ├── tglite.TContext.set_cache_limit.rst
│ ├── tglite.TContext.set_time_window.rst
│ ├── tglite.TContext.train.rst
│ ├── tglite.TGraph.compute_device.rst
│ ├── tglite.TGraph.edata.rst
│ ├── tglite.TGraph.efeat.rst
│ ├── tglite.TGraph.mailbox.rst
│ ├── tglite.TGraph.mem.rst
│ ├── tglite.TGraph.move_data.rst
│ ├── tglite.TGraph.ndata.rst
│ ├── tglite.TGraph.nfeat.rst
│ ├── tglite.TGraph.num_edges.rst
│ ├── tglite.TGraph.num_nodes.rst
│ ├── tglite.TGraph.set_compute.rst
│ ├── tglite.TGraph.storage_device.rst
│ ├── tglite.from_csv.rst
│ ├── tglite.iter_edges.rst
│ ├── tglite.op.aggregate.rst
│ ├── tglite.op.cache.rst
│ ├── tglite.op.coalesce.rst
│ ├── tglite.op.dedup.rst
│ ├── tglite.op.edge_reduce.rst
│ ├── tglite.op.edge_softmax.rst
│ ├── tglite.op.edge_view.rst
│ ├── tglite.op.precomputed_times.rst
│ ├── tglite.op.precomputed_zeros.rst
│ ├── tglite.op.preload.rst
│ ├── tglite.op.propagate.rst
│ └── tglite.op.src_scatter.rst
│ ├── img
│ ├── blank.png
│ ├── colab.svg
│ ├── github.svg
│ ├── tblock-structure.png
│ ├── tblock-workflow.png
│ └── train.png
│ ├── index.rst
│ ├── install
│ └── index.rst
│ └── tutorial
│ ├── quickstart.ipynb
│ └── tblock.rst
├── environment.yml
├── examples
├── apan
│ ├── apan.py
│ └── train.py
├── download-data.sh
├── exp
│ ├── apan-gdelt.sh
│ ├── apan.sh
│ ├── jodie-gdelt.sh
│ ├── jodie.sh
│ ├── tgat-gdelt.sh
│ ├── tgat.sh
│ ├── tgn-gdelt.sh
│ └── tgn.sh
├── gen-data-files.py
├── jodie
│ ├── jodie.py
│ └── train.py
├── requirements.txt
├── support.py
├── tgat
│ ├── TGAT.ipynb
│ ├── tgat.py
│ └── train.py
└── tgn
│ ├── tgn.py
│ └── train.py
├── include
└── tglite
│ ├── cache.h
│ ├── core.h
│ ├── dedup.h
│ ├── sampler.h
│ ├── tcsr.h
│ └── utils.h
├── lib
├── bind.cpp
├── cache.cpp
├── dedup.cpp
├── sampler.cpp
├── tcsr.cpp
└── utils.cpp
├── pyproject.toml
├── python
└── tglite
│ ├── __init__.py
│ ├── _batch.py
│ ├── _block.py
│ ├── _context.py
│ ├── _core.py
│ ├── _frame.py
│ ├── _graph.py
│ ├── _mailbox.py
│ ├── _memory.py
│ ├── _sampler.py
│ ├── _stats.py
│ ├── _utils.py
│ ├── nn.py
│ └── op.py
├── scripts
├── run-exp-large.sh
├── run-exp-slurm.sh
├── run-exp.sh
├── setup-aws.sh
└── setup-repo.sh
├── setup.py
└── tests
└── python
├── data
└── edges.csv
├── test_block.py
├── test_frame.py
├── test_graph.py
├── test_tglite.py
└── test_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | *.DS_Store
3 |
4 | __pycache__/
5 | .pytest_cache/
6 | .coverage
7 |
8 | dist/
9 | build/
10 | *.egg-info/
11 | *.so
12 |
13 | docs/_build/
14 |
15 | examples/data
16 | examples/models/
17 | examples/out-*.csv
18 | examples/out-*.txt
19 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for Sphinx projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version and other tools you might need
8 | build:
9 | os: ubuntu-20.04
10 | tools:
11 | # python: "mambaforge-22.9"
12 | python: "3.7"
13 |
14 | commands:
15 | # - pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
16 | # - pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html
17 | # - pip install .[docs]
18 | # - cd docs/ && make html
19 | # - mkdir _readthedocs
20 | # - cp -r docs/build/html _readthedocs/
21 | # - python setup.py install
22 | # Install dependencies
23 | # - cd docs/ && pip install -r requirements.txt
24 | # Build the site
25 | # - cd docs/ && make html
26 | # Copy generated files into Read the Docs directory
27 | # - cd docs/ && ls
28 | # - cd docs/build && ls
29 | # - mkdir _readthedocs
30 | # - cp --recursive docs/build/html _readthedocs/
31 | - ls _readthedocs
32 |
33 | # conda:
34 | # environment: environment.yml
35 |
36 | # Build documentation in the "docs/" directory with Sphinx
37 | # sphinx:
38 | # configuration: null
39 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs
40 | # builder: "dirhtml"
41 | # Fail on all warnings to avoid broken references
42 | # fail_on_warning: true
43 |
44 | # Optionally build your docs in additional formats such as PDF and ePub
45 | # formats:
46 | # - pdf
47 | # - epub
48 |
49 | # Optional but recommended, declare the Python requirements required
50 | # to build your documentation
51 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
52 | # python:
53 | # install:
54 | # - method: pip
55 | # path: .
56 | # extra_requirements:
57 | # - docs
58 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 TGLite Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Add header files to sdist
2 | graft include/tglite
3 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: build clean develop install test uninstall
2 |
3 | build:
4 | python setup.py build
5 |
6 | develop:
7 | python setup.py develop
8 |
9 | install:
10 | python setup.py install
11 |
12 | uninstall:
13 | pip uninstall --yes tglite
14 |
15 | test:
16 | pytest --cov=tglite
17 |
18 | clean:
19 | rm -rf **/*/__pycache__ .pytest_cache
20 | rm -rf build dist **/*.egg-info
21 | rm -rf .coverage **/*/*.so
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TGLite - A Framework for Temporal GNNs
2 |
3 | TGLite is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. TGNNs, or Temporal Graph Neural Networks, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes. TGLite employs an abstraction called a _TBlock_ to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the [TGL][tgl] framework by [up to 3x](#publication) in terms of training time.
4 |
5 |
6 |
7 | End-to-end training epoch time comparison on an Nvidia A100 GPU.
8 |
9 |
10 | [tgl]: https://github.com/amazon-science/tgl
11 |
12 | ## Installation
13 |
14 | See our [documentation][docs] for instructions on how to install the TGLite package, as well as examples and references for supported functionality. To install from source or for local development, go to the [Building from source][build-src] section, it also explains how to run [examples][exp].
15 |
16 | [docs]: https://tglite.readthedocs.io
17 | [build-src]: https://tglite.readthedocs.io/en/latest/install/index.html#building-from-source
18 | [exp]: examples
19 |
20 | ## Getting Started
21 |
22 | TGLite is currently designed to be used with PyTorch as a training backend, typically with GPU devices. A TGNN model can be defined and trained in the usual way using PyTorch, with the computations constructed using a mix of PyTorch functions and operators/optimizations from TGLite. Below is a simple example (not a real network architecture, just for demonstration purposes):
23 |
24 | ```python
25 | import torch
26 | import tglite as tg
27 |
28 | class TGNN(torch.nn.Module):
29 | def __init__(self, ctx: tg.TContext, dim_node=100, dim_time=100):
30 | super().__init__()
31 | self.ctx = ctx
32 | self.linear = torch.nn.Linear(dim_node + dim_time, dim_node)
33 | self.sampler = tg.TSampler(num_nbrs=10, strategy='recent')
34 | self.encoder = tg.nn.TimeEncode(dim_time)
35 |
36 | def forward(self, batch: tg.TBatch):
37 | blk = batch.block(self.ctx)
38 | blk = tg.op.dedup(blk)
39 | blk = self.sampler.sample(blk)
40 | blk.srcdata['h'] = blk.srcfeat()
41 | return tg.op.aggregate(blk, self.compute, key='h')
42 |
43 | def compute(self, blk: tg.TBlock):
44 | feats = self.encoder(blk.time_deltas())
45 | feats = torch.cat([blk.srcdata['h'], feats], dim=1)
46 | embeds = self.linear(feats)
47 | embeds = tg.op.edge_reduce(blk, embeds, op='sum')
48 | return torch.relu(embeds)
49 |
50 | graph = tg.from_csv(...)
51 | ctx = tg.TContext(graph)
52 | model = TGNN(ctx)
53 | train(model)
54 | ```
55 |
56 | The example model is defined to first construct the graph dependencies for nodes in the current batch of edges. The `dedup()` optimization is applied before sampling for 10 recent neighbors. Node embeddings are computed by simply combining node and time features, applying a linear layer and summing across neighbors. More complex computations and aggregations, such as temporal self-attention often used with TGNNs, can be defined using the provided building blocks.
57 |
58 | ## Publication
59 |
60 | * Yufeng Wang and Charith Mendis. 2024. [TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks][tglite-paper]. In 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS '24), April 2024, La Jolla, CA, USA. (To Appear)
61 |
62 | * Yufeng Wang and Charith Mendis. 2023. [TGOpt: Redundancy-Aware Optimizations for Temporal Graph Attention Networks][tgopt-paper]. In Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming (PPoPP '23), February 2023, Montreal, QC, Canada.
63 |
64 | If you find TGLite useful, please consider attributing to the following citation:
65 |
66 | ```bibtex
67 | @inproceedings{wang2024tglite,
68 | author = {Wang, Yufeng and Mendis, Charith},
69 | title = {TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks},
70 | year = {2024},
71 | booktitle = {Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2},
72 | doi = {10.1145/3620665.3640414}
73 | }
74 | ```
75 |
76 | [tglite-paper]: https://charithmendis.com/assets/pdf/asplos24-tglite.pdf
77 | [tgopt-paper]: https://charithmendis.com/assets/pdf/ppopp23-tgopt.pdf
78 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Build TGLite Documentation
2 | See [Build the documentation locally](https://tglite.readthedocs.io/en/latest/install/index.html#building-the-document-locally) in the doc.
3 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Script to open TGLite documentation using detected browser.
4 | #
5 |
6 | # Get the directory where the script is located.
7 | SCRIPT_DIR="$(dirname "$0")"
8 |
9 | # Path to the HTML file you want to open.
10 | HTML_PATH="$SCRIPT_DIR/build/html/index.html"
11 |
12 | # Function to open the HTML file using a browser.
13 | open_html() {
14 | local path=$1
15 | if command -v xdg-open &> /dev/null; then
16 | # Preferred way to open files on Desktop Linux
17 | xdg-open "$path"
18 | elif command -v gnome-open &> /dev/null; then
19 | # For systems with Gnome.
20 | gnome-open "$path"
21 | elif command -v x-www-browser &> /dev/null; then
22 | # A generic way to open files, might work when xdg-open and gnome-open are unavailable.
23 | x-www-browser "$path"
24 | else
25 | echo "Could not detect the web browser to open the documentation."
26 | return 1
27 | fi
28 | }
29 |
30 | # Detect the operating system and open the HTML file.
31 | case "$(uname -s)" in
32 | Linux*)
33 | open_html "$HTML_PATH"
34 | ;;
35 | Darwin*)
36 | open "$HTML_PATH"
37 | ;;
38 | CYGWIN*|MINGW32*|MSYS*|MINGW*)
39 | start "$HTML_PATH"
40 | ;;
41 | *)
42 | echo "Unknown operating system. Cannot open the documentation automatically."
43 | ;;
44 | esac
45 |
--------------------------------------------------------------------------------
/docs/source/api/python/index.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 |
7 | tglite
8 | tglite.batch
9 | tglite.block
10 | tglite.context
11 | tglite.graph
12 | tglite.mailbox
13 | tglite.memory
14 | tglite.sampler
15 | tglite.nn
16 | tglite.op
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.batch.rst:
--------------------------------------------------------------------------------
1 | .. _api-batch:
2 |
3 | tglite.TBatch
4 | =============
5 |
6 | .. currentmodule:: tglite
7 | .. autoclass:: TBatch
8 |
9 | .. automethod:: __init__
10 |
11 | .. currentmodule:: tglite.TBatch
12 |
13 | Get graph data
14 | --------------
15 | .. autosummary::
16 | :toctree: ../../generated/
17 |
18 | g
19 | neg_nodes
20 | eids
21 | edges
22 | nodes
23 | times
24 |
25 |
26 | Get TBlock
27 | -----------
28 | .. autosummary::
29 | :toctree: ../../generated/
30 |
31 | block
32 | block_adj
33 |
34 | Split data
35 | ----------
36 | .. autosummary::
37 | :toctree: ../../generated/
38 |
39 | split_data
40 |
41 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.block.rst:
--------------------------------------------------------------------------------
1 | .. _api-block:
2 |
3 | tglite.TBlock
4 | =============
5 |
6 | TBlock captures 1-hop relationships between node/time pairs and their neighbors for doing computations. :ref:`Figure
7 | ` shows the internal structure of a TBlock and the doubly-linked list design with next and prev pointing
8 | to sampled neighbors' TBlock. To have a general idea of how TBlock works, please refer to the :ref:`tutorial `.
9 |
10 | .. _tblock figure:
11 | .. figure:: ../../img/tblock-structure.png
12 | :alt: tblock-structure
13 | :align: center
14 | :figwidth: 60 %
15 |
16 | Diagram of the doubly-linked list design and internal structure of a TBlock (destination node-time is denoted as
17 | ).
18 |
19 |
20 | .. currentmodule:: tglite
21 | .. autoclass:: TBlock
22 |
23 | .. automethod:: __init__
24 |
25 |
26 | .. currentmodule:: tglite.TBlock
27 |
28 |
29 | Query TBlock attributes
30 | -----------------------
31 | .. autosummary::
32 | :toctree: ../../generated/
33 |
34 | g
35 | layer
36 | dstnodes
37 | dsttimes
38 | num_dst
39 |
40 | Query neighbor attributes
41 | -------------------------
42 | .. autosummary::
43 | :toctree: ../../generated/
44 |
45 | dstindex
46 | srcnodes
47 | num_src
48 | eid
49 | ets
50 | num_edges
51 | has_nbrs
52 |
53 | Query data
54 | ----------
55 | .. autosummary::
56 | :toctree: ../../generated/
57 |
58 | dstdata
59 | srcdata
60 | edata
61 |
62 | Query Cache
63 | -----------
64 | .. autosummary::
65 | :toctree: ../../generated/
66 |
67 | allnodes
68 | uniq_src
69 | efeat
70 | nfeat
71 | srcfeat
72 | dstfeat
73 | mem_data
74 | mail
75 | time_deltas
76 |
77 | Neighbor ops
78 | ------------
79 | .. autosummary::
80 | :toctree: ../../generated/
81 |
82 | set_nbrs
83 | clear_nbrs
84 |
85 | Linked list ops
86 | ---------------
87 | .. autosummary::
88 | :toctree: ../../generated/
89 |
90 | prev
91 | next
92 | next_block
93 |
94 | Execution ops
95 | -------------
96 | .. autosummary::
97 | :toctree: ../../generated/
98 |
99 | apply
100 | register_hook
101 | run_hooks
102 | clear_hooks
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.context.rst:
--------------------------------------------------------------------------------
1 | .. _api-context:
2 |
3 | tglite.TContext
4 | ===============
5 | .. currentmodule:: tglite
6 | .. autoclass:: TContext
7 |
8 | .. automethod:: __init__
9 |
10 |
11 | .. currentmodule:: tglite.TContext
12 |
13 | Basic settings
14 | --------------
15 | .. autosummary::
16 | :toctree: ../../generated/
17 |
18 | train
19 | eval
20 | need_sampling
21 |
22 | Set node embedding cache
23 | ------------------------
24 | .. autosummary::
25 | :toctree: ../../generated/
26 |
27 | enable_embed_caching
28 | set_cache_limit
29 |
30 | Set time precomputation
31 | -----------------------
32 | .. autosummary::
33 | :toctree: ../../generated/
34 |
35 | enable_time_precompute
36 | set_time_window
37 |
38 | Query graph
39 | -----------
40 | .. autosummary::
41 | :toctree: ../../generated/
42 |
43 | graph
44 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.graph.rst:
--------------------------------------------------------------------------------
1 | .. _api-graph:
2 |
3 | tglite.TGraph
4 | =============
5 | .. currentmodule:: tglite
6 | .. autoclass:: TGraph
7 |
8 | .. automethod:: __init__
9 |
10 |
11 | .. currentmodule:: tglite.TGraph
12 |
13 | Query graph structure
14 | ---------------------
15 |
16 | .. autosummary::
17 | :toctree: ../../generated/
18 |
19 | num_nodes
20 | num_edges
21 |
22 | Query graph data
23 | ----------------
24 |
25 | .. autosummary::
26 | :toctree: ../../generated/
27 |
28 | efeat
29 | nfeat
30 | edata
31 | ndata
32 |
33 | Query memory-based TGN data
34 | ---------------------------
35 |
36 | .. autosummary::
37 | :toctree: ../../generated/
38 |
39 | mem
40 | mailbox
41 |
42 | Set device
43 | ----------
44 |
45 | .. autosummary::
46 | :toctree: ../../generated/
47 |
48 | storage_device
49 | compute_device
50 | set_compute
51 | move_data
52 |
53 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.mailbox.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: tglite
2 |
3 | .. autosummary::
4 | :toctree: ../../generated/
5 | :nosignatures:
6 |
7 | Mailbox
8 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.memory.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: tglite
2 |
3 | .. autosummary::
4 | :toctree: ../../generated/
5 | :nosignatures:
6 |
7 | Memory
8 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.nn.rst:
--------------------------------------------------------------------------------
1 | .. _api-nn:
2 |
3 | tglite.nn
4 | ====================
5 | .. currentmodule:: tglite.nn
6 | .. autoclass:: TimeEncode
7 |
8 | .. automethod:: __init__
9 | .. automethod:: zeros
10 | .. automethod:: forward
11 |
12 |
13 | .. autoclass:: TemporalAttnLayer
14 |
15 | .. automethod:: __init__
16 | .. automethod:: forward
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.op.rst:
--------------------------------------------------------------------------------
1 | .. _api-op:
2 |
3 | tglite.op
4 | =========
5 | .. currentmodule:: tglite.op
6 |
7 | .. autosummary::
8 | :toctree: ../../generated/
9 |
10 | edge_view
11 | edge_softmax
12 | edge_reduce
13 | src_scatter
14 | coalesce
15 | preload
16 | aggregate
17 | propagate
18 | dedup
19 | cache
20 | precomputed_zeros
21 | precomputed_times
22 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.rst:
--------------------------------------------------------------------------------
1 | .. _api-tglite:
2 |
3 | tglite
4 | ======
5 | .. currentmodule:: tglite
6 | .. automodule:: tglite
7 |
8 | .. _api-edge-iteration-ops:
9 |
10 | Edge Iterate Ops
11 | ----------------
12 | .. autosummary::
13 | :toctree: ../../generated/
14 |
15 | iter_edges
16 |
17 | .. _api-graph-creation-ops:
18 |
19 | Create Graph Ops
20 | ----------------
21 | .. autosummary::
22 | :toctree: ../../generated/
23 |
24 | from_csv
25 |
--------------------------------------------------------------------------------
/docs/source/api/python/tglite.sampler.rst:
--------------------------------------------------------------------------------
1 | .. _api-sampler:
2 |
3 | tglite.TSampler
4 | ===============
5 | .. currentmodule:: tglite
6 | .. autoclass:: TSampler
7 |
8 | .. automethod:: __init__
9 | .. automethod:: sample
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 |
7 | # If extensions (or modules to document with autodoc) are in another directory,
8 | # add these directories to sys.path here.
9 |
10 | import pathlib
11 | import sys
12 | import os
13 | import site
14 | sys.path.insert(0, pathlib.Path(__file__).parents[2].resolve().as_posix())
15 | sys.path.insert(0, os.path.abspath("../../python"))
16 | sys.path.insert(0, site.getsitepackages()[0])
17 |
18 | # -- Project information -----------------------------------------------------
19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
20 |
21 | project = 'TGLite'
22 | copyright = '2024, ADAPT Group'
23 | author = 'ADAPT Group'
24 | release = '0.1.0'
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
29 |
30 | extensions = [
31 | 'sphinx.ext.autodoc',
32 | 'sphinx.ext.autosummary',
33 | 'sphinx.ext.doctest',
34 | 'sphinx.ext.duration',
35 | "sphinx.ext.mathjax",
36 | 'sphinx_rtd_theme',
37 | 'nbsphinx'
38 | ]
39 |
40 | templates_path = ['_templates']
41 | exclude_patterns = []
42 |
43 |
44 |
45 | # -- Options for HTML output -------------------------------------------------
46 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
47 |
48 | html_theme = 'sphinx_rtd_theme'
49 | html_static_path = ['_static']
50 | html_context = {
51 | "display_github": True, # Integrate GitHub
52 | "github_user": "ADAPT-uiuc", # Username
53 | "github_repo": "tglite", # Repo name
54 | "github_version": "main", # Version
55 | "conf_py_path": "/docs/source/", # Path in the checkout to the docs root
56 | }
57 |
--------------------------------------------------------------------------------
/docs/source/generated/tglite.Mailbox.rst:
--------------------------------------------------------------------------------
1 | tglite.Mailbox
2 | ==============
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoclass:: Mailbox
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Mailbox.__init__
17 | ~Mailbox.dims
18 | ~Mailbox.move_to
19 | ~Mailbox.reset
20 | ~Mailbox.store
21 |
22 |
23 |
24 |
25 |
26 | .. rubric:: Attributes
27 |
28 | .. autosummary::
29 |
30 | ~Mailbox.device
31 | ~Mailbox.mail
32 | ~Mailbox.time
33 |
34 |
--------------------------------------------------------------------------------
/docs/source/generated/tglite.Memory.rst:
--------------------------------------------------------------------------------
1 | tglite.Memory
2 | =============
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoclass:: Memory
7 |
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. rubric:: Methods
13 |
14 | .. autosummary::
15 |
16 | ~Memory.__init__
17 | ~Memory.backup
18 | ~Memory.dim
19 | ~Memory.move_to
20 | ~Memory.reset
21 | ~Memory.restore
22 | ~Memory.update
23 |
24 |
25 |
26 |
27 |
28 | .. rubric:: Attributes
29 |
30 | .. autosummary::
31 |
32 | ~Memory.data
33 | ~Memory.device
34 | ~Memory.time
35 |
36 |
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.block.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.block
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.block
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.block_adj.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.block\_adj
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.block_adj
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.edges.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.edges
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.edges
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.eids.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.eids
2 | ==================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.eids
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.g.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.g
2 | ===============
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBatch.g
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.neg_nodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.neg\_nodes
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBatch.neg_nodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.nodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.nodes
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.nodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.split_data.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.split\_data
2 | =========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.split_data
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBatch.times.rst:
--------------------------------------------------------------------------------
1 | tglite.TBatch.times
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBatch.times
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.allnodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.allnodes
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.allnodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.apply.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.apply
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.apply
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.clear_hooks.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.clear\_hooks
2 | ==========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.clear_hooks
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.clear_nbrs.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.clear\_nbrs
2 | =========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.clear_nbrs
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.dstdata.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.dstdata
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.dstdata
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.dstfeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.dstfeat
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.dstfeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.dstindex.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.dstindex
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.dstindex
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.dstnodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.dstnodes
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.dstnodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.dsttimes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.dsttimes
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.dsttimes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.edata.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.edata
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.edata
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.efeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.efeat
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.efeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.eid.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.eid
2 | =================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.eid
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.ets.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.ets
2 | =================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.ets
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.g.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.g
2 | ===============
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.g
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.has_nbrs.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.has\_nbrs
2 | =======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.has_nbrs
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.layer.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.layer
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.layer
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.mail.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.mail
2 | ==================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.mail
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.mem_data.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.mem\_data
2 | =======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.mem_data
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.next.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.next
2 | ==================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.next
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.next_block.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.next\_block
2 | =========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.next_block
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.nfeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.nfeat
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.nfeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.num_dst.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.num\_dst
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.num_dst
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.num_edges.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.num\_edges
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.num_edges
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.num_src.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.num\_src
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.num_src
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.prev.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.prev
2 | ==================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.prev
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.register_hook.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.register\_hook
2 | ============================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.register_hook
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.run_hooks.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.run\_hooks
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.run_hooks
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.set_nbrs.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.set\_nbrs
2 | =======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.set_nbrs
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.srcdata.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.srcdata
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.srcdata
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.srcfeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.srcfeat
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.srcfeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.srcnodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.srcnodes
2 | ======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TBlock.srcnodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.time_deltas.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.time\_deltas
2 | ==========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.time_deltas
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TBlock.uniq_src.rst:
--------------------------------------------------------------------------------
1 | tglite.TBlock.uniq\_src
2 | =======================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TBlock.uniq_src
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.enable_embed_caching.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.enable\_embed\_caching
2 | ======================================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.enable_embed_caching
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.enable_time_precompute.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.enable\_time\_precompute
2 | ========================================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.enable_time_precompute
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.eval.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.eval
2 | ====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.eval
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.graph.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.graph
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TContext.graph
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.need_sampling.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.need\_sampling
2 | ==============================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.need_sampling
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.set_cache_limit.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.set\_cache\_limit
2 | =================================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.set_cache_limit
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.set_time_window.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.set\_time\_window
2 | =================================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.set_time_window
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TContext.train.rst:
--------------------------------------------------------------------------------
1 | tglite.TContext.train
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TContext.train
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.compute_device.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.compute\_device
2 | =============================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.compute_device
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.edata.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.edata
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.edata
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.efeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.efeat
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.efeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.mailbox.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.mailbox
2 | =====================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.mailbox
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.mem.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.mem
2 | =================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.mem
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.move_data.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.move\_data
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.move_data
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.ndata.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.ndata
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.ndata
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.nfeat.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.nfeat
2 | ===================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autoproperty:: TGraph.nfeat
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.num_edges.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.num\_edges
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.num_edges
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.num_nodes.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.num\_nodes
2 | ========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.num_nodes
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.set_compute.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.set\_compute
2 | ==========================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.set_compute
--------------------------------------------------------------------------------
/docs/source/generated/tglite.TGraph.storage_device.rst:
--------------------------------------------------------------------------------
1 | tglite.TGraph.storage\_device
2 | =============================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. automethod:: TGraph.storage_device
--------------------------------------------------------------------------------
/docs/source/generated/tglite.from_csv.rst:
--------------------------------------------------------------------------------
1 | tglite.from\_csv
2 | ================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autofunction:: from_csv
--------------------------------------------------------------------------------
/docs/source/generated/tglite.iter_edges.rst:
--------------------------------------------------------------------------------
1 | tglite.iter\_edges
2 | ==================
3 |
4 | .. currentmodule:: tglite
5 |
6 | .. autofunction:: iter_edges
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.aggregate.rst:
--------------------------------------------------------------------------------
1 | tglite.op.aggregate
2 | ===================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: aggregate
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.cache.rst:
--------------------------------------------------------------------------------
1 | tglite.op.cache
2 | ===============
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: cache
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.coalesce.rst:
--------------------------------------------------------------------------------
1 | tglite.op.coalesce
2 | ==================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: coalesce
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.dedup.rst:
--------------------------------------------------------------------------------
1 | tglite.op.dedup
2 | ===============
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: dedup
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.edge_reduce.rst:
--------------------------------------------------------------------------------
1 | tglite.op.edge\_reduce
2 | ======================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: edge_reduce
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.edge_softmax.rst:
--------------------------------------------------------------------------------
1 | tglite.op.edge\_softmax
2 | =======================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: edge_softmax
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.edge_view.rst:
--------------------------------------------------------------------------------
1 | tglite.op.edge\_view
2 | ====================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: edge_view
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.precomputed_times.rst:
--------------------------------------------------------------------------------
1 | tglite.op.precomputed\_times
2 | ============================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: precomputed_times
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.precomputed_zeros.rst:
--------------------------------------------------------------------------------
1 | tglite.op.precomputed\_zeros
2 | ============================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: precomputed_zeros
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.preload.rst:
--------------------------------------------------------------------------------
1 | tglite.op.preload
2 | =================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: preload
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.propagate.rst:
--------------------------------------------------------------------------------
1 | tglite.op.propagate
2 | ===================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: propagate
--------------------------------------------------------------------------------
/docs/source/generated/tglite.op.src_scatter.rst:
--------------------------------------------------------------------------------
1 | tglite.op.src\_scatter
2 | ======================
3 |
4 | .. currentmodule:: tglite.op
5 |
6 | .. autofunction:: src_scatter
--------------------------------------------------------------------------------
/docs/source/img/blank.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/blank.png
--------------------------------------------------------------------------------
/docs/source/img/colab.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
9 |
10 |
13 |
15 |
17 |
20 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/docs/source/img/github.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
8 |
15 |
16 |
--------------------------------------------------------------------------------
/docs/source/img/tblock-structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/tblock-structure.png
--------------------------------------------------------------------------------
/docs/source/img/tblock-workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/tblock-workflow.png
--------------------------------------------------------------------------------
/docs/source/img/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ADAPT-uiuc/tglite/760fa9a3f96663fff50a4e1b92bd7413a90fa4fd/docs/source/img/train.png
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. TGLite documentation master file, created by
2 | sphinx-quickstart on Wed Nov 1 15:04:59 2023.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to TGLite's documentation!
7 | ==================================
8 |
9 | **TGLite** is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. *TGNNs*, or *Temporal Graph Neural Networks*, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes.
10 |
11 | TGLite employs an abstraction called a :ref:`TBlock ` to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the `TGL `_ framework by up to *3x* in terms of training time.
12 |
13 | .. _train figure:
14 | .. figure:: img/train.png
15 | :alt: End-to-end training epoch time comparison on an Nvidia A100 GPU
16 | :align: center
17 | :figwidth: 85 %
18 |
19 | End-to-end training epoch time comparison on an Nvidia A100 GPU
20 |
21 | Install TGLite
22 | --------------
23 | See :ref:`Getting started ` for instructions on how to install the TGLite binaries. To install from source or for local development, refer to :ref:`Building from source ` and :ref:`Development mode `.
24 |
25 | Tutorials
26 | ---------
27 | We provide a set of tutorials to help you get started with TGLite. These tutorials cover the basics of using TGLite, as well as more advanced topics.
28 |
29 | 0. Quickstart_: A step-by-step guide to train a TGNN model using TGLite.
30 | 1. :ref:`How does TBlock work? `: A tutorial on how to use the TBlock abstraction to implement TGNN models.
31 |
32 | .. _Quickstart: tutorial/quickstart.ipynb
33 |
34 | .. toctree::
35 | :maxdepth: 1
36 | :caption: TGLite
37 | :hidden:
38 | :glob:
39 |
40 | install/index
41 | tutorial/quickstart
42 | tutorial/tblock
43 |
44 | .. toctree::
45 | :maxdepth: 2
46 | :caption: API
47 | :glob:
48 |
49 | api/python/tglite
50 | api/python/tglite.batch
51 | api/python/tglite.block
52 | api/python/tglite.context
53 | api/python/tglite.graph
54 | api/python/tglite.sampler
55 | api/python/tglite.mailbox
56 | api/python/tglite.memory
57 | api/python/tglite.nn
58 | api/python/tglite.op
59 |
60 |
61 | .. note::
62 | This project is under active development.
63 |
64 |
65 | Indices and tables
66 | ==================
67 |
68 | * :ref:`genindex`
69 | * :ref:`modindex`
70 | * :ref:`search`
71 |
--------------------------------------------------------------------------------
/docs/source/install/index.rst:
--------------------------------------------------------------------------------
1 | .. _getting-started:
2 |
3 | Getting Started
4 | ---------------
5 |
6 | TGLite currently only support PyTorch as backend.
7 |
8 | Installation with pip
9 | ``````````````````````
10 |
11 | Prerequisites
12 | ^^^^^^^^^^^^^
13 |
14 | .. list-table::
15 | :widths: 25 25 25 25
16 |
17 | * - **python**
18 | - **gcc**
19 | - **torch**
20 | - **torch-scatter**
21 | * - ≥ 3.7
22 | - ≥ 6.1
23 | - ≥ 1.12.1
24 | - ≥ 2.1.0
25 |
26 | .. * python 3.7 or later
27 | .. * gcc 6.1 or later
28 | .. * torch 1.12.1 or later
29 | .. * torch-scatter 2.1.0 or later
30 |
31 | Installation
32 | ^^^^^^^^^^^^
33 | Ensure at least PyTorch 1.12.1 and torch-scatter 2.1.0 are installed (refer to `PyTorch `_ and `torch-scatter `_ for installation instructions), simply run
34 |
35 | .. code-block:: console
36 |
37 | $ pip install tglite
38 |
39 | Verification
40 | ^^^^^^^^^^^^
41 | To verify the installation, run the following in Python:
42 |
43 | .. code-block:: python
44 |
45 | import tglite
46 | print(tglite.__version__)
47 |
48 | .. note::
49 | Currently, the library is only tested on Linux. MacOS and Windows support is not guaranteed.
50 |
51 | .. _build-from-source:
52 |
53 | Building from source
54 | `````````````````````
55 | To install the latest TGLite code for testing or development on the core, you will need to build TGlite from source. Here, we show how to build TGLite with Python 3.7, PyTorch 1.12.1 and torch-scatter 2.1.0.
56 |
57 | Create and activate a python environment:
58 |
59 | .. code-block:: console
60 |
61 | $ conda create -n tglite python=3.7
62 | $ conda activate tglite
63 |
64 | Install dependencies that have CUDA versions:
65 |
66 | .. code-block:: console
67 |
68 | $ pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
69 | $ pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html
70 |
71 | Get the TGLite source:
72 |
73 | .. code-block:: console
74 |
75 | $ git clone https://github.com/ADAPT-uiuc/tglite.git
76 | $ cd tglite
77 |
78 | Then install the package locally:
79 |
80 | .. code-block:: console
81 |
82 | $ python setup.py install
83 |
84 | This will build the C++ extension (which requires C++14 and OpenMP), install
85 | the rest of the dependencies (as listed in `pyproject.toml`), and then install
86 | the `tglite` package.
87 |
88 | .. _development-mode:
89 |
90 | Development Mode
91 | ````````````````
92 |
93 | Development mode allows easily editing the code without having to re-install
94 | the package. However, this only applies to the python code. When editing the
95 | C++ extension code, it needs to be re-compiled again. Use `develop` instead of `install` to use dev mode:
96 |
97 | .. code-block:: console
98 |
99 | $ python setup.py develop
100 |
101 |
102 | Running Tests
103 | ^^^^^^^^^^^^^
104 |
105 | Unit tests are located in `tests` directory. First, install the testing
106 | dependencies specified in `pyproject.toml`. Doing so might overwrite the dev
107 | mode install, so you might need to re-enable dev mode. Then, exercise the tests
108 | using the `pytest` utility.
109 |
110 | .. code-block:: console
111 |
112 | # install test dependencies
113 | $ pip install '.[test]'
114 |
115 | # re-enable dev mode install
116 | $ pip uninstall -y tglite
117 | $ python setup.py develop
118 |
119 | # run with test coverage report
120 | $ pytest --cov=tglite
121 |
122 |
123 | Running Examples
124 | ````````````````
125 | Inside the `examples `_ directory of the repository, several CTDG models have been implemented using `tglite`.
126 | To run these example models, install the additional dependencies and download the datasets:
127 |
128 | .. code-block:: console
129 |
130 | $ cd examples
131 | $ pip install -r requirements.txt # or "conda install -c conda-forge pandas scikit-learn" using conda
132 | $ ./download-data.sh
133 | $ python gen-data-files.py --data wiki-talk
134 |
135 | This will download the datasets inside `examples/data/`, one can also download data to other places.
136 |
137 | Use the scripts in `examples/exp` as a starting point, e.g.:
138 |
139 | .. code-block:: console
140 |
141 | $ ./exp/tgat.sh --data-path . -d wiki --epochs 3
142 |
143 |
144 | Building this document locally
145 | ```````````````````````````````
146 | .. code-block:: console
147 |
148 | # install doc dependencies
149 | $ pip install '.[docs]'
150 |
151 | # build docs
152 | $ cd docs
153 | $ make html
154 |
155 | # launch in browser
156 | $ sh run.sh
157 |
--------------------------------------------------------------------------------
/docs/source/tutorial/tblock.rst:
--------------------------------------------------------------------------------
1 | .. _tutorial-tblock:
2 |
3 | How does TBlock work?
4 | =====================
5 |
6 | Introduction
7 | ------------
8 |
9 | This tutorial provides an overview of `TBlock`, a key component of the TGLite framework. TBlocks, or temporal blocks, capture the message-flow dependencies between target node-time pairs and their temporally sampled neighbors. This tutorial will explain the design choices and features of TBlock to help you understand its usage within the TGLite framework.
10 |
11 | Overview
12 | --------
13 |
14 | TBlocks are motivated by the MFG (Message Flow Graph) objects available in the DGL (Deep Graph Library) but provide additional capabilities for CTDG (Continuous-Time Dynamic Graph) models. The following sections will explain the three key design choices that distinguish TBlocks from MFGs.
15 |
16 | Doubly-Linked List Structure
17 | ----------------------------
18 |
19 | One key distinction of TBlocks is the use of a doubly-linked list structure. This structure explicitly captures the multi-hop neighbor sampling/aggregation relationship that TBlocks are used for. Unlike standalone MFG objects in DGL/TGL, TBlocks maintain links to related blocks, enabling efficient multi-hop aggregation operations.
20 |
21 | Target and Neighbor Information
22 | -------------------------------
23 |
24 | TBlocks primarily focus on target destination nodes and optionally include information about neighbor source nodes. By separating neighbor information as optional, TBlocks allow for easier manipulation of target node information. This flexibility enables optimizations such as deduplication and caching to be applied effectively, as they can be performed on destination nodes before sampling for neighbors.
25 |
26 | Hooks Mechanism for Post-Processing
27 | -----------------------------------
28 |
29 | TBlocks provide a hooks mechanism for running post-processing procedures. These hooks are callable functions that are invoked after computations are performed on the block. The hooks mechanism enables scheduling of transformations on computed output, such as deduplication and preserving output semantics. TGLite runtime automatically handles the execution of registered hooks, simplifying the post-processing step.
30 |
31 | .. _tblock-structure figure:
32 | .. figure:: ../img/tblock-structure.png
33 | :alt: tblock-structure
34 | :align: center
35 | :figwidth: 60 %
36 |
37 | Diagram of the doubly-linked list design and internal structure of a TBlock (destination node-time is denoted as ).
38 |
39 | Block Lifecycle and Usage
40 | -------------------------
41 |
42 | The block lifecycle involves creating a TBlock, applying optimizations, sampling neighbors, performing computations, and accessing cached data. The following steps outline the typical usage of a TBlock:
43 |
44 | 1. Create a TBlock using various methods or construct it directly.
45 | 2. Apply optimizations to the block to minimize subgraph size and potential computations.
46 | 3. Sample neighbors of the block to capture message-flow dependencies.
47 | 4. Manipulate the block in-place, register hooks, or cache data.
48 | 5. Use the block for computations and access cached data for specific nodes and edges.
49 |
50 | .. _tblock-workflow figure:
51 | .. figure:: ../img/tblock-workflow.png
52 | :alt: tblock-workflow
53 | :align: center
54 | :figwidth: 60 %
55 |
56 | Typical flow of constructing and using a TBlock object.
57 |
58 | Example Usage
59 | -------------
60 |
61 | Here's an example code snippet demonstrating the usage of TBlock:
62 |
63 | .. code-block:: python
64 |
65 | # Create the head TBlock from a batch data
66 | head = batch.block(self.ctx)
67 |
68 | # Create the next TBlock iteratively
69 | for i in range(self.num_layers):
70 | tail = head if i == 0 else tail.next_block(...)
71 | # Apply optimizations
72 | tail = tg.op.dedup(tail)
73 | tail = tg.op.cache(self.ctx, tail, ...)
74 | # Sample neighbors
75 | tail = self.sampler.sample(tail)
76 |
77 | # Load data
78 | tg.op.preload(head, use_pin=True)
79 | tail.dstdata['h'] = tail.dstfeat()
80 | tail.srcdata['h'] = tail.srcfeat()
81 | # Perform computations
82 | emb = tg.op.aggregate(head, self.compute, key='h')
83 |
84 |
85 |
86 | In this tutorial, you learned about TBlock, a key component of the TGLite framework. TBlocks provide a powerful mechanism for capturing and analyzing message-flow dependencies in a continuous-time dynamic graph. By understanding the design choices and features of TBlock, you can effectively leverage its capabilities within your applications.
87 |
88 | For more details and advanced usage, refer to the :ref:`TBlock `.
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 | - defaults
4 | dependencies:
5 | - pip=22.3.1=py37h06a4308_0
6 | - python=3.7.16=h7a1cb2a_0
7 | - pandoc
8 | - pip:
9 | - torch==1.12.1
--------------------------------------------------------------------------------
/examples/apan/apan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tglite as tg
3 |
4 | from torch import nn, Tensor
5 | from tglite._stats import tt
6 |
7 | import sys, os
8 | sys.path.append(os.path.join(os.getcwd(), '..'))
9 | import support
10 |
11 |
12 | class APAN(nn.Module):
13 | def __init__(self, ctx: tg.TContext,
14 | dim_mem: int, dim_edge: int, dim_time: int,
15 | sampler: tg.TSampler, num_heads=2, dropout=0.1):
16 | super().__init__()
17 | self.ctx = ctx
18 | self.dim_edge = dim_edge
19 | self.mem_updater = AttnMemoryUpdater(ctx,
20 | dim_mem=dim_mem,
21 | dim_msg=2 * dim_mem + dim_edge,
22 | dim_time=dim_time,
23 | num_heads=num_heads,
24 | dropout=dropout)
25 | self.sampler = sampler
26 | self.edge_predictor = support.EdgePredictor(dim_mem)
27 |
28 | def forward(self, batch: tg.TBatch):
29 | size = len(batch)
30 | t_start = tt.start()
31 | mem = self.mem_updater(batch)
32 | tt.t_mem_update += tt.elapsed(t_start)
33 |
34 | nodes = batch.nodes(include_negs=False)
35 | times = batch.times(include_negs=False)
36 | batch.g.mem.update(nodes, mem[:2 * size], torch.from_numpy(times))
37 |
38 | src, dst, neg = batch.split_data(mem)
39 | scores = self.edge_predictor(src, dst)
40 | if batch.neg_nodes is not None:
41 | scores = (scores, self.edge_predictor(src, neg))
42 | del src
43 | del dst
44 | del neg
45 |
46 | blk = tg.TBlock(self.ctx, 0, nodes, times)
47 | blk = self.sampler.sample(blk)
48 | mem = mem[:2 * size].detach().to(batch.g.storage_device())
49 | self.create_mails(batch, blk, mem)
50 | del mem
51 |
52 | tg.op.propagate(blk, self.send_mails)
53 | return scores
54 |
55 | def create_mails(self, batch: tg.TBatch, blk: tg.TBlock, mem: Tensor):
56 | size = len(batch)
57 | mem_src = mem[:size]
58 | mem_dst = mem[size:]
59 |
60 | if self.dim_edge > 0:
61 | efeat = batch.g.efeat[batch.eids()]
62 | src_mail = torch.cat([mem_src, mem_dst, efeat], dim=1)
63 | dst_mail = torch.cat([mem_dst, mem_src, efeat], dim=1)
64 | else:
65 | src_mail = torch.cat([mem_src, mem_dst], dim=1)
66 | dst_mail = torch.cat([mem_dst, mem_src], dim=1)
67 |
68 | blk.dstdata['mail'] = torch.cat([src_mail, dst_mail], dim=0)
69 |
70 | def send_mails(self, blk: tg.TBlock):
71 | sdev = blk.g.storage_device()
72 | if blk.num_edges() == 0:
73 | return
74 |
75 | mail = blk.dstdata['mail'][blk.dstindex]
76 | mail = tg.op.src_scatter(blk, mail, op='mean')
77 |
78 | mail_ts = torch.from_numpy(blk.dsttimes)
79 | mail_ts = mail_ts.to(sdev)[blk.dstindex]
80 | mail_ts = tg.op.src_scatter(blk, mail_ts, op='mean')
81 |
82 | blk.g.mailbox.store(blk.uniq_src()[0], mail, mail_ts)
83 |
84 |
85 | class AttnMemoryUpdater(nn.Module):
86 | def __init__(self, ctx: tg.TContext,
87 | dim_mem: int, dim_msg: int, dim_time: int,
88 | num_heads=2, dropout=0.1):
89 | super().__init__()
90 | assert (dim_mem % num_heads == 0)
91 | self.ctx = ctx
92 | self.num_heads = num_heads
93 | self.time_encode = tg.nn.TimeEncode(dim_time)
94 | self.w_q = nn.Linear(dim_mem, dim_mem)
95 | self.w_k = nn.Linear(dim_msg + dim_time, dim_mem)
96 | self.w_v = nn.Linear(dim_msg + dim_time, dim_mem)
97 | self.mlp = nn.Linear(dim_mem, dim_mem)
98 | self.attn_act = nn.LeakyReLU(0.2)
99 | self.dropout = nn.Dropout(dropout)
100 | self.layer_norm = nn.LayerNorm(dim_mem)
101 |
102 | def forward(self, batch: tg.TBatch) -> Tensor:
103 | sdev = batch.g.storage_device()
104 | cdev = batch.g.compute_device()
105 | nodes = batch.nodes()
106 | times = batch.times()
107 |
108 | size = len(nodes)
109 | mem = batch.g.mem
110 | mailbox = batch.g.mailbox
111 | mail_size = mailbox.dims()[0]
112 |
113 | mem_data = mem.data[nodes].to(cdev)
114 | Q = self.w_q(mem_data).reshape(size, self.num_heads, -1)
115 |
116 | mail = mailbox.mail[nodes].to(cdev)
117 | mail = mail.reshape(size, mail_size, -1)
118 | time_feat = torch.from_numpy(times).to(sdev).reshape(-1, 1)
119 | time_feat = (time_feat - mailbox.time[nodes]).to(cdev)
120 | time_feat = tg.op.precomputed_times(self.ctx, 0, self.time_encode, time_feat)
121 | time_feat = time_feat.reshape(size, mail_size, -1)
122 | mail = torch.cat([mail, time_feat], dim=2)
123 | del time_feat
124 |
125 | K = self.w_k(mail).reshape(size, mail_size, self.num_heads, -1)
126 | V = self.w_v(mail).reshape(size, mail_size, self.num_heads, -1)
127 | del mail
128 |
129 | attn = torch.sum(Q[:, None, :, :] * K, dim=3)
130 | del Q
131 |
132 | attn = self.attn_act(attn)
133 | attn = nn.functional.softmax(attn, dim=1)
134 | attn = self.dropout(attn)
135 | out = torch.sum(attn[:, :, :, None] * V, dim=1)
136 | del attn
137 |
138 | out = out.reshape(size, -1)
139 | out = out + mem_data
140 | del mem_data
141 |
142 | out = self.layer_norm(out)
143 | out = self.mlp(out)
144 | out = self.dropout(out)
145 | out = nn.functional.relu(out)
146 |
147 | return out
148 |
--------------------------------------------------------------------------------
/examples/apan/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | import tglite as tg
6 |
7 | from apan import APAN
8 | import support
9 |
10 | ### arguments
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name')
14 | parser.add_argument('--data-path', type=str, default='', help='path to data folder')
15 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model')
16 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)')
17 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)')
18 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)')
19 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)')
20 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)')
21 | # parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)')
22 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)')
23 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)')
24 | parser.add_argument('--n-mail', type=int, default=10, help='max number of mails (default: 10)')
25 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)')
26 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)')
27 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use')
28 | parser.add_argument('--move', action='store_true', help='move data to device')
29 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)')
30 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)')
31 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings')
32 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)')
33 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations')
34 | args = parser.parse_args()
35 | print(args)
36 |
37 | device = support.make_device(args.gpu)
38 | model_path = support.make_model_path('apan', args.prefix, args.data)
39 | model_mem_path = support.make_model_mem_path('apan', args.prefix, args.data)
40 | if args.seed >= 0:
41 | support.set_seed(args.seed)
42 |
43 | DATA: str = args.data
44 | DATA_PATH: str = args.data_path
45 | EPOCHS: int = args.epochs
46 | BATCH_SIZE: int = args.bsize
47 | LEARN_RATE: float = float(args.lr)
48 | DROPOUT: float = float(args.dropout)
49 | # N_LAYERS: int = args.n_layers
50 | N_HEADS: int = args.n_heads
51 | N_NBRS: int = args.n_nbrs
52 | N_MAIL: int = args.n_mail
53 | DIM_TIME: int = args.dim_time
54 | DIM_EMBED: int = args.dim_embed
55 | N_THREADS: int = args.n_threads
56 | SAMPLING: str = args.sampling
57 | OPT_TIME: bool = args.opt_time or args.opt_all
58 | TIME_WINDOW: int = int(args.time_window)
59 |
60 |
61 | ### load data
62 |
63 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))
64 | support.load_feats(g, DATA, DATA_PATH)
65 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]
66 | g.nfeat = None
67 | dim_mem = DIM_EMBED
68 |
69 | g.mailbox = tg.Mailbox(g.num_nodes(), N_MAIL, 2 * dim_mem + dim_efeat)
70 | g.mem = tg.Memory(g.num_nodes(), dim_mem)
71 |
72 | g.set_compute(device)
73 | if args.move:
74 | g.move_data(device)
75 |
76 | ctx = tg.TContext(g)
77 | ctx.need_sampling(True)
78 | ctx.enable_time_precompute(OPT_TIME)
79 | ctx.set_time_window(TIME_WINDOW)
80 |
81 |
82 | ### model
83 |
84 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)
85 | model = APAN(ctx,
86 | dim_mem=dim_mem,
87 | dim_edge=dim_efeat,
88 | dim_time=DIM_TIME,
89 | sampler = sampler,
90 | num_heads=N_HEADS,
91 | dropout=DROPOUT)
92 | model = model.to(device)
93 | criterion = torch.nn.BCEWithLogitsLoss()
94 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)
95 |
96 |
97 | ### training
98 |
99 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)
100 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)
101 |
102 | trainer = support.LinkPredTrainer(
103 | ctx, model, criterion, optimizer, neg_sampler,
104 | EPOCHS, BATCH_SIZE, train_end, val_end,
105 | model_path, model_mem_path)
106 |
107 | trainer.train()
108 | trainer.test()
--------------------------------------------------------------------------------
/examples/download-data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz
4 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz
5 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt
6 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv
7 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz
8 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv
9 | wget -P ./data/gdelt https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt
10 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/edges.csv
11 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/ext_full.npz
12 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_full.npz
13 | wget -P ./data/lastfm https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_train.npz
14 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_train.npz
15 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/labels.csv
16 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_full.npz
17 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/ext_full.npz
18 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/edges.csv
19 | # wget -P ./data/mag https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/node_features.pt
20 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/edges.csv
21 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/ext_full.npz
22 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_full.npz
23 | wget -P ./data/mooc https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_train.npz
24 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edge_features.pt
25 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edges.csv
26 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/ext_full.npz
27 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_full.npz
28 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_train.npz
29 | wget -P ./data/reddit https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/labels.csv
30 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt
31 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edges.csv
32 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/ext_full.npz
33 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_full.npz
34 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_train.npz
35 | wget -P ./data/wiki https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/labels.csv
36 |
37 | wget -P ./data/wiki-talk http://snap.stanford.edu/data/wiki-talk-temporal.txt.gz
38 | cd ./data/wiki-talk && gzip -d wiki-talk-temporal.txt.gz
39 |
--------------------------------------------------------------------------------
/examples/exp/apan-gdelt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python apan/train.py --seed 0 --prefix exp \
8 | --epochs 3 --bsize 4000 --n-threads 64 \
9 | --n-nbrs 10 --n-mail 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/exp/apan.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python apan/train.py --seed 0 --prefix exp \
8 | --epochs 10 --bsize 600 --n-threads 64 \
9 | --n-nbrs 10 --n-mail 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/exp/jodie-gdelt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python jodie/train.py --seed 0 --prefix exp \
8 | --epochs 3 --bsize 4000 "$@"
9 |
--------------------------------------------------------------------------------
/examples/exp/jodie.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python jodie/train.py --seed 0 --prefix exp \
8 | --epochs 10 --bsize 600 "$@"
9 |
--------------------------------------------------------------------------------
/examples/exp/tgat-gdelt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python tgat/train.py --seed 0 --prefix exp \
8 | --epochs 3 --bsize 4000 --n-threads 64 \
9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/exp/tgat.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python tgat/train.py --seed 0 --prefix exp \
8 | --epochs 10 --bsize 600 --n-threads 64 \
9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/exp/tgn-gdelt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python tgn/train.py --seed 0 --prefix exp \
8 | --epochs 3 --bsize 4000 --n-threads 64 \
9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/exp/tgn.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | examples_dir="$(cd "$(dirname "$0")"; cd ..; pwd)"
4 | cd "$examples_dir"
5 | export PYTHONPATH="$examples_dir"
6 |
7 | python tgn/train.py --seed 0 --prefix exp \
8 | --epochs 10 --bsize 600 --n-threads 64 \
9 | --n-layers 2 --n-heads 2 --n-nbrs 10 \
10 | --sampling recent "$@"
11 |
--------------------------------------------------------------------------------
/examples/gen-data-files.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pandas as pd
4 |
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name')
8 | args = parser.parse_args()
9 |
10 |
11 | def gen_edges_wiki_talk():
12 | df = pd.read_csv('data/wiki-talk/wiki-talk-temporal.txt',
13 | sep=' ', header=None, names=['src', 'dst', 'time'],
14 | dtype={'src': np.int32, 'dst': np.int32, 'time': np.float32})
15 |
16 | num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
17 | num_edges = df.shape[0]
18 | train_end = int(np.ceil(num_edges * 0.70))
19 | valid_end = int(np.ceil(num_edges * 0.85))
20 | print('num_nodes:', num_nodes)
21 | print('num_edges:', num_edges)
22 | print('train_end:', train_end)
23 | print('valid_end:', valid_end)
24 |
25 | df['int_roll'] = np.zeros(num_edges, dtype=np.int32)
26 | ext_roll = np.zeros(num_edges, dtype=np.int32)
27 | ext_roll[train_end:] = 1
28 | ext_roll[valid_end:] = 2
29 | df['ext_roll'] = ext_roll
30 |
31 | df.to_csv('data/wiki-talk/edges.csv')
32 |
33 |
34 | if args.data == 'wiki-talk':
35 | gen_edges_wiki_talk()
36 | else:
37 | print('not handling dataset:', args.data)
38 |
--------------------------------------------------------------------------------
/examples/jodie/jodie.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import tglite as tg
5 |
6 | from typing import Tuple
7 | from torch import nn, Tensor
8 | from tglite._stats import tt
9 |
10 | import sys, os
11 | sys.path.append(os.path.join(os.getcwd(), '..'))
12 | import support
13 |
14 |
15 | class NormalLinear(nn.Linear):
16 | def reset_parameters(self):
17 | stdv = 1. / math.sqrt(self.weight.size(1))
18 | self.weight.data.normal_(0, stdv)
19 | if self.bias is not None:
20 | self.bias.data.normal_(0, stdv)
21 |
22 |
23 | class JODIE(nn.Module):
24 | def __init__(self, ctx: tg.TContext, dim_embed: int, dim_node: int, dim_edge: int, dim_time: int):
25 | super().__init__()
26 | self.ctx = ctx
27 | self.dim_embed = dim_embed
28 | self.dim_node = dim_node
29 | self.dim_edge = dim_edge
30 | self.dim_time = dim_time
31 |
32 | dim_input = dim_embed + dim_edge + dim_time
33 | self.updater = nn.RNNCell(dim_input, dim_embed)
34 | self.time_encode = tg.nn.TimeEncode(dim_time)
35 | self.time_linear = NormalLinear(1, dim_embed)
36 |
37 | if dim_node != dim_embed:
38 | self.node_linear = nn.Linear(dim_node, dim_embed)
39 | self.norm = nn.LayerNorm(dim_embed)
40 | self.edge_predictor = support.EdgePredictor(dim_embed)
41 |
42 | def forward(self, batch: tg.TBatch):
43 | size = len(batch)
44 | nodes = batch.nodes()
45 |
46 | t_start = tt.start()
47 | embed, embed_ts = self.update_embed(batch, nodes)
48 | tt.t_mem_update += tt.elapsed(t_start)
49 | embed = self.normalize_embed(batch, nodes, embed)
50 | batch.g.mem.update(batch.nodes(include_negs=False), embed[:2 * size], embed_ts[:2 * size])
51 |
52 | embed = self.project_embed(batch, embed, embed_ts)
53 | scores = self.edge_predictor(embed[:size], embed[size:2 * size])
54 | if batch.neg_nodes is not None:
55 | scores = (scores, self.edge_predictor(embed[:size], embed[2 * size: ]))
56 | del embed_ts
57 | del embed
58 |
59 | self.save_raw_msgs(batch)
60 | return scores
61 |
62 | def update_embed(self, batch: tg.TBatch, nodes: np.ndarray) -> Tuple[Tensor, Tensor]:
63 | cdev = batch.g.compute_device()
64 |
65 | embed_time = batch.g.mem.time[nodes]
66 | mail_ts = batch.g.mailbox.time[nodes]
67 | time_feat = (mail_ts - embed_time).to(cdev)
68 | time_feat = self.time_encode(time_feat.squeeze())
69 | input = batch.g.mailbox.mail[nodes].to(cdev)
70 | input = torch.cat([input, time_feat], dim=1)
71 |
72 | embed = batch.g.mem.data[nodes].to(cdev)
73 | embed = self.updater(input, embed)
74 |
75 | return embed, mail_ts.to(cdev)
76 |
77 | def normalize_embed(self, batch: tg.TBatch, nodes: np.ndarray, embed: Tensor) -> Tensor:
78 | nfeat = batch.g.nfeat[nodes].to(embed.device)
79 | if self.dim_node != self.dim_embed:
80 | embed = embed + self.node_linear(nfeat)
81 | else:
82 | embed = embed + nfeat
83 | return self.norm(embed)
84 |
85 | def project_embed(self, batch: tg.TBatch, embed: Tensor, embed_ts: Tensor) -> Tensor:
86 | times = torch.from_numpy(batch.times()).to(embed_ts.device)
87 | delta = times - embed_ts
88 | time_diff = (delta / (times + 1)).to(embed.device)
89 | return embed * (1 + self.time_linear(time_diff.reshape(-1, 1)))
90 |
91 | def save_raw_msgs(self, batch: tg.TBatch):
92 | sdev = batch.g.storage_device()
93 | blk = batch.block_adj(self.ctx)
94 |
95 | adj_nodes = torch.from_numpy(blk.srcnodes).long().to(sdev)
96 | mail = batch.g.mem.data[adj_nodes]
97 | if self.dim_edge > 0:
98 | eids = torch.from_numpy(blk.eid).long().to(sdev)
99 | mail = torch.cat([mail, batch.g.efeat[eids]], dim=1)
100 | mail_ts = torch.from_numpy(blk.ets).to(sdev)
101 | batch.g.mailbox.store(blk.dstnodes, mail, mail_ts)
--------------------------------------------------------------------------------
/examples/jodie/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | import tglite as tg
6 |
7 | from jodie import JODIE
8 | import support
9 |
10 |
11 | ### arguments
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name')
15 | parser.add_argument('--data-path', type=str, default='', help='path to data folder')
16 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model')
17 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)')
18 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)')
19 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)')
20 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)')
21 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)')
22 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)')
23 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use')
24 | parser.add_argument('--move', action='store_true', help='move data to device')
25 | args = parser.parse_args()
26 | print(args)
27 |
28 | device = support.make_device(args.gpu)
29 | model_path = support.make_model_path('jodie', args.prefix, args.data)
30 | model_mem_path = support.make_model_mem_path('jodie', args.prefix, args.data)
31 | if args.seed >= 0:
32 | support.set_seed(args.seed)
33 |
34 | DATA: str = args.data
35 | DATA_PATH: str = args.data_path
36 | EPOCHS: int = args.epochs
37 | BATCH_SIZE: int = args.bsize
38 | LEARN_RATE: float = float(args.lr)
39 | DIM_TIME: int = args.dim_time
40 | DIM_EMBED: int = args.dim_embed
41 |
42 |
43 | ### load data
44 |
45 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))
46 | support.load_feats(g, DATA, DATA_PATH)
47 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]
48 | dim_nfeat = g.nfeat.shape[1]
49 |
50 | g.mem = tg.Memory(g.num_nodes(), DIM_EMBED)
51 | g.mailbox = tg.Mailbox(g.num_nodes(), 1, DIM_EMBED + dim_efeat)
52 |
53 | g.set_compute(device)
54 | if args.move:
55 | g.move_data(device)
56 |
57 | ctx = tg.TContext(g)
58 |
59 |
60 | ### model
61 |
62 | model = JODIE(ctx,
63 | dim_embed=DIM_EMBED,
64 | dim_node=dim_nfeat,
65 | dim_edge=dim_efeat,
66 | dim_time=DIM_TIME)
67 | model = model.to(device)
68 | criterion = torch.nn.BCEWithLogitsLoss()
69 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)
70 |
71 |
72 | ### training
73 |
74 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)
75 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)
76 |
77 | ctx = tg.TContext(g)
78 | trainer = support.LinkPredTrainer(
79 | ctx, model, criterion, optimizer, neg_sampler,
80 | EPOCHS, BATCH_SIZE, train_end, val_end,
81 | model_path, model_mem_path)
82 |
83 | trainer.train()
84 | trainer.test()
85 |
--------------------------------------------------------------------------------
/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.3.5
2 | scikit-learn==1.0.2
3 |
--------------------------------------------------------------------------------
/examples/support.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | import os
4 | import torch
5 | import numpy as np
6 | import pandas as pd
7 | from torch import nn, Tensor
8 | from pathlib import Path
9 | from sklearn.metrics import average_precision_score, roc_auc_score
10 | from typing import Callable, Optional, Tuple, Union
11 |
12 | import tglite as tg
13 | from tglite._stats import tt
14 |
15 |
16 | def set_seed(seed: int):
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 |
22 |
23 | def make_device(gpu: int) -> torch.device:
24 | return torch.device(f'cuda:{gpu}' if gpu >= 0 else 'cpu')
25 |
26 |
27 | def make_model_path(model: str, prefix: str, data: str) -> str:
28 | """If prefix is not empty, return 'models/{model}/{prefix}-{data}.pt', else return
29 | 'models/{model}/{data}-{time.time()}.pt'."""
30 | Path(f'models/{model}').mkdir(parents=True, exist_ok=True)
31 | if prefix:
32 | return f'models/{model}/{prefix}-{data}.pt'
33 | else:
34 | return f'models/{model}/{data}-{time.time()}.pt'
35 |
36 |
37 | def make_model_mem_path(model: str, prefix: str, data: str) -> str:
38 | Path(f'models/{model}').mkdir(parents=True, exist_ok=True)
39 | if prefix:
40 | return f'models/{model}/{prefix}-{data}-mem.pt'
41 | else:
42 | return f'models/{model}/{data}-mem-{time.time()}.pt'
43 |
44 |
45 | def load_graph(path: Union[str, Path]) -> tg.TGraph:
46 | """Create a TGraph with edges and timestamps loaded from path. Provided data should include
47 | 'src' 'dst' and 'time' columns."""
48 | df = pd.read_csv(str(path))
49 |
50 | src = df['src'].to_numpy().astype(np.int32).reshape(-1, 1)
51 | dst = df['dst'].to_numpy().astype(np.int32).reshape(-1, 1)
52 | etime = df['time'].to_numpy().astype(np.float32)
53 | del df
54 |
55 | edges = np.concatenate([src, dst], axis=1)
56 | del src
57 | del dst
58 |
59 | g = tg.TGraph(edges, etime)
60 | print('num edges:', g.num_edges())
61 | print('num nodes:', g.num_nodes())
62 | return g
63 |
64 |
65 | def load_feats(g: tg.TGraph, d: str, data_path: str=''):
66 | """
67 | Load edge features and node features to g from data/{d}/edge_features.pt and
68 | data/{d}/edge_features.pt. If no file, create random edge and node features for data 'mooc',
69 | 'lastfm' and 'wiki-talk', create random edge features for data 'wiki' and 'reddit', None for
70 | other data.
71 | """
72 | edge_feats = None
73 | node_feats = None
74 |
75 | if Path(os.path.join(data_path, f'data/{d}/edge_features.pt')).exists():
76 | edge_feats = torch.load(os.path.join(data_path, f'data/{d}/edge_features.pt'))
77 | edge_feats = edge_feats.type(torch.float32)
78 | elif d in ['mooc', 'lastfm', 'wiki-talk']:
79 | edge_feats = torch.randn(g.num_edges(), 128, dtype=torch.float32)
80 |
81 | if Path(os.path.join(data_path, f'data/{d}/node_features.pt')).exists():
82 | node_feats = torch.load(os.path.join(data_path, f'data/{d}/node_features.pt'))
83 | node_feats = node_feats.type(torch.float32)
84 | elif d in ['wiki', 'mooc', 'reddit', 'lastfm', 'wiki-talk']:
85 | node_feats = torch.randn(g.num_nodes(), edge_feats.shape[1], dtype=torch.float32)
86 |
87 | print('edge feat:', None if edge_feats is None else edge_feats.shape)
88 | print('node feat:', None if node_feats is None else node_feats.shape)
89 | g.efeat = edge_feats
90 | g.nfeat = node_feats
91 |
92 |
93 | def data_split(num_samples: int, train_percent: float, val_percent: float) -> Tuple[int, int]:
94 | train_end = int(np.ceil(num_samples * train_percent))
95 | val_end = int(np.ceil(num_samples * (train_percent + val_percent)))
96 | return train_end, val_end
97 |
98 |
99 | class EdgePredictor(nn.Module):
100 | def __init__(self, dim: int):
101 | super().__init__()
102 | self.dim = dim
103 | self.src_fc = nn.Linear(dim, dim)
104 | self.dst_fc = nn.Linear(dim, dim)
105 | self.out_fc = nn.Linear(dim, 1)
106 | self.act = nn.ReLU()
107 |
108 | def forward(self, src: Tensor, dst: Tensor) -> Tensor:
109 | h_src = self.src_fc(src)
110 | h_dst = self.dst_fc(dst)
111 | h_out = self.act(h_src + h_dst)
112 | return self.out_fc(h_out)
113 |
114 |
115 | class LinkPredTrainer(object):
116 | def __init__(self, ctx: tg.TContext, model: nn.Module,
117 | criterion: nn.Module, optimizer: torch.optim.Optimizer,
118 | neg_sampler: Callable, epochs: int, bsize: int,
119 | train_end: int, val_end: int,
120 | model_path: str, model_mem_path: Optional[str]):
121 | self.ctx = ctx
122 | self.g = ctx.graph
123 | self.model = model
124 | self.criterion = criterion
125 | self.optimizer = optimizer
126 | self.neg_sampler = neg_sampler
127 | self.epochs = epochs
128 | self.bsize = bsize
129 | self.train_end = train_end
130 | self.val_end = val_end
131 | self.model_path = model_path
132 | self.model_mem_path = model_mem_path
133 |
134 | def train(self):
135 | tt.csv_open('out-stats.csv')
136 | tt.csv_write_header()
137 | best_epoch = 0
138 | best_ap = 0
139 | for e in range(self.epochs):
140 | print(f'epoch {e}:')
141 | torch.cuda.synchronize()
142 | t_epoch = tt.start()
143 |
144 | self.ctx.train()
145 | self.model.train()
146 | if self.g.mem is not None:
147 | self.g.mem.reset()
148 | if self.g.mailbox is not None:
149 | self.g.mailbox.reset()
150 |
151 | epoch_loss = 0.0
152 | t_loop = tt.start()
153 | for batch in tg.iter_edges(self.g, size=self.bsize, end=self.train_end):
154 | t_start = tt.start()
155 | batch.neg_nodes = self.neg_sampler(len(batch))
156 | tt.t_prep_batch += tt.elapsed(t_start)
157 |
158 | t_start = tt.start()
159 | self.optimizer.zero_grad()
160 | pred_pos, pred_neg = self.model(batch)
161 | tt.t_forward += tt.elapsed(t_start)
162 |
163 | t_start = tt.start()
164 | loss = self.criterion(pred_pos, torch.ones_like(pred_pos))
165 | loss += self.criterion(pred_neg, torch.zeros_like(pred_neg))
166 | epoch_loss += float(loss)
167 | loss.backward()
168 | self.optimizer.step()
169 | tt.t_backward += tt.elapsed(t_start)
170 | tt.t_loop = tt.elapsed(t_loop)
171 |
172 | t_eval = tt.start()
173 | ap, auc = self.eval(start_idx=self.train_end, end_idx=self.val_end)
174 | tt.t_eval = tt.elapsed(t_eval)
175 |
176 | torch.cuda.synchronize()
177 | tt.t_epoch = tt.elapsed(t_epoch)
178 | if e == 0 or ap > best_ap:
179 | best_epoch = e
180 | best_ap = ap
181 | torch.save(self.model.state_dict(), self.model_path)
182 | if self.g.mem is not None:
183 | torch.save(self.g.mem.backup(), self.model_mem_path)
184 | print(' loss:{:.4f} val ap:{:.4f} val auc:{:.4f}'.format(epoch_loss, ap, auc))
185 | tt.csv_write_line(epoch=e)
186 | tt.print_epoch()
187 | tt.reset_epoch()
188 | tt.csv_close()
189 | print('best model at epoch {}'.format(best_epoch))
190 |
191 | @torch.no_grad()
192 | def eval(self, start_idx: int, end_idx: int = None):
193 | self.ctx.eval()
194 | self.model.eval()
195 | val_aps = []
196 | val_auc = []
197 | for batch in tg.iter_edges(self.g, size=self.bsize, start=start_idx, end=end_idx):
198 | size = len(batch)
199 | batch.neg_nodes = self.neg_sampler(size)
200 | prob_pos, prob_neg = self.model(batch)
201 | prob_pos = prob_pos.cpu()
202 | prob_neg = prob_neg.cpu()
203 | pred_score = torch.cat([prob_pos, prob_neg], dim=0).sigmoid()
204 | true_label = torch.cat([torch.ones(size), torch.zeros(size)])
205 | val_aps.append(average_precision_score(true_label, pred_score))
206 | val_auc.append(roc_auc_score(true_label, pred_score))
207 | return np.mean(val_aps), np.mean(val_auc)
208 |
209 | def test(self):
210 | print('loading saved checkpoint and testing model...')
211 | self.model.load_state_dict(torch.load(self.model_path))
212 | if self.g.mem is not None:
213 | self.g.mem.restore(torch.load(self.model_mem_path))
214 | t_test = tt.start()
215 | ap, auc = self.eval(start_idx=self.val_end)
216 | t_test = tt.elapsed(t_test)
217 | print(' test time:{:.2f}s AP:{:.4f} AUC:{:.4f}'.format(t_test, ap, auc))
218 |
--------------------------------------------------------------------------------
/examples/tgat/TGAT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "id": "F4O-2MjGfOMr"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import os\n",
13 | "import numpy as np\n",
14 | "import tglite as tg\n",
15 | "\n",
16 | "from tgat import TGAT\n",
17 | "import support"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {
24 | "id": "tH_6K5u3iQhy"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "DATA: str = 'wiki' # 'wiki', 'reddit', 'mooc', 'mag', 'lastfm', 'gdelt', 'wiki-talk'\n",
29 | "DATA_PATH: str = '/shared'\n",
30 | "EPOCHS: int = 10\n",
31 | "BATCH_SIZE: int = 200\n",
32 | "LEARN_RATE: float = 0.0001\n",
33 | "DROPOUT: float = 0.1\n",
34 | "N_LAYERS: int = 2\n",
35 | "N_HEADS: int = 2\n",
36 | "N_NBRS: int = 20\n",
37 | "DIM_TIME: int = 100\n",
38 | "DIM_EMBED: int = 100\n",
39 | "N_THREADS: int = 32\n",
40 | "SAMPLING: str = 'recent' # 'recent'or 'uniform'\n",
41 | "OPT_DEDUP = True\n",
42 | "OPT_CACHE = True\n",
43 | "OPT_TIME = True\n",
44 | "OPT_ALL = True\n",
45 | "OPT_DEDUP: bool = OPT_DEDUP or OPT_ALL\n",
46 | "OPT_CACHE: bool = OPT_CACHE or OPT_ALL\n",
47 | "OPT_TIME: bool = OPT_TIME or OPT_ALL\n",
48 | "CACHE_LIMIT: int = int(2e6)\n",
49 | "TIME_WINDOW: int = int(1e4)\n",
50 | "\n",
51 | "MOVE = True\n",
52 | "GPU = 0\n",
53 | "SEED = 1\n",
54 | "PREFIX = ''"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "device = support.make_device(GPU)\n",
64 | "model_path = support.make_model_path('tgat', PREFIX, DATA)\n",
65 | "if SEED >= 0:\n",
66 | " support.set_seed(SEED)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {
73 | "id": "COgTQXmcgCEr"
74 | },
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "num edges: 157474\n",
81 | "num nodes: 9228\n",
82 | "edge feat: torch.Size([157474, 172])\n",
83 | "node feat: torch.Size([9228, 172])\n"
84 | ]
85 | }
86 | ],
87 | "source": [
88 | "### load data\n",
89 | "\n",
90 | "g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))\n",
91 | "support.load_feats(g, DATA, DATA_PATH)\n",
92 | "dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]\n",
93 | "dim_nfeat = g.nfeat.shape[1]\n",
94 | "\n",
95 | "g.set_compute(device)\n",
96 | "if MOVE:\n",
97 | " g.move_data(device)\n",
98 | "\n",
99 | "ctx = tg.TContext(g)\n",
100 | "ctx.need_sampling(True)\n",
101 | "ctx.enable_embed_caching(OPT_CACHE, DIM_EMBED)\n",
102 | "ctx.enable_time_precompute(OPT_TIME)\n",
103 | "ctx.set_cache_limit(CACHE_LIMIT)\n",
104 | "ctx.set_time_window(TIME_WINDOW)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 5,
110 | "metadata": {
111 | "id": "5-YKifBwLuMS"
112 | },
113 | "outputs": [],
114 | "source": [
115 | "### model\n",
116 | "\n",
117 | "sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)\n",
118 | "model = TGAT(ctx,\n",
119 | " dim_node=dim_nfeat,\n",
120 | " dim_edge=dim_efeat,\n",
121 | " dim_time=DIM_TIME,\n",
122 | " dim_embed=DIM_EMBED,\n",
123 | " sampler=sampler,\n",
124 | " num_layers=N_LAYERS,\n",
125 | " num_heads=N_HEADS,\n",
126 | " dropout=DROPOUT,\n",
127 | " dedup=OPT_DEDUP,)\n",
128 | "model = model.to(device)\n",
129 | "criterion = torch.nn.BCEWithLogitsLoss()\n",
130 | "optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 6,
136 | "metadata": {
137 | "id": "DLZ8cipJLx-u"
138 | },
139 | "outputs": [
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "epoch 0:\n",
145 | " loss:295.7572 val ap:0.9739 val auc:0.9782\n",
146 | " epoch | total:14.08s loop:12.34s eval:1.74s\n",
147 | " loop | forward:7.82s backward:4.46s sample:0.49s prep_batch:0.05s prep_input:1.03s post_update:0.00s\n",
148 | " comp | mem_update:0.00s time_zero:1.15s time_nbrs:0.93s self_attn:3.64s\n",
149 | "epoch 1:\n",
150 | " loss:170.5700 val ap:0.9828 val auc:0.9853\n",
151 | " epoch | total:13.69s loop:11.75s eval:1.94s\n",
152 | " loop | forward:6.86s backward:4.83s sample:0.50s prep_batch:0.05s prep_input:0.93s post_update:0.00s\n",
153 | " comp | mem_update:0.00s time_zero:0.21s time_nbrs:1.18s self_attn:3.71s\n",
154 | "epoch 2:\n",
155 | " loss:142.4620 val ap:0.9828 val auc:0.9855\n",
156 | " epoch | total:14.43s loop:12.56s eval:1.86s\n",
157 | " loop | forward:7.24s backward:5.26s sample:0.48s prep_batch:0.05s prep_input:0.82s post_update:0.00s\n",
158 | " comp | mem_update:0.00s time_zero:0.25s time_nbrs:1.30s self_attn:3.99s\n",
159 | "epoch 3:\n",
160 | " loss:128.9173 val ap:0.9850 val auc:0.9872\n",
161 | " epoch | total:13.71s loop:12.58s eval:1.12s\n",
162 | " loop | forward:6.61s backward:5.91s sample:0.52s prep_batch:0.06s prep_input:0.55s post_update:0.00s\n",
163 | " comp | mem_update:0.00s time_zero:0.27s time_nbrs:0.96s self_attn:3.47s\n",
164 | "epoch 4:\n",
165 | " loss:123.7265 val ap:0.9852 val auc:0.9874\n",
166 | " epoch | total:16.85s loop:15.26s eval:1.58s\n",
167 | " loop | forward:7.32s backward:7.84s sample:0.82s prep_batch:0.09s prep_input:0.52s post_update:0.00s\n",
168 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.28s self_attn:3.50s\n",
169 | "epoch 5:\n",
170 | " loss:116.0836 val ap:0.9843 val auc:0.9871\n",
171 | " epoch | total:16.34s loop:14.75s eval:1.58s\n",
172 | " loop | forward:7.10s backward:7.56s sample:0.79s prep_batch:0.08s prep_input:0.51s post_update:0.00s\n",
173 | " comp | mem_update:0.00s time_zero:0.40s time_nbrs:1.27s self_attn:3.39s\n",
174 | "epoch 6:\n",
175 | " loss:113.4501 val ap:0.9877 val auc:0.9894\n",
176 | " epoch | total:18.37s loop:17.14s eval:1.21s\n",
177 | " loop | forward:7.92s backward:9.13s sample:0.81s prep_batch:0.08s prep_input:0.51s post_update:0.00s\n",
178 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.21s self_attn:3.82s\n",
179 | "epoch 7:\n",
180 | " loss:108.1715 val ap:0.9869 val auc:0.9889\n",
181 | " epoch | total:18.27s loop:17.03s eval:1.23s\n",
182 | " loop | forward:7.81s backward:9.12s sample:0.65s prep_batch:0.09s prep_input:0.52s post_update:0.00s\n",
183 | " comp | mem_update:0.00s time_zero:0.41s time_nbrs:1.20s self_attn:3.81s\n",
184 | "epoch 8:\n",
185 | " loss:103.1291 val ap:0.9882 val auc:0.9897\n",
186 | " epoch | total:17.87s loop:16.68s eval:1.18s\n",
187 | " loop | forward:7.59s backward:8.99s sample:0.64s prep_batch:0.09s prep_input:0.50s post_update:0.00s\n",
188 | " comp | mem_update:0.00s time_zero:0.39s time_nbrs:1.17s self_attn:3.71s\n",
189 | "epoch 9:\n",
190 | " loss:101.0907 val ap:0.9899 val auc:0.9912\n",
191 | " epoch | total:18.85s loop:17.23s eval:1.61s\n",
192 | " loop | forward:7.75s backward:9.36s sample:0.85s prep_batch:0.11s prep_input:0.56s post_update:0.00s\n",
193 | " comp | mem_update:0.00s time_zero:0.43s time_nbrs:1.32s self_attn:3.64s\n",
194 | "best model at epoch 9\n",
195 | "loading saved checkpoint and testing model...\n",
196 | " test time:1.54s AP:0.9854 AUC:0.9876\n"
197 | ]
198 | }
199 | ],
200 | "source": [
201 | "### training\n",
202 | "\n",
203 | "train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)\n",
204 | "neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)\n",
205 | "\n",
206 | "trainer = support.LinkPredTrainer(\n",
207 | " ctx, model, criterion, optimizer, neg_sampler,\n",
208 | " EPOCHS, BATCH_SIZE, train_end, val_end,\n",
209 | " model_path, None)\n",
210 | "\n",
211 | "trainer.train()\n",
212 | "trainer.test()"
213 | ]
214 | }
215 | ],
216 | "metadata": {
217 | "colab": {
218 | "provenance": []
219 | },
220 | "kernelspec": {
221 | "display_name": "Python 3",
222 | "name": "python3"
223 | },
224 | "language_info": {
225 | "codemirror_mode": {
226 | "name": "ipython",
227 | "version": 3
228 | },
229 | "file_extension": ".py",
230 | "mimetype": "text/x-python",
231 | "name": "python",
232 | "nbconvert_exporter": "python",
233 | "pygments_lexer": "ipython3",
234 | "version": "3.7.16"
235 | }
236 | },
237 | "nbformat": 4,
238 | "nbformat_minor": 0
239 | }
240 |
--------------------------------------------------------------------------------
/examples/tgat/tgat.py:
--------------------------------------------------------------------------------
1 | import tglite as tg
2 |
3 | from torch import nn, Tensor
4 | from tglite.nn import TemporalAttnLayer
5 |
6 | import os
7 | import sys
8 | sys.path.append(os.path.join(os.getcwd(), '..'))
9 | import support
10 |
11 |
12 | class TGAT(nn.Module):
13 | def __init__(self, ctx: tg.TContext,
14 | dim_node: int, dim_edge: int, dim_time: int, dim_embed: int,
15 | sampler: tg.TSampler, num_layers=2, num_heads=2, dropout=0.1,
16 | dedup: bool = True):
17 | super().__init__()
18 | self.ctx = ctx
19 | self.num_layers = num_layers
20 | self.attn = nn.ModuleList([
21 | TemporalAttnLayer(ctx,
22 | num_heads=num_heads,
23 | dim_node=dim_node if i == 0 else dim_embed,
24 | dim_edge=dim_edge,
25 | dim_time=dim_time,
26 | dim_out=dim_embed,
27 | dropout=dropout)
28 | for i in range(num_layers)])
29 | self.sampler = sampler
30 | self.edge_predictor = support.EdgePredictor(dim=dim_embed)
31 | self.dedup = dedup
32 |
33 | def forward(self, batch: tg.TBatch) -> Tensor:
34 | head = batch.block(self.ctx)
35 | for i in range(self.num_layers):
36 | tail = head if i == 0 \
37 | else tail.next_block(include_dst=True)
38 | tail = tg.op.dedup(tail) if self.dedup else tail
39 | tail = tg.op.cache(self.ctx, tail.layer, tail)
40 | tail = self.sampler.sample(tail)
41 |
42 | tg.op.preload(head, use_pin=True)
43 | if tail.num_dst() > 0:
44 | tail.dstdata['h'] = tail.dstfeat()
45 | tail.srcdata['h'] = tail.srcfeat()
46 | embeds = tg.op.aggregate(head, list(reversed(self.attn)), key='h')
47 | del head
48 | del tail
49 |
50 | src, dst, neg = batch.split_data(embeds)
51 | scores = self.edge_predictor(src, dst)
52 | if batch.neg_nodes is not None:
53 | scores = (scores, self.edge_predictor(src, neg))
54 |
55 | return scores
--------------------------------------------------------------------------------
/examples/tgat/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | import tglite as tg
6 |
7 | from tgat import TGAT
8 | import support
9 |
10 | ### arguments
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name')
14 | parser.add_argument('--data-path', type=str, default='', help='path to data folder')
15 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model')
16 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)')
17 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)')
18 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)')
19 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)')
20 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)')
21 | parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)')
22 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)')
23 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)')
24 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)')
25 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)')
26 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use')
27 | parser.add_argument('--move', action='store_true', help='move data to device')
28 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)')
29 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)')
30 | parser.add_argument('--opt-dedup', action='store_true', help='enable dedup optimization')
31 | parser.add_argument('--opt-cache', action='store_true', help='enable caching optimization')
32 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings')
33 | parser.add_argument('--cache-limit', type=str, default=2e6, help='max number of embeds to cache (default: 2e6)')
34 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)')
35 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations')
36 | args = parser.parse_args()
37 | print(args)
38 |
39 | device = support.make_device(args.gpu)
40 | model_path = support.make_model_path('tgat', args.prefix, args.data)
41 | if args.seed >= 0:
42 | support.set_seed(args.seed)
43 |
44 | DATA: str = args.data
45 | DATA_PATH: str = args.data_path
46 | EPOCHS: int = args.epochs
47 | BATCH_SIZE: int = args.bsize
48 | LEARN_RATE: float = float(args.lr)
49 | DROPOUT: float = float(args.dropout)
50 | N_LAYERS: int = args.n_layers
51 | N_HEADS: int = args.n_heads
52 | N_NBRS: int = args.n_nbrs
53 | DIM_TIME: int = args.dim_time
54 | DIM_EMBED: int = args.dim_embed
55 | N_THREADS: int = args.n_threads
56 | SAMPLING: str = args.sampling
57 | OPT_DEDUP: bool = args.opt_dedup or args.opt_all
58 | OPT_CACHE: bool = args.opt_cache or args.opt_all
59 | OPT_TIME: bool = args.opt_time or args.opt_all
60 | CACHE_LIMIT: int = int(args.cache_limit)
61 | TIME_WINDOW: int = int(args.time_window)
62 |
63 |
64 | ### load data
65 |
66 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))
67 | support.load_feats(g, DATA, DATA_PATH)
68 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]
69 | dim_nfeat = g.nfeat.shape[1]
70 |
71 | g.set_compute(device)
72 | if args.move:
73 | g.move_data(device)
74 |
75 | ctx = tg.TContext(g)
76 | ctx.need_sampling(True)
77 | ctx.enable_embed_caching(OPT_CACHE, DIM_EMBED)
78 | ctx.enable_time_precompute(OPT_TIME)
79 | ctx.set_cache_limit(CACHE_LIMIT)
80 | ctx.set_time_window(TIME_WINDOW)
81 |
82 |
83 | ### model
84 |
85 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)
86 | model = TGAT(ctx,
87 | dim_node=dim_nfeat,
88 | dim_edge=dim_efeat,
89 | dim_time=DIM_TIME,
90 | dim_embed=DIM_EMBED,
91 | sampler=sampler,
92 | num_layers=N_LAYERS,
93 | num_heads=N_HEADS,
94 | dropout=DROPOUT)
95 | model = model.to(device)
96 | criterion = torch.nn.BCEWithLogitsLoss()
97 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)
98 |
99 |
100 | ### training
101 |
102 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)
103 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)
104 |
105 | trainer = support.LinkPredTrainer(
106 | ctx, model, criterion, optimizer, neg_sampler,
107 | EPOCHS, BATCH_SIZE, train_end, val_end,
108 | model_path, None)
109 |
110 | trainer.train()
111 | trainer.test()
--------------------------------------------------------------------------------
/examples/tgn/tgn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tglite as tg
3 |
4 | from torch import nn, Tensor
5 | from tglite.nn import TemporalAttnLayer
6 | from tglite._stats import tt
7 |
8 | import sys, os
9 | sys.path.append(os.path.join(os.getcwd(), '..'))
10 | import support
11 |
12 |
13 | class TGN(nn.Module):
14 | def __init__(self, ctx: tg.TContext,
15 | dim_node: int, dim_edge: int, dim_time: int, dim_embed: int,
16 | sampler: tg.TSampler, num_layers=2, num_heads=2, dropout=0.1,
17 | dedup: bool = True):
18 | super().__init__()
19 | self.ctx = ctx
20 | self.dim_edge = dim_edge
21 | self.num_layers = num_layers
22 | self.nfeat_map = None if dim_node == dim_embed else nn.Linear(dim_node, dim_embed)
23 | self.mem_cell = nn.GRUCell(2 * dim_embed + dim_edge + dim_time, dim_embed)
24 | self.mem_time_encode = tg.nn.TimeEncode(dim_time)
25 | self.attn = nn.ModuleList([
26 | TemporalAttnLayer(ctx,
27 | num_heads=num_heads,
28 | dim_node=dim_embed,
29 | dim_edge=dim_edge,
30 | dim_time=dim_time,
31 | dim_out=dim_embed,
32 | dropout=dropout)
33 | for i in range(num_layers)])
34 | self.sampler = sampler
35 | self.edge_predictor = support.EdgePredictor(dim=dim_embed)
36 | self.dedup = dedup
37 |
38 | def forward(self, batch: tg.TBatch) -> Tensor:
39 | # setup message passing
40 | head = batch.block(self.ctx)
41 | for i in range(self.num_layers):
42 | tail = head if i == 0 \
43 | else tail.next_block(include_dst=True, use_dst_times=False)
44 | tail = tg.op.dedup(tail) if self.dedup else tail
45 | tail = self.sampler.sample(tail)
46 |
47 | # load data / feats
48 | tg.op.preload(head, use_pin=True)
49 | if tail.num_dst() > 0:
50 | t_start = tt.start()
51 | mem = self.update_memory(tail)
52 | nfeat = tail.nfeat() if self.nfeat_map is None else self.nfeat_map(tail.nfeat())
53 | tail.dstdata['h'] = nfeat[:tail.num_dst()] + mem[:tail.num_dst()]
54 | tail.srcdata['h'] = nfeat[tail.num_dst():] + mem[tail.num_dst():]
55 | tt.t_mem_update += tt.elapsed(t_start)
56 | del nfeat
57 | del mem
58 |
59 | # compute embeddings
60 | embeds = tg.op.aggregate(head, list(reversed(self.attn)), key='h')
61 | del head
62 | del tail
63 |
64 | # compute scores
65 | src, dst, neg = batch.split_data(embeds)
66 | scores = self.edge_predictor(src, dst)
67 | if neg is not None:
68 | scores = (scores, self.edge_predictor(src, neg))
69 | del embeds
70 | del src
71 | del dst
72 | del neg
73 |
74 | # memory messages
75 | t_start = tt.start()
76 | self.save_raw_msgs(batch)
77 | tt.t_post_update += tt.elapsed(t_start)
78 |
79 | return scores
80 |
81 | def update_memory(self, blk: tg.TBlock) -> Tensor:
82 | cdev = blk.g.compute_device()
83 | nodes = blk.allnodes()
84 |
85 | mail_ts = blk.g.mailbox.time[nodes]
86 | delta = mail_ts - blk.g.mem.time[nodes]
87 | delta = delta.squeeze().to(cdev)
88 | mail = tg.op.precomputed_times(self.ctx, 0, self.mem_time_encode, delta)
89 | mail = torch.cat([blk.mail(), mail], dim=1)
90 |
91 | mem = blk.mem_data()
92 | mem = self.mem_cell(mail, mem)
93 | blk.g.mem.update(nodes, mem, mail_ts)
94 | return mem
95 |
96 | def save_raw_msgs(self, batch: tg.TBatch):
97 | sdev = batch.g.storage_device()
98 | mem = batch.g.mem.data
99 |
100 | blk = batch.block_adj(self.ctx)
101 | blk = tg.op.coalesce(blk, by='latest')
102 |
103 | uniq = torch.from_numpy(blk.dstnodes).long().to(sdev)
104 | nbrs = torch.from_numpy(blk.srcnodes).long().to(sdev)
105 | if self.dim_edge > 0:
106 | eids = torch.from_numpy(blk.eid).long().to(sdev)
107 | mail = torch.cat([mem[uniq], mem[nbrs], batch.g.efeat[eids]], dim=1)
108 | else:
109 | mail = torch.cat([mem[uniq], mem[nbrs]], dim=1)
110 | mail_ts = torch.from_numpy(blk.ets).to(sdev)
111 | batch.g.mailbox.store(uniq, mail, mail_ts)
--------------------------------------------------------------------------------
/examples/tgn/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | import tglite as tg
6 |
7 | import support
8 | from tgn import TGN
9 |
10 |
11 | ### arguments
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-d', '--data', type=str, required=True, help='dataset name')
15 | parser.add_argument('--data-path', type=str, default='', help='path to data folder')
16 | parser.add_argument('--prefix', type=str, default='', help='name for saving trained model')
17 | parser.add_argument('--gpu', type=int, default=0, help='gpu device to use (or -1 for cpu)')
18 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs (default: 100)')
19 | parser.add_argument('--bsize', type=int, default=200, help='batch size (default: 200)')
20 | parser.add_argument('--lr', type=str, default=0.0001, help='learning rate (default: 1e-4)')
21 | parser.add_argument('--dropout', type=str, default=0.1, help='dropout rate (default: 0.1)')
22 | parser.add_argument('--n-layers', type=int, default=2, help='number of layers (default: 2)')
23 | parser.add_argument('--n-heads', type=int, default=2, help='number of attention heads (default: 2)')
24 | parser.add_argument('--n-nbrs', type=int, default=20, help='number of neighbors to sample (default: 20)')
25 | parser.add_argument('--dim-time', type=int, default=100, help='dimension of time features (default: 100)')
26 | parser.add_argument('--dim-embed', type=int, default=100, help='dimension of embeddings (default: 100)')
27 | parser.add_argument('--seed', type=int, default=-1, help='random seed to use')
28 | parser.add_argument('--move', action='store_true', help='move data to device')
29 | parser.add_argument('--n-threads', type=int, default=32, help='number of threads for sampler (default: 32)')
30 | parser.add_argument('--sampling', type=str, default='recent', choices=['recent', 'uniform'], help='sampling strategy (default: recent)')
31 | parser.add_argument('--opt-dedup', action='store_true', help='enable dedup optimization')
32 | parser.add_argument('--opt-time', action='store_true', help='enable precomputing time encodings')
33 | parser.add_argument('--time-window', type=str, default=1e4, help='time window to precompute (default: 1e4)')
34 | parser.add_argument('--opt-all', action='store_true', help='enable all available optimizations')
35 | args = parser.parse_args()
36 | print(args)
37 |
38 | device = support.make_device(args.gpu)
39 | model_path = support.make_model_path('tgn', args.prefix, args.data)
40 | model_mem_path = support.make_model_mem_path('tgn', args.prefix, args.data)
41 | if args.seed >= 0:
42 | support.set_seed(args.seed)
43 |
44 | DATA: str = args.data
45 | DATA_PATH: str = args.data_path
46 | EPOCHS: int = args.epochs
47 | BATCH_SIZE: int = args.bsize
48 | LEARN_RATE: float = float(args.lr)
49 | DROPOUT: float = float(args.dropout)
50 | N_LAYERS: int = args.n_layers
51 | N_HEADS: int = args.n_heads
52 | N_NBRS: int = args.n_nbrs
53 | DIM_TIME: int = args.dim_time
54 | DIM_EMBED: int = args.dim_embed
55 | N_THREADS: int = args.n_threads
56 | SAMPLING: str = args.sampling
57 | OPT_DEDUP: bool = args.opt_dedup or args.opt_all
58 | OPT_TIME: bool = args.opt_time or args.opt_all
59 | TIME_WINDOW: int = int(args.time_window)
60 |
61 |
62 | ### load data
63 |
64 | g = support.load_graph(os.path.join(DATA_PATH, f'data/{DATA}/edges.csv'))
65 | support.load_feats(g, DATA, DATA_PATH)
66 | dim_efeat = 0 if g.efeat is None else g.efeat.shape[1]
67 | dim_nfeat = g.nfeat.shape[1]
68 |
69 | g.mailbox = tg.Mailbox(g.num_nodes(), 1, 2 * DIM_EMBED + dim_efeat)
70 | g.mem = tg.Memory(g.num_nodes(), DIM_EMBED)
71 |
72 | g.set_compute(device)
73 | if args.move:
74 | g.move_data(device)
75 |
76 | ctx = tg.TContext(g)
77 | ctx.need_sampling(True)
78 | ctx.enable_time_precompute(OPT_TIME)
79 | ctx.set_time_window(TIME_WINDOW)
80 |
81 |
82 | ### model
83 | sampler = tg.TSampler(N_NBRS, strategy=SAMPLING, num_threads=N_THREADS)
84 | model = TGN(ctx,
85 | dim_node=dim_nfeat,
86 | dim_edge=dim_efeat,
87 | dim_time=DIM_TIME,
88 | dim_embed=DIM_EMBED,
89 | sampler = sampler,
90 | num_layers=N_LAYERS,
91 | num_heads=N_HEADS,
92 | dropout=DROPOUT)
93 | model = model.to(device)
94 | criterion = torch.nn.BCEWithLogitsLoss()
95 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)
96 |
97 |
98 | ### training
99 |
100 | train_end, val_end = support.data_split(g.num_edges(), 0.7, 0.15)
101 | neg_sampler = lambda size: np.random.randint(0, g.num_nodes(), size)
102 |
103 | trainer = support.LinkPredTrainer(
104 | ctx, model, criterion, optimizer, neg_sampler,
105 | EPOCHS, BATCH_SIZE, train_end, val_end,
106 | model_path, model_mem_path)
107 |
108 | trainer.train()
109 | trainer.test()
110 |
--------------------------------------------------------------------------------
/include/tglite/cache.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "tglite/core.h"
4 |
5 | namespace tglite {
6 |
7 | py::tuple find_dedup_time_hits(
8 | torch::Tensor ×,
9 | torch::Tensor &time_table,
10 | int time_window);
11 |
12 | py::array_t compute_cache_keys(py::array_t &nodes, py::array_t ×);
13 |
14 | /// Table for caching computed embeddings.
15 | class EmbedTable {
16 | public:
17 | EmbedTable(ssize_t dim_emb, ssize_t limit);
18 |
19 | py::tuple lookup(py::array_t &keys, torch::Device &device);
20 |
21 | void store(py::array_t &keys, torch::Tensor &values);
22 |
23 | private:
24 | ssize_t _dim_emb;
25 | ssize_t _limit;
26 |
27 | ssize_t _start = 0;
28 | torch::Tensor _table;
29 | std::vector _keys;
30 | std::unordered_map _key2idx;
31 | // torch::Tensor _pin;
32 | };
33 |
34 | } // namespace tglite
35 |
--------------------------------------------------------------------------------
/include/tglite/core.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | namespace tglite {
7 |
8 | typedef int32_t IdI32;
9 | typedef float TsF32;
10 |
11 | } // namespace tglite
12 |
--------------------------------------------------------------------------------
/include/tglite/dedup.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "tglite/core.h"
4 |
5 | namespace tglite {
6 |
7 | py::tuple dedup_targets(py::array_t &nodes, py::array_t ×);
8 |
9 | // py::tuple dedup_indices(torch::Tensor &nodes, torch::Tensor ×);
10 |
11 | // bool dedup_targets(
12 | // const IdI32 *nodes_ptr,
13 | // const TsF32 *times_ptr,
14 | // size_t len,
15 | // const IdI32 *pre_nodes,
16 | // const TsF32 *pre_times,
17 | // size_t pre_len,
18 | // std::vector &uniq_nodes,
19 | // std::vector &uniq_times,
20 | // std::vector &inv_idx
21 | // );
22 |
23 | } // namespace tglite
24 |
--------------------------------------------------------------------------------
/include/tglite/sampler.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "tglite/core.h"
4 | #include "tglite/tcsr.h"
5 |
6 | namespace tglite {
7 |
8 | class TemporalBlock {
9 | public:
10 | size_t num_edges = 0;
11 | std::vector dstindex;
12 | std::vector srcnodes;
13 | std::vector eid;
14 | std::vector ets;
15 |
16 | TemporalBlock() {}
17 |
18 | py::array_t dstindex_copy() const { return to_pyarray_copy(dstindex); }
19 | py::array_t srcnodes_copy() const { return to_pyarray_copy(srcnodes); }
20 | py::array_t eid_copy() const { return to_pyarray_copy(eid); }
21 | py::array_t ets_copy() const { return to_pyarray_copy(ets); }
22 |
23 | // py::array_t dstindex_owned() {
24 | // auto *ptr = dstindex;
25 | // dstindex = nullptr;
26 | // return to_pyarray_owned(ptr);
27 | // }
28 |
29 | // py::array_t srcnodes_owned() {
30 | // auto *ptr = srcnodes;
31 | // srcnodes = nullptr;
32 | // return to_pyarray_owned(ptr);
33 | // }
34 |
35 | // py::array_t eid_owned() {
36 | // auto *ptr = eid;
37 | // eid = nullptr;
38 | // return to_pyarray_owned(ptr);
39 | // }
40 |
41 | // py::array_t ets_owned() {
42 | // auto *ptr = ets;
43 | // ets = nullptr;
44 | // return to_pyarray_owned(ptr);
45 | // }
46 | };
47 |
48 | class TemporalSampler {
49 | public:
50 | TemporalSampler(int num_threads, int num_nbrs, bool recent);
51 |
52 | TemporalBlock sample(TCSR &tcsr,
53 | py::array_t &nodes,
54 | py::array_t ×);
55 |
56 | private:
57 | int _num_threads;
58 | int _num_nbrs;
59 | bool _recent;
60 |
61 | void sample_layer(TCSR &tcsr, TemporalBlock &block,
62 | const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t size);
63 |
64 | void add_neighbor(TCSR &tcsr,
65 | std::vector *eid,
66 | std::vector *ets,
67 | std::vector *srcnodes,
68 | std::vector *dstindex,
69 | IdI32 &k, IdI32 &dst_idx);
70 |
71 | void combine_coo(
72 | TemporalBlock &block,
73 | std::vector **eid,
74 | std::vector **ets,
75 | std::vector **srcnodes,
76 | std::vector **dstindex,
77 | std::vector &out_nodes);
78 | };
79 |
80 | } // namespace tglite
81 |
--------------------------------------------------------------------------------
/include/tglite/tcsr.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "tglite/core.h"
4 | #include "tglite/utils.h"
5 |
6 | namespace tglite {
7 |
8 | class TCSR {
9 | public:
10 | std::vector ind;
11 | std::vector nbr;
12 | std::vector eid;
13 | std::vector ets;
14 |
15 | TCSR() {
16 | ind.push_back(0);
17 | }
18 |
19 | py::array_t ind_view() const { return to_pyarray_view(ind); }
20 | py::array_t nbr_view() const { return to_pyarray_view(nbr); }
21 | py::array_t eid_view() const { return to_pyarray_view(eid); }
22 | py::array_t ets_view() const { return to_pyarray_view(ets); }
23 | };
24 |
25 | TCSR create_tcsr(py::array_t &edges, py::array_t ×, size_t num_nodes);
26 |
27 | } // namespace tglite
28 |
--------------------------------------------------------------------------------
/include/tglite/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "tglite/core.h"
4 |
5 | namespace tglite {
6 |
7 | // torch::Tensor index_pinned(torch::Tensor &input, torch::Tensor &index);
8 |
9 | // py::tuple find_last_message(py::array_t &uniq_nodes, py::array_t &edges);
10 | py::array_t find_latest_uniq(py::array_t &uniq, py::array_t &nodes, py::array_t ×);
11 |
12 | /// Custom hash function for collision-free keys.
13 | inline int64_t opt_hash(int32_t &s, float &t) {
14 | return (static_cast(s) << 32) | static_cast(t);
15 | }
16 |
17 | template
18 | inline py::array_t to_pyarray_view(const T &seq) {
19 | if (seq.size() > 0) {
20 | auto capsule = py::capsule(&seq, [](void* p) { /* borrowed */ });
21 | return py::array(seq.size(), seq.data(), capsule);
22 | } else {
23 | return py::array();
24 | }
25 | }
26 |
27 | template
28 | inline py::array_t to_pyarray_copy(const T &seq) {
29 | if (seq.size() > 0) {
30 | T* copy_ptr = new T(seq);
31 | auto capsule = py::capsule(copy_ptr, [](void* p) { delete reinterpret_cast(p); });
32 | return py::array(copy_ptr->size(), copy_ptr->data(), capsule);
33 | } else {
34 | return py::array();
35 | }
36 | }
37 |
38 | template
39 | inline py::array_t to_pyarray_owned(T *seq_ptr) {
40 | if (seq_ptr && seq_ptr->size() > 0) {
41 | auto capsule = py::capsule(seq_ptr, [](void* p) { delete reinterpret_cast(p); });
42 | return py::array(seq_ptr->size(), seq_ptr->data(), capsule);
43 | } else {
44 | return py::array();
45 | }
46 | }
47 |
48 | } // namespace tglite
49 |
--------------------------------------------------------------------------------
/lib/bind.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/core.h"
2 | #include "tglite/cache.h"
3 | #include "tglite/dedup.h"
4 | #include "tglite/sampler.h"
5 | #include "tglite/tcsr.h"
6 | #include "tglite/utils.h"
7 |
8 | using namespace tglite;
9 |
10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
11 | py::class_(m, "TCSR")
12 | .def(py::init<>())
13 | .def_property_readonly("ind", &TCSR::ind_view)
14 | .def_property_readonly("nbr", &TCSR::nbr_view)
15 | .def_property_readonly("eid", &TCSR::eid_view)
16 | .def_property_readonly("ets", &TCSR::ets_view);
17 |
18 | py::class_(m, "TemporalBlock")
19 | .def(py::init<>())
20 | .def("copy_eid", &TemporalBlock::eid_copy)
21 | .def("copy_ets", &TemporalBlock::ets_copy)
22 | .def("copy_srcnodes", &TemporalBlock::srcnodes_copy)
23 | .def("copy_dstindex", &TemporalBlock::dstindex_copy)
24 | .def("num_edges", [](const TemporalBlock &b) { return b.num_edges; });
25 |
26 | py::class_(m, "TemporalSampler")
27 | .def(py::init())
28 | .def("sample", &TemporalSampler::sample);
29 |
30 | py::class_(m, "EmbedTable")
31 | .def(py::init())
32 | // .def("_table", [](const EmbedTable &et) { return et._table; })
33 | // .def("_keys", [](const EmbedTable &et) { return et._keys; })
34 | // .def("_map", [](const EmbedTable &et) { return et._key2idx; })
35 | .def("lookup", &EmbedTable::lookup, "tglite::EmbedTable::lookup")
36 | .def("store", &EmbedTable::store, "tglite::EmbedTable::store");
37 |
38 | m.def("create_tcsr", &create_tcsr, "tglite::create_tcsr");
39 | m.def("dedup_targets", &dedup_targets, "tglite::dedup_targets");
40 | m.def("find_latest_uniq", &find_latest_uniq, "tglite::find_latest_uniq");
41 | // m.def("find_last_message", &find_last_message, "tglite::find_last_message");
42 | m.def("find_dedup_time_hits", &find_dedup_time_hits, "tglite::find_dedup_time_hits");
43 | m.def("compute_cache_keys", &compute_cache_keys, "tglite::compute_cache_keys");
44 |
45 | // m.def("dedup_indices", &dedup_indices, "tglite::dedup_indices");
46 | // m.def("index_pinned", &index_pinned, "tglite::index_pinned");
47 | }
48 |
--------------------------------------------------------------------------------
/lib/cache.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/cache.h"
2 | #include "tglite/utils.h"
3 |
4 | namespace tglite {
5 |
6 | py::tuple find_dedup_time_hits(
7 | torch::Tensor ×,
8 | torch::Tensor &time_table,
9 | int time_window
10 | ) {
11 | auto tup = torch::_unique(times.flatten(), /*sorted=*/true, /*return_inverse=*/true);
12 |
13 | times = std::get<0>(tup);
14 | auto inv_idx = std::get<1>(tup);
15 |
16 | auto delta_int = times.to(torch::kInt64);
17 | auto hit_idx = (delta_int == times) & (0 <= delta_int) & (delta_int <= time_window);
18 | auto hit_delta = delta_int.index({hit_idx});
19 |
20 | int64_t hit_count = torch::sum(hit_idx).item().toLong();
21 | int64_t uniq_size = times.size(0);
22 |
23 | torch::Tensor out_embeds;
24 | if (hit_count == uniq_size) {
25 | out_embeds = time_table.index({hit_delta});
26 | } else {
27 | int64_t time_dim = time_table.size(1);
28 | auto opts = torch::TensorOptions().device(time_table.device());
29 | out_embeds = torch::zeros({uniq_size, time_dim}, opts);
30 | out_embeds.index_put_({hit_idx}, time_table.index({hit_delta}));
31 | }
32 |
33 | return py::make_tuple(hit_count, hit_idx, out_embeds, times, inv_idx);
34 | }
35 |
36 | py::array_t compute_cache_keys(py::array_t &nodes, py::array_t ×) {
37 | ssize_t size = nodes.size();
38 | auto keys = py::array_t(size);
39 | auto keys_ptr = static_cast(keys.request().ptr);
40 | auto node_ptr = static_cast(nodes.request().ptr);
41 | auto time_ptr = static_cast(times.request().ptr);
42 | for (ssize_t i = 0; i < size; i++) {
43 | keys_ptr[i] = opt_hash(node_ptr[i], time_ptr[i]);
44 | }
45 | return keys;
46 | }
47 |
48 | EmbedTable::
49 | EmbedTable(ssize_t dim_emb, ssize_t limit)
50 | : _dim_emb(dim_emb), _limit(limit) {
51 | ssize_t start_capacity = std::min((ssize_t)1024, limit);
52 | _table = torch::zeros({start_capacity, dim_emb}, torch::TensorOptions().dtype(torch::kFloat32));
53 | _key2idx.reserve(start_capacity);
54 | _keys.resize(start_capacity);
55 | }
56 |
57 | py::tuple EmbedTable::
58 | lookup(py::array_t &keys, torch::Device &device) {
59 | ssize_t size = keys.size();
60 | auto output = torch::zeros({size, _dim_emb});
61 | auto hit_idx = torch::zeros(size, torch::TensorOptions().dtype(torch::kBool));
62 |
63 | auto *keys_ptr = static_cast(keys.request().ptr);
64 | auto *hit_ptr = hit_idx.accessor().data();
65 | std::vector indices;
66 | indices.reserve(size);
67 |
68 | for (ssize_t i = 0; i < size; i++) {
69 | auto it = _key2idx.find(keys_ptr[i]);
70 | if (it != _key2idx.end()) {
71 | indices.push_back(it->second);
72 | hit_ptr[i] = true;
73 | }
74 | }
75 |
76 | output.index_put_({hit_idx}, _table.index(
77 | {torch::from_blob(indices.data(), indices.size(),
78 | torch::TensorOptions().dtype(torch::kLong))}));
79 |
80 | output = output.to(device);
81 | hit_idx = hit_idx.to(device);
82 | return py::make_tuple(hit_idx, output);
83 | }
84 |
85 | void EmbedTable::
86 | store(py::array_t &keys, torch::Tensor &values) {
87 | ssize_t size = keys.size();
88 | auto *keys_ptr = static_cast(keys.request().ptr);
89 | auto embeds = values.detach().cpu();
90 |
91 | ssize_t nrows = _table.size(0);
92 | if (_start + size > nrows && nrows < _limit) {
93 | ssize_t grow_size = std::max(_start + size, nrows * 2);
94 | grow_size = std::min(grow_size, _limit);
95 | _table.resize_({grow_size, _dim_emb});
96 | _keys.resize(grow_size);
97 | nrows = grow_size;
98 | }
99 |
100 | ssize_t inp_idx = 0;
101 | while (size > 0) {
102 | ssize_t nslots = std::min(nrows - _start, size);
103 | ssize_t inp_end = inp_idx + nslots;
104 | ssize_t out_end = _start + nslots;
105 |
106 | _table.index_put_({torch::indexing::Slice(_start, out_end)},
107 | embeds.index({torch::indexing::Slice(inp_idx, inp_end)}));
108 |
109 | for (ssize_t i = 0; i < nslots; i++) {
110 | ssize_t slot = _start + i;
111 | int64_t old_key = _keys[slot];
112 | int64_t new_key = keys_ptr[inp_idx + i];
113 | _key2idx.erase(old_key);
114 | _key2idx.emplace(new_key, slot);
115 | _keys[slot] = new_key;
116 | }
117 |
118 | _start = out_end % nrows;
119 | inp_idx = inp_end;
120 | size -= nslots;
121 | }
122 | }
123 |
124 | } // namespace tglite
125 |
--------------------------------------------------------------------------------
/lib/dedup.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/dedup.h"
2 | #include "tglite/utils.h"
3 |
4 | namespace tglite {
5 |
6 | py::tuple dedup_targets(py::array_t &nodes, py::array_t ×) {
7 | ssize_t size = nodes.size();
8 | auto inv_idx = py::array_t(size);
9 |
10 | auto *node_ptr = static_cast(nodes.request().ptr);
11 | auto *time_ptr = static_cast(times.request().ptr);
12 | auto *inv_ptr = static_cast(inv_idx.request().ptr);
13 |
14 | std::unordered_map key2idx;
15 | auto *uniq_node = new std::vector;
16 | auto *uniq_time = new std::vector;
17 | uniq_node->reserve(size);
18 | uniq_time->reserve(size);
19 | key2idx.reserve(size);
20 |
21 | bool has_dups = false;
22 | for (ssize_t i = 0; i < size; i++) {
23 | IdI32 nid = node_ptr[i];
24 | TsF32 nts = time_ptr[i];
25 | int64_t key = opt_hash(nid, nts);
26 | auto iter = key2idx.find(key);
27 | if (iter != key2idx.end()) {
28 | auto uniq_idx = iter->second;
29 | inv_ptr[i] = uniq_idx;
30 | has_dups = true;
31 | } else {
32 | auto idx = uniq_node->size();
33 | uniq_node->push_back(nid);
34 | uniq_time->push_back(nts);
35 | key2idx.emplace(key, idx);
36 | inv_ptr[i] = idx;
37 | }
38 | }
39 |
40 | py::array_t res_nodes = to_pyarray_owned(uniq_node);
41 | py::array_t res_times = to_pyarray_owned(uniq_time);
42 | return py::make_tuple(has_dups, res_nodes, res_times, inv_idx);
43 | }
44 |
45 | // py::tuple dedup_indices(torch::Tensor &nodes, torch::Tensor ×) {
46 | // ssize_t size = nodes.size(0);
47 | //
48 | // auto nodes_ptr = nodes.accessor().data();
49 | // auto times_ptr = times.accessor().data();
50 | //
51 | // auto opt_i64 = torch::TensorOptions().dtype(torch::kInt64);
52 | // auto inv_idx = torch::zeros(size, opt_i64);
53 | // auto inv_ptr = inv_idx.accessor().data();
54 | //
55 | // std::unordered_map key2idx;
56 | // std::vector indices;
57 | // key2idx.reserve(size);
58 | //
59 | // for (ssize_t i = 0; i < size; i++) {
60 | // IdI32 nid = nodes_ptr[i];
61 | // TsF32 nts = times_ptr[i];
62 | //
63 | // int64_t key = opt_hash(nid, nts);
64 | // auto it = key2idx.find(key);
65 | //
66 | // if (it != key2idx.end()) {
67 | // auto uniq_idx = it->second;
68 | // inv_ptr[i] = uniq_idx;
69 | // } else {
70 | // auto idx = indices.size();
71 | // indices.push_back(i);
72 | // key2idx.emplace(key, idx);
73 | // inv_ptr[i] = idx;
74 | // }
75 | // }
76 | //
77 | // py::array_t filter_idx = py::cast(indices);
78 | // return py::make_tuple(filter_idx, inv_idx);
79 | // }
80 |
81 | // bool dedup_targets(
82 | // const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t len,
83 | // const IdI32 *pre_nodes, const TsF32 *pre_times, size_t pre_len,
84 | // std::vector &uniq_nodes,
85 | // std::vector &uniq_times,
86 | // std::vector &inv_idx) {
87 | //
88 | // std::unordered_map key2idx;
89 | // key2idx.reserve(len + pre_len);
90 | // bool has_dups = false;
91 | //
92 | // for (size_t i = 0; i < len + pre_len; i++) {
93 | // IdI32 nid = i < pre_len ? pre_nodes[i] : nodes_ptr[i - pre_len];
94 | // TsF32 nts = i < pre_len ? pre_times[i] : times_ptr[i - pre_len];
95 | //
96 | // int64_t key = opt_hash(nid, nts);
97 | // auto it = key2idx.find(key);
98 | //
99 | // if (it != key2idx.end()) {
100 | // auto uniq_idx = it->second;
101 | // inv_idx.push_back(uniq_idx);
102 | // has_dups = true;
103 | // } else {
104 | // auto idx = uniq_nodes.size();
105 | // uniq_nodes.push_back(nid);
106 | // uniq_times.push_back(nts);
107 | // key2idx.emplace(key, idx);
108 | // inv_idx.push_back(idx);
109 | // }
110 | // }
111 | //
112 | // return has_dups;
113 | // }
114 |
115 | } // namespace tglite
116 |
--------------------------------------------------------------------------------
/lib/sampler.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/sampler.h"
2 | #include
3 |
4 | namespace tglite {
5 |
6 | TemporalSampler::
7 | TemporalSampler(int num_threads, int num_nbrs, bool recent)
8 | : _num_threads(num_threads), _num_nbrs(num_nbrs), _recent(recent) { }
9 |
10 | TemporalBlock TemporalSampler::
11 | sample(TCSR &tcsr, py::array_t &nodes, py::array_t ×) {
12 | omp_set_num_threads(_num_threads);
13 |
14 | const IdI32 *nodes_ptr = static_cast(nodes.request().ptr);
15 | const TsF32 *times_ptr = static_cast(times.request().ptr);
16 | size_t size = nodes.size();
17 |
18 | TemporalBlock block;
19 | sample_layer(tcsr, block, nodes_ptr, times_ptr, size);
20 |
21 | return block;
22 | }
23 |
24 | void TemporalSampler::
25 | sample_layer(TCSR &tcsr, TemporalBlock &block,
26 | const IdI32 *nodes_ptr, const TsF32 *times_ptr, size_t size) {
27 | std::vector *eid[_num_threads];
28 | std::vector *ets[_num_threads];
29 | std::vector *srcnodes[_num_threads];
30 | std::vector *dstindex[_num_threads];
31 | std::vector out_nodes(_num_threads, 0);
32 |
33 | int nodes_per_thread = int(ceil(static_cast(size) / _num_threads));
34 | int reserve_capacity = nodes_per_thread * _num_nbrs;
35 |
36 | #pragma omp parallel
37 | {
38 | int tid = omp_get_thread_num();
39 | unsigned int loc_seed = tid;
40 |
41 | eid[tid] = new std::vector;
42 | ets[tid] = new std::vector;
43 | srcnodes[tid] = new std::vector;
44 | dstindex[tid] = new std::vector;
45 |
46 | eid[tid]->reserve(reserve_capacity);
47 | ets[tid]->reserve(reserve_capacity);
48 | srcnodes[tid]->reserve(reserve_capacity);
49 | dstindex[tid]->reserve(reserve_capacity);
50 |
51 | #pragma omp for schedule(static, nodes_per_thread)
52 | for (size_t j = 0; j < size; j++) {
53 | IdI32 nid = nodes_ptr[j];
54 | TsF32 nts = times_ptr[j];
55 |
56 | IdI32 s_search = tcsr.ind[nid];
57 | auto e_it = std::lower_bound(tcsr.ets.begin() + s_search,
58 | tcsr.ets.begin() + tcsr.ind[nid + 1], nts);
59 | IdI32 e_search = std::max(int(e_it - tcsr.ets.begin()) - 1, s_search);
60 |
61 | if (_recent || (e_search - s_search + 1 < _num_nbrs)) {
62 | for (IdI32 k = e_search; k >= std::max(s_search, e_search - _num_nbrs + 1); k--) {
63 | if (tcsr.ets[k] < nts - 1e-7f) {
64 | add_neighbor(tcsr, eid[tid], ets[tid], srcnodes[tid], dstindex[tid], k, out_nodes[tid]);
65 | }
66 | }
67 | } else {
68 | for (int k = 0; k < _num_nbrs; k++) {
69 | IdI32 picked = s_search + rand_r(&loc_seed) % (e_search - s_search + 1);
70 | if (tcsr.ets[picked] < nts - 1e-7f) {
71 | add_neighbor(tcsr, eid[tid], ets[tid], srcnodes[tid], dstindex[tid], picked, out_nodes[tid]);
72 | }
73 | }
74 | }
75 |
76 | out_nodes[tid] += 1;
77 | }
78 | }
79 |
80 | combine_coo(block, eid, ets, srcnodes, dstindex, out_nodes);
81 | }
82 |
83 | inline void TemporalSampler::
84 | add_neighbor(TCSR &tcsr,
85 | std::vector *eid, std::vector *ets,
86 | std::vector *srcnodes, std::vector *dstindex,
87 | IdI32 &k, IdI32 &dst_idx) {
88 | eid->push_back(tcsr.eid[k]);
89 | ets->push_back(tcsr.ets[k]);
90 | srcnodes->push_back(tcsr.nbr[k]);
91 | dstindex->push_back(dst_idx);
92 | }
93 |
94 | inline void TemporalSampler::
95 | combine_coo(TemporalBlock &block,
96 | std::vector **eid,
97 | std::vector **ets,
98 | std::vector **srcnodes,
99 | std::vector **dstindex,
100 | std::vector &out_nodes) {
101 |
102 | std::vector scan_nodes;
103 | std::vector scan_edges;
104 | scan_nodes.push_back(0);
105 | scan_edges.push_back(0);
106 | for (int tid = 0; tid < _num_threads; tid++) {
107 | scan_nodes.push_back(scan_nodes.back() + out_nodes[tid]);
108 | scan_edges.push_back(scan_edges.back() + eid[tid]->size());
109 | }
110 |
111 | IdI32 num_edges = scan_edges.back();
112 | block.dstindex.resize(num_edges);
113 | block.srcnodes.resize(num_edges);
114 | block.eid.resize(num_edges);
115 | block.ets.resize(num_edges);
116 | block.num_edges = num_edges;
117 |
118 | #pragma omp parallel for schedule(static, 1)
119 | for (int tid = 0; tid < _num_threads; tid++) {
120 | std::transform(dstindex[tid]->begin(), dstindex[tid]->end(),
121 | dstindex[tid]->begin(), [&](auto &v) { return v + scan_nodes[tid]; });
122 | std::copy(eid[tid]->begin(), eid[tid]->end(), block.eid.begin() + scan_edges[tid]);
123 | std::copy(ets[tid]->begin(), ets[tid]->end(), block.ets.begin() + scan_edges[tid]);
124 | std::copy(srcnodes[tid]->begin(), srcnodes[tid]->end(), block.srcnodes.begin() + scan_edges[tid]);
125 | std::copy(dstindex[tid]->begin(), dstindex[tid]->end(), block.dstindex.begin() + scan_edges[tid]);
126 | delete eid[tid];
127 | delete ets[tid];
128 | delete srcnodes[tid];
129 | delete dstindex[tid];
130 | }
131 | }
132 |
133 | } // namespace tglite
134 |
--------------------------------------------------------------------------------
/lib/tcsr.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/tcsr.h"
2 |
3 | namespace tglite {
4 |
5 | struct ETuple {
6 | IdI32 nbr;
7 | IdI32 eid;
8 | TsF32 ets;
9 | static bool cmp_ts(const ETuple &a, const ETuple &b) {
10 | return a.ets < b.ets;
11 | }
12 | };
13 |
14 | TCSR create_tcsr(py::array_t &edges, py::array_t ×, size_t num_nodes) {
15 | auto *edges_ptr = static_cast(edges.request().ptr);
16 | auto *times_ptr = static_cast(times.request().ptr);
17 |
18 | std::vector> adj_list(num_nodes);
19 | for (IdI32 eid = 0; eid < edges.shape(0); eid++) {
20 | IdI32 src = edges_ptr[eid * 2];
21 | IdI32 dst = edges_ptr[eid * 2 + 1];
22 | TsF32 ets = times_ptr[eid];
23 | adj_list[src].push_back({dst, eid, ets});
24 | adj_list[dst].push_back({src, eid, ets});
25 | }
26 |
27 | TCSR tcsr;
28 | for (auto &adj : adj_list) {
29 | std::sort(adj.begin(), adj.end(), ETuple::cmp_ts);
30 | tcsr.ind.push_back(tcsr.ind.back() + adj.size());
31 | for (auto &tuple : adj) {
32 | tcsr.nbr.push_back(tuple.nbr);
33 | tcsr.eid.push_back(tuple.eid);
34 | tcsr.ets.push_back(tuple.ets);
35 | }
36 | adj.clear();
37 | adj.shrink_to_fit();
38 | }
39 |
40 | return tcsr;
41 | }
42 |
43 | } // namespace tglite
44 |
--------------------------------------------------------------------------------
/lib/utils.cpp:
--------------------------------------------------------------------------------
1 | #include "tglite/utils.h"
2 | #include
3 |
4 | namespace tglite {
5 |
6 | // torch::Tensor index_pinned(torch::Tensor &input, torch::Tensor &index) {
7 | // int64_t nrows = index.size(0);
8 | // int64_t nfeats = input.size(1);
9 | //
10 | // auto opts = torch::TensorOptions()
11 | // .device(input.device())
12 | // .dtype(input.dtype())
13 | // .pinned_memory(true);
14 | // auto out = torch::zeros({nrows, nfeats}, opts);
15 | //
16 | // auto *out_ptr = out.accessor().data();
17 | // auto *inp_ptr = input.accessor().data();
18 | // auto *idx_ptr = index.accessor().data();
19 | //
20 | // #pragma omp parallel for
21 | // for (int64_t i = 0; i < nrows; i++) {
22 | // auto idx = idx_ptr[i];
23 | // auto inp_start = inp_ptr + idx * nfeats;
24 | // std::copy(inp_start, inp_start + nfeats, out_ptr + i * nfeats);
25 | // }
26 | //
27 | // return out;
28 | // }
29 |
30 | // py::tuple find_last_message(py::array_t &uniq_nodes, py::array_t &edges) {
31 | // ssize_t num_nodes = uniq_nodes.size();
32 | // ssize_t num_edges = edges.shape(0);
33 | //
34 | // auto *msg_order = new std::vector;
35 | // auto *msg_index = new std::vector;
36 | // msg_order->resize(num_nodes * 2);
37 | // msg_index->resize(num_nodes);
38 | //
39 | // auto *nodes_ptr = static_cast(uniq_nodes.request().ptr);
40 | // auto *edges_ptr = static_cast(edges.request().ptr);
41 | //
42 | // #pragma omp parallel for schedule(static)
43 | // for (ssize_t i = 0; i < num_nodes; i++) {
44 | // IdI32 nid = nodes_ptr[i];
45 | // for (ssize_t e = num_edges - 1; e >= 0; e--) {
46 | // if (edges_ptr[e * 2 + 0] == nid) {
47 | // // is src node, order is same
48 | // (*msg_order)[i * 2 + 0] = edges_ptr[e * 2 + 0];
49 | // (*msg_order)[i * 2 + 1] = edges_ptr[e * 2 + 1];
50 | // (*msg_index)[i] = e;
51 | // break;
52 | // } else if (edges_ptr[e * 2 + 1] == nid) {
53 | // // is dst node, order is flipped
54 | // (*msg_order)[i * 2 + 0] = edges_ptr[e * 2 + 1];
55 | // (*msg_order)[i * 2 + 1] = edges_ptr[e * 2 + 0];
56 | // (*msg_index)[i] = e;
57 | // break;
58 | // }
59 | // }
60 | // }
61 | //
62 | // py::array_t res_order = to_pyarray_owned(msg_order);
63 | // py::array_t res_index = to_pyarray_owned(msg_index);
64 | // return py::make_tuple(res_order, res_index);
65 | // }
66 |
67 | py::array_t find_latest_uniq(py::array_t &uniq, py::array_t &nodes, py::array_t ×) {
68 | ssize_t num_uniq = uniq.size();
69 | ssize_t num_nodes = nodes.size();
70 |
71 | auto *index = new std::vector;
72 | index->resize(num_uniq);
73 |
74 | auto *uniq_ptr = static_cast(uniq.request().ptr);
75 | auto *node_ptr = static_cast(nodes.request().ptr);
76 | auto *time_ptr = static_cast(times.request().ptr);
77 |
78 | #pragma omp parallel for schedule(static)
79 | for (ssize_t i = 0; i < num_uniq; i++) {
80 | IdI32 nid = uniq_ptr[i];
81 | TsF32 max = -1.0f;
82 | for (ssize_t j = num_nodes - 1; j >= 0; j--) {
83 | if (node_ptr[j] == nid && time_ptr[j] > max) {
84 | max = time_ptr[j];
85 | (*index)[i] = j;
86 | }
87 | }
88 | }
89 |
90 | py::array_t res = to_pyarray_owned(index);
91 | return res;
92 | }
93 |
94 | } // namespace tglite
95 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "tglite"
3 | # dynamic = ["version"]
4 | version = "0.0.4"
5 | description = "Temporal GNN Lightweight Framework"
6 | readme = "README.md"
7 | license = {file = "LICENSE"}
8 | authors = [
9 | {name = "Yufeng Wang", email = "yufengwang05@gmail.com"},
10 | {name = "Charith Mendis", email = "charithm@illinois.edu"}
11 | ]
12 | maintainers = [
13 | {name = "Wanyu Zhao", email = "wanyu2@illinois.edu"}
14 | ]
15 | keywords = [
16 | "machine learning", "TGNN",
17 | ]
18 | classifiers = [
19 | "Development Status :: 3 - Alpha",
20 | "Intended Audience :: Developers",
21 | "Intended Audience :: Science/Research",
22 | "License :: OSI Approved :: Apache Software License",
23 | "Operating System :: POSIX :: Linux",
24 | "Programming Language :: Python :: 3",
25 | "Programming Language :: Python :: 3.7",
26 | "Programming Language :: Python :: 3.8",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | ]
30 | requires-python = ">=3.7, <3.11"
31 | dependencies = [
32 | 'numpy==1.21.6; python_version == "3.7"',
33 | 'numpy>=1.21.6, <1.25.0; python_version == "3.8"',
34 | 'numpy>=1.21.6, <1.26.0; python_version == "3.9"',
35 | 'numpy>=1.21.6; python_version == "3.10"',
36 | 'torch>=1.12.1, <2.0.0; python_version == "3.7"',
37 | 'torch>=1.12.1; python_version >= "3.8"',
38 | 'torch-scatter>=2.1.0, <2.1.2; python_version == "3.7"',
39 | 'torch-scatter>=2.1.0; python_version == "3.8"',
40 | ]
41 |
42 | [project.optional-dependencies]
43 | dev = ["sphinx"]
44 | test = ["pytest", "pytest-cov"]
45 | docs = ["sphinx_rtd_theme", "nbsphinx", "ipykernel"]
46 |
47 | [project.urls]
48 | homepage = "https://github.com/ADAPT-uiuc/tglite"
49 | documentation = "https://readthedocs.org"
50 | repository = "https://github.com/ADAPT-uiuc/tglite"
51 | issues = "https://github.com/me/ADAPT-uiuc/tglite/issues"
52 |
53 | # [tool.setuptools.dynamic]
54 | # version = {attr = "tglite.__version__"}
55 |
56 | [tool.setuptools.packages.find]
57 | where = ["python"]
58 | include = ["tglite*"]
59 | namespaces = false
60 |
61 | [build-system]
62 | requires = ["setuptools>=61.0", "wheel", "torch>=1.12.1"]
63 | build-backend = "setuptools.build_meta"
64 |
--------------------------------------------------------------------------------
/python/tglite/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | tglite: Temporal GNN Lightweight Framework
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | __version__ = "0.0.4"
8 |
9 | ### Re-exports
10 |
11 | import torch
12 | from ._graph import TGraph, from_csv
13 | from ._batch import TBatch
14 | from ._block import TBlock
15 | from ._frame import TFrame
16 | from ._memory import Memory
17 | from ._mailbox import Mailbox
18 | from ._sampler import TSampler
19 | from ._context import TContext
20 | from ._core import TError
21 | from . import _utils as utils
22 | from . import nn
23 | from . import op
24 |
25 |
26 | def iter_edges(g: TGraph, size=1, start=None, end=None) -> EdgesIter:
27 | """
28 | Create and return an iterator to generate TBatch
29 |
30 | :param TGraph g: The graph to iterate on.
31 | :param int size: Number of edges in each mini-batch.
32 | :rtype: EdgesIter
33 | :param start: The starting edge index.
34 | :type start: int or None
35 | :param end: The ending edge index.
36 | :type end: int or None
37 | """
38 | return EdgesIter(g, size=size, start=start, end=end)
39 |
40 |
41 | class EdgesIter(object):
42 | """ An edge iterator of a TGraph."""
43 | def __init__(self, g: TGraph, size=1, start=None, end=None):
44 | """
45 | Create an edge iterator.
46 |
47 | :param TGraph g: The graph it iterates on.
48 | :param int size: Number of edges in each mini-batch.
49 | :param start: The starting edge index.
50 | :type start: int or None
51 | :param end: The ending edge index.
52 | :type end: int or None
53 | """
54 | self._g = g
55 | self._size = size
56 | self._curr = 0 if start is None else start
57 | self._last = g.num_edges() if end is None else end
58 |
59 | def __iter__(self) -> EdgesIter:
60 | return self
61 |
62 | def __next__(self) -> TBatch:
63 | if self._curr < self._last:
64 | idx = self._curr
65 | self._curr += self._size
66 | end = min(self._curr, self._last)
67 | return TBatch(self._g, range=(idx, end))
68 | raise StopIteration
69 |
--------------------------------------------------------------------------------
/python/tglite/_batch.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Optional, Tuple
2 | if TYPE_CHECKING:
3 | from ._graph import TGraph
4 | from ._context import TContext
5 |
6 | import numpy as np
7 | from torch import Tensor
8 |
9 | from ._core import TError
10 | from ._block import TBlock
11 | from ._stats import tt
12 |
13 |
14 | class TBatch(object):
15 | """
16 | Represents a batch of temporal edges to process. A thin wrapper with a TGraph reference and without actually
17 | materializing any arrays until they are needed.
18 | """
19 |
20 | def __init__(self, g: 'TGraph', range: Tuple[int, int]):
21 | """
22 | Internal constructor for creating a TBatch.
23 |
24 | :param TGraph g: The TGraph.
25 | :param Tuple[int, int] range: The range of edge indices: beginning and ending edge index.
26 | """
27 | self._g = g
28 | self._beg_idx = range[0]
29 | self._end_idx = range[1]
30 | self._neg_nodes = None
31 |
32 | def __len__(self) -> int:
33 | """Returns the total number of edges in the batch."""
34 | return self._end_idx - self._beg_idx
35 |
36 | @property
37 | def g(self) -> 'TGraph':
38 | """Returns the TGraph."""
39 | return self._g
40 |
41 | @property
42 | def neg_nodes(self) -> Optional[np.ndarray]:
43 | """Get the negative nodes."""
44 | return self._neg_nodes
45 |
46 | @neg_nodes.setter
47 | def neg_nodes(self, value: np.ndarray):
48 | """
49 | Set the negative nodes.
50 |
51 | :param np.ndarray value: An array of negative node samples.
52 | :raises TError: if value is not a 1-dimensional ndarray.
53 | """
54 | if not isinstance(value, np.ndarray):
55 | raise TError('negative samples must be an ndarray')
56 | if len(value.shape) != 1:
57 | raise TError('negative samples must be 1-dimensional')
58 | self._neg_nodes = value
59 |
60 | def block(self, ctx: 'TContext') -> TBlock:
61 | """Creates the head TBlock of the batch, including negative nodes if set."""
62 | t_start = tt.start()
63 | blk = TBlock(ctx, 0, self.nodes(), self.times())
64 | tt.t_prep_input += tt.elapsed(t_start)
65 | return blk
66 |
67 | def block_adj(self, ctx: 'TContext') -> TBlock:
68 | """Creates the head TBlock with batch edges as neighbors (excluding negative nodes)."""
69 | dstnodes = self.nodes(include_negs=False)
70 | srcnodes = self.nodes(include_negs=False, reverse=True)
71 | dstindex = np.arange(len(dstnodes))
72 | eids = np.tile(self.eids(), 2)
73 | ets = self.times(include_negs=False)
74 | return TBlock(ctx, 0, dstnodes, ets, dstindex, srcnodes, eids, ets)
75 |
76 | def eids(self) -> np.ndarray:
77 | """
78 | Returns edge ids of the batch.
79 |
80 | rtype: np.ndarray
81 | """
82 | return np.arange(self._beg_idx, self._end_idx, dtype=np.int32)
83 |
84 | def edges(self) -> np.ndarray:
85 | """
86 | Returns the edges in the batch as a two-column ndarray, where the first column represents the source
87 | node index and the second column represents the destination node index.
88 |
89 | rtype: np.ndarray
90 | """
91 | return self.g._edges[self._beg_idx:self._end_idx]
92 |
93 | def nodes(self, include_negs=True, reverse=False) -> np.ndarray:
94 | """
95 | Returns a node index array: [src, des(, neg)] if reverse is False or [des, src(, src)] if reverse is True.
96 |
97 | :param bool include_negs: Whether to include negative nodes.
98 | :param bool reverse: Whether to reverse the edges.
99 | :rtype: np.ndarray
100 | """
101 | nids = self.g._edges[self._beg_idx:self._end_idx]
102 | nids = np.flip(nids, axis=1) if reverse else nids
103 | nids = nids.T.reshape(-1)
104 | if self._neg_nodes is not None and include_negs:
105 | negs = nids[len(self):] if reverse else self._neg_nodes
106 | nids = np.concatenate([nids, negs])
107 | return nids.astype(np.int32)
108 |
109 | def times(self, include_dsts=True, include_negs=True) -> np.ndarray:
110 | """
111 | Returns timestamps corresponding to the nodes. It retrieves timestamps of the batch edges (as the timestamps for source nodes),
112 | repeating the timestamps to include destination nodes and negative nodes.
113 |
114 | :param bool include_dst: Whether to include destination nodes of the edges as positive nodes.
115 | :param bool include_negs: Whether to include negative nodes.
116 | """
117 | n_repeats = 2 if include_dsts else 1
118 | if include_negs and self._neg_nodes is not None:
119 | n_repeats += 1
120 | times = self.g._times[self._beg_idx:self._end_idx]
121 | return np.tile(times, n_repeats).astype(np.float32)
122 |
123 | def split_data(self, data: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
124 | """
125 | Splits the data into multiple arrays, with each array containing a number of rows equal to the batch size.
126 |
127 | :param Tensor data: The source data to be split.
128 | :raises TError: If the length of data is not three times the batch size when negative nodes are included or two times otherwise.
129 | :return: A tuple (src, dst, neg), where neg is None if no negative nodes are specified.
130 | """
131 | size = len(self)
132 | if self._neg_nodes is not None:
133 | if data.shape[0] != 3 * size:
134 | raise TError('expected data to have 3 times batch size')
135 | dst = data[size:2 * size]
136 | neg = data[2 * size:]
137 | else:
138 | if data.shape[0] != 2 * size:
139 | raise TError('expected data to have 2 times batch size')
140 | dst = data[size:]
141 | neg = None
142 | src = data[:size]
143 | return (src, dst, neg)
--------------------------------------------------------------------------------
/python/tglite/_context.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, List
2 | if TYPE_CHECKING:
3 | from ._graph import TGraph
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from ._core import TError
9 |
10 |
11 | class TContext(object):
12 | """Graph-level context and scratch space used by the tglite runtime."""
13 |
14 | def __init__(self, g: 'TGraph'):
15 | """
16 | Internal constructor for creating a TContext.
17 |
18 | :param TGraph g: The TGraph to operate on.
19 | """
20 | self._g = g
21 | self._training = True
22 |
23 | # pin buffers
24 | self._efeat_pins = {}
25 | self._nfeat_pins = {}
26 | self._mem_data_pins = {}
27 | self._mail_pins = {}
28 |
29 | # embed caching
30 | self._cache_enabled = False
31 | self._cache_dim_emb = None
32 | self._cache_limit = int(2e6)
33 | self._cache_tables = {}
34 |
35 | # time precomputation
36 | self._time_enabled = False
37 | self._time_window = int(1e4)
38 | self._time_tables = {}
39 |
40 | @property
41 | def graph(self) -> 'TGraph':
42 | """Returns the TGraph it associated with."""
43 | return self._g
44 |
45 | def train(self):
46 | """Enables training mode and clear time tables and embedding cache tables."""
47 | self._training = True
48 | self._time_tables.clear()
49 | self._cache_tables.clear()
50 |
51 | def eval(self):
52 | """Disables training mode."""
53 | self._training = False
54 |
55 | def need_sampling(self, need: bool):
56 | """
57 | Creates tcsr within the TGraph if sampling is needed.
58 |
59 | :param bool need: Whether sampling is required.
60 | """
61 | if need:
62 | self._g._init_tcsr()
63 | else:
64 | self._g._tcsr = None
65 |
66 | def enable_embed_caching(self, enabled: bool, dim_embed: int = None):
67 | """
68 | Performs embedding cache settings and clear cache tables.
69 |
70 | :param bool enabled: Whether to enable embedding caching.
71 | :param dim_embed: Dimension of node embeddings.
72 | :raises TError: If enable is True and dim_embed is None.
73 | """
74 | self._cache_enabled = enabled
75 | if enabled and dim_embed is None:
76 | raise TError('need dimension of embeddings')
77 | elif enabled:
78 | self._cache_dim_emb = dim_embed
79 | self._cache_tables.clear()
80 |
81 | def set_cache_limit(self, limit: int):
82 | """
83 | Sets embedding cache limit and clear cache tables.
84 |
85 | :param int limit: Number of embeddings to cache.
86 | """
87 | self._cache_limit = limit
88 | self._cache_tables.clear()
89 |
90 | def enable_time_precompute(self, enabled: bool):
91 | """
92 | Performs time precomputation settings and clear time tables.
93 |
94 | :param bool enabled: Whether to enable embedding caching.
95 | """
96 | self._time_enabled = enabled
97 | self._time_tables.clear()
98 |
99 | def set_time_window(self, window: int):
100 | """
101 | Sets length of time window and clear time tables.
102 |
103 | :param int window: Length of time window.
104 | :raises TError: If int is negative.
105 | """
106 |
107 | if window < 0:
108 | raise TError('time window must be non-negative')
109 | self._time_window = window
110 | self._time_tables.clear()
111 |
112 | def _get_efeat_pin(self, layer: int, rows: int, dim: int) -> Tensor:
113 | return self._get_pin(self._efeat_pins, layer, rows, [dim])
114 |
115 | def _get_nfeat_pin(self, layer: int, rows: int, dim: int) -> Tensor:
116 | return self._get_pin(self._nfeat_pins, layer, rows, [dim])
117 |
118 | def _get_mem_data_pin(self, layer: int, rows: int) -> Tensor:
119 | return self._get_pin(self._mem_data_pins, layer, rows, [self._g.mem.dim()])
120 |
121 | def _get_mail_pin(self, layer: int, rows: int) -> Tensor:
122 | return self._get_pin(self._mail_pins, layer, rows, self._g.mailbox.dims())
123 |
124 | def _get_pin(self, cache: dict, layer: int, rows: int, dims: List[int]) -> Tensor:
125 | """
126 | Creates/reshapes pinned buffer and returns it.
127 |
128 | :param dict cache: The dictionary of the buffer with layer numbers as keys.
129 | :param int layer: Which layer pinned data belongs to.
130 | :param int rows: Number of rows of pinned data.
131 | :param List[int] dims: Size of each pinned data.
132 | :return: Pinned buffer.
133 | :rtype: Tensor
134 | """
135 | if layer not in cache:
136 | shape = tuple([rows] + dims)
137 | pin = torch.zeros(shape, pin_memory=True)
138 | cache[layer] = pin
139 | return pin
140 | pin = cache[layer]
141 | if pin.shape[0] < rows or list(pin.shape[1:]) != dims:
142 | shape = tuple([rows] + dims)
143 | pin.resize_(shape)
144 | return pin[:rows]
145 |
--------------------------------------------------------------------------------
/python/tglite/_core.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class TError(Exception):
4 | """Base error thrown by tglite."""
5 |
--------------------------------------------------------------------------------
/python/tglite/_frame.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import Tensor
4 |
5 | from ._core import TError
6 |
7 |
8 | class TFrame(dict):
9 | """A container for storing tensor features."""
10 |
11 | def __init__(self, dim=None):
12 | """Initialize a TFrame object"""
13 | super().__init__()
14 | self._dim = 0 if dim is None else dim
15 |
16 | def dim(self) -> int:
17 | """Get the leading dimension of stored tensors"""
18 | return self._dim
19 |
20 | def to(self, device, **kwargs) -> TFrame:
21 | """Move tensor data in this frame to the specified device.
22 |
23 | :param device: the device to which the tensor data should be moved
24 | :param **kwargs: additional keyword arguments for the pytorch Tensor.to() method
25 | :returns: a new TFrame object with the tensor data copied to the specified device
26 | """
27 | copy = TFrame(dim=self._dim)
28 | for key, val in self.items():
29 | copy[key] = val.to(device, **kwargs)
30 | return copy
31 |
32 | def __setitem__(self, key, value):
33 | """
34 | Set an item in the TFrame
35 |
36 | :raises TError: if the value is not a tensor or if it does not have the expected leading dimension
37 | """
38 | if not isinstance(value, Tensor):
39 | raise TError("expected value to be a tensor")
40 | if len(value) != self._dim:
41 | raise TError(f"expected value to have leading dimension of {self._dim}, got {len(value)}")
42 | super().__setitem__(key, value)
43 |
--------------------------------------------------------------------------------
/python/tglite/_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from pathlib import Path
5 | from typing import Optional, Union
6 |
7 | from ._core import TError
8 | from ._frame import TFrame
9 | from ._memory import Memory
10 | from ._mailbox import Mailbox
11 | from ._utils import create_tcsr, check_edges_times, check_num_nodes
12 |
13 |
14 | class TGraph(object):
15 | """A container for temporal graph and related tensor data. It initially stores temporal edges in COO format,
16 | sorted based on timestamp. While performing neighborhood sampling, it uses CSR format for faster lookups. TGLite
17 | automatically handles the construction and management of these graph formats without intervention from the user."""
18 |
19 | def __init__(self, edges: np.ndarray, times: np.ndarray, num_nodes: int = None):
20 | """
21 | Internal constructor for creating a TGraph
22 |
23 | :param np.ndarray edges:
24 | :param np.ndarray times:
25 | :param int num_nodes:
26 | """
27 | check_edges_times(edges, times)
28 | self._num_nodes = check_num_nodes(edges, num_nodes)
29 | self._efeat_frame = TFrame(dim=edges.shape[0])
30 | self._nfeat_frame = TFrame(dim=self._num_nodes)
31 | self._edata = TFrame(dim=edges.shape[0])
32 | self._ndata = TFrame(dim=self._num_nodes)
33 | self._edges = edges
34 | self._times = times
35 | self._tcsr = None
36 | self._mem = None
37 | self._mailbox = None
38 | self._storage_dev = torch.device('cpu')
39 | self._compute_dev = torch.device('cpu')
40 |
41 | @property
42 | def efeat(self) -> Optional[Tensor]:
43 | """Returns edge feature"""
44 | return self._efeat_frame.get('f')
45 |
46 | @efeat.setter
47 | def efeat(self, value):
48 | """
49 | Sets edge feature
50 |
51 | :param value: edge feature
52 | """
53 | if value is None:
54 | self._efeat_frame.clear()
55 | else:
56 | self._efeat_frame['f'] = value
57 |
58 | @property
59 | def nfeat(self) -> Optional[Tensor]:
60 | """Returns node feature"""
61 | return self._nfeat_frame.get('f')
62 |
63 | @nfeat.setter
64 | def nfeat(self, value):
65 | """
66 | Sets node feature
67 |
68 | :param value: edge feature
69 | """
70 | if value is None:
71 | self._nfeat_frame.clear()
72 | else:
73 | self._nfeat_frame['f'] = value
74 |
75 | @property
76 | def edata(self) -> TFrame:
77 | """Returns edge data"""
78 | return self._edata
79 |
80 | @property
81 | def ndata(self) -> TFrame:
82 | """Returns node data"""
83 | return self._ndata
84 |
85 | @property
86 | def mem(self) -> Optional[Memory]:
87 | """Returns node memory"""
88 | return self._mem
89 |
90 | @mem.setter
91 | def mem(self, value: Memory):
92 | """
93 | Sets node memory
94 |
95 | :param Memory value: node memory to set
96 | :raises TError: if value is not a Memory instance or its length doesn't equal to number of nodes,
97 | or value is not on this TGraph's storage device.
98 | """
99 | if not isinstance(value, Memory):
100 | raise TError('invalid memory object')
101 | if len(value) != self._num_nodes:
102 | raise TError('memory number of nodes mismatch')
103 | if value.device != self._storage_dev:
104 | raise TError('memory storage device mismatch')
105 | self._mem = value
106 |
107 | @property
108 | def mailbox(self) -> Optional[Mailbox]:
109 | """Returns node mailbox"""
110 | return self._mailbox
111 |
112 | @mailbox.setter
113 | def mailbox(self, value: Mailbox):
114 | """
115 | Sets mailbox
116 |
117 | :param Mailbox value: mailbox to set
118 | :raises TError: if value is not a Mailbox instance or its length doesn't equal to number of nodes,
119 | or value is not on this TGraph's storage device.
120 | """
121 | if not isinstance(value, Mailbox):
122 | raise TError('invalid mailbox object')
123 | if value.device != self._storage_dev:
124 | raise TError('mailbox storage device mismatch')
125 | # ... more checks here ...
126 | self._mailbox = value
127 |
128 | def storage_device(self) -> torch.device:
129 | """Returns TGraph's storage device"""
130 | return self._storage_dev
131 |
132 | def compute_device(self) -> torch.device:
133 | """Returns TGraph's computing device"""
134 | return self._compute_dev
135 |
136 | def num_nodes(self) -> int:
137 | """
138 | Total number of nodes
139 |
140 | :rtype: int
141 | """
142 | return self._num_nodes
143 |
144 | def num_edges(self) -> int:
145 | """
146 | Total number of edges
147 |
148 | :rtype: int
149 | """
150 | return self._edges.shape[0]
151 |
152 | def set_compute(self, device):
153 | """Sets computing device"""
154 | self._compute_dev = torch.device(device)
155 |
156 | def move_data(self, device, **kwargs):
157 | """Moves tensor data to device while keeping graph on CPU"""
158 | if self._storage_dev == device:
159 | return
160 | self._efeat_frame = self._efeat_frame.to(device, **kwargs)
161 | self._nfeat_frame = self._nfeat_frame.to(device, **kwargs)
162 | self._edata = self._edata.to(device, **kwargs),
163 | self._ndata = self._ndata.to(device, **kwargs),
164 | if self._mem is not None:
165 | self._mem.move_to(device, **kwargs)
166 | if self._mailbox is not None:
167 | self._mailbox.move_to(device, **kwargs)
168 | self._storage_dev = device
169 |
170 | def _init_tcsr(self):
171 | """Creates tcsr of the graph if it doesn't exist"""
172 | if self._tcsr is None:
173 | self._tcsr = create_tcsr(self._edges, self._times, num_nodes=self._num_nodes)
174 |
175 | def _get_tcsr(self):
176 | """Returns the tcsr of the graph"""
177 | self._init_tcsr()
178 | return self._tcsr
179 |
180 |
181 | def from_csv(path: Union[str, Path], skip_first=True) -> TGraph:
182 | """
183 | Creates a TGraph from a csv file
184 |
185 | :param path: csv file path
186 | :type path: str or Path
187 | :param bool skip_first: whether to skip the first line
188 | :rtype: TGraph
189 | :raises TError: if path doesn't exist
190 | """
191 | src, dst, ts = [], [], []
192 |
193 | path = Path(path)
194 | if not path.exists():
195 | raise TError(f'file does not exist: {path}')
196 |
197 | with path.open() as file:
198 | if skip_first:
199 | next(file)
200 | for line in file:
201 | line = line.strip().split(',')
202 | src.append(int(line[0]))
203 | dst.append(int(line[1]))
204 | ts.append(float(line[2]))
205 |
206 | src = np.array(src, dtype=np.int32).reshape(-1, 1)
207 | dst = np.array(dst, dtype=np.int32).reshape(-1, 1)
208 | edges = np.concatenate([src, dst], axis=1)
209 | del src
210 | del dst
211 |
212 | etime = np.array(ts, dtype=np.float32)
213 | del ts
214 |
215 | return TGraph(edges, etime)
216 |
--------------------------------------------------------------------------------
/python/tglite/_mailbox.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from typing import List, Union
5 |
6 |
7 | class Mailbox(object):
8 | """A container for node mailbox messages."""
9 |
10 | def __init__(self, num_nodes: int, size: int, dim: int, device=None):
11 | self._size = size
12 | self._device = torch.device('cpu' if device is None else device)
13 |
14 | self._mail = torch.zeros((num_nodes, size, dim), device=self._device).squeeze(dim=1)
15 | self._time = torch.zeros((num_nodes, size), device=self._device).squeeze(dim=1)
16 | if size > 1:
17 | self._next = torch.zeros(num_nodes, dtype=torch.long, device=self._device)
18 |
19 | @property
20 | def mail(self) -> Tensor:
21 | return self._mail
22 |
23 | @property
24 | def time(self) -> Tensor:
25 | return self._time
26 |
27 | @property
28 | def device(self) -> torch.device:
29 | return self._device
30 |
31 | def dims(self) -> List[int]:
32 | return list(self.mail.shape[1:])
33 |
34 | def reset(self):
35 | self._mail.zero_()
36 | self._time.zero_()
37 |
38 | def store(self, nids: Union[np.ndarray, Tensor], mail: Tensor, mail_ts: Tensor):
39 | if not isinstance(nids, Tensor):
40 | nids = torch.from_numpy(nids).long()
41 | nids = nids.to(self._device)
42 | mail = mail.detach().to(self._device)
43 | mail_ts = mail_ts.detach().to(self._device)
44 | if self._size == 1:
45 | self._mail[nids] = mail
46 | self._time[nids] = mail_ts
47 | else:
48 | pos = self._next[nids]
49 | self._mail[nids, pos] = mail
50 | self._time[nids, pos] = mail_ts
51 | self._next[nids] = torch.remainder(pos + 1, self._size)
52 |
53 | def move_to(self, device, **kwargs):
54 | if device is None or self._device == device:
55 | return
56 | self._mail = self._mail.to(device, **kwargs)
57 | self._time = self._time.to(device, **kwargs)
58 | self._device = device
59 |
--------------------------------------------------------------------------------
/python/tglite/_memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import Tensor
4 | from typing import Tuple, Union
5 |
6 | from ._core import TError
7 |
8 |
9 | class Memory(object):
10 | """A container for node memory."""
11 |
12 | def __init__(self, num_nodes: int, dim: int, device=None):
13 | """
14 | Internal constructor for creating a Memory. Initialize node memory and timestamps to zero.
15 |
16 | :param int num_nodes: Number of nodes
17 | :param int dim: Length of memory vector for a single node
18 | :param device: Which device to put node memory
19 | :type device: None or str or torch.device
20 | """
21 | self._device = torch.device('cpu' if device is None else device)
22 |
23 | self._data = torch.zeros((num_nodes, dim), device=self._device)
24 | self._time = torch.zeros(num_nodes, device=self._device)
25 |
26 | if list(self._data.shape) != [num_nodes, dim]:
27 | raise TError('memory data dimension mismatch')
28 | if self._time.shape[0] != num_nodes:
29 | raise TError('memory timestamp dimension mismatch')
30 |
31 | def __len__(self) -> int:
32 | """
33 | Return number of nodes in Memory
34 |
35 | :rtype: int
36 | """
37 | return self._data.shape[0]
38 |
39 | @property
40 | def data(self) -> Tensor:
41 | """Return node memory
42 |
43 | :rtype: Tensor
44 | """
45 | return self._data
46 |
47 | @property
48 | def time(self) -> Tensor:
49 | """Return timestamps of current node memory
50 |
51 | :rtype: Tensor
52 | """
53 | return self._time
54 |
55 | @property
56 | def device(self) -> torch.device:
57 | """Return the device where Memory is located
58 |
59 | :rtype: torch.device
60 | """
61 | return self._device
62 |
63 | def dim(self) -> int:
64 | """Return length of memory vector for a single node"""
65 | return self._data.shape[1]
66 |
67 | def reset(self):
68 | """Reset node memory and timestamps to zero"""
69 | self._data.zero_()
70 | self._time.zero_()
71 |
72 | def update(self, nids: Union[np.ndarray, Tensor], newdata: Tensor, newtime: Tensor):
73 | if not isinstance(nids, Tensor):
74 | nids = torch.from_numpy(nids).long()
75 | nids = nids.to(self._device)
76 | self._data[nids] = newdata.detach().to(self._device)
77 | self._time[nids] = newtime.detach().to(self._device)
78 |
79 | def move_to(self, device, **kwargs):
80 | if device is None or self._device == device:
81 | return
82 | self._data = self._data.to(device, **kwargs)
83 | self._time = self._time.to(device, **kwargs)
84 | self._device = device
85 |
86 | def backup(self) -> Tuple[Tensor, Tensor]:
87 | return (self._data.cpu().clone(), self._time.cpu().clone())
88 |
89 | def restore(self, state: Tuple[Tensor, Tensor]):
90 | data, time = state
91 | if self._data.shape != data.shape:
92 | raise TError('memory data dimension mismatch')
93 | if self._time.shape != time.shape:
94 | raise TError('memory timestamp dimension mismatch')
95 | self._data = data.clone().to(self._device)
96 | self._time = time.clone().to(self._device)
97 |
--------------------------------------------------------------------------------
/python/tglite/_sampler.py:
--------------------------------------------------------------------------------
1 | from . import _c
2 | from ._core import TError
3 | from ._block import TBlock
4 | from ._utils import get_num_cpus
5 | from ._stats import tt
6 |
7 |
8 | class TSampler(object):
9 |
10 | def __init__(self, num_nbrs: int, strategy='recent', num_threads: int = None):
11 | """
12 | Internal constructor for creating a TSampler
13 |
14 | :param int num_nbrs: number of neighbors
15 | :param str strategy: sampling strategy, 'recent' or 'uniform'
16 | :param int num_threads: number of threads for parallel sampling, set to number of cpus if not provided
17 | :raises TError: if strategy is not in ['recent', 'uniform']
18 | """
19 |
20 | if strategy not in ['recent', 'uniform']:
21 | raise TError(f'sampling strategy not supported: {strategy}')
22 |
23 | self._n_nbrs = num_nbrs
24 | self._strategy = strategy
25 | self._n_threads = get_num_cpus() \
26 | if num_threads is None else num_threads
27 |
28 | self._sampler = _c.TemporalSampler(
29 | self._n_threads,
30 | self._n_nbrs,
31 | self._strategy == 'recent')
32 |
33 | def sample(self, blk: TBlock) -> TBlock:
34 | """Updates block with sampled 1-hop source neighbors
35 |
36 | :returns: updated block
37 | """
38 | t_start = tt.start()
39 | if blk.num_dst() > 0:
40 | block = self._sampler.sample(blk._g._get_tcsr(), blk._dstnodes, blk._dsttimes)
41 | blk.set_nbrs(
42 | block.copy_dstindex(),
43 | block.copy_srcnodes(),
44 | block.copy_eid(),
45 | block.copy_ets())
46 | tt.t_sample += tt.elapsed(t_start)
47 | return blk
48 |
--------------------------------------------------------------------------------
/python/tglite/_stats.py:
--------------------------------------------------------------------------------
1 | import time
2 | from pathlib import Path
3 |
4 |
5 | class TimeTable(object):
6 | def __init__(self):
7 | self.csv = None
8 | self.reset_epoch()
9 |
10 | def reset_epoch(self):
11 | """Set all time records to zeros"""
12 | self.t_epoch = 0.0
13 | self.t_loop = 0.0
14 | self.t_eval = 0.0
15 | self.t_forward = 0.0
16 | self.t_backward = 0.0
17 | self.t_sample = 0.0
18 | self.t_prep_batch = 0.0
19 | self.t_prep_input = 0.0
20 | self.t_post_update = 0.0
21 | self.t_mem_update = 0.0
22 | self.t_time_zero = 0.0
23 | self.t_time_nbrs = 0.0
24 | self.t_self_attn = 0.0
25 |
26 | def start(self):
27 | # Uncomment for better breakdown timings
28 | #torch.cuda.synchronize()
29 | return time.perf_counter()
30 |
31 | def elapsed(self, start):
32 | # Uncomment for better breakdown timings
33 | #torch.cuda.synchronize()
34 | return time.perf_counter() - start
35 |
36 | def print_epoch(self, prefix=' '):
37 | """Print the timing breakdown of different components in an epoch"""
38 | lines = f'' \
39 | f'{prefix}epoch | total:{self.t_epoch:.2f}s loop:{self.t_loop:.2f}s eval:{self.t_eval:.2f}s\n' \
40 | f'{prefix} loop | forward:{self.t_forward:.2f}s backward:{self.t_backward:.2f}s sample:{self.t_sample:.2f}s prep_batch:{self.t_prep_batch:.2f}s prep_input:{self.t_prep_input:.2f}s post_update:{self.t_post_update:.2f}s\n' \
41 | f'{prefix} comp | mem_update:{self.t_mem_update:.2f}s time_zero:{self.t_time_zero:.2f}s time_nbrs:{self.t_time_nbrs:.2f}s self_attn:{self.t_self_attn:.2f}s\n'
42 | print(lines, end='')
43 |
44 | def csv_open(self, path):
45 | """Close the opened file (if any) and open a new file in write mode"""
46 | self.csv_close()
47 | self.csv = Path(path).open('w')
48 |
49 | def csv_close(self):
50 | """Close the opened file (if any)"""
51 | if self.csv is not None:
52 | self.csv.close()
53 | self.csv = None
54 |
55 | def csv_write_header(self):
56 | """Write the header line to the CSV file"""
57 | header = 'epoch,total,loop,eval,' \
58 | 'forward,backward,sample,prep_batch,prep_input,post_update,' \
59 | 'mem_update,time_zero,time_nbrs,self_attn'
60 | self.csv.write(header + '\n')
61 |
62 | def csv_write_line(self, epoch):
63 | """Write a line of timing information to the CSV file"""
64 | line = f'{epoch},{self.t_epoch},{self.t_loop},{self.t_eval},' \
65 | f'{self.t_forward},{self.t_backward},{self.t_sample},{self.t_prep_batch},{self.t_prep_input},{self.t_post_update},' \
66 | f'{self.t_mem_update},{self.t_time_zero},{self.t_time_nbrs},{self.t_self_attn}'
67 | self.csv.write(line + '\n')
68 |
69 |
70 | # Global for accumulating timings.
71 | tt = TimeTable()
72 |
--------------------------------------------------------------------------------
/python/tglite/_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from . import _c
5 | from ._core import TError
6 |
7 |
8 | def get_num_cpus(default=16) -> int:
9 | cpus = os.cpu_count()
10 | return default if cpus is None else cpus
11 |
12 |
13 | def check_edges_times(edges: np.ndarray, times: np.ndarray):
14 | if edges.shape[0] != times.shape[0]:
15 | raise TError("edge list and timestamps must have same leading dimension")
16 | if edges.shape[1] != 2:
17 | raise TError("edge list must have only 2 columns")
18 | if edges.dtype != np.int32:
19 | raise TError("currently only supports int32 node/edge ids")
20 | if times.dtype != np.float32:
21 | raise TError("currently only supports float32 timestamps")
22 |
23 |
24 | def check_num_nodes(edges: np.ndarray, num_nodes: int = None) -> int:
25 | """Returns the number of nodes in the graph represented by the given edges
26 |
27 | :raises TErrror: if the specified number of nodes is less than or equal to the number of distinct nodes present in the edges
28 | """
29 | max_nid = int(edges.max())
30 | num_nodes = max_nid + 1 \
31 | if num_nodes is None else num_nodes
32 | if num_nodes <= max_nid:
33 | raise TError("number of nodes must be greater than max node id")
34 | return num_nodes
35 |
36 |
37 | def create_tcsr(edges: np.ndarray, times: np.ndarray, num_nodes: int = None):
38 | check_edges_times(edges, times)
39 | num_nodes = check_num_nodes(edges, num_nodes)
40 | return _c.create_tcsr(edges, times, num_nodes)
41 |
--------------------------------------------------------------------------------
/python/tglite/nn.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 | if TYPE_CHECKING:
4 | from ._block import TBlock
5 | from ._context import TContext
6 |
7 | import torch
8 | import numpy as np
9 | from torch import Tensor
10 |
11 | from ._stats import tt
12 | from .op import precomputed_zeros, precomputed_times, edge_reduce, edge_view, edge_softmax
13 |
14 |
15 | class TimeEncode(torch.nn.Module):
16 |
17 | # A hidden tag used to detect if we have an instance of this builtin encoder
18 | # so that we can call the more optimized method of generating zeros, without
19 | # running into situation with circular dependencies.
20 | __tg_builtin_encoder__ = True
21 |
22 | def __init__(self, dim_time: int):
23 | '''
24 | Initializes the TimeEncode module, which encodes time information into a higher-dimensional space.
25 |
26 | :param dim_time: dimensionality of the encoded time
27 | '''
28 | super().__init__()
29 | self.w = torch.nn.Linear(1, dim_time)
30 | self.w.weight = torch.nn.Parameter(torch
31 | .from_numpy(1 / 10 ** np.linspace(0, 9, dim_time))
32 | .float().reshape(dim_time, 1))
33 | self.w.bias = torch.nn.Parameter(torch.zeros(dim_time).float())
34 | self._z = torch.zeros(1).float()
35 |
36 | def zeros(self, size: int, device):
37 | '''
38 | Generates a tensor of zeros with the encoded time dimensionality.
39 |
40 | :param size:
41 | :param device:
42 | '''
43 | if self._z.device != torch.device(device):
44 | self._z = self._z.to(device)
45 | # expand does not allocate memory
46 | view = self._z.expand(size)
47 | return self(view)
48 |
49 | def forward(self, ts: Tensor) -> Tensor:
50 | '''
51 | Forward pass of the TimeEncode module. Encodes the input time stamps into a high-dimensional space.
52 |
53 | :param ts: input time stamps
54 | '''
55 | return torch.cos(self.w(ts.unsqueeze(-1)))
56 |
57 |
58 | class TemporalAttnLayer(torch.nn.Module):
59 | def __init__(self, ctx: TContext, num_heads: int,
60 | dim_node: int, dim_edge: int, dim_time: int, dim_out: int,
61 | dropout=0.1):
62 | """
63 | Initializes the Temporal Attention Layer for processing dynamic graphs with temporal features.
64 | This layer uses multi-head attention mechanism to incorporate node, edge, and time features.
65 |
66 | :param ctx: context object
67 | :param num_heads: number of heads
68 | :param dim_node: dimension of node features
69 | :param dim_edge: dimension of edge features
70 | :param dim_time: dimension of time features
71 | :param dim_out: dimension of output features
72 | :param dropout: dropout rate
73 | """
74 | super().__init__()
75 | assert (dim_out % num_heads == 0)
76 | self.ctx = ctx
77 | self.num_heads = num_heads
78 | self.dim_edge = dim_edge
79 | self.dim_out = dim_out
80 | self.time_encode = TimeEncode(dim_time)
81 | self.w_q = torch.nn.Linear(dim_node + dim_time, dim_out)
82 | self.w_kv = torch.nn.Linear(dim_node + dim_edge + dim_time, dim_out * 2)
83 | self.w_out = torch.nn.Linear(dim_node + dim_out, dim_out)
84 | self.attn_act = torch.nn.LeakyReLU(0.2)
85 | self.dropout = torch.nn.Dropout(dropout)
86 | self.layer_norm = torch.nn.LayerNorm(dim_out)
87 |
88 | def forward(self, blk: TBlock) -> Tensor:
89 | '''
90 | Forward pass of the Temporal Attention Layer. Applies a time-sensitive attention mechanism over
91 | the input graph block (blk) to produce node embeddings. The method handles both cases of graph blocks
92 | with and without edges.
93 |
94 | If the block has no edges, a zero-initialized tensor is concatenated with the destination node features.
95 | For blocks with edges, the method computes attention scores and aggregates neighbor features using
96 | the computed attention.
97 |
98 | :param blk: input graph block
99 | '''
100 | if blk.num_edges() == 0:
101 | dev = blk.dstdata['h'].device
102 | out = torch.zeros(blk.num_dst(), self.dim_out, dtype=torch.float32, device=dev)
103 | out = torch.cat([out, blk.dstdata['h']], dim=1)
104 | else:
105 | t_start = tt.start()
106 | zero_time_feat = precomputed_zeros(self.ctx, blk.layer, self.time_encode, blk.num_dst())
107 | tt.t_time_zero += tt.elapsed(t_start)
108 | t_start = tt.start()
109 | nbrs_time_feat = precomputed_times(self.ctx, blk.layer, self.time_encode, blk.time_deltas())
110 | tt.t_time_nbrs += tt.elapsed(t_start)
111 | t_start = tt.start()
112 | Q = torch.cat([blk.dstdata['h'], zero_time_feat], dim=1)
113 | if self.dim_edge > 0:
114 | Z = torch.cat([blk.srcdata['h'], blk.efeat(), nbrs_time_feat], dim=1)
115 | else:
116 | Z = torch.cat([blk.srcdata['h'], nbrs_time_feat], dim=1)
117 | del zero_time_feat
118 | del nbrs_time_feat
119 |
120 | Q = self.w_q(Q)
121 | Z = self.w_kv(Z)
122 | K = Z[:, :self.dim_out]
123 | V = Z[:, self.dim_out:]
124 | del Z
125 |
126 | Q = edge_view(blk, Q)
127 | Q = torch.reshape(Q, (Q.shape[0], self.num_heads, -1))
128 | K = torch.reshape(K, (K.shape[0], self.num_heads, -1))
129 | V = torch.reshape(V, (V.shape[0], self.num_heads, -1))
130 |
131 | attn = torch.sum(Q * K, dim=2)
132 | del Q
133 | del K
134 |
135 | attn = self.attn_act(attn)
136 | attn = edge_softmax(blk, attn)
137 | attn = self.dropout(attn)
138 |
139 | out = torch.reshape(V * attn[:, :, None], (V.shape[0], -1))
140 | del attn
141 |
142 | out = edge_reduce(blk, out, op='sum')
143 | out = torch.cat([out, blk.dstdata['h']], dim=1)
144 | tt.t_self_attn += tt.elapsed(t_start)
145 |
146 | out = self.w_out(out)
147 | out = torch.nn.functional.relu(self.dropout(out))
148 | out = self.layer_norm(out)
149 | return out
150 |
151 |
--------------------------------------------------------------------------------
/scripts/run-exp-large.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Run experiments for larger benchmarks.
4 | #
5 |
6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)"
7 | cd "$tglite/examples"
8 |
9 | echo "start: $(date)"
10 |
11 | for data in wiki-talk gdelt; do
12 | for model in apan jodie tgat tgn; do
13 | echo "tglite $data $model";
14 | if [[ "$model" != "jodie" ]]; then
15 | "./exp/$model-gdelt.sh" --data "$data" --n-threads "$(nproc)" --opt-all
16 | else
17 | "./exp/$model-gdelt.sh" --data "$data"
18 | fi
19 | mv out-stats.csv "out-tglite-$data-$model.csv";
20 | echo;
21 | echo "time: $(date)"
22 | echo;
23 | done
24 | done
25 |
26 | echo "end: $(date)"
27 |
--------------------------------------------------------------------------------
/scripts/run-exp-slurm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --mem=64g
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --cpus-per-task=64 # <- match to OMP_NUM_THREADS
6 | #SBATCH --partition=gpuA100x4
7 | #SBATCH --time=00:10:00
8 | #SBATCH --account=bbzw-delta-gpu
9 | #SBATCH --job-name=jodie-wiki
10 | #SBATCH --output=test.out
11 | #SBATCH --error=test.err
12 | ### GPU options ###
13 | #SBATCH --gpus-per-node=1
14 | #SBATCH --gpus-per-task=1
15 | #SBATCH --gpu-bind=verbose,per_task:1
16 | ###SBATCH --gpu-bind=none # <- or closest
17 |
18 | source ~/.bashrc
19 | conda deactivate
20 | module purge
21 | module load anaconda3_gpu
22 | module list
23 |
24 | conda activate tglite
25 | conda info -e
26 |
27 | echo "job is starting on `hostname`"
28 |
29 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)"
30 | cd "$tglite/examples"
31 |
32 | echo "start: $(date)"
33 |
34 | srun python3 \
35 | jodie.py \
36 | --seed 0 \
37 | --prefix exp \
38 | --epochs 50 \
39 | --bsize 2000 \
40 | -d wiki
41 |
42 | echo "end $(date)"
43 |
44 | exit
45 |
--------------------------------------------------------------------------------
/scripts/run-exp.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Run experiments for standard benchmarks.
4 | #
5 |
6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)"
7 | cd "$tglite/examples"
8 |
9 | echo "start: $(date)"
10 |
11 | for data in wiki mooc reddit lastfm; do
12 | for model in apan jodie tgat tgn; do
13 | common_flags="--data $data"
14 | if [[ "$model" != "jodie" ]]; then
15 | common_flags+=" --n-threads $(nproc)"
16 | fi
17 |
18 | echo "tglite $data $model";
19 | "./exp/$model.sh" $common_flags
20 | mv out-stats.csv "out-tglite-$data-$model.csv";
21 | echo;
22 | echo "time: $(date)"
23 | echo;
24 |
25 | "./exp/$model.sh" $common_flags --move
26 | mv out-stats.csv "out-tglite-allgpu-$data-$model.csv";
27 | echo;
28 | echo "time: $(date)"
29 | echo;
30 |
31 | if [[ "$model" != "jodie" ]]; then
32 | "./exp/$model.sh" $common_flags --opt-all
33 | mv out-stats.csv "out-tglite-opt-$data-$model.csv";
34 | echo;
35 | echo "time: $(date)"
36 | echo;
37 | fi
38 | done
39 | done
40 |
41 | echo "end: $(date)"
42 |
--------------------------------------------------------------------------------
/scripts/setup-aws.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Script to setup an AWS EC2 instance with the following expected provisioning:
4 | #
5 | # - AMI: Amazon Linux 2 AMI with NVIDIA TESLA GPU Driver
6 | # - Instance Type: p3.8xlarge (32 vCPU, 244 GiB, Tesla V100 GPU)
7 | # - Storage: at least 80 GB
8 | #
9 |
10 | conda_dir="$HOME/.conda"
11 | conda_bin="$conda_dir/bin/conda"
12 |
13 | echo
14 | echo ">> installing conda"
15 | echo
16 |
17 | curl -sL -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.12.0-Linux-x86_64.sh
18 | bash ~/miniconda.sh -b -p "$conda_dir"
19 | rm ~/miniconda.sh
20 |
21 | "$conda_bin" init
22 | "$conda_bin" config --set auto_activate_base false
23 | source ~/.conda/etc/profile.d/conda.sh
24 |
25 | echo
26 | echo ">> setting up cuda11"
27 | echo
28 |
29 | conda create -n cuda11
30 | conda activate cuda11
31 | conda install -c "nvidia/label/cuda-11.8.0" cuda-libraries
32 | echo 'export LD_LIBRARY_PATH="$HOME/.conda/envs/cuda11/lib:$LD_LIBRARY_PATH"' >> ~/.bashrc
33 |
34 | echo
35 | echo ">> done! please restart your shell session"
36 |
--------------------------------------------------------------------------------
/scripts/setup-repo.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Script to setup environment for this repo.
4 | #
5 |
6 | tglite="$(cd "$(dirname "$0")"; cd ..; pwd)"
7 | cd "$tglite"
8 |
9 | echo
10 | echo ">> setting up environment"
11 | echo
12 |
13 | source ~/.conda/etc/profile.d/conda.sh
14 | conda create -n tglite python=3.7
15 | conda activate tglite
16 |
17 | echo
18 | echo ">> installing python packages"
19 | echo
20 |
21 | pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
22 | pip install torch-scatter==2.1.0+pt112cu116 -f https://data.pyg.org/whl/torch-1.12.1+cu116.html
23 |
24 | echo
25 | echo ">> installing tglite package"
26 | echo
27 |
28 | python setup.py develop
29 |
30 | echo
31 | echo ">> setting up example applications"
32 | echo
33 |
34 | cd "$tglite/examples"
35 | pip install -r requirements.txt
36 | ./download-data.sh
37 | python gen-data-files.py --data wiki-talk
38 |
39 | echo
40 | echo ">> done!"
41 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CppExtension
4 |
5 |
6 | def main():
7 | curr_dir = Path(__file__).absolute().parent
8 | setup(
9 | ext_modules=[
10 | CppExtension(
11 | name="tglite._c",
12 | sources=[
13 | "lib/bind.cpp",
14 | "lib/cache.cpp",
15 | "lib/dedup.cpp",
16 | "lib/sampler.cpp",
17 | "lib/tcsr.cpp",
18 | "lib/utils.cpp"
19 | ],
20 | include_dirs=[curr_dir/"include/"],
21 | extra_compile_args=["-std=c++14", "-fopenmp"],
22 | extra_link_args=["-fopenmp"]
23 | )],
24 | cmdclass={
25 | "build_ext": BuildExtension
26 | })
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/tests/python/data/edges.csv:
--------------------------------------------------------------------------------
1 | src,dst,time
2 | 1,8228,0.0
3 | 2,8229,36.0
4 | 2,8229,77.0
5 | 3,8230,131.0
6 | 2,8229,150.0
7 | 3,8230,153.0
8 | 4,8231,169.0
9 | 2,8229,217.0
10 | 5,8232,218.0
11 | 5,8232,242.0
12 | 6,8233,295.0
13 | 2,8229,300.0
14 | 2,8229,376.0
15 | 3,8230,380.0
16 | 7,8234,384.0
17 | 8,8228,384.0
18 | 4,8231,387.0
19 | 9,8235,395.0
20 | 2,8229,402.0
21 | 3,8230,402.0
22 | 3,8230,432.0
23 | 10,8236,454.0
24 | 2,8229,465.0
25 | 7,8234,563.0
26 | 5,8232,578.0
27 | 2,8229,623.0
28 | 4,8231,646.0
29 | 3,8230,674.0
30 | 2,8229,692.0
31 | 3,8230,715.0
32 | 11,8237,729.0
33 | 12,8238,742.0
34 | 4,8231,742.0
35 | 4,8231,809.0
36 | 13,8239,854.0
37 | 2,8229,854.0
38 | 5,8232,872.0
39 | 14,8240,904.0
40 | 15,8241,939.0
41 | 13,8239,959.0
42 | 2,8229,989.0
43 | 15,8241,1026.0
44 | 15,8241,1072.0
45 | 16,8242,1142.0
46 | 13,8239,1151.0
47 | 17,8243,1153.0
48 | 18,8244,1177.0
49 | 5,8232,1199.0
50 | 5,8232,1214.0
51 | 4,8231,1226.0
52 | 19,8245,1255.0
53 | 12,8238,1301.0
54 | 20,8246,1307.0
55 | 21,8247,1308.0
56 | 22,8248,1314.0
57 | 23,8249,1340.0
58 | 4,8231,1401.0
59 | 18,8250,1451.0
60 | 12,8238,1496.0
61 | 5,8232,1508.0
62 | 24,8251,1529.0
63 | 5,8232,1531.0
64 | 25,8252,1543.0
65 | 13,8239,1555.0
66 | 26,8253,1600.0
67 | 4,8231,1607.0
68 | 5,8232,1630.0
69 | 3,8230,1643.0
70 | 27,8254,1651.0
71 | 5,8232,1664.0
72 | 5,8232,1677.0
73 | 28,8255,1682.0
74 | 12,8238,1706.0
75 | 13,8239,1713.0
76 | 29,8256,1724.0
77 | 19,8245,1729.0
78 | 30,8257,1771.0
79 | 12,8238,1789.0
80 | 3,8230,1797.0
81 | 9,8235,1805.0
82 | 5,8232,1814.0
83 | 5,8232,1827.0
84 | 31,8258,1908.0
85 | 32,8259,1914.0
86 | 29,8256,1925.0
87 | 29,8256,1996.0
88 | 17,8250,2022.0
89 | 33,8260,2053.0
90 | 5,8232,2063.0
91 | 4,8231,2075.0
92 | 34,8261,2079.0
93 | 9,8235,2098.0
94 | 5,8232,2127.0
95 | 17,8250,2131.0
96 | 25,8252,2233.0
97 | 5,8232,2285.0
98 | 4,8231,2290.0
99 | 2,8229,2313.0
100 | 4,8231,2333.0
101 | 35,8262,2343.0
102 |
--------------------------------------------------------------------------------
/tests/python/test_block.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from pathlib import Path
4 |
5 | import tglite as tg
6 |
7 |
8 | def test_tblock():
9 | g = tg.from_csv(Path(__file__).parent / 'data/edges.csv')
10 | ctx = tg.TContext(g)
11 | batch = next(tg.iter_edges(g, size=10))
12 | assert len(batch) == 10
13 |
14 | blk = batch.block(ctx)
15 | assert blk.layer == 0
16 | assert blk.num_dst() == len(batch) * 2 # src, dst nodes of batch edges
17 | assert blk.num_dst() == len(blk.dsttimes)
18 | assert blk.has_nbrs() == False # no neighbor info before sampling
19 |
--------------------------------------------------------------------------------
/tests/python/test_frame.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | import tglite as tg
5 |
6 |
7 | def test_tframe():
8 | frame = tg.TFrame(dim=3)
9 | frame['f'] = torch.ones((3, 2))
10 | assert frame['f'].shape == (3, 2)
11 | assert frame['f'].sum() == 6
12 |
13 |
14 | def test_tframe_dim():
15 | frame = tg.TFrame()
16 | assert frame.dim() == 0
17 | frame = tg.TFrame(dim=16)
18 | assert frame.dim() == 16
19 |
20 |
21 | def test_tframe_checks_dim():
22 | with pytest.raises(tg.TError) as exinfo:
23 | frame = tg.TFrame(dim=3)
24 | frame['f'] = torch.ones((1, 2))
25 | assert "dimension of 3, got 1" in str(exinfo.value)
26 |
27 |
28 | def test_tframe_only_tensors():
29 | frame = tg.TFrame(dim=3)
30 | frame['tensor'] = torch.randn(3, 2)
31 | with pytest.raises(tg.TError) as exinfo:
32 | frame['list'] = [1, 2, 3]
33 | assert "expected value to be a tensor" in str(exinfo.value)
34 |
35 |
36 | def test_tframe_dict_behavior():
37 | frame = tg.TFrame(dim=3)
38 | frame['a'] = torch.randn((3, 1))
39 | frame['b'] = torch.randn((3, 2))
40 | frame['c'] = torch.randn((3, 3))
41 | frame['d'] = torch.randn((3, 4))
42 |
43 | assert 'a' in frame
44 | assert len(frame) == 4
45 | for key, val in frame.items():
46 | if key == 'a': assert val.shape[1] == 1
47 | if key == 'b': assert val.shape[1] == 2
48 | if key == 'c': assert val.shape[1] == 3
49 | if key == 'd': assert val.shape[1] == 4
50 |
51 | frame.clear()
52 | assert frame.dim() == 3
53 | assert len(frame) == 0
54 |
--------------------------------------------------------------------------------
/tests/python/test_graph.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from pathlib import Path
5 |
6 | import tglite as tg
7 |
8 |
9 | def test_from_csv_nofile():
10 | with pytest.raises(tg.TError) as exinfo:
11 | tg.from_csv("foobar")
12 | assert "file does not exist" in str(exinfo.value)
13 |
14 |
15 | def test_from_csv():
16 | path = Path(__file__).parent / 'data/edges.csv'
17 | g = tg.from_csv(path)
18 | assert g.num_edges() == 100
19 | assert g.num_nodes() == 8263
20 |
21 |
22 | def test_tgraph():
23 | edges = np.array([[0,1], [0,2], [1,2]], dtype=np.int32)
24 | etime = np.array([10, 11, 12], dtype=np.float32)
25 |
26 | g = tg.TGraph(edges, etime)
27 | g.edata['f'] = torch.randn((3, 2))
28 | g.ndata['f'] = torch.randn((3, 2))
29 |
30 | assert g.num_edges() == 3
31 | assert g.num_nodes() == 3
32 | assert g.edata['f'].shape[1] == 2
33 | assert g.ndata['f'].shape[1] == 2
34 | assert str(g.storage_device()) == 'cpu'
35 | assert str(g.compute_device()) == 'cpu'
36 |
--------------------------------------------------------------------------------
/tests/python/test_tglite.py:
--------------------------------------------------------------------------------
1 | import tglite as tg
2 |
3 |
4 | def test_terror():
5 | err = tg.TError("message")
6 | assert str(err) == "message"
7 |
--------------------------------------------------------------------------------
/tests/python/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | import tglite as tg
5 |
6 |
7 | def test_tcsr():
8 | edges = np.array([[0,1], [0,2], [1,2]], dtype=np.int32)
9 | etime = np.array([10, 11, 12], dtype=np.float32)
10 | tcsr = tg.utils.create_tcsr(edges, etime)
11 |
12 | assert len(tcsr.ind) == len(edges) + 1
13 | assert list(tcsr.ind) == [0, 2, 4, 6]
14 | assert list(tcsr.nbr) == [1, 2, 0, 2, 0, 1]
15 | assert list(tcsr.eid) == [0, 1, 0, 2, 1, 2]
16 | assert list(tcsr.ets) == [10, 11, 10, 12, 11, 12]
17 |
--------------------------------------------------------------------------------