├── tensorflow_gnn ├── data │ ├── __init__.py │ └── BUILD ├── graph │ ├── __init__.py │ ├── dict_utils.py │ ├── dict_utils_test.py │ └── graph_tensor_pprint_test.py ├── tools │ ├── __init__.py │ ├── BUILD │ ├── generate_training_data_test.py │ └── validate_graph_schema.py ├── converters │ ├── __init__.py │ ├── ogb │ │ └── __init__.py │ └── triples_test.py ├── utils │ ├── __init__.py │ └── BUILD ├── api_def │ ├── gcn-symbols.txt │ ├── hgt-symbols.txt │ ├── mt_albis-symbols.txt │ ├── vanilla_mpnn-symbols.txt │ ├── graph_sage-symbols.txt │ ├── gat_v2-symbols.txt │ ├── multi_head_attention-symbols.txt │ ├── contrastive_losses-symbols.txt │ ├── sampler-symbols.txt │ └── runner-symbols.txt ├── docs │ ├── guide │ │ └── images │ │ │ ├── homogeneous.png │ │ │ └── heterogeneous.png │ └── api_docs │ │ ├── python │ │ ├── models │ │ │ ├── gcn │ │ │ │ └── all_symbols.md │ │ │ ├── mt_albis │ │ │ │ ├── all_symbols.md │ │ │ │ └── graph_update_get_config_dict.md │ │ │ ├── vanilla_mpnn │ │ │ │ ├── all_symbols.md │ │ │ │ └── graph_update_get_config_dict.md │ │ │ ├── gat_v2 │ │ │ │ ├── graph_update_get_config_dict.md │ │ │ │ └── all_symbols.md │ │ │ ├── multi_head_attention │ │ │ │ ├── graph_update_get_config_dict.md │ │ │ │ └── all_symbols.md │ │ │ ├── contrastive_losses │ │ │ │ ├── DeepGraphInfomaxLogits.md │ │ │ │ ├── TripletEmbeddingSquaredDistances.md │ │ │ │ ├── DropoutFeatures.md │ │ │ │ ├── Corruptor.md │ │ │ │ ├── ShuffleFeaturesGlobally.md │ │ │ │ └── coherence.md │ │ │ ├── graph_sage │ │ │ │ └── all_symbols.md │ │ │ ├── gcn.md │ │ │ ├── vanilla_mpnn.md │ │ │ ├── mt_albis.md │ │ │ ├── graph_sage.md │ │ │ ├── gat_v2.md │ │ │ └── multi_head_attention.md │ │ ├── runner │ │ │ ├── Loss.md │ │ │ ├── Predictions.md │ │ │ ├── Losses.md │ │ │ ├── ContextLabelFn.md │ │ │ ├── Metrics.md │ │ │ ├── one_node_per_component.md │ │ │ ├── RootNodeLabelFn.md │ │ │ ├── GraphTensorProcessorFn.md │ │ │ ├── DatasetProvider.md │ │ │ ├── GraphTensorPadding.md │ │ │ ├── ModelExporter.md │ │ │ ├── PassthruDatasetProvider.md │ │ │ ├── incrementing_model_dir.md │ │ │ ├── TightPadding.md │ │ │ ├── PassthruSampleDatasetsProvider.md │ │ │ ├── FitOrSkipPadding.md │ │ │ └── TFDataServiceConfig.md │ │ └── tfgnn │ │ │ ├── Field.md │ │ │ ├── FieldSpec.md │ │ │ ├── Fields.md │ │ │ ├── IncidentNodeOrContextTag.md │ │ │ ├── FieldsSpec.md │ │ │ ├── FieldOrFields.md │ │ │ ├── reverse_tag.md │ │ │ ├── enable_graph_tensor_validation.md │ │ │ ├── is_graph_tensor.md │ │ │ ├── enable_graph_tensor_validation_at_runtime.md │ │ │ ├── get_registered_reduce_operation_names.md │ │ │ ├── is_dense_tensor.md │ │ │ ├── disable_graph_tensor_validation_at_runtime.md │ │ │ ├── is_ragged_tensor.md │ │ │ ├── check_scalar_graph_tensor.md │ │ │ ├── ValidationError.md │ │ │ ├── keras.md │ │ │ ├── check_homogeneous_graph_tensor.md │ │ │ ├── keras │ │ │ └── layers │ │ │ │ ├── AddSelfLoops.md │ │ │ │ ├── ParseExample.md │ │ │ │ ├── ParseSingleExample.md │ │ │ │ ├── PadToTotalSizes.md │ │ │ │ ├── SingleInputNextState.md │ │ │ │ └── NextStateFromConcat.md │ │ │ ├── disable_graph_tensor_validation.md │ │ │ ├── proto │ │ │ ├── Metadata │ │ │ │ └── KeyValue.md │ │ │ ├── BigQuery │ │ │ │ └── TableSpec.md │ │ │ ├── OriginInfo.md │ │ │ ├── Context.md │ │ │ ├── NodeSet.md │ │ │ ├── Metadata.md │ │ │ ├── Feature.md │ │ │ ├── GraphSchema.md │ │ │ └── EdgeSet.md │ │ │ ├── experimental.md │ │ │ ├── softmax_edges_per_node.md │ │ │ ├── sampler │ │ │ ├── SamplingSpec.md │ │ │ └── SamplingOp.md │ │ │ ├── sampler.md │ │ │ ├── write_schema.md │ │ │ ├── FeatureDefaultValues.md │ │ │ ├── parse_schema.md │ │ │ ├── read_schema.md │ │ │ ├── SizeConstraints.md │ │ │ ├── graph_tensor_to_values.md │ │ │ ├── assert_constraints.md │ │ │ ├── iter_sets.md │ │ │ ├── add_self_loops.md │ │ │ ├── combine_values.md │ │ │ ├── satisfies_size_constraints.md │ │ │ └── iter_features.md │ │ └── README.md ├── sampler │ ├── __init__.py │ └── BUILD ├── experimental │ ├── sampler │ │ ├── BUILD │ │ └── proto │ │ │ ├── BUILD │ │ │ └── __init__.py │ ├── BUILD │ └── __init__.py ├── models │ ├── hgt │ │ ├── README.md │ │ ├── hparams_vizier_test.py │ │ └── __init__.py │ ├── gat_v2 │ │ ├── README.md │ │ ├── hparams_vizier_test.py │ │ └── __init__.py │ ├── graph_sage │ │ ├── README.md │ │ ├── BUILD │ │ └── __init__.py │ ├── gcn │ │ ├── README.md │ │ ├── BUILD │ │ └── __init__.py │ ├── contrastive_losses │ │ └── README.md │ ├── vanilla_mpnn │ │ ├── hparams_vizier_test.py │ │ └── __init__.py │ ├── multi_head_attention │ │ ├── README.md │ │ ├── hparams_vizier_test.py │ │ └── __init__.py │ ├── mt_albis │ │ └── __init__.py │ └── README.md ├── runner │ ├── trainers │ │ └── BUILD │ ├── utils │ │ ├── saved_model_test.sh │ │ └── model_dir.py │ └── input │ │ └── BUILD ├── proto │ ├── examples.proto │ └── BUILD └── keras │ └── __init__.py ├── .gitignore ├── testdata ├── heterogeneous │ ├── invalid_customer.csv │ ├── one_customer.csv │ ├── two_customers.csv │ ├── BUILD │ ├── owns_card.csv │ ├── transactions.csv │ ├── paid_with.csv │ ├── creditcard.csv │ └── customer.csv ├── homogeneous │ ├── one_seed.csv │ ├── two_seeds.csv │ ├── tastelike.csv │ ├── fruits.csv │ ├── BUILD │ ├── citrus.pbtxt │ ├── tastelike.recordio.ascii │ └── tastelike.sstable.ascii ├── node_vs_edge │ ├── node_set_one.csv │ ├── node_set_two.csv │ ├── edge_set_two_to_two.csv │ ├── edge_set_one_to_two.csv │ ├── spec.pbtxt │ ├── BUILD │ └── schema.pbtxt ├── README.md ├── BUILD └── feature_repr.pbtxt ├── requirements-dev.txt ├── MANIFEST.in ├── AUTHORS ├── kokoro └── github │ └── ubuntu │ └── cpu │ ├── continuous.cfg │ ├── presubmit.cfg │ ├── newest_stable │ ├── presubmit.cfg │ └── continuous.cfg │ └── oldest │ ├── continuous.cfg │ └── presubmit.cfg ├── examples ├── sampler │ ├── creditcard │ │ ├── sampling_spec.pbtxt │ │ ├── owns_card.csv │ │ ├── graph_schema.pbtxt │ │ ├── creditcard.csv │ │ └── customer.csv │ └── mag │ │ └── sampling_spec.pbtxt └── schemas │ ├── graph_nets.pbtxt │ ├── latent.pbtxt │ └── mpnn.pbtxt ├── package ├── BUILD ├── tfdep.bzl └── move_generated_files.sh └── Dockerfile /tensorflow_gnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_gnn/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_gnn/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_gnn/converters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorflow_gnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tensorflow_gnn/converters/ogb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bazel-* 2 | build 3 | dist 4 | *.egg-info 5 | -------------------------------------------------------------------------------- /testdata/heterogeneous/invalid_customer.csv: -------------------------------------------------------------------------------- 1 | #id 2 | 1 3 | -------------------------------------------------------------------------------- /testdata/homogeneous/one_seed.csv: -------------------------------------------------------------------------------- 1 | #id 2 | amanatsu 3 | -------------------------------------------------------------------------------- /testdata/heterogeneous/one_customer.csv: -------------------------------------------------------------------------------- 1 | #id 2 | 1876448 3 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/node_set_one.csv: -------------------------------------------------------------------------------- 1 | #id,name 2 | a,A 3 | -------------------------------------------------------------------------------- /testdata/homogeneous/two_seeds.csv: -------------------------------------------------------------------------------- 1 | #id 2 | amanatsu 3 | daidai 4 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/node_set_two.csv: -------------------------------------------------------------------------------- 1 | #id,name 2 | b,B 3 | c,C 4 | -------------------------------------------------------------------------------- /testdata/heterogeneous/two_customers.csv: -------------------------------------------------------------------------------- 1 | #id 2 | 1876448 3 | 1372437 4 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/gcn-symbols.txt: -------------------------------------------------------------------------------- 1 | gcn.GCNConv 2 | gcn.GCNHomGraphUpdate 3 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/edge_set_two_to_two.csv: -------------------------------------------------------------------------------- 1 | #source,#target,weight 2 | b,c,1.0 3 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/edge_set_one_to_two.csv: -------------------------------------------------------------------------------- 1 | #source,#target,weight 2 | a,b,1.0 3 | a,c,1.0 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | mock 2 | wheel 3 | # includes glibcxx older than 3.4.29 4 | ai-edge-litert-nightly -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/hgt-symbols.txt: -------------------------------------------------------------------------------- 1 | hgt.HGTGraphUpdate 2 | hgt.graph_update_from_config_dict 3 | hgt.graph_update_get_config_dict 4 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/guide/images/homogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/gnn/HEAD/tensorflow_gnn/docs/guide/images/homogeneous.png -------------------------------------------------------------------------------- /tensorflow_gnn/docs/guide/images/heterogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/gnn/HEAD/tensorflow_gnn/docs/guide/images/heterogeneous.png -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/mt_albis-symbols.txt: -------------------------------------------------------------------------------- 1 | mt_albis.MtAlbisGraphUpdate 2 | mt_albis.graph_update_from_config_dict 3 | mt_albis.graph_update_get_config_dict 4 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/vanilla_mpnn-symbols.txt: -------------------------------------------------------------------------------- 1 | vanilla_mpnn.VanillaMPNNGraphUpdate 2 | vanilla_mpnn.graph_update_from_config_dict 3 | vanilla_mpnn.graph_update_get_config_dict 4 | -------------------------------------------------------------------------------- /testdata/homogeneous/tastelike.csv: -------------------------------------------------------------------------------- 1 | #source,#target,weight 2 | amanatsu,daidai,0.1 3 | amanatsu,lumia,0.2 4 | kiyomi,komikan,0.3 5 | mandora,komikan,0.4 6 | mandora,tangelo,0.5 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ./WORKSPACE 2 | include ./package/move_generated_files.sh 3 | include ./package/tfdep.bzl 4 | include ./tensorflow_gnn/tensorflow_gnn.bzl 5 | global-include *.proto 6 | global-include BUILD -------------------------------------------------------------------------------- /testdata/homogeneous/fruits.csv: -------------------------------------------------------------------------------- 1 | #id,name 2 | amanatsu,Amanatsu 3 | daidai,Daidai 4 | hassaku,Hassaku 5 | kiyomi,Kiyomi 6 | komikan,Komikan 7 | lumia,Lumia 8 | mandora,Mandora 9 | reikou,Reikou 10 | tangelo,Tangelo 11 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/graph_sage-symbols.txt: -------------------------------------------------------------------------------- 1 | graph_sage.GCNGraphSAGENodeSetUpdate 2 | graph_sage.GraphSAGEAggregatorConv 3 | graph_sage.GraphSAGEGraphUpdate 4 | graph_sage.GraphSAGENextState 5 | graph_sage.GraphSAGEPoolingConv 6 | -------------------------------------------------------------------------------- /testdata/README.md: -------------------------------------------------------------------------------- 1 | # Test Data 2 | 3 | ## Synthetic Data 4 | 5 | All test data contained in the `/heterogeneous` and `/homogeneous` directories 6 | are synthetic datasets, and do not contain private or confidential information. 7 | 8 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/gat_v2-symbols.txt: -------------------------------------------------------------------------------- 1 | gat_v2.GATv2Conv 2 | gat_v2.GATv2EdgePool 3 | gat_v2.GATv2GraphUpdate 4 | gat_v2.GATv2HomGraphUpdate 5 | gat_v2.GATv2MPNNGraphUpdate 6 | gat_v2.graph_update_from_config_dict 7 | gat_v2.graph_update_get_config_dict 8 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/spec.pbtxt: -------------------------------------------------------------------------------- 1 | seed_op { 2 | op_name: "seed" 3 | node_set_name: "node_set_one" 4 | } 5 | sampling_ops { 6 | op_name: "hop1" 7 | input_op_names: ["seed"] 8 | strategy: TOP_K 9 | sample_size: 2 10 | edge_set_name: "one_to_two" 11 | } 12 | -------------------------------------------------------------------------------- /testdata/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | package( 4 | default_applicable_licenses = ["//tensorflow_gnn:license"], 5 | default_visibility = ["//visibility:public"], 6 | ) 7 | 8 | filegroup( 9 | name = "feature_repr", 10 | srcs = [ 11 | "feature_repr.pbtxt", 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of TensorFlow GNN authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | 7 | Google Inc. 8 | -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/continuous.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | 8 | action { 9 | define_artifacts { 10 | regex: "**/sponge_log.log" 11 | regex: "**/sponge_log.xml" 12 | } 13 | } -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/presubmit.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | 8 | action { 9 | define_artifacts { 10 | regex: "**/sponge_log.log" 11 | regex: "**/sponge_log.xml" 12 | } 13 | } -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/multi_head_attention-symbols.txt: -------------------------------------------------------------------------------- 1 | multi_head_attention.MultiHeadAttentionConv 2 | multi_head_attention.MultiHeadAttentionEdgePool 3 | multi_head_attention.MultiHeadAttentionHomGraphUpdate 4 | multi_head_attention.MultiHeadAttentionMPNNGraphUpdate 5 | multi_head_attention.graph_update_from_config_dict 6 | multi_head_attention.graph_update_get_config_dict 7 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/gcn/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fGCN 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * gcn 8 | * gcn.GCNConv 9 | * gcn.GCNHomGraphUpdate 10 | -------------------------------------------------------------------------------- /testdata/homogeneous/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | filegroup( 6 | name = "homogeneous", 7 | srcs = [ 8 | ":citrus.pbtxt", 9 | ":fruits.csv", 10 | ":one_seed.csv", 11 | ":sampler_golden.ascii", 12 | ":tastelike.csv", 13 | ":two_seeds.csv", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /examples/sampler/creditcard/sampling_spec.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: tensorflow_gnn/sampler/sampling_spec.proto 2 | # proto-message: PipelineSpec 3 | 4 | seed_op < 5 | op_name: "seed" 6 | node_set_name: "customer" 7 | > 8 | sampling_ops < 9 | op_name: "seed->creditcard" 10 | input_op_names: "seed" 11 | edge_set_name: "owns_card" 12 | sample_size: 3 13 | strategy: RANDOM_UNIFORM 14 | > -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/Loss.md: -------------------------------------------------------------------------------- 1 | # runner.Loss 2 | 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | #### Source: 8 | 9 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/Field.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.Field 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/FieldSpec.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.FieldSpec 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/Fields.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.Fields 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/IncidentNodeOrContextTag.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.IncidentNodeOrContextTag 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | package( 4 | default_applicable_licenses = ["//tensorflow_gnn:license"], 5 | default_visibility = ["//visibility:public"], 6 | ) 7 | 8 | filegroup( 9 | name = "node_vs_edge", 10 | srcs = [ 11 | "edge_set_one_to_two.csv", 12 | "edge_set_two_to_two.csv", 13 | ":node_set_one.csv", 14 | ":node_set_two.csv", 15 | ], 16 | data = glob(["*.pbtxt"]), 17 | ) 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/Predictions.md: -------------------------------------------------------------------------------- 1 | # runner.Predictions 2 | 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | #### Source: 8 | 9 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/FieldsSpec.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.FieldsSpec 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/Losses.md: -------------------------------------------------------------------------------- 1 | # runner.Losses 2 | 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | #### Source: 8 | 9 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /package/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | sh_binary( 4 | name = "move_generated_files", 5 | srcs = ["move_generated_files.sh"], 6 | data = [ 7 | "//tensorflow_gnn/experimental/sampler/proto:eval_dag_py_proto", 8 | "//tensorflow_gnn/proto:examples_py_proto", 9 | "//tensorflow_gnn/proto:graph_schema_py_proto", 10 | "//tensorflow_gnn/sampler:sampling_spec_py_proto", 11 | "//tensorflow_gnn/tools:sampled_stats_py_proto", 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/FieldOrFields.md: -------------------------------------------------------------------------------- 1 | 2 | # tfgnn.FieldOrFields 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | 8 | 9 | #### Source: 10 | 11 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /testdata/heterogeneous/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | filegroup( 6 | name = "heterogeneous", 7 | srcs = [ 8 | ":creditcard.csv", 9 | ":customer.csv", 10 | ":graph.pbtxt", 11 | ":invalid_customer.csv", 12 | ":one_customer.csv", 13 | ":owns_card.csv", 14 | ":paid_with.csv", 15 | ":sampler_golden.ascii", 16 | ":transactions.csv", 17 | ":two_customers.csv", 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Public interface for GNN Sampler.""" 2 | from tensorflow_gnn.sampler import sampling_spec_builder 3 | from tensorflow_gnn.sampler import sampling_spec_pb2 4 | 5 | SamplingOp = sampling_spec_pb2.SamplingOp 6 | SamplingSpec = sampling_spec_pb2.SamplingSpec 7 | SamplingSpecBuilder = sampling_spec_builder.SamplingSpecBuilder 8 | SamplingStrategy = sampling_spec_pb2.SamplingStrategy 9 | make_sampling_spec_tree = sampling_spec_builder.make_sampling_spec_tree 10 | 11 | del sampling_spec_pb2 12 | del sampling_spec_builder 13 | -------------------------------------------------------------------------------- /tensorflow_gnn/experimental/sampler/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | py_test( 6 | name = "core_test", 7 | srcs = ["core_test.py"], 8 | python_version = "PY3", 9 | deps = [], 10 | ) 11 | 12 | py_test( 13 | name = "ext_ops_test", 14 | srcs = ["ext_ops_test.py"], 15 | python_version = "PY3", 16 | deps = [], 17 | ) 18 | 19 | py_test( 20 | name = "eval_dag_test", 21 | srcs = ["eval_dag_test.py"], 22 | python_version = "PY3", 23 | deps = [], 24 | ) -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/mt_albis/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fMtAlbis 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * mt_albis 8 | * mt_albis.MtAlbisGraphUpdate 9 | * mt_albis.graph_update_from_config_dict 10 | * mt_albis.graph_update_get_config_dict 11 | -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/newest_stable/presubmit.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | env_vars: { 8 | key: "PYTHON_VERSION" 9 | value: "3.11" 10 | } 11 | env_vars: { 12 | key: "TF_VERSION" 13 | value: "2.20.*" 14 | } 15 | env_vars: { 16 | key: "TF_USE_LEGACY_KERAS" 17 | value: "1" 18 | } 19 | 20 | action { 21 | define_artifacts { 22 | regex: "**/sponge_log.log" 23 | regex: "**/sponge_log.xml" 24 | } 25 | } -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/newest_stable/continuous.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | env_vars: { 8 | key: "PYTHON_VERSION" 9 | value: "3.11" 10 | } 11 | env_vars: { 12 | key: "TF_VERSION" 13 | value: "2.20.*" 14 | } 15 | env_vars: { 16 | key: "TF_USE_LEGACY_KERAS" 17 | value: "1" 18 | } 19 | 20 | action { 21 | define_artifacts { 22 | regex: "**/sponge_log.log" 23 | regex: "**/sponge_log.xml" 24 | } 25 | } -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/reverse_tag.md: -------------------------------------------------------------------------------- 1 | # tfgnn.reverse_tag 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Flips tfgnn.SOURCE to tfgnn.TARGET and vice versa. 10 | 11 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fVanillaMPNN 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * vanilla_mpnn 8 | * vanilla_mpnn.VanillaMPNNGraphUpdate 9 | * vanilla_mpnn.graph_update_from_config_dict 10 | * vanilla_mpnn.graph_update_get_config_dict 11 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/enable_graph_tensor_validation.md: -------------------------------------------------------------------------------- 1 | # tfgnn.enable_graph_tensor_validation 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Enables static checks of graph tensors. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/ContextLabelFn.md: -------------------------------------------------------------------------------- 1 | # runner.ContextLabelFn 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Reads out a `tfgnn.Field` from the `GraphTensor` context. 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/Metrics.md: -------------------------------------------------------------------------------- 1 | # runner.Metrics 2 | 3 | 4 | 5 | This symbol is a **type alias**. 6 | 7 | #### Source: 8 | 9 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/oldest/continuous.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | env_vars: { 8 | key: "PYTHON_VERSION" 9 | value: "3.9" 10 | } 11 | env_vars: { 12 | key: "TF_VERSION" 13 | value: "2.12.*" 14 | } 15 | env_vars: { 16 | key: "TF_USE_LEGACY_KERAS" 17 | value: "0" 18 | } 19 | env_vars: { 20 | key: "TAG_FILTERS" 21 | value: ",-tf_at_least_2_13" 22 | } 23 | 24 | action { 25 | define_artifacts { 26 | regex: "**/sponge_log.log" 27 | regex: "**/sponge_log.xml" 28 | } 29 | } -------------------------------------------------------------------------------- /kokoro/github/ubuntu/cpu/oldest/presubmit.cfg: -------------------------------------------------------------------------------- 1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh" 2 | 3 | env_vars: { 4 | key: "USE_BAZEL_VERSION" 5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works. 6 | } 7 | env_vars: { 8 | key: "PYTHON_VERSION" 9 | value: "3.9" 10 | } 11 | env_vars: { 12 | key: "TF_VERSION" 13 | value: "2.12.*" 14 | } 15 | env_vars: { 16 | key: "TF_USE_LEGACY_KERAS" 17 | value: "0" 18 | } 19 | env_vars: { 20 | key: "TAG_FILTERS" 21 | value: ",-tf_at_least_2_13" 22 | } 23 | 24 | action { 25 | define_artifacts { 26 | regex: "**/sponge_log.log" 27 | regex: "**/sponge_log.xml" 28 | } 29 | } -------------------------------------------------------------------------------- /tensorflow_gnn/models/hgt/README.md: -------------------------------------------------------------------------------- 1 | # Heterogeneous Graph Transformers 2 | 3 | ## Overview 4 | 5 | This code implements the method in Heterogeneous Graph Transformers, originally 6 | implemented by 7 | 8 | * Ziniu Hu, Yuxiao Dong, Kuansan Wang, Yizhou Sun: 9 | ["Heterogeneous Graph Transformer"](https://arxiv.org/abs/2003.01332), 2020. 10 | 11 | 12 | ## Usage 13 | 14 | TensorFlow programs can import and use this model as described in its 15 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/hgt.md). 16 | 17 | ## API stability 18 | 19 | The API of this model may change between OSS library versions. 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/is_graph_tensor.md: -------------------------------------------------------------------------------- 1 | # tfgnn.is_graph_tensor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns whether `value` is a GraphTensor (possibly wrapped for Keras). 10 | 11 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/enable_graph_tensor_validation_at_runtime.md: -------------------------------------------------------------------------------- 1 | # tfgnn.enable_graph_tensor_validation_at_runtime 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Enables both static and runtime checks of graph tensors. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/get_registered_reduce_operation_names.md: -------------------------------------------------------------------------------- 1 | # tfgnn.get_registered_reduce_operation_names 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns the registered list of supported reduce operation names. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/is_dense_tensor.md: -------------------------------------------------------------------------------- 1 | # tfgnn.is_dense_tensor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns whether a tensor (TF or Keras) is a Tensor. 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md: -------------------------------------------------------------------------------- 1 | # gat_v2.graph_update_get_config_dict 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/disable_graph_tensor_validation_at_runtime.md: -------------------------------------------------------------------------------- 1 | # tfgnn.disable_graph_tensor_validation_at_runtime 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Disables runtime checks (`tf.debugging.Assert`) of graph tensors. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/is_ragged_tensor.md: -------------------------------------------------------------------------------- 1 | # tfgnn.is_ragged_tensor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns whether a tensor (TF or Keras) is a RaggedTensor. 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md: -------------------------------------------------------------------------------- 1 | # mt_albis.graph_update_get_config_dict 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/experimental/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | licenses(["notice"]) 6 | 7 | package( 8 | default_applicable_licenses = ["//tensorflow_gnn:license"], 9 | default_visibility = [ 10 | "//tensorflow_gnn:__subpackages__", 11 | ], 12 | ) 13 | 14 | pytype_strict_library( 15 | name = "experimental", 16 | srcs = ["__init__.py"], 17 | deps = [ 18 | "//tensorflow_gnn/graph:readout", 19 | "//tensorflow_gnn/graph:tensor_utils", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /testdata/homogeneous/citrus.pbtxt: -------------------------------------------------------------------------------- 1 | # TODO(blais): Test context features. 2 | 3 | node_sets { 4 | key: "fruits" 5 | value { 6 | metadata { 7 | filename: "fruits.csv" 8 | } 9 | features { 10 | key: "name" 11 | value: { 12 | description: "Fruit name" 13 | dtype: DT_STRING 14 | } 15 | } 16 | } 17 | } 18 | 19 | edge_sets { 20 | key: "tastelike" 21 | value { 22 | source: "fruits" 23 | target: "fruits" 24 | description: "Similar taste" 25 | metadata { 26 | filename: "tastelike.csv" 27 | } 28 | features { 29 | key: "weight" 30 | value: { 31 | dtype: DT_FLOAT 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/one_node_per_component.md: -------------------------------------------------------------------------------- 1 | # runner.one_node_per_component 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns a `Mapping` `node_set_name: 1` for every node set in `gtspec`. 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md: -------------------------------------------------------------------------------- 1 | # vanilla_mpnn.graph_update_get_config_dict 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/check_scalar_graph_tensor.md: -------------------------------------------------------------------------------- 1 | # tfgnn.check_scalar_graph_tensor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Checks that graph tensor is scalar (has rank 0). 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/runner/trainers/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | licenses(["notice"]) 6 | 7 | package( 8 | default_applicable_licenses = ["//tensorflow_gnn:license"], 9 | default_visibility = ["//visibility:public"], 10 | ) 11 | 12 | pytype_strict_library( 13 | name = "keras_fit", 14 | srcs = ["keras_fit.py"], 15 | visibility = ["//tensorflow_gnn/runner:__pkg__"], 16 | deps = [ 17 | "//:expect_tensorflow_installed", 18 | "//tensorflow_gnn/runner:interfaces", 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/contrastive_losses-symbols.txt: -------------------------------------------------------------------------------- 1 | contrastive_losses.AllSvdMetrics 2 | contrastive_losses.BarlowTwinsTask 3 | contrastive_losses.ContrastiveLossTask 4 | contrastive_losses.CorruptionSpec 5 | contrastive_losses.Corruptor 6 | contrastive_losses.DeepGraphInfomaxLogits 7 | contrastive_losses.DeepGraphInfomaxTask 8 | contrastive_losses.DropoutFeatures 9 | contrastive_losses.ShuffleFeaturesGlobally 10 | contrastive_losses.TripletEmbeddingSquaredDistances 11 | contrastive_losses.TripletLossTask 12 | contrastive_losses.VicRegTask 13 | contrastive_losses.coherence 14 | contrastive_losses.numerical_rank 15 | contrastive_losses.pseudo_condition_number 16 | contrastive_losses.rankme 17 | contrastive_losses.self_clustering 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gat_v2/README.md: -------------------------------------------------------------------------------- 1 | # Graph Attention Networks v2 2 | 3 | ## Overview 4 | 5 | This code implements Graph Attention Networks v2, originally published by 6 | 7 | * Shaked Brody, Uri Alon, Eran Yahav: 8 | ["How Attentive are Graph Attention 9 | Networks?"](https://arxiv.org/abs/2105.14491), 2021. 10 | 11 | TensorFlow programs can import and use it as described in its 12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/gat_v2.md). 13 | 14 | ## API stability 15 | 16 | This model is covered by [semantic 17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's 18 | open-source releases: new minor versions do not break existing users. 19 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/graph_sage/README.md: -------------------------------------------------------------------------------- 1 | # GraphSAGE 2 | 3 | ## Overview 4 | 5 | This code implements the GraphSAGE model, originally published by 6 | 7 | * William L. Hamilton, Rex Ying, Jure Leskovec: 8 | ["Inductive Representation Learning 9 | on Large Graphs"](https://arxiv.org/abs/1706.02216), 2017. 10 | 11 | TensorFlow programs can import and use it as described in its 12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md). 13 | 14 | ## API stability 15 | 16 | This model is covered by [semantic 17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's 18 | open-source releases: new minor versions do not break existing users. 19 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gcn/README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Network 2 | 3 | ## Overview 4 | 5 | This code implements Graph Convolutional Networks, originally published by 6 | 7 | * Thomas N. Kipf and Max Welling: 8 | ["Semi-Supervised Classification with Graph Convolutional 9 | Networks"](https://arxiv.org/abs/1609.02907), 2016. 10 | 11 | TensorFlow programs can import and use it as described in its 12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/gcn.md). 13 | 14 | ## API stability 15 | 16 | This model is covered by [semantic 17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's 18 | open-source releases: new minor versions do not break existing users. 19 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/ValidationError.md: -------------------------------------------------------------------------------- 1 | # tfgnn.ValidationError 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A schema validation error. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | This exception is raised if in the course of validating the schema for 22 | correctness some errors are found. 23 | 24 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md: -------------------------------------------------------------------------------- 1 | # multi_head_attention.graph_update_get_config_dict 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 10 | 11 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/RootNodeLabelFn.md: -------------------------------------------------------------------------------- 1 | # runner.RootNodeLabelFn 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Reads out a `tfgnn.Field` from the `GraphTensor` root (i.e. first) node. 10 | 11 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /examples/sampler/creditcard/owns_card.csv: -------------------------------------------------------------------------------- 1 | source,target 2 | 1876448,16827485386298040 3 | 1372437,11470379189154620 4 | 1368305,11163838768727470 5 | 1974494,16011471358128450 6 | 1257724,18569067217418250 7 | 1758057,17396883707513070 8 | 1531660,14844931107602160 9 | 1489311,1238474857489384 10 | 1407706,11290312140467510 11 | 196838,17861046738135650 12 | 1195675,8878522895102384 13 | 1659366,13019350102369400 14 | 1499004,11470379189154620 15 | 1344333,16283233487191600 16 | 1443888,9991040399813057 17 | 1108778,14912408563871390 18 | 175583,11290312140467510 19 | 1251872,12948957000457930 20 | 1493851,3549061668422198 21 | 1599418,9991040399813057 22 | 1768701,18362223127059380 23 | 1549489,1238474857489384 24 | 1879799,18569067217418250 25 | 125454,18526138896540830 -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxLogits.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.DeepGraphInfomaxLogits 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Computes clean and corrupted logits for Deep Graph Infomax (DGI). 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras.md: -------------------------------------------------------------------------------- 1 | # Module: tfgnn.keras 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The tfgnn.keras package. 10 | 11 | 12 | 13 | ## Modules 14 | 15 | [`layers`](../tfgnn/keras/layers.md) module: The tfgnn.keras.layers package. 16 | 17 | ## Classes 18 | 19 | [`class ConvGNNBuilder`](../tfgnn/keras/ConvGNNBuilder.md): Factory of layers that do convolutions on a graph. 20 | 21 | ## Functions 22 | 23 | [`clone_initializer(...)`](../tfgnn/keras/clone_initializer.md): Clones an 24 | initializer to ensure a new default seed. 25 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/graph_sage/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fGraphSAGE 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * graph_sage 8 | * graph_sage.GCNGraphSAGENodeSetUpdate 9 | * graph_sage.GraphSAGEAggregatorConv 10 | * graph_sage.GraphSAGEGraphUpdate 11 | * graph_sage.GraphSAGENextState 12 | * graph_sage.GraphSAGEPoolingConv 13 | -------------------------------------------------------------------------------- /testdata/heterogeneous/owns_card.csv: -------------------------------------------------------------------------------- 1 | source,target 2 | 1876448,16827485386298040 3 | 1372437,11470379189154620 4 | 1368305,11163838768727470 5 | 1974494,16011471358128450 6 | 1257724,18569067217418250 7 | 1758057,17396883707513070 8 | 1531660,14844931107602160 9 | 1489311,1238474857489384 10 | 1407706,11290312140467510 11 | 196838,17861046738135650 12 | 1195675,8878522895102384 13 | 1659366,13019350102369400 14 | 1499004,11470379189154620 15 | 1344333,16283233487191600 16 | 1443888,9991040399813057 17 | 1108778,14912408563871390 18 | 175583,11290312140467510 19 | 1251872,12948957000457930 20 | 1493851,3549061668422198 21 | 1599418,9991040399813057 22 | 1768701,18362223127059380 23 | 1549489,1238474857489384 24 | 1879799,18569067217418250 25 | 125454,18526138896540830 -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/check_homogeneous_graph_tensor.md: -------------------------------------------------------------------------------- 1 | # tfgnn.check_homogeneous_graph_tensor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Raises ValueError when tfgnn.get_homogeneous_node_and_edge_set_name() does. 10 | 11 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletEmbeddingSquaredDistances.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.TripletEmbeddingSquaredDistances 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Computes embeddings distance between positive and negative pairs. 10 | 11 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tensorflow_gnn/sampler/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites some of these load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@rules_proto//proto:defs.bzl", "proto_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@rules_python//python:proto.bzl", "py_proto_library") 7 | 8 | package(default_visibility = ["//visibility:public"]) 9 | 10 | licenses(["notice"]) # Apache 2.0 11 | 12 | proto_library( 13 | name = "sampling_spec_proto", 14 | srcs = ["sampling_spec.proto"], 15 | deps = ["@org_tensorflow//tensorflow/core:protos_all"], 16 | ) 17 | 18 | py_proto_library( 19 | name = "sampling_spec_py_proto", 20 | deps = [ 21 | ":sampling_spec_proto", 22 | ], 23 | ) -------------------------------------------------------------------------------- /tensorflow_gnn/tools/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites some of these load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@rules_proto//proto:defs.bzl", "proto_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@rules_python//python:proto.bzl", "py_proto_library") 7 | 8 | package(default_visibility = ["//visibility:public"]) 9 | 10 | licenses(["notice"]) # Apache 2.0 11 | 12 | proto_library( 13 | name = "sampled_stats_proto", 14 | srcs = ["sampled_stats.proto"], 15 | deps = ["@org_tensorflow//tensorflow/core:protos_all"], 16 | ) 17 | 18 | py_proto_library( 19 | name = "sampled_stats_py_proto", 20 | deps = [ 21 | ":sampled_stats_proto", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/gat_v2/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fGATv2 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * gat_v2 8 | * gat_v2.GATv2Conv 9 | * gat_v2.GATv2EdgePool 10 | * gat_v2.GATv2HomGraphUpdate 11 | * gat_v2.GATv2MPNNGraphUpdate 12 | * gat_v2.graph_update_from_config_dict 13 | * gat_v2.graph_update_get_config_dict 14 | -------------------------------------------------------------------------------- /tensorflow_gnn/utils/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | licenses(["notice"]) 6 | 7 | package( 8 | default_applicable_licenses = ["//tensorflow_gnn:license"], 9 | default_visibility = ["//visibility:public"], 10 | ) 11 | 12 | pytype_strict_library( 13 | name = "test_utils", 14 | srcs = ["test_utils.py"], 15 | ) 16 | 17 | pytype_strict_library( 18 | name = "api_utils", 19 | srcs = ["api_utils.py"], 20 | ) 21 | 22 | pytype_strict_library( 23 | name = "tf_test_utils", 24 | srcs = ["tf_test_utils.py"], 25 | deps = [ 26 | "//:expect_tensorflow_installed", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/gcn.md: -------------------------------------------------------------------------------- 1 | # Module: gcn 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Graph Convolutional Networks. 10 | 11 | Users of TF-GNN can use this model by importing it next to the core library as 12 | 13 | ```python 14 | import tensorflow_gnn as tfgnn 15 | from tensorflow_gnn.models import gcn 16 | ``` 17 | 18 | ## Classes 19 | 20 | [`class GCNConv`](./gcn/GCNConv.md): Implements the Graph Convolutional Network 21 | by Kipf&Welling (2016). 22 | 23 | ## Functions 24 | 25 | [`GCNHomGraphUpdate(...)`](./gcn/GCNHomGraphUpdate.md): Returns a graph update 26 | layer for GCN convolution. 27 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/AddSelfLoops.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.AddSelfLoops 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Adds self-loops to scalar graphs. 10 | 11 | 16 | 17 | 18 | 19 | The edge_set_name is expected to be a homogeneous edge (connects a node pair of 20 | the node set). NOTE: Self-connections will always be added, regardless if if 21 | self-connections already exist or not. 22 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/disable_graph_tensor_validation.md: -------------------------------------------------------------------------------- 1 | # tfgnn.disable_graph_tensor_validation 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Disables both static and runtime checks of graph tensors. 10 | 11 | 14 | 15 | 16 | 17 | IMPORTANT: This is temporary workaround for the legacy code (before TF-GNN 1.0 18 | release) that may rely on the inconsistent number of graph tensor items and 19 | allowed edges with adjaceny indices for non-existing nodes. **DO NOT USE**. 20 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Metadata/KeyValue.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.Metadata.KeyValue 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A ProtocolMessage 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 22 | 25 | 26 | 29 | 32 | 33 |
20 | key 21 | 23 | string key 24 |
27 | value 28 | 30 | string value 31 |
34 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/GraphTensorProcessorFn.md: -------------------------------------------------------------------------------- 1 | # runner.GraphTensorProcessorFn 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A class for `GraphTensor` processing. 10 | 11 | 12 | 13 | ## Methods 14 | 15 |

