├── .gitkeep ├── dgnn ├── train │ ├── old │ │ ├── paravg_plus.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-37.pyc │ │ │ ├── base.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── distributed.cpython-37.pyc │ │ │ ├── distributed.cpython-38.pyc │ │ │ ├── serial_fedavg.cpython-38.pyc │ │ │ └── serial_paravg.cpython-38.pyc │ │ ├── __init__.py │ │ ├── historic.py │ │ ├── paravg.py │ │ ├── base.py │ │ ├── serial_fullserver.py │ │ ├── serial_fedavg.py │ │ ├── serial_sampleserver.py │ │ └── serial_paravg.py │ ├── __pycache__ │ │ ├── base.cpython-37.pyc │ │ ├── base.cpython-38.pyc │ │ ├── full.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── historic.cpython-37.pyc │ │ ├── historic.cpython-38.pyc │ │ ├── nocomm.cpython-38.pyc │ │ ├── paravg.cpython-37.pyc │ │ ├── paravg.cpython-38.pyc │ │ ├── sampling.cpython-38.pyc │ │ ├── distributed.cpython-37.pyc │ │ ├── distributed.cpython-38.pyc │ │ ├── paravg_plus.cpython-38.pyc │ │ ├── serial_fedavg.cpython-37.pyc │ │ ├── serial_fedavg.cpython-38.pyc │ │ ├── serial_gravg.cpython-37.pyc │ │ ├── serial_gravg.cpython-38.pyc │ │ ├── serial_paravg.cpython-37.pyc │ │ ├── serial_paravg.cpython-38.pyc │ │ ├── serialparavg.cpython-38.pyc │ │ ├── serial_fullserver.cpython-37.pyc │ │ ├── serial_fullserver.cpython-38.pyc │ │ ├── serial_sampleserver.cpython-37.pyc │ │ ├── serial_sampleserver.cpython-38.pyc │ │ ├── serial_fedavg_correct.cpython-37.pyc │ │ └── serial_fedavg_correct.cpython-38.pyc │ ├── dist │ │ ├── __pycache__ │ │ │ ├── dgl.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── distgnn.cpython-38.pyc │ │ │ ├── distgnn_full.cpython-38.pyc │ │ │ └── distgnn_correction.cpython-38.pyc │ │ ├── __init__.py │ │ └── dgl.py │ ├── serial │ │ ├── __pycache__ │ │ │ ├── base.cpython-37.pyc │ │ │ ├── base.cpython-38.pyc │ │ │ ├── full.cpython-37.pyc │ │ │ ├── full.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── distgnn.cpython-37.pyc │ │ │ ├── distgnn.cpython-38.pyc │ │ │ ├── sampling.cpython-38.pyc │ │ │ ├── distgnn_full.cpython-38.pyc │ │ │ ├── distgnn_stale.cpython-38.pyc │ │ │ ├── distgnn_correction.cpython-38.pyc │ │ │ └── distgnn_full_correction.cpython-38.pyc │ │ ├── __init__.py │ │ ├── distgnn_stale.py │ │ ├── distgnn_correction.py │ │ ├── distgnn.py │ │ ├── distgnn_full_correction.py │ │ └── distgnn_full.py │ ├── __init__.py │ ├── sampling.py │ ├── full.py │ └── base.py ├── utils │ ├── cython │ │ ├── __init__.py │ │ ├── extension │ │ │ ├── __init__.py │ │ │ ├── utils │ │ │ │ ├── __init__.pxd │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ └── __init__.cpython-38.pyc │ │ │ │ ├── array.pxd │ │ │ │ └── array.pyx │ │ │ ├── .gitignore │ │ │ ├── sparse │ │ │ │ ├── __init__.py │ │ │ │ └── __pycache__ │ │ │ │ │ └── __init__.cpython-38.pyc │ │ │ └── __pycache__ │ │ │ │ └── __init__.cpython-38.pyc │ │ ├── __pycache__ │ │ │ └── __init__.cpython-38.pyc │ │ ├── makefile │ │ ├── build │ │ │ └── temp.linux-x86_64-3.8 │ │ │ │ └── extension │ │ │ │ ├── utils │ │ │ │ └── array.o │ │ │ │ └── sparse │ │ │ │ └── sample_neighbors.o │ │ ├── README.md │ │ └── setup.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── stats.cpython-37.pyc │ │ ├── stats.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── helpers.cpython-37.pyc │ │ ├── helpers.cpython-38.pyc │ │ ├── dist_operations.cpython-37.pyc │ │ └── dist_operations.cpython-38.pyc │ ├── helpers.py │ ├── dist_operations.py │ ├── stats.py │ └── config.py ├── data │ ├── samplers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── minibatch.cpython-37.pyc │ │ │ ├── minibatch.cpython-38.pyc │ │ │ ├── neighbors.cpython-37.pyc │ │ │ ├── neighbors.cpython-38.pyc │ │ │ ├── subgraph.cpython-37.pyc │ │ │ └── subgraph.cpython-38.pyc │ │ ├── subgraph.py │ │ ├── neighbors.py │ │ └── minibatch.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── partition.cpython-37.pyc │ │ ├── partition.cpython-38.pyc │ │ ├── nodeblocks.cpython-37.pyc │ │ ├── nodeblocks.cpython-38.pyc │ │ └── transforms.cpython-38.pyc │ ├── partition │ │ ├── __pycache__ │ │ │ ├── metis.cpython-38.pyc │ │ │ ├── random.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── overhead.cpython-38.pyc │ │ ├── metis_overhead.py │ │ ├── random.py │ │ ├── metis.py │ │ ├── overhead.py │ │ └── __init__.py │ ├── __init__.py │ ├── nodeblocks.py │ ├── transforms.py │ └── dataset.py ├── layers │ ├── __pycache__ │ │ ├── mlp.cpython-37.pyc │ │ ├── mlp.cpython-38.pyc │ │ ├── gatconv.cpython-38.pyc │ │ ├── gconv.cpython-37.pyc │ │ ├── gconv.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── appnpconv.cpython-38.pyc │ │ ├── dh_gconv.cpython-37.pyc │ │ ├── dh_gconv.cpython-38.pyc │ │ ├── residual.cpython-38.pyc │ │ ├── sage_conv.cpython-38.pyc │ │ ├── dist_gconv.cpython-37.pyc │ │ └── dist_gconv.cpython-38.pyc │ ├── residual.py │ ├── __init__.py │ ├── mlp.py │ ├── appnpconv.py │ ├── sage_conv.py │ ├── gconv.py │ ├── dist_gconv.py │ ├── gatconv.py │ └── dh_gconv.py └── models │ ├── __pycache__ │ ├── gat.cpython-38.pyc │ ├── gcn.cpython-37.pyc │ ├── gcn.cpython-38.pyc │ ├── appnp.cpython-38.pyc │ ├── custom.cpython-38.pyc │ ├── dh_gcn.cpython-37.pyc │ ├── dh_gcn.cpython-38.pyc │ ├── res_gcn.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── dist_gcn.cpython-37.pyc │ └── dist_gcn.cpython-38.pyc │ ├── __init__.py │ ├── dist_gcn.py │ ├── dh_gcn.py │ ├── res_gcn.py │ ├── appnp.py │ ├── gat.py │ ├── gcn.py │ └── custom.py ├── .DS_Store ├── figures ├── psgd-exp-arxiv.pdf ├── srv-mb-arxiv.pdf ├── srv-mb-reddit.pdf ├── psgd-exp-reddit.pdf ├── ogb-mag240m-acc-vs-comm.pdf └── ogb-products-acc-vs-comm.pdf ├── scripts ├── configs │ ├── mag.json │ ├── papers.json │ ├── arxiv3.json │ ├── yelp.json │ ├── arxiv2.json │ ├── flickr.json │ ├── flickr2.json │ ├── proteins2.json │ ├── cora.json │ ├── reddit3.json │ ├── products.json │ ├── proteins.json │ ├── reddit2.json │ ├── arxiv.json │ └── reddit.json ├── partition.py └── run-config.py ├── requirements.txt └── README.md /.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dgnn/train/old/paravg_plus.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dgnn/utils/cython/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/utils/__init__.pxd: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/.gitignore: -------------------------------------------------------------------------------- 1 | *.c 2 | *.cpp 3 | *.so 4 | *.html -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/.DS_Store -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_neighbors import * 2 | 3 | -------------------------------------------------------------------------------- /dgnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .dist_operations import * 3 | from .stats import Stats -------------------------------------------------------------------------------- /figures/psgd-exp-arxiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/psgd-exp-arxiv.pdf -------------------------------------------------------------------------------- /figures/srv-mb-arxiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/srv-mb-arxiv.pdf -------------------------------------------------------------------------------- /figures/srv-mb-reddit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/srv-mb-reddit.pdf -------------------------------------------------------------------------------- /dgnn/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .subgraph import SubGraphSampler 2 | from .neighbors import NeighborSampler -------------------------------------------------------------------------------- /figures/psgd-exp-reddit.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/psgd-exp-reddit.pdf -------------------------------------------------------------------------------- /figures/ogb-mag240m-acc-vs-comm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/ogb-mag240m-acc-vs-comm.pdf -------------------------------------------------------------------------------- /figures/ogb-products-acc-vs-comm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/figures/ogb-products-acc-vs-comm.pdf -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/gat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/gat.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/gcn.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/gcn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/full.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/full.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/stats.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/stats.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/stats.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/partition.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/partition.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/partition.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/partition.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/gatconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/gatconv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/gconv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/gconv.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/gconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/gconv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/appnp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/appnp.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/custom.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/custom.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/dh_gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/dh_gcn.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/dh_gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/dh_gcn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/res_gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/res_gcn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/historic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/historic.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/historic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/historic.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/nocomm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/nocomm.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/paravg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/paravg.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/paravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/paravg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/sampling.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/dist/__pycache__/dgl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/dist/__pycache__/dgl.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/nodeblocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/nodeblocks.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/nodeblocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/nodeblocks.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/appnpconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/appnpconv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/dh_gconv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/dh_gconv.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/dh_gconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/dh_gconv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/residual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/residual.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/sage_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/sage_conv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/dist_gcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/dist_gcn.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/models/__pycache__/dist_gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/models/__pycache__/dist_gcn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/dist_gconv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/dist_gconv.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/layers/__pycache__/dist_gconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/layers/__pycache__/dist_gconv.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/paravg_plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/paravg_plus.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fedavg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fedavg.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fedavg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fedavg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_gravg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_gravg.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_gravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_gravg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_paravg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_paravg.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_paravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_paravg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serialparavg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serialparavg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/dist/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/dist/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/dist/__pycache__/distgnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/dist/__pycache__/distgnn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/full.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/full.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/full.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/full.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/partition/__pycache__/metis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/partition/__pycache__/metis.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/partition/__pycache__/random.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/partition/__pycache__/random.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/sampling.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/dist_operations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/dist_operations.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/utils/__pycache__/dist_operations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/__pycache__/dist_operations.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/cython/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/cython/makefile: -------------------------------------------------------------------------------- 1 | 2 | all: 3 | python setup.py build_ext --inplace 4 | 5 | clean: 6 | rm -rf build/ *.so *.cpp 7 | rm -rf */*/*.cpp */*/*.c */*/*.so -------------------------------------------------------------------------------- /dgnn/data/partition/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/partition/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/partition/__pycache__/overhead.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/partition/__pycache__/overhead.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/minibatch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/minibatch.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/minibatch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/minibatch.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/neighbors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/neighbors.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/neighbors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/neighbors.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/subgraph.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/subgraph.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/data/samplers/__pycache__/subgraph.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/data/samplers/__pycache__/subgraph.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fullserver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fullserver.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fullserver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fullserver.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/dist/__pycache__/distgnn_full.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/dist/__pycache__/distgnn_full.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/serial_fedavg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/serial_fedavg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/old/__pycache__/serial_paravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/old/__pycache__/serial_paravg.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_sampleserver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_sampleserver.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_sampleserver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_sampleserver.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn_full.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn_full.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn_stale.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn_stale.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fedavg_correct.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fedavg_correct.cpython-37.pyc -------------------------------------------------------------------------------- /dgnn/train/__pycache__/serial_fedavg_correct.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/__pycache__/serial_fedavg_correct.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/dist/__pycache__/distgnn_correction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/dist/__pycache__/distgnn_correction.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset 3 | from .nodeblocks import NodeBlocks 4 | from .transforms import * 5 | from .partition import * 6 | from .samplers import * -------------------------------------------------------------------------------- /dgnn/train/dist/__init__.py: -------------------------------------------------------------------------------- 1 | from .distgnn import DistGNN 2 | from .distgnn_correction import DistGNNCorrection 3 | from .distgnn_full import DistGNNFull 4 | from .dgl import DistDGL -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn_correction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn_correction.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/extension/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Base 2 | from .full import Full 3 | from .sampling import Sampling 4 | 5 | from .serial import * 6 | from .dist import * 7 | from .old import * -------------------------------------------------------------------------------- /dgnn/train/serial/__pycache__/distgnn_full_correction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/train/serial/__pycache__/distgnn_full_correction.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/cython/build/temp.linux-x86_64-3.8/extension/utils/array.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/build/temp.linux-x86_64-3.8/extension/utils/array.o -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/sparse/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/extension/sparse/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/extension/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/configs/mag.json: -------------------------------------------------------------------------------- 1 | // Mag, 40? 2 | { 3 | "dataset": "mag", 4 | "loss": "bceloss", 5 | "num_layers": 3, 6 | "hidden_size": 256, 7 | "num_epochs": 500, 8 | "layer_norm": false 9 | } -------------------------------------------------------------------------------- /dgnn/utils/cython/build/temp.linux-x86_64-3.8/extension/sparse/sample_neighbors.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MortezaRamezani/llcg/HEAD/dgnn/utils/cython/build/temp.linux-x86_64-3.8/extension/sparse/sample_neighbors.o -------------------------------------------------------------------------------- /dgnn/train/serial/__init__.py: -------------------------------------------------------------------------------- 1 | from .distgnn import DistGNN 2 | from .distgnn_full import DistGNNFull 3 | from .distgnn_correction import DistGNNCorr 4 | from .distgnn_full_correction import DistGNNFullCorr 5 | from .distgnn_stale import DistGNNStale -------------------------------------------------------------------------------- /dgnn/utils/cython/README.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | Later make a seperate package from this code and change the cimport files accordingly 4 | 5 | from ..utils import array 6 | 7 | to 8 | 9 | from cysparse.utils import array 10 | 11 | and other places as well -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torch-geometric==1.7.0 # from source 3 | torch-cluster==1.5.8 # from source 4 | torch-scatter==2.0.5 # from source 5 | torch-sparse==0.6.8 # from source, compiled with metis enable 6 | ogb==1.3.1 7 | tqdm==4.54.1 8 | commentjson==0.9.0 -------------------------------------------------------------------------------- /dgnn/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def rank2dev(rank, num_gpus): 4 | if num_gpus == 0: 5 | device = torch.device('cpu') 6 | else: 7 | dev_id = rank % num_gpus 8 | device = torch.device('cuda:{}'.format(dev_id)) 9 | return device -------------------------------------------------------------------------------- /scripts/configs/papers.json: -------------------------------------------------------------------------------- 1 | // Papers100M, 2 | { 3 | "dataset": "papers100M", 4 | "num_layers": 3, 5 | "hidden_size": 256, 6 | "dropout": 0.5, 7 | "layer_norm": false, 8 | "cpu_val": true, 9 | "lr": 1e-2, 10 | 11 | "num_epochs": 500, 12 | "val_patience": 50, 13 | 14 | // Local Sampler 15 | "num_samplers" : 1, 16 | "sampler": "neighbor", 17 | "local_updates": 8, 18 | "minibatch_size": 8192, 19 | "num_neighbors": [10, 10, 10], 20 | } -------------------------------------------------------------------------------- /dgnn/train/old/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Base 2 | from .distributed import Distributed 3 | # from .historic import Historic 4 | # from .paravg import ParamsAvg 5 | from .serial_paravg import SerializedParamsAvg 6 | # from .serial_gravg import SerializedGrAvg 7 | # # from .paravg_plus import ParamsAvgPlus 8 | # from .serial_fullserver import SerializedFullServer 9 | # from .serial_sampleserver import SerializedSampleServer 10 | from .serial_fedavg import SerializedFedAvg 11 | # from .serial_fedavg_correct import SerializedFedAvgCorrection -------------------------------------------------------------------------------- /dgnn/layers/residual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | class ResidualLayer(nn.Module): 8 | 9 | def __init__(self, 10 | *args, 11 | **kwargs, 12 | ): 13 | super().__init__() 14 | 15 | self.residual = None 16 | 17 | if 'layer_id' in kwargs: 18 | self.layer_id = kwargs['layer_id'] 19 | 20 | def forward(self, h, *args): 21 | self.residual = h 22 | return h 23 | 24 | def __repr__(self): 25 | return self.__class__.__name__ + f"[{self.layer_id}]" -------------------------------------------------------------------------------- /dgnn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # from .gconv import GConv 2 | from .dist_gconv import DistGConv 3 | from .dh_gconv import DHGConv 4 | from .gconv import GConv 5 | from .mlp import MLPLayer 6 | from .sage_conv import SAGEConv 7 | from .residual import ResidualLayer 8 | from .gatconv import GATConv 9 | from .appnpconv import APPNPConv 10 | 11 | def layer_selector(layer_str): 12 | # Selecting the layer 13 | if layer_str == 'gconv': 14 | layer = GConv 15 | elif layer_str == 'gatconv': 16 | layer = GATConv 17 | elif layer_str == 'appnpconv': 18 | layer = APPNPConv 19 | elif layer_str == 'mlp': 20 | layer = MLPLayer 21 | elif layer_str == 'sageconv': 22 | layer = SAGEConv 23 | else: 24 | return NotImplementedError 25 | return layer -------------------------------------------------------------------------------- /dgnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import GCN 2 | from .res_gcn import ResGCN 3 | from .custom import Custom 4 | from .dist_gcn import DistGCN 5 | from .dh_gcn import DHGCN 6 | from .gat import GAT 7 | from .appnp import APPNP 8 | 9 | def model_selector(model_str): 10 | # Selecting Model 11 | if model_str == 'gcn': 12 | model = GCN 13 | elif model_str == 'appnp': 14 | model = APPNP 15 | elif model_str == 'gat': 16 | model = GAT 17 | elif model_str == 'resgcn': 18 | model = ResGCN 19 | elif model_str == 'custom': 20 | model = Custom 21 | elif model_str == 'distgcn': 22 | model = DistGCN 23 | elif model_str == 'dhgcn': 24 | model = DHGCN 25 | else: 26 | return NotImplementedError 27 | 28 | return model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed Graph Neural Network Learning 2 | 3 | ## Install Metis 4 | 5 | download and extract metis: 6 | 7 | http://glaros.dtc.umn.edu/gkhome/metis/metis/download 8 | gunzip metis-5.x.y.tar.gz 9 | tar -xvf metis-5.x.y.tar 10 | 11 | if you don't have root access, compile metis with `prefix` 12 | 13 | make config prefix= 14 | 15 | 16 | ## Build PyTorch Sparse 17 | 18 | make Pytorch Sparse with metis support: 19 | 20 | cd pytorch_sparse 21 | export CPATH=/export/local/mfr5226/lib/metis/include/:/usr/local/cuda/include/:$CPATH 22 | export LD_LIBRARY_PATH=/export/local/mfr5226/lib/metis/lib/:$LD_LIBRARY_PATH 23 | WITH_METIS=1 LDFLAGS='-L/export/local/mfr5226/lib/metis/lib/' python setup.py build -j 10 24 | python setup.py install 25 | 26 | If you have metis in regular location, just compile with: 27 | WITH_METIS=1 python setup.py install -j 10 -------------------------------------------------------------------------------- /scripts/configs/arxiv3.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "arxiv", 3 | "num_layers": 3, 4 | "hidden_size": 256, 5 | "layer_norm": true, 6 | "dropout": 0.5, 7 | "val_patience": 1500, 8 | "lr": 1e-3, 9 | 10 | // Local Sampler 11 | "num_samplers": 4, 12 | "sampler": "neighbor", 13 | "num_neighbors": [10, 10, 10], 14 | // "num_neighbors": [1, 1, 1], 15 | // "num_neighbors": [-1, -1, -1], 16 | 17 | "num_epochs": 100, 18 | "local_updates": 64, 19 | "minibatch_size": 256, 20 | "rho": 1.0, 21 | "inc_k": false, 22 | 23 | // for Correction 24 | "server_sampler": "neighbor", 25 | "server_minibatch_size": 2048, 26 | "server_num_neighbors": [10,10,10], 27 | // "server_num_neighbors": [-1,-1,-1], 28 | "server_minibatch": "random", 29 | "server_updates": 1, 30 | "server_lr": 1e-2, 31 | "server_start_epoch": 0, 32 | 33 | } 34 | -------------------------------------------------------------------------------- /dgnn/layers/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | class MLPLayer(nn.Module): 8 | 9 | def __init__(self, 10 | input_dim, 11 | output_dim, 12 | *args, 13 | **kwargs, 14 | ): 15 | super().__init__() 16 | 17 | self.input_dim = input_dim 18 | self.output_dim = output_dim 19 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 20 | 21 | if 'layer_id' in kwargs: 22 | self.layer_id = kwargs['layer_id'] 23 | 24 | def forward(self, adj, h, *args): 25 | 26 | h = self.linear(h) 27 | return h 28 | 29 | def __repr__(self): 30 | return self.__class__.__name__ + "[{}] ({}->{})".format( 31 | self.layer_id, 32 | self.input_dim, 33 | self.output_dim) -------------------------------------------------------------------------------- /scripts/configs/yelp.json: -------------------------------------------------------------------------------- 1 | // Yelp, 65% not yet! 2 | { 3 | "dataset": "yelp", 4 | "loss": "bceloss", 5 | "model": "custom", 6 | "arch": "ssl", 7 | "num_layers": 2, 8 | "hidden_size": 512, 9 | "layer_norm": true, 10 | "input_norm": true, 11 | "dropout": 0.1, 12 | "lr": 1e-2, 13 | "val_patience": 500, 14 | 15 | "num_samplers": 4, 16 | "sampler": "neighbor", 17 | "num_neighbors": [10,10], 18 | 19 | // "num_epochs": 100, 20 | // "val_step": 1, 21 | 22 | // Local Sampler 23 | "minibatch_size": 1024, 24 | "num_epochs": 100, 25 | "local_updates": 16, 26 | 27 | // Server Correction 28 | "server_sampler": "neighbor", 29 | "server_minibatch_size": 4096, 30 | "server_num_neighbors": [10,10], 31 | // "server_minibatch": "stratified", 32 | "server_minibatch": "random", 33 | "server_updates": 1, 34 | } -------------------------------------------------------------------------------- /scripts/configs/arxiv2.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "arxiv", 3 | "num_layers": 3, 4 | "hidden_size": 256, 5 | "layer_norm": true, 6 | "dropout": 0.5, 7 | "val_patience": 1500, 8 | "lr": 1e-3, 9 | 10 | // Local Sampler 11 | "num_samplers": 4, 12 | "sampler": "neighbor", 13 | "num_neighbors": [10, 10, 10], 14 | 15 | "num_epochs": 100, 16 | "local_updates": 64, 17 | "minibatch_size": 256, 18 | "rho": 1.0, 19 | "inc_k": false, 20 | 21 | // for Correction 22 | "server_sampler": "neighbor", 23 | "server_minibatch_size": 2048, 24 | "server_num_neighbors": [10,10,10], 25 | "server_minibatch": "random", 26 | "server_updates": 1, 27 | "server_lr": 1e-2, 28 | "server_start_epoch": 0, 29 | 30 | // "model": "gat", 31 | // "layer": "gatconv", 32 | // "model": "appnp", 33 | // "layer": "appnpconv", 34 | // "num_epochs": 5, 35 | } 36 | -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/utils/array.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -std=c++11 4 | # distutils: extra_link_args = 5 | 6 | from libcpp.vector cimport vector 7 | from cython cimport Py_buffer 8 | 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | cdef void npy2vec_int(np.ndarray[int, ndim=1, mode='c'] nda, vector[int] & vec) 13 | cdef void npy2vec_long(np.ndarray[long, ndim=1, mode='c'] nda, vector[long] & vec) 14 | # cdef void npy2vec_float(np.ndarray[int, ndim=1, mode='c'] nda, vector[int] & vec) 15 | # cdef void npy2vec_double(np.ndarray[int, ndim=1, mode='c'] nda, vector[int] & vec) 16 | 17 | 18 | # https://stackoverflow.com/questions/45133276/passing-c-vector-to-numpy-through-cython-without-copying-and-taking-care-of-me 19 | cdef class ArrayWrapperInt: 20 | cdef vector[int] vec 21 | cdef Py_ssize_t shape[1] 22 | cdef Py_ssize_t strides[1] 23 | cdef void set_data(self, vector[int] & data) 24 | -------------------------------------------------------------------------------- /scripts/configs/flickr.json: -------------------------------------------------------------------------------- 1 | // Flickr, 52% 2 | { 3 | "dataset": "flickr", 4 | "num_layers": 2, 5 | "model": "custom", 6 | "arch": "ssl", 7 | "hidden_size": 256, 8 | "dropout": 0.2, 9 | "input_norm": true, 10 | "layer_norm": true, 11 | 12 | "val_patience": 500, 13 | 14 | 15 | // Local Sampler 16 | "num_sampler": 4, 17 | "sampler": "neighbor", 18 | "num_neighbors": [10,10], 19 | 20 | "lr": 1e-3, 21 | // "minibatch_size": 128, 22 | // "local_updates": 1, 23 | 24 | "num_epochs": 50, 25 | "minibatch_size": 64, 26 | "local_updates": 8, 27 | "rho": 1.0, 28 | "inc_k": true, 29 | 30 | // Server Correction 31 | "server_sampler": "neighbor", 32 | "server_num_neighbors": [10, 10], 33 | "server_updates": 1, 34 | "server_minibatch_size": 512, 35 | "server_minibatch": "random", 36 | "server_start_epoch": 0, 37 | "server_lr": 5e-3, 38 | 39 | 40 | // "model": "gat", 41 | // "layer": "gatconv", 42 | // "model": "appnp", 43 | // "layer": "appnpconv", 44 | "num_epochs": 5, 45 | } -------------------------------------------------------------------------------- /scripts/configs/flickr2.json: -------------------------------------------------------------------------------- 1 | // Flickr, 52% 2 | { 3 | "dataset": "flickr", 4 | "num_layers": 2, 5 | "model": "custom", 6 | "arch": "ssl", 7 | "hidden_size": 256, 8 | "dropout": 0.2, 9 | "input_norm": true, 10 | "layer_norm": true, 11 | 12 | "val_patience": 500, 13 | 14 | 15 | // Local Sampler 16 | "num_sampler": 4, 17 | "sampler": "neighbor", 18 | "num_neighbors": [10,10], 19 | 20 | "lr": 1e-3, 21 | // "minibatch_size": 128, 22 | // "local_updates": 1, 23 | 24 | "num_epochs": 50, 25 | "minibatch_size": 64, 26 | "local_updates": 8, 27 | "rho": 1.0, 28 | "inc_k": true, 29 | 30 | // Server Correction 31 | "server_sampler": "neighbor", 32 | "server_num_neighbors": [10, 10], 33 | "server_updates": 1, 34 | "server_minibatch_size": 512, 35 | "server_minibatch": "random", 36 | "server_start_epoch": 0, 37 | "server_lr": 5e-3, 38 | 39 | 40 | // "model": "gat", 41 | // "layer": "gatconv", 42 | // "model": "appnp", 43 | // "layer": "appnpconv", 44 | // "num_epochs": 5, 45 | } -------------------------------------------------------------------------------- /scripts/partition.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import sys 5 | import argparse 6 | 7 | sys.path.append('..') 8 | sys.path.append('../dgnn/utils/cython/') 9 | from dgnn import data 10 | import dgnn.data.partition as P 11 | 12 | if os.environ['LOGNAME'] == 'mfr5226': 13 | os.environ['GNN_DATASET_DIR'] = '/export/local/mfr5226/datasets/pyg_dist/' 14 | else: 15 | os.environ['GNN_DATASET_DIR'] = '/home/weilin/Downloads/GCN_datasets/' 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description='') 20 | 21 | parser.add_argument('--dataset', type=str, default='cora') 22 | parser.add_argument('--num-parts', type=int, default=4) 23 | parser.add_argument('--mode', type=str, default='random') 24 | parser.add_argument('--overhead', type=int, default=10) 25 | 26 | 27 | config = parser.parse_args() 28 | 29 | dataset = data.Dataset(config.dataset) 30 | if config.mode == 'random': 31 | P.random(dataset, num_parts=config.num_parts) 32 | elif config.mode == 'metis': 33 | P.metis(dataset, num_parts=config.num_parts) 34 | elif config.mode == 'overhead': 35 | P.overhead(dataset, num_parts=config.num_parts, overhead=config.overhead) -------------------------------------------------------------------------------- /dgnn/models/dist_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..layers import DistGConv 4 | 5 | 6 | class DistGCN(torch.nn.Module): 7 | 8 | def __init__(self, 9 | input_size, 10 | hidden_size, 11 | output_size, 12 | num_layers, 13 | activation, 14 | ): 15 | super().__init__() 16 | 17 | self.num_layers = num_layers 18 | self.layers = torch.nn.ModuleList() 19 | self.activation = activation 20 | 21 | self.layers.append(DistGConv(input_size, hidden_size)) 22 | for _ in range(1, num_layers-1): 23 | self.layers.append(DistGConv(hidden_size, hidden_size)) 24 | self.layers.append(DistGConv(hidden_size, output_size)) 25 | 26 | self.rank = 0 27 | 28 | def update_rank(self, rank): 29 | 30 | self.rank = rank 31 | for layer in self.layers: 32 | layer.rank = rank 33 | 34 | def forward(self, x, adj, *args, **kwargs): 35 | 36 | h = x 37 | for i, layer in enumerate(self.layers): 38 | h = layer(h, adj) 39 | if i < self.num_layers - 1: 40 | h = self.activation(h) 41 | 42 | return h -------------------------------------------------------------------------------- /scripts/configs/proteins2.json: -------------------------------------------------------------------------------- 1 | // # Protein: 2 | // 74% gcn, 3, 256 3 | // 77% sage, 3, 256 4 | // 79.41% SAGE 5 | { 6 | "dataset": "proteins", 7 | "loss": "bceloss", 8 | "layer": "sageconv", 9 | "num_layers": 3, 10 | "hidden_size": 256, 11 | "input_norm": false, 12 | "layer_norm": false, 13 | 14 | // "num_epochs": 1000, 15 | // "val_step": 5, 16 | 17 | "val_patience": 500, 18 | 19 | // Local Sampler 20 | "num_samplers": 4, 21 | "sampler": "neighbor", 22 | "num_neighbors": [10, 10, 10], 23 | "lr": 1e-3, 24 | 25 | // "num_epochs": 200, 26 | // "local_updates": 1, 27 | 28 | "num_epochs": 100, 29 | "local_updates": 32, 30 | "minibatch_size": 512, 31 | "rho": 1.00, 32 | "inc_k": false, 33 | 34 | // Server Correction 35 | "server_sampler": "neighbor", 36 | "server_minibatch": "random", 37 | "server_num_neighbors": [10, 10, 10], 38 | "server_updates": 1, 39 | "server_minibatch_size": 4096, 40 | "server_start_epoch": 0, 41 | "server_lr": 2e-2, 42 | 43 | // "model": "gat", 44 | // "layer": "gatconv", 45 | "model": "appnp", 46 | "layer": "appnpconv", 47 | // "num_epochs": 5, 48 | 49 | 50 | } 51 | -------------------------------------------------------------------------------- /scripts/configs/cora.json: -------------------------------------------------------------------------------- 1 | // Cora, 87% 2 | { 3 | "dataset": "cora", 4 | "num_layers": 2, 5 | "hidden_size": 64, 6 | "val_patience": 100, 7 | "lr": 1e-3, 8 | 9 | // "model": "gat", 10 | // "layer": "gatconv", 11 | 12 | "model": "appnp", 13 | "layer": "appnpconv", 14 | 15 | // Local Sampler 16 | "num_samplers": 4, 17 | "sampler": "neighbor", 18 | "num_neighbors": [10,10], 19 | 20 | "num_epochs": 100, 21 | "minibatch_size": 32, 22 | "local_updates": 16, 23 | 24 | // // for Correction 25 | // "server_sampler": "neighbor", 26 | // "server_updates": 1, 27 | // "server_minibatch_size": 256, 28 | // "server_num_neighbors": [10,10], 29 | // "server_minibatch": "random", 30 | // // "server_minibatch": "stratified", 31 | // "server_start_epoch": 0, 32 | // "server_lr": 5e-2, 33 | } 34 | 35 | // { 36 | // "dataset": "cora", 37 | // "num_layers": 2, 38 | // "hidden_size": 64, 39 | // "num_epochs": 30, 40 | // "val_patience": 100, 41 | // "lr": 1e-2, 42 | 43 | // // Local Sampler 44 | // "sampler": "neighbor", 45 | // "minibatch_size": 32, 46 | // "num_neighbors": [10,10], 47 | // "local_updates": 1, 48 | // } -------------------------------------------------------------------------------- /dgnn/layers/appnpconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_sparse import spmm 6 | import math 7 | 8 | 9 | class APPNPConv(nn.Module): 10 | """[summary] 11 | 12 | Arguments: 13 | nn {[type]} -- [description] 14 | """ 15 | 16 | def __init__(self, 17 | input_dim, 18 | output_dim, 19 | *args, 20 | **kwargs, 21 | ): 22 | super().__init__() 23 | 24 | self.input_dim = input_dim 25 | self.output_dim = output_dim 26 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 27 | 28 | # use to distinguish from other layers 29 | self.graph_layer = True 30 | 31 | self.layer_id = kwargs['layer_id'] if 'layer_id' in kwargs else '0' 32 | self.alpha = kwargs['alpha'] if 'alpha' in kwargs else 0.2 33 | 34 | def forward(self, adj, h, h0, *args): 35 | output = (1 - self.alpha) * adj.spmm(h) + self.alpha * h0[:adj.size(0)] 36 | return output 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + "[{}] ({}->{})".format( 40 | self.layer_id, 41 | self.input_dim, 42 | self.output_dim) -------------------------------------------------------------------------------- /scripts/configs/reddit3.json: -------------------------------------------------------------------------------- 1 | // Reddit, 96.44% @ Epoch #198, w/ Full 2 | // 96.6 w/ sss and layer norm and sampling 3 | // "arch": "slsl", 4 | // "num_layers": 2, 5 | { 6 | "dataset": "reddit", 7 | "model": "custom", 8 | "arch": "sss", 9 | "num_layers": 3, 10 | "hidden_size": 256, 11 | "layer_norm": true, 12 | "dropout": 0.5, 13 | 14 | // "lr": 1e-3, 15 | "val_patience": 500, 16 | 17 | // Local Sampler 18 | "num_samplers": 8, 19 | "sampler": "neighbor", 20 | // "num_neighbors": [10, 10, 10], 21 | // "num_neighbors": [-1, -1, -1], 22 | "num_neighbors": [1, 1, 1], 23 | 24 | // "num_epochs": 500, 25 | // "local_updates": 64, 26 | 27 | // "local_updates": 256, 28 | // "minibatch_size": 32, 29 | 30 | "num_epochs": 75, 31 | "local_updates": 64, 32 | "minibatch_size": 256, 33 | "rho": 1.1, 34 | 35 | // Server Correction 36 | "server_minibatch": "random", 37 | // "server_sampler": "subgraph", 38 | // "server_minibatch_size": 2048, 39 | // "server_num_neighbors": [-1, -1, -1], 40 | "server_sampler": "neighbor", 41 | "server_minibatch_size": 2048, 42 | "server_num_neighbors": [10, 10, 10], 43 | "server_updates": 2, 44 | "server_start_epoch": 0, 45 | "server_lr": 1e-2, 46 | } -------------------------------------------------------------------------------- /dgnn/models/dh_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..layers import DistGConv, DHGConv 4 | 5 | 6 | class DHGCN(torch.nn.Module): 7 | 8 | def __init__(self, 9 | input_size, 10 | hidden_size, 11 | output_size, 12 | num_layers, 13 | activation, 14 | ): 15 | super().__init__() 16 | 17 | self.num_layers = num_layers 18 | self.layers = torch.nn.ModuleList() 19 | self.activation = activation 20 | 21 | self.layers.append(DHGConv(input_size, hidden_size, 0, 0, num_layers)) 22 | for lid in range(1, num_layers-1): 23 | self.layers.append(DHGConv(hidden_size, hidden_size, 0, lid, num_layers)) 24 | self.layers.append(DHGConv(hidden_size, output_size, 0, num_layers-1, num_layers)) 25 | 26 | self.rank = 0 27 | 28 | def update_rank(self, rank): 29 | 30 | self.rank = rank 31 | for layer in self.layers: 32 | layer.rank = rank 33 | 34 | def forward(self, x, adj, use_hist=False, *args, **kwargs): 35 | 36 | h = x 37 | for i, layer in enumerate(self.layers): 38 | h = layer(h, adj, use_hist) 39 | if i < self.num_layers - 1: 40 | h = self.activation(h) 41 | 42 | return h -------------------------------------------------------------------------------- /dgnn/utils/cython/extension/utils/array.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -std=c++11 4 | # distutils: extra_link_args = 5 | 6 | 7 | cdef void npy2vec_int(np.ndarray[int, ndim=1, mode='c'] nda, vector[int] & vec): 8 | cdef int size = nda.size 9 | vec.assign(& (nda[0]), & (nda[0]) + size) 10 | 11 | cdef void npy2vec_long(np.ndarray[long, ndim=1, mode='c'] nda, vector[long] & vec): 12 | cdef int size = nda.size 13 | vec.assign(& (nda[0]), & (nda[0]) + size) 14 | 15 | cdef class ArrayWrapperInt: 16 | cdef void set_data(self, vector[int] & data): 17 | self.vec.swap(data) 18 | 19 | def __getbuffer__(self, Py_buffer * buffer, int flags): 20 | cdef Py_ssize_t itemsize = sizeof(self.vec[0]) 21 | self.shape[0] = self.vec.size() 22 | self.strides[0] = sizeof(int) 23 | buffer.buf = & (self.vec[0]) 24 | buffer.format = 'i' 25 | buffer.internal = NULL 26 | buffer.itemsize = itemsize 27 | buffer.len = self.vec.size() * itemsize 28 | buffer.ndim = 1 29 | buffer.obj = self 30 | buffer.readonly = 0 31 | buffer.shape = self.shape 32 | buffer.strides = self.strides 33 | buffer.suboffsets = NULL 34 | 35 | def __releasebuffer__(self, Py_buffer * buffer): 36 | pass 37 | -------------------------------------------------------------------------------- /dgnn/layers/sage_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_sparse import spmm 6 | import math 7 | 8 | class SAGEConv(nn.Module): 9 | """[summary] 10 | 11 | Arguments: 12 | nn {[type]} -- [description] 13 | """ 14 | 15 | def __init__(self, 16 | input_dim, 17 | output_dim, 18 | *args, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | 23 | self.input_dim = input_dim 24 | self.output_dim = output_dim 25 | self.linear_self = nn.Linear(input_dim, output_dim, bias=True) 26 | self.linear_neighbors = nn.Linear(input_dim, output_dim, bias=True) 27 | 28 | # use to distinguish from other layers 29 | self.graph_layer = True 30 | 31 | if 'layer_id' in kwargs: 32 | self.layer_id = kwargs['layer_id'] 33 | 34 | def forward(self, adj, h, *args): 35 | 36 | out_nodes = adj.size(0) 37 | self_h = self.linear_self(h[:out_nodes]) 38 | neighbor_h = adj.spmm(h) 39 | h = self_h + self.linear_neighbors(neighbor_h) 40 | return h 41 | 42 | def __repr__(self): 43 | return self.__class__.__name__ + "[{}] ({}->{})".format( 44 | self.layer_id, 45 | self.input_dim, 46 | self.output_dim) -------------------------------------------------------------------------------- /dgnn/data/partition/metis_overhead.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | def metis_overhead(dataset, num_parts, overhead=10): 8 | 9 | data = dataset[0] 10 | 11 | dist_dir = os.path.join(dataset.processed_dir, '../partitioned/') 12 | metis_meta_path = os.path.join(dist_dir, f'metis-{num_parts}/perm.pt') 13 | 14 | metis_perm, metis_partptr = torch.load(metis_meta_path) 15 | 16 | full_adj = data.adj_t 17 | start = 0 18 | for end in range(num_parts): 19 | import pdb; pdb.set_trace() 20 | 21 | part_idx = metis_perm[start:end] 22 | tmp_adj = full_adj[part_idx] 23 | tmp_row = torch.unique(tmp_adj.storage.row()) 24 | tmp_col = torch.unique(tmp_adj.storage.col()) 25 | tmp_diff = tmp_col[~tmp_col.unsqueeze(1).eq(tmp_row).any(1)] 26 | num_overhead = int(overhead * part_idx.size(0) /100) 27 | 28 | overhead_nodes = tmp_diff[:num_overhead] 29 | new_part_idx = torch.cat((part_idx, overhead_nodes)) 30 | 31 | 32 | part_adj = full_adj[part_idx, part_idx] 33 | part_feats = data.x[part_idx] 34 | part_labels = data.y[part_idx] 35 | part_train_mask = data.train_mask[part_idx] 36 | part_val_mask = data.val_mask[part_idx] 37 | part_test_mask = data.test_mask[part_idx] 38 | 39 | 40 | 41 | start = end 42 | 43 | -------------------------------------------------------------------------------- /dgnn/layers/gconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_sparse import spmm 6 | import math 7 | 8 | 9 | def glorot(tensor): 10 | if tensor is not None: 11 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 12 | tensor.data.uniform_(-stdv, stdv) 13 | 14 | class GConv(nn.Module): 15 | """[summary] 16 | 17 | Arguments: 18 | nn {[type]} -- [description] 19 | """ 20 | 21 | def __init__(self, 22 | input_dim, 23 | output_dim, 24 | *args, 25 | **kwargs, 26 | ): 27 | super().__init__() 28 | 29 | self.input_dim = input_dim 30 | self.output_dim = output_dim 31 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 32 | 33 | # use to distinguish from other layers 34 | self.graph_layer = True 35 | 36 | if 'layer_id' in kwargs: 37 | self.layer_id = kwargs['layer_id'] 38 | 39 | # self.linear.bias.data.fill_(0) 40 | # glorot(self.linear.weight) 41 | 42 | def forward(self, adj, h, *args): 43 | 44 | h = adj.spmm(h) 45 | h = self.linear(h) 46 | return h 47 | 48 | def __repr__(self): 49 | return self.__class__.__name__ + "[{}] ({}->{})".format( 50 | self.layer_id, 51 | self.input_dim, 52 | self.output_dim) -------------------------------------------------------------------------------- /scripts/configs/products.json: -------------------------------------------------------------------------------- 1 | // Products, 75.22 @ 496 full 2 | // 76.87 sampling 1k, k=8 and e=1000 3 | 4 | // { 5 | // "dataset": "products", 6 | // "num_layers": 3, 7 | // "hidden_size": 128, 8 | // "layer_norm": false, 9 | // "dropout": 0.5, 10 | // "lr": 1e-3, 11 | // "num_epochs": 1000, 12 | // "val_patience": 500, 13 | 14 | // // Local Sampler 15 | // "sampler": "neighbor", 16 | // // "minibatch_size": 4096, 17 | // "minibatch_size": 1024, 18 | // "num_neighbors": [10,10,10], 19 | // "local_updates": 8, 20 | // } 21 | 22 | 23 | { 24 | "dataset": "products", 25 | "num_layers": 3, 26 | 27 | "model": "custom", 28 | "arch": "ggg", 29 | 30 | "hidden_size": 128, 31 | "layer_norm": false, 32 | "dropout": 0.5, 33 | "lr": 1e-3, 34 | "val_patience": 500, 35 | 36 | "num_samplers": 4, 37 | "sampler": "neighbor", 38 | "num_neighbors": [10,10,10], 39 | 40 | // Local Sampler 41 | "num_epochs": 50, 42 | "minibatch_size": 1024, 43 | "local_updates": 16, 44 | 45 | // Server Correction 46 | "server_sampler": "neighbor", 47 | // "server_minibatch_size": 4096, 48 | "server_minibatch_size": 8192, 49 | "server_num_neighbors": [10,10,10], 50 | // "server_minibatch": "stratified", 51 | "server_minibatch": "random", 52 | "server_updates": 1, 53 | "server_start_epoch": 0, 54 | "server_lr": 1e-3, 55 | 56 | } -------------------------------------------------------------------------------- /scripts/configs/proteins.json: -------------------------------------------------------------------------------- 1 | // # Protein: 2 | // 74% gcn, 3, 256 3 | // 77% sage, 3, 256 4 | // 79.41% SAGE 5 | { 6 | "dataset": "proteins", 7 | "loss": "bceloss", 8 | "layer": "sageconv", 9 | "num_layers": 3, 10 | "hidden_size": 256, 11 | "input_norm": false, 12 | "layer_norm": false, 13 | 14 | // "num_epochs": 1000, 15 | // "val_step": 5, 16 | 17 | "val_patience": 500, 18 | 19 | // Local Sampler 20 | "num_samplers": 4, 21 | "sampler": "neighbor", 22 | "num_neighbors": [10, 10, 10], 23 | "lr": 1e-3, 24 | 25 | // "num_epochs": 200, 26 | // "local_updates": 1, 27 | 28 | "num_epochs": 100, 29 | "local_updates": 32, 30 | "minibatch_size": 512, 31 | "rho": 1.00, 32 | "inc_k": false, 33 | 34 | // Server Correction 35 | "server_sampler": "neighbor", 36 | "server_minibatch": "random", 37 | "server_num_neighbors": [10, 10, 10], 38 | "server_updates": 1, 39 | "server_minibatch_size": 4096, 40 | "server_start_epoch": 0, 41 | "server_lr": 2e-2, 42 | 43 | 44 | 45 | } 46 | 47 | // // Local Sampler 48 | // "sampler": "neighbor", 49 | // "minibatch_size": 4096, 50 | // "num_neighbors": [10,10,10], 51 | // "local_updates": 8, 52 | 53 | // // for Correction 54 | // "server_sampler": "neighbor", 55 | // "server_minibatch_size": 4096, 56 | // "server_num_neighbors": [10,10,10], 57 | // "server_minibatch": "random", 58 | // "server_updates": 1, 59 | // "server_start_epoch": 0, 60 | // "server_lr": 5e-2, 61 | -------------------------------------------------------------------------------- /dgnn/train/serial/distgnn_stale.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from . import DistGNNCorr 7 | from ...data import samplers 8 | 9 | 10 | class DistGNNStale(DistGNNCorr): 11 | 12 | def __init__(self, config, dataset): 13 | super().__init__(config, dataset) 14 | 15 | self.staled_model = None 16 | 17 | 18 | def server_average(self): 19 | self.staled_model = copy.deepcopy(self.model) 20 | DistGNNCorr.server_average(self) 21 | 22 | def server_correction(self): 23 | 24 | old_params = self.staled_model.state_dict() 25 | 26 | self.staled_model.train() 27 | 28 | for input_nid, nodeblocks, output_nid in self.server_trainloader: 29 | 30 | nodeblocks.to(self.device) 31 | features = self.full_features[input_nid] 32 | labels = self.full_labels[output_nid] 33 | train_mask = self.full_train_mask[output_nid] 34 | 35 | self.optimizer.zero_grad() 36 | output = self.staled_model(features, nodeblocks) 37 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 38 | loss.backward() 39 | self.optimizer.step() 40 | 41 | updated_params = self.staled_model.state_dict() 42 | new_params = self.model.state_dict() 43 | 44 | for params in new_params: 45 | new_params[params] = new_params[params] + (updated_params[params] - old_params[params]) 46 | 47 | self.model.load_state_dict(new_params) -------------------------------------------------------------------------------- /scripts/configs/reddit2.json: -------------------------------------------------------------------------------- 1 | // Reddit, 96.44% @ Epoch #198, w/ Full 2 | // 96.6 w/ sss and layer norm and sampling 3 | // "arch": "slsl", 4 | // "num_layers": 2, 5 | { 6 | "dataset": "reddit", 7 | "model": "custom", 8 | "arch": "sss", 9 | "num_layers": 3, 10 | "hidden_size": 256, 11 | "layer_norm": true, 12 | "dropout": 0.5, 13 | 14 | // "lr": 1e-3, 15 | "val_patience": 500, 16 | 17 | // "num_epochs": 1000, 18 | // "minibatch_size": 32, 19 | // "local_updates": 1, 20 | 21 | // correction alone test 22 | // "num_epochs": 50, 23 | // "minibatch_size": 256, 24 | // "local_updates": 1, 25 | // "lr": 1e-2, 26 | 27 | // Local Sampler 28 | "num_samplers": 4, 29 | "sampler": "neighbor", 30 | "num_neighbors": [10, 10, 10], 31 | 32 | // "num_epochs": 500, 33 | // "local_updates": 64, 34 | 35 | // "local_updates": 256, 36 | // "minibatch_size": 32, 37 | 38 | "num_epochs": 75, 39 | "local_updates": 64, 40 | "minibatch_size": 256, 41 | "rho": 1.1, 42 | 43 | // Server Correction 44 | "server_sampler": "neighbor", 45 | "server_minibatch": "random", 46 | "server_num_neighbors": [10, 10, 10], 47 | "server_updates": 1, 48 | // "server_minibatch_size": 256, 49 | "server_minibatch_size": 2048, 50 | "server_start_epoch": 0, 51 | "server_lr": 1e-2, 52 | 53 | "model": "gat", 54 | "layer": "gatconv", 55 | // "model": "appnp", 56 | // "layer": "appnpconv", 57 | "num_epochs": 5, 58 | } 59 | -------------------------------------------------------------------------------- /dgnn/utils/dist_operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | from torch_sparse.matmul import matmul 5 | from torch_sparse import spmm 6 | 7 | 8 | def dist_sum(grad_weight): 9 | dist.all_reduce(grad_weight, op=dist.ReduceOp.SUM) 10 | # print('gw', grad_weight.numel() * grad_weight.element_size()) 11 | return grad_weight 12 | 13 | 14 | def dist_spmm(adjs, inputs, rank, world_size): 15 | device = adjs[0].device() 16 | 17 | t_buf = torch.FloatTensor(adjs[0].size( 18 | 0), inputs.size(1)).fill_(0).to(device) 19 | 20 | # input_buf = torch.FloatTensor(adjs[0].size(0), inputs.size(1)).fill_(0).to(device) 21 | 22 | for i in range(world_size): 23 | if i == rank: 24 | input_buf = inputs.clone() 25 | else: 26 | # other parts maybe differnt sizes, otherwise broadcast freezes 27 | input_buf = torch.FloatTensor(adjs[i].size( 28 | 1), inputs.size(1)).fill_(0).to(device) 29 | 30 | dist.broadcast(input_buf, src=i) 31 | # print('ib', input_buf.numel()*input_buf.element_size() ) 32 | 33 | # buggy with CPU and GLOO 34 | t_buf += adjs[i].spmm(input_buf) 35 | # t_buf += matmul(adjs[i], input_buf) 36 | 37 | # new_adj = adjs[i].to_torch_sparse_coo_tensor() 38 | # t_buf += torch.spmm(new_adj, input_buf) 39 | 40 | # new_adj = adjs[i].to_torch_sparse_coo_tensor().coalesce() 41 | # t_buf += spmm(new_adj.indices(), new_adj.values(), 42 | # new_adj.size(0), new_adj.size(1), input_buf) 43 | 44 | # Slow 45 | # index = torch.stack([adjs[i].storage.row(), adjs[i].storage.col()], dim=0) 46 | # value = adjs[i].storage.value() 47 | # t_buf += spmm(index, value, adjs[i].size(0), adjs[i].size(1), input_buf) 48 | 49 | return t_buf 50 | -------------------------------------------------------------------------------- /dgnn/data/partition/random.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | import torch_geometric.transforms as T 8 | 9 | def random(dataset, num_parts): 10 | 11 | data = dataset[0] 12 | dist_dir = os.path.join(dataset.processed_dir, '../partitioned/') 13 | 14 | partitioned_dir = os.path.join(dist_dir, 'random-{}'.format(num_parts)) 15 | print(partitioned_dir) 16 | if not os.path.exists(partitioned_dir): 17 | os.mkdir(partitioned_dir) 18 | 19 | train_idx = data.train_mask.nonzero(as_tuple=True)[0] 20 | val_idx = data.val_mask.nonzero(as_tuple=True)[0] 21 | test_idx = data.test_mask.nonzero(as_tuple=True)[0] 22 | 23 | train_npp = math.ceil(train_idx.shape[0] / num_parts) 24 | val_npp = math.ceil(val_idx.shape[0] / num_parts) 25 | test_npp = math.ceil(test_idx.shape[0] / num_parts) 26 | 27 | train_parts = train_idx.split(train_npp) 28 | val_parts = val_idx.split(val_npp) 29 | test_parts = test_idx.split(test_npp) 30 | 31 | part_meta = [] 32 | 33 | for i in range(num_parts): 34 | 35 | part_idx = torch.cat((train_parts[i], val_parts[i], test_parts[i])) 36 | part_meta.append(part_idx) 37 | 38 | # Only saving the diagonal! Fix later! 39 | part_adj = data.adj_t[part_idx, part_idx] 40 | part_feats = data.x[part_idx] 41 | part_labels = data.y[part_idx] 42 | part_train_mask = data.train_mask[part_idx] 43 | part_val_mask = data.val_mask[part_idx] 44 | part_test_mask = data.test_mask[part_idx] 45 | 46 | torch.save(part_adj, partitioned_dir+'/adj_{}.pt'.format(i)) 47 | torch.save((part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask), 48 | partitioned_dir+'/fela_{}.pt'.format(i)) 49 | 50 | torch.save(part_meta, partitioned_dir+'/perm.pt') -------------------------------------------------------------------------------- /dgnn/models/res_gcn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import torch.nn.functional as f 6 | 7 | from ..layers import GConv 8 | from ..data import NodeBlocks 9 | 10 | from . import GCN 11 | 12 | 13 | class ResGCN(GCN): 14 | """ 15 | Residual GCN model 16 | """ 17 | 18 | def __init__(self, 19 | features_dim, 20 | hidden_dim, 21 | num_classes, 22 | num_layers, 23 | activation, 24 | layer=GConv, 25 | dropout=0, 26 | input_norm=False, 27 | layer_norm=False, 28 | *args, 29 | **kwargs): 30 | 31 | super().__init__(features_dim, 32 | hidden_dim, 33 | num_classes, 34 | num_layers, 35 | activation, 36 | layer, 37 | dropout, 38 | input_norm, 39 | layer_norm, 40 | *args, 41 | **kwargs) 42 | 43 | def forward(self, x, adjs): 44 | 45 | h = x 46 | adj = adjs 47 | gcl_cnt = 0 48 | h_res = None 49 | 50 | for i, layer in enumerate(self.layers): 51 | 52 | if hasattr(layer, 'graph_layer'): 53 | 54 | if gcl_cnt > 0: # TODO: don't clone if it's last gconv layer! 55 | h_res = h.clone() 56 | 57 | if isinstance(adjs, NodeBlocks): 58 | adj = adjs[gcl_cnt] 59 | 60 | h = layer(adj, h) 61 | gcl_cnt += 1 62 | 63 | else: 64 | h = layer(h) 65 | if i < len(self.layers) - 1 and gcl_cnt > 1 and isinstance(layer, type(self.activation)): 66 | # print('ResAdded') 67 | h = h + h_res 68 | 69 | return h 70 | -------------------------------------------------------------------------------- /scripts/configs/arxiv.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "arxiv", 3 | "num_layers": 3, 4 | "hidden_size": 256, 5 | "layer_norm": true, 6 | "dropout": 0.5, 7 | "val_patience": 1500, 8 | "lr": 1e-3, 9 | 10 | // Local Sampler 11 | "num_samplers": 4, 12 | "sampler": "neighbor", 13 | "num_neighbors": [10, 10, 10], 14 | 15 | "num_epochs": 100, 16 | "local_updates": 64, 17 | "minibatch_size": 256, 18 | "rho": 1.0, 19 | "inc_k": false, 20 | 21 | // for Correction 22 | "server_sampler": "neighbor", 23 | "server_minibatch_size": 2048, 24 | "server_num_neighbors": [10,10,10], 25 | "server_minibatch": "random", 26 | "server_updates": 1, 27 | "server_lr": 1e-2, 28 | "server_start_epoch": 0, 29 | 30 | } 31 | 32 | // // / Local Sampler 33 | // "sampler": "neighbor", 34 | // "num_neighbors": [-1,-1,-1], 35 | // // "minibatch_size": 512, 36 | // // "local_updates": 16, 37 | // "num_epochs": 5000, 38 | // "minibatch_size": 64, 39 | // "local_updates": 1, 40 | 41 | // Arxiv, 71% 42 | // { 43 | // "dataset": "arxiv", 44 | // "num_layers": 3, 45 | // "hidden_size": 256, 46 | // "num_epochs": 300, 47 | // "val_patience": 100, 48 | 49 | // // Local Sampler 50 | // "sampler": "neighbor", 51 | // "minibatch_size": 512, 52 | // "num_neighbors": [10,10,10], 53 | // "local_updates": 8, 54 | 55 | // "layer_norm": true, 56 | // "dropout": 0.5, 57 | // "lr": 1e-3 58 | // } 59 | 60 | // // Arxiv, 71% 61 | // { 62 | // "dataset": "arxiv", 63 | // "num_layers": 3, 64 | // "hidden_size": 256, 65 | // "num_epochs": 1000, 66 | // "val_patience": 100, 67 | 68 | // // Local Sampler 69 | // "sampler": "neighbor", 70 | // "minibatch_size": 256, 71 | // "num_neighbors": [10,10,10], 72 | // "local_updates": 64, 73 | // // "local_updates": 1, 74 | 75 | // "layer_norm": true, 76 | // "dropout": 0.5, 77 | // "lr": 1e-3 78 | // } -------------------------------------------------------------------------------- /dgnn/models/appnp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.functional as f 5 | 6 | from ..layers import APPNPConv 7 | from ..data import NodeBlocks 8 | 9 | class APPNP(nn.Module): 10 | """ 11 | APPNP model with simple APPNPConv layers at all layers 12 | """ 13 | 14 | def __init__(self, 15 | features_dim, 16 | hidden_dim, 17 | num_classes, 18 | num_layers, 19 | activation, 20 | layer=APPNPConv, 21 | dropout=0, 22 | input_norm=False, 23 | layer_norm=False, 24 | *args, 25 | **kwargs): 26 | 27 | super().__init__() 28 | 29 | self.num_layers = num_layers 30 | self.layers = nn.ModuleList() 31 | 32 | self.layer_type = layer 33 | 34 | self.dropout = nn.Dropout(p=dropout) 35 | self.activation = activation 36 | self.batch_norm = nn.BatchNorm1d(features_dim, affine=False) 37 | self.input_fc = nn.Linear(features_dim, hidden_dim) 38 | self.output_fc = nn.Linear(hidden_dim, num_classes) 39 | 40 | for i in range(num_layers): 41 | self.layers.append(layer(hidden_dim, hidden_dim,)) 42 | 43 | 44 | 45 | def forward(self, x, adjs): 46 | 47 | h = x 48 | adj = adjs 49 | 50 | h = self.activation(self.input_fc(self.batch_norm(h))) 51 | h_0 = h.clone() 52 | 53 | gcn_cnt = 0 54 | 55 | for i, layer in enumerate(self.layers): 56 | 57 | if type(layer) == self.layer_type: 58 | 59 | if type(adjs) == NodeBlocks: 60 | adj = adjs[gcn_cnt] 61 | gcn_cnt += 1 62 | 63 | h = layer(adj, h, h_0) 64 | 65 | else: 66 | h = layer(h) 67 | 68 | h = self.dropout(h) 69 | h = self.output_fc(h) 70 | 71 | return h -------------------------------------------------------------------------------- /scripts/configs/reddit.json: -------------------------------------------------------------------------------- 1 | // Reddit, 96.44% @ Epoch #198, w/ Full 2 | // 96.6 w/ sss and layer norm and sampling 3 | // "arch": "slsl", 4 | // "num_layers": 2, 5 | { 6 | "dataset": "reddit", 7 | "model": "custom", 8 | "arch": "sss", 9 | "num_layers": 3, 10 | "hidden_size": 256, 11 | "layer_norm": true, 12 | "dropout": 0.5, 13 | 14 | // "lr": 1e-3, 15 | "val_patience": 500, 16 | 17 | // "num_epochs": 1000, 18 | // "minibatch_size": 32, 19 | // "local_updates": 1, 20 | 21 | // correction alone test 22 | // "num_epochs": 50, 23 | // "minibatch_size": 256, 24 | // "local_updates": 1, 25 | // "lr": 1e-2, 26 | 27 | // Local Sampler 28 | "num_samplers": 4, 29 | "sampler": "neighbor", 30 | "num_neighbors": [10, 10, 10], 31 | 32 | // "num_epochs": 500, 33 | // "local_updates": 64, 34 | 35 | // "local_updates": 256, 36 | // "minibatch_size": 32, 37 | 38 | "num_epochs": 75, 39 | "local_updates": 64, 40 | "minibatch_size": 256, 41 | "rho": 1.1, 42 | 43 | // Server Correction 44 | "server_sampler": "neighbor", 45 | "server_minibatch": "random", 46 | "server_num_neighbors": [10, 10, 10], 47 | "server_updates": 1, 48 | // "server_minibatch_size": 256, 49 | "server_minibatch_size": 2048, 50 | "server_start_epoch": 0, 51 | "server_lr": 1e-2, 52 | } 53 | 54 | // "server_minibatch": "stratified", 55 | 56 | // { 57 | // "dataset": "reddit", 58 | // "model": "custom", 59 | // "arch": "sss", 60 | // "num_layers": 3, 61 | // "hidden_size": 256, 62 | // "layer_norm": true, 63 | 64 | // "dropout": 0.5, 65 | // "lr": 1e-2, 66 | // "num_epochs": 300, 67 | // "val_patience": 100, 68 | 69 | // // Local Sampler 70 | // "sampler": "neighbor", 71 | // // "minibatch_size": 256, 72 | // "minibatch_size": 2048, 73 | // // "local_updates": 8, 74 | // "local_updates": 1, 75 | // "num_neighbors": [10, 10, 10], 76 | // } -------------------------------------------------------------------------------- /dgnn/layers/dist_gconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch_sparse 4 | 5 | from ..utils import dist_sum, dist_spmm 6 | 7 | class distgconv(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, inputs, weight, adjs, rank, world_size): 10 | 11 | ctx.save_for_backward(inputs, weight) 12 | ctx.adjs = adjs 13 | ctx.rank = rank 14 | ctx.world_size = world_size 15 | 16 | # T = AH 17 | T = dist_spmm(adjs, inputs, rank, world_size) 18 | # Z = TW 19 | Z = torch.mm(T, weight) 20 | 21 | return Z 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | 26 | inputs, weight = ctx.saved_tensors 27 | adjs = ctx.adjs 28 | rank = ctx.rank 29 | world_size = ctx.world_size 30 | 31 | ag = dist_spmm(adjs, grad_output, rank , world_size) 32 | 33 | grad_input = torch.mm(ag, weight.t()) 34 | 35 | # grad_weight = dist_sum(ag, inputs.t(), rank, world_size) 36 | 37 | grad_weight = torch.mm(inputs.t(), ag) 38 | 39 | grad_weight = dist_sum(grad_weight) 40 | 41 | return grad_input, grad_weight, None, None, None 42 | 43 | 44 | class DistGConv(torch.nn.Module): 45 | 46 | def __init__(self, 47 | input_dim, 48 | output_dim, 49 | rank=0, 50 | ): 51 | super().__init__() 52 | 53 | self.input_dim = input_dim 54 | self.output_dim = output_dim 55 | self.rank = rank 56 | 57 | self.fn = distgconv.apply 58 | self.weight = torch.nn.Parameter(torch.rand(input_dim, output_dim)) 59 | 60 | 61 | def forward(self, x, adj): 62 | world_size = dist.get_world_size() 63 | x = self.fn(x, self.weight, adj, self.rank, world_size) 64 | return x 65 | 66 | def __repr__(self): 67 | return self.__class__.__name__ + "[{}] ({}->{})".format( 68 | self.rank, 69 | self.input_dim, 70 | self.output_dim) 71 | -------------------------------------------------------------------------------- /dgnn/data/samplers/subgraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 5 | 6 | from ..nodeblocks import NodeBlocks 7 | from .minibatch import RandomBatchSampler, StratifiedMiniBatch 8 | 9 | class SubGraphSampler(torch.utils.data.DataLoader): 10 | 11 | def __init__(self, 12 | adj, 13 | batch_size, 14 | shuffle=False, 15 | num_batches=1, 16 | num_layers=1, 17 | minibatch='random', 18 | **kwargs): 19 | 20 | self.data = copy.deepcopy(adj) 21 | self.num_layers = num_layers 22 | 23 | if 'part_meta' in kwargs: 24 | self.partition_meta = kwargs['part_meta'] 25 | kwargs.pop('part_meta') 26 | 27 | if minibatch == 'random': 28 | batch_sampler = RandomBatchSampler(adj.size(0), batch_size, shuffle, num_batches) 29 | elif minibatch == 'stratified': 30 | batch_sampler = StratifiedMiniBatch(adj.size(0), batch_size, shuffle, num_batches, self.partition_meta) 31 | 32 | super().__init__( 33 | self, 34 | batch_size=1, 35 | sampler=batch_sampler, 36 | collate_fn=self.__collate__, 37 | **kwargs, 38 | ) 39 | 40 | def __getitem__(self, idx): 41 | # Gets the next minibatch (from (minibatch) sampler) 42 | return idx 43 | 44 | def __collate__(self, batch_idx): 45 | # This function is executed in parallel and create/modify the graphs... 46 | # batch_nodes, _ = torch.sort(batch_idx[0]) 47 | batch_nodes = batch_idx[0] 48 | 49 | batch_adj, _ = self.data.saint_subgraph(batch_nodes) 50 | # batch_adj = batch_adj.to_symmetric() 51 | batch_adj = gcn_norm(batch_adj.set_value(None)) 52 | 53 | batch_nb = NodeBlocks(self.num_layers, batch_adj) 54 | batch_nb.set_output_nid(batch_nodes) 55 | 56 | # return input_nid, nodeblocks and output_nid 57 | return batch_nodes, batch_nb, batch_nodes -------------------------------------------------------------------------------- /dgnn/data/partition/metis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | def metis(dataset, num_parts): 8 | 9 | data = dataset[0] 10 | 11 | # Normalize the Adjacency 12 | # data = T.GCNNorm()(data) 13 | 14 | dist_dir = os.path.join(dataset.processed_dir, '../partitioned/') 15 | 16 | # save clustering info for later 17 | cluster_data = data.adj_t.partition(num_parts=num_parts) 18 | 19 | adj = cluster_data[0] 20 | partptr = cluster_data[1] 21 | perm = cluster_data[2] 22 | 23 | # if not os.path.exists(dist_dir): 24 | # os.mkdir(dist_dir) 25 | # part_info = os.path.join(dist_dir, 'partition_{}.pt'.format(num_parts)) 26 | # print(part_info) 27 | # # FIXME: I changed it, fix it later 28 | 29 | partitioned_dir = os.path.join(dist_dir, 'metis-{}'.format(num_parts)) 30 | print(partitioned_dir) 31 | if not os.path.exists(partitioned_dir): 32 | os.mkdir(partitioned_dir) 33 | 34 | torch.save((perm, partptr), partitioned_dir+'/perm.pt') 35 | 36 | start = partptr[0] 37 | part_cnt = 0 38 | for end in partptr[1:]: 39 | 40 | # Adj Partitioning 41 | part_adj = [] 42 | adj_pbn = adj.narrow(0, start, end-start) 43 | i = partptr[0] 44 | for j in partptr[1:]: 45 | adj_pbp = adj_pbn.narrow(1, i, j-i) 46 | part_adj.append(adj_pbp) 47 | i = j 48 | torch.save(part_adj, partitioned_dir+'/adj_{}.pt'.format(part_cnt)) 49 | 50 | # Features and Labels Partitioning 51 | part_feats = data.x[perm[start:end]] 52 | part_labels = data.y[perm[start:end]] 53 | part_train_mask = data.train_mask[perm[start:end]] 54 | part_val_mask = data.val_mask[perm[start:end]] 55 | part_test_mask = data.test_mask[perm[start:end]] 56 | torch.save((part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask), 57 | partitioned_dir+'/fela_{}.pt'.format(part_cnt)) 58 | 59 | start = end 60 | part_cnt += 1 61 | 62 | # TODO: save meta info for the dataset to avoid loading everything -------------------------------------------------------------------------------- /dgnn/layers/gatconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch_sparse import spmm, SparseTensor 6 | from torch_geometric.utils import softmax 7 | import math 8 | 9 | 10 | class GATConv(nn.Module): 11 | """[summary] 12 | 13 | Arguments: 14 | nn {[type]} -- [description] 15 | """ 16 | 17 | def __init__(self, 18 | input_dim, 19 | output_dim, 20 | *args, 21 | **kwargs, 22 | ): 23 | super().__init__() 24 | 25 | self.input_dim = input_dim 26 | self.output_dim = output_dim 27 | self.linear = nn.Linear(input_dim, output_dim, bias=True) 28 | 29 | # use to distinguish from other layers 30 | self.graph_layer = True 31 | 32 | self.layer_id = kwargs['layer_id'] if 'layer_id' in kwargs else '0' 33 | self.num_heads = kwargs['num_heads'] if 'num_heads' in kwargs else 1 34 | 35 | self.attn_src = nn.ModuleList() 36 | self.attn_dst = nn.ModuleList() 37 | 38 | for _ in range(self.num_heads): 39 | self.attn_src.append(nn.Linear(self.output_dim, 1, bias=False)) 40 | self.attn_dst.append(nn.Linear(self.output_dim, 1, bias=False)) 41 | 42 | self.negative_slope = kwargs['negative_slope'] if 'negative_slope' in kwargs else 0.2 43 | 44 | 45 | def forward(self, adj, h, *args): 46 | 47 | h = self.linear(h) 48 | h_list = [] 49 | 50 | row = adj.storage.row() 51 | col = adj.storage.col() 52 | sparse_sizes = adj.sizes() 53 | 54 | for head in range(self.num_heads): 55 | attn = self.attn_src[head](h)[row] + self.attn_dst[head](h)[col] 56 | attn = F.leaky_relu(attn, negative_slope=self.negative_slope) 57 | attn = softmax(attn, row).flatten() 58 | attn = SparseTensor(row=row, col=col, value=attn, sparse_sizes=sparse_sizes) 59 | h_list.append(attn.spmm(h)) 60 | 61 | h = torch.cat(h_list, dim=1) 62 | return h 63 | 64 | def __repr__(self): 65 | return self.__class__.__name__ + "[{}] ({}->{})".format( 66 | self.layer_id, 67 | self.input_dim, 68 | self.output_dim) -------------------------------------------------------------------------------- /dgnn/utils/cython/setup.py: -------------------------------------------------------------------------------- 1 | # REF: https://github.com/cython/cython/wiki/PackageHierarchy 2 | 3 | import os 4 | import sys 5 | import numpy 6 | 7 | from distutils.core import setup 8 | from distutils.extension import Extension 9 | from distutils.sysconfig import customize_compiler 10 | 11 | from Cython.Distutils import build_ext 12 | # from distutils.command.build_ext import build_ext 13 | # from Cython.Build import cythonize 14 | 15 | 16 | # scan the directory for extension files, converting 17 | # them to extension names in dotted notation 18 | def scandir(dir, files=[]): 19 | for file in os.listdir(dir): 20 | path = os.path.join(dir, file) 21 | if os.path.isfile(path) and path.endswith(".pyx"): 22 | files.append(path.replace(os.path.sep, ".")[:-4]) 23 | elif os.path.isdir(path): 24 | scandir(path, files) 25 | return files 26 | 27 | 28 | # generate an Extension object from its dotted name 29 | def make_extension(ext_name): 30 | extPath = ext_name.replace(".", os.path.sep)+".pyx" 31 | return Extension( 32 | ext_name, 33 | [extPath], 34 | language='c++', 35 | include_dirs=[numpy.get_include(), "."], # adding the '.' to include_dirs is CRUCIAL!! 36 | extra_compile_args=['-O3', '-Wall', '-fopenmp'], 37 | extra_link_args=['-g', '-fopenmp'], 38 | # libraries = ['',], 39 | define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], 40 | ) 41 | 42 | # remove annoying warning 43 | 44 | 45 | class my_build_ext(build_ext): 46 | def build_extensions(self): 47 | customize_compiler(self.compiler) 48 | try: 49 | self.compiler.compiler_so.remove("-Wstrict-prototypes") 50 | except (AttributeError, ValueError): 51 | pass 52 | build_ext.build_extensions(self) 53 | 54 | 55 | # get the list of extensions 56 | ext_names = scandir('extension') 57 | 58 | # and build up the set of Extension objects 59 | extensions = [make_extension(name) for name in ext_names] 60 | 61 | # finally, we can pass all this to distutils 62 | setup( 63 | name="extension", 64 | packages=['utils', 'sparse'], 65 | cmdclass={'build_ext': my_build_ext}, 66 | ext_modules=extensions, 67 | # ext_modules=cythonize(extensions), 68 | ) 69 | -------------------------------------------------------------------------------- /dgnn/data/partition/overhead.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | import torch 5 | import math 6 | import numpy as np 7 | 8 | import torch_geometric.transforms as T 9 | 10 | def overhead(dataset, num_parts, overhead=10): 11 | 12 | data = dataset[0] 13 | dist_dir = os.path.join(dataset.processed_dir, '../partitioned/') 14 | 15 | partitioned_dir = os.path.join(dist_dir, 'overhead-{}-{}'.format(overhead, num_parts)) 16 | print(partitioned_dir) 17 | if not os.path.exists(partitioned_dir): 18 | os.mkdir(partitioned_dir) 19 | 20 | train_idx = data.train_mask.nonzero(as_tuple=True)[0] 21 | val_idx = data.val_mask.nonzero(as_tuple=True)[0] 22 | test_idx = data.test_mask.nonzero(as_tuple=True)[0] 23 | 24 | train_npp = math.ceil(train_idx.shape[0] / num_parts) 25 | val_npp = math.ceil(val_idx.shape[0] / num_parts) 26 | test_npp = math.ceil(test_idx.shape[0] / num_parts) 27 | 28 | train_parts = train_idx.split(train_npp) 29 | val_parts = val_idx.split(val_npp) 30 | test_parts = test_idx.split(test_npp) 31 | 32 | for i in range(num_parts): 33 | 34 | part_idx = torch.cat((train_parts[i], val_parts[i], test_parts[i])) 35 | 36 | # add neighbors of part_idx to the part_idx up to overhead% 37 | 38 | tmp_adj = data.adj_t[part_idx] 39 | tmp_row = torch.unique(tmp_adj.storage.row()) 40 | tmp_col = torch.unique(tmp_adj.storage.col()) 41 | tmp_diff = tmp_col[~tmp_col.unsqueeze(1).eq(tmp_row).any(1)] 42 | num_overhead = int(overhead * part_idx.size(0) /100) 43 | 44 | # TODO: random or more walk? 45 | #! only pre-overhead training nodes? 46 | overhead_nodes = tmp_diff[:num_overhead] 47 | part_idx = torch.cat((part_idx, overhead_nodes)) 48 | 49 | part_adj = data.adj_t[part_idx, part_idx] 50 | part_feats = data.x[part_idx] 51 | part_labels = data.y[part_idx] 52 | part_train_mask = data.train_mask[part_idx] 53 | part_val_mask = data.val_mask[part_idx] 54 | part_test_mask = data.test_mask[part_idx] 55 | 56 | torch.save(part_adj, partitioned_dir+'/adj_{}.pt'.format(i)) 57 | torch.save((part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask), 58 | partitioned_dir+'/fela_{}.pt'.format(i)) -------------------------------------------------------------------------------- /dgnn/models/gat.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import torch.nn.functional as f 6 | 7 | from ..layers import GATConv 8 | from ..data import NodeBlocks 9 | 10 | class GAT(nn.Module): 11 | """ 12 | GAT model with simple GATCoonv layers at all layers 13 | """ 14 | 15 | def __init__(self, 16 | features_dim, 17 | hidden_dim, 18 | num_classes, 19 | num_layers, 20 | activation, 21 | layer=GATConv, 22 | dropout=0, 23 | input_norm=False, 24 | layer_norm=False, 25 | *args, 26 | **kwargs): 27 | 28 | super().__init__() 29 | 30 | self.num_layers = num_layers 31 | self.layers = nn.ModuleList() 32 | 33 | dropout = nn.Dropout(p=dropout) 34 | self.activation = activation 35 | 36 | self.layer_type = layer 37 | 38 | 39 | if input_norm: 40 | self.layers.append(nn.BatchNorm1d(features_dim, affine=False)) 41 | 42 | 43 | self.layers.append(layer(features_dim, hidden_dim, layer_id=1)) 44 | if layer_norm: 45 | self.layers.append(torch.nn.BatchNorm1d(hidden_dim)) 46 | self.layers.append(activation) 47 | self.layers.append(dropout) 48 | 49 | for i in range(1, num_layers-1): 50 | self.layers.append(layer(hidden_dim, hidden_dim, layer_id=i+1)) 51 | if layer_norm: 52 | self.layers.append(torch.nn.BatchNorm1d(hidden_dim)) 53 | self.layers.append(activation) 54 | self.layers.append(dropout) 55 | 56 | self.layers.append( 57 | layer(hidden_dim, num_classes, layer_id=num_layers)) 58 | 59 | 60 | def forward(self, x, adjs): 61 | 62 | h = x 63 | adj = adjs 64 | gcn_cnt = 0 65 | 66 | for i, layer in enumerate(self.layers): 67 | 68 | if type(layer) == self.layer_type: 69 | 70 | if type(adjs) == NodeBlocks: 71 | adj = adjs[gcn_cnt] 72 | gcn_cnt += 1 73 | 74 | h = layer(adj, h) 75 | 76 | else: 77 | h = layer(h) 78 | 79 | return h -------------------------------------------------------------------------------- /dgnn/models/gcn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import torch.nn.functional as f 6 | 7 | from ..layers import GConv 8 | from ..data import NodeBlocks 9 | 10 | class GCN(nn.Module): 11 | """ 12 | GCN model with simple GConv layers at all layers 13 | """ 14 | 15 | def __init__(self, 16 | features_dim, 17 | hidden_dim, 18 | num_classes, 19 | num_layers, 20 | activation, 21 | layer=GConv, 22 | dropout=0, 23 | input_norm=False, 24 | layer_norm=False, 25 | *args, 26 | **kwargs): 27 | 28 | super().__init__() 29 | 30 | self.num_layers = num_layers 31 | self.layers = nn.ModuleList() 32 | 33 | dropout = nn.Dropout(p=dropout) 34 | self.activation = activation 35 | 36 | self.layer_type = layer 37 | 38 | 39 | if input_norm: 40 | self.layers.append(nn.BatchNorm1d(features_dim, affine=False)) 41 | 42 | 43 | self.layers.append(layer(features_dim, hidden_dim, layer_id=1)) 44 | if layer_norm: 45 | self.layers.append(torch.nn.BatchNorm1d(hidden_dim)) 46 | self.layers.append(activation) 47 | self.layers.append(dropout) 48 | 49 | for i in range(1, num_layers-1): 50 | self.layers.append(layer(hidden_dim, hidden_dim, layer_id=i+1)) 51 | if layer_norm: 52 | self.layers.append(torch.nn.BatchNorm1d(hidden_dim)) 53 | self.layers.append(activation) 54 | self.layers.append(dropout) 55 | 56 | self.layers.append( 57 | layer(hidden_dim, num_classes, layer_id=num_layers)) 58 | 59 | # for layer in self.layers: 60 | # layer.linear.weight.data.fill_(0.01) 61 | 62 | def forward(self, x, adjs): 63 | 64 | h = x 65 | adj = adjs 66 | gcn_cnt = 0 67 | 68 | for i, layer in enumerate(self.layers): 69 | 70 | if type(layer) == self.layer_type: 71 | 72 | if type(adjs) == NodeBlocks: 73 | adj = adjs[gcn_cnt] 74 | gcn_cnt += 1 75 | 76 | h = layer(adj, h) 77 | 78 | else: 79 | h = layer(h) 80 | 81 | return h -------------------------------------------------------------------------------- /dgnn/utils/stats.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import yaml 5 | import copy 6 | import numpy as np 7 | 8 | from multiprocessing import Value, Array, Manager 9 | 10 | class Stats(object): 11 | 12 | def __init__(self, config): 13 | 14 | 15 | self.config = config 16 | 17 | # Loss and Scores 18 | # manager = Manager() 19 | # self.train_loss = manager.list() 20 | # self.train_scores = manager.list() 21 | # self.val_loss = manager.list() 22 | # self.val_scores = manager.list() 23 | # self.test_score = manager.Value('d', 0) 24 | 25 | self.train_loss = [] 26 | self.train_scores = [] 27 | self.val_loss = [] 28 | self.val_scores = [] 29 | self.test_score = [] 30 | 31 | # best model 32 | self.best_model = [] 33 | self.best_val_score = 0 34 | self.best_val_loss = 1e10 35 | self.best_val_epoch = 0 36 | 37 | self.best_val_buff = [] 38 | 39 | # TODO 40 | # Timing info 41 | self.train_time = [] 42 | self.val_time = [] 43 | self.test_time = 0 44 | self.comm_cost = [] 45 | 46 | 47 | @property 48 | def run_id(self): 49 | current_counter = 1 50 | 51 | if os.path.exists(self.config.output_dir): 52 | for fn in os.listdir(self.config.output_dir): 53 | if fn.startswith(self.config.run_name+'-') and fn.endswith('npz'): 54 | current_counter += 1 55 | 56 | return '{}-{:03d}'.format(self.config.run_name, current_counter) 57 | 58 | @property 59 | def run_output(self): 60 | output = os.path.join(self.config.output_dir, self.run_id) 61 | return output 62 | 63 | 64 | def save(self): 65 | 66 | if self.config.output_dir == '': 67 | return None, None 68 | 69 | config_vars = vars(self.config) 70 | 71 | stats_vars = copy.copy(vars(self)) 72 | stats_vars.pop('config', None) 73 | 74 | # create output folder 75 | if not os.path.exists(self.config.output_dir): 76 | os.makedirs(self.config.output_dir) 77 | 78 | # save model to torch TODO: later 79 | # remove from stats 80 | stats_vars.pop('best_model', None) 81 | 82 | # print(stats_vars) 83 | 84 | # save config and stats to npy 85 | np.savez(self.run_output, config=config_vars, stats=stats_vars) 86 | 87 | return config_vars, stats_vars 88 | 89 | @staticmethod 90 | def load(stats_file): 91 | all_data = np.load(stats_file, allow_pickle=True) 92 | config = all_data['config'][()] 93 | stats = all_data['stats'][()] 94 | return config, stats 95 | -------------------------------------------------------------------------------- /dgnn/data/nodeblocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Tuple 3 | from torch_sparse import SparseTensor 4 | 5 | class NodeBlocks(): 6 | """ NodeBlocks hold the adjacency matrix per layer for GCN 7 | propagation 8 | Arguments: 9 | object {[type]} -- [description] 10 | """ 11 | 12 | def __init__(self, num_layers, from_graph=None): 13 | self.num_layers = num_layers 14 | self.layers_adj: List[SparseTensor] = [] 15 | self.layers_nodes = [] 16 | self.output_nid = None 17 | self.is_subgraph = False 18 | 19 | if from_graph is not None: 20 | self.from_graph(from_graph) 21 | 22 | def __getitem__(self, layer_id): 23 | return self.layers_adj[layer_id] 24 | 25 | def __len__(self): 26 | return self.num_layers 27 | 28 | def __repr__(self): 29 | print(self.layers_adj) 30 | return '' 31 | 32 | def to(self, device='cpu'): 33 | """ Move nodeblock to the device 34 | In case of fullgraph, only first layers is moved and 35 | the rest are pointing to the same adjacency 36 | Keyword Arguments: 37 | device {str} -- [description] (default: {'cpu'}) 38 | """ 39 | if not self.is_subgraph: 40 | for i, adj in enumerate(self.layers_adj): 41 | self.layers_adj[i] = adj.to(device, non_blocking=True) 42 | else: 43 | for i, adj in enumerate(self.layers_adj): 44 | if i == 0: 45 | self.layers_adj[i] = adj.to(device, non_blocking=True) 46 | else: 47 | self.layers_adj[i] = self.layers_adj[0] 48 | 49 | def from_graph(self, graph): 50 | """ Create nodeblocks from a complete graph 51 | The full adjacency is repeated for all layers. 52 | It doesn't copy on CPU, but on GPU does. 53 | To avoid the overhead .to function handle the repeat. 54 | Arguments: 55 | graph {[type]} -- [description] 56 | """ 57 | 58 | self.layers_adj = [graph] * self.num_layers 59 | self.layers_nodes = [torch.arange(graph.size(0))] * self.num_layers 60 | self.is_subgraph = True 61 | 62 | def add_layers(self, spmx, nodes): 63 | """ append new layers adjaceny to the nodeblocks 64 | Arguments: 65 | spmx {[type]} -- [description] 66 | nodes {[type]} -- [description] 67 | """ 68 | 69 | self.layers_adj.insert(0, spmx) 70 | self.layers_nodes.insert(0, nodes) 71 | 72 | def set_output_nid(self, nodes): 73 | """ set output nid 74 | Arguments: 75 | nodes {[type]} -- [description] 76 | """ 77 | self.output_nid = nodes 78 | 79 | @property 80 | def input_nid(self): 81 | return self.layers_nodes[-1] -------------------------------------------------------------------------------- /dgnn/data/samplers/neighbors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 5 | 6 | from .minibatch import RandomBatchSampler 7 | from ..nodeblocks import NodeBlocks 8 | from ..transforms import row_norm, col_norm 9 | 10 | from .minibatch import RandomBatchSampler, StratifiedMiniBatch, DGLBatchSampler 11 | 12 | from ...utils.cython.extension.sparse import sample_neighbors 13 | class NeighborSampler(torch.utils.data.DataLoader): 14 | 15 | def __init__(self, 16 | adj, 17 | batch_size, 18 | shuffle=False, 19 | num_batches=1, 20 | num_layers=1, 21 | num_neighbors=[], 22 | minibatch='random', 23 | **kwargs): 24 | 25 | self.data = copy.deepcopy(adj) 26 | self.num_layers = num_layers 27 | self.num_neighbors = num_neighbors 28 | 29 | 30 | self.partition_meta = '' 31 | if 'part_meta' in kwargs: 32 | self.partition_meta = kwargs['part_meta'] 33 | kwargs.pop('part_meta') 34 | 35 | if 'node_idx' in kwargs: 36 | node_idx = kwargs['node_idx'] 37 | kwargs.pop('node_idx') 38 | 39 | 40 | if minibatch == 'random': 41 | self.sampler = RandomBatchSampler(adj.size(0), batch_size, shuffle, num_batches) 42 | elif minibatch == 'stratified': 43 | self.sampler = StratifiedMiniBatch(adj.size(0), batch_size, shuffle, num_batches, self.partition_meta) 44 | elif minibatch == 'dglsim': 45 | self.sampler = DGLBatchSampler(node_idx, batch_size, shuffle, num_batches) 46 | 47 | super().__init__( 48 | self, 49 | batch_size=1, 50 | sampler=self.sampler, 51 | collate_fn=self.__collate__, 52 | **kwargs, 53 | ) 54 | 55 | def __getitem__(self, idx): 56 | # Gets the next minibatch (from (minibatch) sampler) 57 | return idx 58 | 59 | def __collate__(self, batch_idx): 60 | # This function is exectued in parallel and create/modify the graphs... 61 | 62 | batch_nodes = batch_idx[0] 63 | nodeblocks = NodeBlocks(self.num_layers) 64 | nodeblocks.set_output_nid(batch_nodes) 65 | 66 | for i in range(self.num_layers): 67 | batch_adj, next_nodes = self.data.sample_adj(batch_nodes, self.num_neighbors[i]) 68 | # batch_adj = batch_adj.to_symmetric() 69 | batch_adj = row_norm(batch_adj) 70 | # batch_adj = gcn_norm(batch_adj) 71 | 72 | nodeblocks.add_layers(batch_adj, next_nodes) 73 | batch_nodes = next_nodes 74 | 75 | # return input_nid, nodeblocks and output_nid 76 | return batch_nodes, nodeblocks, nodeblocks.output_nid 77 | 78 | 79 | def update_k(self, k): 80 | # print(self.sampler) 81 | self.sampler.update_k(k) -------------------------------------------------------------------------------- /dgnn/utils/config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self, config): 3 | 4 | self.dataset = '' 5 | self.output_dir = '' 6 | self.run_name = 'run' 7 | 8 | self.model = 'gcn' 9 | self.layer = 'gconv' 10 | self.activation = 'relu' 11 | self.input_norm = False 12 | self.layer_norm = False 13 | self.num_layers = 2 14 | self.hidden_size = 16 15 | 16 | # g for gconv, l for linear, a for attention, s for sageconv 17 | self.arch = 'gg' 18 | self.residual = False 19 | 20 | 21 | self.num_samplers = 5 22 | self.sampler = 'subgraph' 23 | self.num_neighbors = [10,10] 24 | self.minibatch = 'random' 25 | self.minibatch_size = 256 26 | self.local_updates = 5 27 | 28 | # Correction server settings 29 | self.server_sampler = 'subgraph' 30 | self.server_num_neighbors = [10,10] 31 | self.server_minibatch = 'random' 32 | self.server_minibatch_size = 256 33 | self.server_updates = 1 34 | self.server_lr = 1e-3 35 | self.server_start_epoch = 0 36 | self.server_opt_sync = False 37 | self.rho = 1 38 | self.inc_k = False 39 | 40 | self.loss = 'xentropy' 41 | self.optim = 'adam' 42 | self.lr = 2e-2 43 | self.dropout = 0 44 | self.wd = 0 45 | self.num_epochs = 200 46 | self.val_patience = 2 47 | self.val_step = 1 48 | 49 | self.gpu = 0 50 | self.cpu = False 51 | self.cpu_val = False 52 | self.num_gpus = 4 53 | 54 | self.part_method = 'random' 55 | self.part_args = '' 56 | self.num_workers = 2 57 | 58 | # Mostly unused! 59 | self.hist_period = 0 60 | self.hist_exp = False 61 | self.stratified = False 62 | self.sync_local = True 63 | self.full_correct = False 64 | self.use_sampling = False 65 | self.weight_avg = True 66 | 67 | for key, value in config.items(): 68 | setattr(self, key, value) 69 | 70 | if self.dataset is not None and type(self.dataset) != str: 71 | self.dataset = self.dataset.name 72 | 73 | 74 | def __repr__(self): 75 | all_config = vars(self) 76 | for c in all_config: 77 | print(c, all_config[c]) 78 | 79 | return "" 80 | 81 | @property 82 | def world_size(self): 83 | return self.num_workers 84 | 85 | @property 86 | def partitioned_dir(self): 87 | # TOOD: better handling 88 | if self.part_method == 'overhead': 89 | return f'partitioned/overhead-{self.part_args}-{self.num_workers}' 90 | else: 91 | return f'partitioned/{self.part_method}-{self.num_workers}' 92 | 93 | @property 94 | def processed_filename(self): 95 | # if self.dataset in ['proteins', 'arxiv', 'products']: 96 | # return 'processed/geometric_data_processed.pt' 97 | # else: 98 | # return 'processed/data.pt' 99 | return 'processed/data.pt' -------------------------------------------------------------------------------- /dgnn/models/custom.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import torch.nn.functional as f 6 | 7 | from ..layers import GConv, MLPLayer, SAGEConv #,and more 8 | 9 | from ..data import NodeBlocks 10 | 11 | class Custom(nn.Module): 12 | """Custom GNN model builder 13 | 14 | Arguments: 15 | nn {[type]} -- [description] 16 | """ 17 | 18 | def __init__(self, 19 | features_dim, 20 | hidden_dim, 21 | num_classes, 22 | num_layers, 23 | activation, 24 | layer=None, 25 | dropout=0, 26 | input_norm=False, 27 | layer_norm=False, 28 | arch='', 29 | *args, 30 | **kwargs): 31 | 32 | super().__init__() 33 | 34 | self.num_layers = num_layers 35 | self.arch = arch 36 | self.layers = nn.ModuleList() 37 | 38 | self.dropout = nn.Dropout(p=dropout) 39 | self.activation = activation 40 | self.residual = kwargs['residual'] 41 | 42 | self.model_builder(arch, features_dim, hidden_dim, num_classes, input_norm, layer_norm) 43 | 44 | def model_builder(self, arch, feat_dim, hid_dim, output_dim, input_norm, layer_norm): 45 | 46 | layers_tokens = list(arch) 47 | num_layers = len(layers_tokens) 48 | 49 | in_dim = feat_dim 50 | out_dim = hid_dim 51 | 52 | if input_norm: 53 | self.layers.append(nn.BatchNorm1d(feat_dim, affine=False)) 54 | 55 | for i, l in enumerate(layers_tokens): 56 | 57 | if l == 'l': 58 | layer = nn.Linear(in_dim, out_dim) 59 | elif l == 'g': 60 | layer = GConv(in_dim, out_dim, layer_id=i) 61 | elif l == 's': 62 | layer = SAGEConv(in_dim, out_dim, layer_id=i) 63 | 64 | self.layers.append(layer) 65 | 66 | if i < num_layers - 1: 67 | if layer_norm: 68 | self.layers.append(torch.nn.BatchNorm1d(out_dim)) 69 | self.layers.append(self.activation) 70 | self.layers.append(self.dropout) 71 | 72 | in_dim = hid_dim 73 | out_dim = hid_dim if i + 1 < num_layers - 1 else output_dim 74 | 75 | 76 | def forward(self, x, adjs): 77 | 78 | h = x 79 | adj = adjs 80 | gcn_cnt = 0 81 | h_res = None 82 | 83 | for i, layer in enumerate(self.layers): 84 | 85 | if hasattr(layer, 'graph_layer'): 86 | 87 | if gcn_cnt > 0 and self.residual: 88 | h_res = h.clone() 89 | 90 | if type(adjs) == NodeBlocks: 91 | adj = adjs[gcn_cnt] 92 | 93 | h = layer(adj, h) 94 | gcn_cnt += 1 95 | 96 | else: 97 | h = layer(h) 98 | 99 | if self.residual and i < len(self.layers) - 1 and gcn_cnt > 1 and isinstance(layer, type(self.activation)): 100 | h = h + h_res 101 | 102 | return h 103 | -------------------------------------------------------------------------------- /dgnn/data/partition/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .metis import metis 4 | from .random import random 5 | from .overhead import overhead 6 | 7 | 8 | def load_partitions(parted_path, rank): 9 | 10 | adj = torch.load(parted_path+'/adj_{}.pt'.format(rank)) 11 | features, labels, train_mask, val_mask, test_mask = torch.load( 12 | parted_path+'/fela_{}.pt'.format(rank)) 13 | return adj, features, labels, train_mask, val_mask, test_mask 14 | 15 | def load_meta(parted_path): 16 | perm = torch.load(parted_path+'/perm.pt') 17 | if type(perm) == tuple: 18 | perm, partptr = perm 19 | all_perm = [] 20 | start = 0 21 | for end in partptr[1:]: 22 | all_perm.append(perm[start:end]) 23 | start=end 24 | return all_perm 25 | else: 26 | return perm 27 | 28 | 29 | # def load_fixed_part(self, rank): 30 | 31 | # if not self.parted: 32 | # # do it once! 33 | # self.num_nodes = self.dataset[0].num_nodes 34 | # self.node_per_part = math.ceil(self.num_nodes / self.config.num_procs) 35 | # self.part_nodes = torch.split(torch.arange(self.num_nodes), self.node_per_part) 36 | 37 | # self.part_ptr = [self.part_nodes[0][0]] 38 | # for part in self.part_nodes: 39 | # self.part_ptr.append(part[-1]) 40 | 41 | # self.parted = True 42 | 43 | # start = self.part_ptr[rank] 44 | # if start > 0: 45 | # start += 1 46 | # end = self.part_ptr[rank+1] 47 | # adj = self.dataset[0].adj_t.narrow(0, start, end-start+1).narrow(1, start, end-start+1) 48 | 49 | 50 | # part_feats = self.dataset[0].x[start:end+1] 51 | # part_labels = self.dataset[0].y[start:end+1] 52 | # part_train_mask = self.dataset[0].train_mask[start:end+1] 53 | # part_val_mask = self.dataset[0].val_mask[start:end+1] 54 | # part_test_mask = self.dataset[0].test_mask[start:end+1] 55 | 56 | # return adj, part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask 57 | 58 | 59 | # def load_random_part(self, rank): 60 | 61 | # if not self.parted: 62 | # self.parted = True 63 | 64 | # train_idx = self.dataset[0].train_mask.nonzero(as_tuple=True)[0] 65 | # val_idx = self.dataset[0].val_mask.nonzero(as_tuple=True)[0] 66 | # test_idx = self.dataset[0].test_mask.nonzero(as_tuple=True)[0] 67 | 68 | # train_npp = math.ceil(train_idx.shape[0] / self.config.num_procs) 69 | # val_npp = math.ceil(val_idx.shape[0] / self.config.num_procs) 70 | # test_npp = math.ceil(test_idx.shape[0] / self.config.num_procs) 71 | 72 | # train_parts = train_idx.split(train_npp) 73 | # val_parts = val_idx.split(val_npp) 74 | # test_parts = test_idx.split(test_npp) 75 | 76 | 77 | # self.part_idx = [] 78 | # for i in range(self.config.num_procs): 79 | # self.part_idx.append(torch.cat((train_parts[i], val_parts[i], test_parts[i]))) 80 | 81 | # part_idx = self.part_idx[rank] 82 | # part_adj = self.dataset[0].adj_t[part_idx, part_idx] 83 | # part_feats = self.dataset[0].x[part_idx] 84 | # part_labels = self.dataset[0].y[part_idx] 85 | # part_train_mask = self.dataset[0].train_mask[part_idx] 86 | # part_val_mask = self.dataset[0].val_mask[part_idx] 87 | # part_test_mask = self.dataset[0].test_mask[part_idx] 88 | 89 | # return part_adj, part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask -------------------------------------------------------------------------------- /dgnn/train/serial/distgnn_correction.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from . import DistGNN, DistGNNFull 8 | from ...data import samplers 9 | from ...data.transforms import row_norm 10 | 11 | 12 | class DistGNNCorr(DistGNN): 13 | 14 | def __init__(self, config, dataset): 15 | super().__init__(config, dataset) 16 | 17 | full_adj = self.dataset[0].adj_t 18 | if self.dataset.name.startswith('ogbn'): 19 | full_adj = self.dataset[0].adj_t.to_symmetric() 20 | 21 | if self.config.server_sampler == 'subgraph': 22 | self.server_trainloader = samplers.SubGraphSampler(full_adj, 23 | self.config.server_minibatch_size, 24 | num_workers=self.config.server_num_samplers, 25 | num_layers=self.config.num_layers, 26 | num_batches=self.config.server_updates, 27 | minibatch=self.config.server_minibatch, 28 | part_meta=self.dataset_dir, 29 | persistent_workers=True, 30 | ) 31 | elif config.server_sampler == 'neighbor': 32 | self.server_trainloader = samplers.NeighborSampler(full_adj, 33 | self.config.minibatch_size, 34 | num_workers=self.config.num_samplers, 35 | num_layers=self.config.num_layers, 36 | num_batches=self.config.server_updates, 37 | num_neighbors=self.config.server_num_neighbors, 38 | minibatch=self.config.server_minibatch, 39 | part_meta=self.dataset_dir, 40 | persistent_workers=True, 41 | ) 42 | 43 | # self.full_adj = row_norm(self.full_adj).to(self.device) 44 | 45 | def train(self, epoch): 46 | 47 | self.model.train() 48 | 49 | for rank in range(self.config.world_size): 50 | self.local_train(rank, epoch) 51 | 52 | self.server_average() 53 | self.server_correction(epoch) 54 | 55 | def server_correction(self, epoch): 56 | 57 | self.model.train() 58 | 59 | for input_nid, nodeblocks, output_nid in self.server_trainloader: 60 | if epoch == 0: 61 | print('Server correction!') 62 | 63 | nodeblocks.to(self.device) 64 | features = self.full_features[input_nid] 65 | labels = self.full_labels[output_nid] 66 | train_mask = self.full_train_mask[output_nid] 67 | 68 | self.optimizer.zero_grad() 69 | output = self.model(features, nodeblocks) 70 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 71 | loss.backward() 72 | self.optimizer.step() 73 | -------------------------------------------------------------------------------- /dgnn/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 4 | 5 | from torch_geometric.utils import add_self_loops 6 | from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul, coalesce 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def row_norm(adj): 11 | if isinstance(adj, SparseTensor): 12 | # Add self loop 13 | adj_t = fill_diag(adj, 1) 14 | deg = sum(adj_t, dim=1) 15 | deg_inv = 1. / deg 16 | deg_inv.masked_fill_(deg_inv == float('inf'), 0.) 17 | adj_t = mul(adj_t, deg_inv.view(-1, 1)) 18 | return adj_t 19 | 20 | 21 | def col_norm(adj): 22 | if isinstance(adj, SparseTensor): 23 | # Add self loop 24 | adj_t = fill_diag(adj, 1) 25 | deg = sum(adj_t, dim=0) 26 | deg_inv = 1. / deg 27 | deg_inv.masked_fill_(deg_inv == float('inf'), 0.) 28 | adj_t = mul(adj_t, deg_inv.view(-1, 1)) 29 | return adj_t 30 | 31 | 32 | def sym_norm(adj): 33 | if isinstance(adj, SparseTensor): 34 | adj_t = gcn_norm(adj) 35 | return adj_t 36 | 37 | 38 | class PrepareArxiv(object): 39 | """ Transformation for Arxiv for faster loading""" 40 | 41 | def __call__(self, data): 42 | data.adj_t = data.adj_t.to_symmetric() 43 | data.y = data.y.squeeze() 44 | del data.node_year 45 | return data 46 | 47 | def __repr__(self): 48 | return '{}()'.format(self.__class__.__name__) 49 | 50 | 51 | class PrepareProducts(object): 52 | """ Transformation for Products for faster loading""" 53 | 54 | def __call__(self, data): 55 | # data.x = data.x / data.x.sum(1, keepdim=True).clamp(min=1) 56 | data.y = data.y.squeeze() 57 | return data 58 | 59 | def __repr__(self): 60 | return '{}()'.format(self.__class__.__name__) 61 | 62 | class PrepareProteins(object): 63 | """ Prepare features and adjacency for Proteins """ 64 | 65 | def __call__(self, data): 66 | 67 | ##! preprocessing from DeepGCN 68 | # le = preprocessing.LabelEncoder() 69 | # all_species = data.node_species # if I use dataset[0] here, dataset won't get updated! 70 | # species_unique = torch.unique(all_species) 71 | # max_no = species_unique.max() 72 | # le.fit(species_unique % max_no) 73 | # species = le.transform(all_species.squeeze() % max_no) 74 | # species = np.expand_dims(species, axis=1) 75 | # enc = preprocessing.OneHotEncoder() 76 | # enc.fit(species) 77 | # one_hot_encoding = enc.transform(species).toarray() 78 | # data.x = torch.from_numpy(one_hot_encoding).type(torch.FloatTensor) 79 | 80 | data.x = data.adj_t.mean(dim=1) 81 | data.adj_t.set_value_(None) 82 | data.y = data.y.to(torch.float) 83 | data.y = data.y.squeeze() 84 | 85 | # save space? 86 | del data.node_species 87 | 88 | return data 89 | 90 | 91 | def __repr__(self): 92 | return '{}()'.format(self.__class__.__name__) 93 | 94 | 95 | class PreparePapers100M(object): 96 | """ Transformation for Papers100M for faster loading""" 97 | 98 | def __call__(self, data): 99 | data.adj_t = data.adj_t.to_symmetric() 100 | data.y = data.y.squeeze() 101 | return data 102 | 103 | def __repr__(self): 104 | return '{}()'.format(self.__class__.__name__) -------------------------------------------------------------------------------- /dgnn/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch_geometric.datasets import * 5 | import torch_geometric.transforms as T 6 | from ogb.nodeproppred import PygNodePropPredDataset 7 | 8 | from sklearn import preprocessing 9 | import numpy as np 10 | 11 | from .transforms import * 12 | 13 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html 14 | 15 | class Dataset(): 16 | """[summary] 17 | """ 18 | # def __init__(self, dataset_name, split=None): 19 | # # __new__ is called before __init__ 20 | # # hence it can returns another class object 21 | 22 | def __new__(cls, dataset_name, split=None): 23 | default_dir = os.path.join(os.path.expanduser('~'), '.gnn') 24 | dataset_dir = os.environ.get('GNN_DATASET_DIR', default_dir) 25 | 26 | # transform = T.Compose([T.NormalizeFeatures(), T.ToSparseTensor()]) 27 | transform = T.Compose([T.ToSparseTensor()]) 28 | 29 | # support shortened version of OGB dataset 30 | if dataset_name in ['arxiv', 'proteins', 'mag', 'products', 'papers100M']: 31 | dataset_name = 'ogbn-' + dataset_name 32 | 33 | if dataset_name in ['cora', 'citeseer', 'pubmed']: 34 | dataset = Planetoid(root=dataset_dir, name=dataset_name, split='full', pre_transform=transform) 35 | elif dataset_name == 'reddit': 36 | dataset = Reddit(root=dataset_dir+'/reddit/', pre_transform=transform) 37 | setattr(dataset, 'name', 'reddit') 38 | elif dataset_name == 'yelp': 39 | dataset = Yelp(root=dataset_dir+'/yelp/', pre_transform=transform) 40 | setattr(dataset, 'name', 'yelp') 41 | elif dataset_name == 'flickr': 42 | dataset = Flickr(root=dataset_dir+'/flickr/' ,pre_transform=transform) 43 | setattr(dataset, 'name', 'flickr') 44 | elif dataset_name.startswith('ogbn'): 45 | 46 | if dataset_name == 'ogbn-proteins': 47 | transform = transform = T.Compose([T.ToSparseTensor(), PrepareProteins()]) 48 | elif dataset_name == 'ogbn-arxiv': 49 | transform = transform = T.Compose([T.ToSparseTensor(), PrepareArxiv()]) 50 | elif dataset_name == 'ogbn-products': 51 | transform = transform = T.Compose([T.ToSparseTensor(), PrepareProducts()]) 52 | elif dataset_name == 'ogbn-papers100M': 53 | transform = transform = T.Compose([T.ToSparseTensor(), PreparePapers100M()]) 54 | 55 | dataset = PygNodePropPredDataset(dataset_name, dataset_dir, pre_transform=transform) 56 | 57 | splitted_idx = dataset.get_idx_split() 58 | data = dataset.data 59 | 60 | # Fix few things about ogbn-proteins meta_info 61 | if dataset_name == 'ogbn-proteins': 62 | dataset.slices['x'] = dataset.slices['y'] 63 | dataset.__num_classes__ = 112 64 | 65 | # Add split info to Data object 66 | for split in ['train', 'val', 'test']: 67 | mask = torch.zeros(data.num_nodes, dtype=torch.bool) 68 | if split == 'val': 69 | mask[splitted_idx['valid']] = True 70 | else: 71 | mask[splitted_idx[split]] = True 72 | data[f'{split}_mask'] = mask 73 | dataset.slices[f'{split}_mask'] = dataset.slices['x'] 74 | 75 | # data['val_mask'] = data['valid_mask'] 76 | # dataset.slices['val_mask'] = dataset.slices['x'] 77 | 78 | else: 79 | print('dataset {} is not supported!'.format(dataset_name)) 80 | raise NotImplementedError 81 | 82 | return dataset -------------------------------------------------------------------------------- /dgnn/train/sampling.py: -------------------------------------------------------------------------------- 1 | from .base import Base 2 | from .full import Full 3 | 4 | from ..data import samplers 5 | from ..data.transforms import row_norm 6 | 7 | class Sampling(Full, Base): 8 | 9 | def __init__(self, config, dataset): 10 | 11 | # if full inference this, else Base init 12 | Full.__init__(self, config, dataset) 13 | # else: 14 | # Base.__init__(self, config, dataset) 15 | 16 | # Do it again 17 | self.full_adj = self.dataset[0].adj_t 18 | 19 | # if self.dataset.name.startswith('ogbn'): 20 | # self.full_adj = self.dataset[0].adj_t.to_symmetric() 21 | 22 | if config.sampler == 'subgraph': 23 | self.train_loader = samplers.SubGraphSampler(self.full_adj, 24 | self.config.minibatch_size, 25 | num_workers=self.config.num_samplers, 26 | num_batches=self.config.local_updates, 27 | num_layers=self.config.num_layers, 28 | persistent_workers=True, 29 | ) 30 | elif config.sampler == 'neighbor': 31 | self.train_loader = samplers.NeighborSampler(self.full_adj, 32 | self.config.minibatch_size, 33 | num_workers=self.config.num_samplers, 34 | num_batches=self.config.local_updates, 35 | num_layers=self.config.num_layers, 36 | num_neighbors=self.config.num_neighbors, 37 | persistent_workers=True, 38 | ) 39 | # use row_norm for full inference 40 | self.full_adj = row_norm(self.full_adj).to(self.full_device) 41 | 42 | print(f'K={len(self.train_loader)}') 43 | 44 | def train(self, epoch): 45 | self.model.train() 46 | for input_nid, nodeblocks, output_nid in self.train_loader: 47 | 48 | # do this to train sampling with MLP 49 | #input_nid = output_nid 50 | 51 | # import pdb; pdb.set_trace() 52 | nodeblocks.to(self.device) 53 | features = self.full_features[input_nid] 54 | labels = self.full_labels[output_nid] 55 | train_mask = self.full_train_mask[output_nid] 56 | 57 | if self.config.cpu_val: 58 | features = features.to(self.device) 59 | labels = labels.to(self.device) 60 | train_mask = train_mask.to(self.device) 61 | 62 | self.optimizer.zero_grad() 63 | output = self.model(features, nodeblocks) 64 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 65 | 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | train_score = self.calc_score(output[train_mask], labels[train_mask]) 70 | self.stats.train_loss.append(loss.item()) 71 | self.stats.train_scores.append(train_score) 72 | 73 | # Sampling validation 74 | # @torch.no_grad() 75 | # def validation(self, epoch): 76 | # raise NotImplementedError 77 | 78 | # Sampling inference 79 | # @torch.no_grad() 80 | # def inference(self): 81 | # if self.config.sampling_infe: 82 | # Full.... 83 | # raise NotImplementedError 84 | 85 | -------------------------------------------------------------------------------- /dgnn/train/serial/distgnn.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | 8 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 9 | 10 | from ..base import Base 11 | from ..full import Full 12 | from ...data.transforms import row_norm 13 | from ...data import samplers, partition 14 | 15 | from .distgnn_full import DistGNNFull 16 | 17 | 18 | class DistGNN(DistGNNFull): 19 | def __init__(self, config, dataset): 20 | super().__init__(config, dataset) 21 | self.clients_trainloader = [] 22 | 23 | for rank in range(self.config.world_size): 24 | 25 | tmp_adj = self.clients_adj[rank] 26 | 27 | if self.config.sampler == 'subgraph': 28 | tmp_train_loader = samplers.SubGraphSampler(tmp_adj, 29 | self.config.minibatch_size, 30 | num_workers=self.config.num_samplers, 31 | num_batches=self.config.local_updates, 32 | num_layers=self.config.num_layers, 33 | persistent_workers=True, 34 | ) 35 | elif config.sampler == 'neighbor': 36 | tmp_train_loader = samplers.NeighborSampler(tmp_adj, 37 | self.config.minibatch_size, 38 | num_workers=self.config.num_samplers, 39 | num_batches=self.config.local_updates, 40 | num_layers=self.config.num_layers, 41 | num_neighbors=self.config.num_neighbors, 42 | persistent_workers=True, 43 | ) 44 | 45 | self.clients_trainloader.append(tmp_train_loader) 46 | 47 | if self.config.sampler == 'neighbor': 48 | print('FUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU') 49 | self.full_adj = row_norm(self.dataset[0].adj_t.to_symmetric()).to(self.device) 50 | 51 | def local_train(self, rank, epoch): 52 | 53 | # To speedup serial training, move all features and labels 54 | if epoch == 0: 55 | print('Local Train', rank, len(self.clients_trainloader[rank])) 56 | self.clients_features[rank] = self.clients_features[rank].to(self.device) 57 | self.clients_labels[rank] = self.clients_labels[rank].to(self.device) 58 | 59 | self.client_sync(rank) 60 | 61 | self.clients_model[rank].train() 62 | 63 | for input_nid, nodeblocks, output_nid in self.clients_trainloader[rank]: 64 | 65 | nodeblocks.to(self.device) 66 | 67 | features = self.clients_features[rank][input_nid] #.to(self.device) 68 | labels = self.clients_labels[rank][output_nid] #.to(self.device) 69 | train_mask = self.clients_train_mask[rank][output_nid] 70 | 71 | # import pdb; pdb.set_trace() 72 | 73 | self.clients_optimizer[rank].zero_grad() 74 | 75 | output = self.clients_model[rank](features, nodeblocks) 76 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 77 | 78 | loss.backward() 79 | self.clients_optimizer[rank].step() 80 | 81 | if not self.config.weight_avg: 82 | for i, cp in enumerate(self.clients_model[rank].parameters()): 83 | self.clients_grads[rank][i] += cp.grad 84 | -------------------------------------------------------------------------------- /dgnn/train/serial/distgnn_full_correction.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from . import DistGNN, DistGNNFull 8 | from ...data import samplers 9 | from ...data.transforms import row_norm 10 | 11 | class DistGNNFullCorr(DistGNNFull): 12 | 13 | def __init__(self, config, dataset): 14 | super().__init__(config, dataset) 15 | 16 | # self.full_adj = self.dataset[0].adj_t 17 | # if self.dataset.name.startswith('ogbn'): 18 | # self.full_adj = self.dataset[0].adj_t.to_symmetric() 19 | 20 | if self.config.sampler == 'subgraph': 21 | self.server_trainloader = samplers.SubGraphSampler(self.dataset[0].adj_t.to_symmetric(), 22 | self.config.server_minibatch_size, 23 | num_workers=self.config.num_samplers, 24 | num_batches=self.config.server_updates, 25 | num_layers=self.config.num_layers, 26 | minibatch=self.config.server_minibatch, 27 | part_meta=self.dataset_dir, 28 | persistent_workers=True, 29 | ) 30 | elif config.sampler == 'neighbor': 31 | self.server_trainloader = samplers.NeighborSampler(self.dataset[0].adj_t.to_symmetric(), 32 | self.config.server_minibatch_size, 33 | num_workers=self.config.num_samplers, 34 | num_batches=self.config.server_updates, 35 | num_layers=self.config.num_layers, 36 | num_neighbors=self.config.server_num_neighbors, 37 | minibatch=self.config.server_minibatch, 38 | part_meta=self.dataset_dir, 39 | persistent_workers=True, 40 | ) 41 | 42 | 43 | def train(self, epoch): 44 | 45 | self.model.train() 46 | 47 | for rank in range(self.config.world_size): 48 | self.local_train(rank, epoch) 49 | 50 | self.server_average() 51 | self.server_correction(epoch) 52 | 53 | def server_correction(self, epoch): 54 | 55 | self.model.train() 56 | 57 | for input_nid, nodeblocks, output_nid in self.server_trainloader: 58 | if epoch == 0: 59 | print('Server correction!', len(self.server_trainloader)) 60 | 61 | nodeblocks.to(self.device) 62 | features = self.full_features[input_nid] 63 | labels = self.full_labels[output_nid] 64 | train_mask = self.full_train_mask[output_nid] 65 | 66 | self.optimizer.zero_grad() 67 | output = self.model(features, nodeblocks) 68 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 69 | loss.backward() 70 | self.optimizer.step() 71 | 72 | # if epoch == 0: 73 | # print('Server Correction!!') 74 | 75 | # nodeblocks = self.full_adj 76 | # features = self.full_features 77 | # labels = self.full_labels 78 | # train_mask = self.full_train_mask 79 | 80 | # self.optimizer.zero_grad() 81 | # output = self.model(features, nodeblocks) 82 | # loss = self.loss_fnc(output[train_mask], labels[train_mask]) 83 | # loss.backward() 84 | # self.optimizer.step() -------------------------------------------------------------------------------- /dgnn/layers/dh_gconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch_sparse 4 | 5 | from ..utils import dist_sum, dist_spmm 6 | 7 | 8 | class dhgconv(torch.autograd.Function): 9 | 10 | fw_hist_t = [] 11 | bw_hist_ag = [] 12 | bw_hist_gw = [] 13 | 14 | @staticmethod 15 | def forward(ctx, inputs, weight, adjs, rank, world_size, layerid, num_layers, use_hist=False): 16 | 17 | ctx.save_for_backward(inputs, weight) 18 | ctx.adjs = adjs 19 | ctx.rank = rank 20 | ctx.world_size = world_size 21 | ctx.use_hist = use_hist 22 | ctx.layerid = layerid 23 | ctx.num_layers = num_layers 24 | 25 | if not use_hist: 26 | T = dist_spmm(adjs, inputs, rank, world_size) 27 | tmp_hist_t = T - adjs[rank].spmm(inputs) 28 | if len(dhgconv.fw_hist_t) < layerid + 1: 29 | dhgconv.fw_hist_t.append(tmp_hist_t) 30 | else: 31 | dhgconv.fw_hist_t[layerid] = tmp_hist_t 32 | else: 33 | # print('LayLay', layerid, len(dhgconv.fw_hist_t)) 34 | T = adjs[rank].spmm(inputs) + dhgconv.fw_hist_t[layerid] 35 | 36 | # Z = TW 37 | Z = torch.mm(T, weight) 38 | 39 | return Z 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output): 43 | 44 | inputs, weight = ctx.saved_tensors 45 | adjs = ctx.adjs 46 | rank = ctx.rank 47 | world_size = ctx.world_size 48 | use_hist = ctx.use_hist 49 | layerid = ctx.layerid 50 | num_layers = ctx.num_layers 51 | 52 | new_lid = num_layers - (layerid + 1) 53 | 54 | if not ctx.use_hist: 55 | ag = dist_spmm(adjs, grad_output, rank , world_size) 56 | tmp_hist_ag = ag - adjs[rank].spmm(grad_output) 57 | if len(dhgconv.bw_hist_ag) < new_lid + 1: 58 | dhgconv.bw_hist_ag.append(tmp_hist_ag) 59 | else: 60 | dhgconv.bw_hist_ag[new_lid] = tmp_hist_ag 61 | else: 62 | ag = adjs[rank].spmm(grad_output) + dhgconv.bw_hist_ag[new_lid] 63 | 64 | grad_input = torch.mm(ag, weight.t()) 65 | 66 | grad_weight = torch.mm(inputs.t(), ag) 67 | 68 | if not ctx.use_hist: 69 | tmp_gw = grad_weight 70 | grad_weight = dist_sum(grad_weight) 71 | tmp_hist_gw = grad_weight - tmp_gw 72 | 73 | if len(dhgconv.bw_hist_gw) < new_lid + 1: 74 | dhgconv.bw_hist_gw.append(tmp_hist_gw) 75 | else: 76 | dhgconv.bw_hist_gw[new_lid] = tmp_hist_gw 77 | else: 78 | grad_weight = grad_weight + dhgconv.bw_hist_gw[new_lid] 79 | 80 | 81 | return grad_input, grad_weight, None, None, None, None, None, None 82 | 83 | 84 | # https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/ 85 | # https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca 86 | # https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd 87 | # https://www.kaggle.com/sironghuang/understanding-pytorch-hooks 88 | 89 | class DHGConv(torch.nn.Module): 90 | 91 | def __init__(self, 92 | input_dim, 93 | output_dim, 94 | rank=0, 95 | layerid=0, 96 | num_layers=0, 97 | ): 98 | super().__init__() 99 | 100 | self.input_dim = input_dim 101 | self.output_dim = output_dim 102 | self.rank = rank 103 | self.layerid = layerid 104 | self.num_layers = num_layers 105 | 106 | self.fn = dhgconv.apply 107 | self.weight = torch.nn.Parameter(torch.rand(input_dim, output_dim)) 108 | 109 | 110 | def forward(self, x, adj, use_hist): 111 | world_size = dist.get_world_size() 112 | x = self.fn(x, self.weight, adj, self.rank, world_size, self.layerid, self.num_layers, use_hist) 113 | return x 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + "[{}] ({}->{})".format( 117 | self.rank, 118 | self.input_dim, 119 | self.output_dim) 120 | -------------------------------------------------------------------------------- /dgnn/train/old/historic.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | 5 | import torch 6 | import torch.multiprocessing as mp 7 | import torch.distributed as dist 8 | 9 | from ..utils import Stats 10 | from ..models import model_selector 11 | from ..data import partition as P 12 | from ..utils import helpers as H 13 | 14 | from . import Distributed 15 | 16 | class Historic(Distributed): 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | 21 | def train(self, rank, *args, **kwargs): 22 | 23 | # Load the rank-th partition, to rank-th device both adj and features 24 | adj, features, labels, train_mask, val_mask, test_mask = P.load_partitions( 25 | self.dataset_dir, rank) 26 | 27 | # Init the local model, 28 | model = copy.deepcopy(self.model) 29 | model.update_rank(rank) 30 | device = H.rank2dev(rank, self.num_gpus) 31 | 32 | print(device, flush=True) 33 | 34 | model = model.to(device) 35 | adj = [a.to(device) for a in adj] 36 | features = features.to(device) 37 | train_mask = train_mask.to(device) 38 | val_mask = val_mask.to(device) 39 | test_mask = test_mask.to(device) 40 | labels = labels.to(device) 41 | 42 | loss_fnc = torch.nn.CrossEntropyLoss() 43 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 44 | 45 | best_model = copy.deepcopy(model) 46 | best_val_score = 0 47 | 48 | for epoch in range(self.config.num_epochs): 49 | # train 50 | model.train() 51 | optimizer.zero_grad() 52 | 53 | use_hist = None 54 | 55 | if self.config.hist_period > 1: 56 | use_hist = True 57 | 58 | if epoch == self.config.num_epochs - 1: 59 | use_hist = False 60 | 61 | if not self.config.hist_exp: 62 | if epoch % self.config.hist_period == 0: 63 | use_hist = False 64 | else: 65 | exp_pow = int(epoch / self.config.hist_period) 66 | if epoch % pow(2, exp_pow) == 0: 67 | use_hist = False 68 | 69 | output = model(features, adj, use_hist) 70 | loss = loss_fnc(output[train_mask], labels[train_mask]) 71 | 72 | loss.backward() 73 | optimizer.step() 74 | 75 | # Distributed Validation 76 | val_pred = output[val_mask].argmax(dim=1) 77 | val_acc = torch.stack( 78 | (val_pred.eq(labels[val_mask]).sum(), val_mask.sum())) 79 | 80 | all_val_acc = [torch.ones_like(val_acc) for _ in range(self.config.num_procs)] 81 | dist.all_gather(all_val_acc, val_acc) 82 | 83 | tmp_score = torch.stack(all_val_acc, dim=0).sum(dim=0) 84 | val_score = (tmp_score[0]/tmp_score[1]).item() 85 | 86 | if val_score > best_val_score: 87 | best_val_score = val_score 88 | best_model = copy.deepcopy(model) 89 | 90 | # End of Epoch 91 | if rank == 0: 92 | self.stats.train_loss.append(loss.item()) 93 | self.stats.val_scores.append(val_score) 94 | 95 | print(f'Epoch #{epoch}:', 96 | f'train loss {loss.item():.3f}', 97 | f'val accuracy {val_score*100:.2f}%', 98 | '*' if not use_hist else '', 99 | flush=True) 100 | 101 | # Testing 102 | best_model.eval() 103 | test_output = best_model(features, adj) 104 | 105 | test_pred = test_output[test_mask].argmax(dim=1) 106 | test_acc = torch.stack( 107 | (test_pred.eq(labels[test_mask]).sum(), test_mask.sum())) 108 | 109 | all_test_acc = [torch.ones_like(test_acc) for _ in range(self.config.num_procs)] 110 | dist.all_gather(all_test_acc, test_acc) 111 | 112 | if rank == 0: 113 | tmp_score = torch.stack(all_test_acc, dim=0).sum(dim=0) 114 | test_score = (tmp_score[0]/tmp_score[1]).item() 115 | 116 | self.stats.test_score = test_score 117 | 118 | print(f'Best model test score is: {test_score*100:.2f}%', 119 | flush=True) 120 | 121 | self.save() -------------------------------------------------------------------------------- /scripts/run-config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import argparse 5 | import commentjson 6 | 7 | sys.path.append('..') 8 | sys.path.append('../dgnn/utils/cython/') # dirty hack to make up for relative import in pxd 9 | from dgnn import data, utils, train 10 | 11 | if os.environ['LOGNAME'] == 'mfr5226': 12 | os.environ['GNN_DATASET_DIR'] = '/export/local/mfr5226/datasets/pyg_dist/' 13 | 14 | # base_dir = '../../../outputs/dist-gnn/721/' 15 | # base_dir = '../../../outputs/dist-gnn/722/' 16 | # base_dir = '../../../outputs/dist-gnn/723/' 17 | # base_dir = '../../../outputs/dist-gnn/724/' 18 | 19 | # base_dir = '../../../outputs/dist-gnn/801/' # main table 20 | base_dir = '../../../outputs/dist-gnn/802/' # ablation 21 | 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description='') 26 | parser.add_argument('--config', type=str, default='cora') 27 | parser.add_argument('--mode', type=str, default='full') 28 | parser.add_argument('--np', type=int, default='8') 29 | parser.add_argument('--rep', type=int, default='1') 30 | parser.add_argument('--strat', action='store_true', default=False) 31 | parser.add_argument('--metis', action='store_true', default=False) 32 | parser.add_argument('--nosave', action='store_true', default=False) 33 | 34 | parser.add_argument('--bs', type=int, default='2048') 35 | parser.add_argument('--k', type=int, default='64') 36 | parser.add_argument('--s', type=int, default='1') 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | if args.mode == 'full': 42 | trainer = train.Full 43 | elif args.mode == 'dist': 44 | trainer = train.old.Distributed 45 | elif args.mode == 'dgnnfull': 46 | trainer = train.serial.DistGNNFull 47 | elif args.mode == 'dgnnfullcor': 48 | trainer = train.serial.DistGNNFullCorr 49 | elif args.mode == 'sampling': 50 | trainer = train.Sampling 51 | elif args.mode == 'dgnn': 52 | trainer = train.serial.DistGNN 53 | elif args.mode == 'dgnncor': 54 | trainer = train.serial.DistGNNCorr 55 | elif args.mode == 'dgnnstale': 56 | trainer = train.serial.DistGNNStale 57 | elif args.mode == 'd2gnn': 58 | trainer = train.dist.DistGNN 59 | elif args.mode == 'd2gnnfull': 60 | trainer = train.dist.DistGNNFull 61 | elif args.mode == 'd2gnncor': 62 | trainer = train.dist.DistGNNCorrection 63 | elif args.mode == 'dgl': 64 | trainer = train.dist.DistDGL 65 | else: 66 | raise NotImplementedError 67 | 68 | global_config = { 69 | 'num_workers': args.np, 70 | 'part_method': 'random' if not args.metis else 'metis', 71 | # 'part_method': 'metis', 72 | # 'part_method': 'overhead', 'part_args': 10, 73 | 'weight_avg': True, 74 | 'server_updates': 1, 75 | } 76 | 77 | 78 | with open(f'./configs/{args.config}.json', 'r') as config_file: 79 | local_config = commentjson.load(config_file) 80 | 81 | dataset_name = local_config['dataset'] 82 | dataset = data.Dataset(dataset_name) 83 | print('Done loading dataset...', dataset) 84 | out_dir = f'{base_dir}/{dataset_name}/' 85 | # run_name = trainer.__name__.lower() 86 | run_name = trainer.__module__.split('train.')[-1] 87 | tmp_global = { 88 | 'dataset': dataset_name, 89 | 'run_name': run_name, 90 | 'output_dir': out_dir, 91 | } 92 | 93 | # import pdb; pdb.set_trace() 94 | 95 | if args.mode =="d2gnncor": 96 | tmp_global['server_opt_sync'] = True 97 | 98 | tmp_global['server_minibatch_size'] = args.bs 99 | tmp_global['local_updates'] = args.k 100 | # tmp_global['server_updates'] = args.s 101 | 102 | global_config.update(local_config) 103 | global_config.update(tmp_global) 104 | 105 | 106 | print(trainer.__name__.lower(), dataset_name, global_config['num_workers']) 107 | print(global_config) 108 | 109 | train_config = utils.Config(global_config) 110 | 111 | for i in range(args.rep): 112 | exp = trainer(train_config, dataset) 113 | if i == 0: 114 | print(exp.model) 115 | 116 | print(f'Run #{i}...') 117 | exp.run() 118 | # import pdb; pdb.set_trace() 119 | if not args.nosave: 120 | exp.save() 121 | 122 | 123 | # CUDA_VISIBLE_DEVICES=0,1,2 python run-config.py --config test --mode d2gnn --np 4 --nosave 2>/dev/null -------------------------------------------------------------------------------- /dgnn/train/old/paravg.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | 5 | import torch 6 | import torch.multiprocessing as mp 7 | import torch.distributed as dist 8 | import torch_geometric.transforms as T 9 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 10 | 11 | from ..utils import Stats 12 | from ..models import model_selector 13 | from ..data import partition as P 14 | from ..utils import helpers as H 15 | 16 | from . import Distributed 17 | 18 | class ParamsAvg(Distributed): 19 | 20 | def __init__(self, config, dataset): 21 | super().__init__(config, dataset) 22 | 23 | self.dataset = dataset[0] 24 | self.dataset = T.GCNNorm()(self.dataset) 25 | 26 | def train(self, rank, *args, **kwargs): 27 | 28 | # Load the rank-th partition, to rank-th device both adj and features 29 | adj, features, labels, train_mask, val_mask, test_mask = P.load_partitions( 30 | self.dataset_dir, rank) 31 | 32 | # Init the local model, 33 | model = copy.deepcopy(self.model) 34 | # model.update_rank(rank) 35 | device = H.rank2dev(rank, self.num_gpus) 36 | 37 | print(device, flush=True) 38 | 39 | # renormalize this rank adjacency again 40 | adj = adj[rank] 41 | # adj.storage._value = None 42 | adj.set_value(None) 43 | adj = gcn_norm(adj) 44 | 45 | model = model.to(device) 46 | adj = adj.to(device) 47 | features = features.to(device) 48 | train_mask = train_mask.to(device) 49 | val_mask = val_mask.to(device) 50 | test_mask = test_mask.to(device) 51 | labels = labels.to(device) 52 | 53 | loss_fnc = torch.nn.CrossEntropyLoss() 54 | optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) 55 | 56 | best_model = copy.deepcopy(model) 57 | best_val_score = 0 58 | 59 | if rank == 0: 60 | full_adj = self.dataset.adj_t.to(device) 61 | full_features = self.dataset.x.to(device) 62 | full_labels = self.dataset.y.to(device) 63 | full_val_mask = self.dataset.val_mask 64 | 65 | for epoch in range(self.config.num_epochs): 66 | # train 67 | model.train() 68 | optimizer.zero_grad() 69 | 70 | output = model(features, adj) 71 | loss = loss_fnc(output[train_mask], labels[train_mask]) 72 | 73 | loss.backward() 74 | optimizer.step() 75 | 76 | # Fed Average 77 | # print(model.state_dict()) 78 | # print(len(model.state_dict())) 79 | 80 | params = copy.deepcopy(model.state_dict()) 81 | 82 | for layer in params.keys(): 83 | dist.all_reduce(params[layer], op=dist.ReduceOp.SUM) 84 | # print('pa', params[layer].numel() * params[layer].element_size()) 85 | params[layer] = torch.div(params[layer], self.config.num_procs) 86 | 87 | model.load_state_dict(params) 88 | 89 | # End of Epoch 90 | if rank == 0: 91 | self.stats.train_loss.append(loss.item()) 92 | 93 | model.eval() 94 | val_output = model(full_features, full_adj) 95 | val_pred = val_output[full_val_mask].argmax(dim=1) 96 | val_score = (val_pred.eq( 97 | full_labels[full_val_mask]).sum() / full_val_mask.sum()).item() 98 | self.stats.val_scores.append(val_score) 99 | 100 | print(f'Epoch #{epoch}:', 101 | f'train loss {loss.item():.3f}', 102 | f'val accuracy {val_score*100:.2f}%', 103 | flush=True) 104 | 105 | 106 | # Testing 107 | if rank == 0: 108 | print('End of training on rank 0, testing on full graph') 109 | # print(self.dataset) 110 | # adj = self.dataset.adj_t.to(device) 111 | # features = self.dataset.x.to(device) 112 | # labels = self.dataset.y.to(device) 113 | test_mask = self.dataset.test_mask 114 | 115 | model.eval() 116 | test_output = model(full_features, full_adj) 117 | test_pred = test_output[test_mask].argmax(dim=1) 118 | 119 | test_score = (test_pred.eq( 120 | full_labels[test_mask]).sum() / test_mask.sum()).item() 121 | 122 | print('Test accuracy is {:.2f}'.format(test_score*100)) 123 | 124 | 125 | -------------------------------------------------------------------------------- /dgnn/data/samplers/minibatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import partition as P 4 | class RandomBatchSampler(torch.utils.data.Sampler): 5 | 6 | def __init__(self, num_nodes, batch_size, shuffle, num_batches=1): 7 | self.num_nodes = num_nodes 8 | self.batch_size = batch_size 9 | self.shuffle = shuffle 10 | self.num_parts = self.num_nodes // self.batch_size 11 | self.num_batches = num_batches 12 | self.batched_nodes = self.split_nodes() 13 | # print(len(self.batched_nodes)) 14 | 15 | def split_nodes(self): 16 | # Splits nodes into equal size batches, to ensure each node at least happens once and also iter needs prepared list to iterate 17 | 18 | nodes_id = torch.randint(self.num_parts, (self.num_nodes,), dtype=torch.long) 19 | split_ids = [(nodes_id == i).nonzero(as_tuple=False).view(-1) for i in range(self.num_parts)] 20 | return split_ids 21 | 22 | def select_nodes(self): 23 | # node_id = torch.randperm(self.num_nodes) 24 | # return node_id[:self.batch_size] 25 | nodes_id = torch.randint(self.num_nodes, (self.batch_size, ), dtype=torch.long) 26 | nodes_id, _ = torch.sort(nodes_id) 27 | return nodes_id 28 | 29 | def __iter__(self): 30 | # Generates next minibatch, do shuffling here if necessary 31 | 32 | # print('In iter', self.num_batches) 33 | 34 | if self.num_batches < 1: 35 | self.batched_nodes = self.split_nodes() 36 | else: 37 | self.batched_nodes = [] 38 | for _ in range(self.num_batches): 39 | self.batched_nodes.append(self.select_nodes()) 40 | 41 | # print(self.batched_nodes) 42 | return iter(self.batched_nodes) 43 | 44 | def __len__(self): 45 | # Number of minibatch generated 46 | if self.num_batches < 1 : 47 | return self.num_parts 48 | return self.num_batches 49 | 50 | def update_k(self, k): 51 | self.num_batches = k 52 | 53 | class StratifiedMiniBatch(torch.utils.data.Sampler): 54 | def __init__(self, num_nodes, batch_size, shuffle, num_batches, part_meta, *args): 55 | self.num_nodes = num_nodes 56 | self.batch_size = batch_size 57 | self.shuffle = shuffle 58 | self.num_parts = self.num_nodes // self.batch_size 59 | self.num_batches = num_batches 60 | 61 | # load the permutation data 62 | self.partition_idx = P.load_meta(part_meta) 63 | self.batched_nodes = None 64 | 65 | def split_nodes(self): 66 | sampled_nodes = [] 67 | num_partitions = len(self.partition_idx) 68 | nodes_per_part = self.batch_size // num_partitions 69 | for i in range(num_partitions): 70 | perm = torch.randperm(self.partition_idx[i].size(0)) 71 | samples = self.partition_idx[i][perm[:nodes_per_part]] 72 | sampled_nodes.append(samples) 73 | 74 | batch_nodes, _ = torch.sort(torch.cat(sampled_nodes)) 75 | return batch_nodes 76 | 77 | def __iter__(self): 78 | if self.num_batches < 1: 79 | self.batched_nodes = [self.split_nodes()] 80 | else: 81 | self.batched_nodes = [] 82 | for _ in range(self.num_batches): 83 | self.batched_nodes.append(self.split_nodes()) 84 | 85 | return iter(self.batched_nodes) 86 | 87 | def __len__(self): 88 | return len(self.batched_nodes) 89 | 90 | 91 | class DGLBatchSampler(torch.utils.data.Sampler): 92 | 93 | def __init__(self, nodes_idx, batch_size, shuffle, num_batches=1): 94 | self.node_idx = nodes_idx 95 | self.num_nodes = len(nodes_idx) 96 | self.batch_size = batch_size 97 | self.shuffle = shuffle 98 | self.num_parts = self.num_nodes // self.batch_size 99 | self.num_batches = num_batches 100 | self.batched_nodes = [] 101 | 102 | def select_nodes(self): 103 | nodes_id = torch.randint(self.num_nodes, (self.batch_size, ), dtype=torch.long) 104 | nodes_id, _ = torch.sort(self.node_idx[nodes_id]) 105 | return nodes_id 106 | 107 | def __iter__(self): 108 | self.batched_nodes = [] 109 | for _ in range(self.num_batches): 110 | self.batched_nodes.append(self.select_nodes()) 111 | 112 | # print(self.batched_nodes) 113 | return iter(self.batched_nodes) 114 | 115 | def __len__(self): 116 | # Number of minibatch generated 117 | if self.num_batches < 1 : 118 | return self.num_parts 119 | return self.num_batches -------------------------------------------------------------------------------- /dgnn/train/dist/dgl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import torch 5 | import torch.multiprocessing as mp 6 | import torch.distributed as dist 7 | 8 | import numpy as np 9 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 10 | 11 | from multiprocessing import Value 12 | from ctypes import c_bool 13 | from tqdm import trange 14 | 15 | from ..base import Base 16 | from .distgnn import DistGNN 17 | from ...data import samplers, partition, Dataset 18 | from ...utils import helpers as H 19 | from ...data.transforms import row_norm 20 | 21 | 22 | class DistDGL(DistGNN): 23 | 24 | def __init__(self, config, dataset): 25 | super().__init__(config, dataset) 26 | 27 | # self.raw_adj = self.dataset[0].adj_t 28 | 29 | 30 | # worker training 31 | @staticmethod 32 | def workers(rank, params_queue, ready_flag, config, dataset_rawdir, 33 | global_model, loss_fnc, meta_queue, end_train, comm_cost): 34 | 35 | ready_flag = ready_flag[rank] 36 | 37 | dataset_dir = os.path.join(dataset_rawdir[:-3], config.partitioned_dir) 38 | dataset_processed = os.path.join(dataset_rawdir[:-3], config.processed_filename) 39 | 40 | # part_idx for metis loaded seperately 41 | perm, part_ptr = torch.load(dataset_dir+'/perm.pt') 42 | start = rank 43 | end = rank + 1 44 | part_idx = perm[part_ptr[start]:part_ptr[end]] 45 | 46 | # open full adj and features and lab and mask 47 | dataset = torch.load(dataset_processed) 48 | 49 | adj = dataset[0].adj_t 50 | if type(adj) == list: 51 | adj = adj[0] 52 | 53 | feat = dataset[0].x 54 | lab = dataset[0].y 55 | full_train_mask = dataset[0].train_mask 56 | 57 | tr_mask = full_train_mask[part_idx] 58 | 59 | # if rank == 0: 60 | # print(dataset_dir, dataset_processed) 61 | # print(adj, feat, lab, part_idx.shape, tr_mask.shape) 62 | 63 | meta_queue.put((rank, tr_mask.count_nonzero())) 64 | 65 | device = H.rank2dev(rank, config.num_gpus) 66 | 67 | if config.sampler == 'neighbor': 68 | train_loader = samplers.NeighborSampler(adj, 69 | config.minibatch_size, 70 | num_workers=config.num_samplers, 71 | num_batches=config.local_updates, 72 | num_layers=config.num_layers, 73 | num_neighbors=config.num_neighbors, 74 | minibatch='dglsim', 75 | node_idx=part_idx, 76 | persistent_workers=True, 77 | ) 78 | else: 79 | raise NotImplementedError 80 | 81 | model = copy.deepcopy(global_model).to(device) 82 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 83 | 84 | model_size = 0 85 | params = model.state_dict() 86 | for k in params: 87 | model_size += params[k].element_size() * params[k].nelement() 88 | 89 | feat_size = feat[0].element_size() * feat[0].nelement() 90 | 91 | 92 | for epoch in range(config.num_epochs): 93 | 94 | feat_cost = 0 95 | 96 | if end_train.value: 97 | break 98 | 99 | if epoch > 0: 100 | # Sync with Param Server 101 | model.load_state_dict(global_model.state_dict()) 102 | 103 | model.train() 104 | ready_flag.clear() 105 | 106 | # Train Locally for K iterations 107 | for input_nid, nodeblocks, output_nid in train_loader: 108 | nodeblocks.to(device) 109 | features = feat[input_nid].to(device) 110 | labels = lab[output_nid].to(device) 111 | train_mask = full_train_mask[output_nid].to(device) 112 | 113 | # compute remote nodes cost 114 | diff = input_nid[~input_nid.unsqueeze(1).eq(part_idx).any(1)] 115 | 116 | # if rank == 0: 117 | # print(part_idx) 118 | # print(input_nid) 119 | # print(diff) 120 | 121 | if diff.shape[0] > 0: 122 | feat_cost += diff.shape[0] * feat_size 123 | 124 | optimizer.zero_grad() 125 | output = model(features, nodeblocks) 126 | loss = loss_fnc(output[train_mask], labels[train_mask]) 127 | loss.backward() 128 | optimizer.step() 129 | 130 | # Move to CPU and put on the Queue 131 | params_dict = {} 132 | tmp_params = model.state_dict() 133 | for key in tmp_params: 134 | params_dict[key] = tmp_params[key].clone().cpu() 135 | 136 | params_queue.put(params_dict) 137 | 138 | # Cost of communication 139 | comm_cost[epoch] = model_size + feat_cost 140 | 141 | # Wait for server to continue with new global_model 142 | ready_flag.wait() 143 | 144 | -------------------------------------------------------------------------------- /dgnn/train/full.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import time 4 | 5 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 6 | import numpy as np 7 | 8 | from .base import Base 9 | from ..data.transforms import row_norm 10 | 11 | from sklearn.preprocessing import StandardScaler 12 | 13 | class Full(Base): 14 | 15 | def __init__(self, config, dataset): 16 | super().__init__(config, dataset) 17 | 18 | # check if full inference is set 19 | self.full_adj = self.dataset[0].adj_t 20 | 21 | # if self.dataset.name.startswith('ogbn') and self.dataset.name != 'ogbn-proteins': 22 | # self.full_adj = self.full_adj.to_symmetric() 23 | 24 | self.full_adj = gcn_norm(self.full_adj) 25 | # self.full_adj = row_norm(self.full_adj) 26 | 27 | self.full_device = self.val_device if self.config.cpu_val else self.device 28 | 29 | 30 | # import pdb; pdb.set_trace() 31 | 32 | self.full_adj = self.full_adj.to(self.full_device) 33 | # self.full_features = self.dataset[0].x.to(self.full_device) 34 | self.full_labels = self.dataset[0].y.to(self.full_device) 35 | 36 | train_feats = self.dataset[0].x[self.dataset[0].train_mask] 37 | scaler = StandardScaler() 38 | scaler.fit(train_feats) 39 | self.full_features = torch.FloatTensor(scaler.transform(self.dataset[0].x)).to(self.full_device) 40 | # import pdb; pdb.set_trace() 41 | 42 | self.full_train_mask = self.dataset[0].train_mask #.to(self.device) 43 | self.full_val_mask = self.dataset[0].val_mask 44 | self.full_test_mask = self.dataset[0].test_mask 45 | 46 | def train(self, epoch): 47 | 48 | self.model.train() 49 | 50 | adj = self.full_adj 51 | features = self.full_features 52 | labels = self.full_labels 53 | train_mask = self.full_train_mask 54 | 55 | self.optimizer.zero_grad() 56 | output = self.model(features, adj) 57 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 58 | 59 | loss.backward() 60 | self.optimizer.step() 61 | 62 | self.stats.train_loss.append(loss.item()) 63 | # train_score = self.calc_score(output[train_mask], labels[train_mask]) 64 | # self.stats.train_scores.append(train_score) 65 | 66 | self.train_output = output 67 | # self.train_loss = loss.item() 68 | 69 | # print(end_time-start_time) 70 | 71 | @torch.no_grad() 72 | def validation(self, epoch): 73 | 74 | if epoch > 0 and epoch % self.config.val_step != 0: 75 | return True 76 | 77 | if self.config.cpu_val: 78 | model = copy.deepcopy(self.model).cpu() 79 | else: 80 | model = self.model 81 | 82 | model.eval() 83 | 84 | if self.train_output is None or self.config.dropout > 0: 85 | val_output = model(self.full_features, self.full_adj) 86 | else: 87 | val_output = self.train_output 88 | 89 | val_loss = self.loss_fnc(val_output[self.full_val_mask], self.full_labels[self.full_val_mask]) 90 | val_score = self.calc_score(val_output[self.full_val_mask], self.full_labels[self.full_val_mask]) 91 | 92 | if val_score > self.stats.best_val_score: 93 | self.stats.best_val_epoch = epoch 94 | self.stats.best_val_loss = val_loss.item() 95 | self.stats.best_val_score = val_score 96 | self.stats.best_model = copy.deepcopy(model) 97 | 98 | self.stats.val_loss.append(val_loss.item()) 99 | self.stats.val_scores.append(val_score) 100 | 101 | # If train doesn't provide these (useful for other classes) 102 | if len(self.stats.train_loss) < epoch+1: 103 | train_loss = self.loss_fnc(val_output[self.full_train_mask], self.full_labels[self.full_train_mask]) 104 | self.stats.train_loss.append(train_loss.item()) 105 | 106 | if len(self.stats.train_scores) < epoch+1: 107 | train_score = self.calc_score(val_output[self.full_train_mask], self.full_labels[self.full_train_mask]) 108 | self.stats.train_scores.append(train_score) 109 | 110 | return self.patience() 111 | 112 | # if len(self.stats.val_scores) > self.config.val_patience and \ 113 | # np.max(self.stats.val_scores[-1*self.config.val_patience:]) < self.stats.best_val_score: 114 | # print('Run out of patience!') 115 | # return False 116 | 117 | # return True 118 | 119 | # test_score = self.calc_f1(val_output[self.full_test_mask], self.full_labels[self.full_test_mask]) 120 | # print(f'#{epoch} ' 121 | # f'Loss: {self.stats.train_loss[-1]:.3f}, ' 122 | # # f'Train Score: {self.stats.train_scores[-1]*100:.2f}, ' 123 | # f'Val Score: {val_score*100:.2f}, ' 124 | # f'Test Score: {test_score*100:.2f}' 125 | # ) 126 | 127 | @torch.no_grad() 128 | def inference(self): 129 | self.stats.best_model.eval() 130 | test_preds = self.stats.best_model(self.full_features, self.full_adj)[self.full_test_mask] 131 | test_labels = self.full_labels[self.full_test_mask] 132 | test_score = self.calc_score(test_preds, test_labels) 133 | 134 | self.stats.test_score = test_score 135 | -------------------------------------------------------------------------------- /dgnn/train/serial/distgnn_full.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 8 | 9 | from ..base import Base 10 | from ..full import Full 11 | from ...data.transforms import row_norm 12 | from ...data import samplers, partition 13 | 14 | class DistGNNFull(Full, Base): 15 | 16 | def __init__(self, config, dataset): 17 | 18 | # if full inference this, else Base init 19 | Full.__init__(self, config, dataset) 20 | 21 | self.dataset_dir = os.path.join(dataset.raw_dir[:-3], self.config.partitioned_dir) 22 | 23 | self.clients_adj = [] 24 | self.clients_features = [] 25 | self.clients_labels = [] 26 | self.clients_train_mask = [] 27 | self.clients_train_sizes = [] 28 | 29 | self.clients_model = [] 30 | self.clients_optimizer = [] 31 | self.clients_grads = [] 32 | 33 | for rank in range(self.config.world_size): 34 | # load partitions 35 | tmp_adj, tmp_feat, tmp_lab, tmp_tr, _, _ = partition.load_partitions(self.dataset_dir, rank) 36 | 37 | if self.config.part_method == 'metis': 38 | # import pdb; pdb.set_trace() 39 | tmp_adj = tmp_adj[rank] 40 | 41 | if self.dataset.name.startswith('ogbn') and self.dataset.name != 'ogbn-proteins': 42 | tmp_adj = tmp_adj.to_symmetric() 43 | 44 | self.clients_adj.append(tmp_adj) 45 | self.clients_features.append(tmp_feat) 46 | self.clients_labels.append(tmp_lab) 47 | self.clients_train_mask.append(tmp_tr) 48 | self.clients_train_sizes.append(tmp_tr.count_nonzero()) 49 | 50 | # model and grads... 51 | tmp_model = copy.deepcopy(self.model).to(self.device) 52 | tmp_opt = torch.optim.Adam(tmp_model.parameters(), lr=self.config.lr) 53 | tmp_grads = [] 54 | 55 | for p in tmp_model.parameters(): 56 | tmp_grads.append(torch.zeros_like(p)) 57 | 58 | self.clients_model.append(tmp_model) 59 | self.clients_optimizer.append(tmp_opt) 60 | self.clients_grads.append(tmp_grads) 61 | 62 | self.total_train_size = np.sum(self.clients_train_sizes) 63 | 64 | def train(self, epoch): 65 | self.model.train() 66 | for rank in range(self.config.world_size): 67 | self.local_train(rank, epoch) 68 | self.server_average() 69 | 70 | def local_train(self, rank, epoch): 71 | 72 | if epoch == 0: 73 | print('Local Train', rank) 74 | self.clients_adj[rank] = gcn_norm(self.clients_adj[rank]).to(self.device) 75 | # self.clients_adj[rank] = row_norm(self.clients_adj[rank]).to(self.device) 76 | self.clients_features[rank] = self.clients_features[rank].to(self.device) 77 | self.clients_labels[rank] = self.clients_labels[rank].to(self.device) 78 | 79 | # import pdb; pdb.set_trace() 80 | 81 | self.client_sync(rank) 82 | 83 | self.clients_model[rank].train() 84 | 85 | adj = self.clients_adj[rank] 86 | features = self.clients_features[rank] 87 | labels = self.clients_labels[rank] 88 | train_mask = self.clients_train_mask[rank] 89 | 90 | for inner_epoch in range(self.config.local_updates): 91 | 92 | self.clients_optimizer[rank].zero_grad() 93 | 94 | output = self.clients_model[rank](features, adj) 95 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 96 | 97 | loss.backward() 98 | self.clients_optimizer[rank].step() 99 | 100 | if not self.config.weight_avg: 101 | for i, cp in enumerate(self.clients_model[rank].parameters()): 102 | self.clients_grads[rank][i] += cp.grad 103 | 104 | def client_sync(self, rank): 105 | # Sync 106 | self.clients_model[rank].load_state_dict(self.model.state_dict()) 107 | 108 | if not self.config.weight_avg: 109 | for i, cp in enumerate(self.clients_model[rank].parameters()): 110 | self.clients_grads[rank][i] = torch.zeros_like(cp) 111 | 112 | def server_average(self): 113 | 114 | if self.config.weight_avg: 115 | params = self.model.state_dict() 116 | for k in params.keys(): 117 | params[k] = torch.zeros_like(params[k], dtype=torch.float) 118 | 119 | for rank in range(self.config.world_size): 120 | for k in params: 121 | params[k] += torch.div(self.clients_model[rank].state_dict()[k] * 122 | self.clients_train_sizes[rank], self.total_train_size) 123 | 124 | self.model.load_state_dict(params) 125 | 126 | else: 127 | print('Grad Agg') 128 | server_model = self.model 129 | for sp, cp in zip(server_model.parameters(), self.clients_model[0].parameters()): 130 | sp.grad = torch.zeros_like(cp.grad, dtype=torch.float) 131 | 132 | for rank in range(self.config.world_size): 133 | for i, sp in enumerate(server_model.parameters()): 134 | sp.grad += torch.div(self.clients_grads[rank][i], 135 | self.total_train_size/self.clients_train_sizes[rank]) 136 | 137 | self.optimizer.step() -------------------------------------------------------------------------------- /dgnn/train/old/base.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | 5 | import torch 6 | import torch.multiprocessing as mp 7 | import torch.distributed as dist 8 | import torch_geometric.transforms as T 9 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 10 | 11 | from ...utils import Stats 12 | from ...models import model_selector 13 | from ...layers import layer_selector 14 | from ...data import partition as P 15 | from ...utils import helpers as H 16 | 17 | 18 | class Base(object): 19 | """Base class for training GNN, single GPU (Process) 20 | 21 | Arguments: 22 | object {[type]} -- [description] 23 | """ 24 | 25 | def __init__(self, 26 | config, 27 | dataset, 28 | ): 29 | 30 | # self.dataset = dataset 31 | self.config = config 32 | self.stats = Stats(config) 33 | self.dataset = dataset[0] 34 | self.dataset = T.GCNNorm()(self.dataset) 35 | 36 | # Set device 37 | if self.config.gpu >= 0: 38 | self.device = self.config.gpu 39 | else: 40 | self.device = 'cpu' 41 | 42 | model = model_selector(self.config.model) 43 | layer = layer_selector(self.config.layer) 44 | 45 | self.activation = torch.nn.ReLU(True) 46 | 47 | self.model = model( 48 | dataset.num_features, 49 | self.config.hidden_size, 50 | dataset.num_classes, 51 | self.config.num_layers, 52 | self.activation, 53 | layer=layer 54 | ) 55 | 56 | self.model = self.model.to(self.device) 57 | 58 | self.loss_fnc = torch.nn.CrossEntropyLoss() 59 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 60 | # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.config.lr, momentum=0) 61 | 62 | 63 | 64 | def start(self): 65 | full_adj = self.dataset.adj_t.to(self.device) 66 | full_features = self.dataset.x.to(self.device) 67 | full_labels = self.dataset.y.to(self.device) 68 | 69 | full_train_mask, val_mask, test_mask = self.dataset.train_mask, self.dataset.val_mask, self.dataset.test_mask 70 | 71 | for epoch in range(self.config.num_epochs): 72 | 73 | if self.config.use_sampling: 74 | sampled_nodes = torch.randint(0, self.dataset.num_nodes, (self.config.minibatch_size, ), dtype=torch.long) 75 | sampled_nodes, _ = torch.sort(sampled_nodes) 76 | sampled_adj, _ = self.dataset.adj_t.saint_subgraph(sampled_nodes) 77 | sampled_feat = self.dataset.x[sampled_nodes] 78 | sampled_label = self.dataset.y[sampled_nodes] 79 | sampled_train_mask = self.dataset.train_mask[sampled_nodes] 80 | 81 | sampled_adj = sampled_adj.set_value(None) 82 | 83 | sampled_adj = gcn_norm(sampled_adj) 84 | adj = sampled_adj.to(self.device) 85 | features = sampled_feat.to(self.device) 86 | labels = sampled_label.to(self.device) 87 | train_mask = sampled_train_mask 88 | 89 | else: 90 | adj = full_adj 91 | features = full_features 92 | labels = full_labels 93 | train_mask = full_train_mask 94 | 95 | self.model.train() 96 | self.optimizer.zero_grad() 97 | 98 | output = self.model(features, adj) 99 | loss = self.loss_fnc(output[train_mask], labels[train_mask]) 100 | 101 | loss.backward() 102 | self.optimizer.step() 103 | 104 | # Validation 105 | if not self.config.use_sampling: 106 | val_loss = self.loss_fnc(output[val_mask],labels[val_mask]) 107 | val_pred = output[val_mask].detach().argmax(dim=1) 108 | val_score = (val_pred.eq( 109 | labels[val_mask]).sum() / val_mask.sum()).item() 110 | else: 111 | self.model.eval() 112 | val_output = self.model(full_features, full_adj) 113 | val_loss = self.loss_fnc(val_output[val_mask], full_labels[val_mask]) 114 | val_pred = val_output[val_mask].detach().argmax(dim=1) 115 | val_score = (val_pred.eq( 116 | full_labels[val_mask]).sum() / val_mask.sum()).item() 117 | 118 | self.stats.val_scores.append(val_score) 119 | if val_score > self.stats.best_val_score: 120 | self.stats.best_val_epoch = epoch 121 | self.stats.best_val_score = val_score 122 | self.stats.best_model = copy.deepcopy(self.model) 123 | 124 | print(f'Epoch #{epoch}, train loss {loss:.2f} and val score {val_score*100:.2f}, val loss: {val_loss:.4f}') 125 | 126 | 127 | # testing 128 | self.stats.best_model.eval() 129 | test_output = self.stats.best_model(full_features, full_adj) 130 | test_pred = test_output[test_mask].argmax(dim=1) 131 | 132 | test_score = (test_pred.eq( 133 | full_labels[test_mask]).sum() / test_mask.sum()).item() 134 | 135 | self.stats.test_score = test_score 136 | 137 | print('Test accuracy is {:.2f}'.format(test_score*100)) 138 | 139 | def validation(self): 140 | pass 141 | 142 | def inference(self): 143 | pass 144 | 145 | def save(self, *args, **kwargs): 146 | self.stats.save() 147 | -------------------------------------------------------------------------------- /dgnn/train/old/serial_fullserver.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 11 | import numpy as np 12 | 13 | from ..utils import Stats 14 | from ..models import model_selector 15 | from ..data import partition as P 16 | from ..utils import helpers as H 17 | 18 | from . import SerializedParamsAvg 19 | 20 | class SerializedFullServer(SerializedParamsAvg): 21 | """ 22 | This class is only for testing purpose. 23 | """ 24 | 25 | def __init__(self, config, dataset): 26 | super().__init__(config, dataset) 27 | 28 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 29 | self.model = self.model.to(self.device) 30 | 31 | def train(self, rank, *args, **kwargs): 32 | 33 | loss_fnc = torch.nn.CrossEntropyLoss() 34 | 35 | adjs = [] 36 | features = [] 37 | labels = [] 38 | train_masks = [] 39 | train_sizes = [] 40 | 41 | client_models = [] 42 | client_optimizers = [] 43 | 44 | device = self.device 45 | 46 | # Load all partitions, 47 | for rank in range(self.world_size): 48 | 49 | adj, feat, lab, tr, _, _ = P.load_partitions(self.dataset_dir, rank) 50 | 51 | adj = adj[rank] 52 | adj = adj.set_value(None) 53 | adj = gcn_norm(adj) 54 | adj = adj.to(device) 55 | 56 | feat = feat.to(device) 57 | lab = lab.to(device) 58 | 59 | model = copy.deepcopy(self.model) 60 | model = model.to(device) 61 | 62 | optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) 63 | 64 | adjs.append(adj) 65 | features.append(feat) 66 | labels.append(lab) 67 | train_masks.append(tr) 68 | train_sizes.append(tr.count_nonzero()) 69 | 70 | client_models.append(model) 71 | client_optimizers.append(optimizer) 72 | 73 | 74 | # params = OrderedDict() 75 | # for layer in self.model.state_dict().keys(): 76 | # params[layer] = None 77 | 78 | self.model = self.model.to(device) 79 | params = self.model.state_dict() 80 | total_train_size = np.sum(train_sizes) 81 | server_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 82 | 83 | for epoch in range(self.config.num_epochs): 84 | # for rank in range(self.world_size): 85 | 86 | # #! Communication, load from the server 87 | # if epoch > 0: 88 | # if self.config.sync_local or epoch % self.config.local_updates == 0: 89 | # # if rank == 0: 90 | # # print('Sync clients...') 91 | # client_models[rank].load_state_dict(params) 92 | 93 | # # train on clients 94 | # client_models[rank].train() 95 | # client_optimizers[rank].zero_grad() 96 | 97 | # output = client_models[rank](features[rank], adjs[rank]) 98 | # loss = loss_fnc(output[train_masks[rank]], labels[rank][train_masks[rank]]) 99 | 100 | # loss.backward() 101 | # # client_optimizers[rank].step() 102 | 103 | 104 | # # get the fedavg params on servers 105 | # for layer in params.keys(): 106 | # tmp = [] 107 | # for rank in range(self.world_size): 108 | # tmp.append(client_models[rank].state_dict()[layer]) 109 | # params[layer] = torch.div(torch.stack(tmp, dim=0).sum(dim=0), self.world_size) 110 | 111 | server_model = self.model 112 | # for sp, cp in zip(server_model.parameters(), client_models[0].parameters()): 113 | # sp.grad = torch.div(cp.grad , total_train_size/train_sizes[0]) 114 | 115 | # for rank in range(1, self.world_size): 116 | # for sp, cp in zip(server_model.parameters(), client_models[rank].parameters()): 117 | # sp.grad += torch.div(cp.grad , total_train_size/train_sizes[rank]) 118 | # # server_optimizer = torch.optim.Adam(server_model.parameters(), lr=self.config.lr) 119 | # server_optimizer.step() 120 | params = server_model.state_dict() 121 | 122 | if epoch != 0 and epoch % self.config.local_updates == 0 or epoch == self.config.num_epochs - 1: 123 | 124 | print(f'Doing server pass for {self.config.server_epochs} epoch...') 125 | for server_epoch in range(self.config.server_epochs): 126 | 127 | self.model.load_state_dict(params) 128 | 129 | self.model.train() 130 | self.optimizer.zero_grad() 131 | server_output = self.model(self.full_features, self.full_adj) 132 | server_loss = loss_fnc(server_output[self.full_train_mask], self.full_labels[self.full_train_mask]) 133 | server_loss.backward() 134 | self.optimizer.step() 135 | 136 | params = self.model.state_dict() 137 | 138 | 139 | val_score, val_loss = self.validation(params, epoch) 140 | print(f'Training Epoch #{epoch}, val score {val_score*100:.2f}, val loss {val_loss:.4f}') 141 | 142 | 143 | test_score = self.inference(params) 144 | print(f'Test accuracy is {test_score*100:.2f} at epoch {self.stats.best_val_epoch}') 145 | 146 | -------------------------------------------------------------------------------- /dgnn/train/old/serial_fedavg.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | import torch.distributed as dist 10 | import torch_geometric.transforms as T 11 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 12 | import numpy as np 13 | 14 | from ...utils import Stats 15 | from ...models import model_selector 16 | from ...data import partition as P 17 | from ...utils import helpers as H 18 | 19 | from . import SerializedParamsAvg 20 | 21 | class SerializedFedAvg(SerializedParamsAvg): 22 | """ 23 | This class is only for testing purpose. 24 | """ 25 | 26 | def __init__(self, config, dataset): 27 | super().__init__(config, dataset) 28 | 29 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 30 | self.model = self.model.to(self.device) 31 | 32 | self.parted = False 33 | 34 | def load_fixed_part(self, rank): 35 | 36 | if not self.parted: 37 | self.num_nodes = self.dataset.num_nodes 38 | self.node_per_part = math.ceil(self.num_nodes / self.config.num_procs) 39 | self.part_nodes = torch.split(torch.arange(self.num_nodes), self.node_per_part) 40 | 41 | self.part_ptr = [self.part_nodes[0][0]] 42 | for part in self.part_nodes: 43 | self.part_ptr.append(part[-1]) 44 | 45 | self.parted = True 46 | 47 | 48 | start = self.part_ptr[rank] 49 | if start > 0: 50 | start += 1 51 | end = self.part_ptr[rank+1] 52 | adj = self.dataset.adj_t.narrow(0, start, end-start+1).narrow(1, start, end-start+1) 53 | 54 | 55 | part_feats = self.dataset.x[start:end+1] 56 | part_labels = self.dataset.y[start:end+1] 57 | part_train_mask = self.dataset.train_mask[start:end+1] 58 | part_val_mask = self.dataset.val_mask[start:end+1] 59 | part_test_mask = self.dataset.test_mask[start:end+1] 60 | 61 | return adj, part_feats, part_labels, part_train_mask, part_val_mask, part_test_mask 62 | 63 | def train(self, _, *args, **kwargs): 64 | 65 | 66 | adjs = [] 67 | features = [] 68 | labels = [] 69 | train_masks = [] 70 | train_sizes = [] 71 | 72 | client_models = [] 73 | client_optimizers = [] 74 | client_grads = [] 75 | 76 | device = self.device 77 | 78 | # Load all partitions, 79 | for rank in range(self.world_size): 80 | 81 | # device = H.rank2dev(rank, self.num_gpus) 82 | if self.config.part_method == 'metis': 83 | adj, feat, lab, tr, va, te = P.load_partitions(self.dataset_dir, rank) 84 | adj = adj[rank] 85 | else: 86 | adj, feat, lab, tr, va, te = self.load_fixed_part(rank) 87 | 88 | adj = adj.set_value(None) 89 | adj = gcn_norm(adj) 90 | adj = adj.to(device) 91 | 92 | feat = feat.to(device) 93 | lab = lab.to(device) 94 | 95 | model = copy.deepcopy(self.model) 96 | model = model.to(device) 97 | 98 | optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) 99 | # optimizer = torch.optim.SGD(model.parameters(), lr=self.config.lr, momentum=0) 100 | 101 | adjs.append(adj) 102 | features.append(feat) 103 | labels.append(lab) 104 | train_masks.append(tr) 105 | train_sizes.append(tr.count_nonzero()) 106 | 107 | client_models.append(model) 108 | client_optimizers.append(optimizer) 109 | 110 | tmp_cgrad = [] 111 | for p in model.parameters(): 112 | tmp_cgrad.append(torch.zeros_like(p)) 113 | client_grads.append(tmp_cgrad) 114 | 115 | # import pdb; pdb.set_trace() 116 | 117 | total_train_size = np.sum(train_sizes) 118 | 119 | self.model = self.model.to(device) 120 | params = self.model.state_dict() 121 | server_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 122 | 123 | for epoch in range(self.config.num_epochs): 124 | 125 | for inner_epoch in range(self.config.local_updates): 126 | for rank in range(self.world_size): 127 | #! Communication, load from the server 128 | if inner_epoch == 0: 129 | if rank == 0: 130 | print('Sync clients...') 131 | client_models[rank].load_state_dict(params) 132 | 133 | for i, cp in enumerate(client_models[rank].parameters()): 134 | client_grads[rank][i] = torch.zeros_like(cp) 135 | 136 | # train on clients 137 | client_models[rank].train() 138 | client_optimizers[rank].zero_grad() 139 | 140 | output = client_models[rank](features[rank], adjs[rank]) 141 | loss = self.loss_fnc(output[train_masks[rank]], labels[rank][train_masks[rank]]) 142 | 143 | loss.backward() 144 | client_optimizers[rank].step() 145 | 146 | for i, cp in enumerate(client_models[rank].parameters()): 147 | client_grads[rank][i] += cp.grad 148 | 149 | 150 | # Attempt 3 151 | server_model = self.model 152 | for sp, cp in zip(server_model.parameters(), client_models[0].parameters()): 153 | sp.grad = torch.zeros_like(cp.grad) 154 | 155 | for rank in range(0, self.world_size): 156 | for i, sp in enumerate(server_model.parameters()): 157 | sp.grad += torch.div(client_grads[rank][i] , total_train_size/train_sizes[rank]) 158 | 159 | server_optimizer.step() 160 | params = server_model.state_dict() 161 | 162 | val_score, val_loss = self.validation(params, epoch) 163 | print(f'Training Epoch #{epoch}, val score {val_score*100:.2f}, val loss {val_loss:.4f}') 164 | 165 | 166 | test_score = self.inference(params) 167 | print(f'Test accuracy is {test_score*100:.2f} at epoch {self.stats.best_val_epoch}') 168 | -------------------------------------------------------------------------------- /dgnn/train/old/serial_sampleserver.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 11 | import numpy as np 12 | 13 | from ..utils import Stats 14 | from ..models import model_selector 15 | from ..data import partition as P 16 | from ..utils import helpers as H 17 | 18 | from . import SerializedParamsAvg 19 | 20 | class SerializedSampleServer(SerializedParamsAvg): 21 | """ 22 | This class is only for testing purpose. 23 | """ 24 | 25 | def __init__(self, config, dataset): 26 | super().__init__(config, dataset) 27 | 28 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 29 | self.model = self.model.to(self.device) 30 | 31 | def train(self, rank, *args, **kwargs): 32 | 33 | loss_fnc = torch.nn.CrossEntropyLoss() 34 | 35 | adjs = [] 36 | features = [] 37 | labels = [] 38 | train_masks = [] 39 | train_sizes = [] 40 | 41 | client_models = [] 42 | client_optimizers = [] 43 | 44 | device = self.device 45 | 46 | # Load all partitions, 47 | for rank in range(self.world_size): 48 | 49 | adj, feat, lab, tr, va, te = P.load_partitions(self.dataset_dir, rank) 50 | 51 | adj = adj[rank] 52 | adj = adj.set_value(None) 53 | adj = gcn_norm(adj) 54 | adj = adj.to(device) 55 | 56 | feat = feat.to(device) 57 | lab = lab.to(device) 58 | 59 | model = copy.deepcopy(self.model) 60 | model = model.to(device) 61 | 62 | optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) 63 | 64 | adjs.append(adj) 65 | features.append(feat) 66 | labels.append(lab) 67 | train_masks.append(tr) 68 | train_sizes.append(tr.count_nonzero()) 69 | 70 | client_models.append(model) 71 | client_optimizers.append(optimizer) 72 | 73 | 74 | # params = OrderedDict() 75 | # for layer in self.model.state_dict().keys(): 76 | # params[layer] = None 77 | 78 | self.model = self.model.to(device) 79 | params = self.model.state_dict() 80 | total_train_size = np.sum(train_sizes) 81 | server_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 82 | 83 | for epoch in range(self.config.num_epochs): 84 | for rank in range(self.world_size): 85 | #! Communication, load from the server 86 | if epoch > 0: 87 | if self.config.sync_local or epoch % self.config.local_updates == 0: 88 | if rank == 0: 89 | print('Sync clients...') 90 | client_models[rank].load_state_dict(params) 91 | # client_optimizers[rank] = torch.optim.Adam(client_models[rank].parameters(), lr=self.config.lr) 92 | 93 | # train on clients 94 | client_models[rank].train() 95 | client_optimizers[rank].zero_grad() 96 | 97 | output = client_models[rank](features[rank], adjs[rank]) 98 | loss = loss_fnc(output[train_masks[rank]], labels[rank][train_masks[rank]]) 99 | 100 | loss.backward() 101 | # client_optimizers[rank].step() 102 | 103 | # # get the fedavg params on servers 104 | # for layer in params.keys(): 105 | # tmp = [] 106 | # for rank in range(self.world_size): 107 | # tmp.append(client_models[rank].state_dict()[layer]) 108 | # params[layer] = torch.div(torch.stack(tmp, dim=0).sum(dim=0), self.world_size) 109 | 110 | ## Attempt 3 111 | server_model = self.model 112 | for sp, cp in zip(server_model.parameters(), client_models[0].parameters()): 113 | sp.grad = torch.div(cp.grad , total_train_size/train_sizes[0]) 114 | 115 | for rank in range(1, self.world_size): 116 | for sp, cp in zip(server_model.parameters(), client_models[rank].parameters()): 117 | sp.grad += torch.div(cp.grad , total_train_size/train_sizes[rank]) 118 | # server_optimizer = torch.optim.Adam(server_model.parameters(), lr=self.config.lr) 119 | server_optimizer.step() 120 | params = server_model.state_dict() 121 | 122 | 123 | if epoch != 0 and (epoch % self.config.local_updates == 0 or epoch == self.config.num_epochs - 1): 124 | 125 | print(f'Doing server pass for {self.config.server_epochs} epoch...') 126 | for server_epoch in range(self.config.server_epochs): 127 | 128 | sampled_nodes = torch.randint(0, self.dataset.num_nodes, (self.config.minibatch_size, ), dtype=torch.long) 129 | # Temp Fix for subgraph bug 130 | sampled_nodes, _ = torch.sort(sampled_nodes) 131 | sampled_adj, _ = self.dataset.adj_t.saint_subgraph(sampled_nodes) 132 | sampled_feat = self.dataset.x[sampled_nodes] 133 | sampled_label = self.dataset.y[sampled_nodes] 134 | sampled_train_mask = self.dataset.train_mask[sampled_nodes] 135 | 136 | sampled_adj = sampled_adj.set_value(None) 137 | 138 | sampled_adj = gcn_norm(sampled_adj) 139 | sampled_adj = sampled_adj.to(device) 140 | sampled_feat = sampled_feat.to(device) 141 | sampled_label = sampled_label.to(device) 142 | 143 | self.model.load_state_dict(params) 144 | self.model.train() 145 | self.optimizer.zero_grad() 146 | 147 | server_output = self.model(sampled_feat, sampled_adj) 148 | server_loss = loss_fnc(server_output[sampled_train_mask], sampled_label[sampled_train_mask]) 149 | 150 | server_loss.backward() 151 | self.optimizer.step() 152 | 153 | params = self.model.state_dict() 154 | 155 | 156 | val_score, val_loss = self.validation(params, epoch) 157 | print(f'Training Epoch #{epoch}, val score {val_score*100:.2f}, val loss {val_loss:.4f}') 158 | 159 | 160 | test_score = self.inference(params) 161 | print(f'Test accuracy is {test_score*100:.2f} at epoch {self.stats.best_val_epoch}') 162 | 163 | -------------------------------------------------------------------------------- /dgnn/train/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import torch 5 | import torchmetrics 6 | import numpy as np 7 | 8 | import torch_geometric.transforms as T 9 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 10 | 11 | from tqdm import trange 12 | from sklearn.metrics import f1_score, accuracy_score, roc_auc_score 13 | from ogb.nodeproppred import Evaluator 14 | 15 | from ..utils import Stats 16 | from ..models import model_selector 17 | from ..layers import layer_selector 18 | 19 | 20 | class Base(object): 21 | 22 | def __init__(self, 23 | config, 24 | dataset 25 | ): 26 | 27 | self.config = config 28 | self.stats = Stats(config) 29 | self.dataset = dataset 30 | 31 | # Set device 32 | if self.config.gpu >= 0: 33 | self.device = self.config.gpu 34 | else: 35 | self.device = 'cpu' 36 | 37 | if self.config.cpu_val: 38 | self.val_device = 'cpu' 39 | else: 40 | self.val_device = self.device 41 | 42 | model = model_selector(self.config.model) 43 | layer = layer_selector(self.config.layer) 44 | 45 | self.activation = torch.nn.ReLU(True) 46 | 47 | self.model = model( 48 | dataset.num_features, 49 | self.config.hidden_size, 50 | dataset.num_classes, 51 | self.config.num_layers, 52 | self.activation, 53 | layer=layer, 54 | input_norm=self.config.input_norm, 55 | layer_norm=self.config.layer_norm, 56 | arch=self.config.arch, 57 | residual=self.config.residual, 58 | dropout=self.config.dropout, 59 | ) 60 | 61 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 62 | if config.loss == 'xentropy': 63 | self.loss_fnc = torch.nn.CrossEntropyLoss() 64 | elif config.loss == 'bceloss': 65 | self.loss_fnc = torch.nn.BCEWithLogitsLoss() 66 | 67 | self.train_output = None 68 | 69 | if self.dataset.name.startswith('ogbn'): 70 | self.evaluator = Evaluator(name=self.dataset.name) 71 | 72 | def run(self): 73 | 74 | self.model = self.model.to(self.device) 75 | 76 | self.tbar = trange(self.config.num_epochs, desc='Epochs') 77 | 78 | 79 | # for epoch in range(self.config.num_epochs): 80 | for epoch in self.tbar: 81 | start_time = time.perf_counter() 82 | 83 | self.train(epoch) 84 | end_train_time = time.perf_counter() 85 | 86 | if not self.validation(epoch): 87 | break 88 | end_val_time = time.perf_counter() 89 | 90 | self.tbar.set_postfix(loss=f'{self.stats.train_loss[-1]:.4f}' if len(self.stats.train_loss) > 0 else '-', 91 | score=f'{self.stats.train_scores[-1]*100:.2f}' if len(self.stats.train_scores) > 0 else '-', 92 | val_loss=f'{self.stats.val_loss[-1]:.4f}', 93 | val_score=f'{self.stats.val_scores[-1]*100:.2f}' 94 | ) 95 | 96 | self.stats.train_time.append(end_train_time-start_time) 97 | self.stats.val_time.append(end_val_time-end_train_time) 98 | 99 | self.inference() 100 | end_inf_time = time.perf_counter() 101 | self.stats.test_time = end_inf_time - end_val_time 102 | 103 | print(f'Test Score: {self.stats.test_score * 100:.2f}% ' 104 | f'@ Epoch #{self.stats.best_val_epoch}, ' 105 | f'Highest Val: {self.stats.best_val_score * 100:.2f}%' 106 | ) 107 | 108 | print(f'Total training time: {np.sum(self.stats.train_time):.3f} sec.') 109 | 110 | 111 | # start_time = time.perf_counter() 112 | # for epoch in self.tbar: 113 | # self.train(epoch) 114 | # end_time = time.perf_counter() 115 | # print('Total train:', end_time-start_time) 116 | 117 | def calc_score(self, pred, batch_labels): 118 | 119 | # Same device (GPU) metrics 120 | if self.dataset.name == 'ogbn-proteins': 121 | # ROC 122 | pred_labels = torch.nn.Sigmoid()(pred) 123 | score = torchmetrics.functional.auroc(pred_labels, batch_labels.int(), 124 | num_classes=batch_labels.shape[1]).cpu().item() 125 | elif self.dataset.name == 'yelp': 126 | pred_labels = torch.nn.Sigmoid()(pred) 127 | score = torchmetrics.functional.f1(pred_labels, batch_labels.int()).cpu().item() 128 | else: 129 | pred_labels = pred.argmax(dim=-1, keepdim=True) 130 | # score = torchmetrics.functional.f1(pred_labels, batch_labels.int()).cpu().item() 131 | score = pred_labels.eq(batch_labels.unsqueeze(-1)).sum().cpu().item() / batch_labels.size(0) 132 | 133 | return score 134 | 135 | def patience(self): 136 | # terminate after num_patience of not increasing val_score 137 | if len(self.stats.val_scores) > self.config.val_patience and \ 138 | np.max(self.stats.val_scores[-1*self.config.val_patience:]) < self.stats.best_val_score: 139 | print('Run out of patience!') 140 | return False 141 | 142 | return True 143 | 144 | def save(self, *args, **kwargs): 145 | self.stats.save() 146 | 147 | def train(self, epoch): 148 | raise NotImplementedError 149 | 150 | @torch.no_grad() 151 | def validation(self, epoch): 152 | raise NotImplementedError 153 | 154 | @torch.no_grad() 155 | def inference(self): 156 | raise NotImplementedError 157 | 158 | 159 | # if self.dataset.name.startswith('ogbn'): 160 | # # Default OGB, on CPU using Numpy 161 | # if self.dataset.name == 'ogbn-proteins': 162 | # pred_labels = pred 163 | # score = self.evaluator.eval({ 164 | # 'y_true': batch_labels, 165 | # 'y_pred': pred_labels, 166 | # })['rocauc'] 167 | # else: 168 | # pred_labels = pred.detach().argmax(dim=-1, keepdim=True) 169 | # score = self.evaluator.eval({ 170 | # 'y_true': batch_labels.unsqueeze(-1), 171 | # 'y_pred': pred_labels, 172 | # })['acc'] 173 | # elif self.dataset.name == 'yelp': 174 | # pred_labels = torch.nn.Sigmoid()(pred).detach().cpu() > 0.5 175 | # score = f1_score(batch_labels.cpu(), pred_labels, average='micro') 176 | # else: 177 | # pred_labels = pred.detach().cpu().argmax(dim=1) 178 | # score = f1_score(batch_labels.cpu(), pred_labels, average='micro') 179 | 180 | # TODO: https://github.com/tqdm/tqdm/issues/630, https://github.com/KimythAnly/qqdm/ 181 | # from qqdm import qqdm, format_str 182 | # self.tbar = qqdm(range(self.config.num_epochs), desc=format_str('bold', 'Training')) 183 | # self.tbar.set_infos({ 184 | # 'loss': f'{train_loss:.4f}', 185 | # 'score': f'{train_score*100:.2f}', 186 | # 'val_loss': f'{self.stats.val_loss[-1]:.4f}', 187 | # 'val_score': f'{self.stats.val_scores[-1]*100:.2f}', 188 | # }) -------------------------------------------------------------------------------- /dgnn/train/old/serial_paravg.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import copy 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | import torch_geometric.transforms as T 10 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 11 | 12 | import numpy as np 13 | 14 | from ...utils import Stats 15 | from ...models import model_selector 16 | from ...data import partition as P 17 | from ...utils import helpers as H 18 | 19 | from . import Distributed 20 | 21 | class SerializedParamsAvg(Distributed): 22 | """ 23 | This class is only for testing purpose. 24 | The calculation is done on single GPU/CPU and in serial mode. 25 | """ 26 | 27 | def __init__(self, config, dataset): 28 | super().__init__(config, dataset) 29 | 30 | self.dataset = dataset[0] 31 | self.dataset = T.GCNNorm()(self.dataset) 32 | 33 | self.world_size = self.config.num_procs 34 | 35 | self.device = H.rank2dev(0, self.num_gpus) 36 | # self.device = H.rank2dev(0, 0) 37 | 38 | self.full_adj = self.dataset.adj_t.to(self.device) 39 | self.full_features = self.dataset.x.to(self.device) 40 | self.full_labels = self.dataset.y.to(self.device) 41 | self.full_train_mask = self.dataset.train_mask 42 | self.full_val_mask = self.dataset.val_mask 43 | self.full_test_mask = self.dataset.test_mask 44 | 45 | self.loss_fnc = torch.nn.CrossEntropyLoss() 46 | 47 | def start(self): 48 | self.train(0) 49 | 50 | 51 | def train(self, _, *args, **kwargs): 52 | 53 | 54 | adjs = [] 55 | features = [] 56 | labels = [] 57 | train_masks = [] 58 | train_sizes = [] 59 | 60 | client_models = [] 61 | client_optimizers = [] 62 | 63 | device = self.device 64 | 65 | # Load all partitions, 66 | for rank in range(self.world_size): 67 | 68 | # device = H.rank2dev(rank, self.num_gpus) 69 | adj, feat, lab, tr, va, te = P.load_partitions(self.dataset_dir, rank) 70 | adj = adj[rank] 71 | adj = adj.set_value(None) 72 | adj = gcn_norm(adj) 73 | adj = adj.to(device) 74 | 75 | feat = feat.to(device) 76 | lab = lab.to(device) 77 | 78 | model = copy.deepcopy(self.model) 79 | model = model.to(device) 80 | 81 | optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) 82 | # optimizer = torch.optim.SGD(model.parameters(), lr=self.config.lr, momentum=0) 83 | 84 | adjs.append(adj) 85 | features.append(feat) 86 | labels.append(lab) 87 | train_masks.append(tr) 88 | train_sizes.append(tr.count_nonzero()) 89 | 90 | client_models.append(model) 91 | client_optimizers.append(optimizer) 92 | 93 | total_train_size = np.sum(train_sizes) 94 | 95 | self.model = self.model.to(device) 96 | params = self.model.state_dict() 97 | server_optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr) 98 | 99 | for epoch in range(self.config.num_epochs): 100 | for rank in range(self.world_size): 101 | 102 | #! Communication, load from the server 103 | if epoch > 0: 104 | if self.config.sync_local or epoch % self.config.local_updates == 0: 105 | # if rank == 0: 106 | # print('Sync clients...') 107 | client_models[rank].load_state_dict(params) 108 | # client_optimizers[rank] = torch.optim.Adam(client_models[rank].parameters(), lr=self.config.lr) 109 | 110 | # train on clients 111 | client_models[rank].train() 112 | client_optimizers[rank].zero_grad() 113 | 114 | output = client_models[rank](features[rank], adjs[rank]) 115 | loss = self.loss_fnc(output[train_masks[rank]], labels[rank][train_masks[rank]]) 116 | 117 | loss.backward() 118 | # client_optimizers[rank].step() 119 | 120 | 121 | ## Attempt 1: get the fedavg params on servers 122 | # for layer in params.keys(): 123 | # tmp = [] 124 | # for rank in range(self.world_size): 125 | # tmp.append(client_models[rank].state_dict()[layer]) 126 | # params[layer] = torch.div(torch.stack(tmp, dim=0).sum(dim=0), self.world_size) 127 | 128 | ## Attempt 2 129 | # for k in params.keys(): 130 | # params[k] = torch.zeros_like(params[k]) 131 | 132 | # for rank in range(self.world_size): 133 | # for k in params: 134 | # # params[k] += torch.div(client_models[rank].state_dict()[k], self.world_size) 135 | # params[k] += torch.div(client_models[rank].state_dict()[k] * train_sizes[rank], total_train_size) 136 | 137 | # Attempt 3 138 | server_model = self.model 139 | for sp, cp in zip(server_model.parameters(), client_models[0].parameters()): 140 | sp.grad = torch.div(cp.grad , total_train_size/train_sizes[0]) 141 | 142 | for rank in range(1, self.world_size): 143 | for sp, cp in zip(server_model.parameters(), client_models[rank].parameters()): 144 | sp.grad += torch.div(cp.grad , total_train_size/train_sizes[rank]) 145 | server_optimizer.step() 146 | params = server_model.state_dict() 147 | 148 | val_score, val_loss = self.validation(params, epoch) 149 | print(f'Training Epoch #{epoch}, val score {val_score*100:.2f}, val loss {val_loss:.4f}') 150 | 151 | 152 | test_score = self.inference(params) 153 | print(f'Test accuracy is {test_score*100:.2f} at epoch {self.stats.best_val_epoch}') 154 | 155 | 156 | 157 | def validation(self, params, epoch): 158 | 159 | self.model.load_state_dict(params) 160 | model = self.model.to(self.device) 161 | 162 | model.eval() 163 | val_output = model(self.full_features, self.full_adj) 164 | 165 | val_loss = self.loss_fnc(val_output[self.full_val_mask], self.full_labels[self.full_val_mask]) 166 | 167 | val_pred = val_output[self.full_val_mask].detach().argmax(dim=1) 168 | 169 | val_score = (val_pred.eq( 170 | self.full_labels[self.full_val_mask]).sum() / self.full_val_mask.sum()).item() 171 | 172 | 173 | if self.stats.best_val_score == 0 or val_score > self.stats.best_val_score: 174 | self.stats.best_val_score = val_score 175 | self.stats.best_model = copy.deepcopy(model) 176 | self.stats.best_val_epoch = epoch 177 | 178 | self.stats.val_scores.append(val_score) 179 | 180 | return val_score, val_loss 181 | 182 | def inference(self, params): 183 | # self.model.load_state_dict(params) 184 | # model = self.model.to(self.device) 185 | model = self.stats.best_model.to(self.device) 186 | 187 | test_output = model(self.full_features, self.full_adj) 188 | test_pred = test_output[self.full_test_mask].argmax(dim=1) 189 | 190 | test_score = (test_pred.eq( 191 | self.full_labels[self.full_test_mask]).sum() / self.full_test_mask.sum()).item() 192 | 193 | self.stats.test_score = test_score 194 | 195 | return test_score --------------------------------------------------------------------------------