├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── talk_like_a_graph ├── README.md ├── __init__.py ├── graph_generators.py ├── graph_generators_runner.py ├── graph_generators_test.py ├── graph_metrics.py ├── graph_metrics_test.py ├── graph_tasks.py ├── graph_tasks_generator.py ├── graph_tasks_utils.py ├── graph_text_encoders.py ├── graph_text_encoders_test.py └── name_dictionaries.py └── tutorial ├── KDD-Tutorial-1-Talk-Like-a-Graph.ipynb ├── KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb ├── README.md └── imgs ├── let_your_graph_do_the_talking.png └── talk_like_a_graph_colab.png /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Encoding Graphs and Structured Data in Language Models 2 | 3 | This repository contains code for 4 | [Talk like a Graph: Encoding Graphs for Large Language Models](https://arxiv.org/abs/2310.04560) 5 | and 6 | [Let Your Graph Do the Talking: Encoding Structured Data for LLMs](https://arxiv.org/abs/2402.05862). 7 | 8 | ## Cite us 9 | 10 | If you use this package for published work, please cite the following. 11 | 12 | For the "Talk like a Graph" project: 13 | 14 | ``` 15 | @inproceedigs{fatemi2024talk, 16 | title={Talk like a Graph: Encoding Graphs for Large Language Models}, 17 | author={Bahare Fatemi and Jonathan Halcrow and Bryan Perozzi}, 18 | booktitle={International Conference on Learning Representations (ICLR)}, 19 | year={2024} 20 | } 21 | ``` 22 | 23 | For Graph Token (a.k.a "Let Your Graph Do the Talking"): 24 | 25 | ``` 26 | @misc{perozzi2024letgraphtalkingencoding, 27 | title={Let Your Graph Do the Talking: Encoding Structured Data for LLMs}, 28 | author={Bryan Perozzi and Bahare Fatemi and Dustin Zelle and Anton Tsitsulin and Mehran Kazemi and Rami Al-Rfou and Jonathan Halcrow}, 29 | year={2024}, 30 | eprint={2402.05862}, 31 | archivePrefix={arXiv}, 32 | primaryClass={cs.LG}, 33 | url={https://arxiv.org/abs/2402.05862}, 34 | } 35 | ``` 36 | 37 | ## Disclaimer 38 | 39 | This is not an official Google product. 40 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/__init__.py -------------------------------------------------------------------------------- /talk_like_a_graph/README.md: -------------------------------------------------------------------------------- 1 | # Using Large Language Models to Solve Graph Problems 2 | 3 | This repository contains the code to generate graph reasoning problems with 4 | different graph generator algorithms and graph encoding methods, as well as 5 | different prompting techniques. 6 | 7 | The graph tasks are `edge existence`, `node degree`, `node count`, `edge count`, 8 | `connected nodes`, `disconnected nodes`, `cycle check`, `reachability`, 9 | `shortest path`, `maximum flow`, `node classification`, and `triangle counting`. 10 | 11 | The datasets used here are proposed in our paper: 12 | [Talk like a Graph: Encoding Graphs for Large Language Models](https://arxiv.org/abs/2310.04560). 13 | 14 | ### Generating graphs 15 | 16 | ```sh 17 | ./graphqa/graph_generator.sh 18 | ``` 19 | 20 | ### Generating files for tasks 21 | 22 | ```sh 23 | ./graphqa/task_generator.sh 24 | ``` 25 | 26 | ## Contact us 27 | 28 | For questions or comments about the implementation, please contact 29 | baharef@google.com. 30 | 31 | ## Cite us 32 | 33 | If you use this package for published work, please cite the following: 34 | 35 | ``` 36 | @inproceedigs{fatemi2024talk, 37 | title={Talk like a Graph: Encoding Graphs for Large Language Models}, 38 | author={Bahare Fatemi and Jonathan Halcrow and Bryan Perozzi}, 39 | booktitle={International Conference on Learning Representations (ICLR)}, 40 | year={2024} 41 | } 42 | ``` 43 | 44 | ## Disclaimer 45 | 46 | This is not an official Google product. 47 | 48 | # Placeholder for internal data notes. -------------------------------------------------------------------------------- /talk_like_a_graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/talk_like_a_graph/__init__.py -------------------------------------------------------------------------------- /talk_like_a_graph/graph_generators.py: -------------------------------------------------------------------------------- 1 | r"""Random graph generation.""" 2 | 3 | import random 4 | 5 | import networkx as nx 6 | import numpy as np 7 | 8 | 9 | _NUMBER_OF_NODES_RANGE = { 10 | "small": np.arange(5, 10), 11 | "medium": np.arange(10, 15), 12 | "large": np.arange(15, 20), 13 | } 14 | _NUMBER_OF_COMMUNITIES_RANGE = { 15 | "small": np.arange(2, 4), 16 | "medium": np.arange(2, 8), 17 | "large": np.arange(2, 10), 18 | } 19 | 20 | 21 | def generate_graphs( 22 | number_of_graphs: int, 23 | algorithm: str, 24 | directed: bool, 25 | random_seed: int = 1234, 26 | er_min_sparsity: float = 0.0, 27 | er_max_sparsity: float = 1.0, 28 | ) -> list[nx.Graph]: 29 | """Generating multiple graphs using the provided algorithms. 30 | 31 | Args: 32 | number_of_graphs: number of graphs to generate 33 | algorithm: the random graph generator algorithm 34 | directed: whether to generate directed or undirected graphs. 35 | random_seed: the random seed to generate graphs with. 36 | er_min_sparsity: minimum sparsity of er graphs. 37 | er_max_sparsity: maximum sparsity of er graphs. 38 | 39 | Returns: 40 | generated_graphs: a list of nx graphs. 41 | Raises: 42 | NotImplementedError: if the algorithm is not yet implemented. 43 | """ 44 | 45 | random.seed(random_seed) 46 | np.random.seed(random_seed) 47 | 48 | generated_graphs = [] 49 | graph_sizes = random.choices( 50 | list(_NUMBER_OF_NODES_RANGE.keys()), k=number_of_graphs 51 | ) 52 | random_state = np.random.RandomState(random_seed) 53 | if algorithm == "er": 54 | for i in range(number_of_graphs): 55 | sparsity = random.uniform(er_min_sparsity, er_max_sparsity) 56 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 57 | generated_graphs.append( 58 | nx.erdos_renyi_graph( 59 | number_of_nodes, sparsity, seed=random_state, directed=directed 60 | ) 61 | ) 62 | elif algorithm == "ba": 63 | for i in range(number_of_graphs): 64 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 65 | m = random.randint(1, number_of_nodes - 1) 66 | generated_graph = nx.barabasi_albert_graph( 67 | number_of_nodes, m, seed=random_state 68 | ) 69 | if directed: 70 | generated_graphs.append(randomize_directions(generated_graph)) 71 | else: 72 | generated_graphs.append(generated_graph) 73 | elif algorithm == "sbm": 74 | for i in range(number_of_graphs): 75 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 76 | number_of_communities = random.choice( 77 | _NUMBER_OF_COMMUNITIES_RANGE[graph_sizes[i]] 78 | ) 79 | # sizes forms number of nodes in communities. 80 | sizes = [] 81 | for _ in range(number_of_communities - 1): 82 | sizes.append( 83 | random.randint( 84 | 1, 85 | max( 86 | 1, 87 | number_of_nodes - sum(sizes) - (number_of_communities - 1), 88 | ), 89 | ) 90 | ) 91 | sizes.append(number_of_nodes - sum(sizes)) 92 | 93 | # p forms probabilities of communities connecting each other. 94 | p = np.random.uniform(size=(number_of_communities, number_of_communities)) 95 | if random.uniform(0, 1) < 0.5: 96 | p = np.maximum(p, p.transpose()) 97 | else: 98 | p = np.minimum(p, p.transpose()) 99 | sbm_graph = nx.stochastic_block_model( 100 | sizes, p, seed=random_state, directed=directed 101 | ) 102 | # sbm graph generator automatically adds dictionary attributes. 103 | sbm_graph = remove_graph_data(sbm_graph) 104 | generated_graphs.append(sbm_graph) 105 | elif algorithm == "sfn": 106 | for i in range(number_of_graphs): 107 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 108 | generated_graph = nx.scale_free_graph(number_of_nodes, seed=random_state) 109 | # sfn graphs are by defaukt directed. 110 | if not directed: 111 | generated_graphs.append(remove_directions(generated_graph)) 112 | else: 113 | generated_graphs.append(generated_graph) 114 | elif algorithm == "complete": 115 | for i in range(number_of_graphs): 116 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 117 | create_using = nx.DiGraph if directed else nx.Graph 118 | generated_graphs.append( 119 | nx.complete_graph(number_of_nodes, create_using=create_using) 120 | ) 121 | elif algorithm == "star": 122 | for i in range(number_of_graphs): 123 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 124 | # number_of_nodes for star is the input + a center node. 125 | generated_graph = nx.star_graph(number_of_nodes - 1) 126 | if directed: 127 | generated_graphs.append(randomize_directions(generated_graph)) 128 | else: 129 | generated_graphs.append(generated_graph) 130 | elif algorithm == "path": 131 | for i in range(number_of_graphs): 132 | number_of_nodes = random.choice(_NUMBER_OF_NODES_RANGE[graph_sizes[i]]) 133 | create_using = nx.DiGraph if directed else nx.Graph 134 | generated_graphs.append( 135 | nx.path_graph(number_of_nodes, create_using=create_using) 136 | ) 137 | else: 138 | raise NotImplementedError() 139 | return generated_graphs 140 | 141 | 142 | def remove_graph_data(graph: nx.Graph) -> nx.Graph: 143 | # GraphML writer does not support dictionary data for nodes or graphs. 144 | for ind in range((graph.number_of_nodes())): 145 | graph.nodes[ind].pop("block", None) 146 | graph_data_keys = list(graph.graph.keys()) 147 | for _, node in enumerate(graph_data_keys): 148 | graph.graph.pop(node, None) 149 | return graph 150 | 151 | 152 | def randomize_directions(graph: nx.Graph) -> nx.DiGraph: 153 | # Converting the undirected graph to a directed graph. 154 | directed_graph = graph.to_directed() 155 | # For each edge, randomly choose a direction. 156 | edges = list(graph.edges()) 157 | for u, v in edges: 158 | if random.random() < 0.5: 159 | directed_graph.remove_edge(u, v) 160 | else: 161 | directed_graph.remove_edge(v, u) 162 | 163 | return directed_graph 164 | 165 | 166 | def remove_directions(graph: nx.Graph) -> nx.Graph: 167 | # Converting the direted graph to an undirected one by removing directions. 168 | undirected_graph = nx.Graph() 169 | undirected_graph.add_nodes_from(graph.nodes()) 170 | # Add edges between nodes, ignoring directions. 171 | for u, v in graph.edges(): 172 | undirected_graph.add_edge(u, v) 173 | 174 | return undirected_graph 175 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_generators_runner.py: -------------------------------------------------------------------------------- 1 | r"""Random graph generation. 2 | 3 | This code generates random graph using different algorithms. 4 | 5 | # Placeholder for Google-internal comments. 6 | """ 7 | 8 | from collections.abc import Sequence 9 | import os 10 | 11 | from absl import app 12 | from absl import flags 13 | import networkx as nx 14 | 15 | # Internal import. 16 | from . import graph_generators 17 | 18 | _ALGORITHM = flags.DEFINE_string( 19 | "algorithm", 20 | None, 21 | "The graph generating algorithm to use.", 22 | required=True, 23 | ) 24 | _NUMBER_OF_GRAPHS = flags.DEFINE_integer( 25 | "number_of_graphs", 26 | None, 27 | "The number of graphs to generate.", 28 | required=True, 29 | ) 30 | _DIRECTED = flags.DEFINE_bool( 31 | "directed", False, "Whether to generate directed graphs." 32 | ) 33 | _OUTPUT_PATH = flags.DEFINE_string( 34 | "output_path", None, "The output path to write the graphs.", required=True 35 | ) 36 | _SPLIT = flags.DEFINE_string( 37 | "split", None, "The dataset split to generate.", required=True 38 | ) 39 | _MIN_SPARSITY = flags.DEFINE_float("min_sparsity", 0.0, "The minimum sparsity.") 40 | _MAX_SPARSITY = flags.DEFINE_float("max_sparsity", 1.0, "The maximum sparsity.") 41 | 42 | 43 | def write_graphs(graphs: list[nx.Graph], output_dir: str) -> None: 44 | if not os.path.isdir(output_dir): 45 | os.makedirs(output_dir) 46 | for ind, graph in enumerate(graphs): 47 | nx.write_graphml( 48 | graph, 49 | os.Open( 50 | os.path.join(output_dir, str(ind) + ".graphml"), 51 | "wb", 52 | ), 53 | ) 54 | 55 | 56 | def main(argv: Sequence[str]) -> None: 57 | if len(argv) > 1: 58 | raise app.UsageError("Too many command-line arguments.") 59 | 60 | if _SPLIT.value == "train": 61 | random_seed = 9876 62 | elif _SPLIT.value == "test": 63 | random_seed = 1234 64 | elif _SPLIT.value == "validation": 65 | random_seed = 5432 66 | else: 67 | raise NotImplementedError() 68 | 69 | generated_graphs = graph_generators.generate_graphs( 70 | number_of_graphs=_NUMBER_OF_GRAPHS.value, 71 | algorithm=_ALGORITHM.value, 72 | directed=_DIRECTED.value, 73 | random_seed=random_seed, 74 | er_min_sparsity=_MIN_SPARSITY.value, 75 | er_max_sparsity=_MAX_SPARSITY.value, 76 | ) 77 | write_graphs( 78 | graphs=generated_graphs, 79 | output_dir=os.path.join( 80 | _OUTPUT_PATH.value, 81 | "directed" if _DIRECTED.value else "undirected", 82 | _ALGORITHM.value, 83 | _SPLIT.value, 84 | ), 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_generators_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import parameterized 2 | from . import graph_generators 3 | from absl.testing import absltest 4 | 5 | 6 | class GraphGenerationTest(absltest.TestCase, parameterized.TestCase): 7 | 8 | @parameterized.named_parameters( 9 | dict( 10 | testcase_name='er_undirected_1', 11 | algorithm='er', 12 | directed=False, 13 | k=1, 14 | ), 15 | dict( 16 | testcase_name='er_directed_1', 17 | algorithm='er', 18 | directed=True, 19 | k=1, 20 | ), 21 | dict( 22 | testcase_name='ba_undirected_5', 23 | algorithm='ba', 24 | directed=False, 25 | k=5, 26 | ), 27 | dict( 28 | testcase_name='ba_directed_5', 29 | algorithm='ba', 30 | directed=True, 31 | k=5, 32 | ), 33 | ) 34 | def test_number_of_graphs(self, algorithm, directed, k): 35 | generated_graph = graph_generators.generate_graphs(k, algorithm, directed) 36 | self.assertLen(generated_graph, k) 37 | 38 | @parameterized.named_parameters( 39 | dict( 40 | testcase_name='er_undirected', 41 | algorithm='er', 42 | directed=False, 43 | ), 44 | dict( 45 | testcase_name='er_directed', 46 | algorithm='er', 47 | directed=True, 48 | ), 49 | dict( 50 | testcase_name='ba_undirected', 51 | algorithm='ba', 52 | directed=False, 53 | ), 54 | dict( 55 | testcase_name='ba_directed', 56 | algorithm='ba', 57 | directed=True, 58 | ), 59 | dict( 60 | testcase_name='sbm_undirected', 61 | algorithm='sbm', 62 | directed=False, 63 | ), 64 | dict( 65 | testcase_name='sbm_directed', 66 | algorithm='sbm', 67 | directed=True, 68 | ), 69 | dict( 70 | testcase_name='sfn_undirected', 71 | algorithm='sfn', 72 | directed=False, 73 | ), 74 | dict( 75 | testcase_name='sfn_directed', 76 | algorithm='sfn', 77 | directed=True, 78 | ), 79 | dict( 80 | testcase_name='complete_undirected', 81 | algorithm='complete', 82 | directed=False, 83 | ), 84 | dict( 85 | testcase_name='complete_directed', 86 | algorithm='complete', 87 | directed=True, 88 | ), 89 | dict( 90 | testcase_name='star_undirected', 91 | algorithm='star', 92 | directed=False, 93 | ), 94 | dict( 95 | testcase_name='star_directed', 96 | algorithm='star', 97 | directed=True, 98 | ), 99 | dict( 100 | testcase_name='path_undirected', 101 | algorithm='path', 102 | directed=False, 103 | ), 104 | dict( 105 | testcase_name='path_directed', 106 | algorithm='path', 107 | directed=True, 108 | ), 109 | ) 110 | def test_directions(self, algorithm, directed): 111 | generated_graph = graph_generators.generate_graphs(1, algorithm, directed) 112 | self.assertEqual(generated_graph[0].is_directed(), directed) 113 | 114 | 115 | if __name__ == '__main__': 116 | googletest.main() 117 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for seqio tasks over graph data. 2 | 3 | This module contains definitions of metric_fns to be used for scoring 4 | graph tasks from nlgraph and graphqa. 5 | """ 6 | 7 | from typing import Mapping, Sequence 8 | 9 | 10 | def yes_no_accuracy( 11 | targets: Sequence[str], predictions: Sequence[str] 12 | ) -> Mapping[str, float]: 13 | """Assesses the accuracy of LLM outputs on Yes/No tasks. 14 | 15 | Targets must contain either the word 'yes' or the word 'no' but not both. 16 | 17 | Predictions are binarized by checking for 'yes' or 'no' in the first line. 18 | 19 | Args: 20 | targets: The expected output strings. 21 | predictions: The LLM outputs. 22 | 23 | Returns: 24 | Returns a dict of the following metrics: 25 | yes_no_accuracy: The % where the target and prediction match. 26 | yes_no_ambiguous: The % where the prediction contained yes and no 27 | yes_no_indeterminate: The % where the prediction contained neither yes nor 28 | no 29 | 30 | Raises: 31 | ValueError: If a target string contains 'yes' and 'no' 32 | """ 33 | num_correct = 0 34 | num_ambiguous = 0 35 | num_indeterminate = 0 36 | for target, prediction in zip(targets, predictions): 37 | normalized_target = target.lower() 38 | binarized_target = 'yes' in normalized_target 39 | print(binarized_target) 40 | if binarized_target and 'no' in normalized_target: 41 | raise ValueError(f'Ambiguous target string, {target}') 42 | if not binarized_target and 'no' not in normalized_target: 43 | raise ValueError(f'Indeterminate target string, {target}') 44 | normalized_prediction = prediction.splitlines() 45 | if not normalized_prediction: 46 | normalized_prediction = '' 47 | else: 48 | normalized_prediction = normalized_prediction[0] 49 | normalized_prediction = normalized_prediction.lower() 50 | binarized_prediction = 'yes' in normalized_prediction.lower() 51 | print(normalized_prediction) 52 | if binarized_prediction and 'no' in normalized_prediction: 53 | num_ambiguous += 1 54 | continue 55 | if not binarized_prediction and 'no' not in normalized_prediction: 56 | num_indeterminate += 1 57 | continue 58 | if binarized_prediction == binarized_target: 59 | num_correct += 1 60 | return { 61 | 'yes_no_accuracy': num_correct / len(targets), 62 | 'yes_no_ambiguous': num_ambiguous / len(targets), 63 | 'yes_no_indeterminate': num_indeterminate / len(targets), 64 | } 65 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_metrics_test.py: -------------------------------------------------------------------------------- 1 | from . import graph_metrics 2 | from absl.testing import absltest 3 | 4 | 5 | class GraphTasksTest(absltest.TestCase): 6 | 7 | def test_yes_no_correct(self): 8 | result = graph_metrics.yes_no_accuracy( 9 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"], 10 | predictions=[ 11 | "yes", 12 | "Yes\nNo", 13 | "No\nYes", 14 | "That's gonna be no from me.", 15 | "yes", 16 | ], 17 | ) 18 | self.assertEqual(result["yes_no_ambiguous"], 0) 19 | self.assertEqual(result["yes_no_indeterminate"], 0) 20 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.8) 21 | 22 | def test_yes_no_ambiguous(self): 23 | result = graph_metrics.yes_no_accuracy( 24 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"], 25 | predictions=[ 26 | "yes", 27 | "Yes No", 28 | "No Yes", 29 | "That's gonna be no from me.", 30 | "yes", 31 | ], 32 | ) 33 | self.assertEqual(result["yes_no_ambiguous"], 0.4) 34 | self.assertEqual(result["yes_no_indeterminate"], 0) 35 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.4) 36 | 37 | def test_yes_no_indeterminate(self): 38 | result = graph_metrics.yes_no_accuracy( 39 | targets=["Yes", "The answer is yes.", "No", "The answer is no.", "No"], 40 | predictions=[ 41 | "yes", 42 | "\n No", 43 | "", 44 | "That's gonna be no from me.", 45 | "yes", 46 | ], 47 | ) 48 | self.assertEqual(result["yes_no_ambiguous"], 0) 49 | self.assertEqual(result["yes_no_indeterminate"], 0.4) 50 | self.assertAlmostEqual(result["yes_no_accuracy"], 0.4) 51 | 52 | def test_yes_no_accuracy_raises_on_ambiguous_target(self): 53 | with self.assertRaises(ValueError): 54 | graph_metrics.yes_no_accuracy( 55 | targets=["Yes but maybe no"], 56 | predictions=["yes"], 57 | ) 58 | 59 | def test_yes_no_accuracy_raises_on_indeterminate_target(self): 60 | with self.assertRaises(ValueError): 61 | graph_metrics.yes_no_accuracy( 62 | targets=["Hmm?"], 63 | predictions=["yes"], 64 | ) 65 | 66 | 67 | if __name__ == "__main__": 68 | googletest.main() 69 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_tasks.py: -------------------------------------------------------------------------------- 1 | """The graph tasks to be tried with LLMs.""" 2 | 3 | import random 4 | 5 | import networkx as nx 6 | import numpy as np 7 | 8 | from . import graph_text_encoders 9 | 10 | 11 | class GraphTask: 12 | """The parent class for all the graph tasks.""" 13 | 14 | def __init__(self): 15 | self.name = 'default' 16 | self.maximum_nnodes_cot_graph = 10 17 | 18 | def prepare_examples_dict( 19 | self, 20 | graphs: list[nx.Graph], 21 | generator_algorithms: list[str], 22 | encoding_method: str, 23 | ) -> dict[int, dict[str, str | list[int]]]: 24 | raise NotImplementedError() 25 | 26 | def create_few_shot_example( 27 | self, graph: nx.Graph, encoding_method: str, cot: bool 28 | ): 29 | raise NotImplementedError() 30 | 31 | 32 | class CycleCheck(GraphTask): 33 | """The graph task to check if there is at least one cycle or not.""" 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self.name = 'cycle_check' 38 | self._task_description = 'Q: Is there a cycle in this graph?\nA: ' 39 | 40 | def prepare_examples_dict( 41 | self, 42 | graphs: list[nx.Graph], 43 | generator_algorithms: list[str], 44 | encoding_method: str, 45 | ) -> dict[int, dict[str, str | list[int]]]: 46 | examples_dict = {} 47 | for ind, graph in enumerate(graphs): 48 | question = ( 49 | graph_text_encoders.encode_graph(graph, encoding_method) 50 | + self._task_description 51 | ) 52 | try: 53 | nx.find_cycle(graph) 54 | answer = 'Yes, there is a cycle.' 55 | except nx.NetworkXNoCycle: 56 | answer = 'No, there is no cycle.' 57 | examples_dict[ind] = { 58 | 'question': question, 59 | 'answer': answer, 60 | 'nnodes': str(len(graph.nodes())), 61 | 'nedges': str(len(graph.edges())), 62 | 'task_description': self._task_description, 63 | 'graph': graph, 64 | 'algorithm': generator_algorithms[ind], 65 | 'node_ids': [], 66 | } 67 | return examples_dict 68 | 69 | def create_few_shot_example( 70 | self, graph: nx.Graph, encoding_method: str, cot: bool 71 | ) -> str: 72 | """Create a few shot example w or w/o cot for the graph graph.""" 73 | name_dict = graph_text_encoders.get_tlag_node_encoder( 74 | graph, encoding_method 75 | ) 76 | question = ( 77 | graph_text_encoders.encode_graph(graph, encoding_method) 78 | + self._task_description 79 | ) 80 | try: 81 | cycle = nx.find_cycle(graph) 82 | cycle_text = '' 83 | answer = 'Yes, there is a cycle. ' 84 | if cot: 85 | for pair in cycle: 86 | cycle_text += ( 87 | name_dict[pair[0]] 88 | + ' is connected to ' 89 | + name_dict[pair[1]] 90 | + ', ' 91 | ) 92 | cycle_cot = 'The cycle is: %s.' % cycle_text[:-2] 93 | answer += cycle_cot 94 | except nx.NetworkXNoCycle: 95 | answer = 'No, there is no cycle.' 96 | return question + answer 97 | 98 | def choose_few_shot_examples( 99 | self, 100 | few_shots_dict: dict[tuple[str, str], list[str]], 101 | encoding_method: str, 102 | k: int = 2, 103 | ) -> str: 104 | """Choose few shot examples for each algorithm.""" 105 | pos_cycle_algorithms = ['er', 'ba', 'sbm', 'sfn', 'complete'] 106 | neg_cycle_algorithms = ['star', 'path'] 107 | few_shots_str = '' 108 | # choose k-1 shots for pos algorithms and one negative. 109 | positive_algorithms = random.choices(pos_cycle_algorithms, k=k - 1) 110 | for positive_algorithm in positive_algorithms: 111 | example_list = few_shots_dict[(positive_algorithm, encoding_method)] 112 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n' 113 | negative_algorithm = random.choice(neg_cycle_algorithms) 114 | example_list = few_shots_dict[(negative_algorithm, encoding_method)] 115 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n' 116 | return few_shots_str 117 | 118 | 119 | class EdgeExistence(GraphTask): 120 | """The graph task to check if an edge exist in a graph or not.""" 121 | 122 | def __init__(self): 123 | super().__init__() 124 | self.name = 'edge_existence' 125 | 126 | def prepare_examples_dict( 127 | self, 128 | graphs: list[nx.Graph], 129 | generator_algorithms: list[str], 130 | encoding_method: str, 131 | ) -> dict[int, dict[str, str | list[int]]]: 132 | examples_dict = {} 133 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 134 | 135 | for ind, graph in enumerate(graphs): 136 | source, target = random.sample(list(graph.nodes()), k=2) 137 | question = graph_text_encoders.encode_graph(graph, encoding_method) 138 | task_description = 'Q: Is node %s connected to node %s?\nA: ' % ( 139 | name_dict[source], 140 | name_dict[target], 141 | ) 142 | question += task_description 143 | if ((source, target) in graph.edges()) or ( 144 | (target, source) in graph.edges() 145 | ): 146 | answer = 'Yes.' 147 | else: 148 | answer = 'No.' 149 | examples_dict[ind] = { 150 | 'question': question, 151 | 'answer': answer, 152 | 'nnodes': str(len(graph.nodes())), 153 | 'nedges': str(len(graph.edges())), 154 | 'task_description': task_description, 155 | 'graph': graph, 156 | 'algorithm': generator_algorithms[ind], 157 | 'node_ids': [source, target], 158 | } 159 | return examples_dict 160 | 161 | def create_few_shot_example( 162 | self, graph: nx.Graph, encoding_method: str, cot: bool 163 | ) -> str: 164 | name_dict = graph_text_encoders.get_tlag_node_encoder( 165 | graph, encoding_method 166 | ) 167 | source, target = random.sample(list(graph.nodes()), k=2) 168 | question = graph_text_encoders.encode_graph(graph, encoding_method) 169 | question += 'Q: Is node %s connected to node %s?\nA: ' % ( 170 | name_dict[source], 171 | name_dict[target], 172 | ) 173 | if ((source, target) in graph.edges()) or ( 174 | (target, source) in graph.edges() 175 | ): 176 | answer = 'Yes.' 177 | if cot: 178 | answer += ( 179 | ' Because, there is an edge from %s to %s in the graph description.' 180 | % (name_dict[source], name_dict[target]) 181 | ) 182 | else: 183 | answer = 'No.' 184 | if cot: 185 | answer += ( 186 | ' Because, there is no edge from %s to %s in the graph description.' 187 | % (name_dict[source], name_dict[target]) 188 | ) 189 | return question + answer 190 | 191 | 192 | class NodeCount(GraphTask): 193 | """The graph task for finding number of nodes in a graph.""" 194 | 195 | def __init__(self): 196 | super().__init__() 197 | self.name = 'node_count' 198 | self._task_description = 'Q: How many nodes are in this graph?\nA: ' 199 | 200 | def prepare_examples_dict( 201 | self, 202 | graphs: list[nx.Graph], 203 | generator_algorithms: list[str], 204 | encoding_method: str, 205 | ) -> dict[int, dict[str, str | list[int]]]: 206 | examples_dict = {} 207 | for ind, graph in enumerate(graphs): 208 | question = graph_text_encoders.encode_graph(graph, encoding_method) 209 | question += self._task_description 210 | answer = ' %d.' % len(graph.nodes()) 211 | examples_dict[ind] = { 212 | 'question': question, 213 | 'answer': answer, 214 | 'nnodes': str(len(graph.nodes())), 215 | 'nedges': str(len(graph.edges())), 216 | 'task_description': self._task_description, 217 | 'graph': graph, 218 | 'algorithm': generator_algorithms[ind], 219 | 'node_ids': [], 220 | } 221 | return examples_dict 222 | 223 | def get_nodes_string(self, name_dict: dict[int, str], nnodes: int) -> str: 224 | node_string = '' 225 | for i in range(nnodes - 1): 226 | node_string += name_dict[i] + ', ' 227 | node_string += 'and ' + name_dict[nnodes - 1] 228 | return node_string 229 | 230 | def create_few_shot_example( 231 | self, graph: nx.Graph, encoding_method: str, cot: bool 232 | ) -> str: 233 | name_dict = graph_text_encoders.get_tlag_node_encoder( 234 | graph, encoding_method 235 | ) 236 | question = graph_text_encoders.encode_graph(graph, encoding_method) 237 | question += self._task_description 238 | answer = '%d.' % len(graph.nodes()) 239 | if cot: 240 | answer += ' The nodes are %s.' % self.get_nodes_string( 241 | name_dict, len(graph.nodes()) 242 | ) 243 | 244 | return question + answer 245 | 246 | 247 | class NodeDegree(GraphTask): 248 | """The graph task for finding degree of a node in a graph.""" 249 | 250 | def __init__(self): 251 | super().__init__() 252 | self.name = 'node_degree' 253 | 254 | def prepare_examples_dict( 255 | self, 256 | graphs: list[nx.Graph], 257 | generator_algorithms: list[str], 258 | encoding_method: str, 259 | ) -> dict[int, dict[str, str | list[int]]]: 260 | examples_dict = {} 261 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 262 | for ind, graph in enumerate(graphs): 263 | question = graph_text_encoders.encode_graph(graph, encoding_method) 264 | source_node = random.sample(list(graph.nodes()), k=1)[0] 265 | task_description = ( 266 | 'Q: What is the degree of node %s?\nA: ' % name_dict[source_node] 267 | ) 268 | question += task_description 269 | answer = '%d.' % graph.degree[source_node] 270 | examples_dict[ind] = { 271 | 'question': question, 272 | 'answer': answer, 273 | 'nnodes': str(len(graph.nodes())), 274 | 'nedges': str(len(graph.edges())), 275 | 'task_description': task_description, 276 | 'graph': graph, 277 | 'algorithm': generator_algorithms[ind], 278 | 'node_ids': [source_node], 279 | } 280 | return examples_dict 281 | 282 | def get_edge_string( 283 | self, name_dict: dict[int, str], graph: nx.Graph, source_node: int 284 | ) -> str: 285 | """Gets a string identifying the edges a given node is connected to.""" 286 | edge_string = '' 287 | target_edges = graph.edges(source_node) 288 | target_nodes = [] 289 | for edge in target_edges: 290 | target_nodes.append(edge[1]) 291 | if target_nodes: 292 | for i in range(len(target_nodes) - 1): 293 | edge_string += name_dict[target_nodes[i]] + ', ' 294 | edge_string += 'and ' + name_dict[target_nodes[-1]] 295 | else: 296 | edge_string = 'no nodes' 297 | return edge_string 298 | 299 | def create_few_shot_example( 300 | self, graph: nx.Graph, encoding_method: str, cot: bool 301 | ) -> str: 302 | name_dict = graph_text_encoders.get_tlag_node_encoder( 303 | graph, encoding_method 304 | ) 305 | question = graph_text_encoders.encode_graph(graph, encoding_method) 306 | source_node = random.sample(list(graph.nodes()), k=1)[0] 307 | question += ( 308 | 'Q: What is the degree of node %s?\nA: ' % name_dict[source_node] 309 | ) 310 | answer = '%d.' % graph.degree[source_node] 311 | if cot: 312 | answer += ' This is because %s is connected to %s.' % ( 313 | name_dict[source_node], 314 | self.get_edge_string(name_dict, graph, source_node), 315 | ) 316 | return question + answer 317 | 318 | 319 | class EdgeCount(GraphTask): 320 | """The graph task for finding number of edges in a graph.""" 321 | 322 | def __init__(self): 323 | super().__init__() 324 | self.name = 'edge_count' 325 | self._task_description = 'Q: How many edges are in this graph?\nA: ' 326 | 327 | def prepare_examples_dict( 328 | self, 329 | graphs: list[nx.Graph], 330 | generator_algorithms: list[str], 331 | encoding_method: str, 332 | ) -> dict[int, dict[str, str | list[int]]]: 333 | examples_dict = {} 334 | for ind, graph in enumerate(graphs): 335 | question = graph_text_encoders.encode_graph(graph, encoding_method) 336 | question += self._task_description 337 | answer = ' %d.' % len(graph.edges()) 338 | examples_dict[ind] = { 339 | 'question': question, 340 | 'answer': answer, 341 | 'nnodes': str(len(graph.nodes())), 342 | 'nedges': str(len(graph.edges())), 343 | 'task_description': self._task_description, 344 | 'graph': graph, 345 | 'algorithm': generator_algorithms[ind], 346 | 'node_ids': [], 347 | } 348 | return examples_dict 349 | 350 | def get_edges_string( 351 | self, name_dict: dict[int, str], edges: list[tuple[int, int]] 352 | ) -> str: 353 | edges_string = '' 354 | for edge in edges: 355 | edges_string += ( 356 | '(' + name_dict[edge[0]] + ', ' + name_dict[edge[1]] + '), ' 357 | ) 358 | return edges_string.strip()[:-1] 359 | 360 | def create_few_shot_example( 361 | self, graph: nx.Graph, encoding_method: str, cot: bool 362 | ) -> str: 363 | name_dict = graph_text_encoders.get_tlag_node_encoder( 364 | graph, encoding_method 365 | ) 366 | question = graph_text_encoders.encode_graph(graph, encoding_method) 367 | question += self._task_description 368 | answer = '%d.' % len(graph.edges()) 369 | if cot: 370 | answer += ' The edges are %s.' % self.get_edges_string( 371 | name_dict, list(graph.edges()) 372 | ) 373 | return question + answer 374 | 375 | 376 | class ConnectedNodes(GraphTask): 377 | """The graph task for finding connected nodes to a given node in a graph.""" 378 | 379 | def __init__(self): 380 | super().__init__() 381 | self.name = 'connected_nodes' 382 | 383 | def prepare_examples_dict( 384 | self, 385 | graphs: list[nx.Graph], 386 | generator_algorithms: list[str], 387 | encoding_method: str, 388 | ) -> dict[int, dict[str, str | list[int]]]: 389 | examples_dict = {} 390 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 391 | for ind, graph in enumerate(graphs): 392 | question = graph_text_encoders.encode_graph(graph, encoding_method) 393 | source_node = random.sample(list(graph.nodes()), k=1)[0] 394 | task_description = ( 395 | 'Q: List all the nodes connected to %s in alphabetical order.\nA: ' 396 | % name_dict[source_node] 397 | ) 398 | question += task_description 399 | outgoing_edges = list(graph.edges(source_node)) 400 | if outgoing_edges: 401 | answer = self.get_connected_nodes(outgoing_edges, name_dict) + '.' 402 | else: 403 | answer = ' No nodes.' 404 | examples_dict[ind] = { 405 | 'question': question, 406 | 'answer': answer, 407 | 'nnodes': str(len(graph.nodes())), 408 | 'nedges': str(len(graph.edges())), 409 | 'task_description': task_description, 410 | 'graph': graph, 411 | 'algorithm': generator_algorithms[ind], 412 | 'node_ids': [source_node], 413 | } 414 | return examples_dict 415 | 416 | def get_connected_nodes( 417 | self, edges: list[tuple[int, int]], name_dict: dict[int, str] 418 | ) -> str: 419 | """Gets a string including all the nodes that are connected to source.""" 420 | connected_nodes = [] 421 | for edge in edges: 422 | connected_nodes.append(name_dict[edge[1]]) 423 | connected_nodes_string = '' 424 | if connected_nodes: 425 | try: 426 | int(connected_nodes[0]) 427 | connected_nodes_string = ', '.join(map(str, connected_nodes)) 428 | except ValueError: 429 | # Check if these are not integers, sort 430 | connected_nodes_string = ', '.join(map(str, sorted(connected_nodes))) 431 | return connected_nodes_string 432 | 433 | def create_few_shot_example( 434 | self, graph: nx.Graph, encoding_method: str, cot: bool 435 | ) -> str: 436 | name_dict = graph_text_encoders.get_tlag_node_encoder( 437 | graph, encoding_method 438 | ) 439 | question = graph_text_encoders.encode_graph(graph, encoding_method) 440 | source_node = random.sample(list(graph.nodes()), k=1)[0] 441 | question += ( 442 | 'Q: List all the nodes connected to %s in alphabetical order.\nA: ' 443 | % name_dict[source_node] 444 | ) 445 | outgoing_edges = list(graph.edges(source_node)) 446 | answer = '' 447 | if outgoing_edges: 448 | answer = self.get_connected_nodes(outgoing_edges, name_dict) + '.' 449 | if cot: 450 | answer += ' This is because there is an edge from %s to %s.' % ( 451 | name_dict[source_node], 452 | answer, 453 | ) 454 | else: 455 | answer = 'No nodes.' 456 | if cot: 457 | answer += ( 458 | ' This is because %s is not connected to any node.' 459 | % name_dict[source_node] 460 | ) 461 | return question + answer 462 | 463 | 464 | class DisconnectedNodes(GraphTask): 465 | """The task for finding disconnected nodes for a given node in a graph.""" 466 | 467 | def __init__(self): 468 | super().__init__() 469 | self.name = 'disconnected_nodes' 470 | 471 | def prepare_examples_dict( 472 | self, 473 | graphs: list[nx.Graph], 474 | generator_algorithms: list[str], 475 | encoding_method: str, 476 | ) -> dict[int, dict[str, str | list[int]]]: 477 | examples_dict = {} 478 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 479 | for ind, graph in enumerate(graphs): 480 | question = graph_text_encoders.encode_graph(graph, encoding_method) 481 | source_node = random.sample(list(graph.nodes()), k=1)[0] 482 | task_description = ( 483 | 'Q: List all the nodes that are not connected to %s in alphabetical' 484 | ' order.\nA: ' 485 | % name_dict[source_node] 486 | ) 487 | question += task_description 488 | outgoing_edges = list(graph.edges(source_node)) 489 | answer = self.get_disconnected_nodes( 490 | source_node, outgoing_edges, name_dict, list(graph.nodes()) 491 | ) 492 | if not answer: 493 | answer = 'No nodes' 494 | 495 | answer += '.' 496 | examples_dict[ind] = { 497 | 'question': question, 498 | 'answer': answer, 499 | 'nnodes': str(len(graph.nodes())), 500 | 'nedges': str(len(graph.edges())), 501 | 'task_description': task_description, 502 | 'graph': graph, 503 | 'algorithm': generator_algorithms[ind], 504 | 'node_ids': [source_node], 505 | } 506 | return examples_dict 507 | 508 | def get_disconnected_nodes( 509 | self, 510 | source: int, 511 | edges: list[tuple[int, int]], 512 | name_dict: dict[int, str], 513 | all_nodes: list[int], 514 | ) -> str: 515 | """Gets a string with all the nodes that are not connected to source.""" 516 | for edge in edges: 517 | if edge[1] in all_nodes: 518 | all_nodes.remove(edge[1]) 519 | if source in all_nodes: 520 | all_nodes.remove(source) 521 | all_nodes_names = [] 522 | for node in all_nodes: 523 | all_nodes_names.append(name_dict[node]) 524 | # sorted operation should be different for integers vs strings. 525 | if all_nodes_names: 526 | try: 527 | int(all_nodes_names[0]) 528 | for ind, value in enumerate(all_nodes_names): 529 | all_nodes_names[ind] = int(value) 530 | all_nodes_names = sorted(all_nodes_names) 531 | for ind, value in enumerate(all_nodes_names): 532 | all_nodes_names[ind] = str(value) 533 | except ValueError: 534 | pass 535 | return ', '.join(map(str, sorted(all_nodes_names))) 536 | 537 | def create_few_shot_example( 538 | self, graph: nx.Graph, encoding_method: str, cot: bool 539 | ) -> str: 540 | name_dict = graph_text_encoders.get_tlag_node_encoder( 541 | graph, encoding_method 542 | ) 543 | question = graph_text_encoders.encode_graph(graph, encoding_method) 544 | source_node = random.sample(list(graph.nodes()), k=1)[0] 545 | question += ( 546 | 'Q: List all the nodes that are not connected to %s in alphabetical' 547 | ' order.\nA: ' 548 | % name_dict[source_node] 549 | ) 550 | outgoing_edges = list(graph.edges(source_node)) 551 | answer = '' 552 | disconnected_nodes_string = self.get_disconnected_nodes( 553 | source_node, outgoing_edges, name_dict, list(graph.nodes()) 554 | ) 555 | if outgoing_edges: 556 | if not disconnected_nodes_string: 557 | disconnected_nodes_string = 'No nodes' 558 | answer = disconnected_nodes_string + '.' 559 | if cot: 560 | answer += ' This is because there is not an edge from %s to %s.' % ( 561 | name_dict[source_node], 562 | answer, 563 | ) 564 | else: 565 | answer = ' No nodes.' 566 | if cot: 567 | answer += ( 568 | ' This is because %s is connected to all nodes.' 569 | % name_dict[source_node] 570 | ) 571 | return question + answer 572 | 573 | 574 | class Reachability(GraphTask): 575 | """The graph task to check if there is a path from a source to target.""" 576 | 577 | def __init__(self): 578 | super().__init__() 579 | self.name = 'reachability' 580 | 581 | def prepare_examples_dict( 582 | self, 583 | graphs: list[nx.Graph], 584 | generator_algorithms: list[str], 585 | encoding_method: str, 586 | ) -> dict[int, dict[str, str | list[int]]]: 587 | examples_dict = {} 588 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 589 | 590 | for ind, graph in enumerate(graphs): 591 | source, target = random.sample(list(graph.nodes()), k=2) 592 | question = graph_text_encoders.encode_graph(graph, encoding_method) 593 | task_description = 'Q: Is there a path from node %s to node %s?\nA: ' % ( 594 | name_dict[source], 595 | name_dict[target], 596 | ) 597 | question += task_description 598 | if nx.has_path(graph, source, target): 599 | answer = 'Yes.' 600 | else: 601 | answer = 'No.' 602 | examples_dict[ind] = { 603 | 'question': question, 604 | 'answer': answer, 605 | 'nnodes': str(len(graph.nodes())), 606 | 'nedges': str(len(graph.edges())), 607 | 'task_description': task_description, 608 | 'graph': graph, 609 | 'algorithm': generator_algorithms[ind], 610 | 'node_ids': [source, target], 611 | } 612 | return examples_dict 613 | 614 | def create_few_shot_example( 615 | self, graph: nx.Graph, encoding_method: str, cot: bool 616 | ) -> str: 617 | name_dict = graph_text_encoders.get_tlag_node_encoder( 618 | graph, encoding_method 619 | ) 620 | source, target = random.sample(list(graph.nodes()), k=2) 621 | question = graph_text_encoders.encode_graph(graph, encoding_method) 622 | question += 'Q: Is there a path from node %s to node %s?\nA: ' % ( 623 | name_dict[source], 624 | name_dict[target], 625 | ) 626 | if nx.has_path(graph, source, target): 627 | answer = 'Yes.' 628 | if cot: 629 | path = nx.shortest_path(graph, source, target) 630 | explanation = ' Because' 631 | for i in range(len(path) - 1): 632 | # The only edge or the non-last edges in the path. 633 | if len(path) == 2 or i < len(path) - 2: 634 | sep = ',' 635 | # The last edge in a path with more than one edge. 636 | else: 637 | sep = ', and' 638 | explanation += '%s there is an edge from node %d to node %d' % ( 639 | sep, 640 | path[i], 641 | path[i + 1], 642 | ) 643 | explanation += ' .' 644 | answer += explanation 645 | else: 646 | answer = 'No.' 647 | if cot: 648 | answer += ( 649 | ' Because, there is no path connecting node %s to node %s based on' 650 | ' the graph description.' % (name_dict[source], name_dict[target]) 651 | ) 652 | return question + answer 653 | 654 | 655 | class ShortestPath(GraphTask): 656 | """The graph task to check if there is a path from a source to target.""" 657 | 658 | def __init__(self): 659 | super().__init__() 660 | self.name = 'shortest_path' 661 | 662 | def prepare_examples_dict( 663 | self, 664 | graphs: list[nx.Graph], 665 | generator_algorithms: list[str], 666 | encoding_method: str, 667 | ) -> dict[int, dict[str, str | list[int]]]: 668 | examples_dict = {} 669 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 670 | 671 | for ind, graph in enumerate(graphs): 672 | source, target = random.sample(list(graph.nodes()), k=2) 673 | question = graph_text_encoders.encode_graph(graph, encoding_method) 674 | task_description = ( 675 | 'Q: What is the length of the shortest path from node %s to node' 676 | ' %s?\nA: ' 677 | % ( 678 | name_dict[source], 679 | name_dict[target], 680 | ) 681 | ) 682 | question += task_description 683 | try: 684 | path = nx.shortest_path(graph, source, target) 685 | answer = str(len(path) - 1) + '.' 686 | except nx.NetworkXNoPath: 687 | answer = 'There is no path from node %s to node %s.' % ( 688 | name_dict[source], 689 | name_dict[target], 690 | ) 691 | examples_dict[ind] = { 692 | 'question': question, 693 | 'answer': answer, 694 | 'nnodes': str(len(graph.nodes())), 695 | 'nedges': str(len(graph.edges())), 696 | 'task_description': task_description, 697 | 'graph': graph, 698 | 'algorithm': generator_algorithms[ind], 699 | 'node_ids': [source, target], 700 | } 701 | return examples_dict 702 | 703 | def create_few_shot_example( 704 | self, graph: nx.Graph, encoding_method: str, cot: bool 705 | ) -> str: 706 | name_dict = graph_text_encoders.get_tlag_node_encoder( 707 | graph, encoding_method 708 | ) 709 | source, target = random.sample(list(graph.nodes()), k=2) 710 | question = graph_text_encoders.encode_graph(graph, encoding_method) 711 | question += ( 712 | 'Q: What is the length of the shortest path from node %s to node' 713 | ' %s?\nA: ' 714 | % ( 715 | name_dict[source], 716 | name_dict[target], 717 | ) 718 | ) 719 | if nx.has_path(graph, source, target): 720 | path = nx.shortest_path(graph, source, target) 721 | answer = str(len(path) - 1) + '.' 722 | if cot: 723 | explanation = ' Because' 724 | for i in range(len(path) - 1): 725 | # The only edge or the non-last edges in the path. 726 | if len(path) == 2 or i < len(path) - 2: 727 | sep = ',' 728 | # The last edge in a path with more than one edge. 729 | else: 730 | sep = ', and' 731 | explanation += '%s there is an edge from node %d to node %d' % ( 732 | sep, 733 | path[i], 734 | path[i + 1], 735 | ) 736 | explanation += ' .' 737 | answer += explanation 738 | else: 739 | answer = 'There is no path from node %s to node %s.' % ( 740 | name_dict[source], 741 | name_dict[target], 742 | ) 743 | if cot: 744 | answer += ( 745 | ' Because, there is no path connecting node %s to node %s based on' 746 | ' the graph description.' % (name_dict[source], name_dict[target]) 747 | ) 748 | return question + answer 749 | 750 | 751 | class TriangleCounting(GraphTask): 752 | """The graph task to count the number of triangles in a graph.""" 753 | 754 | def __init__(self): 755 | super().__init__() 756 | self.name = 'triangle_counting' 757 | self._task_description = 'Q: How many triangles are in this graph?\nA: ' 758 | 759 | def prepare_examples_dict( 760 | self, 761 | graphs: list[nx.Graph], 762 | generator_algorithms: list[str], 763 | encoding_method: str, 764 | ) -> dict[int, dict[str, str | list[int]]]: 765 | examples_dict = {} 766 | for ind, graph in enumerate(graphs): 767 | question = ( 768 | graph_text_encoders.encode_graph(graph, encoding_method) 769 | + self._task_description 770 | ) 771 | ntriangles = int(np.sum(list(nx.triangles(graph).values())) / 3) 772 | 773 | answer = '%i.' % ntriangles 774 | examples_dict[ind] = { 775 | 'question': question, 776 | 'answer': answer, 777 | 'nnodes': str(len(graph.nodes())), 778 | 'nedges': str(len(graph.edges())), 779 | 'task_description': self._task_description, 780 | 'graph': graph, 781 | 'algorithm': generator_algorithms[ind], 782 | 'node_ids': [], 783 | } 784 | return examples_dict 785 | 786 | def create_few_shot_example( 787 | self, graph: nx.Graph, encoding_method: str, cot: bool 788 | ) -> str: 789 | """Create a few shot example w or w/o cot for the graph graph.""" 790 | name_dict = graph_text_encoders.get_tlag_node_encoder( 791 | graph, encoding_method 792 | ) 793 | question = ( 794 | graph_text_encoders.encode_graph(graph, encoding_method) 795 | + self._task_description 796 | ) 797 | triangles_dict = nx.triangles(graph) 798 | ntriangles = int(np.sum(list(triangles_dict.values())) / 3) 799 | 800 | if ntriangles > 0: 801 | answer = '%i.' % ntriangles 802 | if cot: 803 | ntriangles_cot = '' 804 | for key, value in triangles_dict.items(): 805 | if value > 0: 806 | if value == 1: 807 | ntriangles_cot += ( 808 | 'There is %i triangle including node %s as a vertex.\n' 809 | % (value, name_dict[key]) 810 | ) 811 | else: 812 | ntriangles_cot += ( 813 | 'There are %i triangles including node %s as a vertex.\n' 814 | % (value, name_dict[key]) 815 | ) 816 | ntriangles_cot += ( 817 | 'Summing the number of triangles for all nodes and dividing them by' 818 | ' three gives us %i triangles in total.' % ntriangles 819 | ) 820 | answer += ntriangles_cot 821 | else: 822 | answer = '0.' 823 | if cot: 824 | ntriangles_cot = 'No three nodes form a triangle of edges.' 825 | answer += ntriangles_cot 826 | return question + answer 827 | 828 | 829 | class MaximumFlow(GraphTask): 830 | """The graph task to compute the maximum flow from a source to a target.""" 831 | 832 | def __init__(self): 833 | super().__init__() 834 | self.name = 'maximum_flow' 835 | 836 | def prepare_examples_dict( 837 | self, 838 | graphs: list[nx.Graph], 839 | generator_algorithms: list[str], 840 | encoding_method: str, 841 | ) -> dict[int, dict[str, str | list[int]]]: 842 | examples_dict = {} 843 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 844 | 845 | for ind, graph in enumerate(graphs): 846 | graph = add_edge_weight(graph) 847 | source, target = random.sample(list(graph.nodes()), k=2) 848 | question = graph_text_encoders.encode_graph(graph, encoding_method) 849 | task_description = ( 850 | 'Q: What is the maximum capacity of the flow from node %s to node' 851 | ' %s?\nA: ' % (name_dict[source], name_dict[target]) 852 | ) 853 | question += task_description 854 | maximum_flow_value = nx.maximum_flow( 855 | graph, source, target, capacity='weight' 856 | )[0] 857 | answer = str(maximum_flow_value) + '.' 858 | examples_dict[ind] = { 859 | 'question': question, 860 | 'answer': answer, 861 | 'nnodes': str(len(graph.nodes())), 862 | 'nedges': str(len(graph.edges())), 863 | 'task_description': task_description, 864 | 'graph': graph, 865 | 'algorithm': generator_algorithms[ind], 866 | 'node_ids': [source, target], 867 | } 868 | return examples_dict 869 | 870 | def create_few_shot_example( 871 | self, graph: nx.Graph, encoding_method: str, cot: bool 872 | ) -> str: 873 | graph = add_edge_weight(graph) 874 | name_dict = graph_text_encoders.get_tlag_node_encoder( 875 | graph, encoding_method 876 | ) 877 | source, target = random.sample(list(graph.nodes()), k=2) 878 | question = graph_text_encoders.encode_graph(graph, encoding_method) 879 | question += ( 880 | 'Q: What is the maximum capacity of the flow from node %s to' 881 | ' node %s?\nA: ' % (name_dict[source], name_dict[target]) 882 | ) 883 | flow_value, flow_dict = nx.maximum_flow( 884 | graph, source, target, capacity='weight' 885 | ) 886 | answer = str(flow_value) + '.' 887 | if flow_value > 0: 888 | if cot: 889 | explanation = ' This is because of the following edges: ' 890 | for edge, capacity in flow_dict.items(): 891 | for key, value in capacity.items(): 892 | if value > 0: 893 | explanation += ( 894 | 'the edge from node %i to node %i with capacity %i, ' 895 | % ( 896 | edge, 897 | key, 898 | value, 899 | ) 900 | ) 901 | explanation = explanation.strip()[:-1] + '.' 902 | answer += explanation 903 | else: 904 | if cot: 905 | answer += ( 906 | ' Because, there is no path connecting node %s to node %s based on' 907 | ' the graph description.' % (name_dict[source], name_dict[target]) 908 | ) 909 | return question + answer 910 | 911 | 912 | def has_edge_weights(graph): 913 | for _, _, data in graph.edges(data=True): 914 | if 'weight' not in data: 915 | return False 916 | return True 917 | 918 | 919 | def add_edge_weight(graph): 920 | if has_edge_weights(graph): 921 | return graph 922 | else: 923 | for edge in graph.edges(): 924 | graph[edge[0]][edge[1]]['weight'] = random.randint(1, 10) 925 | return graph 926 | 927 | 928 | class NodeClassification(GraphTask): 929 | """The graph task to classify a given node in the graph.""" 930 | 931 | def __init__(self): 932 | super().__init__() 933 | self.name = 'node_classification' 934 | self.classes = [ 935 | 'soccer', 936 | 'baseball', 937 | 'tennis', 938 | 'golf', 939 | 'football', 940 | 'surfing', 941 | ] 942 | 943 | def prepare_examples_dict( 944 | self, 945 | graphs: list[nx.Graph], 946 | generator_algorithms: list[str], 947 | encoding_method: str, 948 | ) -> dict[int, dict[str, str | list[int]]]: 949 | classes = random.sample(list(self.classes), k=2) 950 | examples_dict = {} 951 | name_dict = graph_text_encoders.get_tlag_node_encoder(None, encoding_method) 952 | for ind, graph in enumerate(graphs): 953 | question = graph_text_encoders.encode_graph(graph, encoding_method) 954 | nnodes = len(graph.nodes()) 955 | # Sampling nnodes // 2 + 1 nodes. 956 | sampled_nodes = random.sample( 957 | list(graph.nodes(data=True)), k=nnodes // 2 + 1 958 | ) 959 | # Adding the class of half of the nodes. 960 | for node_data in sampled_nodes[:-1]: 961 | node_class = classes[node_data[1]['block']] 962 | question += ( 963 | 'Node ' + name_dict[node_data[0]] + ' likes ' + node_class + '.\n' 964 | ) 965 | # Reserving the last sampled node for the question. 966 | task_description = 'Q: Does node %s like %s or %s?\nA: ' % ( 967 | name_dict[sampled_nodes[-1][0]], 968 | classes[0], 969 | classes[1], 970 | ) 971 | question += task_description 972 | answer = classes[sampled_nodes[-1][1]['block']] 973 | 974 | examples_dict[ind] = { 975 | 'question': question, 976 | 'answer': answer, 977 | 'nnodes': str(nnodes), 978 | 'nedges': str(len(graph.edges())), 979 | 'task_description': task_description, 980 | 'graph': graph, 981 | 'algorithm': generator_algorithms[ind], 982 | # id of the last samples node 983 | 'node_ids': [sampled_nodes[-1][0]], 984 | } 985 | 986 | return examples_dict 987 | 988 | def create_few_shot_example( 989 | self, graph: nx.Graph, encoding_method: str, cot: bool 990 | ) -> str: 991 | classes = random.sample(list(self.classes), k=2) 992 | name_dict = graph_text_encoders.get_tlag_node_encoder( 993 | graph, encoding_method 994 | ) 995 | question = graph_text_encoders.encode_graph(graph, encoding_method) 996 | nnodes = len(graph.nodes()) 997 | sampled_nodes = random.sample( 998 | list(graph.nodes(data=True)), k=nnodes // 2 + 1 999 | ) 1000 | for node_data in sampled_nodes[:-1]: 1001 | node_class = classes[node_data[1]['block']] 1002 | question += ( 1003 | 'Node ' + name_dict[node_data[0]] + ' likes ' + node_class + '.\n' 1004 | ) 1005 | task_description = 'Q: Does node %s like %s or %s?\nA: ' % ( 1006 | name_dict[sampled_nodes[-1][0]], 1007 | classes[0], 1008 | classes[1], 1009 | ) 1010 | question += task_description 1011 | answer = classes[sampled_nodes[-1][1]['block']] 1012 | 1013 | if cot: 1014 | explanation = ( 1015 | ' This is because most of the nodes that are connected to node %s' 1016 | ' likes %s.' 1017 | % (sampled_nodes[-1][0], classes[sampled_nodes[-1][1]['block']]) 1018 | ) 1019 | answer += explanation 1020 | return question + answer 1021 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_tasks_generator.py: -------------------------------------------------------------------------------- 1 | r"""The graph tasks to be tried with LLMs.. 2 | 3 | This code loads graphs and creates graph tasks and output them as tf examples in 4 | a recordio file in the task directory provided. 5 | 6 | # Placeholder for Google-internal comments. 7 | """ 8 | 9 | from collections.abc import Sequence 10 | import os 11 | import random 12 | 13 | from absl import app 14 | from absl import flags 15 | import networkx as nx 16 | import numpy as np 17 | 18 | from . import graph_tasks 19 | from . import graph_tasks_utils as utils 20 | 21 | _TASK_DIR = flags.DEFINE_string( 22 | 'task_dir', None, 'The directory to write tasks.', required=True 23 | ) 24 | _GRAPHS_DIR = flags.DEFINE_string( 25 | 'graphs_dir', None, 'The directory containing the graphs.', required=True 26 | ) 27 | _RANDOM_SEED = flags.DEFINE_integer( 28 | 'random_seed', 29 | None, 30 | 'The random seed to use for task generation.', 31 | required=True, 32 | ) 33 | 34 | 35 | def zero_shot( 36 | task: graph_tasks.GraphTask, 37 | graphs: list[nx.Graph], 38 | algorithms: list[str], 39 | text_encoders: list[str], 40 | cot: bool, 41 | random_seed: int, 42 | split: str, 43 | ) -> None: 44 | """Creating zero-shot or zero-cot examples for the given task. 45 | 46 | Args: 47 | task: the corresponding graph task. 48 | graphs: the list of graphs to use for the task. 49 | algorithms: the algorithm used to generate the graphs. 50 | text_encoders: the encoders to use in the tasks. 51 | cot: whether to apply cot or not. 52 | random_seed: the random seed to use in the process. 53 | split: whether we are creating a train or test split. 54 | """ 55 | random.seed(random_seed) 56 | zero_shot_examples = utils.create_zero_shot_task( 57 | task, graphs, algorithms, text_encoders, cot=cot 58 | ) 59 | 60 | file_name = task.name + ('_zero_cot_' if cot else '_zero_shot_') 61 | 62 | file_name += split + '.recordio' 63 | utils.write_examples( 64 | zero_shot_examples, 65 | os.path.join(_TASK_DIR.value, file_name), 66 | ) 67 | 68 | 69 | def few_shot( 70 | task: graph_tasks.GraphTask, 71 | graphs: list[nx.Graph], 72 | few_shot_graphs: list[nx.Graph], 73 | algorithms: list[str], 74 | text_encoders: list[str], 75 | cot: bool, 76 | bag: bool, 77 | random_seed: int, 78 | ) -> None: 79 | """Creating few-shot, cot, or cot-bag examples for the given task. 80 | 81 | Args: 82 | task: the corresponding graph task. 83 | graphs: the list of graphs to use for the task. 84 | few_shot_graphs: the list of graphs to generate few shot examples for. 85 | algorithms: the algorithm used to generate the graphs. 86 | text_encoders: the encoders to use in the tasks. 87 | cot: whether to apply cot or not. 88 | bag: whether to apply build-a-graph method or not. 89 | random_seed: the random seed to use in the process. 90 | """ 91 | random.seed(random_seed) 92 | few_shot_examples = utils.create_few_shot_task( 93 | task, 94 | graphs, 95 | algorithms, 96 | few_shot_graphs, 97 | text_encoders, 98 | cot=cot, 99 | bag=bag, 100 | random_seed=random_seed, 101 | ) 102 | file_name = task.name 103 | if cot and bag: 104 | file_name += '_few_shot_cot_bag_test.recordio' 105 | elif cot: 106 | file_name += '_few_shot_cot_test.recordio' 107 | else: 108 | file_name += '_few_shot_test.recordio' 109 | 110 | utils.write_examples( 111 | few_shot_examples, 112 | os.path.join(_TASK_DIR.value, file_name), 113 | ) 114 | 115 | 116 | def generate_random_sbm_graph(random_state: np.random.RandomState): 117 | # Sampling a small number as the probability of the two nodes in different 118 | # communities being connected. 119 | small_number = random.uniform(0, 0.05) 120 | # Sampling a large number as probability of the nodes in one community 121 | # being connected. 122 | large_number = random.uniform(0.6, 0.8) 123 | number_of_nodes = random.choice(np.arange(5, 20)) 124 | sizes = [number_of_nodes // 2, number_of_nodes // 2] 125 | probs = [[large_number, small_number], [small_number, large_number]] 126 | return nx.stochastic_block_model(sizes, probs, seed=random_state) 127 | 128 | 129 | def main(argv: Sequence[str]) -> None: 130 | if len(argv) > 1: 131 | raise app.UsageError('Too many command-line arguments.') 132 | 133 | algorithms = ['er'] 134 | directions = ['undirected'] 135 | text_encoders = ['adjacency'] 136 | 137 | # Loading the graphs. 138 | graphs = [] 139 | generator_algorithms = [] 140 | for algorithm in algorithms: 141 | for direction in directions: 142 | loaded_graphs = utils.load_graphs( 143 | _GRAPHS_DIR.value, 144 | algorithm, 145 | 'train', 146 | direction, 147 | ) 148 | graphs += loaded_graphs 149 | generator_algorithms += [algorithm] * len(loaded_graphs) 150 | 151 | # Defining a task on the graphs 152 | task = graph_tasks.ShortestPath() 153 | 154 | if isinstance(task, graph_tasks.NodeClassification): 155 | # The node classification task requires SBM graphs. As it's not possible to 156 | # write graphs with data (e.g., blocks data as in SBM graphs), we regenerate 157 | # graphs. 158 | 159 | random_state = np.random.RandomState(_RANDOM_SEED.value) 160 | print('Generating sbm graphs') 161 | graphs = [ 162 | generate_random_sbm_graph(random_state) for _ in range(len(graphs)) 163 | ] 164 | 165 | zero_shot( 166 | task, 167 | graphs, 168 | generator_algorithms, 169 | text_encoders, 170 | cot=False, 171 | random_seed=_RANDOM_SEED.value, 172 | split='test', 173 | ) 174 | zero_shot( 175 | task, 176 | graphs, 177 | generator_algorithms, 178 | text_encoders, 179 | cot=True, 180 | random_seed=_RANDOM_SEED.value, 181 | split='test', 182 | ) 183 | 184 | # Loading few-shot graphs. 185 | few_shot_graphs = [] 186 | for algorithm in algorithms: 187 | for direction in directions: 188 | few_shot_graphs += utils.load_graphs( 189 | _GRAPHS_DIR.value, 190 | algorithm, 191 | 'train', 192 | direction, 193 | ) 194 | 195 | if isinstance(task, graph_tasks.NodeClassification): 196 | # The node classification task requires SBM graphs. As it's not possible to 197 | # write graphs with data (e.g., blocks data as in SBM graphs), we regenerate 198 | # graphs. 199 | random_state = np.random.RandomState(_RANDOM_SEED.value + 1) 200 | print('Generating few shot sbm graphs') 201 | few_shot_graphs = [ 202 | generate_random_sbm_graph(random_state) 203 | for _ in range(len(few_shot_graphs)) 204 | ] 205 | 206 | few_shot( 207 | task, 208 | graphs, 209 | few_shot_graphs, 210 | generator_algorithms, 211 | text_encoders, 212 | cot=False, 213 | bag=False, 214 | random_seed=_RANDOM_SEED.value, 215 | ) 216 | 217 | few_shot( 218 | task, 219 | graphs, 220 | few_shot_graphs, 221 | generator_algorithms, 222 | text_encoders, 223 | cot=True, 224 | bag=False, 225 | random_seed=_RANDOM_SEED.value, 226 | ) 227 | 228 | few_shot( 229 | task, 230 | graphs, 231 | few_shot_graphs, 232 | generator_algorithms, 233 | text_encoders, 234 | cot=True, 235 | bag=True, 236 | random_seed=_RANDOM_SEED.value, 237 | ) 238 | 239 | 240 | if __name__ == '__main__': 241 | app.run(main) 242 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_tasks_utils.py: -------------------------------------------------------------------------------- 1 | """The graph tasks to be tried with LLMs.""" 2 | 3 | import os 4 | import random 5 | 6 | import networkx as nx 7 | import numpy as np 8 | import seqio 9 | import tensorflow as tf 10 | import tensorflow_gnn as tfgnn 11 | 12 | # Google-internal import(s). 13 | # Internal import. 14 | from . import graph_tasks 15 | from tensorflow.core.example import example_pb2 16 | from tensorflow.core.example import feature_pb2 17 | 18 | 19 | def laplacian_pos_embedding(graph: nx.Graph, units: int = 4) -> nx.Graph: 20 | """Adds the laplacian positional encoding.""" 21 | m = nx.normalized_laplacian_matrix( 22 | graph, nodelist=sorted(graph.nodes), weight=None 23 | ).astype(np.float32) 24 | u, _, _ = np.linalg.svd(m.todense(), compute_uv=True) 25 | if units > u.shape[1]: 26 | u = np.pad(u, ((0, 0), (0, units - u.shape[1]))) 27 | nx.set_node_attributes( 28 | graph, dict(zip(sorted(graph.nodes), u[:, :units])), name='lpe' 29 | ) 30 | return graph 31 | 32 | 33 | def to_tfgnn(graph: nx.Graph, node_ids: list[int]) -> tfgnn.GraphTensor: 34 | """Convert a given nx graph to a tfgnn graph.""" 35 | if graph.edges(data=True): 36 | s, t, w = zip(*[ 37 | (s, t, (d['weight'] if d and 'weight' in d else None)) 38 | for s, t, d in graph.edges(data=True) 39 | ]) 40 | else: 41 | s, t, w = (), (), () 42 | # tfgnn assumes graphs are directed. Adding the rev edges for an undirected 43 | # graph. 44 | if not graph.is_directed(): 45 | s, t, w = s + t, t + s, w + w 46 | 47 | graph = laplacian_pos_embedding(graph, units=4) 48 | features = set(k for n in graph.nodes for k in graph.nodes[n].keys()) # pylint: disable=g-complex-comprehension 49 | node_features = { 50 | f: tf.convert_to_tensor([graph.nodes[n][f] for n in graph.nodes]) 51 | for f in features 52 | } 53 | # If all edges have a non-trivial weight, then we record the weights. 54 | if all(w): 55 | edge_features = {'weights': tf.convert_to_tensor(w, dtype=tf.int32)} 56 | gt = tfgnn.homogeneous( 57 | tf.convert_to_tensor(s, dtype=tf.int32), 58 | tf.convert_to_tensor(t, dtype=tf.int32), 59 | node_features=node_features, 60 | edge_features=edge_features, 61 | ) 62 | else: 63 | gt = tfgnn.homogeneous( 64 | tf.convert_to_tensor(s, dtype=tf.int32), 65 | tf.convert_to_tensor(t, dtype=tf.int32), 66 | node_features=node_features, 67 | ) 68 | 69 | if not node_ids: 70 | # No node is mentioned in the task description. 71 | return gt 72 | node_sets = { 73 | **gt.node_sets, 74 | '_readout': tfgnn.NodeSet.from_fields(sizes=[1]), 75 | } 76 | if len(node_ids) == 1: 77 | # Tasks requiring only one node id e.g., computing node degree. 78 | edge_sets = { 79 | **gt.edge_sets, 80 | '_readout/node': tfgnn.EdgeSet.from_fields( 81 | sizes=[1], 82 | adjacency=tfgnn.Adjacency.from_indices( 83 | source=('nodes', node_ids), 84 | target=('_readout', [0]), 85 | ), 86 | ), 87 | } 88 | elif len(node_ids) == 2: 89 | # Tasks requiring two nodes e.g., shortest path from one node to the other. 90 | edge_sets = { 91 | **gt.edge_sets, 92 | '_readout/source': tfgnn.EdgeSet.from_fields( 93 | sizes=[1], 94 | adjacency=tfgnn.Adjacency.from_indices( 95 | source=('nodes', node_ids[:1]), 96 | target=('_readout', [0]), 97 | ), 98 | ), 99 | '_readout/target': tfgnn.EdgeSet.from_fields( 100 | sizes=[1], 101 | adjacency=tfgnn.Adjacency.from_indices( 102 | source=('nodes', node_ids[1:]), 103 | target=('_readout', [0]), 104 | ), 105 | ), 106 | } 107 | else: 108 | # Raising an error if more than two nodes are mentiones. 109 | raise ValueError(f'Invalid number of integers: {len(node_ids)}') 110 | 111 | return tfgnn.GraphTensor.from_pieces( 112 | context=gt.context, node_sets=node_sets, edge_sets=edge_sets 113 | ) 114 | 115 | 116 | def create_example_feature( 117 | key: int, 118 | question: str, 119 | answer: str, 120 | algorithm: str, 121 | encoding_method: str, 122 | nnodes: str, 123 | nedges: str, 124 | task_description: str, 125 | graph: nx.Graph, 126 | node_ids: list[int], 127 | ) -> example_pb2.Example: 128 | """Create a tensorflow example from a datapoint.""" 129 | key_feature = feature_pb2.Feature( 130 | bytes_list=tf.train.BytesList(value=[str(key).encode()]) 131 | ) 132 | question_feature = feature_pb2.Feature( 133 | bytes_list=tf.train.BytesList(value=[question.encode()]) 134 | ) 135 | answer_feature = feature_pb2.Feature( 136 | bytes_list=tf.train.BytesList(value=[answer.encode()]) 137 | ) 138 | algorithm_feature = feature_pb2.Feature( 139 | bytes_list=tf.train.BytesList(value=[algorithm.encode()]) 140 | ) 141 | encoding_method_feature = feature_pb2.Feature( 142 | bytes_list=tf.train.BytesList(value=[encoding_method.encode()]) 143 | ) 144 | nnodes_feature = feature_pb2.Feature( 145 | bytes_list=tf.train.BytesList(value=[nnodes.encode()]) 146 | ) 147 | nedges_feature = feature_pb2.Feature( 148 | bytes_list=tf.train.BytesList(value=[nedges.encode()]) 149 | ) 150 | task_description_feature = feature_pb2.Feature( 151 | bytes_list=tf.train.BytesList(value=[task_description.encode()]) 152 | ) 153 | gt = to_tfgnn(graph, node_ids) 154 | graph_feature = feature_pb2.Feature( 155 | bytes_list=tf.train.BytesList( 156 | value=[tfgnn.write_example(gt).SerializeToString()] 157 | ) 158 | ) 159 | directed_feature = feature_pb2.Feature( 160 | bytes_list=tf.train.BytesList(value=[str(graph.is_directed()).encode()]) 161 | ) 162 | example_feats = tf.train.Features( 163 | feature={ 164 | 'id': key_feature, 165 | 'question': question_feature, 166 | 'answer': answer_feature, 167 | 'algorithm': algorithm_feature, 168 | 'text_encoding': encoding_method_feature, 169 | 'nnodes': nnodes_feature, 170 | 'nedges': nedges_feature, 171 | 'task_description': task_description_feature, 172 | 'graph': graph_feature, 173 | 'directed': directed_feature, 174 | } 175 | ) 176 | return example_pb2.Example(features=example_feats) 177 | 178 | 179 | def load_graphs( 180 | base_path: str, 181 | algorithm: str, 182 | split: str, 183 | direction: str, 184 | max_nnodes: int = 20, 185 | ) -> list[nx.Graph]: 186 | """Load a list of graphs from a given algorithm and split.""" 187 | graphs_path = os.path.join( 188 | base_path, 189 | direction, 190 | algorithm, 191 | split, 192 | ) 193 | loaded_graphs = [] 194 | all_files = gfile.ListDir(graphs_path) 195 | for file in all_files: 196 | if file.endswith('.graphml'): 197 | path = os.path.join(graphs_path, file) 198 | graph = nx.read_graphml(os.Open(path, 'rb'), node_type=int) 199 | if graph.number_of_nodes() <= max_nnodes: 200 | loaded_graphs.append(graph) 201 | return loaded_graphs 202 | 203 | 204 | def prepare_examples( 205 | examples_dict: dict[int, dict[str, str | list[int]]], 206 | encoding_method: str, 207 | ) -> list[example_pb2.Example]: 208 | """Create a list of tf.train.Example from a dict of examples.""" 209 | examples = [] 210 | for key, value in examples_dict.items(): 211 | ( 212 | question, 213 | answer, 214 | nnodes, 215 | nedges, 216 | task_description, 217 | graph, 218 | algorithm, 219 | node_ids, 220 | ) = ( 221 | value['question'], 222 | value['answer'], 223 | value['nnodes'], 224 | value['nedges'], 225 | value['task_description'], 226 | value['graph'], 227 | value['algorithm'], 228 | value['node_ids'], 229 | ) 230 | examples.append( 231 | create_example_feature( 232 | key, 233 | question, 234 | answer, 235 | algorithm, 236 | encoding_method, 237 | nnodes, 238 | nedges, 239 | task_description, 240 | graph, 241 | node_ids, 242 | ) 243 | ) 244 | return examples 245 | 246 | 247 | def create_zero_shot_task( 248 | task: graph_tasks.GraphTask, 249 | graphs: list[nx.Graph], 250 | generator_algorithms: list[str], 251 | text_encoders: list[str], 252 | cot: bool = False, 253 | ) -> list[example_pb2.Example]: 254 | """Create a recordio file with zero-shot examples for the task.""" 255 | examples = [] 256 | for encoding_method in text_encoders: 257 | examples_dict = task.prepare_examples_dict( 258 | graphs, generator_algorithms, encoding_method 259 | ) 260 | if cot: 261 | for key in examples_dict.keys(): 262 | examples_dict[key]['question'] += "Let's think step by step. " 263 | examples += prepare_examples(examples_dict, encoding_method) 264 | return examples 265 | 266 | 267 | def write_examples(examples: list[example_pb2.Example], output_path: str): 268 | with recordio.RecordWriter(output_path) as output_file: 269 | for example in examples: 270 | output_file.WriteRecord(example.SerializeToString()) 271 | 272 | 273 | def prepare_few_shots( 274 | task: graph_tasks.GraphTask, 275 | graphs: list[nx.Graph], 276 | text_encoders: list[str], 277 | cot: bool, 278 | ) -> dict[str, list[str]]: 279 | """Create a dict of few-shot examples with their cot for the task.""" 280 | few_shots_examples_dict = {} 281 | for encoding_method in text_encoders: 282 | if encoding_method not in few_shots_examples_dict: 283 | few_shots_examples_dict[(encoding_method)] = [] 284 | for graph in graphs: 285 | few_shots_examples_dict[(encoding_method)].append( 286 | task.create_few_shot_example(graph, encoding_method, cot) 287 | ) 288 | return few_shots_examples_dict 289 | 290 | 291 | def choose_few_shot_examples( 292 | few_shots_dict: dict[str, list[str]], 293 | encoding_method: str, 294 | k: int = 2, 295 | ) -> str: 296 | """Choose few shot examples for each algorithm.""" 297 | few_shots_str = '' 298 | for _ in range(k): 299 | example_list = few_shots_dict[encoding_method] 300 | few_shots_str += 'Example: ' + random.choice(example_list) + '\n' 301 | return few_shots_str 302 | 303 | 304 | def create_few_shot_task( 305 | task: graph_tasks.GraphTask, 306 | graphs: list[nx.Graph], 307 | generator_algorithms: list[str], 308 | few_shots_graphs: list[nx.Graph], 309 | text_encoders: list[str], 310 | cot: bool, 311 | bag: bool, 312 | random_seed: int, 313 | ) -> list[example_pb2.Example]: 314 | """Create a recordio file with few-shot examples for the task.""" 315 | # LINT.IfChange 316 | vocab_path = None 317 | # LINT.ThenChange(//research/graph/llm/graphqa/copy.bara.sky) 318 | # Loading the palm tokenizer to calculate number of tokens in the sequence. 319 | sp_vocab = seqio.SentencePieceVocabulary(vocab_path) 320 | number_of_tokens = {} 321 | examples = [] 322 | print('prepare few shot task', 'cot', cot, 'bag', bag) 323 | few_shots_examples_dict = prepare_few_shots( 324 | task, 325 | few_shots_graphs, 326 | text_encoders, 327 | cot, 328 | ) 329 | for encoding_method in text_encoders: 330 | random.seed(random_seed) 331 | examples_dict = task.prepare_examples_dict( 332 | graphs, generator_algorithms, encoding_method 333 | ) 334 | for key in examples_dict.keys(): 335 | few_shots_examples = choose_few_shot_examples( 336 | few_shots_examples_dict, 337 | encoding_method, 338 | ) 339 | examples_dict[key]['question'] = ( 340 | few_shots_examples + 'Example: ' + examples_dict[key]['question'] 341 | ) 342 | if bag: 343 | examples_dict[key]['question'] = examples_dict[key]['question'].replace( 344 | '\nQ: ', 345 | "\nLet's construct the graph with the nodes and edges first.\nQ: ", 346 | ) # pytype: disable=attribute-error 347 | if encoding_method not in number_of_tokens: 348 | number_of_tokens[encoding_method] = [] 349 | number_of_tokens[encoding_method].append( 350 | len(sp_vocab.encode(examples_dict[key]['question'])) 351 | ) 352 | examples += prepare_examples(examples_dict, encoding_method) 353 | 354 | # Printing maximum number of tokens in the sequence. 355 | for key, value in number_of_tokens.items(): 356 | print(key, np.max(value)) 357 | 358 | return examples 359 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_text_encoders.py: -------------------------------------------------------------------------------- 1 | """Library for encoding graphs in text.""" 2 | 3 | import networkx as nx 4 | 5 | from . import name_dictionaries 6 | 7 | 8 | def create_node_string(name_dict, nnodes: int) -> str: 9 | node_string = "" 10 | sorted_keys = list(sorted(name_dict.keys())) 11 | for i in sorted_keys[: nnodes - 1]: 12 | node_string += name_dict[i] + ", " 13 | node_string += "and " + name_dict[sorted_keys[nnodes - 1]] 14 | return node_string 15 | 16 | 17 | def nx_encoder(graph: nx.Graph, _: dict[int, str], edge_type="id") -> str: 18 | """Encoding a graph as entries of an adjacency matrix.""" 19 | if graph.is_directed(): 20 | output = ( 21 | "In a directed graph, (s,p,o) means that there is an edge from node s" 22 | " to node o of type p. " 23 | ) 24 | else: 25 | output = ( 26 | "In an undirected graph, (s,p,o) means that node s and node o are" 27 | " connected with an undirected edge of type p. " 28 | ) 29 | 30 | name_dict = {x: str(x) for x in graph.nodes()} 31 | 32 | nodes_string = create_node_string(name_dict, nnodes=len(graph.nodes())) 33 | output += "G describes a graph among nodes %s.\n" % nodes_string 34 | if graph.edges(): 35 | output += "The edges in G are: " 36 | for i, j in graph.edges(): 37 | edge_type = graph.get_edge_data(i, j)[edge_type] 38 | if edge_type is None: 39 | edge_type = "linked" 40 | output += "(%s, %s, %s) " % (name_dict[i], edge_type, name_dict[j]) 41 | return output.strip() + ".\n" 42 | 43 | 44 | def adjacency_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 45 | """Encoding a graph as entries of an adjacency matrix.""" 46 | if graph.is_directed(): 47 | output = ( 48 | "In a directed graph, (i,j) means that there is an edge from node i to" 49 | " node j. " 50 | ) 51 | else: 52 | output = ( 53 | "In an undirected graph, (i,j) means that node i and node j are" 54 | " connected with an undirected edge. " 55 | ) 56 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 57 | output += "G describes a graph among nodes %s.\n" % nodes_string 58 | if graph.edges(): 59 | output += "The edges in G are: " 60 | for i, j in graph.edges(): 61 | output += "(%s, %s) " % (name_dict[i], name_dict[j]) 62 | return output.strip() + ".\n" 63 | 64 | 65 | def friendship_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 66 | """Encoding a graph as a friendship graph.""" 67 | if graph.is_directed(): 68 | raise ValueError("Friendship encoder is not defined for directed graphs.") 69 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 70 | output = ( 71 | "G describes a friendship graph among nodes %s.\n" % nodes_string.strip() 72 | ) 73 | if graph.edges(): 74 | output += "We have the following edges in G:\n" 75 | for i, j in graph.edges(): 76 | output += "%s and %s are friends.\n" % (name_dict[i], name_dict[j]) 77 | return output 78 | 79 | 80 | def coauthorship_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 81 | """Encoding a graph as a coauthorship graph.""" 82 | if graph.is_directed(): 83 | raise ValueError("Coauthorship encoder is not defined for directed graphs.") 84 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 85 | output = ( 86 | "G describes a coauthorship graph among nodes %s.\n" 87 | % nodes_string.strip() 88 | ) 89 | if graph.edges(): 90 | output += "In this coauthorship graph:\n" 91 | for i, j in graph.edges(): 92 | output += "%s and %s wrote a paper together.\n" % ( 93 | name_dict[i], 94 | name_dict[j], 95 | ) 96 | return output.strip() + ".\n" 97 | 98 | 99 | def incident_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 100 | """Encoding a graph with its incident lists.""" 101 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 102 | output = "G describes a graph among nodes %s.\n" % nodes_string 103 | if graph.edges(): 104 | output += "In this graph:\n" 105 | for source_node in graph.nodes(): 106 | target_nodes = graph.neighbors(source_node) 107 | target_nodes_str = "" 108 | nedges = 0 109 | for target_node in target_nodes: 110 | target_nodes_str += name_dict[target_node] + ", " 111 | nedges += 1 112 | if nedges > 1: 113 | output += "Node %s is connected to nodes %s.\n" % ( 114 | source_node, 115 | target_nodes_str[:-2], 116 | ) 117 | elif nedges == 1: 118 | output += "Node %d is connected to node %s.\n" % ( 119 | source_node, 120 | target_nodes_str[:-2], 121 | ) 122 | return output 123 | 124 | 125 | def social_network_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 126 | """Encoding a graph as a social network graph.""" 127 | if graph.is_directed(): 128 | raise ValueError( 129 | "Social network encoder is not defined for directed graphs." 130 | ) 131 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 132 | output = ( 133 | "G describes a social network graph among nodes %s.\n" 134 | % nodes_string.strip() 135 | ) 136 | if graph.edges(): 137 | output += "We have the following edges in G:\n" 138 | for i, j in graph.edges(): 139 | output += "%s and %s are connected.\n" % (name_dict[i], name_dict[j]) 140 | return output 141 | 142 | 143 | def expert_encoder(graph: nx.Graph, name_dict: dict[int, str]) -> str: 144 | nodes_string = create_node_string(name_dict, len(graph.nodes())) 145 | output = ( 146 | "You are a graph analyst and you have been given a graph G among nodes" 147 | " %s.\n" 148 | % nodes_string.strip() 149 | ) 150 | output += "G has the following undirected edges:\n" if graph.edges() else "" 151 | for i, j in graph.edges(): 152 | output += "%s -> %s\n" % (name_dict[i], name_dict[j]) 153 | return output 154 | 155 | 156 | def nodes_to_text(graph, encoding_type): 157 | """Get dictionary converting node ids to text.""" 158 | if encoding_type == "integer": 159 | return name_dictionaries.create_name_dict(graph, "integer", nnodes=1000) 160 | elif encoding_type == "popular": 161 | return name_dictionaries.create_name_dict(graph, "popular") 162 | elif encoding_type == "alphabet": 163 | return name_dictionaries.create_name_dict(graph, "alphabet") 164 | elif encoding_type == "got": 165 | return name_dictionaries.create_name_dict(graph, "got") 166 | elif encoding_type == "south_park": 167 | return name_dictionaries.create_name_dict(graph, "south_park") 168 | elif encoding_type == "politician": 169 | return name_dictionaries.create_name_dict(graph, "politician") 170 | elif encoding_type == "random": 171 | return name_dictionaries.create_name_dict( 172 | graph, "random_integer", nnodes=1000 173 | ) 174 | elif encoding_type == "nx_node_name": 175 | return name_dictionaries.create_name_dict(graph, "nx_node_name") 176 | else: 177 | raise ValueError("Unknown encoding type: %s" % encoding_type) 178 | 179 | 180 | def get_tlag_node_encoder(graph, encoder_name): 181 | """Find the node encoder used in the 'Talk Like a Graph' paper.""" 182 | if encoder_name == "adjacency": 183 | return nodes_to_text(graph, "integer") 184 | elif encoder_name == "incident": 185 | return nodes_to_text(graph, "integer") 186 | elif encoder_name == "friendship": 187 | return nodes_to_text(graph, "popular") 188 | elif encoder_name == "south_park": 189 | return nodes_to_text(graph, "south_park") 190 | elif encoder_name == "got": 191 | return nodes_to_text(graph, "got") 192 | elif encoder_name == "politician": 193 | return nodes_to_text(graph, "politician") 194 | elif encoder_name == "social_network": 195 | return nodes_to_text(graph, "popular") 196 | elif encoder_name == "expert": 197 | return nodes_to_text(graph, "expert") 198 | elif encoder_name == "coauthorship": 199 | return nodes_to_text(graph, "popular") 200 | elif encoder_name == "random": 201 | return nodes_to_text(graph, "random") 202 | elif encoder_name == "nx_node_name": 203 | return nodes_to_text(graph, "nx_node_name") 204 | else: 205 | raise ValueError("Unknown graph encoder strategy: %s" % encoder_name) 206 | 207 | 208 | # A dictionary from edge encoder name to the corresponding function. 209 | EDGE_ENCODER_FN = { 210 | "adjacency": adjacency_encoder, 211 | "incident": incident_encoder, 212 | "friendship": friendship_encoder, 213 | "south_park": friendship_encoder, 214 | "got": friendship_encoder, 215 | "politician": social_network_encoder, 216 | "social_network": social_network_encoder, 217 | "expert": expert_encoder, 218 | "coauthorship": coauthorship_encoder, 219 | "random": adjacency_encoder, 220 | "nx_edge_encoder": nx_encoder, 221 | } 222 | 223 | 224 | def with_ids(graph: nx.Graph, node_encoder: str) -> nx.Graph: 225 | nx.set_node_attributes(graph, nodes_to_text(graph, node_encoder), name="id") 226 | return graph 227 | 228 | 229 | def encode_graph( 230 | graph: nx.Graph, graph_encoder=None, node_encoder=None, edge_encoder=None 231 | ) -> str: 232 | r"""Encodes a graph as text. 233 | 234 | This relies on choosing: 235 | a node_encoder and an edge_encoder: 236 | or 237 | a graph_encoder (a predefined pair of node and edge encoding strategies). 238 | 239 | Note that graph_encoders may assume that the graph has some properties 240 | (e.g. integer keys). 241 | 242 | Example usage: 243 | .. code-block:: python 244 | ``` 245 | # Use a predefined graph encoder from the paper. 246 | >>> G = nx.karate_club_graph() 247 | >>> encode_graph(G, graph_encoder="adjacency") 248 | 'In an undirected graph, (i,j) means that node i and node j are 249 | connected 250 | with an undirected edge. G describes a graph among nodes 0, 1, 2, 3, 4, 5, 251 | 6, 252 | 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 253 | 27, 28, 29, 30, 31, 32, and 33.\nThe edges in G are: (0, 1) (0, 2) (0, 3) 254 | ...' 255 | 256 | # Use the node's name in the graph as the node identifier. 257 | >>> G = nx.les_miserables_graph() 258 | >>> encode_graph(G, node_encoder="nx_node_name", edge_encoder="friendship") 259 | 'G describes a friendship graph among nodes Anzelma, Babet, Bahorel, 260 | Bamatabois, BaronessT, Blacheville, Bossuet, Boulatruelle, Brevet, ... 261 | We have the following edges in G: 262 | Napoleon and Myriel are friends. Myriel and MlleBaptistine are friends...' 263 | 264 | # Use the `id` feature from the edges to describe the edge type. 265 | >>> G = nx.karate_club_graph() 266 | >>> encode_graph(G, node_encoder="nx_node_name", edge_encoder="nx_edge_id") 267 | 'In an undirected graph, (s,p,o) means that node s and node o are connected 268 | with an undirected edge of type p. G describes a graph among nodes 0, 1, 2, 3, 269 | 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 270 | 25, 26, 27, 28, 29, 30, 31, 32, and 33. 271 | The edges in G are: (0, linked, 1) (0, linked, 2) (0, linked, 3) ...' 272 | ``` 273 | 274 | Args: 275 | graph: the graph to be encoded. 276 | graph_encoder: the name of the graph encoder to use. 277 | node_encoder: the name of the node encoder to use. 278 | edge_encoder: the name of the edge encoder to use. 279 | 280 | Returns: 281 | The encoded graph as a string. 282 | """ 283 | 284 | # Check that only one of graph_encoder or (node_encoder, edge_encoder) is set. 285 | if graph_encoder and (node_encoder or edge_encoder): 286 | raise ValueError( 287 | "Only one of graph_encoder or (node_encoder, edge_encoder) can be set." 288 | ) 289 | 290 | if graph_encoder: 291 | if isinstance(graph_encoder, str): 292 | node_encoder_dict = get_tlag_node_encoder(graph, graph_encoder) 293 | return EDGE_ENCODER_FN[graph_encoder](graph, node_encoder_dict) 294 | else: 295 | return graph_encoder(graph) 296 | 297 | else: 298 | node_encoder_dict = nodes_to_text(graph, node_encoder) 299 | return EDGE_ENCODER_FN[edge_encoder](graph, node_encoder_dict) 300 | -------------------------------------------------------------------------------- /talk_like_a_graph/graph_text_encoders_test.py: -------------------------------------------------------------------------------- 1 | """Testing for graph_text_encoders.py.""" 2 | 3 | from absl.testing import parameterized 4 | import networkx as nx 5 | 6 | from . import graph_text_encoders 7 | from absl.testing import absltest 8 | 9 | _G = nx.Graph() 10 | _G.add_node(0) 11 | _G.add_node(1) 12 | _G.add_node(2) 13 | _G.add_node(3) 14 | _G.add_edge(0, 1) 15 | _G.add_edge(1, 2) 16 | _G.add_edge(2, 3) 17 | _G.add_edge(3, 0) 18 | 19 | 20 | class GraphTextEncodersTest(absltest.TestCase, parameterized.TestCase): 21 | 22 | @parameterized.named_parameters( 23 | dict( 24 | testcase_name='adjacency_integer', 25 | encoding_method='adjacency', 26 | expected_result=( 27 | 'In an undirected graph, (i,j) means that node i and node j are' 28 | ' connected with an undirected edge. G describes a graph among' 29 | ' nodes 0, 1, 2, and 3.\nThe edges in G are: (0, 1) (0, 3) (1, 2)' 30 | ' (2, 3).\n' 31 | ), 32 | ), 33 | dict( 34 | testcase_name='incident_integer', 35 | encoding_method='incident', 36 | expected_result=( 37 | 'G describes a graph among nodes 0, 1, 2, and 3.\nIn this' 38 | ' graph:\nNode 0 is connected to nodes 1, 3.\nNode 1 is connected' 39 | ' to nodes 0, 2.\nNode 2 is connected to nodes 1, 3.\nNode 3 is' 40 | ' connected to nodes 2, 0.\n' 41 | ), 42 | ), 43 | dict( 44 | testcase_name='friendship_per_line_popular', 45 | encoding_method='friendship', 46 | expected_result=( 47 | 'G describes a friendship graph among nodes James, Robert, John,' 48 | ' and Michael.\nWe have the following edges in G:\nJames and' 49 | ' Robert are friends.\nJames and Michael are friends.\nRobert and' 50 | ' John are friends.\nJohn and Michael are friends.\n' 51 | ), 52 | ), 53 | dict( 54 | testcase_name='social_network_politician', 55 | encoding_method='politician', 56 | expected_result=( 57 | 'G describes a social network graph among nodes Barack, Jimmy,' 58 | ' Arnold, and Bernie.\nWe have the following edges in' 59 | ' G:\nBarack and Jimmy are connected.\nBarack and Bernie are' 60 | ' connected.\nJimmy and Arnold are connected.\nArnold and' 61 | ' Bernie are connected.\n' 62 | ), 63 | ), 64 | ) 65 | def test_encoders(self, encoding_method, expected_result): 66 | self.assertEqual( 67 | graph_text_encoders.encode_graph(_G, encoding_method), 68 | expected_result, 69 | ) 70 | 71 | 72 | if __name__ == '__main__': 73 | googletest.main() 74 | -------------------------------------------------------------------------------- /talk_like_a_graph/name_dictionaries.py: -------------------------------------------------------------------------------- 1 | """Creates a dictionary mapping integers to node names.""" 2 | 3 | import random 4 | 5 | _RANDOM_SEED = 1234 6 | random.seed(_RANDOM_SEED) 7 | 8 | _INTEGER_NAMES = [str(x) for x in range(10000)] 9 | 10 | _POPULAR_NAMES = [ 11 | "James", 12 | "Robert", 13 | "John", 14 | "Michael", 15 | "David", 16 | "Mary", 17 | "Patricia", 18 | "Jennifer", 19 | "Linda", 20 | "Elizabeth", 21 | "William", 22 | "Richard", 23 | "Joseph", 24 | "Thomas", 25 | "Christopher", 26 | "Barbara", 27 | "Susan", 28 | "Jessica", 29 | "Sarah", 30 | "Karen", 31 | "Daniel", 32 | "Lisa", 33 | "Matthew", 34 | "Nancy", 35 | "Anthony", 36 | "Betty", 37 | "Mark", 38 | "Margaret", 39 | "Donald", 40 | "Sandra", 41 | "Steven", 42 | "Ashley", 43 | "Paul", 44 | "Kimberly", 45 | "Andrew", 46 | "Emily", 47 | "Joshua", 48 | "Donna", 49 | "Kenneth", 50 | "Michelle", 51 | "Kevin", 52 | "Carol", 53 | "Brian", 54 | "Amanda", 55 | "George", 56 | "Melissa", 57 | "Edward", 58 | "Deborah", 59 | "Ronald", 60 | "Stephanie", 61 | "Timothy", 62 | "Rebecca", 63 | "Jason", 64 | "Sharon", 65 | "Jeffrey", 66 | "Laura", 67 | "Ryan", 68 | "Cynthia", 69 | "Jacob", 70 | "Dorothy", 71 | "Gary", 72 | "Olivia", 73 | "Nicholas", 74 | "Emma", 75 | "Eric", 76 | "Sophia", 77 | "Jonathan", 78 | "Ava", 79 | "Stephen", 80 | "Isabella", 81 | "Scott", 82 | "Mia", 83 | "Justin", 84 | "Abigail", 85 | "Brandon", 86 | "Madison", 87 | "Frank", 88 | "Chloe", 89 | "Benjamin", 90 | "Victoria", 91 | "Samuel", 92 | "Lauren", 93 | "Gregory", 94 | "Hannah", 95 | "Alexander", 96 | "Grace", 97 | "Frank", 98 | "Alexis", 99 | "Raymond", 100 | "Alice", 101 | "Patrick", 102 | "Samantha", 103 | "Jack", 104 | "Natalie", 105 | "Dennis", 106 | "Anna", 107 | "Jerry", 108 | "Taylor", 109 | "Tyler", 110 | "Kayla", 111 | "Henry", 112 | "Hailey", 113 | "Douglas", 114 | "Jasmine", 115 | "Peter", 116 | "Nicole", 117 | "Adam", 118 | "Amy", 119 | "Nathan", 120 | "Christina", 121 | "Zachary", 122 | "Andrea", 123 | "Jose", 124 | "Leah", 125 | "Walter", 126 | "Angelina", 127 | "Harold", 128 | "Valerie", 129 | "Kyle", 130 | "Veronica", 131 | "Ethan", 132 | "Carl", 133 | "Arthur", 134 | "Roger", 135 | "Noah", 136 | ] 137 | 138 | 139 | _SOUTH_PARK_NAMES = [ 140 | "Eric", 141 | "Kenny", 142 | "Kyle", 143 | "Stan", 144 | "Tolkien", 145 | "Heidi", 146 | "Bebe", 147 | "Liane", 148 | "Sharon", 149 | "Linda", 150 | "Gerald", 151 | "Veronica", 152 | "Michael", 153 | "Jimbo", 154 | "Herbert", 155 | "Malcolm", 156 | "Gary", 157 | "Steve", 158 | "Chris", 159 | "Wendy", 160 | ] 161 | 162 | _GOT_NAMES = [ 163 | "Ned", 164 | "Cat", 165 | "Daenerys", 166 | "Jon", 167 | "Bran", 168 | "Sansa", 169 | "Arya", 170 | "Cersei", 171 | "Jaime", 172 | "Petyr", 173 | "Robert", 174 | "Jorah", 175 | "Viserys", 176 | "Joffrey", 177 | "Maester", 178 | "Theon", 179 | "Rodrik", 180 | "Lysa", 181 | "Stannis", 182 | "Osha", 183 | ] 184 | 185 | 186 | _POLITICIAN_NAMES = [ 187 | "Barack", 188 | "Jimmy", 189 | "Arnold", 190 | "Bernie", 191 | "Bill", 192 | "Kamala", 193 | "Hillary", 194 | "Elizabeth", 195 | "John", 196 | "Ben", 197 | "Joe", 198 | "Alexandria", 199 | "George", 200 | "Nancy", 201 | "Pete", 202 | "Madeleine", 203 | "Elijah", 204 | "Gabrielle", 205 | "Al", 206 | ] 207 | 208 | 209 | _ALPHABET_NAMES = [ 210 | "A", 211 | "B", 212 | "C", 213 | "D", 214 | "E", 215 | "F", 216 | "G", 217 | "H", 218 | "I", 219 | "J", 220 | "K", 221 | "L", 222 | "M", 223 | "N", 224 | "O", 225 | "P", 226 | "Q", 227 | "R", 228 | "S", 229 | "T", 230 | "U", 231 | "V", 232 | "W", 233 | "X", 234 | "Y", 235 | "Z", 236 | "AA", 237 | "BB", 238 | "CC", 239 | "DD", 240 | "EE", 241 | "FF", 242 | "GG", 243 | "HH", 244 | "II", 245 | "JJ", 246 | "KK", 247 | "LL", 248 | "MM", 249 | "NN", 250 | "OO", 251 | "PP", 252 | "QQ", 253 | "RR", 254 | "SS", 255 | "TT", 256 | "UU", 257 | "VV", 258 | "WW", 259 | "XX", 260 | "YY", 261 | "ZZ", 262 | ] 263 | 264 | 265 | def create_name_dict(graph, name: str, nnodes: int = 20) -> dict[int, str]: 266 | """The runner function to map integers to node names. 267 | 268 | Args: 269 | graph: the graph to be encoded. 270 | name: name of the approach for mapping. 271 | nnodes: optionally provide nnodes in the graph to be encoded. 272 | 273 | Returns: 274 | A dictionary from integers to strings. 275 | """ 276 | if name == "alphabet": 277 | names_list = _ALPHABET_NAMES 278 | elif name == "integer": 279 | names_list = _INTEGER_NAMES 280 | elif name == "random_integer": 281 | names_list = [] 282 | for _ in range(nnodes): 283 | names_list.append(str(random.randint(0, 1000000))) 284 | elif name == "popular": 285 | names_list = _POPULAR_NAMES 286 | elif name == "south_park": 287 | names_list = _SOUTH_PARK_NAMES 288 | elif name == "got": 289 | names_list = _GOT_NAMES 290 | elif name == "politician": 291 | names_list = _POLITICIAN_NAMES 292 | elif name == "nx_node_name": 293 | return {x: str(x) for x in graph.nodes()} 294 | else: 295 | raise ValueError(f"Unknown approach: {name}") 296 | name_dict = {} 297 | for ind, value in enumerate(names_list): 298 | name_dict[ind] = value 299 | 300 | return name_dict 301 | -------------------------------------------------------------------------------- /tutorial/KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "V28" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "TPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "This is a noteboook that illustrates how to use Graph Neural Networks to encode structured data for use in Large Language Models. It is Part 2 of a two part tutorial from KDD'24.\n", 23 | "\n", 24 | "**This notebook requires a TPUv2 runtime**\n", 25 | "\n", 26 | "If you find this tutorial useful or want to know more, please consider our publication:\n", 27 | "Let your graph do the talking: Encoding structured data for LLMs\n", 28 | "```\n", 29 | "@article{perozzi2024let,\n", 30 | " title={Let your graph do the talking: Encoding structured data for llms},\n", 31 | " author={Perozzi, Bryan and Fatemi, Bahare and Zelle, Dustin and Tsitsulin, Anton and Kazemi, Mehran and Al-Rfou, Rami and Halcrow, Jonathan},\n", 32 | " journal={arXiv preprint arXiv:2402.05862},\n", 33 | " year={2024}\n", 34 | "}\n", 35 | "```\n", 36 | "\n", 37 | "## Tutorial Part II: GNN Encoding of Graph Information\n", 38 | "This notebook takes the work we did in the first part of the tutorial and extends it to using a Graph Neural Network to directly encode a representation of a graph into a prompt (vs using a text encoding as we did in the previous part).\n", 39 | "\n", 40 | "## Notebook Outline:\n", 41 | "\n", 42 | "Setup (Install Dependencies, download Gemma weights)\n", 43 | "Dataset creation\n", 44 | "Graph-to-Text conversion\n", 45 | "Evaluation\n", 46 | "Exercise: Graph Encoding Challenge\n", 47 | "Exercise: DBLP Dataset\n", 48 | "Setup\n", 49 | "\n", 50 | "## Prework!\n", 51 | "\n", 52 | "Sign-up for Kaggle and consent to the Gemma TOS (this is a requirement to download the Gemma weights used in this notebook).\n", 53 | "https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2" 54 | ], 55 | "metadata": { 56 | "id": "EdaaBjBQ3c5g" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "%%capture\n", 63 | "# @title Install Dependencies\n", 64 | "!pip install git+https://github.com/google-deepmind/gemma.git\n", 65 | "!pip install --user kaggle\n", 66 | "!pip install sparse_deferred\n", 67 | "!git clone https://github.com/google-research/talk-like-a-graph.git\n", 68 | "import sys\n", 69 | "sys.path.insert(0, \"/content/talk-like-a-graph\")\n" 70 | ], 71 | "metadata": { 72 | "id": "BtYsxQuUQ-K0" 73 | }, 74 | "execution_count": null, 75 | "outputs": [] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "source": [ 80 | "## Login to Kaggle\n", 81 | "Follow the link in the login dialog to get an API key if you don't already have one. Also make sure to approve the [Gemma TOS](https://www.kaggle.com/models/google/gemma/license/consent?returnUrl=%2Fmodels%2Fgoogle%2Fgemma%2FFlax%2F2b-it%2F2) as well." 82 | ], 83 | "metadata": { 84 | "id": "42csum1-l0bn" 85 | } 86 | }, 87 | { 88 | "cell_type": "code", 89 | "source": [ 90 | "import kagglehub\n", 91 | "\n", 92 | "kagglehub.login()" 93 | ], 94 | "metadata": { 95 | "id": "xmM6CSh7RbIk" 96 | }, 97 | "execution_count": null, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "source": [ 103 | "## Download Gemma" 104 | ], 105 | "metadata": { 106 | "id": "500UNvQOmcMn" 107 | } 108 | }, 109 | { 110 | "cell_type": "code", 111 | "source": [ 112 | "import os\n", 113 | "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", 114 | "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", 115 | "ckpt_path = os.path.join(weights_dir, VARIANT)\n", 116 | "vocab_path = os.path.join(weights_dir, 'tokenizer.model')" 117 | ], 118 | "metadata": { 119 | "id": "W0BzTj2tPROY" 120 | }, 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "source": [ 127 | "## Import dependencies" 128 | ], 129 | "metadata": { 130 | "id": "WGFZHBaCRa77" 131 | } 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": { 137 | "id": "ycYOnVRpHZuJ", 138 | "cellView": "form" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "# @title\n", 143 | "import os\n", 144 | "from collections.abc import Sequence\n", 145 | "import dataclasses\n", 146 | "from typing import Any, Callable, Mapping\n", 147 | "import sys\n", 148 | "\n", 149 | "\n", 150 | "import chex\n", 151 | "from flax import linen as nn\n", 152 | "import jax\n", 153 | "import jax.numpy as jnp\n", 154 | "import networkx as nx\n", 155 | "import numpy as np\n", 156 | "\n", 157 | "\n", 158 | "from gemma import params as params_lib\n", 159 | "from gemma import transformer as transformer_lib\n", 160 | "from gemma import sampler as sampler_lib\n", 161 | "import sentencepiece as spm\n", 162 | "import sparse_deferred as sd\n", 163 | "from sparse_deferred import jax as sdjnp\n", 164 | "from sparse_deferred.structs import graph_struct\n", 165 | "from sparse_deferred import np as sdnp\n" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "source": [ 171 | "## GraphToken library code" 172 | ], 173 | "metadata": { 174 | "id": "DzQUN-bCmrjg" 175 | } 176 | }, 177 | { 178 | "cell_type": "code", 179 | "source": [ 180 | "# @title\n", 181 | "import collections\n", 182 | "from collections.abc import Iterable\n", 183 | "import io\n", 184 | "import json\n", 185 | "from typing import Any, Callable, NamedTuple, Sequence\n", 186 | "\n", 187 | "import numpy as np\n", 188 | "import tqdm\n", 189 | "\n", 190 | "# Code for converting NetworkX graphs to graph tensor\n", 191 | "def laplacian_pos_embedding(graph: nx.Graph, units: int = 4) -\u003e nx.Graph:\n", 192 | " \"\"\"Adds the laplacian positional encoding.\"\"\"\n", 193 | " m = nx.normalized_laplacian_matrix(\n", 194 | " graph, nodelist=sorted(graph.nodes), weight=None\n", 195 | " ).astype(np.float32)\n", 196 | " u, _, _ = np.linalg.svd(m.todense(), compute_uv=True)\n", 197 | " if units \u003e u.shape[1]:\n", 198 | " u = np.pad(u, ((0, 0), (0, units - u.shape[1])))\n", 199 | " nx.set_node_attributes(\n", 200 | " graph, dict(zip(sorted(graph.nodes), u[:, :units])), name='lpe'\n", 201 | " )\n", 202 | " return graph\n", 203 | "\n", 204 | "\n", 205 | "def to_graph_struct(graph: nx.Graph, node_ids: list[int]=None) -\u003e graph_struct.GraphStruct:\n", 206 | " if graph.edges(data=True):\n", 207 | " s, t, w = zip(*[\n", 208 | " (s, t, (d['weight'] if d and 'weight' in d else None))\n", 209 | " for s, t, d in graph.edges(data=True)\n", 210 | " ])\n", 211 | " else:\n", 212 | " s, t, w = (), (), ()\n", 213 | " # tfgnn assumes graphs are directed. Adding the rev edges for an undirected\n", 214 | " # graph.\n", 215 | " if not graph.is_directed():\n", 216 | " s, t, w = s + t, t + s, w + w\n", 217 | "\n", 218 | " graph = laplacian_pos_embedding(graph, units=4)\n", 219 | " return graph_struct.GraphStruct.new(\n", 220 | " nodes={'nodes': {'lpe': np.stack([graph.nodes('lpe')[i] for i in range(graph.number_of_nodes())])}},\n", 221 | " edges={'edges': ((np.array(s, dtype=np.int32), np.array(t, dtype=np.int32)), {})}\n", 222 | " )\n", 223 | "\n", 224 | "\n", 225 | "Tensor = sd.matrix.Tensor\n", 226 | "Features = dict[str, Tensor]\n", 227 | "FeatureSets = dict[str, Features]\n", 228 | "Edge = tuple[tuple[Tensor, ...], Features] # (endpoints, edge features)\n", 229 | "Edges = dict[str, Edge]\n", 230 | "Nodes = FeatureSets\n", 231 | "Schema = dict[str, tuple[str, ...]]\n", 232 | "_Schema = dict[str, tuple[dict[str, int], ...]]\n", 233 | "\n", 234 | "\n", 235 | "\n", 236 | "class FixedSizePadder:\n", 237 | " \"\"\"Adds padding to `GraphStruct` instances for fixed-sized tensors.\n", 238 | "\n", 239 | " Fixed-size tensors can be preferred when running on TPU accelerators.\n", 240 | "\n", 241 | " To use this class, you must first initialize it with statistics of your graphs\n", 242 | " then use it to pad graphs. The statistics can be initialized by invoking\n", 243 | " `calculate_pad_statistics`: this function records the *maximum* observerd size\n", 244 | " of every node and edge set, as well as the standard deviation (std) of sizes.\n", 245 | "\n", 246 | " Once initialized, the function: `pad_graph()` will add padding to the graph.\n", 247 | " Specifically, the node feature (tensors) will be padded with zeros. Similarly,\n", 248 | " edges will be inserted, among newly-added virtual nodes.\n", 249 | "\n", 250 | " Each node (or edge) size will become:\n", 251 | "\n", 252 | " `max observed [per calculate_pad_statistics] + slack*std + 1`\n", 253 | "\n", 254 | " NOTE: there will always be at least one more node or edge, even if the\n", 255 | " statistics show zero std. This is required for making virtual nodes.\n", 256 | "\n", 257 | " All sizes node-set (features) and edge-set (features and adjacency list)\n", 258 | " \"\"\"\n", 259 | "\n", 260 | " def __init__(self, engine: sd.ComputeEngine, slack: float = 1.0):\n", 261 | " # `('edge'|'node', NodeOrEdgeName) -\u003e target size`\n", 262 | " # where `target size` is maximum observed size for node (or edge) set, plus\n", 263 | " # one, plus slack-times-std of observed sizes.\n", 264 | " self.sizes: dict[tuple[str, str], int] = {}\n", 265 | " self.slack = slack\n", 266 | " self._engine = engine\n", 267 | "\n", 268 | " def calculate_pad_statistics(\n", 269 | " self, examples: Iterable[graph_struct.GraphStruct], num_steps: int = 100):\n", 270 | " \"\"\"Measures the max and std of node \u0026 edge sizes of elements of `examples`.\n", 271 | "\n", 272 | " Calling this function is necessary before invoking `pad_graph`.\n", 273 | "\n", 274 | " Args:\n", 275 | " examples: iterable that yields `GraphStruct` examples.\n", 276 | " num_steps: If positive, considers this many samples of `examples`.\n", 277 | " Otherwise, iterates over all `examples`. Warning: this may run\n", 278 | " infinitely on infinite iterators (e.g., `dataset.repeat()`).\n", 279 | " \"\"\"\n", 280 | " sizes: dict[tuple[str, str], list[int]] = collections.defaultdict(list)\n", 281 | " for i, graph in enumerate(examples):\n", 282 | " assert isinstance(graph, graph_struct.GraphStruct)\n", 283 | " if i \u003e 0 and i \u003e= num_steps:\n", 284 | " break\n", 285 | " for node_name, features in graph.nodes.items():\n", 286 | " value_list = sizes[('nodes', node_name)]\n", 287 | " if not features:\n", 288 | " value_list.append(0)\n", 289 | " else:\n", 290 | " value_list.append(list(features.values())[0].shape[0])\n", 291 | "\n", 292 | " for edge_name, edges_tuple in graph.edges.items():\n", 293 | " value_list = sizes[('edges', edge_name)]\n", 294 | " source_nodes = edges_tuple[0][0]\n", 295 | " # if len(value_list) and edge_set.sizes.shape != value_list[-1].shape:\n", 296 | " # continue\n", 297 | " value_list.append(source_nodes.shape[0])\n", 298 | "\n", 299 | " self.sizes = {k: int(1 + max(v) + self.slack * np.std(v))\n", 300 | " for k, v in sizes.items()}\n", 301 | "\n", 302 | " def pad_graph(self, graph: graph_struct.GraphStruct) -\u003e graph_struct.GraphStruct:\n", 303 | " \"\"\"Pads node-sets and edge-sets, with zeros, to max-seen during `calc..`.\n", 304 | "\n", 305 | " This function is useful for running on TPU hardware.\n", 306 | "\n", 307 | " Args:\n", 308 | " graph: contains any number of nodes and edges.\n", 309 | "\n", 310 | " Returns:\n", 311 | " graph with deterministic number of nodes and edges. See class docstring.\n", 312 | " \"\"\"\n", 313 | " if not self.sizes:\n", 314 | " raise ValueError(\n", 315 | " 'No statistics have been initialized. '\n", 316 | " 'Perhaps you forgot to invoke \"calculate_pad_statistics\"?')\n", 317 | " # Edge set name -\u003e (1D vectors containing endpoints**), {\"feature\": Tensor})\n", 318 | " edges: Edges = {}\n", 319 | " # ** tuple should have 2 entries for directed graphs\n", 320 | "\n", 321 | " nodes: Nodes = {}\n", 322 | "\n", 323 | " # For every key in `edges`, store names of node sets that `key` edge\n", 324 | " # connects.\n", 325 | " schema = graph.schema\n", 326 | "\n", 327 | " e = self._engine # for short.\n", 328 | " for node_name, node_features in graph.nodes.items():\n", 329 | " padded_features = {}\n", 330 | " desired_size = self.sizes[('nodes', node_name)]\n", 331 | "\n", 332 | " for feature_name, feature in node_features.items():\n", 333 | " feature = feature[:desired_size] # if `is_oversized`.\n", 334 | " pad = self._engine.maximum(\n", 335 | " desired_size - self._engine.shape(feature)[0], 0)\n", 336 | " zeros = e.zeros(\n", 337 | " tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype)\n", 338 | " padded_feature = e.concat([feature, zeros], axis=0)\n", 339 | " padded_feature = e.reshape(\n", 340 | " padded_feature, [desired_size] + list(padded_feature.shape[1:]))\n", 341 | " padded_features[feature_name] = padded_feature\n", 342 | "\n", 343 | " nodes[node_name] = padded_features\n", 344 | "\n", 345 | " for edge_name, (edge_endpoints, features) in graph.edges.items():\n", 346 | " padded_features = {}\n", 347 | " padded_endpoints = []\n", 348 | " desired_size = self.sizes[('edges', edge_name)]\n", 349 | " current_size = e.shape(edge_endpoints[0])[0]\n", 350 | "\n", 351 | " pad = e.maximum(desired_size - current_size, 0)\n", 352 | " e.assert_greater(pad, -1)\n", 353 | "\n", 354 | " for feature_name, feature in features.items():\n", 355 | " feature = feature[:desired_size] # if `is_oversized`.\n", 356 | " zeros = e.zeros(\n", 357 | " tuple([pad] + list(feature.shape[1:])), dtype=feature.dtype\n", 358 | " )\n", 359 | " padded_feature = e.concat([feature, zeros], axis=0)\n", 360 | " padded_feature = e.reshape(\n", 361 | " padded_feature, [desired_size] + list(padded_feature.shape[1:])\n", 362 | " )\n", 363 | " padded_features[feature_name] = padded_feature\n", 364 | "\n", 365 | " edge_endpoints = [node_ids[:desired_size] for node_ids in edge_endpoints]\n", 366 | " # [[src1_is_valid, src2_is_valid, ...], [tgt1_is_valid, ...]]\n", 367 | " valid = e.cast(\n", 368 | " [\n", 369 | " ids \u003c self.sizes[('nodes', node_name)]\n", 370 | " for ids, node_name in zip(edge_endpoints, schema[edge_name])\n", 371 | " ],\n", 372 | " dtype=bool,\n", 373 | " )\n", 374 | " valid = e.reduce_all(valid, axis=0)\n", 375 | "\n", 376 | " for node_ids, node_name in zip(edge_endpoints, schema[edge_name]):\n", 377 | " # Universe size (e.g., of source or target).\n", 378 | " max_endpoint = self.sizes[('nodes', node_name)] - 1\n", 379 | " node_ids = node_ids[:desired_size]\n", 380 | " node_ids = e.boolean_mask(node_ids, valid)\n", 381 | " pad = desired_size - e.shape(node_ids)[0] # Need only to compute once.\n", 382 | "\n", 383 | " padded_ids = e.concat([\n", 384 | " node_ids,\n", 385 | " e.ones((pad), dtype=node_ids.dtype) * max_endpoint\n", 386 | " ], axis=0)\n", 387 | " padded_ids = e.reshape(padded_ids, [desired_size])\n", 388 | " padded_endpoints.append(padded_ids)\n", 389 | "\n", 390 | " edges[edge_name] = (tuple(padded_endpoints), padded_features)\n", 391 | "\n", 392 | " graph = graph_struct.GraphStruct.new(nodes=nodes, edges=edges, schema=schema)\n", 393 | " return graph\n", 394 | "\n", 395 | "\n", 396 | "\n", 397 | "## gnn.py\n", 398 | "class GIN(nn.Module):\n", 399 | " \"\"\"Graph Isomorphism Network: https://arxiv.org/pdf/1810.00826.pdf.\"\"\"\n", 400 | "\n", 401 | " output_dim: int\n", 402 | " num_hidden_layers: int = 1\n", 403 | " hidden_dim: int = 32\n", 404 | " epsilon: float = 0.1 # See GIN paper (link above)\n", 405 | "\n", 406 | " def setup(self):\n", 407 | " layer_dims = [self.hidden_dim] * self.num_hidden_layers\n", 408 | " self.layers = [\n", 409 | " nn.Dense(dim, use_bias=False, dtype=jnp.bfloat16) for dim in layer_dims\n", 410 | " ]\n", 411 | " self.out_layer = nn.Dense(\n", 412 | " self.output_dim, use_bias=False, dtype=jnp.bfloat16\n", 413 | " )\n", 414 | "\n", 415 | " def __call__(self, graph: graph_struct.GraphStruct) -\u003e jax.Array:\n", 416 | " x = graph.nodes['nodes']['lpe']\n", 417 | " adj = graph.adj(sdjnp.engine, 'edges')\n", 418 | " adj = adj.add_eye(1 + self.epsilon) # self connections with 1+eps weight.\n", 419 | "\n", 420 | " for i, layer in enumerate(self.layers):\n", 421 | " x = layer(adj @ x)\n", 422 | " if i \u003c self.num_hidden_layers:\n", 423 | " x = nn.relu(x)\n", 424 | " x = jnp.concat(x, axis=-1)\n", 425 | " return self.out_layer(x)\n", 426 | "\n", 427 | "\n", 428 | "class GCN(nn.Module):\n", 429 | " \"\"\"Graph convolutional network: https://arxiv.org/pdf/1609.02907.pdf.\"\"\"\n", 430 | "\n", 431 | " output_dim: int\n", 432 | " num_hidden_layers: int = 1\n", 433 | " hidden_dim: int = 32\n", 434 | "\n", 435 | " def setup(self):\n", 436 | " layer_dims = [self.hidden_dim] * self.num_hidden_layers\n", 437 | " self.layers = [nn.Dense(dim, use_bias=False) for dim in layer_dims]\n", 438 | " self.out_layer = nn.Dense(\n", 439 | " self.output_dim, use_bias=False, dtype=jnp.bfloat16\n", 440 | " )\n", 441 | "\n", 442 | " def __call__(self, graph: graph_struct.GraphStruct) -\u003e jax.Array:\n", 443 | " x = graph.nodes['nodes']['lpe']\n", 444 | " adj = graph.adj(sdjnp.engine, 'edges')\n", 445 | " adj_symnorm = (adj + adj.transpose()).add_eye().normalize_symmetric()\n", 446 | "\n", 447 | " for i, layer in enumerate(self.layers):\n", 448 | " x = layer(adj_symnorm @ x)\n", 449 | " if i \u003c self.num_hidden_layers:\n", 450 | " x = nn.relu(x)\n", 451 | " x = jnp.concat(x, axis=-1)\n", 452 | " return self.out_layer(x)\n", 453 | "\n", 454 | "\n", 455 | "## sampler.py\n", 456 | "\n", 457 | "@dataclasses.dataclass\n", 458 | "class SamplerOutput:\n", 459 | "\n", 460 | " # Decoded samples from the model.\n", 461 | " text: list[str]\n", 462 | "\n", 463 | " # Per-step logits used during sampling.\n", 464 | " logits: list[list[float]]\n", 465 | "\n", 466 | " # Tokens corresponding to the generated samples.\n", 467 | " tokens: list[list[int]]\n", 468 | "\n", 469 | " graph_embeddings: list[jnp.ndarray]\n", 470 | "\n", 471 | "\n", 472 | "class GraphTokenSampler:\n", 473 | " \"\"\"Sampler for GraphToken.\"\"\"\n", 474 | "\n", 475 | " def __init__(\n", 476 | " self,\n", 477 | " gnn: nn.Module,\n", 478 | " llm: transformer_lib.Transformer,\n", 479 | " vocab: spm.SentencePieceProcessor,\n", 480 | " params: Mapping[str, Any],\n", 481 | " gnn_token_template: str = r'\u003cunused%d\u003e',\n", 482 | " ):\n", 483 | " \"\"\"Initializes the sampler.\n", 484 | "\n", 485 | " Args:\n", 486 | " gnn: The GNN model.\n", 487 | " llm: The LLM model.\n", 488 | " vocab: The vocab used by the LLM.\n", 489 | " params: The parameters for the GNN and LLM. This should contain the params\n", 490 | " for the gnn under params['gnn'] and the params for the llm under\n", 491 | " params['transformer']\n", 492 | " gnn_token_template: The token used to represent the GNN embedding.\n", 493 | " \"\"\"\n", 494 | "\n", 495 | " self._gnn = gnn\n", 496 | " self._llm = llm\n", 497 | " self._params = params\n", 498 | " self._vocab = vocab\n", 499 | " self._gnn_token_template = gnn_token_template\n", 500 | " self._sampler = sampler_lib.Sampler(\n", 501 | " transformer=self._llm,\n", 502 | " vocab=self._vocab,\n", 503 | " params=self._params['transformer'],\n", 504 | " )\n", 505 | "\n", 506 | " def __call__(\n", 507 | " self,\n", 508 | " input_strings: Sequence[str],\n", 509 | " input_graphs: Sequence[graph_struct.GraphStruct],\n", 510 | " total_generation_steps: int,\n", 511 | " echo: bool = False,\n", 512 | " return_logits: bool = True,\n", 513 | " forbidden_tokens: Sequence[str] | None = None,\n", 514 | " ) -\u003e SamplerOutput:\n", 515 | " \"\"\"Samples from the model.\n", 516 | "\n", 517 | " Args:\n", 518 | " input_strings: The input strings.\n", 519 | " input_graphs: The input graphs.\n", 520 | " total_generation_steps: The number of steps to generate.\n", 521 | " echo: Whether to echo the input.\n", 522 | " return_logits: Whether to return the logits.\n", 523 | " forbidden_tokens: Tokens that are forbidden, in addition to the GNN token.\n", 524 | "\n", 525 | " Returns:\n", 526 | " The sampled output.\n", 527 | " \"\"\"\n", 528 | " assert len(input_graphs) == len(input_strings), (\n", 529 | " len(input_graphs),\n", 530 | " len(input_strings),\n", 531 | " )\n", 532 | " augmented_inputs = []\n", 533 | " full_forbidden_tokens = []\n", 534 | " if forbidden_tokens is not None:\n", 535 | " full_forbidden_tokens += forbidden_tokens\n", 536 | " graph_embeddings = []\n", 537 | " augmented_transformer_params = self._params['transformer']\n", 538 | "\n", 539 | " placeholder_token = PLACEHOLDER_TOKEN\n", 540 | " full_forbidden_tokens.append(placeholder_token)\n", 541 | " placeholder_token_id = self._vocab.EncodeAsIds(placeholder_token)\n", 542 | " assert len(placeholder_token_id) == 1, placeholder_token\n", 543 | " placeholder_token_id = placeholder_token_id[0]\n", 544 | "\n", 545 | " for prompt, graph in zip(input_strings, input_graphs):\n", 546 | " embed = self._gnn.apply(self._params['gnn'], graph)\n", 547 | " assert (\n", 548 | " self._params['transformer']['embedder']['input_embedding'][\n", 549 | " placeholder_token_id\n", 550 | " ].shape\n", 551 | " == embed.shape\n", 552 | " )\n", 553 | " augmented_transformer_params['embedder']['input_embedding'] = (\n", 554 | " augmented_transformer_params['embedder']['input_embedding']\n", 555 | " .at[placeholder_token_id]\n", 556 | " .set(embed)\n", 557 | " )\n", 558 | " graph_embeddings.append(embed)\n", 559 | " augmented_inputs.append(placeholder_token + prompt)\n", 560 | "\n", 561 | " self._sampler.params = augmented_transformer_params\n", 562 | " o = self._sampler(\n", 563 | " input_strings=augmented_inputs,\n", 564 | " total_generation_steps=total_generation_steps,\n", 565 | " echo=echo,\n", 566 | " return_logits=return_logits,\n", 567 | " forbidden_tokens=full_forbidden_tokens,\n", 568 | " )\n", 569 | " return SamplerOutput(\n", 570 | " **dataclasses.asdict(o),\n", 571 | " graph_embeddings=graph_embeddings,\n", 572 | " )\n", 573 | "\n", 574 | "\n", 575 | "\n", 576 | "@chex.dataclass(frozen=True)\n", 577 | "class TrainingInput:\n", 578 | " \"\"\"Batch of training data for a GraphToken model.\"\"\"\n", 579 | "\n", 580 | " # Input tokens given to the model\n", 581 | " input_tokens: np.ndarray # size [B, L]\n", 582 | "\n", 583 | " # A mask that determines which tokens contribute to the target loss\n", 584 | " # calculation.\n", 585 | " target_mask: np.ndarray # size [B, L]\n", 586 | "\n", 587 | " input_graphs: list[graph_struct.GraphStruct] # size [B]\n", 588 | "\n", 589 | " # Ground truth for the input tokens, if representable as an integer.\n", 590 | " # For boolean classification tasks, this is 0/1.\n", 591 | " parsed_ground_truth: np.ndarray | None # size [B]\n", 592 | "\n", 593 | "\n", 594 | "def parse_int(s: str) -\u003e int:\n", 595 | " \"\"\"Parse a string as an integer.\"\"\"\n", 596 | " return int(float(s.strip()))\n", 597 | "\n", 598 | "\n", 599 | "def parse_yes_no(s: str) -\u003e bool:\n", 600 | " \"\"\"Parse a string as a yes/no answer, looking at the first 10 chars.\"\"\"\n", 601 | " return 'yes' in s.lower()[:10]\n", 602 | "\n", 603 | "PLACEHOLDER_TOKEN = '\u003cunused0\u003e'\n", 604 | "\n", 605 | "def graphqa_ds(\n", 606 | " vocab: spm.SentencePieceProcessor,\n", 607 | " encoded_examples: list,\n", 608 | " padder: graph_struct.FixedSizePadder | None = None,\n", 609 | " max_tokens: int = 100,\n", 610 | " gt_parser: Callable[[str], Any] | None = None,\n", 611 | ") -\u003e tuple[graph_struct.FixedSizePadder, list[TrainingInput]]:\n", 612 | " \"\"\"Load a GraphQA dataset as a list of TrainingInput.\n", 613 | "\n", 614 | " Args:\n", 615 | " vocab: The vocab to use for tokenization.\n", 616 | " encoded_examples: List of encoded examples generated by GraphQA\n", 617 | " padder: The padder to use for padding the graph. If None, a new padder will\n", 618 | " be created and returned. This is so a padder can be shared across multiple\n", 619 | " datasets / splits.\n", 620 | " max_tokens: The maximum number of tokens to allow in the input. For\n", 621 | " 'task_only' prompting this can be quite small (100 tokens is plenty)\n", 622 | " gt_parser: A function to parse the ground truth from the answer string, used\n", 623 | " to supply the 'parsed_ground_truth' field in the TrainingInput.\n", 624 | "\n", 625 | " Returns:\n", 626 | " The padder used for padding the graphs, and a list of TrainingInput.\n", 627 | " \"\"\"\n", 628 | "\n", 629 | " output = []\n", 630 | " for ex in encoded_examples:\n", 631 | " query = PLACEHOLDER_TOKEN + ex['question'][ex['question'].find('Q:'):]\n", 632 | " answer = ex['answer']\n", 633 | " graph = to_graph_struct(ex['graph'])\n", 634 | " query_tokens = vocab.EncodeAsIds(query)\n", 635 | " answer_tokens = vocab.EncodeAsIds(answer) + [vocab.eos_id()]\n", 636 | " input_tokens = np.array([vocab.bos_id()] + query_tokens + answer_tokens)\n", 637 | " target_mask = np.zeros_like(input_tokens, dtype=jnp.int32)\n", 638 | " # Add one for BOS token\n", 639 | " target_mask[len(query_tokens) + 1 :] = 1\n", 640 | " orig_len = len(query_tokens) + len(answer_tokens) + 1\n", 641 | " input_tokens = np.pad(\n", 642 | " input_tokens,\n", 643 | " [[0, max_tokens - orig_len]],\n", 644 | " constant_values=vocab.pad_id(),\n", 645 | " )\n", 646 | "\n", 647 | " target_mask = np.pad(target_mask, [[0, max_tokens - orig_len]])\n", 648 | "\n", 649 | "\n", 650 | " # The GNN library that we are using requires a global feature. We set\n", 651 | " # a fake value here, but it is unused otherwise.\n", 652 | " #graph = graph.update(nodes={'g': {'foo': np.zeros([1])}})\n", 653 | "\n", 654 | " output.append(\n", 655 | " TrainingInput(\n", 656 | " input_tokens=np.array([input_tokens]),\n", 657 | " target_mask=np.array([target_mask]),\n", 658 | " input_graphs=[graph],\n", 659 | " parsed_ground_truth=np.array(gt_parser(answer)) if gt_parser else None,\n", 660 | " )\n", 661 | " )\n", 662 | " if padder is None:\n", 663 | " padder = FixedSizePadder(sdnp.engine)\n", 664 | " padder.calculate_pad_statistics(\n", 665 | " [e.input_graphs[0] for e in output], len(output)\n", 666 | " )\n", 667 | " for o in output:\n", 668 | " o.input_graphs[0] = padder.pad_graph(o.input_graphs[0])\n", 669 | " return padder, output\n", 670 | "\n", 671 | "\n", 672 | "def decode_questions(\n", 673 | " training_input: TrainingInput, vocab: spm.SentencePieceProcessor\n", 674 | ") -\u003e list[str]:\n", 675 | " \"\"\"Decode the question from the input tokens. (ignoring the first 2).\"\"\"\n", 676 | " b, l = training_input.input_tokens.shape\n", 677 | " question_tokens = []\n", 678 | " for i in range(b):\n", 679 | " question_tokens.append([])\n", 680 | " # Skip the first two tokens (BOS and control token).\n", 681 | " for j in range(2, l):\n", 682 | " if training_input.target_mask[i, j] == 1:\n", 683 | " break\n", 684 | " question_tokens[i].append(int(training_input.input_tokens[i, j]))\n", 685 | " return [''.join(vocab.DecodeIds(q)) for q in question_tokens]\n", 686 | "\n", 687 | "\n", 688 | "## training_loop\n", 689 | "import functools\n", 690 | "from typing import Any, MutableMapping\n", 691 | "\n", 692 | "import chex\n", 693 | "from flax import linen as nn\n", 694 | "from gemma import transformer as transformer_lib\n", 695 | "import jax\n", 696 | "import jax.numpy as jnp\n", 697 | "import optax\n", 698 | "import tqdm\n", 699 | "\n", 700 | "Params = MutableMapping[str, Any]\n", 701 | "\n", 702 | "\n", 703 | "def get_attention_mask_and_positions(\n", 704 | " example: jax.Array,\n", 705 | " pad_id: int,\n", 706 | ") -\u003e tuple[jax.Array, jax.Array]:\n", 707 | " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", 708 | " pad_mask = example != pad_id\n", 709 | " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", 710 | " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", 711 | " return current_token_position, attention_mask\n", 712 | "\n", 713 | "\n", 714 | "def forward_and_loss_fn(\n", 715 | " params: Params,\n", 716 | " *,\n", 717 | " gnn: nn.Module,\n", 718 | " llm: transformer_lib.Transformer,\n", 719 | " input_tokens: jax.Array, # Shape [B, L]\n", 720 | " input_graphs: list[graph_struct.GraphStruct], # Shape [B]\n", 721 | " input_mask: jax.Array, # Shape [B, L]\n", 722 | " positions: jax.Array, # Shape [B, L]\n", 723 | " attention_mask: jax.Array, # [B, L, L]\n", 724 | " placeholder_token_id: int,\n", 725 | ") -\u003e jax.Array:\n", 726 | " \"\"\"Forward pass and loss function.\n", 727 | "\n", 728 | " Args:\n", 729 | " params: Params for the gnn and transformer. The gnn params are stored in\n", 730 | " params['gnn'] and the llm params are stored in params['transformer'].\n", 731 | " gnn: gnn model to call.\n", 732 | " llm: gemma transformer model to call.\n", 733 | " input_tokens: input tokens sequence, shape [B, L].\n", 734 | " input_graphs: input graphs.\n", 735 | " input_mask: tokens to ignore when computing the loss, shape [B, L].\n", 736 | " positions: relative position of each token, shape [B, L].\n", 737 | " attention_mask: input attention mask, shape [B, L].\n", 738 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n", 739 | " graph embeddings.\n", 740 | "\n", 741 | " Returns:\n", 742 | " Softmax cross-entropy loss for the next-token prediction task.\n", 743 | " \"\"\"\n", 744 | " # Right now we only support batch_size = 1\n", 745 | " chex.assert_axis_dimension(input_tokens, 0, 1)\n", 746 | " chex.assert_equal_shape([input_tokens, input_mask, positions])\n", 747 | " chex.assert_axis_dimension(attention_mask, 0, 1)\n", 748 | " chex.assert_equal(len(input_graphs), 1)\n", 749 | "\n", 750 | " # Get the GNN embedding and update the transformer input embedding for a\n", 751 | " # control token.\n", 752 | " graph_embed = gnn.apply(params['gnn'], input_graphs[0])\n", 753 | " params['transformer']['embedder']['input_embedding'] = (\n", 754 | " params['transformer']['embedder']['input_embedding']\n", 755 | " .at[placeholder_token_id]\n", 756 | " .set(graph_embed)\n", 757 | " )\n", 758 | " # Forward pass on the input data.\n", 759 | " # No attention cache is needed here.\n", 760 | " logits, _ = llm.apply(\n", 761 | " {'params': params['transformer']},\n", 762 | " input_tokens,\n", 763 | " positions,\n", 764 | " None, # Attention cache is None.\n", 765 | " attention_mask,\n", 766 | " )\n", 767 | "\n", 768 | " # Exclude the last step as it does not appear in the targets.\n", 769 | " logits = logits[0, :-1]\n", 770 | "\n", 771 | " # Similarly, the first token cannot be predicted.\n", 772 | " target_tokens = input_tokens[0, 1:]\n", 773 | " target_mask = input_mask[0, 1:]\n", 774 | "\n", 775 | " # Convert the target labels into one-hot encoded vectors.\n", 776 | " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n", 777 | "\n", 778 | " # Don't update on unwanted tokens.\n", 779 | " one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., jnp.newaxis]\n", 780 | "\n", 781 | " # Normalisation factor.\n", 782 | " norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)\n", 783 | "\n", 784 | " # Return the nll loss.\n", 785 | " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor\n", 786 | "\n", 787 | "\n", 788 | "@functools.partial(\n", 789 | " jax.jit,\n", 790 | " static_argnames=['gnn', 'llm', 'optimizer', 'pad_id', 'placeholder_token_id'],\n", 791 | ")\n", 792 | "def train_step(\n", 793 | " llm: transformer_lib.Transformer,\n", 794 | " gnn: nn.Module,\n", 795 | " params: MutableMapping[str, Any],\n", 796 | " optimizer: optax.GradientTransformation,\n", 797 | " opt_state: optax.OptState,\n", 798 | " pad_id: int,\n", 799 | " example: TrainingInput,\n", 800 | " placeholder_token_id: int,\n", 801 | ") -\u003e tuple[jax.Array, Params, optax.OptState]:\n", 802 | " \"\"\"Train step.\n", 803 | "\n", 804 | " Args:\n", 805 | " llm: gemma transformer model.\n", 806 | " gnn: gnn model.\n", 807 | " params: model's input parameters.\n", 808 | " optimizer: optax optimizer to use.\n", 809 | " opt_state: input optimizer's state.\n", 810 | " pad_id: id of the pad token.\n", 811 | " example: input batch.\n", 812 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n", 813 | " graph embeddings.\n", 814 | "\n", 815 | " Returns:\n", 816 | " Training loss, updated parameters, updated optimizer state.\n", 817 | " \"\"\"\n", 818 | "\n", 819 | " # Build the position and attention mask vectors.\n", 820 | " positions, attention_mask = get_attention_mask_and_positions(\n", 821 | " jnp.array(example.input_tokens), pad_id\n", 822 | " )\n", 823 | "\n", 824 | " # Forward and backward passes\n", 825 | " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(\n", 826 | " params,\n", 827 | " gnn=gnn,\n", 828 | " llm=llm,\n", 829 | " input_tokens=example.input_tokens,\n", 830 | " input_mask=example.target_mask,\n", 831 | " input_graphs=example.input_graphs,\n", 832 | " positions=positions,\n", 833 | " attention_mask=attention_mask,\n", 834 | " placeholder_token_id=placeholder_token_id,\n", 835 | " )\n", 836 | "\n", 837 | " updates, opt_state = optimizer.update(\n", 838 | " grads['gnn'], opt_state, params=params['gnn']\n", 839 | " )\n", 840 | " params['gnn'] = optax.apply_updates(params['gnn'], updates)\n", 841 | "\n", 842 | " return train_loss, params, opt_state\n", 843 | "\n", 844 | "\n", 845 | "@functools.partial(\n", 846 | " jax.jit, static_argnames=['gnn', 'llm', 'pad_id', 'placeholder_token_id']\n", 847 | ")\n", 848 | "def validation_step(\n", 849 | " gnn: nn.Module,\n", 850 | " llm: transformer_lib.Transformer,\n", 851 | " params: MutableMapping[str, Any],\n", 852 | " pad_id: int,\n", 853 | " example: TrainingInput,\n", 854 | " placeholder_token_id: int,\n", 855 | ") -\u003e jax.Array:\n", 856 | " \"\"\"Validation step.\n", 857 | "\n", 858 | " Args:\n", 859 | " gnn: gnn model.\n", 860 | " llm: gemma transformer model.\n", 861 | " params: model's input parameters. The gnn params are stored in params['gnn']\n", 862 | " and the llm params are stored in params['transformer'].\n", 863 | " pad_id: id of the pad token.\n", 864 | " example: input batch\n", 865 | " placeholder_token_id: Index in the LLM vocabulary that we are using for passing\n", 866 | " graph embeddings.\n", 867 | "\n", 868 | " Returns:\n", 869 | " Validation loss.\n", 870 | " \"\"\"\n", 871 | " jax_input = jax.tree.map(jnp.array, example)\n", 872 | " positions, attention_mask = get_attention_mask_and_positions(\n", 873 | " jax_input.input_tokens, pad_id\n", 874 | " )\n", 875 | " val_loss = forward_and_loss_fn(\n", 876 | " params,\n", 877 | " gnn=gnn,\n", 878 | " llm=llm,\n", 879 | " input_tokens=jax_input.input_tokens,\n", 880 | " input_mask=jax_input.target_mask,\n", 881 | " input_graphs=jax_input.input_graphs,\n", 882 | " positions=positions,\n", 883 | " attention_mask=attention_mask,\n", 884 | " placeholder_token_id=placeholder_token_id,\n", 885 | " )\n", 886 | " return val_loss\n", 887 | "\n", 888 | "\n", 889 | "@chex.dataclass(frozen=True)\n", 890 | "class TrainingConfig:\n", 891 | " learning_rate: float\n", 892 | " num_epochs: int\n", 893 | " eval_every_n: int\n", 894 | " batch_size: int\n", 895 | " max_steps: int | None = None\n", 896 | "\n", 897 | "\n", 898 | "def train_loop(\n", 899 | " llm: transformer_lib.Transformer,\n", 900 | " gnn: nn.Module,\n", 901 | " train_ds: list[TrainingInput],\n", 902 | " validation_ds: list[TrainingInput],\n", 903 | " params: Params,\n", 904 | " training_cfg: TrainingConfig,\n", 905 | " vocab: spm.SentencePieceProcessor,\n", 906 | ") -\u003e Params:\n", 907 | " \"\"\"Main training loop for GraphToken.\n", 908 | "\n", 909 | " Args:\n", 910 | " llm: Gemma transformer model.\n", 911 | " gnn: gnn model.\n", 912 | " train_ds: training dataset.\n", 913 | " validation_ds: validation dataset.\n", 914 | " params: Combined params for both the LLM and GNN. The GNN params are stored\n", 915 | " in params['gnn'] and the LLM params are stored in params['transformer'].\n", 916 | " training_cfg: training configuration.\n", 917 | " vocab: sentence piece vocabulary.\n", 918 | "\n", 919 | " Returns:\n", 920 | " Updated model's input parameters.\n", 921 | " \"\"\"\n", 922 | " optimizer = optax.lion(training_cfg.learning_rate)\n", 923 | " opt_state = optimizer.init(params['gnn'])\n", 924 | "\n", 925 | " avg_loss = 0\n", 926 | "\n", 927 | " placeholder_token_id = vocab.EncodeAsIds(PLACEHOLDER_TOKEN)\n", 928 | " assert (\n", 929 | " len(placeholder_token_id) == 1\n", 930 | " ), f'Placeholder token multiple ids: {placeholder_token_id}'\n", 931 | " placeholder_token_id = placeholder_token_id[0]\n", 932 | " # A first round of validation loss\n", 933 | " n_steps_eval = 0\n", 934 | " eval_loss = 0\n", 935 | "\n", 936 | " with tqdm.tqdm(range(training_cfg.num_epochs * len(train_ds))) as pbar:\n", 937 | " averaged_steps = 0\n", 938 | " for n_steps in pbar:\n", 939 | " train_example = train_ds[n_steps % len(train_ds)]\n", 940 | " train_loss, params, opt_state = train_step(\n", 941 | " gnn=gnn,\n", 942 | " llm=llm,\n", 943 | " params=params,\n", 944 | " optimizer=optimizer,\n", 945 | " opt_state=opt_state,\n", 946 | " pad_id=vocab.pad_id(),\n", 947 | " example=train_example,\n", 948 | " placeholder_token_id=placeholder_token_id,\n", 949 | " )\n", 950 | " averaged_steps += 1\n", 951 | " avg_loss += train_loss\n", 952 | " if n_steps and n_steps % training_cfg.eval_every_n == 0:\n", 953 | " val_iterator = validation_ds\n", 954 | " avg_loss /= averaged_steps\n", 955 | " averaged_steps = 0\n", 956 | " pbar.write(\n", 957 | " f'STEP {n_steps} training loss: {avg_loss}'\n", 958 | " )\n", 959 | " avg_loss = 0\n", 960 | " if (\n", 961 | " training_cfg.max_steps is not None\n", 962 | " and n_steps \u003e training_cfg.max_steps\n", 963 | " ):\n", 964 | " break\n", 965 | " if averaged_steps != 0:\n", 966 | " avg_loss /= averaged_steps\n", 967 | " pbar.write(\n", 968 | " f'STEP {n_steps} training loss: {avg_loss}'\n", 969 | " )\n", 970 | " return params\n", 971 | "\n", 972 | "def merge_params(llm_params, gnn_params):\n", 973 | " out = {}\n", 974 | " out.update(llm_params)\n", 975 | " out['gnn'] = gnn_params\n", 976 | " return out" 977 | ], 978 | "metadata": { 979 | "id": "RZKA-lU7IsZ8", 980 | "cellView": "form" 981 | }, 982 | "execution_count": null, 983 | "outputs": [] 984 | }, 985 | { 986 | "cell_type": "markdown", 987 | "source": [ 988 | "## Train a GraphToken Model" 989 | ], 990 | "metadata": { 991 | "id": "NSuoM8KbnXbF" 992 | } 993 | }, 994 | { 995 | "cell_type": "markdown", 996 | "source": [ 997 | "Load Gemma Weights" 998 | ], 999 | "metadata": { 1000 | "id": "oknoqsTMI78J" 1001 | } 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "source": [ 1006 | "params = params_lib.load_and_format_params(ckpt_path)\n", 1007 | "\n", 1008 | "# Reshard params over TPU device mesh\n", 1009 | "from jax.sharding import PartitionSpec as P\n", 1010 | "mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))\n", 1011 | "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", 1012 | "def try_to_shard(x):\n", 1013 | " try:\n", 1014 | " return jax.device_put(x, sharding)\n", 1015 | " except:\n", 1016 | " return x\n", 1017 | "params = jax.tree_map(try_to_shard, params)\n", 1018 | "\n", 1019 | "\n", 1020 | "config_2b = transformer_lib.TransformerConfig.from_params(\n", 1021 | " params,\n", 1022 | " cache_size=128 # Number of time steps in the transformer's cache\n", 1023 | ")\n", 1024 | "model_2b = transformer_lib.Transformer(config=config_2b)\n", 1025 | "\n", 1026 | "# Load vocabulary\n", 1027 | "vocab = spm.SentencePieceProcessor()\n", 1028 | "assert vocab.Load(vocab_path)" 1029 | ], 1030 | "metadata": { 1031 | "id": "U1q6PiLSCTTa" 1032 | }, 1033 | "execution_count": null, 1034 | "outputs": [] 1035 | }, 1036 | { 1037 | "cell_type": "markdown", 1038 | "source": [ 1039 | "Generate some training data, for the CycleCheck task" 1040 | ], 1041 | "metadata": { 1042 | "id": "NUDd2EisJHX_" 1043 | } 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "source": [ 1048 | "from talk_like_a_graph import graph_generators\n", 1049 | "from talk_like_a_graph import graph_tasks\n", 1050 | "random_seed = 9876\n", 1051 | "\n", 1052 | "train_graphs = graph_generators.generate_graphs(number_of_graphs=500,\n", 1053 | " algorithm='er', # Erdos-Reyni random graphs\n", 1054 | " directed=False,\n", 1055 | " random_seed=random_seed)\n", 1056 | "test_graphs = graph_generators.generate_graphs(number_of_graphs=10,\n", 1057 | " algorithm='er', # Erdos-Reyni random graphs\n", 1058 | " directed=False,\n", 1059 | " random_seed=random_seed + 12385)\n", 1060 | "task = graph_tasks.CycleCheck()\n", 1061 | "train_examples = list(task.prepare_examples_dict(\n", 1062 | " train_graphs,\n", 1063 | " generator_algorithms = ['er']*len(train_graphs),\n", 1064 | " encoding_method='adjacency').values())\n", 1065 | "test_examples = list(task.prepare_examples_dict(\n", 1066 | " test_graphs,\n", 1067 | " generator_algorithms = ['er']*len(test_graphs),\n", 1068 | " encoding_method='adjacency').values())\n", 1069 | "padder, train_ds = graphqa_ds(vocab, train_examples, max_tokens=25)\n", 1070 | "_, test_ds = graphqa_ds(vocab, test_examples, max_tokens=25, padder=padder)" 1071 | ], 1072 | "metadata": { 1073 | "id": "93oCqfLYnXFr" 1074 | }, 1075 | "execution_count": null, 1076 | "outputs": [] 1077 | }, 1078 | { 1079 | "cell_type": "markdown", 1080 | "source": [ 1081 | "Train GraphToken" 1082 | ], 1083 | "metadata": { 1084 | "id": "dJTlxG9yUx9y" 1085 | } 1086 | }, 1087 | { 1088 | "cell_type": "code", 1089 | "source": [ 1090 | "gin = GIN(config_2b.embed_dim, num_hidden_layers=3, hidden_dim=4)\n", 1091 | "key = jax.random.PRNGKey(0)\n", 1092 | "gnn_params = gin.init(key, train_ds[0].input_graphs[0])\n", 1093 | "\n", 1094 | "\n", 1095 | "train_config = TrainingConfig(\n", 1096 | " learning_rate=0.0001, num_epochs=3, eval_every_n=250, batch_size=1\n", 1097 | ")\n", 1098 | "params_learned = train_loop(\n", 1099 | " llm=model_2b,\n", 1100 | " gnn=gin,\n", 1101 | " train_ds=train_ds,\n", 1102 | " validation_ds=test_ds,\n", 1103 | " params=merge_params(params, gnn_params),\n", 1104 | " training_cfg=train_config,\n", 1105 | " vocab=vocab,\n", 1106 | ")" 1107 | ], 1108 | "metadata": { 1109 | "id": "Uq3ZF3v7Ivo7" 1110 | }, 1111 | "execution_count": null, 1112 | "outputs": [] 1113 | }, 1114 | { 1115 | "cell_type": "markdown", 1116 | "source": [ 1117 | "Sample outputs" 1118 | ], 1119 | "metadata": { 1120 | "id": "4eaBMWhPchoJ" 1121 | } 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "source": [ 1126 | "from IPython.display import Markdown, display\n", 1127 | "\n", 1128 | "graph_token_sampler = GraphTokenSampler(\n", 1129 | " params=params_learned, llm=model_2b, gnn=gin, vocab=vocab\n", 1130 | ")\n", 1131 | "\n", 1132 | "\n", 1133 | "def get_graph_qa_question(ex):\n", 1134 | " with_graph = ex['question']\n", 1135 | " q_index = with_graph.find('Q:')\n", 1136 | " return with_graph[q_index:]\n", 1137 | "\n", 1138 | "for i in range(len(test_examples)):\n", 1139 | " tokenized_input = test_ds[i]\n", 1140 | " ex = test_examples[i]\n", 1141 | " prompt = get_graph_qa_question(ex)\n", 1142 | " if i == 0:\n", 1143 | " display(\n", 1144 | " Markdown(\n", 1145 | " '**Prompt:** '\n", 1146 | " + prompt\n", 1147 | " + '\\n\\n'\n", 1148 | " )\n", 1149 | " )\n", 1150 | " llm_output = graph_token_sampler(\n", 1151 | " [prompt],\n", 1152 | " tokenized_input.input_graphs,\n", 1153 | " total_generation_steps=15,\n", 1154 | " return_logits=False,\n", 1155 | " ).text[0]\n", 1156 | " display(Markdown(f'**LLM Output:** \"{llm_output}\"'))\n", 1157 | " display(\n", 1158 | " Markdown(f\"**Ground Truth:** {ex['answer']}\")\n", 1159 | " )\n", 1160 | " display(Markdown('-' * 80))\n", 1161 | " print()\n", 1162 | "\n" 1163 | ], 1164 | "metadata": { 1165 | "id": "0p4CxB2Lcilw" 1166 | }, 1167 | "execution_count": null, 1168 | "outputs": [] 1169 | }, 1170 | { 1171 | "cell_type": "markdown", 1172 | "source": [ 1173 | "## Exercise: Train a model for a different task.\n", 1174 | "\n", 1175 | "Take the above code and modify it for the NodeCount task.\n", 1176 | "How does your model perform?\n", 1177 | "\n", 1178 | "If your kernel runs out of memory run the following code to clear the TPU memory, then re-run the code block labeled 'Load Gemma weights' and retry.\n", 1179 | "```\n", 1180 | "for a in jax.live_arrays():\n", 1181 | " a.delete()\n", 1182 | "```" 1183 | ], 1184 | "metadata": { 1185 | "id": "uuW3rA3-qKy9" 1186 | } 1187 | } 1188 | ] 1189 | } 1190 | -------------------------------------------------------------------------------- /tutorial/README.md: -------------------------------------------------------------------------------- 1 | #

Tutorial on Graph Reasoning with LLMs (GReaL)

2 | 3 | ##

**KDD'24 Tutorial**

4 | 5 | ###

Sunday, August 25th, 2024

6 | 7 |

Centre de Convencions Internacional de Barcelona

8 |

P1 122-123

9 |

2:00 pm - 5:00 pm CEST

10 | 11 | ### Overview 12 | 13 | Graphs are a powerful tool for representing and analyzing complex relationships 14 | in real-world applications. Large Language Models (LLMs) have demonstrated 15 | impressive capabilities by advancing state-of-the-art on many language-based 16 | benchmarks. Their ability to process and understand natural language open 17 | exciting possibilities in various domains. Despite the remarkable progress in 18 | automated reasoning with natural text, reasoning on graphs with LLMs remains an 19 | understudied problem that has recently gained more attention. 20 | 21 | This tutorial builds upon recent advances in expressing reasoning problems 22 | through the lens of tasks on graph data. The first part of the tutorial will 23 | provide an in-depth discussion of techniques for representing graphs as inputs 24 | to LLMs. The second, hands-on, portion will demonstrate these techniques in a 25 | practical setting. 26 | 27 | ### Schedule 28 | 29 | This talk is being given at [KDD 2024](https://kdd2024.kdd.org/) at 2pm in room 30 | P1 122-123 at the Centre de Convencions Internacional de Barcelona. 31 | 32 | ### Tutorial Material 33 | 34 | The slides and notebooks used in the tutorial are available here: 35 | 36 | - [Tutorial Slides](https://drive.google.com/file/d/16rHZVyCeGDfY3djVafKHFPhnabOVtd03/) 37 | 38 | - [Talk Like a Graph Notebook](https://github.com/google-research/talk-like-a-graph/blob/main/tutorial/KDD-Tutorial-1-Talk-Like-a-Graph.ipynb) 39 | 40 | - [GraphToken (Let Your Graph Do the Talking) Notebook](https://github.com/google-research/talk-like-a-graph/blob/main/tutorial/KDD-Tutorial-2-Let-Your-Graph-Do-The-Talking.ipynb) 41 | 42 | 43 | 44 | ### Additional Resources 45 | 46 | *coming soon* 47 | -------------------------------------------------------------------------------- /tutorial/imgs/let_your_graph_do_the_talking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/tutorial/imgs/let_your_graph_do_the_talking.png -------------------------------------------------------------------------------- /tutorial/imgs/talk_like_a_graph_colab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/talk-like-a-graph/36af51e19ef7a44049d64306e3cae56c07067e81/tutorial/imgs/talk_like_a_graph_colab.png --------------------------------------------------------------------------------