├── .gitattributes
├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── dataset
├── chickenpox.json
├── england_covid.json
├── montevideo_bus.json
├── mtm_1.json
├── pedalme_london.json
├── twitter_tennis_rg17.json
├── twitter_tennis_uo17.json
└── wikivital_mathematics.json
├── docs
├── Makefile
├── index.html
├── mapping.py
├── requirements.txt
└── source
│ ├── _figures
│ └── build.sh
│ ├── _static
│ ├── css
│ │ └── custom.css
│ └── img
│ │ ├── logo.jpg
│ │ └── text_logo.jpg
│ ├── conf.py
│ ├── index.rst
│ ├── modules
│ ├── dataset.rst
│ ├── root.rst
│ └── signal.rst
│ └── notes
│ ├── installation.rst
│ ├── introduction.rst
│ └── resources.rst
├── examples
├── indexBatching
│ ├── A3TGCN
│ │ ├── metr_la_main.py
│ │ └── pems_ddp.py
│ ├── DCRNN
│ │ ├── chicken_pox_main.py
│ │ ├── pems_allLA_main.py
│ │ ├── pems_bay_main.py
│ │ ├── pems_ddp.py
│ │ ├── pems_main.py
│ │ ├── submit.sh
│ │ ├── utils.py
│ │ └── windmill_main.py
│ └── README.md
└── recurrent
│ ├── a3tgcn2_example.py
│ ├── a3tgcn_example.py
│ ├── agcrn_example.py
│ ├── dcrnn_example.py
│ ├── dygrencoder_example.py
│ ├── evolvegcnh_example.py
│ ├── evolvegcno_example.py
│ ├── gclstm_example.py
│ ├── gconvgru_example.py
│ ├── gconvlstm_example.py
│ ├── lightning_example.py
│ ├── lrgcn_example.py
│ ├── mpnnlstm_example.py
│ └── tgcn_example.py
├── notebooks
├── a3tgcn_for_traffic_forecasting.ipynb
├── astgcn_for_traffic_flow_forecasting.ipynb
└── processing_traffic_data_for_deep_learning_projects.ipynb
├── readthedocs.yml
├── setup.py
├── test
├── attention_test.py
├── batch_test.py
├── dataset_test.py
├── heterogeneous_test.py
└── recurrent_test.py
└── torch_geometric_temporal
├── __init__.py
├── dataset
├── __init__.py
├── chickenpox.py
├── encovid.py
├── metr_la.py
├── montevideo_bus.py
├── mtm.py
├── pedalme.py
├── pems.py
├── pemsAllLA.py
├── pems_bay.py
├── twitter_tennis.py
├── wikimath.py
├── windmilllarge.py
├── windmillmedium.py
└── windmillsmall.py
├── nn
├── __init__.py
├── attention
│ ├── __init__.py
│ ├── astgcn.py
│ ├── dnntsp.py
│ ├── gman.py
│ ├── mstgcn.py
│ ├── mtgnn.py
│ ├── stgcn.py
│ └── tsagcn.py
├── hetero
│ ├── __init__.py
│ └── heterogclstm.py
└── recurrent
│ ├── __init__.py
│ ├── agcrn.py
│ ├── attentiontemporalgcn.py
│ ├── dcrnn.py
│ ├── dygrae.py
│ ├── evolvegcnh.py
│ ├── evolvegcno.py
│ ├── gc_lstm.py
│ ├── gconv_gru.py
│ ├── gconv_lstm.py
│ ├── lrgcn.py
│ ├── mpnn_lstm.py
│ └── temporalgcn.py
└── signal
├── __init__.py
├── dynamic_graph_static_signal.py
├── dynamic_graph_static_signal_batch.py
├── dynamic_graph_temporal_signal.py
├── dynamic_graph_temporal_signal_batch.py
├── dynamic_hetero_graph_static_signal.py
├── dynamic_hetero_graph_static_signal_batch.py
├── dynamic_hetero_graph_temporal_signal.py
├── dynamic_hetero_graph_temporal_signal_batch.py
├── index_dataset.py
├── static_graph_temporal_signal.py
├── static_graph_temporal_signal_batch.py
├── static_hetero_graph_temporal_signal.py
├── static_hetero_graph_temporal_signal_batch.py
└── train_test_split.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-detectable=false
2 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 | runs-on: ${{ matrix.os }}
13 |
14 | strategy:
15 | matrix:
16 | os: [ubuntu-22.04]
17 |
18 | steps:
19 | - uses: actions/checkout@v2
20 | - uses: actions/setup-python@v4
21 | with:
22 | python-version: 3.8
23 | - uses: s-weigand/setup-conda@v1
24 | with:
25 | activate-conda: true
26 | python-version: 3.8
27 | - run: conda --version
28 | - run: which python
29 | - name: Install main dependencies
30 | run: |
31 | python -m pip install torch==2.3.0 torchvision torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html
32 | python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
33 | python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
34 | python -m pip install torch-geometric
35 | python -m pip install sphinx sphinx_rtd_theme
36 | - name: Install main package
37 | run: |
38 | python -m pip install -e .[test]
39 | - name: Run test-suite
40 | run: |
41 | python -m pytest
42 | - name: Generate coverage report
43 | if: success()
44 | run: |
45 | pip install coverage
46 | coverage run -m pytest
47 | coverage xml
48 | - name: Upload coverage report to codecov
49 | uses: codecov/codecov-action@v1
50 | if: success()
51 | with:
52 | file: coverage.xml
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # vs code
132 | .vscode/
133 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | I'm really glad you're reading this, because we need volunteer developers to help this project come to fruition.
4 |
5 | Here are some important things to check when you contribute:
6 |
7 | * Please make sure that you write tests.
8 | * Update the documentation.
9 | * Add the new model to the readme.
10 | * If your contribution is a paper please update the resource documentation file.
11 |
12 | ## Testing
13 |
14 |
15 | PyTorch Geometric Temporal's testing is located under `test/`.
16 | Run the entire test suite with
17 |
18 | ```
19 | python setup.py test
20 | ```
21 |
22 | ## Submitting changes
23 |
24 | Please send a [GitHub Pull Request to PyTorch Geometric Temporal](https://github.com/benedekrozemberczki/pytorch_geometric_temporal/pull/new/master) with a clear list of what you've done (read more about [pull requests](http://help.github.com/pull-requests/)). Please follow our coding conventions (below) and make sure all of your commits are atomic (one feature per commit).
25 |
26 | Always write a clear log message for your commits. One-line messages are fine for small changes, but bigger changes should look like this:
27 |
28 | $ git commit -m "A brief summary of the commit
29 | >
30 | > A paragraph describing what changed and its impact."
31 |
32 | ## Coding conventions
33 |
34 | Start reading our code and you'll get the hang of it. We optimize for readability:
35 |
36 | * We write tests for the data loaders, iterators and layers.
37 | * We use the type hinting feature of Python.
38 | * We avoid the uses of public methods and vaiarbles in the classes.
39 | * Hyperparameters belong to the constructors.
40 | * Auxiliiary layer instances should have long names.
41 | * Make linear algebra operations line-by-line.
42 |
43 | Thanks,
44 | Benedek
45 |
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Benedek Rozemberczki
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | SPHINXBUILD = sphinx-build
2 | SPHINXPROJ = pytorch_geometric_temporal
3 | SOURCEDIR = source
4 | BUILDDIR = build
5 |
6 | .PHONY: help Makefile
7 |
8 | %: Makefile
9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)"
10 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Redirect
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/mapping.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | from collections import defaultdict
4 |
5 | def find_defined_classes_in_file(filepath):
6 | with open(filepath, 'r', encoding='utf-8') as file:
7 | try:
8 | tree = ast.parse(file.read(), filename=filepath)
9 | except SyntaxError:
10 | return []
11 |
12 | return [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]
13 |
14 | def map_defined_classes(directories):
15 | file_to_classes = defaultdict(list)
16 | all_classes = set()
17 | for directory in directories:
18 | for root, _, files in os.walk(directory):
19 | for file in files:
20 | if file.endswith('.py'):
21 | full_path = os.path.join(root, file)
22 | class_names = find_defined_classes_in_file(full_path)
23 | if class_names:
24 | file_to_classes[full_path[3:].replace("/",".")].extend(class_names)
25 | all_classes.update(class_names)
26 |
27 | return all_classes, file_to_classes
28 |
29 | if __name__ == "__main__":
30 | # Replace these with the paths you want to analyze
31 | directories_to_scan = [
32 | "../torch_geometric_temporal/nn/recurrent",
33 | "../torch_geometric_temporal/nn/attention",
34 | "../torch_geometric_temporal/nn/hetero"
35 | ]
36 |
37 | all_classes, mapping = map_defined_classes(directories_to_scan)
38 |
39 |
40 | order = [
41 |
42 | "torch_geometric_temporal.nn.recurrent.gconv_gru",
43 | "torch_geometric_temporal.nn.recurrent.gconv_lstm",
44 | "torch_geometric_temporal.nn.recurrent.gc_lstm",
45 | "torch_geometric_temporal.nn.recurrent.lrgcn",
46 | "torch_geometric_temporal.nn.recurrent.dygrae",
47 | "torch_geometric_temporal.nn.recurrent.evolvegcnh",
48 | "torch_geometric_temporal.nn.recurrent.evolvegcno",
49 | "torch_geometric_temporal.nn.recurrent.temporalgcn",
50 | "torch_geometric_temporal.nn.recurrent.attentiontemporalgcn",
51 | "torch_geometric_temporal.nn.recurrent.mpnn_lstm",
52 | "torch_geometric_temporal.nn.recurrent.dcrnn",
53 | "torch_geometric_temporal.nn.recurrent.agcrn",
54 |
55 | "torch_geometric_temporal.nn.attention.stgcn",
56 | "torch_geometric_temporal.nn.attention.astgcn",
57 | "torch_geometric_temporal.nn.attention.mstgcn",
58 | "torch_geometric_temporal.nn.attention.gman",
59 | "torch_geometric_temporal.nn.attention.mtgnn",
60 | "torch_geometric_temporal.nn.attention.tsagcn",
61 | "torch_geometric_temporal.nn.attention.dnntsp",
62 |
63 | "torch_geometric_temporal.nn.hetero.heterogclstm"
64 |
65 | ]
66 | model = [
67 | "GConvGRU",
68 | "GConvLSTM",
69 | "GCLSTM",
70 | "LRGCN",
71 | "DyGrEncoder",
72 | "EvolveGCNH",
73 | "EvolveGCNO",
74 | "GCNConv_Fixed_W",
75 | "TGCN",
76 | "TGCN2",
77 | "A3TGCN",
78 | "A3TGCN2",
79 | "MPNNLSTM",
80 | "DCRNN",
81 | "BatchedDCRNN",
82 | "AGCRN",
83 | "STConv",
84 | "ASTGCN",
85 | "MSTGCN",
86 | "GMAN",
87 | "SpatioTemporalAttention",
88 | "GraphConstructor",
89 | "MTGNN",
90 | "AAGCN",
91 | "DNNTSP"
92 | ]
93 | aux = [
94 | "TemporalConv",
95 | "DConv",
96 | "BatchedDConv",
97 | "ChebConvAttention",
98 | "AVWGCN",
99 | "UnitGCN",
100 | "UnitTCN"
101 | ]
102 |
103 | het = [
104 | "HeteroGCLSTM"
105 | ]
106 | # print(mapping.keys())
107 | target = {}
108 | for file, classes in mapping.items():
109 |
110 |
111 | line = ""
112 |
113 | for c in all_classes:
114 | if c not in classes:
115 | line += f"{c}, "
116 |
117 | line = line[:-2]
118 | target[file[:-3]] = line
119 |
120 |
121 | for key in order:
122 | print(f".. autoapimodule:: {key}")
123 | print("\t:members:")
124 | print(f"\t:exclude-members: {target[key]}, LayerNormalization, AggregateTemporalNodeFeatures, GlobalGatedUpdater, MaskedSelfAttention, WeightedGCNBlock, LayerNormalization, K, bias, in_channels, out_channels, normalization, num_bases, num_relations, conv_aggr, conv_num_layers, conv_out_channels, lstm_num_layers, lstm_out_channels, add_self_loops, cached, improved, initial_weight, normalize, num_of_nodes, reinitialize_weight, reset_parameters, weight, batch_size, periods, dropout, hidden_size, num_nodes, window, number_of_nodes, bias_pool, weights_pool, hidden_channels, A, attention, edge_index, gcn1, graph, relu, tcn1, kernel_size, conv_1, conv_2, conv_3, nb_time_filter, adaptive, bn, conv_d, in_c, inter_c, num_jpts, num_subset, out_c, sigmoid, soft, tan, conv, embedding_dimensions, Wq, global_gated_updater, item_embedding, item_embedding_dim, items_total, masked_self_attention, stacked_gcn, in_channels_dict, meta")
125 | print()
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==7.0
2 | sphinx-rtd-theme
3 | sphinx-autoapi
4 | Jinja2>=3.0
--------------------------------------------------------------------------------
/docs/source/_figures/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | for filename in *.tex; do
4 | basename=$(basename $filename .tex)
5 | pdflatex "$basename.tex"
6 | pdf2svg "$basename.pdf" "$basename.svg"
7 | done
8 |
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Use white for logo background. */
2 | .wy-side-nav-search {
3 | background-color: #fff;
4 | }
5 |
6 | .wy-side-nav-search > div.version {
7 | color: #000;
8 | }
9 |
10 | /* Justify the text. */
11 |
12 | .section #basic-2-flip-flop-synchronizer{
13 | text-align:justify;
14 | }
15 |
--------------------------------------------------------------------------------
/docs/source/_static/img/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/9a44e623c484c613a850e16474e03ae4242aaab8/docs/source/_static/img/logo.jpg
--------------------------------------------------------------------------------
/docs/source/_static/img/text_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/9a44e623c484c613a850e16474e03ae4242aaab8/docs/source/_static/img/text_logo.jpg
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sphinx_rtd_theme
3 | import doctest
4 | import inspect
5 |
6 |
7 |
8 |
9 |
10 | extensions = [
11 | 'autoapi.extension',
12 | 'sphinx.ext.doctest',
13 | 'sphinx.ext.intersphinx',
14 | 'sphinx.ext.mathjax',
15 | 'sphinx.ext.napoleon',
16 | 'sphinx.ext.viewcode',
17 | 'sphinx.ext.githubpages',
18 | 'sphinx_rtd_theme',
19 | ]
20 |
21 | source_suffix = '.rst'
22 | master_doc = 'index'
23 |
24 | autoapi_python_use_implicit_namespaces = False
25 |
26 | author = 'Benedek Rozemberczki'
27 | project = 'PyTorch Geometric Temporal'
28 | copyright = '{}, {}'.format(datetime.datetime.now().year, author)
29 |
30 | html_theme = 'sphinx_rtd_theme'
31 | autoapi_add_toctree_entry = True
32 |
33 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE
34 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)}
35 |
36 | html_theme_options = {
37 | 'collapse_navigation': False,
38 | 'display_version': True,
39 | 'logo_only': True,
40 | 'navigation_depth': 2,
41 | }
42 |
43 |
44 | html_logo = '_static/img/text_logo.jpg'
45 | html_static_path = ['_static']
46 |
47 | add_module_names = False
48 | autoapi_generate_api_docs = False
49 |
50 | # --- AutoAPI config ---
51 | autoapi_type = 'python'
52 | autoapi_dirs = ['../../torch_geometric_temporal/']
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/benedekrozemberczki/pytorch_geometric_temporal
2 |
3 | PyTorch Geometric Temporal Documentation
4 | ========================================
5 |
6 | PyTorch Geometric Temporal is a temporal graph neural network extension library for `PyTorch Geometric `_. It builds on open-source deep-learning and graph processing libraries. *PyTorch Geometric Temporal* consists of state-of-the-art deep learning and parametric learning methods to process spatio-temporal signals. It is the first open-source library for temporal deep learning on geometric structures and provides constant time difference graph neural networks on dynamic and static graphs. We make this happen with the use of discrete time graph snapshots. Implemented methods cover a wide range of data mining (`WWW `_, `KDD `_), artificial intelligence and machine learning (`AAAI `_, `ICONIP `_, `ICLR `_) conferences, workshops, and pieces from prominent journals.
7 |
8 |
9 |
10 | PyTorch Geometric Temporal includes support for index-batching - a new batching technique that improves spatiotemporal memory efficiency without any impact on accuracy. Additionally, PyTorch Geometric Temporal supports memory-efficient distributed data parallel training using Dask-DDP in combination with index-batching.
11 |
12 | .. The package interfaces well with `Pytorch Lightning `_ which allows training on CPUs, single and multiple GPUs out-of-the-box. Take a look at this introductory example of using PyTorch Geometric Temporal with Pytorch Lighning.
13 |
14 | .. code-block:: latex
15 |
16 | >@inproceedings{rozemberczki2021pytorch,
17 | author = {Benedek Rozemberczki and Paul Scherer and Yixuan He and George Panagopoulos and Alexander Riedel and Maria Astefanoaei and Oliver Kiss and Ferenc Beres and and Guzman Lopez and Nicolas Collignon and Rik Sarkar},
18 | title = {{PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models}},
19 | year = {2021},
20 | booktitle={Proceedings of the 30th ACM International Conference on Information and Knowledge Management},
21 | pages = {4564–4573},
22 | }
23 |
24 | .. toctree::
25 | :glob:
26 | :maxdepth: 2
27 | :caption: Notes
28 |
29 | notes/installation
30 | notes/introduction
31 | notes/resources
32 |
33 | .. toctree::
34 | :glob:
35 | :maxdepth: 2
36 | :caption: Package Reference
37 |
38 | modules/root
39 | modules/signal
40 | modules/dataset
41 |
--------------------------------------------------------------------------------
/docs/source/modules/dataset.rst:
--------------------------------------------------------------------------------
1 | PyTorch Geometric Temporal Dataset
2 | ========================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | Datasets
8 | -----------------------
9 |
10 | .. autoapimodule:: torch_geometric_temporal.dataset.chickenpox
11 | :members:
12 | :undoc-members:
13 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
14 |
15 | .. autoapimodule:: torch_geometric_temporal.dataset.pedalme
16 | :members:
17 | :undoc-members:
18 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
19 |
20 | .. autoapimodule:: torch_geometric_temporal.dataset.wikimath
21 | :members:
22 | :undoc-members:
23 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
24 |
25 | .. autoapimodule:: torch_geometric_temporal.dataset.windmilllarge
26 | :members:
27 | :undoc-members:
28 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
29 |
30 | .. autoapimodule:: torch_geometric_temporal.dataset.windmillmedium
31 | :members:
32 | :undoc-members:
33 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
34 |
35 | .. autoapimodule:: torch_geometric_temporal.dataset.windmillsmall
36 | :members:
37 | :undoc-members:
38 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
39 |
40 | .. autoapimodule:: torch_geometric_temporal.dataset.metr_la
41 | :members:
42 | :undoc-members:
43 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
44 |
45 | .. autoapimodule:: torch_geometric_temporal.dataset.pems_bay
46 | :members:
47 | :undoc-members:
48 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
49 |
50 | .. autoapimodule:: torch_geometric_temporal.dataset.pemsAllLA
51 | :members:
52 | :undoc-members:
53 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
54 |
55 | .. autoapimodule:: torch_geometric_temporal.dataset.pems
56 | :members:
57 | :undoc-members:
58 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
59 |
60 | .. autoapimodule:: torch_geometric_temporal.dataset.encovid
61 | :members:
62 | :undoc-members:
63 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
64 |
65 | .. autoapimodule:: torch_geometric_temporal.dataset.montevideo_bus
66 | :members:
67 | :undoc-members:
68 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
69 |
70 | .. autoapimodule:: torch_geometric_temporal.dataset.twitter_tennis
71 | :members:
72 | :undoc-members:
73 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
74 | :exclude-members: transform_degree, transform_transitivity, encode_features, onehot_encoding
75 |
76 | .. autoapimodule:: torch_geometric_temporal.dataset.mtm
77 | :members:
78 | :undoc-members:
79 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal
80 |
--------------------------------------------------------------------------------
/docs/source/notes/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | The installation of PyTorch Geometric Temporal requires the presence of certain prerequisites. These are described in great detail in the installation description of PyTorch Geometric. Please follow the instructions laid out `here `_. You might also take a look at the `readme file `_ of the PyTorch Geometric Temporal repository.
5 |
6 | Once the required versions of PyTorch and PyTorch Geometric are installed, simply run:
7 |
8 | .. code-block:: none
9 |
10 | $ pip install torch-geometric-temporal
11 |
12 | **Updating the Library**
13 |
14 | The package itself can be installed via pip:
15 |
16 | .. code-block:: none
17 |
18 | $ pip install torch-geometric-temporal
19 |
20 | Upgrade your outdated PyTorch Geometric Temporal version by using:
21 |
22 | .. code-block:: none
23 |
24 | $ pip install torch-geometric-temporal --upgrade
25 |
26 |
27 | To check your current package version just simply run:
28 |
29 | .. code-block:: none
30 |
31 | $ pip freeze | grep torch-geometric-temporal
32 |
33 | **Index-Batching**
34 |
35 | The package was recently updated to include index-batching, a new method of batching that improves
36 | memory efficiency without any impact on accuracy. To install the needed packages for index-batching,
37 | run the following command:
38 |
39 | .. code-block:: none
40 |
41 | $ pip install torch-geometric-temporal[index]
42 |
43 | **Distributed Data Parallel**
44 |
45 | Alongside index-batching, PGT was recently updated with features to support distributed data parallel training.
46 | To install the needed packages, run the following:
47 |
48 | .. code-block:: none
49 |
50 | $ pip install torch-geometric-temporal[ddp]
51 |
52 |
--------------------------------------------------------------------------------
/docs/source/notes/resources.rst:
--------------------------------------------------------------------------------
1 | External Resources - Architectures
2 | ==================================
3 |
4 | * Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu: **Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting** `Paper `_, `TensorFlow Code `_, `PyTorch Code `_
5 |
6 | * Youngjoo Seo, Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst: **Structured Sequence Modeling With Graph Convolutional Recurrent Networks** `Paper `_, `Code `_, `TensorFlow Code `_
7 |
8 | * Jinyin Chen, Xuanheng Xu, Yangyang Wu, Haibin Zheng: **GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction** `Paper `_
9 |
10 | * Jia Li, Zhichao Han, Hong Cheng, Jiao Su, Pengyun Wang, Jianfeng Zhang, Lujia Pan: **Predicting Path Failure In Time-Evolving Graphs** `Paper `_, `Code `_
11 |
12 | * Aynaz Taheri, Tanya Berger-Wolf: **Predictive Temporal Embedding of Dynamic Graphs** `Paper `_
13 |
14 | * Aynaz Taheri, Kevin Gimpel, Tanya Berger-Wolf: **Learning to Represent the Evolution of Dynamic Graphs with Recurrent Models** `Paper `_, `Code `_
15 |
16 | * Aldo Pareja, Giacomo Domeniconi, Jie Chen, Tengfei Ma, Toyotaro Suzumura, Hiroki Kanezashi, Tim Kaler, Tao B. Schardl, Charles E. Leiserson: **EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs** `Paper `_, `Code `_
17 |
18 | * Bing Yu, Haoteng Yin, Zhanxing Zhu: **Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting** `Paper `_, `Code `_
19 |
20 | * Ling Zhao, Yujiao Song, Chao Zhang, Yu Liu, Pu Wang, Tao Lin, Min Deng, Haifeng Li: **T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction** `Paper `_, `Code `_
21 |
22 | * Jiawei Zhu, Yujiao Song, Ling Zhao, Haifeng Li: **A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting** `Paper `_, `Code `_
23 |
24 | * Shengnan Guo, Youfang Lin, Ning Feng, Chao Song, Huaiyu Wan: **Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting** `Paper `_, `Code `_
25 |
26 | * Chuanpan Zheng, Xiaoliang Fan, Cheng Wang, Jianzhong Qi: **GMAN: A Graph Multi-Attention Network for Traffic Prediction** `Paper `_, `Code `_
27 |
28 | * Zonghan Wu, Shirui Pan, Guodong Long, Jing Jiang, Xiaojun Chang, Chengqi Zhang: **Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks** `Paper `_, `Code `_
29 |
30 | * Lei Bai, Lina Yao, Can Li, Xianzhi Wang, Can Wang: **Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting** `Paper `_, `Code `_
31 |
32 | * George Panagopoulos, Giannis Nikolentzos, Michalis Vazirgiannis: **Transfer Graph Neural Networks for Pandemic Forecasting** `Paper `_, `Code `_
33 |
34 | * Lei Shi, Yifan Zhang, Jian Cheng, Hanqing Lu: **Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition** `Paper `_, `Code `_
35 |
36 | * Le Yu, Leilei Sun, Bowen Du, Chuanren Liu, Hui Xiong, Weifeng Lv: **Predicting Temporal Sets with Deep Neural Networks** `Paper `_, `Code `_
37 |
38 | External Resources - Datasets
39 | =============================
40 |
41 | * Benedek Rozemberczki, Paul Scherer, Oliver Kiss, Rik Sarkar, Tamas Ferenci: **Chickenpox Cases in Hungary: a Benchmark Dataset for Spatiotemporal Signal Processing with Graph Neural Networks** `Paper `_, `Dataset `_
42 |
43 | * Ferenc Béres, Róbert Pálovics, Anna Oláh, András A. Benczúr: **Temporal Walk Based Centrality Metric for Graph Streams** `Paper `_, `Dataset `_
44 |
45 | * Ferenc Béres, Domokos M. Kelen, Róbert Pálovics, András A. Benczúr : **Node Embeddings in Dynamic Graphs** `Paper `_, `Dataset `_
46 |
47 | * Mallick, Tanwi and Balaprakash, Prasanna and Rask, Eric and Macfarlane, Jane: **Graph-partitioning-based diffusion convolutional recurrent neural network for large-scale traffic forecasting** `Paper `_, `Dataset `_
48 |
49 |
--------------------------------------------------------------------------------
/examples/indexBatching/A3TGCN/metr_la_main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import csv
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_geometric.nn import GCNConv
7 | from torch_geometric_temporal.nn.recurrent import A3TGCN2
8 | from torch_geometric_temporal.dataset import METRLADatasetLoader
9 | import argparse
10 |
11 |
12 | def parse_arguments():
13 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset")
14 |
15 | parser.add_argument(
16 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
17 | )
18 | parser.add_argument(
19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
20 | )
21 | parser.add_argument(
22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
23 | )
24 | parser.add_argument(
25 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
26 | )
27 | return parser.parse_args()
28 |
29 | # Making the model
30 | class TemporalGNN(torch.nn.Module):
31 | def __init__(self, node_features, periods, batch_size):
32 | super(TemporalGNN, self).__init__()
33 | # Attention Temporal Graph Convolutional Cell
34 | self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12
35 | # Equals single-shot prediction
36 | self.linear = torch.nn.Linear(32, periods)
37 |
38 | def forward(self, x, edge_index):
39 | """
40 | x = Node features for T time steps
41 | edge_index = Graph edge indices
42 | """
43 | h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12]
44 | h = F.relu(h)
45 | h = self.linear(h)
46 | return h
47 |
48 |
49 |
50 | def train(train_dataloader, val_dataloader, batch_size, epochs, edges, DEVICE, allGPU=False, debug=False):
51 |
52 | # Create model and optimizers
53 | model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(DEVICE)
54 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
55 | loss_fn = torch.nn.MSELoss()
56 |
57 | stats = []
58 | t_mse = []
59 | v_mse = []
60 |
61 |
62 | edges = edges.to(DEVICE)
63 | for epoch in range(epochs):
64 | step = 0
65 | loss_list = []
66 | t1 = time.time()
67 | i = 1
68 | total = len(train_dataloader)
69 | for batch in train_dataloader:
70 | X_batch, y_batch = batch
71 |
72 | # Need to permute based on expected input shape for ATGCN
73 | if allGPU:
74 | X_batch = X_batch.permute(0, 2, 3, 1)
75 | y_batch = y_batch[...,0].permute(0, 2, 1)
76 | else:
77 | X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE)
78 | y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE)
79 |
80 |
81 |
82 | y_hat = model(X_batch, edges) # Get model predictions
83 | loss = loss_fn(y_hat, y_batch) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse
84 | loss.backward()
85 | optimizer.step()
86 | optimizer.zero_grad()
87 | step= step+ 1
88 | loss_list.append(loss.item())
89 |
90 | if debug:
91 | print(f"Train Batch: {i}/{total}", end="\r")
92 | i+=1
93 |
94 |
95 | model.eval()
96 | step = 0
97 | # Store for analysis
98 | total_loss = []
99 | i = 1
100 | total = len(val_dataloader)
101 | if debug:
102 | print(" ", end="\r")
103 | with torch.no_grad():
104 | for batch in val_dataloader:
105 | X_batch, y_batch = batch
106 |
107 |
108 | # Need to permute based on expected input shape for ATGCN
109 | if allGPU:
110 | X_batch = X_batch.permute(0, 2, 3, 1)
111 | y_batch = y_batch[...,0].permute(0, 2, 1)
112 | else:
113 | X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE)
114 | y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE)
115 |
116 | # Get model predictions
117 | y_hat = model(X_batch, edges)
118 | # Mean squared error
119 | loss = loss_fn(y_hat, y_batch)
120 | total_loss.append(loss.item())
121 |
122 | if debug:
123 | print(f"Val Batch: {i}/{total}", end="\r")
124 | i += 1
125 |
126 |
127 | t2 = time.time()
128 | print("Epoch {} time: {:.4f} train RMSE: {:.4f} Test MSE: {:.4f}".format(epoch,t2 - t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss)))
129 | stats.append([epoch, t2-t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss)])
130 | t_mse.append(sum(loss_list)/len(loss_list))
131 | v_mse.append(sum(total_loss)/len(total_loss))
132 | return min(t_mse), min(v_mse)
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 | def main():
142 | args = parse_arguments()
143 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
144 | debug = args.debug.lower() in ["true", "y", "t", "yes"]
145 | batch_size = args.batch_size
146 | epochs = args.epochs
147 |
148 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149 | shuffle= True
150 |
151 |
152 | start = time.time()
153 | p1 = time.time()
154 | indexLoader = METRLADatasetLoader(index=True)
155 | if allGPU:
156 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle, allGPU=0)
157 | else:
158 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle)
159 | p2 = time.time()
160 | t_mse, v_mse = train(train_dataloader, val_dataloader, batch_size, epochs, edges, device, debug=debug)
161 | end = time.time()
162 |
163 | print(f"Runtime: {round(end - start,2)}; T-MSE: {round(t_mse, 3)}; V-MSE: {round(v_mse, 3)}")
164 |
165 | if __name__ == "__main__":
166 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/chicken_pox_main.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN
3 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
4 | import torch.optim as optim
5 |
6 | import argparse
7 | import csv
8 | import os
9 | import time
10 | from utils import *
11 |
12 |
13 | def parse_arguments():
14 | """Parse command-line arguments."""
15 | parser = argparse.ArgumentParser(description="Demo of index batching with chickenpox dataset")
16 |
17 | parser.add_argument(
18 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
19 | )
20 | parser.add_argument(
21 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
22 | )
23 | parser.add_argument(
24 | "-m", "--mode", type=str, default="base", help="Which version to run"
25 | )
26 | parser.add_argument(
27 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
28 | )
29 | parser.add_argument(
30 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
31 | )
32 |
33 | return parser.parse_args()
34 |
35 | def train(train_dataloader, val_dataloader, edge_index,edge_weight, epochs, seq_length, num_nodes, num_features,
36 | allGPU=False, debug=False):
37 |
38 |
39 | # Move to GPU
40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41 | edge_index = edge_index.to(device)
42 | edge_weight = edge_weight.to(device)
43 |
44 |
45 | # Initialize model
46 | model = DCRNN(num_features, num_features, K=3).to(device)
47 |
48 | # Define optimizer and loss function
49 | optimizer = optim.Adam(model.parameters(), lr=0.001)
50 |
51 | # Training loop
52 | stats = []
53 | min_t = 9999
54 | min_v = 9999
55 | for epoch in range(epochs):
56 | # Training phase
57 | model.train()
58 | train_loss = 0.0
59 | i = 1
60 | total = len(train_dataloader)
61 | t1 = time.time()
62 | for batch in train_dataloader:
63 | X_batch, y_batch = batch
64 |
65 | if allGPU == False:
66 |
67 | X_batch = X_batch.to(device).float()
68 | y_batch = y_batch.to(device).float()
69 |
70 | # Forward pass
71 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
72 |
73 | # Calculate loss (use only the first output channel, assuming it's the target)
74 | loss = masked_mae_loss(outputs,y_batch )
75 |
76 | # Backward pass
77 | optimizer.zero_grad()
78 | loss.backward()
79 | optimizer.step()
80 |
81 | train_loss += loss.item()
82 | if debug:
83 | print(f"Train Batch: {i}/{total}", end="\r")
84 | i+=1
85 |
86 |
87 | train_loss /= len(train_dataloader)
88 |
89 | # Validation phase
90 | model.eval()
91 | val_loss = 0.0
92 | i = 0
93 | if debug:
94 | print(" ", end="\r")
95 | total = len(val_dataloader)
96 | with torch.no_grad():
97 | for batch in val_dataloader:
98 | X_batch, y_batch = batch
99 |
100 | if allGPU == False:
101 | X_batch = X_batch.to(device).float()
102 | y_batch = y_batch.to(device).float()
103 |
104 | # Forward pass
105 | outputs = model(X_batch, edge_index, edge_weight)
106 |
107 | # Calculate loss
108 | loss = masked_mae_loss(outputs,y_batch)
109 | val_loss += loss.item()
110 |
111 | if debug:
112 | print(f"Val Batch: {i}/{total}", end="\r")
113 | i += 1
114 |
115 | val_loss /= len(val_dataloader)
116 | t2 = time.time()
117 | # Print epoch metrics
118 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True)
119 | stats.append([epoch+1, t2 - t1, train_loss, val_loss])
120 |
121 | min_t = min(min_t, train_loss)
122 | min_v = min(min_v, val_loss)
123 |
124 | return min_t, min_v
125 |
126 | def main():
127 |
128 | args = parse_arguments()
129 | allGPU = args.gpu.lower() in ["true", "y", "t"]
130 | debug = args.debug.lower() in ["true", "y", "t"]
131 | batch_size = args.batch_size
132 | epochs = args.epochs
133 |
134 | t1 = time.time()
135 | loader = ChickenpoxDatasetLoader(index=True)
136 | if allGPU == True:
137 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights = loader.get_index_dataset(allGPU=0,batch_size=batch_size)
138 | else:
139 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights = loader.get_index_dataset(batch_size=batch_size)
140 |
141 | t_min, v_min = train(train_dataloader, val_dataloader, edges, edge_weights, epochs, 4,20,1, allGPU=allGPU, debug=debug)
142 | t2 = time.time()
143 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}")
144 |
145 | if __name__ == "__main__":
146 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/pems_allLA_main.py:
--------------------------------------------------------------------------------
1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN
2 | from torch_geometric_temporal.dataset import PemsAllLADatasetLoader
3 | import torch.optim as optim
4 |
5 | import argparse
6 | import csv
7 | import os
8 | import time
9 | from utils import *
10 |
11 |
12 | def parse_arguments():
13 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsAllLA dataset")
14 |
15 | parser.add_argument(
16 | "-e", "--epochs", type=int, default=30, help="The desired number of training epochs"
17 | )
18 | parser.add_argument(
19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
20 | )
21 | parser.add_argument(
22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
23 | )
24 | parser.add_argument(
25 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
26 | )
27 | return parser.parse_args()
28 |
29 |
30 |
31 |
32 | def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, seq_length, num_nodes, num_features, allGPU=False, debug=False):
33 |
34 | # Move to GPU
35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36 | edge_index = edges.to(device)
37 | edge_weight = edge_weights.to(device)
38 |
39 | if allGPU == False:
40 | mean = mean.to(device)
41 | std = std.to(device)
42 |
43 | # Initialize model
44 | model = DCRNN(num_features, num_features, K=3).to(device)
45 |
46 | # Define optimizer and loss function
47 | optimizer = optim.Adam(model.parameters(), lr=0.001)
48 |
49 | # Training loop
50 | stats = []
51 | min_t = 9999
52 | min_v = 9999
53 | for epoch in range(epochs):
54 | # Training phase
55 | model.train()
56 | train_loss = 0.0
57 | i = 1
58 | total = len(train_dataloader)
59 | t1 = time.time()
60 | for batch in train_dataloader:
61 | X_batch, y_batch = batch
62 |
63 | if allGPU == False:
64 | # print("casting")
65 | X_batch = X_batch.to(device).float()
66 | y_batch = y_batch.to(device).float()
67 |
68 | # Forward pass
69 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
70 |
71 | # Calculate loss (use only the first output channel, assuming it's the target)
72 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
73 |
74 | # Backward pass
75 | optimizer.zero_grad()
76 | loss.backward()
77 | optimizer.step()
78 |
79 | train_loss += loss.item()
80 | if debug:
81 | print(f"Train Batch: {i}/{total}", end="\r")
82 | i+=1
83 | # break
84 |
85 |
86 | train_loss /= len(train_dataloader)
87 |
88 | # Validation phase
89 | model.eval()
90 | val_loss = 0.0
91 | i = 0
92 | if debug:
93 | print(" ", end="\r")
94 | total = len(val_dataloader)
95 | with torch.no_grad():
96 | for batch in val_dataloader:
97 | X_batch, y_batch = batch
98 |
99 | if allGPU == False:
100 | X_batch = X_batch.to(device).float()
101 | y_batch = y_batch.to(device).float()
102 |
103 | # Forward pass
104 | outputs = model(X_batch, edge_index, edge_weight)
105 |
106 | # Calculate loss
107 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
108 | val_loss += loss.item()
109 | if debug:
110 | print(f"Val Batch: {i}/{total}", end="\r")
111 | i += 1
112 |
113 |
114 | val_loss /= len(val_dataloader)
115 | t2 = time.time()
116 | # Print epoch metrics
117 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True)
118 | stats.append([epoch+1, t2 - t1, train_loss, val_loss])
119 |
120 | min_t = min(min_t, train_loss)
121 | min_v = min(min_v, val_loss)
122 |
123 | return min_t, min_v
124 |
125 | def main():
126 |
127 |
128 |
129 | args = parse_arguments()
130 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
131 | debug = args.debug.lower() in ["true", "y", "t", "yes"]
132 | batch_size = args.batch_size
133 | epochs = args.epochs
134 |
135 | t1 = time.time()
136 | loader = PemsAllLADatasetLoader(index=True)
137 | if allGPU:
138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(allGPU=0, batch_size=batch_size)
139 | else:
140 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size)
141 |
142 |
143 | t_min, v_min = train(train_dataloader, val_dataloader, means, stds, edges, edge_weights, epochs, 12,11160,2, allGPU=allGPU, debug=debug)
144 | t2 = time.time()
145 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}")
146 |
147 | if __name__ == "__main__":
148 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/pems_bay_main.py:
--------------------------------------------------------------------------------
1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN
2 | from torch_geometric_temporal.dataset import PemsBayDatasetLoader
3 | import torch.optim as optim
4 |
5 | import argparse
6 | import csv
7 | import os
8 | import time
9 | from utils import *
10 |
11 |
12 |
13 | def parse_arguments():
14 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset")
15 |
16 | parser.add_argument(
17 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
18 | )
19 | parser.add_argument(
20 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
21 | )
22 | parser.add_argument(
23 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
24 | )
25 | parser.add_argument(
26 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
27 | )
28 | return parser.parse_args()
29 |
30 | def train(train_dataloader, val_dataloader, mean, std, edge_index, edge_weight, epochs, seq_length, num_nodes, num_features,
31 | allGPU=False, debug=False):
32 |
33 |
34 |
35 | # Move to GPU
36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 | edge_index = edge_index.to(device)
38 | edge_weight = edge_weight.to(device)
39 |
40 | if allGPU == False:
41 | mean = mean.to(device)
42 | std = std.to(device)
43 |
44 | # Initialize model
45 | model = DCRNN(num_features, num_features, K=3).to(device)
46 |
47 | # Define optimizer and loss function
48 | optimizer = optim.Adam(model.parameters(), lr=0.001)
49 |
50 | # Training loop
51 | stats = []
52 | min_t = 9999
53 | min_v = 9999
54 | for epoch in range(epochs):
55 | # Training phase
56 | model.train()
57 | train_loss = 0.0
58 | i = 1
59 | total = len(train_dataloader)
60 | t1 = time.time()
61 | for batch in train_dataloader:
62 | X_batch, y_batch = batch
63 |
64 | if allGPU == False:
65 | X_batch = X_batch.to(device).float()
66 | y_batch = y_batch.to(device).float()
67 | # Forward pass
68 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
69 |
70 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
71 |
72 | # Backward pass
73 | optimizer.zero_grad()
74 | loss.backward()
75 | optimizer.step()
76 |
77 | train_loss += loss.item()
78 | if debug:
79 | print(f"Train Batch: {i}/{total}", end="\r")
80 | i+=1
81 |
82 |
83 | train_loss /= len(train_dataloader)
84 |
85 | # Validation phase
86 | model.eval()
87 | val_loss = 0.0
88 | i = 0
89 | if debug:
90 | print(" ", end="\r")
91 | total = len(val_dataloader)
92 | with torch.no_grad():
93 | for batch in val_dataloader:
94 | X_batch, y_batch = batch
95 |
96 | if allGPU == False:
97 | X_batch = X_batch.to(device).float()
98 | y_batch = y_batch.to(device).float()
99 |
100 | # Forward pass
101 | outputs = model(X_batch, edge_index, edge_weight)
102 |
103 | # Calculate loss
104 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
105 | val_loss += loss.item()
106 | if debug:
107 | print(f"Val Batch: {i}/{total}", end="\r")
108 | i += 1
109 |
110 | val_loss /= len(val_dataloader)
111 | t2 = time.time()
112 | # Print epoch metrics
113 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True)
114 | stats.append([epoch+1, t2 - t1, train_loss, val_loss])
115 |
116 | min_t = min(min_t, train_loss)
117 | min_v = min(min_v, val_loss)
118 |
119 |
120 |
121 | return min_t, min_v
122 |
123 | def main():
124 |
125 |
126 |
127 | args = parse_arguments()
128 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
129 | debug = args.debug.lower() in ["true", "y", "t", "yes"]
130 | batch_size = args.batch_size
131 | epochs = args.epochs
132 | t1 = time.time()
133 | loader = PemsBayDatasetLoader(index=True)
134 |
135 | if allGPU == True:
136 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(allGPU=0, batch_size=batch_size)
137 | else:
138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(batch_size=batch_size)
139 |
140 | t_min, v_min = train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, 12,325,2, allGPU=allGPU,debug=debug)
141 | t2 = time.time()
142 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}")
143 |
144 |
145 | if __name__ == "__main__":
146 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/pems_main.py:
--------------------------------------------------------------------------------
1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN
2 | from torch_geometric_temporal.dataset import PemsDatasetLoader
3 | import torch.optim as optim
4 |
5 | import argparse
6 | import csv
7 | import os
8 | import time
9 | from utils import *
10 |
11 |
12 | def parse_arguments():
13 | parser = argparse.ArgumentParser(description="Demo of index batching with Pems dataset")
14 |
15 | parser.add_argument(
16 | "-e", "--epochs", type=int, default=30, help="The desired number of training epochs"
17 | )
18 | parser.add_argument(
19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
20 | )
21 | parser.add_argument(
22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
23 | )
24 | parser.add_argument(
25 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
26 | )
27 | return parser.parse_args()
28 |
29 |
30 |
31 |
32 | def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, seq_length, num_nodes, num_features, allGPU=False, debug=False):
33 |
34 | # Move to GPU
35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36 | edge_index = edges.to(device)
37 | edge_weight = edge_weights.to(device)
38 |
39 | if allGPU == False:
40 | mean = mean.to(device)
41 | std = std.to(device)
42 |
43 | # Initialize model
44 | model = DCRNN(num_features, num_features, K=3).to(device)
45 |
46 | # Define optimizer and loss function
47 | optimizer = optim.Adam(model.parameters(), lr=0.001)
48 |
49 | # Training loop
50 | stats = []
51 | min_t = 9999
52 | min_v = 9999
53 | for epoch in range(epochs):
54 | # Training phase
55 | model.train()
56 | train_loss = 0.0
57 | i = 1
58 | total = len(train_dataloader)
59 | t1 = time.time()
60 | for batch in train_dataloader:
61 | X_batch, y_batch = batch
62 |
63 | if allGPU == False:
64 | # print("casting")
65 | X_batch = X_batch.to(device).float()
66 | y_batch = y_batch.to(device).float()
67 |
68 | # Forward pass
69 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
70 |
71 | # Calculate loss (use only the first output channel, assuming it's the target)
72 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
73 |
74 | # Backward pass
75 | optimizer.zero_grad()
76 | loss.backward()
77 | optimizer.step()
78 |
79 | train_loss += loss.item()
80 | if debug:
81 | print(f"Train Batch: {i}/{total}", end="\r")
82 | i+=1
83 | # break
84 |
85 |
86 | train_loss /= len(train_dataloader)
87 |
88 | # Validation phase
89 | model.eval()
90 | val_loss = 0.0
91 | i = 0
92 | if debug:
93 | print(" ", end="\r")
94 | total = len(val_dataloader)
95 | with torch.no_grad():
96 | for batch in val_dataloader:
97 | X_batch, y_batch = batch
98 |
99 | if allGPU == False:
100 | X_batch = X_batch.to(device).float()
101 | y_batch = y_batch.to(device).float()
102 |
103 | # Forward pass
104 | outputs = model(X_batch, edge_index, edge_weight)
105 |
106 | # Calculate loss
107 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
108 | val_loss += loss.item()
109 | if debug:
110 | print(f"Val Batch: {i}/{total}", end="\r")
111 | i += 1
112 |
113 |
114 | val_loss /= len(val_dataloader)
115 | t2 = time.time()
116 | # Print epoch metrics
117 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True)
118 | stats.append([epoch+1, t2 - t1, train_loss, val_loss])
119 |
120 | min_t = min(min_t, train_loss)
121 | min_v = min(min_v, val_loss)
122 |
123 | return min_t, min_v
124 |
125 | def main():
126 |
127 |
128 |
129 | args = parse_arguments()
130 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
131 | debug = args.debug.lower() in ["true", "y", "t", "yes"]
132 | batch_size = args.batch_size
133 | epochs = args.epochs
134 |
135 | t1 = time.time()
136 | loader = PemsDatasetLoader(index=True)
137 | if allGPU:
138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(allGPU=0, batch_size=batch_size)
139 | else:
140 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size)
141 |
142 |
143 | t_min, v_min = train(train_dataloader, val_dataloader, means, stds, edges, edge_weights, epochs, 12,11160,2, allGPU=allGPU, debug=debug)
144 | t2 = time.time()
145 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}")
146 |
147 | if __name__ == "__main__":
148 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #PBS -l select=5:system=polaris
3 | #PBS -l place=scatter
4 | #PBS -l filesystems=home:eagle
5 | #PBS -l walltime=00:30:00
6 | #PBS -q
7 | #PBS -A
8 | #PBS -o train.out
9 | #PBS -e train.err
10 |
11 | nodes=4
12 | gpus_per_node=4
13 |
14 | # should the training utilize GPU-index-batching
15 | allGPU=True
16 |
17 | # which dataset to train on; valid options include "pems-bay", 'pemsAllLA', and "pems"
18 | dataset="pems"
19 |
20 | # total workers
21 | total=$((gpus_per_node * nodes))
22 |
23 | readarray -t all_nodes < "$NODEFILE"
24 |
25 | # use first node for dask scheduler and client
26 | scheduler_node=${all_nodes[0]}
27 | monitor_node=${all_nodes[1]}
28 |
29 | # all nodes but first for workers
30 | tail -n +2 $NODEFILE > worker_nodefile.txt
31 |
32 | echo "Launching scheduler"
33 | mpiexec -n 1 --ppn 1 --cpu-bind none --hosts $scheduler_node dask scheduler --scheduler-file cluster.info &
34 | scheduler_pid=$!
35 |
36 | # wait for the scheduler to generate the cluster config file
37 | while ! [ -f cluster.info ]; do
38 | sleep 1
39 | echo .
40 | done
41 |
42 | echo "$total workers launching"
43 | mpiexec -n $total --ppn $gpus_per_node --cpu-bind none --hostfile worker_nodefile.txt dask worker --local-directory /local/scratch --scheduler-file cluster.info --nthreads 8 --memory-limit 512GB &
44 |
45 | # give workers a bit to launch
46 | sleep 5
47 |
48 | echo "Launching client"
49 | mpiexec -n 1 --ppn 1 --cpu-bind none --hosts $scheduler_node `which python3` pems_ddp.py --dask-cluster-file cluster.info -np $gpus_per_node -g $allGPU --dataset $dataset
50 |
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 | import csv
6 | import os
7 | import time
8 | import scipy.sparse as sp
9 |
10 | def masked_mae_loss(y_pred, y_true):
11 |
12 | mask = (y_true != 0).float()
13 | mask /= mask.mean()
14 | loss = torch.abs(y_pred - y_true)
15 | loss = loss * mask
16 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
17 | loss[loss != loss] = 0
18 | return loss.mean()
19 |
20 |
21 |
22 | def load_graph_data(pkl_filename):
23 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
24 | return sensor_ids, sensor_id_to_ind, adj_mx
25 |
26 |
27 | def load_pickle(pickle_file):
28 | try:
29 | with open(pickle_file, 'rb') as f:
30 | pickle_data = pickle.load(f)
31 | except UnicodeDecodeError as e:
32 | with open(pickle_file, 'rb') as f:
33 | pickle_data = pickle.load(f, encoding='latin1')
34 | except Exception as e:
35 | print('Unable to load data ', pickle_file, ':', e)
36 | raise
37 | return pickle_data
38 |
39 | def adjacency_to_edge_index(adj_mx):
40 | """
41 | Convert an adjacency matrix to edge_index and edge_weight.
42 | Args:
43 | adj_mx (np.ndarray): Adjacency matrix of shape (num_nodes, num_nodes).
44 | Returns:
45 | edge_index (torch.LongTensor): Shape (2, num_edges), source and target nodes.
46 | edge_weight (torch.FloatTensor): Shape (num_edges,), edge weights.
47 | """
48 | # Convert to sparse matrix
49 | adj_sparse = sp.coo_matrix(adj_mx)
50 |
51 | # Extract edge indices and weights
52 | edge_index = torch.tensor(
53 | np.vstack((adj_sparse.row, adj_sparse.col)), dtype=torch.long
54 | )
55 | edge_weight = torch.tensor(adj_sparse.data, dtype=torch.float)
56 |
57 | return edge_index, edge_weight
58 |
59 | # standard approach: see https://github.com/chnsh/DCRNN_PyTorch
60 | def benchmark_preprocess(h5File, dataset, key=None):
61 |
62 | if "pems" in dataset.lower():
63 | df = pd.read_hdf(h5File)
64 |
65 | x_offsets = np.sort(
66 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1)))
67 | np.concatenate((np.arange(-11, 1, 1),))
68 | )
69 | # Predict the next one hour
70 | y_offsets = np.sort(np.arange(1, 13, 1))
71 | num_samples, num_nodes = df.shape
72 |
73 | data = np.expand_dims(df.values, axis=-1)
74 | data_list = [data]
75 | add_time_in_day= True
76 | if add_time_in_day:
77 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D")
78 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0))
79 | data_list.append(time_in_day)
80 |
81 | data = np.concatenate(data_list, axis=-1)
82 |
83 |
84 | x, y = [], []
85 | # t is the index of the last observation.
86 | min_t = abs(min(x_offsets))
87 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive
88 | for t in range(min_t, max_t):
89 | x_t = data[t + x_offsets, ...]
90 |
91 |
92 | y_t = data[t + y_offsets, ...]
93 | x.append(x_t)
94 | y.append(y_t)
95 |
96 |
97 |
98 | x = np.stack(x, axis=0)
99 | y = np.stack(y, axis=0)
100 |
101 | return torch.tensor(x,dtype=torch.float),torch.tensor(y,dtype=torch.float)
102 |
103 | def collect_metrics():
104 | import psutil
105 | from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlShutdown
106 |
107 | try:
108 | # Initialize NVML for GPU metrics
109 | nvmlInit()
110 |
111 | # Open the CSV file in append mode
112 |
113 | data = []
114 | max_gpu_mem = -1
115 | max_system_mem = -1
116 | while True:
117 | # Collect system memory usage
118 | timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
119 | mem = psutil.virtual_memory()
120 | total_rss = sum(proc.memory_info().rss for proc in psutil.process_iter(attrs=['memory_info']))
121 | system_memory_used = total_rss / (1024**2) # Convert to MB
122 | system_memory_total = mem.total / (1024**2) # Convert to MB
123 |
124 | # Collect GPU memory usage
125 | gpu_metrics = []
126 | handle = nvmlDeviceGetHandleByIndex(0)
127 | info = nvmlDeviceGetMemoryInfo(handle)
128 | gpu_memory_used = info.used / (1024**2) # Convert to MB
129 | gpu_memory_total = info.total / (1024**2) # Convert to MB
130 |
131 | max_gpu_mem = max(gpu_memory_used, max_gpu_mem)
132 | max_system_mem = max(system_memory_used, max_system_mem)
133 | data.append([timestamp,system_memory_used,system_memory_total, gpu_memory_used, gpu_memory_total])
134 |
135 | if os.path.isfile("flag.txt"):
136 | os.remove("flag.txt")
137 | break
138 |
139 | if os.path.isfile("stats.csv"):
140 | with open("system_stats.csv", mode="w", newline="") as f:
141 | writer = csv.writer(f)
142 |
143 | # Write headers to the CSV file
144 | headers = [
145 | "Timestamp",
146 | "System_Memory_Used",
147 | "System_Memory_Total",
148 | "GPU_Memory_Used",
149 | "GPU_Memory_Total"
150 | ]
151 | writer.writerow(headers)
152 | writer.writerows(data)
153 |
154 |
155 | df = pd.read_csv("stats.csv")
156 | df['system_memory_total'] = system_memory_total
157 | df['max_system_mem'] = max_system_mem
158 | df['gpu_memory_total'] = gpu_memory_total
159 | df['max_gpu_mem'] = max_gpu_mem
160 |
161 | df.to_csv("stats.csv", index=False)
162 | break
163 | time.sleep(1)
164 |
165 | except Exception as e:
166 | print("Error in collecting metrics:", str(e))
167 |
168 | finally:
169 | # Shutdown NVML
170 | nvmlShutdown()
--------------------------------------------------------------------------------
/examples/indexBatching/DCRNN/windmill_main.py:
--------------------------------------------------------------------------------
1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN
2 | from torch_geometric_temporal.dataset import WindmillOutputLargeDatasetLoader
3 | import torch.optim as optim
4 |
5 | import argparse
6 | import time
7 | import csv
8 | import os
9 | from utils import *
10 |
11 |
12 |
13 | def parse_arguments():
14 | """Parse command-line arguments."""
15 | parser = argparse.ArgumentParser(description="Demo of index batching with WindmillLarge dataset")
16 |
17 | parser.add_argument(
18 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
19 | )
20 | parser.add_argument(
21 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size"
22 | )
23 | parser.add_argument(
24 | "-m", "--mode", type=str, default="base", help="Which version to run"
25 | )
26 | parser.add_argument(
27 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
28 | )
29 | parser.add_argument(
30 | "-d", "--debug", type=str, default="False", help="Print values for debugging"
31 | )
32 |
33 | return parser.parse_args()
34 |
35 |
36 | def train(train_dataloader, val_dataloader, mean, std, edge_index,edge_weight, epochs, seq_length, num_nodes, num_features,
37 | allGPU=False, debug=False):
38 |
39 |
40 | # Move to GPU
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | edge_index = edge_index.to(device)
43 | edge_weight = edge_weight.to(device)
44 |
45 | if allGPU == False:
46 | mean = mean.to(device)
47 | std = std.to(device)
48 |
49 | # Initialize model
50 | model = DCRNN(num_features, num_features, K=3).to(device)
51 |
52 | # Define optimizer and loss function
53 | optimizer = optim.Adam(model.parameters(), lr=0.001)
54 |
55 | # Training loop
56 | stats = []
57 | min_t = 9999
58 | min_v = 9999
59 | for epoch in range(epochs):
60 | # Training phase
61 | model.train()
62 | train_loss = 0.0
63 | i = 1
64 | total = len(train_dataloader)
65 | t1 = time.time()
66 | for batch in train_dataloader:
67 | X_batch, y_batch = batch
68 |
69 | if allGPU == False:
70 | # print("casting")
71 | X_batch = X_batch.to(device).float()
72 | y_batch = y_batch.to(device).float()
73 |
74 | # Forward pass
75 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
76 |
77 | # Calculate loss (use only the first output channel, assuming it's the target)
78 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
79 |
80 | # Backward pass
81 | optimizer.zero_grad()
82 | loss.backward()
83 | optimizer.step()
84 |
85 | train_loss += loss.item()
86 | if debug:
87 | print(f"Train Batch: {i}/{total}", end="\r")
88 | i+=1
89 |
90 |
91 |
92 | train_loss /= len(train_dataloader)
93 |
94 | # Validation phase
95 | model.eval()
96 | val_loss = 0.0
97 | i = 0
98 | if debug:
99 | print(" ", end="\r")
100 | total = len(val_dataloader)
101 | with torch.no_grad():
102 | for batch in val_dataloader:
103 | X_batch, y_batch = batch
104 |
105 | if allGPU == False:
106 | X_batch = X_batch.to(device).float()
107 | y_batch = y_batch.to(device).float()
108 |
109 | # Forward pass
110 | outputs = model(X_batch, edge_index, edge_weight)
111 |
112 | # Calculate loss
113 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
114 | val_loss += loss.item()
115 |
116 | if debug:
117 | print(f"Val Batch: {i}/{total}", end="\r")
118 | i += 1
119 |
120 | val_loss /= len(val_dataloader)
121 | t2 = time.time()
122 |
123 | # Print epoch metrics
124 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True)
125 | stats.append([epoch+1, t2 - t1, train_loss, val_loss])
126 |
127 | min_t = min(min_t, train_loss)
128 | min_v = min(min_v, val_loss)
129 |
130 | return min_t, min_v
131 |
132 | def main():
133 |
134 | args = parse_arguments()
135 | allGPU = args.gpu.lower() in ["true", "y", "t"]
136 | debug = args.debug.lower() in ["true", "y", "t"]
137 | batch_size = args.batch_size
138 | epochs = args.epochs
139 |
140 | t1 = time.time()
141 | loader = WindmillOutputLargeDatasetLoader(index=True)
142 | if allGPU == True:
143 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(allGPU=0, batch_size=batch_size)
144 | else:
145 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(batch_size=batch_size)
146 |
147 |
148 |
149 | t_min, v_min = train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, 8,319,1, allGPU=allGPU, debug=debug)
150 | t2 = time.time()
151 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}")
152 |
153 | if __name__ == "__main__":
154 | main()
--------------------------------------------------------------------------------
/examples/indexBatching/README.md:
--------------------------------------------------------------------------------
1 | ## PyTorch Geometric Temporal - Index
2 |
3 | Index-batching is a technique that reduces the memory cost of training ST-GNNs with spatiotemporal data with no impact on accurary, enabling greater scalability and training on the full PeMS dataset without graph partioning for the first time. Leveraging the reduced memory footprint, this techique also enables GPU-index-batching - a technique that performs preprocessing entirely in GPU memory and utilizes a single CPU-to-GPU mem-copy in place of batch-level CPU-to-GPU transfers throughout training. We implemented GPU-index-batching and index-batching for the following existing datasets and added two new datasets (highlighted in bold) to PyTorch Geometric Temporal (PGT):
4 |
5 | * PeMs-Bay
6 | * Metr-LA
7 | * WindmillLarge
8 | * HungaryChickenpox
9 | * **PeMSAllLA**
10 | * **PeMS**
11 |
12 | This folder contains examples with DCRNN and A3TGCN. We hope to build out our examples over time.
13 |
14 |
15 | Utilizing index-batching requires minimal modifications to the existing PGT workflow. For example, the following is a sample training loop with static graph dataset with temporal signal:
16 |
17 | ```
18 | train_dataloader, _, _, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size)
19 |
20 | for batch in train_dataloader:
21 | X_batch, y_batch = batch
22 |
23 | # Forward pass
24 | outputs = model(X_batch, edges, edge_weights)
25 |
26 | # Calculate loss
27 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
28 |
29 | # Backward pass
30 | optimizer.zero_grad()
31 | loss.backward()
32 | optimizer.step()
33 |
34 |
35 | ```
36 |
37 | The single-GPU examples in this repo can be executed as follows: `python3 .py` and have the following parameters:
38 |
39 | | Argument | Short Form | Type | Default | Description |
40 | |----------|-----------|------|---------|-------------|
41 | | `--epochs` | `-e` | `int` | `30` | The desired number of training epochs. |
42 | | `--batch-size` | `-bs` | `int` | `64` | The desired batch size for training. |
43 | | `--gpu` | `-g` | `str` | `"False"` | Indicates whether data should be preprocessed and migrated directly to the GPU. Use `"True"` to enable GPU processing. |
44 | | `--debug` | `-d` | `str` | `"False"` | Enables debug mode, printing values for debugging. Use `"True"` to enable debugging. |
45 |
46 |
47 |
48 | We also provide a multi-node, multi-GPU Dask-DDP training implementation for PeMS-Bay, PemsAllLA, and the full PeMS dataset. It has the following parameters:
49 |
50 | | Argument | Short Form | Type | Default | Description |
51 | |----------|-----------|------|---------|-------------|
52 | | `--epochs` | `-e` | `int` | `100` | The desired number of training epochs. |
53 | | `--batch-size` | `-bs` | `int` | `64` | The desired batch size for training. |
54 | | `--gpu` | `-g` | `str` | `"False"` | Indicates whether data should be preprocessed and migrated directly to the GPU. Use `"True"` to enable GPU processing. |
55 | | `--debug` | `-d` | `str` | `"False"` | Enables debug mode, printing values for debugging. Use `"True"` to enable debugging. |
56 | | `--dask-cluster-file` | N/A | `str` | `""` | Path to the Dask scheduler file for the Dask CLI interface. |
57 | | `--npar` | `-np` | `int` | `1` | The number of GPUs or workers per node. |
58 | | `--dataset` | N/A | `str` | `"pems-bay"` | Specifies which dataset is in use (e.g., PeMS-Bay, PeMS-All-LA, PeMS). |
59 |
60 | To execute in a single node, omit the `--dask-cluster-file` argument. To run multi-node, setup a Dask scheduler and dask workers user the [Dask command line interface](https://docs.dask.org/en/latest/deploying-cli.html) and pass the scheduler file via `--dask-cluster-file`. An example multi-GPU, multi-node script for Argonne's [Polaris supercomputer](https://www.alcf.anl.gov/polaris) is shown in `submit.sh`.
61 |
62 |
--------------------------------------------------------------------------------
/examples/recurrent/a3tgcn_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import A3TGCN
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features, periods):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = A3TGCN(node_features, 32, periods)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x.view(x.shape[0], 1, x.shape[1]), edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 1, periods = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(50)):
39 | cost = 0
40 | for time, snapshot in enumerate(train_dataset):
41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
42 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
43 | cost = cost / (time+1)
44 | cost.backward()
45 | optimizer.step()
46 | optimizer.zero_grad()
47 |
48 | model.eval()
49 | cost = 0
50 | for time, snapshot in enumerate(test_dataset):
51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost = cost.item()
55 | print("MSE: {:.4f}".format(cost))
56 |
--------------------------------------------------------------------------------
/examples/recurrent/agcrn_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import AGCRN
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset(lags=8)
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = AGCRN(number_of_nodes = 20,
24 | in_channels = node_features,
25 | out_channels = 2,
26 | K = 2,
27 | embedding_dimensions = 4)
28 | self.linear = torch.nn.Linear(2, 1)
29 |
30 | def forward(self, x, e, h):
31 | h_0 = self.recurrent(x, e, h)
32 | y = F.relu(h_0)
33 | y = self.linear(y)
34 | return y, h_0
35 |
36 | model = RecurrentGCN(node_features = 8)
37 |
38 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
39 |
40 | model.train()
41 |
42 | e = torch.empty(20, 4)
43 |
44 | torch.nn.init.xavier_uniform_(e)
45 |
46 | for epoch in tqdm(range(200)):
47 | cost = 0
48 | h = None
49 | for time, snapshot in enumerate(train_dataset):
50 | x = snapshot.x.view(1, 20, 8)
51 | y_hat, h = model(x, e, h)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost.backward()
55 | optimizer.step()
56 | optimizer.zero_grad()
57 |
58 | model.eval()
59 | cost = 0
60 | for time, snapshot in enumerate(test_dataset):
61 | x = snapshot.x.view(1, 20, 8)
62 | y_hat, h = model(x, e, h)
63 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
64 | cost = cost / (time+1)
65 | cost = cost.item()
66 | print("MSE: {:.4f}".format(cost))
67 |
--------------------------------------------------------------------------------
/examples/recurrent/dcrnn_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import DCRNN
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = DCRNN(node_features, 32, 1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x, edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | for time, snapshot in enumerate(train_dataset):
41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
42 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
43 | cost = cost / (time+1)
44 | cost.backward()
45 | optimizer.step()
46 | optimizer.zero_grad()
47 |
48 | model.eval()
49 | cost = 0
50 | for time, snapshot in enumerate(test_dataset):
51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost = cost.item()
55 | print("MSE: {:.4f}".format(cost))
56 |
--------------------------------------------------------------------------------
/examples/recurrent/dygrencoder_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import DyGrEncoder
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = DyGrEncoder(conv_out_channels=4, conv_num_layers=1, conv_aggr="mean", lstm_out_channels=32, lstm_num_layers=1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight, h_0, c_0):
27 | h, h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h_0, c_0)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h, h_0, c_0
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | h, c = None, None
41 | for time, snapshot in enumerate(train_dataset):
42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
43 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
44 | cost = cost / (time+1)
45 | cost.backward()
46 | optimizer.step()
47 | optimizer.zero_grad()
48 |
49 | model.eval()
50 | cost = 0
51 | h, c = None, None
52 | for time, snapshot in enumerate(test_dataset):
53 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
54 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
55 | cost = cost / (time+1)
56 | cost = cost.item()
57 | print("MSE: {:.4f}".format(cost))
58 |
--------------------------------------------------------------------------------
/examples/recurrent/evolvegcnh_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import EvolveGCNH
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_count, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = EvolveGCNH(node_count, node_features)
24 | self.linear = torch.nn.Linear(node_features, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x, edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 4, node_count = 20)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | for time, snapshot in enumerate(train_dataset):
41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
42 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
43 | cost = cost / (time+1)
44 | cost.backward()
45 | optimizer.step()
46 | optimizer.zero_grad()
47 |
48 | model.eval()
49 | cost = 0
50 | for time, snapshot in enumerate(test_dataset):
51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost = cost.item()
55 | print("MSE: {:.4f}".format(cost))
56 |
--------------------------------------------------------------------------------
/examples/recurrent/evolvegcno_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import EvolveGCNO
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = EvolveGCNO(node_features)
24 | self.linear = torch.nn.Linear(node_features, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x, edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 4)
33 | for param in model.parameters():
34 | param.retain_grad()
35 |
36 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
37 |
38 | model.train()
39 |
40 | for epoch in tqdm(range(200)):
41 | cost = 0
42 | for time, snapshot in enumerate(train_dataset):
43 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
44 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
45 | cost = cost / (time+1)
46 | cost.backward(retain_graph=True)
47 | optimizer.step()
48 | optimizer.zero_grad()
49 |
50 | model.eval()
51 | cost = 0
52 | for time, snapshot in enumerate(test_dataset):
53 | if time == 0:
54 | model.recurrent.weight = None
55 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
56 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
57 | cost = cost / (time+1)
58 | cost = cost.item()
59 | print("MSE: {:.4f}".format(cost))
60 |
--------------------------------------------------------------------------------
/examples/recurrent/gclstm_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import GCLSTM
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = GCLSTM(node_features, 32, 1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight, h, c):
27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h, c)
28 | h = F.relu(h_0)
29 | h = self.linear(h)
30 | return h, h_0, c_0
31 |
32 | model = RecurrentGCN(node_features=4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | h, c = None, None
41 | for time, snapshot in enumerate(train_dataset):
42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
43 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
44 | cost = cost / (time+1)
45 | cost.backward()
46 | optimizer.step()
47 | optimizer.zero_grad()
48 |
49 | model.eval()
50 | cost = 0
51 | for time, snapshot in enumerate(test_dataset):
52 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
53 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
54 | cost = cost / (time+1)
55 | cost = cost.item()
56 | print("MSE: {:.4f}".format(cost))
57 |
--------------------------------------------------------------------------------
/examples/recurrent/gconvgru_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import GConvGRU
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = GConvGRU(node_features, 32, 1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x, edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | for time, snapshot in enumerate(train_dataset):
41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
42 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
43 | cost = cost / (time+1)
44 | cost.backward()
45 | optimizer.step()
46 | optimizer.zero_grad()
47 |
48 | model.eval()
49 | cost = 0
50 | for time, snapshot in enumerate(test_dataset):
51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost = cost.item()
55 | print("MSE: {:.4f}".format(cost))
56 |
--------------------------------------------------------------------------------
/examples/recurrent/gconvlstm_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import GConvLSTM
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = GConvLSTM(node_features, 32, 1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight, h, c):
27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h, c)
28 | h = F.relu(h_0)
29 | h = self.linear(h)
30 | return h, h_0, c_0
31 |
32 | model = RecurrentGCN(node_features=4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | h, c = None, None
41 | for time, snapshot in enumerate(train_dataset):
42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
43 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
44 | cost = cost / (time+1)
45 | cost.backward()
46 | optimizer.step()
47 | optimizer.zero_grad()
48 |
49 | model.eval()
50 | cost = 0
51 | for time, snapshot in enumerate(test_dataset):
52 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
53 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
54 | cost = cost / (time+1)
55 | cost = cost.item()
56 | print("MSE: {:.4f}".format(cost))
57 |
--------------------------------------------------------------------------------
/examples/recurrent/lightning_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
6 |
7 | from torch_geometric_temporal.nn.recurrent import DCRNN
8 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
9 | from torch_geometric_temporal.signal import temporal_signal_split
10 |
11 |
12 | class LitDiffConvModel(pl.LightningModule):
13 |
14 | def __init__(self, node_features, filters):
15 | super().__init__()
16 | self.recurrent = DCRNN(node_features, filters, 1)
17 | self.linear = torch.nn.Linear(filters, 1)
18 |
19 |
20 | def configure_optimizers(self):
21 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
22 | return optimizer
23 |
24 | def training_step(self, train_batch, batch_idx):
25 | x = train_batch.x
26 | y = train_batch.y.view(-1, 1)
27 | edge_index = train_batch.edge_index
28 | h = self.recurrent(x, edge_index)
29 | h = F.relu(h)
30 | h = self.linear(h)
31 | loss = F.mse_loss(h, y)
32 | return loss
33 |
34 | def validation_step(self, val_batch, batch_idx):
35 | x = val_batch.x
36 | y = val_batch.y.view(-1, 1)
37 | edge_index = val_batch.edge_index
38 | h = self.recurrent(x, edge_index)
39 | h = F.relu(h)
40 | h = self.linear(h)
41 | loss = F.mse_loss(h, y)
42 | metrics = {'val_loss': loss}
43 | self.log_dict(metrics)
44 | return metrics
45 |
46 |
47 | loader = ChickenpoxDatasetLoader()
48 |
49 | dataset_loader = loader.get_dataset(lags=32)
50 |
51 | train_loader, val_loader = temporal_signal_split(dataset_loader,
52 | train_ratio=0.2)
53 |
54 | model = LitDiffConvModel(node_features=32,
55 | filters=16)
56 |
57 | early_stop_callback = EarlyStopping(monitor='val_loss',
58 | min_delta=0.00,
59 | patience=10,
60 | verbose=False,
61 | mode='max')
62 |
63 | trainer = pl.Trainer(callbacks=[early_stop_callback])
64 |
65 | trainer.fit(model, train_loader, val_loader)
66 |
--------------------------------------------------------------------------------
/examples/recurrent/lrgcn_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import LRGCN
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = LRGCN(node_features, 32, 1, 1)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight, h_0, c_0):
27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h_0, c_0)
28 | h = F.relu(h_0)
29 | h = self.linear(h)
30 | return h, h_0, c_0
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(200)):
39 | cost = 0
40 | h, c = None, None
41 | for time, snapshot in enumerate(train_dataset):
42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
43 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
44 | cost = cost / (time+1)
45 | cost.backward()
46 | optimizer.step()
47 | optimizer.zero_grad()
48 |
49 | model.eval()
50 | cost = 0
51 | h, c = None, None
52 | for time, snapshot in enumerate(test_dataset):
53 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
54 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
55 | cost = cost / (time+1)
56 | cost = cost.item()
57 | print("MSE: {:.4f}".format(cost))
58 |
--------------------------------------------------------------------------------
/examples/recurrent/mpnnlstm_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import MPNNLSTM
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = MPNNLSTM(node_features, 32, 20, 1, 0.5)
24 | self.linear = torch.nn.Linear(2*32 + node_features, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight):
27 | h = self.recurrent(x, edge_index, edge_weight)
28 | h = F.relu(h)
29 | h = self.linear(h)
30 | return h
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(50)):
39 | cost = 0
40 | for time, snapshot in enumerate(train_dataset):
41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
42 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
43 | cost = cost / (time+1)
44 | cost.backward()
45 | optimizer.step()
46 | optimizer.zero_grad()
47 |
48 | model.eval()
49 | cost = 0
50 | for time, snapshot in enumerate(test_dataset):
51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
52 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
53 | cost = cost / (time+1)
54 | cost = cost.item()
55 | print("MSE: {:.4f}".format(cost))
56 |
--------------------------------------------------------------------------------
/examples/recurrent/tgcn_example.py:
--------------------------------------------------------------------------------
1 | try:
2 | from tqdm import tqdm
3 | except ImportError:
4 | def tqdm(iterable):
5 | return iterable
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric_temporal.nn.recurrent import TGCN
10 |
11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
12 | from torch_geometric_temporal.signal import temporal_signal_split
13 |
14 | loader = ChickenpoxDatasetLoader()
15 |
16 | dataset = loader.get_dataset()
17 |
18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)
19 |
20 | class RecurrentGCN(torch.nn.Module):
21 | def __init__(self, node_features):
22 | super(RecurrentGCN, self).__init__()
23 | self.recurrent = TGCN(node_features, 32)
24 | self.linear = torch.nn.Linear(32, 1)
25 |
26 | def forward(self, x, edge_index, edge_weight, prev_hidden_state):
27 | h = self.recurrent(x, edge_index, edge_weight, prev_hidden_state)
28 | y = F.relu(h)
29 | y = self.linear(y)
30 | return y, h
31 |
32 | model = RecurrentGCN(node_features = 4)
33 |
34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
35 |
36 | model.train()
37 |
38 | for epoch in tqdm(range(50)):
39 | cost = 0
40 | hidden_state = None
41 | for time, snapshot in enumerate(train_dataset):
42 | y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr,hidden_state)
43 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
44 | cost = cost / (time+1)
45 | cost.backward()
46 | optimizer.step()
47 | optimizer.zero_grad()
48 |
49 | model.eval()
50 | cost = 0
51 | hidden_state = None
52 | for time, snapshot in enumerate(test_dataset):
53 | y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, hidden_state)
54 | cost = cost + torch.mean((y_hat-snapshot.y)**2)
55 | cost = cost / (time+1)
56 | cost = cost.item()
57 | print("MSE: {:.4f}".format(cost))
58 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.12"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | install_requires = [
4 | "decorator==4.4.2",
5 | "torch",
6 | "cython",
7 | "torch_sparse",
8 | "torch_scatter",
9 | "torch_geometric",
10 | "numpy",
11 | "networkx",
12 | ]
13 | tests_require = ["pytest", "pytest-cov", "mock", "networkx", "tqdm"]
14 | index_require = ['dask', "pandas", "tables"]
15 | ddp_require = ["dask[distributed]", "dask_pytorch_ddp", "pandas", "tables"]
16 |
17 | keywords = [
18 | "machine-learning",
19 | "deep-learning",
20 | "deeplearning",
21 | "deep learning",
22 | "machine learning",
23 | "signal processing",
24 | "temporal signal",
25 | "graph",
26 | "dynamic graph",
27 | "embedding",
28 | "dynamic embedding",
29 | "graph convolution",
30 | "gcn",
31 | "graph neural network",
32 | "graph attention",
33 | "lstm",
34 | "temporal network",
35 | "representation learning",
36 | "learning",
37 | ]
38 |
39 | setup(
40 | name="torch_geometric_temporal",
41 | packages=find_packages(),
42 | version="0.55.0",
43 | license="MIT",
44 | description="A Temporal Extension Library for PyTorch Geometric.",
45 | author="Benedek Rozemberczki",
46 | author_email="benedek.rozemberczki@gmail.com",
47 | url="https://github.com/benedekrozemberczki/pytorch_geometric_temporal",
48 | download_url="https://github.com/benedekrozemberczki/pytorch_geometric_temporal/archive/v0.54.0.tar.gz",
49 | keywords=keywords,
50 | install_requires=install_requires,
51 | extras_require={
52 | "test": tests_require,
53 | "index": index_require,
54 | "ddp": ddp_require
55 | },
56 | python_requires=">=3.6",
57 | classifiers=[
58 | "Development Status :: 3 - Alpha",
59 | "Intended Audience :: Developers",
60 | "Topic :: Software Development :: Build Tools",
61 | "License :: OSI Approved :: MIT License",
62 | "Programming Language :: Python :: 3.6",
63 | ],
64 | )
65 |
--------------------------------------------------------------------------------
/test/heterogeneous_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import networkx as nx
4 | import torch_geometric.transforms as T
5 | from torch_geometric.data import HeteroData
6 | from torch_geometric_temporal.nn.hetero import HeteroGCLSTM
7 |
8 |
9 | def get_edge_array(n_count):
10 | return np.array([edge for edge in nx.gnp_random_graph(n_count, 0.1).edges()]).T
11 |
12 |
13 | def create_hetero_mock_data(n_count, feature_dict):
14 | _x_dict = {'author': torch.FloatTensor(np.random.uniform(0, 1, (n_count, feature_dict['author']))),
15 | 'paper': torch.FloatTensor(np.random.uniform(0, 1, (n_count, feature_dict['paper'])))}
16 | _edge_index_dict = {('author', 'writes', 'paper'): torch.LongTensor(get_edge_array(n_count))}
17 |
18 | data = HeteroData()
19 | data['author'].x = _x_dict['author']
20 | data['paper'].x = _x_dict['paper']
21 | data[('author', 'writes', 'paper')].edge_index = _edge_index_dict[('author', 'writes', 'paper')]
22 | data = T.ToUndirected()(data)
23 |
24 | return data.x_dict, data.edge_index_dict, data.metadata()
25 |
26 |
27 | def test_hetero_gclstm_layer():
28 | """
29 | Testing the HeteroGCLSTM Layer.
30 | """
31 | number_of_nodes = 50
32 | feature_dict = {'author': 20, 'paper': 30}
33 | out_channels = 32
34 |
35 | x_dict, edge_index_dict, metadata = create_hetero_mock_data(number_of_nodes, feature_dict)
36 |
37 | layer = HeteroGCLSTM(in_channels_dict=feature_dict, out_channels=out_channels, metadata=metadata)
38 |
39 | h_dict, c_dict = layer(x_dict, edge_index_dict)
40 |
41 | assert h_dict['author'].shape == (number_of_nodes, out_channels)
42 | assert h_dict['paper'].shape == (number_of_nodes, out_channels)
43 | assert c_dict['author'].shape == (number_of_nodes, out_channels)
44 | assert c_dict['paper'].shape == (number_of_nodes, out_channels)
45 |
46 | h_dict, c_dict = layer(x_dict, edge_index_dict, h_dict, c_dict)
47 |
48 | assert h_dict['author'].shape == (number_of_nodes, out_channels)
49 | assert h_dict['paper'].shape == (number_of_nodes, out_channels)
50 | assert c_dict['author'].shape == (number_of_nodes, out_channels)
51 | assert c_dict['paper'].shape == (number_of_nodes, out_channels)
52 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/__init__.py:
--------------------------------------------------------------------------------
1 | from torch_geometric_temporal.nn import *
2 | from torch_geometric_temporal.dataset import *
3 | from torch_geometric_temporal.signal import *
4 |
5 | __version__ = "0.54.0"
6 |
7 | __all__ = [
8 | "torch_geometric",
9 | "__version__",
10 | ]
11 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .chickenpox import ChickenpoxDatasetLoader
2 | from .pedalme import PedalMeDatasetLoader
3 | from .metr_la import METRLADatasetLoader
4 | from .pems_bay import PemsBayDatasetLoader
5 | from .wikimath import WikiMathsDatasetLoader
6 | from .windmilllarge import WindmillOutputLargeDatasetLoader
7 | from .windmillmedium import WindmillOutputMediumDatasetLoader
8 | from .windmillsmall import WindmillOutputSmallDatasetLoader
9 | from .encovid import EnglandCovidDatasetLoader
10 | from .twitter_tennis import TwitterTennisDatasetLoader
11 | from .montevideo_bus import MontevideoBusDatasetLoader
12 | from .mtm import MTMDatasetLoader
13 |
14 | from .pemsAllLA import PemsAllLADatasetLoader
15 | from .pems import PemsDatasetLoader
16 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/chickenpox.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 | from ..signal import StaticGraphTemporalSignal
9 |
10 |
11 | class ChickenpoxDatasetLoader(object):
12 | """A dataset of county level chicken pox cases in Hungary between 2004
13 | and 2014. We made it public during the development of PyTorch Geometric
14 | Temporal. The underlying graph is static - vertices are counties and
15 | edges are neighbourhoods. Vertex features are lagged weekly counts of the
16 | chickenpox cases (we included 4 lags). The target is the weekly number of
17 | cases for the upcoming week (signed integers). Our dataset consist of more
18 | than 500 snapshots (weeks).
19 |
20 | Args:
21 | index (bool, optional): If True, initializes the dataloader to use index-based batching.
22 | Defaults to False.
23 | """
24 | def __init__(self, index=False):
25 | self._read_web_data()
26 | self.index = index
27 |
28 | if index == True:
29 | from ..signal.index_dataset import IndexDataset
30 | self.IndexDataset = IndexDataset
31 |
32 | def _read_web_data(self):
33 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/chickenpox.json"
34 |
35 | context = ssl._create_unverified_context()
36 | self._dataset = json.loads(
37 | urllib.request.urlopen(url, context=context).read()
38 | )
39 |
40 | def _get_edges(self):
41 | self._edges = np.array(self._dataset["edges"]).T
42 |
43 | def _get_edge_weights(self):
44 | self._edge_weights = np.ones(self._edges.shape[1])
45 |
46 | def _get_targets_and_features(self):
47 | stacked_target = np.array(self._dataset["FX"])
48 | self.features = [
49 | stacked_target[i : i + self.lags, :].T
50 | for i in range(stacked_target.shape[0] - self.lags)
51 | ]
52 | self.targets = [
53 | stacked_target[i + self.lags, :].T
54 | for i in range(stacked_target.shape[0] - self.lags)
55 | ]
56 |
57 | def get_dataset(self, lags: int = 4) -> StaticGraphTemporalSignal:
58 | """Returning the Chickenpox Hungary data iterator.
59 |
60 | Args types:
61 | * **lags** *(int)* - The number of time lags.
62 | Return types:
63 | * **dataset** *(StaticGraphTemporalSignal)* - The Chickenpox Hungary dataset.
64 | """
65 | self.lags = lags
66 | self._get_edges()
67 | self._get_edge_weights()
68 | self._get_targets_and_features()
69 | dataset = StaticGraphTemporalSignal(
70 | self._edges, self._edge_weights, self.features, self.targets
71 | )
72 | return dataset
73 |
74 | def get_index_dataset(self, lags=4, batch_size=4, shuffle=False, allGPU=-1, ratio=(0.7, 0.1, 0.2),dask_batching=False):
75 | """
76 | Returns torch dataloaders using index batching for Chickenpox Hungary dataset.
77 |
78 | Args:
79 | lags (int, optional): The number of time lags. Defaults to 4.
80 | batch_size (int, optional): Batch size. Defaults to 4.
81 | shuffle (bool, optional): If the data should be shuffled. Defaults to False.
82 | allGPU (int, optional): GPU device ID for performing preprocessing in GPU memory.
83 | If -1, computation is done on CPU. Defaults to -1.
84 | ratio (tuple of float, optional): The desired train, validation, and test split ratios, respectively.
85 |
86 | Returns:
87 | Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.Tensor, torch.Tensor]:
88 |
89 | A 5-tuple containing:
90 | - **train_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the training set.
91 | - **val_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the validation set.
92 | - **test_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the test set.
93 | - **edges** (*torch.Tensor*): The graph edges as a 2D matrix, shape `[2, num_edges]`.
94 | - **edge_weights** (*torch.Tensor*): Each graph edge's weight, shape `[num_edges]`.
95 | """
96 |
97 | if not self.index:
98 | raise ValueError("get_index_dataset requires 'index=True' in the constructor.")
99 |
100 | data = np.array(self._dataset["FX"])
101 | edges = torch.tensor(self._dataset["edges"],dtype=torch.int64).T
102 | edge_weights = torch.ones(edges.shape[1],dtype=torch.float)
103 | num_samples = data.shape[0]
104 |
105 | if allGPU != -1:
106 | data = torch.tensor(data, dtype=torch.float).to(f"cuda:{allGPU}")
107 | data = data.unsqueeze(-1)
108 | else:
109 | data = np.expand_dims(data, axis=-1)
110 |
111 |
112 | x_i = np.arange(num_samples - (2 * lags - 1))
113 |
114 | num_samples = x_i.shape[0]
115 | num_train = round(num_samples * ratio[0])
116 | num_test = round(num_samples * ratio[2])
117 | num_val = num_samples - num_train - num_test
118 |
119 | x_train = x_i[:num_train]
120 | x_val = x_i[num_train: num_train + num_val]
121 | x_test = x_i[-num_test:]
122 |
123 | train_dataset = self.IndexDataset(x_train,data,lags,gpu=not (allGPU == -1), lazy=dask_batching)
124 | val_dataset = self.IndexDataset(x_val,data,lags,gpu=not (allGPU == -1), lazy=dask_batching)
125 | test_dataset = self.IndexDataset(x_test,data,lags,gpu=not (allGPU == -1),lazy=dask_batching)
126 |
127 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
128 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
129 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)
130 |
131 |
132 | return train_dataloader, val_dataloader, test_dataloader, edges, edge_weights
133 |
134 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/encovid.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import DynamicGraphTemporalSignal
6 |
7 |
8 | class EnglandCovidDatasetLoader(object):
9 | """A dataset of mobility and history of reported cases of COVID-19
10 | in England NUTS3 regions, from 3 March to 12 of May. The dataset is
11 | segmented in days and the graph is directed and weighted. The graph
12 | indicates how many people moved from one region to the other each day,
13 | based on Facebook Data For Good disease prevention maps.
14 | The node features correspond to the number of COVID-19 cases
15 | in the region in the past **window** days. The task is to predict the
16 | number of cases in each node after 1 day. For details see this paper:
17 | `"Transfer Graph Neural Networks for Pandemic Forecasting." `_
18 | """
19 |
20 | def __init__(self):
21 | self._read_web_data()
22 |
23 | def _read_web_data(self):
24 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json"
25 | context = ssl._create_unverified_context()
26 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
27 |
28 | def _get_edges(self):
29 | self._edges = []
30 | for time in range(self._dataset["time_periods"] - self.lags):
31 | self._edges.append(
32 | np.array(self._dataset["edge_mapping"]["edge_index"][str(time)]).T
33 | )
34 |
35 | def _get_edge_weights(self):
36 | self._edge_weights = []
37 | for time in range(self._dataset["time_periods"] - self.lags):
38 | self._edge_weights.append(
39 | np.array(self._dataset["edge_mapping"]["edge_weight"][str(time)])
40 | )
41 |
42 | def _get_targets_and_features(self):
43 |
44 | stacked_target = np.array(self._dataset["y"])
45 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / (
46 | np.std(stacked_target, axis=0) + 10 ** -10
47 | )
48 | self.features = [
49 | standardized_target[i : i + self.lags, :].T
50 | for i in range(self._dataset["time_periods"] - self.lags)
51 | ]
52 | self.targets = [
53 | standardized_target[i + self.lags, :].T
54 | for i in range(self._dataset["time_periods"] - self.lags)
55 | ]
56 |
57 | def get_dataset(self, lags: int = 8) -> DynamicGraphTemporalSignal:
58 | """Returning the England COVID19 data iterator.
59 |
60 | Args types:
61 | * **lags** *(int)* - The number of time lags.
62 | Return types:
63 | * **dataset** *(StaticGraphTemporalSignal)* - The England Covid dataset.
64 | """
65 | self.lags = lags
66 | self._get_edges()
67 | self._get_edge_weights()
68 | self._get_targets_and_features()
69 | dataset = DynamicGraphTemporalSignal(
70 | self._edges, self._edge_weights, self.features, self.targets
71 | )
72 | return dataset
73 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/montevideo_bus.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import json
3 | import ssl
4 | import urllib.request
5 | import numpy as np
6 | from torch_geometric_temporal.signal import StaticGraphTemporalSignal
7 |
8 |
9 | class MontevideoBusDatasetLoader(object):
10 | """A dataset of inflow passenger at bus stop level from Montevideo city.
11 | This dataset comprises hourly inflow passenger data at bus stop level for 11 bus lines during
12 | October 2020 from Montevideo city (Uruguay). The bus lines selected are the ones that carry
13 | people to the center of the city and they load more than 25% of the total daily inflow traffic.
14 | Vertices are bus stops, edges are links between bus stops when a bus line connects them and the
15 | weight represent the road distance. The target is the passenger inflow. This is a curated
16 | dataset made from different data sources of the Metropolitan Transportation System (STM) of
17 | Montevideo. These datasets are freely available to anyone in the National Catalog of Open Data
18 | from the government of Uruguay (https://catalogodatos.gub.uy/).
19 | """
20 |
21 | def __init__(self):
22 | self._read_web_data()
23 |
24 | def _read_web_data(self):
25 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/montevideo_bus.json"
26 | context = ssl._create_unverified_context()
27 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
28 |
29 | def _get_node_ids(self):
30 | return [node.get('bus_stop') for node in self._dataset["nodes"]]
31 |
32 | def _get_edges(self):
33 | node_ids = self._get_node_ids()
34 | node_id_map = dict(zip(node_ids, range(len(node_ids))))
35 | self._edges = np.array(
36 | [(node_id_map[d["source"]], node_id_map[d["target"]]) for d in self._dataset["links"]]
37 | ).T
38 |
39 | def _get_edge_weights(self):
40 | self._edge_weights = np.array([(d["weight"]) for d in self._dataset["links"]]).T
41 |
42 | def _get_features(self, feature_vars: List[str] = ["y"]):
43 | features = []
44 | for node in self._dataset["nodes"]:
45 | X = node.get("X")
46 | for feature_var in feature_vars:
47 | features.append(np.array(X.get(feature_var)))
48 | stacked_features = np.stack(features).T
49 | standardized_features = (
50 | stacked_features - np.mean(stacked_features, axis=0)
51 | ) / np.std(stacked_features, axis=0)
52 | self.features = [
53 | standardized_features[i : i + self.lags, :].T
54 | for i in range(len(standardized_features) - self.lags)
55 | ]
56 |
57 | def _get_targets(self, target_var: str = "y"):
58 | targets = []
59 | for node in self._dataset["nodes"]:
60 | y = node.get(target_var)
61 | targets.append(np.array(y))
62 | stacked_targets = np.stack(targets).T
63 | standardized_targets = (
64 | stacked_targets - np.mean(stacked_targets, axis=0)
65 | ) / np.std(stacked_targets, axis=0)
66 | self.targets = [
67 | standardized_targets[i + self.lags, :].T
68 | for i in range(len(standardized_targets) - self.lags)
69 | ]
70 |
71 | def get_dataset(
72 | self, lags: int = 4, target_var: str = "y", feature_vars: List[str] = ["y"]
73 | ) -> StaticGraphTemporalSignal:
74 | """Returning the MontevideoBus passenger inflow data iterator.
75 |
76 | Parameters
77 | ----------
78 | lags : int, optional
79 | The number of time lags, by default 4.
80 | target_var : str, optional
81 | Target variable name, by default "y".
82 | feature_vars : List[str], optional
83 | List of feature variables, by default ["y"].
84 |
85 | Returns
86 | -------
87 | StaticGraphTemporalSignal
88 | The MontevideoBus dataset.
89 | """
90 | self.lags = lags
91 | self._get_edges()
92 | self._get_edge_weights()
93 | self._get_features(feature_vars)
94 | self._get_targets(target_var)
95 | dataset = StaticGraphTemporalSignal(
96 | self._edges, self._edge_weights, self.features, self.targets
97 | )
98 | return dataset
99 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/mtm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from torch_geometric_temporal.signal import StaticGraphTemporalSignal
6 |
7 |
8 | class MTMDatasetLoader:
9 | """
10 | A dataset of `Methods-Time Measurement-1 `_
11 | (MTM-1) motions, signalled as consecutive video frames of 21 3D hand keypoints, acquired via
12 | `MediaPipe Hands `_ from RGB-Video
13 | material. Vertices are the finger joints of the human hand and edges are the bones connecting
14 | them. The targets are manually labeled for each frame, according to one of the five MTM-1
15 | motions (classes :math:`C`): Grasp, Release, Move, Reach, Position plus a negative class for
16 | frames without graph signals (no hand present). This is a classification task where :math:`T`
17 | consecutive frames need to be assigned to the corresponding class :math:`C`. The data x is
18 | returned in shape :obj:`(3, 21, T)`, the target is returned one-hot-encoded in shape :obj:`(T, 6)`.
19 | """
20 |
21 | def __init__(self):
22 | self._read_web_data()
23 |
24 | def _read_web_data(self):
25 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/mtm_1.json"
26 | context = ssl._create_unverified_context()
27 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
28 |
29 | def _get_edges(self):
30 | self._edges = np.array(self._dataset["edges"]).T
31 |
32 | def _get_edge_weights(self):
33 | self._edge_weights = np.array([1 for d in self._dataset["edges"]]).T
34 |
35 | def _get_features(self):
36 | dic = self._dataset
37 | joints = [str(n) for n in range(21)]
38 | dataset_length = len(dic["0"].values())
39 | features = np.zeros((dataset_length, 21, 3))
40 |
41 | for j, joint in enumerate(joints):
42 | for t, xyz in enumerate(dic[joint].values()):
43 | xyz_tuple = list(map(float, xyz.strip("()").split(",")))
44 | features[t, j, :] = xyz_tuple
45 |
46 | self.features = [
47 | features[i : i + self.frames, :].T
48 | for i in range(len(features) - self.frames)
49 | ]
50 |
51 | def _get_targets(self):
52 | # target eoncoding: {0 : 'Grasp', 1 : 'Move', 2 : 'Negative',
53 | # 3 : 'Position', 4 : 'Reach', 5 : 'Release'}
54 | targets = []
55 | for _, y in self._dataset["LABEL"].items():
56 | targets.append(y)
57 |
58 | n_values = np.max(targets) + 1
59 | targets_ohe = np.eye(n_values)[targets]
60 |
61 | self.targets = [
62 | targets_ohe[i : i + self.frames, :]
63 | for i in range(len(targets_ohe) - self.frames)
64 | ]
65 |
66 | def get_dataset(self, frames: int = 16) -> StaticGraphTemporalSignal:
67 | """Returning the MTM-1 motion data iterator.
68 |
69 | Args types:
70 | * **frames** *(int)* - The number of consecutive frames T, default 16.
71 | Return types:
72 | * **dataset** *(StaticGraphTemporalSignal)* - The MTM-1 dataset.
73 | """
74 | self.frames = frames
75 | self._get_edges()
76 | self._get_edge_weights()
77 | self._get_features()
78 | self._get_targets()
79 |
80 | dataset = StaticGraphTemporalSignal(
81 | self._edges, self._edge_weights, self.features, self.targets
82 | )
83 | return dataset
84 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/pedalme.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import StaticGraphTemporalSignal
6 |
7 |
8 | class PedalMeDatasetLoader(object):
9 | """A dataset of PedalMe Bicycle deliver orders in London between 2020
10 | and 2021. We made it public during the development of PyTorch Geometric
11 | Temporal. The underlying graph is static - vertices are localities and
12 | edges are spatial_connections. Vertex features are lagged weekly counts of the
13 | delivery demands (we included 4 lags). The target is the weekly number of
14 | deliveries the upcoming week. Our dataset consist of more than 30 snapshots (weeks).
15 | """
16 |
17 | def __init__(self):
18 | self._read_web_data()
19 |
20 | def _read_web_data(self):
21 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/pedalme_london.json"
22 | context = ssl._create_unverified_context()
23 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
24 |
25 | def _get_edges(self):
26 | self._edges = np.array(self._dataset["edges"]).T
27 |
28 | def _get_edge_weights(self):
29 | self._edge_weights = np.array(self._dataset["weights"]).T
30 |
31 | def _get_targets_and_features(self):
32 | stacked_target = np.array(self._dataset["X"])
33 | self.features = [
34 | stacked_target[i : i + self.lags, :].T
35 | for i in range(stacked_target.shape[0] - self.lags)
36 | ]
37 | self.targets = [
38 | stacked_target[i + self.lags, :].T
39 | for i in range(stacked_target.shape[0] - self.lags)
40 | ]
41 |
42 | def get_dataset(self, lags: int = 4) -> StaticGraphTemporalSignal:
43 | """Returning the PedalMe London demand data iterator.
44 |
45 | Args types:
46 | * **lags** *(int)* - The number of time lags.
47 | Return types:
48 | * **dataset** *(StaticGraphTemporalSignal)* - The PedalMe dataset.
49 | """
50 | self.lags = lags
51 | self._get_edges()
52 | self._get_edge_weights()
53 | self._get_targets_and_features()
54 | dataset = StaticGraphTemporalSignal(
55 | self._edges, self._edge_weights, self.features, self.targets
56 | )
57 | return dataset
58 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/twitter_tennis.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import DynamicGraphTemporalSignal
6 |
7 |
8 | def transform_degree(x, cutoff=4):
9 | log_deg = np.ceil(np.log(x + 1.0))
10 | return np.minimum(log_deg, cutoff)
11 |
12 |
13 | def transform_transitivity(x):
14 | trans = x * 10
15 | return np.floor(trans)
16 |
17 |
18 | def onehot_encoding(x, unique_vals):
19 | E = np.zeros((len(x), len(unique_vals)))
20 | for i, val in enumerate(x):
21 | E[i, unique_vals.index(val)] = 1.0
22 | return E
23 |
24 |
25 | def encode_features(X, log_degree_cutoff=4):
26 | X_arr = np.array(X)
27 | a = transform_degree(X_arr[:, 0], log_degree_cutoff)
28 | b = transform_transitivity(X_arr[:, 1])
29 | A = onehot_encoding(a, range(log_degree_cutoff + 1))
30 | B = onehot_encoding(b, range(11))
31 | return np.concatenate((A, B), axis=1)
32 |
33 |
34 | class TwitterTennisDatasetLoader(object):
35 | """
36 | Twitter mention graphs related to major tennis tournaments from 2017.
37 | Nodes are Twitter accounts and edges are mentions between them.
38 | Each snapshot contains the graph induced by the most popular nodes
39 | of the original dataset. Node labels encode the number of mentions
40 | received in the original dataset for the next snapshot. Read more
41 | on the original Twitter data in the 'Temporal Walk Based Centrality Metric for Graph Streams' paper.
42 |
43 | Parameters
44 | ----------
45 | event_id : str
46 | Choose to load the mention network for Roland-Garros 2017 ("rg17") or USOpen 2017 ("uo17")
47 | N : int <= 1000
48 | Number of most popular nodes to load. By default N=1000. Each snapshot contains the graph induced by these nodes.
49 | feature_mode : str
50 | None : load raw degree and transitivity node features
51 | "encoded" : load onehot encoded degree and transitivity node features
52 | "diagonal" : set identity matrix as node features
53 | target_offset : int
54 | Set the snapshot offset for the node labels to be predicted. By default node labels for the next snapshot are predicted (target_offset=1).
55 | """
56 |
57 | def __init__(
58 | self, event_id="rg17", N=None, feature_mode="encoded", target_offset=1
59 | ):
60 | self.N = N
61 | self.target_offset = target_offset
62 | if event_id in ["rg17", "uo17"]:
63 | self.event_id = event_id
64 | else:
65 | raise ValueError(
66 | "Invalid 'event_id'! Choose 'rg17' or 'uo17' to load the Roland-Garros 2017 or the USOpen 2017 Twitter tennis dataset respectively."
67 | )
68 | if feature_mode in [None, "diagonal", "encoded"]:
69 | self.feature_mode = feature_mode
70 | else:
71 | raise ValueError(
72 | "Choose feature_mode from values [None, 'diagonal', 'encoded']."
73 | )
74 | self._read_web_data()
75 |
76 | def _read_web_data(self):
77 | fname = "twitter_tennis_%s.json" % self.event_id
78 | url = (
79 | "https://raw.githubusercontent.com/ferencberes/pytorch_geometric_temporal/developer/dataset/"
80 | + fname
81 | )
82 | context = ssl._create_unverified_context()
83 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
84 | # with open("/home/fberes/git/pytorch_geometric_temporal/dataset/"+fname) as f:
85 | # self._dataset = json.load(f)
86 |
87 | def _get_edges(self):
88 | edge_indices = []
89 | self.edges = []
90 | for time in range(self._dataset["time_periods"]):
91 | E = np.array(self._dataset[str(time)]["edges"])
92 | if self.N != None:
93 | selector = np.where((E[:, 0] < self.N) & (E[:, 1] < self.N))
94 | E = E[selector]
95 | edge_indices.append(selector)
96 | self.edges.append(E.T)
97 | self.edge_indices = edge_indices
98 |
99 | def _get_edge_weights(self):
100 | edge_indices = self.edge_indices
101 | self.edge_weights = []
102 | for i, time in enumerate(range(self._dataset["time_periods"])):
103 | W = np.array(self._dataset[str(time)]["weights"])
104 | if self.N != None:
105 | W = W[edge_indices[i]]
106 | self.edge_weights.append(W)
107 |
108 | def _get_features(self):
109 | self.features = []
110 | for time in range(self._dataset["time_periods"]):
111 | X = np.array(self._dataset[str(time)]["X"])
112 | if self.N != None:
113 | X = X[: self.N]
114 | if self.feature_mode == "diagonal":
115 | X = np.identity(X.shape[0])
116 | elif self.feature_mode == "encoded":
117 | X = encode_features(X)
118 | self.features.append(X)
119 |
120 | def _get_targets(self):
121 | self.targets = []
122 | T = self._dataset["time_periods"]
123 | for time in range(T):
124 | # predict node degrees in advance
125 | snapshot_id = min(time + self.target_offset, T - 1)
126 | y = np.array(self._dataset[str(snapshot_id)]["y"])
127 | # logarithmic transformation for node degrees
128 | y = np.log(1.0 + y)
129 | if self.N != None:
130 | y = y[: self.N]
131 | self.targets.append(y)
132 |
133 | def get_dataset(self) -> DynamicGraphTemporalSignal:
134 | """Returning the TennisDataset data iterator.
135 |
136 | Return types:
137 | * **dataset** *(DynamicGraphTemporalSignal)* - Selected Twitter tennis dataset (Roland-Garros 2017 or USOpen 2017).
138 | """
139 | self._get_edges()
140 | self._get_edge_weights()
141 | self._get_features()
142 | self._get_targets()
143 | dataset = DynamicGraphTemporalSignal(
144 | self.edges, self.edge_weights, self.features, self.targets
145 | )
146 | return dataset
147 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/wikimath.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import StaticGraphTemporalSignal
6 |
7 |
8 | class WikiMathsDatasetLoader(object):
9 | """A dataset of vital mathematics articles from Wikipedia. We made it
10 | public during the development of PyTorch Geometric Temporal. The
11 | underlying graph is static - vertices are Wikipedia pages and edges are
12 | links between them. The graph is directed and weighted. Weights represent
13 | the number of links found at the source Wikipedia page linking to the target
14 | Wikipedia page. The target is the daily user visits to the Wikipedia pages
15 | between March 16th 2019 and March 15th 2021 which results in 731 periods.
16 | """
17 |
18 | def __init__(self):
19 | self._read_web_data()
20 |
21 | def _read_web_data(self):
22 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/wikivital_mathematics.json"
23 | context = ssl._create_unverified_context()
24 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read())
25 |
26 | def _get_edges(self):
27 | self._edges = np.array(self._dataset["edges"]).T
28 |
29 | def _get_edge_weights(self):
30 | self._edge_weights = np.array(self._dataset["weights"]).T
31 |
32 | def _get_targets_and_features(self):
33 |
34 | targets = []
35 | for time in range(self._dataset["time_periods"]):
36 | targets.append(np.array(self._dataset[str(time)]["y"]))
37 | stacked_target = np.stack(targets)
38 | standardized_target = (
39 | stacked_target - np.mean(stacked_target, axis=0)
40 | ) / np.std(stacked_target, axis=0)
41 | self.features = [
42 | standardized_target[i : i + self.lags, :].T
43 | for i in range(len(targets) - self.lags)
44 | ]
45 | self.targets = [
46 | standardized_target[i + self.lags, :].T
47 | for i in range(len(targets) - self.lags)
48 | ]
49 |
50 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal:
51 | """Returning the Wikipedia Vital Mathematics data iterator.
52 |
53 | Args types:
54 | * **lags** *(int)* - The number of time lags.
55 | Return types:
56 | * **dataset** *(StaticGraphTemporalSignal)* - The Wiki Maths dataset.
57 | """
58 | self.lags = lags
59 | self._get_edges()
60 | self._get_edge_weights()
61 | self._get_targets_and_features()
62 | dataset = StaticGraphTemporalSignal(
63 | self._edges, self._edge_weights, self.features, self.targets
64 | )
65 | return dataset
66 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/windmillmedium.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import StaticGraphTemporalSignal
6 |
7 |
8 | class WindmillOutputMediumDatasetLoader(object):
9 | """Hourly energy output of windmills from a European country
10 | for more than 2 years. Vertices represent 26 windmills and
11 | weighted edges describe the strength of relationships. The target
12 | variable allows for regression tasks.
13 | """
14 |
15 | def __init__(self):
16 | self._read_web_data()
17 |
18 | def _read_web_data(self):
19 | url = "https://graphmining.ai/temporal_datasets/windmill_output_medium.json"
20 | context = ssl._create_unverified_context()
21 | self._dataset = json.loads(
22 | urllib.request.urlopen(url, context=context).read().decode()
23 | )
24 |
25 | def _get_edges(self):
26 | self._edges = np.array(self._dataset["edges"]).T
27 |
28 | def _get_edge_weights(self):
29 | self._edge_weights = np.array(self._dataset["weights"]).T
30 |
31 | def _get_targets_and_features(self):
32 | stacked_target = np.stack(self._dataset["block"])
33 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / (
34 | np.std(stacked_target, axis=0) + 10 ** -10
35 | )
36 | self.features = [
37 | standardized_target[i : i + self.lags, :].T
38 | for i in range(standardized_target.shape[0] - self.lags)
39 | ]
40 | self.targets = [
41 | standardized_target[i + self.lags, :].T
42 | for i in range(standardized_target.shape[0] - self.lags)
43 | ]
44 |
45 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal:
46 | """Returning the Windmill Output data iterator.
47 |
48 | Args types:
49 | * **lags** *(int)* - The number of time lags.
50 | Return types:
51 | * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset.
52 | """
53 | self.lags = lags
54 | self._get_edges()
55 | self._get_edge_weights()
56 | self._get_targets_and_features()
57 | dataset = StaticGraphTemporalSignal(
58 | self._edges, self._edge_weights, self.features, self.targets
59 | )
60 | return dataset
61 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/dataset/windmillsmall.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ssl
3 | import urllib.request
4 | import numpy as np
5 | from ..signal import StaticGraphTemporalSignal
6 |
7 |
8 | class WindmillOutputSmallDatasetLoader(object):
9 | """Hourly energy output of windmills from a European country
10 | for more than 2 years. Vertices represent 11 windmills and
11 | weighted edges describe the strength of relationships. The target
12 | variable allows for regression tasks.
13 | """
14 |
15 | def __init__(self):
16 | self._read_web_data()
17 |
18 | def _read_web_data(self):
19 | url = "https://graphmining.ai/temporal_datasets/windmill_output_small.json"
20 | context = ssl._create_unverified_context()
21 | self._dataset = json.loads(
22 | urllib.request.urlopen(url, context=context).read().decode()
23 | )
24 |
25 | def _get_edges(self):
26 | self._edges = np.array(self._dataset["edges"]).T
27 |
28 | def _get_edge_weights(self):
29 | self._edge_weights = np.array(self._dataset["weights"]).T
30 |
31 | def _get_targets_and_features(self):
32 | stacked_target = np.stack(self._dataset["block"])
33 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / (
34 | np.std(stacked_target, axis=0) + 10 ** -10
35 | )
36 | self.features = [
37 | standardized_target[i : i + self.lags, :].T
38 | for i in range(standardized_target.shape[0] - self.lags)
39 | ]
40 | self.targets = [
41 | standardized_target[i + self.lags, :].T
42 | for i in range(standardized_target.shape[0] - self.lags)
43 | ]
44 |
45 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal:
46 | """Returning the Windmill Output data iterator.
47 |
48 | Args types:
49 | * **lags** *(int)* - The number of time lags.
50 | Return types:
51 | * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset.
52 | """
53 | self.lags = lags
54 | self._get_edges()
55 | self._get_edge_weights()
56 | self._get_targets_and_features()
57 | dataset = StaticGraphTemporalSignal(
58 | self._edges, self._edge_weights, self.features, self.targets
59 | )
60 | return dataset
61 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .recurrent import *
2 | from .attention import *
3 | from .hetero import *
4 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .stgcn import STConv, TemporalConv
2 | from .astgcn import ASTGCN, ChebConvAttention
3 | from .mstgcn import MSTGCN
4 | from .gman import GMAN, SpatioTemporalEmbedding, SpatioTemporalAttention
5 | from .mtgnn import MTGNN, MixProp, GraphConstructor
6 | from .tsagcn import GraphAAGCN, AAGCN
7 | from .dnntsp import DNNTSP
8 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/attention/stgcn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import ChebConv
6 |
7 |
8 | class TemporalConv(nn.Module):
9 | r"""Temporal convolution block applied to nodes in the STGCN Layer
10 | For details see: `"Spatio-Temporal Graph Convolutional Networks:
11 | A Deep Learning Framework for Traffic Forecasting."
12 | `_ Based off the temporal convolution
13 | introduced in "Convolutional Sequence to Sequence Learning" `_
14 |
15 | Args:
16 | in_channels (int): Number of input features.
17 | out_channels (int): Number of output features.
18 | kernel_size (int): Convolutional kernel size.
19 | """
20 |
21 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
22 | super(TemporalConv, self).__init__()
23 | self.conv_1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
24 | self.conv_2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
25 | self.conv_3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
26 |
27 | def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
28 | """Forward pass through temporal convolution block.
29 |
30 | Arg types:
31 | * **X** (torch.FloatTensor) - Input data of shape
32 | (batch_size, input_time_steps, num_nodes, in_channels).
33 |
34 | Return types:
35 | * **H** (torch.FloatTensor) - Output data of shape
36 | (batch_size, in_channels, num_nodes, input_time_steps).
37 | """
38 | X = X.permute(0, 3, 2, 1)
39 | P = self.conv_1(X)
40 | Q = torch.sigmoid(self.conv_2(X))
41 | PQ = P * Q
42 | H = F.relu(PQ + self.conv_3(X))
43 | H = H.permute(0, 3, 2, 1)
44 | return H
45 |
46 |
47 | class STConv(nn.Module):
48 | r"""Spatio-temporal convolution block using ChebConv Graph Convolutions.
49 | For details see: `"Spatio-Temporal Graph Convolutional Networks:
50 | A Deep Learning Framework for Traffic Forecasting"
51 | `_
52 |
53 | NB. The ST-Conv block contains two temporal convolutions (TemporalConv)
54 | with kernel size k. Hence for an input sequence of length m,
55 | the output sequence will be length m-2(k-1).
56 |
57 | Args:
58 | in_channels (int): Number of input features.
59 | hidden_channels (int): Number of hidden units output by graph convolution block
60 | out_channels (int): Number of output features.
61 | kernel_size (int): Size of the kernel considered.
62 | K (int): Chebyshev filter size :math:`K`.
63 | normalization (str, optional): The normalization scheme for the graph
64 | Laplacian (default: :obj:`"sym"`):
65 |
66 | 1. :obj:`None`: No normalization
67 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
68 |
69 | 2. :obj:`"sym"`: Symmetric normalization
70 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
71 | \mathbf{D}^{-1/2}`
72 |
73 | 3. :obj:`"rw"`: Random-walk normalization
74 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
75 |
76 | You need to pass :obj:`lambda_max` to the :meth:`forward` method of
77 | this operator in case the normalization is non-symmetric.
78 | :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
79 | :obj:`[num_graphs]` in a mini-batch scenario and a
80 | scalar/zero-dimensional tensor when operating on single graphs.
81 | You can pre-compute :obj:`lambda_max` via the
82 | :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
83 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
84 | an additive bias. (default: :obj:`True`)
85 |
86 | """
87 |
88 | def __init__(
89 | self,
90 | num_nodes: int,
91 | in_channels: int,
92 | hidden_channels: int,
93 | out_channels: int,
94 | kernel_size: int,
95 | K: int,
96 | normalization: str = "sym",
97 | bias: bool = True,
98 | ):
99 | super(STConv, self).__init__()
100 | self.num_nodes = num_nodes
101 | self.in_channels = in_channels
102 | self.hidden_channels = hidden_channels
103 | self.out_channels = out_channels
104 | self.kernel_size = kernel_size
105 | self.K = K
106 | self.normalization = normalization
107 | self.bias = bias
108 |
109 | self._temporal_conv1 = TemporalConv(
110 | in_channels=in_channels,
111 | out_channels=hidden_channels,
112 | kernel_size=kernel_size,
113 | )
114 |
115 | self._graph_conv = ChebConv(
116 | in_channels=hidden_channels,
117 | out_channels=hidden_channels,
118 | K=K,
119 | normalization=normalization,
120 | bias=bias,
121 | )
122 |
123 | self._temporal_conv2 = TemporalConv(
124 | in_channels=hidden_channels,
125 | out_channels=out_channels,
126 | kernel_size=kernel_size,
127 | )
128 |
129 | self._batch_norm = nn.BatchNorm2d(num_nodes)
130 |
131 | def forward(
132 | self,
133 | X: torch.FloatTensor,
134 | edge_index: torch.LongTensor,
135 | edge_weight: torch.FloatTensor = None,
136 | ) -> torch.FloatTensor:
137 |
138 | r"""Forward pass. If edge weights are not present the forward pass
139 | defaults to an unweighted graph.
140 |
141 | Arg types:
142 | * **X** (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels).
143 | * **edge_index** (PyTorch LongTensor) - Graph edge indices.
144 | * **edge_weight** (PyTorch LongTensor, optional)- Edge weight vector.
145 |
146 | Return types:
147 | * **T** (PyTorch FloatTensor) - Sequence of node features.
148 | """
149 | T_0 = self._temporal_conv1(X)
150 | T = torch.zeros_like(T_0).to(T_0.device)
151 | for b in range(T_0.size(0)):
152 | for t in range(T_0.size(1)):
153 | T[b][t] = self._graph_conv(T_0[b][t], edge_index, edge_weight)
154 |
155 | T = F.relu(T)
156 | T = self._temporal_conv2(T)
157 | T = T.permute(0, 2, 1, 3)
158 | T = self._batch_norm(T)
159 | T = T.permute(0, 2, 1, 3)
160 | return T
161 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/hetero/__init__.py:
--------------------------------------------------------------------------------
1 | from .heterogclstm import HeteroGCLSTM
2 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/__init__.py:
--------------------------------------------------------------------------------
1 | from .gconv_gru import GConvGRU
2 | from .gconv_lstm import GConvLSTM
3 | from .lrgcn import LRGCN
4 | from .gc_lstm import GCLSTM
5 | from .dygrae import DyGrEncoder
6 | from .evolvegcnh import EvolveGCNH
7 | from .evolvegcno import EvolveGCNO
8 | from .dcrnn import DCRNN, BatchedDCRNN
9 | from .temporalgcn import TGCN
10 | from .temporalgcn import TGCN2
11 | from .attentiontemporalgcn import A3TGCN
12 | from .attentiontemporalgcn import A3TGCN2
13 | from .mpnn_lstm import MPNNLSTM
14 | from .agcrn import AGCRN
15 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/agcrn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn.inits import glorot, zeros
5 |
6 |
7 | class AVWGCN(nn.Module):
8 | r"""An implementation of the Node Adaptive Graph Convolution Layer.
9 | For details see: `"Adaptive Graph Convolutional Recurrent Network
10 | for Traffic Forecasting" `_
11 | Args:
12 | in_channels (int): Number of input features.
13 | out_channels (int): Number of output features.
14 | K (int): Filter size :math:`K`.
15 | embedding_dimensions (int): Number of node embedding dimensions.
16 | """
17 |
18 | def __init__(
19 | self, in_channels: int, out_channels: int, K: int, embedding_dimensions: int
20 | ):
21 | super(AVWGCN, self).__init__()
22 | self.K = K
23 | self.weights_pool = torch.nn.Parameter(
24 | torch.Tensor(embedding_dimensions, K, in_channels, out_channels)
25 | )
26 | self.bias_pool = torch.nn.Parameter(
27 | torch.Tensor(embedding_dimensions, out_channels)
28 | )
29 | glorot(self.weights_pool)
30 | zeros(self.bias_pool)
31 |
32 | def forward(self, X: torch.FloatTensor, E: torch.FloatTensor) -> torch.FloatTensor:
33 | r"""Making a forward pass.
34 | Arg types:
35 | * **X** (PyTorch Float Tensor) - Node features.
36 | * **E** (PyTorch Float Tensor) - Node embeddings.
37 | Return types:
38 | * **X_G** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
39 | """
40 |
41 | number_of_nodes = E.shape[0]
42 | supports = F.softmax(F.relu(torch.mm(E, E.transpose(0, 1))), dim=1)
43 | support_set = [torch.eye(number_of_nodes).to(supports.device), supports]
44 | for _ in range(2, self.K):
45 | support = torch.matmul(2 * supports, support_set[-1]) - support_set[-2]
46 | support_set.append(support)
47 | supports = torch.stack(support_set, dim=0)
48 | W = torch.einsum("nd,dkio->nkio", E, self.weights_pool)
49 | bias = torch.matmul(E, self.bias_pool)
50 | X_G = torch.einsum("knm,bmc->bknc", supports, X)
51 | X_G = X_G.permute(0, 2, 1, 3)
52 | X_G = torch.einsum("bnki,nkio->bno", X_G, W) + bias
53 | return X_G
54 |
55 |
56 | class AGCRN(nn.Module):
57 | r"""An implementation of the Adaptive Graph Convolutional Recurrent Unit.
58 | For details see: `"Adaptive Graph Convolutional Recurrent Network
59 | for Traffic Forecasting" `_
60 | Args:
61 | number_of_nodes (int): Number of vertices.
62 | in_channels (int): Number of input features.
63 | out_channels (int): Number of output features.
64 | K (int): Filter size :math:`K`.
65 | embedding_dimensions (int): Number of node embedding dimensions.
66 | """
67 |
68 | def __init__(
69 | self,
70 | number_of_nodes: int,
71 | in_channels: int,
72 | out_channels: int,
73 | K: int,
74 | embedding_dimensions: int,
75 | ):
76 | super(AGCRN, self).__init__()
77 |
78 | self.number_of_nodes = number_of_nodes
79 | self.in_channels = in_channels
80 | self.out_channels = out_channels
81 | self.K = K
82 | self.embedding_dimensions = embedding_dimensions
83 | self._setup_layers()
84 |
85 | def _setup_layers(self):
86 | self._gate = AVWGCN(
87 | in_channels=self.in_channels + self.out_channels,
88 | out_channels=2 * self.out_channels,
89 | K=self.K,
90 | embedding_dimensions=self.embedding_dimensions,
91 | )
92 |
93 | self._update = AVWGCN(
94 | in_channels=self.in_channels + self.out_channels,
95 | out_channels=self.out_channels,
96 | K=self.K,
97 | embedding_dimensions=self.embedding_dimensions,
98 | )
99 |
100 | def _set_hidden_state(self, X, H):
101 | if H is None:
102 | H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device)
103 | return H
104 |
105 | def forward(
106 | self, X: torch.FloatTensor, E: torch.FloatTensor, H: torch.FloatTensor = None
107 | ) -> torch.FloatTensor:
108 | r"""Making a forward pass.
109 | Arg types:
110 | * **X** (PyTorch Float Tensor) - Node feature matrix.
111 | * **E** (PyTorch Float Tensor) - Node embedding matrix.
112 | * **H** (PyTorch Float Tensor) - Node hidden state matrix. Default is None.
113 | Return types:
114 | * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
115 | """
116 | H = self._set_hidden_state(X, H)
117 | X_H = torch.cat((X, H), dim=-1)
118 | Z_R = torch.sigmoid(self._gate(X_H, E))
119 | Z, R = torch.split(Z_R, self.out_channels, dim=-1)
120 | C = torch.cat((X, Z * H), dim=-1)
121 | HC = torch.tanh(self._update(C, E))
122 | H = R * H + (1 - R) * HC
123 | return H
124 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/attentiontemporalgcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .temporalgcn import TGCN
3 | from .temporalgcn import TGCN2
4 | from torch_geometric.nn import GCNConv
5 |
6 |
7 | class A3TGCN(torch.nn.Module):
8 | r"""An implementation of the Attention Temporal Graph Convolutional Cell.
9 | For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
10 | Network for Traffic Forecasting." `_
11 |
12 | Args:
13 | in_channels (int): Number of input features.
14 | out_channels (int): Number of output features.
15 | periods (int): Number of time periods.
16 | improved (bool): Stronger self loops (default :obj:`False`).
17 | cached (bool): Caching the message weights (default :obj:`False`).
18 | add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
19 | """
20 |
21 | def __init__(
22 | self,
23 | in_channels: int,
24 | out_channels: int,
25 | periods: int,
26 | improved: bool = False,
27 | cached: bool = False,
28 | add_self_loops: bool = True
29 | ):
30 | super(A3TGCN, self).__init__()
31 |
32 | self.in_channels = in_channels
33 | self.out_channels = out_channels
34 | self.periods = periods
35 | self.improved = improved
36 | self.cached = cached
37 | self.add_self_loops = add_self_loops
38 | self._setup_layers()
39 |
40 | def _setup_layers(self):
41 | self._base_tgcn = TGCN(
42 | in_channels=self.in_channels,
43 | out_channels=self.out_channels,
44 | improved=self.improved,
45 | cached=self.cached,
46 | add_self_loops=self.add_self_loops,
47 | )
48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49 | self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
50 | torch.nn.init.uniform_(self._attention)
51 |
52 | def forward(
53 | self,
54 | X: torch.FloatTensor,
55 | edge_index: torch.LongTensor,
56 | edge_weight: torch.FloatTensor = None,
57 | H: torch.FloatTensor = None,
58 | ) -> torch.FloatTensor:
59 | """
60 | Making a forward pass. If edge weights are not present the forward pass
61 | defaults to an unweighted graph. If the hidden state matrix is not present
62 | when the forward pass is called it is initialized with zeros.
63 |
64 | Arg types:
65 | * **X** (PyTorch Float Tensor): Node features for T time periods.
66 | * **edge_index** (PyTorch Long Tensor): Graph edge indices.
67 | * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
68 | * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.
69 |
70 | Return types:
71 | * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
72 | """
73 | H_accum = 0
74 | probs = torch.nn.functional.softmax(self._attention, dim=0)
75 | for period in range(self.periods):
76 | H_accum = H_accum + probs[period] * self._base_tgcn(
77 | X[:, :, period], edge_index, edge_weight, H
78 | )
79 | return H_accum
80 |
81 |
82 |
83 | class A3TGCN2(torch.nn.Module):
84 | r"""An implementation THAT SUPPORTS BATCHES of the Attention Temporal Graph Convolutional Cell.
85 | For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
86 | Network for Traffic Forecasting." `_
87 |
88 | Args:
89 | in_channels (int): Number of input features.
90 | out_channels (int): Number of output features.
91 | periods (int): Number of time periods.
92 | improved (bool): Stronger self loops (default :obj:`False`).
93 | cached (bool): Caching the message weights (default :obj:`False`).
94 | add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
95 | """
96 |
97 | def __init__(
98 | self,
99 | in_channels: int,
100 | out_channels: int,
101 | periods: int,
102 | batch_size:int,
103 | improved: bool = False,
104 | cached: bool = False,
105 | add_self_loops: bool = True):
106 | super(A3TGCN2, self).__init__()
107 |
108 | self.in_channels = in_channels # 2
109 | self.out_channels = out_channels # 32
110 | self.periods = periods # 12
111 | self.improved = improved
112 | self.cached = cached
113 | self.add_self_loops = add_self_loops
114 | self.batch_size = batch_size
115 | self._setup_layers()
116 |
117 | def _setup_layers(self):
118 | self._base_tgcn = TGCN2(
119 | in_channels=self.in_channels,
120 | out_channels=self.out_channels,
121 | batch_size=self.batch_size,
122 | improved=self.improved,
123 | cached=self.cached,
124 | add_self_loops=self.add_self_loops)
125 |
126 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
127 | self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
128 | torch.nn.init.uniform_(self._attention)
129 |
130 | def forward(
131 | self,
132 | X: torch.FloatTensor,
133 | edge_index: torch.LongTensor,
134 | edge_weight: torch.FloatTensor = None,
135 | H: torch.FloatTensor = None
136 | ) -> torch.FloatTensor:
137 | """
138 | Making a forward pass. If edge weights are not present the forward pass
139 | defaults to an unweighted graph. If the hidden state matrix is not present
140 | when the forward pass is called it is initialized with zeros.
141 |
142 | Arg types:
143 | * **X** (PyTorch Float Tensor): Node features for T time periods.
144 | * **edge_index** (PyTorch Long Tensor): Graph edge indices.
145 | * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
146 | * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.
147 |
148 | Return types:
149 | * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
150 | """
151 | H_accum = 0
152 | probs = torch.nn.functional.softmax(self._attention, dim=0)
153 | for period in range(self.periods):
154 |
155 | H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32]
156 |
157 | return H_accum
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/dygrae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import LSTM
3 | from torch_geometric.nn import GatedGraphConv
4 |
5 |
6 | class DyGrEncoder(torch.nn.Module):
7 | r"""An implementation of the integrated Gated Graph Convolution Long Short
8 | Term Memory Layer. For details see this paper: `"Predictive Temporal Embedding
9 | of Dynamic Graphs." `_
10 |
11 | Args:
12 | conv_out_channels (int): Number of output channels for the GGCN.
13 | conv_num_layers (int): Number of Gated Graph Convolutions.
14 | conv_aggr (str): Aggregation scheme to use
15 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
16 | lstm_out_channels (int): Number of LSTM channels.
17 | lstm_num_layers (int): Number of neurons in LSTM.
18 | """
19 |
20 | def __init__(
21 | self,
22 | conv_out_channels: int,
23 | conv_num_layers: int,
24 | conv_aggr: str,
25 | lstm_out_channels: int,
26 | lstm_num_layers: int,
27 | ):
28 | super(DyGrEncoder, self).__init__()
29 | assert conv_aggr in ["mean", "add", "max"], "Wrong aggregator."
30 | self.conv_out_channels = conv_out_channels
31 | self.conv_num_layers = conv_num_layers
32 | self.conv_aggr = conv_aggr
33 | self.lstm_out_channels = lstm_out_channels
34 | self.lstm_num_layers = lstm_num_layers
35 | self._create_layers()
36 |
37 | def _create_layers(self):
38 | self.conv_layer = GatedGraphConv(
39 | out_channels=self.conv_out_channels,
40 | num_layers=self.conv_num_layers,
41 | aggr=self.conv_aggr,
42 | bias=True,
43 | )
44 |
45 | self.recurrent_layer = LSTM(
46 | input_size=self.conv_out_channels,
47 | hidden_size=self.lstm_out_channels,
48 | num_layers=self.lstm_num_layers,
49 | )
50 |
51 | def forward(
52 | self,
53 | X: torch.FloatTensor,
54 | edge_index: torch.LongTensor,
55 | edge_weight: torch.FloatTensor = None,
56 | H: torch.FloatTensor = None,
57 | C: torch.FloatTensor = None,
58 | ) -> torch.FloatTensor:
59 | """
60 | Making a forward pass. If the hidden state and cell state matrices are
61 | not present when the forward pass is called these are initialized with zeros.
62 |
63 | Arg types:
64 | * **X** *(PyTorch Float Tensor)* - Node features.
65 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
66 | * **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector.
67 | * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
68 | * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
69 |
70 | Return types:
71 | * **H_tilde** *(PyTorch Float Tensor)* - Output matrix for all nodes.
72 | * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
73 | * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
74 | """
75 | H_tilde = self.conv_layer(X, edge_index, edge_weight)
76 | H_tilde = H_tilde[None, :, :]
77 | if H is None and C is None:
78 | H_tilde, (H, C) = self.recurrent_layer(H_tilde)
79 | elif H is not None and C is not None:
80 | H = H[None, :, :]
81 | C = C[None, :, :]
82 | H_tilde, (H, C) = self.recurrent_layer(H_tilde, (H, C))
83 | else:
84 | raise ValueError("Invalid hidden state and cell matrices.")
85 | H_tilde = H_tilde.squeeze()
86 | H = H.squeeze()
87 | C = C.squeeze()
88 | return H_tilde, H, C
89 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/evolvegcnh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import GRU
3 | from torch_geometric.nn import TopKPooling
4 |
5 | from .evolvegcno import glorot, GCNConv_Fixed_W
6 |
7 |
8 | class EvolveGCNH(torch.nn.Module):
9 | r"""An implementation of the Evolving Graph Convolutional Hidden Layer.
10 | For details see this paper: `"EvolveGCN: Evolving Graph Convolutional
11 | Networks for Dynamic Graph." `_
12 |
13 | Args:
14 | num_of_nodes (int): Number of vertices.
15 | in_channels (int): Number of filters.
16 | improved (bool, optional): If set to :obj:`True`, the layer computes
17 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
18 | (default: :obj:`False`)
19 | cached (bool, optional): If set to :obj:`True`, the layer will cache
20 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
21 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
22 | cached version for further executions.
23 | This parameter should only be set to :obj:`True` in transductive
24 | learning scenarios. (default: :obj:`False`)
25 | normalize (bool, optional): Whether to add self-loops and apply
26 | symmetric normalization. (default: :obj:`True`)
27 | add_self_loops (bool, optional): If set to :obj:`False`, will not add
28 | self-loops to the input graph. (default: :obj:`True`)
29 | """
30 |
31 | def __init__(
32 | self,
33 | num_of_nodes: int,
34 | in_channels: int,
35 | improved: bool = False,
36 | cached: bool = False,
37 | normalize: bool = True,
38 | add_self_loops: bool = True,
39 | ):
40 | super(EvolveGCNH, self).__init__()
41 |
42 | self.num_of_nodes = num_of_nodes
43 | self.in_channels = in_channels
44 | self.improved = improved
45 | self.cached = cached
46 | self.normalize = normalize
47 | self.add_self_loops = add_self_loops
48 | self.weight = None
49 | self.initial_weight = torch.nn.Parameter(torch.Tensor(1, in_channels, in_channels))
50 | self._create_layers()
51 | self.reset_parameters()
52 |
53 | def reset_parameters(self):
54 | glorot(self.initial_weight)
55 |
56 | def reinitialize_weight(self):
57 | self.weight = None
58 |
59 | def _create_layers(self):
60 |
61 | self.ratio = self.in_channels / self.num_of_nodes
62 |
63 | self.pooling_layer = TopKPooling(self.in_channels, self.ratio)
64 |
65 | self.recurrent_layer = GRU(
66 | input_size=self.in_channels, hidden_size=self.in_channels, num_layers=1
67 | )
68 |
69 | self.conv_layer = GCNConv_Fixed_W(
70 | in_channels=self.in_channels,
71 | out_channels=self.in_channels,
72 | improved=self.improved,
73 | cached=self.cached,
74 | normalize=self.normalize,
75 | add_self_loops=self.add_self_loops
76 | )
77 |
78 | def forward(
79 | self,
80 | X: torch.FloatTensor,
81 | edge_index: torch.LongTensor,
82 | edge_weight: torch.FloatTensor = None,
83 | ) -> torch.FloatTensor:
84 | """
85 | Making a forward pass.
86 |
87 | Arg types:
88 | * **X** *(PyTorch Float Tensor)* - Node embedding.
89 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
90 | * **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector.
91 |
92 | Return types:
93 | * **X** *(PyTorch Float Tensor)* - Output matrix for all nodes.
94 | """
95 | X_tilde = self.pooling_layer(X, edge_index)
96 | X_tilde = X_tilde[0][None, :, :]
97 | if self.weight is None:
98 | _, self.weight = self.recurrent_layer(X_tilde, self.initial_weight)
99 | else:
100 | _, self.weight = self.recurrent_layer(X_tilde, self.weight)
101 | X = self.conv_layer(self.weight.squeeze(dim=0), X, edge_index, edge_weight)
102 | return X
103 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/lrgcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_geometric.nn import RGCNConv
4 | from torch_geometric.nn.inits import glorot, zeros
5 |
6 |
7 | class LRGCN(torch.nn.Module):
8 | r"""An implementation of the Long Short Term Memory Relational
9 | Graph Convolution Layer. For details see this paper: `"Predicting Path
10 | Failure In Time-Evolving Graphs." `_
11 |
12 | Args:
13 | in_channels (int): Number of input features.
14 | out_channels (int): Number of output features.
15 | num_relations (int): Number of relations.
16 | num_bases (int): Number of bases.
17 | """
18 |
19 | def __init__(
20 | self, in_channels: int, out_channels: int, num_relations: int, num_bases: int
21 | ):
22 | super(LRGCN, self).__init__()
23 |
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 | self.num_relations = num_relations
27 | self.num_bases = num_bases
28 | self._create_layers()
29 |
30 | def _create_input_gate_layers(self):
31 |
32 | self.conv_x_i = RGCNConv(
33 | in_channels=self.in_channels,
34 | out_channels=self.out_channels,
35 | num_relations=self.num_relations,
36 | num_bases=self.num_bases,
37 | )
38 |
39 | self.conv_h_i = RGCNConv(
40 | in_channels=self.out_channels,
41 | out_channels=self.out_channels,
42 | num_relations=self.num_relations,
43 | num_bases=self.num_bases,
44 | )
45 |
46 | def _create_forget_gate_layers(self):
47 |
48 | self.conv_x_f = RGCNConv(
49 | in_channels=self.in_channels,
50 | out_channels=self.out_channels,
51 | num_relations=self.num_relations,
52 | num_bases=self.num_bases,
53 | )
54 |
55 | self.conv_h_f = RGCNConv(
56 | in_channels=self.out_channels,
57 | out_channels=self.out_channels,
58 | num_relations=self.num_relations,
59 | num_bases=self.num_bases,
60 | )
61 |
62 | def _create_cell_state_layers(self):
63 |
64 | self.conv_x_c = RGCNConv(
65 | in_channels=self.in_channels,
66 | out_channels=self.out_channels,
67 | num_relations=self.num_relations,
68 | num_bases=self.num_bases,
69 | )
70 |
71 | self.conv_h_c = RGCNConv(
72 | in_channels=self.out_channels,
73 | out_channels=self.out_channels,
74 | num_relations=self.num_relations,
75 | num_bases=self.num_bases,
76 | )
77 |
78 | def _create_output_gate_layers(self):
79 |
80 | self.conv_x_o = RGCNConv(
81 | in_channels=self.in_channels,
82 | out_channels=self.out_channels,
83 | num_relations=self.num_relations,
84 | num_bases=self.num_bases,
85 | )
86 |
87 | self.conv_h_o = RGCNConv(
88 | in_channels=self.out_channels,
89 | out_channels=self.out_channels,
90 | num_relations=self.num_relations,
91 | num_bases=self.num_bases,
92 | )
93 |
94 | def _create_layers(self):
95 | self._create_input_gate_layers()
96 | self._create_forget_gate_layers()
97 | self._create_cell_state_layers()
98 | self._create_output_gate_layers()
99 |
100 | def _set_hidden_state(self, X, H):
101 | if H is None:
102 | H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
103 | return H
104 |
105 | def _set_cell_state(self, X, C):
106 | if C is None:
107 | C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
108 | return C
109 |
110 | def _calculate_input_gate(self, X, edge_index, edge_type, H, C):
111 | I = self.conv_x_i(X, edge_index, edge_type)
112 | I = I + self.conv_h_i(H, edge_index, edge_type)
113 | I = torch.sigmoid(I)
114 | return I
115 |
116 | def _calculate_forget_gate(self, X, edge_index, edge_type, H, C):
117 | F = self.conv_x_f(X, edge_index, edge_type)
118 | F = F + self.conv_h_f(H, edge_index, edge_type)
119 | F = torch.sigmoid(F)
120 | return F
121 |
122 | def _calculate_cell_state(self, X, edge_index, edge_type, H, C, I, F):
123 | T = self.conv_x_c(X, edge_index, edge_type)
124 | T = T + self.conv_h_c(H, edge_index, edge_type)
125 | T = torch.tanh(T)
126 | C = F * C + I * T
127 | return C
128 |
129 | def _calculate_output_gate(self, X, edge_index, edge_type, H, C):
130 | O = self.conv_x_o(X, edge_index, edge_type)
131 | O = O + self.conv_h_o(H, edge_index, edge_type)
132 | O = torch.sigmoid(O)
133 | return O
134 |
135 | def _calculate_hidden_state(self, O, C):
136 | H = O * torch.tanh(C)
137 | return H
138 |
139 | def forward(
140 | self,
141 | X: torch.FloatTensor,
142 | edge_index: torch.LongTensor,
143 | edge_type: torch.LongTensor,
144 | H: torch.FloatTensor = None,
145 | C: torch.FloatTensor = None,
146 | ) -> torch.FloatTensor:
147 | """
148 | Making a forward pass. If the hidden state and cell state matrices are
149 | not present when the forward pass is called these are initialized with zeros.
150 |
151 | Arg types:
152 | * **X** *(PyTorch Float Tensor)* - Node features.
153 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
154 | * **edge_type** *(PyTorch Long Tensor)* - Edge type vector.
155 | * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
156 | * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
157 |
158 | Return types:
159 | * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
160 | * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
161 | """
162 | H = self._set_hidden_state(X, H)
163 | C = self._set_cell_state(X, C)
164 | I = self._calculate_input_gate(X, edge_index, edge_type, H, C)
165 | F = self._calculate_forget_gate(X, edge_index, edge_type, H, C)
166 | C = self._calculate_cell_state(X, edge_index, edge_type, H, C, I, F)
167 | O = self._calculate_output_gate(X, edge_index, edge_type, H, C)
168 | H = self._calculate_hidden_state(O, C)
169 | return H, C
170 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/nn/recurrent/mpnn_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import GCNConv
5 |
6 |
7 | class MPNNLSTM(nn.Module):
8 | r"""An implementation of the Message Passing Neural Network with Long Short Term Memory.
9 | For details see this paper: `"Transfer Graph Neural Networks for Pandemic Forecasting." `_
10 |
11 | Args:
12 | in_channels (int): Number of input features.
13 | hidden_size (int): Dimension of hidden representations.
14 | num_nodes (int): Number of nodes in the network.
15 | window (int): Number of past samples included in the input.
16 | dropout (float): Dropout rate.
17 | """
18 |
19 | def __init__(
20 | self,
21 | in_channels: int,
22 | hidden_size: int,
23 | num_nodes: int,
24 | window: int,
25 | dropout: float,
26 | ):
27 | super(MPNNLSTM, self).__init__()
28 |
29 | self.window = window
30 | self.num_nodes = num_nodes
31 | self.hidden_size = hidden_size
32 | self.dropout = dropout
33 | self.in_channels = in_channels
34 |
35 | self._create_parameters_and_layers()
36 |
37 | def _create_parameters_and_layers(self):
38 |
39 | self._convolution_1 = GCNConv(self.in_channels, self.hidden_size)
40 | self._convolution_2 = GCNConv(self.hidden_size, self.hidden_size)
41 |
42 | self._batch_norm_1 = nn.BatchNorm1d(self.hidden_size)
43 | self._batch_norm_2 = nn.BatchNorm1d(self.hidden_size)
44 |
45 | self._recurrent_1 = nn.LSTM(2 * self.hidden_size, self.hidden_size, 1)
46 | self._recurrent_2 = nn.LSTM(self.hidden_size, self.hidden_size, 1)
47 |
48 | def _graph_convolution_1(self, X, edge_index, edge_weight):
49 | X = F.relu(self._convolution_1(X, edge_index, edge_weight))
50 | X = self._batch_norm_1(X)
51 | X = F.dropout(X, p=self.dropout, training=self.training)
52 | return X
53 |
54 | def _graph_convolution_2(self, X, edge_index, edge_weight):
55 | X = F.relu(self._convolution_2(X, edge_index, edge_weight))
56 | X = self._batch_norm_2(X)
57 | X = F.dropout(X, p=self.dropout, training=self.training)
58 | return X
59 |
60 | def forward(
61 | self,
62 | X: torch.FloatTensor,
63 | edge_index: torch.LongTensor,
64 | edge_weight: torch.FloatTensor,
65 | ) -> torch.FloatTensor:
66 | """
67 | Making a forward pass through the whole architecture.
68 |
69 | Arg types:
70 | * **X** *(PyTorch FloatTensor)* - Node features.
71 | * **edge_index** *(PyTorch LongTensor)* - Graph edge indices.
72 | * **edge_weight** *(PyTorch LongTensor, optional)* - Edge weight vector.
73 |
74 | Return types:
75 | * **H** *(PyTorch FloatTensor)* - The hidden representation of size 2*nhid+in_channels+window-1 for each node.
76 | """
77 | R = list()
78 |
79 | S = X.view(-1, self.window, self.num_nodes, self.in_channels)
80 | S = torch.transpose(S, 1, 2)
81 | S = S.reshape(-1, self.window, self.in_channels)
82 | O = [S[:, 0, :]]
83 |
84 | for l in range(1, self.window):
85 | O.append(S[:, l, self.in_channels - 1].unsqueeze(1))
86 |
87 | S = torch.cat(O, dim=1)
88 |
89 | X = self._graph_convolution_1(X, edge_index, edge_weight)
90 | R.append(X)
91 |
92 | X = self._graph_convolution_2(X, edge_index, edge_weight)
93 | R.append(X)
94 |
95 | X = torch.cat(R, dim=1)
96 |
97 | X = X.view(-1, self.window, self.num_nodes, X.size(1))
98 | X = torch.transpose(X, 0, 1)
99 | X = X.contiguous().view(self.window, -1, X.size(3))
100 |
101 | X, (H_1, _) = self._recurrent_1(X)
102 | X, (H_2, _) = self._recurrent_2(X)
103 |
104 | H = torch.cat([H_1[0, :, :], H_2[0, :, :], S], dim=1)
105 | return H
106 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/__init__.py:
--------------------------------------------------------------------------------
1 | from .dynamic_graph_temporal_signal import *
2 | from .dynamic_graph_temporal_signal_batch import *
3 |
4 | from .static_graph_temporal_signal import *
5 | from .static_graph_temporal_signal_batch import *
6 |
7 | from .dynamic_graph_static_signal import *
8 | from .dynamic_graph_static_signal_batch import *
9 |
10 | from .dynamic_hetero_graph_temporal_signal import *
11 | from .dynamic_hetero_graph_temporal_signal_batch import *
12 |
13 | from .static_hetero_graph_temporal_signal import *
14 | from .static_hetero_graph_temporal_signal_batch import *
15 |
16 | from .dynamic_hetero_graph_static_signal import *
17 | from .dynamic_hetero_graph_static_signal_batch import *
18 |
19 | from .train_test_split import *
20 |
21 | # from .index_dataset import *
22 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/dynamic_graph_static_signal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Data
5 |
6 |
7 | Edge_Indices = Sequence[Union[np.ndarray, None]]
8 | Edge_Weights = Sequence[Union[np.ndarray, None]]
9 | Node_Feature = Union[np.ndarray, None]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Additional_Features = Sequence[np.ndarray]
12 |
13 |
14 | class DynamicGraphStaticSignal(object):
15 | r"""A data iterator object to contain a dynamic graph with a
16 | changing edge set and weights . The node labels
17 | (target) are also dynamic. The iterator returns a single discrete temporal
18 | snapshot for a time period (e.g. day or week). This single snapshot is a
19 | Pytorch Geometric Data object. Between two temporal snapshots the edges,
20 | edge weights, target matrices and optionally passed attributes might change.
21 |
22 | Args:
23 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
24 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
25 | feature (Numpy array): Node feature tensor.
26 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
27 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
28 | """
29 |
30 | def __init__(
31 | self,
32 | edge_indices: Edge_Indices,
33 | edge_weights: Edge_Weights,
34 | feature: Node_Feature,
35 | targets: Targets,
36 | **kwargs: Additional_Features
37 | ):
38 | self.edge_indices = edge_indices
39 | self.edge_weights = edge_weights
40 | self.feature = feature
41 | self.targets = targets
42 | self.additional_feature_keys = []
43 | for key, value in kwargs.items():
44 | setattr(self, key, value)
45 | self.additional_feature_keys.append(key)
46 | self._check_temporal_consistency()
47 | self._set_snapshot_count()
48 |
49 | def _check_temporal_consistency(self):
50 | assert len(self.edge_indices) == len(
51 | self.edge_weights
52 | ), "Temporal dimension inconsistency."
53 | assert len(self.targets) == len(
54 | self.edge_indices
55 | ), "Temporal dimension inconsistency."
56 | for key in self.additional_feature_keys:
57 | assert len(self.targets) == len(
58 | getattr(self, key)
59 | ), "Temporal dimension inconsistency."
60 |
61 | def _set_snapshot_count(self):
62 | self.snapshot_count = len(self.targets)
63 |
64 | def _get_edge_index(self, time_index: int):
65 | if self.edge_indices[time_index] is None:
66 | return self.edge_indices[time_index]
67 | else:
68 | return torch.LongTensor(self.edge_indices[time_index])
69 |
70 | def _get_edge_weight(self, time_index: int):
71 | if self.edge_weights[time_index] is None:
72 | return self.edge_weights[time_index]
73 | else:
74 | return torch.FloatTensor(self.edge_weights[time_index])
75 |
76 | def _get_feature(self):
77 | if self.feature is None:
78 | return self.feature
79 | else:
80 | return torch.FloatTensor(self.feature)
81 |
82 | def _get_target(self, time_index: int):
83 | if self.targets[time_index] is None:
84 | return self.targets[time_index]
85 | else:
86 | if self.targets[time_index].dtype.kind == "i":
87 | return torch.LongTensor(self.targets[time_index])
88 | elif self.targets[time_index].dtype.kind == "f":
89 | return torch.FloatTensor(self.targets[time_index])
90 |
91 | def _get_additional_feature(self, time_index: int, feature_key: str):
92 | feature = getattr(self, feature_key)[time_index]
93 | if feature.dtype.kind == "i":
94 | return torch.LongTensor(feature)
95 | elif feature.dtype.kind == "f":
96 | return torch.FloatTensor(feature)
97 |
98 | def _get_additional_features(self, time_index: int):
99 | additional_features = {
100 | key: self._get_additional_feature(time_index, key)
101 | for key in self.additional_feature_keys
102 | }
103 | return additional_features
104 |
105 | def __len__(self):
106 | return len(self.targets)
107 |
108 | def __getitem__(self, time_index: Union[int, slice]):
109 | if isinstance(time_index, slice):
110 | snapshot = DynamicGraphStaticSignal(
111 | self.edge_indices[time_index],
112 | self.edge_weights[time_index],
113 | self.feature,
114 | self.targets[time_index],
115 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
116 | )
117 | else:
118 | x = self._get_feature()
119 | edge_index = self._get_edge_index(time_index)
120 | edge_weight = self._get_edge_weight(time_index)
121 | y = self._get_target(time_index)
122 | additional_features = self._get_additional_features(time_index)
123 |
124 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
125 | y=y, **additional_features)
126 | return snapshot
127 |
128 | def __next__(self):
129 | if self.t < len(self.targets):
130 | snapshot = self[self.t]
131 | self.t = self.t + 1
132 | return snapshot
133 | else:
134 | self.t = 0
135 | raise StopIteration
136 |
137 | def __iter__(self):
138 | self.t = 0
139 | return self
140 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/dynamic_graph_static_signal_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Batch
5 |
6 |
7 | Edge_Indices = Sequence[Union[np.ndarray, None]]
8 | Edge_Weights = Sequence[Union[np.ndarray, None]]
9 | Node_Feature = Union[np.ndarray, None]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Batches = Sequence[Union[np.ndarray, None]]
12 | Additional_Features = Sequence[np.ndarray]
13 |
14 |
15 | class DynamicGraphStaticSignalBatch(object):
16 | r"""A batch iterator object to contain a dynamic graph with a
17 | changing edge set and weights . The node labels
18 | (target) are also dynamic. The iterator returns a single discrete temporal
19 | snapshot for a time period (e.g. day or week). This single snapshot is a
20 | Pytorch Geometric Batch object. Between two temporal snapshots the edges,
21 | batch memberships, edge weights, target matrices and optionally passed
22 | attributes might change.
23 |
24 | Args:
25 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
26 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
27 | feature (Numpy array): Node feature tensor.
28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
29 | batches (Sequence of Numpy arrays): Sequence of batch index tensors.
30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
31 | """
32 |
33 | def __init__(
34 | self,
35 | edge_indices: Edge_Indices,
36 | edge_weights: Edge_Weights,
37 | feature: Node_Feature,
38 | targets: Targets,
39 | batches: Batches,
40 | **kwargs: Additional_Features
41 | ):
42 | self.edge_indices = edge_indices
43 | self.edge_weights = edge_weights
44 | self.feature = feature
45 | self.targets = targets
46 | self.batches = batches
47 | self.additional_feature_keys = []
48 | for key, value in kwargs.items():
49 | setattr(self, key, value)
50 | self.additional_feature_keys.append(key)
51 | self._check_temporal_consistency()
52 | self._set_snapshot_count()
53 |
54 | def _check_temporal_consistency(self):
55 | assert len(self.edge_indices) == len(
56 | self.edge_weights
57 | ), "Temporal dimension inconsistency."
58 | assert len(self.targets) == len(
59 | self.edge_indices
60 | ), "Temporal dimension inconsistency."
61 | assert len(self.batches) == len(
62 | self.edge_indices
63 | ), "Temporal dimension inconsistency."
64 | for key in self.additional_feature_keys:
65 | assert len(self.targets) == len(
66 | getattr(self, key)
67 | ), "Temporal dimension inconsistency."
68 |
69 | def _set_snapshot_count(self):
70 | self.snapshot_count = len(self.targets)
71 |
72 | def _get_edge_index(self, time_index: int):
73 | if self.edge_indices[time_index] is None:
74 | return self.edge_indices[time_index]
75 | else:
76 | return torch.LongTensor(self.edge_indices[time_index])
77 |
78 | def _get_batch_index(self, time_index: int):
79 | if self.batches[time_index] is None:
80 | return self.batches[time_index]
81 | else:
82 | return torch.LongTensor(self.batches[time_index])
83 |
84 | def _get_edge_weight(self, time_index: int):
85 | if self.edge_weights[time_index] is None:
86 | return self.edge_weights[time_index]
87 | else:
88 | return torch.FloatTensor(self.edge_weights[time_index])
89 |
90 | def _get_feature(self):
91 | if self.feature is None:
92 | return self.feature
93 | else:
94 | return torch.FloatTensor(self.feature)
95 |
96 | def _get_target(self, time_index: int):
97 | if self.targets[time_index] is None:
98 | return self.targets[time_index]
99 | else:
100 | if self.targets[time_index].dtype.kind == "i":
101 | return torch.LongTensor(self.targets[time_index])
102 | elif self.targets[time_index].dtype.kind == "f":
103 | return torch.FloatTensor(self.targets[time_index])
104 |
105 | def _get_additional_feature(self, time_index: int, feature_key: str):
106 | feature = getattr(self, feature_key)[time_index]
107 | if feature.dtype.kind == "i":
108 | return torch.LongTensor(feature)
109 | elif feature.dtype.kind == "f":
110 | return torch.FloatTensor(feature)
111 |
112 | def _get_additional_features(self, time_index: int):
113 | additional_features = {
114 | key: self._get_additional_feature(time_index, key)
115 | for key in self.additional_feature_keys
116 | }
117 | return additional_features
118 |
119 | def __getitem__(self, time_index: Union[int, slice]):
120 | if isinstance(time_index, slice):
121 | snapshot = DynamicGraphStaticSignalBatch(
122 | self.edge_indices[time_index],
123 | self.edge_weights[time_index],
124 | self.feature,
125 | self.targets[time_index],
126 | self.batches[time_index],
127 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
128 | )
129 | else:
130 | x = self._get_feature()
131 | edge_index = self._get_edge_index(time_index)
132 | edge_weight = self._get_edge_weight(time_index)
133 | batch = self._get_batch_index(time_index)
134 | y = self._get_target(time_index)
135 | additional_features = self._get_additional_features(time_index)
136 |
137 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight,
138 | y=y, batch=batch, **additional_features)
139 | return snapshot
140 |
141 | def __next__(self):
142 | if self.t < len(self.targets):
143 | snapshot = self[self.t]
144 | self.t = self.t + 1
145 | return snapshot
146 | else:
147 | self.t = 0
148 | raise StopIteration
149 |
150 | def __iter__(self):
151 | self.t = 0
152 | return self
153 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/dynamic_graph_temporal_signal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Data
5 |
6 |
7 | Edge_Indices = Sequence[Union[np.ndarray, None]]
8 | Edge_Weights = Sequence[Union[np.ndarray, None]]
9 | Node_Features = Sequence[Union[np.ndarray, None]]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Additional_Features = Sequence[np.ndarray]
12 |
13 |
14 | class DynamicGraphTemporalSignal(object):
15 | r"""A data iterator object to contain a dynamic graph with a
16 | changing edge set and weights . The feature set and node labels
17 | (target) are also dynamic. The iterator returns a single discrete temporal
18 | snapshot for a time period (e.g. day or week). This single snapshot is a
19 | Pytorch Geometric Data object. Between two temporal snapshots the edges,
20 | edge weights, target matrices and optionally passed attributes might change.
21 |
22 | Args:
23 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
24 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
25 | features (Sequence of Numpy arrays): Sequence of node feature tensors.
26 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
27 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
28 | """
29 |
30 | def __init__(
31 | self,
32 | edge_indices: Edge_Indices,
33 | edge_weights: Edge_Weights,
34 | features: Node_Features,
35 | targets: Targets,
36 | **kwargs: Additional_Features
37 | ):
38 | self.edge_indices = edge_indices
39 | self.edge_weights = edge_weights
40 | self.features = features
41 | self.targets = targets
42 | self.additional_feature_keys = []
43 | for key, value in kwargs.items():
44 | setattr(self, key, value)
45 | self.additional_feature_keys.append(key)
46 | self._check_temporal_consistency()
47 | self._set_snapshot_count()
48 |
49 | def _check_temporal_consistency(self):
50 | assert len(self.features) == len(
51 | self.targets
52 | ), "Temporal dimension inconsistency."
53 | assert len(self.edge_indices) == len(
54 | self.edge_weights
55 | ), "Temporal dimension inconsistency."
56 | assert len(self.features) == len(
57 | self.edge_weights
58 | ), "Temporal dimension inconsistency."
59 | for key in self.additional_feature_keys:
60 | assert len(self.targets) == len(
61 | getattr(self, key)
62 | ), "Temporal dimension inconsistency."
63 |
64 | def _set_snapshot_count(self):
65 | self.snapshot_count = len(self.features)
66 |
67 | def _get_edge_index(self, time_index: int):
68 | if self.edge_indices[time_index] is None:
69 | return self.edge_indices[time_index]
70 | else:
71 | return torch.LongTensor(self.edge_indices[time_index])
72 |
73 | def _get_edge_weight(self, time_index: int):
74 | if self.edge_weights[time_index] is None:
75 | return self.edge_weights[time_index]
76 | else:
77 | return torch.FloatTensor(self.edge_weights[time_index])
78 |
79 | def _get_features(self, time_index: int):
80 | if self.features[time_index] is None:
81 | return self.features[time_index]
82 | else:
83 | return torch.FloatTensor(self.features[time_index])
84 |
85 | def _get_target(self, time_index: int):
86 | if self.targets[time_index] is None:
87 | return self.targets[time_index]
88 | else:
89 | if self.targets[time_index].dtype.kind == "i":
90 | return torch.LongTensor(self.targets[time_index])
91 | elif self.targets[time_index].dtype.kind == "f":
92 | return torch.FloatTensor(self.targets[time_index])
93 |
94 | def _get_additional_feature(self, time_index: int, feature_key: str):
95 | feature = getattr(self, feature_key)[time_index]
96 | if feature.dtype.kind == "i":
97 | return torch.LongTensor(feature)
98 | elif feature.dtype.kind == "f":
99 | return torch.FloatTensor(feature)
100 |
101 | def _get_additional_features(self, time_index: int):
102 | additional_features = {
103 | key: self._get_additional_feature(time_index, key)
104 | for key in self.additional_feature_keys
105 | }
106 | return additional_features
107 |
108 | def __getitem__(self, time_index: Union[int, slice]):
109 | if isinstance(time_index, slice):
110 | snapshot = DynamicGraphTemporalSignal(
111 | self.edge_indices[time_index],
112 | self.edge_weights[time_index],
113 | self.features[time_index],
114 | self.targets[time_index],
115 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
116 | )
117 | else:
118 | x = self._get_features(time_index)
119 | edge_index = self._get_edge_index(time_index)
120 | edge_weight = self._get_edge_weight(time_index)
121 | y = self._get_target(time_index)
122 | additional_features = self._get_additional_features(time_index)
123 |
124 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
125 | y=y, **additional_features)
126 | return snapshot
127 |
128 | def __next__(self):
129 | if self.t < len(self.features):
130 | snapshot = self[self.t]
131 | self.t = self.t + 1
132 | return snapshot
133 | else:
134 | self.t = 0
135 | raise StopIteration
136 |
137 | def __iter__(self):
138 | self.t = 0
139 | return self
140 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/dynamic_graph_temporal_signal_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Batch
5 |
6 |
7 | Edge_Indices = Sequence[Union[np.ndarray, None]]
8 | Edge_Weights = Sequence[Union[np.ndarray, None]]
9 | Node_Features = Sequence[Union[np.ndarray, None]]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Batches = Sequence[Union[np.ndarray, None]]
12 | Additional_Features = Sequence[np.ndarray]
13 |
14 |
15 | class DynamicGraphTemporalSignalBatch(object):
16 | r"""A data iterator object to contain a dynamic graph with a
17 | changing edge set and weights . The feature set and node labels
18 | (target) are also dynamic. The iterator returns a single discrete temporal
19 | snapshot for a time period (e.g. day or week). This single snapshot is a
20 | Pytorch Geometric Batch object. Between two temporal snapshots the edges,
21 | edge weights, the feature matrix, target matrices and optionally passed
22 | attributes might change.
23 |
24 | Args:
25 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors.
26 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors.
27 | features (Sequence of Numpy arrays): Sequence of node feature tensors.
28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
29 | batches (Sequence of Numpy arrays): Sequence of batch index tensors.
30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
31 | """
32 |
33 | def __init__(
34 | self,
35 | edge_indices: Edge_Indices,
36 | edge_weights: Edge_Weights,
37 | features: Node_Features,
38 | targets: Targets,
39 | batches: Batches,
40 | **kwargs: Additional_Features
41 | ):
42 | self.edge_indices = edge_indices
43 | self.edge_weights = edge_weights
44 | self.features = features
45 | self.targets = targets
46 | self.batches = batches
47 | self.additional_feature_keys = []
48 | for key, value in kwargs.items():
49 | setattr(self, key, value)
50 | self.additional_feature_keys.append(key)
51 | self._check_temporal_consistency()
52 | self._set_snapshot_count()
53 |
54 | def _check_temporal_consistency(self):
55 | assert len(self.features) == len(
56 | self.targets
57 | ), "Temporal dimension inconsistency."
58 | assert len(self.edge_indices) == len(
59 | self.edge_weights
60 | ), "Temporal dimension inconsistency."
61 | assert len(self.features) == len(
62 | self.edge_weights
63 | ), "Temporal dimension inconsistency."
64 | assert len(self.features) == len(
65 | self.batches
66 | ), "Temporal dimension inconsistency."
67 | for key in self.additional_feature_keys:
68 | assert len(self.targets) == len(
69 | getattr(self, key)
70 | ), "Temporal dimension inconsistency."
71 |
72 | def _set_snapshot_count(self):
73 | self.snapshot_count = len(self.features)
74 |
75 | def _get_edge_index(self, time_index: int):
76 | if self.edge_indices[time_index] is None:
77 | return self.edge_indices[time_index]
78 | else:
79 | return torch.LongTensor(self.edge_indices[time_index])
80 |
81 | def _get_batch_index(self, time_index: int):
82 | if self.batches[time_index] is None:
83 | return self.batches[time_index]
84 | else:
85 | return torch.LongTensor(self.batches[time_index])
86 |
87 | def _get_edge_weight(self, time_index: int):
88 | if self.edge_weights[time_index] is None:
89 | return self.edge_weights[time_index]
90 | else:
91 | return torch.FloatTensor(self.edge_weights[time_index])
92 |
93 | def _get_feature(self, time_index: int):
94 | if self.features[time_index] is None:
95 | return self.features[time_index]
96 | else:
97 | return torch.FloatTensor(self.features[time_index])
98 |
99 | def _get_target(self, time_index: int):
100 | if self.targets[time_index] is None:
101 | return self.targets[time_index]
102 | else:
103 | if self.targets[time_index].dtype.kind == "i":
104 | return torch.LongTensor(self.targets[time_index])
105 | elif self.targets[time_index].dtype.kind == "f":
106 | return torch.FloatTensor(self.targets[time_index])
107 |
108 | def _get_additional_feature(self, time_index: int, feature_key: str):
109 | feature = getattr(self, feature_key)[time_index]
110 | if feature.dtype.kind == "i":
111 | return torch.LongTensor(feature)
112 | elif feature.dtype.kind == "f":
113 | return torch.FloatTensor(feature)
114 |
115 | def _get_additional_features(self, time_index: int):
116 | additional_features = {
117 | key: self._get_additional_feature(time_index, key)
118 | for key in self.additional_feature_keys
119 | }
120 | return additional_features
121 |
122 | def __getitem__(self, time_index: Union[int, slice]):
123 | if isinstance(time_index, slice):
124 | snapshot = DynamicGraphTemporalSignalBatch(
125 | self.edge_indices[time_index],
126 | self.edge_weights[time_index],
127 | self.features[time_index],
128 | self.targets[time_index],
129 | self.batches[time_index],
130 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
131 | )
132 | else:
133 | x = self._get_feature(time_index)
134 | edge_index = self._get_edge_index(time_index)
135 | edge_weight = self._get_edge_weight(time_index)
136 | batch = self._get_batch_index(time_index)
137 | y = self._get_target(time_index)
138 | additional_features = self._get_additional_features(time_index)
139 |
140 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight,
141 | y=y, batch=batch, **additional_features)
142 | return snapshot
143 |
144 | def __next__(self):
145 | if self.t < len(self.features):
146 | snapshot = self[self.t]
147 | self.t = self.t + 1
148 | return snapshot
149 | else:
150 | self.t = 0
151 | raise StopIteration
152 |
153 | def __iter__(self):
154 | self.t = 0
155 | return self
156 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/index_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import dask.array as da
4 | import numpy as np
5 |
6 |
7 |
8 | class IndexDataset(Dataset):
9 | """
10 | A custom pytorch-compatible dataset that implements index-batching and DDP-index-batching.
11 | It also supports GPU-index-batching and lazy-index-batching.
12 |
13 | Args:
14 | indices (array-like): Indices corresponding to the time slicies.
15 | data (array-like or Dask array): The dataset to be indexed.
16 | horizon (int): The prediction period for the dataset.
17 | lazy (bool, optional): Whether to use Dask lazy loading (distribute the data across all workers). Defaults to False.
18 | gpu (bool, optional): If the data is already on the GPU. Defaults to False.
19 | """
20 | def __init__(self, indices, data, horizon, lazy=False, gpu=False):
21 | self.indices = indices
22 | self.data = data
23 | self.horizon = horizon
24 | self.lazy = lazy
25 | self.gpu = gpu
26 |
27 | def __len__(self):
28 |
29 | # Return the number of samples
30 | return self.indices.shape[0]
31 |
32 | def __getitem__(self, x):
33 | """
34 | Retrieve a data sample and its corresponding target based on the index.
35 |
36 | Args:
37 | x (int): The index of the sample to retrieve.
38 |
39 | Returns:
40 | tuple: A tuple (x, y), where `x` is the input sequence and `y` is the target sequence.
41 | """
42 |
43 | idx = self.indices[x]
44 |
45 | # Calculate the offset based on the horizon value
46 | y_start = idx + self.horizon
47 |
48 | # If the data is already on the gpu (likely due to using index-gpu-preprocessing), return tensor-slice
49 | if self.gpu:
50 | return self.data[idx:y_start,...], self.data[y_start:y_start + self.horizon,...]
51 |
52 | else:
53 | # if utilizing DDP-batching, gather the data on to this worker and convert to tensor
54 | if self.lazy:
55 | return torch.from_numpy(self.data[idx:y_start,...].compute()),torch.from_numpy(self.data[y_start:y_start + self.horizon,...].compute())
56 | else:
57 | return torch.from_numpy(self.data[idx:y_start,...]), torch.from_numpy(self.data[y_start:y_start + self.horizon,...])
58 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/static_graph_temporal_signal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Data
5 |
6 |
7 | Edge_Index = Union[np.ndarray, None]
8 | Edge_Weight = Union[np.ndarray, None]
9 | Node_Features = Sequence[Union[np.ndarray, None]]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Additional_Features = Sequence[np.ndarray]
12 |
13 |
14 | class StaticGraphTemporalSignal(object):
15 | r"""A data iterator object to contain a static graph with a dynamically
16 | changing constant time difference temporal feature set (multiple signals).
17 | The node labels (target) are also temporal. The iterator returns a single
18 | constant time difference temporal snapshot for a time period (e.g. day or week).
19 | This single temporal snapshot is a Pytorch Geometric Data object. Between two
20 | temporal snapshots the features and optionally passed attributes might change.
21 | However, the underlying graph is the same.
22 |
23 | Args:
24 | edge_index (Numpy array): Index tensor of edges.
25 | edge_weight (Numpy array): Edge weight tensor.
26 | features (Sequence of Numpy arrays): Sequence of node feature tensors.
27 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
28 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
29 | """
30 |
31 | def __init__(
32 | self,
33 | edge_index: Edge_Index,
34 | edge_weight: Edge_Weight,
35 | features: Node_Features,
36 | targets: Targets,
37 | **kwargs: Additional_Features
38 | ):
39 | self.edge_index = edge_index
40 | self.edge_weight = edge_weight
41 | self.features = features
42 | self.targets = targets
43 | self.additional_feature_keys = []
44 | for key, value in kwargs.items():
45 | setattr(self, key, value)
46 | self.additional_feature_keys.append(key)
47 | self._check_temporal_consistency()
48 | self._set_snapshot_count()
49 |
50 | def _check_temporal_consistency(self):
51 | assert len(self.features) == len(
52 | self.targets
53 | ), "Temporal dimension inconsistency."
54 | for key in self.additional_feature_keys:
55 | assert len(self.targets) == len(
56 | getattr(self, key)
57 | ), "Temporal dimension inconsistency."
58 |
59 | def _set_snapshot_count(self):
60 | self.snapshot_count = len(self.features)
61 |
62 | def _get_edge_index(self):
63 | if self.edge_index is None:
64 | return self.edge_index
65 | else:
66 | return torch.LongTensor(self.edge_index)
67 |
68 | def _get_edge_weight(self):
69 | if self.edge_weight is None:
70 | return self.edge_weight
71 | else:
72 | return torch.FloatTensor(self.edge_weight)
73 |
74 | def _get_features(self, time_index: int):
75 | if self.features[time_index] is None:
76 | return self.features[time_index]
77 | else:
78 | return torch.FloatTensor(self.features[time_index])
79 |
80 | def _get_target(self, time_index: int):
81 | if self.targets[time_index] is None:
82 | return self.targets[time_index]
83 | else:
84 | if self.targets[time_index].dtype.kind == "i":
85 | return torch.LongTensor(self.targets[time_index])
86 | elif self.targets[time_index].dtype.kind == "f":
87 | return torch.FloatTensor(self.targets[time_index])
88 |
89 | def _get_additional_feature(self, time_index: int, feature_key: str):
90 | feature = getattr(self, feature_key)[time_index]
91 | if feature.dtype.kind == "i":
92 | return torch.LongTensor(feature)
93 | elif feature.dtype.kind == "f":
94 | return torch.FloatTensor(feature)
95 |
96 | def _get_additional_features(self, time_index: int):
97 | additional_features = {
98 | key: self._get_additional_feature(time_index, key)
99 | for key in self.additional_feature_keys
100 | }
101 | return additional_features
102 |
103 | def __getitem__(self, time_index: Union[int, slice]):
104 | if isinstance(time_index, slice):
105 | snapshot = StaticGraphTemporalSignal(
106 | self.edge_index,
107 | self.edge_weight,
108 | self.features[time_index],
109 | self.targets[time_index],
110 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
111 | )
112 | else:
113 | x = self._get_features(time_index)
114 | edge_index = self._get_edge_index()
115 | edge_weight = self._get_edge_weight()
116 | y = self._get_target(time_index)
117 | additional_features = self._get_additional_features(time_index)
118 |
119 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
120 | y=y, **additional_features)
121 | return snapshot
122 |
123 | def __next__(self):
124 | if self.t < len(self.features):
125 | snapshot = self[self.t]
126 | self.t = self.t + 1
127 | return snapshot
128 | else:
129 | self.t = 0
130 | raise StopIteration
131 |
132 | def __iter__(self):
133 | self.t = 0
134 | return self
135 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/static_graph_temporal_signal_batch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence, Union
4 | from torch_geometric.data import Batch
5 |
6 |
7 | Edge_Index = Union[np.ndarray, None]
8 | Edge_Weight = Union[np.ndarray, None]
9 | Node_Features = Sequence[Union[np.ndarray, None]]
10 | Targets = Sequence[Union[np.ndarray, None]]
11 | Batches = Union[np.ndarray, None]
12 | Additional_Features = Sequence[np.ndarray]
13 |
14 |
15 | class StaticGraphTemporalSignalBatch(object):
16 | r"""A data iterator object to contain a static graph with a dynamically
17 | changing constant time difference temporal feature set (multiple signals).
18 | The node labels (target) are also temporal. The iterator returns a single
19 | constant time difference temporal snapshot for a time period (e.g. day or week).
20 | This single temporal snapshot is a Pytorch Geometric Batch object. Between two
21 | temporal snapshots the feature matrix, target matrices and optionally passed
22 | attributes might change. However, the underlying graph is the same.
23 |
24 | Args:
25 | edge_index (Numpy array): Index tensor of edges.
26 | edge_weight (Numpy array): Edge weight tensor.
27 | features (Sequence of Numpy arrays): Sequence of node feature tensors.
28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors.
29 | batches (Numpy array): Batch index tensor.
30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes.
31 | """
32 |
33 | def __init__(
34 | self,
35 | edge_index: Edge_Index,
36 | edge_weight: Edge_Weight,
37 | features: Node_Features,
38 | targets: Targets,
39 | batches: Batches,
40 | **kwargs: Additional_Features
41 | ):
42 | self.edge_index = edge_index
43 | self.edge_weight = edge_weight
44 | self.features = features
45 | self.targets = targets
46 | self.batches = batches
47 | self.additional_feature_keys = []
48 | for key, value in kwargs.items():
49 | setattr(self, key, value)
50 | self.additional_feature_keys.append(key)
51 | self._check_temporal_consistency()
52 | self._set_snapshot_count()
53 |
54 | def _check_temporal_consistency(self):
55 | assert len(self.features) == len(
56 | self.targets
57 | ), "Temporal dimension inconsistency."
58 | for key in self.additional_feature_keys:
59 | assert len(self.targets) == len(
60 | getattr(self, key)
61 | ), "Temporal dimension inconsistency."
62 |
63 | def _set_snapshot_count(self):
64 | self.snapshot_count = len(self.features)
65 |
66 | def _get_edge_index(self):
67 | if self.edge_index is None:
68 | return self.edge_index
69 | else:
70 | return torch.LongTensor(self.edge_index)
71 |
72 | def _get_batch_index(self):
73 | if self.batches is None:
74 | return self.batches
75 | else:
76 | return torch.LongTensor(self.batches)
77 |
78 | def _get_edge_weight(self):
79 | if self.edge_weight is None:
80 | return self.edge_weight
81 | else:
82 | return torch.FloatTensor(self.edge_weight)
83 |
84 | def _get_feature(self, time_index: int):
85 | if self.features[time_index] is None:
86 | return self.features[time_index]
87 | else:
88 | return torch.FloatTensor(self.features[time_index])
89 |
90 | def _get_target(self, time_index: int):
91 | if self.targets[time_index] is None:
92 | return self.targets[time_index]
93 | else:
94 | if self.targets[time_index].dtype.kind == "i":
95 | return torch.LongTensor(self.targets[time_index])
96 | elif self.targets[time_index].dtype.kind == "f":
97 | return torch.FloatTensor(self.targets[time_index])
98 |
99 | def _get_additional_feature(self, time_index: int, feature_key: str):
100 | feature = getattr(self, feature_key)[time_index]
101 | if feature.dtype.kind == "i":
102 | return torch.LongTensor(feature)
103 | elif feature.dtype.kind == "f":
104 | return torch.FloatTensor(feature)
105 |
106 | def _get_additional_features(self, time_index: int):
107 | additional_features = {
108 | key: self._get_additional_feature(time_index, key)
109 | for key in self.additional_feature_keys
110 | }
111 | return additional_features
112 |
113 | def __getitem__(self, time_index: Union[int, slice]):
114 | if isinstance(time_index, slice):
115 | snapshot = StaticGraphTemporalSignalBatch(
116 | self.edge_index,
117 | self.edge_weight,
118 | self.features[time_index],
119 | self.targets[time_index],
120 | self.batches,
121 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys}
122 | )
123 | else:
124 | x = self._get_feature(time_index)
125 | edge_index = self._get_edge_index()
126 | edge_weight = self._get_edge_weight()
127 | batch = self._get_batch_index()
128 | y = self._get_target(time_index)
129 | additional_features = self._get_additional_features(time_index)
130 |
131 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight,
132 | y=y, batch=batch, **additional_features)
133 | return snapshot
134 |
135 | def __next__(self):
136 | if self.t < len(self.features):
137 | snapshot = self[self.t]
138 | self.t = self.t + 1
139 | return snapshot
140 | else:
141 | self.t = 0
142 | raise StopIteration
143 |
144 | def __iter__(self):
145 | self.t = 0
146 | return self
147 |
--------------------------------------------------------------------------------
/torch_geometric_temporal/signal/train_test_split.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple
2 |
3 | from .static_graph_temporal_signal import StaticGraphTemporalSignal
4 | from .dynamic_graph_temporal_signal import DynamicGraphTemporalSignal
5 | from .dynamic_graph_static_signal import DynamicGraphStaticSignal
6 |
7 | from .static_graph_temporal_signal_batch import StaticGraphTemporalSignalBatch
8 | from .dynamic_graph_temporal_signal_batch import DynamicGraphTemporalSignalBatch
9 | from .dynamic_graph_static_signal_batch import DynamicGraphStaticSignalBatch
10 |
11 | from .static_hetero_graph_temporal_signal import StaticHeteroGraphTemporalSignal
12 | from .dynamic_hetero_graph_temporal_signal import DynamicHeteroGraphTemporalSignal
13 | from .dynamic_hetero_graph_static_signal import DynamicHeteroGraphStaticSignal
14 |
15 | from .static_hetero_graph_temporal_signal_batch import StaticHeteroGraphTemporalSignalBatch
16 | from .dynamic_hetero_graph_temporal_signal_batch import DynamicHeteroGraphTemporalSignalBatch
17 | from .dynamic_hetero_graph_static_signal_batch import DynamicHeteroGraphStaticSignalBatch
18 |
19 |
20 | Discrete_Signal = Union[
21 | StaticGraphTemporalSignal,
22 | StaticGraphTemporalSignalBatch,
23 | DynamicGraphTemporalSignal,
24 | DynamicGraphTemporalSignalBatch,
25 | DynamicGraphStaticSignal,
26 | DynamicGraphStaticSignalBatch,
27 | StaticHeteroGraphTemporalSignal,
28 | StaticHeteroGraphTemporalSignalBatch,
29 | DynamicHeteroGraphTemporalSignal,
30 | DynamicHeteroGraphTemporalSignalBatch,
31 | DynamicHeteroGraphStaticSignal,
32 | DynamicHeteroGraphStaticSignalBatch,
33 | ]
34 |
35 |
36 | def temporal_signal_split(
37 | data_iterator, train_ratio: float = 0.8
38 | ) -> Tuple[Discrete_Signal, Discrete_Signal]:
39 | r"""Function to split a data iterator according to a fixed ratio.
40 |
41 | Arg types:
42 | * **data_iterator** *(Signal Iterator)* - Node features.
43 | * **train_ratio** *(float)* - Graph edge indices.
44 |
45 | Return types:
46 | * **(train_iterator, test_iterator)** *(tuple of Signal Iterators)* - Train and test data iterators.
47 | """
48 |
49 | train_snapshots = int(train_ratio * data_iterator.snapshot_count)
50 |
51 | train_iterator = data_iterator[0:train_snapshots]
52 | test_iterator = data_iterator[train_snapshots:]
53 |
54 | return train_iterator, test_iterator
55 |
--------------------------------------------------------------------------------