├── tensorflow_gnn
├── data
│ ├── __init__.py
│ └── BUILD
├── graph
│ ├── __init__.py
│ ├── dict_utils.py
│ ├── dict_utils_test.py
│ └── graph_tensor_pprint_test.py
├── tools
│ ├── __init__.py
│ ├── BUILD
│ ├── generate_training_data_test.py
│ └── validate_graph_schema.py
├── converters
│ ├── __init__.py
│ ├── ogb
│ │ └── __init__.py
│ └── triples_test.py
├── utils
│ ├── __init__.py
│ └── BUILD
├── api_def
│ ├── gcn-symbols.txt
│ ├── hgt-symbols.txt
│ ├── mt_albis-symbols.txt
│ ├── vanilla_mpnn-symbols.txt
│ ├── graph_sage-symbols.txt
│ ├── gat_v2-symbols.txt
│ ├── multi_head_attention-symbols.txt
│ ├── contrastive_losses-symbols.txt
│ ├── sampler-symbols.txt
│ └── runner-symbols.txt
├── docs
│ ├── guide
│ │ └── images
│ │ │ ├── homogeneous.png
│ │ │ └── heterogeneous.png
│ └── api_docs
│ │ ├── python
│ │ ├── models
│ │ │ ├── gcn
│ │ │ │ └── all_symbols.md
│ │ │ ├── mt_albis
│ │ │ │ ├── all_symbols.md
│ │ │ │ └── graph_update_get_config_dict.md
│ │ │ ├── vanilla_mpnn
│ │ │ │ ├── all_symbols.md
│ │ │ │ └── graph_update_get_config_dict.md
│ │ │ ├── gat_v2
│ │ │ │ ├── graph_update_get_config_dict.md
│ │ │ │ └── all_symbols.md
│ │ │ ├── multi_head_attention
│ │ │ │ ├── graph_update_get_config_dict.md
│ │ │ │ └── all_symbols.md
│ │ │ ├── contrastive_losses
│ │ │ │ ├── DeepGraphInfomaxLogits.md
│ │ │ │ ├── TripletEmbeddingSquaredDistances.md
│ │ │ │ ├── DropoutFeatures.md
│ │ │ │ ├── Corruptor.md
│ │ │ │ ├── ShuffleFeaturesGlobally.md
│ │ │ │ └── coherence.md
│ │ │ ├── graph_sage
│ │ │ │ └── all_symbols.md
│ │ │ ├── gcn.md
│ │ │ ├── vanilla_mpnn.md
│ │ │ ├── mt_albis.md
│ │ │ ├── graph_sage.md
│ │ │ ├── gat_v2.md
│ │ │ └── multi_head_attention.md
│ │ ├── runner
│ │ │ ├── Loss.md
│ │ │ ├── Predictions.md
│ │ │ ├── Losses.md
│ │ │ ├── ContextLabelFn.md
│ │ │ ├── Metrics.md
│ │ │ ├── one_node_per_component.md
│ │ │ ├── RootNodeLabelFn.md
│ │ │ ├── GraphTensorProcessorFn.md
│ │ │ ├── DatasetProvider.md
│ │ │ ├── GraphTensorPadding.md
│ │ │ ├── ModelExporter.md
│ │ │ ├── PassthruDatasetProvider.md
│ │ │ ├── incrementing_model_dir.md
│ │ │ ├── TightPadding.md
│ │ │ ├── PassthruSampleDatasetsProvider.md
│ │ │ ├── FitOrSkipPadding.md
│ │ │ └── TFDataServiceConfig.md
│ │ └── tfgnn
│ │ │ ├── Field.md
│ │ │ ├── FieldSpec.md
│ │ │ ├── Fields.md
│ │ │ ├── IncidentNodeOrContextTag.md
│ │ │ ├── FieldsSpec.md
│ │ │ ├── FieldOrFields.md
│ │ │ ├── reverse_tag.md
│ │ │ ├── enable_graph_tensor_validation.md
│ │ │ ├── is_graph_tensor.md
│ │ │ ├── enable_graph_tensor_validation_at_runtime.md
│ │ │ ├── get_registered_reduce_operation_names.md
│ │ │ ├── is_dense_tensor.md
│ │ │ ├── disable_graph_tensor_validation_at_runtime.md
│ │ │ ├── is_ragged_tensor.md
│ │ │ ├── check_scalar_graph_tensor.md
│ │ │ ├── ValidationError.md
│ │ │ ├── keras.md
│ │ │ ├── check_homogeneous_graph_tensor.md
│ │ │ ├── keras
│ │ │ └── layers
│ │ │ │ ├── AddSelfLoops.md
│ │ │ │ ├── ParseExample.md
│ │ │ │ ├── ParseSingleExample.md
│ │ │ │ ├── PadToTotalSizes.md
│ │ │ │ ├── SingleInputNextState.md
│ │ │ │ └── NextStateFromConcat.md
│ │ │ ├── disable_graph_tensor_validation.md
│ │ │ ├── proto
│ │ │ ├── Metadata
│ │ │ │ └── KeyValue.md
│ │ │ ├── BigQuery
│ │ │ │ └── TableSpec.md
│ │ │ ├── OriginInfo.md
│ │ │ ├── Context.md
│ │ │ ├── NodeSet.md
│ │ │ ├── Metadata.md
│ │ │ ├── Feature.md
│ │ │ ├── GraphSchema.md
│ │ │ └── EdgeSet.md
│ │ │ ├── experimental.md
│ │ │ ├── softmax_edges_per_node.md
│ │ │ ├── sampler
│ │ │ ├── SamplingSpec.md
│ │ │ └── SamplingOp.md
│ │ │ ├── sampler.md
│ │ │ ├── write_schema.md
│ │ │ ├── FeatureDefaultValues.md
│ │ │ ├── parse_schema.md
│ │ │ ├── read_schema.md
│ │ │ ├── SizeConstraints.md
│ │ │ ├── graph_tensor_to_values.md
│ │ │ ├── assert_constraints.md
│ │ │ ├── iter_sets.md
│ │ │ ├── add_self_loops.md
│ │ │ ├── combine_values.md
│ │ │ ├── satisfies_size_constraints.md
│ │ │ └── iter_features.md
│ │ └── README.md
├── sampler
│ ├── __init__.py
│ └── BUILD
├── experimental
│ ├── sampler
│ │ ├── BUILD
│ │ └── proto
│ │ │ ├── BUILD
│ │ │ └── __init__.py
│ ├── BUILD
│ └── __init__.py
├── models
│ ├── hgt
│ │ ├── README.md
│ │ ├── hparams_vizier_test.py
│ │ └── __init__.py
│ ├── gat_v2
│ │ ├── README.md
│ │ ├── hparams_vizier_test.py
│ │ └── __init__.py
│ ├── graph_sage
│ │ ├── README.md
│ │ ├── BUILD
│ │ └── __init__.py
│ ├── gcn
│ │ ├── README.md
│ │ ├── BUILD
│ │ └── __init__.py
│ ├── contrastive_losses
│ │ └── README.md
│ ├── vanilla_mpnn
│ │ ├── hparams_vizier_test.py
│ │ └── __init__.py
│ ├── multi_head_attention
│ │ ├── README.md
│ │ ├── hparams_vizier_test.py
│ │ └── __init__.py
│ ├── mt_albis
│ │ └── __init__.py
│ └── README.md
├── runner
│ ├── trainers
│ │ └── BUILD
│ ├── utils
│ │ ├── saved_model_test.sh
│ │ └── model_dir.py
│ └── input
│ │ └── BUILD
├── proto
│ ├── examples.proto
│ └── BUILD
└── keras
│ └── __init__.py
├── .gitignore
├── testdata
├── heterogeneous
│ ├── invalid_customer.csv
│ ├── one_customer.csv
│ ├── two_customers.csv
│ ├── BUILD
│ ├── owns_card.csv
│ ├── transactions.csv
│ ├── paid_with.csv
│ ├── creditcard.csv
│ └── customer.csv
├── homogeneous
│ ├── one_seed.csv
│ ├── two_seeds.csv
│ ├── tastelike.csv
│ ├── fruits.csv
│ ├── BUILD
│ ├── citrus.pbtxt
│ ├── tastelike.recordio.ascii
│ └── tastelike.sstable.ascii
├── node_vs_edge
│ ├── node_set_one.csv
│ ├── node_set_two.csv
│ ├── edge_set_two_to_two.csv
│ ├── edge_set_one_to_two.csv
│ ├── spec.pbtxt
│ ├── BUILD
│ └── schema.pbtxt
├── README.md
├── BUILD
└── feature_repr.pbtxt
├── requirements-dev.txt
├── MANIFEST.in
├── AUTHORS
├── kokoro
└── github
│ └── ubuntu
│ └── cpu
│ ├── continuous.cfg
│ ├── presubmit.cfg
│ ├── newest_stable
│ ├── presubmit.cfg
│ └── continuous.cfg
│ └── oldest
│ ├── continuous.cfg
│ └── presubmit.cfg
├── examples
├── sampler
│ ├── creditcard
│ │ ├── sampling_spec.pbtxt
│ │ ├── owns_card.csv
│ │ ├── graph_schema.pbtxt
│ │ ├── creditcard.csv
│ │ └── customer.csv
│ └── mag
│ │ └── sampling_spec.pbtxt
└── schemas
│ ├── graph_nets.pbtxt
│ ├── latent.pbtxt
│ └── mpnn.pbtxt
├── package
├── BUILD
├── tfdep.bzl
└── move_generated_files.sh
└── Dockerfile
/tensorflow_gnn/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_gnn/graph/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_gnn/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_gnn/converters/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_gnn/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tensorflow_gnn/converters/ogb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bazel-*
2 | build
3 | dist
4 | *.egg-info
5 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/invalid_customer.csv:
--------------------------------------------------------------------------------
1 | #id
2 | 1
3 |
--------------------------------------------------------------------------------
/testdata/homogeneous/one_seed.csv:
--------------------------------------------------------------------------------
1 | #id
2 | amanatsu
3 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/one_customer.csv:
--------------------------------------------------------------------------------
1 | #id
2 | 1876448
3 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/node_set_one.csv:
--------------------------------------------------------------------------------
1 | #id,name
2 | a,A
3 |
--------------------------------------------------------------------------------
/testdata/homogeneous/two_seeds.csv:
--------------------------------------------------------------------------------
1 | #id
2 | amanatsu
3 | daidai
4 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/node_set_two.csv:
--------------------------------------------------------------------------------
1 | #id,name
2 | b,B
3 | c,C
4 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/two_customers.csv:
--------------------------------------------------------------------------------
1 | #id
2 | 1876448
3 | 1372437
4 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/gcn-symbols.txt:
--------------------------------------------------------------------------------
1 | gcn.GCNConv
2 | gcn.GCNHomGraphUpdate
3 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/edge_set_two_to_two.csv:
--------------------------------------------------------------------------------
1 | #source,#target,weight
2 | b,c,1.0
3 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/edge_set_one_to_two.csv:
--------------------------------------------------------------------------------
1 | #source,#target,weight
2 | a,b,1.0
3 | a,c,1.0
4 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | mock
2 | wheel
3 | # includes glibcxx older than 3.4.29
4 | ai-edge-litert-nightly
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/hgt-symbols.txt:
--------------------------------------------------------------------------------
1 | hgt.HGTGraphUpdate
2 | hgt.graph_update_from_config_dict
3 | hgt.graph_update_get_config_dict
4 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/guide/images/homogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/gnn/HEAD/tensorflow_gnn/docs/guide/images/homogeneous.png
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/guide/images/heterogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/gnn/HEAD/tensorflow_gnn/docs/guide/images/heterogeneous.png
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/mt_albis-symbols.txt:
--------------------------------------------------------------------------------
1 | mt_albis.MtAlbisGraphUpdate
2 | mt_albis.graph_update_from_config_dict
3 | mt_albis.graph_update_get_config_dict
4 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/vanilla_mpnn-symbols.txt:
--------------------------------------------------------------------------------
1 | vanilla_mpnn.VanillaMPNNGraphUpdate
2 | vanilla_mpnn.graph_update_from_config_dict
3 | vanilla_mpnn.graph_update_get_config_dict
4 |
--------------------------------------------------------------------------------
/testdata/homogeneous/tastelike.csv:
--------------------------------------------------------------------------------
1 | #source,#target,weight
2 | amanatsu,daidai,0.1
3 | amanatsu,lumia,0.2
4 | kiyomi,komikan,0.3
5 | mandora,komikan,0.4
6 | mandora,tangelo,0.5
7 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include ./WORKSPACE
2 | include ./package/move_generated_files.sh
3 | include ./package/tfdep.bzl
4 | include ./tensorflow_gnn/tensorflow_gnn.bzl
5 | global-include *.proto
6 | global-include BUILD
--------------------------------------------------------------------------------
/testdata/homogeneous/fruits.csv:
--------------------------------------------------------------------------------
1 | #id,name
2 | amanatsu,Amanatsu
3 | daidai,Daidai
4 | hassaku,Hassaku
5 | kiyomi,Kiyomi
6 | komikan,Komikan
7 | lumia,Lumia
8 | mandora,Mandora
9 | reikou,Reikou
10 | tangelo,Tangelo
11 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/graph_sage-symbols.txt:
--------------------------------------------------------------------------------
1 | graph_sage.GCNGraphSAGENodeSetUpdate
2 | graph_sage.GraphSAGEAggregatorConv
3 | graph_sage.GraphSAGEGraphUpdate
4 | graph_sage.GraphSAGENextState
5 | graph_sage.GraphSAGEPoolingConv
6 |
--------------------------------------------------------------------------------
/testdata/README.md:
--------------------------------------------------------------------------------
1 | # Test Data
2 |
3 | ## Synthetic Data
4 |
5 | All test data contained in the `/heterogeneous` and `/homogeneous` directories
6 | are synthetic datasets, and do not contain private or confidential information.
7 |
8 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/gat_v2-symbols.txt:
--------------------------------------------------------------------------------
1 | gat_v2.GATv2Conv
2 | gat_v2.GATv2EdgePool
3 | gat_v2.GATv2GraphUpdate
4 | gat_v2.GATv2HomGraphUpdate
5 | gat_v2.GATv2MPNNGraphUpdate
6 | gat_v2.graph_update_from_config_dict
7 | gat_v2.graph_update_get_config_dict
8 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/spec.pbtxt:
--------------------------------------------------------------------------------
1 | seed_op {
2 | op_name: "seed"
3 | node_set_name: "node_set_one"
4 | }
5 | sampling_ops {
6 | op_name: "hop1"
7 | input_op_names: ["seed"]
8 | strategy: TOP_K
9 | sample_size: 2
10 | edge_set_name: "one_to_two"
11 | }
12 |
--------------------------------------------------------------------------------
/testdata/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
3 | package(
4 | default_applicable_licenses = ["//tensorflow_gnn:license"],
5 | default_visibility = ["//visibility:public"],
6 | )
7 |
8 | filegroup(
9 | name = "feature_repr",
10 | srcs = [
11 | "feature_repr.pbtxt",
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of TensorFlow GNN authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 |
7 | Google Inc.
8 |
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/continuous.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 |
8 | action {
9 | define_artifacts {
10 | regex: "**/sponge_log.log"
11 | regex: "**/sponge_log.xml"
12 | }
13 | }
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/presubmit.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 |
8 | action {
9 | define_artifacts {
10 | regex: "**/sponge_log.log"
11 | regex: "**/sponge_log.xml"
12 | }
13 | }
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/multi_head_attention-symbols.txt:
--------------------------------------------------------------------------------
1 | multi_head_attention.MultiHeadAttentionConv
2 | multi_head_attention.MultiHeadAttentionEdgePool
3 | multi_head_attention.MultiHeadAttentionHomGraphUpdate
4 | multi_head_attention.MultiHeadAttentionMPNNGraphUpdate
5 | multi_head_attention.graph_update_from_config_dict
6 | multi_head_attention.graph_update_get_config_dict
7 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/gcn/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fGCN
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * gcn
8 | * gcn.GCNConv
9 | * gcn.GCNHomGraphUpdate
10 |
--------------------------------------------------------------------------------
/testdata/homogeneous/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
3 | package(default_visibility = ["//visibility:public"])
4 |
5 | filegroup(
6 | name = "homogeneous",
7 | srcs = [
8 | ":citrus.pbtxt",
9 | ":fruits.csv",
10 | ":one_seed.csv",
11 | ":sampler_golden.ascii",
12 | ":tastelike.csv",
13 | ":two_seeds.csv",
14 | ],
15 | )
16 |
--------------------------------------------------------------------------------
/examples/sampler/creditcard/sampling_spec.pbtxt:
--------------------------------------------------------------------------------
1 | # proto-file: tensorflow_gnn/sampler/sampling_spec.proto
2 | # proto-message: PipelineSpec
3 |
4 | seed_op <
5 | op_name: "seed"
6 | node_set_name: "customer"
7 | >
8 | sampling_ops <
9 | op_name: "seed->creditcard"
10 | input_op_names: "seed"
11 | edge_set_name: "owns_card"
12 | sample_size: 3
13 | strategy: RANDOM_UNIFORM
14 | >
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/Loss.md:
--------------------------------------------------------------------------------
1 | # runner.Loss
2 |
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 | #### Source:
8 |
9 |
10 | Loss = Callable[
11 | tf.Tensor,
12 | tf.Tensor,
13 | tf.Tensor
14 | ]
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/Field.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.Field
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | Field = Union[
13 | tf.Tensor,
14 | tf.RaggedTensor
15 | ]
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/FieldSpec.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.FieldSpec
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | FieldSpec = Union[
13 | tf.TensorSpec,
14 | tf.RaggedTensorSpec
15 | ]
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/Fields.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.Fields
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | Fields = Mapping[
13 | str,
14 | tfgnn.Field
15 | ]
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/IncidentNodeOrContextTag.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.IncidentNodeOrContextTag
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | IncidentNodeOrContextTag = Union[
13 | int,
14 | str
15 | ]
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
3 | package(
4 | default_applicable_licenses = ["//tensorflow_gnn:license"],
5 | default_visibility = ["//visibility:public"],
6 | )
7 |
8 | filegroup(
9 | name = "node_vs_edge",
10 | srcs = [
11 | "edge_set_one_to_two.csv",
12 | "edge_set_two_to_two.csv",
13 | ":node_set_one.csv",
14 | ":node_set_two.csv",
15 | ],
16 | data = glob(["*.pbtxt"]),
17 | )
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/Predictions.md:
--------------------------------------------------------------------------------
1 | # runner.Predictions
2 |
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 | #### Source:
8 |
9 |
10 | Predictions = Union[
11 | tf.Tensor,
12 | tf.RaggedTensor,
13 | Mapping[str, Union[tf.Tensor, tf.RaggedTensor]]
14 | ]
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/FieldsSpec.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.FieldsSpec
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | FieldsSpec = Mapping[
13 | str,
14 | tfgnn.FieldSpec
15 | ]
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/Losses.md:
--------------------------------------------------------------------------------
1 | # runner.Losses
2 |
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 | #### Source:
8 |
9 |
10 | Losses = Union[
11 | runner.Loss,
12 | Mapping[str, runner.Loss]
13 | ]
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/package/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"]) # Apache 2.0
2 |
3 | sh_binary(
4 | name = "move_generated_files",
5 | srcs = ["move_generated_files.sh"],
6 | data = [
7 | "//tensorflow_gnn/experimental/sampler/proto:eval_dag_py_proto",
8 | "//tensorflow_gnn/proto:examples_py_proto",
9 | "//tensorflow_gnn/proto:graph_schema_py_proto",
10 | "//tensorflow_gnn/sampler:sampling_spec_py_proto",
11 | "//tensorflow_gnn/tools:sampled_stats_py_proto",
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/FieldOrFields.md:
--------------------------------------------------------------------------------
1 |
2 | # tfgnn.FieldOrFields
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 |
8 |
9 | #### Source:
10 |
11 |
12 | FieldOrFields = Union[
13 | tf.Tensor,
14 | tf.RaggedTensor,
15 | tfgnn.Fields
16 | ]
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/BUILD:
--------------------------------------------------------------------------------
1 | licenses(["notice"])
2 |
3 | package(default_visibility = ["//visibility:public"])
4 |
5 | filegroup(
6 | name = "heterogeneous",
7 | srcs = [
8 | ":creditcard.csv",
9 | ":customer.csv",
10 | ":graph.pbtxt",
11 | ":invalid_customer.csv",
12 | ":one_customer.csv",
13 | ":owns_card.csv",
14 | ":paid_with.csv",
15 | ":sampler_golden.ascii",
16 | ":transactions.csv",
17 | ":two_customers.csv",
18 | ],
19 | )
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | """Public interface for GNN Sampler."""
2 | from tensorflow_gnn.sampler import sampling_spec_builder
3 | from tensorflow_gnn.sampler import sampling_spec_pb2
4 |
5 | SamplingOp = sampling_spec_pb2.SamplingOp
6 | SamplingSpec = sampling_spec_pb2.SamplingSpec
7 | SamplingSpecBuilder = sampling_spec_builder.SamplingSpecBuilder
8 | SamplingStrategy = sampling_spec_pb2.SamplingStrategy
9 | make_sampling_spec_tree = sampling_spec_builder.make_sampling_spec_tree
10 |
11 | del sampling_spec_pb2
12 | del sampling_spec_builder
13 |
--------------------------------------------------------------------------------
/tensorflow_gnn/experimental/sampler/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//visibility:public"])
2 |
3 | licenses(["notice"]) # Apache 2.0
4 |
5 | py_test(
6 | name = "core_test",
7 | srcs = ["core_test.py"],
8 | python_version = "PY3",
9 | deps = [],
10 | )
11 |
12 | py_test(
13 | name = "ext_ops_test",
14 | srcs = ["ext_ops_test.py"],
15 | python_version = "PY3",
16 | deps = [],
17 | )
18 |
19 | py_test(
20 | name = "eval_dag_test",
21 | srcs = ["eval_dag_test.py"],
22 | python_version = "PY3",
23 | deps = [],
24 | )
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/mt_albis/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fMtAlbis
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * mt_albis
8 | * mt_albis.MtAlbisGraphUpdate
9 | * mt_albis.graph_update_from_config_dict
10 | * mt_albis.graph_update_get_config_dict
11 |
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/newest_stable/presubmit.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 | env_vars: {
8 | key: "PYTHON_VERSION"
9 | value: "3.11"
10 | }
11 | env_vars: {
12 | key: "TF_VERSION"
13 | value: "2.20.*"
14 | }
15 | env_vars: {
16 | key: "TF_USE_LEGACY_KERAS"
17 | value: "1"
18 | }
19 |
20 | action {
21 | define_artifacts {
22 | regex: "**/sponge_log.log"
23 | regex: "**/sponge_log.xml"
24 | }
25 | }
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/newest_stable/continuous.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 | env_vars: {
8 | key: "PYTHON_VERSION"
9 | value: "3.11"
10 | }
11 | env_vars: {
12 | key: "TF_VERSION"
13 | value: "2.20.*"
14 | }
15 | env_vars: {
16 | key: "TF_USE_LEGACY_KERAS"
17 | value: "1"
18 | }
19 |
20 | action {
21 | define_artifacts {
22 | regex: "**/sponge_log.log"
23 | regex: "**/sponge_log.xml"
24 | }
25 | }
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/reverse_tag.md:
--------------------------------------------------------------------------------
1 | # tfgnn.reverse_tag
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Flips tfgnn.SOURCE to tfgnn.TARGET and vice versa.
10 |
11 |
12 | tfgnn.reverse_tag(
13 | tag
14 | )
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fVanillaMPNN
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * vanilla_mpnn
8 | * vanilla_mpnn.VanillaMPNNGraphUpdate
9 | * vanilla_mpnn.graph_update_from_config_dict
10 | * vanilla_mpnn.graph_update_get_config_dict
11 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/enable_graph_tensor_validation.md:
--------------------------------------------------------------------------------
1 | # tfgnn.enable_graph_tensor_validation
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Enables static checks of graph tensors.
10 |
11 |
12 | tfgnn.enable_graph_tensor_validation()
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/ContextLabelFn.md:
--------------------------------------------------------------------------------
1 | # runner.ContextLabelFn
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Reads out a `tfgnn.Field` from the `GraphTensor` context.
10 |
11 |
12 | runner.ContextLabelFn(
13 | feature_name: str, **kwargs
14 | )
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/Metrics.md:
--------------------------------------------------------------------------------
1 | # runner.Metrics
2 |
3 |
4 |
5 | This symbol is a **type alias**.
6 |
7 | #### Source:
8 |
9 |
10 | Metrics = Union[
11 | runner.Loss,
12 | Sequence[runner.Loss],
13 | Mapping[str, Union[runner.Loss, Sequence[runner.Loss]]]
14 | ]
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/oldest/continuous.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 | env_vars: {
8 | key: "PYTHON_VERSION"
9 | value: "3.9"
10 | }
11 | env_vars: {
12 | key: "TF_VERSION"
13 | value: "2.12.*"
14 | }
15 | env_vars: {
16 | key: "TF_USE_LEGACY_KERAS"
17 | value: "0"
18 | }
19 | env_vars: {
20 | key: "TAG_FILTERS"
21 | value: ",-tf_at_least_2_13"
22 | }
23 |
24 | action {
25 | define_artifacts {
26 | regex: "**/sponge_log.log"
27 | regex: "**/sponge_log.xml"
28 | }
29 | }
--------------------------------------------------------------------------------
/kokoro/github/ubuntu/cpu/oldest/presubmit.cfg:
--------------------------------------------------------------------------------
1 | build_file: "gnn/kokoro/github/ubuntu/cpu/build_versioned.sh"
2 |
3 | env_vars: {
4 | key: "USE_BAZEL_VERSION"
5 | value: "7.4.1" # TODO - b/390391579: Unpin once bazel 8 works.
6 | }
7 | env_vars: {
8 | key: "PYTHON_VERSION"
9 | value: "3.9"
10 | }
11 | env_vars: {
12 | key: "TF_VERSION"
13 | value: "2.12.*"
14 | }
15 | env_vars: {
16 | key: "TF_USE_LEGACY_KERAS"
17 | value: "0"
18 | }
19 | env_vars: {
20 | key: "TAG_FILTERS"
21 | value: ",-tf_at_least_2_13"
22 | }
23 |
24 | action {
25 | define_artifacts {
26 | regex: "**/sponge_log.log"
27 | regex: "**/sponge_log.xml"
28 | }
29 | }
--------------------------------------------------------------------------------
/tensorflow_gnn/models/hgt/README.md:
--------------------------------------------------------------------------------
1 | # Heterogeneous Graph Transformers
2 |
3 | ## Overview
4 |
5 | This code implements the method in Heterogeneous Graph Transformers, originally
6 | implemented by
7 |
8 | * Ziniu Hu, Yuxiao Dong, Kuansan Wang, Yizhou Sun:
9 | ["Heterogeneous Graph Transformer"](https://arxiv.org/abs/2003.01332), 2020.
10 |
11 |
12 | ## Usage
13 |
14 | TensorFlow programs can import and use this model as described in its
15 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/hgt.md).
16 |
17 | ## API stability
18 |
19 | The API of this model may change between OSS library versions.
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/is_graph_tensor.md:
--------------------------------------------------------------------------------
1 | # tfgnn.is_graph_tensor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns whether `value` is a GraphTensor (possibly wrapped for Keras).
10 |
11 |
12 | tfgnn.is_graph_tensor(
13 | value: Any
14 | ) -> bool
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/enable_graph_tensor_validation_at_runtime.md:
--------------------------------------------------------------------------------
1 | # tfgnn.enable_graph_tensor_validation_at_runtime
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Enables both static and runtime checks of graph tensors.
10 |
11 |
12 | tfgnn.enable_graph_tensor_validation_at_runtime()
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/get_registered_reduce_operation_names.md:
--------------------------------------------------------------------------------
1 | # tfgnn.get_registered_reduce_operation_names
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns the registered list of supported reduce operation names.
10 |
11 |
12 | tfgnn.get_registered_reduce_operation_names() -> list[str]
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/is_dense_tensor.md:
--------------------------------------------------------------------------------
1 | # tfgnn.is_dense_tensor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns whether a tensor (TF or Keras) is a Tensor.
10 |
11 |
12 | tfgnn.is_dense_tensor(
13 | value: tfgnn.Field
14 | ) -> bool
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md:
--------------------------------------------------------------------------------
1 | # gat_v2.graph_update_get_config_dict
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
10 |
11 |
12 | gat_v2.graph_update_get_config_dict() -> config_dict.ConfigDict
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/disable_graph_tensor_validation_at_runtime.md:
--------------------------------------------------------------------------------
1 | # tfgnn.disable_graph_tensor_validation_at_runtime
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Disables runtime checks (`tf.debugging.Assert`) of graph tensors.
10 |
11 |
12 | tfgnn.disable_graph_tensor_validation_at_runtime()
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/is_ragged_tensor.md:
--------------------------------------------------------------------------------
1 | # tfgnn.is_ragged_tensor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns whether a tensor (TF or Keras) is a RaggedTensor.
10 |
11 |
12 | tfgnn.is_ragged_tensor(
13 | value: tfgnn.Field
14 | ) -> bool
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md:
--------------------------------------------------------------------------------
1 | # mt_albis.graph_update_get_config_dict
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
10 |
11 |
12 | mt_albis.graph_update_get_config_dict() -> config_dict.ConfigDict
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/experimental/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | licenses(["notice"])
6 |
7 | package(
8 | default_applicable_licenses = ["//tensorflow_gnn:license"],
9 | default_visibility = [
10 | "//tensorflow_gnn:__subpackages__",
11 | ],
12 | )
13 |
14 | pytype_strict_library(
15 | name = "experimental",
16 | srcs = ["__init__.py"],
17 | deps = [
18 | "//tensorflow_gnn/graph:readout",
19 | "//tensorflow_gnn/graph:tensor_utils",
20 | ],
21 | )
22 |
--------------------------------------------------------------------------------
/testdata/homogeneous/citrus.pbtxt:
--------------------------------------------------------------------------------
1 | # TODO(blais): Test context features.
2 |
3 | node_sets {
4 | key: "fruits"
5 | value {
6 | metadata {
7 | filename: "fruits.csv"
8 | }
9 | features {
10 | key: "name"
11 | value: {
12 | description: "Fruit name"
13 | dtype: DT_STRING
14 | }
15 | }
16 | }
17 | }
18 |
19 | edge_sets {
20 | key: "tastelike"
21 | value {
22 | source: "fruits"
23 | target: "fruits"
24 | description: "Similar taste"
25 | metadata {
26 | filename: "tastelike.csv"
27 | }
28 | features {
29 | key: "weight"
30 | value: {
31 | dtype: DT_FLOAT
32 | }
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/one_node_per_component.md:
--------------------------------------------------------------------------------
1 | # runner.one_node_per_component
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns a `Mapping` `node_set_name: 1` for every node set in `gtspec`.
10 |
11 |
12 | runner.one_node_per_component(
13 | gtspec: tfgnn.GraphTensorSpec
14 | ) -> Mapping[str, int]
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md:
--------------------------------------------------------------------------------
1 | # vanilla_mpnn.graph_update_get_config_dict
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
10 |
11 |
12 | vanilla_mpnn.graph_update_get_config_dict() -> config_dict.ConfigDict
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/check_scalar_graph_tensor.md:
--------------------------------------------------------------------------------
1 | # tfgnn.check_scalar_graph_tensor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Checks that graph tensor is scalar (has rank 0).
10 |
11 |
12 | tfgnn.check_scalar_graph_tensor(
13 | graph: Union[GraphTensor, GraphTensorSpec], name='This operation'
14 | ) -> None
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/runner/trainers/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | licenses(["notice"])
6 |
7 | package(
8 | default_applicable_licenses = ["//tensorflow_gnn:license"],
9 | default_visibility = ["//visibility:public"],
10 | )
11 |
12 | pytype_strict_library(
13 | name = "keras_fit",
14 | srcs = ["keras_fit.py"],
15 | visibility = ["//tensorflow_gnn/runner:__pkg__"],
16 | deps = [
17 | "//:expect_tensorflow_installed",
18 | "//tensorflow_gnn/runner:interfaces",
19 | ],
20 | )
21 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/contrastive_losses-symbols.txt:
--------------------------------------------------------------------------------
1 | contrastive_losses.AllSvdMetrics
2 | contrastive_losses.BarlowTwinsTask
3 | contrastive_losses.ContrastiveLossTask
4 | contrastive_losses.CorruptionSpec
5 | contrastive_losses.Corruptor
6 | contrastive_losses.DeepGraphInfomaxLogits
7 | contrastive_losses.DeepGraphInfomaxTask
8 | contrastive_losses.DropoutFeatures
9 | contrastive_losses.ShuffleFeaturesGlobally
10 | contrastive_losses.TripletEmbeddingSquaredDistances
11 | contrastive_losses.TripletLossTask
12 | contrastive_losses.VicRegTask
13 | contrastive_losses.coherence
14 | contrastive_losses.numerical_rank
15 | contrastive_losses.pseudo_condition_number
16 | contrastive_losses.rankme
17 | contrastive_losses.self_clustering
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gat_v2/README.md:
--------------------------------------------------------------------------------
1 | # Graph Attention Networks v2
2 |
3 | ## Overview
4 |
5 | This code implements Graph Attention Networks v2, originally published by
6 |
7 | * Shaked Brody, Uri Alon, Eran Yahav:
8 | ["How Attentive are Graph Attention
9 | Networks?"](https://arxiv.org/abs/2105.14491), 2021.
10 |
11 | TensorFlow programs can import and use it as described in its
12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/gat_v2.md).
13 |
14 | ## API stability
15 |
16 | This model is covered by [semantic
17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's
18 | open-source releases: new minor versions do not break existing users.
19 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/graph_sage/README.md:
--------------------------------------------------------------------------------
1 | # GraphSAGE
2 |
3 | ## Overview
4 |
5 | This code implements the GraphSAGE model, originally published by
6 |
7 | * William L. Hamilton, Rex Ying, Jure Leskovec:
8 | ["Inductive Representation Learning
9 | on Large Graphs"](https://arxiv.org/abs/1706.02216), 2017.
10 |
11 | TensorFlow programs can import and use it as described in its
12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md).
13 |
14 | ## API stability
15 |
16 | This model is covered by [semantic
17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's
18 | open-source releases: new minor versions do not break existing users.
19 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gcn/README.md:
--------------------------------------------------------------------------------
1 | # Graph Convolutional Network
2 |
3 | ## Overview
4 |
5 | This code implements Graph Convolutional Networks, originally published by
6 |
7 | * Thomas N. Kipf and Max Welling:
8 | ["Semi-Supervised Classification with Graph Convolutional
9 | Networks"](https://arxiv.org/abs/1609.02907), 2016.
10 |
11 | TensorFlow programs can import and use it as described in its
12 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/gcn.md).
13 |
14 | ## API stability
15 |
16 | This model is covered by [semantic
17 | versioning](https://semver.org/spec/v2.0.0.html) of TensorFlow GNN's
18 | open-source releases: new minor versions do not break existing users.
19 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/ValidationError.md:
--------------------------------------------------------------------------------
1 | # tfgnn.ValidationError
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A schema validation error.
10 |
11 |
12 | tfgnn.ValidationError(
13 | *args, **kwargs
14 | )
15 |
16 |
17 |
18 |
19 |
20 |
21 | This exception is raised if in the course of validating the schema for
22 | correctness some errors are found.
23 |
24 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md:
--------------------------------------------------------------------------------
1 | # multi_head_attention.graph_update_get_config_dict
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
10 |
11 |
12 | multi_head_attention.graph_update_get_config_dict() -> config_dict.ConfigDict
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/RootNodeLabelFn.md:
--------------------------------------------------------------------------------
1 | # runner.RootNodeLabelFn
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Reads out a `tfgnn.Field` from the `GraphTensor` root (i.e. first) node.
10 |
11 |
12 | runner.RootNodeLabelFn(
13 | node_set_name: tfgnn.NodeSetName,
14 | *,
15 | feature_name: tfgnn.FieldName = tfgnn.HIDDEN_STATE,
16 | **kwargs
17 | )
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/examples/sampler/creditcard/owns_card.csv:
--------------------------------------------------------------------------------
1 | source,target
2 | 1876448,16827485386298040
3 | 1372437,11470379189154620
4 | 1368305,11163838768727470
5 | 1974494,16011471358128450
6 | 1257724,18569067217418250
7 | 1758057,17396883707513070
8 | 1531660,14844931107602160
9 | 1489311,1238474857489384
10 | 1407706,11290312140467510
11 | 196838,17861046738135650
12 | 1195675,8878522895102384
13 | 1659366,13019350102369400
14 | 1499004,11470379189154620
15 | 1344333,16283233487191600
16 | 1443888,9991040399813057
17 | 1108778,14912408563871390
18 | 175583,11290312140467510
19 | 1251872,12948957000457930
20 | 1493851,3549061668422198
21 | 1599418,9991040399813057
22 | 1768701,18362223127059380
23 | 1549489,1238474857489384
24 | 1879799,18569067217418250
25 | 125454,18526138896540830
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DeepGraphInfomaxLogits.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.DeepGraphInfomaxLogits
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Computes clean and corrupted logits for Deep Graph Infomax (DGI).
10 |
11 |
12 | contrastive_losses.DeepGraphInfomaxLogits(
13 | trainable=True, name=None, dtype=None, dynamic=False, **kwargs
14 | )
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras.md:
--------------------------------------------------------------------------------
1 | # Module: tfgnn.keras
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The tfgnn.keras package.
10 |
11 |
12 |
13 | ## Modules
14 |
15 | [`layers`](../tfgnn/keras/layers.md) module: The tfgnn.keras.layers package.
16 |
17 | ## Classes
18 |
19 | [`class ConvGNNBuilder`](../tfgnn/keras/ConvGNNBuilder.md): Factory of layers that do convolutions on a graph.
20 |
21 | ## Functions
22 |
23 | [`clone_initializer(...)`](../tfgnn/keras/clone_initializer.md): Clones an
24 | initializer to ensure a new default seed.
25 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/graph_sage/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fGraphSAGE
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * graph_sage
8 | * graph_sage.GCNGraphSAGENodeSetUpdate
9 | * graph_sage.GraphSAGEAggregatorConv
10 | * graph_sage.GraphSAGEGraphUpdate
11 | * graph_sage.GraphSAGENextState
12 | * graph_sage.GraphSAGEPoolingConv
13 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/owns_card.csv:
--------------------------------------------------------------------------------
1 | source,target
2 | 1876448,16827485386298040
3 | 1372437,11470379189154620
4 | 1368305,11163838768727470
5 | 1974494,16011471358128450
6 | 1257724,18569067217418250
7 | 1758057,17396883707513070
8 | 1531660,14844931107602160
9 | 1489311,1238474857489384
10 | 1407706,11290312140467510
11 | 196838,17861046738135650
12 | 1195675,8878522895102384
13 | 1659366,13019350102369400
14 | 1499004,11470379189154620
15 | 1344333,16283233487191600
16 | 1443888,9991040399813057
17 | 1108778,14912408563871390
18 | 175583,11290312140467510
19 | 1251872,12948957000457930
20 | 1493851,3549061668422198
21 | 1599418,9991040399813057
22 | 1768701,18362223127059380
23 | 1549489,1238474857489384
24 | 1879799,18569067217418250
25 | 125454,18526138896540830
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/check_homogeneous_graph_tensor.md:
--------------------------------------------------------------------------------
1 | # tfgnn.check_homogeneous_graph_tensor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Raises ValueError when tfgnn.get_homogeneous_node_and_edge_set_name() does.
10 |
11 |
12 | tfgnn.check_homogeneous_graph_tensor(
13 | graph: Union[GraphTensor, GraphTensorSpec],
14 | name: str = 'This operation'
15 | ) -> None
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/TripletEmbeddingSquaredDistances.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.TripletEmbeddingSquaredDistances
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Computes embeddings distance between positive and negative pairs.
10 |
11 |
12 | contrastive_losses.TripletEmbeddingSquaredDistances(
13 | trainable=True, name=None, dtype=None, dynamic=False, **kwargs
14 | )
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tensorflow_gnn/sampler/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites some of these load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@rules_proto//proto:defs.bzl", "proto_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@rules_python//python:proto.bzl", "py_proto_library")
7 |
8 | package(default_visibility = ["//visibility:public"])
9 |
10 | licenses(["notice"]) # Apache 2.0
11 |
12 | proto_library(
13 | name = "sampling_spec_proto",
14 | srcs = ["sampling_spec.proto"],
15 | deps = ["@org_tensorflow//tensorflow/core:protos_all"],
16 | )
17 |
18 | py_proto_library(
19 | name = "sampling_spec_py_proto",
20 | deps = [
21 | ":sampling_spec_proto",
22 | ],
23 | )
--------------------------------------------------------------------------------
/tensorflow_gnn/tools/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites some of these load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@rules_proto//proto:defs.bzl", "proto_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@rules_python//python:proto.bzl", "py_proto_library")
7 |
8 | package(default_visibility = ["//visibility:public"])
9 |
10 | licenses(["notice"]) # Apache 2.0
11 |
12 | proto_library(
13 | name = "sampled_stats_proto",
14 | srcs = ["sampled_stats.proto"],
15 | deps = ["@org_tensorflow//tensorflow/core:protos_all"],
16 | )
17 |
18 | py_proto_library(
19 | name = "sampled_stats_py_proto",
20 | deps = [
21 | ":sampled_stats_proto",
22 | ],
23 | )
24 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/gat_v2/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fGATv2
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * gat_v2
8 | * gat_v2.GATv2Conv
9 | * gat_v2.GATv2EdgePool
10 | * gat_v2.GATv2HomGraphUpdate
11 | * gat_v2.GATv2MPNNGraphUpdate
12 | * gat_v2.graph_update_from_config_dict
13 | * gat_v2.graph_update_get_config_dict
14 |
--------------------------------------------------------------------------------
/tensorflow_gnn/utils/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | licenses(["notice"])
6 |
7 | package(
8 | default_applicable_licenses = ["//tensorflow_gnn:license"],
9 | default_visibility = ["//visibility:public"],
10 | )
11 |
12 | pytype_strict_library(
13 | name = "test_utils",
14 | srcs = ["test_utils.py"],
15 | )
16 |
17 | pytype_strict_library(
18 | name = "api_utils",
19 | srcs = ["api_utils.py"],
20 | )
21 |
22 | pytype_strict_library(
23 | name = "tf_test_utils",
24 | srcs = ["tf_test_utils.py"],
25 | deps = [
26 | "//:expect_tensorflow_installed",
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/gcn.md:
--------------------------------------------------------------------------------
1 | # Module: gcn
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Graph Convolutional Networks.
10 |
11 | Users of TF-GNN can use this model by importing it next to the core library as
12 |
13 | ```python
14 | import tensorflow_gnn as tfgnn
15 | from tensorflow_gnn.models import gcn
16 | ```
17 |
18 | ## Classes
19 |
20 | [`class GCNConv`](./gcn/GCNConv.md): Implements the Graph Convolutional Network
21 | by Kipf&Welling (2016).
22 |
23 | ## Functions
24 |
25 | [`GCNHomGraphUpdate(...)`](./gcn/GCNHomGraphUpdate.md): Returns a graph update
26 | layer for GCN convolution.
27 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/AddSelfLoops.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.AddSelfLoops
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Adds self-loops to scalar graphs.
10 |
11 |
12 | tfgnn.keras.layers.AddSelfLoops(
13 | edge_set_name
14 | )
15 |
16 |
17 |
18 |
19 | The edge_set_name is expected to be a homogeneous edge (connects a node pair of
20 | the node set). NOTE: Self-connections will always be added, regardless if if
21 | self-connections already exist or not.
22 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/disable_graph_tensor_validation.md:
--------------------------------------------------------------------------------
1 | # tfgnn.disable_graph_tensor_validation
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Disables both static and runtime checks of graph tensors.
10 |
11 |
12 | tfgnn.disable_graph_tensor_validation()
13 |
14 |
15 |
16 |
17 | IMPORTANT: This is temporary workaround for the legacy code (before TF-GNN 1.0
18 | release) that may rely on the inconsistent number of graph tensor items and
19 | allowed edges with adjaceny indices for non-existing nodes. **DO NOT USE**.
20 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Metadata/KeyValue.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.Metadata.KeyValue
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A ProtocolMessage
10 |
11 |
12 |
13 |
14 |
15 |
16 | Attributes |
17 |
18 |
19 |
20 | key
21 | |
22 |
23 | string key
24 | |
25 |
26 |
27 | value
28 | |
29 |
30 | string value
31 | |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorProcessorFn.md:
--------------------------------------------------------------------------------
1 | # runner.GraphTensorProcessorFn
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A class for `GraphTensor` processing.
10 |
11 |
12 |
13 | ## Methods
14 |
15 | __call__
16 |
17 | View
18 | source
19 |
20 |
21 | __call__(
22 | inputs: GraphTensor
23 | ) -> GraphTensor
24 |
25 |
26 | Processes a `GraphTensor`.
27 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/experimental.md:
--------------------------------------------------------------------------------
1 | # Module: tfgnn.experimental
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Experimental (unstable) parts of the public interface of TensorFlow GNN.
10 |
11 | A symbol `foo` exposed here is available to library users as
12 |
13 | ```
14 | import tensorflow_gnn as tfgnn
15 |
16 | tfgnn.experimental.foo()
17 | ```
18 |
19 | This is the preferred way to expose individual functions on track to inclusion
20 | into the stable public interface of TensorFlow GNN.
21 |
22 | Beyond these symbols, there are also experimental sub-libraries that need to be
23 | imported separately (`from tensorflow_gnn.experimental import foo`). That is for
24 | special cases only.
25 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/ParseExample.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.ParseExample
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Applies tfgnn.parse_example(graph_tensor_spec, _) to a batch of strings.
10 |
11 |
12 | tfgnn.keras.layers.ParseExample(
13 | graph_tensor_spec: tfgnn.GraphTensorSpec,
14 | **kwargs
15 | )
16 |
17 |
18 |
19 |
20 |
21 |
22 | This layer can be restored from config by `tf.keras.models.load_model()` when
23 | saved as part of a Keras model using `save_format="tf"`.
24 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/ParseSingleExample.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.ParseSingleExample
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Applies tfgnn.parse_single_example(graph_tensor_spec, _).
10 |
11 |
12 | tfgnn.keras.layers.ParseSingleExample(
13 | graph_tensor_spec: tfgnn.GraphTensorSpec,
14 | **kwargs
15 | )
16 |
17 |
18 |
19 |
20 |
21 |
22 | This layer can be restored from config by `tf.keras.models.load_model()` when
23 | saved as part of a Keras model using `save_format="tf"`.
24 |
--------------------------------------------------------------------------------
/testdata/node_vs_edge/schema.pbtxt:
--------------------------------------------------------------------------------
1 | node_sets {
2 | key: "node_set_one"
3 | value {
4 | metadata {
5 | filename: "node_set_one.csv"
6 | }
7 | features {
8 | key: "#id"
9 | value {
10 | dtype: DT_STRING
11 | }
12 | }
13 | }
14 | }
15 | node_sets {
16 | key: "node_set_two"
17 | value {
18 | metadata {
19 | filename: "node_set_two.csv"
20 | }
21 | features {
22 | key: "#id"
23 | value {
24 | dtype: DT_STRING
25 | }
26 | }
27 | }
28 | }
29 | edge_sets {
30 | key: "one_to_two"
31 | value {
32 | source: "node_set_one"
33 | target: "node_set_two"
34 | metadata {
35 | filename: "edge_set_one_to_two.csv"
36 | }
37 | }
38 | }
39 | edge_sets {
40 | key: "two_to_two"
41 | value {
42 | source: "node_set_two"
43 | target: "node_set_two"
44 | metadata {
45 | filename: "edge_set_two_to_two.csv"
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/softmax_edges_per_node.md:
--------------------------------------------------------------------------------
1 | # tfgnn.softmax_edges_per_node
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns softmax() of edge values per common `node_tag` node.
10 |
11 |
12 | tfgnn.softmax_edges_per_node(
13 | graph_tensor: tfgnn.GraphTensor,
14 | edge_set_name: EdgeSetName,
15 | node_tag: IncidentNodeTag,
16 | *,
17 | feature_value: Optional[Field] = None,
18 | feature_name: Optional[FieldName] = None
19 | ) -> tfgnn.Field
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/DatasetProvider.md:
--------------------------------------------------------------------------------
1 | # runner.DatasetProvider
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Helper class that provides a standard way to create an ABC using inheritance.
10 |
11 |
12 |
13 | ## Methods
14 |
15 | get_dataset
16 |
17 | View
18 | source
19 |
20 |
21 | @abc.abstractmethod
22 | get_dataset(
23 | context: tf.distribute.InputContext
24 | ) -> tf.data.Dataset
25 |
26 |
27 | Get a `tf.data.Dataset` by `context` per replica.
28 |
--------------------------------------------------------------------------------
/package/tfdep.bzl:
--------------------------------------------------------------------------------
1 | """Dependency on TensorFlow for build."""
2 |
3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
4 |
5 | def tf_setup():
6 | """Define tensorflow>=2.13.0 dependency for Bazel build.
7 |
8 | This downloads the TensorFlow files required for building TFGNN protos
9 | (examples and graph_schema).This TF version should always be within our supported range and gets
10 | updated manually as our TF dependency advances. This version is somewhat flexible since TFGNN
11 | protos only depend on tensorflow.Example, tensorflow.Feature, tensorflow.TensorShapeProto, and
12 | tensorflow.DataType, which are all very stable definitions in TensorFlow.
13 | """
14 | http_archive(
15 | name = "org_tensorflow",
16 | sha256 = "e58c939079588623e6fa1d054aec2f90f95018266e0a970fd353a5244f5173dc",
17 | urls = [
18 | "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.13.0.tar.gz",
19 | ],
20 | strip_prefix = "tensorflow-2.13.0",
21 | )
22 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/BigQuery/TableSpec.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.BigQuery.TableSpec
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A ProtocolMessage
10 |
11 |
12 |
13 |
14 |
15 |
16 | Attributes |
17 |
18 |
19 |
20 | dataset
21 | |
22 |
23 | string dataset
24 | |
25 |
26 |
27 | project
28 | |
29 |
30 | string project
31 | |
32 |
33 |
34 | table
35 | |
36 |
37 | string table
38 | |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/OriginInfo.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.OriginInfo
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Metadata about the origin of the graph data.
10 |
11 |
12 |
13 | For detailed documentation, see the comments in the `graph_schema.proto` file.
14 |
15 |
16 |
17 |
18 |
19 | Attributes |
20 |
21 |
22 |
23 | graph_type
24 | |
25 |
26 | GraphType graph_type
27 | |
28 |
29 |
30 | root_set
31 | |
32 |
33 | repeated string root_set
34 | |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/sampler-symbols.txt:
--------------------------------------------------------------------------------
1 | sampler.Artifacts
2 | sampler.CompositeLayer
3 | sampler.ConnectingEdgesSampler
4 | sampler.InMemIndexToFeaturesAccessor
5 | sampler.InMemIntegerKeyToBytesAccessor
6 | sampler.InMemStringKeyToBytesAccessor
7 | sampler.InMemUniformEdgesSampler
8 | sampler.KeyToBytesAccessor
9 | sampler.KeyToFeaturesAccessor
10 | sampler.KeyToTfExampleAccessor
11 | sampler.OutgoingEdgesSampler
12 | sampler.TfExamplesParser
13 | sampler.UniformEdgesSampler
14 | sampler.build_graph_tensor
15 | sampler.create_link_sampling_model_from_spec
16 | sampler.create_program
17 | sampler.create_sampling_model_from_spec
18 | sampler.ragged_choice
19 | sampler.ragged_lookup
20 | sampler.ragged_unique
21 | sampler.save_model
22 | sampler.set_ext_ops_implementation
23 | sampler.proto.Program
24 | sampler.proto.EvalDAG
25 | sampler.proto.Stage
26 | sampler.proto.Layer
27 | sampler.proto.ValueSpec
28 | sampler.proto.TensorSpec
29 | sampler.proto.RaggedTensorSpec
30 | sampler.proto.FlattenedSpec
31 | sampler.proto.EdgeSamplingConfig
32 | sampler.proto.IOFeatures
33 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Context.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.Context
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The schema for the features that apply across the entire input graph.
10 |
11 |
12 |
13 | For detailed documentation, see the comments in the `graph_schema.proto` file.
14 |
15 |
16 |
17 |
18 |
19 | Attributes |
20 |
21 |
22 |
23 | features
24 | |
25 |
26 | repeated FeaturesEntry features
27 | |
28 |
29 |
30 | metadata
31 | |
32 |
33 | Metadata metadata
34 | |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tensorflow_gnn/experimental/sampler/proto/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites some of these load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@rules_proto//proto:defs.bzl", "proto_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@rules_python//python:proto.bzl", "py_proto_library")
7 |
8 | package(default_visibility = ["//visibility:public"])
9 |
10 | licenses(["notice"]) # Apache 2.0
11 |
12 | py_library(
13 | name = "proto",
14 | srcs = ["__init__.py"],
15 | srcs_version = "PY3",
16 | visibility = ["//visibility:public"],
17 | deps = [
18 | ":eval_dag_py_proto",
19 | "//tensorflow_gnn/utils:api_utils",
20 | ],
21 | )
22 |
23 |
24 | proto_library(
25 | name = "eval_dag_proto",
26 | srcs = ["eval_dag.proto"],
27 | deps = [
28 | "@org_tensorflow//tensorflow/core:protos_all",
29 | ],
30 | )
31 |
32 | py_proto_library(
33 | name = "eval_dag_py_proto",
34 | deps = [
35 | ":eval_dag_proto",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Example of building the docker image locally:
16 | # docker build . -t tfgnn:latest
17 | #
18 | # You can then start an interactive python interpreter shell and
19 | # import tensorflow-gnn with:
20 | # docker run -it tfgnn:latest
21 | FROM python:3.9-slim
22 | # tzdata asks questions.
23 | ENV DEBIAN_FRONTEND="noninteractive"
24 | ENV TZ="America/New_York"
25 |
26 | RUN pip3 install --upgrade pip
27 |
28 | RUN pip3 install "tensorflow-gnn==1.0.0" httplib2 notebook ogb
29 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/all_symbols.md:
--------------------------------------------------------------------------------
1 | # All symbols in TensorFlow GNN Models: fMultiHeadAttention
2 |
3 |
4 |
5 | ## Primary symbols
6 |
7 | * multi_head_attention
8 | * multi_head_attention.MultiHeadAttentionConv
9 | * multi_head_attention.MultiHeadAttentionEdgePool
10 | * multi_head_attention.MultiHeadAttentionHomGraphUpdate
11 | * multi_head_attention.MultiHeadAttentionMPNNGraphUpdate
12 | * multi_head_attention.graph_update_from_config_dict
13 | * multi_head_attention.graph_update_get_config_dict
14 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/sampler/SamplingSpec.md:
--------------------------------------------------------------------------------
1 | # tfgnn.sampler.SamplingSpec
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A ProtocolMessage
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | Attributes |
18 |
19 |
20 |
21 | sampling_ops
22 | |
23 |
24 | repeated SamplingOp sampling_ops
25 | |
26 |
27 |
28 | seed_op
29 | |
30 |
31 | SeedOp seed_op
32 | |
33 |
34 |
35 | symmetric_link_seed_op
36 | |
37 |
38 | SymmetricLinkSeedOp symmetric_link_seed_op
39 | |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/tensorflow_gnn/proto/examples.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | // =============================================================================
15 | syntax = "proto2";
16 |
17 | package tensorflow_gnn.testdata;
18 |
19 | import "tensorflow/core/example/example.proto";
20 |
21 | // Specifies one or more Examples. This is used to aid testing readability by
22 | // allowing us to specify multiple Examples in an ASCII proto file.
23 | message ExampleList {
24 | repeated tensorflow.Example examples = 1;
25 | }
26 |
--------------------------------------------------------------------------------
/tensorflow_gnn/runner/utils/saved_model_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 |
18 | set -e -u -o pipefail # Let no failure go undetected.
19 |
20 | saved_model_gen_testdata=$1
21 | saved_model_load_testdata=$2
22 | use_legacy_model_save=$3
23 |
24 | $saved_model_gen_testdata --filepath=${TEST_TMPDIR}/saved_model_testdata \
25 | --use_legacy_model_save=$use_legacy_model_save
26 | $saved_model_load_testdata --filepath=${TEST_TMPDIR}/saved_model_testdata
27 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow GNN: API docs
2 |
3 | TensorFlow GNN is a Python package that consists of multiple sub-libraries.
4 |
5 | * Core library [[API docs](python/tfgnn.md)]:
6 | `import tensorflow_gnn as tfgnn`
7 |
8 | * Training orchestration [[API docs](python/runner.md)]:
9 | `from tensorflow_gnn import runner`
10 |
11 | * Graph sampler:
12 | `from tensorflow_gnn.experimental import sampler`
13 |
14 | * Models from the [models collection](../../models/README.md):
15 | `from tensorflow_gnn.models import `
16 |
17 | The TensorFlow GNN package uses [semantic
18 | versioning](https://semver.org/spec/v2.0.0.html) for its numbered releases.
19 | Its stable Python API consists of the Python identifiers exposed in these
20 | imported modules and their sub-modules, with the following exceptions:
21 |
22 | * identifiers or modules that contain `experimental` in their name,
23 | that are expressly documented to be unstable, or not documented at all;
24 | * private identifiers beginning with a single underscore (`_`);
25 | * models from the models collection whose documentation opts out of
26 | semantic versioning.
--------------------------------------------------------------------------------
/examples/sampler/mag/sampling_spec.pbtxt:
--------------------------------------------------------------------------------
1 | # proto-file: tensorflow_gnn/sampler/sampling_spec.proto
2 | # proto-message: PipelineSpec
3 |
4 | seed_op <
5 | op_name: "seed"
6 | node_set_name: "paper"
7 | >
8 | sampling_ops <
9 | op_name: "seed->paper"
10 | input_op_names: "seed"
11 | edge_set_name: "cites"
12 | sample_size: 32
13 | strategy: RANDOM_UNIFORM
14 | >
15 | sampling_ops <
16 | op_name: "paper->author"
17 | input_op_names: "seed"
18 | input_op_names: "seed->paper"
19 | edge_set_name: "written"
20 | sample_size: 8
21 | strategy: RANDOM_UNIFORM
22 | >
23 | sampling_ops <
24 | op_name: "author->paper"
25 | input_op_names: "paper->author"
26 | edge_set_name: "writes"
27 | sample_size: 16
28 | strategy: RANDOM_UNIFORM
29 | >
30 | sampling_ops <
31 | op_name: "author->institution"
32 | input_op_names: "paper->author"
33 | edge_set_name: "affiliated_with"
34 | sample_size: 16
35 | strategy: RANDOM_UNIFORM
36 | >
37 | sampling_ops <
38 | op_name: "paper->field_of_study"
39 | input_op_names: "seed"
40 | input_op_names: "seed->paper"
41 | input_op_names: "author->paper"
42 | edge_set_name: "has_topic"
43 | sample_size: 16
44 | strategy: RANDOM_UNIFORM
45 | >
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md:
--------------------------------------------------------------------------------
1 | # Module: vanilla_mpnn
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | TF-GNN's "Vanilla MPNN" model.
10 |
11 | Users of TF-GNN can use this model by importing it next to the core library as
12 |
13 | ```python
14 | import tensorflow_gnn as tfgnn
15 | from tensorflow_gnn.models import vanilla_gnn
16 | ```
17 |
18 | This model ties together some simple convolutions from the TF-GNN core library,
19 | so it does not define any Conv class by itself.
20 |
21 | ## Functions
22 |
23 | [`VanillaMPNNGraphUpdate(...)`](./vanilla_mpnn/VanillaMPNNGraphUpdate.md):
24 | Returns a GraphUpdate layer for a Vanilla MPNN.
25 |
26 | [`graph_update_from_config_dict(...)`](./vanilla_mpnn/graph_update_from_config_dict.md):
27 | Returns a VanillaMPNNGraphUpdate initialized from `cfg`.
28 |
29 | [`graph_update_get_config_dict(...)`](./vanilla_mpnn/graph_update_get_config_dict.md):
30 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
31 |
--------------------------------------------------------------------------------
/tensorflow_gnn/runner/input/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test")
7 |
8 | licenses(["notice"])
9 |
10 | package(
11 | default_applicable_licenses = ["//tensorflow_gnn:license"],
12 | default_visibility = ["//visibility:public"],
13 | )
14 |
15 | pytype_strict_library(
16 | name = "datasets",
17 | srcs = ["datasets.py"],
18 | visibility = ["//tensorflow_gnn/runner:__pkg__"],
19 | deps = [
20 | "//:expect_tensorflow_installed",
21 | "//tensorflow_gnn/runner:interfaces",
22 | ],
23 | )
24 |
25 | py_strict_test(
26 | name = "datasets_test",
27 | srcs = ["datasets_test.py"],
28 | deps = [
29 | ":datasets",
30 | "//:expect_absl_installed_testing",
31 | "//third_party/py/google/protobuf:use_fast_cpp_protos", # Automatically added go/proto_python_upb_flip
32 | "//:expect_tensorflow_installed",
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/mt_albis.md:
--------------------------------------------------------------------------------
1 | # Module: mt_albis
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | TF-GNN's Model Template "Albis".
10 |
11 | The TF-GNN Model Template "Albis" provides a small selection of field-tested GNN
12 | architectures through the
13 | mt_albis.MtAlbisGraphUpdate
14 | class.
15 |
16 | Users of TF-GNN can use it by importing it next to the core library as
17 |
18 | ```python
19 | import tensorflow_gnn as tfgnn
20 | from tensorflow_gnn.models import mt_albis
21 | ```
22 |
23 | ## Functions
24 |
25 | [`MtAlbisGraphUpdate(...)`](./mt_albis/MtAlbisGraphUpdate.md): Returns
26 | GraphUpdate layer for message passing with Model Template "Albis".
27 |
28 | [`graph_update_from_config_dict(...)`](./mt_albis/graph_update_from_config_dict.md):
29 | Constructs a MtAlbisGraphUpdate from a ConfigDict.
30 |
31 | [`graph_update_get_config_dict(...)`](./mt_albis/graph_update_get_config_dict.md):
32 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
33 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/sampler.md:
--------------------------------------------------------------------------------
1 | # Module: tfgnn.sampler
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Public interface for GNN Sampler.
10 |
11 | ## Classes
12 |
13 | [`class SamplingOp`](../tfgnn/sampler/SamplingOp.md): A ProtocolMessage
14 |
15 | [`class SamplingSpec`](../tfgnn/sampler/SamplingSpec.md): A ProtocolMessage
16 |
17 | [`class SamplingSpecBuilder`](../tfgnn/sampler/SamplingSpecBuilder.md): Mimics
18 | builder pattern that eases creation of `tfgnn.SamplingSpec`.
19 |
20 | ## Functions
21 |
22 | [`make_sampling_spec_tree(...)`](../tfgnn/sampler/make_sampling_spec_tree.md):
23 | Automatically creates `SamplingSpec` by starting from seed node set.
24 |
25 |
26 |
27 |
28 |
29 | Other Members |
30 |
31 |
32 | |
33 | SamplingStrategy
34 | |
35 |
36 | ['TOP_K', 'RANDOM_UNIFORM', 'RANDOM_WEIGHTED']
37 | |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/write_schema.md:
--------------------------------------------------------------------------------
1 | # tfgnn.write_schema
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Write a `GraphSchema` to a text-formatted proto file.
10 |
11 |
12 | tfgnn.write_schema(
13 | schema: tfgnn.proto.GraphSchema,
14 | filename: str
15 | )
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | Args |
25 |
26 |
27 |
28 | schema
29 | |
30 |
31 | A GraphSchema instance to write out.
32 | |
33 |
34 |
35 | filename
36 | |
37 |
38 | A string, the path to a file to render a text-formatted rendition
39 | of the GraphSchema message to.
40 | |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/examples/schemas/graph_nets.pbtxt:
--------------------------------------------------------------------------------
1 | # An example graph schema matching DeepMind's GraphsTuple.
2 | # //third_party/py/tensorflow_gnn/proto/graph_schema.proto:GraphSchema
3 |
4 | context {
5 | features {
6 | key: "embedding"
7 | value: {
8 | description: "Globals feature vector"
9 | dtype: DT_FLOAT
10 | shape: { dim { size: 128 } }
11 | }
12 | }
13 | }
14 |
15 | node_sets {
16 | key: "default"
17 | value {
18 | features {
19 | key: "embedding"
20 | value: {
21 | description: "Encoded node features vector"
22 | dtype: DT_FLOAT
23 | shape: { dim { size: 64 } }
24 | }
25 | }
26 |
27 | features {
28 | key: "labels"
29 | value: {
30 | description: "Multiple ground truth text labels"
31 | dtype: DT_STRING
32 | shape: { dim { size: -1 } }
33 | }
34 | }
35 |
36 | context: "embedding"
37 | description: "Main set of node features"
38 | }
39 | }
40 |
41 | edge_sets {
42 | key: "default"
43 | value {
44 | features {
45 | key: "embedding"
46 | value: {
47 | description: "Encoded edge features vector"
48 | dtype: DT_FLOAT
49 | shape: { dim { size: 32 } }
50 | }
51 | }
52 | source: "default"
53 | target: "default"
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/NodeSet.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.NodeSet
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The schema shared by a set of nodes in the graph.
10 |
11 |
12 |
13 | For detailed documentation, see the comments in the `graph_schema.proto` file.
14 |
15 |
16 |
17 |
18 |
19 | Attributes |
20 |
21 |
22 |
23 | context
24 | |
25 |
26 | repeated string context
27 | |
28 |
29 |
30 | description
31 | |
32 |
33 | string description
34 | |
35 |
36 |
37 | features
38 | |
39 |
40 | repeated FeaturesEntry features
41 | |
42 |
43 |
44 | metadata
45 | |
46 |
47 | Metadata metadata
48 | |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/sampler/SamplingOp.md:
--------------------------------------------------------------------------------
1 | # tfgnn.sampler.SamplingOp
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A ProtocolMessage
10 |
11 |
12 |
13 |
14 |
15 |
16 | Attributes |
17 |
18 |
19 |
20 | edge_set_name
21 | |
22 |
23 | string edge_set_name
24 | |
25 |
26 |
27 | input_op_names
28 | |
29 |
30 | repeated string input_op_names
31 | |
32 |
33 |
34 | op_name
35 | |
36 |
37 | string op_name
38 | |
39 |
40 |
41 | sample_size
42 | |
43 |
44 | int32 sample_size
45 | |
46 |
47 |
48 | strategy
49 | |
50 |
51 | SamplingStrategy strategy
52 | |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/testdata/feature_repr.pbtxt:
--------------------------------------------------------------------------------
1 | context {
2 | features {
3 | key: "rankings"
4 | value {
5 | description: "Scores for each quartile"
6 | dtype: DT_FLOAT
7 | shape { dim { size: 4 } }
8 | }
9 | }
10 | }
11 |
12 | node_sets {
13 | key: "items"
14 | value {
15 | features {
16 | key: "category"
17 | value {
18 | description: "Purchase category"
19 | dtype: DT_STRING
20 | }
21 | }
22 | features {
23 | key: "amounts"
24 | value {
25 | description: "Purchase price"
26 | dtype: DT_FLOAT
27 | shape { dim { size: -1 } }
28 | }
29 | }
30 | }
31 | }
32 |
33 | node_sets {
34 | key: "persons"
35 | value {
36 | features {
37 | key: "name"
38 | value {
39 | description: "First name"
40 | dtype: DT_STRING
41 | }
42 | }
43 | features {
44 | key: "age"
45 | value {
46 | description: "Age"
47 | dtype: DT_INT64
48 | }
49 | }
50 | features {
51 | key: "country"
52 | value {
53 | description: "Country of origin"
54 | dtype: DT_STRING
55 | }
56 | }
57 | }
58 | }
59 |
60 | edge_sets {
61 | key: "purchased"
62 | value { source: "items" target: "persons" }
63 | }
64 |
65 | edge_sets {
66 | key: "is-friend"
67 | value { source: "persons" target: "persons" }
68 | }
69 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorPadding.md:
--------------------------------------------------------------------------------
1 | # runner.GraphTensorPadding
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Collects `GraphtTensor` padding helpers.
10 |
11 |
12 |
13 | ## Methods
14 |
15 | get_filter_fn
16 |
17 | View
18 | source
19 |
20 |
21 | @abc.abstractmethod
22 | get_filter_fn(
23 | size_constraints: SizeConstraints
24 | ) -> Callable[..., bool]
25 |
26 |
27 | get_size_constraints
28 |
29 | View
30 | source
31 |
32 |
33 | @abc.abstractmethod
34 | get_size_constraints(
35 | target_batch_size: int
36 | ) -> SizeConstraints
37 |
38 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/FeatureDefaultValues.md:
--------------------------------------------------------------------------------
1 | # tfgnn.FeatureDefaultValues
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Default values for graph context, node sets and edge sets features.
10 |
11 |
12 | tfgnn.FeatureDefaultValues(
13 | context=None, node_sets=None, edge_sets=None
14 | )
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Attributes |
28 |
29 |
30 |
31 | context
32 | |
33 |
34 | A namedtuple alias for field number 0
35 | |
36 |
37 |
38 | node_sets
39 | |
40 |
41 | A namedtuple alias for field number 1
42 | |
43 |
44 |
45 | edge_sets
46 | |
47 |
48 | A namedtuple alias for field number 2
49 | |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md:
--------------------------------------------------------------------------------
1 | # Module: graph_sage
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | GraphSAGE.
10 |
11 | Users of TF-GNN can use this model by importing it next to the core library as
12 |
13 | ```python
14 | import tensorflow_gnn as tfgnn
15 | from tensorflow_gnn.models import graph_sage
16 | ```
17 |
18 | ## Classes
19 |
20 | [`class GCNGraphSAGENodeSetUpdate`](./graph_sage/GCNGraphSAGENodeSetUpdate.md):
21 | GCNGraphSAGENodeSetUpdate is an extension of the mean aggregator operator.
22 |
23 | [`class GraphSAGEAggregatorConv`](./graph_sage/GraphSAGEAggregatorConv.md):
24 | GraphSAGE: element-wise aggregation of neighbors and their linear
25 | transformation.
26 |
27 | [`class GraphSAGENextState`](./graph_sage/GraphSAGENextState.md):
28 | GraphSAGENextState: compute new node states with GraphSAGE algorithm.
29 |
30 | [`class GraphSAGEPoolingConv`](./graph_sage/GraphSAGEPoolingConv.md): GraphSAGE:
31 | pooling aggregator transform of neighbors followed by linear transformation.
32 |
33 | ## Functions
34 |
35 | [`GraphSAGEGraphUpdate(...)`](./graph_sage/GraphSAGEGraphUpdate.md): Returns a
36 | GraphSAGE GraphUpdater layer for nodes in node_set_names.
37 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/PadToTotalSizes.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.PadToTotalSizes
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Applies tfgnn.pad_to_total_sizes() to a GraphTensor.
10 |
11 |
12 | tfgnn.keras.layers.PadToTotalSizes(
13 | sizes_constraints: tfgnn.SizeConstraints,
14 | *,
15 | validate: bool = True,
16 | **kwargs
17 | )
18 |
19 |
20 |
21 |
22 |
23 |
24 | This Keras layer maps a GraphTensor to a GraphTensor by calling
25 | tfgnn.pad_to_total_sizes() with the additional arguments, notably
26 | `sizes_constraints`, passed at initialization time. See that function
27 | for detailed documentation.
28 |
29 | This layer can be restored from config by `tf.keras.models.load_model()` when
30 | saved as part of a Keras model using `save_format="tf"`. Serialization to a
31 | Keras model config requires the `sizes_constraints` to contain Python integers
32 | or eager Tensors, not symbolic Tensors.
33 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gcn/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "tf_py_test")
7 |
8 | licenses(["notice"])
9 |
10 | package(
11 | default_visibility = [":__subpackages__"],
12 | )
13 |
14 | package_group(name = "users")
15 |
16 | pytype_strict_library(
17 | name = "gcn",
18 | srcs = ["__init__.py"],
19 | visibility = [
20 | ":__subpackages__",
21 | ":users",
22 | ],
23 | deps = [
24 | ":gcn_conv",
25 | "//tensorflow_gnn/utils:api_utils",
26 | ],
27 | )
28 |
29 | pytype_strict_library(
30 | name = "gcn_conv",
31 | srcs = ["gcn_conv.py"],
32 | deps = [
33 | "//:expect_tensorflow_installed",
34 | "//tensorflow_gnn",
35 | ],
36 | )
37 |
38 | tf_py_test(
39 | name = "gcn_conv_test",
40 | srcs = ["gcn_conv_test.py"],
41 | deps = [
42 | ":gcn_conv",
43 | "//:expect_absl_installed_testing",
44 | "//:expect_tensorflow_installed",
45 | "//tensorflow_gnn",
46 | "//tensorflow_gnn/utils:tf_test_utils",
47 | "//:expect_ai_edge_litert_installed",
48 | ],
49 | )
50 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/graph_sage/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "tf_py_test")
7 |
8 | licenses(["notice"])
9 |
10 | package(default_visibility = [
11 | "//tensorflow_gnn:__pkg__",
12 | "//tensorflow_gnn/graph:__subpackages__",
13 | ])
14 |
15 | package_group(name = "users")
16 |
17 | pytype_strict_library(
18 | name = "graph_sage",
19 | srcs = ["__init__.py"],
20 | visibility = [":users"],
21 | deps = [
22 | ":layers",
23 | "//tensorflow_gnn/utils:api_utils",
24 | ],
25 | )
26 |
27 | pytype_strict_library(
28 | name = "layers",
29 | srcs = ["layers.py"],
30 | deps = [
31 | "//:expect_tensorflow_installed",
32 | "//tensorflow_gnn",
33 | ],
34 | )
35 |
36 | tf_py_test(
37 | name = "layers_test",
38 | srcs = ["layers_test.py"],
39 | deps = [
40 | ":layers",
41 | "//:expect_absl_installed_testing",
42 | "//:expect_tensorflow_installed",
43 | "//tensorflow_gnn",
44 | "//tensorflow_gnn/utils:tf_test_utils",
45 | "//:expect_ai_edge_litert_installed",
46 | ],
47 | )
48 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Metadata.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.Metadata
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Extra information optionally provided on a context, node set or edge set.
10 |
11 |
12 |
13 | For detailed documentation, see the comments in the `graph_schema.proto` file.
14 |
15 |
16 |
17 |
18 |
19 | Attributes |
20 |
21 |
22 |
23 | bigquery
24 | |
25 |
26 | BigQuery bigquery
27 | |
28 |
29 |
30 | cardinality
31 | |
32 |
33 | int64 cardinality
34 | |
35 |
36 |
37 | extra
38 | |
39 |
40 | repeated KeyValue extra
41 | |
42 |
43 |
44 | filename
45 | |
46 |
47 | string filename
48 | |
49 |
50 |
51 |
52 | ## Child Classes
53 |
54 | [`class KeyValue`](../../tfgnn/proto/Metadata/KeyValue.md)
55 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/contrastive_losses/README.md:
--------------------------------------------------------------------------------
1 | # Contrastive Losses
2 |
3 | ## Overview
4 |
5 | This code implements and collections various contrastive losses for
6 | self-supervised learning. This code is under *active development*. An overview
7 | of the included:
8 |
9 | ### Deep Graph Infomax
10 |
11 | Deep Graph Infomax [1] attempts to learn a bilinear
12 | layer capable of discriminating between positive examples (any input
13 | `GraphTensor`) and negative examples (the input `GraphTensor` but with perturbed
14 | features: this implementation, as in the original paper, shuffles features
15 | across batch, that is, the components the merged `GraphTensor`).
16 |
17 | Deep Graph Infomax is particularly useful in unsupervised tasks that wish to
18 | learn latent representations informed primarily by a node's neighborhood
19 | attributes (vs. its structure).
20 |
21 | * [1] Petar Veličković, William Fedus, William L. Hamilton, Pietro Liò,
22 | Yoshua Bengio, R Devon Hjelm:
23 | ["Deep Graph Infomax"](https://arxiv.org/abs/1809.10341), 2018.
24 |
25 | ## Usage
26 |
27 | TensorFlow programs can import and use this model as described in its
28 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md).
29 |
30 | ## API stability
31 |
32 | The API of this model may change between OSS library versions.
33 |
34 |
35 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/SingleInputNextState.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.SingleInputNextState
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Replaces a state from a single input.
10 |
11 |
12 | tfgnn.keras.layers.SingleInputNextState(
13 | trainable=True, name=None, dtype=None, dynamic=False, **kwargs
14 | )
15 |
16 |
17 |
18 |
19 | In a NodeSetUpdate, it replaces the node state with a single edge set input. For
20 | an EdgeSetUpdate, it replaces the edge_state with the incident node set's input.
21 | For a ContextUpdate, it replaces the context state with a single node set input.
22 |
23 | This layer can be restored from config by `tf.keras.models.load_model()` when
24 | saved as part of a Keras model using `save_format="tf"`.
25 |
26 |
27 |
28 |
29 | Call returns |
30 |
31 | |
32 | A tensor to use as the new state.
33 | |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tensorflow_gnn/graph/dict_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utilities for Python dictionaries."""
16 |
17 | from typing import Any, Dict, Mapping, MutableMapping
18 |
19 |
20 | def with_key_prefix(d: Mapping[str, Any], key_prefix: str) -> Dict[str, Any]:
21 | """Returns {key_prefix+k: v for k, v in d.items()}."""
22 | return {key_prefix+k: v for k, v in d.items()}
23 |
24 |
25 | def pop_by_prefix(
26 | d: MutableMapping[str, Any], key_prefix: str) -> Dict[str, Any]:
27 | """Returns {k: v for key_prefix+k, v in d.items()} and removes them from d."""
28 | popped = {}
29 | for key in list(d.keys()):
30 | if key.startswith(key_prefix):
31 | popped[key[len(key_prefix):]] = d.pop(key)
32 | return popped
33 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/Feature.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.Feature
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The schema entry for a single feature.
10 |
11 |
12 | View aliases
13 |
14 | Main aliases
15 |
`tfgnn.Feature`
16 |
17 |
18 |
19 |
20 |
21 | For detailed documentation, see the comments in the `graph_schema.proto` file.
22 |
23 |
24 |
25 |
26 |
27 | Attributes |
28 |
29 |
30 |
31 | description
32 | |
33 |
34 | string description
35 | |
36 |
37 |
38 | dtype
39 | |
40 |
41 | DataType dtype
42 | |
43 |
44 |
45 | shape
46 | |
47 |
48 | TensorShapeProto shape
49 | |
50 |
51 |
52 | source
53 | |
54 |
55 | string source
56 | |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/gat_v2.md:
--------------------------------------------------------------------------------
1 | # Module: gat_v2
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Graph Attention Networks v2.
10 |
11 | Users of TF-GNN can use this model by importing it next to the core library as
12 |
13 | ```python
14 | import tensorflow_gnn as tfgnn
15 | from tensorflow_gnn.models import gat_v2
16 | ```
17 |
18 | ## Classes
19 |
20 | [`class GATv2Conv`](./gat_v2/GATv2Conv.md): The multi-head attention from Graph
21 | Attention Networks v2 (GATv2).
22 |
23 | ## Functions
24 |
25 | [`GATv2EdgePool(...)`](./gat_v2/GATv2EdgePool.md): Returns a layer for pooling
26 | edges with GATv2-style attention.
27 |
28 | [`GATv2HomGraphUpdate(...)`](./gat_v2/GATv2HomGraphUpdate.md): Returns a
29 | GraphUpdate layer with a Graph Attention Network V2 (GATv2).
30 |
31 | [`GATv2MPNNGraphUpdate(...)`](./gat_v2/GATv2MPNNGraphUpdate.md): Returns a
32 | GraphUpdate layer for message passing with GATv2 pooling.
33 |
34 | [`graph_update_from_config_dict(...)`](./gat_v2/graph_update_from_config_dict.md):
35 | Returns a GATv2MPNNGraphUpdate initialized from `cfg`.
36 |
37 | [`graph_update_get_config_dict(...)`](./gat_v2/graph_update_get_config_dict.md):
38 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
39 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/parse_schema.md:
--------------------------------------------------------------------------------
1 | # tfgnn.parse_schema
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Parse a schema from text-formatted protos.
10 |
11 |
12 | tfgnn.parse_schema(
13 | schema_text: Union[bytes, str]
14 | ) -> tfgnn.proto.GraphSchema
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | Args |
24 |
25 |
26 |
27 | schema_text
28 | |
29 |
30 | A string containing a text-formatted protocol buffer rendition
31 | of a GraphSchema message.
32 | |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | Returns |
41 |
42 |
43 | A GraphSchema instance.
44 | |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/read_schema.md:
--------------------------------------------------------------------------------
1 | # tfgnn.read_schema
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Read a proto schema from a file with text-formatted contents.
10 |
11 |
12 | tfgnn.read_schema(
13 | filename: str
14 | ) -> tfgnn.proto.GraphSchema
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | Args |
24 |
25 |
26 |
27 | filename
28 | |
29 |
30 | A string, the path to a file containing a text-formatted protocol
31 | buffer rendition of a GraphSchema message.
32 | |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | Returns |
41 |
42 |
43 | A GraphSchema instance.
44 | |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/vanilla_mpnn/hparams_vizier_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for hparams_vizier."""
16 |
17 | from absl.testing import absltest
18 | from tensorflow_gnn.models.vanilla_mpnn import hparams_vizier
19 |
20 | from vizier.service import pyvizier as vz
21 |
22 |
23 | class HparamsVizierTest(absltest.TestCase):
24 |
25 | def test_regularization(self):
26 | problem = vz.ProblemStatement()
27 | hparams_vizier.add_params_regularization(problem.search_space,
28 | prefix="foo.")
29 | self.assertCountEqual(
30 | [p.name for p in problem.search_space.parameters],
31 | ["foo.dropout_rate", "foo.l2_regularization"])
32 |
33 |
34 | if __name__ == "__main__":
35 | absltest.main()
36 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/ModelExporter.md:
--------------------------------------------------------------------------------
1 | # runner.ModelExporter
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Saves a Keras model.
10 |
11 |
12 |
13 | ## Methods
14 |
15 | save
16 |
17 | View
18 | source
19 |
20 |
21 | @abc.abstractmethod
22 | save(
23 | run_result: RunResult, export_dir: str
24 | )
25 |
26 |
27 | Saves a Keras model.
28 |
29 | All persistence decisions are left to the implementation: e.g., a Keras model
30 | with full API or a simple `tf.train.Checkpoint` may be saved.
31 |
32 |
33 |
34 |
35 |
36 | | Args |
37 |
38 |
39 |
40 | run_result
41 | |
42 |
43 | A RunResult from training.
44 | |
45 |
46 |
47 | export_dir
48 | |
49 |
50 | A destination directory.
51 | |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tensorflow_gnn/runner/utils/model_dir.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Model directory methods."""
16 | import os
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def incrementing_model_dir(dirname: str, start: int = 0) -> str:
22 | """Create, given some `dirname`, an incrementing model directory.
23 |
24 | Args:
25 | dirname: The base directory name.
26 | start: The starting integer.
27 |
28 | Returns:
29 | A model directory `dirname/n` where 'n' is the maximum integer in `dirname`.
30 | """
31 | if not tf.io.gfile.isdir(dirname):
32 | return os.path.join(dirname, str(start))
33 | files = tf.io.gfile.listdir(dirname)
34 | integers = [int(f) for f in files if f.isdigit()]
35 | return os.path.join(dirname, str(max(integers) + 1 if integers else start))
36 |
--------------------------------------------------------------------------------
/tensorflow_gnn/data/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_contrib_test")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
7 |
8 | licenses(["notice"])
9 |
10 | package(
11 | default_applicable_licenses = ["//tensorflow_gnn:license"],
12 | default_visibility = ["//visibility:public"],
13 | )
14 |
15 | GOOGLE_INTERNAL_UNIGRAPH_DEPENDENCIES = []
16 |
17 | pytype_strict_library(
18 | name = "unigraph",
19 | srcs = ["unigraph.py"],
20 | deps = [
21 | "//third_party/py/apache_beam",
22 | "//third_party/py/pyarrow",
23 | "//:expect_tensorflow_installed",
24 | "//tensorflow_gnn",
25 | ] + GOOGLE_INTERNAL_UNIGRAPH_DEPENDENCIES,
26 | )
27 |
28 | pytype_strict_contrib_test(
29 | name = "unigraph_test",
30 | srcs = ["unigraph_test.py"],
31 | data = [
32 | "@tensorflow_gnn//testdata/heterogeneous",
33 | "@tensorflow_gnn//testdata/homogeneous",
34 | ],
35 | deps = [
36 | ":unigraph",
37 | "//:expect_absl_installed_testing",
38 | "//third_party/py/apache_beam",
39 | "//:expect_tensorflow_installed",
40 | "//tensorflow_gnn",
41 | "//tensorflow_gnn/utils:test_utils",
42 | ],
43 | )
44 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/PassthruDatasetProvider.md:
--------------------------------------------------------------------------------
1 | # runner.PassthruDatasetProvider
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Builds a `tf.data.Dataset` from a pass thru dataset.
10 |
11 | Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
12 |
13 |
14 | runner.PassthruDatasetProvider(
15 | dataset: tf.data.Dataset,
16 | *,
17 | shuffle_datasets: bool = False,
18 | examples_shuffle_size: Optional[int] = None
19 | )
20 |
21 |
22 |
23 |
24 | Passes any `dataset` thru: omitting any sharding. For detailed documentation,
25 | see the filename dataset provider complement: `SimpleDatasetProvider.`
26 |
27 | ## Methods
28 |
29 | get_dataset
30 |
31 | View
32 | source
33 |
34 |
35 | get_dataset(
36 | _: tf.distribute.InputContext
37 | ) -> tf.data.Dataset
38 |
39 |
40 | Gets a `tf.data.Dataset` omitting any input context.
41 |
--------------------------------------------------------------------------------
/tensorflow_gnn/graph/dict_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for dict_utils."""
16 |
17 | from absl.testing import absltest
18 | from tensorflow_gnn.graph import dict_utils
19 |
20 |
21 | class KeyPrefixTest(absltest.TestCase):
22 |
23 | def testWithKeyPrefix(self):
24 | d1 = {"a": 1, "b": 2}
25 | d2 = dict_utils.with_key_prefix(d1, "p/")
26 | self.assertDictEqual(d1, {"a": 1, "b": 2}) # Unchanged.
27 | self.assertDictEqual(d2, {"p/a": 1, "p/b": 2})
28 |
29 | def testPopByPrefix(self):
30 | d1 = {"p/a": 1, "p/b": 2, "q/c": 3, "q/d": 4}
31 | d2 = dict_utils.pop_by_prefix(d1, "p/")
32 | self.assertDictEqual(d1, {"q/c": 3, "q/d": 4}) # Changed in-place.
33 | self.assertDictEqual(d2, {"a": 1, "b": 2})
34 |
35 |
36 | if __name__ == "__main__":
37 | absltest.main()
38 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/GraphSchema.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.GraphSchema
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The top-level container for the schema of a graph dataset.
10 |
11 |
12 | View aliases
13 |
14 | Main aliases
15 |
`tfgnn.GraphSchema`
16 |
17 |
18 |
19 |
20 |
21 | For detailed documentation, see the comments in the `graph_schema.proto` file.
22 |
23 |
24 |
25 |
26 |
27 | Attributes |
28 |
29 |
30 |
31 | context
32 | |
33 |
34 | Context context
35 | |
36 |
37 |
38 | edge_sets
39 | |
40 |
41 | repeated EdgeSetsEntry edge_sets
42 | |
43 |
44 |
45 | info
46 | |
47 |
48 | OriginInfo info
49 | |
50 |
51 |
52 | node_sets
53 | |
54 |
55 | repeated NodeSetsEntry node_sets
56 | |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/examples/schemas/latent.pbtxt:
--------------------------------------------------------------------------------
1 | # An example graph schema with some latent nodes.
2 | # This defines three node sets and three edge sets:
3 | # dish -> restaurant -> user
4 | # dish ---------------> user
5 | # with only the 'dish' node type having features.
6 |
7 | node_sets {
8 | key: "dish"
9 | value {
10 | description: "A canonical dish"
11 | features {
12 | key: "embedding"
13 | value: {
14 | description: "Some dish embedding feature"
15 | dtype: DT_FLOAT
16 | shape: {
17 | dim {
18 | size: 32
19 | }
20 | }
21 | }
22 | }
23 | }
24 | }
25 |
26 | # This is how you define a latent set of nodes.
27 | node_sets {
28 | key: "restaurant"
29 | value {
30 | description: "A restaurant that offers dishes"
31 | }
32 | }
33 |
34 | node_sets {
35 | key: "user"
36 | value {
37 | description: "A person who visits restaurants and orders dishes"
38 | }
39 | }
40 |
41 | edge_sets {
42 | key: "dish->restaurant"
43 | value {
44 | description: "Dishes available at restaurants"
45 | source: "dish"
46 | target: "restaurant"
47 | }
48 | }
49 |
50 | edge_sets {
51 | key: "restaurant->user"
52 | value {
53 | description: "Restaurant visits by users"
54 | source: "restaurant"
55 | target: "user"
56 | }
57 | }
58 |
59 | edge_sets {
60 | key: "dish->restaurant"
61 | value {
62 | description: "Dishes the users has ordered in the past"
63 | source: "dish"
64 | target: "restaurant"
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/incrementing_model_dir.md:
--------------------------------------------------------------------------------
1 | # runner.incrementing_model_dir
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Create, given some `dirname`, an incrementing model directory.
10 |
11 |
12 | runner.incrementing_model_dir(
13 | dirname: str, start: int = 0
14 | ) -> str
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | Args |
23 |
24 |
25 |
26 | dirname
27 | |
28 |
29 | The base directory name.
30 | |
31 |
32 |
33 | start
34 | |
35 |
36 | The starting integer.
37 | |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | Returns |
46 |
47 |
48 | A model directory dirname/n where 'n' is the maximum integer in dirname.
49 | |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gcn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Graph Convolutional Networks.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import gcn
22 | ```
23 | """
24 | from tensorflow_gnn.models.gcn import gcn_conv
25 | from tensorflow_gnn.utils import api_utils
26 |
27 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
28 | # Please see there for instructions how to reflect API changes.
29 | # LINT.IfChange
30 |
31 | GCNConv = gcn_conv.GCNConv
32 | GCNHomGraphUpdate = gcn_conv.GCNHomGraphUpdate
33 |
34 | # Remove all names added by module imports, unless explicitly allowed here.
35 | api_utils.remove_submodules_except(__name__, [])
36 | # LINT.ThenChange(../../api_def/gcn-symbols.txt)
37 |
--------------------------------------------------------------------------------
/tensorflow_gnn/proto/BUILD:
--------------------------------------------------------------------------------
1 | # Copybara rewrites some of these load() statements back and forth; do not reformat.
2 | # buildifier: disable=out-of-order-load, disable=same-origin-load
3 | load("@rules_proto//proto:defs.bzl", "proto_library")
4 |
5 | # buildifier: disable=out-of-order-load, disable=same-origin-load
6 | load("@rules_python//python:proto.bzl", "py_proto_library")
7 |
8 | package(default_visibility = ["//visibility:public"])
9 |
10 | licenses(["notice"]) # Apache 2.0
11 |
12 | py_library(
13 | name = "proto",
14 | srcs = ["__init__.py"],
15 | srcs_version = "PY3",
16 | visibility = ["//visibility:public"],
17 | deps = [
18 | ":graph_schema",
19 | "//tensorflow_gnn/utils:api_utils",
20 | ],
21 | )
22 |
23 | proto_library(
24 | name = "graph_schema_proto",
25 | srcs = ["graph_schema.proto"],
26 | deps = [
27 | "@org_tensorflow//tensorflow/core:protos_all",
28 | ],
29 | )
30 |
31 | py_proto_library(
32 | name = "graph_schema_py_proto",
33 | deps = [
34 | ":graph_schema_proto",
35 | ],
36 | )
37 |
38 | py_library(
39 | name = "graph_schema",
40 | srcs = ["graph_schema.py"],
41 | srcs_version = "PY3",
42 | deps = [
43 | ":graph_schema_py_proto",
44 | ],
45 | )
46 |
47 | proto_library(
48 | name = "examples_proto",
49 | srcs = ["examples.proto"],
50 | deps = ["@org_tensorflow//tensorflow/core:protos_all"],
51 | )
52 |
53 | py_proto_library(
54 | name = "examples_py_proto",
55 | deps = [
56 | ":examples_proto",
57 | ],
58 | )
59 |
--------------------------------------------------------------------------------
/tensorflow_gnn/tools/generate_training_data_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Unit tests for generate training data test."""
16 |
17 | from os import path
18 |
19 | from absl import flags
20 | import tensorflow as tf
21 | from tensorflow_gnn.tools import generate_training_data
22 | from tensorflow_gnn.utils import test_utils
23 |
24 |
25 | FLAGS = flags.FLAGS
26 |
27 |
28 | class GenerateDataTest(tf.test.TestCase):
29 |
30 | def test_generate_training_data(self):
31 | schema_filename = test_utils.get_resource("examples/schemas/mpnn.pbtxt")
32 | output_filename = path.join(FLAGS.test_tmpdir, "examples.tfrecords")
33 | generate_training_data.generate_training_data(
34 | schema_filename, output_filename, "tfrecord", 64)
35 | self.assertTrue(path.exists(output_filename))
36 |
37 |
38 | if __name__ == "__main__":
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/examples/schemas/mpnn.pbtxt:
--------------------------------------------------------------------------------
1 | # An example graph schema matching demo graphs from MPNN modeling.
2 | # //third_party/py/tensorflow_gnn/proto/graph_schema.proto:GraphSchema
3 |
4 | context {
5 | features {
6 | key: "embedding"
7 | value: {
8 | description: "Global feature vector"
9 | dtype: DT_FLOAT
10 | shape: { dim { size: 128 } }
11 | }
12 | }
13 | }
14 |
15 | node_sets {
16 | key: "videos"
17 | value {
18 | features {
19 | key: "features"
20 | value: {
21 | description: "Encoded video features vector"
22 | dtype: DT_FLOAT
23 | shape: { dim { size: 256 } }
24 | }
25 | }
26 | }
27 | }
28 |
29 | node_sets {
30 | key: "channels"
31 | value {
32 | description: "User or Channel in YouTube."
33 | context: "embedding"
34 |
35 | features {
36 | key: "features"
37 | value: {
38 | description: "Encoded channel features vector"
39 | dtype: DT_FLOAT
40 | shape: { dim { size: 128 } }
41 | }
42 | }
43 | features {
44 | key: "labels"
45 | value: {
46 | description: "Multiple ground truth text labels"
47 | dtype: DT_STRING
48 | shape: { dim { size: -1 } }
49 | }
50 | }
51 | }
52 | }
53 |
54 | edge_sets {
55 | key: "videos->channels"
56 | value {
57 | features {
58 | key: "embedding"
59 | value: {
60 | description: "Encoded edge features vector"
61 | dtype: DT_FLOAT
62 | shape: { dim { size: 32 } }
63 | }
64 | }
65 | source: "videos"
66 | target: "channels"
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/proto/EdgeSet.md:
--------------------------------------------------------------------------------
1 | # tfgnn.proto.EdgeSet
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | The schema shared by a set of edges that connect the same pair of node sets.
10 |
11 |
12 |
13 | For detailed documentation, see the comments in the `graph_schema.proto` file.
14 |
15 |
16 |
17 |
18 |
19 | Attributes |
20 |
21 |
22 |
23 | context
24 | |
25 |
26 | repeated string context
27 | |
28 |
29 |
30 | description
31 | |
32 |
33 | string description
34 | |
35 |
36 |
37 | features
38 | |
39 |
40 | repeated FeaturesEntry features
41 | |
42 |
43 |
44 | metadata
45 | |
46 |
47 | Metadata metadata
48 | |
49 |
50 |
51 | source
52 | |
53 |
54 | string source
55 | |
56 |
57 |
58 | target
59 | |
60 |
61 | string target
62 | |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/DropoutFeatures.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.DropoutFeatures
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Base class for graph corruptor.
10 |
11 | Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md)
12 |
13 |
14 | contrastive_losses.DropoutFeatures(
15 | *args, seed: Optional[float] = None, **kwargs
16 | )
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | Args |
25 |
26 |
27 |
28 | corruption_spec
29 | |
30 |
31 | A spec for corruption application.
32 | |
33 |
34 |
35 | corruption_fn
36 | |
37 |
38 | Corruption function.
39 | |
40 |
41 |
42 | default
43 | |
44 |
45 | Global application default of the corruptor. This is only used
46 | when corruption_spec is None.
47 | |
48 |
49 |
50 | **kwargs
51 | |
52 |
53 | Additional keyword arguments.
54 | |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/tensorflow_gnn/keras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The tfgnn.keras package."""
16 |
17 | from tensorflow_gnn.keras import builders
18 | from tensorflow_gnn.keras import initializers
19 | from tensorflow_gnn.keras import keras_tensors # To register the types. pylint: disable=unused-import
20 | from tensorflow_gnn.keras import layers # Exposed as submodule. pylint: disable=unused-import
21 | from tensorflow_gnn.utils import api_utils
22 |
23 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
24 | # Please see there for instructions how to reflect API changes.
25 | # LINT.IfChange
26 |
27 | ConvGNNBuilder = builders.ConvGNNBuilder
28 | clone_initializer = initializers.clone_initializer
29 |
30 | # Remove all names added by module imports, unless explicitly allowed here.
31 | api_utils.remove_submodules_except(__name__, [
32 | "layers",
33 | ])
34 | # LINT.ThenChange()../api_def/tfgnn-symbols.txt)
35 |
--------------------------------------------------------------------------------
/examples/sampler/creditcard/graph_schema.pbtxt:
--------------------------------------------------------------------------------
1 | # proto-file: tensorflow_gnn/proto/graph_schema.proto
2 | # proto-message: GraphSchema
3 | # Example graph of transactions, credit cards, customer.
4 |
5 | node_sets {
6 | key: "customer"
7 | value {
8 | features {
9 | key: "name"
10 | value: {
11 | description: "Name"
12 | dtype: DT_STRING
13 | }
14 | }
15 | features {
16 | key: "address"
17 | value: {
18 | description: "address"
19 | dtype: DT_STRING
20 | }
21 | }
22 | features {
23 | key: "zipcode"
24 | value: {
25 | description: "Zipcode"
26 | dtype: DT_INT64
27 | }
28 | }
29 | features {
30 | key: "score"
31 | value: {
32 | description: "Credit score"
33 | dtype: DT_FLOAT
34 | }
35 | }
36 | metadata {
37 | filename: "customer.csv"
38 | }
39 | }
40 | }
41 |
42 | node_sets {
43 | key: "creditcard"
44 | value {
45 | metadata {
46 | filename: "creditcard.csv"
47 | }
48 | features {
49 | key: "number"
50 | value: {
51 | description: "Credit card number"
52 | dtype: DT_INT64
53 | }
54 | }
55 | features {
56 | key: "issuer"
57 | value: {
58 | description: "Credit card issuer institution"
59 | dtype: DT_STRING
60 | }
61 | }
62 | }
63 | }
64 |
65 | edge_sets {
66 | key: "owns_card"
67 | value {
68 | description: "Owns and uses the credit card."
69 | source: "customer"
70 | target: "creditcard"
71 | metadata {
72 | filename: "owns_card.csv"
73 | }
74 | }
75 | }
--------------------------------------------------------------------------------
/tensorflow_gnn/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Experimental (unstable) parts of the public interface of TensorFlow GNN.
16 |
17 | A symbol `foo` exposed here is available to library users as
18 |
19 | ```
20 | import tensorflow_gnn as tfgnn
21 |
22 | tfgnn.experimental.foo()
23 | ```
24 |
25 | This is the preferred way to expose individual functions on track to inclusion
26 | into the stable public interface of TensorFlow GNN.
27 |
28 | Beyond these symbols, there are also experimental sub-libraries that
29 | need to be imported separately (`from tensorflow_gnn.experimental import foo`).
30 | That is for special cases only.
31 | """
32 |
33 | from tensorflow_gnn.graph import readout
34 | from tensorflow_gnn.graph import tensor_utils
35 |
36 | context_readout_into_feature = readout.context_readout_into_feature
37 | segment_random_index_shuffle = tensor_utils.segment_random_index_shuffle
38 |
39 | del readout
40 | del tensor_utils
41 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/Corruptor.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.Corruptor
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Base class for graph corruptor.
10 |
11 |
12 | contrastive_losses.Corruptor(
13 | corruption_spec: Optional[CorruptionSpec[T]] = None,
14 | *,
15 | corruption_fn: Callable[[tfgnn.Field, T], tfgnn.Field],
16 | default: Optional[T] = None,
17 | **kwargs
18 | )
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | Args |
27 |
28 |
29 |
30 | corruption_spec
31 | |
32 |
33 | A spec for corruption application.
34 | |
35 |
36 |
37 | corruption_fn
38 | |
39 |
40 | Corruption function.
41 | |
42 |
43 |
44 | default
45 | |
46 |
47 | Global application default of the corruptor. This is only used
48 | when corruption_spec is None.
49 | |
50 |
51 |
52 | **kwargs
53 | |
54 |
55 | Additional keyword arguments.
56 | |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/SizeConstraints.md:
--------------------------------------------------------------------------------
1 | # tfgnn.SizeConstraints
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Constraints on the number of entities in the graph.
10 |
11 |
12 | tfgnn.SizeConstraints(
13 | total_num_components,
14 | total_num_nodes,
15 | total_num_edges,
16 | min_nodes_per_component=()
17 | )
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | Attributes |
31 |
32 |
33 |
34 | total_num_components
35 | |
36 |
37 | A namedtuple alias for field number 0
38 | |
39 |
40 |
41 | total_num_nodes
42 | |
43 |
44 | A namedtuple alias for field number 1
45 | |
46 |
47 |
48 | total_num_edges
49 | |
50 |
51 | A namedtuple alias for field number 2
52 | |
53 |
54 |
55 | min_nodes_per_component
56 | |
57 |
58 | A namedtuple alias for field number 3
59 | |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/multi_head_attention/README.md:
--------------------------------------------------------------------------------
1 | # Transformer-style multi-head attention
2 |
3 | ## Overview
4 |
5 | This code implements transformer-style (dot-product) multi-head attention,
6 | with different variants and optional attention score leaks.
7 |
8 | Some publications in the GNN context that either use this multi-head attention
9 | as a component ([1]&[2]) or a baseline ([3]) of their method:
10 |
11 | * [1] Vijay Prakash Dwivedi, Xavier Bresson: ["A Generalization of Transformer
12 | Networks to Graphs"](https://arxiv.org/abs/2012.09699), 2021.
13 | (We only implement their attention, not their position encoding.)
14 | * [2] Dongkwan Kim, Alice Oh: ["How to Find Your Friendly Neighborhood: Graph
15 | Attention Design with Self-Supervision"](https://arxiv.org/abs/2204.04879)
16 | , 2022. (They call it "DP" attention.)
17 | * [3] Shaked Brody, Uri Alon, Eran Yahav: ["How Attentive are Graph Attention
18 | Networks?"](https://arxiv.org/abs/2105.14491), 2021.
19 | (They discuss "DPGAT" as a baseline in the appendix, citing further uses.
20 | Their main contribution "GATv2" is implemented [elsewhere](../gat_v2)
21 | in TF-GNN.)
22 |
23 | ## Usage
24 |
25 | TensorFlow programs can import and use this model as described in its
26 | [API docs](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md).
27 |
28 | ## API stability
29 |
30 | The API of this model may change between OSS library versions.
31 |
32 | TF-GNN's [Model Template "Albis"](../mt_albis/README.md) offers a stable and
33 | simplified API for a subset of this model's configuration options.
34 |
35 |
--------------------------------------------------------------------------------
/testdata/homogeneous/tastelike.recordio.ascii:
--------------------------------------------------------------------------------
1 | features {
2 | feature {
3 | key: "#source"
4 | value: { bytes_list: { value: [ "amanatsu" ] } }
5 | }
6 | feature {
7 | key: "#target"
8 | value: { bytes_list: { value: [ "daidai" ] } }
9 | }
10 | feature {
11 | key: "weights"
12 | value: { float_list: { value: [ 0.1 ] } }
13 | }
14 | }
15 |
16 | features {
17 | feature {
18 | key: "#source"
19 | value: { bytes_list: { value: [ "amanatsu" ] } }
20 | }
21 | feature {
22 | key: "#target"
23 | value: { bytes_list: { value: [ "lumia" ] } }
24 | }
25 | feature {
26 | key: "weights"
27 | value: { float_list: { value: [ 0.2 ] } }
28 | }
29 | }
30 |
31 | features {
32 | feature {
33 | key: "#source"
34 | value: { bytes_list: { value: [ "kiyomi" ] } }
35 | }
36 | feature {
37 | key: "#target"
38 | value: { bytes_list: { value: [ "komikan" ] } }
39 | }
40 | feature {
41 | key: "weights"
42 | value: { float_list: { value: [ 0.3 ] } }
43 | }
44 | }
45 |
46 | features {
47 | feature {
48 | key: "#source"
49 | value: { bytes_list: { value: [ "mandora" ] } }
50 | }
51 | feature {
52 | key: "#target"
53 | value: { bytes_list: { value: [ "komikan" ] } }
54 | }
55 | feature {
56 | key: "weights"
57 | value: { float_list: { value: [ 0.4 ] } }
58 | }
59 | }
60 |
61 | features {
62 | feature {
63 | key: "#source"
64 | value: { bytes_list: { value: [ "mandora" ] } }
65 | }
66 | feature {
67 | key: "#target"
68 | value: { bytes_list: { value: [ "tangelo" ] } }
69 | }
70 | feature {
71 | key: "weights"
72 | value: { float_list: { value: [ 0.5 ] } }
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/transactions.csv:
--------------------------------------------------------------------------------
1 | id,merchant,amount
2 | 5338667949,AcmeCorp,1156.80
3 | 2485246926,Plumbing Co,206.92
4 | 7574807079,HW Store Big,7.97
5 | 1935037613,Cons. Depot,2.38
6 | 3719329604,The Oil Tanker,172.99
7 | 1015902584,Light Inc,924.72
8 | 6360467569,Roofing Etc.,6.27
9 | 3672845476,Dig-a-Hole Co,91.68
10 | 7577999036,Smile Networks,257.00
11 | 4077264491,Jewel Corporation,1162.72
12 | 6932577020,Melon Systems,32.17
13 | 2952830085,Explority,1.02
14 | 4045329101,Apachicorp,1085.31
15 | 2187034790,Gnomelectrics,44.41
16 | 4071088591,Amazystems,19.96
17 | 3301719092,Herocast,329.38
18 | 3630070609,Gorillalife,689.12
19 | 7220792354,Desertcoms,2.19
20 | 9418989087,Yellow Sports,41.13
21 | 4751453767,Vertex Coms,30.05
22 | 5488583952,Ice Records,67.77
23 | 7738736966,Happindustries,37.15
24 | 7118748138,Antelligence,1.03
25 | 2507142984,Cycloration,299.41
26 | 6967991082,Herolutions,16.24
27 | 9099370006,Signcloud,9.63
28 | 2674535393,Diamondlight,18.76
29 | 8806410342,Accentmart,21.09
30 | 8140734876,Padlock Productions,59.83
31 | 2432513727,Buck Intelligence,132.79
32 | 5420798237,Root Aviation,648.87
33 | 6060907986,Joytechs,161.45
34 | 3315349239,Driftonics,3.66
35 | 5221186582,Honeydustries,21.87
36 | 8597703614,Tigressystems,204.30
37 | 2277618242,Cubebank,3.32
38 | 8179757481,Voidtechs,9.99
39 | 2561247267,Houndbooks,11.17
40 | 5675097657,Slick Co.,41.98
41 | 8790765633,Oak Softwares,213.11
42 | 4719730577,Hero Motors,7.86
43 | 9232165517,Elecoms,124.75
44 | 8334230482,Microlutions,12.26
45 | 5930862377,Tundracoustics,0.64
46 | 8603978315,Tempestechnologies,842.09
47 | 5339238148,Stormex,2.61
48 | 2346560845,Heartair,2.34
49 | 4415620884,Gnomebank,12.14
--------------------------------------------------------------------------------
/tensorflow_gnn/models/hgt/hparams_vizier_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for hparams_vizier."""
16 |
17 | from absl.testing import absltest
18 | from tensorflow_gnn.models.hgt import hparams_vizier
19 |
20 | from vizier.service import pyvizier as vz
21 |
22 |
23 | class HparamsVizierTest(absltest.TestCase):
24 |
25 | def test_regularization(self):
26 | problem = vz.ProblemStatement()
27 | hparams_vizier.add_params_regularization(
28 | problem.search_space, prefix="foo."
29 | )
30 | self.assertCountEqual(
31 | [p.name for p in problem.search_space.parameters], ["foo.dropout_rate"]
32 | )
33 |
34 | def test_hgt_attention(self):
35 | problem = vz.ProblemStatement()
36 | hparams_vizier.add_params_attention(
37 | problem.search_space, prefix="foo."
38 | )
39 | self.assertCountEqual(
40 | [p.name for p in problem.search_space.parameters], ["foo.num_heads"]
41 | )
42 |
43 |
44 | if __name__ == "__main__":
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/graph_tensor_to_values.md:
--------------------------------------------------------------------------------
1 | # tfgnn.graph_tensor_to_values
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Convert an eager `GraphTensor` to a mapping of mappings of PODTs.
10 |
11 |
12 | tfgnn.graph_tensor_to_values(
13 | graph: tfgnn.GraphTensor
14 | ) -> Dict[str, Any]
15 |
16 |
17 |
18 |
19 |
20 |
21 | This is used for pretty-printing. Convert your graph tensor with this and run
22 | the result through `pprint.pprint()` or `pprint.pformat()` for display of its
23 | contents.
24 |
25 |
26 |
27 |
28 | Args |
29 |
30 |
31 |
32 | graph
33 | |
34 |
35 | An eager GraphTensor instance to be pprinted.
36 | |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | Returns |
45 |
46 |
47 | A dict of plain-old data types that can be run through pprint.pprint() or
48 | a JSON conversion library.
49 | |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/hgt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Heterogeneous Graph Transformers.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import hgt
22 | ```
23 | """
24 | from tensorflow_gnn.models.hgt import config_dict
25 | from tensorflow_gnn.models.hgt import layers
26 | from tensorflow_gnn.utils import api_utils
27 |
28 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
29 | # Please see there for instructions how to reflect API changes.
30 | # LINT.IfChange
31 |
32 | HGTGraphUpdate = layers.HGTGraphUpdate
33 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict
34 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict
35 |
36 | # Remove all names added by module imports, unless explicitly allowed here.
37 | api_utils.remove_submodules_except(__name__, [])
38 | # LINT.ThenChange(../../api_def/hgt-symbols.txt)
39 |
--------------------------------------------------------------------------------
/testdata/homogeneous/tastelike.sstable.ascii:
--------------------------------------------------------------------------------
1 | 0
2 | features {
3 | feature {
4 | key: "#source"
5 | value: { bytes_list: { value: [ "amanatsu" ] } }
6 | }
7 | feature {
8 | key: "#target"
9 | value: { bytes_list: { value: [ "daidai" ] } }
10 | }
11 | feature {
12 | key: "weights"
13 | value: { float_list: { value: [ 0.1 ] } }
14 | }
15 | }
16 |
17 | 1
18 | features {
19 | feature {
20 | key: "#source"
21 | value: { bytes_list: { value: [ "amanatsu" ] } }
22 | }
23 | feature {
24 | key: "#target"
25 | value: { bytes_list: { value: [ "lumia" ] } }
26 | }
27 | feature {
28 | key: "weights"
29 | value: { float_list: { value: [ 0.2 ] } }
30 | }
31 | }
32 |
33 | 2
34 | features {
35 | feature {
36 | key: "#source"
37 | value: { bytes_list: { value: [ "kiyomi" ] } }
38 | }
39 | feature {
40 | key: "#target"
41 | value: { bytes_list: { value: [ "komikan" ] } }
42 | }
43 | feature {
44 | key: "weights"
45 | value: { float_list: { value: [ 0.3 ] } }
46 | }
47 | }
48 |
49 | 3
50 | features {
51 | feature {
52 | key: "#source"
53 | value: { bytes_list: { value: [ "mandora" ] } }
54 | }
55 | feature {
56 | key: "#target"
57 | value: { bytes_list: { value: [ "komikan" ] } }
58 | }
59 | feature {
60 | key: "weights"
61 | value: { float_list: { value: [ 0.4 ] } }
62 | }
63 | }
64 |
65 | 4
66 | features {
67 | feature {
68 | key: "#source"
69 | value: { bytes_list: { value: [ "mandora" ] } }
70 | }
71 | feature {
72 | key: "#target"
73 | value: { bytes_list: { value: [ "tangelo" ] } }
74 | }
75 | feature {
76 | key: "weights"
77 | value: { float_list: { value: [ 0.5 ] } }
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/tensorflow_gnn/experimental/sampler/proto/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The protocol message (protobuf) types defined by TF-GNN Sampler."""
16 |
17 | from tensorflow_gnn.experimental.sampler.proto import eval_dag_pb2
18 | from tensorflow_gnn.utils import api_utils
19 |
20 |
21 | # Program computation DAG, its stages and layers.
22 | Program = eval_dag_pb2.Program
23 | EvalDAG = eval_dag_pb2.EvalDAG
24 | Stage = eval_dag_pb2.Stage
25 | Layer = eval_dag_pb2.Layer
26 |
27 | # Specifications of input/output values of layers.
28 | ValueSpec = eval_dag_pb2.ValueSpec
29 | TensorSpec = eval_dag_pb2.TensorSpec
30 | RaggedTensorSpec = eval_dag_pb2.RaggedTensorSpec
31 | FlattenedSpec = eval_dag_pb2.FlattenedSpec
32 |
33 | # Layer configs.
34 | EdgeSamplingConfig = eval_dag_pb2.EdgeSamplingConfig
35 | IOFeatures = eval_dag_pb2.IOFeatures
36 |
37 |
38 | # Remove all names added by module imports, unless explicitly allowed here.
39 | api_utils.remove_submodules_except(__name__, [])
40 | # LINT.ThenChange()../api_def/sampler-symbols.txt)
41 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md:
--------------------------------------------------------------------------------
1 | # Module: multi_head_attention
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Transformer-style multi-head attention.
10 |
11 | Users of TF-GNN can use this model by importing it next to the core library as
12 |
13 | ```python
14 | import tensorflow_gnn as tfgnn
15 | from tensorflow_gnn.models import multi_head_attention
16 | ```
17 |
18 | ## Classes
19 |
20 | [`class MultiHeadAttentionConv`](./multi_head_attention/MultiHeadAttentionConv.md):
21 | Transformer-style (dot-product) multi-head attention on GNNs.
22 |
23 | ## Functions
24 |
25 | [`MultiHeadAttentionEdgePool(...)`](./multi_head_attention/MultiHeadAttentionEdgePool.md):
26 | Returns a layer for pooling edges with Transformer-style Multi-Head Attention.
27 |
28 | [`MultiHeadAttentionHomGraphUpdate(...)`](./multi_head_attention/MultiHeadAttentionHomGraphUpdate.md):
29 | Returns a GraphUpdate layer with a transformer-style multihead attention.
30 |
31 | [`MultiHeadAttentionMPNNGraphUpdate(...)`](./multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md):
32 | Returns a GraphUpdate layer for message passing with MultiHeadAttention pooling.
33 |
34 | [`graph_update_from_config_dict(...)`](./multi_head_attention/graph_update_from_config_dict.md):
35 | Returns a MultiHeadAttentionMPNNGraphUpdate initialized from `cfg`.
36 |
37 | [`graph_update_get_config_dict(...)`](./multi_head_attention/graph_update_get_config_dict.md):
38 | Returns ConfigDict for graph_update_from_config_dict() with defaults.
39 |
--------------------------------------------------------------------------------
/tensorflow_gnn/graph/graph_tensor_pprint_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for pretty-printing."""
16 |
17 | import pprint
18 |
19 | import tensorflow as tf
20 | from tensorflow_gnn.graph import graph_tensor_pprint as gpp
21 | from tensorflow_gnn.graph import graph_tensor_random as gr
22 | from tensorflow_gnn.graph import schema_utils as su
23 | import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2
24 | from tensorflow_gnn.utils import test_utils
25 |
26 |
27 | class TestConvertForPprint(tf.test.TestCase):
28 |
29 | def test_graph_tensor_to_values(self):
30 | schema = test_utils.get_proto_resource(
31 | 'testdata/feature_repr.pbtxt', schema_pb2.GraphSchema())
32 | spec = su.create_graph_spec_from_schema_pb(schema)
33 | graph = gr.random_graph_tensor(spec)
34 | values = gpp.graph_tensor_to_values(graph)
35 | text = pprint.pformat(values)
36 | # This just ensures there is no error in the genreation.
37 | self.assertIsInstance(text, str)
38 |
39 |
40 | if __name__ == '__main__':
41 | tf.test.main()
42 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/paid_with.csv:
--------------------------------------------------------------------------------
1 | target,source,retries
2 | 5338667949,1238474857489384,0
3 | 2485246926,12968701241275060,0
4 | 7574807079,12441028369470600,0
5 | 1935037613,12968701241275060,0
6 | 3719329604,14990890937985390,0
7 | 1015902584,14912408563871390,2
8 | 6360467569,18362223127059380,3
9 | 3672845476,3549061668422198,0
10 | 7577999036,11385846637304370,0
11 | 4077264491,14844931107602160,0
12 | 6932577020,14844931107602160,0
13 | 2952830085,11739198589848540,0
14 | 4045329101,13916484476264770,0
15 | 2187034790,8878522895102384,0
16 | 4071088591,14990890937985390,1
17 | 3301719092,11290312140467510,0
18 | 3630070609,13019350102369400,0
19 | 7220792354,18569067217418250,0
20 | 9418989087,3549061668422198,7
21 | 4751453767,4541017563963442,4
22 | 5488583952,1238474857489384,0
23 | 7738736966,11739198589848540,0
24 | 7118748138,18526138896540830,0
25 | 2507142984,17035680063294790,4
26 | 6967991082,16073125141142750,0
27 | 9099370006,17035680063294790,0
28 | 2674535393,9991040399813057,0
29 | 8806410342,8889177882781586,1
30 | 8140734876,11739198589848540,0
31 | 2432513727,11470379189154620,0
32 | 5420798237,11584989140147230,0
33 | 6060907986,14990890937985390,0
34 | 3315349239,18569067217418250,0
35 | 5221186582,11739198589848540,0
36 | 8597703614,11584989140147230,0
37 | 2277618242,11771673810809530,1
38 | 8179757481,8878522895102384,0
39 | 2561247267,16827485386298040,0
40 | 5675097657,18526138896540830,0
41 | 8790765633,16283233487191600,1
42 | 4719730577,14844931107602160,0
43 | 9232165517,16011471358128450,0
44 | 8334230482,15054318664602640,1
45 | 5930862377,12948957000457930,0
46 | 8603978315,14453480592564160,5
47 | 5339238148,11739198589848540,0
48 | 2346560845,11584989140147230,1
49 | 4415620884,12441028369470600,0
--------------------------------------------------------------------------------
/tensorflow_gnn/models/graph_sage/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """GraphSAGE.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import graph_sage
22 | ```
23 | """
24 |
25 | from tensorflow_gnn.models.graph_sage import layers
26 | from tensorflow_gnn.utils import api_utils
27 |
28 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
29 | # Please see there for instructions how to reflect API changes.
30 | # LINT.IfChange
31 |
32 | GCNGraphSAGENodeSetUpdate = layers.GCNGraphSAGENodeSetUpdate
33 | GraphSAGEAggregatorConv = layers.GraphSAGEAggregatorConv
34 | GraphSAGEPoolingConv = layers.GraphSAGEPoolingConv
35 | GraphSAGENextState = layers.GraphSAGENextState
36 | GraphSAGEGraphUpdate = layers.GraphSAGEGraphUpdate
37 |
38 | # Remove all names added by module imports, unless explicitly allowed here.
39 | api_utils.remove_submodules_except(__name__, [])
40 | # LINT.ThenChange(../../api_def/graph_sage-symbols.txt)
41 |
--------------------------------------------------------------------------------
/examples/sampler/creditcard/creditcard.csv:
--------------------------------------------------------------------------------
1 | id,number,issuer
2 | 11238474857489380,11238474857489380,BofBC
3 | 14216252633958570,14216252633958570,HeyBank
4 | 14541017563963440,14541017563963440,BellsGarbo
5 | 13549061668422190,13549061668422190,BellsGarbo
6 | 12948957000457930,12948957000457930,GDBank
7 | 11163838768727470,11163838768727470,BellsGarbo
8 | 11191576325053580,11191576325053580,BofBC
9 | 11290312140467510,11290312140467510,GDBank
10 | 11385846637304370,11385846637304370,BellsGarbo
11 | 11470379189154620,11470379189154620,HeyBank
12 | 11584989140147230,11584989140147230,BofBC
13 | 11739198589848540,11739198589848540,GDBank
14 | 11771673810809530,11771673810809530,BellsGarbo
15 | 12441028369470600,12441028369470600,BofBC
16 | 12968701241275060,12968701241275060,BellsGarbo
17 | 12982257258547830,12982257258547830,BellsGarbo
18 | 13019350102369400,13019350102369400,BellsGarbo
19 | 13916484476264770,13916484476264770,GDBank
20 | 14453480592564160,14453480592564160,BofBC
21 | 14844931107602160,14844931107602160,BellsGarbo
22 | 14912408563871390,14912408563871390,BofBC
23 | 14990890937985390,14990890937985390,BellsGarbo
24 | 15054318664602640,15054318664602640,HeyBank
25 | 16011471358128450,16011471358128450,BellsGarbo
26 | 16073125141142750,16073125141142750,GDBank
27 | 16283233487191600,16283233487191600,BellsGarbo
28 | 16827485386298040,16827485386298040,BellsGarbo
29 | 17035680063294790,17035680063294790,BellsGarbo
30 | 17396883707513070,17396883707513070,BellsGarbo
31 | 17861046738135650,17861046738135650,HeyBank
32 | 18362223127059380,18362223127059380,GDBank
33 | 18526138896540830,18526138896540830,GDBank
34 | 18569067217418250,18569067217418250,GDBank
35 | 18878522895102380,18878522895102380,HeyBank
36 | 18889177882781580,18889177882781580,BofBC
37 | 19991040399813050,19991040399813050,BofBC
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/keras/layers/NextStateFromConcat.md:
--------------------------------------------------------------------------------
1 | # tfgnn.keras.layers.NextStateFromConcat
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Computes a new state by concatenating inputs and applying a Keras Layer.
10 |
11 |
12 | tfgnn.keras.layers.NextStateFromConcat(
13 | transformation: tf.keras.layers.Layer, **kwargs
14 | )
15 |
16 |
17 |
18 |
19 |
20 |
21 | This layer flattens all inputs into a list (forgetting their origin),
22 | concatenates them and sends them through a user-supplied feed-forward network.
23 |
24 | This layer can be restored from config by `tf.keras.models.load_model()` when
25 | saved as part of a Keras model using `save_format="tf"`.
26 |
27 |
28 |
29 |
30 | Init args |
31 |
32 |
33 |
34 | transformation
35 | |
36 |
37 | Required. A Keras Layer to transform the combined inputs
38 | into the new state.
39 | |
40 |
41 |
42 |
43 |
44 |
45 |
46 | Call returns |
47 |
48 | |
49 | The result of transformation.
50 | |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/assert_constraints.md:
--------------------------------------------------------------------------------
1 | # tfgnn.assert_constraints
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Validate the shape constaints of a graph's features at runtime.
10 |
11 |
12 | tfgnn.assert_constraints(
13 | graph: tfgnn.GraphTensor
14 | ) -> tf.Operation
15 |
16 |
17 |
18 |
19 |
20 |
21 | This code returns a TensorFlow op with debugging assertions that ensure the
22 | parsed data has valid shape constraints for a graph. This can be instantiated
23 | in your TensorFlow graph while debugging if you believe that your data may be
24 | incorrectly shaped, or simply applied to a manually produced dataset to ensure
25 | that those constraints have been applied correctly.
26 |
27 |
28 |
29 |
30 | Args |
31 |
32 |
33 |
34 | graph
35 | |
36 |
37 | An instance of a GraphTensor.
38 | |
39 |
40 |
41 |
42 |
43 |
44 |
45 | Returns |
46 |
47 | |
48 | A list of check operations.
49 | |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gat_v2/hparams_vizier_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for hparams_vizier."""
16 |
17 | from absl.testing import absltest
18 | from tensorflow_gnn.models.gat_v2 import hparams_vizier
19 |
20 | from vizier.service import pyvizier as vz
21 |
22 |
23 | class HparamsVizierTest(absltest.TestCase):
24 |
25 | def test_regularization(self):
26 | problem = vz.ProblemStatement()
27 | hparams_vizier.add_params_regularization(
28 | problem.search_space, prefix="foo.")
29 | self.assertCountEqual([p.name for p in problem.search_space.parameters], [
30 | "foo.state_dropout_rate", "foo.edge_dropout_rate",
31 | "foo.l2_regularization"
32 | ])
33 |
34 | def test_attention(self):
35 | problem = vz.ProblemStatement()
36 | hparams_vizier.add_params_attention(problem.search_space,
37 | prefix="foo.")
38 | self.assertCountEqual(
39 | [p.name for p in problem.search_space.parameters],
40 | ["foo.num_heads"])
41 |
42 |
43 | if __name__ == "__main__":
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/package/move_generated_files.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Moves the bazel generated files needed for packaging the wheel to the source
17 | # tree.
18 |
19 | function _is_windows() {
20 | [[ "$(uname -s | tr 'A-Z' 'a-z')" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]
21 | }
22 |
23 | function tfgnn::move_generated_files() {
24 | if _is_windows; then
25 | # See https://github.com/bazelbuild/bazel/issues/6761 for bazel-bin.
26 | GENFILES=${BUILD_WORKSPACE_DIRECTORY}/bazel-genfiles
27 | if [[ ! -d ${GENFILES} ]]; then
28 | GENFILES=${BUILD_WORKSPACE_DIRECTORY}/bazel-bin
29 | fi
30 | else
31 | # If run by "bazel run", $(pwd) is the .runfiles dir that contains all the
32 | # data dependencies.
33 | GENFILES=$(pwd)
34 | fi
35 |
36 | FILES="
37 | tensorflow_gnn/experimental/sampler/proto/eval_dag_pb2.py
38 | tensorflow_gnn/proto/graph_schema_pb2.py
39 | tensorflow_gnn/proto/examples_pb2.py
40 | tensorflow_gnn/sampler/sampling_spec_pb2.py
41 | tensorflow_gnn/tools/sampled_stats_pb2.py
42 | "
43 | for FILE in ${FILES}; do
44 | cp -f ${GENFILES}/${FILE} ${BUILD_WORKSPACE_DIRECTORY}/${FILE}
45 | done
46 | }
47 |
48 | tfgnn::move_generated_files
49 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/ShuffleFeaturesGlobally.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.ShuffleFeaturesGlobally
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | A corruptor that shuffles features.
10 |
11 | Inherits From: [`Corruptor`](../contrastive_losses/Corruptor.md)
12 |
13 |
14 | contrastive_losses.ShuffleFeaturesGlobally(
15 | *args, seed: Optional[float] = None, **kwargs
16 | )
17 |
18 |
19 |
20 |
21 | NOTE: this function does not currently support TPUs. Consider using other
22 | corruptor functions if executing on TPUs. See b/269249455 for reference.
23 |
24 |
25 |
26 |
27 |
28 | Args |
29 |
30 |
31 |
32 | corruption_spec
33 | |
34 |
35 | A spec for corruption application.
36 | |
37 |
38 |
39 | corruption_fn
40 | |
41 |
42 | Corruption function.
43 | |
44 |
45 |
46 | default
47 | |
48 |
49 | Global application default of the corruptor. This is only used
50 | when corruption_spec is None.
51 | |
52 |
53 |
54 | **kwargs
55 | |
56 |
57 | Additional keyword arguments.
58 | |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/creditcard.csv:
--------------------------------------------------------------------------------
1 | id,number,issuer
2 | 11238474857489380,11238474857489380,BofBC
3 | 14216252633958570,14216252633958570,HeyBank
4 | 14541017563963440,14541017563963440,BellsGarbo
5 | 13549061668422190,13549061668422190,BellsGarbo
6 | 12948957000457930,12948957000457930,GDBank
7 | 11163838768727470,11163838768727470,BellsGarbo
8 | 11191576325053580,11191576325053580,BofBC
9 | 11290312140467510,11290312140467510,GDBank
10 | 11385846637304370,11385846637304370,BellsGarbo
11 | 11470379189154620,11470379189154620,HeyBank
12 | 11584989140147230,11584989140147230,BofBC
13 | 11739198589848540,11739198589848540,GDBank
14 | 11771673810809530,11771673810809530,BellsGarbo
15 | 12441028369470600,12441028369470600,BofBC
16 | 12968701241275060,12968701241275060,BellsGarbo
17 | 12982257258547830,12982257258547830,BellsGarbo
18 | 13019350102369400,13019350102369400,BellsGarbo
19 | 13916484476264770,13916484476264770,GDBank
20 | 14453480592564160,14453480592564160,BofBC
21 | 14844931107602160,14844931107602160,BellsGarbo
22 | 14912408563871390,14912408563871390,BofBC
23 | 14990890937985390,14990890937985390,BellsGarbo
24 | 15054318664602640,15054318664602640,HeyBank
25 | 16011471358128450,16011471358128450,BellsGarbo
26 | 16073125141142750,16073125141142750,GDBank
27 | 16283233487191600,16283233487191600,BellsGarbo
28 | 16827485386298040,16827485386298040,BellsGarbo
29 | 17035680063294790,17035680063294790,BellsGarbo
30 | 17396883707513070,17396883707513070,BellsGarbo
31 | 17861046738135650,17861046738135650,HeyBank
32 | 18362223127059380,18362223127059380,GDBank
33 | 18526138896540830,18526138896540830,GDBank
34 | 18569067217418250,18569067217418250,GDBank
35 | 18878522895102380,18878522895102380,HeyBank
36 | 18889177882781580,18889177882781580,BofBC
37 | 19991040399813050,19991040399813050,BofBC
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/iter_sets.md:
--------------------------------------------------------------------------------
1 | # tfgnn.iter_sets
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Utility function to iterate over all the sets present in a graph schema.
10 |
11 |
12 | tfgnn.iter_sets(
13 | schema: Union[tfgnn.proto.GraphSchema, tfgnn.GraphTensor]
14 | ) -> Iterator[Tuple[str, str, Any]]
15 |
16 |
17 |
18 |
19 | This function iterates over the context set, each of the node sets, and
20 | finally each of the edge sets.
21 |
22 |
23 |
24 |
25 | Args |
26 |
27 |
28 |
29 | schema
30 | |
31 |
32 | An instance of a GraphSchema proto message.
33 | |
34 |
35 |
36 |
37 |
38 |
39 |
40 | Yields |
41 |
42 | |
43 | Triplets of (set-type, set-name, features) where
44 |
45 | * set-type: A type of set, which is either of "context", "nodes" or "edges".
46 | * set-name: A string, the name of the set.
47 | * features: A dict of feature-name to feature-value.
48 | |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/multi_head_attention/hparams_vizier_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for hparams_vizier."""
16 |
17 | from absl.testing import absltest
18 | from tensorflow_gnn.models.multi_head_attention import hparams_vizier
19 |
20 | from vizier.service import pyvizier as vz
21 |
22 |
23 | class HparamsVizierTest(absltest.TestCase):
24 |
25 | def test_regularization(self):
26 | problem = vz.ProblemStatement()
27 | hparams_vizier.add_params_regularization(
28 | problem.search_space, prefix="foo.")
29 | self.assertCountEqual([p.name for p in problem.search_space.parameters], [
30 | "foo.state_dropout_rate", "foo.edge_dropout_rate",
31 | "foo.l2_regularization"
32 | ])
33 |
34 | def test_attention(self):
35 | problem = vz.ProblemStatement()
36 | hparams_vizier.add_params_attention(problem.search_space,
37 | prefix="foo.")
38 | self.assertCountEqual(
39 | [p.name for p in problem.search_space.parameters],
40 | ["foo.num_heads"])
41 |
42 |
43 | if __name__ == "__main__":
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/TightPadding.md:
--------------------------------------------------------------------------------
1 | # runner.TightPadding
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Calculates tight `SizeConstraints` for `GraphTensor` padding.
10 |
11 | Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md)
12 |
13 |
14 | runner.TightPadding(
15 | gtspec: tfgnn.GraphTensorSpec,
16 | dataset_provider: runner.DatasetProvider,
17 | min_nodes_per_component: Optional[Mapping[str, int]] = None
18 | )
19 |
20 |
21 |
22 |
23 | See: `tfgnn.find_tight_size_constraints.`
24 |
25 | ## Methods
26 |
27 | get_filter_fn
28 |
29 | View
30 | source
31 |
32 |
33 | get_filter_fn(
34 | size_constraints: SizeConstraints
35 | ) -> Callable[..., bool]
36 |
37 |
38 | get_size_constraints
39 |
40 | View
41 | source
42 |
43 |
44 | get_size_constraints(
45 | target_batch_size: int
46 | ) -> SizeConstraints
47 |
48 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/PassthruSampleDatasetsProvider.md:
--------------------------------------------------------------------------------
1 | # runner.PassthruSampleDatasetsProvider
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Builds a sampled `tf.data.Dataset` from multiple pass thru datasets.
10 |
11 | Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
12 |
13 |
14 | runner.PassthruSampleDatasetsProvider(
15 | principal_dataset: tf.data.Dataset,
16 | extra_datasets: Sequence[tf.data.Dataset],
17 | principal_weight: Optional[float] = None,
18 | extra_weights: Optional[Sequence[float]] = None,
19 | *,
20 | principal_cardinality: Optional[int] = None,
21 | fixed_cardinality: bool = False,
22 | shuffle_dataset: bool = False,
23 | examples_shuffle_size: Optional[int] = None
24 | )
25 |
26 |
27 |
28 |
29 | Passes any `principal_dataset` and `extra_datasets` thru: omitting any sharding.
30 | For detailed documentation, see the filename dataset provider complement:
31 | `SimpleSampleDatasetsProvider.`
32 |
33 | ## Methods
34 |
35 | get_dataset
36 |
37 | View
38 | source
39 |
40 |
41 | get_dataset(
42 | _: tf.distribute.InputContext
43 | ) -> tf.data.Dataset
44 |
45 |
46 | Gets a sampled `tf.data.Dataset` omitting any input context.
47 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/add_self_loops.md:
--------------------------------------------------------------------------------
1 | # tfgnn.add_self_loops
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Adds self-loops for `edge_set_name` EVEN if they already exist.
10 |
11 |
12 | tfgnn.add_self_loops(
13 | graph: GraphTensor, edge_set_name: gt.EdgeSetName
14 | ) -> GraphTensor
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | Args |
24 |
25 |
26 |
27 | graph
28 | |
29 |
30 | A scalar GraphTensor.
31 | |
32 |
33 |
34 | edge_set_name
35 | |
36 |
37 | An edge set in graph that has the same node set as source
38 | and target.
39 | |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | Returns |
48 |
49 | |
50 | A GraphTensor with self-loops added. A self-loop is added at each node,
51 | even if some or all of these nodes already have a loop. All feature tensors
52 | of the edge set are extended to cover the newly added edges with values
53 | that are all zeros (for numeric features), false (for boolean features), or
54 | empty (for string features), respectively.
55 | |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/tensorflow_gnn/converters/triples_test.py:
--------------------------------------------------------------------------------
1 | """Tests for triples."""
2 |
3 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | import tensorflow as tf
18 | import tensorflow_gnn as tfgnn
19 | from tensorflow_gnn.converters import triples
20 |
21 |
22 | class TriplesTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super().setUp()
26 | tfgnn.enable_graph_tensor_validation_at_runtime()
27 |
28 | def test_triples_from_array(self):
29 | spos = [["gnns", "are", "awesome"], ["tfgnn", "is", "awesome"],
30 | ["nns", "are", "awesome"]]
31 |
32 | gt = triples.triple_to_graphtensor(triples=spos)
33 | is_source = int(gt.edge_sets["is"].adjacency.source)
34 | is_target = int(gt.edge_sets["is"].adjacency.target)
35 |
36 | self.assertLen(gt.node_sets, 1)
37 | self.assertLen(gt.edge_sets, 2)
38 |
39 | self.assertEqual(gt.node_sets["nodes"].sizes, 4)
40 | self.assertEqual(gt.node_sets["nodes"].features["#id"].shape, 4)
41 |
42 | self.assertEqual(gt.node_sets["nodes"].features["#id"][is_source], "tfgnn")
43 | self.assertEqual(gt.node_sets["nodes"].features["#id"][is_target],
44 | "awesome")
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/mt_albis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF-GNN's Model Template "Albis".
16 |
17 | The TF-GNN Model Template "Albis" provides a small selection of field-tested
18 | GNN architectures through the `mt_albis.MtAlbisGraphUpdate` class.
19 |
20 | Users of TF-GNN can use it by importing it next to the core library as
21 |
22 | ```python
23 | import tensorflow_gnn as tfgnn
24 | from tensorflow_gnn.models import mt_albis
25 | ```
26 | """
27 |
28 | from tensorflow_gnn.models.mt_albis import config_dict
29 | from tensorflow_gnn.models.mt_albis import layers
30 | from tensorflow_gnn.utils import api_utils
31 |
32 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
33 | # Please see there for instructions how to reflect API changes.
34 | # LINT.IfChange
35 |
36 | MtAlbisGraphUpdate = layers.MtAlbisGraphUpdate
37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict
38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict
39 |
40 | # Remove all names added by module imports, unless explicitly allowed here.
41 | api_utils.remove_submodules_except(__name__, [])
42 | # LINT.ThenChange(../../api_def/mt_albis-symbols.txt)
43 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/vanilla_mpnn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF-GNN's "Vanilla MPNN" model.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import vanilla_gnn
22 | ```
23 |
24 | This model ties together some simple convolutions from the TF-GNN core library,
25 | so it does not define any Conv class by itself.
26 | """
27 |
28 | from tensorflow_gnn.models.vanilla_mpnn import config_dict
29 | from tensorflow_gnn.models.vanilla_mpnn import layers
30 | from tensorflow_gnn.utils import api_utils
31 |
32 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
33 | # Please see there for instructions how to reflect API changes.
34 | # LINT.IfChange
35 |
36 | VanillaMPNNGraphUpdate = layers.VanillaMPNNGraphUpdate
37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict
38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict
39 |
40 | # Remove all names added by module imports, unless explicitly allowed here.
41 | api_utils.remove_submodules_except(__name__, [])
42 | # LINT.ThenChange(../../api_def/vanilla_mpnn-symbols.txt)
43 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/gat_v2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Graph Attention Networks v2.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import gat_v2
22 | ```
23 | """
24 |
25 | from tensorflow_gnn.models.gat_v2 import config_dict
26 | from tensorflow_gnn.models.gat_v2 import layers
27 | from tensorflow_gnn.utils import api_utils
28 |
29 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
30 | # Please see there for instructions how to reflect API changes.
31 | # LINT.IfChange
32 |
33 | GATv2Conv = layers.GATv2Conv
34 | GATv2EdgePool = layers.GATv2EdgePool
35 | GATv2GraphUpdate = layers.GATv2GraphUpdate # Deprecated.
36 | GATv2HomGraphUpdate = layers.GATv2HomGraphUpdate
37 | GATv2MPNNGraphUpdate = layers.GATv2MPNNGraphUpdate
38 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict
39 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict
40 |
41 | # Remove all names added by module imports, unless explicitly allowed here.
42 | api_utils.remove_submodules_except(__name__, [])
43 | # LINT.ThenChange(../../api_def/gat_v2-symbols.txt)
44 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/FitOrSkipPadding.md:
--------------------------------------------------------------------------------
1 | # runner.FitOrSkipPadding
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Calculates fit or skip `SizeConstraints` for `GraphTensor` padding.
10 |
11 | Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md)
12 |
13 |
14 | runner.FitOrSkipPadding(
15 | gtspec: tfgnn.GraphTensorSpec,
16 | dataset_provider: runner.DatasetProvider,
17 | min_nodes_per_component: Optional[Mapping[str, int]] = None,
18 | fit_or_skip_sample_sample_size: int = 10000,
19 | fit_or_skip_success_ratio: float = 0.99
20 | )
21 |
22 |
23 |
24 |
25 | See: `tfgnn.learn_fit_or_skip_size_constraints.`
26 |
27 | ## Methods
28 |
29 | get_filter_fn
30 |
31 | View
32 | source
33 |
34 |
35 | get_filter_fn(
36 | size_constraints: SizeConstraints
37 | ) -> Callable[..., bool]
38 |
39 |
40 | get_size_constraints
41 |
42 | View
43 | source
44 |
45 |
46 | get_size_constraints(
47 | target_batch_size: int
48 | ) -> SizeConstraints
49 |
50 |
--------------------------------------------------------------------------------
/tensorflow_gnn/models/README.md:
--------------------------------------------------------------------------------
1 | # TF-GNN Models
2 |
3 | ## Introduction
4 |
5 | This directory contains a collection of GNN models implemented with the
6 | TF-GNN library. Some of them offer reusable pieces that can be imported
7 | _next to_ the core TF-GNN library, which effectively makes them little
8 | libraries of their own.
9 |
10 | ### Usage
11 |
12 | If, for example, the hypothetical FancyNet model offered a graph update layer,
13 | its use would look like
14 |
15 | ```python
16 | import tensorflow_gnn as tfgnn
17 | from tensorflow_gnn.models import fancynet
18 |
19 | graph = fancynet.FancyGraphUpdate(units=42, fanciness=0.99, ...)(graph)
20 | ```
21 |
22 | ...and require a separate dependency for `fancynet` in a BUILD file.
23 |
24 | ### API stability
25 |
26 | Each model comes with a README file that describes its intended level of
27 | API stability. Not all models are covered by the [semantic
28 | versioning](https://semver.org/spec/v2.0.0.html) of the TF-GNN package.
29 |
30 | ## List of Models
31 |
32 |
33 |
34 | * [Contrastive Losses](contrastive_losses/README.md): Contrastive losses for
35 | self-supervised learning.
36 | * [GATv2](gat_v2/README.md): Graph Attention Networks v2
37 | (Brody&al, 2021).
38 | * [GCN](gcn/README.md): Graph Convolutional Networks
39 | (Kipf&Welling, 2016), for homogeneous graphs only.
40 | * [GraphSAGE](graph_sage/README.md) (Hamilton&al., 2017).
41 | * [MtAlbis](mt_albis/README.md): Model Template "Albis" for easy configuration
42 | of a few field-tested GNN architectures, generalizing VanillaMPNN.
43 | * [MultiHeadAttention](multi_head_attention/README.md): Transformer-style
44 | multi-head attention on graph (Dwivedi&Bresson, 2021).
45 | * [VanillaMPNN](vanilla_mpnn/README.md): TF-GNN's classic baseline model,
46 | based on (Gilmer&al., 2016).
47 |
48 | Unsure? For generic node prediction tasks on relational data, we recommend
49 | to start with MtAlbis.
--------------------------------------------------------------------------------
/tensorflow_gnn/api_def/runner-symbols.txt:
--------------------------------------------------------------------------------
1 | runner.ContextLabelFn
2 | runner.DatasetProvider
3 | runner.DotProductLinkPrediction
4 | runner.FitOrSkipPadding
5 | runner.GraphBinaryClassification
6 | runner.GraphMeanAbsoluteError
7 | runner.GraphMeanAbsolutePercentageError
8 | runner.GraphMeanSquaredError
9 | runner.GraphMeanSquaredLogScaledError
10 | runner.GraphMeanSquaredLogarithmicError
11 | runner.GraphMulticlassClassification
12 | runner.GraphTensorPadding
13 | runner.GraphTensorProcessorFn
14 | runner.HadamardProductLinkPrediction
15 | runner.IntegratedGradientsExporter
16 | runner.KerasModelExporter
17 | runner.KerasTrainer
18 | runner.KerasTrainerCheckpointOptions
19 | runner.KerasTrainerOptions
20 | runner.Loss
21 | runner.Losses
22 | runner.Metric
23 | runner.Metrics
24 | runner.ModelExporter
25 | runner.NodeMeanAbsoluteError
26 | runner.NodeMeanAbsolutePercentageError
27 | runner.NodeMeanSquaredError
28 | runner.NodeMeanSquaredLogScaledError
29 | runner.NodeMeanSquaredLogarithmicError
30 | runner.ParameterServerStrategy
31 | runner.PassthruDatasetProvider
32 | runner.PassthruSampleDatasetsProvider
33 | runner.Predictions
34 | runner.NodeBinaryClassification
35 | runner.RootNodeBinaryClassification
36 | runner.RootNodeLabelFn
37 | runner.RootNodeMeanAbsoluteError
38 | runner.RootNodeMeanAbsolutePercentageError
39 | runner.RootNodeMeanSquaredError
40 | runner.RootNodeMeanSquaredLogScaledError
41 | runner.RootNodeMeanSquaredLogarithmicError
42 | runner.NodeMulticlassClassification
43 | runner.RootNodeMulticlassClassification
44 | runner.RunResult
45 | runner.SampleTFRecordDatasetsProvider
46 | runner.export_model
47 | runner.SimpleDatasetProvider
48 | runner.SimpleSampleDatasetsProvider
49 | runner.SubmoduleExporter
50 | runner.TFDataServiceConfig
51 | runner.TFRecordDatasetProvider
52 | runner.TPUStrategy
53 | runner.Task
54 | runner.TightPadding
55 | runner.Trainer
56 | runner.incrementing_model_dir
57 | runner.integrated_gradients
58 | runner.one_node_per_component
59 | runner.run
60 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/runner/TFDataServiceConfig.md:
--------------------------------------------------------------------------------
1 | # runner.TFDataServiceConfig
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Provides tf.data service related configuration options.
10 |
11 |
12 | runner.TFDataServiceConfig(
13 | tf_data_service_address: str,
14 | tf_data_service_job_name: str,
15 | tf_data_service_mode: Union[str, tf.data.experimental.service.ShardingPolicy]
16 | )
17 |
18 |
19 |
20 |
21 | tf.data service has data flexible visitation guarantees, its impact over your
22 | training pipelines will be empirical. Check out the tf.data service internals
23 | and operation details from
24 | https://www.tensorflow.org/api_docs/python/tf/data/experimental/service.
25 |
26 |
27 |
28 |
29 |
30 | Attributes |
31 |
32 |
33 |
34 | tf_data_service_address
35 | |
36 |
37 | Dataclass field
38 | |
39 |
40 |
41 | tf_data_service_job_name
42 | |
43 |
44 | Dataclass field
45 | |
46 |
47 |
48 | tf_data_service_mode
49 | |
50 |
51 | Dataclass field
52 | |
53 |
54 |
55 |
56 | ## Methods
57 |
58 | __eq__
59 |
60 |
61 | __eq__(
62 | other
63 | )
64 |
65 |
66 | Return self==value.
67 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/combine_values.md:
--------------------------------------------------------------------------------
1 | # tfgnn.combine_values
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Combines a list of tensors into one (by concatenation or otherwise).
10 |
11 |
12 | tfgnn.combine_values(
13 | inputs: List[Field], combine_type: str
14 | ) -> Field
15 |
16 |
17 |
18 |
19 | This is a convenience wrapper around standard TensorFlow operations, to
20 | provide standard names for common types of combining.
21 |
22 |
23 |
24 |
25 | Args |
26 |
27 | inputs | a list of Tensors or
28 | RaggedTensors, with shapes and types that are compatible for the selected
29 | combine_type. |
30 | combine_type | one of the
31 | following string values, to select the method for combining the inputs:
32 |
33 | * "sum": The input tensors are added. Their dtypes and shapes must
34 | match.
35 | * "concat": The input tensors are concatenated along the last axis.
36 | Their dtypes and shapes must match, except for the number of elements
37 | along the last axis.
38 | |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | Returns |
48 |
49 | |
50 | A tensor with the combined value of the inputs.
51 | |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/coherence.md:
--------------------------------------------------------------------------------
1 | # contrastive_losses.coherence
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Coherence metric implementation.
10 |
11 |
12 | @tf.function
13 | contrastive_losses.coherence(
14 | representations: tf.Tensor,
15 | *,
16 | sigma: Optional[tf.Tensor] = None,
17 | u: Optional[tf.Tensor] = None
18 | ) -> tf.Tensor
19 |
20 |
21 |
22 |
23 | Coherence measures how easy it is to construct a linear classifier on top of
24 | data without knowing downstream labels. Refer to
25 | https://arxiv.org/abs/2305.16562 for more details.
26 |
27 |
28 |
29 |
30 |
31 | Args |
32 |
33 |
34 |
35 | representations
36 | |
37 |
38 | Input representations, a rank-2 tensor.
39 | |
40 |
41 |
42 | sigma
43 | |
44 |
45 | Unused.
46 | |
47 |
48 |
49 | u
50 | |
51 |
52 | An optional tensor with left singular vectors of representations. If not
53 | present, computes a SVD of representations.
54 | |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 | Returns |
63 |
64 |
65 | Metric value as scalar tf.Tensor.
66 | |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/tensorflow_gnn/tools/validate_graph_schema.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Validate a schema's features and shapes.
16 |
17 | This script ensures that a schema is valid, has correct shapes, and isn't
18 | clobbering over reserve feature names.
19 | """
20 |
21 | import sys
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import tensorflow_gnn as tfgnn
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | def define_flags():
33 | """Define program flags."""
34 |
35 | flags.DEFINE_string("graph_schema", None,
36 | ("A filename to a text-formatted schema proto describing "
37 | "the available graph features."))
38 |
39 | flags.mark_flag_as_required("graph_schema")
40 |
41 |
42 | def app_main(unused_argv):
43 | """App runner main function."""
44 | schema = tfgnn.read_schema(FLAGS.graph_schema)
45 | try:
46 | warnings = tfgnn.validate_schema(schema)
47 | for warning in warnings:
48 | logging.warning(warning)
49 | logging.info("Schema validated correctly.")
50 | except tfgnn.ValidationError as exc:
51 | logging.error("Schema validation error: %s", exc)
52 | sys.exit(1)
53 |
54 |
55 | def main():
56 | define_flags()
57 | app.run(app_main)
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
--------------------------------------------------------------------------------
/examples/sampler/creditcard/customer.csv:
--------------------------------------------------------------------------------
1 | id,name,address,zipcode,score
2 | 1876448,Ji Grindstaff,"343 Third St. Houston, TX 77016",77016,0.5174975367
3 | 1372437,Augustina Uren,"9940 Prairie Ave. Deer Park, NY 11729",11729,0.9055414601
4 | 1368305,Yolonda Nave,"478 Grove Drive Hicksville, NY 11801",11801,0.2407929709
5 | 1974494,Adriana Mcburney,"909 Vermont St. Livonia, MI 48150",48150,0.3432604494
6 | 1257724,Dione Reeb,"7758 West Devon St. Algonquin, IL 60102",60102,0.5395817998
7 | 1758057,Geri Bones,"720 Marsh Road Tucker, GA 30084",30084,0.9270786453
8 | 1531660,Krystal Pablo,"227 Bishop St. Bemidji, MN 56601",56601,0.1871476497
9 | 1489311,Tonia Behnke,"51 Ramblewood St. Hobart, IN 46342",46342,0.1171643897
10 | 1407706,Fidel Speers,"714 Hilldale Ave. Cumming, GA 30040",30040,0.2359154945
11 | 196838,Necole Hunkins,"14 South Grove St. Coatesville, PA 19320",19320,0.06037660472
12 | 1195675,Tona Crays,"136 Somerset Dr. Chester, PA 19013",19013,0.6864072435
13 | 1659366,Mary Zeitz,"6 Thatcher St. Hillsboro, OR 97124",97124,0.643440095
14 | 1499004,Rudolph Sinquefield,"673 Elmwood Drive Fairburn, GA 30213",30213,0.9164648402
15 | 1344333,Marhta Rodrigue,"7284 Young Lane Upland, CA 91784",91784,0.6359473777
16 | 1443888,Ricardo Bundrick,"19 Rock Creek St. Sulphur, LA 70663",70663,0.4666323814
17 | 1108778,Myron Barrick,"450 Boston Street Solon, OH 44139",44139,0.7162866466
18 | 175583,Nichol Poulton,"34 Old 8th Drive Carmel, NY 10512",10512,0.2922429056
19 | 1251872,Boyd Padilla,"7351 North Spring Ave. Oshkosh, WI 54901",54901,0.4839780673
20 | 1493851,Orpha Yokoyama,"7805 Newport Street Sylvania, OH 43560",43560,0.7630316714
21 | 1599418,Ulysses Harps,"9138 Gates Street Braintree, MA 02184",1284,0.3491518542
22 | 1768701,Sulema Aguero,"42 South Wayne St. Hollywood, FL 33020",33020,0.7137308073
23 | 1549489,Meredith Warman,"98 Corona Court Morton Grove, IL 60053",60053,0.5057767373
24 | 1879799,Vonda Borth,"97 Tunnel Dr. Elyria, OH 44035",44035,0.6586585941
25 | 125454,Candida Uvalle,"608 Heritage Street Harrison Township, MI 48045",48045,0.6696324006
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/satisfies_size_constraints.md:
--------------------------------------------------------------------------------
1 | # tfgnn.satisfies_size_constraints
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Returns whether the input `graph_tensor` satisfies `total_sizes`.
10 |
11 |
12 | View aliases
13 |
14 | Main aliases
15 |
`tfgnn.satisfies_total_sizes`
16 |
17 |
18 |
19 |
20 | tfgnn.satisfies_size_constraints(
21 | graph_tensor: tfgnn.GraphTensor,
22 | total_sizes: tfgnn.SizeConstraints
23 | ) -> tf.Tensor
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | Args |
35 |
36 |
37 |
38 | graph_tensor
39 | |
40 |
41 | a graph tensor to check against target total sizes.
42 | |
43 |
44 |
45 | total_sizes
46 | |
47 |
48 | target total sizes for each graph piece.
49 | |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | Returns |
58 |
59 |
60 | A scalar boolean tensor equal to True if the graph_tensor statisifies
61 | total_sizes, and False if not.
62 | |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tensorflow_gnn/docs/api_docs/python/tfgnn/iter_features.md:
--------------------------------------------------------------------------------
1 | # tfgnn.iter_features
2 |
3 |
4 |
5 |
6 |
View source
7 | on GitHub
8 |
9 | Utility function to iterate over the features of a graph schema.
10 |
11 |
12 | tfgnn.iter_features(
13 | schema: Union[tfgnn.proto.GraphSchema, tfgnn.GraphTensor]
14 | ) -> Iterator[Tuple[Text, Text, Text, Union[schema_pb2.Feature, gt.Field]]]
15 |
16 |
17 |
18 |
19 | This function iterates over all the feature values of each of the context set,
20 | each of the node sets, and each of the edge sets.
21 |
22 |
23 |
24 |
25 | Args |
26 |
27 |
28 |
29 | schema
30 | |
31 |
32 | An instance of a GraphSchema proto message.
33 | |
34 |
35 |
36 |
37 |
38 |
39 |
40 | Yields |
41 |
42 |
43 | Triplets of (set-type, set-name, feature-name, feature-value) where
44 |
45 | * set-type: A type of set, which is either of "context", "nodes" or "edges".
46 | * set-name: A string, the name of the set.
47 | * feature-name: A string, the name of the feature in the set.
48 | * feature-value: A potentially ragged tensor (either a tf.Tensor
49 | or a tf.RaggedTensor). |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/testdata/heterogeneous/customer.csv:
--------------------------------------------------------------------------------
1 | id,name,address,zipcode,score
2 | 1876448,Ji Grindstaff,"343 Third St. Houston, TX 77016",77016,0.5174975367
3 | 1372437,Augustina Uren,"9940 Prairie Ave. Deer Park, NY 11729",11729,0.9055414601
4 | 1368305,Yolonda Nave,"478 Grove Drive Hicksville, NY 11801",11801,0.2407929709
5 | 1974494,Adriana Mcburney,"909 Vermont St. Livonia, MI 48150",48150,0.3432604494
6 | 1257724,Dione Reeb,"7758 West Devon St. Algonquin, IL 60102",60102,0.5395817998
7 | 1758057,Geri Bones,"720 Marsh Road Tucker, GA 30084",30084,0.9270786453
8 | 1531660,Krystal Pablo,"227 Bishop St. Bemidji, MN 56601",56601,0.1871476497
9 | 1489311,Tonia Behnke,"51 Ramblewood St. Hobart, IN 46342",46342,0.1171643897
10 | 1407706,Fidel Speers,"714 Hilldale Ave. Cumming, GA 30040",30040,0.2359154945
11 | 196838,Necole Hunkins,"14 South Grove St. Coatesville, PA 19320",19320,0.06037660472
12 | 1195675,Tona Crays,"136 Somerset Dr. Chester, PA 19013",19013,0.6864072435
13 | 1659366,Mary Zeitz,"6 Thatcher St. Hillsboro, OR 97124",97124,0.643440095
14 | 1499004,Rudolph Sinquefield,"673 Elmwood Drive Fairburn, GA 30213",30213,0.9164648402
15 | 1344333,Marhta Rodrigue,"7284 Young Lane Upland, CA 91784",91784,0.6359473777
16 | 1443888,Ricardo Bundrick,"19 Rock Creek St. Sulphur, LA 70663",70663,0.4666323814
17 | 1108778,Myron Barrick,"450 Boston Street Solon, OH 44139",44139,0.7162866466
18 | 175583,Nichol Poulton,"34 Old 8th Drive Carmel, NY 10512",10512,0.2922429056
19 | 1251872,Boyd Padilla,"7351 North Spring Ave. Oshkosh, WI 54901",54901,0.4839780673
20 | 1493851,Orpha Yokoyama,"7805 Newport Street Sylvania, OH 43560",43560,0.7630316714
21 | 1599418,Ulysses Harps,"9138 Gates Street Braintree, MA 02184",1284,0.3491518542
22 | 1768701,Sulema Aguero,"42 South Wayne St. Hollywood, FL 33020",33020,0.7137308073
23 | 1549489,Meredith Warman,"98 Corona Court Morton Grove, IL 60053",60053,0.5057767373
24 | 1879799,Vonda Borth,"97 Tunnel Dr. Elyria, OH 44035",44035,0.6586585941
25 | 125454,Candida Uvalle,"608 Heritage Street Harrison Township, MI 48045",48045,0.6696324006
--------------------------------------------------------------------------------
/tensorflow_gnn/models/multi_head_attention/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Transformer-style multi-head attention.
16 |
17 | Users of TF-GNN can use this model by importing it next to the core library as
18 |
19 | ```python
20 | import tensorflow_gnn as tfgnn
21 | from tensorflow_gnn.models import multi_head_attention
22 | ```
23 | """
24 |
25 | from tensorflow_gnn.models.multi_head_attention import config_dict
26 | from tensorflow_gnn.models.multi_head_attention import layers
27 | from tensorflow_gnn.utils import api_utils
28 |
29 | # NOTE: This package is covered by tensorflow_gnn/api_def/api_symbols_test.py.
30 | # Please see there for instructions how to reflect API changes.
31 | # LINT.IfChange
32 |
33 | MultiHeadAttentionConv = layers.MultiHeadAttentionConv
34 | MultiHeadAttentionEdgePool = layers.MultiHeadAttentionEdgePool
35 | MultiHeadAttentionHomGraphUpdate = layers.MultiHeadAttentionHomGraphUpdate
36 | MultiHeadAttentionMPNNGraphUpdate = layers.MultiHeadAttentionMPNNGraphUpdate
37 | graph_update_get_config_dict = config_dict.graph_update_get_config_dict
38 | graph_update_from_config_dict = config_dict.graph_update_from_config_dict
39 |
40 | # Remove all names added by module imports, unless explicitly allowed here.
41 | api_utils.remove_submodules_except(__name__, [])
42 | # LINT.ThenChange(../../api_def/multi_head_attention-symbols.txt)
43 |
--------------------------------------------------------------------------------