├── .bazeliskrc ├── .bazelrc ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE ├── PULL_REQUEST_TEMPLATE ├── actions │ ├── benchmarks │ │ └── action.yml │ ├── buildcache │ │ └── action.yml │ ├── lint │ │ └── action.yml │ ├── test │ │ └── action.yml │ └── wheel │ │ └── action.yml └── workflows │ ├── benchmark.yml │ ├── ci.yml │ ├── lint.yml │ ├── test.yml │ └── wheel.yml ├── .gitignore ├── .mypy.ini ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── BUILD ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── WORKSPACE ├── config ├── BUILD ├── configs.bzl ├── hadoop.BUILD ├── json.BUILD ├── jvm.BUILD ├── mimalloc.BUILD ├── openssl.BUILD ├── openssl_deps.bzl ├── openssl_setup.bzl └── variables.bzl ├── docs ├── BUILD ├── advanced │ ├── hdfs.rst │ ├── index.rst │ └── sql_shard.rst ├── conf.py ├── doctest_template.py ├── graph_engine │ ├── custom_decoder.rst │ ├── data_spec.rst │ ├── from_networkx.rst │ ├── index.rst │ ├── overview.rst │ ├── spark_converter.rst │ └── temporal.rst ├── index.rst ├── make_docs.py ├── requirements.txt ├── tf │ ├── index.rst │ ├── link_pred.rst │ ├── node_class.rst │ └── ray_usage.rst └── torch │ ├── distrib.rst │ ├── index.rst │ ├── link_pred.rst │ └── node_class.rst ├── examples ├── __init__.py ├── hdfs_setup.sh ├── pytorch │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── aml.py │ ├── gat.py │ ├── gcn.py │ ├── hetgnn │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── evaluation.py │ │ ├── graph.py │ │ ├── main.py │ │ ├── model.py │ │ └── sampler.py │ ├── pyg_interface.py │ ├── sage.py │ └── tgn.py └── tensorflow │ ├── __init__.py │ ├── gat │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── gat.py │ ├── main.py │ ├── run.sh │ └── test_gat.py │ ├── gcn │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── gcn.py │ ├── main.py │ ├── run.sh │ └── test_gcn.py │ ├── han │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── han.py │ ├── main.py │ ├── run.sh │ └── test_han.py │ ├── requirements.txt │ └── sage │ ├── BUILD │ ├── README.md │ ├── __init__.py │ ├── main.py │ ├── main_linkprediction.py │ ├── main_unsup.py │ ├── run.sh │ ├── sage.py │ ├── sage_linkprediction.py │ ├── sage_unsupervised.py │ ├── test_sage.py │ └── test_sage_link.py ├── requirements.txt ├── src ├── cc │ ├── lib │ │ ├── BUILD │ │ ├── benchmark │ │ │ ├── BUILD │ │ │ ├── grpc_benchmark.cc │ │ │ ├── neighbor_sampler_benchmark.cc │ │ │ ├── partition_benchmark.cc │ │ │ ├── sampler_benchmark.cc │ │ │ └── search_benchmark.cc │ │ ├── distributed │ │ │ ├── BUILD │ │ │ ├── call_data.cc │ │ │ ├── call_data.h │ │ │ ├── client.cc │ │ │ ├── client.h │ │ │ ├── graph_engine.cc │ │ │ ├── graph_engine.h │ │ │ ├── graph_sampler.cc │ │ │ ├── graph_sampler.h │ │ │ ├── server.cc │ │ │ ├── server.h │ │ │ └── service.proto │ │ ├── graph │ │ │ ├── BUILD │ │ │ ├── graph.cc │ │ │ ├── graph.h │ │ │ ├── hdfs_wrap.cc │ │ │ ├── hdfs_wrap.h │ │ │ ├── locator.cc │ │ │ ├── locator.h │ │ │ ├── logger.cc │ │ │ ├── logger.h │ │ │ ├── merger.h │ │ │ ├── metadata.cc │ │ │ ├── metadata.h │ │ │ ├── partition.cc │ │ │ ├── partition.h │ │ │ ├── reservoir.cc │ │ │ ├── reservoir.h │ │ │ ├── sampler.cc │ │ │ ├── sampler.h │ │ │ ├── storage.h │ │ │ ├── types.h │ │ │ └── xoroshiro.h │ │ ├── py_graph.cc │ │ ├── py_graph.h │ │ ├── py_server.cc │ │ ├── version-script.darwin.lds │ │ └── version-script.linux.lds │ └── tests │ │ ├── BUILD │ │ ├── core-site.xml │ │ ├── distributed_test.cc │ │ ├── graph_test.cc │ │ ├── hdfs_test.cc │ │ ├── mocks.cc │ │ ├── mocks.h │ │ └── temporal_test.cc └── python │ ├── deepgnn │ ├── BUILD │ ├── __init__.py │ ├── arg_types.py │ ├── graph_engine │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── _adl_reader.py │ │ ├── _base.py │ │ ├── adl_uploader.py │ │ ├── backends │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── options.py │ │ │ └── snark │ │ │ │ ├── BUILD │ │ │ │ ├── __init__.py │ │ │ │ ├── client.py │ │ │ │ ├── synchronized.py │ │ │ │ ├── test_snark_client.py │ │ │ │ └── test_synchronized.py │ │ ├── data │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── citation.py │ │ │ ├── citeseer.py │ │ │ ├── citeseer.zip │ │ │ ├── citeseer_full.zip │ │ │ ├── cora.py │ │ │ ├── cora.zip │ │ │ ├── cora_full.zip │ │ │ ├── data_util.py │ │ │ ├── mooc.py │ │ │ ├── ppi.py │ │ │ ├── reddit.py │ │ │ └── test_graph_dataset.py │ │ ├── graph_dataset.py │ │ ├── graph_ops.py │ │ ├── multihop.py │ │ ├── prefetch.py │ │ ├── samplers.py │ │ ├── snark │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── _downloader.py │ │ │ ├── _lib.py │ │ │ ├── alias.py │ │ │ ├── client.py │ │ │ ├── convert.py │ │ │ ├── converter │ │ │ │ ├── __init__.py │ │ │ │ ├── options.py │ │ │ │ ├── process.py │ │ │ │ └── writers.py │ │ │ ├── decoders.py │ │ │ ├── dispatcher.py │ │ │ ├── distributed.py │ │ │ ├── local.py │ │ │ ├── meta.py │ │ │ ├── meta_merger.py │ │ │ ├── preprocess │ │ │ │ ├── BUILD │ │ │ │ ├── __init__.py │ │ │ │ └── sampler │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── forest_fire.py │ │ │ │ │ └── metric.py │ │ │ ├── server.py │ │ │ └── tests │ │ │ │ ├── BUILD │ │ │ │ ├── alias_test.py │ │ │ │ ├── convert_test.py │ │ │ │ ├── downloader_test.py │ │ │ │ ├── e2e_test.py │ │ │ │ ├── forest_fire_test.py │ │ │ │ ├── hdfs_test.py │ │ │ │ ├── metric_test.py │ │ │ │ ├── neighbor_sampler_test.py │ │ │ │ ├── ppr_benchmark_test.py │ │ │ │ ├── random_walk_test.py │ │ │ │ ├── requirements.txt │ │ │ │ ├── snark_test.py │ │ │ │ ├── sparse_features_test.py │ │ │ │ ├── temporal_test.py │ │ │ │ └── util_test.py │ │ ├── test_adl_reader.py │ │ ├── test_multihop.py │ │ ├── test_prefetch.py │ │ └── utils.py │ ├── log_consts.py │ ├── logging_utils.py │ ├── migrate │ │ ├── 0_1_56.py │ │ ├── 0_1_57.py │ │ └── __init__.py │ ├── pytorch │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── common │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── aggregators.py │ │ │ ├── args.py │ │ │ ├── consts.py │ │ │ ├── dataset.py │ │ │ ├── metrics.py │ │ │ ├── optimization.py │ │ │ ├── test_metrics.py │ │ │ ├── test_utils.py │ │ │ └── utils.py │ │ ├── encoding │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── feature_encoder.py │ │ │ ├── gnn_encoder_gat.py │ │ │ ├── gnn_encoder_hetgnn.py │ │ │ ├── gnn_encoder_lgcl.py │ │ │ ├── gnn_encoder_lightgcn.py │ │ │ ├── gnn_encoder_sage.py │ │ │ ├── test_feature_encoder.py │ │ │ ├── test_gnn_encoders.py │ │ │ └── twinbert │ │ │ │ ├── BUILD │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration.py │ │ │ │ ├── deepspeed │ │ │ │ ├── BUILD │ │ │ │ ├── __init__.py │ │ │ │ ├── convert_bert_ckpt_to_deepspeed.py │ │ │ │ ├── file_utils.py │ │ │ │ ├── loss.py │ │ │ │ ├── nvidia_modeling.py │ │ │ │ └── nvidia_modeling_no_apex.py │ │ │ │ ├── embedding.py │ │ │ │ ├── encoder.py │ │ │ │ ├── pooler.py │ │ │ │ ├── test_encoder.py │ │ │ │ ├── test_tokenization.py │ │ │ │ └── tokenization.py │ │ ├── modeling │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ └── base_model.py │ │ ├── nn │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── gat_conv.py │ │ │ └── test_conv.py │ │ └── training │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── args.py │ │ │ ├── factory.py │ │ │ ├── test_trainer.py │ │ │ ├── trainer.py │ │ │ ├── trainer_ddp.py │ │ │ ├── trainer_fp16.py │ │ │ ├── trainer_hvd.py │ │ │ └── utils.py │ ├── tf │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── common │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── args.py │ │ │ ├── base_trainer.py │ │ │ ├── dataset.py │ │ │ ├── dist_sync.py │ │ │ ├── hooks.py │ │ │ ├── horovod_trainer.py │ │ │ ├── ps_trainer.py │ │ │ ├── test_helper.py │ │ │ ├── test_trainer.py │ │ │ ├── test_trainer_tf2.py │ │ │ ├── tf2_horovod_trainer.py │ │ │ ├── tf2_trainer.py │ │ │ ├── trainer.py │ │ │ ├── trainer_factory.py │ │ │ ├── unittest │ │ │ │ ├── BUILD │ │ │ │ ├── test_dist_sync.py │ │ │ │ └── testserver.py │ │ │ └── utils.py │ │ ├── encoders │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── att_encoder.py │ │ │ └── han_encoder.py │ │ ├── layers │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── attention_header.py │ │ │ └── base.py │ │ └── nn │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── gat_conv.py │ │ │ ├── gcn_conv.py │ │ │ ├── metrics.py │ │ │ ├── sage_conv.py │ │ │ └── test_conv.py │ └── train_types.py │ └── setup.py └── tools ├── BUILD ├── manylinux ├── README.md ├── fixlinks.sh ├── install-gcc11.sh └── rpm-patch.sh └── toolchain ├── BUILD └── cc_toolchain_config.bzl /.bazeliskrc: -------------------------------------------------------------------------------- 1 | USE_BAZEL_VERSION=6.4.0 2 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # We want to have a dynamic library without any dependencies, 5 | # that's why we statically link standard libs. 6 | build:linux --repo_env=CC=gcc-12 7 | # Needed to build grpc. 8 | build:linux --cxxopt="-std=c++20" 9 | 10 | # TODO(mortid0): Remove opts below when https://bugs.chromium.org/p/boringssl/issues/detail?id=492 fixed" 11 | build:linux --copt="-D_XOPEN_SOURCE=700" 12 | build:linux --copt="-Wno-array-bounds" 13 | build:linux --copt="-Wno-stringop-overflow" 14 | build:linux --copt="-Wno-unknown-warning-option" 15 | 16 | build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a 17 | build:linux --action_env=BAZEL_LINKOPTS=-static-libgcc:-lm:-pthread 18 | build:linux --action_env=SETUPTOOLS_USE_DISTUTILS="false" 19 | build:linux --action_env=HADOOP_HOME="/tmp/snark/" 20 | build:linux --action_env=LD_LIBRARY_PATH="./external/jvm/jre/lib/amd64/server" 21 | 22 | build:manylinux --crosstool_top=//tools/toolchain:manylinux 23 | build:manylinux --cpu=k8 24 | build:manylinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain 25 | 26 | build:macos --cxxopt="-std=c++20" 27 | build:macos --action_env=SETUPTOOLS_USE_DISTUTILS="false" 28 | build:macos --action_env=HADOOP_HOME="/tmp/snark/" 29 | 30 | build:windows --action_env=BAZEL_LINKOPTS=/MT 31 | build:windows --action_env=SETUPTOOLS_USE_DISTUTILS="false" 32 | build:windows --action_env=HADOOP_HOME="C:/usr/local/hadoop/" 33 | build:windows --action_env=APPDATA="C:/usr/local/hadoop/datalake" 34 | build:windows --action_env=LIB="C:/Program Files/OpenSSL-Win64/lib" 35 | build:windows --action_env=INCLUDE="C:/Program Files/OpenSSL-Win64/include" 36 | build:windows --action_env=USERPROFILE="." # needed for python3.8 pathlib Path.home() 37 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @microsoft/deepgnn 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE: -------------------------------------------------------------------------------- 1 | - [ ] Issue is labeled using the label menu on the right side. 2 | 3 | Environment 4 | ----------- 5 | * Python version: (python -V) 6 | * deepgnn-ge Version: (python -m pip show deepgnn-ge) 7 | * OS: (Windows, Linux, ...) 8 | 9 | Issue Details 10 | ----------- 11 | * What you did - code sample or commands run 12 | 13 | * Expected behavior 14 | 15 | * Actual behavior 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE: -------------------------------------------------------------------------------- 1 | - [ ] Forked repo is synced with upstream -> github shows no code delta outside of the desired. 2 | https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork 3 | - [ ] Tests are passing? https://github.com/microsoft/DeepGNN/blob/main/CONTRIBUTING.md#run-tests 4 | - [ ] Changelog and documentation updated. 5 | - [ ] PR is labeled using the label menu on the right side. 6 | 7 | Previous Behavior 8 | ---------------- 9 | * Provide relevant issue number if applicable, eg #44. 10 | 11 | New Behavior 12 | ---------------- 13 | -------------------------------------------------------------------------------- /.github/actions/benchmarks/action.yml: -------------------------------------------------------------------------------- 1 | name: "Benchmarks" 2 | description: "Run benchmarks" 3 | inputs: 4 | cache_address: 5 | description: 'Address of bazel remote cache.' 6 | required: true 7 | default: localhost:8080 8 | runs: 9 | using: "composite" 10 | steps: 11 | - run: | 12 | if [ ping -c 1 ${{ inputs.cache_address }} &> /dev/null ] && [ $(curl ${{ inputs.cache_address }}/status | python -c "import sys, json; print(json.load(sys.stdin)['NumFiles'])") -ne 0 ] 13 | then 14 | echo "REMOTE_CACHE=--remote_cache=${{ inputs.cache_address }}" >> $GITHUB_ENV 15 | fi 16 | bazel run -c opt src/cc/lib/benchmark:grpc_benchmark --config=linux ${{ env.REMOTE_CACHE }} 17 | bazel run -c opt src/cc/lib/benchmark:sampler_benchmark --config=linux ${{ env.REMOTE_CACHE }} 18 | bazel run -c opt src/cc/lib/benchmark:neighbor_sampler_benchmark --config=linux ${{ env.REMOTE_CACHE }} 19 | shell: bash 20 | if: runner.os == 'Linux' 21 | -------------------------------------------------------------------------------- /.github/actions/buildcache/action.yml: -------------------------------------------------------------------------------- 1 | name: "Build cache" 2 | description: "Start remote bazel cache proxy to azure storage" 3 | inputs: 4 | cache_address: 5 | description: 'Address of bazel remote cache.' 6 | required: true 7 | default: localhost:8080 8 | runs: 9 | using: "composite" 10 | steps: 11 | - uses: actions/setup-go@v4 12 | with: 13 | go-version: '^1.20.7' 14 | - name: Start cache proxy 15 | shell: bash 16 | if: runner.os == 'Linux' 17 | run: | 18 | go install github.com/buchgr/bazel-remote/v2@v2.4.1 19 | bazel-remote --azblob.tenant_id=bazelcache --azblob.storage_account=bazelcache --azblob.container_name=cache --azblob.auth_method=shared_key --dir=/tmp/bazelcache --max_size=2 --http_address=${{ inputs.cache_address }} & 20 | sleep 5 # Give the server time to start 21 | echo $(curl ${{ inputs.cache_address }}/status) 22 | -------------------------------------------------------------------------------- /.github/actions/lint/action.yml: -------------------------------------------------------------------------------- 1 | name: "Lint" 2 | description: "run pre-commit linters" 3 | runs: 4 | using: "composite" 5 | steps: 6 | - run: sudo apt-get install clang-format 7 | shell: bash 8 | if: runner.os == 'Linux' 9 | - run: $env:PATH+=";"+${env:ProgramFiles(x86)}+"\Microsoft Visual Studio\2019\Enterprise\VC\Tools\Llvm\bin" 10 | shell: pwsh 11 | if: runner.os == 'Windows' 12 | - run: brew install clang-format 13 | shell: bash 14 | if: runner.os == 'macOS' 15 | - run: pip install wheel pre-commit==2.17.0 mypy==0.971 numpy==1.22.4 torch==1.13.1 tensorflow==2.13.0 ray==2.9.1 16 | shell: bash 17 | name: install dependencies 18 | if: runner.os != 'macOS' 19 | - run: pip install wheel pre-commit==2.17.0 mypy==0.971 numpy==1.22.4 torch==1.13.1 20 | shell: bash 21 | name: install dependencies 22 | if: runner.os == 'macOS' 23 | - run: pre-commit install 24 | shell: bash 25 | name: initialize pre-commit 26 | - run: pre-commit run --all-files 27 | shell: bash 28 | name: run linters 29 | -------------------------------------------------------------------------------- /.github/actions/test/action.yml: -------------------------------------------------------------------------------- 1 | name: "Test" 2 | description: "Run unit tests" 3 | inputs: 4 | cache_address: 5 | description: 'Address of bazel remote cache.' 6 | required: true 7 | default: localhost:8080 8 | runs: 9 | using: "composite" 10 | steps: 11 | - run: echo "BAZEL_CONFIG=linux" >> $GITHUB_ENV 12 | shell: bash 13 | if: runner.os == 'Linux' 14 | - run: | 15 | if [ ping -c 1 ${{ inputs.cache_address }} &> /dev/null ] && [ $(curl ${{ inputs.cache_address }}/status | python -c "import sys, json; print(json.load(sys.stdin)['NumFiles'])") -ne 0 ] 16 | then 17 | echo "REMOTE_CACHE=--remote_cache=${{ inputs.cache_address }}" >> $GITHUB_ENV 18 | fi 19 | shell: bash 20 | if: runner.os == 'Linux' 21 | - run: echo "BAZEL_CONFIG=windows" >> $GITHUB_ENV 22 | shell: bash 23 | if: runner.os == 'Windows' 24 | - run: echo "BAZEL_CONFIG=macos" >> $GITHUB_ENV 25 | shell: bash 26 | if: runner.os == 'macOS' 27 | - run: | 28 | bazel test -c dbg //src/cc/tests:* --test_output=all --test_timeout 30 --config=${{ env.BAZEL_CONFIG }} ${{ env.REMOTE_CACHE }} --verbose_failures 29 | shell: bash 30 | name: run cpp tests 31 | - run: | 32 | bazel test -c dbg //src/python/deepgnn/...:* --jobs 1 --test_output=all --test_timeout 600 --config=${{ env.BAZEL_CONFIG }} ${{ env.REMOTE_CACHE }} --verbose_failures 33 | bazel clean 34 | shell: bash 35 | name: run python tests 36 | - run: | 37 | bazel run -c dbg //examples/pytorch:aml --config=${{ env.BAZEL_CONFIG }} 38 | bazel run -c dbg //examples/pytorch:gcn --config=${{ env.BAZEL_CONFIG }} 39 | bazel run -c dbg //examples/pytorch:gat --config=${{ env.BAZEL_CONFIG }} 40 | bazel run -c dbg //examples/pytorch:sage --config=${{ env.BAZEL_CONFIG }} 41 | bazel run -c dbg //examples/pytorch:pyg_interface --config=${{ env.BAZEL_CONFIG }} 42 | bazel test -c dbg //docs:* --test_output=all --jobs 1 --config=${{ env.BAZEL_CONFIG }} ${{ env.REMOTE_CACHE }} --verbose_failures 43 | shell: bash 44 | name: run python examples and doctests 45 | if: runner.os == 'Linux' 46 | - run: | 47 | bazel run -c dbg //docs:make_docs --config=linux ${{ env.REMOTE_CACHE }} --verbose_failures 48 | bazel build -c dbg src/cc/lib/benchmark:* --config=linux ${{ env.REMOTE_CACHE }} --verbose_failures 49 | shell: bash 50 | name: build documentation and benchmarks 51 | if: runner.os == 'Linux' 52 | -------------------------------------------------------------------------------- /.github/actions/wheel/action.yml: -------------------------------------------------------------------------------- 1 | name: "Wheel" 2 | description: "Build wheels" 3 | inputs: 4 | package_version: 5 | description: 'Package version' 6 | required: true 7 | default: '0.1.1' 8 | cache_address: 9 | description: 'Address of bazel remote cache.' 10 | required: true 11 | default: localhost:8080 12 | runs: 13 | using: "composite" 14 | steps: 15 | - run: echo "BUILD_VERSION=${{ inputs.package_version }}" >> $GITHUB_ENV 16 | name: Set package version 17 | shell: bash 18 | - if: runner.os == 'Linux' 19 | run: | 20 | echo "BAZEL_CONFIG=manylinux" >> $GITHUB_ENV 21 | echo "PLAT_NAME=manylinux1_x86_64" >> $GITHUB_ENV 22 | if [ ping -c 1 ${{ inputs.cache_address }} &> /dev/null ] && [ $(curl ${{ inputs.cache_address }}/status | python -c "import sys, json; print(json.load(sys.stdin)['NumFiles'])") -ne 0 ] 23 | then 24 | echo "REMOTE_CACHE=--remote_cache=${{ inputs.cache_address }}" >> $GITHUB_ENV 25 | fi 26 | sudo mkdir /dt11 27 | sudo chown $USER /dt11 28 | shell: bash 29 | name: Configure env variables for linux 30 | - uses: actions/cache@v4 31 | id: compiler-cache 32 | with: 33 | path: /dt11 34 | key: ${{ runner.os }}-${{ hashFiles('tools/manylinux/install-gcc11.sh') }} 35 | 36 | - name: Install Dependencies 37 | if: runner.os == 'Linux' && steps.compiler-cache.outputs.cache-hit != 'true' 38 | run: | 39 | cd tools/manylinux 40 | sudo ./install-gcc11.sh 41 | shell: bash 42 | 43 | - if: runner.os == 'Windows' 44 | run: | 45 | echo "BAZEL_CONFIG=windows" >> $GITHUB_ENV 46 | echo "PLAT_NAME=win-amd64" >> $GITHUB_ENV 47 | shell: bash 48 | name: Configure env variables for windows 49 | 50 | - if: runner.os == 'macOS' 51 | run: | 52 | echo "BAZEL_CONFIG=macos" >> $GITHUB_ENV 53 | echo "PLAT_NAME=macosx-10.9-x86_64" >> $GITHUB_ENV 54 | shell: bash 55 | name: Configure env variables for macos 56 | 57 | - run: bazelisk build -c opt //src/cc/lib:wrapper --config=${{ env.BAZEL_CONFIG }} ${{ env.REMOTE_CACHE }} 58 | shell: bash 59 | name: Build shared library 60 | 61 | - if: runner.os == 'Linux' 62 | run: | 63 | cp -f ./bazel-bin/src/cc/lib/libwrapper.so src/python/deepgnn/graph_engine/snark/ 64 | sudo chmod -R a+rw src/python 65 | shell: bash 66 | name: tweak linux dependencies 67 | 68 | - run: | 69 | cd src/python 70 | pip install wheel twine artifacts-keyring 71 | python setup.py bdist_wheel --plat-name "${PLAT_NAME}" clean --all 72 | name: build wheels 73 | shell: bash 74 | 75 | - if: runner.os == 'Linux' 76 | run: | 77 | pip install auditwheel 78 | auditwheel show src/python/dist/deepgnn_ge-"${BUILD_VERSION}"-py3-none-"${PLAT_NAME}".whl 79 | name: Audit linux wheel 80 | shell: bash 81 | -------------------------------------------------------------------------------- /.github/workflows/benchmark.yml: -------------------------------------------------------------------------------- 1 | name: Benchmark 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | linux: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3.5.3 11 | - name: Start build cache proxy 12 | uses: ./.github/actions/buildcache 13 | env: 14 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 15 | - name: Run benchmarks 16 | uses: ./.github/actions/benchmarks 17 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, reopened, synchronize] 9 | secrets: 10 | BAZEL_REMOTE_AZBLOB_SHARED_KEY: 11 | required: true 12 | workflow_dispatch: 13 | 14 | jobs: 15 | pre-commit: 16 | strategy: 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10"] 19 | os: ["ubuntu-22.04", "windows-2019"] 20 | runs-on: ${{ matrix.os }} 21 | steps: 22 | - uses: actions/checkout@v3.5.3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4.7.0 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Run lint 28 | uses: ./.github/actions/lint 29 | tests: 30 | needs: pre-commit 31 | strategy: 32 | matrix: 33 | python-version: ["3.10"] 34 | os: ["ubuntu-22.04", "windows-2019"] 35 | runs-on: ${{ matrix.os }} 36 | steps: 37 | - uses: actions/checkout@v3.5.3 38 | - name: Set up Python ${{ matrix.python-version }} 39 | uses: actions/setup-python@v4.7.0 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | - name: Start build cache proxy 43 | uses: ./.github/actions/buildcache 44 | env: 45 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 46 | - name: Run tests 47 | uses: ./.github/actions/test 48 | benchmarks: 49 | needs: tests 50 | runs-on: "ubuntu-22.04" 51 | steps: 52 | - uses: actions/checkout@v3.5.3 53 | - name: Start build cache proxy 54 | uses: ./.github/actions/buildcache 55 | env: 56 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 57 | - name: Run benchmarks 58 | uses: ./.github/actions/benchmarks 59 | wheel: 60 | runs-on: ubuntu-22.04 61 | needs: tests 62 | steps: 63 | - uses: actions/checkout@v3.5.3 64 | - name: Set up Python 3.10 65 | uses: actions/setup-python@v4.7.0 66 | with: 67 | python-version: "3.10" 68 | - name: Upload examples 69 | uses: actions/upload-artifact@v4 70 | with: 71 | name: examples 72 | path: examples/* 73 | - name: Start build cache proxy 74 | uses: ./.github/actions/buildcache 75 | env: 76 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 77 | - name: build wheel 78 | uses: ./.github/actions/wheel 79 | with: 80 | package_version: "0.1.1" 81 | - name: Upload wheel file 82 | uses: actions/upload-artifact@v4 83 | with: 84 | name: deepgnn 85 | path: src/python/dist/*.whl 86 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: workflow_dispatch 4 | 5 | jobs: 6 | pre-commit: 7 | strategy: 8 | matrix: 9 | python-version: ["3.8", "3.9", "3.10"] 10 | os: ["ubuntu-22.04", "windows-2022"] 11 | runs-on: ${{ matrix.os }} 12 | steps: 13 | - uses: actions/checkout@v3.5.3 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v4.7.0 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Run lint 19 | uses: ./.github/actions/lint 20 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | tests: 8 | runs-on: ${{ matrix.os }} 9 | if: ${{ github.event.workflow_run == null || github.event.workflow_run.conclusion == 'success'}} 10 | strategy: 11 | matrix: 12 | python-version: ["3.10"] 13 | os: ["ubuntu-22.04", "windows-2019"] 14 | steps: 15 | - uses: actions/checkout@v3.5.3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4.7.0 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Start build cache proxy 21 | uses: ./.github/actions/buildcache 22 | env: 23 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 24 | - name: Run tests 25 | uses: ./.github/actions/test 26 | -------------------------------------------------------------------------------- /.github/workflows/wheel.yml: -------------------------------------------------------------------------------- 1 | name: Wheel 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | package_version: 7 | required: true 8 | type: string 9 | description: DeepGNN version to put in wheel 10 | 11 | jobs: 12 | wheels: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: ["ubuntu-22.04", "windows-2019"] 17 | steps: 18 | - uses: actions/checkout@v3.5.3 19 | - name: Set up Python 3.10 20 | uses: actions/setup-python@v4.7.0 21 | with: 22 | python-version: "3.10" 23 | - name: Start build cache proxy 24 | uses: ./.github/actions/buildcache 25 | env: 26 | BAZEL_REMOTE_AZBLOB_SHARED_KEY : ${{ secrets.BAZEL_REMOTE_AZBLOB_SHARED_KEY }} 27 | - name: build wheel 28 | uses: ./.github/actions/wheel 29 | with: 30 | package_version: ${{ github.event.inputs.package_version }} 31 | - name: Upload wheel file 32 | uses: actions/upload-artifact@v4 33 | with: 34 | name: deepgnn-${{ matrix.os }} 35 | path: src/python/dist/*.whl 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Bazel output folders 5 | bazel-bin* 6 | bazel-out* 7 | bazel-DeepGNN* 8 | bazel-testlogs* 9 | 10 | # Python package artifacts 11 | src/python/dist/* 12 | src/python/build/* 13 | *.egg-info 14 | 15 | #Python cache files 16 | *.pyc 17 | 18 | # Binary files 19 | *.so 20 | *.dll 21 | 22 | #IDE related 23 | .vscode 24 | 25 | # compiler toolchain 26 | .dt11* 27 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_return_any = False 3 | warn_unused_configs = True 4 | mypy_path = src/python/deepgnn 5 | 6 | files=src/python/deepgnn,examples 7 | # exclude files starting with test_ or ending with _test.py 8 | exclude = (?x)(_test\.py$|test_.) 9 | 10 | [mypy-azureml.*] 11 | ignore_missing_imports = True 12 | 13 | [mypy-numpy.*] 14 | ignore_missing_imports = True 15 | 16 | [mypy-pytest.*] 17 | ignore_missing_imports = True 18 | 19 | [mypy-setuptools.*] 20 | ignore_missing_imports = True 21 | 22 | [mypy-networkx.*] 23 | ignore_missing_imports = True 24 | 25 | [mypy-fsspec.*] 26 | ignore_missing_imports = True 27 | 28 | [mypy-aiohttp.*] 29 | ignore_missing_imports = True 30 | 31 | [mypy-grpc.*] 32 | ignore_missing_imports = True 33 | 34 | [mypy-grpc_health.v1.*] 35 | ignore_missing_imports = True 36 | 37 | [mypy-tensorflow.*] 38 | ignore_missing_imports = True 39 | 40 | [mypy-tensorflow_addons.*] 41 | ignore_missing_imports = True 42 | 43 | [mypy-tensorboard.*] 44 | ignore_missing_imports = True 45 | 46 | [mypy-opencensus.*] 47 | ignore_missing_imports = True 48 | 49 | [mypy-azure.*] 50 | ignore_missing_imports = True 51 | 52 | [mypy-horovod.*] 53 | ignore_missing_imports = True 54 | 55 | [mypy-sklearn.*] 56 | ignore_missing_imports = True 57 | 58 | [mypy-torch.*] 59 | ignore_missing_imports = True 60 | 61 | [mypy-torch_geometric.*] 62 | ignore_missing_imports = True 63 | 64 | [mypy-deepgnn.tf.encoders.*] 65 | ignore_errors = True 66 | 67 | [mypy-kubernetes.*] 68 | ignore_missing_imports = True 69 | 70 | [mypy-ray.*] 71 | ignore_missing_imports = True 72 | 73 | [mypy-ray_on_aml.*] 74 | ignore_missing_imports = True 75 | 76 | [mypy-pasta.*] 77 | ignore_missing_imports = True 78 | 79 | [mypy-tenacity.*] 80 | ignore_missing_imports = True 81 | 82 | [mypy-torch_sparse.tensor.*] 83 | ignore_missing_imports = True 84 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-yaml 7 | - id: check-json 8 | - id: detect-private-key 9 | - id: end-of-file-fixer 10 | - id: forbid-new-submodules 11 | - id: mixed-line-ending 12 | - id: name-tests-test 13 | - id: pretty-format-json 14 | args: ["--autofix"] 15 | - id: requirements-txt-fixer 16 | - id: trailing-whitespace 17 | - repo: https://github.com/psf/black 18 | rev: 22.12.0 19 | hooks: 20 | - id: black 21 | - repo: https://github.com/pre-commit/mirrors-mypy 22 | rev: v0.991 23 | hooks: 24 | - id: mypy 25 | args: [--config-file, .mypy.ini] 26 | pass_filenames: false 27 | - repo: https://github.com/pycqa/flake8 28 | rev: '6.0.0' 29 | hooks: 30 | - id: flake8 31 | additional_dependencies: [flake8-docstrings] 32 | # https://black.readthedocs.io/en/stable/compatible_configs.html#id2 33 | args: [--ignore,"E501", --extend-ignore,"E203,W503", --max-line-length, "88"] 34 | exclude: (test_|_test.py|deepspeed|twinbert|main|conftest|testserver|_adl_reader) 35 | - repo: https://github.com/pocc/pre-commit-hooks 36 | rev: v1.3.5 37 | hooks: 38 | - id: clang-format 39 | args: [--style=Microsoft, -i] 40 | - repo: https://github.com/google/pre-commit-tool-hooks 41 | rev: v1.2.4 42 | hooks: 43 | - id: check-copyright 44 | args: 45 | - --copyright 46 | - |+ 47 | Copyright (c) Microsoft Corporation. 48 | Licensed under the MIT License. 49 | - --skip_pattern 50 | - \..*|LICENSE 51 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.9" 7 | sphinx: 8 | configuration: docs/conf.py 9 | 10 | python: 11 | install: 12 | - requirements: docs/requirements.txt 13 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGNN Overview 2 | 3 | DeepGNN is a framework for training machine learning models on large scale graph data. DeepGNN contains all the necessary features including: 4 | 5 | * Distributed GNN training and inferencing on both CPU and GPU. 6 | * Custom graph neural network design. 7 | * Online Sampling: Graph Engine (GE) will load all graph data, each training worker will call GE to get node/edge/neighbor features and labels. 8 | * Automatic graph partitioning. 9 | * Highly performant and scalable. 10 | 11 | Project is in alpha version, there might be breaking changes in the future and they will be documented in the changelog. 12 | 13 | ## Usage 14 | 15 | Install pip package: 16 | ```bash 17 | python -m pip install deepgnn 18 | ``` 19 | If you want to build package from source, see instructions in [`CONTRIBUTING.md`](CONTRIBUTING.md). 20 | 21 | Train and evaluate a graphsage model with pytorch on cora dataset: 22 | ```bash 23 | cd examples/pytorch 24 | python sage.py 25 | ``` 26 | 27 | ## Migrating Scripts 28 | 29 | We provide a python module to help you upgrade your scripts to new deepgnn versions. 30 | 31 | ```bash 32 | pip install google-pasta 33 | python -m deepgnn.migrate.0_1_56 --script_dir directory_to_migrate 34 | ``` 35 | 36 | See [`CHANGELOG.md`](CHANGELOG.md) for full change details. 37 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 4 | 5 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 6 | 7 | ## Reporting Security Issues 8 | 9 | **Please do not report security vulnerabilities through public GitHub issues.** 10 | 11 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 12 | 13 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 14 | 15 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 16 | 17 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 18 | 19 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 20 | * Full paths of source file(s) related to the manifestation of the issue 21 | * The location of the affected source code (tag/branch/commit or direct URL) 22 | * Any special configuration required to reproduce the issue 23 | * Step-by-step instructions to reproduce the issue 24 | * Proof-of-concept or exploit code (if possible) 25 | * Impact of the issue, including how an attacker might exploit the issue 26 | 27 | This information will help us triage your report more quickly. 28 | 29 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 30 | 31 | ## Preferred Languages 32 | 33 | We prefer all communications to be in English. 34 | 35 | ## Policy 36 | 37 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 38 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please contact [deepgnn@microsoft.com](mailto:deepgnn@microsoft.com). 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for DeepGNN is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /config/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | -------------------------------------------------------------------------------- /config/configs.bzl: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Configurations to build manylinux pip package""" 5 | 6 | def _manylinux_config(name, compiler, sysroot = None): 7 | if cuda_version != None and rocm_version != None: 8 | fail("Specifying both cuda_version and rocm_version is not supported.") 9 | 10 | env = { 11 | "ABI_VERSION": "gcc", 12 | "ABI_LIBC_VERSION": "glibc_2.17", 13 | "BAZEL_COMPILER": compiler, 14 | "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", 15 | "BAZEL_TARGET_LIBC": "glibc_2.17", 16 | "BAZEL_TARGET_CPU": "k8", 17 | "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", 18 | "CC_TOOLCHAIN_NAME": "linux_gnu_x86", 19 | "CC": compiler, 20 | "CLEAR_CACHE": "1", 21 | "HOST_CXX_COMPILER": compiler, 22 | "HOST_C_COMPILER": compiler, 23 | } 24 | 25 | 26 | manylinux_config = _manylinux_config 27 | 28 | 29 | def initialize_manylinux_configs(): 30 | manylinux_config("manylinux", "gcc-11") 31 | -------------------------------------------------------------------------------- /config/hadoop.BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cc_import( 5 | name = "hadoop_so", 6 | shared_library = "lib/native/libhdfs.so", 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | cc_library( 11 | name = "hadoop", 12 | deps = [ 13 | ":hadoop_so", 14 | "@jvm//:jvm", 15 | ], 16 | data = [":."], 17 | visibility = ["//visibility:public"], 18 | ) 19 | 20 | cc_import( # For #include "hdfs.h" 21 | name = "hadoop_include", 22 | hdrs = ["include/hdfs.h"], 23 | shared_library = "lib/native/libhdfs.so.0.0.0", 24 | visibility = ["//visibility:public"], 25 | ) 26 | 27 | genrule( 28 | name = "gen_hadoop_py", 29 | srcs = [], 30 | outs = [ 31 | "hadoop_py.py" 32 | ], 33 | cmd = "echo '' > $(@D)/hadoop_py.py", 34 | visibility = ["//visibility:public"], 35 | ) 36 | 37 | py_binary( 38 | name = "hadoop_py", 39 | srcs = ["hadoop_py.py"], 40 | data = ["lib/native/libhdfs.so", "@jvm//:jvm_py"], 41 | deps = [], 42 | visibility = ["//visibility:public"], 43 | ) 44 | -------------------------------------------------------------------------------- /config/json.BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cc_library( 5 | name = "json", 6 | hdrs = ["single_include/nlohmann/json.hpp"], 7 | includes = ["single_include"], 8 | visibility = ["//visibility:public"], 9 | ) 10 | -------------------------------------------------------------------------------- /config/jvm.BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cc_import( 5 | name = "jvm", 6 | shared_library = "jre/lib/amd64/server/libjvm.so", 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | filegroup( 11 | name = "jvm_py", 12 | srcs = [ 13 | "jre/lib/amd64/server/libjvm.so", 14 | ], 15 | visibility = ["//visibility:public"], 16 | ) 17 | -------------------------------------------------------------------------------- /config/mimalloc.BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake") 5 | 6 | # TODO(alsamylk): use mimalloc on windows and mac. 7 | filegroup( 8 | name = "all", 9 | srcs = glob(["**"]), 10 | visibility = ["//visibility:public"], 11 | ) 12 | 13 | config_setting( 14 | name = "debug_build", 15 | constraint_values = ["@platforms//os:linux"], 16 | values = { 17 | "compilation_mode": "dbg", 18 | }, 19 | ) 20 | 21 | config_setting( 22 | name = "optimized_build", 23 | constraint_values = ["@platforms//os:linux"], 24 | values = { 25 | "compilation_mode": "opt", 26 | }, 27 | ) 28 | 29 | cmake( 30 | name = "mimalloc", 31 | cache_entries = { 32 | "CMAKE_C_FLAGS": "-fPIC", 33 | "MI_OVERRIDE": "ON", 34 | "MI_USE_CXX": "ON", 35 | "MI_BUILD_STATIC": "ON", 36 | "MI_BUILD_TESTS": "OFF", 37 | "MI_INSTALL_TOPLEVEL": "ON", 38 | "CMAKE_INSTALL_LIBDIR": "libdir", 39 | }, 40 | out_lib_dir = "libdir", 41 | lib_source = "@mimalloc//:all", 42 | out_static_libs = select({ 43 | ":optimized_build": ["libmimalloc.a"], 44 | ":debug_build": ["libmimalloc-debug.a"], 45 | }), 46 | visibility = ["//visibility:public"], 47 | ) 48 | -------------------------------------------------------------------------------- /config/openssl.BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_foreign_cc//foreign_cc:defs.bzl", "configure_make") 5 | 6 | filegroup( 7 | name = "all_srcs", 8 | srcs = glob( 9 | include = ["**"], 10 | exclude = ["*.bazel"], 11 | ), 12 | ) 13 | 14 | configure_make( 15 | name = "openssl", 16 | configure_command = "config", 17 | configure_in_place = True, 18 | env = select({ 19 | "@platforms//os:macos": { 20 | "AR": "", 21 | "PERL": "$$EXT_BUILD_ROOT$$/$(PERL)", 22 | }, 23 | "//conditions:default": { 24 | "PERL": "$$EXT_BUILD_ROOT$$/$(PERL)", 25 | }, 26 | }), 27 | lib_name = "openssl", 28 | lib_source = ":all_srcs", 29 | out_static_libs = [ 30 | "libssl.a", 31 | "libcrypto.a", 32 | ], 33 | out_lib_dir = select({ 34 | "@platforms//os:macos": "../copy_openssl/openssl/lib", 35 | "//conditions:default": "../copy_openssl/openssl/lib64", 36 | }), 37 | targets = ["install_sw"], 38 | toolchains = ["@rules_perl//:current_toolchain"], 39 | visibility = ["//visibility:public"], 40 | ) 41 | -------------------------------------------------------------------------------- /config/openssl_deps.bzl: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """A module defining the third party dependency OpenSSL""" 5 | 6 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 7 | load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") 8 | 9 | def openssl_deps(): 10 | maybe( 11 | http_archive, 12 | name = "rules_perl", 13 | sha256 = "7ad2510e54d530f75058e55f38e3e44acb682d65051514be88636adb1779b383", 14 | strip_prefix = "rules_perl-0.2.1", 15 | urls = [ 16 | "https://github.com/bazelbuild/rules_perl/archive/refs/tags/0.2.1.tar.gz", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /config/openssl_setup.bzl: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # openssl needs Perl and Text::Template module https://github.com/openssl/openssl/blob/master/INSTALL.md#prerequisites 5 | load("@rules_perl//perl:deps.bzl", "perl_register_toolchains", "perl_rules_dependencies") 6 | 7 | def openssl_setup(): 8 | perl_rules_dependencies() 9 | perl_register_toolchains() 10 | -------------------------------------------------------------------------------- /config/variables.bzl: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Shared variables for build configuration based on the host platform.""" 5 | 6 | CXX_OPTS = select({ 7 | "@platforms//os:macos": [ 8 | "-std=c++20", 9 | "-Werror", 10 | "-fvisibility=hidden", 11 | "-fvisibility-inlines-hidden", 12 | "-Wno-error=non-pod-varargs", 13 | ], 14 | "@platforms//os:windows": ["/std:c++20", "/W3", "/guard:cf", "/Qspectre"], 15 | "//conditions:default": ["-std=c++20", "-Werror", "-fvisibility=hidden", "-fvisibility-inlines-hidden"], 16 | }) 17 | 18 | PLATFORM_DEFINES = select({ 19 | "@platforms//os:linux": ["SNARK_PLATFORM_LINUX"], 20 | "@platforms//os:windows": ["SNARK_PLATFORM_WINDOWS"], 21 | "//conditions:default": [], 22 | }) 23 | -------------------------------------------------------------------------------- /docs/advanced/index.rst: -------------------------------------------------------------------------------- 1 | Advanced 2 | ======== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | hdfs 9 | sql_shard 10 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Configuration for the Sphinx documentation builder.""" 5 | project = "DeepGNN" 6 | copyright = "2022, Microsoft" 7 | author = "DeepGNN team" 8 | 9 | # The short X.Y version 10 | version = "0.1" 11 | # The full version, including alpha/beta/rc tags 12 | release = "0.1" 13 | 14 | extensions = [ 15 | "sphinx.ext.autodoc", 16 | "sphinx.ext.doctest", 17 | "sphinx.ext.todo", 18 | "sphinx.ext.coverage", 19 | "sphinx_copybutton", 20 | ] 21 | 22 | templates_path = ["_templates"] 23 | 24 | source_suffix = ".rst" 25 | master_doc = "index" 26 | language = None 27 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 28 | pygments_style = None 29 | html_theme = "alabaster" 30 | html_static_path = ["../_build/_static"] 31 | htmlhelp_basename = "DeepGNNdoc" 32 | man_pages = [(master_doc, "deepgnn", "DeepGNN Documentation", [author], 1)] 33 | 34 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 35 | copybutton_prompt_is_regexp = True 36 | 37 | todo_include_todos = False 38 | -------------------------------------------------------------------------------- /docs/doctest_template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import doctest 5 | import os 6 | import platform 7 | import sys 8 | import deepgnn 9 | 10 | if __name__ == "__main__": 11 | res = doctest.testfile(sys.argv[1], report=True, optionflags=doctest.ELLIPSIS) 12 | assert res.failed == 0 13 | -------------------------------------------------------------------------------- /docs/graph_engine/index.rst: -------------------------------------------------------------------------------- 1 | Graph Engine 2 | ============ 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | overview 9 | data_spec 10 | from_networkx 11 | custom_decoder 12 | spark_converter 13 | temporal 14 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to DeepGNN's documentation! 2 | =================================== 3 | DeepGNN is a package for training/evaluating ML models on graph data. It is a Python library that provides: 4 | 5 | * A graph engine object designed for ML tasks with an assortment of routines for sampling nodes, edges and neighbors as well as feature fetching. 6 | * Various aggregators, encoders and decoders to pass graph data to neural nets. 7 | * Basic NN layers for training: convolution, attention and bindings to pytorch-geometric library. 8 | * A collection of trainers to work with models in local and distributed environments. 9 | 10 | Documentation 11 | ------------- 12 | .. toctree:: 13 | :maxdepth: 2 14 | :titlesonly: 15 | 16 | graph_engine/index 17 | 18 | torch/index 19 | tf/index 20 | 21 | advanced/index 22 | -------------------------------------------------------------------------------- /docs/make_docs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Helper class to generate docs with sphinx.""" 5 | import sphinx.cmd.build as build 6 | 7 | build.build_main(argv=["-n", "-b", "html", "docs", "_build"]) 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | babel==2.10.3 3 | certifi==2024.7.4 4 | chardet==3.0.4 5 | charset-normalizer==2.1.1 6 | docutils==0.18.1 7 | idna==3.7 8 | imagesize==1.2.0 9 | importlib-metadata==4.6.1 10 | Jinja2==3.0.1 11 | MarkupSafe==2.0.1 12 | packaging==21.3 13 | Pygments==2.15.0 14 | pyparsing==2.4.7 15 | pytz==2021.1 16 | requests==2.32.2 17 | setuptools==70.0.0 18 | six==1.15.0 19 | snowballstemmer==2.2.0 20 | sphinx==1.8.5 21 | sphinx-copybutton==0.5.0 22 | sphinxcontrib-applehelp==1.0.2 23 | sphinxcontrib-devhelp==1.0.2 24 | sphinxcontrib-htmlhelp==2.0.0 25 | sphinxcontrib-jsmath==1.0.1 26 | sphinxcontrib-qthelp==1.0.3 27 | sphinxcontrib-serializinghtml==1.1.5 28 | sphinxcontrib-websupport==1.2.4 29 | urllib3==1.26.19 30 | zipp==3.19.1 31 | -------------------------------------------------------------------------------- /docs/tf/index.rst: -------------------------------------------------------------------------------- 1 | Tensorflow 2 | ========== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | node_class 9 | link_pred 10 | ray_usage 11 | -------------------------------------------------------------------------------- /docs/torch/index.rst: -------------------------------------------------------------------------------- 1 | Pytorch 2 | ======= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | link_pred 9 | node_class 10 | distrib 11 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN model examples.""" 4 | -------------------------------------------------------------------------------- /examples/hdfs_setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | sudo apt-get update 5 | sudo apt-get install openjdk-8-jre 6 | export HADOOP_VERSION=hadoop-3.3.6 7 | wget -nc https://dlcdn.apache.org/hadoop/common/$HADOOP_VERSION/$HADOOP_VERSION.tar.gz 8 | tar -xvzf $HADOOP_VERSION.tar.gz 9 | export HADOOP_HOME=$(pwd)/$HADOOP_VERSION 10 | export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 11 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server/" 12 | -------------------------------------------------------------------------------- /examples/pytorch/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_python//python:defs.bzl", "py_binary") 4 | load("@pip_deps//:requirements.bzl", "requirement") 5 | 6 | common_deps = [ 7 | "//src/python/deepgnn:deepgnn_ge_wheel_library", 8 | requirement("azure-datalake-store"), 9 | requirement("fsspec"), 10 | requirement("grpcio"), 11 | requirement("msgpack"), 12 | requirement("numpy"), 13 | requirement("networkx"), 14 | requirement("opencensus"), 15 | requirement("opencensus-context"), 16 | requirement("opencensus-ext-azure"), 17 | requirement("packaging"), 18 | requirement("pyyaml"), 19 | requirement("ray"), 20 | requirement("referencing"), 21 | requirement("rpds-py"), 22 | requirement("scikit-learn"), 23 | requirement("torch"), 24 | requirement("torch_geometric"), 25 | requirement("tenacity"), 26 | ] 27 | 28 | sparse_deps = common_deps + [ 29 | requirement("torch-sparse"), 30 | requirement("torch-scatter"), 31 | requirement("torch-cluster"), 32 | ] 33 | 34 | py_binary( 35 | name = "gcn", 36 | srcs = [ 37 | "gcn.py", 38 | ], 39 | deps = sparse_deps, 40 | ) 41 | 42 | py_binary( 43 | name = "gat", 44 | srcs = [ 45 | "gat.py", 46 | ], 47 | deps = common_deps, 48 | ) 49 | 50 | py_binary( 51 | name = "tgn", 52 | srcs = [ 53 | "tgn.py", 54 | ], 55 | deps = sparse_deps, 56 | ) 57 | 58 | py_binary( 59 | name = "sage", 60 | srcs = [ 61 | "sage.py", 62 | ], 63 | deps = sparse_deps, 64 | ) 65 | 66 | py_binary( 67 | name = "aml", 68 | srcs = [ 69 | "aml.py", 70 | ], 71 | deps = common_deps, 72 | ) 73 | 74 | py_binary( 75 | name = "pyg_interface", 76 | srcs = [ 77 | "pyg_interface.py", 78 | ], 79 | deps = sparse_deps, 80 | ) 81 | -------------------------------------------------------------------------------- /examples/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN Torch model examples.""" 4 | -------------------------------------------------------------------------------- /examples/pytorch/hetgnn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_binary") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_binary( 8 | name = "main", 9 | srcs = [ 10 | "__init__.py", 11 | "evaluation.py", 12 | "graph.py", 13 | "main.py", 14 | "model.py", 15 | "sampler.py", 16 | ], 17 | deps = [ 18 | "//src/python/deepgnn:deepgnn_ge_wheel_library", 19 | requirement("numpy"), 20 | requirement("scikit-learn"), 21 | requirement("fsspec"), 22 | requirement("networkx"), 23 | requirement("opencensus"), 24 | requirement("opencensus-context"), 25 | requirement("opencensus-ext-azure"), 26 | requirement("azure-datalake-store"), 27 | requirement("torch"), 28 | requirement("torch_geometric"), 29 | requirement("tenacity"), 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /examples/pytorch/hetgnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN Torch HetGNN model example.""" 4 | -------------------------------------------------------------------------------- /examples/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN TF model examples.""" 4 | -------------------------------------------------------------------------------- /examples/tensorflow/gat/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "example_tf_gat", 9 | srcs = [ 10 | "gat.py", 11 | "main.py", 12 | ], 13 | deps = [ 14 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 15 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 16 | "//src/python/deepgnn/graph_engine/data:graph_engine_data", 17 | "//src/python/deepgnn/tf:deepgnn_tf", 18 | "//src/python/deepgnn/tf/common:deepgnn_tf_common", 19 | "//src/python/deepgnn/tf/nn:deepgnn_tf_nn", 20 | ], 21 | deprecation = "This target is deprecated", 22 | ) 23 | 24 | py_test( 25 | name = "test_gat", 26 | srcs = ["test_gat.py"], 27 | imports = ["../../../src/python/"], 28 | main = "test_gat.py", 29 | python_version = "PY3", 30 | srcs_version = "PY3", 31 | deps = [ 32 | ":example_tf_gat", 33 | requirement("numpy"), 34 | requirement("pytest"), 35 | requirement("scikit-learn"), 36 | requirement("tensorflow-addons"), 37 | requirement("keras"), 38 | requirement("tensorflow"), 39 | requirement("fsspec"), 40 | requirement("networkx"), 41 | requirement("opencensus"), 42 | requirement("opencensus-context"), 43 | requirement("opencensus-ext-azure"), 44 | requirement("azure-datalake-store"), 45 | requirement("tenacity"), 46 | ], 47 | tags = ["manual"], 48 | deprecation = "This test is deprecated", 49 | ) 50 | -------------------------------------------------------------------------------- /examples/tensorflow/gat/README.md: -------------------------------------------------------------------------------- 1 | ## Graph Attention Networks (GAT) 2 | - Reference : [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903) 3 | - Author's code: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT) 4 | 5 | ### How to run 6 | see [run.sh](./run.sh) 7 | 8 | 9 | #### Results 10 | 11 | | Dataset | Test Accuracy | Baseline (Paper) | 12 | | -------- | ------------- | ---------------- | 13 | | Cora | 83.3 | 83.0 (+/-0.5) | 14 | | Citeseer | 71.8 | 72.5 (+/-0.7) | 15 | 16 | Training time: (300 epochs) 17 | | Dataset | CPU | GPU | 18 | | -------- | ---- | --- | 19 | | Cora | 21.9 | 6.1 | 20 | | Citeseer | 30.7 | 6.8 | 21 | 22 | Cora test 23 | ```shell 24 | 25 | # prepare cora dataset 26 | python -m deepgnn.graph_engine.data.citation --dataset cora --data_dir /tmp/citation/cora 27 | 28 | # train 29 | python3 /examples/tensorflow/gat/main.py --mode train --seed 123 --model_dir /tmp/tmp70oef8fd --data_dir /tmp/citation/cora --eager --batch_size 140 --learning_rate 0.005 --epochs 300 --neighbor_edge_types 0 --attn_drop 0.6 --ffd_drop 0.6 --head_num 8,1 --l2_coef 0.0005 --hidden_dim 8 --gpu --feature_idx 0 --feature_dim 1433 --label_idx 1 --label_dim 1 --num_classes 7 --prefetch_worker_size 1 --log_save_steps 20 --summary_save_steps 1 30 | 31 | # evaluate 32 | python3 /examples/tensorflow/gat/main.py --mode evaluate --seed 123 --model_dir /tmp/tmp70oef8fd --data_dir /tmp/citation/cora --eager --batch_size 1000 --evaluate_node_files /tmp/citation/cora/test.nodes --neighbor_edge_types 0 --attn_drop 0.0 --ffd_drop 0.0 --head_num 8,1 --l2_coef 0.0005 --hidden_dim 8 --gpu --feature_idx 0 --feature_dim 1433 --label_idx 1 --label_dim 1 --num_classes 7 --prefetch_worker_size 1 --log_save_steps 1 --summary_save_steps 1 33 | ``` 34 | 35 | ### Run GAT with your graph 36 | * [Prepare Graph Data](../../../docs/graph_engine/data_spec.rst) 37 | -------------------------------------------------------------------------------- /examples/tensorflow/gat/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN TF GAT model example.""" 4 | -------------------------------------------------------------------------------- /examples/tensorflow/gat/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | set -ex 5 | 6 | DIR_NAME=$(dirname "$0") 7 | 8 | ## DEVICE support: ["cpu", "gpu"] 9 | DEVICE=${1:-cpu} 10 | 11 | if [[ ${DEVICE} == "gpu" ]] 12 | then 13 | PLATFORM_DEVICE=--gpu 14 | env CUDA_VISIBLE_DEVICES=0 15 | fi 16 | 17 | CLEANUP=${2:-"no_cleanup"} 18 | 19 | DATA_DIR=/tmp/cora/ 20 | python -m deepgnn.graph_engine.data.citation --data_dir $DATA_DIR 21 | 22 | BASE_DIR=$HOME/tmp/gat-cora-$(date +"%Y%m%d_%H%M%N") 23 | ### =================== EagerTrainer (singler worker) =================== 24 | MODEL_DIR=$BASE_DIR/eager 25 | rm -rf $MODEL_DIR 26 | python $DIR_NAME/main.py \ 27 | --mode train --trainer ps \ 28 | --seed 123 \ 29 | --model_dir $MODEL_DIR \ 30 | --data_dir $DATA_DIR \ 31 | --eager \ 32 | --batch_size 140 \ 33 | --learning_rate 0.005 \ 34 | --epochs 300 \ 35 | --neighbor_edge_types 0 \ 36 | --attn_drop 0.6 \ 37 | --ffd_drop 0.6 \ 38 | --head_num 8,1 \ 39 | --l2_coef 0.0005 \ 40 | --hidden_dim 8 \ 41 | --feature_idx 0 \ 42 | --feature_dim 1433 \ 43 | --label_idx 1 \ 44 | --label_dim 1 \ 45 | --num_classes 7 \ 46 | --backend snark \ 47 | --converter skip ${PLATFORM_DEVICE} 48 | 49 | ### =================== HorovodEagerTrainer (singler worker) =================== 50 | MODEL_DIR=$BASE_DIR/hvd 51 | rm -rf $MODEL_DIR 52 | python $DIR_NAME/main.py \ 53 | --mode train --trainer hvd \ 54 | --seed 123 \ 55 | --model_dir $MODEL_DIR \ 56 | --data_dir $DATA_DIR \ 57 | --eager \ 58 | --batch_size 140 \ 59 | --learning_rate 0.005 \ 60 | --epochs 300 \ 61 | --neighbor_edge_types 0 \ 62 | --attn_drop 0.6 \ 63 | --ffd_drop 0.6 \ 64 | --head_num 8,1 \ 65 | --l2_coef 0.0005 \ 66 | --hidden_dim 8 \ 67 | --feature_idx 0 \ 68 | --feature_dim 1433 \ 69 | --label_idx 1 \ 70 | --label_dim 1 \ 71 | --num_classes 7 \ 72 | --backend snark \ 73 | --converter skip ${PLATFORM_DEVICE} 74 | 75 | 76 | if [[ "${CLEANUP}" != "no_cleanup" ]]; then 77 | rm -rf $BASE_DIR 78 | fi 79 | -------------------------------------------------------------------------------- /examples/tensorflow/gcn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "example_tf_gcn", 9 | srcs = [ 10 | "gcn.py", 11 | "main.py", 12 | ], 13 | deps = [ 14 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 15 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 16 | "//src/python/deepgnn/graph_engine/data:graph_engine_data", 17 | "//src/python/deepgnn/tf:deepgnn_tf", 18 | "//src/python/deepgnn/tf/common:deepgnn_tf_common", 19 | "//src/python/deepgnn/tf/nn:deepgnn_tf_nn", 20 | ], 21 | deprecation = "This target is deprecated", 22 | ) 23 | 24 | py_test( 25 | name = "test_gcn", 26 | srcs = ["test_gcn.py"], 27 | imports = ["../../../src/python/"], 28 | main = "test_gcn.py", 29 | python_version = "PY3", 30 | srcs_version = "PY3", 31 | deps = [ 32 | ":example_tf_gcn", 33 | requirement("numpy"), 34 | requirement("pytest"), 35 | requirement("scikit-learn"), 36 | requirement("fsspec"), 37 | requirement("tensorflow"), 38 | requirement("tensorflow-addons"), 39 | requirement("keras"), 40 | requirement("networkx"), 41 | requirement("opencensus"), 42 | requirement("opencensus-context"), 43 | requirement("opencensus-ext-azure"), 44 | requirement("azure-datalake-store"), 45 | requirement("tenacity"), 46 | ], 47 | tags = ["manual"], 48 | deprecation = "This test is deprecated", 49 | ) 50 | -------------------------------------------------------------------------------- /examples/tensorflow/gcn/README.md: -------------------------------------------------------------------------------- 1 | ## Graph Convolutional Networks (GCN) 2 | - Reference : https://arxiv.org/abs/1609.02907 3 | - Author's code: https://github.com/tkipf/gcn 4 | 5 | ### How to run 6 | see [run.sh](./run.sh) 7 | 8 | 9 | #### Results 10 | 11 | | Dataset | Test Accuracy | Baseline (Paper) | 12 | | -------- | ------------- | ---------------- | 13 | | Cora | 80.8 | 81.5 | 14 | | Citeseer | 70.4 | 70.3 | 15 | 16 | 17 | Training time: (200 epochs) 18 | | Dataset | CPU | GPU | 19 | | -------- | ---- | --- | 20 | | Cora | 2.9 | 2.3 | 21 | | Citeseer | 4.0 | 2.5 | 22 | 23 | * CPU: E5-2690 v4 @ 2.60GHz (6 cores) 24 | * GPU: P100 16GB 25 | 26 | ### Run GCN with your graph 27 | * [Prepare Graph Data](../../../docs/graph_engine/data_spec.rst) 28 | -------------------------------------------------------------------------------- /examples/tensorflow/gcn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN GCN model example.""" 4 | -------------------------------------------------------------------------------- /examples/tensorflow/gcn/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | set -ex 5 | 6 | DIR_NAME=$(dirname "$0") 7 | 8 | ## DEVICE support: ["cpu", "gpu"] 9 | DEVICE=${1:-cpu} 10 | 11 | if [[ ${DEVICE} == "gpu" ]] 12 | then 13 | PLATFORM_DEVICE=--gpu 14 | env CUDA_VISIBLE_DEVICES=0 15 | fi 16 | 17 | CLEANUP=${2:-"no_cleanup"} 18 | 19 | DATA_DIR=/tmp/cora/ 20 | python -m deepgnn.graph_engine.data.citation --data_dir $DATA_DIR 21 | 22 | BASE_DIR=$HOME/tmp/gcn-cora-$(date +"%Y%m%d_%H%M%N") 23 | ### =================== EagerTrainer (singler worker) =================== 24 | MODEL_DIR=$BASE_DIR/eager 25 | rm -rf $MODEL_DIR 26 | python $DIR_NAME/main.py \ 27 | --mode train --trainer ps \ 28 | --seed 123 \ 29 | --model_dir $MODEL_DIR \ 30 | --data_dir $DATA_DIR \ 31 | --eager \ 32 | --batch_size 140 \ 33 | --learning_rate 0.01 \ 34 | --epochs 200 \ 35 | --neighbor_edge_types 0 \ 36 | --dropout 0.5 \ 37 | --l2_coef 0.0005 \ 38 | --hidden_dim 16 \ 39 | --feature_idx 0 \ 40 | --feature_dim 1433 \ 41 | --label_idx 1 \ 42 | --label_dim 1 \ 43 | --num_classes 7 \ 44 | --backend snark \ 45 | --converter skip ${PLATFORM_DEVICE} 46 | 47 | python $DIR_NAME/main.py \ 48 | --mode evaluate --trainer ps \ 49 | --seed 123 \ 50 | --model_dir $MODEL_DIR \ 51 | --data_dir $DATA_DIR \ 52 | --eager \ 53 | --batch_size 1000 \ 54 | --evaluate_node_files $DATA_DIR/test.nodes \ 55 | --neighbor_edge_types 0 \ 56 | --dropout 0.0 \ 57 | --hidden_dim 16 \ 58 | --feature_idx 0 \ 59 | --feature_dim 1433 \ 60 | --label_idx 1 \ 61 | --label_dim 1 \ 62 | --num_classes 7 \ 63 | --backend snark \ 64 | --converter skip ${PLATFORM_DEVICE} 65 | 66 | if [[ "${CLEANUP}" != "no_cleanup" ]]; then 67 | rm -rf $BASE_DIR 68 | fi 69 | -------------------------------------------------------------------------------- /examples/tensorflow/han/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "example_tf_han", 9 | srcs = [ 10 | "han.py", 11 | "main.py", 12 | ], 13 | deps = [ 14 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 15 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 16 | "//src/python/deepgnn/graph_engine/data:graph_engine_data", 17 | "//src/python/deepgnn/tf:deepgnn_tf", 18 | "//src/python/deepgnn/tf/common:deepgnn_tf_common", 19 | "//src/python/deepgnn/tf/encoders:deepgnn_tf_encoders", 20 | "//src/python/deepgnn/tf/nn:deepgnn_tf_nn", 21 | ], 22 | deprecation = "This target is deprecated", 23 | ) 24 | 25 | py_test( 26 | name = "test_han", 27 | srcs = ["test_han.py"], 28 | imports = ["../../../src/python/"], 29 | main = "test_han.py", 30 | python_version = "PY3", 31 | srcs_version = "PY3", 32 | deps = [ 33 | ":example_tf_han", 34 | requirement("numpy"), 35 | requirement("fsspec"), 36 | requirement("pytest"), 37 | requirement("scikit-learn"), 38 | requirement("tensorflow"), 39 | requirement("tensorflow-addons"), 40 | requirement("keras"), 41 | requirement("networkx"), 42 | requirement("opencensus"), 43 | requirement("opencensus-context"), 44 | requirement("opencensus-ext-azure"), 45 | requirement("azure-datalake-store"), 46 | requirement("tenacity"), 47 | ], 48 | tags = ["manual"], 49 | deprecation = "This test is deprecated", 50 | ) 51 | -------------------------------------------------------------------------------- /examples/tensorflow/han/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | Reference: [Paper](https://arxiv.org/pdf/1903.07293). 3 | 4 | # Generate Graph Data 5 | * [Prepare Graph Data](../../../docs/graph_engine/data_spec.rst) 6 | 7 | # Job Augmentations 8 | ## HAN 9 | Training 10 | > --mode=train --model=han --num_epochs=2 --max_id=56994 --feature_idx=1 --learning_rate=0.001 --feature_dim=50 --label_idx=0 --label_dim=121 --fanouts=10 --all_edge_type=0,1 --batch_size=300 --head_num=8 --layer_dims=8 11 | 12 | ### Evaluate 13 | > --mode=evaluate --model=han --max_id=56994 --feature_idx=1 --learning_rate=0.001 --feature_dim=50 --label_idx=0 --label_dim=121 --fanouts=10 --all_edge_type=0,1 --batch_size=300 --head_num=8 --layer_dims=8 14 | 15 | ### Inference 16 | > --mode=inference --model=han --max_id=56994 --feature_idx=1 --learning_rate=0.001 --feature_dim=50 --label_idx=0 --label_dim=121 --fanouts=10 --all_edge_type=0,1 --batch_size=300 --head_num=8 --layer_dims=8 17 | 18 | # Parameters 19 | Code reference: 20 | - Create [models.HAN(source code)](https://github.com/microsoft/DeepGNN/blob/main/examples/tensorflow/han/model.py) 21 | 22 | | Parameters | Default | Description | 23 | | ----- | ----------- | ------- | 24 | | **mode** | train | Run mode. ["train", "evaluate", "save_embedding", "inference"] | 25 | | **model_dir** | ckpt | Model checkpoint. | 26 | | **num_epochs** | 20 | Number of epochs for training. | 27 | | **batch_size** | 512 | Mini-batch size. | 28 | | **learning_rate** | 0.01 | Learning rate. | 29 | | **optimizer** | adam | TF Optimizer. ["adagrad", "adam", "adadelta", "rmsprop", "ftrl", "sgd", "momentum"] | 30 | | **feature_idx** | -1 | Feature index. | 31 | | **feature_dim** | 0 | Feature dimension. | 32 | | **max_id** | -1 | Max node id. | 33 | | **use_id** | False | Whether to use identity feature. | 34 | | **label_idx** | -1 | Label index. | 35 | | **label_dim** | 0 | Label dimension. | 36 | | **all_edge_type** | [0] | All edge types of training set for HAN metapath. | 37 | | **fanouts** | [10, 10] | neighbor/fanouts parameters for one metapath. HAN support multi-hop neighbor.| 38 | | **head_num** | [1] | head attention num for each layer. HAN can support multipe layers.| 39 | | **layer_dims** | [8] | Hidden dimension for each layer.| 40 | -------------------------------------------------------------------------------- /examples/tensorflow/han/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN HAN model example.""" 4 | -------------------------------------------------------------------------------- /examples/tensorflow/han/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | set -ex 5 | 6 | DIR_NAME=$(dirname "$0") 7 | 8 | ## DEVICE support: ["cpu", "gpu"] 9 | DEVICE=${1:-cpu} 10 | 11 | if [[ ${DEVICE} == "gpu" ]] 12 | then 13 | PLATFORM_DEVICE=--gpu 14 | env CUDA_VISIBLE_DEVICES=0 15 | fi 16 | 17 | CLEANUP=${2:-"no_cleanup"} 18 | 19 | MODEL_DIR=$HOME/tmp/han_local-$(date +"%Y%m%d_%H%M%N") 20 | DATA_DIR=/tmp/cora/ 21 | python -m deepgnn.graph_engine.data.citation --data_dir $DATA_DIR 22 | 23 | ## training, save model checkpoint to $MODEL_DIR 24 | rm -rf $MODEL_DIR 25 | python3 $DIR_NAME/main.py --mode train \ 26 | --model_dir $MODEL_DIR \ 27 | --data_dir $DATA_DIR \ 28 | --training_node_types 0 \ 29 | --edge_types "0;1" \ 30 | --num_nodes 300 \ 31 | --feature_idx 0 \ 32 | --feature_dim 1870 \ 33 | --batch_size 24 \ 34 | --learning_rate 0.001 \ 35 | --epochs 10 \ 36 | --head_num 8 \ 37 | --hidden_dim 128 \ 38 | --label_idx 1 \ 39 | --label_dim 3 \ 40 | --fanouts 10 \ 41 | --seed 123 \ 42 | --backend snark \ 43 | --converter skip ${PLATFORM_DEVICE} 44 | 45 | 46 | ## Evaluation 47 | python3 $DIR_NAME/main.py --mode evaluate \ 48 | --model_dir $MODEL_DIR \ 49 | --data_dir $DATA_DIR \ 50 | --evaluate_node_types 1 \ 51 | --edge_types "0;1" \ 52 | --num_nodes 2000 \ 53 | --feature_idx 0 \ 54 | --feature_dim 1870 \ 55 | --batch_size 24 \ 56 | --head_num 8 \ 57 | --hidden_dim 128 \ 58 | --label_idx 1 \ 59 | --label_dim 3 \ 60 | --fanouts 10 \ 61 | --seed 123 \ 62 | --backend snark \ 63 | --converter skip ${PLATFORM_DEVICE} 64 | 65 | ## Inference 66 | python3 $DIR_NAME/main.py --mode inference \ 67 | --model_dir $MODEL_DIR \ 68 | --data_dir $DATA_DIR \ 69 | --edge_types "0;1" \ 70 | --num_nodes 2000 \ 71 | --feature_idx 0 \ 72 | --feature_dim 1870 \ 73 | --batch_size 24 \ 74 | --head_num 8 \ 75 | --hidden_dim 128 \ 76 | --fanouts 10 \ 77 | --seed 123 \ 78 | --backend snark \ 79 | --converter skip ${PLATFORM_DEVICE} 80 | 81 | if [[ "${CLEANUP}" != "no_cleanup" ]]; then 82 | rm -rf $MODEL_DIR 83 | fi 84 | -------------------------------------------------------------------------------- /examples/tensorflow/requirements.txt: -------------------------------------------------------------------------------- 1 | # for different TF version, we need different addon version, 2 | # please follow the instruction here: https://github.com/tensorflow/addons 3 | # to find which addon version is correct. 4 | horovod[tensorflow]==0.28.1;sys_platform != 'win32' 5 | tensorflow-addons==0.21.0 6 | -------------------------------------------------------------------------------- /examples/tensorflow/sage/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "example_tf_sage", 9 | srcs = [ 10 | "main.py", 11 | "main_linkprediction.py", 12 | "main_unsup.py", 13 | "sage_linkprediction.py", 14 | "sage_unsupervised.py", 15 | "sage.py", 16 | ], 17 | deps = [ 18 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 19 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 20 | "//src/python/deepgnn/graph_engine/data:graph_engine_data", 21 | "//src/python/deepgnn/tf:deepgnn_tf", 22 | "//src/python/deepgnn/tf/common:deepgnn_tf_common", 23 | "//src/python/deepgnn/tf/nn:deepgnn_tf_nn", 24 | ], 25 | tags = ["manual"], 26 | deprecation = "This target is deprecated", 27 | ) 28 | 29 | py_test( 30 | name = "test_sage", 31 | srcs = ["test_sage.py"], 32 | imports = ["../../../src/python/"], 33 | main = "test_sage.py", 34 | python_version = "PY3", 35 | srcs_version = "PY3", 36 | deps = [ 37 | ":example_tf_sage", 38 | requirement("numpy"), 39 | requirement("pytest"), 40 | requirement("scikit-learn"), 41 | requirement("tensorflow"), 42 | requirement("fsspec"), 43 | requirement("tensorflow-addons"), 44 | requirement("networkx"), 45 | requirement("opencensus"), 46 | requirement("opencensus-context"), 47 | requirement("opencensus-ext-azure"), 48 | requirement("azure-datalake-store"), 49 | requirement("tenacity"), 50 | ], 51 | tags = ["manual"], 52 | deprecation = "This test is deprecated", 53 | ) 54 | 55 | 56 | py_test( 57 | name = "test_sage_link", 58 | srcs = ["test_sage_link.py"], 59 | imports = ["../../../src/python/"], 60 | main = "test_sage_link.py", 61 | python_version = "PY3", 62 | srcs_version = "PY3", 63 | deps = [ 64 | ":example_tf_sage", 65 | requirement("numpy"), 66 | requirement("pytest"), 67 | requirement("scikit-learn"), 68 | requirement("fsspec"), 69 | requirement("tensorflow"), 70 | requirement("tensorflow-addons"), 71 | requirement("networkx"), 72 | requirement("opencensus"), 73 | requirement("opencensus-context"), 74 | requirement("opencensus-ext-azure"), 75 | requirement("azure-datalake-store"), 76 | requirement("tenacity"), 77 | ], 78 | tags = ["manual"], 79 | deprecation = "This test is deprecated", 80 | ) 81 | -------------------------------------------------------------------------------- /examples/tensorflow/sage/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | __GraphSAGE__ is a framework for inductive representation learning on large graphs. GraphSAGE is used to generate low-dimensional vector representations for nodes, and is especially useful for graphs that have rich node attribute information. 3 | 4 | Reference: [Inductive Representation Learning on Large Graphs](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) 5 | 6 | 7 | ### How to run 8 | see [run.sh](./run.sh) 9 | 10 | 11 | ### Run GraphSage with your graph 12 | * [Prepare Graph Data](../../../docs/graph_engine/data_spec.rst) 13 | -------------------------------------------------------------------------------- /examples/tensorflow/sage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN GraphSAGE model example.""" 4 | -------------------------------------------------------------------------------- /src/cc/lib/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_cc//cc:defs.bzl", "cc_binary") 4 | load("//config:variables.bzl", "CXX_OPTS") 5 | 6 | cc_binary( 7 | name = "wrapper", 8 | srcs = [ 9 | "py_graph.cc", 10 | "py_graph.h", 11 | "py_server.cc", 12 | ], 13 | copts = CXX_OPTS, 14 | defines = ["GLOG_NO_ABBREVIATED_SEVERITIES"], 15 | # ERROR macro is defined in glog and windows.h 16 | linkopts = select({ 17 | "@platforms//os:windows": ["/DEBUG", "/guard:cf"], 18 | # Explicitly list functions we want to export. 19 | "@platforms//os:macos": [ 20 | "-Wl,-exported_symbols_list,$(location version-script.darwin.lds)", 21 | ], 22 | "@platforms//os:linux": [ 23 | "-Wl,--version-script=$(location version-script.linux.lds)", 24 | "-fwhole-program", 25 | ], 26 | }), 27 | linkshared = True, 28 | linkstatic = True, 29 | visibility = ["//visibility:public"], 30 | deps = select({ 31 | "@platforms//os:windows": [ 32 | "//src/cc/lib/distributed:grpc", 33 | "//src/cc/lib/graph", 34 | "@com_github_google_glog//:glog", 35 | "@com_google_absl//absl/container:flat_hash_map", 36 | "@com_google_absl//absl/container:flat_hash_set", 37 | "@boost//:random", 38 | ], 39 | "@platforms//os:macos": [ 40 | "version-script.darwin.lds", 41 | "//src/cc/lib/distributed:grpc", 42 | "//src/cc/lib/graph", 43 | "@com_github_google_glog//:glog", 44 | "@com_google_absl//absl/container:flat_hash_map", 45 | "@com_google_absl//absl/container:flat_hash_set", 46 | "@boost//:random", 47 | ], 48 | "@platforms//os:linux": [ 49 | "@mimalloc//:mimalloc", # mimalloc should go first to ensure malloc is overridden everywhere 50 | "version-script.linux.lds", 51 | "//src/cc/lib/distributed:grpc", 52 | "//src/cc/lib/graph", 53 | "@com_github_google_glog//:glog", 54 | "@com_google_absl//absl/container:flat_hash_map", 55 | "@com_google_absl//absl/container:flat_hash_set", 56 | "@boost//:random", 57 | ], 58 | }), 59 | ) 60 | -------------------------------------------------------------------------------- /src/cc/lib/benchmark/search_benchmark.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #ifdef SNARK_PLATFORM_LINUX 12 | #include 13 | #endif 14 | 15 | static void BM_LOWER_BOUND(benchmark::State &state) 16 | { 17 | size_t max_count = state.range(0); 18 | std::vector elements(max_count); 19 | std::iota(elements.begin(), elements.end(), 0); 20 | std::vector ids(elements); 21 | int64_t seed = 42; 22 | std::shuffle(ids.begin(), ids.end(), std::mt19937_64(seed)); 23 | size_t curr = 0; 24 | for (auto _ : state) 25 | { 26 | ++curr; 27 | if (curr == max_count) 28 | { 29 | curr = 0; 30 | } 31 | const auto res = std::lower_bound(elements.cbegin(), elements.cend(), ids[curr]); 32 | if (*res != ids[curr]) 33 | { 34 | throw std::runtime_error("not found" + std::to_string(ids[curr]) + " got " + std::to_string(*res)); 35 | } 36 | } 37 | } 38 | 39 | static void BM_STD_FIND(benchmark::State &state) 40 | { 41 | size_t max_count = state.range(0); 42 | std::vector elements(max_count); 43 | std::iota(elements.begin(), elements.end(), 0); 44 | std::vector ids(elements); 45 | int64_t seed = 42; 46 | std::shuffle(ids.begin(), ids.end(), std::mt19937_64(seed)); 47 | size_t curr = 0; 48 | for (auto _ : state) 49 | { 50 | ++curr; 51 | if (curr == max_count) 52 | { 53 | curr = 0; 54 | } 55 | const auto res = std::find(elements.cbegin(), elements.cend(), ids[curr]); 56 | if (*res != ids[curr]) 57 | { 58 | throw std::runtime_error("not found" + std::to_string(ids[curr]) + " got " + std::to_string(*res)); 59 | } 60 | } 61 | } 62 | 63 | // Results of this benchmark are used to set split value for edge feature fetching 64 | BENCHMARK(BM_LOWER_BOUND)->RangeMultiplier(4)->Range(1 << 3, 1 << 10); 65 | BENCHMARK(BM_STD_FIND)->RangeMultiplier(4)->Range(1 << 3, 1 << 10); 66 | 67 | BENCHMARK_MAIN(); 68 | -------------------------------------------------------------------------------- /src/cc/lib/distributed/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_cc//cc:defs.bzl", "cc_library") 4 | load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") 5 | load("//config:variables.bzl", "CXX_OPTS") 6 | 7 | proto_library( 8 | name = "service_proto", 9 | srcs = ["service.proto"], 10 | visibility = ["//visibility:public"], 11 | ) 12 | 13 | cc_proto_library( 14 | name = "service_cc_proto", 15 | deps = [":service_proto"], 16 | ) 17 | 18 | cc_grpc_library( 19 | name = "service_cc_grpc", 20 | srcs = [":service_proto"], 21 | grpc_only = True, 22 | deps = [":service_cc_proto"], 23 | ) 24 | 25 | windows_deps = [ 26 | "//src/cc/lib/graph", 27 | # Order matters on windows. We want to use latest abseil instead of transitive from grpc. 28 | "@com_google_absl//absl/container:flat_hash_map", 29 | "@com_google_absl//absl/container:flat_hash_set", 30 | "@boost//:random", 31 | ":service_cc_grpc", 32 | "@com_github_grpc_grpc//:grpc++", 33 | "@com_google_benchmark//:benchmark", 34 | "@com_github_google_glog//:glog" 35 | ] 36 | 37 | # Windows builds use a system provided openssl defined in .bazelrc file. 38 | macos_linux_deps = windows_deps + ["@openssl"] 39 | 40 | cc_library( 41 | name = "grpc", 42 | srcs = [ 43 | "call_data.cc", 44 | "client.cc", 45 | "graph_engine.cc", 46 | "graph_sampler.cc", 47 | "server.cc", 48 | ], 49 | hdrs = [ 50 | "call_data.h", 51 | "client.h", 52 | "graph_engine.h", 53 | "graph_sampler.h", 54 | "server.h", 55 | ], 56 | copts = CXX_OPTS, 57 | features = ["fully_static_link"], 58 | linkstatic = True, 59 | visibility = ["//visibility:public"], 60 | # ERROR macro is defined in glog and windows.h 61 | defines = ["GLOG_NO_ABBREVIATED_SEVERITIES", "OPENSSL_IS_BORINGSSL"], 62 | deps = select({ 63 | "@platforms//os:windows": windows_deps, 64 | "//conditions:default": macos_linux_deps, 65 | }), 66 | ) 67 | -------------------------------------------------------------------------------- /src/cc/lib/distributed/graph_sampler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_GRAPH_SAMPLER_SERVICE_H 5 | #define SNARK_GRAPH_SAMPLER_SERVICE_H 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "absl/container/flat_hash_map.h" 12 | #include 13 | #include 14 | 15 | #include "src/cc/lib/distributed/service.grpc.pb.h" 16 | #include "src/cc/lib/graph/graph.h" 17 | #include "src/cc/lib/graph/logger.h" 18 | 19 | namespace snark 20 | { 21 | 22 | // Stub value to indicate there are no items for a sampler 23 | // and client is safe to skip requests to such shards. 24 | const uint64_t empty_sampler_id = std::numeric_limits::max(); 25 | 26 | class GraphSamplerServiceImpl final : public snark::GraphSampler::Service 27 | { 28 | public: 29 | GraphSamplerServiceImpl(snark::Metadata metadata, std::vector partition_paths, 30 | std::vector partition_indices, std::shared_ptr logger = nullptr); 31 | 32 | grpc::Status Create(::grpc::ServerContext *context, const snark::CreateSamplerRequest *request, 33 | snark::CreateSamplerReply *response) override; 34 | 35 | grpc::Status Sample(::grpc::ServerContext *context, const snark::SampleRequest *request, 36 | snark::SampleReply *response) override; 37 | 38 | private: 39 | absl::flat_hash_map> 40 | m_node_sampler_factory; 41 | absl::flat_hash_map> 42 | m_edge_sampler_factory; 43 | std::vector> m_samplers; 44 | snark::Metadata m_metadata; 45 | std::vector m_partition_indices; 46 | std::vector m_patrition_paths; 47 | std::mutex m_mutex; 48 | std::shared_ptr m_logger; 49 | }; 50 | 51 | } // namespace snark 52 | #endif // SNARK_GRAPH_SAMPLER_SERVICE_H 53 | -------------------------------------------------------------------------------- /src/cc/lib/distributed/server.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_SERVER_H 5 | #define SNARK_SERVER_H 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "src/cc/lib/distributed/graph_engine.h" 15 | #include "src/cc/lib/distributed/graph_sampler.h" 16 | #include "src/cc/lib/graph/graph.h" 17 | 18 | namespace snark 19 | { 20 | class GRPCServer final 21 | { 22 | public: 23 | GRPCServer(std::shared_ptr engine_service_impl, 24 | std::shared_ptr sampler_service_impl, std::string host_name, 25 | std::string ssl_key, std::string ssl_cert, std::string ssl_root); 26 | 27 | ~GRPCServer(); 28 | 29 | std::shared_ptr InProcessChannel(); 30 | 31 | void HandleRpcs(size_t index); 32 | 33 | private: 34 | std::vector> m_cqs; 35 | 36 | // Sampler/Engine split helps us to manage runtime: 37 | // * Resource heavy components can be deployed to separate machine types. 38 | // * Adding new server side samplers doesn't require to restart a service 39 | // and interrupt existing clients, new clients can connect to old and new 40 | // endpoints. 41 | snark::GraphEngine::AsyncService m_engine_service; 42 | std::shared_ptr m_engine_service_impl; 43 | snark::GraphSampler::AsyncService m_sampler_service; 44 | std::shared_ptr m_sampler_service_impl; 45 | std::unique_ptr m_server; 46 | std::vector m_runner_threads; 47 | std::atomic m_shutdown; 48 | }; 49 | } // namespace snark 50 | #endif // SNARK_SERVER_H 51 | -------------------------------------------------------------------------------- /src/cc/lib/graph/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_cc//cc:defs.bzl", "cc_library") 4 | load("//config:variables.bzl", "CXX_OPTS", "PLATFORM_DEFINES") 5 | 6 | cc_library( 7 | name = "graph", 8 | srcs = [ 9 | "graph.cc", 10 | "locator.cc", 11 | "metadata.cc", 12 | "partition.cc", 13 | "sampler.cc", 14 | "hdfs_wrap.cc", 15 | "reservoir.cc", 16 | "logger.cc", 17 | ], 18 | hdrs = [ 19 | "graph.h", 20 | "locator.h", 21 | "logger.h", 22 | "metadata.h", 23 | "partition.h", 24 | "sampler.h", 25 | "storage.h", 26 | "hdfs_wrap.h", 27 | "types.h", 28 | "xoroshiro.h", 29 | "reservoir.h", 30 | "merger.h", 31 | ], 32 | copts = CXX_OPTS, 33 | linkopts = select({ 34 | "@platforms//os:linux": ["-ldl"], 35 | "//conditions:default": [], 36 | }), 37 | # ERROR macro is defined in glog and windows.h 38 | defines = PLATFORM_DEFINES + ["GLOG_NO_ABBREVIATED_SEVERITIES"], 39 | features = ["fully_static_link"], 40 | linkstatic = True, 41 | visibility = ["//visibility:public"], 42 | deps = [ 43 | "@boost//:random", 44 | "@json//:json", 45 | "@com_github_google_glog//:glog", 46 | "@com_google_absl//absl/container:flat_hash_map", 47 | "@com_google_absl//absl/container:flat_hash_set", 48 | "@com_google_absl//absl/container:inlined_vector", 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /src/cc/lib/graph/hdfs_wrap.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_HDFS_WRAP_H 5 | #define SNARK_HDFS_WRAP_H 6 | 7 | #include 8 | #include 9 | 10 | typedef int32_t hdfs_int; 11 | 12 | struct hdfs_internal_so; 13 | typedef struct hdfs_internal_so *hdfsFS_so; 14 | 15 | enum hdfsStreamType_so 16 | { 17 | HDFS_STREAM_UNINITIALIZED = 0, 18 | HDFS_STREAM_INPUT = 1, 19 | HDFS_STREAM_OUTPUT = 2, 20 | }; 21 | struct hdfsFile_internal_so 22 | { 23 | void *file; 24 | enum hdfsStreamType_so type; 25 | hdfs_int flags; 26 | }; 27 | typedef hdfsFile_internal_so *hdfsFile_so; 28 | 29 | typedef enum tObjectKind_so 30 | { 31 | kObjectKindFile_so = 'F', 32 | kObjectKindDirectory_so = 'D', 33 | } tObjectKind_so; 34 | 35 | typedef struct 36 | { 37 | tObjectKind_so mKind; /* file or directory */ 38 | char *mName; /* the name of the file */ 39 | time_t mLastMod; /* the last modification time for the file in seconds */ 40 | int64_t mSize; /* the size of the file in bytes */ 41 | short mReplication; /* the count of replicas */ 42 | int64_t mBlockSize; /* the block size for the file */ 43 | char *mOwner; /* the owner of the file */ 44 | char *mGroup; /* the group associated with the file */ 45 | short mPermissions; /* the permissions associated with the file */ 46 | time_t mLastAccess; /* the last access time for the file in seconds */ 47 | } hdfsFileInfo_so; 48 | 49 | class hdfsBindings; 50 | 51 | class HDFSConnection 52 | { 53 | public: 54 | HDFSConnection(); 55 | HDFSConnection(std::string data_path, std::string config_path); 56 | 57 | int64_t get_file_size(const char *path, const char *host, int port); 58 | 59 | std::vector list_directory(const char *full_path); 60 | 61 | hdfsFile_so open_file(const char *path); 62 | void close_file(hdfsFile_so readFile); 63 | void read(hdfsFile_so readFile, int64_t read_size, void *output); 64 | 65 | private: 66 | std::shared_ptr hdfs_bindings; 67 | hdfsFS_so fs = nullptr; 68 | std::string m_data_path = ""; 69 | void *m_buffer = nullptr; 70 | }; 71 | 72 | void parse_hdfs_path(std::string full_path, std::string &data_path, std::string &host, int &port); 73 | 74 | std::vector hdfs_list_directory(std::string full_path, std::string config_path); 75 | 76 | template std::vector read_hdfs(std::string full_path, std::string config_path); 77 | 78 | bool is_hdfs_path(std::filesystem::path path); 79 | 80 | #endif 81 | -------------------------------------------------------------------------------- /src/cc/lib/graph/locator.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_LOCATOR_H 5 | #define SNARK_LOCATOR_H 6 | 7 | #include 8 | #include 9 | 10 | #include "hdfs_wrap.h" 11 | #include "logger.h" 12 | #include "types.h" 13 | 14 | namespace snark 15 | { 16 | FILE *open_file(std::filesystem::path s, const char *mode, std::shared_ptr logger = nullptr); 17 | FILE *open_meta(std::filesystem::path path, std::string mode, std::shared_ptr logger = nullptr); 18 | FILE *open_node_map(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 19 | FILE *open_node_index(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 20 | FILE *open_node_features_index(std::filesystem::path path, std::string suffix, 21 | std::shared_ptr logger = nullptr); 22 | FILE *open_node_features_data(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 23 | FILE *open_neighbor_index(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 24 | FILE *open_edge_timestamps(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 25 | FILE *open_edge_index(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 26 | FILE *open_edge_features_index(std::filesystem::path path, std::string suffix, 27 | std::shared_ptr logger = nullptr); 28 | FILE *open_edge_features_data(std::filesystem::path path, std::string suffix, std::shared_ptr logger = nullptr); 29 | FILE *open_edge_alias(std::filesystem::path path, size_t partition, Type type, 30 | std::shared_ptr logger = nullptr); 31 | FILE *open_node_alias(std::filesystem::path path, size_t partition, Type type, 32 | std::shared_ptr logger = nullptr); 33 | 34 | void platform_fseek(FILE *f, int offset, int origin); 35 | size_t platform_ftell(FILE *f); 36 | }; // namespace snark 37 | 38 | #endif // SNARK_LOCATOR_H 39 | -------------------------------------------------------------------------------- /src/cc/lib/graph/logger.cc: -------------------------------------------------------------------------------- 1 | #include "src/cc/lib/graph/logger.h" 2 | 3 | #include 4 | // Use raw log to avoid possible initialization conflicts with glog from other libraries. 5 | #include 6 | #include 7 | 8 | namespace snark 9 | { 10 | 11 | void GLogger::log_info(const char *format, ...) 12 | { 13 | va_list args; 14 | va_start(args, format); 15 | std::string msg; 16 | char buffer[256]; 17 | #ifdef _WIN32 18 | vsnprintf_s(buffer, sizeof(buffer), sizeof(buffer) - 1, format, args); 19 | #else 20 | vsnprintf(buffer, sizeof(buffer), format, args); 21 | #endif 22 | va_end(args); 23 | msg = buffer; 24 | RAW_LOG_INFO("%s", msg.c_str()); 25 | } 26 | 27 | void GLogger::log_error(const char *format, ...) 28 | { 29 | va_list args; 30 | va_start(args, format); 31 | std::string msg; 32 | char buffer[256]; 33 | #ifdef _WIN32 34 | vsnprintf_s(buffer, sizeof(buffer), sizeof(buffer) - 1, format, args); 35 | #else 36 | vsnprintf(buffer, sizeof(buffer), format, args); 37 | #endif 38 | va_end(args); 39 | msg = buffer; 40 | RAW_LOG_ERROR("%s", msg.c_str()); 41 | } 42 | 43 | void GLogger::log_warning(const char *format, ...) 44 | { 45 | va_list args; 46 | va_start(args, format); 47 | std::string msg; 48 | char buffer[256]; 49 | #ifdef _WIN32 50 | vsnprintf_s(buffer, sizeof(buffer), sizeof(buffer) - 1, format, args); 51 | #else 52 | vsnprintf(buffer, sizeof(buffer), format, args); 53 | #endif 54 | va_end(args); 55 | msg = buffer; 56 | RAW_LOG_WARNING("%s", msg.c_str()); 57 | } 58 | 59 | void GLogger::log_fatal(const char *format, ...) 60 | { 61 | va_list args; 62 | va_start(args, format); 63 | std::string msg; 64 | char buffer[256]; 65 | #ifdef _WIN32 66 | vsnprintf_s(buffer, sizeof(buffer), sizeof(buffer) - 1, format, args); 67 | #else 68 | vsnprintf(buffer, sizeof(buffer), format, args); 69 | #endif 70 | va_end(args); 71 | msg = buffer; 72 | RAW_LOG_FATAL("%s", msg.c_str()); 73 | } 74 | 75 | } // namespace snark 76 | -------------------------------------------------------------------------------- /src/cc/lib/graph/logger.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_LOGGER_H 5 | #define SNARK_LOGGER_H 6 | 7 | #include 8 | 9 | namespace snark 10 | { 11 | 12 | // Simple logger interface to allow non glog logging. 13 | struct Logger 14 | { 15 | virtual void log_info(const char *format, ...) = 0; 16 | virtual void log_error(const char *format, ...) = 0; 17 | virtual void log_warning(const char *format, ...) = 0; 18 | virtual void log_fatal(const char *format, ...) = 0; 19 | virtual ~Logger() = default; 20 | }; 21 | 22 | // Logger implementation that uses glog. 23 | struct GLogger : public Logger 24 | { 25 | void log_info(const char *format, ...) override; 26 | void log_error(const char *format, ...) override; 27 | void log_warning(const char *format, ...) override; 28 | void log_fatal(const char *format, ...) override; 29 | }; 30 | 31 | } // namespace snark 32 | 33 | #endif // SNARK_LOGGER_H 34 | -------------------------------------------------------------------------------- /src/cc/lib/graph/metadata.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_METADATA_H 5 | #define SNARK_METADATA_H 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "logger.h" 12 | 13 | namespace snark 14 | { 15 | 16 | const size_t MINIMUM_SUPPORTED_VERSION = 2; 17 | 18 | struct Metadata 19 | { 20 | Metadata() = default; 21 | explicit Metadata(std::filesystem::path path, std::string config_path = "", bool skip_feature_loading = false, 22 | bool skip_temporal_loading = false, std::shared_ptr logger = nullptr); 23 | void Write(std::filesystem::path path) const; 24 | 25 | // Graph information. 26 | size_t m_version; 27 | size_t m_node_count; 28 | size_t m_edge_count; 29 | size_t m_edge_type_count; 30 | size_t m_node_type_count; 31 | size_t m_node_feature_count; 32 | size_t m_edge_feature_count; 33 | 34 | // Infrastructure related information about graph. 35 | size_t m_partition_count; 36 | std::string m_path; 37 | std::string m_config_path; 38 | 39 | std::vector> m_partition_node_weights; 40 | std::vector> m_partition_edge_weights; 41 | std::vector m_node_count_per_type; 42 | std::vector m_edge_count_per_type; 43 | int64_t m_watermark; 44 | }; 45 | } // namespace snark 46 | 47 | #endif // SNARK_METADATA_H 48 | -------------------------------------------------------------------------------- /src/cc/lib/graph/reservoir.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_RESERVOIR_H 5 | #define SNARK_RESERVOIR_H 6 | 7 | #include 8 | #include 9 | 10 | #include "boost/random/uniform_real_distribution.hpp" 11 | 12 | #include "xoroshiro.h" 13 | 14 | namespace snark 15 | { 16 | 17 | // Implementation of an optimal algorithm for reservoir sampling: 18 | // https://en.wikipedia.org/wiki/Reservoir_sampling#Optimal:_Algorithm_L 19 | class AlgorithmL 20 | { 21 | public: 22 | // `k` is the size of reservoir, short name to be consistent with reference. 23 | // gen is the random generator to use for sampling and support for setting seeds in clients. 24 | AlgorithmL(size_t k, snark::Xoroshiro128PlusGenerator &gen); 25 | 26 | // add n elements to the reservoir. Every time an element is selected to put in a reservoir, 27 | // the update function is called. We use a callback approach, because we usually need to fetch 28 | // elements from multiple sources (edge type, destination, timestamps). Arguments passed to the callback 29 | // are (pick, offset). Pick is the index of element in the reservoir to be replaced in the range [0; k). 30 | // Offset is the offset in the stream in the range [0, n). 31 | // Method might be called multiple times which will result in merging multiple streams into one. 32 | void add(size_t n, std::function update); 33 | 34 | private: 35 | size_t m_k; 36 | float m_W; 37 | size_t m_next; 38 | size_t m_seen; 39 | snark::Xoroshiro128PlusGenerator &m_gen; 40 | boost::random::uniform_real_distribution m_dist; 41 | }; 42 | 43 | // Following paper "Reservoir-based Random Sampling with Replacement from Data Stream" by BH Park · 2004 44 | class WithReplacement 45 | { 46 | public: 47 | WithReplacement(size_t k, snark::Xoroshiro128PlusGenerator &gen); 48 | 49 | void add(size_t n, std::function update); 50 | 51 | void reset(); 52 | 53 | private: 54 | size_t m_seen; 55 | size_t m_k; 56 | snark::Xoroshiro128PlusGenerator &m_gen; 57 | boost::random::uniform_real_distribution m_dist; 58 | }; 59 | 60 | // Used for merging multiple sampled neighbors lists into one. We can't use WithReplacement directly, 61 | // because we need to consider intervals with smaller than k elements: if we have two lists 62 | // of equal sizes 10, we can't use bernulli trials to merge them into one list of size 15, 63 | // because we need to backfill first and then sample from the merged reservoir, but with updated weights. 64 | class WithoutReplacementMerge 65 | { 66 | public: 67 | WithoutReplacementMerge(size_t k, snark::Xoroshiro128PlusGenerator &gen); 68 | 69 | // w in this case is the weight of the interval, not the number of elements as in classes above. 70 | void add(size_t w, std::function update); 71 | 72 | void reset(); 73 | 74 | private: 75 | size_t m_seen; 76 | size_t m_k; 77 | snark::Xoroshiro128PlusGenerator &m_gen; 78 | boost::random::uniform_real_distribution m_dist; 79 | }; 80 | 81 | } // namespace snark 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /src/cc/lib/graph/types.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #ifndef SNARK_TYPES_H 5 | #define SNARK_TYPES_H 6 | #include 7 | #include 8 | 9 | namespace snark 10 | { 11 | 12 | using NodeId = int64_t; 13 | using Type = int32_t; 14 | using FeatureId = int32_t; 15 | using FeatureSize = uint32_t; 16 | using Timestamp = int64_t; 17 | using FeatureMeta = std::pair; 18 | 19 | const int32_t PLACEHOLDER_NODE_TYPE = -1; 20 | const Timestamp PLACEHOLDER_TIMESTAMP = -1; 21 | 22 | // Enum ordering should match PyPartitionStorageType in py_graph.h. 23 | enum PartitionStorageType 24 | { 25 | memory, 26 | disk, 27 | }; 28 | 29 | } // namespace snark 30 | #endif 31 | -------------------------------------------------------------------------------- /src/cc/lib/py_server.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include "py_graph.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "distributed/server.h" 14 | 15 | namespace deep_graph 16 | { 17 | namespace python 18 | { 19 | 20 | namespace 21 | { 22 | std::string safe_convert(const char *buffer) 23 | { 24 | if (buffer) 25 | { 26 | return std::string(buffer); 27 | } 28 | return std::string(); 29 | } 30 | } // namespace 31 | 32 | int32_t StartServer(PyServer *graph, const char *meta_location, size_t count, uint32_t *partition_indices, 33 | const char **partition_locations, const char *host_name, const char *ssl_key, const char *ssl_cert, 34 | const char *ssl_root, const PyPartitionStorageType storage_type_, const char *config_path, 35 | bool skip_feature_loading, bool skip_temporal_loading) 36 | { 37 | snark::PartitionStorageType storage_type = static_cast(storage_type_); 38 | snark::Metadata metadata(safe_convert(meta_location), safe_convert(config_path), skip_feature_loading, 39 | skip_temporal_loading); 40 | std::vector partition_paths; 41 | partition_paths.reserve(count); 42 | for (size_t i = 0; i < count; ++i) 43 | { 44 | partition_paths.emplace_back(safe_convert(partition_locations[i])); 45 | } 46 | graph->server = std::make_unique( 47 | std::make_shared( 48 | metadata, partition_paths, std::vector(partition_indices, partition_indices + count), 49 | static_cast(storage_type)), 50 | std::make_shared( 51 | metadata, partition_paths, std::vector(partition_indices, partition_indices + count)), 52 | safe_convert(host_name), safe_convert(ssl_key), safe_convert(ssl_cert), safe_convert(ssl_root)); 53 | return 0; 54 | } 55 | 56 | int32_t ResetServer(PyServer *py_graph) 57 | { 58 | py_graph->server.reset(); 59 | return 0; 60 | } 61 | 62 | } // namespace python 63 | } // namespace deep_graph 64 | -------------------------------------------------------------------------------- /src/cc/lib/version-script.darwin.lds: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | _CreateLocalGraph 4 | _StartServer 5 | _CreateRemoteClient 6 | _GetNodeFeature 7 | _GetNodeSparseFeature 8 | _GetNodeStringFeature 9 | _GetEdgeFeature 10 | _GetEdgeSparseFeature 11 | _GetEdgeStringFeature 12 | _StartServer 13 | _NeighborCount 14 | _GetNeighbors 15 | _WeightedSampleNeighbor 16 | _UniformSampleNeighbor 17 | _PPRSampleNeighbor 18 | _LastNCreatedNeighbor 19 | _CreateWeightedNodeSampler 20 | _CreateUniformNodeSampler 21 | _CreateUniformNodeSamplerWithoutReplacement 22 | _SampleNodes 23 | _CreateWeightedEdgeSampler 24 | _CreateUniformEdgeSampler 25 | _CreateUniformEdgeSamplerWithoutReplacement 26 | _SampleEdges 27 | _ResetSampler 28 | _ResetGraph 29 | _ResetServer 30 | _RandomWalk 31 | _GetNodeType 32 | _HDFSMoveMeta 33 | -------------------------------------------------------------------------------- /src/cc/lib/version-script.linux.lds: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | { 4 | global: 5 | CreateLocalGraph; 6 | StartServer; 7 | CreateRemoteClient; 8 | GetNodeFeature; 9 | GetNodeSparseFeature; 10 | GetNodeStringFeature; 11 | GetEdgeFeature; 12 | GetEdgeSparseFeature; 13 | GetEdgeStringFeature; 14 | StartServer; 15 | NeighborCount; 16 | GetNeighbors; 17 | WeightedSampleNeighbor; 18 | UniformSampleNeighbor; 19 | PPRSampleNeighbor; 20 | LastNCreatedNeighbor; 21 | CreateWeightedNodeSampler; 22 | CreateUniformNodeSampler; 23 | CreateUniformNodeSamplerWithoutReplacement; 24 | SampleNodes; 25 | CreateWeightedEdgeSampler; 26 | CreateUniformEdgeSampler; 27 | CreateUniformEdgeSamplerWithoutReplacement; 28 | SampleEdges; 29 | ResetSampler; 30 | ResetGraph; 31 | ResetServer; 32 | RandomWalk; 33 | GetNodeType; 34 | HDFSMoveMeta; 35 | local: *; 36 | }; 37 | -------------------------------------------------------------------------------- /src/cc/tests/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_cc//cc:defs.bzl", "cc_test") 4 | load("//config:variables.bzl", "CXX_OPTS", "PLATFORM_DEFINES") 5 | 6 | cc_library( 7 | name = "mocks", 8 | srcs = ["mocks.cc"], 9 | hdrs = ["mocks.h"], 10 | copts = CXX_OPTS, 11 | linkstatic = True, 12 | visibility = ["//visibility:public"], 13 | deps = ["//src/cc/lib/graph"], 14 | ) 15 | 16 | cc_test( 17 | name = "native_tests", 18 | srcs = [ 19 | "graph_test.cc", 20 | "mocks.cc", 21 | "mocks.h", 22 | ], 23 | copts = CXX_OPTS, 24 | defines = PLATFORM_DEFINES, 25 | linkopts = ["-lm"], 26 | deps = [ 27 | ":mocks", 28 | "@googletest//:gtest_main", 29 | ], 30 | ) 31 | 32 | cc_test( 33 | name = "temporal_tests", 34 | srcs = [ 35 | "temporal_test.cc", 36 | "mocks.cc", 37 | "mocks.h", 38 | ], 39 | copts = CXX_OPTS, 40 | defines = PLATFORM_DEFINES, 41 | linkopts = ["-lm"], 42 | deps = [ 43 | ":mocks", 44 | "@googletest//:gtest_main", 45 | "//src/cc/lib/distributed:grpc", 46 | "//src/cc/lib/graph", 47 | ], 48 | ) 49 | 50 | cc_test( 51 | name = "distributed_tests", 52 | srcs = [ 53 | "distributed_test.cc", 54 | ], 55 | copts = CXX_OPTS, 56 | defines = PLATFORM_DEFINES, 57 | linkopts = ["-lm"], 58 | deps = [ 59 | ":mocks", 60 | "//src/cc/lib/distributed:grpc", 61 | "//src/cc/lib/graph", 62 | "@googletest//:gtest_main", 63 | ], 64 | ) 65 | 66 | cc_test( 67 | name = "hdfs_tests", 68 | srcs = [ 69 | "hdfs_test.cc", 70 | ], 71 | copts = CXX_OPTS, 72 | defines = PLATFORM_DEFINES, 73 | deps = [ 74 | ":mocks", 75 | "//src/cc/lib/graph", 76 | "@googletest//:gtest_main", 77 | "@hadoop//:hadoop", 78 | "@hadoop//:hadoop_include", 79 | ], 80 | data = ["core-site.xml",], 81 | target_compatible_with = ["@platforms//os:linux"], 82 | ) 83 | -------------------------------------------------------------------------------- /src/cc/tests/core-site.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | fs.defaultFShdfs://localhost:9000 4 | -------------------------------------------------------------------------------- /src/cc/tests/mocks.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include "src/cc/lib/graph/graph.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace TestGraph 11 | { 12 | 13 | using NeighborRecord = std::tuple; 14 | struct Node 15 | { 16 | int64_t m_id; 17 | int32_t m_type; 18 | float m_weight; 19 | std::vector> m_float_features; // first dimension is id, second is feature vector. 20 | std::vector m_neighbors; 21 | 22 | // ordered in the same way as m_neighbors, 1st dimension is edge, 2nd - feature_id, 3rd actual data. 23 | std::vector>> m_edge_features; 24 | }; 25 | 26 | struct MemoryGraph 27 | { 28 | std::vector m_nodes; 29 | 30 | // Temporal information. 31 | snark::Timestamp m_watermark = -1; // use -1 to flag a non-temporal graph. 32 | std::vector> m_edge_timestamps; 33 | }; 34 | 35 | snark::Partition convert(std::filesystem::path path, std::string suffix, MemoryGraph t, size_t node_types); 36 | 37 | std::vector serialize_temporal_features(std::vector timestamps, 38 | std::vector> features); 39 | } // namespace TestGraph 40 | -------------------------------------------------------------------------------- /src/python/deepgnn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | load("@rules_python//python:packaging.bzl", "py_package", "py_wheel") 6 | load("@jvolkman_rules_pycross//pycross:defs.bzl", "pycross_wheel_library") 7 | 8 | py_library( 9 | name = "deepgnn", 10 | srcs = [ 11 | "__init__.py", 12 | "arg_types.py", 13 | "log_consts.py", 14 | "logging_utils.py", 15 | "train_types.py", 16 | ], 17 | visibility = ["//visibility:public"], 18 | ) 19 | 20 | py_library( 21 | name = "deepgnn_exports", 22 | srcs = [ 23 | "__init__.py", 24 | ], 25 | visibility = ["//visibility:public"], 26 | ) 27 | 28 | py_library( 29 | name = "deepgnn_ge_library", 30 | srcs = [ 31 | "__init__.py", 32 | ":deepgnn", 33 | "//src/python/deepgnn/graph_engine:graph_engine", 34 | "//src/python/deepgnn/graph_engine/data:graph_engine_data", 35 | "//src/python/deepgnn/graph_engine:graph_engine_exports", 36 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 37 | "//src/python/deepgnn/graph_engine/snark/preprocess:snark_sampler", 38 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 39 | "//src/python/deepgnn/graph_engine/backends/snark:graph_engine_backends_snark", 40 | ], 41 | visibility = ["//visibility:public"], 42 | ) 43 | 44 | 45 | py_package( 46 | name = "deepgnn_ge_package", 47 | packages = ["src.python.deepgnn.graph_engine", "src.python.deepgnn"], 48 | deps = [ 49 | "deepgnn_ge_library", 50 | 51 | ], 52 | visibility = ["//visibility:public"], 53 | ) 54 | 55 | py_wheel( 56 | name = "deepgnn_ge_wheel", 57 | distribution = "deepgnn-ge", 58 | python_tag = "py3", 59 | strip_path_prefixes = [ 60 | "src/python", 61 | ], 62 | version = "0.0.1", 63 | deps = [ 64 | ":deepgnn_ge_package", 65 | ], 66 | visibility = ["//visibility:public"], 67 | ) 68 | 69 | pycross_wheel_library( 70 | name = "deepgnn_ge_wheel_library", 71 | wheel = "//src/python/deepgnn:deepgnn_ge_wheel", 72 | visibility = ["//visibility:public"], 73 | ) 74 | -------------------------------------------------------------------------------- /src/python/deepgnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .arg_types import ( 6 | vec2str, 7 | str2bool, 8 | str2list_int, 9 | str2list2_int, 10 | str2list2, 11 | str2list_str, 12 | ) 13 | from .log_consts import ( 14 | LOG_NAME_DEEPGNN, 15 | LOG_PROPS_CUSTOM_DIMENSIONS, 16 | LOG_PROPS_EVENT_END_JOB, 17 | LOG_PROPS_EVENT_END_WORKER, 18 | LOG_PROPS_EVENT_START_JOB, 19 | LOG_PROPS_EVENT_START_WORKER, 20 | LOG_PROPS_KEY_ERR_CODE, 21 | LOG_PROPS_KEY_EVENT_TYPE, 22 | LOG_PROPS_KEY_JOB_ID, 23 | LOG_PROPS_KEY_MODE, 24 | LOG_PROPS_KEY_MODEL, 25 | LOG_PROPS_KEY_NUM_WORKERS, 26 | LOG_PROPS_KEY_PLATFORM, 27 | LOG_PROPS_KEY_USER_NAME, 28 | LOG_PROPS_KEY_WORKER_INDEX, 29 | LOG_PROPS_PLATFORM_PYTORCH, 30 | LOG_PROPS_PLATFORM_TF, 31 | ) 32 | from .train_types import TrainerType, TrainMode 33 | from .logging_utils import ( 34 | get_current_user, 35 | log_telemetry, 36 | get_logger, 37 | setup_default_logging_config, 38 | ) 39 | -------------------------------------------------------------------------------- /src/python/deepgnn/arg_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Type conversions to parse command line arguments.""" 4 | import argparse 5 | 6 | 7 | def vec2str(vec): 8 | """Concat 1 or 2 dimensionanl ndarray to str. 9 | 10 | It is used when output embeddings to files, e.g.: 11 | vec = [1,2,3] 12 | res = vec2str(vec) 13 | # res is "1 2 3" 14 | """ 15 | if len(vec.shape) == 2: 16 | return ",".join([" ".join([str(i) for i in j]) for j in vec.tolist()]) 17 | elif len(vec.shape) == 1: 18 | return " ".join([str(i) for i in vec.tolist()]) 19 | else: 20 | raise RuntimeError("wrong shape: " + str(vec.shape)) 21 | 22 | 23 | def str2bool(v): 24 | """Convert the string value to bool. 25 | 26 | For example: 27 | str2bool("yes") # True 28 | str2bool("Yes") # True 29 | str2bool("True") # True 30 | str2bool("true") # True 31 | str2bool("False") # False 32 | str2bool("0") # False 33 | """ 34 | if isinstance(v, bool): 35 | return v 36 | if v.lower() in ("yes", "true", "t", "y", "1"): 37 | return True 38 | elif v.lower() in ("no", "false", "f", "n", "0"): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError("Boolean value expected.") 42 | 43 | 44 | # str to 1d list 45 | def str2list_int(v): 46 | """Convert a comma separated string to int list. 47 | 48 | For example: 49 | str2list_int("1,2,3") # result is [1,2,3] 50 | str2list_int([1,2,3]) # result is [1,2,3] 51 | """ 52 | if isinstance(v, list): 53 | return v 54 | if v == "": 55 | return [] 56 | return [int(x) for x in v.split(",")] 57 | 58 | 59 | # str to 2d int list 60 | def str2list2_int(v): 61 | """Convert string to 2d-list. 62 | 63 | It is used to parse the metapath of the node. 64 | metapath = "1;2;3,4" 65 | str2list2_int(metapath) # [[1],[2],[3, 4]] 66 | """ 67 | if isinstance(v, list): 68 | return v 69 | ret = [] 70 | for y in v.split(";"): 71 | if y != "": 72 | ret.append([int(x) for x in y.split(",")]) 73 | return ret 74 | 75 | 76 | # str to 2d list 77 | def str2list2(v): 78 | """Convert string to string list. 79 | 80 | It is used to parse the edge types of the node. 81 | edges = "q;k;s" 82 | str2list2(edges) # [['q'],['k'],['s']] 83 | """ 84 | if isinstance(v, list): 85 | return v 86 | ret = [] 87 | for y in v.split(";"): 88 | if y != "": 89 | ret.append([x for x in y.split(",")]) 90 | return ret 91 | 92 | 93 | # 94 | def str2list_str(v): 95 | """Convert string to string list, default separator is ",".""" 96 | return v.split(",") 97 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "graph_engine", 9 | srcs = [ 10 | "_adl_reader.py", 11 | "_base.py", 12 | "adl_uploader.py", 13 | "graph_dataset.py", 14 | "graph_ops.py", 15 | "multihop.py", 16 | "prefetch.py", 17 | "samplers.py", 18 | "utils.py", 19 | ], 20 | data = ["//src/cc/lib:wrapper"], 21 | deps = [ 22 | "//src/python/deepgnn:deepgnn", 23 | ], 24 | visibility = ["//visibility:public"], 25 | ) 26 | 27 | py_library( 28 | name = "graph_engine_exports", 29 | srcs = [ 30 | "__init__.py", 31 | ], 32 | deps = [ 33 | "//src/python/deepgnn:deepgnn", 34 | ], 35 | visibility = ["//visibility:public"], 36 | ) 37 | 38 | py_library( 39 | name = "graph_engine_testlib", 40 | srcs = [ 41 | "test_adl_reader.py", 42 | ], 43 | deps = [ 44 | "//src/python/deepgnn:deepgnn", 45 | ], 46 | visibility = ["//visibility:public"], 47 | ) 48 | 49 | py_test( 50 | name = "test_prefetch", 51 | srcs = ["test_prefetch.py"], 52 | imports = ["../../"], 53 | main = "test_prefetch.py", 54 | python_version = "PY3", 55 | srcs_version = "PY3", 56 | deps = [ 57 | ":graph_engine", 58 | requirement("numpy"), 59 | requirement("fsspec"), 60 | requirement("pytest"), 61 | requirement("opencensus"), 62 | requirement("opencensus-context"), 63 | requirement("opencensus-ext-azure"), 64 | requirement("azure-datalake-store"), 65 | ], 66 | ) 67 | 68 | py_test( 69 | name = "test_adl_reader", 70 | srcs = ["test_adl_reader.py"], 71 | imports = ["../../"], 72 | main = "test_adl_reader.py", 73 | python_version = "PY3", 74 | srcs_version = "PY3", 75 | deps = [ 76 | ":graph_engine", 77 | requirement("adlfs"), 78 | requirement("numpy"), 79 | requirement("fsspec"), 80 | requirement("pytest"), 81 | requirement("opencensus"), 82 | requirement("opencensus-context"), 83 | requirement("opencensus-ext-azure"), 84 | requirement("azure-datalake-store"), 85 | ], 86 | ) 87 | 88 | py_test( 89 | name = "test_multihop", 90 | srcs = ["test_multihop.py"], 91 | imports = ["../../"], 92 | main = "test_multihop.py", 93 | python_version = "PY3", 94 | srcs_version = "PY3", 95 | deps = [ 96 | ":graph_engine", 97 | requirement("numpy"), 98 | requirement("fsspec"), 99 | requirement("pytest"), 100 | requirement("opencensus"), 101 | requirement("opencensus-context"), 102 | requirement("opencensus-ext-azure"), 103 | requirement("azure-datalake-store"), 104 | ], 105 | ) 106 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Graph engine package contains key components such as 5 | backend, dataset, samplers. 6 | 7 | Put these components into a flatten namespace for easy usability. 8 | 9 | e.g. 10 | 11 | import deepgnn.graph_engine as ge 12 | 13 | dataset = ge.DeepGNNDataset( 14 | sampler=ge.GENodeSampler, 15 | backend_options=ge.BackendOptions(args), 16 | ... 17 | ) 18 | 19 | for i, data in enumerate(dataset): 20 | # train 21 | 22 | """ 23 | 24 | # flake8: noqa 25 | from deepgnn.graph_engine._base import * 26 | from deepgnn.graph_engine._adl_reader import ( 27 | TextFileIterator, 28 | TextFileSplitIterator, 29 | AdlCredentialParser, 30 | ) 31 | from deepgnn.graph_engine.samplers import * 32 | from deepgnn.graph_engine import graph_ops 33 | from deepgnn.graph_engine import multihop 34 | from deepgnn.graph_engine import backends 35 | from deepgnn.graph_engine.graph_dataset import * 36 | from deepgnn.graph_engine.utils import define_param_graph_engine 37 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "graph_engine_backends", 8 | srcs = [ 9 | "__init__.py", 10 | "common.py", 11 | "options.py", 12 | ], 13 | visibility = ["//visibility:public"], 14 | deps = [ 15 | "//src/python/deepgnn/graph_engine:graph_engine", 16 | ], 17 | ) 18 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from deepgnn.graph_engine.backends.options import BackendOptions, GraphType 3 | from deepgnn.graph_engine.backends.common import GraphEngineBackend 4 | from deepgnn.graph_engine.backends.snark import * 5 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Base graph engine backend class.""" 4 | import abc 5 | from deepgnn.graph_engine._base import Graph 6 | 7 | 8 | class GraphEngineBackend(abc.ABC): 9 | """Interface class of all backends for graph engine.""" 10 | 11 | @property 12 | def graph(self) -> Graph: 13 | """Get the graph client.""" 14 | raise NotImplementedError 15 | 16 | def close(self): 17 | """Close backend object.""" 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Options to initialize graph engine.""" 4 | import argparse 5 | from typing import List, Optional, Tuple 6 | from enum import Enum 7 | from deepgnn.graph_engine.snark.converter.options import ConverterOptions 8 | from deepgnn.graph_engine.snark.client import PartitionStorageType 9 | 10 | 11 | class GraphType(Enum): 12 | """Graph engine types.""" 13 | 14 | LOCAL = "local" 15 | REMOTE = "remote" 16 | 17 | def __str__(self): 18 | """Convert enum to string.""" 19 | return self.value 20 | 21 | 22 | class BackendOptions: 23 | """Options to start graph engine backend.""" 24 | 25 | def __init__(self, params: argparse.Namespace): 26 | """Initialize options from command line arguments.""" 27 | self.backend = None 28 | self.data_dir = "" 29 | # local GE only for local debugging. 30 | self.graph_type = GraphType.REMOTE 31 | self.model_dir = "" 32 | # Snark parameters 33 | self.ge_start_timeout = 30 34 | self.num_ge = 0 35 | self.partitions: List[int] = [] 36 | self.servers: List[str] = [] 37 | self.server_idx = -1 38 | self.client_rank = -1 39 | self.skip_ge_start = False 40 | self.sync_dir = "" 41 | self.enable_ssl = False 42 | self.ssl_cert = "" 43 | self.storage_type = PartitionStorageType.memory 44 | self.config_path = "" 45 | self.stream = False 46 | self.grpc_options: List[Tuple] = [] 47 | self.num_threads: Optional[int] = None 48 | self.num_cq_per_thread: Optional[int] = None 49 | 50 | # sometimes user need to implement their own backend, using this custom 51 | # field, user can start graph engine using their own code. 52 | self.custom_backendclass = None 53 | 54 | for arg in vars(params): 55 | setattr(self, arg, getattr(params, arg)) 56 | 57 | self.converter_options = ConverterOptions(params) 58 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/snark/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "graph_engine_backends_snark", 9 | srcs = [ 10 | "__init__.py", 11 | "client.py", 12 | "synchronized.py", 13 | ], 14 | visibility = ["//visibility:public"], 15 | deps = [ 16 | "//src/python/deepgnn/graph_engine/backends:graph_engine_backends", 17 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 18 | ], 19 | ) 20 | 21 | py_test( 22 | name = "test_snark_client", 23 | srcs = ["test_snark_client.py"], 24 | imports = ["../../../../"], 25 | main = "test_snark_client.py", 26 | python_version = "PY3", 27 | srcs_version = "PY3", 28 | deps = [ 29 | ":graph_engine_backends_snark", 30 | requirement("numpy"), 31 | requirement("pytest"), 32 | requirement("opencensus"), 33 | requirement("opencensus-context"), 34 | requirement("opencensus-ext-azure"), 35 | requirement("azure-datalake-store"), 36 | requirement("networkx"), 37 | requirement("fsspec"), 38 | requirement("tenacity"), 39 | ], 40 | ) 41 | 42 | py_test( 43 | name = "test_synchronized", 44 | srcs = ["test_synchronized.py"], 45 | imports = ["../../../../"], 46 | main = "test_synchronized.py", 47 | python_version = "PY3", 48 | srcs_version = "PY3", 49 | deps = [ 50 | ":graph_engine_backends_snark", 51 | requirement("numpy"), 52 | requirement("pytest"), 53 | requirement("opencensus"), 54 | requirement("opencensus-context"), 55 | requirement("azure-datalake-store"), 56 | requirement("opencensus-ext-azure"), 57 | requirement("tenacity"), 58 | requirement("fsspec"), 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/snark/__init__.py: -------------------------------------------------------------------------------- 1 | """C++ implementation of graph engine interface.""" 2 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/backends/snark/test_synchronized.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from concurrent.futures.thread import ThreadPoolExecutor 5 | import multiprocessing as mp 6 | import tempfile 7 | 8 | import pytest 9 | 10 | import deepgnn.graph_engine.backends.snark.synchronized as synchronized 11 | 12 | 13 | def test_simple_client_server_initialized_in_correct_order(): 14 | server_event = mp.Event() 15 | 16 | class MockServer: 17 | def __init__(self): 18 | server_event.set() 19 | 20 | def reset(self): 21 | assert server_event.is_set() 22 | 23 | client_event = mp.Event() 24 | 25 | class MockClient: 26 | def __init__(self): 27 | client_event.set() 28 | 29 | def reset(self): 30 | assert client_event.is_set() 31 | 32 | working_dir = tempfile.TemporaryDirectory() 33 | server = synchronized.SynchronizedServer(working_dir.name, 0, None, MockServer) 34 | client = synchronized.SynchronizedClient(working_dir.name, 0, 1, None, MockClient) 35 | server_event.wait(1) 36 | client_event.wait(1) 37 | client.reset() 38 | server.reset() 39 | 40 | 41 | def test_client_initialization_timeout(): 42 | # if a client is started without a server it should throw a timeout exception 43 | class MockClass: 44 | def reset(self): 45 | pass 46 | 47 | working_dir = tempfile.TemporaryDirectory() 48 | with pytest.raises(TimeoutError): 49 | wrapper = synchronized.SynchronizedClient( 50 | working_dir.name, rank=0, num_servers=1, timeout=1, klass=MockClass 51 | ) 52 | wrapper.client 53 | 54 | 55 | def test_server_waits_for_client_to_stop(): 56 | server_event = mp.Event() 57 | 58 | class MockServer: 59 | def __init__(self): 60 | server_event.set() 61 | 62 | def reset(self): 63 | assert server_event.is_set() 64 | 65 | client_event = mp.Event() 66 | 67 | class MockClient: 68 | def __init__(self): 69 | client_event.set() 70 | 71 | def reset(self): 72 | assert client_event.is_set() 73 | 74 | working_dir = tempfile.TemporaryDirectory() 75 | server = synchronized.SynchronizedServer(working_dir.name, 0, None, MockServer) 76 | client = synchronized.SynchronizedClient(working_dir.name, 0, 1, None, MockClient) 77 | server_event.wait() 78 | server_finished_event = mp.Event() 79 | client.client 80 | client_event.wait() 81 | with ThreadPoolExecutor() as executor: 82 | 83 | def server_done(): 84 | server.reset() 85 | server_finished_event.set() 86 | 87 | executor.submit(server_done) 88 | client.reset() 89 | server_finished_event.wait() 90 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "graph_engine_data", 9 | srcs = [ 10 | "citation.py", 11 | "citeseer.py", 12 | "cora.py", 13 | "data_util.py", 14 | "mooc.py", 15 | "ppi.py", 16 | "reddit.py", 17 | ], 18 | data = ["cora_full.zip", "citeseer_full.zip", "cora.zip", "citeseer.zip"], 19 | visibility = ["//visibility:public"], 20 | deps = [ 21 | "//src/python/deepgnn/graph_engine:graph_engine", 22 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 23 | ], 24 | ) 25 | 26 | py_test( 27 | name = "test_graph_dataset", 28 | srcs = ["test_graph_dataset.py"], 29 | imports = ["../../../"], 30 | main = "test_graph_dataset.py", 31 | python_version = "PY3", 32 | srcs_version = "PY3", 33 | deps = [ 34 | ":graph_engine_data", 35 | requirement("numpy"), 36 | requirement("azure-datalake-store"), 37 | requirement("pytest"), 38 | requirement("opencensus"), 39 | requirement("opencensus-context"), 40 | requirement("opencensus-ext-azure"), 41 | requirement("scipy"), 42 | requirement("scikit-learn"), 43 | requirement("networkx"), 44 | requirement("fsspec"), 45 | requirement("tenacity"), 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .cora import CoraFull 6 | from .citeseer import CiteseerFull 7 | from .citation import Cora, Citeseer 8 | from .ppi import PPI 9 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/citeseer.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/DeepGNN/6a8b5ce521f4908ac1d2e4401bb5e56df86f5074/src/python/deepgnn/graph_engine/data/citeseer.zip -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/citeseer_full.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/DeepGNN/6a8b5ce521f4908ac1d2e4401bb5e56df86f5074/src/python/deepgnn/graph_engine/data/citeseer_full.zip -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/cora.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/DeepGNN/6a8b5ce521f4908ac1d2e4401bb5e56df86f5074/src/python/deepgnn/graph_engine/data/cora.zip -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/data/cora_full.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/DeepGNN/6a8b5ce521f4908ac1d2e4401bb5e56df86f5074/src/python/deepgnn/graph_engine/data/cora_full.zip -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | load("@bazel_skylib//rules:copy_file.bzl", "copy_file") 6 | 7 | # out atrtibute is not configurable, hence the need for three rules 8 | copy_file( 9 | name = "library_as_data_windows", 10 | src = "//src/cc/lib:wrapper", 11 | out = ":wrapper.dll", 12 | ) 13 | 14 | copy_file( 15 | name = "library_as_data_macos", 16 | src = "//src/cc/lib:wrapper", 17 | out = ":libwrapper.dylib", 18 | ) 19 | 20 | copy_file( 21 | name = "library_as_data_linux", 22 | src = "//src/cc/lib:wrapper", 23 | out = ":libwrapper.so", 24 | ) 25 | 26 | py_library( 27 | name = "graph_engine_snark", 28 | srcs = [ 29 | "__init__.py", 30 | "_downloader.py", 31 | "_lib.py", 32 | "alias.py", 33 | "client.py", 34 | "convert.py", 35 | "meta_merger.py", 36 | "converter/__init__.py", 37 | "converter/process.py", 38 | "converter/writers.py", 39 | "converter/options.py", 40 | "decoders.py", 41 | "dispatcher.py", 42 | "distributed.py", 43 | "local.py", 44 | "meta.py", 45 | "server.py", 46 | ], 47 | data = select({ 48 | "@platforms//os:windows": [":wrapper.dll"], 49 | "@platforms//os:macos": [":libwrapper.dylib"], 50 | "@platforms//os:linux": [":libwrapper.so"], 51 | }), 52 | deps = [ 53 | "//src/python/deepgnn/graph_engine:graph_engine", 54 | "//src/python/deepgnn/graph_engine/snark/preprocess:snark_sampler", 55 | ], 56 | visibility = ["//visibility:public"], 57 | ) 58 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Graph engine implementation with compressed sparse format.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import platform 6 | import functools 7 | 8 | if platform.system() == "Windows": 9 | from ctypes import WinDLL # type: ignore 10 | else: 11 | from ctypes import CDLL # type: ignore 12 | 13 | _LIB_FILE_NAME = "libwrapper.so" 14 | if platform.system() == "Windows": 15 | _LIB_FILE_NAME = "wrapper.dll" 16 | elif platform.system() == "Darwin": 17 | _LIB_FILE_NAME = "libwrapper.dylib" 18 | 19 | _LIB_PATH = os.path.join(os.path.dirname(__file__), _LIB_FILE_NAME) 20 | 21 | # Use environment variables to load library with multiprocessing module 22 | _SNARK_LIB_PATH_ENV_KEY = "SNARK_LIB_PATH" 23 | 24 | 25 | # Use lru_cache for to load library only once in thread safe mode. 26 | @functools.lru_cache(maxsize=1) 27 | def _get_c_lib(): 28 | global _LIB_PATH 29 | 30 | if _SNARK_LIB_PATH_ENV_KEY in os.environ: 31 | _LIB_PATH = os.environ[_SNARK_LIB_PATH_ENV_KEY] 32 | 33 | if platform.system() == "Windows": 34 | lib = WinDLL(_LIB_PATH) # type: ignore 35 | else: 36 | lib = CDLL(_LIB_PATH) # type: ignore 37 | return lib 38 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/alias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Alias table generators.""" 5 | import random 6 | import typing 7 | 8 | import numpy as np 9 | 10 | 11 | class Vose: 12 | """Generate alias tables with Vose method.""" 13 | 14 | def __init__(self, elements: typing.List, weights: np.ndarray): 15 | """Create alias tables for weight sampling. 16 | 17 | Args: 18 | elements (typing.List): elements to sample 19 | weights (np.array): corresponging elements weights 20 | """ 21 | self.elements = elements 22 | weights = np.multiply(weights, len(weights) / np.sum(weights)) 23 | self.alias = np.empty(len(elements), dtype=np.uint64) 24 | self.prob = np.empty(len(elements), dtype=np.float32) 25 | self._generate_table(weights) 26 | 27 | def _generate_table(self, weights: np.ndarray): 28 | small = [] 29 | large = [] 30 | for i, w in enumerate(weights): 31 | if w < 1: 32 | small.append(i) 33 | else: 34 | large.append(i) 35 | 36 | while small and large: 37 | small_element = small.pop() 38 | large_element = large.pop() 39 | self.alias[small_element] = large_element 40 | self.prob[small_element] = weights[small_element] 41 | 42 | weights[large_element] = ( 43 | weights[large_element] + weights[small_element] 44 | ) - 1 45 | if weights[large_element] < 1: 46 | small.append(large_element) 47 | else: 48 | large.append(large_element) 49 | 50 | while large: 51 | self.prob[large.pop()] = 1 52 | 53 | while small: 54 | self.prob[small.pop()] = 1 55 | 56 | def sample(self) -> typing.Any: 57 | """Sample from alias tables. 58 | 59 | Returns: 60 | typing.Any: element from the original list 61 | """ 62 | n = random.randrange(len(self.alias)) 63 | if random.uniform(0, 1) > self.prob[n]: 64 | n = self.alias[n] 65 | return self.elements[n] 66 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/converter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Converters from various formats to the binary format for graph engine.""" 5 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/converter/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Snark converter option.""" 5 | from enum import Enum 6 | 7 | 8 | class DataConverterType(Enum): 9 | """DeepGNN's graph engine data converter type. 10 | 11 | Supported: 12 | Skip: binary data is already converted and no need to convert any more. 13 | """ 14 | 15 | SKIP = "skip" 16 | LOCAL = "local" 17 | 18 | def __str__(self): 19 | """Convert to string.""" 20 | return self.value 21 | 22 | 23 | class ConverterOptions: 24 | """All the data converter related configurations are here. 25 | 26 | Converters supported: 27 | Skip 28 | """ 29 | 30 | def __init__(self, params): 31 | """Init the coverter option.""" 32 | # default values. 33 | self.converter = DataConverterType.SKIP 34 | 35 | for arg in vars(params): 36 | if hasattr(self, arg): 37 | setattr(self, arg, getattr(params, arg)) 38 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/preprocess/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "snark_sampler", 8 | srcs = [ 9 | "__init__.py", 10 | "sampler/__init__.py", 11 | "sampler/forest_fire.py", 12 | "sampler/metric.py", 13 | ], 14 | visibility = ["//visibility:public"], 15 | ) 16 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """An implementation of graph sampling with plain data files.""" 5 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/preprocess/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """An implementation of graph sampling with plain data files.""" 5 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/preprocess/sampler/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Diagnostic tool built to obtain the properties of a NetworkX graph.""" 5 | import networkx as nx 6 | import math 7 | import random 8 | 9 | 10 | def densification(g: nx.digraph) -> float: 11 | """Return densification constant in NetworkX directed graph.""" 12 | return math.log(nx.number_of_edges(g), nx.number_of_nodes(g)) 13 | 14 | 15 | def diameter(g: nx.digraph) -> int: 16 | """Return effective 90% diameter in NetworkX directed graph.""" 17 | largest_strongly_connected_component = g.subgraph( 18 | max(nx.strongly_connected_components(g), key=len) 19 | ) 20 | return nx.diameter(largest_strongly_connected_component) 21 | 22 | 23 | def largest_connected_component(g: nx.digraph) -> float: 24 | """Return the scaled size of the largest strongly connected component in NetworkX directed graph.""" 25 | largest_strongly_connected_component = g.subgraph( 26 | max(nx.strongly_connected_components(g), key=len) 27 | ) 28 | return nx.number_of_nodes( 29 | largest_strongly_connected_component 30 | ) / nx.number_of_nodes(g) 31 | 32 | 33 | def max_adjacency(g: nx.digraph) -> float: 34 | """Return the largest eigenvalue of the adjacency matrix in NetworkX directed graph.""" 35 | return max(nx.adjacency_spectrum(g)) 36 | 37 | 38 | def average_clustering(input_graph: nx.digraph, trials: int) -> float: 39 | """Return the average clustering coefficient in NetworkX directed graph.""" 40 | g = nx.to_undirected(input_graph) 41 | n = len(g) 42 | triangles = 0 43 | nodes = g.nodes() 44 | for i in [random.randint(0, n) for _ in range(trials)]: 45 | nbrs = list(g[nodes[i]]) 46 | if len(nbrs) < 2: 47 | continue 48 | u, v = random.sample(nbrs, 2) 49 | if u in g[v]: 50 | triangles += 1 51 | return triangles / float(trials) 52 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/tests/alias_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import sys 5 | import random 6 | import pytest 7 | import deepgnn.graph_engine.snark.alias as alias 8 | import os 9 | 10 | 11 | def test_sanity_alias(): 12 | random.seed(5) 13 | t = alias.Vose([0, 1, 2, 3], [1.0, 2.4, 0.5, 0.1]) 14 | # expected ratios: [0.25, 0.6, 0.125, 0.025] 15 | counts = [0, 0, 0, 0] 16 | num_trials = 10000 17 | for _ in range(num_trials): 18 | counts[t.sample()] += 1 19 | assert counts[0] == 2454 20 | assert counts[1] == 6006 21 | assert counts[2] == 1279 22 | assert counts[3] == 261 23 | 24 | 25 | if __name__ == "__main__": 26 | sys.exit( 27 | pytest.main( 28 | [__file__, "--junitxml", os.environ["XML_OUTPUT_FILE"], *sys.argv[1:]] 29 | ) 30 | ) 31 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/tests/metric_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import sys 5 | import pytest 6 | import os 7 | import networkx as nx 8 | import random 9 | import deepgnn.graph_engine.snark.preprocess.sampler.metric as metr 10 | 11 | 12 | @pytest.fixture 13 | def graph_choice_one(): 14 | random.seed(5) 15 | g = nx.scale_free_graph(1000, seed=100) 16 | return g 17 | 18 | 19 | @pytest.fixture 20 | def graph_choice_two(): 21 | g = nx.star_graph(100).to_directed() 22 | return g 23 | 24 | 25 | def test_densification_one(graph_choice_one): 26 | d = metr.densification(graph_choice_one) 27 | assert d == pytest.approx(1.11181930002297) 28 | 29 | 30 | def test_densification_two(graph_choice_two): 31 | d = metr.densification(graph_choice_two) 32 | assert d == pytest.approx(1.1480344548346) 33 | 34 | 35 | def test_effective_diameter_one(graph_choice_one): 36 | d = metr.diameter(graph_choice_one) 37 | assert d == 7 38 | 39 | 40 | def test_effective_diameter_two(graph_choice_two): 41 | d = metr.diameter(graph_choice_two) 42 | assert d == 2 43 | 44 | 45 | def test_largest_eigenvalue_one(graph_choice_one): 46 | d = metr.max_adjacency(graph_choice_one) 47 | assert d == pytest.approx(52.7454614388038) 48 | 49 | 50 | def test_largest_eigenvalue_two(graph_choice_two): 51 | d = metr.max_adjacency(graph_choice_two) 52 | assert d == pytest.approx(10) 53 | 54 | 55 | if __name__ == "__main__": 56 | sys.exit( 57 | pytest.main( 58 | [__file__, "--junitxml", os.environ["XML_OUTPUT_FILE"], *sys.argv[1:]] 59 | ) 60 | ) 61 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/tests/ppr_benchmark_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from dataclasses import dataclass 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import pytest 10 | 11 | from deepgnn.graph_engine.data.cora import CoraFull 12 | import deepgnn.graph_engine.snark.server as server 13 | import deepgnn.graph_engine.snark.distributed as distributed 14 | 15 | 16 | def weighted_sample(graph, batches): 17 | total = 0 18 | for batch in batches: 19 | nodes = graph.sample_neighbors( 20 | strategy="byweight", 21 | nodes=batch, 22 | edge_types=np.array([0], dtype=np.int32), 23 | count=50, 24 | )[0] 25 | total += len(nodes) 26 | return total 27 | 28 | 29 | def sample(graph, batches, alpha, epsilon): 30 | total = 0 31 | for batch in batches: 32 | nodes = graph.sample_neighbors( 33 | strategy="ppr-go", 34 | nodes=batch, 35 | edge_types=np.array([0], dtype=np.int32), 36 | count=50, 37 | alpha=alpha, 38 | eps=epsilon, 39 | )[0] 40 | total += len(nodes) 41 | return total 42 | 43 | 44 | @dataclass 45 | class BenchmarkData: 46 | graph: str 47 | data_dir: str 48 | inputs: np.array 49 | returned_nodes_count: int = 0 50 | 51 | 52 | @pytest.fixture(scope="session") 53 | def dataset(): 54 | data_dir = "/tmp/cora" 55 | graph = CoraFull(data_dir) 56 | batch_size = 256 57 | nodes = np.arange( 58 | graph.NUM_NODES + (batch_size - graph.NUM_NODES % batch_size), 59 | dtype=np.int64, 60 | ) 61 | batches = nodes.reshape(-1, batch_size) 62 | return BenchmarkData(graph, data_dir, batches, len(nodes)) 63 | 64 | 65 | def test_ppr_on_cora_distributed(benchmark, dataset): 66 | s = server.Server(dataset.data_dir, partitions=[0], hostname="localhost:50051") 67 | c = distributed.Client(["localhost:50051"]) 68 | result = benchmark( 69 | sample, 70 | c, 71 | dataset.inputs, 72 | 0.85, 73 | 0.0001, 74 | ) 75 | c.reset() 76 | s.reset() 77 | assert result == dataset.returned_nodes_count 78 | 79 | 80 | def test_ppr_on_cora_in_memory(benchmark, dataset): 81 | result = benchmark(sample, dataset.graph, dataset.inputs, 0.85, 0.0001) 82 | assert result == dataset.returned_nodes_count 83 | 84 | 85 | def test_weighted_on_cora_in_memory(benchmark, dataset): 86 | result = benchmark(weighted_sample, dataset.graph, dataset.inputs) 87 | assert result == dataset.returned_nodes_count 88 | 89 | 90 | if __name__ == "__main__": 91 | sys.exit( 92 | pytest.main( 93 | [__file__, "--junitxml", os.environ["XML_OUTPUT_FILE"], *sys.argv[1:]] 94 | ) 95 | ) 96 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/snark/tests/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.10.11 2 | attr==0.3.1 3 | fsspec==2021.8.1 4 | # Preferably the vershion should match C++ version defined in the WORKSPACE file. 5 | grpcio==1.53.2 6 | grpcio-health-checking==1.35.0 7 | importlib_metadata==2.0.0 8 | iniconfig==1.0.1 9 | networkx==2.5.1 10 | numpy==1.22.0 11 | packaging==20.4 12 | pluggy==0.13.1 13 | py==1.10.0 14 | 15 | # For tests 16 | pytest==6.1.2 17 | requests==2.32.0 18 | scipy==1.10.0 19 | -------------------------------------------------------------------------------- /src/python/deepgnn/graph_engine/test_prefetch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import unittest.mock as mock 5 | import math 6 | import pytest 7 | import numpy as np 8 | 9 | from deepgnn.graph_engine._base import Graph 10 | from deepgnn.graph_engine.samplers import RangeNodeSampler, BaseSampler 11 | from deepgnn.graph_engine import prefetch 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def create_graph(): 16 | g = Graph() 17 | yield g 18 | 19 | 20 | class ExceptionSampler(BaseSampler): 21 | def __init__(self, exception_msg): 22 | self.msg = exception_msg 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | raise Exception(self.msg) 29 | 30 | 31 | def test_handle_sampler_exception(create_graph): 32 | s = ExceptionSampler("test sampler exception") 33 | model = mock.MagicMock() 34 | model.sampler = s 35 | 36 | gen = prefetch.Generator(create_graph, model.sampler, model.query) 37 | num_iterations = 0 38 | 39 | with pytest.raises(Exception): 40 | for _ in gen(): 41 | num_iterations += 1 42 | 43 | assert num_iterations == 0 44 | gen.join() 45 | 46 | 47 | def test_handle_query_exception(create_graph): 48 | model = mock.MagicMock() 49 | attrs = {"query.side_effect": Exception("test query exception")} 50 | model = mock.MagicMock(**attrs) 51 | model.sampler = RangeNodeSampler(0, 10, 6, 0, 1, backfill_id=-1) 52 | gen = prefetch.Generator(create_graph, model.sampler, model.query) 53 | 54 | with pytest.raises(Exception): 55 | for _ in gen(): 56 | pass 57 | 58 | gen.join() 59 | 60 | 61 | def test_generator_join_with_break(create_graph): 62 | s = RangeNodeSampler(0, 1024, 32, 0, 1, backfill_id=-1) 63 | model = mock.MagicMock() 64 | model.sampler = s 65 | 66 | gen = prefetch.Generator(create_graph, model.sampler, model.query) 67 | num_iterations = 0 68 | 69 | for i, _ in enumerate(gen): 70 | if i >= 20: 71 | break 72 | num_iterations += 1 73 | 74 | gen.join() 75 | # make sure the join will not block the process. 76 | assert True 77 | 78 | 79 | def test_generator_join(create_graph): 80 | for total_size in range(1, 50): 81 | for batch_size in range(1, total_size): 82 | s = RangeNodeSampler(0, total_size, batch_size, 0, 1, backfill_id=-1) 83 | attrs = {"query.return_value": {"inputs": np.ndarray([1], np.int64)}} 84 | model = mock.MagicMock(**attrs) 85 | model.sampler = s 86 | 87 | gen = prefetch.Generator(create_graph, model.sampler, model.query) 88 | num_iterations = 0 89 | for _, minibatch in enumerate(gen): 90 | num_iterations += 1 91 | 92 | assert num_iterations == math.ceil(total_size / batch_size) 93 | gen.join() 94 | -------------------------------------------------------------------------------- /src/python/deepgnn/log_consts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Common strings used for logging.""" 4 | 5 | LOG_NAME_DEEPGNN = "deepgnn" 6 | LOG_PROPS_CUSTOM_DIMENSIONS = "custom_dimensions" 7 | LOG_PROPS_EVENT_END_JOB = "deepgnn.end_job" 8 | LOG_PROPS_EVENT_END_WORKER = "deepgnn.end_worker" 9 | LOG_PROPS_EVENT_START_JOB = "deepgnn.start_job" 10 | LOG_PROPS_EVENT_START_WORKER = "deepgnn.start_worker" 11 | LOG_PROPS_KEY_ERR_CODE = "deepgnn.error_code" 12 | LOG_PROPS_KEY_EVENT_TYPE = "deepgnn.event_type" 13 | LOG_PROPS_KEY_JOB_ID = "deepgnn.job_id" 14 | LOG_PROPS_KEY_MODE = "deepgnn.mode" 15 | LOG_PROPS_KEY_MODEL = "deepgnn.model" 16 | LOG_PROPS_KEY_NUM_WORKERS = "deepgnn.num_workers" 17 | LOG_PROPS_KEY_PLATFORM = "deepgnn.platform" 18 | LOG_PROPS_KEY_USER_NAME = "deepgnn.user_name" 19 | LOG_PROPS_KEY_WORKER_INDEX = "deepgnn.worker_index" 20 | LOG_PROPS_PLATFORM_PYTORCH = "deepgnn.pytorch" 21 | LOG_PROPS_PLATFORM_TF = "deepgnn.tensorflow" 22 | -------------------------------------------------------------------------------- /src/python/deepgnn/migrate/0_1_56.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Script to migrate to new versions of DeepGNN.""" 4 | import argparse 5 | from pathlib import Path 6 | import pasta 7 | from pasta.augment import rename 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="Migrate to new versions of DeepgNN.") 12 | parser.add_argument( 13 | "--script_dir", type=str, required=True, help="Directory to migrate." 14 | ) 15 | args = parser.parse_args() 16 | 17 | for filename in Path(args.script_dir).glob("**/*.py"): 18 | with open(filename, "r") as file: 19 | raw_input = file.read() 20 | 21 | raw_input = raw_input.replace("FeatureType.FLOAT", "np.float32") 22 | raw_input = raw_input.replace("FeatureType.INT64", "np.int64") 23 | raw_input = raw_input.replace("FeatureType.BINARY", "np.uint8") 24 | 25 | tree = pasta.parse(raw_input) 26 | 27 | rename.rename_external( 28 | tree, "deepgnn.graph_engine.FeatureType", "numpy.dtype_temp" 29 | ) 30 | 31 | raw_output = pasta.dump(tree) 32 | 33 | raw_output = raw_output.replace( 34 | "from numpy import dtype_temp", "import numpy as np" 35 | ) 36 | raw_output = raw_output.replace("dtype_temp", "np.dtype") 37 | 38 | with open(filename, "w") as file: 39 | file.write(raw_output) 40 | -------------------------------------------------------------------------------- /src/python/deepgnn/migrate/0_1_57.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Script to migrate to new versions of DeepGNN.""" 4 | import argparse 5 | from pathlib import Path 6 | import pasta 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser(description="Migrate to new versions of DeepgNN.") 11 | parser.add_argument( 12 | "--script_dir", type=str, required=True, help="Directory to migrate." 13 | ) 14 | args = parser.parse_args() 15 | 16 | for filename in Path(args.script_dir).glob("**/*.py"): 17 | with open(filename, "r") as file: 18 | raw_input = file.read() 19 | 20 | tree = pasta.parse(raw_input) 21 | 22 | raw_output = pasta.dump(tree) 23 | 24 | raw_output = raw_output.replace("get_feature_type", "get_python_type") 25 | 26 | with open(filename, "w") as file: 27 | file.write(raw_output) 28 | -------------------------------------------------------------------------------- /src/python/deepgnn/migrate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN migration modules.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "deepgnn_pytorch", 8 | srcs = [ 9 | "__init__.py", 10 | ], 11 | deps = [ 12 | "//src/python/deepgnn/graph_engine:graph_engine", 13 | "//src/python/deepgnn/graph_engine:graph_engine_exports", 14 | "//src/python/deepgnn:deepgnn", 15 | ], 16 | visibility = ["//visibility:public"], 17 | ) 18 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN modules to work with torch models.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "deepgnn_pytorch_common", 9 | srcs = [ 10 | "__init__.py", 11 | "aggregators.py", 12 | "args.py", 13 | "consts.py", 14 | "dataset.py", 15 | "metrics.py", 16 | "optimization.py", 17 | "utils.py", 18 | ], 19 | deps = [ 20 | "//src/python/deepgnn/pytorch:deepgnn_pytorch", 21 | ], 22 | visibility = ["//visibility:public"], 23 | ) 24 | 25 | py_test( 26 | name = "test_metrics", 27 | srcs = ["test_metrics.py"], 28 | imports = ["../../../"], 29 | main = "test_metrics.py", 30 | python_version = "PY3", 31 | srcs_version = "PY3", 32 | deps = [ 33 | ":deepgnn_pytorch_common", 34 | "//src/python/deepgnn/graph_engine/backends/snark:graph_engine_backends_snark", 35 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 36 | requirement("numpy"), 37 | requirement("pytest"), 38 | requirement("torch"), 39 | requirement("fsspec"), 40 | requirement("networkx"), 41 | requirement("scikit-learn"), 42 | requirement("opencensus"), 43 | requirement("opencensus-context"), 44 | requirement("opencensus-ext-azure"), 45 | requirement("azure-datalake-store"), 46 | requirement("tenacity"), 47 | ], 48 | ) 49 | 50 | py_test( 51 | name = "test_utils", 52 | srcs = ["test_utils.py"], 53 | imports = ["../../../", "./"], 54 | main = "test_utils.py", 55 | python_version = "PY3", 56 | srcs_version = "PY3", 57 | deps = [ 58 | ":deepgnn_pytorch_common", 59 | "//src/python/deepgnn/graph_engine/backends/snark:graph_engine_backends_snark", 60 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 61 | requirement("numpy"), 62 | requirement("fsspec"), 63 | requirement("pytest"), 64 | requirement("scikit-learn"), 65 | requirement("torch"), 66 | requirement("networkx"), 67 | requirement("opencensus"), 68 | requirement("opencensus-context"), 69 | requirement("opencensus-ext-azure"), 70 | requirement("azure-datalake-store"), 71 | requirement("tenacity"), 72 | ], 73 | ) 74 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .aggregators import MeanAggregator 6 | from .args import init_common_args 7 | from .metrics import BaseMetric, MRR, F1Score, ROC, Accuracy 8 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/aggregators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Set of modules for aggregating embeddings of neighbors.""" 5 | 6 | from typing import Callable 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class MeanAggregator(nn.Module): 13 | """Aggregates a node's embeddings using mean of neighbors' embeddings.""" 14 | 15 | def __init__(self, features: Callable[[torch.Tensor], torch.Tensor]): 16 | """ 17 | Initialize the aggregator for a specific graph. 18 | 19 | features -- function mapping LongTensor of node ids to FloatTensor of feature values. 20 | """ 21 | super(MeanAggregator, self).__init__() 22 | 23 | self.features = features 24 | 25 | def forward(self, neighs: torch.Tensor, node_count: int) -> torch.Tensor: 26 | """ 27 | Propagate node features to NN Layer. 28 | 29 | neighs --- context of neighbors with a shape 30 | """ 31 | neigh_feats = self.features(neighs) 32 | 33 | nb_count = int(neigh_feats.shape[0] / node_count) 34 | fv_by_node = neigh_feats.view(node_count, nb_count, neigh_feats.shape[-1]) 35 | return fv_by_node.mean(1) 36 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/consts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Common strings used across models.""" 4 | 5 | DOWNSCALE = "downscale" 6 | EMBEDDING_TYPE = "embedding_type" 7 | EMBS_LIST = "embs_list" 8 | ENCODER_MASK = "encoder_mask" 9 | ENCODER_SEQ = "encoder_seq" 10 | ENCODER_TYPES = "encoder_types" 11 | FANOUTS = "fanouts" 12 | FEATURE_ENCODER_STR = "feature_encoder" 13 | FP16_AMP = "amp" 14 | FP16_APEX = "apex" 15 | FP16_NO = "no" 16 | INPUTS = "inputs" 17 | MAX_SENT_CHARS = "max_sentence_characters" 18 | MAX_SEQ_LEN = "max_seq_len" 19 | MODEL_RESIDUAL_ADD = "add" 20 | MODEL_RESIDUAL_CONCAT = "concat" 21 | META_DIR = "metadir" 22 | NODE_DST = "dst" 23 | NODE_FEATURES = "node_feats" 24 | NODE_SRC = "src" 25 | OUTPUT = "output" 26 | PREFIX_CHECKPOINT = "gnnmodel" 27 | PREFIX_EMBEDDING = "embedding" 28 | TERM_TENSOR = "term_tensor" 29 | TRILETTER = "triletter" 30 | TRILETTER_MAX_LETTERS_IN_WORD = "triletter_max_letters_in_word" 31 | VOCAB_FILE = "vocab_file" 32 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Dataset implementation for torch models.""" 4 | 5 | import torch 6 | from deepgnn.graph_engine import ( 7 | DeepGNNDataset, 8 | GraphEngineBackend, 9 | ) 10 | from torch.utils.data import IterableDataset 11 | from typing import Optional, Callable, Iterator 12 | 13 | 14 | class TorchDeepGNNDataset(IterableDataset, DeepGNNDataset): 15 | """Implementation of TorchDeepGNNDataset for use in a Torch Dataloader. 16 | 17 | TorchDeepGNNDataset initializes and executes a node or edge sampler given as 18 | sampler_class. For every batch of data requested, batch_size items are sampled 19 | from the sampler and passed to the given query_fn which pulls all necessaary 20 | information about the samples using the graph engine API. The output from 21 | the query function is passed to the trainer worker as the input to the 22 | model forward function. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | sampler_class, 28 | query_fn: Callable, 29 | backend: Optional[GraphEngineBackend] = None, 30 | num_workers: int = 1, 31 | worker_index: int = 0, 32 | batch_size: int = 1, 33 | epochs: int = 1, 34 | enable_prefetch: bool = False, 35 | # parameters to initialize samplers 36 | **kwargs, 37 | ): 38 | """Initialize TorchDeepGNNDataset.""" 39 | super(TorchDeepGNNDataset, self).__init__( 40 | sampler_class, 41 | query_fn, 42 | backend, 43 | num_workers, 44 | worker_index, 45 | batch_size, 46 | epochs, 47 | enable_prefetch, 48 | **kwargs, 49 | ) 50 | 51 | def init_graph_client(self): 52 | """No-op function. 53 | 54 | When using multiple process to load the data in 55 | parallel, each process should has its own copy of graph 56 | client, otherwise there will be segmentfault error. Here 57 | we return a None to postpone the initializeation of the 58 | graph/sampler to __iter__. 59 | """ 60 | pass 61 | 62 | def init_sampler(self): 63 | """No-op function. 64 | 65 | Overide the base method to postpone the sampler initialization to __iter__ 66 | """ 67 | return 68 | 69 | def _torch_init_sampler(self): 70 | # get the 'deep' copy of the graph. 71 | self.graph = self.backend.graph 72 | 73 | worker_info = torch.utils.data.get_worker_info() 74 | if worker_info is not None: 75 | self.kwargs.update( 76 | { 77 | "data_parallel_index": worker_info.id, 78 | "data_parallel_num": worker_info.num_workers, 79 | } 80 | ) 81 | super().init_sampler() 82 | 83 | def __iter__(self) -> Iterator: 84 | """Create sampler and start iteration.""" 85 | self._torch_init_sampler() 86 | return DeepGNNDataset.__iter__(self) 87 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Metrics implementation for graphsage models.""" 5 | 6 | import torch 7 | from sklearn.metrics import f1_score, roc_auc_score, accuracy_score 8 | 9 | 10 | class BaseMetric(object): 11 | """Base class for metrics.""" 12 | 13 | def name(self) -> str: 14 | """Return name of the metric.""" 15 | return self.__class__.__name__ 16 | 17 | def compute(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 18 | """Compute metric value based on model scores and expected labels.""" 19 | raise NotImplementedError 20 | 21 | 22 | class F1Score(BaseMetric): 23 | """F1 score implementation.""" 24 | 25 | def compute(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 26 | """Passthrough to scikit.""" 27 | return torch.tensor( 28 | f1_score(labels.squeeze(), scores.detach().cpu().numpy(), average="micro") 29 | ) 30 | 31 | 32 | class Accuracy(BaseMetric): 33 | """Accuracy classification score.""" 34 | 35 | def compute(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 36 | """Passthrough to scikit.""" 37 | return torch.tensor( 38 | accuracy_score(y_true=labels.cpu(), y_pred=scores.detach().cpu().numpy()) 39 | ) 40 | 41 | 42 | class MRR(BaseMetric): 43 | """MRR score implementation.""" 44 | 45 | def __init__(self, rank_in_ascending_order: bool = False): 46 | """ 47 | Initialize MRR metric. 48 | 49 | rank_in_ascending_order: 50 | Should we get the rank in the ascending order or 51 | descending order, if set to True will calculate 52 | the rank in ascending order. 53 | """ 54 | super().__init__() 55 | self.rank_in_ascending_order = rank_in_ascending_order 56 | 57 | def compute(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 58 | """Compute metric based on logit scores.""" 59 | assert len(scores.shape) > 1 60 | assert scores.size() == labels.size() 61 | 62 | size = scores.shape[-1] 63 | if self.rank_in_ascending_order: 64 | scores = -1 * scores 65 | _, indices_of_ranks = torch.topk(scores, k=size) 66 | _, ranks = torch.topk(-indices_of_ranks, k=size) 67 | return torch.mean( 68 | torch.reciprocal( 69 | torch.matmul(ranks.float(), torch.transpose(labels, -2, -1).float()) + 1 70 | ) 71 | ) 72 | 73 | 74 | class ROC(BaseMetric): 75 | """ROC score implementation with scikit.""" 76 | 77 | def compute(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 78 | """Compute metric value based on model scores and expected labels.""" 79 | return torch.tensor(roc_auc_score(labels.cpu(), scores.cpu().detach().numpy())) 80 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/test_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import pytest 5 | 6 | import torch 7 | import deepgnn.pytorch.common.metrics as metrics 8 | 9 | 10 | @pytest.fixture 11 | def mrr_test_tensors(): 12 | # Return list of scores, labels, ascending expected value, descending expected value 13 | return [ 14 | ( 15 | torch.tensor([[1, 2.3]]), 16 | torch.tensor([[0, 1]]), # Positive example in the end. 17 | 0.5, # Ascending expected value 18 | 1.0, # Descending expected value 19 | ), 20 | ( 21 | torch.tensor([[1.0, 0.08, 0.01]]), 22 | torch.tensor([[0, 1, 0]]), # Positive example in the middle. 23 | 0.5, # Ascending expected value 24 | 0.5, # Descending expected value 25 | ), 26 | ( 27 | torch.tensor([[0.2, 0.8, 0.1]]), 28 | torch.tensor([[0, 0, 1]]), # Positive example in the end. 29 | 1.0, # Ascending expected value 30 | (1.0 / 3.0), # Descending expected value 31 | ), 32 | ( 33 | torch.tensor([[0.2, 0.8, 0.1, 0.3, 0.4]]), 34 | torch.tensor([[1, 0, 0, 0, 0]]), # Positive example in the front. 35 | (1.0 / 2.0), # Ascending expected value 36 | (1.0 / 4.0), # Descending expected value 37 | ), 38 | ] 39 | 40 | 41 | def test_mrr_implementation(mrr_test_tensors): 42 | mrr_asc = metrics.MRR(rank_in_ascending_order=True) 43 | mrr_desc = metrics.MRR(rank_in_ascending_order=False) 44 | for scores, labels, asc_expected, desc_expected in mrr_test_tensors: 45 | asc_pred = mrr_asc.compute(scores, labels) 46 | assert asc_pred == asc_expected 47 | desc_pred = mrr_desc.compute(scores, labels) 48 | assert desc_pred == desc_expected 49 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/common/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from deepgnn.pytorch.common.utils import get_store_name_and_path 5 | 6 | 7 | def test_get_store_name_and_path(): 8 | input_path = "adl://grocery.azuredatalakestore.net/local/test/" 9 | store_name, relative_path = get_store_name_and_path(input_path) 10 | assert store_name == "grocery" 11 | assert relative_path == "/local/test/" 12 | 13 | input_path = "/local/test1/" 14 | store_name, relative_path = get_store_name_and_path(input_path) 15 | assert store_name == "" 16 | assert relative_path == "/local/test1/" 17 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "deepgnn_pytorch_encoding", 9 | srcs = [ 10 | "__init__.py", 11 | "feature_encoder.py", 12 | "gnn_encoder_gat.py", 13 | "gnn_encoder_hetgnn.py", 14 | "gnn_encoder_lgcl.py", 15 | "gnn_encoder_lightgcn.py", 16 | "gnn_encoder_sage.py", 17 | ], 18 | deps = [ 19 | "//src/python/deepgnn/pytorch/encoding/twinbert:deepgnn_pytorch_encoding_twinbert", 20 | "//src/python/deepgnn/graph_engine/backends/snark:graph_engine_backends_snark", 21 | "//src/python/deepgnn/graph_engine/snark:graph_engine_snark", 22 | ], 23 | visibility = ["//visibility:public"], 24 | ) 25 | 26 | py_test( 27 | name = "test_feature_encoder", 28 | srcs = ["test_feature_encoder.py"], 29 | imports = ["../../../"], 30 | main = "test_feature_encoder.py", 31 | python_version = "PY3", 32 | srcs_version = "PY3", 33 | deps = [ 34 | ":deepgnn_pytorch_encoding", 35 | requirement("numpy"), 36 | requirement("pytest"), 37 | requirement("scikit-learn"), 38 | requirement("torch"), 39 | requirement("fsspec"), 40 | requirement("transformers"), 41 | requirement("networkx"), 42 | requirement("opencensus"), 43 | requirement("opencensus-context"), 44 | requirement("opencensus-ext-azure"), 45 | requirement("azure-datalake-store"), 46 | requirement("tenacity"), 47 | ], 48 | ) 49 | 50 | py_test( 51 | name = "test_gnn_encoders", 52 | srcs = ["test_gnn_encoders.py"], 53 | imports = ["../../../"], 54 | main = "test_gnn_encoders.py", 55 | python_version = "PY3", 56 | srcs_version = "PY3", 57 | deps = [ 58 | ":deepgnn_pytorch_encoding", 59 | requirement("numpy"), 60 | requirement("pytest"), 61 | requirement("fsspec"), 62 | requirement("scikit-learn"), 63 | requirement("torch"), 64 | requirement("transformers"), 65 | requirement("networkx"), 66 | requirement("opencensus"), 67 | requirement("opencensus-context"), 68 | requirement("opencensus-ext-azure"), 69 | requirement("azure-datalake-store"), 70 | requirement("tenacity"), 71 | ], 72 | ) 73 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Collection of encoders for GNNs.""" 4 | # flake8: noqa 5 | from .feature_encoder import ( 6 | FeatureEncoder, 7 | TwinBERTEncoder, 8 | TwinBERTFeatureEncoder, 9 | MultiTypeFeatureEncoder, 10 | get_feature_encoder, 11 | ) 12 | 13 | from .gnn_encoder_gat import GatEncoder 14 | from .gnn_encoder_hetgnn import HetGnnEncoder 15 | from .gnn_encoder_lgcl import LgclEncoder 16 | from .gnn_encoder_sage import SageEncoder 17 | from .gnn_encoder_lightgcn import LightGCNEncoder 18 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/gnn_encoder_gat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Encoder for GAT model.""" 4 | from typing import Optional, Callable 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class GatEncoder(nn.Module): 11 | """Reference: https://arxiv.org/pdf/1710.10903.pdf.""" 12 | 13 | def __init__( 14 | self, 15 | in_features: int, 16 | out_features: int, 17 | dropout: float = 0.2, 18 | negative_slope: float = 1e-2, 19 | concat: bool = True, 20 | act: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 21 | ): 22 | """Initialize encoder.""" 23 | super(GatEncoder, self).__init__() 24 | 25 | self.dropout = dropout 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.negative_slope = negative_slope 29 | self.concat = concat 30 | self.act = act 31 | 32 | self.leakyrelu = nn.LeakyReLU(self.negative_slope) 33 | self.fc = nn.Linear(self.in_features, self.out_features, bias=False) 34 | self.attn_l = nn.Linear(self.out_features, 1) 35 | self.attn_r = nn.Linear(self.out_features, 1) 36 | 37 | def forward(self, combind_feats: torch.Tensor) -> torch.Tensor: 38 | """Evaluate encoder.""" 39 | feats = self.fc(combind_feats) 40 | f_1 = self.attn_l(feats) 41 | f_2 = self.attn_r(feats) 42 | logits = f_1 + f_2.permute(0, 2, 1) 43 | coefs = F.softmax(self.leakyrelu(logits), dim=1) 44 | ret = torch.matmul(coefs, feats) 45 | 46 | if self.act is None: 47 | return ret 48 | else: 49 | return self.act(ret) 50 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/gnn_encoder_lightgcn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Reference: http://staff.ustc.edu.cn/~hexn/papers/sigir20-LightGCN.pdf.""" 4 | import torch 5 | import torch.nn as nn 6 | from deepgnn.pytorch.common.consts import INPUTS, FANOUTS 7 | 8 | 9 | class LightGCNEncoder(nn.Module): 10 | """Encoder for lightGCN model.""" 11 | 12 | def __init__(self): 13 | """Initialize underlying nn.Module.""" 14 | super(LightGCNEncoder, self).__init__() 15 | 16 | def forward(self, context: dict) -> torch.Tensor: 17 | """Evaluate encoder.""" 18 | samples = context[INPUTS] 19 | fanouts = context[FANOUTS] 20 | 21 | num_layers = len(fanouts) 22 | if num_layers == 0: 23 | return samples[0] 24 | 25 | assert num_layers == 1 26 | neighbor = torch.reshape(samples[1], [-1, fanouts[0], samples[1].shape[-1]]) 27 | neighbor = torch.mean(neighbor, dim=1) 28 | seq = (samples[0] + neighbor) / 2 29 | return seq 30 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 4 | load("@pip_deps//:requirements.bzl", "requirement") 5 | 6 | py_library( 7 | name = "deepgnn_pytorch_encoding_twinbert", 8 | srcs = [ 9 | "__init__.py", 10 | "configuration.py", 11 | "embedding.py", 12 | "encoder.py", 13 | "pooler.py", 14 | "tokenization.py", 15 | ], 16 | deps = [ 17 | "//src/python/deepgnn/pytorch:deepgnn_pytorch", 18 | "//src/python/deepgnn/pytorch/common:deepgnn_pytorch_common", 19 | ], 20 | visibility = ["//visibility:public"], 21 | ) 22 | 23 | py_test( 24 | name = "test_encoder", 25 | srcs = ["test_encoder.py"], 26 | imports = ["../../../../"], 27 | main = "test_encoder.py", 28 | python_version = "PY3", 29 | srcs_version = "PY3", 30 | deps = [ 31 | ":deepgnn_pytorch_encoding_twinbert", 32 | requirement("numpy"), 33 | requirement("fsspec"), 34 | requirement("pytest"), 35 | requirement("scikit-learn"), 36 | requirement("torch"), 37 | requirement("networkx"), 38 | requirement("transformers"), 39 | requirement("opencensus"), 40 | requirement("opencensus-context"), 41 | requirement("opencensus-ext-azure"), 42 | requirement("azure-datalake-store"), 43 | ], 44 | ) 45 | 46 | py_test( 47 | name = "test_tokenization", 48 | srcs = ["test_tokenization.py"], 49 | imports = ["../../../../"], 50 | main = "test_tokenization.py", 51 | python_version = "PY3", 52 | srcs_version = "PY3", 53 | deps = [ 54 | ":deepgnn_pytorch_encoding_twinbert", 55 | requirement("numpy"), 56 | requirement("pytest"), 57 | requirement("fsspec"), 58 | requirement("scikit-learn"), 59 | requirement("transformers"), 60 | requirement("torch"), 61 | requirement("networkx"), 62 | requirement("opencensus"), 63 | requirement("opencensus-context"), 64 | requirement("opencensus-ext-azure"), 65 | requirement("azure-datalake-store"), 66 | ], 67 | ) 68 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/__init__.py: -------------------------------------------------------------------------------- 1 | # DeepGNN note, TriLetterTokenizer, TriletterEmbeddings and WeightPooler were taken from Author's implementation of TwinBERT. 2 | # Details please refer to https://arxiv.org/abs/2002.06275. 3 | 4 | # flake8: noqa 5 | from .encoder import TwinBERTEncoder 6 | from .tokenization import TriLetterTokenizer, StdBertTokenizer 7 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/configuration.py: -------------------------------------------------------------------------------- 1 | class DeepSpeedArgs: 2 | def __init__(self, config: dict, local_rank=0): 3 | 4 | dsconfig = config["deepspeed"] if "deepspeed" in config else {} 5 | 6 | # Use DeepSpeed transformer kernel to accelerate. 7 | self.deepspeed_transformer_kernel = ( 8 | dsconfig["transformer_kernel"] 9 | if "transformer_kernel" in dsconfig 10 | else False 11 | ) 12 | 13 | # Total batch size for training, only used for summary writer. 14 | self.train_batch_size = ( 15 | dsconfig["train_batch_size"] if "train_batch_size" in dsconfig else 1024 16 | ) 17 | 18 | self.train_micro_batch_size_per_gpu = ( 19 | dsconfig["train_micro_batch_size_per_gpu"] 20 | if "train_micro_batch_size_per_gpu" in dsconfig 21 | else 1024 22 | ) 23 | 24 | # Use stochastic mode for high-performance transformer kernel. 25 | self.stochastic_mode = ( 26 | dsconfig["stochastic_mode"] if "stochastic_mode" in dsconfig else False 27 | ) 28 | 29 | # Use DeepSpeed transformer kernel memory optimization to perform invertible normalize 30 | # backpropagation. 31 | self.normalize_invertible = ( 32 | dsconfig["normalize_invertible"] 33 | if "normalize_invertible" in dsconfig 34 | else False 35 | ) 36 | 37 | # random seed for initialization 38 | self.seed = dsconfig["seed"] if "seed" in dsconfig else 42 39 | 40 | self.local_rank = local_rank 41 | 42 | # use global fp16 setting 43 | self.fp16_enabled = config["enable_fp16"] if "enable_fp16" in config else False 44 | 45 | self.apex = dsconfig["apex"] if "apex" in dsconfig else True 46 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/deepspeed/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | load("@rules_python//python:defs.bzl", "py_library") 4 | 5 | py_library( 6 | name = "deepgnn_pytorch_encoding_twinbert_ds", 7 | srcs = [ 8 | "__init__.py", 9 | "convert_bert_ckpt_to_deepspeed.py", 10 | "file_utils.py", 11 | "loss.py", 12 | "nvidia_modeling_no_apex.py", 13 | "nvidia_modeling.py", 14 | ], 15 | visibility = ["//visibility:public"], 16 | ) 17 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | # DeepGNN note, code in this package was taken from commit 1bee84f6eb75ed7e39e34601bfdd66d79cafe99a. 2 | # https://github.com/microsoft/DeepSpeedExamples/tree/1bee84f6eb75ed7e39e34601bfdd66d79cafe99a/BingBertSquad/turing 3 | # Several trivial changes were made on top: 4 | # 1. Removed the dependency for `deepspeed_config` in nvidia_modeling.py. 5 | # 2. Removed `max_seq_length` from the parameter of DeepSpeedTransformerConfig in nvidia_modeling.py 6 | # 3. Replaced `dense_act` with `dense` in nvidia_modeling.py to make it compatible to checkpoint published by transformers@huggingface. 7 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/deepspeed/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # fmt: off 7 | class FocalLoss(nn.Module): 8 | r""" 9 | This criterion is a implemenation of Focal Loss, which is proposed in 10 | Focal Loss for Dense Object Detection. 11 | 12 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 13 | 14 | The losses are averaged across observations for each minibatch. 15 | Args: 16 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 17 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 18 | putting more focus on hard, misclassified examples 19 | size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. 20 | However, if the field size_average is set to False, the losses are 21 | instead summed for each minibatch. 22 | """ 23 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True): 24 | super(FocalLoss, self).__init__() 25 | if alpha is None: 26 | self.alpha = torch.ones(class_num, 1) 27 | else: 28 | if isinstance(alpha, Variable): 29 | self.alpha = alpha 30 | else: 31 | self.alpha = Variable(alpha) 32 | self.gamma = gamma 33 | self.class_num = class_num 34 | self.size_average = size_average 35 | 36 | def forward(self, inputs, targets): 37 | N = inputs.size(0) 38 | C = inputs.size(1) 39 | P = F.softmax(inputs) 40 | 41 | class_mask = inputs.data.new(N, C).fill_(0) 42 | # class_mask = Variable(class_mask) 43 | ids = targets.view(-1, 1) 44 | class_mask.scatter_(1, ids.data, 1.) 45 | 46 | if inputs.is_cuda and not self.alpha.is_cuda: 47 | self.alpha = self.alpha.cuda() 48 | alpha = self.alpha[ids.data.view(-1)] 49 | 50 | probs = (P * class_mask).sum(1).view(-1, 1) 51 | 52 | log_p = probs.log() 53 | 54 | batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p 55 | 56 | if self.size_average: 57 | loss = batch_loss.mean() 58 | else: 59 | loss = batch_loss.sum() 60 | return loss 61 | # fmt: on 62 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TriletterEmbeddings(nn.Module): 6 | """Taken from TwinBERT Author's implementation.""" 7 | 8 | def __init__(self, config): 9 | super(TriletterEmbeddings, self).__init__() 10 | self.max_letters_in_word = config["triletter_max_letters_in_word"] 11 | self.triletter_embeddings = nn.Embedding( 12 | config["vocab_size"] + 1, config["hidden_size"], padding_idx=0 13 | ) 14 | self.position_embeddings = nn.Embedding( 15 | config["max_position_embeddings"] + 1, config["hidden_size"], padding_idx=0 16 | ) 17 | 18 | def forward(self, input_ids, token_type_ids=None): 19 | seq_len = input_ids.shape[1] // self.max_letters_in_word 20 | 21 | position_ids = ( 22 | torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + 1 23 | ) 24 | position_ids = position_ids.unsqueeze(0).repeat(input_ids.shape[0], 1) 25 | 26 | # below two lines may be useful when we want to convert to onnx. 27 | # position_ids[attention_mask == 0] = 0 28 | # position_ids = position_ids.type(torch.float).masked_fill(attention_mask==0, 0.0).type(torch.int64) 29 | 30 | position_embeddings = self.position_embeddings(position_ids) 31 | 32 | # [batch_size, max_seq_len*max_letters_in_word, hidden_size] 33 | embeddings = self.triletter_embeddings(input_ids) 34 | 35 | # [batch_size, max_seq_len, max_letters_in_word, hidden_size] 36 | embeddings = embeddings.view( 37 | -1, seq_len, self.max_letters_in_word, embeddings.shape[-1] 38 | ) 39 | embeddings = embeddings.sum(dim=2).view(-1, seq_len, embeddings.shape[-1]) 40 | embeddings = embeddings + position_embeddings 41 | 42 | return embeddings 43 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/pooler.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as fun 3 | 4 | 5 | class WeightPooler(nn.Module): 6 | """Taken from TwinBERT Author's implementation.""" 7 | 8 | def __init__(self, config): 9 | super(WeightPooler, self).__init__() 10 | self.weighting = nn.Linear(config["hidden_size"], 1) 11 | # scale down to 1e-4 to avoid fp16 overflow. 12 | self.weight_factor = ( 13 | 1e-4 if "enable_fp16" in config and config["enable_fp16"] else 1e-8 14 | ) 15 | 16 | def forward(self, term_tensor, mask): 17 | weights = self.weighting(term_tensor) 18 | weights = ( 19 | weights + (mask - 1).type(weights.dtype).unsqueeze(2) / self.weight_factor 20 | ) 21 | weights = fun.softmax(weights, dim=1) 22 | return (term_tensor * weights).sum(dim=1) 23 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/test_tokenization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from deepgnn.pytorch.encoding.twinbert.tokenization import ( 3 | StdBertTokenizer, 4 | TriLetterTokenizer, 5 | ) 6 | import urllib.request 7 | import zipfile 8 | import tempfile 9 | import pytest 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def prepare_local_test_files(): 14 | name = "twinbert.zip" 15 | working_dir = tempfile.TemporaryDirectory() 16 | zip_file = os.path.join(working_dir.name, name) 17 | urllib.request.urlretrieve( 18 | f"https://deepgraphpub.blob.core.windows.net/public/testdata/{name}", zip_file 19 | ) 20 | with zipfile.ZipFile(zip_file, "r") as zip_ref: 21 | zip_ref.extractall(working_dir.name) 22 | 23 | yield working_dir.name 24 | working_dir.cleanup() 25 | 26 | 27 | @pytest.mark.skip(reason="Deprecated") 28 | def test_stdberttokenizer(prepare_local_test_files): 29 | sentence = "hello world" 30 | tokenizer = StdBertTokenizer( 31 | os.path.join(prepare_local_test_files, "twinbert", "uncased_eng_vocab.tsv") 32 | ) 33 | seq, mask = tokenizer.extract_from_sentence(sentence) 34 | assert seq == [7592, 2088, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 35 | assert mask == [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 36 | 37 | 38 | @pytest.mark.skip(reason="Deprecated") 39 | def test_trilettertokenizer(prepare_local_test_files): 40 | sentence = "hello world" 41 | tokenizer = TriLetterTokenizer( 42 | os.path.join(prepare_local_test_files, "twinbert", "l3g.txt") 43 | ) 44 | seq, mask = tokenizer.extract_from_sentence( 45 | sentence, max_seq_len=5, max_n_letters=3 46 | ) 47 | assert seq == [1121, 1450, 766, 684, 685, 1164, 0, 0, 0, 0, 0, 0, 0, 0, 0] 48 | assert mask == [1, 1, 0, 0, 0] 49 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/encoding/twinbert/tokenization.py: -------------------------------------------------------------------------------- 1 | import re 2 | from transformers.models.bert.tokenization_bert import BertTokenizer # type: ignore 3 | 4 | 5 | class StdBertTokenizer(BertTokenizer): 6 | def __init__(self, vocab_file): 7 | super(StdBertTokenizer, self).__init__(vocab_file) 8 | 9 | def extract_from_sentence(self, sentence, max_seq_len=12): 10 | seq = [0] * max_seq_len 11 | mask = [0] * max_seq_len 12 | ret = self.encode(sentence, add_special_tokens=False) 13 | seq_len = min(max_seq_len, len(ret)) 14 | seq[0:seq_len] = ret[0:seq_len] 15 | mask[0:seq_len] = [1] * seq_len 16 | return seq, mask 17 | 18 | 19 | class TriLetterTokenizer: 20 | """Taken from TwinBERT Author's implementation.""" 21 | 22 | def __init__(self, l3g_path): 23 | self._init_lg3_dict(l3g_path) 24 | self.invalid = re.compile("[^a-zA-Z0-9 ]") 25 | self.multispace = re.compile(" +") 26 | 27 | def _init_lg3_dict(self, l3g_path): 28 | self.l3g_dict = {} 29 | with open(l3g_path, "r", encoding="utf-8") as fin: 30 | for i, token in enumerate(fin): 31 | token = token.strip("\r\n") 32 | if len(token) == 0: 33 | continue 34 | # reserve 0 as default, start from 1 35 | self.l3g_dict[token] = i + 1 36 | 37 | def extract_from_sentence(self, text, max_seq_len=12, max_n_letters=20): 38 | step1 = text.lower() 39 | step2 = self.invalid.sub("", step1) 40 | step3 = self.multispace.sub(" ", step2) 41 | step4 = step3.strip() 42 | words = step4.split(" ") 43 | return self._from_words_to_id_sequence(words, max_seq_len, max_n_letters) 44 | 45 | def _from_words_to_id_sequence(self, words, max_seq_len=12, max_n_letters=20): 46 | n_seq = min(len(words), max_seq_len) 47 | n_letter = max_n_letters 48 | feature_seq = [0] * (max_seq_len * max_n_letters) 49 | seq_mask = [0] * max_seq_len 50 | for i in range(n_seq): 51 | if words[i] == "": 52 | words[i] = "#" 53 | word = "#" + words[i] + "#" 54 | n_letter = min(len(word) - 2, max_n_letters) 55 | for j in range(n_letter): 56 | s = word[j : (j + 3)] 57 | if s in self.l3g_dict: 58 | feature_seq[i * max_n_letters + j] = self.l3g_dict[s] 59 | seq_mask[i] = 1 60 | return feature_seq, seq_mask 61 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/modeling/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "deepgnn_pytorch_modeling", 8 | srcs = [ 9 | "__init__.py", 10 | "base_model.py", 11 | ], 12 | visibility = ["//visibility:public"], 13 | ) 14 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .base_model import BaseModel, BaseSupervisedModel, BaseUnsupervisedModel 6 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/nn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "deepgnn_pytorch_nn", 9 | srcs = [ 10 | "__init__.py", 11 | "gat_conv.py", 12 | ], 13 | visibility = ["//visibility:public"], 14 | ) 15 | 16 | py_test( 17 | name = "test_conv", 18 | srcs = ["test_conv.py"], 19 | imports = ["../../../"], 20 | main = "test_conv.py", 21 | python_version = "PY3", 22 | srcs_version = "PY3", 23 | deps = [ 24 | ":deepgnn_pytorch_nn", 25 | requirement("numpy"), 26 | requirement("fsspec"), 27 | requirement("pytest"), 28 | requirement("scikit-learn"), 29 | requirement("torch"), 30 | requirement("networkx"), 31 | requirement("opencensus"), 32 | requirement("opencensus-context"), 33 | requirement("opencensus-ext-azure"), 34 | requirement("azure-datalake-store"), 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Shared torch.nn modules.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/nn/test_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import numpy as np 6 | import random 7 | 8 | from deepgnn.pytorch.nn.gat_conv import AttnHead 9 | 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | 18 | def test_attn_head(): 19 | N = 5 20 | D = 3 21 | 22 | set_seed(123) 23 | x = np.random.rand(N, D).astype(np.float32) # [N, D] 24 | adj = np.random.randint(2, size=N * N).astype(np.float32).reshape(N, N) 25 | 26 | def run_dense_version(): 27 | set_seed(123) 28 | attn_layer = AttnHead(D, 4, in_drop=0.2, coef_drop=0.0) 29 | x2 = torch.from_numpy(x) 30 | adj2 = torch.from_numpy(adj) 31 | out = attn_layer(x2, adj2) 32 | return out 33 | 34 | def run_sparse_version(): 35 | set_seed(123) 36 | attn_layer = AttnHead(D, 4, in_drop=0.2, coef_drop=0.0) 37 | row, col = np.nonzero(adj) 38 | edge = np.concatenate([row.reshape(-1, 1), col.reshape(-1, 1)], axis=1) 39 | edge_value = np.ones(edge.shape[0], np.float32) 40 | edge = np.transpose(edge) 41 | sp_adj = torch.sparse_coo_tensor(edge, edge_value, [N, N]) 42 | x2 = torch.from_numpy(x) 43 | out = attn_layer(x2, sp_adj) 44 | return out 45 | 46 | dense_out = run_dense_version() 47 | sparse_out = run_sparse_version() 48 | np.testing.assert_allclose( 49 | dense_out.detach().numpy(), sparse_out.detach().numpy(), atol=0.001 50 | ) 51 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/training/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "deepgnn_pytorch_training", 9 | srcs = [ 10 | "__init__.py", 11 | "args.py", 12 | "factory.py", 13 | "trainer_ddp.py", 14 | "trainer_fp16.py", 15 | "trainer_hvd.py", 16 | "trainer.py", 17 | "utils.py", 18 | ], 19 | deps = [ 20 | "//src/python/deepgnn/pytorch/modeling:deepgnn_pytorch_modeling", 21 | "//src/python/deepgnn/pytorch/encoding:deepgnn_pytorch_encoding", 22 | ], 23 | visibility = ["//visibility:public"], 24 | ) 25 | 26 | py_test( 27 | name = "test_trainer", 28 | srcs = ["test_trainer.py"], 29 | imports = ["../../../"], 30 | main = "test_trainer.py", 31 | python_version = "PY3", 32 | srcs_version = "PY3", 33 | deps = select({ 34 | "@platforms//os:linux": [ 35 | ":deepgnn_pytorch_training", 36 | requirement("horovod"), 37 | requirement("fsspec"), 38 | requirement("numpy"), 39 | requirement("pytest"), 40 | requirement("scikit-learn"), 41 | requirement("torch"), 42 | requirement("networkx"), 43 | requirement("tensorboard"), 44 | requirement("transformers"), 45 | requirement("opencensus"), 46 | requirement("opencensus-context"), 47 | requirement("opencensus-ext-azure"), 48 | requirement("azure-datalake-store"), 49 | requirement("tenacity"), 50 | ], 51 | "@platforms//os:macos": [ 52 | ":deepgnn_pytorch_training", 53 | requirement("horovod"), 54 | requirement("numpy"), 55 | requirement("fsspec"), 56 | requirement("pytest"), 57 | requirement("scikit-learn"), 58 | requirement("torch"), 59 | requirement("networkx"), 60 | requirement("tensorboard"), 61 | requirement("transformers"), 62 | requirement("opencensus"), 63 | requirement("opencensus-context"), 64 | requirement("opencensus-ext-azure"), 65 | requirement("azure-datalake-store"), 66 | requirement("tenacity"), 67 | ], 68 | "@platforms//os:windows": [ 69 | ":deepgnn_pytorch_training", 70 | requirement("numpy"), 71 | requirement("pytest"), 72 | requirement("scikit-learn"), 73 | requirement("fsspec"), 74 | requirement("torch"), 75 | requirement("networkx"), 76 | requirement("tensorboard"), 77 | requirement("transformers"), 78 | requirement("opencensus"), 79 | requirement("opencensus-context"), 80 | requirement("opencensus-ext-azure"), 81 | requirement("azure-datalake-store"), 82 | requirement("tenacity"), 83 | ], 84 | }), 85 | ) 86 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .factory import run_dist 6 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/training/trainer_hvd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Distributed training with horovod.""" 4 | 5 | from typing import Tuple 6 | import argparse 7 | import torch 8 | from torch.optim import Optimizer 9 | from deepgnn.pytorch.training.trainer_fp16 import FP16Trainer, BaseModel 10 | from deepgnn.pytorch.training.utils import disable_infini_band 11 | import horovod.torch as hvd # type: ignore 12 | 13 | 14 | class HVDTrainer(FP16Trainer): 15 | """Horovod based distributed trainer.""" 16 | 17 | def __init__(self, args: argparse.Namespace): 18 | """Initialize horovod.""" 19 | super().__init__(args) 20 | self._init_hvd() 21 | 22 | def _evaluate(self, model: BaseModel) -> Tuple[torch.Tensor, torch.Tensor]: 23 | metric, loss = super()._evaluate(model) 24 | metric = hvd.allreduce(metric) 25 | loss = hvd.allreduce(loss) 26 | self.logger.info( 27 | self._wrap_log( 28 | f"AllReduced {self.model.metric_name()}: {metric:.4f}; loss: {loss:.4f}" 29 | ) 30 | ) 31 | return metric, loss 32 | 33 | def _init_hvd(self): 34 | if self.args.disable_ib: 35 | disable_infini_band() 36 | hvd.init() 37 | self.rank = hvd.rank() 38 | self.local_rank = hvd.local_rank() 39 | self.world_size = hvd.size() 40 | self.logger.info( 41 | f"Initialized horovod trainer. rank:{self.rank}, local_rank:{self.local_rank}," 42 | f" world_size:{self.world_size}" 43 | ) 44 | 45 | def _init_model(self, model: BaseModel) -> BaseModel: 46 | model = super()._init_model(model) 47 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 48 | return model 49 | 50 | def _init_optimizer(self, optimizer: Optimizer) -> Optimizer: 51 | optimizer = super()._init_optimizer(optimizer) 52 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 53 | compression = ( 54 | hvd.Compression.fp16 if self.fp16_enabled() else hvd.Compression.none 55 | ) 56 | return hvd.DistributedOptimizer( 57 | optimizer=optimizer, 58 | named_parameters=self.model.named_parameters(), 59 | compression=compression, 60 | op=hvd.Average, 61 | ) 62 | 63 | def _train_one_epoch(self, model: BaseModel, epoch: int): 64 | super()._train_one_epoch(model, epoch) 65 | hvd.join() 66 | 67 | def _inference(self, model: BaseModel): 68 | super()._inference(model) 69 | hvd.join() 70 | 71 | def _apex_backward(self, scaled_loss: torch.Tensor): 72 | scaled_loss.backward() 73 | self.optimizer.synchronize() # type: ignore 74 | 75 | def _apex_step(self): 76 | with self.optimizer.skip_synchronize(): 77 | self.optimizer.step() 78 | 79 | def _amp_backward(self, loss): 80 | self.grad_scaler.scale(loss).backward() 81 | self.optimizer.synchronize() 82 | 83 | def _amp_step(self): 84 | with self.optimizer.skip_synchronize(): 85 | self.grad_scaler.step(self.optimizer) 86 | -------------------------------------------------------------------------------- /src/python/deepgnn/pytorch/training/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Training utility functions.""" 4 | 5 | import os 6 | from deepgnn import get_logger 7 | 8 | 9 | def disable_infini_band(): 10 | """Disable InifiniBand for communication.""" 11 | os.environ["GLOO_SOCKET_IFNAME"] = "eth0" 12 | os.environ["NCCL_SOCKET_IFNAME"] = "eth0" 13 | os.environ["NCCL_IB_DISABLE"] = "1" 14 | get_logger().warn("InfiniBand(IB) has been disabled, use eth0 instead.") 15 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "deepgnn_tf", 8 | srcs = [ 9 | "__init__.py", 10 | ], 11 | deps = [ 12 | "//src/python/deepgnn:deepgnn", 13 | "//src/python/deepgnn/graph_engine:graph_engine", 14 | "//src/python/deepgnn/graph_engine:graph_engine_exports", 15 | ], 16 | visibility = ["//visibility:public"], 17 | deprecation = "This target is deprecated", 18 | ) 19 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """DeepGNN modules to work with Tensorflow.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from deepgnn.tf.common.hooks import * 6 | from deepgnn.tf.common.dist_sync import * 7 | from deepgnn.tf.common.args import * 8 | from deepgnn.tf.common.utils import * 9 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/common/dist_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Synchronization for distributed training.""" 4 | import tensorflow as tf 5 | import os 6 | import datetime 7 | import time 8 | import glob 9 | 10 | 11 | class DistributedSync: 12 | """Distributed synchronization with plain files.""" 13 | 14 | SYNC_FILE = "sync" 15 | 16 | def __init__(self, folder: str, task_index: int, num_tasks: int): 17 | """ 18 | Initialize folder. 19 | 20 | Task with index 0 must be called first to start from a fresh state. 21 | """ 22 | self.folder = folder 23 | self.task_index = task_index 24 | self.num_tasks = num_tasks 25 | if self.task_index == 0: 26 | self._cleanup_sync_files() 27 | 28 | def _cleanup_sync_files(self): 29 | retry = 100 30 | filelist = glob.glob(os.path.join(self.folder, "{}.*".format(self.SYNC_FILE))) 31 | for fname in filelist: 32 | while os.path.exists(fname) and retry > 0: 33 | try: 34 | tf.compat.v1.logging.info("remove {}".format(fname)) 35 | os.remove(fname) 36 | except FileNotFoundError as err: 37 | tf.compat.v1.logging.info( 38 | "Oops! Delete file ({0}). OSError: {1}".format(fname, err) 39 | ) 40 | time.sleep(60) 41 | retry -= 1 42 | 43 | def sync(self, tag: str): 44 | """Block until all workers are ready.""" 45 | with open( 46 | os.path.join( 47 | self.folder, "{}.{}.{}".format(self.SYNC_FILE, tag, self.task_index) 48 | ), 49 | "w", 50 | ) as w: 51 | w.write(str(datetime.datetime.now())) 52 | for i in range(self.num_tasks): 53 | while not os.path.exists( 54 | os.path.join(self.folder, "{}.{}.{}".format(self.SYNC_FILE, tag, i)) 55 | ): 56 | time.sleep(30) 57 | tf.compat.v1.logging.info("worker {}-{} is not ready...".format(i, tag)) 58 | tf.compat.v1.logging.info("all workers-{} are done.".format(tag)) 59 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/common/hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Hooks for distributed training.""" 4 | from typing import Optional, List 5 | import tensorflow as tf 6 | import os 7 | import sys 8 | from deepgnn.tf.common.dist_sync import DistributedSync 9 | 10 | 11 | class ChiefCheckpointSaverHook(tf.estimator.CheckpointSaverHook): 12 | """ 13 | Chief only hooks, chief session will call `dist_sync.sync()` to wait all workers session. 14 | 15 | Once all workers finished, chief session run `end(session)` to save the final checkpoint. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | dist_sync: DistributedSync, 21 | checkpoint_dir: str, 22 | save_secs: Optional[int] = None, 23 | save_steps: Optional[int] = None, 24 | saver: Optional[object] = None, 25 | checkpoint_basename: str = "model.ckpt", 26 | scaffold: Optional[object] = None, 27 | listeners: Optional[List[object]] = None, 28 | ): 29 | """Initialize hook.""" 30 | super().__init__( 31 | checkpoint_dir, 32 | save_secs, 33 | save_steps, 34 | saver, 35 | checkpoint_basename, 36 | scaffold, 37 | listeners, 38 | ) 39 | assert dist_sync.task_index == 0 40 | self.dist_sync = dist_sync 41 | 42 | def end(self, session): 43 | """End session.""" 44 | self.dist_sync.sync("session") 45 | super().end(session) 46 | 47 | # WORKAROUND: in windows we need to create this folder otherwise we will get this error: 48 | # Failed to create a NewWriteableFile 49 | def _save(self, session, step): 50 | if sys.platform == "win32": 51 | path = f"{self._save_path}-{step}_temp" 52 | os.makedirs(path, exist_ok=True) 53 | super()._save(session, step) 54 | 55 | 56 | class SessionExitHook(tf.estimator.SessionRunHook): 57 | """Synchronize training at exit.""" 58 | 59 | def __init__(self, dist_sync: DistributedSync): 60 | """Create lock.""" 61 | self.dist_sync = dist_sync 62 | 63 | def end(self, session): 64 | """Wait for all workers to finish.""" 65 | self.dist_sync.sync("session") 66 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/common/test_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import glob 5 | import numpy as np 6 | import os 7 | import tensorflow as tf 8 | 9 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 10 | from deepgnn import get_logger 11 | 12 | 13 | class TestHelper: 14 | @staticmethod 15 | def get_tf2_summary_value(summary_dir, metric_name): 16 | 17 | events_file_pattern = os.path.join(summary_dir, "events*") 18 | events_files = sorted(glob.glob(events_file_pattern)) 19 | get_logger().info(f"event files: {events_files}") 20 | events = EventAccumulator(summary_dir) 21 | events.Reload() 22 | 23 | metric_values = [] 24 | for w, s, t in events.Tensors(metric_name): 25 | val = tf.make_ndarray(t) 26 | metric_values.append(np.asscalar(val)) 27 | return metric_values 28 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/common/unittest/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_test") 5 | load("//config:variables.bzl", "PLATFORM_DEFINES") 6 | load("@rules_python//python:defs.bzl", "py_library") 7 | load("@pip_deps//:requirements.bzl", "requirement") 8 | 9 | py_test( 10 | name = "test_dist_sync", 11 | srcs = ["test_dist_sync.py"], 12 | imports = ["../../../../"], 13 | main = "test_dist_sync.py", 14 | python_version = "PY3", 15 | srcs_version = "PY3", 16 | deps = [ 17 | "//src/python/deepgnn/tf/common:deepgnn_tf_common", 18 | requirement("numpy"), 19 | requirement("pytest"), 20 | requirement("fsspec"), 21 | requirement("scikit-learn"), 22 | requirement("tensorflow"), 23 | requirement("networkx"), 24 | requirement("opencensus"), 25 | requirement("opencensus-context"), 26 | requirement("opencensus-ext-azure"), 27 | requirement("azure-datalake-store"), 28 | requirement("tenacity"), 29 | ], 30 | tags = ["manual"], 31 | deprecation = "This test is deprecated", 32 | ) 33 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/encoders/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "deepgnn_tf_encoders", 8 | srcs = [ 9 | "__init__.py", 10 | "att_encoder.py", 11 | "han_encoder.py", 12 | ], 13 | visibility = ["//visibility:public"], 14 | deprecation = "This target is deprecated", 15 | ) 16 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # flake8: noqa 5 | from .att_encoder import AttEncoder 6 | from .han_encoder import HANEncoder 7 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/layers/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | 6 | py_library( 7 | name = "deepgnn_tf_layers", 8 | srcs = [ 9 | "__init__.py", 10 | "attention_header.py", 11 | "base.py", 12 | ], 13 | visibility = ["//visibility:public"], 14 | deprecation = "This target is deprecated", 15 | ) 16 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Alibaba Group Holding Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # flake8: noqa 17 | from deepgnn.tf.layers.base import Layer, Dense, Embedding, SparseEmbedding 18 | from .attention_header import AttentionHeader 19 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/layers/attention_header.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """GAT model layer.""" 4 | from .base import Layer 5 | import tensorflow as tf 6 | 7 | 8 | class AttentionHeader(Layer): 9 | """Attention header for GAT, reference: https://github.com/PetarV-/GAT/blob/master/utils/layers.py.""" 10 | 11 | def __init__(self, out_size, act=None): 12 | """Initialize layer.""" 13 | super().__init__() 14 | self.out_dim = out_size 15 | self.act = act 16 | 17 | def build(self, input_shape): 18 | """Create convolution layers.""" 19 | self.conv0 = tf.keras.layers.Conv1D(self.out_dim, 1, use_bias=False) 20 | self.conv1 = tf.keras.layers.Conv1D(1, 1) 21 | self.conv2 = tf.keras.layers.Conv1D(1, 1) 22 | self.bias = tf.compat.v1.get_variable( 23 | "att.bias", 24 | shape=[self.out_dim], 25 | initializer=tf.compat.v1.zeros_initializer(), 26 | ) 27 | self.built = True 28 | 29 | def call(self, seq): 30 | """Compute embeddings.""" 31 | seq_fts = self.conv0(seq) 32 | f_1 = self.conv1(seq_fts) 33 | f_2 = self.conv2(seq_fts) 34 | logits = f_1 + tf.transpose(a=f_2, perm=[0, 2, 1]) 35 | coefs = tf.nn.softmax(tf.nn.leaky_relu(logits)) 36 | vals = tf.matmul(coefs, seq_fts) 37 | tf.compat.v1.logging.info( 38 | "AttentionHeader Tensor Shape: seq_flts {0}, f_1 {1}, f_2 {2}, logits {3}, coefs {4}, vals {5}".format( 39 | seq_fts.shape, 40 | f_1.shape, 41 | f_2.shape, 42 | logits.shape, 43 | coefs.shape, 44 | vals.shape, 45 | ) 46 | ) 47 | ret = tf.nn.bias_add(vals, self.bias) 48 | if self.act is None: 49 | return ret 50 | else: 51 | return self.act(ret) 52 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/nn/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load("@rules_python//python:defs.bzl", "py_library") 5 | load("@pip_deps//:requirements.bzl", "requirement") 6 | 7 | py_library( 8 | name = "deepgnn_tf_nn", 9 | srcs = [ 10 | "__init__.py", 11 | "gat_conv.py", 12 | "gcn_conv.py", 13 | "metrics.py", 14 | "sage_conv.py", 15 | ], 16 | visibility = ["//visibility:public"], 17 | deprecation = "This target is deprecated", 18 | ) 19 | 20 | py_test( 21 | name = "test_conv", 22 | srcs = ["test_conv.py"], 23 | imports = ["../../../"], 24 | main = "test_conv.py", 25 | python_version = "PY3", 26 | srcs_version = "PY3", 27 | deps = [ 28 | ":deepgnn_tf_nn", 29 | requirement("numpy"), 30 | requirement("fsspec"), 31 | requirement("pytest"), 32 | requirement("scikit-learn"), 33 | requirement("tensorflow"), 34 | requirement("networkx"), 35 | requirement("opencensus"), 36 | requirement("opencensus-context"), 37 | requirement("opencensus-ext-azure"), 38 | requirement("azure-datalake-store"), 39 | ], 40 | tags = ["manual"], 41 | deprecation = "This test is deprecated", 42 | ) 43 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """NN modules for TF models.""" 4 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/nn/gcn_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Convolution functions for GCN model.""" 4 | import tensorflow as tf 5 | from typing import Callable 6 | 7 | 8 | def gcn_norm_adj(adj: tf.Tensor) -> tf.Tensor: 9 | """Symmetrically normalize adjacency matrix.""" 10 | if isinstance(adj, tf.SparseTensor): 11 | edges = adj.indices 12 | edges_value = adj.values 13 | adj_shape = adj.dense_shape 14 | degree = tf.scatter_nd( 15 | tf.reshape(edges[:, 0], (-1, 1)), edges_value, [adj_shape[0]] 16 | ) 17 | d_inv = tf.math.pow(degree, -0.5) 18 | d_inv = tf.where(tf.math.is_inf(d_inv), tf.zeros_like(d_inv), d_inv) 19 | d_inv_src = tf.nn.embedding_lookup(d_inv, edges[:, 0]) 20 | d_inv_dst = tf.nn.embedding_lookup(d_inv, edges[:, 1]) 21 | edge_weight = d_inv_src * d_inv_dst 22 | sp_adj = tf.sparse.SparseTensor(edges, edge_weight, dense_shape=adj_shape) 23 | return sp_adj 24 | else: 25 | rowsum = tf.reduce_sum(adj, axis=1) 26 | d_inv_sqrt = tf.math.pow(rowsum, -0.5) 27 | d_inv_sqrt = tf.where( 28 | tf.math.is_inf(d_inv_sqrt), tf.zeros_like(d_inv_sqrt), d_inv_sqrt 29 | ) 30 | d_mat_inv_sqrt = tf.linalg.diag(d_inv_sqrt) 31 | return d_mat_inv_sqrt @ adj @ d_mat_inv_sqrt 32 | 33 | 34 | class GCNConv(tf.keras.layers.Layer): 35 | """Graph Conv Layer.""" 36 | 37 | def __init__( 38 | self, 39 | out_dim: int, 40 | dropout: float = 0.0, 41 | act: Callable = tf.nn.relu, 42 | use_bias: bool = False, 43 | ): 44 | """Initialize convolution layer.""" 45 | # TODO: support sparse input 46 | super().__init__() 47 | 48 | self.out_dim = out_dim 49 | self.act = act 50 | self.dropout = dropout 51 | self.use_bias = use_bias 52 | 53 | self.w = tf.keras.layers.Dense(self.out_dim, use_bias=False, name="w") 54 | self.bias = self.add_weight( 55 | name="gcn.bias", 56 | shape=[self.out_dim], 57 | initializer="zeros", 58 | dtype=tf.float32, 59 | trainable=True, 60 | ) 61 | 62 | def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor: 63 | """Compute embeddings.""" 64 | x, adj = inputs 65 | x = tf.nn.dropout(x, rate=self.dropout) 66 | x = self.w(x) 67 | if isinstance(adj, tf.SparseTensor): 68 | support = tf.sparse.sparse_dense_matmul(adj, x) 69 | else: 70 | support = tf.matmul(adj, x, a_is_sparse=True) 71 | # output = tf.add_n(supports) # skip this because len(support) == 1 72 | output = support 73 | 74 | if self.use_bias: 75 | output = tf.nn.bias_add(output, self.bias) # [N, F'] 76 | return output 77 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/nn/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Various metrics implementations.""" 4 | import tensorflow as tf 5 | from typing import Optional 6 | 7 | 8 | def masked_softmax_cross_entropy( 9 | preds: tf.Tensor, labels: tf.Tensor, mask: Optional[tf.Tensor] = None 10 | ) -> tf.Tensor: 11 | """Softmax cross-entropy loss with masking.""" 12 | if mask is not None: 13 | preds = preds[mask] 14 | labels = labels[mask] 15 | loss = tf.nn.softmax_cross_entropy_with_logits( 16 | logits=preds, labels=tf.stop_gradient(labels) 17 | ) 18 | return tf.reduce_mean(loss) 19 | 20 | 21 | def masked_accuracy( 22 | preds: tf.Tensor, 23 | labels: tf.Tensor, 24 | mask: Optional[tf.Tensor] = None, 25 | dtype: tf.dtypes.DType = tf.float32, 26 | ) -> tf.Tensor: 27 | """Accuracy with masking.""" 28 | if mask is not None: 29 | preds = preds[mask] 30 | labels = labels[mask] 31 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1)) 32 | accuracy_all = tf.cast(correct_prediction, dtype) 33 | return tf.reduce_mean(accuracy_all) 34 | -------------------------------------------------------------------------------- /src/python/deepgnn/tf/nn/test_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from deepgnn.tf.nn.gcn_conv import gcn_norm_adj 8 | from deepgnn.tf.nn.gat_conv import AttnHead 9 | 10 | 11 | def set_seed(seed): 12 | tf.random.set_seed(seed) 13 | np.random.seed(seed) 14 | 15 | 16 | def test_attn_head(): 17 | N = 5 18 | D = 3 19 | 20 | set_seed(123) 21 | x = np.random.rand(N, D).astype(np.float32) # [N, D] 22 | adj = np.random.randint(2, size=N * N).astype(np.float32).reshape(N, N) 23 | 24 | def run_dense_version(): 25 | set_seed(123) 26 | attn_layer = AttnHead(4, act=tf.nn.elu, in_drop=0.1, coef_drop=0.0) 27 | out = attn_layer([x, adj]) 28 | return out 29 | 30 | def run_sparse_version(): 31 | set_seed(123) 32 | attn_layer = AttnHead(4, act=tf.nn.elu, in_drop=0.1, coef_drop=0.0) 33 | row, col = np.nonzero(adj) 34 | edge = np.concatenate([row.reshape(-1, 1), col.reshape(-1, 1)], axis=1) 35 | edge_value = np.ones(edge.shape[0], np.float32) 36 | adj_shape = np.array([N, N], np.int64) 37 | sp_adj = tf.SparseTensor(edge, edge_value, adj_shape) 38 | attn_layer = AttnHead(4, act=tf.nn.elu, in_drop=0.1, coef_drop=0.0) 39 | out = attn_layer([x, sp_adj]) 40 | return out 41 | 42 | # fmt: off 43 | expected_out = np.array( 44 | [[0.13966814, -0.43618077, 1.1495067, 0.35430294], 45 | [0.10271902, -0.52295756, 0.5240147, 0.25755188], 46 | [-0.0497849, -0.49666244, 0.68464315, 0.2244271], 47 | [0.4595074, -0.5716846, 0.7725301, 0.], 48 | [0.22125529, -0.539732, 0.6065793, 0.1719851]], 49 | np.float32, 50 | ) 51 | # fmt: on 52 | dense_out = run_dense_version() 53 | tf.debugging.assert_near(dense_out, expected_out, atol=0.0001) 54 | 55 | _ = run_sparse_version() 56 | tf.debugging.assert_near(dense_out, expected_out, atol=0.0001) 57 | 58 | 59 | def test_gcn_norm(): 60 | N = 5 61 | D = 3 62 | set_seed(123) 63 | _ = np.random.rand(N, D) # [N, D] 64 | adj = np.random.randint(2, size=N * N).astype(np.float32).reshape(N, N) 65 | 66 | def run_dense_adj(raw_adj): 67 | adj1 = gcn_norm_adj(raw_adj) 68 | return adj1 69 | 70 | def run_sparse_adj(raw_adj): 71 | row, col = np.nonzero(raw_adj) 72 | edge = np.concatenate([row.reshape(-1, 1), col.reshape(-1, 1)], axis=1) 73 | edge_value = np.ones(edge.shape[0], np.float32) 74 | adj_shape = np.array([N, N], np.int64) 75 | sp_adj_raw = tf.sparse.SparseTensor(edge, edge_value, adj_shape) 76 | sp_adj = gcn_norm_adj(sp_adj_raw) 77 | return sp_adj 78 | 79 | norm_adj1 = run_dense_adj(adj) 80 | assert isinstance(norm_adj1, tf.Tensor) 81 | 82 | norm_adj2_sp = run_sparse_adj(adj) 83 | norm_adj2 = tf.sparse.to_dense(norm_adj2_sp) 84 | assert isinstance(norm_adj2_sp, tf.SparseTensor) 85 | tf.debugging.assert_equal(norm_adj1, norm_adj2) 86 | -------------------------------------------------------------------------------- /src/python/deepgnn/train_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | """Enums to define training.""" 4 | from enum import Enum 5 | 6 | 7 | class TrainerType(Enum): 8 | """Trainer types. 9 | 10 | DeepGNN currently support 4 trainer types: 11 | * BASE: The most basic local trainer, used for simple experiments. 12 | * PS: Parameter server based distributed trainer, used for tensorflow models only. 13 | * HVD: Horovod based distributed trainer, supports both tensorflow and pytorch models. 14 | * DDP: DistributedDataParallel based distributed trainer, used for pytorch models only. 15 | """ 16 | 17 | BASE = "base" 18 | PS = "ps" 19 | MULTINODE = "multinode" 20 | HVD = "hvd" 21 | DDP = "ddp" 22 | 23 | def __str__(self): 24 | """Convert enum to string.""" 25 | return self.value 26 | 27 | 28 | class TrainMode(Enum): 29 | """What to do with a model.""" 30 | 31 | TRAIN = "train" 32 | EVALUATE = "evaluate" 33 | INFERENCE = "inference" 34 | 35 | def __str__(self): 36 | """Convert enum to string.""" 37 | return self.value 38 | -------------------------------------------------------------------------------- /tools/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | -------------------------------------------------------------------------------- /tools/manylinux/README.md: -------------------------------------------------------------------------------- 1 | This folder contains scripts to install gcc11 compiler to generate manylinux14 compatible wheels taken from TF repo. 2 | To install it locally run: 3 | ```bash 4 | sudo ./install-gcc11.sh 5 | ``` 6 | 7 | Use `manylinux` configuration and `--force_pic` argument to build an almost self contained shared library with bazel: 8 | ```bash 9 | bazel build -c opt //src/cc/lib:wrapper --config=manylinux --force_pic 10 | ``` 11 | -------------------------------------------------------------------------------- /tools/manylinux/fixlinks.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # 18 | # Re-direct all links in $1 that point to /lib... to point to $1/lib... instead. 19 | 20 | BASE="$1" 21 | find "${BASE}" -type l | \ 22 | while read l ; do 23 | if [[ "$(readlink "$l")" == /lib* ]]; then 24 | ORIG="$(readlink "$l")"; 25 | rm "$l"; 26 | ln -s "${BASE}${ORIG}" "$l" 27 | fi 28 | done 29 | -------------------------------------------------------------------------------- /tools/manylinux/rpm-patch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # Given an RPM spec file $1, apply its patches. 18 | 19 | SPEC="$1" 20 | grep '%patch' "${SPEC}" |while read cmd ; do 21 | N=$(echo "${cmd}" |sed 's,%patch\([0-9]\+\).*,\1,') 22 | file=$(grep "Patch$N:" "${SPEC}" |sed 's,.*: ,,') 23 | parg=$(echo "${cmd}" |sed 's,.*\(-p[0-9]\).*,\1,') 24 | if [[ ! "${file}" =~ doxygen && "${cmd}" != \#* ]]; then 25 | echo "patch ${parg} -s < ${file}" 26 | patch ${parg} -s < "${file}" 27 | fi 28 | done 29 | -------------------------------------------------------------------------------- /tools/toolchain/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | load(":cc_toolchain_config.bzl", "cc_toolchain_config") 5 | 6 | package(default_visibility = ["//visibility:public"]) 7 | 8 | cc_toolchain_config( 9 | name = "k8_toolchain_config", 10 | ) 11 | 12 | filegroup(name = "empty") 13 | 14 | cc_toolchain( 15 | name = "k8_toolchain", 16 | all_files = ":empty", 17 | compiler_files = ":empty", 18 | dwp_files = ":empty", 19 | linker_files = ":empty", 20 | objcopy_files = ":empty", 21 | strip_files = ":empty", 22 | supports_param_files = 0, 23 | toolchain_config = ":k8_toolchain_config", 24 | toolchain_identifier = "k8-toolchain", 25 | ) 26 | 27 | cc_toolchain_suite( 28 | name = "manylinux", 29 | toolchains = { 30 | "k8": ":k8_toolchain", 31 | }, 32 | ) 33 | --------------------------------------------------------------------------------