├── .gitattributes ├── .github └── workflows │ └── main.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dataset ├── chickenpox.json ├── england_covid.json ├── montevideo_bus.json ├── mtm_1.json ├── pedalme_london.json ├── twitter_tennis_rg17.json ├── twitter_tennis_uo17.json └── wikivital_mathematics.json ├── docs ├── Makefile ├── index.html ├── mapping.py ├── requirements.txt └── source │ ├── _figures │ └── build.sh │ ├── _static │ ├── css │ │ └── custom.css │ └── img │ │ ├── logo.jpg │ │ └── text_logo.jpg │ ├── conf.py │ ├── index.rst │ ├── modules │ ├── dataset.rst │ ├── root.rst │ └── signal.rst │ └── notes │ ├── installation.rst │ ├── introduction.rst │ └── resources.rst ├── examples ├── indexBatching │ ├── A3TGCN │ │ ├── metr_la_main.py │ │ └── pems_ddp.py │ ├── DCRNN │ │ ├── chicken_pox_main.py │ │ ├── pems_allLA_main.py │ │ ├── pems_bay_main.py │ │ ├── pems_ddp.py │ │ ├── pems_main.py │ │ ├── submit.sh │ │ ├── utils.py │ │ └── windmill_main.py │ └── README.md └── recurrent │ ├── a3tgcn2_example.py │ ├── a3tgcn_example.py │ ├── agcrn_example.py │ ├── dcrnn_example.py │ ├── dygrencoder_example.py │ ├── evolvegcnh_example.py │ ├── evolvegcno_example.py │ ├── gclstm_example.py │ ├── gconvgru_example.py │ ├── gconvlstm_example.py │ ├── lightning_example.py │ ├── lrgcn_example.py │ ├── mpnnlstm_example.py │ └── tgcn_example.py ├── notebooks ├── a3tgcn_for_traffic_forecasting.ipynb ├── astgcn_for_traffic_flow_forecasting.ipynb └── processing_traffic_data_for_deep_learning_projects.ipynb ├── readthedocs.yml ├── setup.py ├── test ├── attention_test.py ├── batch_test.py ├── dataset_test.py ├── heterogeneous_test.py └── recurrent_test.py └── torch_geometric_temporal ├── __init__.py ├── dataset ├── __init__.py ├── chickenpox.py ├── encovid.py ├── metr_la.py ├── montevideo_bus.py ├── mtm.py ├── pedalme.py ├── pems.py ├── pemsAllLA.py ├── pems_bay.py ├── twitter_tennis.py ├── wikimath.py ├── windmilllarge.py ├── windmillmedium.py └── windmillsmall.py ├── nn ├── __init__.py ├── attention │ ├── __init__.py │ ├── astgcn.py │ ├── dnntsp.py │ ├── gman.py │ ├── mstgcn.py │ ├── mtgnn.py │ ├── stgcn.py │ └── tsagcn.py ├── hetero │ ├── __init__.py │ └── heterogclstm.py └── recurrent │ ├── __init__.py │ ├── agcrn.py │ ├── attentiontemporalgcn.py │ ├── dcrnn.py │ ├── dygrae.py │ ├── evolvegcnh.py │ ├── evolvegcno.py │ ├── gc_lstm.py │ ├── gconv_gru.py │ ├── gconv_lstm.py │ ├── lrgcn.py │ ├── mpnn_lstm.py │ └── temporalgcn.py └── signal ├── __init__.py ├── dynamic_graph_static_signal.py ├── dynamic_graph_static_signal_batch.py ├── dynamic_graph_temporal_signal.py ├── dynamic_graph_temporal_signal_batch.py ├── dynamic_hetero_graph_static_signal.py ├── dynamic_hetero_graph_static_signal_batch.py ├── dynamic_hetero_graph_temporal_signal.py ├── dynamic_hetero_graph_temporal_signal_batch.py ├── index_dataset.py ├── static_graph_temporal_signal.py ├── static_graph_temporal_signal_batch.py ├── static_hetero_graph_temporal_signal.py ├── static_hetero_graph_temporal_signal_batch.py └── train_test_split.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | runs-on: ${{ matrix.os }} 13 | 14 | strategy: 15 | matrix: 16 | os: [ubuntu-22.04] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.8 23 | - uses: s-weigand/setup-conda@v1 24 | with: 25 | activate-conda: true 26 | python-version: 3.8 27 | - run: conda --version 28 | - run: which python 29 | - name: Install main dependencies 30 | run: | 31 | python -m pip install torch==2.3.0 torchvision torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html 32 | python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cpu.html 33 | python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cpu.html 34 | python -m pip install torch-geometric 35 | python -m pip install sphinx sphinx_rtd_theme 36 | - name: Install main package 37 | run: | 38 | python -m pip install -e .[test] 39 | - name: Run test-suite 40 | run: | 41 | python -m pytest 42 | - name: Generate coverage report 43 | if: success() 44 | run: | 45 | pip install coverage 46 | coverage run -m pytest 47 | coverage xml 48 | - name: Upload coverage report to codecov 49 | uses: codecov/codecov-action@v1 50 | if: success() 51 | with: 52 | file: coverage.xml 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vs code 132 | .vscode/ 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | I'm really glad you're reading this, because we need volunteer developers to help this project come to fruition. 4 | 5 | Here are some important things to check when you contribute: 6 | 7 | * Please make sure that you write tests. 8 | * Update the documentation. 9 | * Add the new model to the readme. 10 | * If your contribution is a paper please update the resource documentation file. 11 | 12 | ## Testing 13 | 14 | 15 | PyTorch Geometric Temporal's testing is located under `test/`. 16 | Run the entire test suite with 17 | 18 | ``` 19 | python setup.py test 20 | ``` 21 | 22 | ## Submitting changes 23 | 24 | Please send a [GitHub Pull Request to PyTorch Geometric Temporal](https://github.com/benedekrozemberczki/pytorch_geometric_temporal/pull/new/master) with a clear list of what you've done (read more about [pull requests](http://help.github.com/pull-requests/)). Please follow our coding conventions (below) and make sure all of your commits are atomic (one feature per commit). 25 | 26 | Always write a clear log message for your commits. One-line messages are fine for small changes, but bigger changes should look like this: 27 | 28 | $ git commit -m "A brief summary of the commit 29 | > 30 | > A paragraph describing what changed and its impact." 31 | 32 | ## Coding conventions 33 | 34 | Start reading our code and you'll get the hang of it. We optimize for readability: 35 | 36 | * We write tests for the data loaders, iterators and layers. 37 | * We use the type hinting feature of Python. 38 | * We avoid the uses of public methods and vaiarbles in the classes. 39 | * Hyperparameters belong to the constructors. 40 | * Auxiliiary layer instances should have long names. 41 | * Make linear algebra operations line-by-line. 42 | 43 | Thanks, 44 | Benedek 45 | 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Benedek Rozemberczki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD = sphinx-build 2 | SPHINXPROJ = pytorch_geometric_temporal 3 | SOURCEDIR = source 4 | BUILDDIR = build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" 10 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Redirect 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/mapping.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | from collections import defaultdict 4 | 5 | def find_defined_classes_in_file(filepath): 6 | with open(filepath, 'r', encoding='utf-8') as file: 7 | try: 8 | tree = ast.parse(file.read(), filename=filepath) 9 | except SyntaxError: 10 | return [] 11 | 12 | return [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] 13 | 14 | def map_defined_classes(directories): 15 | file_to_classes = defaultdict(list) 16 | all_classes = set() 17 | for directory in directories: 18 | for root, _, files in os.walk(directory): 19 | for file in files: 20 | if file.endswith('.py'): 21 | full_path = os.path.join(root, file) 22 | class_names = find_defined_classes_in_file(full_path) 23 | if class_names: 24 | file_to_classes[full_path[3:].replace("/",".")].extend(class_names) 25 | all_classes.update(class_names) 26 | 27 | return all_classes, file_to_classes 28 | 29 | if __name__ == "__main__": 30 | # Replace these with the paths you want to analyze 31 | directories_to_scan = [ 32 | "../torch_geometric_temporal/nn/recurrent", 33 | "../torch_geometric_temporal/nn/attention", 34 | "../torch_geometric_temporal/nn/hetero" 35 | ] 36 | 37 | all_classes, mapping = map_defined_classes(directories_to_scan) 38 | 39 | 40 | order = [ 41 | 42 | "torch_geometric_temporal.nn.recurrent.gconv_gru", 43 | "torch_geometric_temporal.nn.recurrent.gconv_lstm", 44 | "torch_geometric_temporal.nn.recurrent.gc_lstm", 45 | "torch_geometric_temporal.nn.recurrent.lrgcn", 46 | "torch_geometric_temporal.nn.recurrent.dygrae", 47 | "torch_geometric_temporal.nn.recurrent.evolvegcnh", 48 | "torch_geometric_temporal.nn.recurrent.evolvegcno", 49 | "torch_geometric_temporal.nn.recurrent.temporalgcn", 50 | "torch_geometric_temporal.nn.recurrent.attentiontemporalgcn", 51 | "torch_geometric_temporal.nn.recurrent.mpnn_lstm", 52 | "torch_geometric_temporal.nn.recurrent.dcrnn", 53 | "torch_geometric_temporal.nn.recurrent.agcrn", 54 | 55 | "torch_geometric_temporal.nn.attention.stgcn", 56 | "torch_geometric_temporal.nn.attention.astgcn", 57 | "torch_geometric_temporal.nn.attention.mstgcn", 58 | "torch_geometric_temporal.nn.attention.gman", 59 | "torch_geometric_temporal.nn.attention.mtgnn", 60 | "torch_geometric_temporal.nn.attention.tsagcn", 61 | "torch_geometric_temporal.nn.attention.dnntsp", 62 | 63 | "torch_geometric_temporal.nn.hetero.heterogclstm" 64 | 65 | ] 66 | model = [ 67 | "GConvGRU", 68 | "GConvLSTM", 69 | "GCLSTM", 70 | "LRGCN", 71 | "DyGrEncoder", 72 | "EvolveGCNH", 73 | "EvolveGCNO", 74 | "GCNConv_Fixed_W", 75 | "TGCN", 76 | "TGCN2", 77 | "A3TGCN", 78 | "A3TGCN2", 79 | "MPNNLSTM", 80 | "DCRNN", 81 | "BatchedDCRNN", 82 | "AGCRN", 83 | "STConv", 84 | "ASTGCN", 85 | "MSTGCN", 86 | "GMAN", 87 | "SpatioTemporalAttention", 88 | "GraphConstructor", 89 | "MTGNN", 90 | "AAGCN", 91 | "DNNTSP" 92 | ] 93 | aux = [ 94 | "TemporalConv", 95 | "DConv", 96 | "BatchedDConv", 97 | "ChebConvAttention", 98 | "AVWGCN", 99 | "UnitGCN", 100 | "UnitTCN" 101 | ] 102 | 103 | het = [ 104 | "HeteroGCLSTM" 105 | ] 106 | # print(mapping.keys()) 107 | target = {} 108 | for file, classes in mapping.items(): 109 | 110 | 111 | line = "" 112 | 113 | for c in all_classes: 114 | if c not in classes: 115 | line += f"{c}, " 116 | 117 | line = line[:-2] 118 | target[file[:-3]] = line 119 | 120 | 121 | for key in order: 122 | print(f".. autoapimodule:: {key}") 123 | print("\t:members:") 124 | print(f"\t:exclude-members: {target[key]}, LayerNormalization, AggregateTemporalNodeFeatures, GlobalGatedUpdater, MaskedSelfAttention, WeightedGCNBlock, LayerNormalization, K, bias, in_channels, out_channels, normalization, num_bases, num_relations, conv_aggr, conv_num_layers, conv_out_channels, lstm_num_layers, lstm_out_channels, add_self_loops, cached, improved, initial_weight, normalize, num_of_nodes, reinitialize_weight, reset_parameters, weight, batch_size, periods, dropout, hidden_size, num_nodes, window, number_of_nodes, bias_pool, weights_pool, hidden_channels, A, attention, edge_index, gcn1, graph, relu, tcn1, kernel_size, conv_1, conv_2, conv_3, nb_time_filter, adaptive, bn, conv_d, in_c, inter_c, num_jpts, num_subset, out_c, sigmoid, soft, tan, conv, embedding_dimensions, Wq, global_gated_updater, item_embedding, item_embedding_dim, items_total, masked_self_attention, stacked_gcn, in_channels_dict, meta") 125 | print() -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.0 2 | sphinx-rtd-theme 3 | sphinx-autoapi 4 | Jinja2>=3.0 -------------------------------------------------------------------------------- /docs/source/_figures/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for filename in *.tex; do 4 | basename=$(basename $filename .tex) 5 | pdflatex "$basename.tex" 6 | pdf2svg "$basename.pdf" "$basename.svg" 7 | done 8 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Use white for logo background. */ 2 | .wy-side-nav-search { 3 | background-color: #fff; 4 | } 5 | 6 | .wy-side-nav-search > div.version { 7 | color: #000; 8 | } 9 | 10 | /* Justify the text. */ 11 | 12 | .section #basic-2-flip-flop-synchronizer{ 13 | text-align:justify; 14 | } 15 | -------------------------------------------------------------------------------- /docs/source/_static/img/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/9a44e623c484c613a850e16474e03ae4242aaab8/docs/source/_static/img/logo.jpg -------------------------------------------------------------------------------- /docs/source/_static/img/text_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/9a44e623c484c613a850e16474e03ae4242aaab8/docs/source/_static/img/text_logo.jpg -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sphinx_rtd_theme 3 | import doctest 4 | import inspect 5 | 6 | 7 | 8 | 9 | 10 | extensions = [ 11 | 'autoapi.extension', 12 | 'sphinx.ext.doctest', 13 | 'sphinx.ext.intersphinx', 14 | 'sphinx.ext.mathjax', 15 | 'sphinx.ext.napoleon', 16 | 'sphinx.ext.viewcode', 17 | 'sphinx.ext.githubpages', 18 | 'sphinx_rtd_theme', 19 | ] 20 | 21 | source_suffix = '.rst' 22 | master_doc = 'index' 23 | 24 | autoapi_python_use_implicit_namespaces = False 25 | 26 | author = 'Benedek Rozemberczki' 27 | project = 'PyTorch Geometric Temporal' 28 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 29 | 30 | html_theme = 'sphinx_rtd_theme' 31 | autoapi_add_toctree_entry = True 32 | 33 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE 34 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)} 35 | 36 | html_theme_options = { 37 | 'collapse_navigation': False, 38 | 'display_version': True, 39 | 'logo_only': True, 40 | 'navigation_depth': 2, 41 | } 42 | 43 | 44 | html_logo = '_static/img/text_logo.jpg' 45 | html_static_path = ['_static'] 46 | 47 | add_module_names = False 48 | autoapi_generate_api_docs = False 49 | 50 | # --- AutoAPI config --- 51 | autoapi_type = 'python' 52 | autoapi_dirs = ['../../torch_geometric_temporal/'] -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/benedekrozemberczki/pytorch_geometric_temporal 2 | 3 | PyTorch Geometric Temporal Documentation 4 | ======================================== 5 | 6 | PyTorch Geometric Temporal is a temporal graph neural network extension library for `PyTorch Geometric `_. It builds on open-source deep-learning and graph processing libraries. *PyTorch Geometric Temporal* consists of state-of-the-art deep learning and parametric learning methods to process spatio-temporal signals. It is the first open-source library for temporal deep learning on geometric structures and provides constant time difference graph neural networks on dynamic and static graphs. We make this happen with the use of discrete time graph snapshots. Implemented methods cover a wide range of data mining (`WWW `_, `KDD `_), artificial intelligence and machine learning (`AAAI `_, `ICONIP `_, `ICLR `_) conferences, workshops, and pieces from prominent journals. 7 | 8 | 9 | 10 | PyTorch Geometric Temporal includes support for index-batching - a new batching technique that improves spatiotemporal memory efficiency without any impact on accuracy. Additionally, PyTorch Geometric Temporal supports memory-efficient distributed data parallel training using Dask-DDP in combination with index-batching. 11 | 12 | .. The package interfaces well with `Pytorch Lightning `_ which allows training on CPUs, single and multiple GPUs out-of-the-box. Take a look at this introductory example of using PyTorch Geometric Temporal with Pytorch Lighning. 13 | 14 | .. code-block:: latex 15 | 16 | >@inproceedings{rozemberczki2021pytorch, 17 | author = {Benedek Rozemberczki and Paul Scherer and Yixuan He and George Panagopoulos and Alexander Riedel and Maria Astefanoaei and Oliver Kiss and Ferenc Beres and and Guzman Lopez and Nicolas Collignon and Rik Sarkar}, 18 | title = {{PyTorch Geometric Temporal: Spatiotemporal Signal Processing with Neural Machine Learning Models}}, 19 | year = {2021}, 20 | booktitle={Proceedings of the 30th ACM International Conference on Information and Knowledge Management}, 21 | pages = {4564–4573}, 22 | } 23 | 24 | .. toctree:: 25 | :glob: 26 | :maxdepth: 2 27 | :caption: Notes 28 | 29 | notes/installation 30 | notes/introduction 31 | notes/resources 32 | 33 | .. toctree:: 34 | :glob: 35 | :maxdepth: 2 36 | :caption: Package Reference 37 | 38 | modules/root 39 | modules/signal 40 | modules/dataset 41 | -------------------------------------------------------------------------------- /docs/source/modules/dataset.rst: -------------------------------------------------------------------------------- 1 | PyTorch Geometric Temporal Dataset 2 | ======================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Datasets 8 | ----------------------- 9 | 10 | .. autoapimodule:: torch_geometric_temporal.dataset.chickenpox 11 | :members: 12 | :undoc-members: 13 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 14 | 15 | .. autoapimodule:: torch_geometric_temporal.dataset.pedalme 16 | :members: 17 | :undoc-members: 18 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 19 | 20 | .. autoapimodule:: torch_geometric_temporal.dataset.wikimath 21 | :members: 22 | :undoc-members: 23 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 24 | 25 | .. autoapimodule:: torch_geometric_temporal.dataset.windmilllarge 26 | :members: 27 | :undoc-members: 28 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 29 | 30 | .. autoapimodule:: torch_geometric_temporal.dataset.windmillmedium 31 | :members: 32 | :undoc-members: 33 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 34 | 35 | .. autoapimodule:: torch_geometric_temporal.dataset.windmillsmall 36 | :members: 37 | :undoc-members: 38 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 39 | 40 | .. autoapimodule:: torch_geometric_temporal.dataset.metr_la 41 | :members: 42 | :undoc-members: 43 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 44 | 45 | .. autoapimodule:: torch_geometric_temporal.dataset.pems_bay 46 | :members: 47 | :undoc-members: 48 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 49 | 50 | .. autoapimodule:: torch_geometric_temporal.dataset.pemsAllLA 51 | :members: 52 | :undoc-members: 53 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 54 | 55 | .. autoapimodule:: torch_geometric_temporal.dataset.pems 56 | :members: 57 | :undoc-members: 58 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 59 | 60 | .. autoapimodule:: torch_geometric_temporal.dataset.encovid 61 | :members: 62 | :undoc-members: 63 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 64 | 65 | .. autoapimodule:: torch_geometric_temporal.dataset.montevideo_bus 66 | :members: 67 | :undoc-members: 68 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 69 | 70 | .. autoapimodule:: torch_geometric_temporal.dataset.twitter_tennis 71 | :members: 72 | :undoc-members: 73 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 74 | :exclude-members: transform_degree, transform_transitivity, encode_features, onehot_encoding 75 | 76 | .. autoapimodule:: torch_geometric_temporal.dataset.mtm 77 | :members: 78 | :undoc-members: 79 | :exclude-members: index,additional_feature_keys,edge_index,edge_weight,features,targets, raw_data_dir, edge_indicies, edge_weights, StaticGraphTemporalSignal, DynamicGraphTemporalSignal 80 | -------------------------------------------------------------------------------- /docs/source/notes/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The installation of PyTorch Geometric Temporal requires the presence of certain prerequisites. These are described in great detail in the installation description of PyTorch Geometric. Please follow the instructions laid out `here `_. You might also take a look at the `readme file `_ of the PyTorch Geometric Temporal repository. 5 | 6 | Once the required versions of PyTorch and PyTorch Geometric are installed, simply run: 7 | 8 | .. code-block:: none 9 | 10 | $ pip install torch-geometric-temporal 11 | 12 | **Updating the Library** 13 | 14 | The package itself can be installed via pip: 15 | 16 | .. code-block:: none 17 | 18 | $ pip install torch-geometric-temporal 19 | 20 | Upgrade your outdated PyTorch Geometric Temporal version by using: 21 | 22 | .. code-block:: none 23 | 24 | $ pip install torch-geometric-temporal --upgrade 25 | 26 | 27 | To check your current package version just simply run: 28 | 29 | .. code-block:: none 30 | 31 | $ pip freeze | grep torch-geometric-temporal 32 | 33 | **Index-Batching** 34 | 35 | The package was recently updated to include index-batching, a new method of batching that improves 36 | memory efficiency without any impact on accuracy. To install the needed packages for index-batching, 37 | run the following command: 38 | 39 | .. code-block:: none 40 | 41 | $ pip install torch-geometric-temporal[index] 42 | 43 | **Distributed Data Parallel** 44 | 45 | Alongside index-batching, PGT was recently updated with features to support distributed data parallel training. 46 | To install the needed packages, run the following: 47 | 48 | .. code-block:: none 49 | 50 | $ pip install torch-geometric-temporal[ddp] 51 | 52 | -------------------------------------------------------------------------------- /docs/source/notes/resources.rst: -------------------------------------------------------------------------------- 1 | External Resources - Architectures 2 | ================================== 3 | 4 | * Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu: **Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting** `Paper `_, `TensorFlow Code `_, `PyTorch Code `_ 5 | 6 | * Youngjoo Seo, Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst: **Structured Sequence Modeling With Graph Convolutional Recurrent Networks** `Paper `_, `Code `_, `TensorFlow Code `_ 7 | 8 | * Jinyin Chen, Xuanheng Xu, Yangyang Wu, Haibin Zheng: **GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction** `Paper `_ 9 | 10 | * Jia Li, Zhichao Han, Hong Cheng, Jiao Su, Pengyun Wang, Jianfeng Zhang, Lujia Pan: **Predicting Path Failure In Time-Evolving Graphs** `Paper `_, `Code `_ 11 | 12 | * Aynaz Taheri, Tanya Berger-Wolf: **Predictive Temporal Embedding of Dynamic Graphs** `Paper `_ 13 | 14 | * Aynaz Taheri, Kevin Gimpel, Tanya Berger-Wolf: **Learning to Represent the Evolution of Dynamic Graphs with Recurrent Models** `Paper `_, `Code `_ 15 | 16 | * Aldo Pareja, Giacomo Domeniconi, Jie Chen, Tengfei Ma, Toyotaro Suzumura, Hiroki Kanezashi, Tim Kaler, Tao B. Schardl, Charles E. Leiserson: **EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs** `Paper `_, `Code `_ 17 | 18 | * Bing Yu, Haoteng Yin, Zhanxing Zhu: **Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting** `Paper `_, `Code `_ 19 | 20 | * Ling Zhao, Yujiao Song, Chao Zhang, Yu Liu, Pu Wang, Tao Lin, Min Deng, Haifeng Li: **T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction** `Paper `_, `Code `_ 21 | 22 | * Jiawei Zhu, Yujiao Song, Ling Zhao, Haifeng Li: **A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting** `Paper `_, `Code `_ 23 | 24 | * Shengnan Guo, Youfang Lin, Ning Feng, Chao Song, Huaiyu Wan: **Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting** `Paper `_, `Code `_ 25 | 26 | * Chuanpan Zheng, Xiaoliang Fan, Cheng Wang, Jianzhong Qi: **GMAN: A Graph Multi-Attention Network for Traffic Prediction** `Paper `_, `Code `_ 27 | 28 | * Zonghan Wu, Shirui Pan, Guodong Long, Jing Jiang, Xiaojun Chang, Chengqi Zhang: **Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks** `Paper `_, `Code `_ 29 | 30 | * Lei Bai, Lina Yao, Can Li, Xianzhi Wang, Can Wang: **Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting** `Paper `_, `Code `_ 31 | 32 | * George Panagopoulos, Giannis Nikolentzos, Michalis Vazirgiannis: **Transfer Graph Neural Networks for Pandemic Forecasting** `Paper `_, `Code `_ 33 | 34 | * Lei Shi, Yifan Zhang, Jian Cheng, Hanqing Lu: **Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition** `Paper `_, `Code `_ 35 | 36 | * Le Yu, Leilei Sun, Bowen Du, Chuanren Liu, Hui Xiong, Weifeng Lv: **Predicting Temporal Sets with Deep Neural Networks** `Paper `_, `Code `_ 37 | 38 | External Resources - Datasets 39 | ============================= 40 | 41 | * Benedek Rozemberczki, Paul Scherer, Oliver Kiss, Rik Sarkar, Tamas Ferenci: **Chickenpox Cases in Hungary: a Benchmark Dataset for Spatiotemporal Signal Processing with Graph Neural Networks** `Paper `_, `Dataset `_ 42 | 43 | * Ferenc Béres, Róbert Pálovics, Anna Oláh, András A. Benczúr: **Temporal Walk Based Centrality Metric for Graph Streams** `Paper `_, `Dataset `_ 44 | 45 | * Ferenc Béres, Domokos M. Kelen, Róbert Pálovics, András A. Benczúr : **Node Embeddings in Dynamic Graphs** `Paper `_, `Dataset `_ 46 | 47 | * Mallick, Tanwi and Balaprakash, Prasanna and Rask, Eric and Macfarlane, Jane: **Graph-partitioning-based diffusion convolutional recurrent neural network for large-scale traffic forecasting** `Paper `_, `Dataset `_ 48 | 49 | -------------------------------------------------------------------------------- /examples/indexBatching/A3TGCN/metr_la_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import csv 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.nn import GCNConv 7 | from torch_geometric_temporal.nn.recurrent import A3TGCN2 8 | from torch_geometric_temporal.dataset import METRLADatasetLoader 9 | import argparse 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset") 14 | 15 | parser.add_argument( 16 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" 17 | ) 18 | parser.add_argument( 19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 20 | ) 21 | parser.add_argument( 22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 23 | ) 24 | parser.add_argument( 25 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 26 | ) 27 | return parser.parse_args() 28 | 29 | # Making the model 30 | class TemporalGNN(torch.nn.Module): 31 | def __init__(self, node_features, periods, batch_size): 32 | super(TemporalGNN, self).__init__() 33 | # Attention Temporal Graph Convolutional Cell 34 | self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12 35 | # Equals single-shot prediction 36 | self.linear = torch.nn.Linear(32, periods) 37 | 38 | def forward(self, x, edge_index): 39 | """ 40 | x = Node features for T time steps 41 | edge_index = Graph edge indices 42 | """ 43 | h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12] 44 | h = F.relu(h) 45 | h = self.linear(h) 46 | return h 47 | 48 | 49 | 50 | def train(train_dataloader, val_dataloader, batch_size, epochs, edges, DEVICE, allGPU=False, debug=False): 51 | 52 | # Create model and optimizers 53 | model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(DEVICE) 54 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 55 | loss_fn = torch.nn.MSELoss() 56 | 57 | stats = [] 58 | t_mse = [] 59 | v_mse = [] 60 | 61 | 62 | edges = edges.to(DEVICE) 63 | for epoch in range(epochs): 64 | step = 0 65 | loss_list = [] 66 | t1 = time.time() 67 | i = 1 68 | total = len(train_dataloader) 69 | for batch in train_dataloader: 70 | X_batch, y_batch = batch 71 | 72 | # Need to permute based on expected input shape for ATGCN 73 | if allGPU: 74 | X_batch = X_batch.permute(0, 2, 3, 1) 75 | y_batch = y_batch[...,0].permute(0, 2, 1) 76 | else: 77 | X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE) 78 | y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE) 79 | 80 | 81 | 82 | y_hat = model(X_batch, edges) # Get model predictions 83 | loss = loss_fn(y_hat, y_batch) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse 84 | loss.backward() 85 | optimizer.step() 86 | optimizer.zero_grad() 87 | step= step+ 1 88 | loss_list.append(loss.item()) 89 | 90 | if debug: 91 | print(f"Train Batch: {i}/{total}", end="\r") 92 | i+=1 93 | 94 | 95 | model.eval() 96 | step = 0 97 | # Store for analysis 98 | total_loss = [] 99 | i = 1 100 | total = len(val_dataloader) 101 | if debug: 102 | print(" ", end="\r") 103 | with torch.no_grad(): 104 | for batch in val_dataloader: 105 | X_batch, y_batch = batch 106 | 107 | 108 | # Need to permute based on expected input shape for ATGCN 109 | if allGPU: 110 | X_batch = X_batch.permute(0, 2, 3, 1) 111 | y_batch = y_batch[...,0].permute(0, 2, 1) 112 | else: 113 | X_batch = X_batch.permute(0, 2, 3, 1).to(DEVICE) 114 | y_batch = y_batch[...,0].permute(0, 2, 1).to(DEVICE) 115 | 116 | # Get model predictions 117 | y_hat = model(X_batch, edges) 118 | # Mean squared error 119 | loss = loss_fn(y_hat, y_batch) 120 | total_loss.append(loss.item()) 121 | 122 | if debug: 123 | print(f"Val Batch: {i}/{total}", end="\r") 124 | i += 1 125 | 126 | 127 | t2 = time.time() 128 | print("Epoch {} time: {:.4f} train RMSE: {:.4f} Test MSE: {:.4f}".format(epoch,t2 - t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss))) 129 | stats.append([epoch, t2-t1, sum(loss_list)/len(loss_list), sum(total_loss)/len(total_loss)]) 130 | t_mse.append(sum(loss_list)/len(loss_list)) 131 | v_mse.append(sum(total_loss)/len(total_loss)) 132 | return min(t_mse), min(v_mse) 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | def main(): 142 | args = parse_arguments() 143 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] 144 | debug = args.debug.lower() in ["true", "y", "t", "yes"] 145 | batch_size = args.batch_size 146 | epochs = args.epochs 147 | 148 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 149 | shuffle= True 150 | 151 | 152 | start = time.time() 153 | p1 = time.time() 154 | indexLoader = METRLADatasetLoader(index=True) 155 | if allGPU: 156 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle, allGPU=0) 157 | else: 158 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle) 159 | p2 = time.time() 160 | t_mse, v_mse = train(train_dataloader, val_dataloader, batch_size, epochs, edges, device, debug=debug) 161 | end = time.time() 162 | 163 | print(f"Runtime: {round(end - start,2)}; T-MSE: {round(t_mse, 3)}; V-MSE: {round(v_mse, 3)}") 164 | 165 | if __name__ == "__main__": 166 | main() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/chicken_pox_main.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN 3 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 4 | import torch.optim as optim 5 | 6 | import argparse 7 | import csv 8 | import os 9 | import time 10 | from utils import * 11 | 12 | 13 | def parse_arguments(): 14 | """Parse command-line arguments.""" 15 | parser = argparse.ArgumentParser(description="Demo of index batching with chickenpox dataset") 16 | 17 | parser.add_argument( 18 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" 19 | ) 20 | parser.add_argument( 21 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 22 | ) 23 | parser.add_argument( 24 | "-m", "--mode", type=str, default="base", help="Which version to run" 25 | ) 26 | parser.add_argument( 27 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 28 | ) 29 | parser.add_argument( 30 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 31 | ) 32 | 33 | return parser.parse_args() 34 | 35 | def train(train_dataloader, val_dataloader, edge_index,edge_weight, epochs, seq_length, num_nodes, num_features, 36 | allGPU=False, debug=False): 37 | 38 | 39 | # Move to GPU 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | edge_index = edge_index.to(device) 42 | edge_weight = edge_weight.to(device) 43 | 44 | 45 | # Initialize model 46 | model = DCRNN(num_features, num_features, K=3).to(device) 47 | 48 | # Define optimizer and loss function 49 | optimizer = optim.Adam(model.parameters(), lr=0.001) 50 | 51 | # Training loop 52 | stats = [] 53 | min_t = 9999 54 | min_v = 9999 55 | for epoch in range(epochs): 56 | # Training phase 57 | model.train() 58 | train_loss = 0.0 59 | i = 1 60 | total = len(train_dataloader) 61 | t1 = time.time() 62 | for batch in train_dataloader: 63 | X_batch, y_batch = batch 64 | 65 | if allGPU == False: 66 | 67 | X_batch = X_batch.to(device).float() 68 | y_batch = y_batch.to(device).float() 69 | 70 | # Forward pass 71 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels) 72 | 73 | # Calculate loss (use only the first output channel, assuming it's the target) 74 | loss = masked_mae_loss(outputs,y_batch ) 75 | 76 | # Backward pass 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | train_loss += loss.item() 82 | if debug: 83 | print(f"Train Batch: {i}/{total}", end="\r") 84 | i+=1 85 | 86 | 87 | train_loss /= len(train_dataloader) 88 | 89 | # Validation phase 90 | model.eval() 91 | val_loss = 0.0 92 | i = 0 93 | if debug: 94 | print(" ", end="\r") 95 | total = len(val_dataloader) 96 | with torch.no_grad(): 97 | for batch in val_dataloader: 98 | X_batch, y_batch = batch 99 | 100 | if allGPU == False: 101 | X_batch = X_batch.to(device).float() 102 | y_batch = y_batch.to(device).float() 103 | 104 | # Forward pass 105 | outputs = model(X_batch, edge_index, edge_weight) 106 | 107 | # Calculate loss 108 | loss = masked_mae_loss(outputs,y_batch) 109 | val_loss += loss.item() 110 | 111 | if debug: 112 | print(f"Val Batch: {i}/{total}", end="\r") 113 | i += 1 114 | 115 | val_loss /= len(val_dataloader) 116 | t2 = time.time() 117 | # Print epoch metrics 118 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True) 119 | stats.append([epoch+1, t2 - t1, train_loss, val_loss]) 120 | 121 | min_t = min(min_t, train_loss) 122 | min_v = min(min_v, val_loss) 123 | 124 | return min_t, min_v 125 | 126 | def main(): 127 | 128 | args = parse_arguments() 129 | allGPU = args.gpu.lower() in ["true", "y", "t"] 130 | debug = args.debug.lower() in ["true", "y", "t"] 131 | batch_size = args.batch_size 132 | epochs = args.epochs 133 | 134 | t1 = time.time() 135 | loader = ChickenpoxDatasetLoader(index=True) 136 | if allGPU == True: 137 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights = loader.get_index_dataset(allGPU=0,batch_size=batch_size) 138 | else: 139 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights = loader.get_index_dataset(batch_size=batch_size) 140 | 141 | t_min, v_min = train(train_dataloader, val_dataloader, edges, edge_weights, epochs, 4,20,1, allGPU=allGPU, debug=debug) 142 | t2 = time.time() 143 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}") 144 | 145 | if __name__ == "__main__": 146 | main() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/pems_allLA_main.py: -------------------------------------------------------------------------------- 1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN 2 | from torch_geometric_temporal.dataset import PemsAllLADatasetLoader 3 | import torch.optim as optim 4 | 5 | import argparse 6 | import csv 7 | import os 8 | import time 9 | from utils import * 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsAllLA dataset") 14 | 15 | parser.add_argument( 16 | "-e", "--epochs", type=int, default=30, help="The desired number of training epochs" 17 | ) 18 | parser.add_argument( 19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 20 | ) 21 | parser.add_argument( 22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 23 | ) 24 | parser.add_argument( 25 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 26 | ) 27 | return parser.parse_args() 28 | 29 | 30 | 31 | 32 | def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, seq_length, num_nodes, num_features, allGPU=False, debug=False): 33 | 34 | # Move to GPU 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | edge_index = edges.to(device) 37 | edge_weight = edge_weights.to(device) 38 | 39 | if allGPU == False: 40 | mean = mean.to(device) 41 | std = std.to(device) 42 | 43 | # Initialize model 44 | model = DCRNN(num_features, num_features, K=3).to(device) 45 | 46 | # Define optimizer and loss function 47 | optimizer = optim.Adam(model.parameters(), lr=0.001) 48 | 49 | # Training loop 50 | stats = [] 51 | min_t = 9999 52 | min_v = 9999 53 | for epoch in range(epochs): 54 | # Training phase 55 | model.train() 56 | train_loss = 0.0 57 | i = 1 58 | total = len(train_dataloader) 59 | t1 = time.time() 60 | for batch in train_dataloader: 61 | X_batch, y_batch = batch 62 | 63 | if allGPU == False: 64 | # print("casting") 65 | X_batch = X_batch.to(device).float() 66 | y_batch = y_batch.to(device).float() 67 | 68 | # Forward pass 69 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels) 70 | 71 | # Calculate loss (use only the first output channel, assuming it's the target) 72 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 73 | 74 | # Backward pass 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | 79 | train_loss += loss.item() 80 | if debug: 81 | print(f"Train Batch: {i}/{total}", end="\r") 82 | i+=1 83 | # break 84 | 85 | 86 | train_loss /= len(train_dataloader) 87 | 88 | # Validation phase 89 | model.eval() 90 | val_loss = 0.0 91 | i = 0 92 | if debug: 93 | print(" ", end="\r") 94 | total = len(val_dataloader) 95 | with torch.no_grad(): 96 | for batch in val_dataloader: 97 | X_batch, y_batch = batch 98 | 99 | if allGPU == False: 100 | X_batch = X_batch.to(device).float() 101 | y_batch = y_batch.to(device).float() 102 | 103 | # Forward pass 104 | outputs = model(X_batch, edge_index, edge_weight) 105 | 106 | # Calculate loss 107 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 108 | val_loss += loss.item() 109 | if debug: 110 | print(f"Val Batch: {i}/{total}", end="\r") 111 | i += 1 112 | 113 | 114 | val_loss /= len(val_dataloader) 115 | t2 = time.time() 116 | # Print epoch metrics 117 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True) 118 | stats.append([epoch+1, t2 - t1, train_loss, val_loss]) 119 | 120 | min_t = min(min_t, train_loss) 121 | min_v = min(min_v, val_loss) 122 | 123 | return min_t, min_v 124 | 125 | def main(): 126 | 127 | 128 | 129 | args = parse_arguments() 130 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] 131 | debug = args.debug.lower() in ["true", "y", "t", "yes"] 132 | batch_size = args.batch_size 133 | epochs = args.epochs 134 | 135 | t1 = time.time() 136 | loader = PemsAllLADatasetLoader(index=True) 137 | if allGPU: 138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(allGPU=0, batch_size=batch_size) 139 | else: 140 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size) 141 | 142 | 143 | t_min, v_min = train(train_dataloader, val_dataloader, means, stds, edges, edge_weights, epochs, 12,11160,2, allGPU=allGPU, debug=debug) 144 | t2 = time.time() 145 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}") 146 | 147 | if __name__ == "__main__": 148 | main() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/pems_bay_main.py: -------------------------------------------------------------------------------- 1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN 2 | from torch_geometric_temporal.dataset import PemsBayDatasetLoader 3 | import torch.optim as optim 4 | 5 | import argparse 6 | import csv 7 | import os 8 | import time 9 | from utils import * 10 | 11 | 12 | 13 | def parse_arguments(): 14 | parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset") 15 | 16 | parser.add_argument( 17 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" 18 | ) 19 | parser.add_argument( 20 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 21 | ) 22 | parser.add_argument( 23 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 24 | ) 25 | parser.add_argument( 26 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 27 | ) 28 | return parser.parse_args() 29 | 30 | def train(train_dataloader, val_dataloader, mean, std, edge_index, edge_weight, epochs, seq_length, num_nodes, num_features, 31 | allGPU=False, debug=False): 32 | 33 | 34 | 35 | # Move to GPU 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | edge_index = edge_index.to(device) 38 | edge_weight = edge_weight.to(device) 39 | 40 | if allGPU == False: 41 | mean = mean.to(device) 42 | std = std.to(device) 43 | 44 | # Initialize model 45 | model = DCRNN(num_features, num_features, K=3).to(device) 46 | 47 | # Define optimizer and loss function 48 | optimizer = optim.Adam(model.parameters(), lr=0.001) 49 | 50 | # Training loop 51 | stats = [] 52 | min_t = 9999 53 | min_v = 9999 54 | for epoch in range(epochs): 55 | # Training phase 56 | model.train() 57 | train_loss = 0.0 58 | i = 1 59 | total = len(train_dataloader) 60 | t1 = time.time() 61 | for batch in train_dataloader: 62 | X_batch, y_batch = batch 63 | 64 | if allGPU == False: 65 | X_batch = X_batch.to(device).float() 66 | y_batch = y_batch.to(device).float() 67 | # Forward pass 68 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels) 69 | 70 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 71 | 72 | # Backward pass 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | train_loss += loss.item() 78 | if debug: 79 | print(f"Train Batch: {i}/{total}", end="\r") 80 | i+=1 81 | 82 | 83 | train_loss /= len(train_dataloader) 84 | 85 | # Validation phase 86 | model.eval() 87 | val_loss = 0.0 88 | i = 0 89 | if debug: 90 | print(" ", end="\r") 91 | total = len(val_dataloader) 92 | with torch.no_grad(): 93 | for batch in val_dataloader: 94 | X_batch, y_batch = batch 95 | 96 | if allGPU == False: 97 | X_batch = X_batch.to(device).float() 98 | y_batch = y_batch.to(device).float() 99 | 100 | # Forward pass 101 | outputs = model(X_batch, edge_index, edge_weight) 102 | 103 | # Calculate loss 104 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 105 | val_loss += loss.item() 106 | if debug: 107 | print(f"Val Batch: {i}/{total}", end="\r") 108 | i += 1 109 | 110 | val_loss /= len(val_dataloader) 111 | t2 = time.time() 112 | # Print epoch metrics 113 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True) 114 | stats.append([epoch+1, t2 - t1, train_loss, val_loss]) 115 | 116 | min_t = min(min_t, train_loss) 117 | min_v = min(min_v, val_loss) 118 | 119 | 120 | 121 | return min_t, min_v 122 | 123 | def main(): 124 | 125 | 126 | 127 | args = parse_arguments() 128 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] 129 | debug = args.debug.lower() in ["true", "y", "t", "yes"] 130 | batch_size = args.batch_size 131 | epochs = args.epochs 132 | t1 = time.time() 133 | loader = PemsBayDatasetLoader(index=True) 134 | 135 | if allGPU == True: 136 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(allGPU=0, batch_size=batch_size) 137 | else: 138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(batch_size=batch_size) 139 | 140 | t_min, v_min = train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, 12,325,2, allGPU=allGPU,debug=debug) 141 | t2 = time.time() 142 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}") 143 | 144 | 145 | if __name__ == "__main__": 146 | main() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/pems_main.py: -------------------------------------------------------------------------------- 1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN 2 | from torch_geometric_temporal.dataset import PemsDatasetLoader 3 | import torch.optim as optim 4 | 5 | import argparse 6 | import csv 7 | import os 8 | import time 9 | from utils import * 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="Demo of index batching with Pems dataset") 14 | 15 | parser.add_argument( 16 | "-e", "--epochs", type=int, default=30, help="The desired number of training epochs" 17 | ) 18 | parser.add_argument( 19 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 20 | ) 21 | parser.add_argument( 22 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 23 | ) 24 | parser.add_argument( 25 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 26 | ) 27 | return parser.parse_args() 28 | 29 | 30 | 31 | 32 | def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, seq_length, num_nodes, num_features, allGPU=False, debug=False): 33 | 34 | # Move to GPU 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | edge_index = edges.to(device) 37 | edge_weight = edge_weights.to(device) 38 | 39 | if allGPU == False: 40 | mean = mean.to(device) 41 | std = std.to(device) 42 | 43 | # Initialize model 44 | model = DCRNN(num_features, num_features, K=3).to(device) 45 | 46 | # Define optimizer and loss function 47 | optimizer = optim.Adam(model.parameters(), lr=0.001) 48 | 49 | # Training loop 50 | stats = [] 51 | min_t = 9999 52 | min_v = 9999 53 | for epoch in range(epochs): 54 | # Training phase 55 | model.train() 56 | train_loss = 0.0 57 | i = 1 58 | total = len(train_dataloader) 59 | t1 = time.time() 60 | for batch in train_dataloader: 61 | X_batch, y_batch = batch 62 | 63 | if allGPU == False: 64 | # print("casting") 65 | X_batch = X_batch.to(device).float() 66 | y_batch = y_batch.to(device).float() 67 | 68 | # Forward pass 69 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels) 70 | 71 | # Calculate loss (use only the first output channel, assuming it's the target) 72 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 73 | 74 | # Backward pass 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | 79 | train_loss += loss.item() 80 | if debug: 81 | print(f"Train Batch: {i}/{total}", end="\r") 82 | i+=1 83 | # break 84 | 85 | 86 | train_loss /= len(train_dataloader) 87 | 88 | # Validation phase 89 | model.eval() 90 | val_loss = 0.0 91 | i = 0 92 | if debug: 93 | print(" ", end="\r") 94 | total = len(val_dataloader) 95 | with torch.no_grad(): 96 | for batch in val_dataloader: 97 | X_batch, y_batch = batch 98 | 99 | if allGPU == False: 100 | X_batch = X_batch.to(device).float() 101 | y_batch = y_batch.to(device).float() 102 | 103 | # Forward pass 104 | outputs = model(X_batch, edge_index, edge_weight) 105 | 106 | # Calculate loss 107 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 108 | val_loss += loss.item() 109 | if debug: 110 | print(f"Val Batch: {i}/{total}", end="\r") 111 | i += 1 112 | 113 | 114 | val_loss /= len(val_dataloader) 115 | t2 = time.time() 116 | # Print epoch metrics 117 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True) 118 | stats.append([epoch+1, t2 - t1, train_loss, val_loss]) 119 | 120 | min_t = min(min_t, train_loss) 121 | min_v = min(min_v, val_loss) 122 | 123 | return min_t, min_v 124 | 125 | def main(): 126 | 127 | 128 | 129 | args = parse_arguments() 130 | allGPU = args.gpu.lower() in ["true", "y", "t", "yes"] 131 | debug = args.debug.lower() in ["true", "y", "t", "yes"] 132 | batch_size = args.batch_size 133 | epochs = args.epochs 134 | 135 | t1 = time.time() 136 | loader = PemsDatasetLoader(index=True) 137 | if allGPU: 138 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(allGPU=0, batch_size=batch_size) 139 | else: 140 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size) 141 | 142 | 143 | t_min, v_min = train(train_dataloader, val_dataloader, means, stds, edges, edge_weights, epochs, 12,11160,2, allGPU=allGPU, debug=debug) 144 | t2 = time.time() 145 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}") 146 | 147 | if __name__ == "__main__": 148 | main() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -l select=5:system=polaris 3 | #PBS -l place=scatter 4 | #PBS -l filesystems=home:eagle 5 | #PBS -l walltime=00:30:00 6 | #PBS -q 7 | #PBS -A 8 | #PBS -o train.out 9 | #PBS -e train.err 10 | 11 | nodes=4 12 | gpus_per_node=4 13 | 14 | # should the training utilize GPU-index-batching 15 | allGPU=True 16 | 17 | # which dataset to train on; valid options include "pems-bay", 'pemsAllLA', and "pems" 18 | dataset="pems" 19 | 20 | # total workers 21 | total=$((gpus_per_node * nodes)) 22 | 23 | readarray -t all_nodes < "$NODEFILE" 24 | 25 | # use first node for dask scheduler and client 26 | scheduler_node=${all_nodes[0]} 27 | monitor_node=${all_nodes[1]} 28 | 29 | # all nodes but first for workers 30 | tail -n +2 $NODEFILE > worker_nodefile.txt 31 | 32 | echo "Launching scheduler" 33 | mpiexec -n 1 --ppn 1 --cpu-bind none --hosts $scheduler_node dask scheduler --scheduler-file cluster.info & 34 | scheduler_pid=$! 35 | 36 | # wait for the scheduler to generate the cluster config file 37 | while ! [ -f cluster.info ]; do 38 | sleep 1 39 | echo . 40 | done 41 | 42 | echo "$total workers launching" 43 | mpiexec -n $total --ppn $gpus_per_node --cpu-bind none --hostfile worker_nodefile.txt dask worker --local-directory /local/scratch --scheduler-file cluster.info --nthreads 8 --memory-limit 512GB & 44 | 45 | # give workers a bit to launch 46 | sleep 5 47 | 48 | echo "Launching client" 49 | mpiexec -n 1 --ppn 1 --cpu-bind none --hosts $scheduler_node `which python3` pems_ddp.py --dask-cluster-file cluster.info -np $gpus_per_node -g $allGPU --dataset $dataset 50 | -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import csv 6 | import os 7 | import time 8 | import scipy.sparse as sp 9 | 10 | def masked_mae_loss(y_pred, y_true): 11 | 12 | mask = (y_true != 0).float() 13 | mask /= mask.mean() 14 | loss = torch.abs(y_pred - y_true) 15 | loss = loss * mask 16 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 17 | loss[loss != loss] = 0 18 | return loss.mean() 19 | 20 | 21 | 22 | def load_graph_data(pkl_filename): 23 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 24 | return sensor_ids, sensor_id_to_ind, adj_mx 25 | 26 | 27 | def load_pickle(pickle_file): 28 | try: 29 | with open(pickle_file, 'rb') as f: 30 | pickle_data = pickle.load(f) 31 | except UnicodeDecodeError as e: 32 | with open(pickle_file, 'rb') as f: 33 | pickle_data = pickle.load(f, encoding='latin1') 34 | except Exception as e: 35 | print('Unable to load data ', pickle_file, ':', e) 36 | raise 37 | return pickle_data 38 | 39 | def adjacency_to_edge_index(adj_mx): 40 | """ 41 | Convert an adjacency matrix to edge_index and edge_weight. 42 | Args: 43 | adj_mx (np.ndarray): Adjacency matrix of shape (num_nodes, num_nodes). 44 | Returns: 45 | edge_index (torch.LongTensor): Shape (2, num_edges), source and target nodes. 46 | edge_weight (torch.FloatTensor): Shape (num_edges,), edge weights. 47 | """ 48 | # Convert to sparse matrix 49 | adj_sparse = sp.coo_matrix(adj_mx) 50 | 51 | # Extract edge indices and weights 52 | edge_index = torch.tensor( 53 | np.vstack((adj_sparse.row, adj_sparse.col)), dtype=torch.long 54 | ) 55 | edge_weight = torch.tensor(adj_sparse.data, dtype=torch.float) 56 | 57 | return edge_index, edge_weight 58 | 59 | # standard approach: see https://github.com/chnsh/DCRNN_PyTorch 60 | def benchmark_preprocess(h5File, dataset, key=None): 61 | 62 | if "pems" in dataset.lower(): 63 | df = pd.read_hdf(h5File) 64 | 65 | x_offsets = np.sort( 66 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1))) 67 | np.concatenate((np.arange(-11, 1, 1),)) 68 | ) 69 | # Predict the next one hour 70 | y_offsets = np.sort(np.arange(1, 13, 1)) 71 | num_samples, num_nodes = df.shape 72 | 73 | data = np.expand_dims(df.values, axis=-1) 74 | data_list = [data] 75 | add_time_in_day= True 76 | if add_time_in_day: 77 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 78 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 79 | data_list.append(time_in_day) 80 | 81 | data = np.concatenate(data_list, axis=-1) 82 | 83 | 84 | x, y = [], [] 85 | # t is the index of the last observation. 86 | min_t = abs(min(x_offsets)) 87 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 88 | for t in range(min_t, max_t): 89 | x_t = data[t + x_offsets, ...] 90 | 91 | 92 | y_t = data[t + y_offsets, ...] 93 | x.append(x_t) 94 | y.append(y_t) 95 | 96 | 97 | 98 | x = np.stack(x, axis=0) 99 | y = np.stack(y, axis=0) 100 | 101 | return torch.tensor(x,dtype=torch.float),torch.tensor(y,dtype=torch.float) 102 | 103 | def collect_metrics(): 104 | import psutil 105 | from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlShutdown 106 | 107 | try: 108 | # Initialize NVML for GPU metrics 109 | nvmlInit() 110 | 111 | # Open the CSV file in append mode 112 | 113 | data = [] 114 | max_gpu_mem = -1 115 | max_system_mem = -1 116 | while True: 117 | # Collect system memory usage 118 | timestamp = time.strftime("%Y-%m-%d %H:%M:%S") 119 | mem = psutil.virtual_memory() 120 | total_rss = sum(proc.memory_info().rss for proc in psutil.process_iter(attrs=['memory_info'])) 121 | system_memory_used = total_rss / (1024**2) # Convert to MB 122 | system_memory_total = mem.total / (1024**2) # Convert to MB 123 | 124 | # Collect GPU memory usage 125 | gpu_metrics = [] 126 | handle = nvmlDeviceGetHandleByIndex(0) 127 | info = nvmlDeviceGetMemoryInfo(handle) 128 | gpu_memory_used = info.used / (1024**2) # Convert to MB 129 | gpu_memory_total = info.total / (1024**2) # Convert to MB 130 | 131 | max_gpu_mem = max(gpu_memory_used, max_gpu_mem) 132 | max_system_mem = max(system_memory_used, max_system_mem) 133 | data.append([timestamp,system_memory_used,system_memory_total, gpu_memory_used, gpu_memory_total]) 134 | 135 | if os.path.isfile("flag.txt"): 136 | os.remove("flag.txt") 137 | break 138 | 139 | if os.path.isfile("stats.csv"): 140 | with open("system_stats.csv", mode="w", newline="") as f: 141 | writer = csv.writer(f) 142 | 143 | # Write headers to the CSV file 144 | headers = [ 145 | "Timestamp", 146 | "System_Memory_Used", 147 | "System_Memory_Total", 148 | "GPU_Memory_Used", 149 | "GPU_Memory_Total" 150 | ] 151 | writer.writerow(headers) 152 | writer.writerows(data) 153 | 154 | 155 | df = pd.read_csv("stats.csv") 156 | df['system_memory_total'] = system_memory_total 157 | df['max_system_mem'] = max_system_mem 158 | df['gpu_memory_total'] = gpu_memory_total 159 | df['max_gpu_mem'] = max_gpu_mem 160 | 161 | df.to_csv("stats.csv", index=False) 162 | break 163 | time.sleep(1) 164 | 165 | except Exception as e: 166 | print("Error in collecting metrics:", str(e)) 167 | 168 | finally: 169 | # Shutdown NVML 170 | nvmlShutdown() -------------------------------------------------------------------------------- /examples/indexBatching/DCRNN/windmill_main.py: -------------------------------------------------------------------------------- 1 | from torch_geometric_temporal.nn.recurrent import BatchedDCRNN as DCRNN 2 | from torch_geometric_temporal.dataset import WindmillOutputLargeDatasetLoader 3 | import torch.optim as optim 4 | 5 | import argparse 6 | import time 7 | import csv 8 | import os 9 | from utils import * 10 | 11 | 12 | 13 | def parse_arguments(): 14 | """Parse command-line arguments.""" 15 | parser = argparse.ArgumentParser(description="Demo of index batching with WindmillLarge dataset") 16 | 17 | parser.add_argument( 18 | "-e", "--epochs", type=int, default=100, help="The desired number of training epochs" 19 | ) 20 | parser.add_argument( 21 | "-bs", "--batch-size", type=int, default=64, help="The desired batch size" 22 | ) 23 | parser.add_argument( 24 | "-m", "--mode", type=str, default="base", help="Which version to run" 25 | ) 26 | parser.add_argument( 27 | "-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU" 28 | ) 29 | parser.add_argument( 30 | "-d", "--debug", type=str, default="False", help="Print values for debugging" 31 | ) 32 | 33 | return parser.parse_args() 34 | 35 | 36 | def train(train_dataloader, val_dataloader, mean, std, edge_index,edge_weight, epochs, seq_length, num_nodes, num_features, 37 | allGPU=False, debug=False): 38 | 39 | 40 | # Move to GPU 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | edge_index = edge_index.to(device) 43 | edge_weight = edge_weight.to(device) 44 | 45 | if allGPU == False: 46 | mean = mean.to(device) 47 | std = std.to(device) 48 | 49 | # Initialize model 50 | model = DCRNN(num_features, num_features, K=3).to(device) 51 | 52 | # Define optimizer and loss function 53 | optimizer = optim.Adam(model.parameters(), lr=0.001) 54 | 55 | # Training loop 56 | stats = [] 57 | min_t = 9999 58 | min_v = 9999 59 | for epoch in range(epochs): 60 | # Training phase 61 | model.train() 62 | train_loss = 0.0 63 | i = 1 64 | total = len(train_dataloader) 65 | t1 = time.time() 66 | for batch in train_dataloader: 67 | X_batch, y_batch = batch 68 | 69 | if allGPU == False: 70 | # print("casting") 71 | X_batch = X_batch.to(device).float() 72 | y_batch = y_batch.to(device).float() 73 | 74 | # Forward pass 75 | outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels) 76 | 77 | # Calculate loss (use only the first output channel, assuming it's the target) 78 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 79 | 80 | # Backward pass 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | train_loss += loss.item() 86 | if debug: 87 | print(f"Train Batch: {i}/{total}", end="\r") 88 | i+=1 89 | 90 | 91 | 92 | train_loss /= len(train_dataloader) 93 | 94 | # Validation phase 95 | model.eval() 96 | val_loss = 0.0 97 | i = 0 98 | if debug: 99 | print(" ", end="\r") 100 | total = len(val_dataloader) 101 | with torch.no_grad(): 102 | for batch in val_dataloader: 103 | X_batch, y_batch = batch 104 | 105 | if allGPU == False: 106 | X_batch = X_batch.to(device).float() 107 | y_batch = y_batch.to(device).float() 108 | 109 | # Forward pass 110 | outputs = model(X_batch, edge_index, edge_weight) 111 | 112 | # Calculate loss 113 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 114 | val_loss += loss.item() 115 | 116 | if debug: 117 | print(f"Val Batch: {i}/{total}", end="\r") 118 | i += 1 119 | 120 | val_loss /= len(val_dataloader) 121 | t2 = time.time() 122 | 123 | # Print epoch metrics 124 | print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", flush=True) 125 | stats.append([epoch+1, t2 - t1, train_loss, val_loss]) 126 | 127 | min_t = min(min_t, train_loss) 128 | min_v = min(min_v, val_loss) 129 | 130 | return min_t, min_v 131 | 132 | def main(): 133 | 134 | args = parse_arguments() 135 | allGPU = args.gpu.lower() in ["true", "y", "t"] 136 | debug = args.debug.lower() in ["true", "y", "t"] 137 | batch_size = args.batch_size 138 | epochs = args.epochs 139 | 140 | t1 = time.time() 141 | loader = WindmillOutputLargeDatasetLoader(index=True) 142 | if allGPU == True: 143 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(allGPU=0, batch_size=batch_size) 144 | else: 145 | train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = loader.get_index_dataset(batch_size=batch_size) 146 | 147 | 148 | 149 | t_min, v_min = train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epochs, 8,319,1, allGPU=allGPU, debug=debug) 150 | t2 = time.time() 151 | print(f"Runtime: {round(t2 - t1,2)}; Best Train MSE: {t_min}; Best Validation MSE: {v_min}") 152 | 153 | if __name__ == "__main__": 154 | main() -------------------------------------------------------------------------------- /examples/indexBatching/README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch Geometric Temporal - Index 2 | 3 | Index-batching is a technique that reduces the memory cost of training ST-GNNs with spatiotemporal data with no impact on accurary, enabling greater scalability and training on the full PeMS dataset without graph partioning for the first time. Leveraging the reduced memory footprint, this techique also enables GPU-index-batching - a technique that performs preprocessing entirely in GPU memory and utilizes a single CPU-to-GPU mem-copy in place of batch-level CPU-to-GPU transfers throughout training. We implemented GPU-index-batching and index-batching for the following existing datasets and added two new datasets (highlighted in bold) to PyTorch Geometric Temporal (PGT): 4 | 5 | * PeMs-Bay 6 | * Metr-LA 7 | * WindmillLarge 8 | * HungaryChickenpox 9 | * **PeMSAllLA** 10 | * **PeMS** 11 | 12 | This folder contains examples with DCRNN and A3TGCN. We hope to build out our examples over time. 13 | 14 | 15 | Utilizing index-batching requires minimal modifications to the existing PGT workflow. For example, the following is a sample training loop with static graph dataset with temporal signal: 16 | 17 | ``` 18 | train_dataloader, _, _, edges, edge_weights, means, stds = loader.get_index_dataset(batch_size=batch_size) 19 | 20 | for batch in train_dataloader: 21 | X_batch, y_batch = batch 22 | 23 | # Forward pass 24 | outputs = model(X_batch, edges, edge_weights) 25 | 26 | # Calculate loss 27 | loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean) 28 | 29 | # Backward pass 30 | optimizer.zero_grad() 31 | loss.backward() 32 | optimizer.step() 33 | 34 | 35 | ``` 36 | 37 | The single-GPU examples in this repo can be executed as follows: `python3 .py` and have the following parameters: 38 | 39 | | Argument | Short Form | Type | Default | Description | 40 | |----------|-----------|------|---------|-------------| 41 | | `--epochs` | `-e` | `int` | `30` | The desired number of training epochs. | 42 | | `--batch-size` | `-bs` | `int` | `64` | The desired batch size for training. | 43 | | `--gpu` | `-g` | `str` | `"False"` | Indicates whether data should be preprocessed and migrated directly to the GPU. Use `"True"` to enable GPU processing. | 44 | | `--debug` | `-d` | `str` | `"False"` | Enables debug mode, printing values for debugging. Use `"True"` to enable debugging. | 45 | 46 | 47 | 48 | We also provide a multi-node, multi-GPU Dask-DDP training implementation for PeMS-Bay, PemsAllLA, and the full PeMS dataset. It has the following parameters: 49 | 50 | | Argument | Short Form | Type | Default | Description | 51 | |----------|-----------|------|---------|-------------| 52 | | `--epochs` | `-e` | `int` | `100` | The desired number of training epochs. | 53 | | `--batch-size` | `-bs` | `int` | `64` | The desired batch size for training. | 54 | | `--gpu` | `-g` | `str` | `"False"` | Indicates whether data should be preprocessed and migrated directly to the GPU. Use `"True"` to enable GPU processing. | 55 | | `--debug` | `-d` | `str` | `"False"` | Enables debug mode, printing values for debugging. Use `"True"` to enable debugging. | 56 | | `--dask-cluster-file` | N/A | `str` | `""` | Path to the Dask scheduler file for the Dask CLI interface. | 57 | | `--npar` | `-np` | `int` | `1` | The number of GPUs or workers per node. | 58 | | `--dataset` | N/A | `str` | `"pems-bay"` | Specifies which dataset is in use (e.g., PeMS-Bay, PeMS-All-LA, PeMS). | 59 | 60 | To execute in a single node, omit the `--dask-cluster-file` argument. To run multi-node, setup a Dask scheduler and dask workers user the [Dask command line interface](https://docs.dask.org/en/latest/deploying-cli.html) and pass the scheduler file via `--dask-cluster-file`. An example multi-GPU, multi-node script for Argonne's [Polaris supercomputer](https://www.alcf.anl.gov/polaris) is shown in `submit.sh`. 61 | 62 | -------------------------------------------------------------------------------- /examples/recurrent/a3tgcn_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import A3TGCN 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features, periods): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = A3TGCN(node_features, 32, periods) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x.view(x.shape[0], 1, x.shape[1]), edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 1, periods = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(50)): 39 | cost = 0 40 | for time, snapshot in enumerate(train_dataset): 41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 42 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 43 | cost = cost / (time+1) 44 | cost.backward() 45 | optimizer.step() 46 | optimizer.zero_grad() 47 | 48 | model.eval() 49 | cost = 0 50 | for time, snapshot in enumerate(test_dataset): 51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost = cost.item() 55 | print("MSE: {:.4f}".format(cost)) 56 | -------------------------------------------------------------------------------- /examples/recurrent/agcrn_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import AGCRN 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset(lags=8) 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = AGCRN(number_of_nodes = 20, 24 | in_channels = node_features, 25 | out_channels = 2, 26 | K = 2, 27 | embedding_dimensions = 4) 28 | self.linear = torch.nn.Linear(2, 1) 29 | 30 | def forward(self, x, e, h): 31 | h_0 = self.recurrent(x, e, h) 32 | y = F.relu(h_0) 33 | y = self.linear(y) 34 | return y, h_0 35 | 36 | model = RecurrentGCN(node_features = 8) 37 | 38 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 39 | 40 | model.train() 41 | 42 | e = torch.empty(20, 4) 43 | 44 | torch.nn.init.xavier_uniform_(e) 45 | 46 | for epoch in tqdm(range(200)): 47 | cost = 0 48 | h = None 49 | for time, snapshot in enumerate(train_dataset): 50 | x = snapshot.x.view(1, 20, 8) 51 | y_hat, h = model(x, e, h) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost.backward() 55 | optimizer.step() 56 | optimizer.zero_grad() 57 | 58 | model.eval() 59 | cost = 0 60 | for time, snapshot in enumerate(test_dataset): 61 | x = snapshot.x.view(1, 20, 8) 62 | y_hat, h = model(x, e, h) 63 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 64 | cost = cost / (time+1) 65 | cost = cost.item() 66 | print("MSE: {:.4f}".format(cost)) 67 | -------------------------------------------------------------------------------- /examples/recurrent/dcrnn_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import DCRNN 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = DCRNN(node_features, 32, 1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x, edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | for time, snapshot in enumerate(train_dataset): 41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 42 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 43 | cost = cost / (time+1) 44 | cost.backward() 45 | optimizer.step() 46 | optimizer.zero_grad() 47 | 48 | model.eval() 49 | cost = 0 50 | for time, snapshot in enumerate(test_dataset): 51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost = cost.item() 55 | print("MSE: {:.4f}".format(cost)) 56 | -------------------------------------------------------------------------------- /examples/recurrent/dygrencoder_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import DyGrEncoder 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = DyGrEncoder(conv_out_channels=4, conv_num_layers=1, conv_aggr="mean", lstm_out_channels=32, lstm_num_layers=1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight, h_0, c_0): 27 | h, h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h_0, c_0) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h, h_0, c_0 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | h, c = None, None 41 | for time, snapshot in enumerate(train_dataset): 42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 43 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 44 | cost = cost / (time+1) 45 | cost.backward() 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | model.eval() 50 | cost = 0 51 | h, c = None, None 52 | for time, snapshot in enumerate(test_dataset): 53 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 54 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 55 | cost = cost / (time+1) 56 | cost = cost.item() 57 | print("MSE: {:.4f}".format(cost)) 58 | -------------------------------------------------------------------------------- /examples/recurrent/evolvegcnh_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import EvolveGCNH 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_count, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = EvolveGCNH(node_count, node_features) 24 | self.linear = torch.nn.Linear(node_features, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x, edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 4, node_count = 20) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | for time, snapshot in enumerate(train_dataset): 41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 42 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 43 | cost = cost / (time+1) 44 | cost.backward() 45 | optimizer.step() 46 | optimizer.zero_grad() 47 | 48 | model.eval() 49 | cost = 0 50 | for time, snapshot in enumerate(test_dataset): 51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost = cost.item() 55 | print("MSE: {:.4f}".format(cost)) 56 | -------------------------------------------------------------------------------- /examples/recurrent/evolvegcno_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import EvolveGCNO 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = EvolveGCNO(node_features) 24 | self.linear = torch.nn.Linear(node_features, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x, edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | for param in model.parameters(): 34 | param.retain_grad() 35 | 36 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 37 | 38 | model.train() 39 | 40 | for epoch in tqdm(range(200)): 41 | cost = 0 42 | for time, snapshot in enumerate(train_dataset): 43 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 44 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 45 | cost = cost / (time+1) 46 | cost.backward(retain_graph=True) 47 | optimizer.step() 48 | optimizer.zero_grad() 49 | 50 | model.eval() 51 | cost = 0 52 | for time, snapshot in enumerate(test_dataset): 53 | if time == 0: 54 | model.recurrent.weight = None 55 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 56 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 57 | cost = cost / (time+1) 58 | cost = cost.item() 59 | print("MSE: {:.4f}".format(cost)) 60 | -------------------------------------------------------------------------------- /examples/recurrent/gclstm_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import GCLSTM 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = GCLSTM(node_features, 32, 1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight, h, c): 27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h, c) 28 | h = F.relu(h_0) 29 | h = self.linear(h) 30 | return h, h_0, c_0 31 | 32 | model = RecurrentGCN(node_features=4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | h, c = None, None 41 | for time, snapshot in enumerate(train_dataset): 42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 43 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 44 | cost = cost / (time+1) 45 | cost.backward() 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | model.eval() 50 | cost = 0 51 | for time, snapshot in enumerate(test_dataset): 52 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 53 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 54 | cost = cost / (time+1) 55 | cost = cost.item() 56 | print("MSE: {:.4f}".format(cost)) 57 | -------------------------------------------------------------------------------- /examples/recurrent/gconvgru_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import GConvGRU 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = GConvGRU(node_features, 32, 1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x, edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | for time, snapshot in enumerate(train_dataset): 41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 42 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 43 | cost = cost / (time+1) 44 | cost.backward() 45 | optimizer.step() 46 | optimizer.zero_grad() 47 | 48 | model.eval() 49 | cost = 0 50 | for time, snapshot in enumerate(test_dataset): 51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost = cost.item() 55 | print("MSE: {:.4f}".format(cost)) 56 | -------------------------------------------------------------------------------- /examples/recurrent/gconvlstm_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import GConvLSTM 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = GConvLSTM(node_features, 32, 1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight, h, c): 27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h, c) 28 | h = F.relu(h_0) 29 | h = self.linear(h) 30 | return h, h_0, c_0 31 | 32 | model = RecurrentGCN(node_features=4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | h, c = None, None 41 | for time, snapshot in enumerate(train_dataset): 42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 43 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 44 | cost = cost / (time+1) 45 | cost.backward() 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | model.eval() 50 | cost = 0 51 | for time, snapshot in enumerate(test_dataset): 52 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 53 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 54 | cost = cost / (time+1) 55 | cost = cost.item() 56 | print("MSE: {:.4f}".format(cost)) 57 | -------------------------------------------------------------------------------- /examples/recurrent/lightning_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 6 | 7 | from torch_geometric_temporal.nn.recurrent import DCRNN 8 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 9 | from torch_geometric_temporal.signal import temporal_signal_split 10 | 11 | 12 | class LitDiffConvModel(pl.LightningModule): 13 | 14 | def __init__(self, node_features, filters): 15 | super().__init__() 16 | self.recurrent = DCRNN(node_features, filters, 1) 17 | self.linear = torch.nn.Linear(filters, 1) 18 | 19 | 20 | def configure_optimizers(self): 21 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-2) 22 | return optimizer 23 | 24 | def training_step(self, train_batch, batch_idx): 25 | x = train_batch.x 26 | y = train_batch.y.view(-1, 1) 27 | edge_index = train_batch.edge_index 28 | h = self.recurrent(x, edge_index) 29 | h = F.relu(h) 30 | h = self.linear(h) 31 | loss = F.mse_loss(h, y) 32 | return loss 33 | 34 | def validation_step(self, val_batch, batch_idx): 35 | x = val_batch.x 36 | y = val_batch.y.view(-1, 1) 37 | edge_index = val_batch.edge_index 38 | h = self.recurrent(x, edge_index) 39 | h = F.relu(h) 40 | h = self.linear(h) 41 | loss = F.mse_loss(h, y) 42 | metrics = {'val_loss': loss} 43 | self.log_dict(metrics) 44 | return metrics 45 | 46 | 47 | loader = ChickenpoxDatasetLoader() 48 | 49 | dataset_loader = loader.get_dataset(lags=32) 50 | 51 | train_loader, val_loader = temporal_signal_split(dataset_loader, 52 | train_ratio=0.2) 53 | 54 | model = LitDiffConvModel(node_features=32, 55 | filters=16) 56 | 57 | early_stop_callback = EarlyStopping(monitor='val_loss', 58 | min_delta=0.00, 59 | patience=10, 60 | verbose=False, 61 | mode='max') 62 | 63 | trainer = pl.Trainer(callbacks=[early_stop_callback]) 64 | 65 | trainer.fit(model, train_loader, val_loader) 66 | -------------------------------------------------------------------------------- /examples/recurrent/lrgcn_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import LRGCN 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = LRGCN(node_features, 32, 1, 1) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight, h_0, c_0): 27 | h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h_0, c_0) 28 | h = F.relu(h_0) 29 | h = self.linear(h) 30 | return h, h_0, c_0 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(200)): 39 | cost = 0 40 | h, c = None, None 41 | for time, snapshot in enumerate(train_dataset): 42 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 43 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 44 | cost = cost / (time+1) 45 | cost.backward() 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | model.eval() 50 | cost = 0 51 | h, c = None, None 52 | for time, snapshot in enumerate(test_dataset): 53 | y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c) 54 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 55 | cost = cost / (time+1) 56 | cost = cost.item() 57 | print("MSE: {:.4f}".format(cost)) 58 | -------------------------------------------------------------------------------- /examples/recurrent/mpnnlstm_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import MPNNLSTM 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = MPNNLSTM(node_features, 32, 20, 1, 0.5) 24 | self.linear = torch.nn.Linear(2*32 + node_features, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight): 27 | h = self.recurrent(x, edge_index, edge_weight) 28 | h = F.relu(h) 29 | h = self.linear(h) 30 | return h 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(50)): 39 | cost = 0 40 | for time, snapshot in enumerate(train_dataset): 41 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 42 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 43 | cost = cost / (time+1) 44 | cost.backward() 45 | optimizer.step() 46 | optimizer.zero_grad() 47 | 48 | model.eval() 49 | cost = 0 50 | for time, snapshot in enumerate(test_dataset): 51 | y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) 52 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 53 | cost = cost / (time+1) 54 | cost = cost.item() 55 | print("MSE: {:.4f}".format(cost)) 56 | -------------------------------------------------------------------------------- /examples/recurrent/tgcn_example.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | except ImportError: 4 | def tqdm(iterable): 5 | return iterable 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric_temporal.nn.recurrent import TGCN 10 | 11 | from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader 12 | from torch_geometric_temporal.signal import temporal_signal_split 13 | 14 | loader = ChickenpoxDatasetLoader() 15 | 16 | dataset = loader.get_dataset() 17 | 18 | train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2) 19 | 20 | class RecurrentGCN(torch.nn.Module): 21 | def __init__(self, node_features): 22 | super(RecurrentGCN, self).__init__() 23 | self.recurrent = TGCN(node_features, 32) 24 | self.linear = torch.nn.Linear(32, 1) 25 | 26 | def forward(self, x, edge_index, edge_weight, prev_hidden_state): 27 | h = self.recurrent(x, edge_index, edge_weight, prev_hidden_state) 28 | y = F.relu(h) 29 | y = self.linear(y) 30 | return y, h 31 | 32 | model = RecurrentGCN(node_features = 4) 33 | 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 35 | 36 | model.train() 37 | 38 | for epoch in tqdm(range(50)): 39 | cost = 0 40 | hidden_state = None 41 | for time, snapshot in enumerate(train_dataset): 42 | y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr,hidden_state) 43 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 44 | cost = cost / (time+1) 45 | cost.backward() 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | model.eval() 50 | cost = 0 51 | hidden_state = None 52 | for time, snapshot in enumerate(test_dataset): 53 | y_hat, hidden_state = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, hidden_state) 54 | cost = cost + torch.mean((y_hat-snapshot.y)**2) 55 | cost = cost / (time+1) 56 | cost = cost.item() 57 | print("MSE: {:.4f}".format(cost)) 58 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | install_requires = [ 4 | "decorator==4.4.2", 5 | "torch", 6 | "cython", 7 | "torch_sparse", 8 | "torch_scatter", 9 | "torch_geometric", 10 | "numpy", 11 | "networkx", 12 | ] 13 | tests_require = ["pytest", "pytest-cov", "mock", "networkx", "tqdm"] 14 | index_require = ['dask', "pandas", "tables"] 15 | ddp_require = ["dask[distributed]", "dask_pytorch_ddp", "pandas", "tables"] 16 | 17 | keywords = [ 18 | "machine-learning", 19 | "deep-learning", 20 | "deeplearning", 21 | "deep learning", 22 | "machine learning", 23 | "signal processing", 24 | "temporal signal", 25 | "graph", 26 | "dynamic graph", 27 | "embedding", 28 | "dynamic embedding", 29 | "graph convolution", 30 | "gcn", 31 | "graph neural network", 32 | "graph attention", 33 | "lstm", 34 | "temporal network", 35 | "representation learning", 36 | "learning", 37 | ] 38 | 39 | setup( 40 | name="torch_geometric_temporal", 41 | packages=find_packages(), 42 | version="0.55.0", 43 | license="MIT", 44 | description="A Temporal Extension Library for PyTorch Geometric.", 45 | author="Benedek Rozemberczki", 46 | author_email="benedek.rozemberczki@gmail.com", 47 | url="https://github.com/benedekrozemberczki/pytorch_geometric_temporal", 48 | download_url="https://github.com/benedekrozemberczki/pytorch_geometric_temporal/archive/v0.54.0.tar.gz", 49 | keywords=keywords, 50 | install_requires=install_requires, 51 | extras_require={ 52 | "test": tests_require, 53 | "index": index_require, 54 | "ddp": ddp_require 55 | }, 56 | python_requires=">=3.6", 57 | classifiers=[ 58 | "Development Status :: 3 - Alpha", 59 | "Intended Audience :: Developers", 60 | "Topic :: Software Development :: Build Tools", 61 | "License :: OSI Approved :: MIT License", 62 | "Programming Language :: Python :: 3.6", 63 | ], 64 | ) 65 | -------------------------------------------------------------------------------- /test/heterogeneous_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import networkx as nx 4 | import torch_geometric.transforms as T 5 | from torch_geometric.data import HeteroData 6 | from torch_geometric_temporal.nn.hetero import HeteroGCLSTM 7 | 8 | 9 | def get_edge_array(n_count): 10 | return np.array([edge for edge in nx.gnp_random_graph(n_count, 0.1).edges()]).T 11 | 12 | 13 | def create_hetero_mock_data(n_count, feature_dict): 14 | _x_dict = {'author': torch.FloatTensor(np.random.uniform(0, 1, (n_count, feature_dict['author']))), 15 | 'paper': torch.FloatTensor(np.random.uniform(0, 1, (n_count, feature_dict['paper'])))} 16 | _edge_index_dict = {('author', 'writes', 'paper'): torch.LongTensor(get_edge_array(n_count))} 17 | 18 | data = HeteroData() 19 | data['author'].x = _x_dict['author'] 20 | data['paper'].x = _x_dict['paper'] 21 | data[('author', 'writes', 'paper')].edge_index = _edge_index_dict[('author', 'writes', 'paper')] 22 | data = T.ToUndirected()(data) 23 | 24 | return data.x_dict, data.edge_index_dict, data.metadata() 25 | 26 | 27 | def test_hetero_gclstm_layer(): 28 | """ 29 | Testing the HeteroGCLSTM Layer. 30 | """ 31 | number_of_nodes = 50 32 | feature_dict = {'author': 20, 'paper': 30} 33 | out_channels = 32 34 | 35 | x_dict, edge_index_dict, metadata = create_hetero_mock_data(number_of_nodes, feature_dict) 36 | 37 | layer = HeteroGCLSTM(in_channels_dict=feature_dict, out_channels=out_channels, metadata=metadata) 38 | 39 | h_dict, c_dict = layer(x_dict, edge_index_dict) 40 | 41 | assert h_dict['author'].shape == (number_of_nodes, out_channels) 42 | assert h_dict['paper'].shape == (number_of_nodes, out_channels) 43 | assert c_dict['author'].shape == (number_of_nodes, out_channels) 44 | assert c_dict['paper'].shape == (number_of_nodes, out_channels) 45 | 46 | h_dict, c_dict = layer(x_dict, edge_index_dict, h_dict, c_dict) 47 | 48 | assert h_dict['author'].shape == (number_of_nodes, out_channels) 49 | assert h_dict['paper'].shape == (number_of_nodes, out_channels) 50 | assert c_dict['author'].shape == (number_of_nodes, out_channels) 51 | assert c_dict['paper'].shape == (number_of_nodes, out_channels) 52 | -------------------------------------------------------------------------------- /torch_geometric_temporal/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_geometric_temporal.nn import * 2 | from torch_geometric_temporal.dataset import * 3 | from torch_geometric_temporal.signal import * 4 | 5 | __version__ = "0.54.0" 6 | 7 | __all__ = [ 8 | "torch_geometric", 9 | "__version__", 10 | ] 11 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .chickenpox import ChickenpoxDatasetLoader 2 | from .pedalme import PedalMeDatasetLoader 3 | from .metr_la import METRLADatasetLoader 4 | from .pems_bay import PemsBayDatasetLoader 5 | from .wikimath import WikiMathsDatasetLoader 6 | from .windmilllarge import WindmillOutputLargeDatasetLoader 7 | from .windmillmedium import WindmillOutputMediumDatasetLoader 8 | from .windmillsmall import WindmillOutputSmallDatasetLoader 9 | from .encovid import EnglandCovidDatasetLoader 10 | from .twitter_tennis import TwitterTennisDatasetLoader 11 | from .montevideo_bus import MontevideoBusDatasetLoader 12 | from .mtm import MTMDatasetLoader 13 | 14 | from .pemsAllLA import PemsAllLADatasetLoader 15 | from .pems import PemsDatasetLoader 16 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/chickenpox.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from ..signal import StaticGraphTemporalSignal 9 | 10 | 11 | class ChickenpoxDatasetLoader(object): 12 | """A dataset of county level chicken pox cases in Hungary between 2004 13 | and 2014. We made it public during the development of PyTorch Geometric 14 | Temporal. The underlying graph is static - vertices are counties and 15 | edges are neighbourhoods. Vertex features are lagged weekly counts of the 16 | chickenpox cases (we included 4 lags). The target is the weekly number of 17 | cases for the upcoming week (signed integers). Our dataset consist of more 18 | than 500 snapshots (weeks). 19 | 20 | Args: 21 | index (bool, optional): If True, initializes the dataloader to use index-based batching. 22 | Defaults to False. 23 | """ 24 | def __init__(self, index=False): 25 | self._read_web_data() 26 | self.index = index 27 | 28 | if index == True: 29 | from ..signal.index_dataset import IndexDataset 30 | self.IndexDataset = IndexDataset 31 | 32 | def _read_web_data(self): 33 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/chickenpox.json" 34 | 35 | context = ssl._create_unverified_context() 36 | self._dataset = json.loads( 37 | urllib.request.urlopen(url, context=context).read() 38 | ) 39 | 40 | def _get_edges(self): 41 | self._edges = np.array(self._dataset["edges"]).T 42 | 43 | def _get_edge_weights(self): 44 | self._edge_weights = np.ones(self._edges.shape[1]) 45 | 46 | def _get_targets_and_features(self): 47 | stacked_target = np.array(self._dataset["FX"]) 48 | self.features = [ 49 | stacked_target[i : i + self.lags, :].T 50 | for i in range(stacked_target.shape[0] - self.lags) 51 | ] 52 | self.targets = [ 53 | stacked_target[i + self.lags, :].T 54 | for i in range(stacked_target.shape[0] - self.lags) 55 | ] 56 | 57 | def get_dataset(self, lags: int = 4) -> StaticGraphTemporalSignal: 58 | """Returning the Chickenpox Hungary data iterator. 59 | 60 | Args types: 61 | * **lags** *(int)* - The number of time lags. 62 | Return types: 63 | * **dataset** *(StaticGraphTemporalSignal)* - The Chickenpox Hungary dataset. 64 | """ 65 | self.lags = lags 66 | self._get_edges() 67 | self._get_edge_weights() 68 | self._get_targets_and_features() 69 | dataset = StaticGraphTemporalSignal( 70 | self._edges, self._edge_weights, self.features, self.targets 71 | ) 72 | return dataset 73 | 74 | def get_index_dataset(self, lags=4, batch_size=4, shuffle=False, allGPU=-1, ratio=(0.7, 0.1, 0.2),dask_batching=False): 75 | """ 76 | Returns torch dataloaders using index batching for Chickenpox Hungary dataset. 77 | 78 | Args: 79 | lags (int, optional): The number of time lags. Defaults to 4. 80 | batch_size (int, optional): Batch size. Defaults to 4. 81 | shuffle (bool, optional): If the data should be shuffled. Defaults to False. 82 | allGPU (int, optional): GPU device ID for performing preprocessing in GPU memory. 83 | If -1, computation is done on CPU. Defaults to -1. 84 | ratio (tuple of float, optional): The desired train, validation, and test split ratios, respectively. 85 | 86 | Returns: 87 | Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.Tensor, torch.Tensor]: 88 | 89 | A 5-tuple containing: 90 | - **train_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the training set. 91 | - **val_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the validation set. 92 | - **test_dataLoader** (*torch.utils.data.DataLoader*): Dataloader for the test set. 93 | - **edges** (*torch.Tensor*): The graph edges as a 2D matrix, shape `[2, num_edges]`. 94 | - **edge_weights** (*torch.Tensor*): Each graph edge's weight, shape `[num_edges]`. 95 | """ 96 | 97 | if not self.index: 98 | raise ValueError("get_index_dataset requires 'index=True' in the constructor.") 99 | 100 | data = np.array(self._dataset["FX"]) 101 | edges = torch.tensor(self._dataset["edges"],dtype=torch.int64).T 102 | edge_weights = torch.ones(edges.shape[1],dtype=torch.float) 103 | num_samples = data.shape[0] 104 | 105 | if allGPU != -1: 106 | data = torch.tensor(data, dtype=torch.float).to(f"cuda:{allGPU}") 107 | data = data.unsqueeze(-1) 108 | else: 109 | data = np.expand_dims(data, axis=-1) 110 | 111 | 112 | x_i = np.arange(num_samples - (2 * lags - 1)) 113 | 114 | num_samples = x_i.shape[0] 115 | num_train = round(num_samples * ratio[0]) 116 | num_test = round(num_samples * ratio[2]) 117 | num_val = num_samples - num_train - num_test 118 | 119 | x_train = x_i[:num_train] 120 | x_val = x_i[num_train: num_train + num_val] 121 | x_test = x_i[-num_test:] 122 | 123 | train_dataset = self.IndexDataset(x_train,data,lags,gpu=not (allGPU == -1), lazy=dask_batching) 124 | val_dataset = self.IndexDataset(x_val,data,lags,gpu=not (allGPU == -1), lazy=dask_batching) 125 | test_dataset = self.IndexDataset(x_test,data,lags,gpu=not (allGPU == -1),lazy=dask_batching) 126 | 127 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) 128 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) 129 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle) 130 | 131 | 132 | return train_dataloader, val_dataloader, test_dataloader, edges, edge_weights 133 | 134 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/encovid.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import DynamicGraphTemporalSignal 6 | 7 | 8 | class EnglandCovidDatasetLoader(object): 9 | """A dataset of mobility and history of reported cases of COVID-19 10 | in England NUTS3 regions, from 3 March to 12 of May. The dataset is 11 | segmented in days and the graph is directed and weighted. The graph 12 | indicates how many people moved from one region to the other each day, 13 | based on Facebook Data For Good disease prevention maps. 14 | The node features correspond to the number of COVID-19 cases 15 | in the region in the past **window** days. The task is to predict the 16 | number of cases in each node after 1 day. For details see this paper: 17 | `"Transfer Graph Neural Networks for Pandemic Forecasting." `_ 18 | """ 19 | 20 | def __init__(self): 21 | self._read_web_data() 22 | 23 | def _read_web_data(self): 24 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json" 25 | context = ssl._create_unverified_context() 26 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 27 | 28 | def _get_edges(self): 29 | self._edges = [] 30 | for time in range(self._dataset["time_periods"] - self.lags): 31 | self._edges.append( 32 | np.array(self._dataset["edge_mapping"]["edge_index"][str(time)]).T 33 | ) 34 | 35 | def _get_edge_weights(self): 36 | self._edge_weights = [] 37 | for time in range(self._dataset["time_periods"] - self.lags): 38 | self._edge_weights.append( 39 | np.array(self._dataset["edge_mapping"]["edge_weight"][str(time)]) 40 | ) 41 | 42 | def _get_targets_and_features(self): 43 | 44 | stacked_target = np.array(self._dataset["y"]) 45 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( 46 | np.std(stacked_target, axis=0) + 10 ** -10 47 | ) 48 | self.features = [ 49 | standardized_target[i : i + self.lags, :].T 50 | for i in range(self._dataset["time_periods"] - self.lags) 51 | ] 52 | self.targets = [ 53 | standardized_target[i + self.lags, :].T 54 | for i in range(self._dataset["time_periods"] - self.lags) 55 | ] 56 | 57 | def get_dataset(self, lags: int = 8) -> DynamicGraphTemporalSignal: 58 | """Returning the England COVID19 data iterator. 59 | 60 | Args types: 61 | * **lags** *(int)* - The number of time lags. 62 | Return types: 63 | * **dataset** *(StaticGraphTemporalSignal)* - The England Covid dataset. 64 | """ 65 | self.lags = lags 66 | self._get_edges() 67 | self._get_edge_weights() 68 | self._get_targets_and_features() 69 | dataset = DynamicGraphTemporalSignal( 70 | self._edges, self._edge_weights, self.features, self.targets 71 | ) 72 | return dataset 73 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/montevideo_bus.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import json 3 | import ssl 4 | import urllib.request 5 | import numpy as np 6 | from torch_geometric_temporal.signal import StaticGraphTemporalSignal 7 | 8 | 9 | class MontevideoBusDatasetLoader(object): 10 | """A dataset of inflow passenger at bus stop level from Montevideo city. 11 | This dataset comprises hourly inflow passenger data at bus stop level for 11 bus lines during 12 | October 2020 from Montevideo city (Uruguay). The bus lines selected are the ones that carry 13 | people to the center of the city and they load more than 25% of the total daily inflow traffic. 14 | Vertices are bus stops, edges are links between bus stops when a bus line connects them and the 15 | weight represent the road distance. The target is the passenger inflow. This is a curated 16 | dataset made from different data sources of the Metropolitan Transportation System (STM) of 17 | Montevideo. These datasets are freely available to anyone in the National Catalog of Open Data 18 | from the government of Uruguay (https://catalogodatos.gub.uy/). 19 | """ 20 | 21 | def __init__(self): 22 | self._read_web_data() 23 | 24 | def _read_web_data(self): 25 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/montevideo_bus.json" 26 | context = ssl._create_unverified_context() 27 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 28 | 29 | def _get_node_ids(self): 30 | return [node.get('bus_stop') for node in self._dataset["nodes"]] 31 | 32 | def _get_edges(self): 33 | node_ids = self._get_node_ids() 34 | node_id_map = dict(zip(node_ids, range(len(node_ids)))) 35 | self._edges = np.array( 36 | [(node_id_map[d["source"]], node_id_map[d["target"]]) for d in self._dataset["links"]] 37 | ).T 38 | 39 | def _get_edge_weights(self): 40 | self._edge_weights = np.array([(d["weight"]) for d in self._dataset["links"]]).T 41 | 42 | def _get_features(self, feature_vars: List[str] = ["y"]): 43 | features = [] 44 | for node in self._dataset["nodes"]: 45 | X = node.get("X") 46 | for feature_var in feature_vars: 47 | features.append(np.array(X.get(feature_var))) 48 | stacked_features = np.stack(features).T 49 | standardized_features = ( 50 | stacked_features - np.mean(stacked_features, axis=0) 51 | ) / np.std(stacked_features, axis=0) 52 | self.features = [ 53 | standardized_features[i : i + self.lags, :].T 54 | for i in range(len(standardized_features) - self.lags) 55 | ] 56 | 57 | def _get_targets(self, target_var: str = "y"): 58 | targets = [] 59 | for node in self._dataset["nodes"]: 60 | y = node.get(target_var) 61 | targets.append(np.array(y)) 62 | stacked_targets = np.stack(targets).T 63 | standardized_targets = ( 64 | stacked_targets - np.mean(stacked_targets, axis=0) 65 | ) / np.std(stacked_targets, axis=0) 66 | self.targets = [ 67 | standardized_targets[i + self.lags, :].T 68 | for i in range(len(standardized_targets) - self.lags) 69 | ] 70 | 71 | def get_dataset( 72 | self, lags: int = 4, target_var: str = "y", feature_vars: List[str] = ["y"] 73 | ) -> StaticGraphTemporalSignal: 74 | """Returning the MontevideoBus passenger inflow data iterator. 75 | 76 | Parameters 77 | ---------- 78 | lags : int, optional 79 | The number of time lags, by default 4. 80 | target_var : str, optional 81 | Target variable name, by default "y". 82 | feature_vars : List[str], optional 83 | List of feature variables, by default ["y"]. 84 | 85 | Returns 86 | ------- 87 | StaticGraphTemporalSignal 88 | The MontevideoBus dataset. 89 | """ 90 | self.lags = lags 91 | self._get_edges() 92 | self._get_edge_weights() 93 | self._get_features(feature_vars) 94 | self._get_targets(target_var) 95 | dataset = StaticGraphTemporalSignal( 96 | self._edges, self._edge_weights, self.features, self.targets 97 | ) 98 | return dataset 99 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/mtm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from torch_geometric_temporal.signal import StaticGraphTemporalSignal 6 | 7 | 8 | class MTMDatasetLoader: 9 | """ 10 | A dataset of `Methods-Time Measurement-1 `_ 11 | (MTM-1) motions, signalled as consecutive video frames of 21 3D hand keypoints, acquired via 12 | `MediaPipe Hands `_ from RGB-Video 13 | material. Vertices are the finger joints of the human hand and edges are the bones connecting 14 | them. The targets are manually labeled for each frame, according to one of the five MTM-1 15 | motions (classes :math:`C`): Grasp, Release, Move, Reach, Position plus a negative class for 16 | frames without graph signals (no hand present). This is a classification task where :math:`T` 17 | consecutive frames need to be assigned to the corresponding class :math:`C`. The data x is 18 | returned in shape :obj:`(3, 21, T)`, the target is returned one-hot-encoded in shape :obj:`(T, 6)`. 19 | """ 20 | 21 | def __init__(self): 22 | self._read_web_data() 23 | 24 | def _read_web_data(self): 25 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/mtm_1.json" 26 | context = ssl._create_unverified_context() 27 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 28 | 29 | def _get_edges(self): 30 | self._edges = np.array(self._dataset["edges"]).T 31 | 32 | def _get_edge_weights(self): 33 | self._edge_weights = np.array([1 for d in self._dataset["edges"]]).T 34 | 35 | def _get_features(self): 36 | dic = self._dataset 37 | joints = [str(n) for n in range(21)] 38 | dataset_length = len(dic["0"].values()) 39 | features = np.zeros((dataset_length, 21, 3)) 40 | 41 | for j, joint in enumerate(joints): 42 | for t, xyz in enumerate(dic[joint].values()): 43 | xyz_tuple = list(map(float, xyz.strip("()").split(","))) 44 | features[t, j, :] = xyz_tuple 45 | 46 | self.features = [ 47 | features[i : i + self.frames, :].T 48 | for i in range(len(features) - self.frames) 49 | ] 50 | 51 | def _get_targets(self): 52 | # target eoncoding: {0 : 'Grasp', 1 : 'Move', 2 : 'Negative', 53 | # 3 : 'Position', 4 : 'Reach', 5 : 'Release'} 54 | targets = [] 55 | for _, y in self._dataset["LABEL"].items(): 56 | targets.append(y) 57 | 58 | n_values = np.max(targets) + 1 59 | targets_ohe = np.eye(n_values)[targets] 60 | 61 | self.targets = [ 62 | targets_ohe[i : i + self.frames, :] 63 | for i in range(len(targets_ohe) - self.frames) 64 | ] 65 | 66 | def get_dataset(self, frames: int = 16) -> StaticGraphTemporalSignal: 67 | """Returning the MTM-1 motion data iterator. 68 | 69 | Args types: 70 | * **frames** *(int)* - The number of consecutive frames T, default 16. 71 | Return types: 72 | * **dataset** *(StaticGraphTemporalSignal)* - The MTM-1 dataset. 73 | """ 74 | self.frames = frames 75 | self._get_edges() 76 | self._get_edge_weights() 77 | self._get_features() 78 | self._get_targets() 79 | 80 | dataset = StaticGraphTemporalSignal( 81 | self._edges, self._edge_weights, self.features, self.targets 82 | ) 83 | return dataset 84 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/pedalme.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import StaticGraphTemporalSignal 6 | 7 | 8 | class PedalMeDatasetLoader(object): 9 | """A dataset of PedalMe Bicycle deliver orders in London between 2020 10 | and 2021. We made it public during the development of PyTorch Geometric 11 | Temporal. The underlying graph is static - vertices are localities and 12 | edges are spatial_connections. Vertex features are lagged weekly counts of the 13 | delivery demands (we included 4 lags). The target is the weekly number of 14 | deliveries the upcoming week. Our dataset consist of more than 30 snapshots (weeks). 15 | """ 16 | 17 | def __init__(self): 18 | self._read_web_data() 19 | 20 | def _read_web_data(self): 21 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/pedalme_london.json" 22 | context = ssl._create_unverified_context() 23 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 24 | 25 | def _get_edges(self): 26 | self._edges = np.array(self._dataset["edges"]).T 27 | 28 | def _get_edge_weights(self): 29 | self._edge_weights = np.array(self._dataset["weights"]).T 30 | 31 | def _get_targets_and_features(self): 32 | stacked_target = np.array(self._dataset["X"]) 33 | self.features = [ 34 | stacked_target[i : i + self.lags, :].T 35 | for i in range(stacked_target.shape[0] - self.lags) 36 | ] 37 | self.targets = [ 38 | stacked_target[i + self.lags, :].T 39 | for i in range(stacked_target.shape[0] - self.lags) 40 | ] 41 | 42 | def get_dataset(self, lags: int = 4) -> StaticGraphTemporalSignal: 43 | """Returning the PedalMe London demand data iterator. 44 | 45 | Args types: 46 | * **lags** *(int)* - The number of time lags. 47 | Return types: 48 | * **dataset** *(StaticGraphTemporalSignal)* - The PedalMe dataset. 49 | """ 50 | self.lags = lags 51 | self._get_edges() 52 | self._get_edge_weights() 53 | self._get_targets_and_features() 54 | dataset = StaticGraphTemporalSignal( 55 | self._edges, self._edge_weights, self.features, self.targets 56 | ) 57 | return dataset 58 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/twitter_tennis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import DynamicGraphTemporalSignal 6 | 7 | 8 | def transform_degree(x, cutoff=4): 9 | log_deg = np.ceil(np.log(x + 1.0)) 10 | return np.minimum(log_deg, cutoff) 11 | 12 | 13 | def transform_transitivity(x): 14 | trans = x * 10 15 | return np.floor(trans) 16 | 17 | 18 | def onehot_encoding(x, unique_vals): 19 | E = np.zeros((len(x), len(unique_vals))) 20 | for i, val in enumerate(x): 21 | E[i, unique_vals.index(val)] = 1.0 22 | return E 23 | 24 | 25 | def encode_features(X, log_degree_cutoff=4): 26 | X_arr = np.array(X) 27 | a = transform_degree(X_arr[:, 0], log_degree_cutoff) 28 | b = transform_transitivity(X_arr[:, 1]) 29 | A = onehot_encoding(a, range(log_degree_cutoff + 1)) 30 | B = onehot_encoding(b, range(11)) 31 | return np.concatenate((A, B), axis=1) 32 | 33 | 34 | class TwitterTennisDatasetLoader(object): 35 | """ 36 | Twitter mention graphs related to major tennis tournaments from 2017. 37 | Nodes are Twitter accounts and edges are mentions between them. 38 | Each snapshot contains the graph induced by the most popular nodes 39 | of the original dataset. Node labels encode the number of mentions 40 | received in the original dataset for the next snapshot. Read more 41 | on the original Twitter data in the 'Temporal Walk Based Centrality Metric for Graph Streams' paper. 42 | 43 | Parameters 44 | ---------- 45 | event_id : str 46 | Choose to load the mention network for Roland-Garros 2017 ("rg17") or USOpen 2017 ("uo17") 47 | N : int <= 1000 48 | Number of most popular nodes to load. By default N=1000. Each snapshot contains the graph induced by these nodes. 49 | feature_mode : str 50 | None : load raw degree and transitivity node features 51 | "encoded" : load onehot encoded degree and transitivity node features 52 | "diagonal" : set identity matrix as node features 53 | target_offset : int 54 | Set the snapshot offset for the node labels to be predicted. By default node labels for the next snapshot are predicted (target_offset=1). 55 | """ 56 | 57 | def __init__( 58 | self, event_id="rg17", N=None, feature_mode="encoded", target_offset=1 59 | ): 60 | self.N = N 61 | self.target_offset = target_offset 62 | if event_id in ["rg17", "uo17"]: 63 | self.event_id = event_id 64 | else: 65 | raise ValueError( 66 | "Invalid 'event_id'! Choose 'rg17' or 'uo17' to load the Roland-Garros 2017 or the USOpen 2017 Twitter tennis dataset respectively." 67 | ) 68 | if feature_mode in [None, "diagonal", "encoded"]: 69 | self.feature_mode = feature_mode 70 | else: 71 | raise ValueError( 72 | "Choose feature_mode from values [None, 'diagonal', 'encoded']." 73 | ) 74 | self._read_web_data() 75 | 76 | def _read_web_data(self): 77 | fname = "twitter_tennis_%s.json" % self.event_id 78 | url = ( 79 | "https://raw.githubusercontent.com/ferencberes/pytorch_geometric_temporal/developer/dataset/" 80 | + fname 81 | ) 82 | context = ssl._create_unverified_context() 83 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 84 | # with open("/home/fberes/git/pytorch_geometric_temporal/dataset/"+fname) as f: 85 | # self._dataset = json.load(f) 86 | 87 | def _get_edges(self): 88 | edge_indices = [] 89 | self.edges = [] 90 | for time in range(self._dataset["time_periods"]): 91 | E = np.array(self._dataset[str(time)]["edges"]) 92 | if self.N != None: 93 | selector = np.where((E[:, 0] < self.N) & (E[:, 1] < self.N)) 94 | E = E[selector] 95 | edge_indices.append(selector) 96 | self.edges.append(E.T) 97 | self.edge_indices = edge_indices 98 | 99 | def _get_edge_weights(self): 100 | edge_indices = self.edge_indices 101 | self.edge_weights = [] 102 | for i, time in enumerate(range(self._dataset["time_periods"])): 103 | W = np.array(self._dataset[str(time)]["weights"]) 104 | if self.N != None: 105 | W = W[edge_indices[i]] 106 | self.edge_weights.append(W) 107 | 108 | def _get_features(self): 109 | self.features = [] 110 | for time in range(self._dataset["time_periods"]): 111 | X = np.array(self._dataset[str(time)]["X"]) 112 | if self.N != None: 113 | X = X[: self.N] 114 | if self.feature_mode == "diagonal": 115 | X = np.identity(X.shape[0]) 116 | elif self.feature_mode == "encoded": 117 | X = encode_features(X) 118 | self.features.append(X) 119 | 120 | def _get_targets(self): 121 | self.targets = [] 122 | T = self._dataset["time_periods"] 123 | for time in range(T): 124 | # predict node degrees in advance 125 | snapshot_id = min(time + self.target_offset, T - 1) 126 | y = np.array(self._dataset[str(snapshot_id)]["y"]) 127 | # logarithmic transformation for node degrees 128 | y = np.log(1.0 + y) 129 | if self.N != None: 130 | y = y[: self.N] 131 | self.targets.append(y) 132 | 133 | def get_dataset(self) -> DynamicGraphTemporalSignal: 134 | """Returning the TennisDataset data iterator. 135 | 136 | Return types: 137 | * **dataset** *(DynamicGraphTemporalSignal)* - Selected Twitter tennis dataset (Roland-Garros 2017 or USOpen 2017). 138 | """ 139 | self._get_edges() 140 | self._get_edge_weights() 141 | self._get_features() 142 | self._get_targets() 143 | dataset = DynamicGraphTemporalSignal( 144 | self.edges, self.edge_weights, self.features, self.targets 145 | ) 146 | return dataset 147 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/wikimath.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import StaticGraphTemporalSignal 6 | 7 | 8 | class WikiMathsDatasetLoader(object): 9 | """A dataset of vital mathematics articles from Wikipedia. We made it 10 | public during the development of PyTorch Geometric Temporal. The 11 | underlying graph is static - vertices are Wikipedia pages and edges are 12 | links between them. The graph is directed and weighted. Weights represent 13 | the number of links found at the source Wikipedia page linking to the target 14 | Wikipedia page. The target is the daily user visits to the Wikipedia pages 15 | between March 16th 2019 and March 15th 2021 which results in 731 periods. 16 | """ 17 | 18 | def __init__(self): 19 | self._read_web_data() 20 | 21 | def _read_web_data(self): 22 | url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/wikivital_mathematics.json" 23 | context = ssl._create_unverified_context() 24 | self._dataset = json.loads(urllib.request.urlopen(url, context=context).read()) 25 | 26 | def _get_edges(self): 27 | self._edges = np.array(self._dataset["edges"]).T 28 | 29 | def _get_edge_weights(self): 30 | self._edge_weights = np.array(self._dataset["weights"]).T 31 | 32 | def _get_targets_and_features(self): 33 | 34 | targets = [] 35 | for time in range(self._dataset["time_periods"]): 36 | targets.append(np.array(self._dataset[str(time)]["y"])) 37 | stacked_target = np.stack(targets) 38 | standardized_target = ( 39 | stacked_target - np.mean(stacked_target, axis=0) 40 | ) / np.std(stacked_target, axis=0) 41 | self.features = [ 42 | standardized_target[i : i + self.lags, :].T 43 | for i in range(len(targets) - self.lags) 44 | ] 45 | self.targets = [ 46 | standardized_target[i + self.lags, :].T 47 | for i in range(len(targets) - self.lags) 48 | ] 49 | 50 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal: 51 | """Returning the Wikipedia Vital Mathematics data iterator. 52 | 53 | Args types: 54 | * **lags** *(int)* - The number of time lags. 55 | Return types: 56 | * **dataset** *(StaticGraphTemporalSignal)* - The Wiki Maths dataset. 57 | """ 58 | self.lags = lags 59 | self._get_edges() 60 | self._get_edge_weights() 61 | self._get_targets_and_features() 62 | dataset = StaticGraphTemporalSignal( 63 | self._edges, self._edge_weights, self.features, self.targets 64 | ) 65 | return dataset 66 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/windmillmedium.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import StaticGraphTemporalSignal 6 | 7 | 8 | class WindmillOutputMediumDatasetLoader(object): 9 | """Hourly energy output of windmills from a European country 10 | for more than 2 years. Vertices represent 26 windmills and 11 | weighted edges describe the strength of relationships. The target 12 | variable allows for regression tasks. 13 | """ 14 | 15 | def __init__(self): 16 | self._read_web_data() 17 | 18 | def _read_web_data(self): 19 | url = "https://graphmining.ai/temporal_datasets/windmill_output_medium.json" 20 | context = ssl._create_unverified_context() 21 | self._dataset = json.loads( 22 | urllib.request.urlopen(url, context=context).read().decode() 23 | ) 24 | 25 | def _get_edges(self): 26 | self._edges = np.array(self._dataset["edges"]).T 27 | 28 | def _get_edge_weights(self): 29 | self._edge_weights = np.array(self._dataset["weights"]).T 30 | 31 | def _get_targets_and_features(self): 32 | stacked_target = np.stack(self._dataset["block"]) 33 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( 34 | np.std(stacked_target, axis=0) + 10 ** -10 35 | ) 36 | self.features = [ 37 | standardized_target[i : i + self.lags, :].T 38 | for i in range(standardized_target.shape[0] - self.lags) 39 | ] 40 | self.targets = [ 41 | standardized_target[i + self.lags, :].T 42 | for i in range(standardized_target.shape[0] - self.lags) 43 | ] 44 | 45 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal: 46 | """Returning the Windmill Output data iterator. 47 | 48 | Args types: 49 | * **lags** *(int)* - The number of time lags. 50 | Return types: 51 | * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset. 52 | """ 53 | self.lags = lags 54 | self._get_edges() 55 | self._get_edge_weights() 56 | self._get_targets_and_features() 57 | dataset = StaticGraphTemporalSignal( 58 | self._edges, self._edge_weights, self.features, self.targets 59 | ) 60 | return dataset 61 | -------------------------------------------------------------------------------- /torch_geometric_temporal/dataset/windmillsmall.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ssl 3 | import urllib.request 4 | import numpy as np 5 | from ..signal import StaticGraphTemporalSignal 6 | 7 | 8 | class WindmillOutputSmallDatasetLoader(object): 9 | """Hourly energy output of windmills from a European country 10 | for more than 2 years. Vertices represent 11 windmills and 11 | weighted edges describe the strength of relationships. The target 12 | variable allows for regression tasks. 13 | """ 14 | 15 | def __init__(self): 16 | self._read_web_data() 17 | 18 | def _read_web_data(self): 19 | url = "https://graphmining.ai/temporal_datasets/windmill_output_small.json" 20 | context = ssl._create_unverified_context() 21 | self._dataset = json.loads( 22 | urllib.request.urlopen(url, context=context).read().decode() 23 | ) 24 | 25 | def _get_edges(self): 26 | self._edges = np.array(self._dataset["edges"]).T 27 | 28 | def _get_edge_weights(self): 29 | self._edge_weights = np.array(self._dataset["weights"]).T 30 | 31 | def _get_targets_and_features(self): 32 | stacked_target = np.stack(self._dataset["block"]) 33 | standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( 34 | np.std(stacked_target, axis=0) + 10 ** -10 35 | ) 36 | self.features = [ 37 | standardized_target[i : i + self.lags, :].T 38 | for i in range(standardized_target.shape[0] - self.lags) 39 | ] 40 | self.targets = [ 41 | standardized_target[i + self.lags, :].T 42 | for i in range(standardized_target.shape[0] - self.lags) 43 | ] 44 | 45 | def get_dataset(self, lags: int = 8) -> StaticGraphTemporalSignal: 46 | """Returning the Windmill Output data iterator. 47 | 48 | Args types: 49 | * **lags** *(int)* - The number of time lags. 50 | Return types: 51 | * **dataset** *(StaticGraphTemporalSignal)* - The Windmill Output dataset. 52 | """ 53 | self.lags = lags 54 | self._get_edges() 55 | self._get_edge_weights() 56 | self._get_targets_and_features() 57 | dataset = StaticGraphTemporalSignal( 58 | self._edges, self._edge_weights, self.features, self.targets 59 | ) 60 | return dataset 61 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .recurrent import * 2 | from .attention import * 3 | from .hetero import * 4 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .stgcn import STConv, TemporalConv 2 | from .astgcn import ASTGCN, ChebConvAttention 3 | from .mstgcn import MSTGCN 4 | from .gman import GMAN, SpatioTemporalEmbedding, SpatioTemporalAttention 5 | from .mtgnn import MTGNN, MixProp, GraphConstructor 6 | from .tsagcn import GraphAAGCN, AAGCN 7 | from .dnntsp import DNNTSP 8 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/attention/stgcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import ChebConv 6 | 7 | 8 | class TemporalConv(nn.Module): 9 | r"""Temporal convolution block applied to nodes in the STGCN Layer 10 | For details see: `"Spatio-Temporal Graph Convolutional Networks: 11 | A Deep Learning Framework for Traffic Forecasting." 12 | `_ Based off the temporal convolution 13 | introduced in "Convolutional Sequence to Sequence Learning" `_ 14 | 15 | Args: 16 | in_channels (int): Number of input features. 17 | out_channels (int): Number of output features. 18 | kernel_size (int): Convolutional kernel size. 19 | """ 20 | 21 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): 22 | super(TemporalConv, self).__init__() 23 | self.conv_1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 24 | self.conv_2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 25 | self.conv_3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 26 | 27 | def forward(self, X: torch.FloatTensor) -> torch.FloatTensor: 28 | """Forward pass through temporal convolution block. 29 | 30 | Arg types: 31 | * **X** (torch.FloatTensor) - Input data of shape 32 | (batch_size, input_time_steps, num_nodes, in_channels). 33 | 34 | Return types: 35 | * **H** (torch.FloatTensor) - Output data of shape 36 | (batch_size, in_channels, num_nodes, input_time_steps). 37 | """ 38 | X = X.permute(0, 3, 2, 1) 39 | P = self.conv_1(X) 40 | Q = torch.sigmoid(self.conv_2(X)) 41 | PQ = P * Q 42 | H = F.relu(PQ + self.conv_3(X)) 43 | H = H.permute(0, 3, 2, 1) 44 | return H 45 | 46 | 47 | class STConv(nn.Module): 48 | r"""Spatio-temporal convolution block using ChebConv Graph Convolutions. 49 | For details see: `"Spatio-Temporal Graph Convolutional Networks: 50 | A Deep Learning Framework for Traffic Forecasting" 51 | `_ 52 | 53 | NB. The ST-Conv block contains two temporal convolutions (TemporalConv) 54 | with kernel size k. Hence for an input sequence of length m, 55 | the output sequence will be length m-2(k-1). 56 | 57 | Args: 58 | in_channels (int): Number of input features. 59 | hidden_channels (int): Number of hidden units output by graph convolution block 60 | out_channels (int): Number of output features. 61 | kernel_size (int): Size of the kernel considered. 62 | K (int): Chebyshev filter size :math:`K`. 63 | normalization (str, optional): The normalization scheme for the graph 64 | Laplacian (default: :obj:`"sym"`): 65 | 66 | 1. :obj:`None`: No normalization 67 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 68 | 69 | 2. :obj:`"sym"`: Symmetric normalization 70 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 71 | \mathbf{D}^{-1/2}` 72 | 73 | 3. :obj:`"rw"`: Random-walk normalization 74 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` 75 | 76 | You need to pass :obj:`lambda_max` to the :meth:`forward` method of 77 | this operator in case the normalization is non-symmetric. 78 | :obj:`\lambda_max` should be a :class:`torch.Tensor` of size 79 | :obj:`[num_graphs]` in a mini-batch scenario and a 80 | scalar/zero-dimensional tensor when operating on single graphs. 81 | You can pre-compute :obj:`lambda_max` via the 82 | :class:`torch_geometric.transforms.LaplacianLambdaMax` transform. 83 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 84 | an additive bias. (default: :obj:`True`) 85 | 86 | """ 87 | 88 | def __init__( 89 | self, 90 | num_nodes: int, 91 | in_channels: int, 92 | hidden_channels: int, 93 | out_channels: int, 94 | kernel_size: int, 95 | K: int, 96 | normalization: str = "sym", 97 | bias: bool = True, 98 | ): 99 | super(STConv, self).__init__() 100 | self.num_nodes = num_nodes 101 | self.in_channels = in_channels 102 | self.hidden_channels = hidden_channels 103 | self.out_channels = out_channels 104 | self.kernel_size = kernel_size 105 | self.K = K 106 | self.normalization = normalization 107 | self.bias = bias 108 | 109 | self._temporal_conv1 = TemporalConv( 110 | in_channels=in_channels, 111 | out_channels=hidden_channels, 112 | kernel_size=kernel_size, 113 | ) 114 | 115 | self._graph_conv = ChebConv( 116 | in_channels=hidden_channels, 117 | out_channels=hidden_channels, 118 | K=K, 119 | normalization=normalization, 120 | bias=bias, 121 | ) 122 | 123 | self._temporal_conv2 = TemporalConv( 124 | in_channels=hidden_channels, 125 | out_channels=out_channels, 126 | kernel_size=kernel_size, 127 | ) 128 | 129 | self._batch_norm = nn.BatchNorm2d(num_nodes) 130 | 131 | def forward( 132 | self, 133 | X: torch.FloatTensor, 134 | edge_index: torch.LongTensor, 135 | edge_weight: torch.FloatTensor = None, 136 | ) -> torch.FloatTensor: 137 | 138 | r"""Forward pass. If edge weights are not present the forward pass 139 | defaults to an unweighted graph. 140 | 141 | Arg types: 142 | * **X** (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels). 143 | * **edge_index** (PyTorch LongTensor) - Graph edge indices. 144 | * **edge_weight** (PyTorch LongTensor, optional)- Edge weight vector. 145 | 146 | Return types: 147 | * **T** (PyTorch FloatTensor) - Sequence of node features. 148 | """ 149 | T_0 = self._temporal_conv1(X) 150 | T = torch.zeros_like(T_0).to(T_0.device) 151 | for b in range(T_0.size(0)): 152 | for t in range(T_0.size(1)): 153 | T[b][t] = self._graph_conv(T_0[b][t], edge_index, edge_weight) 154 | 155 | T = F.relu(T) 156 | T = self._temporal_conv2(T) 157 | T = T.permute(0, 2, 1, 3) 158 | T = self._batch_norm(T) 159 | T = T.permute(0, 2, 1, 3) 160 | return T 161 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/hetero/__init__.py: -------------------------------------------------------------------------------- 1 | from .heterogclstm import HeteroGCLSTM 2 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | from .gconv_gru import GConvGRU 2 | from .gconv_lstm import GConvLSTM 3 | from .lrgcn import LRGCN 4 | from .gc_lstm import GCLSTM 5 | from .dygrae import DyGrEncoder 6 | from .evolvegcnh import EvolveGCNH 7 | from .evolvegcno import EvolveGCNO 8 | from .dcrnn import DCRNN, BatchedDCRNN 9 | from .temporalgcn import TGCN 10 | from .temporalgcn import TGCN2 11 | from .attentiontemporalgcn import A3TGCN 12 | from .attentiontemporalgcn import A3TGCN2 13 | from .mpnn_lstm import MPNNLSTM 14 | from .agcrn import AGCRN 15 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/agcrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn.inits import glorot, zeros 5 | 6 | 7 | class AVWGCN(nn.Module): 8 | r"""An implementation of the Node Adaptive Graph Convolution Layer. 9 | For details see: `"Adaptive Graph Convolutional Recurrent Network 10 | for Traffic Forecasting" `_ 11 | Args: 12 | in_channels (int): Number of input features. 13 | out_channels (int): Number of output features. 14 | K (int): Filter size :math:`K`. 15 | embedding_dimensions (int): Number of node embedding dimensions. 16 | """ 17 | 18 | def __init__( 19 | self, in_channels: int, out_channels: int, K: int, embedding_dimensions: int 20 | ): 21 | super(AVWGCN, self).__init__() 22 | self.K = K 23 | self.weights_pool = torch.nn.Parameter( 24 | torch.Tensor(embedding_dimensions, K, in_channels, out_channels) 25 | ) 26 | self.bias_pool = torch.nn.Parameter( 27 | torch.Tensor(embedding_dimensions, out_channels) 28 | ) 29 | glorot(self.weights_pool) 30 | zeros(self.bias_pool) 31 | 32 | def forward(self, X: torch.FloatTensor, E: torch.FloatTensor) -> torch.FloatTensor: 33 | r"""Making a forward pass. 34 | Arg types: 35 | * **X** (PyTorch Float Tensor) - Node features. 36 | * **E** (PyTorch Float Tensor) - Node embeddings. 37 | Return types: 38 | * **X_G** (PyTorch Float Tensor) - Hidden state matrix for all nodes. 39 | """ 40 | 41 | number_of_nodes = E.shape[0] 42 | supports = F.softmax(F.relu(torch.mm(E, E.transpose(0, 1))), dim=1) 43 | support_set = [torch.eye(number_of_nodes).to(supports.device), supports] 44 | for _ in range(2, self.K): 45 | support = torch.matmul(2 * supports, support_set[-1]) - support_set[-2] 46 | support_set.append(support) 47 | supports = torch.stack(support_set, dim=0) 48 | W = torch.einsum("nd,dkio->nkio", E, self.weights_pool) 49 | bias = torch.matmul(E, self.bias_pool) 50 | X_G = torch.einsum("knm,bmc->bknc", supports, X) 51 | X_G = X_G.permute(0, 2, 1, 3) 52 | X_G = torch.einsum("bnki,nkio->bno", X_G, W) + bias 53 | return X_G 54 | 55 | 56 | class AGCRN(nn.Module): 57 | r"""An implementation of the Adaptive Graph Convolutional Recurrent Unit. 58 | For details see: `"Adaptive Graph Convolutional Recurrent Network 59 | for Traffic Forecasting" `_ 60 | Args: 61 | number_of_nodes (int): Number of vertices. 62 | in_channels (int): Number of input features. 63 | out_channels (int): Number of output features. 64 | K (int): Filter size :math:`K`. 65 | embedding_dimensions (int): Number of node embedding dimensions. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | number_of_nodes: int, 71 | in_channels: int, 72 | out_channels: int, 73 | K: int, 74 | embedding_dimensions: int, 75 | ): 76 | super(AGCRN, self).__init__() 77 | 78 | self.number_of_nodes = number_of_nodes 79 | self.in_channels = in_channels 80 | self.out_channels = out_channels 81 | self.K = K 82 | self.embedding_dimensions = embedding_dimensions 83 | self._setup_layers() 84 | 85 | def _setup_layers(self): 86 | self._gate = AVWGCN( 87 | in_channels=self.in_channels + self.out_channels, 88 | out_channels=2 * self.out_channels, 89 | K=self.K, 90 | embedding_dimensions=self.embedding_dimensions, 91 | ) 92 | 93 | self._update = AVWGCN( 94 | in_channels=self.in_channels + self.out_channels, 95 | out_channels=self.out_channels, 96 | K=self.K, 97 | embedding_dimensions=self.embedding_dimensions, 98 | ) 99 | 100 | def _set_hidden_state(self, X, H): 101 | if H is None: 102 | H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) 103 | return H 104 | 105 | def forward( 106 | self, X: torch.FloatTensor, E: torch.FloatTensor, H: torch.FloatTensor = None 107 | ) -> torch.FloatTensor: 108 | r"""Making a forward pass. 109 | Arg types: 110 | * **X** (PyTorch Float Tensor) - Node feature matrix. 111 | * **E** (PyTorch Float Tensor) - Node embedding matrix. 112 | * **H** (PyTorch Float Tensor) - Node hidden state matrix. Default is None. 113 | Return types: 114 | * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes. 115 | """ 116 | H = self._set_hidden_state(X, H) 117 | X_H = torch.cat((X, H), dim=-1) 118 | Z_R = torch.sigmoid(self._gate(X_H, E)) 119 | Z, R = torch.split(Z_R, self.out_channels, dim=-1) 120 | C = torch.cat((X, Z * H), dim=-1) 121 | HC = torch.tanh(self._update(C, E)) 122 | H = R * H + (1 - R) * HC 123 | return H 124 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/attentiontemporalgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .temporalgcn import TGCN 3 | from .temporalgcn import TGCN2 4 | from torch_geometric.nn import GCNConv 5 | 6 | 7 | class A3TGCN(torch.nn.Module): 8 | r"""An implementation of the Attention Temporal Graph Convolutional Cell. 9 | For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional 10 | Network for Traffic Forecasting." `_ 11 | 12 | Args: 13 | in_channels (int): Number of input features. 14 | out_channels (int): Number of output features. 15 | periods (int): Number of time periods. 16 | improved (bool): Stronger self loops (default :obj:`False`). 17 | cached (bool): Caching the message weights (default :obj:`False`). 18 | add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`). 19 | """ 20 | 21 | def __init__( 22 | self, 23 | in_channels: int, 24 | out_channels: int, 25 | periods: int, 26 | improved: bool = False, 27 | cached: bool = False, 28 | add_self_loops: bool = True 29 | ): 30 | super(A3TGCN, self).__init__() 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.periods = periods 35 | self.improved = improved 36 | self.cached = cached 37 | self.add_self_loops = add_self_loops 38 | self._setup_layers() 39 | 40 | def _setup_layers(self): 41 | self._base_tgcn = TGCN( 42 | in_channels=self.in_channels, 43 | out_channels=self.out_channels, 44 | improved=self.improved, 45 | cached=self.cached, 46 | add_self_loops=self.add_self_loops, 47 | ) 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device)) 50 | torch.nn.init.uniform_(self._attention) 51 | 52 | def forward( 53 | self, 54 | X: torch.FloatTensor, 55 | edge_index: torch.LongTensor, 56 | edge_weight: torch.FloatTensor = None, 57 | H: torch.FloatTensor = None, 58 | ) -> torch.FloatTensor: 59 | """ 60 | Making a forward pass. If edge weights are not present the forward pass 61 | defaults to an unweighted graph. If the hidden state matrix is not present 62 | when the forward pass is called it is initialized with zeros. 63 | 64 | Arg types: 65 | * **X** (PyTorch Float Tensor): Node features for T time periods. 66 | * **edge_index** (PyTorch Long Tensor): Graph edge indices. 67 | * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector. 68 | * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes. 69 | 70 | Return types: 71 | * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes. 72 | """ 73 | H_accum = 0 74 | probs = torch.nn.functional.softmax(self._attention, dim=0) 75 | for period in range(self.periods): 76 | H_accum = H_accum + probs[period] * self._base_tgcn( 77 | X[:, :, period], edge_index, edge_weight, H 78 | ) 79 | return H_accum 80 | 81 | 82 | 83 | class A3TGCN2(torch.nn.Module): 84 | r"""An implementation THAT SUPPORTS BATCHES of the Attention Temporal Graph Convolutional Cell. 85 | For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional 86 | Network for Traffic Forecasting." `_ 87 | 88 | Args: 89 | in_channels (int): Number of input features. 90 | out_channels (int): Number of output features. 91 | periods (int): Number of time periods. 92 | improved (bool): Stronger self loops (default :obj:`False`). 93 | cached (bool): Caching the message weights (default :obj:`False`). 94 | add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`). 95 | """ 96 | 97 | def __init__( 98 | self, 99 | in_channels: int, 100 | out_channels: int, 101 | periods: int, 102 | batch_size:int, 103 | improved: bool = False, 104 | cached: bool = False, 105 | add_self_loops: bool = True): 106 | super(A3TGCN2, self).__init__() 107 | 108 | self.in_channels = in_channels # 2 109 | self.out_channels = out_channels # 32 110 | self.periods = periods # 12 111 | self.improved = improved 112 | self.cached = cached 113 | self.add_self_loops = add_self_loops 114 | self.batch_size = batch_size 115 | self._setup_layers() 116 | 117 | def _setup_layers(self): 118 | self._base_tgcn = TGCN2( 119 | in_channels=self.in_channels, 120 | out_channels=self.out_channels, 121 | batch_size=self.batch_size, 122 | improved=self.improved, 123 | cached=self.cached, 124 | add_self_loops=self.add_self_loops) 125 | 126 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 127 | self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device)) 128 | torch.nn.init.uniform_(self._attention) 129 | 130 | def forward( 131 | self, 132 | X: torch.FloatTensor, 133 | edge_index: torch.LongTensor, 134 | edge_weight: torch.FloatTensor = None, 135 | H: torch.FloatTensor = None 136 | ) -> torch.FloatTensor: 137 | """ 138 | Making a forward pass. If edge weights are not present the forward pass 139 | defaults to an unweighted graph. If the hidden state matrix is not present 140 | when the forward pass is called it is initialized with zeros. 141 | 142 | Arg types: 143 | * **X** (PyTorch Float Tensor): Node features for T time periods. 144 | * **edge_index** (PyTorch Long Tensor): Graph edge indices. 145 | * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector. 146 | * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes. 147 | 148 | Return types: 149 | * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes. 150 | """ 151 | H_accum = 0 152 | probs = torch.nn.functional.softmax(self._attention, dim=0) 153 | for period in range(self.periods): 154 | 155 | H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32] 156 | 157 | return H_accum -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/dygrae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import LSTM 3 | from torch_geometric.nn import GatedGraphConv 4 | 5 | 6 | class DyGrEncoder(torch.nn.Module): 7 | r"""An implementation of the integrated Gated Graph Convolution Long Short 8 | Term Memory Layer. For details see this paper: `"Predictive Temporal Embedding 9 | of Dynamic Graphs." `_ 10 | 11 | Args: 12 | conv_out_channels (int): Number of output channels for the GGCN. 13 | conv_num_layers (int): Number of Gated Graph Convolutions. 14 | conv_aggr (str): Aggregation scheme to use 15 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 16 | lstm_out_channels (int): Number of LSTM channels. 17 | lstm_num_layers (int): Number of neurons in LSTM. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | conv_out_channels: int, 23 | conv_num_layers: int, 24 | conv_aggr: str, 25 | lstm_out_channels: int, 26 | lstm_num_layers: int, 27 | ): 28 | super(DyGrEncoder, self).__init__() 29 | assert conv_aggr in ["mean", "add", "max"], "Wrong aggregator." 30 | self.conv_out_channels = conv_out_channels 31 | self.conv_num_layers = conv_num_layers 32 | self.conv_aggr = conv_aggr 33 | self.lstm_out_channels = lstm_out_channels 34 | self.lstm_num_layers = lstm_num_layers 35 | self._create_layers() 36 | 37 | def _create_layers(self): 38 | self.conv_layer = GatedGraphConv( 39 | out_channels=self.conv_out_channels, 40 | num_layers=self.conv_num_layers, 41 | aggr=self.conv_aggr, 42 | bias=True, 43 | ) 44 | 45 | self.recurrent_layer = LSTM( 46 | input_size=self.conv_out_channels, 47 | hidden_size=self.lstm_out_channels, 48 | num_layers=self.lstm_num_layers, 49 | ) 50 | 51 | def forward( 52 | self, 53 | X: torch.FloatTensor, 54 | edge_index: torch.LongTensor, 55 | edge_weight: torch.FloatTensor = None, 56 | H: torch.FloatTensor = None, 57 | C: torch.FloatTensor = None, 58 | ) -> torch.FloatTensor: 59 | """ 60 | Making a forward pass. If the hidden state and cell state matrices are 61 | not present when the forward pass is called these are initialized with zeros. 62 | 63 | Arg types: 64 | * **X** *(PyTorch Float Tensor)* - Node features. 65 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices. 66 | * **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector. 67 | * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes. 68 | * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes. 69 | 70 | Return types: 71 | * **H_tilde** *(PyTorch Float Tensor)* - Output matrix for all nodes. 72 | * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes. 73 | * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes. 74 | """ 75 | H_tilde = self.conv_layer(X, edge_index, edge_weight) 76 | H_tilde = H_tilde[None, :, :] 77 | if H is None and C is None: 78 | H_tilde, (H, C) = self.recurrent_layer(H_tilde) 79 | elif H is not None and C is not None: 80 | H = H[None, :, :] 81 | C = C[None, :, :] 82 | H_tilde, (H, C) = self.recurrent_layer(H_tilde, (H, C)) 83 | else: 84 | raise ValueError("Invalid hidden state and cell matrices.") 85 | H_tilde = H_tilde.squeeze() 86 | H = H.squeeze() 87 | C = C.squeeze() 88 | return H_tilde, H, C 89 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/evolvegcnh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import GRU 3 | from torch_geometric.nn import TopKPooling 4 | 5 | from .evolvegcno import glorot, GCNConv_Fixed_W 6 | 7 | 8 | class EvolveGCNH(torch.nn.Module): 9 | r"""An implementation of the Evolving Graph Convolutional Hidden Layer. 10 | For details see this paper: `"EvolveGCN: Evolving Graph Convolutional 11 | Networks for Dynamic Graph." `_ 12 | 13 | Args: 14 | num_of_nodes (int): Number of vertices. 15 | in_channels (int): Number of filters. 16 | improved (bool, optional): If set to :obj:`True`, the layer computes 17 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 18 | (default: :obj:`False`) 19 | cached (bool, optional): If set to :obj:`True`, the layer will cache 20 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 21 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 22 | cached version for further executions. 23 | This parameter should only be set to :obj:`True` in transductive 24 | learning scenarios. (default: :obj:`False`) 25 | normalize (bool, optional): Whether to add self-loops and apply 26 | symmetric normalization. (default: :obj:`True`) 27 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 28 | self-loops to the input graph. (default: :obj:`True`) 29 | """ 30 | 31 | def __init__( 32 | self, 33 | num_of_nodes: int, 34 | in_channels: int, 35 | improved: bool = False, 36 | cached: bool = False, 37 | normalize: bool = True, 38 | add_self_loops: bool = True, 39 | ): 40 | super(EvolveGCNH, self).__init__() 41 | 42 | self.num_of_nodes = num_of_nodes 43 | self.in_channels = in_channels 44 | self.improved = improved 45 | self.cached = cached 46 | self.normalize = normalize 47 | self.add_self_loops = add_self_loops 48 | self.weight = None 49 | self.initial_weight = torch.nn.Parameter(torch.Tensor(1, in_channels, in_channels)) 50 | self._create_layers() 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | glorot(self.initial_weight) 55 | 56 | def reinitialize_weight(self): 57 | self.weight = None 58 | 59 | def _create_layers(self): 60 | 61 | self.ratio = self.in_channels / self.num_of_nodes 62 | 63 | self.pooling_layer = TopKPooling(self.in_channels, self.ratio) 64 | 65 | self.recurrent_layer = GRU( 66 | input_size=self.in_channels, hidden_size=self.in_channels, num_layers=1 67 | ) 68 | 69 | self.conv_layer = GCNConv_Fixed_W( 70 | in_channels=self.in_channels, 71 | out_channels=self.in_channels, 72 | improved=self.improved, 73 | cached=self.cached, 74 | normalize=self.normalize, 75 | add_self_loops=self.add_self_loops 76 | ) 77 | 78 | def forward( 79 | self, 80 | X: torch.FloatTensor, 81 | edge_index: torch.LongTensor, 82 | edge_weight: torch.FloatTensor = None, 83 | ) -> torch.FloatTensor: 84 | """ 85 | Making a forward pass. 86 | 87 | Arg types: 88 | * **X** *(PyTorch Float Tensor)* - Node embedding. 89 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices. 90 | * **edge_weight** *(PyTorch Float Tensor, optional)* - Edge weight vector. 91 | 92 | Return types: 93 | * **X** *(PyTorch Float Tensor)* - Output matrix for all nodes. 94 | """ 95 | X_tilde = self.pooling_layer(X, edge_index) 96 | X_tilde = X_tilde[0][None, :, :] 97 | if self.weight is None: 98 | _, self.weight = self.recurrent_layer(X_tilde, self.initial_weight) 99 | else: 100 | _, self.weight = self.recurrent_layer(X_tilde, self.weight) 101 | X = self.conv_layer(self.weight.squeeze(dim=0), X, edge_index, edge_weight) 102 | return X 103 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/lrgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn import RGCNConv 4 | from torch_geometric.nn.inits import glorot, zeros 5 | 6 | 7 | class LRGCN(torch.nn.Module): 8 | r"""An implementation of the Long Short Term Memory Relational 9 | Graph Convolution Layer. For details see this paper: `"Predicting Path 10 | Failure In Time-Evolving Graphs." `_ 11 | 12 | Args: 13 | in_channels (int): Number of input features. 14 | out_channels (int): Number of output features. 15 | num_relations (int): Number of relations. 16 | num_bases (int): Number of bases. 17 | """ 18 | 19 | def __init__( 20 | self, in_channels: int, out_channels: int, num_relations: int, num_bases: int 21 | ): 22 | super(LRGCN, self).__init__() 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.num_relations = num_relations 27 | self.num_bases = num_bases 28 | self._create_layers() 29 | 30 | def _create_input_gate_layers(self): 31 | 32 | self.conv_x_i = RGCNConv( 33 | in_channels=self.in_channels, 34 | out_channels=self.out_channels, 35 | num_relations=self.num_relations, 36 | num_bases=self.num_bases, 37 | ) 38 | 39 | self.conv_h_i = RGCNConv( 40 | in_channels=self.out_channels, 41 | out_channels=self.out_channels, 42 | num_relations=self.num_relations, 43 | num_bases=self.num_bases, 44 | ) 45 | 46 | def _create_forget_gate_layers(self): 47 | 48 | self.conv_x_f = RGCNConv( 49 | in_channels=self.in_channels, 50 | out_channels=self.out_channels, 51 | num_relations=self.num_relations, 52 | num_bases=self.num_bases, 53 | ) 54 | 55 | self.conv_h_f = RGCNConv( 56 | in_channels=self.out_channels, 57 | out_channels=self.out_channels, 58 | num_relations=self.num_relations, 59 | num_bases=self.num_bases, 60 | ) 61 | 62 | def _create_cell_state_layers(self): 63 | 64 | self.conv_x_c = RGCNConv( 65 | in_channels=self.in_channels, 66 | out_channels=self.out_channels, 67 | num_relations=self.num_relations, 68 | num_bases=self.num_bases, 69 | ) 70 | 71 | self.conv_h_c = RGCNConv( 72 | in_channels=self.out_channels, 73 | out_channels=self.out_channels, 74 | num_relations=self.num_relations, 75 | num_bases=self.num_bases, 76 | ) 77 | 78 | def _create_output_gate_layers(self): 79 | 80 | self.conv_x_o = RGCNConv( 81 | in_channels=self.in_channels, 82 | out_channels=self.out_channels, 83 | num_relations=self.num_relations, 84 | num_bases=self.num_bases, 85 | ) 86 | 87 | self.conv_h_o = RGCNConv( 88 | in_channels=self.out_channels, 89 | out_channels=self.out_channels, 90 | num_relations=self.num_relations, 91 | num_bases=self.num_bases, 92 | ) 93 | 94 | def _create_layers(self): 95 | self._create_input_gate_layers() 96 | self._create_forget_gate_layers() 97 | self._create_cell_state_layers() 98 | self._create_output_gate_layers() 99 | 100 | def _set_hidden_state(self, X, H): 101 | if H is None: 102 | H = torch.zeros(X.shape[0], self.out_channels).to(X.device) 103 | return H 104 | 105 | def _set_cell_state(self, X, C): 106 | if C is None: 107 | C = torch.zeros(X.shape[0], self.out_channels).to(X.device) 108 | return C 109 | 110 | def _calculate_input_gate(self, X, edge_index, edge_type, H, C): 111 | I = self.conv_x_i(X, edge_index, edge_type) 112 | I = I + self.conv_h_i(H, edge_index, edge_type) 113 | I = torch.sigmoid(I) 114 | return I 115 | 116 | def _calculate_forget_gate(self, X, edge_index, edge_type, H, C): 117 | F = self.conv_x_f(X, edge_index, edge_type) 118 | F = F + self.conv_h_f(H, edge_index, edge_type) 119 | F = torch.sigmoid(F) 120 | return F 121 | 122 | def _calculate_cell_state(self, X, edge_index, edge_type, H, C, I, F): 123 | T = self.conv_x_c(X, edge_index, edge_type) 124 | T = T + self.conv_h_c(H, edge_index, edge_type) 125 | T = torch.tanh(T) 126 | C = F * C + I * T 127 | return C 128 | 129 | def _calculate_output_gate(self, X, edge_index, edge_type, H, C): 130 | O = self.conv_x_o(X, edge_index, edge_type) 131 | O = O + self.conv_h_o(H, edge_index, edge_type) 132 | O = torch.sigmoid(O) 133 | return O 134 | 135 | def _calculate_hidden_state(self, O, C): 136 | H = O * torch.tanh(C) 137 | return H 138 | 139 | def forward( 140 | self, 141 | X: torch.FloatTensor, 142 | edge_index: torch.LongTensor, 143 | edge_type: torch.LongTensor, 144 | H: torch.FloatTensor = None, 145 | C: torch.FloatTensor = None, 146 | ) -> torch.FloatTensor: 147 | """ 148 | Making a forward pass. If the hidden state and cell state matrices are 149 | not present when the forward pass is called these are initialized with zeros. 150 | 151 | Arg types: 152 | * **X** *(PyTorch Float Tensor)* - Node features. 153 | * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices. 154 | * **edge_type** *(PyTorch Long Tensor)* - Edge type vector. 155 | * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes. 156 | * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes. 157 | 158 | Return types: 159 | * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes. 160 | * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes. 161 | """ 162 | H = self._set_hidden_state(X, H) 163 | C = self._set_cell_state(X, C) 164 | I = self._calculate_input_gate(X, edge_index, edge_type, H, C) 165 | F = self._calculate_forget_gate(X, edge_index, edge_type, H, C) 166 | C = self._calculate_cell_state(X, edge_index, edge_type, H, C, I, F) 167 | O = self._calculate_output_gate(X, edge_index, edge_type, H, C) 168 | H = self._calculate_hidden_state(O, C) 169 | return H, C 170 | -------------------------------------------------------------------------------- /torch_geometric_temporal/nn/recurrent/mpnn_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv 5 | 6 | 7 | class MPNNLSTM(nn.Module): 8 | r"""An implementation of the Message Passing Neural Network with Long Short Term Memory. 9 | For details see this paper: `"Transfer Graph Neural Networks for Pandemic Forecasting." `_ 10 | 11 | Args: 12 | in_channels (int): Number of input features. 13 | hidden_size (int): Dimension of hidden representations. 14 | num_nodes (int): Number of nodes in the network. 15 | window (int): Number of past samples included in the input. 16 | dropout (float): Dropout rate. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | in_channels: int, 22 | hidden_size: int, 23 | num_nodes: int, 24 | window: int, 25 | dropout: float, 26 | ): 27 | super(MPNNLSTM, self).__init__() 28 | 29 | self.window = window 30 | self.num_nodes = num_nodes 31 | self.hidden_size = hidden_size 32 | self.dropout = dropout 33 | self.in_channels = in_channels 34 | 35 | self._create_parameters_and_layers() 36 | 37 | def _create_parameters_and_layers(self): 38 | 39 | self._convolution_1 = GCNConv(self.in_channels, self.hidden_size) 40 | self._convolution_2 = GCNConv(self.hidden_size, self.hidden_size) 41 | 42 | self._batch_norm_1 = nn.BatchNorm1d(self.hidden_size) 43 | self._batch_norm_2 = nn.BatchNorm1d(self.hidden_size) 44 | 45 | self._recurrent_1 = nn.LSTM(2 * self.hidden_size, self.hidden_size, 1) 46 | self._recurrent_2 = nn.LSTM(self.hidden_size, self.hidden_size, 1) 47 | 48 | def _graph_convolution_1(self, X, edge_index, edge_weight): 49 | X = F.relu(self._convolution_1(X, edge_index, edge_weight)) 50 | X = self._batch_norm_1(X) 51 | X = F.dropout(X, p=self.dropout, training=self.training) 52 | return X 53 | 54 | def _graph_convolution_2(self, X, edge_index, edge_weight): 55 | X = F.relu(self._convolution_2(X, edge_index, edge_weight)) 56 | X = self._batch_norm_2(X) 57 | X = F.dropout(X, p=self.dropout, training=self.training) 58 | return X 59 | 60 | def forward( 61 | self, 62 | X: torch.FloatTensor, 63 | edge_index: torch.LongTensor, 64 | edge_weight: torch.FloatTensor, 65 | ) -> torch.FloatTensor: 66 | """ 67 | Making a forward pass through the whole architecture. 68 | 69 | Arg types: 70 | * **X** *(PyTorch FloatTensor)* - Node features. 71 | * **edge_index** *(PyTorch LongTensor)* - Graph edge indices. 72 | * **edge_weight** *(PyTorch LongTensor, optional)* - Edge weight vector. 73 | 74 | Return types: 75 | * **H** *(PyTorch FloatTensor)* - The hidden representation of size 2*nhid+in_channels+window-1 for each node. 76 | """ 77 | R = list() 78 | 79 | S = X.view(-1, self.window, self.num_nodes, self.in_channels) 80 | S = torch.transpose(S, 1, 2) 81 | S = S.reshape(-1, self.window, self.in_channels) 82 | O = [S[:, 0, :]] 83 | 84 | for l in range(1, self.window): 85 | O.append(S[:, l, self.in_channels - 1].unsqueeze(1)) 86 | 87 | S = torch.cat(O, dim=1) 88 | 89 | X = self._graph_convolution_1(X, edge_index, edge_weight) 90 | R.append(X) 91 | 92 | X = self._graph_convolution_2(X, edge_index, edge_weight) 93 | R.append(X) 94 | 95 | X = torch.cat(R, dim=1) 96 | 97 | X = X.view(-1, self.window, self.num_nodes, X.size(1)) 98 | X = torch.transpose(X, 0, 1) 99 | X = X.contiguous().view(self.window, -1, X.size(3)) 100 | 101 | X, (H_1, _) = self._recurrent_1(X) 102 | X, (H_2, _) = self._recurrent_2(X) 103 | 104 | H = torch.cat([H_1[0, :, :], H_2[0, :, :], S], dim=1) 105 | return H 106 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic_graph_temporal_signal import * 2 | from .dynamic_graph_temporal_signal_batch import * 3 | 4 | from .static_graph_temporal_signal import * 5 | from .static_graph_temporal_signal_batch import * 6 | 7 | from .dynamic_graph_static_signal import * 8 | from .dynamic_graph_static_signal_batch import * 9 | 10 | from .dynamic_hetero_graph_temporal_signal import * 11 | from .dynamic_hetero_graph_temporal_signal_batch import * 12 | 13 | from .static_hetero_graph_temporal_signal import * 14 | from .static_hetero_graph_temporal_signal_batch import * 15 | 16 | from .dynamic_hetero_graph_static_signal import * 17 | from .dynamic_hetero_graph_static_signal_batch import * 18 | 19 | from .train_test_split import * 20 | 21 | # from .index_dataset import * 22 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/dynamic_graph_static_signal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Data 5 | 6 | 7 | Edge_Indices = Sequence[Union[np.ndarray, None]] 8 | Edge_Weights = Sequence[Union[np.ndarray, None]] 9 | Node_Feature = Union[np.ndarray, None] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Additional_Features = Sequence[np.ndarray] 12 | 13 | 14 | class DynamicGraphStaticSignal(object): 15 | r"""A data iterator object to contain a dynamic graph with a 16 | changing edge set and weights . The node labels 17 | (target) are also dynamic. The iterator returns a single discrete temporal 18 | snapshot for a time period (e.g. day or week). This single snapshot is a 19 | Pytorch Geometric Data object. Between two temporal snapshots the edges, 20 | edge weights, target matrices and optionally passed attributes might change. 21 | 22 | Args: 23 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors. 24 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors. 25 | feature (Numpy array): Node feature tensor. 26 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 27 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | edge_indices: Edge_Indices, 33 | edge_weights: Edge_Weights, 34 | feature: Node_Feature, 35 | targets: Targets, 36 | **kwargs: Additional_Features 37 | ): 38 | self.edge_indices = edge_indices 39 | self.edge_weights = edge_weights 40 | self.feature = feature 41 | self.targets = targets 42 | self.additional_feature_keys = [] 43 | for key, value in kwargs.items(): 44 | setattr(self, key, value) 45 | self.additional_feature_keys.append(key) 46 | self._check_temporal_consistency() 47 | self._set_snapshot_count() 48 | 49 | def _check_temporal_consistency(self): 50 | assert len(self.edge_indices) == len( 51 | self.edge_weights 52 | ), "Temporal dimension inconsistency." 53 | assert len(self.targets) == len( 54 | self.edge_indices 55 | ), "Temporal dimension inconsistency." 56 | for key in self.additional_feature_keys: 57 | assert len(self.targets) == len( 58 | getattr(self, key) 59 | ), "Temporal dimension inconsistency." 60 | 61 | def _set_snapshot_count(self): 62 | self.snapshot_count = len(self.targets) 63 | 64 | def _get_edge_index(self, time_index: int): 65 | if self.edge_indices[time_index] is None: 66 | return self.edge_indices[time_index] 67 | else: 68 | return torch.LongTensor(self.edge_indices[time_index]) 69 | 70 | def _get_edge_weight(self, time_index: int): 71 | if self.edge_weights[time_index] is None: 72 | return self.edge_weights[time_index] 73 | else: 74 | return torch.FloatTensor(self.edge_weights[time_index]) 75 | 76 | def _get_feature(self): 77 | if self.feature is None: 78 | return self.feature 79 | else: 80 | return torch.FloatTensor(self.feature) 81 | 82 | def _get_target(self, time_index: int): 83 | if self.targets[time_index] is None: 84 | return self.targets[time_index] 85 | else: 86 | if self.targets[time_index].dtype.kind == "i": 87 | return torch.LongTensor(self.targets[time_index]) 88 | elif self.targets[time_index].dtype.kind == "f": 89 | return torch.FloatTensor(self.targets[time_index]) 90 | 91 | def _get_additional_feature(self, time_index: int, feature_key: str): 92 | feature = getattr(self, feature_key)[time_index] 93 | if feature.dtype.kind == "i": 94 | return torch.LongTensor(feature) 95 | elif feature.dtype.kind == "f": 96 | return torch.FloatTensor(feature) 97 | 98 | def _get_additional_features(self, time_index: int): 99 | additional_features = { 100 | key: self._get_additional_feature(time_index, key) 101 | for key in self.additional_feature_keys 102 | } 103 | return additional_features 104 | 105 | def __len__(self): 106 | return len(self.targets) 107 | 108 | def __getitem__(self, time_index: Union[int, slice]): 109 | if isinstance(time_index, slice): 110 | snapshot = DynamicGraphStaticSignal( 111 | self.edge_indices[time_index], 112 | self.edge_weights[time_index], 113 | self.feature, 114 | self.targets[time_index], 115 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 116 | ) 117 | else: 118 | x = self._get_feature() 119 | edge_index = self._get_edge_index(time_index) 120 | edge_weight = self._get_edge_weight(time_index) 121 | y = self._get_target(time_index) 122 | additional_features = self._get_additional_features(time_index) 123 | 124 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, 125 | y=y, **additional_features) 126 | return snapshot 127 | 128 | def __next__(self): 129 | if self.t < len(self.targets): 130 | snapshot = self[self.t] 131 | self.t = self.t + 1 132 | return snapshot 133 | else: 134 | self.t = 0 135 | raise StopIteration 136 | 137 | def __iter__(self): 138 | self.t = 0 139 | return self 140 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/dynamic_graph_static_signal_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Batch 5 | 6 | 7 | Edge_Indices = Sequence[Union[np.ndarray, None]] 8 | Edge_Weights = Sequence[Union[np.ndarray, None]] 9 | Node_Feature = Union[np.ndarray, None] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Batches = Sequence[Union[np.ndarray, None]] 12 | Additional_Features = Sequence[np.ndarray] 13 | 14 | 15 | class DynamicGraphStaticSignalBatch(object): 16 | r"""A batch iterator object to contain a dynamic graph with a 17 | changing edge set and weights . The node labels 18 | (target) are also dynamic. The iterator returns a single discrete temporal 19 | snapshot for a time period (e.g. day or week). This single snapshot is a 20 | Pytorch Geometric Batch object. Between two temporal snapshots the edges, 21 | batch memberships, edge weights, target matrices and optionally passed 22 | attributes might change. 23 | 24 | Args: 25 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors. 26 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors. 27 | feature (Numpy array): Node feature tensor. 28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 29 | batches (Sequence of Numpy arrays): Sequence of batch index tensors. 30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | edge_indices: Edge_Indices, 36 | edge_weights: Edge_Weights, 37 | feature: Node_Feature, 38 | targets: Targets, 39 | batches: Batches, 40 | **kwargs: Additional_Features 41 | ): 42 | self.edge_indices = edge_indices 43 | self.edge_weights = edge_weights 44 | self.feature = feature 45 | self.targets = targets 46 | self.batches = batches 47 | self.additional_feature_keys = [] 48 | for key, value in kwargs.items(): 49 | setattr(self, key, value) 50 | self.additional_feature_keys.append(key) 51 | self._check_temporal_consistency() 52 | self._set_snapshot_count() 53 | 54 | def _check_temporal_consistency(self): 55 | assert len(self.edge_indices) == len( 56 | self.edge_weights 57 | ), "Temporal dimension inconsistency." 58 | assert len(self.targets) == len( 59 | self.edge_indices 60 | ), "Temporal dimension inconsistency." 61 | assert len(self.batches) == len( 62 | self.edge_indices 63 | ), "Temporal dimension inconsistency." 64 | for key in self.additional_feature_keys: 65 | assert len(self.targets) == len( 66 | getattr(self, key) 67 | ), "Temporal dimension inconsistency." 68 | 69 | def _set_snapshot_count(self): 70 | self.snapshot_count = len(self.targets) 71 | 72 | def _get_edge_index(self, time_index: int): 73 | if self.edge_indices[time_index] is None: 74 | return self.edge_indices[time_index] 75 | else: 76 | return torch.LongTensor(self.edge_indices[time_index]) 77 | 78 | def _get_batch_index(self, time_index: int): 79 | if self.batches[time_index] is None: 80 | return self.batches[time_index] 81 | else: 82 | return torch.LongTensor(self.batches[time_index]) 83 | 84 | def _get_edge_weight(self, time_index: int): 85 | if self.edge_weights[time_index] is None: 86 | return self.edge_weights[time_index] 87 | else: 88 | return torch.FloatTensor(self.edge_weights[time_index]) 89 | 90 | def _get_feature(self): 91 | if self.feature is None: 92 | return self.feature 93 | else: 94 | return torch.FloatTensor(self.feature) 95 | 96 | def _get_target(self, time_index: int): 97 | if self.targets[time_index] is None: 98 | return self.targets[time_index] 99 | else: 100 | if self.targets[time_index].dtype.kind == "i": 101 | return torch.LongTensor(self.targets[time_index]) 102 | elif self.targets[time_index].dtype.kind == "f": 103 | return torch.FloatTensor(self.targets[time_index]) 104 | 105 | def _get_additional_feature(self, time_index: int, feature_key: str): 106 | feature = getattr(self, feature_key)[time_index] 107 | if feature.dtype.kind == "i": 108 | return torch.LongTensor(feature) 109 | elif feature.dtype.kind == "f": 110 | return torch.FloatTensor(feature) 111 | 112 | def _get_additional_features(self, time_index: int): 113 | additional_features = { 114 | key: self._get_additional_feature(time_index, key) 115 | for key in self.additional_feature_keys 116 | } 117 | return additional_features 118 | 119 | def __getitem__(self, time_index: Union[int, slice]): 120 | if isinstance(time_index, slice): 121 | snapshot = DynamicGraphStaticSignalBatch( 122 | self.edge_indices[time_index], 123 | self.edge_weights[time_index], 124 | self.feature, 125 | self.targets[time_index], 126 | self.batches[time_index], 127 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 128 | ) 129 | else: 130 | x = self._get_feature() 131 | edge_index = self._get_edge_index(time_index) 132 | edge_weight = self._get_edge_weight(time_index) 133 | batch = self._get_batch_index(time_index) 134 | y = self._get_target(time_index) 135 | additional_features = self._get_additional_features(time_index) 136 | 137 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight, 138 | y=y, batch=batch, **additional_features) 139 | return snapshot 140 | 141 | def __next__(self): 142 | if self.t < len(self.targets): 143 | snapshot = self[self.t] 144 | self.t = self.t + 1 145 | return snapshot 146 | else: 147 | self.t = 0 148 | raise StopIteration 149 | 150 | def __iter__(self): 151 | self.t = 0 152 | return self 153 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/dynamic_graph_temporal_signal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Data 5 | 6 | 7 | Edge_Indices = Sequence[Union[np.ndarray, None]] 8 | Edge_Weights = Sequence[Union[np.ndarray, None]] 9 | Node_Features = Sequence[Union[np.ndarray, None]] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Additional_Features = Sequence[np.ndarray] 12 | 13 | 14 | class DynamicGraphTemporalSignal(object): 15 | r"""A data iterator object to contain a dynamic graph with a 16 | changing edge set and weights . The feature set and node labels 17 | (target) are also dynamic. The iterator returns a single discrete temporal 18 | snapshot for a time period (e.g. day or week). This single snapshot is a 19 | Pytorch Geometric Data object. Between two temporal snapshots the edges, 20 | edge weights, target matrices and optionally passed attributes might change. 21 | 22 | Args: 23 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors. 24 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors. 25 | features (Sequence of Numpy arrays): Sequence of node feature tensors. 26 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 27 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | edge_indices: Edge_Indices, 33 | edge_weights: Edge_Weights, 34 | features: Node_Features, 35 | targets: Targets, 36 | **kwargs: Additional_Features 37 | ): 38 | self.edge_indices = edge_indices 39 | self.edge_weights = edge_weights 40 | self.features = features 41 | self.targets = targets 42 | self.additional_feature_keys = [] 43 | for key, value in kwargs.items(): 44 | setattr(self, key, value) 45 | self.additional_feature_keys.append(key) 46 | self._check_temporal_consistency() 47 | self._set_snapshot_count() 48 | 49 | def _check_temporal_consistency(self): 50 | assert len(self.features) == len( 51 | self.targets 52 | ), "Temporal dimension inconsistency." 53 | assert len(self.edge_indices) == len( 54 | self.edge_weights 55 | ), "Temporal dimension inconsistency." 56 | assert len(self.features) == len( 57 | self.edge_weights 58 | ), "Temporal dimension inconsistency." 59 | for key in self.additional_feature_keys: 60 | assert len(self.targets) == len( 61 | getattr(self, key) 62 | ), "Temporal dimension inconsistency." 63 | 64 | def _set_snapshot_count(self): 65 | self.snapshot_count = len(self.features) 66 | 67 | def _get_edge_index(self, time_index: int): 68 | if self.edge_indices[time_index] is None: 69 | return self.edge_indices[time_index] 70 | else: 71 | return torch.LongTensor(self.edge_indices[time_index]) 72 | 73 | def _get_edge_weight(self, time_index: int): 74 | if self.edge_weights[time_index] is None: 75 | return self.edge_weights[time_index] 76 | else: 77 | return torch.FloatTensor(self.edge_weights[time_index]) 78 | 79 | def _get_features(self, time_index: int): 80 | if self.features[time_index] is None: 81 | return self.features[time_index] 82 | else: 83 | return torch.FloatTensor(self.features[time_index]) 84 | 85 | def _get_target(self, time_index: int): 86 | if self.targets[time_index] is None: 87 | return self.targets[time_index] 88 | else: 89 | if self.targets[time_index].dtype.kind == "i": 90 | return torch.LongTensor(self.targets[time_index]) 91 | elif self.targets[time_index].dtype.kind == "f": 92 | return torch.FloatTensor(self.targets[time_index]) 93 | 94 | def _get_additional_feature(self, time_index: int, feature_key: str): 95 | feature = getattr(self, feature_key)[time_index] 96 | if feature.dtype.kind == "i": 97 | return torch.LongTensor(feature) 98 | elif feature.dtype.kind == "f": 99 | return torch.FloatTensor(feature) 100 | 101 | def _get_additional_features(self, time_index: int): 102 | additional_features = { 103 | key: self._get_additional_feature(time_index, key) 104 | for key in self.additional_feature_keys 105 | } 106 | return additional_features 107 | 108 | def __getitem__(self, time_index: Union[int, slice]): 109 | if isinstance(time_index, slice): 110 | snapshot = DynamicGraphTemporalSignal( 111 | self.edge_indices[time_index], 112 | self.edge_weights[time_index], 113 | self.features[time_index], 114 | self.targets[time_index], 115 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 116 | ) 117 | else: 118 | x = self._get_features(time_index) 119 | edge_index = self._get_edge_index(time_index) 120 | edge_weight = self._get_edge_weight(time_index) 121 | y = self._get_target(time_index) 122 | additional_features = self._get_additional_features(time_index) 123 | 124 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, 125 | y=y, **additional_features) 126 | return snapshot 127 | 128 | def __next__(self): 129 | if self.t < len(self.features): 130 | snapshot = self[self.t] 131 | self.t = self.t + 1 132 | return snapshot 133 | else: 134 | self.t = 0 135 | raise StopIteration 136 | 137 | def __iter__(self): 138 | self.t = 0 139 | return self 140 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/dynamic_graph_temporal_signal_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Batch 5 | 6 | 7 | Edge_Indices = Sequence[Union[np.ndarray, None]] 8 | Edge_Weights = Sequence[Union[np.ndarray, None]] 9 | Node_Features = Sequence[Union[np.ndarray, None]] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Batches = Sequence[Union[np.ndarray, None]] 12 | Additional_Features = Sequence[np.ndarray] 13 | 14 | 15 | class DynamicGraphTemporalSignalBatch(object): 16 | r"""A data iterator object to contain a dynamic graph with a 17 | changing edge set and weights . The feature set and node labels 18 | (target) are also dynamic. The iterator returns a single discrete temporal 19 | snapshot for a time period (e.g. day or week). This single snapshot is a 20 | Pytorch Geometric Batch object. Between two temporal snapshots the edges, 21 | edge weights, the feature matrix, target matrices and optionally passed 22 | attributes might change. 23 | 24 | Args: 25 | edge_indices (Sequence of Numpy arrays): Sequence of edge index tensors. 26 | edge_weights (Sequence of Numpy arrays): Sequence of edge weight tensors. 27 | features (Sequence of Numpy arrays): Sequence of node feature tensors. 28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 29 | batches (Sequence of Numpy arrays): Sequence of batch index tensors. 30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | edge_indices: Edge_Indices, 36 | edge_weights: Edge_Weights, 37 | features: Node_Features, 38 | targets: Targets, 39 | batches: Batches, 40 | **kwargs: Additional_Features 41 | ): 42 | self.edge_indices = edge_indices 43 | self.edge_weights = edge_weights 44 | self.features = features 45 | self.targets = targets 46 | self.batches = batches 47 | self.additional_feature_keys = [] 48 | for key, value in kwargs.items(): 49 | setattr(self, key, value) 50 | self.additional_feature_keys.append(key) 51 | self._check_temporal_consistency() 52 | self._set_snapshot_count() 53 | 54 | def _check_temporal_consistency(self): 55 | assert len(self.features) == len( 56 | self.targets 57 | ), "Temporal dimension inconsistency." 58 | assert len(self.edge_indices) == len( 59 | self.edge_weights 60 | ), "Temporal dimension inconsistency." 61 | assert len(self.features) == len( 62 | self.edge_weights 63 | ), "Temporal dimension inconsistency." 64 | assert len(self.features) == len( 65 | self.batches 66 | ), "Temporal dimension inconsistency." 67 | for key in self.additional_feature_keys: 68 | assert len(self.targets) == len( 69 | getattr(self, key) 70 | ), "Temporal dimension inconsistency." 71 | 72 | def _set_snapshot_count(self): 73 | self.snapshot_count = len(self.features) 74 | 75 | def _get_edge_index(self, time_index: int): 76 | if self.edge_indices[time_index] is None: 77 | return self.edge_indices[time_index] 78 | else: 79 | return torch.LongTensor(self.edge_indices[time_index]) 80 | 81 | def _get_batch_index(self, time_index: int): 82 | if self.batches[time_index] is None: 83 | return self.batches[time_index] 84 | else: 85 | return torch.LongTensor(self.batches[time_index]) 86 | 87 | def _get_edge_weight(self, time_index: int): 88 | if self.edge_weights[time_index] is None: 89 | return self.edge_weights[time_index] 90 | else: 91 | return torch.FloatTensor(self.edge_weights[time_index]) 92 | 93 | def _get_feature(self, time_index: int): 94 | if self.features[time_index] is None: 95 | return self.features[time_index] 96 | else: 97 | return torch.FloatTensor(self.features[time_index]) 98 | 99 | def _get_target(self, time_index: int): 100 | if self.targets[time_index] is None: 101 | return self.targets[time_index] 102 | else: 103 | if self.targets[time_index].dtype.kind == "i": 104 | return torch.LongTensor(self.targets[time_index]) 105 | elif self.targets[time_index].dtype.kind == "f": 106 | return torch.FloatTensor(self.targets[time_index]) 107 | 108 | def _get_additional_feature(self, time_index: int, feature_key: str): 109 | feature = getattr(self, feature_key)[time_index] 110 | if feature.dtype.kind == "i": 111 | return torch.LongTensor(feature) 112 | elif feature.dtype.kind == "f": 113 | return torch.FloatTensor(feature) 114 | 115 | def _get_additional_features(self, time_index: int): 116 | additional_features = { 117 | key: self._get_additional_feature(time_index, key) 118 | for key in self.additional_feature_keys 119 | } 120 | return additional_features 121 | 122 | def __getitem__(self, time_index: Union[int, slice]): 123 | if isinstance(time_index, slice): 124 | snapshot = DynamicGraphTemporalSignalBatch( 125 | self.edge_indices[time_index], 126 | self.edge_weights[time_index], 127 | self.features[time_index], 128 | self.targets[time_index], 129 | self.batches[time_index], 130 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 131 | ) 132 | else: 133 | x = self._get_feature(time_index) 134 | edge_index = self._get_edge_index(time_index) 135 | edge_weight = self._get_edge_weight(time_index) 136 | batch = self._get_batch_index(time_index) 137 | y = self._get_target(time_index) 138 | additional_features = self._get_additional_features(time_index) 139 | 140 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight, 141 | y=y, batch=batch, **additional_features) 142 | return snapshot 143 | 144 | def __next__(self): 145 | if self.t < len(self.features): 146 | snapshot = self[self.t] 147 | self.t = self.t + 1 148 | return snapshot 149 | else: 150 | self.t = 0 151 | raise StopIteration 152 | 153 | def __iter__(self): 154 | self.t = 0 155 | return self 156 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/index_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import dask.array as da 4 | import numpy as np 5 | 6 | 7 | 8 | class IndexDataset(Dataset): 9 | """ 10 | A custom pytorch-compatible dataset that implements index-batching and DDP-index-batching. 11 | It also supports GPU-index-batching and lazy-index-batching. 12 | 13 | Args: 14 | indices (array-like): Indices corresponding to the time slicies. 15 | data (array-like or Dask array): The dataset to be indexed. 16 | horizon (int): The prediction period for the dataset. 17 | lazy (bool, optional): Whether to use Dask lazy loading (distribute the data across all workers). Defaults to False. 18 | gpu (bool, optional): If the data is already on the GPU. Defaults to False. 19 | """ 20 | def __init__(self, indices, data, horizon, lazy=False, gpu=False): 21 | self.indices = indices 22 | self.data = data 23 | self.horizon = horizon 24 | self.lazy = lazy 25 | self.gpu = gpu 26 | 27 | def __len__(self): 28 | 29 | # Return the number of samples 30 | return self.indices.shape[0] 31 | 32 | def __getitem__(self, x): 33 | """ 34 | Retrieve a data sample and its corresponding target based on the index. 35 | 36 | Args: 37 | x (int): The index of the sample to retrieve. 38 | 39 | Returns: 40 | tuple: A tuple (x, y), where `x` is the input sequence and `y` is the target sequence. 41 | """ 42 | 43 | idx = self.indices[x] 44 | 45 | # Calculate the offset based on the horizon value 46 | y_start = idx + self.horizon 47 | 48 | # If the data is already on the gpu (likely due to using index-gpu-preprocessing), return tensor-slice 49 | if self.gpu: 50 | return self.data[idx:y_start,...], self.data[y_start:y_start + self.horizon,...] 51 | 52 | else: 53 | # if utilizing DDP-batching, gather the data on to this worker and convert to tensor 54 | if self.lazy: 55 | return torch.from_numpy(self.data[idx:y_start,...].compute()),torch.from_numpy(self.data[y_start:y_start + self.horizon,...].compute()) 56 | else: 57 | return torch.from_numpy(self.data[idx:y_start,...]), torch.from_numpy(self.data[y_start:y_start + self.horizon,...]) 58 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/static_graph_temporal_signal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Data 5 | 6 | 7 | Edge_Index = Union[np.ndarray, None] 8 | Edge_Weight = Union[np.ndarray, None] 9 | Node_Features = Sequence[Union[np.ndarray, None]] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Additional_Features = Sequence[np.ndarray] 12 | 13 | 14 | class StaticGraphTemporalSignal(object): 15 | r"""A data iterator object to contain a static graph with a dynamically 16 | changing constant time difference temporal feature set (multiple signals). 17 | The node labels (target) are also temporal. The iterator returns a single 18 | constant time difference temporal snapshot for a time period (e.g. day or week). 19 | This single temporal snapshot is a Pytorch Geometric Data object. Between two 20 | temporal snapshots the features and optionally passed attributes might change. 21 | However, the underlying graph is the same. 22 | 23 | Args: 24 | edge_index (Numpy array): Index tensor of edges. 25 | edge_weight (Numpy array): Edge weight tensor. 26 | features (Sequence of Numpy arrays): Sequence of node feature tensors. 27 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 28 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | edge_index: Edge_Index, 34 | edge_weight: Edge_Weight, 35 | features: Node_Features, 36 | targets: Targets, 37 | **kwargs: Additional_Features 38 | ): 39 | self.edge_index = edge_index 40 | self.edge_weight = edge_weight 41 | self.features = features 42 | self.targets = targets 43 | self.additional_feature_keys = [] 44 | for key, value in kwargs.items(): 45 | setattr(self, key, value) 46 | self.additional_feature_keys.append(key) 47 | self._check_temporal_consistency() 48 | self._set_snapshot_count() 49 | 50 | def _check_temporal_consistency(self): 51 | assert len(self.features) == len( 52 | self.targets 53 | ), "Temporal dimension inconsistency." 54 | for key in self.additional_feature_keys: 55 | assert len(self.targets) == len( 56 | getattr(self, key) 57 | ), "Temporal dimension inconsistency." 58 | 59 | def _set_snapshot_count(self): 60 | self.snapshot_count = len(self.features) 61 | 62 | def _get_edge_index(self): 63 | if self.edge_index is None: 64 | return self.edge_index 65 | else: 66 | return torch.LongTensor(self.edge_index) 67 | 68 | def _get_edge_weight(self): 69 | if self.edge_weight is None: 70 | return self.edge_weight 71 | else: 72 | return torch.FloatTensor(self.edge_weight) 73 | 74 | def _get_features(self, time_index: int): 75 | if self.features[time_index] is None: 76 | return self.features[time_index] 77 | else: 78 | return torch.FloatTensor(self.features[time_index]) 79 | 80 | def _get_target(self, time_index: int): 81 | if self.targets[time_index] is None: 82 | return self.targets[time_index] 83 | else: 84 | if self.targets[time_index].dtype.kind == "i": 85 | return torch.LongTensor(self.targets[time_index]) 86 | elif self.targets[time_index].dtype.kind == "f": 87 | return torch.FloatTensor(self.targets[time_index]) 88 | 89 | def _get_additional_feature(self, time_index: int, feature_key: str): 90 | feature = getattr(self, feature_key)[time_index] 91 | if feature.dtype.kind == "i": 92 | return torch.LongTensor(feature) 93 | elif feature.dtype.kind == "f": 94 | return torch.FloatTensor(feature) 95 | 96 | def _get_additional_features(self, time_index: int): 97 | additional_features = { 98 | key: self._get_additional_feature(time_index, key) 99 | for key in self.additional_feature_keys 100 | } 101 | return additional_features 102 | 103 | def __getitem__(self, time_index: Union[int, slice]): 104 | if isinstance(time_index, slice): 105 | snapshot = StaticGraphTemporalSignal( 106 | self.edge_index, 107 | self.edge_weight, 108 | self.features[time_index], 109 | self.targets[time_index], 110 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 111 | ) 112 | else: 113 | x = self._get_features(time_index) 114 | edge_index = self._get_edge_index() 115 | edge_weight = self._get_edge_weight() 116 | y = self._get_target(time_index) 117 | additional_features = self._get_additional_features(time_index) 118 | 119 | snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, 120 | y=y, **additional_features) 121 | return snapshot 122 | 123 | def __next__(self): 124 | if self.t < len(self.features): 125 | snapshot = self[self.t] 126 | self.t = self.t + 1 127 | return snapshot 128 | else: 129 | self.t = 0 130 | raise StopIteration 131 | 132 | def __iter__(self): 133 | self.t = 0 134 | return self 135 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/static_graph_temporal_signal_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence, Union 4 | from torch_geometric.data import Batch 5 | 6 | 7 | Edge_Index = Union[np.ndarray, None] 8 | Edge_Weight = Union[np.ndarray, None] 9 | Node_Features = Sequence[Union[np.ndarray, None]] 10 | Targets = Sequence[Union[np.ndarray, None]] 11 | Batches = Union[np.ndarray, None] 12 | Additional_Features = Sequence[np.ndarray] 13 | 14 | 15 | class StaticGraphTemporalSignalBatch(object): 16 | r"""A data iterator object to contain a static graph with a dynamically 17 | changing constant time difference temporal feature set (multiple signals). 18 | The node labels (target) are also temporal. The iterator returns a single 19 | constant time difference temporal snapshot for a time period (e.g. day or week). 20 | This single temporal snapshot is a Pytorch Geometric Batch object. Between two 21 | temporal snapshots the feature matrix, target matrices and optionally passed 22 | attributes might change. However, the underlying graph is the same. 23 | 24 | Args: 25 | edge_index (Numpy array): Index tensor of edges. 26 | edge_weight (Numpy array): Edge weight tensor. 27 | features (Sequence of Numpy arrays): Sequence of node feature tensors. 28 | targets (Sequence of Numpy arrays): Sequence of node label (target) tensors. 29 | batches (Numpy array): Batch index tensor. 30 | **kwargs (optional Sequence of Numpy arrays): Sequence of additional attributes. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | edge_index: Edge_Index, 36 | edge_weight: Edge_Weight, 37 | features: Node_Features, 38 | targets: Targets, 39 | batches: Batches, 40 | **kwargs: Additional_Features 41 | ): 42 | self.edge_index = edge_index 43 | self.edge_weight = edge_weight 44 | self.features = features 45 | self.targets = targets 46 | self.batches = batches 47 | self.additional_feature_keys = [] 48 | for key, value in kwargs.items(): 49 | setattr(self, key, value) 50 | self.additional_feature_keys.append(key) 51 | self._check_temporal_consistency() 52 | self._set_snapshot_count() 53 | 54 | def _check_temporal_consistency(self): 55 | assert len(self.features) == len( 56 | self.targets 57 | ), "Temporal dimension inconsistency." 58 | for key in self.additional_feature_keys: 59 | assert len(self.targets) == len( 60 | getattr(self, key) 61 | ), "Temporal dimension inconsistency." 62 | 63 | def _set_snapshot_count(self): 64 | self.snapshot_count = len(self.features) 65 | 66 | def _get_edge_index(self): 67 | if self.edge_index is None: 68 | return self.edge_index 69 | else: 70 | return torch.LongTensor(self.edge_index) 71 | 72 | def _get_batch_index(self): 73 | if self.batches is None: 74 | return self.batches 75 | else: 76 | return torch.LongTensor(self.batches) 77 | 78 | def _get_edge_weight(self): 79 | if self.edge_weight is None: 80 | return self.edge_weight 81 | else: 82 | return torch.FloatTensor(self.edge_weight) 83 | 84 | def _get_feature(self, time_index: int): 85 | if self.features[time_index] is None: 86 | return self.features[time_index] 87 | else: 88 | return torch.FloatTensor(self.features[time_index]) 89 | 90 | def _get_target(self, time_index: int): 91 | if self.targets[time_index] is None: 92 | return self.targets[time_index] 93 | else: 94 | if self.targets[time_index].dtype.kind == "i": 95 | return torch.LongTensor(self.targets[time_index]) 96 | elif self.targets[time_index].dtype.kind == "f": 97 | return torch.FloatTensor(self.targets[time_index]) 98 | 99 | def _get_additional_feature(self, time_index: int, feature_key: str): 100 | feature = getattr(self, feature_key)[time_index] 101 | if feature.dtype.kind == "i": 102 | return torch.LongTensor(feature) 103 | elif feature.dtype.kind == "f": 104 | return torch.FloatTensor(feature) 105 | 106 | def _get_additional_features(self, time_index: int): 107 | additional_features = { 108 | key: self._get_additional_feature(time_index, key) 109 | for key in self.additional_feature_keys 110 | } 111 | return additional_features 112 | 113 | def __getitem__(self, time_index: Union[int, slice]): 114 | if isinstance(time_index, slice): 115 | snapshot = StaticGraphTemporalSignalBatch( 116 | self.edge_index, 117 | self.edge_weight, 118 | self.features[time_index], 119 | self.targets[time_index], 120 | self.batches, 121 | **{key: getattr(self, key)[time_index] for key in self.additional_feature_keys} 122 | ) 123 | else: 124 | x = self._get_feature(time_index) 125 | edge_index = self._get_edge_index() 126 | edge_weight = self._get_edge_weight() 127 | batch = self._get_batch_index() 128 | y = self._get_target(time_index) 129 | additional_features = self._get_additional_features(time_index) 130 | 131 | snapshot = Batch(x=x, edge_index=edge_index, edge_attr=edge_weight, 132 | y=y, batch=batch, **additional_features) 133 | return snapshot 134 | 135 | def __next__(self): 136 | if self.t < len(self.features): 137 | snapshot = self[self.t] 138 | self.t = self.t + 1 139 | return snapshot 140 | else: 141 | self.t = 0 142 | raise StopIteration 143 | 144 | def __iter__(self): 145 | self.t = 0 146 | return self 147 | -------------------------------------------------------------------------------- /torch_geometric_temporal/signal/train_test_split.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | from .static_graph_temporal_signal import StaticGraphTemporalSignal 4 | from .dynamic_graph_temporal_signal import DynamicGraphTemporalSignal 5 | from .dynamic_graph_static_signal import DynamicGraphStaticSignal 6 | 7 | from .static_graph_temporal_signal_batch import StaticGraphTemporalSignalBatch 8 | from .dynamic_graph_temporal_signal_batch import DynamicGraphTemporalSignalBatch 9 | from .dynamic_graph_static_signal_batch import DynamicGraphStaticSignalBatch 10 | 11 | from .static_hetero_graph_temporal_signal import StaticHeteroGraphTemporalSignal 12 | from .dynamic_hetero_graph_temporal_signal import DynamicHeteroGraphTemporalSignal 13 | from .dynamic_hetero_graph_static_signal import DynamicHeteroGraphStaticSignal 14 | 15 | from .static_hetero_graph_temporal_signal_batch import StaticHeteroGraphTemporalSignalBatch 16 | from .dynamic_hetero_graph_temporal_signal_batch import DynamicHeteroGraphTemporalSignalBatch 17 | from .dynamic_hetero_graph_static_signal_batch import DynamicHeteroGraphStaticSignalBatch 18 | 19 | 20 | Discrete_Signal = Union[ 21 | StaticGraphTemporalSignal, 22 | StaticGraphTemporalSignalBatch, 23 | DynamicGraphTemporalSignal, 24 | DynamicGraphTemporalSignalBatch, 25 | DynamicGraphStaticSignal, 26 | DynamicGraphStaticSignalBatch, 27 | StaticHeteroGraphTemporalSignal, 28 | StaticHeteroGraphTemporalSignalBatch, 29 | DynamicHeteroGraphTemporalSignal, 30 | DynamicHeteroGraphTemporalSignalBatch, 31 | DynamicHeteroGraphStaticSignal, 32 | DynamicHeteroGraphStaticSignalBatch, 33 | ] 34 | 35 | 36 | def temporal_signal_split( 37 | data_iterator, train_ratio: float = 0.8 38 | ) -> Tuple[Discrete_Signal, Discrete_Signal]: 39 | r"""Function to split a data iterator according to a fixed ratio. 40 | 41 | Arg types: 42 | * **data_iterator** *(Signal Iterator)* - Node features. 43 | * **train_ratio** *(float)* - Graph edge indices. 44 | 45 | Return types: 46 | * **(train_iterator, test_iterator)** *(tuple of Signal Iterators)* - Train and test data iterators. 47 | """ 48 | 49 | train_snapshots = int(train_ratio * data_iterator.snapshot_count) 50 | 51 | train_iterator = data_iterator[0:train_snapshots] 52 | test_iterator = data_iterator[train_snapshots:] 53 | 54 | return train_iterator, test_iterator 55 | --------------------------------------------------------------------------------