__call__

16 | 17 | View 18 | source 19 | 20 | 25 | 26 | Processes a `GraphTensor`. 27 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/experimental.md: -------------------------------------------------------------------------------- 1 | # Module: tfgnn.experimental 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Experimental (unstable) parts of the public interface of TensorFlow GNN. 10 | 11 | A symbol `foo` exposed here is available to library users as 12 | 13 | ``` 14 | import tensorflow_gnn as tfgnn 15 | 16 | tfgnn.experimental.foo() 17 | ``` 18 | 19 | This is the preferred way to expose individual functions on track to inclusion 20 | into the stable public interface of TensorFlow GNN. 21 | 22 | Beyond these symbols, there are also experimental sub-libraries that need to be 23 | imported separately (`from tensorflow_gnn.experimental import foo`). That is for 24 | special cases only. 25 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/ParseExample.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.ParseExample 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Applies tfgnn.parse_example(graph_tensor_spec, _) to a batch of strings. 10 | 11 | 17 | 18 | 19 | 20 | 21 | 22 | This layer can be restored from config by `tf.keras.models.load_model()` when 23 | saved as part of a Keras model using `save_format="tf"`. 24 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/ParseSingleExample.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.ParseSingleExample 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Applies tfgnn.parse_single_example(graph_tensor_spec, _). 10 | 11 | 17 | 18 | 19 | 20 | 21 | 22 | This layer can be restored from config by `tf.keras.models.load_model()` when 23 | saved as part of a Keras model using `save_format="tf"`. 24 | -------------------------------------------------------------------------------- /testdata/node_vs_edge/schema.pbtxt: -------------------------------------------------------------------------------- 1 | node_sets { 2 | key: "node_set_one" 3 | value { 4 | metadata { 5 | filename: "node_set_one.csv" 6 | } 7 | features { 8 | key: "#id" 9 | value { 10 | dtype: DT_STRING 11 | } 12 | } 13 | } 14 | } 15 | node_sets { 16 | key: "node_set_two" 17 | value { 18 | metadata { 19 | filename: "node_set_two.csv" 20 | } 21 | features { 22 | key: "#id" 23 | value { 24 | dtype: DT_STRING 25 | } 26 | } 27 | } 28 | } 29 | edge_sets { 30 | key: "one_to_two" 31 | value { 32 | source: "node_set_one" 33 | target: "node_set_two" 34 | metadata { 35 | filename: "edge_set_one_to_two.csv" 36 | } 37 | } 38 | } 39 | edge_sets { 40 | key: "two_to_two" 41 | value { 42 | source: "node_set_two" 43 | target: "node_set_two" 44 | metadata { 45 | filename: "edge_set_two_to_two.csv" 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/softmax_edges_per_node.md: -------------------------------------------------------------------------------- 1 | # tfgnn.softmax_edges_per_node 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns softmax() of edge values per common `node_tag` node. 10 | 11 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/DatasetProvider.md: -------------------------------------------------------------------------------- 1 | # runner.DatasetProvider 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Helper class that provides a standard way to create an ABC using inheritance. 10 | 11 | 12 | 13 | ## Methods 14 | 15 |

