├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── talk_like_a_graph
├── README.md
├── __init__.py
├── graph_generators.py
├── graph_generators_runner.py
├── graph_generators_test.py
├── graph_metrics.py
├── graph_metrics_test.py
├── graph_tasks.py
├── graph_tasks_generator.py
├── graph_tasks_utils.py
├── graph_text_encoders.py
├── graph_text_encoders_test.py
└── name_dictionaries.py
└── tutorial
├── KDD-Tutorial-1-Talk-Like-a-Graph.ipynb
├── KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb
├── README.md
└── imgs
├── let_your_graph_do_the_talking.png
└── talk_like_a_graph_colab.png
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Encoding Graphs and Structured Data in Language Models
2 |
3 | This repository contains code for
4 | [Talk like a Graph: Encoding Graphs for Large Language Models](https://arxiv.org/abs/2310.04560)
5 | and
6 | [Let Your Graph Do the Talking: Encoding Structured Data for LLMs](https://arxiv.org/abs/2402.05862).
7 |
8 | ## Cite us
9 |
10 | If you use this package for published work, please cite the following.
11 |
12 | For the "Talk like a Graph" project:
13 |
14 | ```
15 | @inproceedigs{fatemi2024talk,
16 | title={Talk like a Graph: Encoding Graphs for Large Language Models},
17 | author={Bahare Fatemi and Jonathan Halcrow and Bryan Perozzi},
18 | booktitle={International Conference on Learning Representations (ICLR)},
19 | year={2024}
20 | }
21 | ```
22 |
23 | For Graph Token (a.k.a "Let Your Graph Do the Talking"):
24 |
25 | ```
26 | @misc{perozzi2024letgraphtalkingencoding,
27 | title={Let Your Graph Do the Talking: Encoding Structured Data for LLMs},
28 | author={Bryan Perozzi and Bahare Fatemi and Dustin Zelle and Anton Tsitsulin and Mehran Kazemi and Rami Al-Rfou and Jonathan Halcrow},
29 | year={2024},
30 | eprint={2402.05862},
31 | archivePrefix={arXiv},
32 | primaryClass={cs.LG},
33 | url={https://arxiv.org/abs/2402.05862},
34 | }
35 | ```
36 |
37 | ## Disclaimer
38 |
39 | This is not an official Google product.
40 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/__init__.py
--------------------------------------------------------------------------------
/talk_like_a_graph/README.md:
--------------------------------------------------------------------------------
1 | # Using Large Language Models to Solve Graph Problems
2 |
3 | This repository contains the code to generate graph reasoning problems with
4 | different graph generator algorithms and graph encoding methods, as well as
5 | different prompting techniques.
6 |
7 | The graph tasks are `edge existence`, `node degree`, `node count`, `edge count`,
8 | `connected nodes`, `disconnected nodes`, `cycle check`, `reachability`,
9 | `shortest path`, `maximum flow`, `node classification`, and `triangle counting`.
10 |
11 | The datasets used here are proposed in our paper:
12 | [Talk like a Graph: Encoding Graphs for Large Language Models](https://arxiv.org/abs/2310.04560).
13 |
14 | ### Generating graphs
15 |
16 | ```sh
17 | ./graphqa/graph_generator.sh
18 | ```
19 |
20 | ### Generating files for tasks
21 |
22 | ```sh
23 | ./graphqa/task_generator.sh
24 | ```
25 |
26 | ## Contact us
27 |
28 | For questions or comments about the implementation, please contact
29 | baharef@google.com.
30 |
31 | ## Cite us
32 |
33 | If you use this package for published work, please cite the following:
34 |
35 | ```
36 | @inproceedigs{fatemi2024talk,
37 | title={Talk like a Graph: Encoding Graphs for Large Language Models},
38 | author={Bahare Fatemi and Jonathan Halcrow and Bryan Perozzi},
39 | booktitle={International Conference on Learning Representations (ICLR)},
40 | year={2024}
41 | }
42 | ```
43 |
44 | ## Disclaimer
45 |
46 | This is not an official Google product.
47 |
48 | # Placeholder for internal data notes.
--------------------------------------------------------------------------------
/talk_like_a_graph/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/talk_like_a_graph/__init__.py
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_generators.py:
--------------------------------------------------------------------------------
1 | r"""Random graph generation."""
2 |
3 | import random
4 |
5 | import networkx as nx
6 | import numpy as np
7 |
8 |
9 | _NUMBER_OF_NODES_RANGE = {
10 | "small": np.arange(5, 10),
11 | "medium": np.arange(10, 15),
12 | "large": np.arange(15, 20),
13 | }
14 | _NUMBER_OF_COMMUNITIES_RANGE = {
15 | "small": np.arange(2, 4),
16 | "medium": np.arange(2, 8),
17 | "large": np.arange(2, 10),
18 | }
19 |
20 |
21 | def generate_graphs(
22 | number_of_graphs: int,
23 | algorithm: str,
24 | directed: bool,
25 | random_seed: int = 1234,
26 | er_min_sparsity: float = 0.0,
27 | er_max_sparsity: float = 1.0,
28 | ) -> list[nx.Graph]:
29 | """Generating multiple graphs using the provided algorithms.
30 |
31 | Args:
32 | number_of_graphs: number of graphs to generate
33 | algorithm: the random graph generator algorithm
34 | directed: whether to generate directed or undirected graphs.
35 | random_seed: the random seed to generate graphs with.
36 | er_min_sparsity: minimum sparsity of er graphs.
37 | er_max_sparsity: maximum sparsity of er graphs.
38 |
39 | Returns:
40 | generated_graphs: a list of nx graphs.
41 | Raises:
42 | NotImplementedError: if the algorithm is not yet implemented.
43 | """
44 |
45 | random.seed(random_seed)
46 | np.random.seed(random_seed)
47 |
48 | generated_graphs = []
49 | graph_sizes = random.choices(
50 | list(_NUMBER_OF_NODES_RANGE.keys()), k=number_of_graphs
51 | )
52 | random_state = np.random.RandomState(random_seed)
53 | if algorithm == "er":
54 | for i in range(number_of_graphs):
55 | sparsity = random.uniform(er_min_sparsity, er_max_sparsity)
56 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
57 | generated_graphs.append(
58 | nx.erdos_renyi_graph(
59 | number_of_nodes, sparsity, seed=random_state, directed=directed
60 | )
61 | )
62 | elif algorithm == "ba":
63 | for i in range(number_of_graphs):
64 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
65 | m = random.randint(1, number_of_nodes - 1)
66 | generated_graph = nx.barabasi_albert_graph(
67 | number_of_nodes, m, seed=random_state
68 | )
69 | if directed:
70 | generated_graphs.append(randomize_directions(generated_graph))
71 | else:
72 | generated_graphs.append(generated_graph)
73 | elif algorithm == "sbm":
74 | for i in range(number_of_graphs):
75 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
76 | number_of_communities = random.choice(
77 | _NUMBER_OF_COMMUNITIES_RANGE[graph_sizes[i]]
78 | )
79 | # sizes forms number of nodes in communities.
80 | sizes = []
81 | for _ in range(number_of_communities - 1):
82 | sizes.append(
83 | random.randint(
84 | 1,
85 | max(
86 | 1,
87 | number_of_nodes - sum(sizes) - (number_of_communities - 1),
88 | ),
89 | )
90 | )
91 | sizes.append(number_of_nodes - sum(sizes))
92 |
93 | # p forms probabilities of communities connecting each other.
94 | p = np.random.uniform(size=(number_of_communities, number_of_communities))
95 | if random.uniform(0, 1) < 0.5:
96 | p = np.maximum(p, p.transpose())
97 | else:
98 | p = np.minimum(p, p.transpose())
99 | sbm_graph = nx.stochastic_block_model(
100 | sizes, p, seed=random_state, directed=directed
101 | )
102 | # sbm graph generator automatically adds dictionary attributes.
103 | sbm_graph = remove_graph_data(sbm_graph)
104 | generated_graphs.append(sbm_graph)
105 | elif algorithm == "sfn":
106 | for i in range(number_of_graphs):
107 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
108 | generated_graph = nx.scale_free_graph(number_of_nodes, seed=random_state)
109 | # sfn graphs are by defaukt directed.
110 | if not directed:
111 | generated_graphs.append(remove_directions(generated_graph))
112 | else:
113 | generated_graphs.append(generated_graph)
114 | elif algorithm == "complete":
115 | for i in range(number_of_graphs):
116 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
117 | create_using = nx.DiGraph if directed else nx.Graph
118 | generated_graphs.append(
119 | nx.complete_graph(number_of_nodes, create_using=create_using)
120 | )
121 | elif algorithm == "star":
122 | for i in range(number_of_graphs):
123 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
124 | # number_of_nodes for star is the input + a center node.
125 | generated_graph = nx.star_graph(number_of_nodes - 1)
126 | if directed:
127 | generated_graphs.append(randomize_directions(generated_graph))
128 | else:
129 | generated_graphs.append(generated_graph)
130 | elif algorithm == "path":
131 | for i in range(number_of_graphs):
132 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]])
133 | create_using = nx.DiGraph if directed else nx.Graph
134 | generated_graphs.append(
135 | nx.path_graph(number_of_nodes, create_using=create_using)
136 | )
137 | else:
138 | raise NotImplementedError()
139 | return generated_graphs
140 |
141 |
142 | def remove_graph_data(graph: nx.Graph) -> nx.Graph:
143 | # GraphML writer does not support dictionary data for nodes or graphs.
144 | for ind in range((graph.number_of_nodes())):
145 | graph.nodes[ind].pop("block", None)
146 | graph_data_keys = list(graph.graph.keys())
147 | for _, node in enumerate(graph_data_keys):
148 | graph.graph.pop(node, None)
149 | return graph
150 |
151 |
152 | def randomize_directions(graph: nx.Graph) -> nx.DiGraph:
153 | # Converting the undirected graph to a directed graph.
154 | directed_graph = graph.to_directed()
155 | # For each edge, randomly choose a direction.
156 | edges = list(graph.edges())
157 | for u, v in edges:
158 | if random.random() < 0.5:
159 | directed_graph.remove_edge(u, v)
160 | else:
161 | directed_graph.remove_edge(v, u)
162 |
163 | return directed_graph
164 |
165 |
166 | def remove_directions(graph: nx.Graph) -> nx.Graph:
167 | # Converting the direted graph to an undirected one by removing directions.
168 | undirected_graph = nx.Graph()
169 | undirected_graph.add_nodes_from(graph.nodes())
170 | # Add edges between nodes, ignoring directions.
171 | for u, v in graph.edges():
172 | undirected_graph.add_edge(u, v)
173 |
174 | return undirected_graph
175 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_generators_runner.py:
--------------------------------------------------------------------------------
1 | r"""Random graph generation.
2 |
3 | This code generates random graph using different algorithms.
4 |
5 | # Placeholder for Google-internal comments.
6 | """
7 |
8 | from collections.abc import Sequence
9 | import os
10 |
11 | from absl import app
12 | from absl import flags
13 | import networkx as nx
14 |
15 | # Internal import.
16 | from . import graph_generators
17 |
18 | _ALGORITHM = flags.DEFINE_string(
19 | "algorithm",
20 | None,
21 | "The graph generating algorithm to use.",
22 | required=True,
23 | )
24 | _NUMBER_OF_GRAPHS = flags.DEFINE_integer(
25 | "number_of_graphs",
26 | None,
27 | "The number of graphs to generate.",
28 | required=True,
29 | )
30 | _DIRECTED = flags.DEFINE_bool(
31 | "directed", False, "Whether to generate directed graphs."
32 | )
33 | _OUTPUT_PATH = flags.DEFINE_string(
34 | "output_path", None, "The output path to write the graphs.", required=True
35 | )
36 | _SPLIT = flags.DEFINE_string(
37 | "split", None, "The dataset split to generate.", required=True
38 | )
39 | _MIN_SPARSITY = flags.DEFINE_float("min_sparsity", 0.0, "The minimum sparsity.")
40 | _MAX_SPARSITY = flags.DEFINE_float("max_sparsity", 1.0, "The maximum sparsity.")
41 |
42 |
43 | def write_graphs(graphs: list[nx.Graph], output_dir: str) -> None:
44 | if not os.path.isdir(output_dir):
45 | os.makedirs(output_dir)
46 | for ind, graph in enumerate(graphs):
47 | nx.write_graphml(
48 | graph,
49 | os.Open(
50 | os.path.join(output_dir, str(ind) + ".graphml"),
51 | "wb",
52 | ),
53 | )
54 |
55 |
56 | def main(argv: Sequence[str]) -> None:
57 | if len(argv) > 1:
58 | raise app.UsageError("Too many command-line arguments.")
59 |
60 | if _SPLIT.value == "train":
61 | random_seed = 9876
62 | elif _SPLIT.value == "test":
63 | random_seed = 1234
64 | elif _SPLIT.value == "validation":
65 | random_seed = 5432
66 | else:
67 | raise NotImplementedError()
68 |
69 | generated_graphs = graph_generators.generate_graphs(
70 | number_of_graphs=_NUMBER_OF_GRAPHS.value,
71 | algorithm=_ALGORITHM.value,
72 | directed=_DIRECTED.value,
73 | random_seed=random_seed,
74 | er_min_sparsity=_MIN_SPARSITY.value,
75 | er_max_sparsity=_MAX_SPARSITY.value,
76 | )
77 | write_graphs(
78 | graphs=generated_graphs,
79 | output_dir=os.path.join(
80 | _OUTPUT_PATH.value,
81 | "directed" if _DIRECTED.value else "undirected",
82 | _ALGORITHM.value,
83 | _SPLIT.value,
84 | ),
85 | )
86 |
87 |
88 | if __name__ == "__main__":
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_generators_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import parameterized
2 | from . import graph_generators
3 | from absl.testing import absltest
4 |
5 |
6 | class GraphGenerationTest(absltest.TestCase, parameterized.TestCase):
7 |
8 | @parameterized.named_parameters(
9 | dict(
10 | testcase_name='er_undirected_1',
11 | algorithm='er',
12 | directed=False,
13 | k=1,
14 | ),
15 | dict(
16 | testcase_name='er_directed_1',
17 | algorithm='er',
18 | directed=True,
19 | k=1,
20 | ),
21 | dict(
22 | testcase_name='ba_undirected_5',
23 | algorithm='ba',
24 | directed=False,
25 | k=5,
26 | ),
27 | dict(
28 | testcase_name='ba_directed_5',
29 | algorithm='ba',
30 | directed=True,
31 | k=5,
32 | ),
33 | )
34 | def test_number_of_graphs(self, algorithm, directed, k):
35 | generated_graph = graph_generators.generate_graphs(k, algorithm, directed)
36 | self.assertLen(generated_graph, k)
37 |
38 | @parameterized.named_parameters(
39 | dict(
40 | testcase_name='er_undirected',
41 | algorithm='er',
42 | directed=False,
43 | ),
44 | dict(
45 | testcase_name='er_directed',
46 | algorithm='er',
47 | directed=True,
48 | ),
49 | dict(
50 | testcase_name='ba_undirected',
51 | algorithm='ba',
52 | directed=False,
53 | ),
54 | dict(
55 | testcase_name='ba_directed',
56 | algorithm='ba',
57 | directed=True,
58 | ),
59 | dict(
60 | testcase_name='sbm_undirected',
61 | algorithm='sbm',
62 | directed=False,
63 | ),
64 | dict(
65 | testcase_name='sbm_directed',
66 | algorithm='sbm',
67 | directed=True,
68 | ),
69 | dict(
70 | testcase_name='sfn_undirected',
71 | algorithm='sfn',
72 | directed=False,
73 | ),
74 | dict(
75 | testcase_name='sfn_directed',
76 | algorithm='sfn',
77 | directed=True,
78 | ),
79 | dict(
80 | testcase_name='complete_undirected',
81 | algorithm='complete',
82 | directed=False,
83 | ),
84 | dict(
85 | testcase_name='complete_directed',
86 | algorithm='complete',
87 | directed=True,
88 | ),
89 | dict(
90 | testcase_name='star_undirected',
91 | algorithm='star',
92 | directed=False,
93 | ),
94 | dict(
95 | testcase_name='star_directed',
96 | algorithm='star',
97 | directed=True,
98 | ),
99 | dict(
100 | testcase_name='path_undirected',
101 | algorithm='path',
102 | directed=False,
103 | ),
104 | dict(
105 | testcase_name='path_directed',
106 | algorithm='path',
107 | directed=True,
108 | ),
109 | )
110 | def test_directions(self, algorithm, directed):
111 | generated_graph = graph_generators.generate_graphs(1, algorithm, directed)
112 | self.assertEqual(generated_graph[0].is_directed(), directed)
113 |
114 |
115 | if __name__ == '__main__':
116 | googletest.main()
117 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_metrics.py:
--------------------------------------------------------------------------------
1 | """Metrics for seqio tasks over graph data.
2 |
3 | This module contains definitions of metric_fns to be used for scoring
4 | graph tasks from nlgraph and graphqa.
5 | """
6 |
7 | from typing import Mapping, Sequence
8 |
9 |
10 | def yes_no_accuracy(
11 | targets: Sequence[str], predictions: Sequence[str]
12 | ) -> Mapping[str, float]:
13 | """Assesses the accuracy of LLM outputs on Yes/No tasks.
14 |
15 | Targets must contain either the word 'yes' or the word 'no' but not both.
16 |
17 | Predictions are binarized by checking for 'yes' or 'no' in the first line.
18 |
19 | Args:
20 | targets: The expected output strings.
21 | predictions: The LLM outputs.
22 |
23 | Returns:
24 | Returns a dict of the following metrics:
25 | yes_no_accuracy: The % where the target and prediction match.
26 | yes_no_ambiguous: The % where the prediction contained yes and no
27 | yes_no_indeterminate: The % where the prediction contained neither yes nor
28 | no
29 |
30 | Raises:
31 | ValueError: If a target string contains 'yes' and 'no'
32 | """
33 | num_correct = 0
34 | num_ambiguous = 0
35 | num_indeterminate = 0
36 | for target, prediction in zip(targets, predictions):
37 | normalized_target = target.lower()
38 | binarized_target = 'yes' in normalized_target
39 | print(binarized_target)
40 | if binarized_target and 'no' in normalized_target:
41 | raise ValueError(f'Ambiguous target string, {target}')
42 | if not binarized_target and 'no' not in normalized_target:
43 | raise ValueError(f'Indeterminate target string, {target}')
44 | normalized_prediction = prediction.splitlines()
45 | if not normalized_prediction:
46 | normalized_prediction = ''
47 | else:
48 | normalized_prediction = normalized_prediction[0]
49 | normalized_prediction = normalized_prediction.lower()
50 | binarized_prediction = 'yes' in normalized_prediction.lower()
51 | print(normalized_prediction)
52 | if binarized_prediction and 'no' in normalized_prediction:
53 | num_ambiguous += 1
54 | continue
55 | if not binarized_prediction and 'no' not in normalized_prediction:
56 | num_indeterminate += 1
57 | continue
58 | if binarized_prediction == binarized_target:
59 | num_correct += 1
60 | return {
61 | 'yes_no_accuracy': num_correct / len(targets),
62 | 'yes_no_ambiguous': num_ambiguous / len(targets),
63 | 'yes_no_indeterminate': num_indeterminate / len(targets),
64 | }
65 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_metrics_test.py:
--------------------------------------------------------------------------------
1 | from . import graph_metrics
2 | from absl.testing import absltest
3 |
4 |
5 | class GraphTasksTest(absltest.TestCase):
6 |
7 | def test_yes_no_correct(self):
8 | result = graph_metrics.yes_no_accuracy(
9 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"],
10 | predictions=[
11 | "yes",
12 | "Yes\nNo",
13 | "No\nYes",
14 | "That's gonna be no from me.",
15 | "yes",
16 | ],
17 | )
18 | self.assertEqual(result["yes_no_ambiguous"], 0)
19 | self.assertEqual(result["yes_no_indeterminate"], 0)
20 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.8)
21 |
22 | def test_yes_no_ambiguous(self):
23 | result = graph_metrics.yes_no_accuracy(
24 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"],
25 | predictions=[
26 | "yes",
27 | "Yes No",
28 | "No Yes",
29 | "That's gonna be no from me.",
30 | "yes",
31 | ],
32 | )
33 | self.assertEqual(result["yes_no_ambiguous"], 0.4)
34 | self.assertEqual(result["yes_no_indeterminate"], 0)
35 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.4)
36 |
37 | def test_yes_no_indeterminate(self):
38 | result = graph_metrics.yes_no_accuracy(
39 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"],
40 | predictions=[
41 | "yes",
42 | "\n No",
43 | "",
44 | "That's gonna be no from me.",
45 | "yes",
46 | ],
47 | )
48 | self.assertEqual(result["yes_no_ambiguous"], 0)
49 | self.assertEqual(result["yes_no_indeterminate"], 0.4)
50 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.4)
51 |
52 | def test_yes_no_accuracy_raises_on_ambiguous_target(self):
53 | with self.assertRaises(ValueError):
54 | graph_metrics.yes_no_accuracy(
55 | targets=["Yes but maybe no"],
56 | predictions=["yes"],
57 | )
58 |
59 | def test_yes_no_accuracy_raises_on_indeterminate_target(self):
60 | with self.assertRaises(ValueError):
61 | graph_metrics.yes_no_accuracy(
62 | targets=["Hmm?"],
63 | predictions=["yes"],
64 | )
65 |
66 |
67 | if __name__ == "__main__":
68 | googletest.main()
69 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_tasks.py:
--------------------------------------------------------------------------------
1 | """The graph tasks to be tried with LLMs."""
2 |
3 | import random
4 |
5 | import networkx as nx
6 | import numpy as np
7 |
8 | from . import graph_text_encoders
9 |
10 |
11 | class GraphTask:
12 | """The parent class for all the graph tasks."""
13 |
14 | def __init__(self):
15 | self.name = 'default'
16 | self.maximum_nnodes_cot_graph = 10
17 |
18 | def prepare_examples_dict(
19 | self,
20 | graphs: list[nx.Graph],
21 | generator_algorithms: list[str],
22 | encoding_method: str,
23 | ) -> dict[int, dict[str, str | list[int]]]:
24 | raise NotImplementedError()
25 |
26 | def create_few_shot_example(
27 | self, graph: nx.Graph, encoding_method: str, cot: bool
28 | ):
29 | raise NotImplementedError()
30 |
31 |
32 | class CycleCheck(GraphTask):
33 | """The graph task to check if there is at least one cycle or not."""
34 |
35 | def __init__(self):
36 | super().__init__()
37 | self.name = 'cycle_check'
38 | self._task_description = 'Q: Is there a cycle in this graph?\nA: '
39 |
40 | def prepare_examples_dict(
41 | self,
42 | graphs: list[nx.Graph],
43 | generator_algorithms: list[str],
44 | encoding_method: str,
45 | ) -> dict[int, dict[str, str | list[int]]]:
46 | examples_dict = {}
47 | for ind, graph in enumerate(graphs):
48 | question = (
49 | graph_text_encoders.encode_graph(graph, encoding_method)
50 | + self._task_description
51 | )
52 | try:
53 | nx.find_cycle(graph)
54 | answer = 'Yes, there is a cycle.'
55 | except nx.NetworkXNoCycle:
56 | answer = 'No, there is no cycle.'
57 | examples_dict[ind] = {
58 | 'question': question,
59 | 'answer': answer,
60 | 'nnodes': str(len(graph.nodes())),
61 | 'nedges': str(len(graph.edges())),
62 | 'task_description': self._task_description,
63 | 'graph': graph,
64 | 'algorithm': generator_algorithms[ind],
65 | 'node_ids': [],
66 | }
67 | return examples_dict
68 |
69 | def create_few_shot_example(
70 | self, graph: nx.Graph, encoding_method: str, cot: bool
71 | ) -> str:
72 | """Create a few shot example w or w/o cot for the graph graph."""
73 | name_dict = graph_text_encoders.get_tlag_node_encoder(
74 | graph, encoding_method
75 | )
76 | question = (
77 | graph_text_encoders.encode_graph(graph, encoding_method)
78 | + self._task_description
79 | )
80 | try:
81 | cycle = nx.find_cycle(graph)
82 | cycle_text = ''
83 | answer = 'Yes, there is a cycle. '
84 | if cot:
85 | for pair in cycle:
86 | cycle_text += (
87 | name_dict[pair[0]]
88 | + ' is connected to '
89 | + name_dict[pair[1]]
90 | + ', '
91 | )
92 | cycle_cot = 'The cycle is: %s.' % cycle_text[:-2]
93 | answer += cycle_cot
94 | except nx.NetworkXNoCycle:
95 | answer = 'No, there is no cycle.'
96 | return question + answer
97 |
98 | def choose_few_shot_examples(
99 | self,
100 | few_shots_dict: dict[tuple[str, str], list[str]],
101 | encoding_method: str,
102 | k: int = 2,
103 | ) -> str:
104 | """Choose few shot examples for each algorithm."""
105 | pos_cycle_algorithms = ['er', 'ba', 'sbm', 'sfn', 'complete']
106 | neg_cycle_algorithms = ['star', 'path']
107 | few_shots_str = ''
108 | # choose k-1 shots for pos algorithms and one negative.
109 | positive_algorithms = random.choices(pos_cycle_algorithms, k=k - 1)
110 | for positive_algorithm in positive_algorithms:
111 | example_list = few_shots_dict[(positive_algorithm, encoding_method)]
112 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n'
113 | negative_algorithm = random.choice(neg_cycle_algorithms)
114 | example_list = few_shots_dict[(negative_algorithm, encoding_method)]
115 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n'
116 | return few_shots_str
117 |
118 |
119 | class EdgeExistence(GraphTask):
120 | """The graph task to check if an edge exist in a graph or not."""
121 |
122 | def __init__(self):
123 | super().__init__()
124 | self.name = 'edge_existence'
125 |
126 | def prepare_examples_dict(
127 | self,
128 | graphs: list[nx.Graph],
129 | generator_algorithms: list[str],
130 | encoding_method: str,
131 | ) -> dict[int, dict[str, str | list[int]]]:
132 | examples_dict = {}
133 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
134 |
135 | for ind, graph in enumerate(graphs):
136 | source, target = random.sample(list(graph.nodes()), k=2)
137 | question = graph_text_encoders.encode_graph(graph, encoding_method)
138 | task_description = 'Q: Is node %s connected to node %s?\nA: ' % (
139 | name_dict[source],
140 | name_dict[target],
141 | )
142 | question += task_description
143 | if ((source, target) in graph.edges()) or (
144 | (target, source) in graph.edges()
145 | ):
146 | answer = 'Yes.'
147 | else:
148 | answer = 'No.'
149 | examples_dict[ind] = {
150 | 'question': question,
151 | 'answer': answer,
152 | 'nnodes': str(len(graph.nodes())),
153 | 'nedges': str(len(graph.edges())),
154 | 'task_description': task_description,
155 | 'graph': graph,
156 | 'algorithm': generator_algorithms[ind],
157 | 'node_ids': [source, target],
158 | }
159 | return examples_dict
160 |
161 | def create_few_shot_example(
162 | self, graph: nx.Graph, encoding_method: str, cot: bool
163 | ) -> str:
164 | name_dict = graph_text_encoders.get_tlag_node_encoder(
165 | graph, encoding_method
166 | )
167 | source, target = random.sample(list(graph.nodes()), k=2)
168 | question = graph_text_encoders.encode_graph(graph, encoding_method)
169 | question += 'Q: Is node %s connected to node %s?\nA: ' % (
170 | name_dict[source],
171 | name_dict[target],
172 | )
173 | if ((source, target) in graph.edges()) or (
174 | (target, source) in graph.edges()
175 | ):
176 | answer = 'Yes.'
177 | if cot:
178 | answer += (
179 | ' Because, there is an edge from %s to %s in the graph description.'
180 | % (name_dict[source], name_dict[target])
181 | )
182 | else:
183 | answer = 'No.'
184 | if cot:
185 | answer += (
186 | ' Because, there is no edge from %s to %s in the graph description.'
187 | % (name_dict[source], name_dict[target])
188 | )
189 | return question + answer
190 |
191 |
192 | class NodeCount(GraphTask):
193 | """The graph task for finding number of nodes in a graph."""
194 |
195 | def __init__(self):
196 | super().__init__()
197 | self.name = 'node_count'
198 | self._task_description = 'Q: How many nodes are in this graph?\nA: '
199 |
200 | def prepare_examples_dict(
201 | self,
202 | graphs: list[nx.Graph],
203 | generator_algorithms: list[str],
204 | encoding_method: str,
205 | ) -> dict[int, dict[str, str | list[int]]]:
206 | examples_dict = {}
207 | for ind, graph in enumerate(graphs):
208 | question = graph_text_encoders.encode_graph(graph, encoding_method)
209 | question += self._task_description
210 | answer = ' %d.' % len(graph.nodes())
211 | examples_dict[ind] = {
212 | 'question': question,
213 | 'answer': answer,
214 | 'nnodes': str(len(graph.nodes())),
215 | 'nedges': str(len(graph.edges())),
216 | 'task_description': self._task_description,
217 | 'graph': graph,
218 | 'algorithm': generator_algorithms[ind],
219 | 'node_ids': [],
220 | }
221 | return examples_dict
222 |
223 | def get_nodes_string(self, name_dict: dict[int, str], nnodes: int) -> str:
224 | node_string = ''
225 | for i in range(nnodes - 1):
226 | node_string += name_dict[i] + ', '
227 | node_string += 'and ' + name_dict[nnodes - 1]
228 | return node_string
229 |
230 | def create_few_shot_example(
231 | self, graph: nx.Graph, encoding_method: str, cot: bool
232 | ) -> str:
233 | name_dict = graph_text_encoders.get_tlag_node_encoder(
234 | graph, encoding_method
235 | )
236 | question = graph_text_encoders.encode_graph(graph, encoding_method)
237 | question += self._task_description
238 | answer = '%d.' % len(graph.nodes())
239 | if cot:
240 | answer += ' The nodes are %s.' % self.get_nodes_string(
241 | name_dict, len(graph.nodes())
242 | )
243 |
244 | return question + answer
245 |
246 |
247 | class NodeDegree(GraphTask):
248 | """The graph task for finding degree of a node in a graph."""
249 |
250 | def __init__(self):
251 | super().__init__()
252 | self.name = 'node_degree'
253 |
254 | def prepare_examples_dict(
255 | self,
256 | graphs: list[nx.Graph],
257 | generator_algorithms: list[str],
258 | encoding_method: str,
259 | ) -> dict[int, dict[str, str | list[int]]]:
260 | examples_dict = {}
261 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
262 | for ind, graph in enumerate(graphs):
263 | question = graph_text_encoders.encode_graph(graph, encoding_method)
264 | source_node = random.sample(list(graph.nodes()), k=1)[0]
265 | task_description = (
266 | 'Q: What is the degree of node %s?\nA: ' % name_dict[source_node]
267 | )
268 | question += task_description
269 | answer = '%d.' % graph.degree[source_node]
270 | examples_dict[ind] = {
271 | 'question': question,
272 | 'answer': answer,
273 | 'nnodes': str(len(graph.nodes())),
274 | 'nedges': str(len(graph.edges())),
275 | 'task_description': task_description,
276 | 'graph': graph,
277 | 'algorithm': generator_algorithms[ind],
278 | 'node_ids': [source_node],
279 | }
280 | return examples_dict
281 |
282 | def get_edge_string(
283 | self, name_dict: dict[int, str], graph: nx.Graph, source_node: int
284 | ) -> str:
285 | """Gets a string identifying the edges a given node is connected to."""
286 | edge_string = ''
287 | target_edges = graph.edges(source_node)
288 | target_nodes = []
289 | for edge in target_edges:
290 | target_nodes.append(edge[1])
291 | if target_nodes:
292 | for i in range(len(target_nodes) - 1):
293 | edge_string += name_dict[target_nodes[i]] + ', '
294 | edge_string += 'and ' + name_dict[target_nodes[-1]]
295 | else:
296 | edge_string = 'no nodes'
297 | return edge_string
298 |
299 | def create_few_shot_example(
300 | self, graph: nx.Graph, encoding_method: str, cot: bool
301 | ) -> str:
302 | name_dict = graph_text_encoders.get_tlag_node_encoder(
303 | graph, encoding_method
304 | )
305 | question = graph_text_encoders.encode_graph(graph, encoding_method)
306 | source_node = random.sample(list(graph.nodes()), k=1)[0]
307 | question += (
308 | 'Q: What is the degree of node %s?\nA: ' % name_dict[source_node]
309 | )
310 | answer = '%d.' % graph.degree[source_node]
311 | if cot:
312 | answer += ' This is because %s is connected to %s.' % (
313 | name_dict[source_node],
314 | self.get_edge_string(name_dict, graph, source_node),
315 | )
316 | return question + answer
317 |
318 |
319 | class EdgeCount(GraphTask):
320 | """The graph task for finding number of edges in a graph."""
321 |
322 | def __init__(self):
323 | super().__init__()
324 | self.name = 'edge_count'
325 | self._task_description = 'Q: How many edges are in this graph?\nA: '
326 |
327 | def prepare_examples_dict(
328 | self,
329 | graphs: list[nx.Graph],
330 | generator_algorithms: list[str],
331 | encoding_method: str,
332 | ) -> dict[int, dict[str, str | list[int]]]:
333 | examples_dict = {}
334 | for ind, graph in enumerate(graphs):
335 | question = graph_text_encoders.encode_graph(graph, encoding_method)
336 | question += self._task_description
337 | answer = ' %d.' % len(graph.edges())
338 | examples_dict[ind] = {
339 | 'question': question,
340 | 'answer': answer,
341 | 'nnodes': str(len(graph.nodes())),
342 | 'nedges': str(len(graph.edges())),
343 | 'task_description': self._task_description,
344 | 'graph': graph,
345 | 'algorithm': generator_algorithms[ind],
346 | 'node_ids': [],
347 | }
348 | return examples_dict
349 |
350 | def get_edges_string(
351 | self, name_dict: dict[int, str], edges: list[tuple[int, int]]
352 | ) -> str:
353 | edges_string = ''
354 | for edge in edges:
355 | edges_string += (
356 | '(' + name_dict[edge[0]] + ', ' + name_dict[edge[1]] + '), '
357 | )
358 | return edges_string.strip()[:-1]
359 |
360 | def create_few_shot_example(
361 | self, graph: nx.Graph, encoding_method: str, cot: bool
362 | ) -> str:
363 | name_dict = graph_text_encoders.get_tlag_node_encoder(
364 | graph, encoding_method
365 | )
366 | question = graph_text_encoders.encode_graph(graph, encoding_method)
367 | question += self._task_description
368 | answer = '%d.' % len(graph.edges())
369 | if cot:
370 | answer += ' The edges are %s.' % self.get_edges_string(
371 | name_dict, list(graph.edges())
372 | )
373 | return question + answer
374 |
375 |
376 | class ConnectedNodes(GraphTask):
377 | """The graph task for finding connected nodes to a given node in a graph."""
378 |
379 | def __init__(self):
380 | super().__init__()
381 | self.name = 'connected_nodes'
382 |
383 | def prepare_examples_dict(
384 | self,
385 | graphs: list[nx.Graph],
386 | generator_algorithms: list[str],
387 | encoding_method: str,
388 | ) -> dict[int, dict[str, str | list[int]]]:
389 | examples_dict = {}
390 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
391 | for ind, graph in enumerate(graphs):
392 | question = graph_text_encoders.encode_graph(graph, encoding_method)
393 | source_node = random.sample(list(graph.nodes()), k=1)[0]
394 | task_description = (
395 | 'Q: List all the nodes connected to %s in alphabetical order.\nA: '
396 | % name_dict[source_node]
397 | )
398 | question += task_description
399 | outgoing_edges = list(graph.edges(source_node))
400 | if outgoing_edges:
401 | answer = self.get_connected_nodes(outgoing_edges, name_dict) + '.'
402 | else:
403 | answer = ' No nodes.'
404 | examples_dict[ind] = {
405 | 'question': question,
406 | 'answer': answer,
407 | 'nnodes': str(len(graph.nodes())),
408 | 'nedges': str(len(graph.edges())),
409 | 'task_description': task_description,
410 | 'graph': graph,
411 | 'algorithm': generator_algorithms[ind],
412 | 'node_ids': [source_node],
413 | }
414 | return examples_dict
415 |
416 | def get_connected_nodes(
417 | self, edges: list[tuple[int, int]], name_dict: dict[int, str]
418 | ) -> str:
419 | """Gets a string including all the nodes that are connected to source."""
420 | connected_nodes = []
421 | for edge in edges:
422 | connected_nodes.append(name_dict[edge[1]])
423 | connected_nodes_string = ''
424 | if connected_nodes:
425 | try:
426 | int(connected_nodes[0])
427 | connected_nodes_string = ', '.join(map(str, connected_nodes))
428 | except ValueError:
429 | # Check if these are not integers, sort
430 | connected_nodes_string = ', '.join(map(str, sorted(connected_nodes)))
431 | return connected_nodes_string
432 |
433 | def create_few_shot_example(
434 | self, graph: nx.Graph, encoding_method: str, cot: bool
435 | ) -> str:
436 | name_dict = graph_text_encoders.get_tlag_node_encoder(
437 | graph, encoding_method
438 | )
439 | question = graph_text_encoders.encode_graph(graph, encoding_method)
440 | source_node = random.sample(list(graph.nodes()), k=1)[0]
441 | question += (
442 | 'Q: List all the nodes connected to %s in alphabetical order.\nA: '
443 | % name_dict[source_node]
444 | )
445 | outgoing_edges = list(graph.edges(source_node))
446 | answer = ''
447 | if outgoing_edges:
448 | answer = self.get_connected_nodes(outgoing_edges, name_dict) + '.'
449 | if cot:
450 | answer += ' This is because there is an edge from %s to %s.' % (
451 | name_dict[source_node],
452 | answer,
453 | )
454 | else:
455 | answer = 'No nodes.'
456 | if cot:
457 | answer += (
458 | ' This is because %s is not connected to any node.'
459 | % name_dict[source_node]
460 | )
461 | return question + answer
462 |
463 |
464 | class DisconnectedNodes(GraphTask):
465 | """The task for finding disconnected nodes for a given node in a graph."""
466 |
467 | def __init__(self):
468 | super().__init__()
469 | self.name = 'disconnected_nodes'
470 |
471 | def prepare_examples_dict(
472 | self,
473 | graphs: list[nx.Graph],
474 | generator_algorithms: list[str],
475 | encoding_method: str,
476 | ) -> dict[int, dict[str, str | list[int]]]:
477 | examples_dict = {}
478 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
479 | for ind, graph in enumerate(graphs):
480 | question = graph_text_encoders.encode_graph(graph, encoding_method)
481 | source_node = random.sample(list(graph.nodes()), k=1)[0]
482 | task_description = (
483 | 'Q: List all the nodes that are not connected to %s in alphabetical'
484 | ' order.\nA: '
485 | % name_dict[source_node]
486 | )
487 | question += task_description
488 | outgoing_edges = list(graph.edges(source_node))
489 | answer = self.get_disconnected_nodes(
490 | source_node, outgoing_edges, name_dict, list(graph.nodes())
491 | )
492 | if not answer:
493 | answer = 'No nodes'
494 |
495 | answer += '.'
496 | examples_dict[ind] = {
497 | 'question': question,
498 | 'answer': answer,
499 | 'nnodes': str(len(graph.nodes())),
500 | 'nedges': str(len(graph.edges())),
501 | 'task_description': task_description,
502 | 'graph': graph,
503 | 'algorithm': generator_algorithms[ind],
504 | 'node_ids': [source_node],
505 | }
506 | return examples_dict
507 |
508 | def get_disconnected_nodes(
509 | self,
510 | source: int,
511 | edges: list[tuple[int, int]],
512 | name_dict: dict[int, str],
513 | all_nodes: list[int],
514 | ) -> str:
515 | """Gets a string with all the nodes that are not connected to source."""
516 | for edge in edges:
517 | if edge[1] in all_nodes:
518 | all_nodes.remove(edge[1])
519 | if source in all_nodes:
520 | all_nodes.remove(source)
521 | all_nodes_names = []
522 | for node in all_nodes:
523 | all_nodes_names.append(name_dict[node])
524 | # sorted operation should be different for integers vs strings.
525 | if all_nodes_names:
526 | try:
527 | int(all_nodes_names[0])
528 | for ind, value in enumerate(all_nodes_names):
529 | all_nodes_names[ind] = int(value)
530 | all_nodes_names = sorted(all_nodes_names)
531 | for ind, value in enumerate(all_nodes_names):
532 | all_nodes_names[ind] = str(value)
533 | except ValueError:
534 | pass
535 | return ', '.join(map(str, sorted(all_nodes_names)))
536 |
537 | def create_few_shot_example(
538 | self, graph: nx.Graph, encoding_method: str, cot: bool
539 | ) -> str:
540 | name_dict = graph_text_encoders.get_tlag_node_encoder(
541 | graph, encoding_method
542 | )
543 | question = graph_text_encoders.encode_graph(graph, encoding_method)
544 | source_node = random.sample(list(graph.nodes()), k=1)[0]
545 | question += (
546 | 'Q: List all the nodes that are not connected to %s in alphabetical'
547 | ' order.\nA: '
548 | % name_dict[source_node]
549 | )
550 | outgoing_edges = list(graph.edges(source_node))
551 | answer = ''
552 | disconnected_nodes_string = self.get_disconnected_nodes(
553 | source_node, outgoing_edges, name_dict, list(graph.nodes())
554 | )
555 | if outgoing_edges:
556 | if not disconnected_nodes_string:
557 | disconnected_nodes_string = 'No nodes'
558 | answer = disconnected_nodes_string + '.'
559 | if cot:
560 | answer += ' This is because there is not an edge from %s to %s.' % (
561 | name_dict[source_node],
562 | answer,
563 | )
564 | else:
565 | answer = ' No nodes.'
566 | if cot:
567 | answer += (
568 | ' This is because %s is connected to all nodes.'
569 | % name_dict[source_node]
570 | )
571 | return question + answer
572 |
573 |
574 | class Reachability(GraphTask):
575 | """The graph task to check if there is a path from a source to target."""
576 |
577 | def __init__(self):
578 | super().__init__()
579 | self.name = 'reachability'
580 |
581 | def prepare_examples_dict(
582 | self,
583 | graphs: list[nx.Graph],
584 | generator_algorithms: list[str],
585 | encoding_method: str,
586 | ) -> dict[int, dict[str, str | list[int]]]:
587 | examples_dict = {}
588 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
589 |
590 | for ind, graph in enumerate(graphs):
591 | source, target = random.sample(list(graph.nodes()), k=2)
592 | question = graph_text_encoders.encode_graph(graph, encoding_method)
593 | task_description = 'Q: Is there a path from node %s to node %s?\nA: ' % (
594 | name_dict[source],
595 | name_dict[target],
596 | )
597 | question += task_description
598 | if nx.has_path(graph, source, target):
599 | answer = 'Yes.'
600 | else:
601 | answer = 'No.'
602 | examples_dict[ind] = {
603 | 'question': question,
604 | 'answer': answer,
605 | 'nnodes': str(len(graph.nodes())),
606 | 'nedges': str(len(graph.edges())),
607 | 'task_description': task_description,
608 | 'graph': graph,
609 | 'algorithm': generator_algorithms[ind],
610 | 'node_ids': [source, target],
611 | }
612 | return examples_dict
613 |
614 | def create_few_shot_example(
615 | self, graph: nx.Graph, encoding_method: str, cot: bool
616 | ) -> str:
617 | name_dict = graph_text_encoders.get_tlag_node_encoder(
618 | graph, encoding_method
619 | )
620 | source, target = random.sample(list(graph.nodes()), k=2)
621 | question = graph_text_encoders.encode_graph(graph, encoding_method)
622 | question += 'Q: Is there a path from node %s to node %s?\nA: ' % (
623 | name_dict[source],
624 | name_dict[target],
625 | )
626 | if nx.has_path(graph, source, target):
627 | answer = 'Yes.'
628 | if cot:
629 | path = nx.shortest_path(graph, source, target)
630 | explanation = ' Because'
631 | for i in range(len(path) - 1):
632 | # The only edge or the non-last edges in the path.
633 | if len(path) == 2 or i < len(path) - 2:
634 | sep = ','
635 | # The last edge in a path with more than one edge.
636 | else:
637 | sep = ', and'
638 | explanation += '%s there is an edge from node %d to node %d' % (
639 | sep,
640 | path[i],
641 | path[i + 1],
642 | )
643 | explanation += ' .'
644 | answer += explanation
645 | else:
646 | answer = 'No.'
647 | if cot:
648 | answer += (
649 | ' Because, there is no path connecting node %s to node %s based on'
650 | ' the graph description.' % (name_dict[source], name_dict[target])
651 | )
652 | return question + answer
653 |
654 |
655 | class ShortestPath(GraphTask):
656 | """The graph task to check if there is a path from a source to target."""
657 |
658 | def __init__(self):
659 | super().__init__()
660 | self.name = 'shortest_path'
661 |
662 | def prepare_examples_dict(
663 | self,
664 | graphs: list[nx.Graph],
665 | generator_algorithms: list[str],
666 | encoding_method: str,
667 | ) -> dict[int, dict[str, str | list[int]]]:
668 | examples_dict = {}
669 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
670 |
671 | for ind, graph in enumerate(graphs):
672 | source, target = random.sample(list(graph.nodes()), k=2)
673 | question = graph_text_encoders.encode_graph(graph, encoding_method)
674 | task_description = (
675 | 'Q: What is the length of the shortest path from node %s to node'
676 | ' %s?\nA: '
677 | % (
678 | name_dict[source],
679 | name_dict[target],
680 | )
681 | )
682 | question += task_description
683 | try:
684 | path = nx.shortest_path(graph, source, target)
685 | answer = str(len(path) - 1) + '.'
686 | except nx.NetworkXNoPath:
687 | answer = 'There is no path from node %s to node %s.' % (
688 | name_dict[source],
689 | name_dict[target],
690 | )
691 | examples_dict[ind] = {
692 | 'question': question,
693 | 'answer': answer,
694 | 'nnodes': str(len(graph.nodes())),
695 | 'nedges': str(len(graph.edges())),
696 | 'task_description': task_description,
697 | 'graph': graph,
698 | 'algorithm': generator_algorithms[ind],
699 | 'node_ids': [source, target],
700 | }
701 | return examples_dict
702 |
703 | def create_few_shot_example(
704 | self, graph: nx.Graph, encoding_method: str, cot: bool
705 | ) -> str:
706 | name_dict = graph_text_encoders.get_tlag_node_encoder(
707 | graph, encoding_method
708 | )
709 | source, target = random.sample(list(graph.nodes()), k=2)
710 | question = graph_text_encoders.encode_graph(graph, encoding_method)
711 | question += (
712 | 'Q: What is the length of the shortest path from node %s to node'
713 | ' %s?\nA: '
714 | % (
715 | name_dict[source],
716 | name_dict[target],
717 | )
718 | )
719 | if nx.has_path(graph, source, target):
720 | path = nx.shortest_path(graph, source, target)
721 | answer = str(len(path) - 1) + '.'
722 | if cot:
723 | explanation = ' Because'
724 | for i in range(len(path) - 1):
725 | # The only edge or the non-last edges in the path.
726 | if len(path) == 2 or i < len(path) - 2:
727 | sep = ','
728 | # The last edge in a path with more than one edge.
729 | else:
730 | sep = ', and'
731 | explanation += '%s there is an edge from node %d to node %d' % (
732 | sep,
733 | path[i],
734 | path[i + 1],
735 | )
736 | explanation += ' .'
737 | answer += explanation
738 | else:
739 | answer = 'There is no path from node %s to node %s.' % (
740 | name_dict[source],
741 | name_dict[target],
742 | )
743 | if cot:
744 | answer += (
745 | ' Because, there is no path connecting node %s to node %s based on'
746 | ' the graph description.' % (name_dict[source], name_dict[target])
747 | )
748 | return question + answer
749 |
750 |
751 | class TriangleCounting(GraphTask):
752 | """The graph task to count the number of triangles in a graph."""
753 |
754 | def __init__(self):
755 | super().__init__()
756 | self.name = 'triangle_counting'
757 | self._task_description = 'Q: How many triangles are in this graph?\nA: '
758 |
759 | def prepare_examples_dict(
760 | self,
761 | graphs: list[nx.Graph],
762 | generator_algorithms: list[str],
763 | encoding_method: str,
764 | ) -> dict[int, dict[str, str | list[int]]]:
765 | examples_dict = {}
766 | for ind, graph in enumerate(graphs):
767 | question = (
768 | graph_text_encoders.encode_graph(graph, encoding_method)
769 | + self._task_description
770 | )
771 | ntriangles = int(np.sum(list(nx.triangles(graph).values())) / 3)
772 |
773 | answer = '%i.' % ntriangles
774 | examples_dict[ind] = {
775 | 'question': question,
776 | 'answer': answer,
777 | 'nnodes': str(len(graph.nodes())),
778 | 'nedges': str(len(graph.edges())),
779 | 'task_description': self._task_description,
780 | 'graph': graph,
781 | 'algorithm': generator_algorithms[ind],
782 | 'node_ids': [],
783 | }
784 | return examples_dict
785 |
786 | def create_few_shot_example(
787 | self, graph: nx.Graph, encoding_method: str, cot: bool
788 | ) -> str:
789 | """Create a few shot example w or w/o cot for the graph graph."""
790 | name_dict = graph_text_encoders.get_tlag_node_encoder(
791 | graph, encoding_method
792 | )
793 | question = (
794 | graph_text_encoders.encode_graph(graph, encoding_method)
795 | + self._task_description
796 | )
797 | triangles_dict = nx.triangles(graph)
798 | ntriangles = int(np.sum(list(triangles_dict.values())) / 3)
799 |
800 | if ntriangles > 0:
801 | answer = '%i.' % ntriangles
802 | if cot:
803 | ntriangles_cot = ''
804 | for key, value in triangles_dict.items():
805 | if value > 0:
806 | if value == 1:
807 | ntriangles_cot += (
808 | 'There is %i triangle including node %s as a vertex.\n'
809 | % (value, name_dict[key])
810 | )
811 | else:
812 | ntriangles_cot += (
813 | 'There are %i triangles including node %s as a vertex.\n'
814 | % (value, name_dict[key])
815 | )
816 | ntriangles_cot += (
817 | 'Summing the number of triangles for all nodes and dividing them by'
818 | ' three gives us %i triangles in total.' % ntriangles
819 | )
820 | answer += ntriangles_cot
821 | else:
822 | answer = '0.'
823 | if cot:
824 | ntriangles_cot = 'No three nodes form a triangle of edges.'
825 | answer += ntriangles_cot
826 | return question + answer
827 |
828 |
829 | class MaximumFlow(GraphTask):
830 | """The graph task to compute the maximum flow from a source to a target."""
831 |
832 | def __init__(self):
833 | super().__init__()
834 | self.name = 'maximum_flow'
835 |
836 | def prepare_examples_dict(
837 | self,
838 | graphs: list[nx.Graph],
839 | generator_algorithms: list[str],
840 | encoding_method: str,
841 | ) -> dict[int, dict[str, str | list[int]]]:
842 | examples_dict = {}
843 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
844 |
845 | for ind, graph in enumerate(graphs):
846 | graph = add_edge_weight(graph)
847 | source, target = random.sample(list(graph.nodes()), k=2)
848 | question = graph_text_encoders.encode_graph(graph, encoding_method)
849 | task_description = (
850 | 'Q: What is the maximum capacity of the flow from node %s to node'
851 | ' %s?\nA: ' % (name_dict[source], name_dict[target])
852 | )
853 | question += task_description
854 | maximum_flow_value = nx.maximum_flow(
855 | graph, source, target, capacity='weight'
856 | )[0]
857 | answer = str(maximum_flow_value) + '.'
858 | examples_dict[ind] = {
859 | 'question': question,
860 | 'answer': answer,
861 | 'nnodes': str(len(graph.nodes())),
862 | 'nedges': str(len(graph.edges())),
863 | 'task_description': task_description,
864 | 'graph': graph,
865 | 'algorithm': generator_algorithms[ind],
866 | 'node_ids': [source, target],
867 | }
868 | return examples_dict
869 |
870 | def create_few_shot_example(
871 | self, graph: nx.Graph, encoding_method: str, cot: bool
872 | ) -> str:
873 | graph = add_edge_weight(graph)
874 | name_dict = graph_text_encoders.get_tlag_node_encoder(
875 | graph, encoding_method
876 | )
877 | source, target = random.sample(list(graph.nodes()), k=2)
878 | question = graph_text_encoders.encode_graph(graph, encoding_method)
879 | question += (
880 | 'Q: What is the maximum capacity of the flow from node %s to'
881 | ' node %s?\nA: ' % (name_dict[source], name_dict[target])
882 | )
883 | flow_value, flow_dict = nx.maximum_flow(
884 | graph, source, target, capacity='weight'
885 | )
886 | answer = str(flow_value) + '.'
887 | if flow_value > 0:
888 | if cot:
889 | explanation = ' This is because of the following edges: '
890 | for edge, capacity in flow_dict.items():
891 | for key, value in capacity.items():
892 | if value > 0:
893 | explanation += (
894 | 'the edge from node %i to node %i with capacity %i, '
895 | % (
896 | edge,
897 | key,
898 | value,
899 | )
900 | )
901 | explanation = explanation.strip()[:-1] + '.'
902 | answer += explanation
903 | else:
904 | if cot:
905 | answer += (
906 | ' Because, there is no path connecting node %s to node %s based on'
907 | ' the graph description.' % (name_dict[source], name_dict[target])
908 | )
909 | return question + answer
910 |
911 |
912 | def has_edge_weights(graph):
913 | for _, _, data in graph.edges(data=True):
914 | if 'weight' not in data:
915 | return False
916 | return True
917 |
918 |
919 | def add_edge_weight(graph):
920 | if has_edge_weights(graph):
921 | return graph
922 | else:
923 | for edge in graph.edges():
924 | graph[edge[0]][edge[1]]['weight'] = random.randint(1, 10)
925 | return graph
926 |
927 |
928 | class NodeClassification(GraphTask):
929 | """The graph task to classify a given node in the graph."""
930 |
931 | def __init__(self):
932 | super().__init__()
933 | self.name = 'node_classification'
934 | self.classes = [
935 | 'soccer',
936 | 'baseball',
937 | 'tennis',
938 | 'golf',
939 | 'football',
940 | 'surfing',
941 | ]
942 |
943 | def prepare_examples_dict(
944 | self,
945 | graphs: list[nx.Graph],
946 | generator_algorithms: list[str],
947 | encoding_method: str,
948 | ) -> dict[int, dict[str, str | list[int]]]:
949 | classes = random.sample(list(self.classes), k=2)
950 | examples_dict = {}
951 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method)
952 | for ind, graph in enumerate(graphs):
953 | question = graph_text_encoders.encode_graph(graph, encoding_method)
954 | nnodes = len(graph.nodes())
955 | # Sampling nnodes // 2 + 1 nodes.
956 | sampled_nodes = random.sample(
957 | list(graph.nodes(data=True)), k=nnodes // 2 + 1
958 | )
959 | # Adding the class of half of the nodes.
960 | for node_data in sampled_nodes[:-1]:
961 | node_class = classes[node_data[1]['block']]
962 | question += (
963 | 'Node ' + name_dict[node_data[0]] + ' likes ' + node_class + '.\n'
964 | )
965 | # Reserving the last sampled node for the question.
966 | task_description = 'Q: Does node %s like %s or %s?\nA: ' % (
967 | name_dict[sampled_nodes[-1][0]],
968 | classes[0],
969 | classes[1],
970 | )
971 | question += task_description
972 | answer = classes[sampled_nodes[-1][1]['block']]
973 |
974 | examples_dict[ind] = {
975 | 'question': question,
976 | 'answer': answer,
977 | 'nnodes': str(nnodes),
978 | 'nedges': str(len(graph.edges())),
979 | 'task_description': task_description,
980 | 'graph': graph,
981 | 'algorithm': generator_algorithms[ind],
982 | # id of the last samples node
983 | 'node_ids': [sampled_nodes[-1][0]],
984 | }
985 |
986 | return examples_dict
987 |
988 | def create_few_shot_example(
989 | self, graph: nx.Graph, encoding_method: str, cot: bool
990 | ) -> str:
991 | classes = random.sample(list(self.classes), k=2)
992 | name_dict = graph_text_encoders.get_tlag_node_encoder(
993 | graph, encoding_method
994 | )
995 | question = graph_text_encoders.encode_graph(graph, encoding_method)
996 | nnodes = len(graph.nodes())
997 | sampled_nodes = random.sample(
998 | list(graph.nodes(data=True)), k=nnodes // 2 + 1
999 | )
1000 | for node_data in sampled_nodes[:-1]:
1001 | node_class = classes[node_data[1]['block']]
1002 | question += (
1003 | 'Node ' + name_dict[node_data[0]] + ' likes ' + node_class + '.\n'
1004 | )
1005 | task_description = 'Q: Does node %s like %s or %s?\nA: ' % (
1006 | name_dict[sampled_nodes[-1][0]],
1007 | classes[0],
1008 | classes[1],
1009 | )
1010 | question += task_description
1011 | answer = classes[sampled_nodes[-1][1]['block']]
1012 |
1013 | if cot:
1014 | explanation = (
1015 | ' This is because most of the nodes that are connected to node %s'
1016 | ' likes %s.'
1017 | % (sampled_nodes[-1][0], classes[sampled_nodes[-1][1]['block']])
1018 | )
1019 | answer += explanation
1020 | return question + answer
1021 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_tasks_generator.py:
--------------------------------------------------------------------------------
1 | r"""The graph tasks to be tried with LLMs..
2 |
3 | This code loads graphs and creates graph tasks and output them as tf examples in
4 | a recordio file in the task directory provided.
5 |
6 | # Placeholder for Google-internal comments.
7 | """
8 |
9 | from collections.abc import Sequence
10 | import os
11 | import random
12 |
13 | from absl import app
14 | from absl import flags
15 | import networkx as nx
16 | import numpy as np
17 |
18 | from . import graph_tasks
19 | from . import graph_tasks_utils as utils
20 |
21 | _TASK_DIR = flags.DEFINE_string(
22 | 'task_dir', None, 'The directory to write tasks.', required=True
23 | )
24 | _GRAPHS_DIR = flags.DEFINE_string(
25 | 'graphs_dir', None, 'The directory containing the graphs.', required=True
26 | )
27 | _RANDOM_SEED = flags.DEFINE_integer(
28 | 'random_seed',
29 | None,
30 | 'The random seed to use for task generation.',
31 | required=True,
32 | )
33 |
34 |
35 | def zero_shot(
36 | task: graph_tasks.GraphTask,
37 | graphs: list[nx.Graph],
38 | algorithms: list[str],
39 | text_encoders: list[str],
40 | cot: bool,
41 | random_seed: int,
42 | split: str,
43 | ) -> None:
44 | """Creating zero-shot or zero-cot examples for the given task.
45 |
46 | Args:
47 | task: the corresponding graph task.
48 | graphs: the list of graphs to use for the task.
49 | algorithms: the algorithm used to generate the graphs.
50 | text_encoders: the encoders to use in the tasks.
51 | cot: whether to apply cot or not.
52 | random_seed: the random seed to use in the process.
53 | split: whether we are creating a train or test split.
54 | """
55 | random.seed(random_seed)
56 | zero_shot_examples = utils.create_zero_shot_task(
57 | task, graphs, algorithms, text_encoders, cot=cot
58 | )
59 |
60 | file_name = task.name + ('_zero_cot_' if cot else '_zero_shot_')
61 |
62 | file_name += split + '.recordio'
63 | utils.write_examples(
64 | zero_shot_examples,
65 | os.path.join(_TASK_DIR.value, file_name),
66 | )
67 |
68 |
69 | def few_shot(
70 | task: graph_tasks.GraphTask,
71 | graphs: list[nx.Graph],
72 | few_shot_graphs: list[nx.Graph],
73 | algorithms: list[str],
74 | text_encoders: list[str],
75 | cot: bool,
76 | bag: bool,
77 | random_seed: int,
78 | ) -> None:
79 | """Creating few-shot, cot, or cot-bag examples for the given task.
80 |
81 | Args:
82 | task: the corresponding graph task.
83 | graphs: the list of graphs to use for the task.
84 | few_shot_graphs: the list of graphs to generate few shot examples for.
85 | algorithms: the algorithm used to generate the graphs.
86 | text_encoders: the encoders to use in the tasks.
87 | cot: whether to apply cot or not.
88 | bag: whether to apply build-a-graph method or not.
89 | random_seed: the random seed to use in the process.
90 | """
91 | random.seed(random_seed)
92 | few_shot_examples = utils.create_few_shot_task(
93 | task,
94 | graphs,
95 | algorithms,
96 | few_shot_graphs,
97 | text_encoders,
98 | cot=cot,
99 | bag=bag,
100 | random_seed=random_seed,
101 | )
102 | file_name = task.name
103 | if cot and bag:
104 | file_name += '_few_shot_cot_bag_test.recordio'
105 | elif cot:
106 | file_name += '_few_shot_cot_test.recordio'
107 | else:
108 | file_name += '_few_shot_test.recordio'
109 |
110 | utils.write_examples(
111 | few_shot_examples,
112 | os.path.join(_TASK_DIR.value, file_name),
113 | )
114 |
115 |
116 | def generate_random_sbm_graph(random_state: np.random.RandomState):
117 | # Sampling a small number as the probability of the two nodes in different
118 | # communities being connected.
119 | small_number = random.uniform(0, 0.05)
120 | # Sampling a large number as probability of the nodes in one community
121 | # being connected.
122 | large_number = random.uniform(0.6, 0.8)
123 | number_of_nodes = random.choice(np.arange(5, 20))
124 | sizes = [number_of_nodes // 2, number_of_nodes // 2]
125 | probs = [[large_number, small_number], [small_number, large_number]]
126 | return nx.stochastic_block_model(sizes, probs, seed=random_state)
127 |
128 |
129 | def main(argv: Sequence[str]) -> None:
130 | if len(argv) > 1:
131 | raise app.UsageError('Too many command-line arguments.')
132 |
133 | algorithms = ['er']
134 | directions = ['undirected']
135 | text_encoders = ['adjacency']
136 |
137 | # Loading the graphs.
138 | graphs = []
139 | generator_algorithms = []
140 | for algorithm in algorithms:
141 | for direction in directions:
142 | loaded_graphs = utils.load_graphs(
143 | _GRAPHS_DIR.value,
144 | algorithm,
145 | 'train',
146 | direction,
147 | )
148 | graphs += loaded_graphs
149 | generator_algorithms += [algorithm] * len(loaded_graphs)
150 |
151 | # Defining a task on the graphs
152 | task = graph_tasks.ShortestPath()
153 |
154 | if isinstance(task, graph_tasks.NodeClassification):
155 | # The node classification task requires SBM graphs. As it's not possible to
156 | # write graphs with data (e.g., blocks data as in SBM graphs), we regenerate
157 | # graphs.
158 |
159 | random_state = np.random.RandomState(_RANDOM_SEED.value)
160 | print('Generating sbm graphs')
161 | graphs = [
162 | generate_random_sbm_graph(random_state) for _ in range(len(graphs))
163 | ]
164 |
165 | zero_shot(
166 | task,
167 | graphs,
168 | generator_algorithms,
169 | text_encoders,
170 | cot=False,
171 | random_seed=_RANDOM_SEED.value,
172 | split='test',
173 | )
174 | zero_shot(
175 | task,
176 | graphs,
177 | generator_algorithms,
178 | text_encoders,
179 | cot=True,
180 | random_seed=_RANDOM_SEED.value,
181 | split='test',
182 | )
183 |
184 | # Loading few-shot graphs.
185 | few_shot_graphs = []
186 | for algorithm in algorithms:
187 | for direction in directions:
188 | few_shot_graphs += utils.load_graphs(
189 | _GRAPHS_DIR.value,
190 | algorithm,
191 | 'train',
192 | direction,
193 | )
194 |
195 | if isinstance(task, graph_tasks.NodeClassification):
196 | # The node classification task requires SBM graphs. As it's not possible to
197 | # write graphs with data (e.g., blocks data as in SBM graphs), we regenerate
198 | # graphs.
199 | random_state = np.random.RandomState(_RANDOM_SEED.value + 1)
200 | print('Generating few shot sbm graphs')
201 | few_shot_graphs = [
202 | generate_random_sbm_graph(random_state)
203 | for _ in range(len(few_shot_graphs))
204 | ]
205 |
206 | few_shot(
207 | task,
208 | graphs,
209 | few_shot_graphs,
210 | generator_algorithms,
211 | text_encoders,
212 | cot=False,
213 | bag=False,
214 | random_seed=_RANDOM_SEED.value,
215 | )
216 |
217 | few_shot(
218 | task,
219 | graphs,
220 | few_shot_graphs,
221 | generator_algorithms,
222 | text_encoders,
223 | cot=True,
224 | bag=False,
225 | random_seed=_RANDOM_SEED.value,
226 | )
227 |
228 | few_shot(
229 | task,
230 | graphs,
231 | few_shot_graphs,
232 | generator_algorithms,
233 | text_encoders,
234 | cot=True,
235 | bag=True,
236 | random_seed=_RANDOM_SEED.value,
237 | )
238 |
239 |
240 | if __name__ == '__main__':
241 | app.run(main)
242 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_tasks_utils.py:
--------------------------------------------------------------------------------
1 | """The graph tasks to be tried with LLMs."""
2 |
3 | import os
4 | import random
5 |
6 | import networkx as nx
7 | import numpy as np
8 | import seqio
9 | import tensorflow as tf
10 | import tensorflow_gnn as tfgnn
11 |
12 | # Google-internal import(s).
13 | # Internal import.
14 | from . import graph_tasks
15 | from tensorflow.core.example import example_pb2
16 | from tensorflow.core.example import feature_pb2
17 |
18 |
19 | def laplacian_pos_embedding(graph: nx.Graph, units: int = 4) -> nx.Graph:
20 | """Adds the laplacian positional encoding."""
21 | m = nx.normalized_laplacian_matrix(
22 | graph, nodelist=sorted(graph.nodes), weight=None
23 | ).astype(np.float32)
24 | u, _, _ = np.linalg.svd(m.todense(), compute_uv=True)
25 | if units > u.shape[1]:
26 | u = np.pad(u, ((0, 0), (0, units - u.shape[1])))
27 | nx.set_node_attributes(
28 | graph, dict(zip(sorted(graph.nodes), u[:, :units])), name='lpe'
29 | )
30 | return graph
31 |
32 |
33 | def to_tfgnn(graph: nx.Graph, node_ids: list[int]) -> tfgnn.GraphTensor:
34 | """Convert a given nx graph to a tfgnn graph."""
35 | if graph.edges(data=True):
36 | s, t, w = zip(*[
37 | (s, t, (d['weight'] if d and 'weight' in d else None))
38 | for s, t, d in graph.edges(data=True)
39 | ])
40 | else:
41 | s, t, w = (), (), ()
42 | # tfgnn assumes graphs are directed. Adding the rev edges for an undirected
43 | # graph.
44 | if not graph.is_directed():
45 | s, t, w = s + t, t + s, w + w
46 |
47 | graph = laplacian_pos_embedding(graph, units=4)
48 | features = set(k for n in graph.nodes for k in graph.nodes[n].keys()) # pylint: disable=g-complex-comprehension
49 | node_features = {
50 | f: tf.convert_to_tensor([graph.nodes[n][f] for n in graph.nodes])
51 | for f in features
52 | }
53 | # If all edges have a non-trivial weight, then we record the weights.
54 | if all(w):
55 | edge_features = {'weights': tf.convert_to_tensor(w, dtype=tf.int32)}
56 | gt = tfgnn.homogeneous(
57 | tf.convert_to_tensor(s, dtype=tf.int32),
58 | tf.convert_to_tensor(t, dtype=tf.int32),
59 | node_features=node_features,
60 | edge_features=edge_features,
61 | )
62 | else:
63 | gt = tfgnn.homogeneous(
64 | tf.convert_to_tensor(s, dtype=tf.int32),
65 | tf.convert_to_tensor(t, dtype=tf.int32),
66 | node_features=node_features,
67 | )
68 |
69 | if not node_ids:
70 | # No node is mentioned in the task description.
71 | return gt
72 | node_sets = {
73 | **gt.node_sets,
74 | '_readout': tfgnn.NodeSet.from_fields(sizes=[1]),
75 | }
76 | if len(node_ids) == 1:
77 | # Tasks requiring only one node id e.g., computing node degree.
78 | edge_sets = {
79 | **gt.edge_sets,
80 | '_readout/node': tfgnn.EdgeSet.from_fields(
81 | sizes=[1],
82 | adjacency=tfgnn.Adjacency.from_indices(
83 | source=('nodes', node_ids),
84 | target=('_readout', [0]),
85 | ),
86 | ),
87 | }
88 | elif len(node_ids) == 2:
89 | # Tasks requiring two nodes e.g., shortest path from one node to the other.
90 | edge_sets = {
91 | **gt.edge_sets,
92 | '_readout/source': tfgnn.EdgeSet.from_fields(
93 | sizes=[1],
94 | adjacency=tfgnn.Adjacency.from_indices(
95 | source=('nodes', node_ids[:1]),
96 | target=('_readout', [0]),
97 | ),
98 | ),
99 | '_readout/target': tfgnn.EdgeSet.from_fields(
100 | sizes=[1],
101 | adjacency=tfgnn.Adjacency.from_indices(
102 | source=('nodes', node_ids[1:]),
103 | target=('_readout', [0]),
104 | ),
105 | ),
106 | }
107 | else:
108 | # Raising an error if more than two nodes are mentiones.
109 | raise ValueError(f'Invalid number of integers: {len(node_ids)}')
110 |
111 | return tfgnn.GraphTensor.from_pieces(
112 | context=gt.context, node_sets=node_sets, edge_sets=edge_sets
113 | )
114 |
115 |
116 | def create_example_feature(
117 | key: int,
118 | question: str,
119 | answer: str,
120 | algorithm: str,
121 | encoding_method: str,
122 | nnodes: str,
123 | nedges: str,
124 | task_description: str,
125 | graph: nx.Graph,
126 | node_ids: list[int],
127 | ) -> example_pb2.Example:
128 | """Create a tensorflow example from a datapoint."""
129 | key_feature = feature_pb2.Feature(
130 | bytes_list=tf.train.BytesList(value=[str(key).encode()])
131 | )
132 | question_feature = feature_pb2.Feature(
133 | bytes_list=tf.train.BytesList(value=[question.encode()])
134 | )
135 | answer_feature = feature_pb2.Feature(
136 | bytes_list=tf.train.BytesList(value=[answer.encode()])
137 | )
138 | algorithm_feature = feature_pb2.Feature(
139 | bytes_list=tf.train.BytesList(value=[algorithm.encode()])
140 | )
141 | encoding_method_feature = feature_pb2.Feature(
142 | bytes_list=tf.train.BytesList(value=[encoding_method.encode()])
143 | )
144 | nnodes_feature = feature_pb2.Feature(
145 | bytes_list=tf.train.BytesList(value=[nnodes.encode()])
146 | )
147 | nedges_feature = feature_pb2.Feature(
148 | bytes_list=tf.train.BytesList(value=[nedges.encode()])
149 | )
150 | task_description_feature = feature_pb2.Feature(
151 | bytes_list=tf.train.BytesList(value=[task_description.encode()])
152 | )
153 | gt = to_tfgnn(graph, node_ids)
154 | graph_feature = feature_pb2.Feature(
155 | bytes_list=tf.train.BytesList(
156 | value=[tfgnn.write_example(gt).SerializeToString()]
157 | )
158 | )
159 | directed_feature = feature_pb2.Feature(
160 | bytes_list=tf.train.BytesList(value=[str(graph.is_directed()).encode()])
161 | )
162 | example_feats = tf.train.Features(
163 | feature={
164 | 'id': key_feature,
165 | 'question': question_feature,
166 | 'answer': answer_feature,
167 | 'algorithm': algorithm_feature,
168 | 'text_encoding': encoding_method_feature,
169 | 'nnodes': nnodes_feature,
170 | 'nedges': nedges_feature,
171 | 'task_description': task_description_feature,
172 | 'graph': graph_feature,
173 | 'directed': directed_feature,
174 | }
175 | )
176 | return example_pb2.Example(features=example_feats)
177 |
178 |
179 | def load_graphs(
180 | base_path: str,
181 | algorithm: str,
182 | split: str,
183 | direction: str,
184 | max_nnodes: int = 20,
185 | ) -> list[nx.Graph]:
186 | """Load a list of graphs from a given algorithm and split."""
187 | graphs_path = os.path.join(
188 | base_path,
189 | direction,
190 | algorithm,
191 | split,
192 | )
193 | loaded_graphs = []
194 | all_files = gfile.ListDir(graphs_path)
195 | for file in all_files:
196 | if file.endswith('.graphml'):
197 | path = os.path.join(graphs_path, file)
198 | graph = nx.read_graphml(os.Open(path, 'rb'), node_type=int)
199 | if graph.number_of_nodes() <= max_nnodes:
200 | loaded_graphs.append(graph)
201 | return loaded_graphs
202 |
203 |
204 | def prepare_examples(
205 | examples_dict: dict[int, dict[str, str | list[int]]],
206 | encoding_method: str,
207 | ) -> list[example_pb2.Example]:
208 | """Create a list of tf.train.Example from a dict of examples."""
209 | examples = []
210 | for key, value in examples_dict.items():
211 | (
212 | question,
213 | answer,
214 | nnodes,
215 | nedges,
216 | task_description,
217 | graph,
218 | algorithm,
219 | node_ids,
220 | ) = (
221 | value['question'],
222 | value['answer'],
223 | value['nnodes'],
224 | value['nedges'],
225 | value['task_description'],
226 | value['graph'],
227 | value['algorithm'],
228 | value['node_ids'],
229 | )
230 | examples.append(
231 | create_example_feature(
232 | key,
233 | question,
234 | answer,
235 | algorithm,
236 | encoding_method,
237 | nnodes,
238 | nedges,
239 | task_description,
240 | graph,
241 | node_ids,
242 | )
243 | )
244 | return examples
245 |
246 |
247 | def create_zero_shot_task(
248 | task: graph_tasks.GraphTask,
249 | graphs: list[nx.Graph],
250 | generator_algorithms: list[str],
251 | text_encoders: list[str],
252 | cot: bool = False,
253 | ) -> list[example_pb2.Example]:
254 | """Create a recordio file with zero-shot examples for the task."""
255 | examples = []
256 | for encoding_method in text_encoders:
257 | examples_dict = task.prepare_examples_dict(
258 | graphs, generator_algorithms, encoding_method
259 | )
260 | if cot:
261 | for key in examples_dict.keys():
262 | examples_dict[key]['question'] += "Let's think step by step. "
263 | examples += prepare_examples(examples_dict, encoding_method)
264 | return examples
265 |
266 |
267 | def write_examples(examples: list[example_pb2.Example], output_path: str):
268 | with recordio.RecordWriter(output_path) as output_file:
269 | for example in examples:
270 | output_file.WriteRecord(example.SerializeToString())
271 |
272 |
273 | def prepare_few_shots(
274 | task: graph_tasks.GraphTask,
275 | graphs: list[nx.Graph],
276 | text_encoders: list[str],
277 | cot: bool,
278 | ) -> dict[str, list[str]]:
279 | """Create a dict of few-shot examples with their cot for the task."""
280 | few_shots_examples_dict = {}
281 | for encoding_method in text_encoders:
282 | if encoding_method not in few_shots_examples_dict:
283 | few_shots_examples_dict[(encoding_method)] = []
284 | for graph in graphs:
285 | few_shots_examples_dict[(encoding_method)].append(
286 | task.create_few_shot_example(graph, encoding_method, cot)
287 | )
288 | return few_shots_examples_dict
289 |
290 |
291 | def choose_few_shot_examples(
292 | few_shots_dict: dict[str, list[str]],
293 | encoding_method: str,
294 | k: int = 2,
295 | ) -> str:
296 | """Choose few shot examples for each algorithm."""
297 | few_shots_str = ''
298 | for _ in range(k):
299 | example_list = few_shots_dict[encoding_method]
300 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n'
301 | return few_shots_str
302 |
303 |
304 | def create_few_shot_task(
305 | task: graph_tasks.GraphTask,
306 | graphs: list[nx.Graph],
307 | generator_algorithms: list[str],
308 | few_shots_graphs: list[nx.Graph],
309 | text_encoders: list[str],
310 | cot: bool,
311 | bag: bool,
312 | random_seed: int,
313 | ) -> list[example_pb2.Example]:
314 | """Create a recordio file with few-shot examples for the task."""
315 | # LINT.IfChange
316 | vocab_path = None
317 | # LINT.ThenChange(//research/graph/llm/graphqa/copy.bara.sky)
318 | # Loading the palm tokenizer to calculate number of tokens in the sequence.
319 | sp_vocab = seqio.SentencePieceVocabulary(vocab_path)
320 | number_of_tokens = {}
321 | examples = []
322 | print('prepare few shot task', 'cot', cot, 'bag', bag)
323 | few_shots_examples_dict = prepare_few_shots(
324 | task,
325 | few_shots_graphs,
326 | text_encoders,
327 | cot,
328 | )
329 | for encoding_method in text_encoders:
330 | random.seed(random_seed)
331 | examples_dict = task.prepare_examples_dict(
332 | graphs, generator_algorithms, encoding_method
333 | )
334 | for key in examples_dict.keys():
335 | few_shots_examples = choose_few_shot_examples(
336 | few_shots_examples_dict,
337 | encoding_method,
338 | )
339 | examples_dict[key]['question'] = (
340 | few_shots_examples + 'Example: ' + examples_dict[key]['question']
341 | )
342 | if bag:
343 | examples_dict[key]['question'] = examples_dict[key]['question'].replace(
344 | '\nQ: ',
345 | "\nLet's construct the graph with the nodes and edges first.\nQ: ",
346 | ) # pytype: disable=attribute-error
347 | if encoding_method not in number_of_tokens:
348 | number_of_tokens[encoding_method] = []
349 | number_of_tokens[encoding_method].append(
350 | len(sp_vocab.encode(examples_dict[key]['question']))
351 | )
352 | examples += prepare_examples(examples_dict, encoding_method)
353 |
354 | # Printing maximum number of tokens in the sequence.
355 | for key, value in number_of_tokens.items():
356 | print(key, np.max(value))
357 |
358 | return examples
359 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_text_encoders.py:
--------------------------------------------------------------------------------
1 | """Library for encoding graphs in text."""
2 |
3 | import networkx as nx
4 |
5 | from . import name_dictionaries
6 |
7 |
8 | def create_node_string(name_dict, nnodes: int) -> str:
9 | node_string = ""
10 | sorted_keys = list(sorted(name_dict.keys()))
11 | for i in sorted_keys[: nnodes - 1]:
12 | node_string += name_dict[i] + ", "
13 | node_string += "and " + name_dict[sorted_keys[nnodes - 1]]
14 | return node_string
15 |
16 |
17 | def nx_encoder(graph: nx.Graph, _: dict[int, str], edge_type="id") -> str:
18 | """Encoding a graph as entries of an adjacency matrix."""
19 | if graph.is_directed():
20 | output = (
21 | "In a directed graph, (s,p,o) means that there is an edge from node s"
22 | " to node o of type p. "
23 | )
24 | else:
25 | output = (
26 | "In an undirected graph, (s,p,o) means that node s and node o are"
27 | " connected with an undirected edge of type p. "
28 | )
29 |
30 | name_dict = {x: str(x) for x in graph.nodes()}
31 |
32 | nodes_string = create_node_string(name_dict, nnodes=len(graph.nodes()))
33 | output += "G describes a graph among nodes %s.\n" % nodes_string
34 | if graph.edges():
35 | output += "The edges in G are: "
36 | for i, j in graph.edges():
37 | edge_type = graph.get_edge_data(i, j)[edge_type]
38 | if edge_type is None:
39 | edge_type = "linked"
40 | output += "(%s, %s, %s) " % (name_dict[i], edge_type, name_dict[j])
41 | return output.strip() + ".\n"
42 |
43 |
44 | def adjacency_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
45 | """Encoding a graph as entries of an adjacency matrix."""
46 | if graph.is_directed():
47 | output = (
48 | "In a directed graph, (i,j) means that there is an edge from node i to"
49 | " node j. "
50 | )
51 | else:
52 | output = (
53 | "In an undirected graph, (i,j) means that node i and node j are"
54 | " connected with an undirected edge. "
55 | )
56 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
57 | output += "G describes a graph among nodes %s.\n" % nodes_string
58 | if graph.edges():
59 | output += "The edges in G are: "
60 | for i, j in graph.edges():
61 | output += "(%s, %s) " % (name_dict[i], name_dict[j])
62 | return output.strip() + ".\n"
63 |
64 |
65 | def friendship_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
66 | """Encoding a graph as a friendship graph."""
67 | if graph.is_directed():
68 | raise ValueError("Friendship encoder is not defined for directed graphs.")
69 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
70 | output = (
71 | "G describes a friendship graph among nodes %s.\n" % nodes_string.strip()
72 | )
73 | if graph.edges():
74 | output += "We have the following edges in G:\n"
75 | for i, j in graph.edges():
76 | output += "%s and %s are friends.\n" % (name_dict[i], name_dict[j])
77 | return output
78 |
79 |
80 | def coauthorship_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
81 | """Encoding a graph as a coauthorship graph."""
82 | if graph.is_directed():
83 | raise ValueError("Coauthorship encoder is not defined for directed graphs.")
84 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
85 | output = (
86 | "G describes a coauthorship graph among nodes %s.\n"
87 | % nodes_string.strip()
88 | )
89 | if graph.edges():
90 | output += "In this coauthorship graph:\n"
91 | for i, j in graph.edges():
92 | output += "%s and %s wrote a paper together.\n" % (
93 | name_dict[i],
94 | name_dict[j],
95 | )
96 | return output.strip() + ".\n"
97 |
98 |
99 | def incident_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
100 | """Encoding a graph with its incident lists."""
101 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
102 | output = "G describes a graph among nodes %s.\n" % nodes_string
103 | if graph.edges():
104 | output += "In this graph:\n"
105 | for source_node in graph.nodes():
106 | target_nodes = graph.neighbors(source_node)
107 | target_nodes_str = ""
108 | nedges = 0
109 | for target_node in target_nodes:
110 | target_nodes_str += name_dict[target_node] + ", "
111 | nedges += 1
112 | if nedges > 1:
113 | output += "Node %s is connected to nodes %s.\n" % (
114 | source_node,
115 | target_nodes_str[:-2],
116 | )
117 | elif nedges == 1:
118 | output += "Node %d is connected to node %s.\n" % (
119 | source_node,
120 | target_nodes_str[:-2],
121 | )
122 | return output
123 |
124 |
125 | def social_network_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
126 | """Encoding a graph as a social network graph."""
127 | if graph.is_directed():
128 | raise ValueError(
129 | "Social network encoder is not defined for directed graphs."
130 | )
131 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
132 | output = (
133 | "G describes a social network graph among nodes %s.\n"
134 | % nodes_string.strip()
135 | )
136 | if graph.edges():
137 | output += "We have the following edges in G:\n"
138 | for i, j in graph.edges():
139 | output += "%s and %s are connected.\n" % (name_dict[i], name_dict[j])
140 | return output
141 |
142 |
143 | def expert_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str:
144 | nodes_string = create_node_string(name_dict, len(graph.nodes()))
145 | output = (
146 | "You are a graph analyst and you have been given a graph G among nodes"
147 | " %s.\n"
148 | % nodes_string.strip()
149 | )
150 | output += "G has the following undirected edges:\n" if graph.edges() else ""
151 | for i, j in graph.edges():
152 | output += "%s -> %s\n" % (name_dict[i], name_dict[j])
153 | return output
154 |
155 |
156 | def nodes_to_text(graph, encoding_type):
157 | """Get dictionary converting node ids to text."""
158 | if encoding_type == "integer":
159 | return name_dictionaries.create_name_dict(graph, "integer", nnodes=1000)
160 | elif encoding_type == "popular":
161 | return name_dictionaries.create_name_dict(graph, "popular")
162 | elif encoding_type == "alphabet":
163 | return name_dictionaries.create_name_dict(graph, "alphabet")
164 | elif encoding_type == "got":
165 | return name_dictionaries.create_name_dict(graph, "got")
166 | elif encoding_type == "south_park":
167 | return name_dictionaries.create_name_dict(graph, "south_park")
168 | elif encoding_type == "politician":
169 | return name_dictionaries.create_name_dict(graph, "politician")
170 | elif encoding_type == "random":
171 | return name_dictionaries.create_name_dict(
172 | graph, "random_integer", nnodes=1000
173 | )
174 | elif encoding_type == "nx_node_name":
175 | return name_dictionaries.create_name_dict(graph, "nx_node_name")
176 | else:
177 | raise ValueError("Unknown encoding type: %s" % encoding_type)
178 |
179 |
180 | def get_tlag_node_encoder(graph, encoder_name):
181 | """Find the node encoder used in the 'Talk Like a Graph' paper."""
182 | if encoder_name == "adjacency":
183 | return nodes_to_text(graph, "integer")
184 | elif encoder_name == "incident":
185 | return nodes_to_text(graph, "integer")
186 | elif encoder_name == "friendship":
187 | return nodes_to_text(graph, "popular")
188 | elif encoder_name == "south_park":
189 | return nodes_to_text(graph, "south_park")
190 | elif encoder_name == "got":
191 | return nodes_to_text(graph, "got")
192 | elif encoder_name == "politician":
193 | return nodes_to_text(graph, "politician")
194 | elif encoder_name == "social_network":
195 | return nodes_to_text(graph, "popular")
196 | elif encoder_name == "expert":
197 | return nodes_to_text(graph, "expert")
198 | elif encoder_name == "coauthorship":
199 | return nodes_to_text(graph, "popular")
200 | elif encoder_name == "random":
201 | return nodes_to_text(graph, "random")
202 | elif encoder_name == "nx_node_name":
203 | return nodes_to_text(graph, "nx_node_name")
204 | else:
205 | raise ValueError("Unknown graph encoder strategy: %s" % encoder_name)
206 |
207 |
208 | # A dictionary from edge encoder name to the corresponding function.
209 | EDGE_ENCODER_FN = {
210 | "adjacency": adjacency_encoder,
211 | "incident": incident_encoder,
212 | "friendship": friendship_encoder,
213 | "south_park": friendship_encoder,
214 | "got": friendship_encoder,
215 | "politician": social_network_encoder,
216 | "social_network": social_network_encoder,
217 | "expert": expert_encoder,
218 | "coauthorship": coauthorship_encoder,
219 | "random": adjacency_encoder,
220 | "nx_edge_encoder": nx_encoder,
221 | }
222 |
223 |
224 | def with_ids(graph: nx.Graph, node_encoder: str) -> nx.Graph:
225 | nx.set_node_attributes(graph, nodes_to_text(graph, node_encoder), name="id")
226 | return graph
227 |
228 |
229 | def encode_graph(
230 | graph: nx.Graph, graph_encoder=None, node_encoder=None, edge_encoder=None
231 | ) -> str:
232 | r"""Encodes a graph as text.
233 |
234 | This relies on choosing:
235 | a node_encoder and an edge_encoder:
236 | or
237 | a graph_encoder (a predefined pair of node and edge encoding strategies).
238 |
239 | Note that graph_encoders may assume that the graph has some properties
240 | (e.g. integer keys).
241 |
242 | Example usage:
243 | .. code-block:: python
244 | ```
245 | # Use a predefined graph encoder from the paper.
246 | >>> G = nx.karate_club_graph()
247 | >>> encode_graph(G, graph_encoder="adjacency")
248 | 'In an undirected graph, (i,j) means that node i and node j are
249 | connected
250 | with an undirected edge. G describes a graph among nodes 0, 1, 2, 3, 4, 5,
251 | 6,
252 | 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
253 | 27, 28, 29, 30, 31, 32, and 33.\nThe edges in G are: (0, 1) (0, 2) (0, 3)
254 | ...'
255 |
256 | # Use the node's name in the graph as the node identifier.
257 | >>> G = nx.les_miserables_graph()
258 | >>> encode_graph(G, node_encoder="nx_node_name", edge_encoder="friendship")
259 | 'G describes a friendship graph among nodes Anzelma, Babet, Bahorel,
260 | Bamatabois, BaronessT, Blacheville, Bossuet, Boulatruelle, Brevet, ...
261 | We have the following edges in G:
262 | Napoleon and Myriel are friends. Myriel and MlleBaptistine are friends...'
263 |
264 | # Use the `id` feature from the edges to describe the edge type.
265 | >>> G = nx.karate_club_graph()
266 | >>> encode_graph(G, node_encoder="nx_node_name", edge_encoder="nx_edge_id")
267 | 'In an undirected graph, (s,p,o) means that node s and node o are connected
268 | with an undirected edge of type p. G describes a graph among nodes 0, 1, 2, 3,
269 | 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
270 | 25, 26, 27, 28, 29, 30, 31, 32, and 33.
271 | The edges in G are: (0, linked, 1) (0, linked, 2) (0, linked, 3) ...'
272 | ```
273 |
274 | Args:
275 | graph: the graph to be encoded.
276 | graph_encoder: the name of the graph encoder to use.
277 | node_encoder: the name of the node encoder to use.
278 | edge_encoder: the name of the edge encoder to use.
279 |
280 | Returns:
281 | The encoded graph as a string.
282 | """
283 |
284 | # Check that only one of graph_encoder or (node_encoder, edge_encoder) is set.
285 | if graph_encoder and (node_encoder or edge_encoder):
286 | raise ValueError(
287 | "Only one of graph_encoder or (node_encoder, edge_encoder) can be set."
288 | )
289 |
290 | if graph_encoder:
291 | if isinstance(graph_encoder, str):
292 | node_encoder_dict = get_tlag_node_encoder(graph, graph_encoder)
293 | return EDGE_ENCODER_FN[graph_encoder](graph, node_encoder_dict)
294 | else:
295 | return graph_encoder(graph)
296 |
297 | else:
298 | node_encoder_dict = nodes_to_text(graph, node_encoder)
299 | return EDGE_ENCODER_FN[edge_encoder](graph, node_encoder_dict)
300 |
--------------------------------------------------------------------------------
/talk_like_a_graph/graph_text_encoders_test.py:
--------------------------------------------------------------------------------
1 | """Testing for graph_text_encoders.py."""
2 |
3 | from absl.testing import parameterized
4 | import networkx as nx
5 |
6 | from . import graph_text_encoders
7 | from absl.testing import absltest
8 |
9 | _G = nx.Graph()
10 | _G.add_node(0)
11 | _G.add_node(1)
12 | _G.add_node(2)
13 | _G.add_node(3)
14 | _G.add_edge(0, 1)
15 | _G.add_edge(1, 2)
16 | _G.add_edge(2, 3)
17 | _G.add_edge(3, 0)
18 |
19 |
20 | class GraphTextEncodersTest(absltest.TestCase, parameterized.TestCase):
21 |
22 | @parameterized.named_parameters(
23 | dict(
24 | testcase_name='adjacency_integer',
25 | encoding_method='adjacency',
26 | expected_result=(
27 | 'In an undirected graph, (i,j) means that node i and node j are'
28 | ' connected with an undirected edge. G describes a graph among'
29 | ' nodes 0, 1, 2, and 3.\nThe edges in G are: (0, 1) (0, 3) (1, 2)'
30 | ' (2, 3).\n'
31 | ),
32 | ),
33 | dict(
34 | testcase_name='incident_integer',
35 | encoding_method='incident',
36 | expected_result=(
37 | 'G describes a graph among nodes 0, 1, 2, and 3.\nIn this'
38 | ' graph:\nNode 0 is connected to nodes 1, 3.\nNode 1 is connected'
39 | ' to nodes 0, 2.\nNode 2 is connected to nodes 1, 3.\nNode 3 is'
40 | ' connected to nodes 2, 0.\n'
41 | ),
42 | ),
43 | dict(
44 | testcase_name='friendship_per_line_popular',
45 | encoding_method='friendship',
46 | expected_result=(
47 | 'G describes a friendship graph among nodes James, Robert, John,'
48 | ' and Michael.\nWe have the following edges in G:\nJames and'
49 | ' Robert are friends.\nJames and Michael are friends.\nRobert and'
50 | ' John are friends.\nJohn and Michael are friends.\n'
51 | ),
52 | ),
53 | dict(
54 | testcase_name='social_network_politician',
55 | encoding_method='politician',
56 | expected_result=(
57 | 'G describes a social network graph among nodes Barack, Jimmy,'
58 | ' Arnold, and Bernie.\nWe have the following edges in'
59 | ' G:\nBarack and Jimmy are connected.\nBarack and Bernie are'
60 | ' connected.\nJimmy and Arnold are connected.\nArnold and'
61 | ' Bernie are connected.\n'
62 | ),
63 | ),
64 | )
65 | def test_encoders(self, encoding_method, expected_result):
66 | self.assertEqual(
67 | graph_text_encoders.encode_graph(_G, encoding_method),
68 | expected_result,
69 | )
70 |
71 |
72 | if __name__ == '__main__':
73 | googletest.main()
74 |
--------------------------------------------------------------------------------
/talk_like_a_graph/name_dictionaries.py:
--------------------------------------------------------------------------------
1 | """Creates a dictionary mapping integers to node names."""
2 |
3 | import random
4 |
5 | _RANDOM_SEED = 1234
6 | random.seed(_RANDOM_SEED)
7 |
8 | _INTEGER_NAMES = [str(x) for x in range(10000)]
9 |
10 | _POPULAR_NAMES = [
11 | "James",
12 | "Robert",
13 | "John",
14 | "Michael",
15 | "David",
16 | "Mary",
17 | "Patricia",
18 | "Jennifer",
19 | "Linda",
20 | "Elizabeth",
21 | "William",
22 | "Richard",
23 | "Joseph",
24 | "Thomas",
25 | "Christopher",
26 | "Barbara",
27 | "Susan",
28 | "Jessica",
29 | "Sarah",
30 | "Karen",
31 | "Daniel",
32 | "Lisa",
33 | "Matthew",
34 | "Nancy",
35 | "Anthony",
36 | "Betty",
37 | "Mark",
38 | "Margaret",
39 | "Donald",
40 | "Sandra",
41 | "Steven",
42 | "Ashley",
43 | "Paul",
44 | "Kimberly",
45 | "Andrew",
46 | "Emily",
47 | "Joshua",
48 | "Donna",
49 | "Kenneth",
50 | "Michelle",
51 | "Kevin",
52 | "Carol",
53 | "Brian",
54 | "Amanda",
55 | "George",
56 | "Melissa",
57 | "Edward",
58 | "Deborah",
59 | "Ronald",
60 | "Stephanie",
61 | "Timothy",
62 | "Rebecca",
63 | "Jason",
64 | "Sharon",
65 | "Jeffrey",
66 | "Laura",
67 | "Ryan",
68 | "Cynthia",
69 | "Jacob",
70 | "Dorothy",
71 | "Gary",
72 | "Olivia",
73 | "Nicholas",
74 | "Emma",
75 | "Eric",
76 | "Sophia",
77 | "Jonathan",
78 | "Ava",
79 | "Stephen",
80 | "Isabella",
81 | "Scott",
82 | "Mia",
83 | "Justin",
84 | "Abigail",
85 | "Brandon",
86 | "Madison",
87 | "Frank",
88 | "Chloe",
89 | "Benjamin",
90 | "Victoria",
91 | "Samuel",
92 | "Lauren",
93 | "Gregory",
94 | "Hannah",
95 | "Alexander",
96 | "Grace",
97 | "Frank",
98 | "Alexis",
99 | "Raymond",
100 | "Alice",
101 | "Patrick",
102 | "Samantha",
103 | "Jack",
104 | "Natalie",
105 | "Dennis",
106 | "Anna",
107 | "Jerry",
108 | "Taylor",
109 | "Tyler",
110 | "Kayla",
111 | "Henry",
112 | "Hailey",
113 | "Douglas",
114 | "Jasmine",
115 | "Peter",
116 | "Nicole",
117 | "Adam",
118 | "Amy",
119 | "Nathan",
120 | "Christina",
121 | "Zachary",
122 | "Andrea",
123 | "Jose",
124 | "Leah",
125 | "Walter",
126 | "Angelina",
127 | "Harold",
128 | "Valerie",
129 | "Kyle",
130 | "Veronica",
131 | "Ethan",
132 | "Carl",
133 | "Arthur",
134 | "Roger",
135 | "Noah",
136 | ]
137 |
138 |
139 | _SOUTH_PARK_NAMES = [
140 | "Eric",
141 | "Kenny",
142 | "Kyle",
143 | "Stan",
144 | "Tolkien",
145 | "Heidi",
146 | "Bebe",
147 | "Liane",
148 | "Sharon",
149 | "Linda",
150 | "Gerald",
151 | "Veronica",
152 | "Michael",
153 | "Jimbo",
154 | "Herbert",
155 | "Malcolm",
156 | "Gary",
157 | "Steve",
158 | "Chris",
159 | "Wendy",
160 | ]
161 |
162 | _GOT_NAMES = [
163 | "Ned",
164 | "Cat",
165 | "Daenerys",
166 | "Jon",
167 | "Bran",
168 | "Sansa",
169 | "Arya",
170 | "Cersei",
171 | "Jaime",
172 | "Petyr",
173 | "Robert",
174 | "Jorah",
175 | "Viserys",
176 | "Joffrey",
177 | "Maester",
178 | "Theon",
179 | "Rodrik",
180 | "Lysa",
181 | "Stannis",
182 | "Osha",
183 | ]
184 |
185 |
186 | _POLITICIAN_NAMES = [
187 | "Barack",
188 | "Jimmy",
189 | "Arnold",
190 | "Bernie",
191 | "Bill",
192 | "Kamala",
193 | "Hillary",
194 | "Elizabeth",
195 | "John",
196 | "Ben",
197 | "Joe",
198 | "Alexandria",
199 | "George",
200 | "Nancy",
201 | "Pete",
202 | "Madeleine",
203 | "Elijah",
204 | "Gabrielle",
205 | "Al",
206 | ]
207 |
208 |
209 | _ALPHABET_NAMES = [
210 | "A",
211 | "B",
212 | "C",
213 | "D",
214 | "E",
215 | "F",
216 | "G",
217 | "H",
218 | "I",
219 | "J",
220 | "K",
221 | "L",
222 | "M",
223 | "N",
224 | "O",
225 | "P",
226 | "Q",
227 | "R",
228 | "S",
229 | "T",
230 | "U",
231 | "V",
232 | "W",
233 | "X",
234 | "Y",
235 | "Z",
236 | "AA",
237 | "BB",
238 | "CC",
239 | "DD",
240 | "EE",
241 | "FF",
242 | "GG",
243 | "HH",
244 | "II",
245 | "JJ",
246 | "KK",
247 | "LL",
248 | "MM",
249 | "NN",
250 | "OO",
251 | "PP",
252 | "QQ",
253 | "RR",
254 | "SS",
255 | "TT",
256 | "UU",
257 | "VV",
258 | "WW",
259 | "XX",
260 | "YY",
261 | "ZZ",
262 | ]
263 |
264 |
265 | def create_name_dict(graph, name: str, nnodes: int = 20) -> dict[int, str]:
266 | """The runner function to map integers to node names.
267 |
268 | Args:
269 | graph: the graph to be encoded.
270 | name: name of the approach for mapping.
271 | nnodes: optionally provide nnodes in the graph to be encoded.
272 |
273 | Returns:
274 | A dictionary from integers to strings.
275 | """
276 | if name == "alphabet":
277 | names_list = _ALPHABET_NAMES
278 | elif name == "integer":
279 | names_list = _INTEGER_NAMES
280 | elif name == "random_integer":
281 | names_list = []
282 | for _ in range(nnodes):
283 | names_list.append(str(random.randint(0, 1000000)))
284 | elif name == "popular":
285 | names_list = _POPULAR_NAMES
286 | elif name == "south_park":
287 | names_list = _SOUTH_PARK_NAMES
288 | elif name == "got":
289 | names_list = _GOT_NAMES
290 | elif name == "politician":
291 | names_list = _POLITICIAN_NAMES
292 | elif name == "nx_node_name":
293 | return {x: str(x) for x in graph.nodes()}
294 | else:
295 | raise ValueError(f"Unknown approach: {name}")
296 | name_dict = {}
297 | for ind, value in enumerate(names_list):
298 | name_dict[ind] = value
299 |
300 | return name_dict
301 |
--------------------------------------------------------------------------------
/tutorial/KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "V28"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "TPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "This is a noteboook that illustrates how to use Graph Neural Networks to encode structured data for use in Large Language Models. It is Part 2 of a two part tutorial from KDD'24.\n",
23 | "\n",
24 | "**This notebook requires a TPUv2 runtime**\n",
25 | "\n",
26 | "If you find this tutorial useful or want to know more, please consider our publication:\n",
27 | "Let your graph do the talking: Encoding structured data for LLMs\n",
28 | "```\n",
29 | "@article{perozzi2024let,\n",
30 | " title={Let your graph do the talking: Encoding structured data for llms},\n",
31 | " author={Perozzi, Bryan and Fatemi, Bahare and Zelle, Dustin and Tsitsulin, Anton and Kazemi, Mehran and Al-Rfou, Rami and Halcrow, Jonathan},\n",
32 | " journal={arXiv preprint arXiv:2402.05862},\n",
33 | " year={2024}\n",
34 | "}\n",
35 | "```\n",
36 | "\n",
37 | "## Tutorial Part II: GNN Encoding of Graph Information\n",
38 | "This notebook takes the work we did in the first part of the tutorial and extends it to using a Graph Neural Network to directly encode a representation of a graph into a prompt (vs using a text encoding as we did in the previous part).\n",
39 | "\n",
40 | "## Notebook Outline:\n",
41 | "\n",
42 | "Setup (Install Dependencies, download Gemma weights)\n",
43 | "Dataset creation\n",
44 | "Graph-to-Text conversion\n",
45 | "Evaluation\n",
46 | "Exercise: Graph Encoding Challenge\n",
47 | "Exercise: DBLP Dataset\n",
48 | "Setup\n",
49 | "\n",
50 | "## Prework!\n",
51 | "\n",
52 | "Sign-up for Kaggle and consent to the Gemma TOS (this is a requirement to download the Gemma weights used in this notebook).\n",
53 | "https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2"
54 | ],
55 | "metadata": {
56 | "id": "EdaaBjBQ3c5g"
57 | }
58 | },
59 | {
60 | "cell_type": "code",
61 | "source": [
62 | "%%capture\n",
63 | "# @title Install Dependencies\n",
64 | "!pip install git+https://github.com/google-deepmind/gemma.git\n",
65 | "!pip install --user kaggle\n",
66 | "!pip install sparse_deferred\n",
67 | "!git clone https://github.com/google-research/talk-like-a-graph.git\n",
68 | "import sys\n",
69 | "sys.path.insert(0, \"/content/talk-like-a-graph\")\n"
70 | ],
71 | "metadata": {
72 | "id": "BtYsxQuUQ-K0"
73 | },
74 | "execution_count": null,
75 | "outputs": []
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "source": [
80 | "## Login to Kaggle\n",
81 | "Follow the link in the login dialog to get an API key if you don't already have one. Also make sure to approve the [Gemma TOS](https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2) as well."
82 | ],
83 | "metadata": {
84 | "id": "42csum1-l0bn"
85 | }
86 | },
87 | {
88 | "cell_type": "code",
89 | "source": [
90 | "import kagglehub\n",
91 | "\n",
92 | "kagglehub.login()"
93 | ],
94 | "metadata": {
95 | "id": "xmM6CSh7RbIk"
96 | },
97 | "execution_count": null,
98 | "outputs": []
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "source": [
103 | "## Download Gemma"
104 | ],
105 | "metadata": {
106 | "id": "500UNvQOmcMn"
107 | }
108 | },
109 | {
110 | "cell_type": "code",
111 | "source": [
112 | "import os\n",
113 | "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
114 | "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
115 | "ckpt_path = os.path.join(weights_dir, VARIANT)\n",
116 | "vocab_path = os.path.join(weights_dir, 'tokenizer.model')"
117 | ],
118 | "metadata": {
119 | "id": "W0BzTj2tPROY"
120 | },
121 | "execution_count": null,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "source": [
127 | "## Import dependencies"
128 | ],
129 | "metadata": {
130 | "id": "WGFZHBaCRa77"
131 | }
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {
137 | "id": "ycYOnVRpHZuJ",
138 | "cellView": "form"
139 | },
140 | "outputs": [],
141 | "source": [
142 | "# @title\n",
143 | "import os\n",
144 | "from collections.abc import Sequence\n",
145 | "import dataclasses\n",
146 | "from typing import Any, Callable, Mapping\n",
147 | "import sys\n",
148 | "\n",
149 | "\n",
150 | "import chex\n",
151 | "from flax import linen as nn\n",
152 | "import jax\n",
153 | "import jax.numpy as jnp\n",
154 | "import networkx as nx\n",
155 | "import numpy as np\n",
156 | "\n",
157 | "\n",
158 | "from gemma import params as params_lib\n",
159 | "from gemma import transformer as transformer_lib\n",
160 | "from gemma import sampler as sampler_lib\n",
161 | "import sentencepiece as spm\n",
162 | "import sparse_deferred as sd\n",
163 | "from sparse_deferred import jax as sdjnp\n",
164 | "from sparse_deferred.structs import graph_struct\n",
165 | "from sparse_deferred import np as sdnp\n"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "source": [
171 | "## GraphToken library code"
172 | ],
173 | "metadata": {
174 | "id": "DzQUN-bCmrjg"
175 | }
176 | },
177 | {
178 | "cell_type": "code",
179 | "source": [
180 | "# @title\n",
181 | "import collections\n",
182 | "from collections.abc import Iterable\n",
183 | "import io\n",
184 | "import json\n",
185 | "from typing import Any, Callable, NamedTuple, Sequence\n",
186 | "\n",
187 | "import numpy as np\n",
188 | "import tqdm\n",
189 | "\n",
190 | "# Code for converting NetworkX graphs to graph tensor\n",
191 | "def laplacian_pos_embedding(graph: nx.Graph, units: int = 4) -\u003e nx.Graph:\n",
192 | " \"\"\"Adds the laplacian positional encoding.\"\"\"\n",
193 | " m = nx.normalized_laplacian_matrix(\n",
194 | " graph, nodelist=sorted(graph.nodes), weight=None\n",
195 | " ).astype(np.float32)\n",
196 | " u, _, _ = np.linalg.svd(m.todense(), compute_uv=True)\n",
197 | " if units \u003e u.shape[1]:\n",
198 | " u = np.pad(u, ((0, 0), (0, units - u.shape[1])))\n",
199 | " nx.set_node_attributes(\n",
200 | " graph, dict(zip(sorted(graph.nodes), u[:, :units])), name='lpe'\n",
201 | " )\n",
202 | " return graph\n",
203 | "\n",
204 | "\n",
205 | "def to_graph_struct(graph: nx.Graph, node_ids: list[int]=None) -\u003e graph_struct.GraphStruct:\n",
206 | " if graph.edges(data=True):\n",
207 | " s, t, w = zip(*[\n",
208 | " (s, t, (d['weight'] if d and 'weight' in d else None))\n",
209 | " for s, t, d in graph.edges(data=True)\n",
210 | " ])\n",
211 | " else:\n",
212 | " s, t, w = (), (), ()\n",
213 | " # tfgnn assumes graphs are directed. Adding the rev edges for an undirected\n",
214 | " # graph.\n",
215 | " if not graph.is_directed():\n",
216 | " s, t, w = s + t, t + s, w + w\n",
217 | "\n",
218 | " graph = laplacian_pos_embedding(graph, units=4)\n",
219 | " return graph_struct.GraphStruct.new(\n",
220 | " nodes={'nodes': {'lpe': np.stack([graph.nodes('lpe')[i] for i in range(graph.number_of_nodes())])}},\n",
221 | " edges={'edges': ((np.array(s, dtype=np.int32), np.array(t, dtype=np.int32)), {})}\n",
222 | " )\n",
223 | "\n",
224 | "\n",
225 | "Tensor = sd.matrix.Tensor\n",
226 | "Features = dict[str, Tensor]\n",
227 | "FeatureSets = dict[str, Features]\n",
228 | "Edge = tuple[tuple[Tensor, ...], Features] # (endpoints, edge features)\n",
229 | "Edges = dict[str, Edge]\n",
230 | "Nodes = FeatureSets\n",
231 | "Schema = dict[str, tuple[str, ...]]\n",
232 | "_Schema = dict[str, tuple[dict[str, int], ...]]\n",
233 | "\n",
234 | "\n",
235 | "\n",
236 | "class FixedSizePadder:\n",
237 | " \"\"\"Adds padding to `GraphStruct` instances for fixed-sized tensors.\n",
238 | "\n",
239 | " Fixed-size tensors can be preferred when running on TPU accelerators.\n",
240 | "\n",
241 | " To use this class, you must first initialize it with statistics of your graphs\n",
242 | " then use it to pad graphs. The statistics can be initialized by invoking\n",
243 | " `calculate_pad_statistics`: this function records the *maximum* observerd size\n",
244 | " of every node and edge set, as well as the standard deviation (std) of sizes.\n",
245 | "\n",
246 | " Once initialized, the function: `pad_graph()` will add padding to the graph.\n",
247 | " Specifically, the node feature (tensors) will be padded with zeros. Similarly,\n",
248 | " edges will be inserted, among newly-added virtual nodes.\n",
249 | "\n",
250 | " Each node (or edge) size will become:\n",
251 | "\n",
252 | " `max observed [per calculate_pad_statistics] + slack*std + 1`\n",
253 | "\n",
254 | " NOTE: there will always be at least one more node or edge, even if the\n",
255 | " statistics show zero std. This is required for making virtual nodes.\n",
256 | "\n",
257 | " All sizes node-set (features) and edge-set (features and adjacency list)\n",
258 | " \"\"\"\n",
259 | "\n",
260 | " def __init__(self, engine: sd.ComputeEngine, slack: float = 1.0):\n",
261 | " # `('edge'|'node', NodeOrEdgeName) -\u003e target size`\n",
262 | " # where `target size` is maximum observed size for node (or edge) set, plus\n",
263 | " # one, plus slack-times-std of observed sizes.\n",
264 | " self.sizes: dict[tuple[str, str], int] = {}\n",
265 | " self.slack = slack\n",
266 | " self._engine = engine\n",
267 | "\n",
268 | " def calculate_pad_statistics(\n",
269 | " self, examples: Iterable[graph_struct.GraphStruct], num_steps: int = 100):\n",
270 | " \"\"\"Measures the max and std of node \u0026 edge sizes of elements of `examples`.\n",
271 | "\n",
272 | " Calling this function is necessary before invoking `pad_graph`.\n",
273 | "\n",
274 | " Args:\n",
275 | " examples: iterable that yields `GraphStruct` examples.\n",
276 | " num_steps: If positive, considers this many samples of `examples`.\n",
277 | " Otherwise, iterates over all `examples`. Warning: this may run\n",
278 | " infinitely on infinite iterators (e.g., `dataset.repeat()`).\n",
279 | " \"\"\"\n",
280 | " sizes: dict[tuple[str, str], list[int]] = collections.defaultdict(list)\n",
281 | " for i, graph in enumerate(examples):\n",
282 | " assert isinstance(graph, graph_struct.GraphStruct)\n",
283 | " if i \u003e 0 and i \u003e= num_steps:\n",
284 | " break\n",
285 | " for node_name, features in graph.nodes.items():\n",
286 | " value_list = sizes[('nodes', node_name)]\n",
287 | " if not features:\n",
288 | " value_list.append(0)\n",
289 | " else:\n",
290 | " value_list.append(list(features.values())[0].shape[0])\n",
291 | "\n",
292 | " for edge_name, edges_tuple in graph.edges.items():\n",
293 | " value_list = sizes[('edges', edge_name)]\n",
294 | " source_nodes = edges_tuple[0][0]\n",
295 | " # if len(value_list) and edge_set.sizes.shape != value_list[-1].shape:\n",
296 | " # continue\n",
297 | " value_list.append(source_nodes.shape[0])\n",
298 | "\n",
299 | " self.sizes = {k: int(1 + max(v) + self.slack * np.std(v))\n",
300 | " for k, v in sizes.items()}\n",
301 | "\n",
302 | " def pad_graph(self, graph: graph_struct.GraphStruct) -\u003e graph_struct.GraphStruct:\n",
303 | " \"\"\"Pads node-sets and edge-sets, with zeros, to max-seen during `calc..`.\n",
304 | "\n",
305 | " This function is useful for running on TPU hardware.\n",
306 | "\n",
307 | " Args:\n",
308 | " graph: contains any number of nodes and edges.\n",
309 | "\n",
310 | " Returns:\n",
311 | " graph with deterministic number of nodes and edges. See class docstring.\n",
312 | " \"\"\"\n",
313 | " if not self.sizes:\n",
314 | " raise ValueError(\n",
315 | " 'No statistics have been initialized. '\n",
316 | " 'Perhaps you forgot to invoke \"calculate_pad_statistics\"?')\n",
317 | " # Edge set name -\u003e (1D vectors containing endpoints**), {\"feature\": Tensor})\n",
318 | " edges: Edges = {}\n",
319 | " # ** tuple should have 2 entries for directed graphs\n",
320 | "\n",
321 | " nodes: Nodes = {}\n",
322 | "\n",
323 | " # For every key in `edges`, store names of node sets that `key` edge\n",
324 | " # connects.\n",
325 | " schema = graph.schema\n",
326 | "\n",
327 | " e = self._engine # for short.\n",
328 | " for node_name, node_features in graph.nodes.items():\n",
329 | " padded_features = {}\n",
330 | " desired_size = self.sizes[('nodes', node_name)]\n",
331 | "\n",
332 | " for feature_name, feature in node_features.items():\n",
333 | " feature = feature[:desired_size] # if `is_oversized`.\n",
334 | " pad = self._engine.maximum(\n",
335 | " desired_size - self._engine.shape(feature)[0], 0)\n",
336 | " zeros = e.zeros(\n",
337 | " tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype)\n",
338 | " padded_feature = e.concat([feature, zeros], axis=0)\n",
339 | " padded_feature = e.reshape(\n",
340 | " padded_feature, [desired_size] + list(padded_feature.shape[1:]))\n",
341 | " padded_features[feature_name] = padded_feature\n",
342 | "\n",
343 | " nodes[node_name] = padded_features\n",
344 | "\n",
345 | " for edge_name, (edge_endpoints, features) in graph.edges.items():\n",
346 | " padded_features = {}\n",
347 | " padded_endpoints = []\n",
348 | " desired_size = self.sizes[('edges', edge_name)]\n",
349 | " current_size = e.shape(edge_endpoints[0])[0]\n",
350 | "\n",
351 | " pad = e.maximum(desired_size - current_size, 0)\n",
352 | " e.assert_greater(pad, -1)\n",
353 | "\n",
354 | " for feature_name, feature in features.items():\n",
355 | " feature = feature[:desired_size] # if `is_oversized`.\n",
356 | " zeros = e.zeros(\n",
357 | " tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype\n",
358 | " )\n",
359 | " padded_feature = e.concat([feature, zeros], axis=0)\n",
360 | " padded_feature = e.reshape(\n",
361 | " padded_feature, [desired_size] + list(padded_feature.shape[1:])\n",
362 | " )\n",
363 | " padded_features[feature_name] = padded_feature\n",
364 | "\n",
365 | " edge_endpoints = [node_ids[:desired_size] for node_ids in edge_endpoints]\n",
366 | " # [[src1_is_valid, src2_is_valid, ...], [tgt1_is_valid, ...]]\n",
367 | " valid = e.cast(\n",
368 | " [\n",
369 | " ids \u003c self.sizes[('nodes', node_name)]\n",
370 | " for ids, node_name in zip(edge_endpoints, schema[edge_name])\n",
371 | " ],\n",
372 | " dtype=bool,\n",
373 | " )\n",
374 | " valid = e.reduce_all(valid, axis=0)\n",
375 | "\n",
376 | " for node_ids, node_name in zip(edge_endpoints, schema[edge_name]):\n",
377 | " # Universe size (e.g., of source or target).\n",
378 | " max_endpoint = self.sizes[('nodes', node_name)] - 1\n",
379 | " node_ids = node_ids[:desired_size]\n",
380 | " node_ids = e.boolean_mask(node_ids, valid)\n",
381 | " pad = desired_size - e.shape(node_ids)[0] # Need only to compute once.\n",
382 | "\n",
383 | " padded_ids = e.concat([\n",
384 | " node_ids,\n",
385 | " e.ones((pad), dtype=node_ids.dtype) * max_endpoint\n",
386 | " ], axis=0)\n",
387 | " padded_ids = e.reshape(padded_ids, [desired_size])\n",
388 | " padded_endpoints.append(padded_ids)\n",
389 | "\n",
390 | " edges[edge_name] = (tuple(padded_endpoints), padded_features)\n",
391 | "\n",
392 | " graph = graph_struct.GraphStruct.new(nodes=nodes, edges=edges, schema=schema)\n",
393 | " return graph\n",
394 | "\n",
395 | "\n",
396 | "\n",
397 | "## gnn.py\n",
398 | "class GIN(nn.Module):\n",
399 | " \"\"\"Graph Isomorphism Network: https://arxiv.org/pdf/1810.00826.pdf.\"\"\"\n",
400 | "\n",
401 | " output_dim: int\n",
402 | " num_hidden_layers: int = 1\n",
403 | " hidden_dim: int = 32\n",
404 | " epsilon: float = 0.1 # See GIN paper (link above)\n",
405 | "\n",
406 | " def setup(self):\n",
407 | " layer_dims = [self.hidden_dim] * self.num_hidden_layers\n",
408 | " self.layers = [\n",
409 | " nn.Dense(dim, use_bias=False, dtype=jnp.bfloat16) for dim in layer_dims\n",
410 | " ]\n",
411 | " self.out_layer = nn.Dense(\n",
412 | " self.output_dim, use_bias=False, dtype=jnp.bfloat16\n",
413 | " )\n",
414 | "\n",
415 | " def __call__(self, graph: graph_struct.GraphStruct) -\u003e jax.Array:\n",
416 | " x = graph.nodes['nodes']['lpe']\n",
417 | " adj = graph.adj(sdjnp.engine, 'edges')\n",
418 | " adj = adj.add_eye(1 + self.epsilon) # self connections with 1+eps weight.\n",
419 | "\n",
420 | " for i, layer in enumerate(self.layers):\n",
421 | " x = layer(adj @ x)\n",
422 | " if i \u003c self.num_hidden_layers:\n",
423 | " x = nn.relu(x)\n",
424 | " x = jnp.concat(x, axis=-1)\n",
425 | " return self.out_layer(x)\n",
426 | "\n",
427 | "\n",
428 | "class GCN(nn.Module):\n",
429 | " \"\"\"Graph convolutional network: https://arxiv.org/pdf/1609.02907.pdf.\"\"\"\n",
430 | "\n",
431 | " output_dim: int\n",
432 | " num_hidden_layers: int = 1\n",
433 | " hidden_dim: int = 32\n",
434 | "\n",
435 | " def setup(self):\n",
436 | " layer_dims = [self.hidden_dim] * self.num_hidden_layers\n",
437 | " self.layers = [nn.Dense(dim, use_bias=False) for dim in layer_dims]\n",
438 | " self.out_layer = nn.Dense(\n",
439 | " self.output_dim, use_bias=False, dtype=jnp.bfloat16\n",
440 | " )\n",
441 | "\n",
442 | " def __call__(self, graph: graph_struct.GraphStruct) -\u003e jax.Array:\n",
443 | " x = graph.nodes['nodes']['lpe']\n",
444 | " adj = graph.adj(sdjnp.engine, 'edges')\n",
445 | " adj_symnorm = (adj + adj.transpose()).add_eye().normalize_symmetric()\n",
446 | "\n",
447 | " for i, layer in enumerate(self.layers):\n",
448 | " x = layer(adj_symnorm @ x)\n",
449 | " if i \u003c self.num_hidden_layers:\n",
450 | " x = nn.relu(x)\n",
451 | " x = jnp.concat(x, axis=-1)\n",
452 | " return self.out_layer(x)\n",
453 | "\n",
454 | "\n",
455 | "## sampler.py\n",
456 | "\n",
457 | "@dataclasses.dataclass\n",
458 | "class SamplerOutput:\n",
459 | "\n",
460 | " # Decoded samples from the model.\n",
461 | " text: list[str]\n",
462 | "\n",
463 | " # Per-step logits used during sampling.\n",
464 | " logits: list[list[float]]\n",
465 | "\n",
466 | " # Tokens corresponding to the generated samples.\n",
467 | " tokens: list[list[int]]\n",
468 | "\n",
469 | " graph_embeddings: list[jnp.ndarray]\n",
470 | "\n",
471 | "\n",
472 | "class GraphTokenSampler:\n",
473 | " \"\"\"Sampler for GraphToken.\"\"\"\n",
474 | "\n",
475 | " def __init__(\n",
476 | " self,\n",
477 | " gnn: nn.Module,\n",
478 | " llm: transformer_lib.Transformer,\n",
479 | " vocab: spm.SentencePieceProcessor,\n",
480 | " params: Mapping[str, Any],\n",
481 | " gnn_token_template: str = r'\u003cunused%d\u003e',\n",
482 | " ):\n",
483 | " \"\"\"Initializes the sampler.\n",
484 | "\n",
485 | " Args:\n",
486 | " gnn: The GNN model.\n",
487 | " llm: The LLM model.\n",
488 | " vocab: The vocab used by the LLM.\n",
489 | " params: The parameters for the GNN and LLM. This should contain the params\n",
490 | " for the gnn under params['gnn'] and the params for the llm under\n",
491 | " params['transformer']\n",
492 | " gnn_token_template: The token used to represent the GNN embedding.\n",
493 | " \"\"\"\n",
494 | "\n",
495 | " self._gnn = gnn\n",
496 | " self._llm = llm\n",
497 | " self._params = params\n",
498 | " self._vocab = vocab\n",
499 | " self._gnn_token_template = gnn_token_template\n",
500 | " self._sampler = sampler_lib.Sampler(\n",
501 | " transformer=self._llm,\n",
502 | " vocab=self._vocab,\n",
503 | " params=self._params['transformer'],\n",
504 | " )\n",
505 | "\n",
506 | " def __call__(\n",
507 | " self,\n",
508 | " input_strings: Sequence[str],\n",
509 | " input_graphs: Sequence[graph_struct.GraphStruct],\n",
510 | " total_generation_steps: int,\n",
511 | " echo: bool = False,\n",
512 | " return_logits: bool = True,\n",
513 | " forbidden_tokens: Sequence[str] | None = None,\n",
514 | " ) -\u003e SamplerOutput:\n",
515 | " \"\"\"Samples from the model.\n",
516 | "\n",
517 | " Args:\n",
518 | " input_strings: The input strings.\n",
519 | " input_graphs: The input graphs.\n",
520 | " total_generation_steps: The number of steps to generate.\n",
521 | " echo: Whether to echo the input.\n",
522 | " return_logits: Whether to return the logits.\n",
523 | " forbidden_tokens: Tokens that are forbidden, in addition to the GNN token.\n",
524 | "\n",
525 | " Returns:\n",
526 | " The sampled output.\n",
527 | " \"\"\"\n",
528 | " assert len(input_graphs) == len(input_strings), (\n",
529 | " len(input_graphs),\n",
530 | " len(input_strings),\n",
531 | " )\n",
532 | " augmented_inputs = []\n",
533 | " full_forbidden_tokens = []\n",
534 | " if forbidden_tokens is not None:\n",
535 | " full_forbidden_tokens += forbidden_tokens\n",
536 | " graph_embeddings = []\n",
537 | " augmented_transformer_params = self._params['transformer']\n",
538 | "\n",
539 | " placeholder_token = PLACEHOLDER_TOKEN\n",
540 | " full_forbidden_tokens.append(placeholder_token)\n",
541 | " placeholder_token_id = self._vocab.EncodeAsIds(placeholder_token)\n",
542 | " assert len(placeholder_token_id) == 1, placeholder_token\n",
543 | " placeholder_token_id = placeholder_token_id[0]\n",
544 | "\n",
545 | " for prompt, graph in zip(input_strings, input_graphs):\n",
546 | " embed = self._gnn.apply(self._params['gnn'], graph)\n",
547 | " assert (\n",
548 | " self._params['transformer']['embedder']['input_embedding'][\n",
549 | " placeholder_token_id\n",
550 | " ].shape\n",
551 | " == embed.shape\n",
552 | " )\n",
553 | " augmented_transformer_params['embedder']['input_embedding'] = (\n",
554 | " augmented_transformer_params['embedder']['input_embedding']\n",
555 | " .at[placeholder_token_id]\n",
556 | " .set(embed)\n",
557 | " )\n",
558 | " graph_embeddings.append(embed)\n",
559 | " augmented_inputs.append(placeholder_token + prompt)\n",
560 | "\n",
561 | " self._sampler.params = augmented_transformer_params\n",
562 | " o = self._sampler(\n",
563 | " input_strings=augmented_inputs,\n",
564 | " total_generation_steps=total_generation_steps,\n",
565 | " echo=echo,\n",
566 | " return_logits=return_logits,\n",
567 | " forbidden_tokens=full_forbidden_tokens,\n",
568 | " )\n",
569 | " return SamplerOutput(\n",
570 | " **dataclasses.asdict(o),\n",
571 | " graph_embeddings=graph_embeddings,\n",
572 | " )\n",
573 | "\n",
574 | "\n",
575 | "\n",
576 | "@chex.dataclass(frozen=True)\n",
577 | "class TrainingInput:\n",
578 | " \"\"\"Batch of training data for a GraphToken model.\"\"\"\n",
579 | "\n",
580 | " # Input tokens given to the model\n",
581 | " input_tokens: np.ndarray # size [B, L]\n",
582 | "\n",
583 | " # A mask that determines which tokens contribute to the target loss\n",
584 | " # calculation.\n",
585 | " target_mask: np.ndarray # size [B, L]\n",
586 | "\n",
587 | " input_graphs: list[graph_struct.GraphStruct] # size [B]\n",
588 | "\n",
589 | " # Ground truth for the input tokens, if representable as an integer.\n",
590 | " # For boolean classification tasks, this is 0/1.\n",
591 | " parsed_ground_truth: np.ndarray | None # size [B]\n",
592 | "\n",
593 | "\n",
594 | "def parse_int(s: str) -\u003e int:\n",
595 | " \"\"\"Parse a string as an integer.\"\"\"\n",
596 | " return int(float(s.strip()))\n",
597 | "\n",
598 | "\n",
599 | "def parse_yes_no(s: str) -\u003e bool:\n",
600 | " \"\"\"Parse a string as a yes/no answer, looking at the first 10 chars.\"\"\"\n",
601 | " return 'yes' in s.lower()[:10]\n",
602 | "\n",
603 | "PLACEHOLDER_TOKEN = '\u003cunused0\u003e'\n",
604 | "\n",
605 | "def graphqa_ds(\n",
606 | " vocab: spm.SentencePieceProcessor,\n",
607 | " encoded_examples: list,\n",
608 | " padder: graph_struct.FixedSizePadder | None = None,\n",
609 | " max_tokens: int = 100,\n",
610 | " gt_parser: Callable[[str], Any] | None = None,\n",
611 | ") -\u003e tuple[graph_struct.FixedSizePadder, list[TrainingInput]]:\n",
612 | " \"\"\"Load a GraphQA dataset as a list of TrainingInput.\n",
613 | "\n",
614 | " Args:\n",
615 | " vocab: The vocab to use for tokenization.\n",
616 | " encoded_examples: List of encoded examples generated by GraphQA\n",
617 | " padder: The padder to use for padding the graph. If None, a new padder will\n",
618 | " be created and returned. This is so a padder can be shared across multiple\n",
619 | " datasets / splits.\n",
620 | " max_tokens: The maximum number of tokens to allow in the input. For\n",
621 | " 'task_only' prompting this can be quite small (100 tokens is plenty)\n",
622 | " gt_parser: A function to parse the ground truth from the answer string, used\n",
623 | " to supply the 'parsed_ground_truth' field in the TrainingInput.\n",
624 | "\n",
625 | " Returns:\n",
626 | " The padder used for padding the graphs, and a list of TrainingInput.\n",
627 | " \"\"\"\n",
628 | "\n",
629 | " output = []\n",
630 | " for ex in encoded_examples:\n",
631 | " query = PLACEHOLDER_TOKEN + ex['question'][ex['question'].find('Q:'):]\n",
632 | " answer = ex['answer']\n",
633 | " graph = to_graph_struct(ex['graph'])\n",
634 | " query_tokens = vocab.EncodeAsIds(query)\n",
635 | " answer_tokens = vocab.EncodeAsIds(answer) + [vocab.eos_id()]\n",
636 | " input_tokens = np.array([vocab.bos_id()] + query_tokens + answer_tokens)\n",
637 | " target_mask = np.zeros_like(input_tokens, dtype=jnp.int32)\n",
638 | " # Add one for BOS token\n",
639 | " target_mask[len(query_tokens) + 1 :] = 1\n",
640 | " orig_len = len(query_tokens) + len(answer_tokens) + 1\n",
641 | " input_tokens = np.pad(\n",
642 | " input_tokens,\n",
643 | " [[0, max_tokens - orig_len]],\n",
644 | " constant_values=vocab.pad_id(),\n",
645 | " )\n",
646 | "\n",
647 | " target_mask = np.pad(target_mask, [[0, max_tokens - orig_len]])\n",
648 | "\n",
649 | "\n",
650 | " # The GNN library that we are using requires a global feature. We set\n",
651 | " # a fake value here, but it is unused otherwise.\n",
652 | " #graph = graph.update(nodes={'g': {'foo': np.zeros([1])}})\n",
653 | "\n",
654 | " output.append(\n",
655 | " TrainingInput(\n",
656 | " input_tokens=np.array([input_tokens]),\n",
657 | " target_mask=np.array([target_mask]),\n",
658 | " input_graphs=[graph],\n",
659 | " parsed_ground_truth=np.array(gt_parser(answer)) if gt_parser else None,\n",
660 | " )\n",
661 | " )\n",
662 | " if padder is None:\n",
663 | " padder = FixedSizePadder(sdnp.engine)\n",
664 | " padder.calculate_pad_statistics(\n",
665 | " [e.input_graphs[0] for e in output], len(output)\n",
666 | " )\n",
667 | " for o in output:\n",
668 | " o.input_graphs[0] = padder.pad_graph(o.input_graphs[0])\n",
669 | " return padder, output\n",
670 | "\n",
671 | "\n",
672 | "def decode_questions(\n",
673 | " training_input: TrainingInput, vocab: spm.SentencePieceProcessor\n",
674 | ") -\u003e list[str]:\n",
675 | " \"\"\"Decode the question from the input tokens. (ignoring the first 2).\"\"\"\n",
676 | " b, l = training_input.input_tokens.shape\n",
677 | " question_tokens = []\n",
678 | " for i in range(b):\n",
679 | " question_tokens.append([])\n",
680 | " # Skip the first two tokens (BOS and control token).\n",
681 | " for j in range(2, l):\n",
682 | " if training_input.target_mask[i, j] == 1:\n",
683 | " break\n",
684 | " question_tokens[i].append(int(training_input.input_tokens[i, j]))\n",
685 | " return [''.join(vocab.DecodeIds(q)) for q in question_tokens]\n",
686 | "\n",
687 | "\n",
688 | "## training_loop\n",
689 | "import functools\n",
690 | "from typing import Any, MutableMapping\n",
691 | "\n",
692 | "import chex\n",
693 | "from flax import linen as nn\n",
694 | "from gemma import transformer as transformer_lib\n",
695 | "import jax\n",
696 | "import jax.numpy as jnp\n",
697 | "import optax\n",
698 | "import tqdm\n",
699 | "\n",
700 | "Params = MutableMapping[str, Any]\n",
701 | "\n",
702 | "\n",
703 | "def get_attention_mask_and_positions(\n",
704 | " example: jax.Array,\n",
705 | " pad_id: int,\n",
706 | ") -\u003e tuple[jax.Array, jax.Array]:\n",
707 | " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n",
708 | " pad_mask = example != pad_id\n",
709 | " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n",
710 | " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n",
711 | " return current_token_position, attention_mask\n",
712 | "\n",
713 | "\n",
714 | "def forward_and_loss_fn(\n",
715 | " params: Params,\n",
716 | " *,\n",
717 | " gnn: nn.Module,\n",
718 | " llm: transformer_lib.Transformer,\n",
719 | " input_tokens: jax.Array, # Shape [B, L]\n",
720 | " input_graphs: list[graph_struct.GraphStruct], # Shape [B]\n",
721 | " input_mask: jax.Array, # Shape [B, L]\n",
722 | " positions: jax.Array, # Shape [B, L]\n",
723 | " attention_mask: jax.Array, # [B, L, L]\n",
724 | " placeholder_token_id: int,\n",
725 | ") -\u003e jax.Array:\n",
726 | " \"\"\"Forward pass and loss function.\n",
727 | "\n",
728 | " Args:\n",
729 | " params: Params for the gnn and transformer. The gnn params are stored in\n",
730 | " params['gnn'] and the llm params are stored in params['transformer'].\n",
731 | " gnn: gnn model to call.\n",
732 | " llm: gemma transformer model to call.\n",
733 | " input_tokens: input tokens sequence, shape [B, L].\n",
734 | " input_graphs: input graphs.\n",
735 | " input_mask: tokens to ignore when computing the loss, shape [B, L].\n",
736 | " positions: relative position of each token, shape [B, L].\n",
737 | " attention_mask: input attention mask, shape [B, L].\n",
738 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n",
739 | " graph embeddings.\n",
740 | "\n",
741 | " Returns:\n",
742 | " Softmax cross-entropy loss for the next-token prediction task.\n",
743 | " \"\"\"\n",
744 | " # Right now we only support batch_size = 1\n",
745 | " chex.assert_axis_dimension(input_tokens, 0, 1)\n",
746 | " chex.assert_equal_shape([input_tokens, input_mask, positions])\n",
747 | " chex.assert_axis_dimension(attention_mask, 0, 1)\n",
748 | " chex.assert_equal(len(input_graphs), 1)\n",
749 | "\n",
750 | " # Get the GNN embedding and update the transformer input embedding for a\n",
751 | " # control token.\n",
752 | " graph_embed = gnn.apply(params['gnn'], input_graphs[0])\n",
753 | " params['transformer']['embedder']['input_embedding'] = (\n",
754 | " params['transformer']['embedder']['input_embedding']\n",
755 | " .at[placeholder_token_id]\n",
756 | " .set(graph_embed)\n",
757 | " )\n",
758 | " # Forward pass on the input data.\n",
759 | " # No attention cache is needed here.\n",
760 | " logits, _ = llm.apply(\n",
761 | " {'params': params['transformer']},\n",
762 | " input_tokens,\n",
763 | " positions,\n",
764 | " None, # Attention cache is None.\n",
765 | " attention_mask,\n",
766 | " )\n",
767 | "\n",
768 | " # Exclude the last step as it does not appear in the targets.\n",
769 | " logits = logits[0, :-1]\n",
770 | "\n",
771 | " # Similarly, the first token cannot be predicted.\n",
772 | " target_tokens = input_tokens[0, 1:]\n",
773 | " target_mask = input_mask[0, 1:]\n",
774 | "\n",
775 | " # Convert the target labels into one-hot encoded vectors.\n",
776 | " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n",
777 | "\n",
778 | " # Don't update on unwanted tokens.\n",
779 | " one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., jnp.newaxis]\n",
780 | "\n",
781 | " # Normalisation factor.\n",
782 | " norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)\n",
783 | "\n",
784 | " # Return the nll loss.\n",
785 | " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor\n",
786 | "\n",
787 | "\n",
788 | "@functools.partial(\n",
789 | " jax.jit,\n",
790 | " static_argnames=['gnn', 'llm', 'optimizer', 'pad_id', 'placeholder_token_id'],\n",
791 | ")\n",
792 | "def train_step(\n",
793 | " llm: transformer_lib.Transformer,\n",
794 | " gnn: nn.Module,\n",
795 | " params: MutableMapping[str, Any],\n",
796 | " optimizer: optax.GradientTransformation,\n",
797 | " opt_state: optax.OptState,\n",
798 | " pad_id: int,\n",
799 | " example: TrainingInput,\n",
800 | " placeholder_token_id: int,\n",
801 | ") -\u003e tuple[jax.Array, Params, optax.OptState]:\n",
802 | " \"\"\"Train step.\n",
803 | "\n",
804 | " Args:\n",
805 | " llm: gemma transformer model.\n",
806 | " gnn: gnn model.\n",
807 | " params: model's input parameters.\n",
808 | " optimizer: optax optimizer to use.\n",
809 | " opt_state: input optimizer's state.\n",
810 | " pad_id: id of the pad token.\n",
811 | " example: input batch.\n",
812 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n",
813 | " graph embeddings.\n",
814 | "\n",
815 | " Returns:\n",
816 | " Training loss, updated parameters, updated optimizer state.\n",
817 | " \"\"\"\n",
818 | "\n",
819 | " # Build the position and attention mask vectors.\n",
820 | " positions, attention_mask = get_attention_mask_and_positions(\n",
821 | " jnp.array(example.input_tokens), pad_id\n",
822 | " )\n",
823 | "\n",
824 | " # Forward and backward passes\n",
825 | " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(\n",
826 | " params,\n",
827 | " gnn=gnn,\n",
828 | " llm=llm,\n",
829 | " input_tokens=example.input_tokens,\n",
830 | " input_mask=example.target_mask,\n",
831 | " input_graphs=example.input_graphs,\n",
832 | " positions=positions,\n",
833 | " attention_mask=attention_mask,\n",
834 | " placeholder_token_id=placeholder_token_id,\n",
835 | " )\n",
836 | "\n",
837 | " updates, opt_state = optimizer.update(\n",
838 | " grads['gnn'], opt_state, params=params['gnn']\n",
839 | " )\n",
840 | " params['gnn'] = optax.apply_updates(params['gnn'], updates)\n",
841 | "\n",
842 | " return train_loss, params, opt_state\n",
843 | "\n",
844 | "\n",
845 | "@functools.partial(\n",
846 | " jax.jit, static_argnames=['gnn', 'llm', 'pad_id', 'placeholder_token_id']\n",
847 | ")\n",
848 | "def validation_step(\n",
849 | " gnn: nn.Module,\n",
850 | " llm: transformer_lib.Transformer,\n",
851 | " params: MutableMapping[str, Any],\n",
852 | " pad_id: int,\n",
853 | " example: TrainingInput,\n",
854 | " placeholder_token_id: int,\n",
855 | ") -\u003e jax.Array:\n",
856 | " \"\"\"Validation step.\n",
857 | "\n",
858 | " Args:\n",
859 | " gnn: gnn model.\n",
860 | " llm: gemma transformer model.\n",
861 | " params: model's input parameters. The gnn params are stored in params['gnn']\n",
862 | " and the llm params are stored in params['transformer'].\n",
863 | " pad_id: id of the pad token.\n",
864 | " example: input batch\n",
865 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n",
866 | " graph embeddings.\n",
867 | "\n",
868 | " Returns:\n",
869 | " Validation loss.\n",
870 | " \"\"\"\n",
871 | " jax_input = jax.tree.map(jnp.array, example)\n",
872 | " positions, attention_mask = get_attention_mask_and_positions(\n",
873 | " jax_input.input_tokens, pad_id\n",
874 | " )\n",
875 | " val_loss = forward_and_loss_fn(\n",
876 | " params,\n",
877 | " gnn=gnn,\n",
878 | " llm=llm,\n",
879 | " input_tokens=jax_input.input_tokens,\n",
880 | " input_mask=jax_input.target_mask,\n",
881 | " input_graphs=jax_input.input_graphs,\n",
882 | " positions=positions,\n",
883 | " attention_mask=attention_mask,\n",
884 | " placeholder_token_id=placeholder_token_id,\n",
885 | " )\n",
886 | " return val_loss\n",
887 | "\n",
888 | "\n",
889 | "@chex.dataclass(frozen=True)\n",
890 | "class TrainingConfig:\n",
891 | " learning_rate: float\n",
892 | " num_epochs: int\n",
893 | " eval_every_n: int\n",
894 | " batch_size: int\n",
895 | " max_steps: int | None = None\n",
896 | "\n",
897 | "\n",
898 | "def train_loop(\n",
899 | " llm: transformer_lib.Transformer,\n",
900 | " gnn: nn.Module,\n",
901 | " train_ds: list[TrainingInput],\n",
902 | " validation_ds: list[TrainingInput],\n",
903 | " params: Params,\n",
904 | " training_cfg: TrainingConfig,\n",
905 | " vocab: spm.SentencePieceProcessor,\n",
906 | ") -\u003e Params:\n",
907 | " \"\"\"Main training loop for GraphToken.\n",
908 | "\n",
909 | " Args:\n",
910 | " llm: Gemma transformer model.\n",
911 | " gnn: gnn model.\n",
912 | " train_ds: training dataset.\n",
913 | " validation_ds: validation dataset.\n",
914 | " params: Combined params for both the LLM and GNN. The GNN params are stored\n",
915 | " in params['gnn'] and the LLM params are stored in params['transformer'].\n",
916 | " training_cfg: training configuration.\n",
917 | " vocab: sentence piece vocabulary.\n",
918 | "\n",
919 | " Returns:\n",
920 | " Updated model's input parameters.\n",
921 | " \"\"\"\n",
922 | " optimizer = optax.lion(training_cfg.learning_rate)\n",
923 | " opt_state = optimizer.init(params['gnn'])\n",
924 | "\n",
925 | " avg_loss = 0\n",
926 | "\n",
927 | " placeholder_token_id = vocab.EncodeAsIds(PLACEHOLDER_TOKEN)\n",
928 | " assert (\n",
929 | " len(placeholder_token_id) == 1\n",
930 | " ), f'Placeholder token multiple ids: {placeholder_token_id}'\n",
931 | " placeholder_token_id = placeholder_token_id[0]\n",
932 | " # A first round of validation loss\n",
933 | " n_steps_eval = 0\n",
934 | " eval_loss = 0\n",
935 | "\n",
936 | " with tqdm.tqdm(range(training_cfg.num_epochs * len(train_ds))) as pbar:\n",
937 | " averaged_steps = 0\n",
938 | " for n_steps in pbar:\n",
939 | " train_example = train_ds[n_steps % len(train_ds)]\n",
940 | " train_loss, params, opt_state = train_step(\n",
941 | " gnn=gnn,\n",
942 | " llm=llm,\n",
943 | " params=params,\n",
944 | " optimizer=optimizer,\n",
945 | " opt_state=opt_state,\n",
946 | " pad_id=vocab.pad_id(),\n",
947 | " example=train_example,\n",
948 | " placeholder_token_id=placeholder_token_id,\n",
949 | " )\n",
950 | " averaged_steps += 1\n",
951 | " avg_loss += train_loss\n",
952 | " if n_steps and n_steps % training_cfg.eval_every_n == 0:\n",
953 | " val_iterator = validation_ds\n",
954 | " avg_loss /= averaged_steps\n",
955 | " averaged_steps = 0\n",
956 | " pbar.write(\n",
957 | " f'STEP {n_steps} training loss: {avg_loss}'\n",
958 | " )\n",
959 | " avg_loss = 0\n",
960 | " if (\n",
961 | " training_cfg.max_steps is not None\n",
962 | " and n_steps \u003e training_cfg.max_steps\n",
963 | " ):\n",
964 | " break\n",
965 | " if averaged_steps != 0:\n",
966 | " avg_loss /= averaged_steps\n",
967 | " pbar.write(\n",
968 | " f'STEP {n_steps} training loss: {avg_loss}'\n",
969 | " )\n",
970 | " return params\n",
971 | "\n",
972 | "def merge_params(llm_params, gnn_params):\n",
973 | " out = {}\n",
974 | " out.update(llm_params)\n",
975 | " out['gnn'] = gnn_params\n",
976 | " return out"
977 | ],
978 | "metadata": {
979 | "id": "RZKA-lU7IsZ8",
980 | "cellView": "form"
981 | },
982 | "execution_count": null,
983 | "outputs": []
984 | },
985 | {
986 | "cell_type": "markdown",
987 | "source": [
988 | "## Train a GraphToken Model"
989 | ],
990 | "metadata": {
991 | "id": "NSuoM8KbnXbF"
992 | }
993 | },
994 | {
995 | "cell_type": "markdown",
996 | "source": [
997 | "Load Gemma Weights"
998 | ],
999 | "metadata": {
1000 | "id": "oknoqsTMI78J"
1001 | }
1002 | },
1003 | {
1004 | "cell_type": "code",
1005 | "source": [
1006 | "params = params_lib.load_and_format_params(ckpt_path)\n",
1007 | "\n",
1008 | "# Reshard params over TPU device mesh\n",
1009 | "from jax.sharding import PartitionSpec as P\n",
1010 | "mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))\n",
1011 | "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
1012 | "def try_to_shard(x):\n",
1013 | " try:\n",
1014 | " return jax.device_put(x, sharding)\n",
1015 | " except:\n",
1016 | " return x\n",
1017 | "params = jax.tree_map(try_to_shard, params)\n",
1018 | "\n",
1019 | "\n",
1020 | "config_2b = transformer_lib.TransformerConfig.from_params(\n",
1021 | " params,\n",
1022 | " cache_size=128 # Number of time steps in the transformer's cache\n",
1023 | ")\n",
1024 | "model_2b = transformer_lib.Transformer(config=config_2b)\n",
1025 | "\n",
1026 | "# Load vocabulary\n",
1027 | "vocab = spm.SentencePieceProcessor()\n",
1028 | "assert vocab.Load(vocab_path)"
1029 | ],
1030 | "metadata": {
1031 | "id": "U1q6PiLSCTTa"
1032 | },
1033 | "execution_count": null,
1034 | "outputs": []
1035 | },
1036 | {
1037 | "cell_type": "markdown",
1038 | "source": [
1039 | "Generate some training data, for the CycleCheck task"
1040 | ],
1041 | "metadata": {
1042 | "id": "NUDd2EisJHX_"
1043 | }
1044 | },
1045 | {
1046 | "cell_type": "code",
1047 | "source": [
1048 | "from talk_like_a_graph import graph_generators\n",
1049 | "from talk_like_a_graph import graph_tasks\n",
1050 | "random_seed = 9876\n",
1051 | "\n",
1052 | "train_graphs = graph_generators.generate_graphs(number_of_graphs=500,\n",
1053 | " algorithm='er', # Erdos-Reyni random graphs\n",
1054 | " directed=False,\n",
1055 | " random_seed=random_seed)\n",
1056 | "test_graphs = graph_generators.generate_graphs(number_of_graphs=10,\n",
1057 | " algorithm='er', # Erdos-Reyni random graphs\n",
1058 | " directed=False,\n",
1059 | " random_seed=random_seed + 12385)\n",
1060 | "task = graph_tasks.CycleCheck()\n",
1061 | "train_examples = list(task.prepare_examples_dict(\n",
1062 | " train_graphs,\n",
1063 | " generator_algorithms = ['er']*len(train_graphs),\n",
1064 | " encoding_method='adjacency').values())\n",
1065 | "test_examples = list(task.prepare_examples_dict(\n",
1066 | " test_graphs,\n",
1067 | " generator_algorithms = ['er']*len(test_graphs),\n",
1068 | " encoding_method='adjacency').values())\n",
1069 | "padder, train_ds = graphqa_ds(vocab, train_examples, max_tokens=25)\n",
1070 | "_, test_ds = graphqa_ds(vocab, test_examples, max_tokens=25, padder=padder)"
1071 | ],
1072 | "metadata": {
1073 | "id": "93oCqfLYnXFr"
1074 | },
1075 | "execution_count": null,
1076 | "outputs": []
1077 | },
1078 | {
1079 | "cell_type": "markdown",
1080 | "source": [
1081 | "Train GraphToken"
1082 | ],
1083 | "metadata": {
1084 | "id": "dJTlxG9yUx9y"
1085 | }
1086 | },
1087 | {
1088 | "cell_type": "code",
1089 | "source": [
1090 | "gin = GIN(config_2b.embed_dim, num_hidden_layers=3, hidden_dim=4)\n",
1091 | "key = jax.random.PRNGKey(0)\n",
1092 | "gnn_params = gin.init(key, train_ds[0].input_graphs[0])\n",
1093 | "\n",
1094 | "\n",
1095 | "train_config = TrainingConfig(\n",
1096 | " learning_rate=0.0001, num_epochs=3, eval_every_n=250, batch_size=1\n",
1097 | ")\n",
1098 | "params_learned = train_loop(\n",
1099 | " llm=model_2b,\n",
1100 | " gnn=gin,\n",
1101 | " train_ds=train_ds,\n",
1102 | " validation_ds=test_ds,\n",
1103 | " params=merge_params(params, gnn_params),\n",
1104 | " training_cfg=train_config,\n",
1105 | " vocab=vocab,\n",
1106 | ")"
1107 | ],
1108 | "metadata": {
1109 | "id": "Uq3ZF3v7Ivo7"
1110 | },
1111 | "execution_count": null,
1112 | "outputs": []
1113 | },
1114 | {
1115 | "cell_type": "markdown",
1116 | "source": [
1117 | "Sample outputs"
1118 | ],
1119 | "metadata": {
1120 | "id": "4eaBMWhPchoJ"
1121 | }
1122 | },
1123 | {
1124 | "cell_type": "code",
1125 | "source": [
1126 | "from IPython.display import Markdown, display\n",
1127 | "\n",
1128 | "graph_token_sampler = GraphTokenSampler(\n",
1129 | " params=params_learned, llm=model_2b, gnn=gin, vocab=vocab\n",
1130 | ")\n",
1131 | "\n",
1132 | "\n",
1133 | "def get_graph_qa_question(ex):\n",
1134 | " with_graph = ex['question']\n",
1135 | " q_index = with_graph.find('Q:')\n",
1136 | " return with_graph[q_index:]\n",
1137 | "\n",
1138 | "for i in range(len(test_examples)):\n",
1139 | " tokenized_input = test_ds[i]\n",
1140 | " ex = test_examples[i]\n",
1141 | " prompt = get_graph_qa_question(ex)\n",
1142 | " if i == 0:\n",
1143 | " display(\n",
1144 | " Markdown(\n",
1145 | " '**Prompt:** '\n",
1146 | " + prompt\n",
1147 | " + '\\n\\n'\n",
1148 | " )\n",
1149 | " )\n",
1150 | " llm_output = graph_token_sampler(\n",
1151 | " [prompt],\n",
1152 | " tokenized_input.input_graphs,\n",
1153 | " total_generation_steps=15,\n",
1154 | " return_logits=False,\n",
1155 | " ).text[0]\n",
1156 | " display(Markdown(f'**LLM Output:** \"{llm_output}\"'))\n",
1157 | " display(\n",
1158 | " Markdown(f\"**Ground Truth:** {ex['answer']}\")\n",
1159 | " )\n",
1160 | " display(Markdown('-' * 80))\n",
1161 | " print()\n",
1162 | "\n"
1163 | ],
1164 | "metadata": {
1165 | "id": "0p4CxB2Lcilw"
1166 | },
1167 | "execution_count": null,
1168 | "outputs": []
1169 | },
1170 | {
1171 | "cell_type": "markdown",
1172 | "source": [
1173 | "## Exercise: Train a model for a different task.\n",
1174 | "\n",
1175 | "Take the above code and modify it for the NodeCount task.\n",
1176 | "How does your model perform?\n",
1177 | "\n",
1178 | "If your kernel runs out of memory run the following code to clear the TPU memory, then re-run the code block labeled 'Load Gemma weights' and retry.\n",
1179 | "```\n",
1180 | "for a in jax.live_arrays():\n",
1181 | " a.delete()\n",
1182 | "```"
1183 | ],
1184 | "metadata": {
1185 | "id": "uuW3rA3-qKy9"
1186 | }
1187 | }
1188 | ]
1189 | }
1190 |
--------------------------------------------------------------------------------
/tutorial/README.md:
--------------------------------------------------------------------------------
1 | #
Tutorial on Graph Reasoning with LLMs (GReaL)
2 |
3 | ## **KDD'24 Tutorial**
4 |
5 | ### Sunday, August 25th, 2024
6 |
7 | Centre de Convencions Internacional de Barcelona
8 | P1 122-123
9 | 2:00 pm - 5:00 pm CEST
10 |
11 | ### Overview
12 |
13 | Graphs are a powerful tool for representing and analyzing complex relationships
14 | in real-world applications. Large Language Models (LLMs) have demonstrated
15 | impressive capabilities by advancing state-of-the-art on many language-based
16 | benchmarks. Their ability to process and understand natural language open
17 | exciting possibilities in various domains. Despite the remarkable progress in
18 | automated reasoning with natural text, reasoning on graphs with LLMs remains an
19 | understudied problem that has recently gained more attention.
20 |
21 | This tutorial builds upon recent advances in expressing reasoning problems
22 | through the lens of tasks on graph data. The first part of the tutorial will
23 | provide an in-depth discussion of techniques for representing graphs as inputs
24 | to LLMs. The second, hands-on, portion will demonstrate these techniques in a
25 | practical setting.
26 |
27 | ### Schedule
28 |
29 | This talk is being given at [KDD 2024](https://kdd2024.kdd.org/) at 2pm in room
30 | P1 122-123 at the Centre de Convencions Internacional de Barcelona.
31 |
32 | ### Tutorial Material
33 |
34 | The slides and notebooks used in the tutorial are available here:
35 |
36 | - [Tutorial Slides](https://drive.google.com/file/d/16rHZVyCeGDfY3djVafKHFPhnabOVtd03/)
37 |
38 | - [Talk Like a Graph Notebook](https://github.com/google-research/talk-like-a-graph/blob/main/tutorial/KDD-Tutorial-1-Talk-Like-a-Graph.ipynb)
39 |
40 | - [GraphToken (Let Your Graph Do the Talking) Notebook](https://github.com/google-research/talk-like-a-graph/blob/main/tutorial/KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb)
41 |
42 |
43 |
44 | ### Additional Resources
45 |
46 | *coming soon*
47 |
--------------------------------------------------------------------------------
/tutorial/imgs/let_your_graph_do_the_talking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/tutorial/imgs/let_your_graph_do_the_talking.png
--------------------------------------------------------------------------------
/tutorial/imgs/talk_like_a_graph_colab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/tutorial/imgs/talk_like_a_graph_colab.png
--------------------------------------------------------------------------------