get_dataset

16 | 17 | View 18 | source 19 | 20 | 26 | 27 | Get a `tf.data.Dataset` by `context` per replica. 28 | -------------------------------------------------------------------------------- /package/tfdep.bzl: -------------------------------------------------------------------------------- 1 | """Dependency on TensorFlow for build.""" 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | def tf_setup(): 6 | """Define tensorflow>=2.13.0 dependency for Bazel build. 7 | 8 | This downloads the TensorFlow files required for building TFGNN protos 9 | (examples and graph_schema).This TF version should always be within our supported range and gets 10 | updated manually as our TF dependency advances. This version is somewhat flexible since TFGNN 11 | protos only depend on tensorflow.Example, tensorflow.Feature, tensorflow.TensorShapeProto, and 12 | tensorflow.DataType, which are all very stable definitions in TensorFlow. 13 | """ 14 | http_archive( 15 | name = "org_tensorflow", 16 | sha256 = "e58c939079588623e6fa1d054aec2f90f95018266e0a970fd353a5244f5173dc", 17 | urls = [ 18 | "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.13.0.tar.gz", 19 | ], 20 | strip_prefix = "tensorflow-2.13.0", 21 | ) 22 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/BigQuery/TableSpec.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.BigQuery.TableSpec 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A ProtocolMessage 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 22 | 25 | 26 | 29 | 32 | 33 | 36 | 39 | 40 |
20 | dataset 21 | 23 | string dataset 24 |
27 | project 28 | 30 | string project 31 |
34 | table 35 | 37 | string table 38 |
41 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/OriginInfo.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.OriginInfo 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Metadata about the origin of the graph data. 10 | 11 | 12 | 13 | For detailed documentation, see the comments in the `graph_schema.proto` file. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 28 | 29 | 32 | 35 | 36 |
23 | graph_type 24 | 26 | GraphType graph_type 27 |
30 | root_set 31 | 33 | repeated string root_set 34 |
37 | -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/sampler-symbols.txt: -------------------------------------------------------------------------------- 1 | sampler.Artifacts 2 | sampler.CompositeLayer 3 | sampler.ConnectingEdgesSampler 4 | sampler.InMemIndexToFeaturesAccessor 5 | sampler.InMemIntegerKeyToBytesAccessor 6 | sampler.InMemStringKeyToBytesAccessor 7 | sampler.InMemUniformEdgesSampler 8 | sampler.KeyToBytesAccessor 9 | sampler.KeyToFeaturesAccessor 10 | sampler.KeyToTfExampleAccessor 11 | sampler.OutgoingEdgesSampler 12 | sampler.TfExamplesParser 13 | sampler.UniformEdgesSampler 14 | sampler.build_graph_tensor 15 | sampler.create_link_sampling_model_from_spec 16 | sampler.create_program 17 | sampler.create_sampling_model_from_spec 18 | sampler.ragged_choice 19 | sampler.ragged_lookup 20 | sampler.ragged_unique 21 | sampler.save_model 22 | sampler.set_ext_ops_implementation 23 | sampler.proto.Program 24 | sampler.proto.EvalDAG 25 | sampler.proto.Stage 26 | sampler.proto.Layer 27 | sampler.proto.ValueSpec 28 | sampler.proto.TensorSpec 29 | sampler.proto.RaggedTensorSpec 30 | sampler.proto.FlattenedSpec 31 | sampler.proto.EdgeSamplingConfig 32 | sampler.proto.IOFeatures 33 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Context.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.Context 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The schema for the features that apply across the entire input graph. 10 | 11 | 12 | 13 | For detailed documentation, see the comments in the `graph_schema.proto` file. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 28 | 29 | 32 | 35 | 36 |
23 | features 24 | 26 | repeated FeaturesEntry features 27 |
30 | metadata 31 | 33 | Metadata metadata 34 |
37 | -------------------------------------------------------------------------------- /tensorflow_gnn/experimental/sampler/proto/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites some of these load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@rules_proto//proto:defs.bzl", "proto_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@rules_python//python:proto.bzl", "py_proto_library") 7 | 8 | package(default_visibility = ["//visibility:public"]) 9 | 10 | licenses(["notice"]) # Apache 2.0 11 | 12 | py_library( 13 | name = "proto", 14 | srcs = ["__init__.py"], 15 | srcs_version = "PY3", 16 | visibility = ["//visibility:public"], 17 | deps = [ 18 | ":eval_dag_py_proto", 19 | "//tensorflow_gnn/utils:api_utils", 20 | ], 21 | ) 22 | 23 | 24 | proto_library( 25 | name = "eval_dag_proto", 26 | srcs = ["eval_dag.proto"], 27 | deps = [ 28 | "@org_tensorflow//tensorflow/core:protos_all", 29 | ], 30 | ) 31 | 32 | py_proto_library( 33 | name = "eval_dag_py_proto", 34 | deps = [ 35 | ":eval_dag_proto", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Example of building the docker image locally: 16 | # docker build . -t tfgnn:latest 17 | # 18 | # You can then start an interactive python interpreter shell and 19 | # import tensorflow-gnn with: 20 | # docker run -it tfgnn:latest 21 | FROM python:3.9-slim 22 | # tzdata asks questions. 23 | ENV DEBIAN_FRONTEND="noninteractive" 24 | ENV TZ="America/New_York" 25 | 26 | RUN pip3 install --upgrade pip 27 | 28 | RUN pip3 install "tensorflow-gnn==1.0.0" httplib2 notebook ogb 29 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/all_symbols.md: -------------------------------------------------------------------------------- 1 | # All symbols in TensorFlow GNN Models: fMultiHeadAttention 2 | 3 | 4 | 5 | ## Primary symbols 6 | 7 | * multi_head_attention 8 | * multi_head_attention.MultiHeadAttentionConv 9 | * multi_head_attention.MultiHeadAttentionEdgePool 10 | * multi_head_attention.MultiHeadAttentionHomGraphUpdate 11 | * multi_head_attention.MultiHeadAttentionMPNNGraphUpdate 12 | * multi_head_attention.graph_update_from_config_dict 13 | * multi_head_attention.graph_update_get_config_dict 14 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/sampler/SamplingSpec.md: -------------------------------------------------------------------------------- 1 | # tfgnn.sampler.SamplingSpec 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A ProtocolMessage 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 23 | 26 | 27 | 30 | 33 | 34 | 37 | 40 | 41 |
21 | sampling_ops 22 | 24 | repeated SamplingOp sampling_ops 25 |
28 | seed_op 29 | 31 | SeedOp seed_op 32 |
35 | symmetric_link_seed_op 36 | 38 | SymmetricLinkSeedOp symmetric_link_seed_op 39 |
42 | -------------------------------------------------------------------------------- /tensorflow_gnn/proto/examples.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | syntax = "proto2"; 16 | 17 | package tensorflow_gnn.testdata; 18 | 19 | import "tensorflow/core/example/example.proto"; 20 | 21 | // Specifies one or more Examples. This is used to aid testing readability by 22 | // allowing us to specify multiple Examples in an ASCII proto file. 23 | message ExampleList { 24 | repeated tensorflow.Example examples = 1; 25 | } 26 | -------------------------------------------------------------------------------- /tensorflow_gnn/runner/utils/saved_model_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | set -e -u -o pipefail # Let no failure go undetected. 19 | 20 | saved_model_gen_testdata=$1 21 | saved_model_load_testdata=$2 22 | use_legacy_model_save=$3 23 | 24 | $saved_model_gen_testdata --filepath=${TEST_TMPDIR}/saved_model_testdata \ 25 | --use_legacy_model_save=$use_legacy_model_save 26 | $saved_model_load_testdata --filepath=${TEST_TMPDIR}/saved_model_testdata 27 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow GNN: API docs 2 | 3 | TensorFlow GNN is a Python package that consists of multiple sub-libraries. 4 | 5 | * Core library [[API docs](python/tfgnn.md)]:
6 | `import tensorflow_gnn as tfgnn` 7 | 8 | * Training orchestration [[API docs](python/runner.md)]:
9 | `from tensorflow_gnn import runner` 10 | 11 | * Graph sampler:
12 | `from tensorflow_gnn.experimental import sampler` 13 | 14 | * Models from the [models collection](../../models/README.md):
15 | `from tensorflow_gnn.models import ` 16 | 17 | The TensorFlow GNN package uses [semantic 18 | versioning](https://semver.org/spec/v2.0.0.html) for its numbered releases. 19 | Its stable Python API consists of the Python identifiers exposed in these 20 | imported modules and their sub-modules, with the following exceptions: 21 | 22 | * identifiers or modules that contain `experimental` in their name, 23 | that are expressly documented to be unstable, or not documented at all; 24 | * private identifiers beginning with a single underscore (`_`); 25 | * models from the models collection whose documentation opts out of 26 | semantic versioning. -------------------------------------------------------------------------------- /examples/sampler/mag/sampling_spec.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: tensorflow_gnn/sampler/sampling_spec.proto 2 | # proto-message: PipelineSpec 3 | 4 | seed_op < 5 | op_name: "seed" 6 | node_set_name: "paper" 7 | > 8 | sampling_ops < 9 | op_name: "seed->paper" 10 | input_op_names: "seed" 11 | edge_set_name: "cites" 12 | sample_size: 32 13 | strategy: RANDOM_UNIFORM 14 | > 15 | sampling_ops < 16 | op_name: "paper->author" 17 | input_op_names: "seed" 18 | input_op_names: "seed->paper" 19 | edge_set_name: "written" 20 | sample_size: 8 21 | strategy: RANDOM_UNIFORM 22 | > 23 | sampling_ops < 24 | op_name: "author->paper" 25 | input_op_names: "paper->author" 26 | edge_set_name: "writes" 27 | sample_size: 16 28 | strategy: RANDOM_UNIFORM 29 | > 30 | sampling_ops < 31 | op_name: "author->institution" 32 | input_op_names: "paper->author" 33 | edge_set_name: "affiliated_with" 34 | sample_size: 16 35 | strategy: RANDOM_UNIFORM 36 | > 37 | sampling_ops < 38 | op_name: "paper->field_of_study" 39 | input_op_names: "seed" 40 | input_op_names: "seed->paper" 41 | input_op_names: "author->paper" 42 | edge_set_name: "has_topic" 43 | sample_size: 16 44 | strategy: RANDOM_UNIFORM 45 | > -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md: -------------------------------------------------------------------------------- 1 | # Module: vanilla_mpnn 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | TF-GNN's "Vanilla MPNN" model. 10 | 11 | Users of TF-GNN can use this model by importing it next to the core library as 12 | 13 | ```python 14 | import tensorflow_gnn as tfgnn 15 | from tensorflow_gnn.models import vanilla_gnn 16 | ``` 17 | 18 | This model ties together some simple convolutions from the TF-GNN core library, 19 | so it does not define any Conv class by itself. 20 | 21 | ## Functions 22 | 23 | [`VanillaMPNNGraphUpdate(...)`](./vanilla_mpnn/VanillaMPNNGraphUpdate.md): 24 | Returns a GraphUpdate layer for a Vanilla MPNN. 25 | 26 | [`graph_update_from_config_dict(...)`](./vanilla_mpnn/graph_update_from_config_dict.md): 27 | Returns a VanillaMPNNGraphUpdate initialized from `cfg`. 28 | 29 | [`graph_update_get_config_dict(...)`](./vanilla_mpnn/graph_update_get_config_dict.md): 30 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 31 | -------------------------------------------------------------------------------- /tensorflow_gnn/runner/input/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test") 7 | 8 | licenses(["notice"]) 9 | 10 | package( 11 | default_applicable_licenses = ["//tensorflow_gnn:license"], 12 | default_visibility = ["//visibility:public"], 13 | ) 14 | 15 | pytype_strict_library( 16 | name = "datasets", 17 | srcs = ["datasets.py"], 18 | visibility = ["//tensorflow_gnn/runner:__pkg__"], 19 | deps = [ 20 | "//:expect_tensorflow_installed", 21 | "//tensorflow_gnn/runner:interfaces", 22 | ], 23 | ) 24 | 25 | py_strict_test( 26 | name = "datasets_test", 27 | srcs = ["datasets_test.py"], 28 | deps = [ 29 | ":datasets", 30 | "//:expect_absl_installed_testing", 31 | "//third_party/py/google/protobuf:use_fast_cpp_protos", # Automatically added go/proto_python_upb_flip 32 | "//:expect_tensorflow_installed", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/mt_albis.md: -------------------------------------------------------------------------------- 1 | # Module: mt_albis 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | TF-GNN's Model Template "Albis". 10 | 11 | The TF-GNN Model Template "Albis" provides a small selection of field-tested GNN 12 | architectures through the 13 | mt_albis.MtAlbisGraphUpdate 14 | class. 15 | 16 | Users of TF-GNN can use it by importing it next to the core library as 17 | 18 | ```python 19 | import tensorflow_gnn as tfgnn 20 | from tensorflow_gnn.models import mt_albis 21 | ``` 22 | 23 | ## Functions 24 | 25 | [`MtAlbisGraphUpdate(...)`](./mt_albis/MtAlbisGraphUpdate.md): Returns 26 | GraphUpdate layer for message passing with Model Template "Albis". 27 | 28 | [`graph_update_from_config_dict(...)`](./mt_albis/graph_update_from_config_dict.md): 29 | Constructs a MtAlbisGraphUpdate from a ConfigDict. 30 | 31 | [`graph_update_get_config_dict(...)`](./mt_albis/graph_update_get_config_dict.md): 32 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 33 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/sampler.md: -------------------------------------------------------------------------------- 1 | # Module: tfgnn.sampler 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Public interface for GNN Sampler. 10 | 11 | ## Classes 12 | 13 | [`class SamplingOp`](../tfgnn/sampler/SamplingOp.md): A ProtocolMessage 14 | 15 | [`class SamplingSpec`](../tfgnn/sampler/SamplingSpec.md): A ProtocolMessage 16 | 17 | [`class SamplingSpecBuilder`](../tfgnn/sampler/SamplingSpecBuilder.md): Mimics 18 | builder pattern that eases creation of `tfgnn.SamplingSpec`. 19 | 20 | ## Functions 21 | 22 | [`make_sampling_spec_tree(...)`](../tfgnn/sampler/make_sampling_spec_tree.md): 23 | Automatically creates `SamplingSpec` by starting from seed node set. 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 35 | 38 | 39 |
33 | SamplingStrategy 34 | 36 | ['TOP_K', 'RANDOM_UNIFORM', 'RANDOM_WEIGHTED'] 37 |
40 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/write_schema.md: -------------------------------------------------------------------------------- 1 | # tfgnn.write_schema 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Write a `GraphSchema` to a text-formatted proto file. 10 | 11 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 30 | 33 | 34 | 37 | 41 | 42 |
28 | schema 29 | 31 | A GraphSchema instance to write out. 32 |
35 | filename 36 | 38 | A string, the path to a file to render a text-formatted rendition 39 | of the GraphSchema message to. 40 |
43 | -------------------------------------------------------------------------------- /examples/schemas/graph_nets.pbtxt: -------------------------------------------------------------------------------- 1 | # An example graph schema matching DeepMind's GraphsTuple. 2 | # //third_party/py/tensorflow_gnn/proto/graph_schema.proto:GraphSchema 3 | 4 | context { 5 | features { 6 | key: "embedding" 7 | value: { 8 | description: "Globals feature vector" 9 | dtype: DT_FLOAT 10 | shape: { dim { size: 128 } } 11 | } 12 | } 13 | } 14 | 15 | node_sets { 16 | key: "default" 17 | value { 18 | features { 19 | key: "embedding" 20 | value: { 21 | description: "Encoded node features vector" 22 | dtype: DT_FLOAT 23 | shape: { dim { size: 64 } } 24 | } 25 | } 26 | 27 | features { 28 | key: "labels" 29 | value: { 30 | description: "Multiple ground truth text labels" 31 | dtype: DT_STRING 32 | shape: { dim { size: -1 } } 33 | } 34 | } 35 | 36 | context: "embedding" 37 | description: "Main set of node features" 38 | } 39 | } 40 | 41 | edge_sets { 42 | key: "default" 43 | value { 44 | features { 45 | key: "embedding" 46 | value: { 47 | description: "Encoded edge features vector" 48 | dtype: DT_FLOAT 49 | shape: { dim { size: 32 } } 50 | } 51 | } 52 | source: "default" 53 | target: "default" 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/NodeSet.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.NodeSet 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The schema shared by a set of nodes in the graph. 10 | 11 | 12 | 13 | For detailed documentation, see the comments in the `graph_schema.proto` file. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 28 | 29 | 32 | 35 | 36 | 39 | 42 | 43 | 46 | 49 | 50 |
23 | context 24 | 26 | repeated string context 27 |
30 | description 31 | 33 | string description 34 |
37 | features 38 | 40 | repeated FeaturesEntry features 41 |
44 | metadata 45 | 47 | Metadata metadata 48 |
51 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/sampler/SamplingOp.md: -------------------------------------------------------------------------------- 1 | # tfgnn.sampler.SamplingOp 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A ProtocolMessage 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 22 | 25 | 26 | 29 | 32 | 33 | 36 | 39 | 40 | 43 | 46 | 47 | 50 | 53 | 54 |
20 | edge_set_name 21 | 23 | string edge_set_name 24 |
27 | input_op_names 28 | 30 | repeated string input_op_names 31 |
34 | op_name 35 | 37 | string op_name 38 |
41 | sample_size 42 | 44 | int32 sample_size 45 |
48 | strategy 49 | 51 | SamplingStrategy strategy 52 |
55 | -------------------------------------------------------------------------------- /testdata/feature_repr.pbtxt: -------------------------------------------------------------------------------- 1 | context { 2 | features { 3 | key: "rankings" 4 | value { 5 | description: "Scores for each quartile" 6 | dtype: DT_FLOAT 7 | shape { dim { size: 4 } } 8 | } 9 | } 10 | } 11 | 12 | node_sets { 13 | key: "items" 14 | value { 15 | features { 16 | key: "category" 17 | value { 18 | description: "Purchase category" 19 | dtype: DT_STRING 20 | } 21 | } 22 | features { 23 | key: "amounts" 24 | value { 25 | description: "Purchase price" 26 | dtype: DT_FLOAT 27 | shape { dim { size: -1 } } 28 | } 29 | } 30 | } 31 | } 32 | 33 | node_sets { 34 | key: "persons" 35 | value { 36 | features { 37 | key: "name" 38 | value { 39 | description: "First name" 40 | dtype: DT_STRING 41 | } 42 | } 43 | features { 44 | key: "age" 45 | value { 46 | description: "Age" 47 | dtype: DT_INT64 48 | } 49 | } 50 | features { 51 | key: "country" 52 | value { 53 | description: "Country of origin" 54 | dtype: DT_STRING 55 | } 56 | } 57 | } 58 | } 59 | 60 | edge_sets { 61 | key: "purchased" 62 | value { source: "items" target: "persons" } 63 | } 64 | 65 | edge_sets { 66 | key: "is-friend" 67 | value { source: "persons" target: "persons" } 68 | } 69 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/GraphTensorPadding.md: -------------------------------------------------------------------------------- 1 | # runner.GraphTensorPadding 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Collects `GraphtTensor` padding helpers. 10 | 11 | 12 | 13 | ## Methods 14 | 15 |

get_filter_fn

16 | 17 | View 18 | source 19 | 20 | 26 | 27 |

get_size_constraints

28 | 29 | View 30 | source 31 | 32 | 38 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/FeatureDefaultValues.md: -------------------------------------------------------------------------------- 1 | # tfgnn.FeatureDefaultValues 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Default values for graph context, node sets and edge sets features. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 |
31 | context 32 | 34 | A namedtuple alias for field number 0 35 |
38 | node_sets 39 | 41 | A namedtuple alias for field number 1 42 |
45 | edge_sets 46 | 48 | A namedtuple alias for field number 2 49 |
52 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/graph_sage.md: -------------------------------------------------------------------------------- 1 | # Module: graph_sage 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | GraphSAGE. 10 | 11 | Users of TF-GNN can use this model by importing it next to the core library as 12 | 13 | ```python 14 | import tensorflow_gnn as tfgnn 15 | from tensorflow_gnn.models import graph_sage 16 | ``` 17 | 18 | ## Classes 19 | 20 | [`class GCNGraphSAGENodeSetUpdate`](./graph_sage/GCNGraphSAGENodeSetUpdate.md): 21 | GCNGraphSAGENodeSetUpdate is an extension of the mean aggregator operator. 22 | 23 | [`class GraphSAGEAggregatorConv`](./graph_sage/GraphSAGEAggregatorConv.md): 24 | GraphSAGE: element-wise aggregation of neighbors and their linear 25 | transformation. 26 | 27 | [`class GraphSAGENextState`](./graph_sage/GraphSAGENextState.md): 28 | GraphSAGENextState: compute new node states with GraphSAGE algorithm. 29 | 30 | [`class GraphSAGEPoolingConv`](./graph_sage/GraphSAGEPoolingConv.md): GraphSAGE: 31 | pooling aggregator transform of neighbors followed by linear transformation. 32 | 33 | ## Functions 34 | 35 | [`GraphSAGEGraphUpdate(...)`](./graph_sage/GraphSAGEGraphUpdate.md): Returns a 36 | GraphSAGE GraphUpdater layer for nodes in node_set_names. 37 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/PadToTotalSizes.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.PadToTotalSizes 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Applies tfgnn.pad_to_total_sizes() to a GraphTensor. 10 | 11 | 19 | 20 | 21 | 22 | 23 | 24 | This Keras layer maps a GraphTensor to a GraphTensor by calling 25 | tfgnn.pad_to_total_sizes() with the additional arguments, notably 26 | `sizes_constraints`, passed at initialization time. See that function 27 | for detailed documentation. 28 | 29 | This layer can be restored from config by `tf.keras.models.load_model()` when 30 | saved as part of a Keras model using `save_format="tf"`. Serialization to a 31 | Keras model config requires the `sizes_constraints` to contain Python integers 32 | or eager Tensors, not symbolic Tensors. 33 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gcn/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "tf_py_test") 7 | 8 | licenses(["notice"]) 9 | 10 | package( 11 | default_visibility = [":__subpackages__"], 12 | ) 13 | 14 | package_group(name = "users") 15 | 16 | pytype_strict_library( 17 | name = "gcn", 18 | srcs = ["__init__.py"], 19 | visibility = [ 20 | ":__subpackages__", 21 | ":users", 22 | ], 23 | deps = [ 24 | ":gcn_conv", 25 | "//tensorflow_gnn/utils:api_utils", 26 | ], 27 | ) 28 | 29 | pytype_strict_library( 30 | name = "gcn_conv", 31 | srcs = ["gcn_conv.py"], 32 | deps = [ 33 | "//:expect_tensorflow_installed", 34 | "//tensorflow_gnn", 35 | ], 36 | ) 37 | 38 | tf_py_test( 39 | name = "gcn_conv_test", 40 | srcs = ["gcn_conv_test.py"], 41 | deps = [ 42 | ":gcn_conv", 43 | "//:expect_absl_installed_testing", 44 | "//:expect_tensorflow_installed", 45 | "//tensorflow_gnn", 46 | "//tensorflow_gnn/utils:tf_test_utils", 47 | "//:expect_ai_edge_litert_installed", 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/graph_sage/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "tf_py_test") 7 | 8 | licenses(["notice"]) 9 | 10 | package(default_visibility = [ 11 | "//tensorflow_gnn:__pkg__", 12 | "//tensorflow_gnn/graph:__subpackages__", 13 | ]) 14 | 15 | package_group(name = "users") 16 | 17 | pytype_strict_library( 18 | name = "graph_sage", 19 | srcs = ["__init__.py"], 20 | visibility = [":users"], 21 | deps = [ 22 | ":layers", 23 | "//tensorflow_gnn/utils:api_utils", 24 | ], 25 | ) 26 | 27 | pytype_strict_library( 28 | name = "layers", 29 | srcs = ["layers.py"], 30 | deps = [ 31 | "//:expect_tensorflow_installed", 32 | "//tensorflow_gnn", 33 | ], 34 | ) 35 | 36 | tf_py_test( 37 | name = "layers_test", 38 | srcs = ["layers_test.py"], 39 | deps = [ 40 | ":layers", 41 | "//:expect_absl_installed_testing", 42 | "//:expect_tensorflow_installed", 43 | "//tensorflow_gnn", 44 | "//tensorflow_gnn/utils:tf_test_utils", 45 | "//:expect_ai_edge_litert_installed", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Metadata.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.Metadata 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Extra information optionally provided on a context, node set or edge set. 10 | 11 | 12 | 13 | For detailed documentation, see the comments in the `graph_schema.proto` file. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 28 | 29 | 32 | 35 | 36 | 39 | 42 | 43 | 46 | 49 | 50 |
23 | bigquery 24 | 26 | BigQuery bigquery 27 |
30 | cardinality 31 | 33 | int64 cardinality 34 |
37 | extra 38 | 40 | repeated KeyValue extra 41 |
44 | filename 45 | 47 | string filename 48 |
51 | 52 | ## Child Classes 53 | 54 | [`class KeyValue`](../../tfgnn/proto/Metadata/KeyValue.md) 55 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/contrastive_losses/README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Losses 2 | 3 | ## Overview 4 | 5 | This code implements and collections various contrastive losses for 6 | self-supervised learning. This code is under *active development*. An overview 7 | of the included: 8 | 9 | ### Deep Graph Infomax 10 | 11 | Deep Graph Infomax [1] attempts to learn a bilinear 12 | layer capable of discriminating between positive examples (any input 13 | `GraphTensor`) and negative examples (the input `GraphTensor` but with perturbed 14 | features: this implementation, as in the original paper, shuffles features 15 | across batch, that is, the components the merged `GraphTensor`). 16 | 17 | Deep Graph Infomax is particularly useful in unsupervised tasks that wish to 18 | learn latent representations informed primarily by a node's neighborhood 19 | attributes (vs. its structure). 20 | 21 | * [1] Petar Veličković, William Fedus, William L. Hamilton, Pietro Liò, 22 | Yoshua Bengio, R Devon Hjelm: 23 | ["Deep Graph Infomax"](https://arxiv.org/abs/1809.10341), 2018. 24 | 25 | ## Usage 26 | 27 | TensorFlow programs can import and use this model as described in its 28 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md). 29 | 30 | ## API stability 31 | 32 | The API of this model may change between OSS library versions. 33 | 34 | 35 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/SingleInputNextState.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.SingleInputNextState 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Replaces a state from a single input. 10 | 11 | 16 | 17 | 18 | 19 | In a NodeSetUpdate, it replaces the node state with a single edge set input. For 20 | an EdgeSetUpdate, it replaces the edge_state with the incident node set's input. 21 | For a ContextUpdate, it replaces the context state with a single node set input. 22 | 23 | This layer can be restored from config by `tf.keras.models.load_model()` when 24 | saved as part of a Keras model using `save_format="tf"`. 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 35 | 36 |
32 | A tensor to use as the new state. 33 |
37 | -------------------------------------------------------------------------------- /tensorflow_gnn/graph/dict_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for Python dictionaries.""" 16 | 17 | from typing import Any, Dict, Mapping, MutableMapping 18 | 19 | 20 | def with_key_prefix(d: Mapping[str, Any], key_prefix: str) -> Dict[str, Any]: 21 | """Returns {key_prefix+k: v for k, v in d.items()}.""" 22 | return {key_prefix+k: v for k, v in d.items()} 23 | 24 | 25 | def pop_by_prefix( 26 | d: MutableMapping[str, Any], key_prefix: str) -> Dict[str, Any]: 27 | """Returns {k: v for key_prefix+k, v in d.items()} and removes them from d.""" 28 | popped = {} 29 | for key in list(d.keys()): 30 | if key.startswith(key_prefix): 31 | popped[key[len(key_prefix):]] = d.pop(key) 32 | return popped 33 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Feature.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.Feature 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The schema entry for a single feature. 10 | 11 | 18 | 19 | 20 | 21 | For detailed documentation, see the comments in the `graph_schema.proto` file. 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 | 54 | 57 | 58 |
31 | description 32 | 34 | string description 35 |
38 | dtype 39 | 41 | DataType dtype 42 |
45 | shape 46 | 48 | TensorShapeProto shape 49 |
52 | source 53 | 55 | string source 56 |
59 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/gat_v2.md: -------------------------------------------------------------------------------- 1 | # Module: gat_v2 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Graph Attention Networks v2. 10 | 11 | Users of TF-GNN can use this model by importing it next to the core library as 12 | 13 | ```python 14 | import tensorflow_gnn as tfgnn 15 | from tensorflow_gnn.models import gat_v2 16 | ``` 17 | 18 | ## Classes 19 | 20 | [`class GATv2Conv`](./gat_v2/GATv2Conv.md): The multi-head attention from Graph 21 | Attention Networks v2 (GATv2). 22 | 23 | ## Functions 24 | 25 | [`GATv2EdgePool(...)`](./gat_v2/GATv2EdgePool.md): Returns a layer for pooling 26 | edges with GATv2-style attention. 27 | 28 | [`GATv2HomGraphUpdate(...)`](./gat_v2/GATv2HomGraphUpdate.md): Returns a 29 | GraphUpdate layer with a Graph Attention Network V2 (GATv2). 30 | 31 | [`GATv2MPNNGraphUpdate(...)`](./gat_v2/GATv2MPNNGraphUpdate.md): Returns a 32 | GraphUpdate layer for message passing with GATv2 pooling. 33 | 34 | [`graph_update_from_config_dict(...)`](./gat_v2/graph_update_from_config_dict.md): 35 | Returns a GATv2MPNNGraphUpdate initialized from `cfg`. 36 | 37 | [`graph_update_get_config_dict(...)`](./gat_v2/graph_update_get_config_dict.md): 38 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 39 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/parse_schema.md: -------------------------------------------------------------------------------- 1 | # tfgnn.parse_schema 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Parse a schema from text-formatted protos. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 33 | 34 |
27 | schema_text 28 | 30 | A string containing a text-formatted protocol buffer rendition 31 | of a GraphSchema message. 32 |
35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 46 | 47 |
43 | A GraphSchema instance. 44 |
48 | 49 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/read_schema.md: -------------------------------------------------------------------------------- 1 | # tfgnn.read_schema 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Read a proto schema from a file with text-formatted contents. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 33 | 34 |
27 | filename 28 | 30 | A string, the path to a file containing a text-formatted protocol 31 | buffer rendition of a GraphSchema message. 32 |
35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 46 | 47 |
43 | A GraphSchema instance. 44 |
48 | 49 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/vanilla_mpnn/hparams_vizier_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for hparams_vizier.""" 16 | 17 | from absl.testing import absltest 18 | from tensorflow_gnn.models.vanilla_mpnn import hparams_vizier 19 | 20 | from vizier.service import pyvizier as vz 21 | 22 | 23 | class HparamsVizierTest(absltest.TestCase): 24 | 25 | def test_regularization(self): 26 | problem = vz.ProblemStatement() 27 | hparams_vizier.add_params_regularization(problem.search_space, 28 | prefix="foo.") 29 | self.assertCountEqual( 30 | [p.name for p in problem.search_space.parameters], 31 | ["foo.dropout_rate", "foo.l2_regularization"]) 32 | 33 | 34 | if __name__ == "__main__": 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/ModelExporter.md: -------------------------------------------------------------------------------- 1 | # runner.ModelExporter 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Saves a Keras model. 10 | 11 | 12 | 13 | ## Methods 14 | 15 |

save

16 | 17 | View 18 | source 19 | 20 | 26 | 27 | Saves a Keras model. 28 | 29 | All persistence decisions are left to the implementation: e.g., a Keras model 30 | with full API or a simple `tf.train.Checkpoint` may be saved. 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 42 | 45 | 46 | 49 | 52 | 53 |
Args
40 | run_result 41 | 43 | A RunResult from training. 44 |
47 | export_dir 48 | 50 | A destination directory. 51 |
54 | -------------------------------------------------------------------------------- /tensorflow_gnn/runner/utils/model_dir.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Model directory methods.""" 16 | import os 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def incrementing_model_dir(dirname: str, start: int = 0) -> str: 22 | """Create, given some `dirname`, an incrementing model directory. 23 | 24 | Args: 25 | dirname: The base directory name. 26 | start: The starting integer. 27 | 28 | Returns: 29 | A model directory `dirname/n` where 'n' is the maximum integer in `dirname`. 30 | """ 31 | if not tf.io.gfile.isdir(dirname): 32 | return os.path.join(dirname, str(start)) 33 | files = tf.io.gfile.listdir(dirname) 34 | integers = [int(f) for f in files if f.isdigit()] 35 | return os.path.join(dirname, str(max(integers) + 1 if integers else start)) 36 | -------------------------------------------------------------------------------- /tensorflow_gnn/data/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_contrib_test") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") 7 | 8 | licenses(["notice"]) 9 | 10 | package( 11 | default_applicable_licenses = ["//tensorflow_gnn:license"], 12 | default_visibility = ["//visibility:public"], 13 | ) 14 | 15 | GOOGLE_INTERNAL_UNIGRAPH_DEPENDENCIES = [] 16 | 17 | pytype_strict_library( 18 | name = "unigraph", 19 | srcs = ["unigraph.py"], 20 | deps = [ 21 | "//third_party/py/apache_beam", 22 | "//third_party/py/pyarrow", 23 | "//:expect_tensorflow_installed", 24 | "//tensorflow_gnn", 25 | ] + GOOGLE_INTERNAL_UNIGRAPH_DEPENDENCIES, 26 | ) 27 | 28 | pytype_strict_contrib_test( 29 | name = "unigraph_test", 30 | srcs = ["unigraph_test.py"], 31 | data = [ 32 | "@tensorflow_gnn//testdata/heterogeneous", 33 | "@tensorflow_gnn//testdata/homogeneous", 34 | ], 35 | deps = [ 36 | ":unigraph", 37 | "//:expect_absl_installed_testing", 38 | "//third_party/py/apache_beam", 39 | "//:expect_tensorflow_installed", 40 | "//tensorflow_gnn", 41 | "//tensorflow_gnn/utils:test_utils", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/PassthruDatasetProvider.md: -------------------------------------------------------------------------------- 1 | # runner.PassthruDatasetProvider 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Builds a `tf.data.Dataset` from a pass thru dataset. 10 | 11 | Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md) 12 | 13 | 21 | 22 | 23 | 24 | Passes any `dataset` thru: omitting any sharding. For detailed documentation, 25 | see the filename dataset provider complement: `SimpleDatasetProvider.` 26 | 27 | ## Methods 28 | 29 |

get_dataset

30 | 31 | View 32 | source 33 | 34 | 39 | 40 | Gets a `tf.data.Dataset` omitting any input context. 41 | -------------------------------------------------------------------------------- /tensorflow_gnn/graph/dict_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dict_utils.""" 16 | 17 | from absl.testing import absltest 18 | from tensorflow_gnn.graph import dict_utils 19 | 20 | 21 | class KeyPrefixTest(absltest.TestCase): 22 | 23 | def testWithKeyPrefix(self): 24 | d1 = {"a": 1, "b": 2} 25 | d2 = dict_utils.with_key_prefix(d1, "p/") 26 | self.assertDictEqual(d1, {"a": 1, "b": 2}) # Unchanged. 27 | self.assertDictEqual(d2, {"p/a": 1, "p/b": 2}) 28 | 29 | def testPopByPrefix(self): 30 | d1 = {"p/a": 1, "p/b": 2, "q/c": 3, "q/d": 4} 31 | d2 = dict_utils.pop_by_prefix(d1, "p/") 32 | self.assertDictEqual(d1, {"q/c": 3, "q/d": 4}) # Changed in-place. 33 | self.assertDictEqual(d2, {"a": 1, "b": 2}) 34 | 35 | 36 | if __name__ == "__main__": 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/GraphSchema.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.GraphSchema 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The top-level container for the schema of a graph dataset. 10 | 11 | 18 | 19 | 20 | 21 | For detailed documentation, see the comments in the `graph_schema.proto` file. 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 | 54 | 57 | 58 |
31 | context 32 | 34 | Context context 35 |
38 | edge_sets 39 | 41 | repeated EdgeSetsEntry edge_sets 42 |
45 | info 46 | 48 | OriginInfo info 49 |
52 | node_sets 53 | 55 | repeated NodeSetsEntry node_sets 56 |
59 | -------------------------------------------------------------------------------- /examples/schemas/latent.pbtxt: -------------------------------------------------------------------------------- 1 | # An example graph schema with some latent nodes. 2 | # This defines three node sets and three edge sets: 3 | # dish -> restaurant -> user 4 | # dish ---------------> user 5 | # with only the 'dish' node type having features. 6 | 7 | node_sets { 8 | key: "dish" 9 | value { 10 | description: "A canonical dish" 11 | features { 12 | key: "embedding" 13 | value: { 14 | description: "Some dish embedding feature" 15 | dtype: DT_FLOAT 16 | shape: { 17 | dim { 18 | size: 32 19 | } 20 | } 21 | } 22 | } 23 | } 24 | } 25 | 26 | # This is how you define a latent set of nodes. 27 | node_sets { 28 | key: "restaurant" 29 | value { 30 | description: "A restaurant that offers dishes" 31 | } 32 | } 33 | 34 | node_sets { 35 | key: "user" 36 | value { 37 | description: "A person who visits restaurants and orders dishes" 38 | } 39 | } 40 | 41 | edge_sets { 42 | key: "dish->restaurant" 43 | value { 44 | description: "Dishes available at restaurants" 45 | source: "dish" 46 | target: "restaurant" 47 | } 48 | } 49 | 50 | edge_sets { 51 | key: "restaurant->user" 52 | value { 53 | description: "Restaurant visits by users" 54 | source: "restaurant" 55 | target: "user" 56 | } 57 | } 58 | 59 | edge_sets { 60 | key: "dish->restaurant" 61 | value { 62 | description: "Dishes the users has ordered in the past" 63 | source: "dish" 64 | target: "restaurant" 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/incrementing_model_dir.md: -------------------------------------------------------------------------------- 1 | # runner.incrementing_model_dir 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Create, given some `dirname`, an incrementing model directory. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 31 | 32 | 35 | 38 | 39 |
26 | dirname 27 | 29 | The base directory name. 30 |
33 | start 34 | 36 | The starting integer. 37 |
40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 50 | 51 | 52 |
48 | A model directory dirname/n where 'n' is the maximum integer in dirname. 49 |
53 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gcn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Graph Convolutional Networks. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import gcn 22 | ``` 23 | """ 24 | from tensorflow_gnn.models.gcn import gcn_conv 25 | from tensorflow_gnn.utils import api_utils 26 | 27 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 28 | # Please see there for instructions how to reflect API changes. 29 | # LINT.IfChange 30 | 31 | GCNConv = gcn_conv.GCNConv 32 | GCNHomGraphUpdate = gcn_conv.GCNHomGraphUpdate 33 | 34 | # Remove all names added by module imports, unless explicitly allowed here. 35 | api_utils.remove_submodules_except(__name__, []) 36 | # LINT.ThenChange(../../api_def/gcn-symbols.txt) 37 | -------------------------------------------------------------------------------- /tensorflow_gnn/proto/BUILD: -------------------------------------------------------------------------------- 1 | # Copybara rewrites some of these load() statements back and forth; do not reformat. 2 | # buildifier: disable=out-of-order-load, disable=same-origin-load 3 | load("@rules_proto//proto:defs.bzl", "proto_library") 4 | 5 | # buildifier: disable=out-of-order-load, disable=same-origin-load 6 | load("@rules_python//python:proto.bzl", "py_proto_library") 7 | 8 | package(default_visibility = ["//visibility:public"]) 9 | 10 | licenses(["notice"]) # Apache 2.0 11 | 12 | py_library( 13 | name = "proto", 14 | srcs = ["__init__.py"], 15 | srcs_version = "PY3", 16 | visibility = ["//visibility:public"], 17 | deps = [ 18 | ":graph_schema", 19 | "//tensorflow_gnn/utils:api_utils", 20 | ], 21 | ) 22 | 23 | proto_library( 24 | name = "graph_schema_proto", 25 | srcs = ["graph_schema.proto"], 26 | deps = [ 27 | "@org_tensorflow//tensorflow/core:protos_all", 28 | ], 29 | ) 30 | 31 | py_proto_library( 32 | name = "graph_schema_py_proto", 33 | deps = [ 34 | ":graph_schema_proto", 35 | ], 36 | ) 37 | 38 | py_library( 39 | name = "graph_schema", 40 | srcs = ["graph_schema.py"], 41 | srcs_version = "PY3", 42 | deps = [ 43 | ":graph_schema_py_proto", 44 | ], 45 | ) 46 | 47 | proto_library( 48 | name = "examples_proto", 49 | srcs = ["examples.proto"], 50 | deps = ["@org_tensorflow//tensorflow/core:protos_all"], 51 | ) 52 | 53 | py_proto_library( 54 | name = "examples_py_proto", 55 | deps = [ 56 | ":examples_proto", 57 | ], 58 | ) 59 | -------------------------------------------------------------------------------- /tensorflow_gnn/tools/generate_training_data_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Unit tests for generate training data test.""" 16 | 17 | from os import path 18 | 19 | from absl import flags 20 | import tensorflow as tf 21 | from tensorflow_gnn.tools import generate_training_data 22 | from tensorflow_gnn.utils import test_utils 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | class GenerateDataTest(tf.test.TestCase): 29 | 30 | def test_generate_training_data(self): 31 | schema_filename = test_utils.get_resource("examples/schemas/mpnn.pbtxt") 32 | output_filename = path.join(FLAGS.test_tmpdir, "examples.tfrecords") 33 | generate_training_data.generate_training_data( 34 | schema_filename, output_filename, "tfrecord", 64) 35 | self.assertTrue(path.exists(output_filename)) 36 | 37 | 38 | if __name__ == "__main__": 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /examples/schemas/mpnn.pbtxt: -------------------------------------------------------------------------------- 1 | # An example graph schema matching demo graphs from MPNN modeling. 2 | # //third_party/py/tensorflow_gnn/proto/graph_schema.proto:GraphSchema 3 | 4 | context { 5 | features { 6 | key: "embedding" 7 | value: { 8 | description: "Global feature vector" 9 | dtype: DT_FLOAT 10 | shape: { dim { size: 128 } } 11 | } 12 | } 13 | } 14 | 15 | node_sets { 16 | key: "videos" 17 | value { 18 | features { 19 | key: "features" 20 | value: { 21 | description: "Encoded video features vector" 22 | dtype: DT_FLOAT 23 | shape: { dim { size: 256 } } 24 | } 25 | } 26 | } 27 | } 28 | 29 | node_sets { 30 | key: "channels" 31 | value { 32 | description: "User or Channel in YouTube." 33 | context: "embedding" 34 | 35 | features { 36 | key: "features" 37 | value: { 38 | description: "Encoded channel features vector" 39 | dtype: DT_FLOAT 40 | shape: { dim { size: 128 } } 41 | } 42 | } 43 | features { 44 | key: "labels" 45 | value: { 46 | description: "Multiple ground truth text labels" 47 | dtype: DT_STRING 48 | shape: { dim { size: -1 } } 49 | } 50 | } 51 | } 52 | } 53 | 54 | edge_sets { 55 | key: "videos->channels" 56 | value { 57 | features { 58 | key: "embedding" 59 | value: { 60 | description: "Encoded edge features vector" 61 | dtype: DT_FLOAT 62 | shape: { dim { size: 32 } } 63 | } 64 | } 65 | source: "videos" 66 | target: "channels" 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/proto/EdgeSet.md: -------------------------------------------------------------------------------- 1 | # tfgnn.proto.EdgeSet 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | The schema shared by a set of edges that connect the same pair of node sets. 10 | 11 | 12 | 13 | For detailed documentation, see the comments in the `graph_schema.proto` file. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 25 | 28 | 29 | 32 | 35 | 36 | 39 | 42 | 43 | 46 | 49 | 50 | 53 | 56 | 57 | 60 | 63 | 64 |
23 | context 24 | 26 | repeated string context 27 |
30 | description 31 | 33 | string description 34 |
37 | features 38 | 40 | repeated FeaturesEntry features 41 |
44 | metadata 45 | 47 | Metadata metadata 48 |
51 | source 52 | 54 | string source 55 |
58 | target 59 | 61 | string target 62 |
65 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DropoutFeatures.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.DropoutFeatures 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Base class for graph corruptor. 10 | 11 | Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md) 12 | 13 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 30 | 33 | 34 | 37 | 40 | 41 | 44 | 48 | 49 | 52 | 55 | 56 |
28 | corruption_spec 29 | 31 | A spec for corruption application. 32 |
35 | corruption_fn 36 | 38 | Corruption function. 39 |
42 | default 43 | 45 | Global application default of the corruptor. This is only used 46 | when corruption_spec is None. 47 |
50 | **kwargs 51 | 53 | Additional keyword arguments. 54 |
57 | -------------------------------------------------------------------------------- /tensorflow_gnn/keras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The tfgnn.keras package.""" 16 | 17 | from tensorflow_gnn.keras import builders 18 | from tensorflow_gnn.keras import initializers 19 | from tensorflow_gnn.keras import keras_tensors # To register the types. pylint: disable=unused-import 20 | from tensorflow_gnn.keras import layers # Exposed as submodule. pylint: disable=unused-import 21 | from tensorflow_gnn.utils import api_utils 22 | 23 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 24 | # Please see there for instructions how to reflect API changes. 25 | # LINT.IfChange 26 | 27 | ConvGNNBuilder = builders.ConvGNNBuilder 28 | clone_initializer = initializers.clone_initializer 29 | 30 | # Remove all names added by module imports, unless explicitly allowed here. 31 | api_utils.remove_submodules_except(__name__, [ 32 | "layers", 33 | ]) 34 | # LINT.ThenChange()../api_def/tfgnn-symbols.txt) 35 | -------------------------------------------------------------------------------- /examples/sampler/creditcard/graph_schema.pbtxt: -------------------------------------------------------------------------------- 1 | # proto-file: tensorflow_gnn/proto/graph_schema.proto 2 | # proto-message: GraphSchema 3 | # Example graph of transactions, credit cards, customer. 4 | 5 | node_sets { 6 | key: "customer" 7 | value { 8 | features { 9 | key: "name" 10 | value: { 11 | description: "Name" 12 | dtype: DT_STRING 13 | } 14 | } 15 | features { 16 | key: "address" 17 | value: { 18 | description: "address" 19 | dtype: DT_STRING 20 | } 21 | } 22 | features { 23 | key: "zipcode" 24 | value: { 25 | description: "Zipcode" 26 | dtype: DT_INT64 27 | } 28 | } 29 | features { 30 | key: "score" 31 | value: { 32 | description: "Credit score" 33 | dtype: DT_FLOAT 34 | } 35 | } 36 | metadata { 37 | filename: "customer.csv" 38 | } 39 | } 40 | } 41 | 42 | node_sets { 43 | key: "creditcard" 44 | value { 45 | metadata { 46 | filename: "creditcard.csv" 47 | } 48 | features { 49 | key: "number" 50 | value: { 51 | description: "Credit card number" 52 | dtype: DT_INT64 53 | } 54 | } 55 | features { 56 | key: "issuer" 57 | value: { 58 | description: "Credit card issuer institution" 59 | dtype: DT_STRING 60 | } 61 | } 62 | } 63 | } 64 | 65 | edge_sets { 66 | key: "owns_card" 67 | value { 68 | description: "Owns and uses the credit card." 69 | source: "customer" 70 | target: "creditcard" 71 | metadata { 72 | filename: "owns_card.csv" 73 | } 74 | } 75 | } -------------------------------------------------------------------------------- /tensorflow_gnn/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Experimental (unstable) parts of the public interface of TensorFlow GNN. 16 | 17 | A symbol `foo` exposed here is available to library users as 18 | 19 | ``` 20 | import tensorflow_gnn as tfgnn 21 | 22 | tfgnn.experimental.foo() 23 | ``` 24 | 25 | This is the preferred way to expose individual functions on track to inclusion 26 | into the stable public interface of TensorFlow GNN. 27 | 28 | Beyond these symbols, there are also experimental sub-libraries that 29 | need to be imported separately (`from tensorflow_gnn.experimental import foo`). 30 | That is for special cases only. 31 | """ 32 | 33 | from tensorflow_gnn.graph import readout 34 | from tensorflow_gnn.graph import tensor_utils 35 | 36 | context_readout_into_feature = readout.context_readout_into_feature 37 | segment_random_index_shuffle = tensor_utils.segment_random_index_shuffle 38 | 39 | del readout 40 | del tensor_utils 41 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/Corruptor.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.Corruptor 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Base class for graph corruptor. 10 | 11 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 32 | 35 | 36 | 39 | 42 | 43 | 46 | 50 | 51 | 54 | 57 | 58 |
30 | corruption_spec 31 | 33 | A spec for corruption application. 34 |
37 | corruption_fn 38 | 40 | Corruption function. 41 |
44 | default 45 | 47 | Global application default of the corruptor. This is only used 48 | when corruption_spec is None. 49 |
52 | **kwargs 53 | 55 | Additional keyword arguments. 56 |
59 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/SizeConstraints.md: -------------------------------------------------------------------------------- 1 | # tfgnn.SizeConstraints 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Constraints on the number of entities in the graph. 10 | 11 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 39 | 40 | 43 | 46 | 47 | 50 | 53 | 54 | 57 | 60 | 61 |
34 | total_num_components 35 | 37 | A namedtuple alias for field number 0 38 |
41 | total_num_nodes 42 | 44 | A namedtuple alias for field number 1 45 |
48 | total_num_edges 49 | 51 | A namedtuple alias for field number 2 52 |
55 | min_nodes_per_component 56 | 58 | A namedtuple alias for field number 3 59 |
62 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/multi_head_attention/README.md: -------------------------------------------------------------------------------- 1 | # Transformer-style multi-head attention 2 | 3 | ## Overview 4 | 5 | This code implements transformer-style (dot-product) multi-head attention, 6 | with different variants and optional attention score leaks. 7 | 8 | Some publications in the GNN context that either use this multi-head attention 9 | as a component ([1]&[2]) or a baseline ([3]) of their method: 10 | 11 | * [1] Vijay Prakash Dwivedi, Xavier Bresson: ["A Generalization of Transformer 12 | Networks to Graphs"](https://arxiv.org/abs/2012.09699), 2021. 13 | (We only implement their attention, not their position encoding.) 14 | * [2] Dongkwan Kim, Alice Oh: ["How to Find Your Friendly Neighborhood: Graph 15 | Attention Design with Self-Supervision"](https://arxiv.org/abs/2204.04879) 16 | , 2022. (They call it "DP" attention.) 17 | * [3] Shaked Brody, Uri Alon, Eran Yahav: ["How Attentive are Graph Attention 18 | Networks?"](https://arxiv.org/abs/2105.14491), 2021. 19 | (They discuss "DPGAT" as a baseline in the appendix, citing further uses. 20 | Their main contribution "GATv2" is implemented [elsewhere](../gat_v2) 21 | in TF-GNN.) 22 | 23 | ## Usage 24 | 25 | TensorFlow programs can import and use this model as described in its 26 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md). 27 | 28 | ## API stability 29 | 30 | The API of this model may change between OSS library versions. 31 | 32 | TF-GNN's [Model Template "Albis"](../mt_albis/README.md) offers a stable and 33 | simplified API for a subset of this model's configuration options. 34 | 35 | -------------------------------------------------------------------------------- /testdata/homogeneous/tastelike.recordio.ascii: -------------------------------------------------------------------------------- 1 | features { 2 | feature { 3 | key: "#source" 4 | value: { bytes_list: { value: [ "amanatsu" ] } } 5 | } 6 | feature { 7 | key: "#target" 8 | value: { bytes_list: { value: [ "daidai" ] } } 9 | } 10 | feature { 11 | key: "weights" 12 | value: { float_list: { value: [ 0.1 ] } } 13 | } 14 | } 15 | 16 | features { 17 | feature { 18 | key: "#source" 19 | value: { bytes_list: { value: [ "amanatsu" ] } } 20 | } 21 | feature { 22 | key: "#target" 23 | value: { bytes_list: { value: [ "lumia" ] } } 24 | } 25 | feature { 26 | key: "weights" 27 | value: { float_list: { value: [ 0.2 ] } } 28 | } 29 | } 30 | 31 | features { 32 | feature { 33 | key: "#source" 34 | value: { bytes_list: { value: [ "kiyomi" ] } } 35 | } 36 | feature { 37 | key: "#target" 38 | value: { bytes_list: { value: [ "komikan" ] } } 39 | } 40 | feature { 41 | key: "weights" 42 | value: { float_list: { value: [ 0.3 ] } } 43 | } 44 | } 45 | 46 | features { 47 | feature { 48 | key: "#source" 49 | value: { bytes_list: { value: [ "mandora" ] } } 50 | } 51 | feature { 52 | key: "#target" 53 | value: { bytes_list: { value: [ "komikan" ] } } 54 | } 55 | feature { 56 | key: "weights" 57 | value: { float_list: { value: [ 0.4 ] } } 58 | } 59 | } 60 | 61 | features { 62 | feature { 63 | key: "#source" 64 | value: { bytes_list: { value: [ "mandora" ] } } 65 | } 66 | feature { 67 | key: "#target" 68 | value: { bytes_list: { value: [ "tangelo" ] } } 69 | } 70 | feature { 71 | key: "weights" 72 | value: { float_list: { value: [ 0.5 ] } } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /testdata/heterogeneous/transactions.csv: -------------------------------------------------------------------------------- 1 | id,merchant,amount 2 | 5338667949,AcmeCorp,1156.80 3 | 2485246926,Plumbing Co,206.92 4 | 7574807079,HW Store Big,7.97 5 | 1935037613,Cons. Depot,2.38 6 | 3719329604,The Oil Tanker,172.99 7 | 1015902584,Light Inc,924.72 8 | 6360467569,Roofing Etc.,6.27 9 | 3672845476,Dig-a-Hole Co,91.68 10 | 7577999036,Smile Networks,257.00 11 | 4077264491,Jewel Corporation,1162.72 12 | 6932577020,Melon Systems,32.17 13 | 2952830085,Explority,1.02 14 | 4045329101,Apachicorp,1085.31 15 | 2187034790,Gnomelectrics,44.41 16 | 4071088591,Amazystems,19.96 17 | 3301719092,Herocast,329.38 18 | 3630070609,Gorillalife,689.12 19 | 7220792354,Desertcoms,2.19 20 | 9418989087,Yellow Sports,41.13 21 | 4751453767,Vertex Coms,30.05 22 | 5488583952,Ice Records,67.77 23 | 7738736966,Happindustries,37.15 24 | 7118748138,Antelligence,1.03 25 | 2507142984,Cycloration,299.41 26 | 6967991082,Herolutions,16.24 27 | 9099370006,Signcloud,9.63 28 | 2674535393,Diamondlight,18.76 29 | 8806410342,Accentmart,21.09 30 | 8140734876,Padlock Productions,59.83 31 | 2432513727,Buck Intelligence,132.79 32 | 5420798237,Root Aviation,648.87 33 | 6060907986,Joytechs,161.45 34 | 3315349239,Driftonics,3.66 35 | 5221186582,Honeydustries,21.87 36 | 8597703614,Tigressystems,204.30 37 | 2277618242,Cubebank,3.32 38 | 8179757481,Voidtechs,9.99 39 | 2561247267,Houndbooks,11.17 40 | 5675097657,Slick Co.,41.98 41 | 8790765633,Oak Softwares,213.11 42 | 4719730577,Hero Motors,7.86 43 | 9232165517,Elecoms,124.75 44 | 8334230482,Microlutions,12.26 45 | 5930862377,Tundracoustics,0.64 46 | 8603978315,Tempestechnologies,842.09 47 | 5339238148,Stormex,2.61 48 | 2346560845,Heartair,2.34 49 | 4415620884,Gnomebank,12.14 -------------------------------------------------------------------------------- /tensorflow_gnn/models/hgt/hparams_vizier_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for hparams_vizier.""" 16 | 17 | from absl.testing import absltest 18 | from tensorflow_gnn.models.hgt import hparams_vizier 19 | 20 | from vizier.service import pyvizier as vz 21 | 22 | 23 | class HparamsVizierTest(absltest.TestCase): 24 | 25 | def test_regularization(self): 26 | problem = vz.ProblemStatement() 27 | hparams_vizier.add_params_regularization( 28 | problem.search_space, prefix="foo." 29 | ) 30 | self.assertCountEqual( 31 | [p.name for p in problem.search_space.parameters], ["foo.dropout_rate"] 32 | ) 33 | 34 | def test_hgt_attention(self): 35 | problem = vz.ProblemStatement() 36 | hparams_vizier.add_params_attention( 37 | problem.search_space, prefix="foo." 38 | ) 39 | self.assertCountEqual( 40 | [p.name for p in problem.search_space.parameters], ["foo.num_heads"] 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/graph_tensor_to_values.md: -------------------------------------------------------------------------------- 1 | # tfgnn.graph_tensor_to_values 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Convert an eager `GraphTensor` to a mapping of mappings of PODTs. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | This is used for pretty-printing. Convert your graph tensor with this and run 22 | the result through `pprint.pprint()` or `pprint.pformat()` for display of its 23 | contents. 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 37 | 38 |
32 | graph 33 | 35 | An eager GraphTensor instance to be pprinted. 36 |
39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 50 | 51 | 52 |
47 | A dict of plain-old data types that can be run through pprint.pprint() or 48 | a JSON conversion library. 49 |
53 | 54 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/hgt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Heterogeneous Graph Transformers. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import hgt 22 | ``` 23 | """ 24 | from tensorflow_gnn.models.hgt import config_dict 25 | from tensorflow_gnn.models.hgt import layers 26 | from tensorflow_gnn.utils import api_utils 27 | 28 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 29 | # Please see there for instructions how to reflect API changes. 30 | # LINT.IfChange 31 | 32 | HGTGraphUpdate = layers.HGTGraphUpdate 33 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict 34 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict 35 | 36 | # Remove all names added by module imports, unless explicitly allowed here. 37 | api_utils.remove_submodules_except(__name__, []) 38 | # LINT.ThenChange(../../api_def/hgt-symbols.txt) 39 | -------------------------------------------------------------------------------- /testdata/homogeneous/tastelike.sstable.ascii: -------------------------------------------------------------------------------- 1 | 0 2 | features { 3 | feature { 4 | key: "#source" 5 | value: { bytes_list: { value: [ "amanatsu" ] } } 6 | } 7 | feature { 8 | key: "#target" 9 | value: { bytes_list: { value: [ "daidai" ] } } 10 | } 11 | feature { 12 | key: "weights" 13 | value: { float_list: { value: [ 0.1 ] } } 14 | } 15 | } 16 | 17 | 1 18 | features { 19 | feature { 20 | key: "#source" 21 | value: { bytes_list: { value: [ "amanatsu" ] } } 22 | } 23 | feature { 24 | key: "#target" 25 | value: { bytes_list: { value: [ "lumia" ] } } 26 | } 27 | feature { 28 | key: "weights" 29 | value: { float_list: { value: [ 0.2 ] } } 30 | } 31 | } 32 | 33 | 2 34 | features { 35 | feature { 36 | key: "#source" 37 | value: { bytes_list: { value: [ "kiyomi" ] } } 38 | } 39 | feature { 40 | key: "#target" 41 | value: { bytes_list: { value: [ "komikan" ] } } 42 | } 43 | feature { 44 | key: "weights" 45 | value: { float_list: { value: [ 0.3 ] } } 46 | } 47 | } 48 | 49 | 3 50 | features { 51 | feature { 52 | key: "#source" 53 | value: { bytes_list: { value: [ "mandora" ] } } 54 | } 55 | feature { 56 | key: "#target" 57 | value: { bytes_list: { value: [ "komikan" ] } } 58 | } 59 | feature { 60 | key: "weights" 61 | value: { float_list: { value: [ 0.4 ] } } 62 | } 63 | } 64 | 65 | 4 66 | features { 67 | feature { 68 | key: "#source" 69 | value: { bytes_list: { value: [ "mandora" ] } } 70 | } 71 | feature { 72 | key: "#target" 73 | value: { bytes_list: { value: [ "tangelo" ] } } 74 | } 75 | feature { 76 | key: "weights" 77 | value: { float_list: { value: [ 0.5 ] } } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /tensorflow_gnn/experimental/sampler/proto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The protocol message (protobuf) types defined by TF-GNN Sampler.""" 16 | 17 | from tensorflow_gnn.experimental.sampler.proto import eval_dag_pb2 18 | from tensorflow_gnn.utils import api_utils 19 | 20 | 21 | # Program computation DAG, its stages and layers. 22 | Program = eval_dag_pb2.Program 23 | EvalDAG = eval_dag_pb2.EvalDAG 24 | Stage = eval_dag_pb2.Stage 25 | Layer = eval_dag_pb2.Layer 26 | 27 | # Specifications of input/output values of layers. 28 | ValueSpec = eval_dag_pb2.ValueSpec 29 | TensorSpec = eval_dag_pb2.TensorSpec 30 | RaggedTensorSpec = eval_dag_pb2.RaggedTensorSpec 31 | FlattenedSpec = eval_dag_pb2.FlattenedSpec 32 | 33 | # Layer configs. 34 | EdgeSamplingConfig = eval_dag_pb2.EdgeSamplingConfig 35 | IOFeatures = eval_dag_pb2.IOFeatures 36 | 37 | 38 | # Remove all names added by module imports, unless explicitly allowed here. 39 | api_utils.remove_submodules_except(__name__, []) 40 | # LINT.ThenChange()../api_def/sampler-symbols.txt) 41 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md: -------------------------------------------------------------------------------- 1 | # Module: multi_head_attention 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Transformer-style multi-head attention. 10 | 11 | Users of TF-GNN can use this model by importing it next to the core library as 12 | 13 | ```python 14 | import tensorflow_gnn as tfgnn 15 | from tensorflow_gnn.models import multi_head_attention 16 | ``` 17 | 18 | ## Classes 19 | 20 | [`class MultiHeadAttentionConv`](./multi_head_attention/MultiHeadAttentionConv.md): 21 | Transformer-style (dot-product) multi-head attention on GNNs. 22 | 23 | ## Functions 24 | 25 | [`MultiHeadAttentionEdgePool(...)`](./multi_head_attention/MultiHeadAttentionEdgePool.md): 26 | Returns a layer for pooling edges with Transformer-style Multi-Head Attention. 27 | 28 | [`MultiHeadAttentionHomGraphUpdate(...)`](./multi_head_attention/MultiHeadAttentionHomGraphUpdate.md): 29 | Returns a GraphUpdate layer with a transformer-style multihead attention. 30 | 31 | [`MultiHeadAttentionMPNNGraphUpdate(...)`](./multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md): 32 | Returns a GraphUpdate layer for message passing with MultiHeadAttention pooling. 33 | 34 | [`graph_update_from_config_dict(...)`](./multi_head_attention/graph_update_from_config_dict.md): 35 | Returns a MultiHeadAttentionMPNNGraphUpdate initialized from `cfg`. 36 | 37 | [`graph_update_get_config_dict(...)`](./multi_head_attention/graph_update_get_config_dict.md): 38 | Returns ConfigDict for graph_update_from_config_dict() with defaults. 39 | -------------------------------------------------------------------------------- /tensorflow_gnn/graph/graph_tensor_pprint_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for pretty-printing.""" 16 | 17 | import pprint 18 | 19 | import tensorflow as tf 20 | from tensorflow_gnn.graph import graph_tensor_pprint as gpp 21 | from tensorflow_gnn.graph import graph_tensor_random as gr 22 | from tensorflow_gnn.graph import schema_utils as su 23 | import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2 24 | from tensorflow_gnn.utils import test_utils 25 | 26 | 27 | class TestConvertForPprint(tf.test.TestCase): 28 | 29 | def test_graph_tensor_to_values(self): 30 | schema = test_utils.get_proto_resource( 31 | 'testdata/feature_repr.pbtxt', schema_pb2.GraphSchema()) 32 | spec = su.create_graph_spec_from_schema_pb(schema) 33 | graph = gr.random_graph_tensor(spec) 34 | values = gpp.graph_tensor_to_values(graph) 35 | text = pprint.pformat(values) 36 | # This just ensures there is no error in the genreation. 37 | self.assertIsInstance(text, str) 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.test.main() 42 | -------------------------------------------------------------------------------- /testdata/heterogeneous/paid_with.csv: -------------------------------------------------------------------------------- 1 | target,source,retries 2 | 5338667949,1238474857489384,0 3 | 2485246926,12968701241275060,0 4 | 7574807079,12441028369470600,0 5 | 1935037613,12968701241275060,0 6 | 3719329604,14990890937985390,0 7 | 1015902584,14912408563871390,2 8 | 6360467569,18362223127059380,3 9 | 3672845476,3549061668422198,0 10 | 7577999036,11385846637304370,0 11 | 4077264491,14844931107602160,0 12 | 6932577020,14844931107602160,0 13 | 2952830085,11739198589848540,0 14 | 4045329101,13916484476264770,0 15 | 2187034790,8878522895102384,0 16 | 4071088591,14990890937985390,1 17 | 3301719092,11290312140467510,0 18 | 3630070609,13019350102369400,0 19 | 7220792354,18569067217418250,0 20 | 9418989087,3549061668422198,7 21 | 4751453767,4541017563963442,4 22 | 5488583952,1238474857489384,0 23 | 7738736966,11739198589848540,0 24 | 7118748138,18526138896540830,0 25 | 2507142984,17035680063294790,4 26 | 6967991082,16073125141142750,0 27 | 9099370006,17035680063294790,0 28 | 2674535393,9991040399813057,0 29 | 8806410342,8889177882781586,1 30 | 8140734876,11739198589848540,0 31 | 2432513727,11470379189154620,0 32 | 5420798237,11584989140147230,0 33 | 6060907986,14990890937985390,0 34 | 3315349239,18569067217418250,0 35 | 5221186582,11739198589848540,0 36 | 8597703614,11584989140147230,0 37 | 2277618242,11771673810809530,1 38 | 8179757481,8878522895102384,0 39 | 2561247267,16827485386298040,0 40 | 5675097657,18526138896540830,0 41 | 8790765633,16283233487191600,1 42 | 4719730577,14844931107602160,0 43 | 9232165517,16011471358128450,0 44 | 8334230482,15054318664602640,1 45 | 5930862377,12948957000457930,0 46 | 8603978315,14453480592564160,5 47 | 5339238148,11739198589848540,0 48 | 2346560845,11584989140147230,1 49 | 4415620884,12441028369470600,0 -------------------------------------------------------------------------------- /tensorflow_gnn/models/graph_sage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """GraphSAGE. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import graph_sage 22 | ``` 23 | """ 24 | 25 | from tensorflow_gnn.models.graph_sage import layers 26 | from tensorflow_gnn.utils import api_utils 27 | 28 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 29 | # Please see there for instructions how to reflect API changes. 30 | # LINT.IfChange 31 | 32 | GCNGraphSAGENodeSetUpdate = layers.GCNGraphSAGENodeSetUpdate 33 | GraphSAGEAggregatorConv = layers.GraphSAGEAggregatorConv 34 | GraphSAGEPoolingConv = layers.GraphSAGEPoolingConv 35 | GraphSAGENextState = layers.GraphSAGENextState 36 | GraphSAGEGraphUpdate = layers.GraphSAGEGraphUpdate 37 | 38 | # Remove all names added by module imports, unless explicitly allowed here. 39 | api_utils.remove_submodules_except(__name__, []) 40 | # LINT.ThenChange(../../api_def/graph_sage-symbols.txt) 41 | -------------------------------------------------------------------------------- /examples/sampler/creditcard/creditcard.csv: -------------------------------------------------------------------------------- 1 | id,number,issuer 2 | 11238474857489380,11238474857489380,BofBC 3 | 14216252633958570,14216252633958570,HeyBank 4 | 14541017563963440,14541017563963440,BellsGarbo 5 | 13549061668422190,13549061668422190,BellsGarbo 6 | 12948957000457930,12948957000457930,GDBank 7 | 11163838768727470,11163838768727470,BellsGarbo 8 | 11191576325053580,11191576325053580,BofBC 9 | 11290312140467510,11290312140467510,GDBank 10 | 11385846637304370,11385846637304370,BellsGarbo 11 | 11470379189154620,11470379189154620,HeyBank 12 | 11584989140147230,11584989140147230,BofBC 13 | 11739198589848540,11739198589848540,GDBank 14 | 11771673810809530,11771673810809530,BellsGarbo 15 | 12441028369470600,12441028369470600,BofBC 16 | 12968701241275060,12968701241275060,BellsGarbo 17 | 12982257258547830,12982257258547830,BellsGarbo 18 | 13019350102369400,13019350102369400,BellsGarbo 19 | 13916484476264770,13916484476264770,GDBank 20 | 14453480592564160,14453480592564160,BofBC 21 | 14844931107602160,14844931107602160,BellsGarbo 22 | 14912408563871390,14912408563871390,BofBC 23 | 14990890937985390,14990890937985390,BellsGarbo 24 | 15054318664602640,15054318664602640,HeyBank 25 | 16011471358128450,16011471358128450,BellsGarbo 26 | 16073125141142750,16073125141142750,GDBank 27 | 16283233487191600,16283233487191600,BellsGarbo 28 | 16827485386298040,16827485386298040,BellsGarbo 29 | 17035680063294790,17035680063294790,BellsGarbo 30 | 17396883707513070,17396883707513070,BellsGarbo 31 | 17861046738135650,17861046738135650,HeyBank 32 | 18362223127059380,18362223127059380,GDBank 33 | 18526138896540830,18526138896540830,GDBank 34 | 18569067217418250,18569067217418250,GDBank 35 | 18878522895102380,18878522895102380,HeyBank 36 | 18889177882781580,18889177882781580,BofBC 37 | 19991040399813050,19991040399813050,BofBC -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/NextStateFromConcat.md: -------------------------------------------------------------------------------- 1 | # tfgnn.keras.layers.NextStateFromConcat 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Computes a new state by concatenating inputs and applying a Keras Layer. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | This layer flattens all inputs into a list (forgetting their origin), 22 | concatenates them and sends them through a user-supplied feed-forward network. 23 | 24 | This layer can be restored from config by `tf.keras.models.load_model()` when 25 | saved as part of a Keras model using `save_format="tf"`. 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 40 | 41 |
34 | transformation 35 | 37 | Required. A Keras Layer to transform the combined inputs 38 | into the new state. 39 |
42 | 43 | 44 | 45 | 46 | 47 | 48 | 51 | 52 | 53 |
49 | The result of transformation. 50 |
54 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/assert_constraints.md: -------------------------------------------------------------------------------- 1 | # tfgnn.assert_constraints 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Validate the shape constaints of a graph's features at runtime. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | This code returns a TensorFlow op with debugging assertions that ensure the 22 | parsed data has valid shape constraints for a graph. This can be instantiated 23 | in your TensorFlow graph while debugging if you believe that your data may be 24 | incorrectly shaped, or simply applied to a manually produced dataset to ensure 25 | that those constraints have been applied correctly. 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 39 | 40 |
34 | graph 35 | 37 | An instance of a GraphTensor. 38 |
41 | 42 | 43 | 44 | 45 | 46 | 47 | 50 | 51 | 52 |
48 | A list of check operations. 49 |
53 | 54 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gat_v2/hparams_vizier_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for hparams_vizier.""" 16 | 17 | from absl.testing import absltest 18 | from tensorflow_gnn.models.gat_v2 import hparams_vizier 19 | 20 | from vizier.service import pyvizier as vz 21 | 22 | 23 | class HparamsVizierTest(absltest.TestCase): 24 | 25 | def test_regularization(self): 26 | problem = vz.ProblemStatement() 27 | hparams_vizier.add_params_regularization( 28 | problem.search_space, prefix="foo.") 29 | self.assertCountEqual([p.name for p in problem.search_space.parameters], [ 30 | "foo.state_dropout_rate", "foo.edge_dropout_rate", 31 | "foo.l2_regularization" 32 | ]) 33 | 34 | def test_attention(self): 35 | problem = vz.ProblemStatement() 36 | hparams_vizier.add_params_attention(problem.search_space, 37 | prefix="foo.") 38 | self.assertCountEqual( 39 | [p.name for p in problem.search_space.parameters], 40 | ["foo.num_heads"]) 41 | 42 | 43 | if __name__ == "__main__": 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /package/move_generated_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Moves the bazel generated files needed for packaging the wheel to the source 17 | # tree. 18 | 19 | function _is_windows() { 20 | [[ "$(uname -s | tr 'A-Z' 'a-z')" =~ (cygwin|mingw32|mingw64|msys)_nt* ]] 21 | } 22 | 23 | function tfgnn::move_generated_files() { 24 | if _is_windows; then 25 | # See https://github.com/bazelbuild/bazel/issues/6761 for bazel-bin. 26 | GENFILES=${BUILD_WORKSPACE_DIRECTORY}/bazel-genfiles 27 | if [[ ! -d ${GENFILES} ]]; then 28 | GENFILES=${BUILD_WORKSPACE_DIRECTORY}/bazel-bin 29 | fi 30 | else 31 | # If run by "bazel run", $(pwd) is the .runfiles dir that contains all the 32 | # data dependencies. 33 | GENFILES=$(pwd) 34 | fi 35 | 36 | FILES=" 37 | tensorflow_gnn/experimental/sampler/proto/eval_dag_pb2.py 38 | tensorflow_gnn/proto/graph_schema_pb2.py 39 | tensorflow_gnn/proto/examples_pb2.py 40 | tensorflow_gnn/sampler/sampling_spec_pb2.py 41 | tensorflow_gnn/tools/sampled_stats_pb2.py 42 | " 43 | for FILE in ${FILES}; do 44 | cp -f ${GENFILES}/${FILE} ${BUILD_WORKSPACE_DIRECTORY}/${FILE} 45 | done 46 | } 47 | 48 | tfgnn::move_generated_files 49 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ShuffleFeaturesGlobally.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.ShuffleFeaturesGlobally 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | A corruptor that shuffles features. 10 | 11 | Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md) 12 | 13 | 18 | 19 | 20 | 21 | NOTE: this function does not currently support TPUs. Consider using other 22 | corruptor functions if executing on TPUs. See b/269249455 for reference. 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 34 | 37 | 38 | 41 | 44 | 45 | 48 | 52 | 53 | 56 | 59 | 60 |
32 | corruption_spec 33 | 35 | A spec for corruption application. 36 |
39 | corruption_fn 40 | 42 | Corruption function. 43 |
46 | default 47 | 49 | Global application default of the corruptor. This is only used 50 | when corruption_spec is None. 51 |
54 | **kwargs 55 | 57 | Additional keyword arguments. 58 |
61 | -------------------------------------------------------------------------------- /testdata/heterogeneous/creditcard.csv: -------------------------------------------------------------------------------- 1 | id,number,issuer 2 | 11238474857489380,11238474857489380,BofBC 3 | 14216252633958570,14216252633958570,HeyBank 4 | 14541017563963440,14541017563963440,BellsGarbo 5 | 13549061668422190,13549061668422190,BellsGarbo 6 | 12948957000457930,12948957000457930,GDBank 7 | 11163838768727470,11163838768727470,BellsGarbo 8 | 11191576325053580,11191576325053580,BofBC 9 | 11290312140467510,11290312140467510,GDBank 10 | 11385846637304370,11385846637304370,BellsGarbo 11 | 11470379189154620,11470379189154620,HeyBank 12 | 11584989140147230,11584989140147230,BofBC 13 | 11739198589848540,11739198589848540,GDBank 14 | 11771673810809530,11771673810809530,BellsGarbo 15 | 12441028369470600,12441028369470600,BofBC 16 | 12968701241275060,12968701241275060,BellsGarbo 17 | 12982257258547830,12982257258547830,BellsGarbo 18 | 13019350102369400,13019350102369400,BellsGarbo 19 | 13916484476264770,13916484476264770,GDBank 20 | 14453480592564160,14453480592564160,BofBC 21 | 14844931107602160,14844931107602160,BellsGarbo 22 | 14912408563871390,14912408563871390,BofBC 23 | 14990890937985390,14990890937985390,BellsGarbo 24 | 15054318664602640,15054318664602640,HeyBank 25 | 16011471358128450,16011471358128450,BellsGarbo 26 | 16073125141142750,16073125141142750,GDBank 27 | 16283233487191600,16283233487191600,BellsGarbo 28 | 16827485386298040,16827485386298040,BellsGarbo 29 | 17035680063294790,17035680063294790,BellsGarbo 30 | 17396883707513070,17396883707513070,BellsGarbo 31 | 17861046738135650,17861046738135650,HeyBank 32 | 18362223127059380,18362223127059380,GDBank 33 | 18526138896540830,18526138896540830,GDBank 34 | 18569067217418250,18569067217418250,GDBank 35 | 18878522895102380,18878522895102380,HeyBank 36 | 18889177882781580,18889177882781580,BofBC 37 | 19991040399813050,19991040399813050,BofBC -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/iter_sets.md: -------------------------------------------------------------------------------- 1 | # tfgnn.iter_sets 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Utility function to iterate over all the sets present in a graph schema. 10 | 11 | 16 | 17 | 18 | 19 | This function iterates over the context set, each of the node sets, and 20 | finally each of the edge sets. 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 31 | 34 | 35 |
29 | schema 30 | 32 | An instance of a GraphSchema proto message. 33 |
36 | 37 | 38 | 39 | 40 | 41 | 42 | 49 | 50 | 51 |
43 | Triplets of (set-type, set-name, features) where 44 | 45 | * set-type: A type of set, which is either of "context", "nodes" or "edges". 46 | * set-name: A string, the name of the set. 47 | * features: A dict of feature-name to feature-value. 48 |
52 | 53 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/multi_head_attention/hparams_vizier_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for hparams_vizier.""" 16 | 17 | from absl.testing import absltest 18 | from tensorflow_gnn.models.multi_head_attention import hparams_vizier 19 | 20 | from vizier.service import pyvizier as vz 21 | 22 | 23 | class HparamsVizierTest(absltest.TestCase): 24 | 25 | def test_regularization(self): 26 | problem = vz.ProblemStatement() 27 | hparams_vizier.add_params_regularization( 28 | problem.search_space, prefix="foo.") 29 | self.assertCountEqual([p.name for p in problem.search_space.parameters], [ 30 | "foo.state_dropout_rate", "foo.edge_dropout_rate", 31 | "foo.l2_regularization" 32 | ]) 33 | 34 | def test_attention(self): 35 | problem = vz.ProblemStatement() 36 | hparams_vizier.add_params_attention(problem.search_space, 37 | prefix="foo.") 38 | self.assertCountEqual( 39 | [p.name for p in problem.search_space.parameters], 40 | ["foo.num_heads"]) 41 | 42 | 43 | if __name__ == "__main__": 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/TightPadding.md: -------------------------------------------------------------------------------- 1 | # runner.TightPadding 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Calculates tight `SizeConstraints` for `GraphTensor` padding. 10 | 11 | Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md) 12 | 13 | 20 | 21 | 22 | 23 | See: `tfgnn.find_tight_size_constraints.` 24 | 25 | ## Methods 26 | 27 |

get_filter_fn

28 | 29 | View 30 | source 31 | 32 | 37 | 38 |

get_size_constraints

39 | 40 | View 41 | source 42 | 43 | 48 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/PassthruSampleDatasetsProvider.md: -------------------------------------------------------------------------------- 1 | # runner.PassthruSampleDatasetsProvider 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Builds a sampled `tf.data.Dataset` from multiple pass thru datasets. 10 | 11 | Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md) 12 | 13 | 26 | 27 | 28 | 29 | Passes any `principal_dataset` and `extra_datasets` thru: omitting any sharding. 30 | For detailed documentation, see the filename dataset provider complement: 31 | `SimpleSampleDatasetsProvider.` 32 | 33 | ## Methods 34 | 35 |

get_dataset

36 | 37 | View 38 | source 39 | 40 | 45 | 46 | Gets a sampled `tf.data.Dataset` omitting any input context. 47 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/add_self_loops.md: -------------------------------------------------------------------------------- 1 | # tfgnn.add_self_loops 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Adds self-loops for `edge_set_name` EVEN if they already exist. 10 | 11 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 32 | 33 | 36 | 40 | 41 |
27 | graph 28 | 30 | A scalar GraphTensor. 31 |
34 | edge_set_name 35 | 37 | An edge set in graph that has the same node set as source 38 | and target. 39 |
42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 56 | 57 | 58 |
50 | A GraphTensor with self-loops added. A self-loop is added at each node, 51 | even if some or all of these nodes already have a loop. All feature tensors 52 | of the edge set are extended to cover the newly added edges with values 53 | that are all zeros (for numeric features), false (for boolean features), or 54 | empty (for string features), respectively. 55 |
59 | -------------------------------------------------------------------------------- /tensorflow_gnn/converters/triples_test.py: -------------------------------------------------------------------------------- 1 | """Tests for triples.""" 2 | 3 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | import tensorflow as tf 18 | import tensorflow_gnn as tfgnn 19 | from tensorflow_gnn.converters import triples 20 | 21 | 22 | class TriplesTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super().setUp() 26 | tfgnn.enable_graph_tensor_validation_at_runtime() 27 | 28 | def test_triples_from_array(self): 29 | spos = [["gnns", "are", "awesome"], ["tfgnn", "is", "awesome"], 30 | ["nns", "are", "awesome"]] 31 | 32 | gt = triples.triple_to_graphtensor(triples=spos) 33 | is_source = int(gt.edge_sets["is"].adjacency.source) 34 | is_target = int(gt.edge_sets["is"].adjacency.target) 35 | 36 | self.assertLen(gt.node_sets, 1) 37 | self.assertLen(gt.edge_sets, 2) 38 | 39 | self.assertEqual(gt.node_sets["nodes"].sizes, 4) 40 | self.assertEqual(gt.node_sets["nodes"].features["#id"].shape, 4) 41 | 42 | self.assertEqual(gt.node_sets["nodes"].features["#id"][is_source], "tfgnn") 43 | self.assertEqual(gt.node_sets["nodes"].features["#id"][is_target], 44 | "awesome") 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/mt_albis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF-GNN's Model Template "Albis". 16 | 17 | The TF-GNN Model Template "Albis" provides a small selection of field-tested 18 | GNN architectures through the `mt_albis.MtAlbisGraphUpdate` class. 19 | 20 | Users of TF-GNN can use it by importing it next to the core library as 21 | 22 | ```python 23 | import tensorflow_gnn as tfgnn 24 | from tensorflow_gnn.models import mt_albis 25 | ``` 26 | """ 27 | 28 | from tensorflow_gnn.models.mt_albis import config_dict 29 | from tensorflow_gnn.models.mt_albis import layers 30 | from tensorflow_gnn.utils import api_utils 31 | 32 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 33 | # Please see there for instructions how to reflect API changes. 34 | # LINT.IfChange 35 | 36 | MtAlbisGraphUpdate = layers.MtAlbisGraphUpdate 37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict 38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict 39 | 40 | # Remove all names added by module imports, unless explicitly allowed here. 41 | api_utils.remove_submodules_except(__name__, []) 42 | # LINT.ThenChange(../../api_def/mt_albis-symbols.txt) 43 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/vanilla_mpnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF-GNN's "Vanilla MPNN" model. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import vanilla_gnn 22 | ``` 23 | 24 | This model ties together some simple convolutions from the TF-GNN core library, 25 | so it does not define any Conv class by itself. 26 | """ 27 | 28 | from tensorflow_gnn.models.vanilla_mpnn import config_dict 29 | from tensorflow_gnn.models.vanilla_mpnn import layers 30 | from tensorflow_gnn.utils import api_utils 31 | 32 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 33 | # Please see there for instructions how to reflect API changes. 34 | # LINT.IfChange 35 | 36 | VanillaMPNNGraphUpdate = layers.VanillaMPNNGraphUpdate 37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict 38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict 39 | 40 | # Remove all names added by module imports, unless explicitly allowed here. 41 | api_utils.remove_submodules_except(__name__, []) 42 | # LINT.ThenChange(../../api_def/vanilla_mpnn-symbols.txt) 43 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/gat_v2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Graph Attention Networks v2. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import gat_v2 22 | ``` 23 | """ 24 | 25 | from tensorflow_gnn.models.gat_v2 import config_dict 26 | from tensorflow_gnn.models.gat_v2 import layers 27 | from tensorflow_gnn.utils import api_utils 28 | 29 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 30 | # Please see there for instructions how to reflect API changes. 31 | # LINT.IfChange 32 | 33 | GATv2Conv = layers.GATv2Conv 34 | GATv2EdgePool = layers.GATv2EdgePool 35 | GATv2GraphUpdate = layers.GATv2GraphUpdate # Deprecated. 36 | GATv2HomGraphUpdate = layers.GATv2HomGraphUpdate 37 | GATv2MPNNGraphUpdate = layers.GATv2MPNNGraphUpdate 38 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict 39 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict 40 | 41 | # Remove all names added by module imports, unless explicitly allowed here. 42 | api_utils.remove_submodules_except(__name__, []) 43 | # LINT.ThenChange(../../api_def/gat_v2-symbols.txt) 44 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/FitOrSkipPadding.md: -------------------------------------------------------------------------------- 1 | # runner.FitOrSkipPadding 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Calculates fit or skip `SizeConstraints` for `GraphTensor` padding. 10 | 11 | Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md) 12 | 13 | 22 | 23 | 24 | 25 | See: `tfgnn.learn_fit_or_skip_size_constraints.` 26 | 27 | ## Methods 28 | 29 |

get_filter_fn

30 | 31 | View 32 | source 33 | 34 | 39 | 40 |

get_size_constraints

41 | 42 | View 43 | source 44 | 45 | 50 | -------------------------------------------------------------------------------- /tensorflow_gnn/models/README.md: -------------------------------------------------------------------------------- 1 | # TF-GNN Models 2 | 3 | ## Introduction 4 | 5 | This directory contains a collection of GNN models implemented with the 6 | TF-GNN library. Some of them offer reusable pieces that can be imported 7 | _next to_ the core TF-GNN library, which effectively makes them little 8 | libraries of their own. 9 | 10 | ### Usage 11 | 12 | If, for example, the hypothetical FancyNet model offered a graph update layer, 13 | its use would look like 14 | 15 | ```python 16 | import tensorflow_gnn as tfgnn 17 | from tensorflow_gnn.models import fancynet 18 | 19 | graph = fancynet.FancyGraphUpdate(units=42, fanciness=0.99, ...)(graph) 20 | ``` 21 | 22 | ...and require a separate dependency for `fancynet` in a BUILD file. 23 | 24 | ### API stability 25 | 26 | Each model comes with a README file that describes its intended level of 27 | API stability. Not all models are covered by the [semantic 28 | versioning](https://semver.org/spec/v2.0.0.html) of the TF-GNN package. 29 | 30 | ## List of Models 31 | 32 | 33 | 34 | * [Contrastive Losses](contrastive_losses/README.md): Contrastive losses for 35 | self-supervised learning. 36 | * [GATv2](gat_v2/README.md): Graph Attention Networks v2 37 | (Brody&al, 2021). 38 | * [GCN](gcn/README.md): Graph Convolutional Networks 39 | (Kipf&Welling, 2016), for homogeneous graphs only. 40 | * [GraphSAGE](graph_sage/README.md) (Hamilton&al., 2017). 41 | * [MtAlbis](mt_albis/README.md): Model Template "Albis" for easy configuration 42 | of a few field-tested GNN architectures, generalizing VanillaMPNN. 43 | * [MultiHeadAttention](multi_head_attention/README.md): Transformer-style 44 | multi-head attention on graph (Dwivedi&Bresson, 2021). 45 | * [VanillaMPNN](vanilla_mpnn/README.md): TF-GNN's classic baseline model, 46 | based on (Gilmer&al., 2016). 47 | 48 | Unsure? For generic node prediction tasks on relational data, we recommend 49 | to start with MtAlbis. -------------------------------------------------------------------------------- /tensorflow_gnn/api_def/runner-symbols.txt: -------------------------------------------------------------------------------- 1 | runner.ContextLabelFn 2 | runner.DatasetProvider 3 | runner.DotProductLinkPrediction 4 | runner.FitOrSkipPadding 5 | runner.GraphBinaryClassification 6 | runner.GraphMeanAbsoluteError 7 | runner.GraphMeanAbsolutePercentageError 8 | runner.GraphMeanSquaredError 9 | runner.GraphMeanSquaredLogScaledError 10 | runner.GraphMeanSquaredLogarithmicError 11 | runner.GraphMulticlassClassification 12 | runner.GraphTensorPadding 13 | runner.GraphTensorProcessorFn 14 | runner.HadamardProductLinkPrediction 15 | runner.IntegratedGradientsExporter 16 | runner.KerasModelExporter 17 | runner.KerasTrainer 18 | runner.KerasTrainerCheckpointOptions 19 | runner.KerasTrainerOptions 20 | runner.Loss 21 | runner.Losses 22 | runner.Metric 23 | runner.Metrics 24 | runner.ModelExporter 25 | runner.NodeMeanAbsoluteError 26 | runner.NodeMeanAbsolutePercentageError 27 | runner.NodeMeanSquaredError 28 | runner.NodeMeanSquaredLogScaledError 29 | runner.NodeMeanSquaredLogarithmicError 30 | runner.ParameterServerStrategy 31 | runner.PassthruDatasetProvider 32 | runner.PassthruSampleDatasetsProvider 33 | runner.Predictions 34 | runner.NodeBinaryClassification 35 | runner.RootNodeBinaryClassification 36 | runner.RootNodeLabelFn 37 | runner.RootNodeMeanAbsoluteError 38 | runner.RootNodeMeanAbsolutePercentageError 39 | runner.RootNodeMeanSquaredError 40 | runner.RootNodeMeanSquaredLogScaledError 41 | runner.RootNodeMeanSquaredLogarithmicError 42 | runner.NodeMulticlassClassification 43 | runner.RootNodeMulticlassClassification 44 | runner.RunResult 45 | runner.SampleTFRecordDatasetsProvider 46 | runner.export_model 47 | runner.SimpleDatasetProvider 48 | runner.SimpleSampleDatasetsProvider 49 | runner.SubmoduleExporter 50 | runner.TFDataServiceConfig 51 | runner.TFRecordDatasetProvider 52 | runner.TPUStrategy 53 | runner.Task 54 | runner.TightPadding 55 | runner.Trainer 56 | runner.incrementing_model_dir 57 | runner.integrated_gradients 58 | runner.one_node_per_component 59 | runner.run 60 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/runner/TFDataServiceConfig.md: -------------------------------------------------------------------------------- 1 | # runner.TFDataServiceConfig 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Provides tf.data service related configuration options. 10 | 11 | 18 | 19 | 20 | 21 | tf.data service has data flexible visitation guarantees, its impact over your 22 | training pipelines will be empirical. Check out the tf.data service internals 23 | and operation details from 24 | https://www.tensorflow.org/api_docs/python/tf/data/experimental/service. 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 39 | 40 | 43 | 46 | 47 | 50 | 53 | 54 |
34 | tf_data_service_address 35 | 37 | Dataclass field 38 |
41 | tf_data_service_job_name 42 | 44 | Dataclass field 45 |
48 | tf_data_service_mode 49 | 51 | Dataclass field 52 |
55 | 56 | ## Methods 57 | 58 |

__eq__

59 | 60 | 65 | 66 | Return self==value. 67 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/combine_values.md: -------------------------------------------------------------------------------- 1 | # tfgnn.combine_values 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Combines a list of tensors into one (by concatenation or otherwise). 10 | 11 | 16 | 17 | 18 | 19 | This is a convenience wrapper around standard TensorFlow operations, to 20 | provide standard names for common types of combining. 21 | 22 | 23 | 24 | 25 | 26 | 27 | 39 | 40 |
inputs a list of Tensors or 28 | RaggedTensors, with shapes and types that are compatible for the selected 29 | combine_type.
30 | combine_type one of the 31 | following string values, to select the method for combining the inputs: 32 | 33 | * "sum": The input tensors are added. Their dtypes and shapes must 34 | match. 35 | * "concat": The input tensors are concatenated along the last axis. 36 | Their dtypes and shapes must match, except for the number of elements 37 | along the last axis. 38 |
41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 52 | 53 | 54 |
50 | A tensor with the combined value of the inputs. 51 |
55 | 56 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/coherence.md: -------------------------------------------------------------------------------- 1 | # contrastive_losses.coherence 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Coherence metric implementation. 10 | 11 | 20 | 21 | 22 | 23 | Coherence measures how easy it is to construct a linear classifier on top of 24 | data without knowing downstream labels. Refer to 25 | https://arxiv.org/abs/2305.16562 for more details. 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 37 | 40 | 41 | 44 | 47 | 48 | 51 | 55 | 56 |
35 | representations 36 | 38 | Input representations, a rank-2 tensor. 39 |
42 | sigma 43 | 45 | Unused. 46 |
49 | u 50 | 52 | An optional tensor with left singular vectors of representations. If not 53 | present, computes a SVD of representations. 54 |
57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 67 | 68 | 69 |
65 | Metric value as scalar tf.Tensor. 66 |
70 | -------------------------------------------------------------------------------- /tensorflow_gnn/tools/validate_graph_schema.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate a schema's features and shapes. 16 | 17 | This script ensures that a schema is valid, has correct shapes, and isn't 18 | clobbering over reserve feature names. 19 | """ 20 | 21 | import sys 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import tensorflow_gnn as tfgnn 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | def define_flags(): 33 | """Define program flags.""" 34 | 35 | flags.DEFINE_string("graph_schema", None, 36 | ("A filename to a text-formatted schema proto describing " 37 | "the available graph features.")) 38 | 39 | flags.mark_flag_as_required("graph_schema") 40 | 41 | 42 | def app_main(unused_argv): 43 | """App runner main function.""" 44 | schema = tfgnn.read_schema(FLAGS.graph_schema) 45 | try: 46 | warnings = tfgnn.validate_schema(schema) 47 | for warning in warnings: 48 | logging.warning(warning) 49 | logging.info("Schema validated correctly.") 50 | except tfgnn.ValidationError as exc: 51 | logging.error("Schema validation error: %s", exc) 52 | sys.exit(1) 53 | 54 | 55 | def main(): 56 | define_flags() 57 | app.run(app_main) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /examples/sampler/creditcard/customer.csv: -------------------------------------------------------------------------------- 1 | id,name,address,zipcode,score 2 | 1876448,Ji Grindstaff,"343 Third St. Houston, TX 77016",77016,0.5174975367 3 | 1372437,Augustina Uren,"9940 Prairie Ave. Deer Park, NY 11729",11729,0.9055414601 4 | 1368305,Yolonda Nave,"478 Grove Drive Hicksville, NY 11801",11801,0.2407929709 5 | 1974494,Adriana Mcburney,"909 Vermont St. Livonia, MI 48150",48150,0.3432604494 6 | 1257724,Dione Reeb,"7758 West Devon St. Algonquin, IL 60102",60102,0.5395817998 7 | 1758057,Geri Bones,"720 Marsh Road Tucker, GA 30084",30084,0.9270786453 8 | 1531660,Krystal Pablo,"227 Bishop St. Bemidji, MN 56601",56601,0.1871476497 9 | 1489311,Tonia Behnke,"51 Ramblewood St. Hobart, IN 46342",46342,0.1171643897 10 | 1407706,Fidel Speers,"714 Hilldale Ave. Cumming, GA 30040",30040,0.2359154945 11 | 196838,Necole Hunkins,"14 South Grove St. Coatesville, PA 19320",19320,0.06037660472 12 | 1195675,Tona Crays,"136 Somerset Dr. Chester, PA 19013",19013,0.6864072435 13 | 1659366,Mary Zeitz,"6 Thatcher St. Hillsboro, OR 97124",97124,0.643440095 14 | 1499004,Rudolph Sinquefield,"673 Elmwood Drive Fairburn, GA 30213",30213,0.9164648402 15 | 1344333,Marhta Rodrigue,"7284 Young Lane Upland, CA 91784",91784,0.6359473777 16 | 1443888,Ricardo Bundrick,"19 Rock Creek St. Sulphur, LA 70663",70663,0.4666323814 17 | 1108778,Myron Barrick,"450 Boston Street Solon, OH 44139",44139,0.7162866466 18 | 175583,Nichol Poulton,"34 Old 8th Drive Carmel, NY 10512",10512,0.2922429056 19 | 1251872,Boyd Padilla,"7351 North Spring Ave. Oshkosh, WI 54901",54901,0.4839780673 20 | 1493851,Orpha Yokoyama,"7805 Newport Street Sylvania, OH 43560",43560,0.7630316714 21 | 1599418,Ulysses Harps,"9138 Gates Street Braintree, MA 02184",1284,0.3491518542 22 | 1768701,Sulema Aguero,"42 South Wayne St. Hollywood, FL 33020",33020,0.7137308073 23 | 1549489,Meredith Warman,"98 Corona Court Morton Grove, IL 60053",60053,0.5057767373 24 | 1879799,Vonda Borth,"97 Tunnel Dr. Elyria, OH 44035",44035,0.6586585941 25 | 125454,Candida Uvalle,"608 Heritage Street Harrison Township, MI 48045",48045,0.6696324006 -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/satisfies_size_constraints.md: -------------------------------------------------------------------------------- 1 | # tfgnn.satisfies_size_constraints 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Returns whether the input `graph_tensor` satisfies `total_sizes`. 10 | 11 | 18 | 19 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 40 | 43 | 44 | 47 | 50 | 51 |
38 | graph_tensor 39 | 41 | a graph tensor to check against target total sizes. 42 |
45 | total_sizes 46 | 48 | target total sizes for each graph piece. 49 |
52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 63 | 64 | 65 |
60 | A scalar boolean tensor equal to True if the graph_tensor statisifies 61 | total_sizes, and False if not. 62 |
66 | 67 | -------------------------------------------------------------------------------- /tensorflow_gnn/docs/api_docs/python/tfgnn/iter_features.md: -------------------------------------------------------------------------------- 1 | # tfgnn.iter_features 2 | 3 | 4 | 5 | 6 | View source 7 | on GitHub 8 | 9 | Utility function to iterate over the features of a graph schema. 10 | 11 | 16 | 17 | 18 | 19 | This function iterates over all the feature values of each of the context set, 20 | each of the node sets, and each of the edge sets. 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 31 | 34 | 35 |
29 | schema 30 | 32 | An instance of a GraphSchema proto message. 33 |
36 | 37 | 38 | 39 | 40 | 41 | 42 | 50 | 51 |
43 | Triplets of (set-type, set-name, feature-name, feature-value) where 44 | 45 | * set-type: A type of set, which is either of "context", "nodes" or "edges". 46 | * set-name: A string, the name of the set. 47 | * feature-name: A string, the name of the feature in the set. 48 | * feature-value: A potentially ragged tensor (either a tf.Tensor 49 | or a tf.RaggedTensor).
52 | 53 | -------------------------------------------------------------------------------- /testdata/heterogeneous/customer.csv: -------------------------------------------------------------------------------- 1 | id,name,address,zipcode,score 2 | 1876448,Ji Grindstaff,"343 Third St. Houston, TX 77016",77016,0.5174975367 3 | 1372437,Augustina Uren,"9940 Prairie Ave. Deer Park, NY 11729",11729,0.9055414601 4 | 1368305,Yolonda Nave,"478 Grove Drive Hicksville, NY 11801",11801,0.2407929709 5 | 1974494,Adriana Mcburney,"909 Vermont St. Livonia, MI 48150",48150,0.3432604494 6 | 1257724,Dione Reeb,"7758 West Devon St. Algonquin, IL 60102",60102,0.5395817998 7 | 1758057,Geri Bones,"720 Marsh Road Tucker, GA 30084",30084,0.9270786453 8 | 1531660,Krystal Pablo,"227 Bishop St. Bemidji, MN 56601",56601,0.1871476497 9 | 1489311,Tonia Behnke,"51 Ramblewood St. Hobart, IN 46342",46342,0.1171643897 10 | 1407706,Fidel Speers,"714 Hilldale Ave. Cumming, GA 30040",30040,0.2359154945 11 | 196838,Necole Hunkins,"14 South Grove St. Coatesville, PA 19320",19320,0.06037660472 12 | 1195675,Tona Crays,"136 Somerset Dr. Chester, PA 19013",19013,0.6864072435 13 | 1659366,Mary Zeitz,"6 Thatcher St. Hillsboro, OR 97124",97124,0.643440095 14 | 1499004,Rudolph Sinquefield,"673 Elmwood Drive Fairburn, GA 30213",30213,0.9164648402 15 | 1344333,Marhta Rodrigue,"7284 Young Lane Upland, CA 91784",91784,0.6359473777 16 | 1443888,Ricardo Bundrick,"19 Rock Creek St. Sulphur, LA 70663",70663,0.4666323814 17 | 1108778,Myron Barrick,"450 Boston Street Solon, OH 44139",44139,0.7162866466 18 | 175583,Nichol Poulton,"34 Old 8th Drive Carmel, NY 10512",10512,0.2922429056 19 | 1251872,Boyd Padilla,"7351 North Spring Ave. Oshkosh, WI 54901",54901,0.4839780673 20 | 1493851,Orpha Yokoyama,"7805 Newport Street Sylvania, OH 43560",43560,0.7630316714 21 | 1599418,Ulysses Harps,"9138 Gates Street Braintree, MA 02184",1284,0.3491518542 22 | 1768701,Sulema Aguero,"42 South Wayne St. Hollywood, FL 33020",33020,0.7137308073 23 | 1549489,Meredith Warman,"98 Corona Court Morton Grove, IL 60053",60053,0.5057767373 24 | 1879799,Vonda Borth,"97 Tunnel Dr. Elyria, OH 44035",44035,0.6586585941 25 | 125454,Candida Uvalle,"608 Heritage Street Harrison Township, MI 48045",48045,0.6696324006 -------------------------------------------------------------------------------- /tensorflow_gnn/models/multi_head_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Transformer-style multi-head attention. 16 | 17 | Users of TF-GNN can use this model by importing it next to the core library as 18 | 19 | ```python 20 | import tensorflow_gnn as tfgnn 21 | from tensorflow_gnn.models import multi_head_attention 22 | ``` 23 | """ 24 | 25 | from tensorflow_gnn.models.multi_head_attention import config_dict 26 | from tensorflow_gnn.models.multi_head_attention import layers 27 | from tensorflow_gnn.utils import api_utils 28 | 29 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py. 30 | # Please see there for instructions how to reflect API changes. 31 | # LINT.IfChange 32 | 33 | MultiHeadAttentionConv = layers.MultiHeadAttentionConv 34 | MultiHeadAttentionEdgePool = layers.MultiHeadAttentionEdgePool 35 | MultiHeadAttentionHomGraphUpdate = layers.MultiHeadAttentionHomGraphUpdate 36 | MultiHeadAttentionMPNNGraphUpdate = layers.MultiHeadAttentionMPNNGraphUpdate 37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict 38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict 39 | 40 | # Remove all names added by module imports, unless explicitly allowed here. 41 | api_utils.remove_submodules_except(__name__, []) 42 | # LINT.ThenChange(../../api_def/multi_head_attention-symbols.txt) 43 | --------------------------------------------------------------------------------