├── docs ├── requirements.txt ├── logo.png ├── Makefile ├── make.bat └── source │ ├── 010 │ ├── install_and_setup.rst │ └── tutorial.rst │ ├── 030 │ ├── llm_baseline.rst │ ├── random_search.rst │ ├── bayesian_optimization.rst │ ├── evolutionary_algorithm.rst │ ├── reinforcement_learning.rst │ ├── benchmark_dataset.rst │ ├── greedy_search.rst │ ├── run_benchmark.rst │ └── micro_action.rst │ ├── index.rst │ ├── conf.py │ └── 020 │ ├── download_precomputed.rst │ └── reproduce_datasets.rst ├── rdb2g_bench ├── common │ ├── __init__.py │ ├── text_embedder.py │ └── search_space │ │ ├── search_space.py │ │ ├── row2ne_search_space.py │ │ └── gnn_search_space.py ├── benchmark │ ├── llm │ │ ├── __init__.py │ │ ├── prompts │ │ │ ├── task.json │ │ │ ├── action.json │ │ │ └── prompt.py │ │ ├── llm_micro_action.py │ │ └── llm_utils.py │ ├── __init__.py │ ├── bench_runner.py │ └── baselines │ │ ├── random.py │ │ ├── utils.py │ │ └── ea.py ├── dataset │ ├── __init__.py │ ├── utils.py │ ├── models │ │ ├── modules │ │ │ ├── sage_conv_edge.py │ │ │ ├── nn.py │ │ │ ├── gps_conv.py │ │ │ └── gin_conv.py │ │ └── model.py │ └── dataset.py └── __init__.py ├── .gitignore ├── .readthedocs.yaml ├── MANIFEST.in ├── requirements.txt ├── examples ├── run_benchmark.py ├── run_gnn_node.py ├── run_idgnn_link.py ├── load_dataset.py └── download_dataset.py ├── LICENSE ├── pyproject.toml └── README.md /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.1.2 2 | sphinx-rtd-theme==1.3.0rc1 3 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chlehdwon/RDB2G-Bench/HEAD/docs/logo.png -------------------------------------------------------------------------------- /rdb2g_bench/common/__init__.py: -------------------------------------------------------------------------------- 1 | from . import search_space 2 | 3 | __all__ = ["search_space"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | logs/ 4 | results/ 5 | rdb2g_bench.egg-info 6 | build/ -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_runner import run_llm_baseline 2 | 3 | __all__ = [ 4 | "run_llm_baseline" 5 | ] -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from . import baselines 2 | from . import llm 3 | from .bench_runner import run_benchmark 4 | 5 | __all__ = [ 6 | "baselines", 7 | "llm", 8 | "run_benchmark" 9 | ] -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: pip 12 | path: . 13 | 14 | sphinx: 15 | configuration: docs/source/conf.py -------------------------------------------------------------------------------- /rdb2g_bench/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | from .node_worker import run_gnn_node_worker 3 | from .link_worker import run_idgnn_link_worker 4 | from .utils import integrate_edge_tf 5 | 6 | __all__ = [ 7 | "models", 8 | "run_gnn_node_worker", 9 | "run_idgnn_link_worker", 10 | "integrate_edge_tf" 11 | ] -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | include pyproject.toml 5 | recursive-include rdb2g_bench *.py 6 | recursive-include rdb2g_bench *.txt 7 | recursive-include rdb2g_bench *.md 8 | recursive-include rdb2g_bench *.json 9 | recursive-exclude rdb2g_bench __pycache__ 10 | recursive-exclude rdb2g_bench *.pyc 11 | recursive-exclude rdb2g_bench *.pyo -------------------------------------------------------------------------------- /rdb2g_bench/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | __author__ = "Dongwon Choi" 3 | __email__ = "chlehdwon@kaist.ac.kr" 4 | 5 | # Import main modules for easy access 6 | from . import benchmark 7 | from . import common 8 | from . import dataset 9 | 10 | __all__ = [ 11 | "__version__", 12 | "__author__", 13 | "__email__", 14 | "benchmark", 15 | "common", 16 | "dataset", 17 | ] -------------------------------------------------------------------------------- /rdb2g_bench/common/text_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | 5 | # Please run `pip install -U sentence-transformers` 6 | from sentence_transformers import SentenceTransformer 7 | from torch import Tensor 8 | 9 | 10 | class GloveTextEmbedding: 11 | def __init__(self, device: Optional[torch.device] = None): 12 | self.model = SentenceTransformer( 13 | "sentence-transformers/average_word_embeddings_glove.6B.300d", 14 | device=device, 15 | ) 16 | 17 | def __call__(self, sentences: List[str]) -> Tensor: 18 | return self.model.encode(sentences, convert_to_tensor=True) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies 2 | torch>=2.1.0 3 | torch_geometric==2.6.1 4 | numpy>=1.26.0 5 | pandas>=2.2.0 6 | scikit-learn>=1.6.0 7 | sentence-transformers>=3.3.1 8 | 9 | # RelBench 10 | relbench==1.1.0 11 | pytorch-frame>=0.2.5 12 | 13 | # Hugging Face datasets 14 | datasets>=2.15.0 15 | 16 | # Plotting and visualization 17 | matplotlib>=3.10.0 18 | seaborn>=0.13.0 19 | networkx>=3.1 20 | 21 | # Utilities 22 | tqdm>=4.65.0 23 | 24 | # LLM dependencies 25 | anthropic>=0.25.0 26 | 27 | # PyTorch Geometric dependencies 28 | # Install separately: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html 29 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/run_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["ANTHROPIC_API_KEY"] = "YOUR_API_KEY" 3 | 4 | from rdb2g_bench.benchmark.bench_runner import run_benchmark 5 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline 6 | 7 | # Example 1: Basic run 8 | run_benchmark( 9 | dataset="rel-f1", 10 | task="driver-top3", 11 | gnn="GraphSAGE", 12 | budget_percentage=0.05, 13 | method="all", 14 | num_runs=1, 15 | seed=0, 16 | ) 17 | 18 | # Example 2: Run only for specific methods (Currently, Greedy, BO, RL, and EA are supported) 19 | run_benchmark( 20 | dataset="rel-f1", 21 | task="driver-top3", 22 | gnn="GraphSAGE", 23 | budget_percentage=0.05, 24 | method=["rl", "bo"], 25 | num_runs=1, 26 | seed=0, 27 | ) 28 | 29 | # Example 3: Run LLM-based baseline 30 | run_llm_baseline( 31 | dataset="rel-f1", 32 | task="driver-top3", 33 | gnn="GraphSAGE", 34 | budget_percentage=0.05, 35 | seed=0, 36 | ) -------------------------------------------------------------------------------- /docs/source/010/install_and_setup.rst: -------------------------------------------------------------------------------- 1 | Install and Setup 2 | ================== 3 | 4 | Installation 5 | ------------ 6 | 7 | Clone the repository and install the package: 8 | 9 | .. code-block:: bash 10 | 11 | git clone https://github.com/chlehdwon/RDB2G-Bench.git 12 | cd RDB2G-Bench 13 | pip install -e . 14 | 15 | PyTorch Geometric Dependencies 16 | ------------------------------ 17 | 18 | Install additional PyG dependencies. The example below shows installation for torch 2.1.0 + cuda 12.1: 19 | 20 | .. code-block:: bash 21 | 22 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html 23 | 24 | .. note:: 25 | You can skip this step if you don't want to reproduce the datasets. 26 | 27 | Environment Setup for LLM usage 28 | -------------------------------- 29 | 30 | For LLM-based baselines, set up your API key: 31 | 32 | .. code-block:: bash 33 | 34 | export ANTHROPIC_API_KEY="YOUR_API_KEY" 35 | -------------------------------------------------------------------------------- /examples/run_gnn_node.py: -------------------------------------------------------------------------------- 1 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker 2 | 3 | # Example 1: Basic run 4 | results = run_gnn_node_worker( 5 | dataset_name="rel-f1", 6 | task_name="driver-top3", 7 | gnn="GraphSAGE", 8 | ) 9 | 10 | print(results['processed_graphs']) # [0, ..., 721] 11 | print(results['total_processed']) # 722 12 | print(results['csv_file']) # ./results/tables/rel-f1/driver-top3/{tag}/GraphSAGE/42.csv 13 | 14 | # Example 2: Run parallelly on multiple GPUs 15 | results_even = run_gnn_node_worker( 16 | dataset_name="rel-f1", 17 | task_name="driver-top3", 18 | gnn="GraphSAGE", 19 | idx=0, 20 | workers=2, 21 | device="cuda:0" 22 | ) 23 | 24 | results_odd = run_gnn_node_worker( 25 | dataset_name="rel-f1", 26 | task_name="driver-top3", 27 | gnn="GraphSAGE", 28 | idx=1, 29 | workers=2, 30 | device="cuda:1" 31 | ) 32 | 33 | # Example 3: Run worker on specific GNN and target indices 34 | results_0 = run_gnn_node_worker( 35 | dataset_name="rel-f1", 36 | task_name="driver-top3", 37 | gnn="GPS", 38 | target_indices=[0] 39 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Dongwon Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/run_idgnn_link.py: -------------------------------------------------------------------------------- 1 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker 2 | 3 | # Example 1: Basic run 4 | results = run_idgnn_link_worker( 5 | dataset_name="rel-avito", 6 | task_name="user-ad-visit", 7 | gnn="GraphSAGE", 8 | ) 9 | 10 | print(results['processed_graphs']) # [0, ..., 908] 11 | print(results['total_processed']) # 909 12 | print(results['csv_file']) # ./results/tables/rel-avito/user-ad-visit/{tag}/GraphSAGE/42.csv 13 | 14 | # Example 2: Run parallelly on multiple GPUs 15 | results_even = run_idgnn_link_worker( 16 | dataset_name="rel-avito", 17 | task_name="user-ad-visit", 18 | gnn="GraphSAGE", 19 | idx=0, 20 | workers=2, 21 | device="cuda:0" 22 | ) 23 | 24 | results_odd = run_idgnn_link_worker( 25 | dataset_name="rel-avito", 26 | task_name="user-ad-visit", 27 | gnn="GraphSAGE", 28 | idx=1, 29 | workers=2, 30 | device="cuda:1" 31 | ) 32 | 33 | # Example 3: Run worker on specific target indices 34 | results_0 = run_idgnn_link_worker( 35 | dataset_name="rel-avito", 36 | task_name="user-ad-visit", 37 | gnn="GraphSAGE", 38 | target_indices=[0] 39 | ) -------------------------------------------------------------------------------- /docs/source/030/llm_baseline.rst: -------------------------------------------------------------------------------- 1 | LLM-based Baseline 2 | ================== 3 | 4 | This module implements Large Language Model (LLM) based baseline. 5 | 6 | How it Works 7 | ------------ 8 | 9 | The LLM-based baseline process: 10 | 11 | 1. Provide the LLM with problem description and search space 12 | 2. LLM suggests promising micro actions for graph models 13 | 3. Evaluate suggested micro actions on the actual task 14 | 4. Provide feedback to the LLM about performance 15 | 5. LLM refines its suggestions based on feedback 16 | 6. Repeat until the budget is exhausted 17 | 18 | LLM-Based Baseline 19 | ------------------ 20 | 21 | .. automodule:: rdb2g_bench.benchmark.llm.llm_runner 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | Example Usage 27 | ~~~~~~~~~~~~~ 28 | 29 | .. code-block:: python 30 | 31 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline 32 | 33 | # Run LLM-based baseline 34 | results = run_llm_baseline( 35 | dataset="rel-f1", 36 | task="driver-top3", 37 | gnn="GraphSAGE", 38 | budget_percentage=0.05, 39 | seed=42, 40 | model="claude-3-5-sonnet-latest", 41 | temperature=0.8, 42 | ) 43 | 44 | -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/prompts/task.json: -------------------------------------------------------------------------------- 1 | { 2 | "driver-dnf": "For each driver predict the if they will DNF (did not finish) a race in the next 1 month.", 3 | "driver-top3": "For each driver predict if they will qualify in the top-3 for a race in the next 1 month.", 4 | "driver-position": "Predict the average finishing position of each driver all races in the next 2 months.", 5 | "user-repeat": "Predict whether a user will attend an event(by responding yes or maybe) in the next 7 days if they have already attended an event in the last 14 days.", 6 | "user-ignore": "Predict whether a user will ignore more than 2 event invitations in the next 7 days", 7 | "user-attendance": "Predict how many events each user will respond yes or maybe in the next seven days.", 8 | "user-clicks": "Predict whether each customer will click on more than one Ads in the next 4 days.", 9 | "user-visits": "Predict whether each customer will visit more than one Ad in the next 4 days.", 10 | "ad-ctr": "Assuming the Ad will be clicked in the next 4 days, predict the Click-Through-Rate (CTR) for each Ad.", 11 | "user-ad-visit":"Predict the list of ads a user will visit in the next 4 days.", 12 | "post-post-related" :"Predict a list of existing posts that users will link a given post to in the next two years.", 13 | "study-outcome":"Predict if the trials will achieve its primary outcome (defined as p-value < 0.05)." 14 | } 15 | 16 | -------------------------------------------------------------------------------- /docs/source/030/random_search.rst: -------------------------------------------------------------------------------- 1 | Random Search 2 | ============= 3 | 4 | This module implements random search baseline for RDB-to-Graph modeling search. 5 | Random search provides a simple yet effective baseline by uniformly sampling graph models 6 | from the search space without any heuristic guidance, serving as an important comparison 7 | point for evaluating more sophisticated optimization algorithms. 8 | 9 | How it Works 10 | ------------ 11 | 12 | The random search algorithm process: 13 | 14 | 1. Uniformly sample graph models from the entire search space 15 | 2. Evaluate the performance of each sampled graph model 16 | 3. Keep track of the best graph model found so far 17 | 4. Repeat until the budget is exhausted 18 | 19 | Random Search Baseline 20 | ---------------------- 21 | 22 | .. automodule:: rdb2g_bench.benchmark.baselines.random 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Example Usage 28 | ~~~~~~~~~~~~~ 29 | 30 | .. code-block:: python 31 | 32 | from rdb2g_bench.benchmark.runner import run_benchmark 33 | 34 | # Run random search baseline 35 | results = run_benchmark( 36 | dataset="rel-f1", 37 | task="driver-top3", 38 | gnn="GraphSAGE", 39 | budget_percentage=0.05, 40 | method=["random"], 41 | num_runs=10, 42 | seed=42 43 | ) 44 | 45 | # Access results 46 | print(f"Best architecture found: {results['random']['selected_graph_id']}") 47 | print(f"Performance: {results['random']['actual_y_perf_of_selected']:.4f}") 48 | print(f"Architectures evaluated: {results['random']['discovered_count']}") -------------------------------------------------------------------------------- /docs/source/010/tutorial.rst: -------------------------------------------------------------------------------- 1 | Tutorial 2 | ======== 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | This tutorial provides a quick overview of RDB2G-Bench. 8 | 9 | Quick Start Example 10 | ------------------- 11 | 12 | Here's a simple example to get you started: 13 | 14 | .. code-block:: python 15 | 16 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench 17 | 18 | # Load pre-computed benchmark results 19 | bench = load_rdb2g_bench("./results") 20 | 21 | # Access specific results 22 | result = bench['rel-f1']['driver-top3'][0] 23 | test_metric = result['test_metric'] 24 | params = result['params'] 25 | train_time = result['train_time'] 26 | 27 | print(f"Test Metric: {test_metric}") 28 | print(f"Training Time: {train_time}") 29 | 30 | Run First Benchmark 31 | ----------------------------- 32 | 33 | .. code-block:: python 34 | 35 | from rdb2g_bench.benchmark.runner import run_benchmark 36 | 37 | # Run a simple benchmark 38 | results = run_benchmark( 39 | dataset="rel-f1", 40 | task="driver-top3", 41 | budget_percentage=0.05, 42 | method="all", 43 | num_runs=3, 44 | seed=42 45 | ) 46 | 47 | For more detailed examples, check the ``examples/`` directory in the repository. 48 | 49 | Package Structure Overview 50 | -------------------------- 51 | 52 | .. code-block:: text 53 | 54 | rdb2g_bench/ 55 | ├── benchmark/ # Core benchmarking functionality 56 | ├── common/ # Shared utilities and search spaces 57 | ├── dataset/ # Dataset loading and processing 58 | └── __init__.py # Package initialization -------------------------------------------------------------------------------- /docs/source/030/bayesian_optimization.rst: -------------------------------------------------------------------------------- 1 | Bayesian Optimization 2 | ===================== 3 | 4 | This module implements Bayesian optimization baseline for RDB-to-Graph modeling search. 5 | Bayesian optimization is a sequential model-based optimization technique particularly 6 | effective for expensive black-box optimization problems like finding optimal graph models. 7 | 8 | How it Works 9 | ------------ 10 | 11 | The Bayesian optimization algorithm process: 12 | 13 | 1. Build a probabilistic model (surrogate) of the objective function 14 | 2. Use acquisition function to determine next graph model to evaluate 15 | 3. Evaluate the objective function at the selected graph model 16 | 4. Update the surrogate model with new observation 17 | 5. Repeat until budget is exhausted 18 | 19 | Bayesian Optimization Baseline 20 | ------------------------------ 21 | 22 | .. automodule:: rdb2g_bench.benchmark.baselines.bo 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Example Usage 28 | ~~~~~~~~~~~~~ 29 | 30 | .. code-block:: python 31 | 32 | from rdb2g_bench.benchmark.runner import run_benchmark 33 | 34 | # Run Bayesian optimization with default parameters 35 | results = run_benchmark( 36 | dataset="rel-f1", 37 | task="driver-top3", 38 | gnn="GraphSAGE", 39 | budget_percentage=0.05, 40 | method=["bo"], 41 | num_runs=10, 42 | seed=42 43 | ) 44 | 45 | # Access results 46 | print(f"Best architecture found: {results['bo']['selected_graph_id']}") 47 | print(f"Performance: {results['bo']['actual_y_perf_of_selected']:.4f}") 48 | print(f"Rank: {results['bo']['rank_position_overall']}") 49 | -------------------------------------------------------------------------------- /docs/source/030/evolutionary_algorithm.rst: -------------------------------------------------------------------------------- 1 | Evolutionary Algorithm 2 | ====================== 3 | 4 | This module implements evolutionary algorithm baseline for RDB-to-Graph modeling search. 5 | Evolutionary algorithms are population-based metaheuristics that use biological evolution 6 | mechanisms such as reproduction, mutation, and selection to find optimal solutions. 7 | 8 | How it Works 9 | ------------ 10 | 11 | The evolutionary algorithm process: 12 | 13 | 1. Initialize a population of random graph models 14 | 2. Evaluate the performance of each graph model 15 | 3. Select parents for reproduction 16 | 4. Apply crossover and mutation operators to the selected graph models 17 | 5. Replace old population with new generation 18 | 6. Repeat until budget is exhausted 19 | 20 | Evolutionary Algorithm Baseline 21 | ------------------------------- 22 | 23 | .. automodule:: rdb2g_bench.benchmark.baselines.ea 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Example Usage 29 | ~~~~~~~~~~~~~ 30 | 31 | .. code-block:: python 32 | 33 | from rdb2g_bench.benchmark.runner import run_benchmark 34 | 35 | # Run evolutionary algorithm with default parameters 36 | results = run_benchmark( 37 | dataset="rel-f1", 38 | task="driver-top3", 39 | gnn="GraphSAGE", 40 | budget_percentage=0.05, 41 | method=["ea"], 42 | num_runs=10, 43 | seed=42 44 | ) 45 | 46 | # Access results 47 | print(f"Best architecture found: {results['ea']['selected_graph_id']}") 48 | print(f"Performance: {results['ea']['actual_y_perf_of_selected']:.4f}") 49 | print(f"Generations completed: {results['ea']['total_iterations_run']}") 50 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to RDB2G-Bench Documentation! 2 | ======================================== 3 | 4 | This is the official documentation of the paper **RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases.** 5 | 6 | **RDB2G-Bench** is an *easy-to-use framework* for benchmarking graph-based analysis and prediction tasks by converting relational database data into graphs. 7 | 8 | Contents 9 | -------- 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Get Started 14 | 15 | 010/install_and_setup 16 | 010/tutorial 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Datasets 21 | 22 | 020/download_precomputed 23 | 020/reproduce_datasets 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :caption: Benchmarks 28 | 29 | 030/benchmark_dataset 30 | 030/micro_action 31 | 030/run_benchmark 32 | 030/random_search 33 | 030/greedy_search 34 | 030/evolutionary_algorithm 35 | 030/bayesian_optimization 36 | 030/reinforcement_learning 37 | 030/llm_baseline 38 | 39 | Citation 40 | -------- 41 | 42 | If you use RDB2G-Bench in your research, please cite: 43 | 44 | .. code-block:: bibtex 45 | 46 | @article{choi2025rdb2gbench, 47 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases}, 48 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin}, 49 | year={2025}, 50 | url={https://arxiv.org/abs/2506.01360}, 51 | } 52 | 53 | License 54 | ------- 55 | 56 | This project is distributed under the MIT License as specified in the LICENSE file. 57 | 58 | 59 | -------------------------------------------------------------------------------- /docs/source/030/reinforcement_learning.rst: -------------------------------------------------------------------------------- 1 | Reinforcement Learning 2 | ====================== 3 | 4 | This module implements reinforcement learning baseline for RDB-to-Graph modeling search. 5 | The approach uses deep reinforcement learning with policy gradients to train 6 | a recurrent neural network controller that learns to generate sequences of micro actions 7 | for constructing high-performing graph models. 8 | 9 | How it Works 10 | ------------ 11 | 12 | The reinforcement learning algorithm process: 13 | 14 | 1. Initialize RNN-based controller to learn micro actions for graph model optimization 15 | 2. Start from random graph model and convert to state embedding 16 | 3. Controller selects actions 17 | 4. Evaluate the performance of new graph model and compute reward 18 | 5. Update controller policy using Policy Gradient based on discounted rewards 19 | 6. Repeat episodes until the budget is exhausted 20 | 21 | Reinforcement Learning Baseline 22 | ------------------------------- 23 | 24 | .. automodule:: rdb2g_bench.benchmark.baselines.rl 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | Example Usage 30 | ~~~~~~~~~~~~~ 31 | 32 | .. code-block:: python 33 | 34 | from rdb2g_bench.benchmark.runner import run_benchmark 35 | 36 | # Run reinforcement learning with default parameters 37 | results = run_benchmark( 38 | dataset="rel-f1", 39 | task="driver-top3", 40 | gnn="GraphSAGE", 41 | budget_percentage=0.05, 42 | method=["rl"], 43 | num_runs=10, 44 | seed=42 45 | ) 46 | 47 | # Access results 48 | print(f"Best architecture found: {results['rl']['selected_graph_id']}") 49 | print(f"Performance: {results['rl']['actual_y_perf_of_selected']:.4f}") 50 | print(f"Episodes completed: {results['rl']['total_iterations_run']}") 51 | -------------------------------------------------------------------------------- /docs/source/030/benchmark_dataset.rst: -------------------------------------------------------------------------------- 1 | Performance Prediction Dataset 2 | =============================== 3 | 4 | This module implements the ``PerformancePredictionDataset`` class, which is used to load performance data for benchmarking RDB-to-Graph search algorithms on RDB2G-Bench. 5 | 6 | The ``PerformancePredictionDataset`` class is the core data interface for RDB2G-Bench. It loads pre-computed performance results from CSV files, processes graph configurations, and provides a unified interface for benchmark algorithms to access performance data. 7 | 8 | How it Works 9 | ------------ 10 | 11 | The dataset loading process involves the following steps: 12 | 13 | 1. **RelBench Integration**: Connects to the specified RelBench dataset and task 14 | 2. **Graph Materialization**: Creates heterogeneous graph data with proper embeddings 15 | 3. **Results Loading**: Reads performance data from CSV files for the specified GNN backbone architectures 16 | 4. **Data Aggregation**: Groups results by graph models and aggregates across seeds 17 | 5. **Graph Indexing**: Creates mappings between graph models and search space 18 | 19 | Performance Prediction Dataset 20 | ------------------------------- 21 | 22 | .. automodule:: rdb2g_bench.benchmark.dataset 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Example Usage 28 | ------------- 29 | 30 | .. code-block:: python 31 | 32 | from rdb2g_bench.benchmark.dataset import PerformancePredictionDataset 33 | 34 | # Initialize dataset 35 | dataset = PerformancePredictionDataset( 36 | dataset_name="rel-f1", 37 | task_name="driver-top3", 38 | gnn="GraphSAGE", 39 | tag="hf", 40 | result_dir="./results" 41 | ) 42 | 43 | # Access basic information 44 | print(f"Number of configurations: {len(dataset)}") 45 | print(f"Target metric: {dataset.target_col}") 46 | print(f"Task type: {dataset.task.task_type}") 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /docs/source/030/greedy_search.rst: -------------------------------------------------------------------------------- 1 | Greedy Search 2 | ============= 3 | 4 | This module implements multiple greedy search strategies for RDB-to-Graph modeling search. 5 | Greedy algorithms make locally optimal choices at each step, providing fast and deterministic 6 | approaches for finding good graph models with minimal computational overhead. 7 | 8 | How it Works 9 | ------------ 10 | 11 | The greedy search algorithm implements three different greedy strategies for graph model optimization: 12 | 13 | 1. **Forward Greedy**: Starts with a graph model only with target table(s) and iteratively moves to the best graph model 14 | 2. **Backward Greedy**: Starts with the All-Rows-to-Nodes(AR2N) graph model and iteratively moves to the best graph model 15 | 3. **Local Greedy**: Combines forward and backward greedy strategies with randomly initialized graph model 16 | 17 | The greedy search algorithm process: 18 | 19 | 1. Starts with a random graph model 20 | 2. Evaluates the performance of the graph model 21 | 3. Selects the best local improvement based on the chosen greedy strategy 22 | 4. Repeats until the budget is exhausted 23 | 24 | 25 | Greedy Search Baseline 26 | ---------------------- 27 | 28 | .. automodule:: rdb2g_bench.benchmark.baselines.greedy 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Example Usage 34 | ~~~~~~~~~~~~~ 35 | 36 | .. code-block:: python 37 | 38 | from rdb2g_bench.benchmark.runner import run_benchmark 39 | 40 | # Run greedy search by default 41 | results = run_benchmark( 42 | dataset="rel-f1", 43 | task="driver-top3", 44 | gnn="GraphSAGE", 45 | budget_percentage=0.05, 46 | method=["greedy"], 47 | num_runs=10, 48 | seed=42 49 | ) 50 | 51 | # Access results 52 | print(f"Best architecture found: {results['greedy']['selected_graph_id']}") 53 | print(f"Performance: {results['greedy']['actual_y_perf_of_selected']:.4f}") 54 | print(f"Greedy steps completed: {results['greedy']['total_iterations_run']}") 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rdb2g-bench" 7 | version = "0.1.1" 8 | description = "A benchmark framework for automatic graph modeling of relational databases" 9 | readme = "README.md" 10 | license = {file = "LICENSE"} 11 | authors = [ 12 | {name = "RDB2G-Bench Team", email = "chlehdwon@kaist.ac.kr"} 13 | ] 14 | maintainers = [ 15 | {name = "RDB2G-Bench Team", email = "chlehdwon@kaist.ac.kr"} 16 | ] 17 | keywords = ["graph neural networks", "relational databases", "benchmark", "machine learning"] 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Science/Research", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "License :: OSI Approved :: MIT License", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Operating System :: OS Independent", 31 | ] 32 | requires-python = ">=3.7" 33 | dependencies = [ 34 | "torch>=2.1.0", 35 | "numpy>=1.26.0", 36 | "pandas>=2.2.0", 37 | "scikit-learn>=1.6.0", 38 | "pytorch-frame>=0.2.5", 39 | "torch_geometric==2.6.1", 40 | "relbench==1.1.0", 41 | "sentence-transformers>=3.3.1", 42 | "matplotlib>=3.10.0", 43 | "seaborn>=0.13.0", 44 | "networkx>=3.1", 45 | "tqdm>=4.65.0", 46 | "anthropic>=0.25.0", 47 | "datasets>=3.6.0" 48 | ] 49 | 50 | [project.urls] 51 | Homepage = "https://github.com/chlehdwon/RDB2G-Bench" 52 | Repository = "https://github.com/chlehdwon/RDB2G-Bench" 53 | Issues = "https://github.com/chlehdwon/RDB2G-Bench/issues" 54 | 55 | [tool.setuptools.packages.find] 56 | include = ["rdb2g_bench*"] 57 | 58 | [tool.setuptools.package-data] 59 | rdb2g_bench = ["*.txt", "*.md"] -------------------------------------------------------------------------------- /rdb2g_bench/common/search_space/search_space.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch_geometric.data import HeteroData 7 | from .gnn_search_space import GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace 8 | from .row2ne_search_space import Row2NESearchSpace 9 | 10 | class TotalSearchSpace(): 11 | def __init__(self, 12 | dataset: str, 13 | task: str, 14 | hetero_data: HeteroData, 15 | GNNSearchSpace: Union[GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace], 16 | num_layers: int, 17 | src_entity_table: str, 18 | dst_entity_table: str = None 19 | ): 20 | self.dataset = dataset 21 | self.task = task 22 | self.data = hetero_data 23 | self.row2graph = Row2NESearchSpace(dataset, task, hetero_data) 24 | self.num_layers = num_layers 25 | self.src_entity_table = src_entity_table 26 | self.dst_entity_table = dst_entity_table 27 | self.full_edges = self.row2graph.find_full_edges() 28 | self.gnn_search_space = GNNSearchSpace( 29 | dataset=self.dataset, 30 | task=self.task, 31 | num_layers=self.num_layers, 32 | node_types=self.row2graph.node_types, 33 | edge_types=self.full_edges, 34 | src_entity_table=self.src_entity_table, 35 | dst_entity_table=self.dst_entity_table 36 | ) 37 | 38 | def get_possible_edges(self): 39 | return self.row2graph.find_possible_edges() 40 | 41 | def get_full_edges(self): 42 | return self.full_edges 43 | 44 | def generate_all_graphs(self): 45 | return self.gnn_search_space.generate_all_graphs() 46 | 47 | def get_full_graph_idx(self, graphs: List[tuple]) -> int: 48 | graphs = np.array(graphs) 49 | r2n_idx = np.where(np.sum(graphs[:, :len(self.get_possible_edges())], axis=1) == 0)[0] 50 | return -1 if len(r2n_idx) == 0 else r2n_idx[-1] 51 | 52 | def get_data(self, edges: Union[torch.Tensor, tuple]) -> HeteroData: 53 | if type(edges) != torch.Tensor: 54 | edges = torch.tensor(edges) 55 | converted_data = self.row2graph.convert_row_to_edge(edges) 56 | return self.gnn_search_space.get_data(edges, converted_data) -------------------------------------------------------------------------------- /examples/load_dataset.py: -------------------------------------------------------------------------------- 1 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench 2 | 3 | bench = load_rdb2g_bench(result_dir="./results") 4 | 5 | available = bench.get_available() 6 | print(available) 7 | """ 8 | { 9 | 'rel-stack': { 10 | 'post-post-related': ['GraphSAGE', 'GIN', 'GPS'] 11 | }, 12 | 'rel-event': { 13 | 'user-attendance': ['GraphSAGE', 'GIN', 'GPS'], 14 | 'user-repeat': ['GraphSAGE', 'GIN', 'GPS'], 15 | 'user-ignore': ['GraphSAGE', 'GIN', 'GPS'] 16 | }, 17 | 'rel-f1': { 18 | 'driver-top3': ['GraphSAGE', 'GIN', 'GPS'], 19 | 'driver-position': ['GraphSAGE', 'GIN', 'GPS'], 20 | 'driver-dnf': ['GraphSAGE', 'GIN', 'GPS'] 21 | }, 22 | 'rel-trial': { 23 | 'study-outcome': ['GraphSAGE', 'GIN', 'GPS'] 24 | }, 25 | 'rel-avito': { 26 | 'user-visits': ['GraphSAGE', 'GIN', 'GPS'], 27 | 'ad-ctr': ['GraphSAGE', 'GIN', 'GPS'], 28 | 'user-clicks': ['GraphSAGE', 'GIN', 'GPS'], 29 | 'user-ad-visit': ['GraphSAGE', 'GIN', 'GPS'] 30 | } 31 | } 32 | """ 33 | 34 | # Access specific task and GNN model 35 | task = bench['rel-f1']['driver-top3'] 36 | available_gnns = task.get_available_gnns() 37 | print(f"Available GNNs: {available_gnns}") # ['GraphSAGE', 'GIN', 'GPS'] 38 | 39 | # Access specific GNN model 40 | gnn = task['GraphSAGE'] 41 | indices = gnn.get_available_indices() 42 | print(f"Available indices for GraphSAGE: {len(indices)}") # 722 43 | 44 | # Get results for specific graph configuration 45 | result = gnn[0] 46 | print(f"Index 0 results: {result.stats}") 47 | """ 48 | Index 0: {'test_metric_mean': 0.805233155657748, 49 | 'test_metric_std': 0.034900872955993194, 50 | 'params': 954241, 51 | 'train_time': 0.30655655304590856, 52 | 'valid_time': 0.06576369603474932, 53 | 'test_time': 0.05956562360127762} 54 | """ 55 | 56 | # Access aggregated statistics 57 | stats = result.stats 58 | test_metric_mean = stats['test_metric_mean'] 59 | test_metric_std = stats['test_metric_std'] 60 | params = stats['params'] 61 | train_time = stats['train_time'] 62 | 63 | print(f"Test metric: {test_metric_mean:.4f} ± {test_metric_std:.4f}") # 0.8052 ± 0.0347 64 | print(f"Parameters: {params:.0f}") # 954241 65 | print(f"Train time: {train_time:.4f} seconds") # 0.3066 66 | 67 | # Compare different GNN models 68 | print("\nComparing GNN models:") 69 | for gnn_name in available_gnns: 70 | gnn_model = task[gnn_name] 71 | result_0 = gnn_model[0] 72 | perf = result_0.stats['test_metric_mean'] 73 | print(f"{gnn_name}: {perf:.4f}") -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/prompts/action.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_fk_pk_edge": "Here is the introduction of add_fk_pk_edge:\nDescription:\nCreates a directed edge from one table to another by adding a foreign key (FK) to primary key (PK) relationship. \nUse when you need to represent an important directional relationship between two tables in your graph schema.\nParameters:\n from_table_name: the name of the table containing the foreign key\n from_col_name: the name of the foreign key column in to_table\n to_table_name: the name of the table containing the primary key", 3 | "remove_fk_pk_edge": "Here is the introduction of remove_fk_pk_edge:\nDescription:\nEliminates a directed edge between tables by removing a FK-PK relationship. \nUse when a previously modeled relationship doesn't add meaningful context to your graph structure and should be excluded.\nParameters:\n from_table_name: the name of the table containing the foreign key\n from_col_name: the name of the primary key column in to_table\n to_table_name: the name of the table containing the primary key", 4 | "convert_row_to_edge": "Here is the introduction of convert_row_to_edge:\nDescription:\nTransforms what was originally modeled as an entity table into a relationship edge in your graph. \nUse when an intermediate table (denoted as edge_table_name) better represents a relationship property between two tables (denoted as table_1_name and table_2_name) rather than being an independent entity.\nNote that table_1_name and table_2_name can be equal when the edge_table_name has 2 foreign keys which refer to the same primary key.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the table to convert to edge between table_1_name and table_2_name", 5 | "convert_edge_to_row": "Here is the introduction of convert_edge_to_row:\nDescription:\nTransforms what was modeled as a relationship edge into a proper entity table in your graph. \nUse when an edge contains sufficient attributes and identity to justify becoming an entity table with its own properties.\nNote that table_1_name and table_2_name can be equal when the edge_table_name has 2 foreign keys which refer to the same primary key.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name", 6 | "add_row2edge_edge": "Here is the introduction of add_row2edge_edge:\nDescription:\nCreate a new edge (relationship) between two tables by introducing an edge (join) table.\nUse this when a many-to-many relationship between two tables is needed and meaningful for your graph schema.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name", 7 | "remove_row2edge_edge": "Here is the introduction of remove_row2edge_edge:\nDescription:\nRemove an edge (relationship) between two tables by deleting the edge (join) table.\nUse this when the many-to-many relationship is not necessary for your graph schema.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name" 8 | } -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.abspath('.')) 6 | sys.path.append(os.path.abspath('..')) 7 | sys.path.append(os.path.abspath('../..')) 8 | 9 | # ReadTheDocs-specific configuration 10 | on_rtd = os.environ.get('READTHEDOCS') == 'True' 11 | if on_rtd: 12 | # ReadTheDocs automatically installs the package via .readthedocs.yaml 13 | # But we need to ensure the path is correct 14 | sys.path.insert(0, os.path.abspath('../../')) 15 | 16 | # -- Project information 17 | 18 | project = 'RDB2G-Bench' 19 | copyright = 'KAIST Data Mining Lab' 20 | author = 'Dongwon Choi' 21 | 22 | release = '0.1' 23 | version = '0.1.1' 24 | 25 | # -- General configuration 26 | 27 | extensions = [ 28 | 'sphinx.ext.duration', 29 | 'sphinx.ext.autodoc', 30 | 'sphinx.ext.autosummary', 31 | 'sphinx.ext.doctest', 32 | 'sphinx.ext.intersphinx', 33 | 'sphinx.ext.todo', 34 | 'sphinx.ext.coverage', 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.ifconfig', 37 | 'sphinx.ext.napoleon', 38 | 'sphinx.ext.ifconfig', 39 | 'sphinx.ext.viewcode', 40 | 'sphinx.ext.githubpages', 41 | 'sphinx.ext.todo', 42 | 'sphinx.ext.napoleon', 43 | ] 44 | 45 | intersphinx_mapping = { 46 | 'python': ('https://docs.python.org/3/', None), 47 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 48 | } 49 | intersphinx_disabled_domains = ['std'] 50 | 51 | templates_path = ['_templates'] 52 | 53 | # -- Options for HTML output 54 | 55 | html_theme = 'sphinx_rtd_theme' 56 | 57 | # Show project logo at the top-left of the docs 58 | html_logo = '../logo.png' 59 | 60 | # HTML theme options for sphinx_rtd_theme 61 | html_theme_options = { 62 | 'prev_next_buttons_location': 'bottom', 63 | 'style_external_links': False, 64 | 'vcs_pageview_mode': '', 65 | 'style_nav_header_background': '#FFFFFF', 66 | 'logo_only': True, 67 | # These options are for the theme but we'll add GitHub link via html_context 68 | } 69 | 70 | # Add GitHub link in the HTML context 71 | html_context = { 72 | 'display_github': True, 73 | 'github_user': 'chlehdwon', 74 | 'github_repo': 'RDB2G-Bench', 75 | 'github_version': 'main', 76 | 'conf_py_path': '/docs/source/', 77 | } 78 | 79 | # -- Autodoc configuration for type hints 80 | autodoc_typehints = 'signature' 81 | 82 | autodoc_mock_imports = [ 83 | "torch", "numpy", "pandas", "typing_extensions", 84 | "torch_frame", "torch_geometric", "torch_scatter", "torch_sparse", 85 | "torch_cluster", "torch_spline_conv", "pyg_lib", 86 | "dgl", "sklearn", "scipy", "networkx", 87 | "ogb", "tqdm", "qpth", "quadprog", "cvxpy", "rdkit", "dgllife", 88 | "relbench", "anthropic", "openai", 89 | "typin", "pathlib", "json", "ast", "copy", "itertools", "pickle", 90 | "matplotlib", "seaborn", "plotly", "wandb", "tensorboard", 91 | "transformers", "datasets", "tokenizers", "accelerate", 92 | "optuna", "ray", "hyperopt", "ax-platform", 93 | "psutil", "memory_profiler", "line_profiler", 94 | # ReadTheDocs specific mocks 95 | "torch.nn", "torch.optim", "torch.utils", "torch.cuda" 96 | ] 97 | 98 | if on_rtd: 99 | autodoc_mock_imports.extend([ 100 | "torch.nn.functional", "torch.distributions", 101 | "torch_geometric.nn", "torch_geometric.data", "torch_geometric.utils", 102 | "relbench.base", "relbench.datasets", "relbench.modeling", "relbench.tasks" 103 | ]) 104 | 105 | # -- Options for EPUB output 106 | epub_show_urls = 'footnote' 107 | 108 | add_module_names = False -------------------------------------------------------------------------------- /examples/download_dataset.py: -------------------------------------------------------------------------------- 1 | from rdb2g_bench.dataset.dataset import download_rdb2g_bench, get_dataset_stats 2 | 3 | # List all available datasets and tasks with GNN models 4 | df_stats = get_dataset_stats(cache_dir="~/.cache") 5 | print(df_stats) 6 | """ 7 | dataset task gnn idx seed test_metric_mean test_metric_std test_metric_min test_metric_max 8 | 0 rel-avito ad-ctr GraphSAGE 1304 5 0.0423 0.0010 0.0374 0.0458 9 | 1 rel-avito ad-ctr GIN 1304 5 0.0419 0.0012 0.0365 0.0461 10 | 2 rel-avito ad-ctr GPS 1304 5 0.0425 0.0009 0.0378 0.0459 11 | 3 rel-avito user-ad-visit GraphSAGE 909 5 0.0165 0.0076 0.0016 0.0372 12 | 4 rel-avito user-ad-visit GIN 909 5 0.0163 0.0074 0.0018 0.0369 13 | 5 rel-avito user-ad-visit GPS 909 5 0.0168 0.0078 0.0014 0.0375 14 | 6 rel-f1 driver-top3 GraphSAGE 722 15 0.7847 0.0150 0.6554 0.8732 15 | 7 rel-f1 driver-top3 GIN 722 15 0.7823 0.0148 0.6521 0.8698 16 | 8 rel-f1 driver-top3 GPS 722 15 0.7891 0.0152 0.6598 0.8756 17 | ... 18 | """ 19 | 20 | # Filter statistics by specific GNN model 21 | graphsage_stats = df_stats[df_stats['gnn'] == 'GraphSAGE'] 22 | print("\nGraphSAGE performance across datasets:") 23 | print(graphsage_stats[['dataset', 'task', 'test_metric_mean', 'test_metric_std']]) 24 | 25 | # Download entire RDB2G-Bench dataset (includes all GNN models) 26 | saved_files = download_rdb2g_bench( 27 | result_dir="./results", 28 | cache_dir="~/.cache", 29 | tag="hf" 30 | ) 31 | print(f"\nDownloaded {len(saved_files)} dataset/task combinations") 32 | for combo, files in saved_files.items(): 33 | print(f"{combo}: {len(files)} files") 34 | 35 | # Download specific RDB2G-Bench dataset with specific GNN model 36 | saved_files = download_rdb2g_bench( 37 | result_dir="./results", 38 | cache_dir="~/.cache", 39 | dataset_names=["rel-f1"], 40 | task_names=["driver-top3"], 41 | gnn_names=["GIN"], 42 | tag="hf_gin" 43 | ) 44 | print(f"\nSpecific GNN model download saved files:") 45 | for combo, files in saved_files.items(): 46 | print(f"{combo}: {files}") -------------------------------------------------------------------------------- /docs/source/020/download_precomputed.rst: -------------------------------------------------------------------------------- 1 | Download Pre-computed Datasets 2 | =============================== 3 | 4 | This module provides functionality to load and access pre-computed benchmark results without running experiments. 5 | The RDB2G-Bench dataset can be downloaded from Hugging Face Hub and accessed through a hierarchical interface. 6 | 7 | The RDB2G-Bench dataset is also available from the HuggingFace: https://huggingface.co/datasets/kaistdata/RDB2G-Bench 8 | 9 | Dataset Module 10 | -------------- 11 | 12 | The dataset module provides functions for downloading and loading benchmark data. 13 | 14 | .. automodule:: rdb2g_bench.dataset.dataset 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Dataloader Module 20 | ----------------- 21 | 22 | The dataloader module provides classes for hierarchical access to benchmark results. 23 | 24 | .. automodule:: rdb2g_bench.dataset.dataloader 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | Data Access Pattern 30 | ------------------- 31 | 32 | The benchmark data follows a hierarchical access pattern: 33 | 34 | .. code-block:: text 35 | 36 | RDB2GBench[dataset_name][task_name][gnn_name][idx] -> IndexAccessor 37 | 38 | Example: 39 | bench['rel-f1']['driver-top3']['GraphSAGE'][0] -> Results for graph configuration 0 with GraphSAGE 40 | 41 | Directory Structure 42 | ~~~~~~~~~~~~~~~~~~~ 43 | 44 | The downloaded data is organized in the following directory structure: 45 | 46 | .. code-block:: text 47 | 48 | results/ 49 | tables/ 50 | dataset_name/ # e.g., rel-f1, rel-avito 51 | task_name/ # e.g., driver-top3, user-ad-visit 52 | tag/ # e.g., hf 53 | gnn_name/ # e.g., GraphSAGE, GIN, GPS 54 | 0.csv # Results for seed 0 55 | 1.csv # Results for seed 1 56 | ... 57 | 58 | Example Usage 59 | ~~~~~~~~~~~~~ 60 | 61 | .. code-block:: python 62 | 63 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench 64 | 65 | # Load benchmark results (downloads if missing) 66 | bench = load_rdb2g_bench("./results") 67 | 68 | # Check available datasets, tasks, and GNN models 69 | available = bench.get_available() 70 | print(available) 71 | # {'rel-f1': {'driver-top3': ['GIN', 'GPS', 'GraphSAGE'], ...}, ...} 72 | 73 | # Access results for rel-f1 dataset, driver-top3 task, GraphSAGE model 74 | result = bench['rel-f1']['driver-top3']['GraphSAGE'][0] # First graph configuration 75 | 76 | # Extract performance metrics (aggregated across seeds) 77 | test_metric = result['test_metric'] # Test performance mean 78 | test_std = result['test_metric_std'] # Test performance std 79 | params = result['params'] # Model parameters 80 | train_time = result['train_time'] # Training time 81 | 82 | CSV File Format 83 | ~~~~~~~~~~~~~~~ 84 | 85 | Each CSV file contains the following columns: 86 | 87 | .. code-block:: text 88 | 89 | idx : Graph configuration index 90 | graph : Graph structure representation (e.g., "graph_00000000000010") 91 | train_metric : Training performance metric 92 | valid_metric : Validation performance metric 93 | test_metric : Test performance metric 94 | params : Number of trainable parameters 95 | train_time : Average training time per epoch 96 | valid_time : Validation time 97 | test_time : Test time 98 | dataset : Dataset name 99 | task : Task name 100 | seed : Random seed 101 | gnn : GNN model name (GraphSAGE, GIN, GPS) 102 | 103 | Accessing GNN-specific Results 104 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 105 | 106 | The benchmark now supports GNN-specific access patterns: 107 | 108 | .. code-block:: python 109 | 110 | # Access specific GNN models 111 | graphsage_result = bench['rel-f1']['driver-top3']['GraphSAGE'][0] 112 | gin_result = bench['rel-f1']['driver-top3']['GIN'][0] 113 | 114 | # Get available GNN models for a task 115 | available_gnns = bench['rel-f1']['driver-top3'].get_available_gnns() 116 | print(available_gnns) # ['GIN', 'GPS', 'GraphSAGE'] 117 | 118 | # Compare performance across GNN models 119 | for gnn in available_gnns: 120 | result = bench['rel-f1']['driver-top3'][gnn][0] 121 | print(f"{gnn}: {result['test_metric']}") 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /docs/source/030/run_benchmark.rst: -------------------------------------------------------------------------------- 1 | Run Benchmark 2 | ============= 3 | 4 | This module provides the main interface for running comprehensive benchmarks on RDB2G-Bench. 5 | The benchmark runner executes multiple RDB-to-Graph modeling search algorithms and provides 6 | statistical analysis and comparison across different methods, datasets, and tasks. 7 | 8 | The benchmark system supports all available RDB-to-Graph modeling strategies and automatically handles 9 | data preparation, caching, multi-run execution, and results aggregation. 10 | 11 | Benchmark Runner Interface 12 | -------------------------- 13 | 14 | .. automodule:: rdb2g_bench.benchmark.bench_runner 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Core Benchmark Engine 20 | --------------------- 21 | 22 | .. automodule:: rdb2g_bench.benchmark.benchmark 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | Example Usage 28 | ~~~~~~~~~~~~~ 29 | 30 | Basic Benchmark Execution 31 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 32 | 33 | .. code-block:: python 34 | 35 | from rdb2g_bench.benchmark.bench_runner import run_benchmark 36 | 37 | # Run all methods on default dataset and task with specific GNN 38 | results = run_benchmark( 39 | dataset="rel-f1", 40 | task="driver-top3", 41 | gnn="GraphSAGE", 42 | budget_percentage=0.05, 43 | method="all", 44 | num_runs=10, 45 | seed=42 46 | ) 47 | 48 | # Print summary of results 49 | for method, stats in results.items(): 50 | if 'avg_actual_y_perf_of_selected' in stats: 51 | perf = stats['avg_actual_y_perf_of_selected'] 52 | rank = stats.get('avg_rank_position_overall', 'N/A') 53 | print(f"{method}: Performance={perf:.4f}, Rank={rank}") 54 | 55 | Specific Methods Comparison 56 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 57 | 58 | .. code-block:: python 59 | 60 | # Compare evolutionary algorithm vs greedy search 61 | results = run_benchmark( 62 | dataset="rel-avito", 63 | task="user-ad-visit", 64 | gnn="GIN", 65 | budget_percentage=0.05, 66 | method=["ea", "greedy", "random"], 67 | num_runs=15, 68 | seed=123 69 | ) 70 | 71 | # Extract key metrics 72 | for method in ["ea", "greedy", "random"]: 73 | if method in results: 74 | stats = results[method] 75 | print(f"\n{method.upper()} Results:") 76 | print(f" Average Performance: {stats['avg_actual_y_perf_of_selected']:.4f}") 77 | print(f" Average Rank: {stats['avg_rank_position_overall']:.1f}") 78 | print(f" Average Runtime: {stats['avg_run_time']:.2f}s") 79 | 80 | Available Methods 81 | ----------------- 82 | 83 | The benchmark runner supports the following search methods: 84 | 85 | - **"all"**: Run all available methods 86 | - **"ea"**: Evolutionary Algorithm baseline 87 | - **"greedy"**: Greedy Search strategies (forward, backward, random) 88 | - **"rl"**: Reinforcement Learning with policy gradients 89 | - **"bo"**: Bayesian Optimization with surrogate models 90 | 91 | You can specify a single method as a string or multiple methods as a list. 92 | 93 | Results Structure 94 | ----------------- 95 | 96 | The returned results dictionary contains method-wise statistics: 97 | 98 | .. code-block:: python 99 | 100 | { 101 | "Method Name": { 102 | "avg_actual_y_perf_of_selected": float, # Average performance 103 | "avg_rank_position_overall": float, # Average ranking 104 | "avg_percentile_overall": float, # Average percentile 105 | "total_samples_overall": int, # Total architectures 106 | "selected_graph_ids_runs": list, # Selected graph IDs 107 | "avg_selection_metric_value": float, # Average selection metric 108 | "selected_graph_origins": list, # Method origins 109 | "avg_evaluation_time": float, # Average evaluation time 110 | "avg_run_time": float # Average total runtime 111 | } 112 | } 113 | 114 | Output Files 115 | ------------ 116 | 117 | The benchmark automatically generates several output files: 118 | 119 | - **Individual Trajectories**: ``avg_trajectory_{method}_{gnn}_{num_runs}runs.csv`` 120 | - **Combined Trajectories**: ``all_methods_trajectories_{gnn}_{num_runs}runs.csv`` 121 | - **Performance Summary**: ``performance_summary_{gnn}_{num_runs}runs.csv`` 122 | 123 | These files are saved in: ``{result_dir}/benchmark/{dataset}/{task}/{tag}/{gnn}/`` 124 | -------------------------------------------------------------------------------- /docs/source/020/reproduce_datasets.rst: -------------------------------------------------------------------------------- 1 | Reproduce Datasets 2 | ================== 3 | 4 | This module provides functionality to reproduce the datasets and benchmark results from the RDB2G-Bench paper. 5 | The workers support both node-level tasks (classification/regression) and link prediction tasks (recommendation) 6 | with various GNN architectures and comprehensive evaluation metrics. 7 | 8 | Node Worker 9 | ----------- 10 | 11 | The node worker handles node classification and regression tasks using Graph Neural Networks. 12 | 13 | .. automodule:: rdb2g_bench.dataset.node_worker 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | Link Worker 19 | ----------- 20 | 21 | The link worker handles link prediction tasks using ID-aware Graph Neural Networks (IDGNN). 22 | 23 | .. automodule:: rdb2g_bench.dataset.link_worker 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Task Types 29 | ---------- 30 | 31 | The benchmark supports different types of tasks: 32 | 33 | Node-Level Tasks 34 | ~~~~~~~~~~~~~~~~ 35 | 36 | - **Binary Classification**: Predicting binary labels for nodes (e.g., driver performance) 37 | - **Regression**: Predicting continuous values for nodes (e.g., ratings, scores) 38 | - **Multilabel Classification**: Predicting multiple binary labels per node 39 | 40 | Link Prediction Tasks 41 | ~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | - **Recommendation**: Predicting user-item interactions using ranking metrics 44 | - **Link Prediction**: General link prediction between entities 45 | 46 | Supported Models 47 | ---------------- 48 | 49 | Both workers support multiple GNN architectures: 50 | 51 | - **GraphSAGE**: Inductive graph representation learning 52 | - **GIN**: Graph Isomorphism Network for graph-level tasks 53 | - **GPS**: Graph transformer with positional encodings 54 | 55 | Example Usage 56 | ---------------- 57 | 58 | Node Classification/Regression 59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 60 | 61 | .. code-block:: python 62 | 63 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker 64 | 65 | # Run binary classification experiment 66 | results = run_gnn_node_worker( 67 | dataset_name="rel-f1", 68 | task_name="driver-top3", 69 | gnn="GraphSAGE", 70 | epochs=20, 71 | lr=0.005, 72 | batch_size=512, 73 | channels=128, 74 | ) 75 | print(f"Processed {results['total_processed']} graph configurations") 76 | 77 | # Run regression experiment 78 | results = run_gnn_node_worker( 79 | dataset_name="rel-f1", 80 | task_name="driver-position", 81 | gnn="GIN", 82 | epochs=50, 83 | lr=0.001, 84 | weight_decay=1e-4 85 | ) 86 | 87 | # Parallel processing with multiple workers 88 | results = run_gnn_node_worker( 89 | dataset_name="rel-f1", 90 | task_name="driver-top3", 91 | idx=0, # Worker 0 ~ 3 92 | workers=4, # Total 4 workers 93 | epochs=20 94 | ) 95 | 96 | Link Prediction 97 | ~~~~~~~~~~~~~~~ 98 | 99 | .. code-block:: python 100 | 101 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker 102 | 103 | # Run recommendation experiment 104 | results = run_idgnn_link_worker( 105 | dataset_name="rel-avito", 106 | task_name="user-ad-visit", 107 | gnn="GraphSAGE", 108 | epochs=20, 109 | lr=0.001, 110 | batch_size=512, 111 | channels=128, 112 | temporal_strategy="last", 113 | ) 114 | 115 | # Run with specific graph configurations 116 | results = run_idgnn_link_worker( 117 | dataset_name="rel-avito", 118 | task_name="user-ad-visit", 119 | target_indices=[0, 5, 10, 15], # Run only these configurations 120 | epochs=30, 121 | patience=10 122 | ) 123 | 124 | 125 | Parallel Processing 126 | ~~~~~~~~~~~~~~~~~~~ 127 | 128 | .. code-block:: python 129 | 130 | import multiprocessing as mp 131 | from concurrent.futures import ProcessPoolExecutor 132 | 133 | def run_worker(worker_id, total_workers): 134 | """Run a single worker process.""" 135 | return run_gnn_node_worker( 136 | dataset_name="rel-f1", 137 | task_name="driver-top3", 138 | idx=worker_id, 139 | workers=total_workers, 140 | epochs=20, 141 | save_csv=True 142 | ) 143 | 144 | # Run multiple workers in parallel 145 | num_workers = 4 146 | with ProcessPoolExecutor(max_workers=num_workers) as executor: 147 | futures = [ 148 | executor.submit(run_worker, i, num_workers) 149 | for i in range(num_workers) 150 | ] 151 | 152 | # Collect results 153 | all_results = [future.result() for future in futures] 154 | total_processed = sum(r['total_processed'] for r in all_results) 155 | print(f"Total graphs processed: {total_processed}") 156 | -------------------------------------------------------------------------------- /rdb2g_bench/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def integrate_edge_tf(batch, edge_tf_dict): 5 | r2e_edge_types = [] 6 | rev_r2e_edge_types = [] 7 | 8 | for edge_type in batch.edge_types: 9 | src, rel, dst = edge_type 10 | if rel.startswith('r2e'): 11 | r2e_edge_types.append(edge_type) 12 | elif rel.startswith('rev_r2e'): 13 | rev_r2e_edge_types.append(edge_type) 14 | 15 | for edge_type in edge_tf_dict: 16 | del batch[edge_type] 17 | 18 | table_to_edge_types = {} 19 | for edge_type in r2e_edge_types: 20 | src, rel, dst = edge_type 21 | table_name = rel[4:] 22 | if table_name not in table_to_edge_types: 23 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []} 24 | table_to_edge_types[table_name]['r2e'].append(edge_type) 25 | 26 | for edge_type in rev_r2e_edge_types: 27 | src, rel, dst = edge_type 28 | table_name = rel[8:] 29 | if table_name not in table_to_edge_types: 30 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []} 31 | table_to_edge_types[table_name]['rev_r2e'].append(edge_type) 32 | 33 | for table_name, edge_types in table_to_edge_types.items(): 34 | if table_name not in edge_tf_dict: 35 | continue 36 | 37 | all_mapped_ids = set() 38 | r2e_mapped_ids_dict = {} 39 | rev_r2e_mapped_ids_dict = {} 40 | 41 | for edge_type in edge_types['r2e']: 42 | if 'mapped_node_ids' in batch[edge_type] and len(batch[edge_type]['mapped_node_ids']) > 0: 43 | mapped_ids = batch[edge_type]['mapped_node_ids'].cpu().numpy() 44 | all_mapped_ids.update(mapped_ids) 45 | r2e_mapped_ids_dict[edge_type] = mapped_ids 46 | 47 | for edge_type in edge_types['rev_r2e']: 48 | if 'mapped_node_ids' in batch[edge_type] and len(batch[edge_type]['mapped_node_ids']) > 0: 49 | mapped_ids = batch[edge_type]['mapped_node_ids'].cpu().numpy() 50 | all_mapped_ids.update(mapped_ids) 51 | rev_r2e_mapped_ids_dict[edge_type] = mapped_ids 52 | 53 | if not all_mapped_ids: 54 | continue 55 | 56 | all_mapped_ids = sorted(list(all_mapped_ids)) 57 | mapped_tensor = torch.tensor(all_mapped_ids, dtype=torch.long) 58 | 59 | batch[table_name]['tf'] = edge_tf_dict[table_name][mapped_tensor] 60 | 61 | max_id = max(all_mapped_ids) + 1 62 | id_to_idx_array = np.full(max_id, -1) 63 | for idx, id_val in enumerate(all_mapped_ids): 64 | id_to_idx_array[id_val] = idx 65 | 66 | for edge_type in edge_types['r2e']: 67 | if edge_type in r2e_mapped_ids_dict: 68 | old_mapped_ids = r2e_mapped_ids_dict[edge_type] 69 | new_indices = torch.tensor(id_to_idx_array[old_mapped_ids], dtype=torch.long) 70 | batch[edge_type]['mapped_node_ids'] = new_indices 71 | 72 | for edge_type in edge_types['rev_r2e']: 73 | if edge_type in rev_r2e_mapped_ids_dict: 74 | old_mapped_ids = rev_r2e_mapped_ids_dict[edge_type] 75 | new_indices = torch.tensor(id_to_idx_array[old_mapped_ids], dtype=torch.long) 76 | batch[edge_type]['mapped_node_ids'] = new_indices 77 | 78 | return batch 79 | 80 | def divide_node_edge_dict(batch, x_dict): 81 | table_to_edge_types = {} 82 | for edge_type in batch.edge_types: 83 | src, rel, dst = edge_type 84 | if rel.startswith('r2e'): 85 | table_name = rel[4:] 86 | if table_name not in table_to_edge_types: 87 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []} 88 | table_to_edge_types[table_name]['r2e'].append(edge_type) 89 | elif rel.startswith('rev_r2e'): 90 | table_name = rel[8:] 91 | if table_name not in table_to_edge_types: 92 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []} 93 | table_to_edge_types[table_name]['rev_r2e'].append(edge_type) 94 | edge_dict = {} 95 | tables_to_remove = set() 96 | for table_name, edge_types in table_to_edge_types.items(): 97 | if table_name in x_dict: 98 | table_features = x_dict[table_name] 99 | for edge_type in edge_types['r2e'] + edge_types['rev_r2e']: 100 | assert hasattr(batch[edge_type], 'mapped_node_ids') 101 | mapped_ids = batch[edge_type]['mapped_node_ids'] 102 | edge_dict[edge_type] = table_features[mapped_ids] 103 | tables_to_remove.add(table_name) 104 | 105 | for table_name in tables_to_remove: 106 | if table_name in x_dict: 107 | del x_dict[table_name] 108 | 109 | return x_dict, edge_dict -------------------------------------------------------------------------------- /docs/source/030/micro_action.rst: -------------------------------------------------------------------------------- 1 | Micro Actions 2 | ============= 3 | 4 | This module defines the core micro actions used by all optimization algorithms for RDB-to-Graph modeling search. 5 | These atomic operations enable systematic exploration of the graph model space by transforming 6 | one valid graph model into another. 7 | 8 | How it Works 9 | ------------ 10 | 11 | Micro actions represent atomic operations for transforming graph models: 12 | 13 | 1. **add_fk_pk_edge**: Add foreign key-primary key edge between tables 14 | 2. **remove_fk_pk_edge**: Remove foreign key-primary key edge between tables 15 | 3. **convert_row_to_edge**: Convert row representation to edge representation 16 | 4. **convert_edge_to_row**: Convert edge representation to row representation 17 | 18 | Each action transforms the current graph model (edge set) to a new valid graph model, enabling systematic exploration of the graph model space. 19 | 20 | Micro Action Set 21 | ---------------- 22 | 23 | .. automodule:: rdb2g_bench.benchmark.micro_action 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Example Usage 29 | ~~~~~~~~~~~~~ 30 | 31 | Example with Data Preparation 32 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 33 | 34 | .. code-block:: python 35 | 36 | import os 37 | import json 38 | from pathlib import Path 39 | from torch_frame import stype 40 | from torch_frame.config.text_embedder import TextEmbedderConfig 41 | from relbench.datasets import get_dataset 42 | from relbench.tasks import get_task 43 | from relbench.modeling.graph import make_pkey_fkey_graph 44 | from relbench.modeling.utils import get_stype_proposal 45 | 46 | from rdb2g_bench.benchmark.micro_action import MicroActionSet 47 | from rdb2g_bench.common.search_space.gnn_search_space import GNNNodeSearchSpace 48 | from rdb2g_bench.common.text_embedder import GloveTextEmbedding 49 | 50 | # Step 1: Load dataset and task 51 | dataset_name = "rel-f1" 52 | task_name = "driver-top3" 53 | 54 | dataset = get_dataset(dataset_name, download=True) 55 | task = get_task(dataset_name, task_name, download=True) 56 | 57 | # Step 2: Prepare column type information 58 | cache_dir = os.path.expanduser("~/.cache/relbench_examples") 59 | stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json") 60 | 61 | try: 62 | with open(stypes_cache_path, "r") as f: 63 | col_to_stype_dict = json.load(f) 64 | for table, col_to_stype in col_to_stype_dict.items(): 65 | for col, stype_str in col_to_stype.items(): 66 | col_to_stype[col] = stype(stype_str) 67 | except FileNotFoundError: 68 | col_to_stype_dict = get_stype_proposal(dataset.get_db()) 69 | Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True) 70 | with open(stypes_cache_path, "w") as f: 71 | json.dump(col_to_stype_dict, f, indent=2, default=str) 72 | 73 | # Step 3: Create heterogeneous graph data 74 | device = "cuda" if torch.cuda.is_available() else "cpu" 75 | hetero_data, col_stats_dict = make_pkey_fkey_graph( 76 | dataset.get_db(), 77 | col_to_stype_dict=col_to_stype_dict, 78 | text_embedder_cfg=TextEmbedderConfig( 79 | text_embedder=GloveTextEmbedding(device=device), 80 | batch_size=256 81 | ), 82 | cache_dir=f"{cache_dir}/{dataset_name}/materialized", 83 | ) 84 | 85 | # Step 4: Initialize micro action set 86 | micro_actions = MicroActionSet( 87 | dataset=dataset_name, 88 | task=task_name, 89 | hetero_data=hetero_data, 90 | GNNSpaceClass=GNNNodeSearchSpace, 91 | num_layers=2, 92 | src_entity_table=task.entity_table 93 | ) 94 | 95 | print(f"Total number of valid graph models: {len(micro_actions.valid_edge_sets_list)}") 96 | print(f"Number of FK-PK edges: {len(micro_actions.fk_pk_indices)}") 97 | print(f"Number of R2E edges: {len(micro_actions.r2e_indices)}") 98 | 99 | Basic Micro Action Operations 100 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 101 | 102 | .. code-block:: python 103 | 104 | # Get the first valid edge set as starting point 105 | current_edge_set = micro_actions.valid_edge_sets_list[0] 106 | print(f"Starting edge set: {current_edge_set}") 107 | 108 | # Explore all possible FK-PK edge additions 109 | add_fk_actions = micro_actions.add_fk_pk_edge(current_edge_set) 110 | print(f"Possible FK-PK additions: {len(add_fk_actions)}") 111 | 112 | for new_set, index in add_fk_actions[:3]: # Show first 3 113 | print(f" Add action: {current_edge_set} -> {new_set} (index: {index})") 114 | 115 | # Explore FK-PK edge removals 116 | remove_fk_actions = micro_actions.remove_fk_pk_edge(current_edge_set) 117 | print(f"Possible FK-PK removals: {len(remove_fk_actions)}") 118 | 119 | # Explore row-to-edge conversions 120 | row_to_edge_actions = micro_actions.convert_row_to_edge(current_edge_set) 121 | print(f"Possible row-to-edge conversions: {len(row_to_edge_actions)}") 122 | 123 | # Explore edge-to-row conversions 124 | edge_to_row_actions = micro_actions.convert_edge_to_row(current_edge_set) 125 | print(f"Possible edge-to-row conversions: {len(edge_to_row_actions)}") 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ---- 6 | 7 | [![Latest Release](https://img.shields.io/badge/Latest-v0.1.2-success)](https://github.com/chlehdwon/RDB2G-Bench/releases) 8 | [![Read the Docs](https://img.shields.io/readthedocs/RDB2G-Bench)](https://rdb2g-bench.readthedocs.io/en/latest/) 9 | [![Hugging Face](https://img.shields.io/badge/🤗_Hugging_Face-Datasets-blue)](https://huggingface.co/datasets/kaistdata/RDB2G-Bench) 10 | [![arXiv](https://img.shields.io/badge/arXiv-2506.01360-b31b1b.svg)](https://arxiv.org/abs/2506.01360) 11 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 12 | 13 | This is the official implementation of the paper **RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases.** 14 | 15 | **RDB2G-Bench** is an **easy-to-use framework** for benchmarking graph-based analysis and prediction tasks by converting relational database data into graphs. 16 | 17 | ## 🚀 Installation 18 | 19 | ```bash 20 | git clone https://github.com/chlehdwon/RDB2G-Bench.git 21 | cd RDB2G-Bench 22 | pip install -e . 23 | ``` 24 | 25 | Also, please install additional PyG dependencies. The below shows an example when you use torch 2.1.0 + cuda 12.1. 26 | 27 | ```bash 28 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html 29 | ``` 30 | You can skip this part if you don't want to reproduce our dataset. 31 | 32 | ## ⚡ Package Usage 33 | 34 | Comprehensive documentation and detailed guides are available at [our documentation site](https://rdb2g-bench.readthedocs.io/en/latest/). 35 | 36 | You can also check the `examples/` directory for complete usage examples and tutorials. 37 | 38 | ### Download Pre-computed Datasets 39 | 40 | ```python 41 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench 42 | 43 | bench = load_rdb2g_bench("./results") 44 | 45 | result = bench['rel-f1']['driver-top3'][0] # Access by graph index 46 | test_metric = result['test_metric'] # Test performance 47 | params = result['params'] # Model parameters 48 | train_time = result['train_time'] # Train time 49 | ``` 50 | 51 | ### Reproduce Datasets for Classification & Regression Tasks 52 | 53 | ```python 54 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker 55 | 56 | results = run_gnn_node_worker( 57 | dataset_name="rel-f1", 58 | task_name="driver-top3", 59 | gnn="GraphSAGE", 60 | epochs=20, 61 | lr=0.005 62 | ) 63 | ``` 64 | 65 | ### Reproduce Datasets for Recommendation Tasks 66 | 67 | ```python 68 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker 69 | 70 | results = run_idgnn_link_worker( 71 | dataset_name="rel-avito", 72 | task_name="user-ad-visit", 73 | gnn="GraphSAGE", 74 | epochs=20, 75 | lr=0.001 76 | ) 77 | ``` 78 | 79 | ### Run Benchmarks 80 | 81 | ```python 82 | from rdb2g_bench.benchmark.bench_runner import run_benchmark 83 | 84 | results = run_benchmark( 85 | dataset="rel-f1", 86 | task="driver-top3", 87 | gnn="GraphSAGE", 88 | budget_percentage=0.05, 89 | method="all", 90 | num_runs=10, 91 | seed=0 92 | ) 93 | ``` 94 | 95 | ### Run LLM-based baseline 96 | 97 | Before using LLM-based baseline, you need to set up your API key: 98 | 99 | ```bash 100 | export ANTHROPIC_API_KEY="YOUR_API_KEY" 101 | ``` 102 | 103 | ```python 104 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline 105 | 106 | results = run_llm_baseline( 107 | dataset="rel-f1", 108 | task="driver-top3", 109 | gnn="GraphSAGE", 110 | budget_percentage=0.05, 111 | model="claude-3-5-sonnet-latest", 112 | temperature=0.8, 113 | seed=42 114 | ) 115 | ``` 116 | 117 | ## 📁 Package Structure 118 | 119 | ``` 120 | rdb2g_bench/ 121 | ├── benchmark/ # Core benchmarking functionality 122 | │ ├── llm/ # LLM-based baseline methods 123 | │ └── baselines/ # Other baseline methods 124 | ├── common/ # Shared utilities and search spaces 125 | ├── dataset/ # Dataset loading and processing 126 | └── __init__.py # Package initialization 127 | ``` 128 | 129 | ## 📖 Reference 130 | 131 | The dataset construction and implementation of RDB2G-Bench is based on the [RelBench](https://github.com/snap-stanford/relbench) framework. 132 | 133 | ## 📝 Citation 134 | 135 | If you use RDB2G-Bench in your research, please cite: 136 | 137 | ```bibtex 138 | @inproceedings{choi2025rdb2gbench, 139 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases}, 140 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin}, 141 | year={2025}, 142 | booktitle={NeurIPS}, 143 | } 144 | ``` 145 | or 146 | ```bibtex 147 | @article{choi2025rdb2gbench, 148 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases}, 149 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin}, 150 | year={2025}, 151 | url={https://arxiv.org/abs/2506.01360}, 152 | } 153 | ``` 154 | 155 | ## ⚖️ License 156 | 157 | This project is distributed under the MIT License as specified in the LICENSE file. 158 | 159 | 160 | -------------------------------------------------------------------------------- /rdb2g_bench/dataset/models/modules/sage_conv_edge.py: -------------------------------------------------------------------------------- 1 | # Custom implementation of SAGEConv with edge features based on torch_geometric.nn.conv.SAGEConv 2 | 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from torch_geometric.nn.aggr import Aggregation, MultiAggregation 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.nn.dense.linear import Linear 12 | from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor, PairTensor 13 | from torch_geometric.utils import spmm 14 | 15 | 16 | class SAGEConvEdge(MessagePassing): 17 | r"""The GraphSAGE operator with edge features support. 18 | 19 | This implementation extends the original GraphSAGE to handle edge features by: 20 | 1. Transforming edge features separately 21 | 2. Combining node and edge features in the message passing phase 22 | 3. Supporting both directed and undirected edges 23 | 24 | Args: 25 | in_channels (int or tuple): Size of each input sample, or :obj:`-1` to 26 | derive the size from the first input(s) to the forward method. 27 | A tuple corresponds to the sizes of source and target 28 | dimensionalities. 29 | out_channels (int): Size of each output sample. 30 | normalize (bool, optional): If set to :obj:`False`, output features 31 | will not be :math:`\ell_2`-normalized. (default: :obj:`True`) 32 | root_weight (bool, optional): If set to :obj:`False`, the layer will 33 | not add the transformed root node features to the output. 34 | (default: :obj:`True`) 35 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 36 | an additive bias. (default: :obj:`True`) 37 | edge_feature_channels (int, optional): Size of edge features. If not provided, 38 | will use the same size as node features. (default: :obj:`None`) 39 | **kwargs (optional): Additional arguments of 40 | :class:`torch_geometric.nn.conv.MessagePassing`. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | in_channels: Union[int, Tuple[int, int]], 46 | out_channels: int, 47 | aggr: Optional[Union[str, List[str], Aggregation]] = "mean", 48 | normalize: bool = False, 49 | root_weight: bool = True, 50 | project: bool = False, 51 | bias: bool = True, 52 | **kwargs, 53 | ): 54 | self.in_channels = in_channels 55 | self.out_channels = out_channels 56 | self.normalize = normalize 57 | self.root_weight = root_weight 58 | self.project = project 59 | 60 | if isinstance(in_channels, int): 61 | in_channels = (in_channels, in_channels) 62 | 63 | super().__init__(aggr, **kwargs) 64 | 65 | if self.project: 66 | if in_channels[0] <= 0: 67 | raise ValueError(f"'{self.__class__.__name__}' does not " 68 | f"support lazy initialization with " 69 | f"`project=True`") 70 | self.lin = Linear(in_channels[0], in_channels[0], bias=True) 71 | 72 | if isinstance(self.aggr_module, MultiAggregation): 73 | aggr_out_channels = self.aggr_module.get_out_channels( 74 | in_channels[0]) 75 | else: 76 | # node features + edge features 77 | aggr_out_channels = in_channels[0] * 2 78 | 79 | # Linear layer for node features 80 | self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias) 81 | if self.root_weight: 82 | self.lin_r = Linear(in_channels[1], out_channels, bias=False) 83 | 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | super().reset_parameters() 88 | if self.project: 89 | self.lin.reset_parameters() 90 | self.lin_l.reset_parameters() 91 | if self.root_weight: 92 | self.lin_r.reset_parameters() 93 | 94 | def forward( 95 | self, 96 | x: Union[Tensor, PairTensor], 97 | edge_index: Adj, 98 | edge_attr: Optional[Tensor] = None, 99 | size: Size = None, 100 | ) -> Tensor: 101 | if isinstance(x, Tensor): 102 | x = (x, x) 103 | 104 | if self.project and hasattr(self, 'lin'): 105 | x = (self.lin(x[0]).relu(), x[1]) 106 | 107 | # propagate_type: (x: PairTensor, edge_attr: OptTensor) 108 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 109 | out = self.lin_l(out) 110 | 111 | x_r = x[1] 112 | if self.root_weight and x_r is not None: 113 | out += self.lin_r(x_r) 114 | 115 | if self.normalize: 116 | out = F.normalize(out, p=2., dim=-1) 117 | 118 | return out 119 | 120 | def message( 121 | self, 122 | x_j: Tensor, 123 | edge_attr: Optional[Tensor] = None, 124 | ) -> Tensor: 125 | if edge_attr is None: 126 | edge_attr = torch.zeros_like(x_j) 127 | return torch.cat([x_j, edge_attr], dim=-1) 128 | 129 | def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: 130 | if isinstance(adj_t, SparseTensor): 131 | adj_t = adj_t.set_value(None, layout=None) 132 | return spmm(adj_t, x[0], reduce=self.aggr) 133 | 134 | def __repr__(self) -> str: 135 | return (f'{self.__class__.__name__}({self.in_channels}, ' 136 | f'{self.out_channels}, aggr={self.aggr})') 137 | -------------------------------------------------------------------------------- /rdb2g_bench/dataset/models/model.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/snap-stanford/relbench/blob/main/examples/model.py 2 | 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import Embedding, ModuleDict 8 | from torch_frame.data.stats import StatType 9 | from torch_geometric.data import HeteroData 10 | from torch_geometric.nn import MLP 11 | from torch_geometric.typing import NodeType 12 | 13 | from relbench.modeling.nn import HeteroEncoder, HeteroTemporalEncoder, HeteroGraphSAGE 14 | from .modules.nn import HeteroGraphSAGE, HeteroGIN, HeteroGPS 15 | from ..utils import divide_node_edge_dict 16 | 17 | class Model(torch.nn.Module): 18 | def __init__( 19 | self, 20 | data: HeteroData, 21 | col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], 22 | num_layers: int, 23 | channels: int, 24 | out_channels: int, 25 | aggr: str, 26 | norm: str, 27 | gnn: str = 'GraphSAGE', 28 | # List of node types to add shallow embeddings to input 29 | shallow_list: List[NodeType] = [], 30 | # ID awareness 31 | id_awareness: bool = False, 32 | ): 33 | super().__init__() 34 | 35 | self.encoder = HeteroEncoder( 36 | channels=channels, 37 | node_to_col_names_dict={ 38 | node_type: data[node_type].tf.col_names_dict 39 | for node_type in data.node_types 40 | }, 41 | node_to_col_stats=col_stats_dict, 42 | ) 43 | 44 | # Create temporal encoder if time attributes exist 45 | self.temporal_encoder = HeteroTemporalEncoder( 46 | node_types=[ 47 | node_type for node_type in data.node_types if "time" in data[node_type] 48 | ], 49 | channels=channels, 50 | ) 51 | 52 | # Initialize GNN based on specified type 53 | if gnn == "GraphSAGE": 54 | self.gnn = HeteroGraphSAGE( 55 | node_types=data.node_types, 56 | edge_types=data.edge_types, 57 | channels=channels, 58 | aggr=aggr, 59 | num_layers=num_layers, 60 | ) 61 | elif gnn == "GIN": 62 | self.gnn = HeteroGIN( 63 | node_types=data.node_types, 64 | edge_types=data.edge_types, 65 | channels=channels, 66 | ) 67 | elif gnn == "GPS": 68 | self.gnn = HeteroGPS( 69 | node_types=data.node_types, 70 | edge_types=data.edge_types, 71 | channels=channels, 72 | ) 73 | else: 74 | raise ValueError(f"Unknown GNN type: {gnn}") 75 | 76 | self.head = MLP( 77 | channels, 78 | out_channels=out_channels, 79 | norm=norm, 80 | num_layers=1, 81 | ) 82 | 83 | self.embedding_dict = ModuleDict( 84 | { 85 | node: Embedding(data.num_nodes_dict[node], channels) 86 | for node in shallow_list 87 | } 88 | ) 89 | self.id_awareness_emb = None 90 | if id_awareness: 91 | self.id_awareness_emb = torch.nn.Embedding(1, channels) 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | self.encoder.reset_parameters() 96 | self.temporal_encoder.reset_parameters() 97 | self.gnn.reset_parameters() 98 | self.head.reset_parameters() 99 | for embedding in self.embedding_dict.values(): 100 | torch.nn.init.normal_(embedding.weight, std=0.1) 101 | if self.id_awareness_emb is not None: 102 | self.id_awareness_emb.reset_parameters() 103 | 104 | def forward( 105 | self, 106 | batch: HeteroData, 107 | entity_table: NodeType, 108 | ) -> Tensor: 109 | seed_time = batch[entity_table].seed_time if hasattr(batch[entity_table], "seed_time") else None 110 | x_dict = self.encoder(batch.tf_dict) 111 | if seed_time is not None: 112 | rel_time_dict = self.temporal_encoder( 113 | seed_time, batch.time_dict, batch.batch_dict 114 | ) 115 | for node_type, rel_time in rel_time_dict.items(): 116 | x_dict[node_type] = x_dict[node_type] + rel_time 117 | 118 | for node_type, embedding in self.embedding_dict.items(): 119 | x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id) 120 | 121 | x_dict, edge_dict = divide_node_edge_dict(batch, x_dict) 122 | 123 | x_dict = self.gnn( 124 | x_dict, 125 | edge_dict, 126 | batch.edge_index_dict, 127 | batch.num_sampled_nodes_dict, 128 | batch.num_sampled_edges_dict, 129 | ) 130 | 131 | if seed_time is not None: 132 | return self.head(x_dict[entity_table][: seed_time.size(0)]) 133 | else: 134 | return self.head(x_dict[entity_table][: batch[entity_table]['batch_size']]) 135 | 136 | def forward_dst_readout( 137 | self, 138 | batch: HeteroData, 139 | entity_table: NodeType, 140 | dst_table: NodeType, 141 | ) -> Tensor: 142 | if self.id_awareness_emb is None: 143 | raise RuntimeError( 144 | "id_awareness must be set True to use forward_dst_readout" 145 | ) 146 | seed_time = batch[entity_table].seed_time if hasattr(batch[entity_table], "seed_time") else None 147 | x_dict = self.encoder(batch.tf_dict) 148 | if seed_time is not None: 149 | x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight 150 | rel_time_dict = self.temporal_encoder( 151 | seed_time, batch.time_dict, batch.batch_dict 152 | ) 153 | for node_type, rel_time in rel_time_dict.items(): 154 | x_dict[node_type] = x_dict[node_type] + rel_time 155 | 156 | for node_type, embedding in self.embedding_dict.items(): 157 | x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id) 158 | 159 | x_dict, edge_dict = divide_node_edge_dict(batch, x_dict) 160 | 161 | x_dict = self.gnn( 162 | x_dict, 163 | edge_dict, 164 | batch.edge_index_dict, 165 | ) 166 | 167 | return self.head(x_dict[dst_table]) -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/bench_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Union, Any 3 | 4 | from .benchmark import main as benchmark_main 5 | 6 | 7 | def run_benchmark( 8 | dataset: str = "rel-f1", 9 | task: str = "driver-top3", 10 | gnn: str = "GraphSAGE", 11 | budget_percentage: float = 0.05, 12 | method: Union[str, List[str]] = "all", 13 | num_runs: int = 10, 14 | seed: int = 0, 15 | **kwargs 16 | ) -> Dict[str, Any]: 17 | """ 18 | Execute comprehensive benchmark analysis on specified dataset and task. 19 | 20 | This function provides a high-level interface for running neural architecture search 21 | benchmarks on RDB2G-Bench. It supports multiple optimization methods and automatically 22 | handles data preparation, model training, evaluation, and results aggregation across 23 | multiple runs for statistical robustness. 24 | 25 | The benchmark process includes: 26 | 27 | 1. Dataset and task preparation with proper caching 28 | 2. Search space initialization with micro actions 29 | 3. Performance prediction dataset setup 30 | 4. Multiple independent runs with different random seeds 31 | 5. Statistical analysis and results aggregation 32 | 6. Trajectory analysis and CSV export for visualization 33 | 34 | Args: 35 | dataset (str): Name of the RelBench dataset to benchmark on. 36 | Available datasets include "rel-f1", "rel-avito", "rel-amazon", etc. 37 | Defaults to "rel-f1". 38 | task (str): Name of the RelBench task to evaluate. 39 | Task names depend on the dataset (e.g., "driver-top3", "user-ad-visit"). 40 | Defaults to "driver-top3". 41 | gnn (str): Name of the GNN model to benchmark (e.g., "GraphSAGE", "GIN", "GPS"). 42 | Defaults to "GraphSAGE". 43 | budget_percentage (float): Budget percentage for search algorithms as fraction 44 | of total search space (0.0-1.0). Higher values allow more thorough search 45 | but increase computational cost. Defaults to 0.05 (5%). 46 | method (Union[str, List[str]]): Search method(s) to benchmark. 47 | Options: "all", "ea", "greedy", "rl", "bo", "random", or list of methods. 48 | "all" runs all available methods. Defaults to "all". 49 | num_runs (int): Number of independent runs for statistical analysis. 50 | More runs provide better statistical confidence but increase runtime. 51 | Defaults to 10. 52 | seed (int): Base random seed for reproducibility. Each run uses seed + run_index. 53 | Defaults to 0. 54 | **kwargs: Additional configuration parameters: 55 | 56 | - tag (str): Experiment tag for result organization. Defaults to "hf". 57 | - cache_dir (str): Directory for caching datasets and models. 58 | Defaults to "~/.cache/relbench_examples". 59 | - result_dir (str): Root directory for saving results and trajectories. 60 | Defaults to "./results". 61 | 62 | Returns: 63 | Dict[str, Any]: Dictionary containing comprehensive benchmark results. 64 | 65 | - avg_actual_y_perf_of_selected (float): Average performance across runs 66 | - avg_rank_position_overall (float): Average ranking position 67 | - avg_percentile_overall (float): Average percentile ranking 68 | - selected_graph_ids_runs (List[int]): Graph IDs selected in each run 69 | - avg_evaluation_time (float): Average time spent on evaluations 70 | - avg_run_time (float): Average total runtime per run 71 | - method_wise_statistics (Dict): Detailed statistics for each method 72 | - performance_trajectories (Dict): Performance over time for each method 73 | - statistical_comparisons (Dict): Rankings and comparisons between methods 74 | 75 | Example: 76 | >>> # Run all methods on default dataset/task with specific GNN 77 | >>> results = run_benchmark( 78 | ... dataset="rel-f1", 79 | ... task="driver-top3", 80 | ... gnn="GraphSAGE", 81 | ... budget_percentage=0.05, 82 | ... method="all", 83 | ... num_runs=10, 84 | ... seed=42 85 | ... ) 86 | >>> 87 | >>> # Print aggregated results 88 | >>> for method, stats in results.items(): 89 | ... if 'avg_actual_y_perf_of_selected' in stats: 90 | ... print(f"{method}: {stats['avg_actual_y_perf_of_selected']:.4f}") 91 | 92 | >>> # Run specific methods with custom configuration 93 | >>> results = run_benchmark( 94 | ... dataset="rel-avito", 95 | ... task="user-ad-visit", 96 | ... gnn="GIN", 97 | ... budget_percentage=0.10, 98 | ... method=["ea", "greedy", "rl"], 99 | ... num_runs=5, 100 | ... tag="custom_experiment", 101 | ... cache_dir="/custom/cache", 102 | ... result_dir="/custom/results" 103 | ... ) 104 | 105 | >>> # Run quick test with single method 106 | >>> results = run_benchmark( 107 | ... gnn="GPS", 108 | ... method="all", 109 | ... num_runs=10, 110 | ... budget_percentage=0.05 111 | ... ) 112 | 113 | Note: 114 | - Results are automatically saved to CSV files for further analysis 115 | - Performance trajectories are exported for visualization 116 | - All intermediate data is cached to speed up repeated experiments 117 | - Large datasets may require significant memory and storage 118 | """ 119 | 120 | if isinstance(method, str): 121 | method = [method] 122 | 123 | default_params = { 124 | "tag": "hf", 125 | "cache_dir": "~/.cache/relbench_examples", 126 | "result_dir": "./results", 127 | } 128 | 129 | params = {**default_params, **kwargs} 130 | 131 | class Args: 132 | def __init__(self, **kwargs): 133 | for key, value in kwargs.items(): 134 | setattr(self, key, value) 135 | 136 | args = Args( 137 | dataset=dataset, 138 | task=task, 139 | gnn=gnn, 140 | budget_percentage=budget_percentage, 141 | method=method, 142 | num_runs=num_runs, 143 | seed=seed, 144 | **params 145 | ) 146 | 147 | results = benchmark_main(args) 148 | 149 | return results -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/baselines/random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import time 5 | from typing import Dict, Union, List, Optional 6 | 7 | from ..dataset import PerformancePredictionDataset 8 | from ..micro_action import MicroActionSet 9 | from .utils import calculate_overall_rank, get_performance_for_index, update_trajectory_and_best, pad_trajectory, calculate_evaluation_time 10 | 11 | def random_heuristic_analysis( 12 | dataset: PerformancePredictionDataset, 13 | micro_action_set: MicroActionSet, 14 | overall_actual_y: torch.Tensor, 15 | higher_is_better: bool, 16 | termination_threshold_ratio: float, 17 | method_name: str = "Random Heuristic", 18 | ): 19 | """ 20 | Perform Neural Architecture Search using Random Sampling Strategy. 21 | 22 | This function implements pure random search that uniformly samples graph neural 23 | network architectures from the entire valid search space. Each architecture is 24 | evaluated independently without any guidance from previous evaluations, providing 25 | an unbiased baseline for comparison with other optimization methods. 26 | 27 | 28 | Args: 29 | dataset (PerformancePredictionDataset): Dataset containing architecture 30 | performance data 31 | micro_action_set (MicroActionSet): Set of micro actions defining the 32 | architecture search space 33 | overall_actual_y (torch.Tensor): Complete performance tensor for 34 | ranking calculations 35 | higher_is_better (bool): Whether higher performance values are better 36 | termination_threshold_ratio (float): Fraction of total architectures to 37 | evaluate as budget 38 | method_name (str): Name identifier for this method. 39 | Defaults to "Random Heuristic". 40 | 41 | Returns: 42 | Dict[str, Union[str, int, float, List, Optional[int]]]: Dictionary containing search results and performance metrics. 43 | 44 | - method (str): Method name 45 | - selected_graph_id (Optional[int]): Index of best found architecture 46 | - actual_y_perf_of_selected (float): Performance of selected architecture 47 | - selection_metric_value (float): Metric value used for selection 48 | - selected_graph_origin (str): Origin method name 49 | - discovered_count (int): Number of architectures evaluated 50 | - total_iterations_run (int): Number of random samples drawn 51 | - rank_position_overall (float): Rank among all architectures 52 | - percentile_overall (float): Percentile ranking 53 | - total_samples_overall (int): Total available architectures 54 | - performance_trajectory (List): Performance over time 55 | - total_evaluation_time (float): Time spent on evaluations 56 | - total_run_time (float): Total algorithm runtime 57 | 58 | Example: 59 | >>> results = random_heuristic_analysis( 60 | ... dataset=dataset, 61 | ... micro_action_set=micro_actions, 62 | ... overall_actual_y=y_tensor, 63 | ... higher_is_better=True, 64 | ... termination_threshold_ratio=0.05 65 | ... ) 66 | >>> print(f"Best architecture: {results['selected_graph_id']}") 67 | >>> print(f"Performance: {results['actual_y_perf_of_selected']:.4f}") 68 | >>> print(f"Evaluated: {results['discovered_count']} architectures") 69 | """ 70 | performance_cache = {} 71 | time_cache = {} 72 | performance_trajectory = [] 73 | total_evaluated_count = 0 74 | best_perf_so_far = float('-inf') if higher_is_better else float('inf') 75 | best_index_so_far = -1 76 | method_origin = "Random Heuristic" 77 | total_samples_overall = overall_actual_y.numel() if overall_actual_y is not None else 0 78 | total_evaluation_time = 0.0 79 | 80 | start_time = time.time() 81 | 82 | num_total_valid_graphs = len(micro_action_set.valid_edge_sets_list) 83 | if num_total_valid_graphs == 0: 84 | print(f"Error: No valid graphs available in MicroActionSet. Cannot proceed.") 85 | return { 86 | "method": method_name, "selected_graph_id": None, "actual_y_perf_of_selected": np.nan, 87 | "selection_metric_value": np.nan, "selected_graph_origin": method_origin, 88 | "discovered_count": 0, "total_iterations_run": 0, 89 | "rank_position_overall": np.nan, "percentile_overall": np.nan, 90 | "total_samples_overall": total_samples_overall, "performance_trajectory": [], 91 | "total_evaluation_time": 0.0, "total_run_time": 0.0 92 | } 93 | 94 | evaluation_budget = max(1, int(termination_threshold_ratio * num_total_valid_graphs)) 95 | print(f"{method_name}: Total valid graphs: {num_total_valid_graphs}. Budget set to {evaluation_budget} unique evaluations.") 96 | 97 | all_valid_indices = list(range(num_total_valid_graphs)) 98 | shuffled_indices = random.sample(all_valid_indices, k=num_total_valid_graphs) 99 | 100 | for index in shuffled_indices: 101 | if total_evaluated_count >= evaluation_budget: 102 | print(f"{method_name}: Budget reached ({total_evaluated_count}/{evaluation_budget}).") 103 | break 104 | 105 | initial_cache_size = len(performance_cache) 106 | perf = get_performance_for_index(index, dataset, performance_cache) 107 | 108 | if perf is not None: 109 | performance_cache[index] = perf 110 | 111 | eval_time = calculate_evaluation_time(index, dataset, time_cache) 112 | if eval_time is not None: 113 | if len(performance_cache) > initial_cache_size: 114 | total_evaluation_time += eval_time 115 | 116 | total_evaluated_count, best_perf_so_far, best_index_so_far = \ 117 | update_trajectory_and_best( 118 | index, perf, performance_cache, initial_cache_size, 119 | total_evaluated_count, performance_trajectory, 120 | best_perf_so_far, best_index_so_far, higher_is_better 121 | ) 122 | 123 | print(f"{method_name}: Finished evaluating {total_evaluated_count} graphs.") 124 | 125 | total_run_time = time.time() - start_time 126 | 127 | pad_trajectory(performance_trajectory, total_evaluated_count, evaluation_budget, method_name) 128 | 129 | final_selected_perf = best_perf_so_far if best_index_so_far != -1 else np.nan 130 | final_selected_index = best_index_so_far if best_index_so_far != -1 else None 131 | 132 | results = { 133 | "method": method_name, 134 | "selected_graph_id": final_selected_index, 135 | "actual_y_perf_of_selected": final_selected_perf, 136 | "selection_metric_value": final_selected_perf, 137 | "selected_graph_origin": method_origin, 138 | "discovered_count": total_evaluated_count, 139 | "total_iterations_run": total_evaluated_count, 140 | "rank_position_overall": np.nan, 141 | "percentile_overall": np.nan, 142 | "total_samples_overall": total_samples_overall, 143 | "performance_trajectory": performance_trajectory, 144 | "total_evaluation_time": total_evaluation_time, 145 | "total_run_time": total_run_time 146 | } 147 | 148 | if final_selected_index is not None and not np.isnan(final_selected_perf) and overall_actual_y is not None: 149 | rank_info = calculate_overall_rank( 150 | final_selected_perf, 151 | overall_actual_y, 152 | higher_is_better 153 | ) 154 | if rank_info: 155 | results["rank_position_overall"] = rank_info["rank_position_overall"] 156 | results["percentile_overall"] = rank_info["percentile_overall"] 157 | 158 | return results -------------------------------------------------------------------------------- /rdb2g_bench/dataset/models/modules/nn.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv 6 | from torch_geometric.typing import EdgeType, NodeType 7 | from .sage_conv_edge import SAGEConvEdge 8 | from .gin_conv import GINConv, GINEConv 9 | from .gps_conv import GPSConv 10 | 11 | 12 | class HeteroGraphSAGE(torch.nn.Module): 13 | def __init__( 14 | self, 15 | node_types: List[NodeType], 16 | edge_types: List[EdgeType], 17 | channels: int, 18 | aggr: str = "mean", 19 | num_layers: int = 2, 20 | ): 21 | super().__init__() 22 | 23 | self.convs = torch.nn.ModuleList() 24 | for _ in range(num_layers): 25 | conv = HeteroConv( 26 | { 27 | edge_type: SAGEConvEdge((channels, channels), channels, aggr=aggr) 28 | if edge_type[1].startswith('r2e_') or edge_type[1].startswith('rev_r2e_') 29 | else SAGEConv((channels, channels), channels, aggr=aggr) 30 | for edge_type in edge_types 31 | }, 32 | aggr="sum", 33 | ) 34 | self.convs.append(conv) 35 | 36 | self.norms = torch.nn.ModuleList() 37 | for _ in range(num_layers): 38 | norm_dict = torch.nn.ModuleDict() 39 | for node_type in node_types: 40 | norm_dict[node_type] = LayerNorm(channels, mode="node") 41 | self.norms.append(norm_dict) 42 | 43 | def reset_parameters(self): 44 | for conv in self.convs: 45 | conv.reset_parameters() 46 | for norm_dict in self.norms: 47 | for norm in norm_dict.values(): 48 | norm.reset_parameters() 49 | 50 | def forward( 51 | self, 52 | x_dict: Dict[NodeType, Tensor], 53 | edge_dict: Dict[str, Tensor], 54 | edge_index_dict: Dict[EdgeType, Tensor], 55 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, 56 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, 57 | ) -> Dict[NodeType, Tensor]: 58 | for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): 59 | edge_attr_dict = {} 60 | 61 | for edge_type in edge_index_dict.keys(): 62 | if edge_type in edge_dict: 63 | edge_attr_dict[edge_type] = edge_dict[edge_type] 64 | 65 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict) 66 | 67 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} 68 | x_dict = {key: x.relu() for key, x in x_dict.items()} 69 | 70 | return x_dict 71 | 72 | class HeteroGIN(torch.nn.Module): 73 | def __init__( 74 | self, 75 | node_types: List[NodeType], 76 | edge_types: List[EdgeType], 77 | channels: int, 78 | aggr: str = "mean", 79 | num_layers: int = 2, 80 | ): 81 | super().__init__() 82 | 83 | self.convs = torch.nn.ModuleList() 84 | for _ in range(num_layers): 85 | nn_gin = torch.nn.Sequential( 86 | torch.nn.Linear(channels, channels * 2), 87 | torch.nn.ReLU(), 88 | torch.nn.Linear(channels * 2, channels), 89 | ) 90 | conv = HeteroConv( 91 | { 92 | edge_type: GINEConv(nn=nn_gin, train_eps=True) 93 | if edge_type[1].startswith('r2e_') or edge_type[1].startswith('rev_r2e_') 94 | else GINConv(nn=nn_gin, train_eps=True) 95 | for edge_type in edge_types 96 | }, 97 | aggr="sum", 98 | ) 99 | self.convs.append(conv) 100 | 101 | self.norms = torch.nn.ModuleList() 102 | for _ in range(num_layers): 103 | norm_dict = torch.nn.ModuleDict() 104 | for node_type in node_types: 105 | norm_dict[node_type] = LayerNorm(channels, mode="node") 106 | self.norms.append(norm_dict) 107 | 108 | def reset_parameters(self): 109 | for conv in self.convs: 110 | conv.reset_parameters() 111 | for norm_dict in self.norms: 112 | for norm in norm_dict.values(): 113 | norm.reset_parameters() 114 | 115 | def forward( 116 | self, 117 | x_dict: Dict[NodeType, Tensor], 118 | edge_dict: Dict[str, Tensor], 119 | edge_index_dict: Dict[EdgeType, Tensor], 120 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, 121 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, 122 | ) -> Dict[NodeType, Tensor]: 123 | for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): 124 | edge_attr_dict = {} 125 | 126 | for edge_type in edge_index_dict.keys(): 127 | if edge_type in edge_dict: 128 | edge_attr_dict[edge_type] = edge_dict[edge_type] 129 | 130 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict) 131 | 132 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} 133 | x_dict = {key: x.relu() for key, x in x_dict.items()} 134 | 135 | return x_dict 136 | 137 | class HeteroGPS(torch.nn.Module): 138 | def __init__( 139 | self, 140 | node_types: List[NodeType], 141 | edge_types: List[EdgeType], 142 | channels: int, 143 | num_layers: int = 2, 144 | heads: int = 1, 145 | dropout: float = 0.0, 146 | attn_type: str = 'multihead', 147 | aggr: str = "mean", 148 | ): 149 | super().__init__() 150 | 151 | self.convs = torch.nn.ModuleList() 152 | for _ in range(num_layers): 153 | conv_dict = {} 154 | for edge_type in edge_types: 155 | local_mpnn = SAGEConvEdge( 156 | in_channels=(channels, channels), 157 | out_channels=channels, 158 | aggr=aggr 159 | ) 160 | gps_layer = GPSConv( 161 | channels=channels, 162 | conv=local_mpnn, 163 | heads=heads, 164 | dropout=dropout, 165 | attn_type=attn_type, 166 | norm=None, 167 | ) 168 | conv_dict[edge_type] = gps_layer 169 | 170 | conv = HeteroConv(conv_dict, aggr="sum") 171 | self.convs.append(conv) 172 | 173 | self.norms = torch.nn.ModuleList() 174 | for _ in range(num_layers): 175 | norm_dict = torch.nn.ModuleDict() 176 | for node_type in node_types: 177 | norm_dict[node_type] = LayerNorm(channels, mode="node") 178 | self.norms.append(norm_dict) 179 | 180 | def reset_parameters(self): 181 | for conv in self.convs: 182 | conv.reset_parameters() 183 | for norm_dict in self.norms: 184 | for norm in norm_dict.values(): 185 | norm.reset_parameters() 186 | 187 | def forward( 188 | self, 189 | x_dict: Dict[NodeType, Tensor], 190 | edge_dict: Dict[str, Tensor], 191 | edge_index_dict: Dict[EdgeType, Tensor], 192 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, 193 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, 194 | ) -> Dict[NodeType, Tensor]: 195 | for conv, norm_dict in zip(self.convs, self.norms): 196 | edge_attr_dict = {} 197 | for edge_type in edge_index_dict.keys(): 198 | if edge_type in edge_dict: 199 | edge_attr_dict[edge_type] = edge_dict[edge_type] 200 | 201 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict) 202 | 203 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} 204 | x_dict = {key: x.relu() for key, x in x_dict.items()} 205 | 206 | return x_dict 207 | -------------------------------------------------------------------------------- /rdb2g_bench/dataset/models/modules/gps_conv.py: -------------------------------------------------------------------------------- 1 | # Custom implementation of GPSConv based on torch_geometric.nn.conv.GPSConv 2 | 3 | import inspect 4 | from typing import Any, Dict, Optional, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn import Dropout, Linear, Sequential 10 | 11 | from torch_geometric.nn.attention import PerformerAttention 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.nn.inits import reset 14 | from torch_geometric.nn.resolver import ( 15 | activation_resolver, 16 | normalization_resolver, 17 | ) 18 | from torch_geometric.typing import Adj, PairTensor 19 | from torch_geometric.utils import to_dense_batch 20 | 21 | 22 | class GPSConv(torch.nn.Module): 23 | r"""The general, powerful, scalable (GPS) graph transformer layer from the 24 | `"Recipe for a General, Powerful, Scalable Graph Transformer" 25 | `_ paper. 26 | 27 | The GPS layer is based on a 3-part recipe: 28 | 29 | 1. Inclusion of positional (PE) and structural encodings (SE) to the input 30 | features (done in a pre-processing step via 31 | :class:`torch_geometric.transforms`). 32 | 2. A local message passing layer (MPNN) that operates on the input graph. 33 | 3. A global attention layer that operates on the entire graph. 34 | 35 | .. note:: 36 | 37 | For an example of using :class:`GPSConv`, see 38 | `examples/graph_gps.py 39 | `_. 41 | 42 | Args: 43 | channels (int): Size of each input sample. 44 | conv (MessagePassing, optional): The local message passing layer. 45 | heads (int, optional): Number of multi-head-attentions. 46 | (default: :obj:`1`) 47 | dropout (float, optional): Dropout probability of intermediate 48 | embeddings. (default: :obj:`0.`) 49 | act (str or Callable, optional): The non-linear activation function to 50 | use. (default: :obj:`"relu"`) 51 | act_kwargs (Dict[str, Any], optional): Arguments passed to the 52 | respective activation function defined by :obj:`act`. 53 | (default: :obj:`None`) 54 | norm (str or Callable, optional): The normalization function to 55 | use. (default: :obj:`"batch_norm"`) 56 | norm_kwargs (Dict[str, Any], optional): Arguments passed to the 57 | respective normalization function defined by :obj:`norm`. 58 | (default: :obj:`None`) 59 | attn_type (str): Global attention type, :obj:`multihead` or 60 | :obj:`performer`. (default: :obj:`multihead`) 61 | attn_kwargs (Dict[str, Any], optional): Arguments passed to the 62 | attention layer. (default: :obj:`None`) 63 | """ 64 | def __init__( 65 | self, 66 | channels: int, 67 | conv: Optional[MessagePassing], 68 | heads: int = 1, 69 | dropout: float = 0.0, 70 | act: str = 'relu', 71 | act_kwargs: Optional[Dict[str, Any]] = None, 72 | norm: Optional[str] = 'batch_norm', 73 | norm_kwargs: Optional[Dict[str, Any]] = None, 74 | attn_type: str = 'multihead', 75 | attn_kwargs: Optional[Dict[str, Any]] = None, 76 | ): 77 | super().__init__() 78 | 79 | self.channels = channels 80 | self.conv = conv 81 | self.heads = heads 82 | self.dropout = dropout 83 | self.attn_type = attn_type 84 | 85 | attn_kwargs = attn_kwargs or {} 86 | if attn_type == 'multihead': 87 | self.attn = torch.nn.MultiheadAttention( 88 | channels, 89 | heads, 90 | batch_first=True, 91 | **attn_kwargs, 92 | ) 93 | elif attn_type == 'performer': 94 | self.attn = PerformerAttention( 95 | channels=channels, 96 | heads=heads, 97 | **attn_kwargs, 98 | ) 99 | else: 100 | # TODO: Support BigBird 101 | raise ValueError(f'{attn_type} is not supported') 102 | 103 | self.mlp = Sequential( 104 | Linear(channels, channels * 2), 105 | activation_resolver(act, **(act_kwargs or {})), 106 | Dropout(dropout), 107 | Linear(channels * 2, channels), 108 | Dropout(dropout), 109 | ) 110 | 111 | norm_kwargs = norm_kwargs or {} 112 | self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) 113 | self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) 114 | self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) 115 | 116 | self.norm_with_batch = False 117 | if self.norm1 is not None: 118 | signature = inspect.signature(self.norm1.forward) 119 | self.norm_with_batch = 'batch' in signature.parameters 120 | 121 | def reset_parameters(self): 122 | r"""Resets all learnable parameters of the module.""" 123 | if self.conv is not None: 124 | self.conv.reset_parameters() 125 | self.attn._reset_parameters() 126 | reset(self.mlp) 127 | if self.norm1 is not None: 128 | self.norm1.reset_parameters() 129 | if self.norm2 is not None: 130 | self.norm2.reset_parameters() 131 | if self.norm3 is not None: 132 | self.norm3.reset_parameters() 133 | 134 | def forward( 135 | self, 136 | x: Union[Tensor, PairTensor], 137 | edge_index: Adj, 138 | batch: Optional[torch.Tensor] = None, 139 | edge_attr: Optional[Tensor] = None, 140 | **kwargs, 141 | ) -> Tensor: 142 | r"""Runs the forward pass of the module.""" 143 | hs = [] 144 | 145 | # Determine input for residual connections and global attention based on x type 146 | x_for_residual_and_global = x[1] if isinstance(x, tuple) else x 147 | 148 | if self.conv is not None: # Local MPNN. 149 | # Pass edge_attr to the local message passing layer 150 | h_local = self.conv(x, edge_index, edge_attr=edge_attr, **kwargs) 151 | h_local = F.dropout(h_local, p=self.dropout, training=self.training) 152 | h_local = h_local + x_for_residual_and_global 153 | if self.norm1 is not None: 154 | if self.norm_with_batch: 155 | h_local = self.norm1(h_local, batch=batch) 156 | else: 157 | h_local = self.norm1(h_local) 158 | hs.append(h_local) 159 | 160 | # Global attention transformer-style model. 161 | # Use x_for_residual_and_global for global attention input 162 | h_global_in, mask = to_dense_batch(x_for_residual_and_global, batch) 163 | 164 | if isinstance(self.attn, torch.nn.MultiheadAttention): 165 | h_global, _ = self.attn(h_global_in, h_global_in, h_global_in, key_padding_mask=~mask, 166 | need_weights=False) 167 | elif isinstance(self.attn, PerformerAttention): 168 | h_global = self.attn(h_global_in, mask=mask) 169 | 170 | h_global = h_global[mask] 171 | h_global = F.dropout(h_global, p=self.dropout, training=self.training) 172 | h_global = h_global + x_for_residual_and_global # Residual connection. 173 | if self.norm2 is not None: 174 | if self.norm_with_batch: 175 | h_global = self.norm2(h_global, batch=batch) 176 | else: 177 | h_global = self.norm2(h_global) 178 | hs.append(h_global) 179 | 180 | out = sum(hs) # Combine local and global outputs. 181 | 182 | out = out + self.mlp(out) 183 | if self.norm3 is not None: 184 | if self.norm_with_batch: 185 | out = self.norm3(out, batch=batch) 186 | else: 187 | out = self.norm3(out) 188 | 189 | return out 190 | 191 | def __repr__(self) -> str: 192 | return (f'{self.__class__.__name__}({self.channels}, ' 193 | f'conv={self.conv}, heads={self.heads}, ' 194 | f'attn_type={self.attn_type})') -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/prompts/prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def get_score_feedback(initial_score, past_score, current_score, higher_is_better, last_action_num): 4 | score_feedback = f"In history actions, after the last {last_action_num} actions, the score has changed from {past_score:.4f} to {current_score:.4f}.\n" 5 | score_feedback += ( 6 | "Since a **higher** score is better, " if higher_is_better else "Since a **lower** score is better, " 7 | ) 8 | if (current_score > past_score and higher_is_better) or (current_score < past_score and not higher_is_better): 9 | score_feedback += f"the performance has **improved**." 10 | # if score has improved but still lower than initial score 11 | if (initial_score == current_score) or (initial_score > current_score and higher_is_better) or (initial_score < current_score and not higher_is_better): 12 | score_feedback += f"\nHowever, the initial score was {initial_score:.4f}, Please try other actions to improve the performance." 13 | else: 14 | score_feedback += "the performance has **decreased**. Please consider either reversing the previous action or exploring alternative actions to improve the schema." 15 | return score_feedback 16 | 17 | 18 | def get_action_info(actions= [ 'add_fk_pk_edge','remove_fk_pk_edge', 'convert_edge_to_row', 'convert_row_to_edge' ], #'add_row2edge_edge', 'remove_row2edge_edge', 19 | edge_info={"add_fk_pk_edge": "", "remove_fk_pk_edge": "", "convert_row_to_edge": "", "convert_edge_to_row": ""}): #, 'add_row2edge_edge', 'remove_row2edge_edge' 20 | with open("./prompts/action.json", "r") as f: 21 | action_info = json.load(f) 22 | action = "" 23 | for i, action_name in enumerate(actions): 24 | if edge_info[action_name] != "": 25 | action += f"{action_info[action_name]}\n{edge_info[action_name]}\n\n" 26 | return action 27 | 28 | def augmentation_prompt(dataset_name, task_name, edge_info, history_actions="", error_msg="", past_score=0, current_score=0, higher_is_better=True, initial_score=0, budget=10, initial_attempt=True, last_action_num=1): 29 | with open("./prompts/data_stats.json", "r") as f: 30 | stats = json.load(f) 31 | stats = stats[dataset_name] 32 | with open("./prompts/schema.json", "r") as f: 33 | schema = json.load(f) 34 | schema = schema[dataset_name] 35 | with open("./prompts/task.json", "r") as f: 36 | task = json.load(f) 37 | task = task[task_name] 38 | 39 | action_info = get_action_info(edge_info=edge_info) 40 | score_feedback = "" 41 | 42 | 43 | if error_msg != "": 44 | error_msg = f"Warning: The following actions will cause errors: \n{error_msg}" 45 | 46 | 47 | if initial_attempt: 48 | return f"""You are expected to construct graph schema based on the original inputs. 49 | You will be given an original schema represented in the dictionary format: 50 | 51 | 1. dataset_name: name of the dataset 52 | 2. tables: meta data for list of tables, each one will present following attributes 53 | 1. name: table name 54 | 2. columns: list of columns, each column will have following attributes 55 | 1. name: column name 56 | 2. dtype: column type, can be either text, categorical, float, primary_key, foreign_key, or multi_category. primary_key and foreign_key are two special types of categorical columns, which presents a structural relationship with other tables. Multi_category means this column is of list type, and each cell main contains a list of categorical values. After a column is set as primary_key or foreign_key, it should not be changed to other types. 57 | 3. link_to (optional): if this column is a foreign key, point to which primary key from which table 58 | 3. statistics of the table: statistics of the column value of tables. These statistics can be used to help you determine the characteristics of the columns. 59 | 60 | 61 | Here are the documents of the actions: 62 | {action_info} 63 | {error_msg} 64 | 65 | Now, you need to: 66 | 67 | 1. Actively think about which actions (from the list below) should be conducted to improve the schema. 68 | 2. Output all actions you can think of from the above list to make the schema better, and output your selections in the following format: 69 | If multiple actions are needed, please list **all** of them. 70 | 71 | 72 | [ 73 | {{ 74 | "explanation": , 75 | "action": , 76 | "parameters": 77 | }}, 78 | {{ 79 | "explanation": , 80 | "action": , 81 | "parameters": 82 | }}, 83 | {{ 84 | "explanation": , 85 | "action": , 86 | "parameters": 87 | }}, 88 | ... 89 | ] 90 | 91 | 92 | 93 | 94 | {stats} 95 | 96 | 97 | {task} 98 | 99 | 100 | {schema} 101 | 102 | 103 | 104 | 105 | Return your output in the json format inside .""" 106 | 107 | else: 108 | # Add history actions and score feedback only if there are history actions 109 | if history_actions.strip() != "": 110 | history_actions = f"History Actions: \n{history_actions}" 111 | score_feedback = get_score_feedback(initial_score, past_score, current_score, higher_is_better, error_msg, last_action_num) 112 | return f"""You are expected to construct graph schema based on the original inputs. 113 | You will be given an original schema represented in the dictionary format: 114 | 115 | 1. dataset_name: name of the dataset 116 | 2. tables: meta data for list of tables, each one will present following attributes 117 | 1. name: table name 118 | 2. columns: list of columns, each column will have following attributes 119 | 1. name: column name 120 | 2. dtype: column type, can be either text, categorical, float, primary_key, foreign_key, or multi_category. primary_key and foreign_key are two special types of categorical columns, which presents a structural relationship with other tables. Multi_category means this column is of list type, and each cell main contains a list of categorical values. After a column is set as primary_key or foreign_key, it should not be changed to other types. 121 | 3. link_to (optional): if this column is a foreign key, point to which primary key from which table 122 | 3. statistics of the table: statistics of the column value of tables. These statistics can be used to help you determine the characteristics of the columns. 123 | 124 | 125 | Here are the documents of the actions: 126 | {action_info} 127 | {error_msg} 128 | 129 | Now, you need to 130 | 1. Actively think about whether any one of the 4 actions should be conducted 131 | 2. Output all actions you can think of from the above list to make the schema better, and output your selections in the following format: 132 | If multiple actions are needed, please list **all** of them. 133 | 134 | 135 | [ 136 | {{ 137 | "explanation": , 138 | "action": , 139 | "parameters": 140 | }}, 141 | {{ 142 | "explanation": , 143 | "action": , 144 | "parameters": 145 | }}, 146 | {{ 147 | "explanation": , 148 | "action": , 149 | "parameters": 150 | }}, 151 | ... 152 | ] 153 | 154 | 155 | 3. If you think there's no more action, you can output 156 | 157 | None 158 | 159 | 160 | {history_actions} 161 | 162 | 163 | 164 | {stats} 165 | 166 | 167 | {task} 168 | 169 | 170 | {schema} 171 | 172 | 173 | 174 | {score_feedback} 175 | 176 | Note that the current schema may **not be optimal**, so other actions may yield better results. 177 | Please **only** halt the program with `None` if you believe no further actions are worth trying. 178 | You can try {budget} more times to improve the performance. 179 | Return your output in the json format inside .""" 180 | 181 | -------------------------------------------------------------------------------- /rdb2g_bench/dataset/models/modules/gin_conv.py: -------------------------------------------------------------------------------- 1 | # Custom implementation of GINConv based on torch_geometric.nn.conv.GINConv 2 | 3 | from typing import Callable, Optional, Union 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.nn.dense.linear import Linear 10 | from torch_geometric.nn.inits import reset 11 | from torch_geometric.typing import ( 12 | Adj, 13 | OptPairTensor, 14 | OptTensor, 15 | Size, 16 | SparseTensor, 17 | ) 18 | from torch_geometric.utils import spmm 19 | 20 | 21 | class GINConv(MessagePassing): 22 | r"""The graph isomorphism operator from the `"How Powerful are 23 | Graph Neural Networks?" `_ paper. 24 | 25 | .. math:: 26 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot 27 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) 28 | 29 | or 30 | 31 | .. math:: 32 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + 33 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), 34 | 35 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. 36 | 37 | Args: 38 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 39 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to 40 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by 41 | :class:`torch.nn.Sequential`. 42 | eps (float, optional): (Initial) :math:`\epsilon`-value. 43 | (default: :obj:`0.`) 44 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` 45 | will be a trainable parameter. (default: :obj:`False`) 46 | **kwargs (optional): Additional arguments of 47 | :class:`torch_geometric.nn.conv.MessagePassing`. 48 | 49 | Shapes: 50 | - **input:** 51 | node features :math:`(|\mathcal{V}|, F_{in})` or 52 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` 53 | if bipartite, 54 | edge indices :math:`(2, |\mathcal{E}|)` 55 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or 56 | :math:`(|\mathcal{V}_t|, F_{out})` if bipartite 57 | """ 58 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 59 | **kwargs): 60 | kwargs.setdefault('aggr', 'add') 61 | super().__init__(**kwargs) 62 | self.nn = nn 63 | self.initial_eps = eps 64 | if train_eps: 65 | self.eps = torch.nn.Parameter(torch.empty(1)) 66 | else: 67 | self.register_buffer('eps', torch.empty(1)) 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | super().reset_parameters() 72 | reset(self.nn) 73 | self.eps.data.fill_(self.initial_eps) 74 | 75 | def forward( 76 | self, 77 | x: Union[Tensor, OptPairTensor], 78 | edge_index: Adj, 79 | size: Size = None, 80 | ) -> Tensor: 81 | 82 | if isinstance(x, Tensor): 83 | x = (x, x) 84 | 85 | # propagate_type: (x: OptPairTensor) 86 | out = self.propagate(edge_index, x=x, size=size) 87 | 88 | x_r = x[1] 89 | if x_r is not None: 90 | out = out + (1 + self.eps) * x_r 91 | 92 | return self.nn(out) 93 | 94 | def message(self, x_j: Tensor) -> Tensor: 95 | return x_j 96 | 97 | def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: 98 | if isinstance(adj_t, SparseTensor): 99 | adj_t = adj_t.set_value(None, layout=None) 100 | return spmm(adj_t, x[0], reduce=self.aggr) 101 | 102 | def __repr__(self) -> str: 103 | return f'{self.__class__.__name__}(nn={self.nn})' 104 | 105 | 106 | class GINEConv(MessagePassing): 107 | r"""The modified :class:`GINConv` operator from the `"Strategies for 108 | Pre-training Graph Neural Networks" `_ 109 | paper. 110 | 111 | .. math:: 112 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot 113 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} 114 | ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right) 115 | 116 | that is able to incorporate edge features :math:`\mathbf{e}_{j,i}` into 117 | the aggregation procedure. 118 | 119 | Args: 120 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 121 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to 122 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by 123 | :class:`torch.nn.Sequential`. 124 | eps (float, optional): (Initial) :math:`\epsilon`-value. 125 | (default: :obj:`0.`) 126 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` 127 | will be a trainable parameter. (default: :obj:`False`) 128 | edge_dim (int, optional): Edge feature dimensionality. If set to 129 | :obj:`None`, node and edge feature dimensionality is expected to 130 | match. Other-wise, edge features are linearly transformed to match 131 | node feature dimensionality. (default: :obj:`None`) 132 | **kwargs (optional): Additional arguments of 133 | :class:`torch_geometric.nn.conv.MessagePassing`. 134 | 135 | Shapes: 136 | - **input:** 137 | node features :math:`(|\mathcal{V}|, F_{in})` or 138 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` 139 | if bipartite, 140 | edge indices :math:`(2, |\mathcal{E}|)`, 141 | edge features :math:`(|\mathcal{E}|, D)` *(optional)* 142 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or 143 | :math:`(|\mathcal{V}_t|, F_{out})` if bipartite 144 | """ 145 | def __init__(self, nn: torch.nn.Module, eps: float = 0., 146 | train_eps: bool = False, edge_dim: Optional[int] = None, 147 | **kwargs): 148 | kwargs.setdefault('aggr', 'add') 149 | super().__init__(**kwargs) 150 | self.nn = nn 151 | self.initial_eps = eps 152 | if train_eps: 153 | self.eps = torch.nn.Parameter(torch.empty(1)) 154 | else: 155 | self.register_buffer('eps', torch.empty(1)) 156 | if edge_dim is not None: 157 | if isinstance(self.nn, torch.nn.Sequential): 158 | nn = self.nn[0] 159 | if hasattr(nn, 'in_features'): 160 | in_channels = nn.in_features 161 | elif hasattr(nn, 'in_channels'): 162 | in_channels = nn.in_channels 163 | else: 164 | raise ValueError("Could not infer input channels from `nn`.") 165 | self.lin = Linear(edge_dim, in_channels) 166 | 167 | else: 168 | self.lin = None 169 | self.reset_parameters() 170 | 171 | def reset_parameters(self): 172 | reset(self.nn) 173 | self.eps.data.fill_(self.initial_eps) 174 | if self.lin is not None: 175 | self.lin.reset_parameters() 176 | 177 | def forward( 178 | self, 179 | x: Union[Tensor, OptPairTensor], 180 | edge_index: Adj, 181 | edge_attr: OptTensor = None, 182 | size: Size = None, 183 | ) -> Tensor: 184 | if isinstance(x, Tensor): 185 | x = (x, x) 186 | 187 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 188 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 189 | 190 | x_r = x[1] 191 | if x_r is not None: 192 | out = out + (1 + self.eps) * x_r 193 | 194 | return self.nn(out) 195 | 196 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: 197 | if edge_attr is None: 198 | edge_attr = torch.zeros_like(x_j) 199 | if self.lin is None and x_j.size(-1) != edge_attr.size(-1): 200 | raise ValueError("Node and edge feature dimensionalities do not " 201 | "match. Consider setting the 'edge_dim' " 202 | "attribute of 'GINEConv'") 203 | if self.lin is not None: 204 | edge_attr = self.lin(edge_attr) 205 | 206 | return (x_j + edge_attr).relu() 207 | 208 | def __repr__(self) -> str: 209 | return f'{self.__class__.__name__}(nn={self.nn})' 210 | -------------------------------------------------------------------------------- /rdb2g_bench/common/search_space/row2ne_search_space.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | import copy 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch_geometric.data import HeteroData 8 | 9 | class Row2NESearchSpace(): 10 | def __init__(self, dataset: str, task: str, hetero_data: HeteroData): 11 | self.dataset = dataset 12 | self.task = task 13 | self.data = hetero_data 14 | self.node_types = sorted(list(hetero_data.node_types)) 15 | self.edge_types = sorted(list(hetero_data.edge_types)) 16 | self.possible_edge_types = self.find_possible_edges() 17 | self.full_edges = self.find_full_edges() 18 | self._cached_r2e_key = None 19 | self._cached_converted_data = None 20 | 21 | def find_possible_edges(self): 22 | possible_edges = [] 23 | for node_type in self.node_types: 24 | outgoing_edges = [(src, rel, dst) for src, rel, dst in self.edge_types 25 | if src == node_type and rel.startswith('f2p')] 26 | incoming_edges = [(src, rel, dst) for src, rel, dst in self.edge_types 27 | if dst == node_type and rel.startswith('f2p')] 28 | if len(outgoing_edges) == 2 and len(incoming_edges) == 0: 29 | possible_edges.append(node_type) 30 | 31 | return sorted(possible_edges) 32 | 33 | def find_full_edges(self): 34 | full_edges = [] 35 | 36 | for node_type in self.possible_edge_types: 37 | connected_nodes = [] 38 | for src, rel, dst in self.edge_types: 39 | if src == node_type and rel.startswith('f2p'): 40 | connected_nodes.append(dst) 41 | 42 | if len(connected_nodes) == 1: 43 | src_node, dst_node = connected_nodes[0], connected_nodes[0] 44 | elif len(connected_nodes) == 2: 45 | src_node, dst_node = connected_nodes[0], connected_nodes[1] 46 | else: 47 | raise ValueError(f"Invalid number of connected nodes: {len(connected_nodes)}") 48 | 49 | src_edge_type, dst_edge_type = None, None 50 | for edge_type in self.edge_types: 51 | if edge_type[0] == node_type and edge_type[2] == src_node and edge_type != dst_edge_type and src_edge_type is None: 52 | src_edge_type = edge_type 53 | if edge_type[0] == node_type and edge_type[2] == dst_node and edge_type != src_edge_type and dst_edge_type is None: 54 | dst_edge_type = edge_type 55 | 56 | if src_edge_type and dst_edge_type: 57 | full_edges.append((src_node, f"r2e_{node_type}", dst_node)) 58 | 59 | return full_edges + [edge for edge in sorted(self.data.edge_types) if edge[1].startswith("f2p")] 60 | 61 | def convert_row_to_edge(self, edges: torch.Tensor): 62 | if not isinstance(edges, torch.Tensor): 63 | edges = torch.tensor(edges) 64 | 65 | # Determine the cache key based on the row-to-edge part of the edges tensor 66 | r2e_part = edges[:len(self.possible_edge_types)] 67 | current_r2e_key = tuple(r2e_part.tolist()) 68 | if current_r2e_key == self._cached_r2e_key: 69 | print("Using cached converted data") 70 | return copy.deepcopy(self._cached_converted_data) 71 | 72 | converted_data = HeteroData() 73 | # Copy node data and ensure tensors are contiguous 74 | nodes_to_convert = set() 75 | for i, is_selected in enumerate(r2e_part): 76 | if is_selected == 1: 77 | _, rel, _ = self.full_edges[i] 78 | nodes_to_convert.add(rel[4:]) 79 | 80 | for node_type in self.node_types: 81 | for attr_name, attr_value in self.data[node_type].items(): 82 | if node_type in nodes_to_convert and attr_name == 'time': 83 | continue 84 | if isinstance(attr_value, torch.Tensor): 85 | converted_data[node_type][attr_name] = attr_value.clone().detach().contiguous() 86 | else: 87 | converted_data[node_type][attr_name] = attr_value 88 | 89 | # Copy existing edges and ensure tensors are contiguous 90 | for edge_type in self.edge_types: 91 | src_node, edge_name, dst_node = edge_type 92 | if src_node not in nodes_to_convert and dst_node not in nodes_to_convert: 93 | converted_data[edge_type].edge_index = self.data[edge_type].edge_index.clone().detach().contiguous() 94 | 95 | # Convert nodes to edges 96 | for node_type in nodes_to_convert: 97 | connected_nodes = [] 98 | for src, rel, dst in self.edge_types: 99 | if src == node_type and rel.startswith('f2p'): 100 | connected_nodes.append(dst) 101 | 102 | if len(connected_nodes) == 1: 103 | src_node, dst_node = connected_nodes[0], connected_nodes[0] 104 | elif len(connected_nodes) == 2: 105 | src_node, dst_node = connected_nodes[0], connected_nodes[1] 106 | else: 107 | raise ValueError(f"Invalid number of connected nodes: {len(connected_nodes)}") 108 | 109 | # Find relevant edge types 110 | src_edge_type, dst_edge_type = None, None 111 | for edge_type in self.edge_types: 112 | if edge_type[0] == node_type and edge_type[2] == src_node and edge_type != dst_edge_type and src_edge_type is None: 113 | src_edge_type = edge_type 114 | if edge_type[0] == node_type and edge_type[2] == dst_node and edge_type != src_edge_type and dst_edge_type is None: 115 | dst_edge_type = edge_type 116 | 117 | if src_edge_type and dst_edge_type: 118 | # Create new edge (from src_node to dst_node) 119 | new_edge_type = (src_node, f"r2e_{node_type}", dst_node) 120 | reverse_edge_type = (dst_node, f"rev_r2e_{node_type}", src_node) 121 | 122 | src_edges = self.data[src_edge_type].edge_index.clone().detach().contiguous() 123 | dst_edges = self.data[dst_edge_type].edge_index.clone().detach().contiguous() 124 | 125 | # Create mappings using numpy operations 126 | src_node_ids = src_edges[0].numpy() 127 | src_table_ids = src_edges[1].numpy() 128 | src_mapping = dict(zip(src_node_ids, src_table_ids)) 129 | 130 | dst_node_ids = dst_edges[0].numpy() 131 | dst_table_ids = dst_edges[1].numpy() 132 | dst_mapping = dict(zip(dst_node_ids, dst_table_ids)) 133 | 134 | # Get common node ids that exist in both mappings 135 | node_ids = np.arange(len(self.data[node_type].tf)) 136 | mask = np.isin(node_ids, list(src_mapping.keys())) & np.isin(node_ids, list(dst_mapping.keys())) 137 | valid_node_ids = node_ids[mask] 138 | 139 | # Construct edge pairs 140 | if len(valid_node_ids) > 0: 141 | # Use vectorized operations to get src and dst ids 142 | max_node_id = max(max(src_mapping.keys()), max(dst_mapping.keys())) + 1 143 | src_lookup = np.full(max_node_id, -1) 144 | dst_lookup = np.full(max_node_id, -1) 145 | 146 | src_keys = np.array(list(src_mapping.keys()), dtype=np.int64) 147 | src_values = np.array(list(src_mapping.values()), dtype=np.int64) 148 | src_lookup[src_keys] = src_values 149 | 150 | dst_keys = np.array(list(dst_mapping.keys()), dtype=np.int64) 151 | dst_values = np.array(list(dst_mapping.values()), dtype=np.int64) 152 | dst_lookup[dst_keys] = dst_values 153 | 154 | # Apply vectorized lookups 155 | src_ids = src_lookup[valid_node_ids] 156 | dst_ids = dst_lookup[valid_node_ids] 157 | new_edges = np.stack([src_ids, dst_ids], axis=1) 158 | 159 | if len(new_edges) > 0: 160 | new_edge_index = torch.tensor(new_edges, dtype=torch.long).t().contiguous() 161 | converted_data[new_edge_type].edge_index = new_edge_index 162 | converted_data[new_edge_type].mapped_node_ids = torch.tensor(valid_node_ids, dtype=torch.long).clone().detach().contiguous() 163 | 164 | reverse_edge_index = torch.stack([new_edge_index[1], new_edge_index[0]], dim=0).contiguous() 165 | converted_data[reverse_edge_type].edge_index = reverse_edge_index 166 | converted_data[reverse_edge_type].mapped_node_ids = torch.tensor(valid_node_ids, dtype=torch.long).clone().detach().contiguous() 167 | 168 | self._cached_converted_data = converted_data.to('cpu') 169 | self._cached_r2e_key = current_r2e_key 170 | 171 | return converted_data -------------------------------------------------------------------------------- /rdb2g_bench/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from pathlib import Path 4 | from typing import Optional, List, Dict, Union 5 | from datasets import load_dataset 6 | 7 | from .dataloader import RDB2GBench 8 | 9 | 10 | def load_rdb2g_bench( 11 | result_dir: str = "./results", 12 | download: bool = True, 13 | cache_dir: Optional[str] = None, 14 | tag: str = "hf" 15 | ) -> RDB2GBench: 16 | """ 17 | Load RDB2G-Bench results from local directory, download if missing. 18 | 19 | This function serves as the main entry point for accessing RDB2G-Bench data. 20 | It first checks for existing data in the specified directory, and if no data 21 | is found and download is enabled, it automatically downloads the dataset 22 | from Hugging Face Hub. 23 | 24 | Args: 25 | result_dir (str): Directory containing the benchmark results. 26 | Defaults to "./results". 27 | download (bool): Whether to download data if results are missing. 28 | Defaults to True. 29 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets. 30 | If None, uses default HF cache location. 31 | tag (str): Tag to use for downloaded data organization. 32 | Defaults to "hf". 33 | 34 | Returns: 35 | RDB2GBench: Benchmark object for accessing organized results with 36 | hierarchical access pattern: bench[dataset][task][idx]. 37 | 38 | Raises: 39 | RuntimeError: If no valid data is found in the result directory 40 | after download attempts. 41 | 42 | Example: 43 | >>> bench = load_rdb2g_bench("./my_results") 44 | >>> available = bench.get_available() 45 | >>> result = bench['rel-f1']['driver-top3'][0] 46 | """ 47 | result_path = Path(result_dir) 48 | tables_path = result_path / "tables" 49 | 50 | has_data = False 51 | if tables_path.exists(): 52 | for dataset_dir in tables_path.iterdir(): 53 | if dataset_dir.is_dir(): 54 | for task_dir in dataset_dir.iterdir(): 55 | if task_dir.is_dir(): 56 | for tag_dir in task_dir.iterdir(): 57 | if tag_dir.is_dir(): 58 | # Check for GNN subdirectories with CSV files 59 | for gnn_dir in tag_dir.iterdir(): 60 | if gnn_dir.is_dir() and list(gnn_dir.glob("*.csv")): 61 | has_data = True 62 | break 63 | if has_data: 64 | break 65 | if has_data: 66 | break 67 | if has_data: 68 | break 69 | 70 | if not has_data and download: 71 | print(f"No data found in {result_dir}, downloading from Hugging Face...") 72 | download_rdb2g_bench( 73 | result_dir=result_dir, 74 | cache_dir=cache_dir, 75 | tag=tag 76 | ) 77 | 78 | bench = RDB2GBench(result_dir) 79 | 80 | if not bench.get_available(): 81 | raise RuntimeError(f"No valid data found in {result_dir}") 82 | 83 | return bench 84 | 85 | 86 | def download_rdb2g_bench( 87 | result_dir: str = "./results", 88 | cache_dir: Optional[str] = None, 89 | dataset_names: Optional[List[str]] = None, 90 | task_names: Optional[List[str]] = None, 91 | gnn_names: Optional[List[str]] = None, 92 | tag: str = "hf", 93 | ) -> Dict[str, List[str]]: 94 | """ 95 | Download RDB2G-Bench dataset from Hugging Face and organize it by dataset/task. 96 | 97 | This function downloads the complete or filtered RDB2G-Bench dataset from 98 | Hugging Face Hub and organizes it into a structured directory format. 99 | The data is grouped by dataset and task, with separate CSV files for each 100 | random seed. 101 | 102 | Args: 103 | result_dir (str): Directory to save the organized results. 104 | Will be created if it doesn't exist. Defaults to "./results". 105 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets. 106 | If None, uses default HF cache location. 107 | dataset_names (Optional[List[str]]): List of specific datasets to download. 108 | If None, downloads all available datasets. 109 | task_names (Optional[List[str]]): List of specific tasks to download. 110 | If None, downloads all available tasks. 111 | gnn_names (Optional[List[str]]): List of specific GNN models to download. 112 | If None, downloads all available GNN models. 113 | tag (str): Tag to identify the download and organize files. 114 | Defaults to "hf". 115 | 116 | Returns: 117 | Dict[str, List[str]]: Dictionary mapping dataset/task combinations to lists of saved file paths. 118 | 119 | Keys are in format "dataset/task". 120 | 121 | Example: 122 | >>> saved_files = download_rdb2g_bench( 123 | ... dataset_names=['rel-f1'], 124 | ... task_names=['driver-top3'], 125 | ... gnn_names=['GraphSAGE'] 126 | ... ) 127 | >>> print(saved_files) 128 | {'rel-f1/driver-top3': ['./results/tables/rel-f1/driver-top3/hf/GraphSAGE/0.csv', ...]} 129 | """ 130 | result_dir = Path(result_dir) 131 | result_dir.mkdir(parents=True, exist_ok=True) 132 | 133 | dataset = load_dataset( 134 | "kaistdata/RDB2G-Bench", 135 | cache_dir=cache_dir, 136 | split="train" 137 | ) 138 | 139 | df = dataset.to_pandas() 140 | 141 | if dataset_names is not None: 142 | df = df[df['dataset'].isin(dataset_names)] 143 | 144 | if task_names is not None: 145 | df = df[df['task'].isin(task_names)] 146 | 147 | if gnn_names is not None: 148 | df = df[df['gnn'].isin(gnn_names)] 149 | 150 | if df.empty: 151 | return {} 152 | 153 | saved_files = {} 154 | grouped = df.groupby(['dataset', 'task']) 155 | 156 | for (dataset_name, task_name), group_df in grouped: 157 | group_df = group_df.sort_values(['seed', 'idx']) 158 | 159 | combination_key = f"{dataset_name}/{task_name}" 160 | saved_files[combination_key] = [] 161 | 162 | # Group by GNN to create separate directories 163 | for gnn_name, gnn_group_df in group_df.groupby('gnn'): 164 | gnn_dir = result_dir / "tables" / dataset_name / task_name / tag / gnn_name 165 | gnn_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | for seed, seed_df in gnn_group_df.groupby('seed'): 168 | output_df = seed_df[[ 169 | 'idx', 'graph', 'train_metric', 'valid_metric', 'test_metric', 'params', 170 | 'train_time', 'valid_time', 'test_time', 'gnn' 171 | ]].copy() 172 | 173 | filename = f"{seed}.csv" 174 | filepath = gnn_dir / filename 175 | output_df.to_csv(filepath, index=False) 176 | saved_files[combination_key].append(str(filepath)) 177 | 178 | return saved_files 179 | 180 | 181 | def get_dataset_stats(cache_dir: Optional[str] = None) -> pd.DataFrame: 182 | """ 183 | Get comprehensive statistics about all available datasets and tasks in RDB2G-Bench. 184 | 185 | This function loads the complete dataset from Hugging Face and computes 186 | aggregate statistics for each dataset/task combination, including counts 187 | of unique indices and seeds, as well as performance metrics statistics. 188 | 189 | Args: 190 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets. 191 | If None, uses default HF cache location. 192 | 193 | Returns: 194 | pd.DataFrame: DataFrame with statistical information about all datasets and tasks. 195 | 196 | - dataset: Dataset name 197 | - task: Task name 198 | - gnn: GNN model name 199 | - idx: Number of unique graph configurations 200 | - seed: Number of random seeds 201 | - test_metric_mean: Mean test performance 202 | - test_metric_std: Standard deviation of test performance 203 | - test_metric_min: Minimum test performance 204 | - test_metric_max: Maximum test performance 205 | 206 | Example: 207 | >>> stats = get_dataset_stats() 208 | >>> print(stats.head()) 209 | dataset task gnn idx seed test_metric_mean test_metric_std ... 210 | rel-f1 driver-top3 GraphSAGE 50 10 0.8542 0.0123 ... 211 | """ 212 | dataset = load_dataset( 213 | "kaistdata/RDB2G-Bench", 214 | cache_dir=cache_dir, 215 | split="train" 216 | ) 217 | 218 | df = dataset.to_pandas() 219 | 220 | stats = df.groupby(['dataset', 'task', 'gnn']).agg({ 221 | 'idx': 'nunique', 222 | 'seed': 'nunique', 223 | 'test_metric': ['mean', 'std', 'min', 'max'] 224 | }).round(4) 225 | 226 | stats.columns = ['_'.join(col).strip() for col in stats.columns] 227 | 228 | stats = stats.rename(columns={ 229 | 'idx_nunique': 'idx', 230 | 'seed_nunique': 'seed', 231 | }) 232 | stats = stats.reset_index() 233 | 234 | return stats 235 | -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/baselines/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | from ..dataset import PerformancePredictionDataset 4 | import numpy as np 5 | 6 | def calculate_overall_rank(selected_actual_y: float, 7 | overall_actual_y: torch.Tensor, 8 | higher_is_better: bool): 9 | """ 10 | Calculates the rank (position and percentile) of a selected performance. 11 | 12 | Args: 13 | selected_actual_y (float): The performance score to rank 14 | overall_actual_y (torch.Tensor): Tensor containing all performance values 15 | higher_is_better (bool): Whether higher values indicate better performance 16 | 17 | Returns: 18 | Optional[Dict]: Dictionary containing rank information with keys: 19 | 20 | rank_position_overall (int): Position in the overall ranking 21 | percentile_overall (float): Percentile ranking (0-100) 22 | total_samples_overall (int): Total number of samples in ranking 23 | 24 | Returns None if overall_actual_y is empty. 25 | """ 26 | if overall_actual_y.numel() == 0: 27 | print("Warning: Overall actual y tensor is empty, cannot calculate overall rank.") 28 | return None 29 | 30 | overall_actual_y = overall_actual_y.squeeze() 31 | if overall_actual_y.ndim == 0: 32 | overall_actual_y = overall_actual_y.unsqueeze(0) 33 | 34 | num_total_overall = overall_actual_y.numel() 35 | if higher_is_better: 36 | num_better_or_equal = torch.sum(overall_actual_y >= selected_actual_y).item() 37 | else: 38 | num_better_or_equal = torch.sum(overall_actual_y <= selected_actual_y).item() 39 | 40 | rank_position = num_better_or_equal 41 | percentile = (rank_position / num_total_overall) * 100 if num_total_overall > 0 else 0 42 | 43 | return { 44 | "rank_position_overall": rank_position, 45 | "percentile_overall": percentile, 46 | "total_samples_overall": num_total_overall 47 | } 48 | 49 | def get_performance_for_index( 50 | index: int, 51 | dataset: PerformancePredictionDataset, 52 | performance_cache: dict 53 | ) -> float: 54 | """ 55 | Retrieves the performance for a given graph index, using a cache if available. 56 | 57 | Args: 58 | index (int): Graph index to retrieve performance for 59 | dataset (PerformancePredictionDataset): Dataset containing performance data 60 | performance_cache (dict): Cache for storing retrieved performance values 61 | 62 | Returns: 63 | Optional[float]: Performance value for the given index, or None if not found 64 | or if performance data is invalid (NaN/None). 65 | """ 66 | if index in performance_cache: 67 | return performance_cache[index] 68 | 69 | try: 70 | row = dataset.df_result_group[dataset.df_result_group['idx'] == index] 71 | if len(row) == 0: 72 | print(f"Warning: No performance data found for index {index}") 73 | return None 74 | 75 | performance = row.iloc[0][dataset.target_col] 76 | 77 | if performance is None or (isinstance(performance, float) and np.isnan(performance)): 78 | print(f"Warning: NaN or None performance found for index {index}") 79 | return None 80 | 81 | return performance 82 | 83 | except Exception as e: 84 | print(f"Error retrieving performance for index {index}: {str(e)}") 85 | return None 86 | 87 | def update_trajectory_and_best( 88 | index: int, 89 | perf: Optional[float], 90 | performance_cache: dict, 91 | initial_cache_size: int, 92 | total_evaluated_count: int, 93 | performance_trajectory: list, 94 | global_best_perf: float, 95 | global_best_index: int, 96 | higher_is_better: bool, 97 | ) -> Tuple[int, float, int]: 98 | """ 99 | Updates trajectory and global best after a performance evaluation. 100 | 101 | Args: 102 | index (int): Index of the evaluated architecture 103 | perf (Optional[float]): Performance value obtained 104 | performance_cache (dict): Cache storing performance values 105 | initial_cache_size (int): Initial size of performance cache 106 | total_evaluated_count (int): Current count of evaluated architectures 107 | performance_trajectory (list): List tracking performance over time 108 | global_best_perf (float): Current best performance found 109 | global_best_index (int): Index of current best architecture 110 | higher_is_better (bool): Whether higher values indicate better performance 111 | 112 | Returns: 113 | Tuple[int, float, int]: Updated values: 114 | 115 | new_total_evaluated_count (int): Updated evaluation count 116 | new_global_best_perf (float): Updated best performance 117 | new_global_best_index (int): Updated best architecture index 118 | """ 119 | new_total_evaluated_count = total_evaluated_count 120 | new_global_best_perf = global_best_perf 121 | new_global_best_index = global_best_index 122 | 123 | if perf is not None and np.isfinite(perf): 124 | if len(performance_cache) > initial_cache_size: 125 | new_total_evaluated_count += 1 126 | 127 | is_better = False 128 | if new_global_best_index == -1: 129 | is_better = True 130 | elif higher_is_better and perf > new_global_best_perf: 131 | is_better = True 132 | elif not higher_is_better and perf < new_global_best_perf: 133 | is_better = True 134 | 135 | if is_better: 136 | new_global_best_perf = perf 137 | new_global_best_index = index 138 | 139 | performance_trajectory.append((new_total_evaluated_count, new_global_best_perf)) 140 | 141 | else: 142 | is_better_than_current_global = False 143 | if new_global_best_index == -1: 144 | is_better_than_current_global = True 145 | elif higher_is_better and perf > new_global_best_perf: 146 | is_better_than_current_global = True 147 | elif not higher_is_better and perf < new_global_best_perf: 148 | is_better_than_current_global = True 149 | 150 | if is_better_than_current_global: 151 | new_global_best_perf = perf 152 | new_global_best_index = index 153 | 154 | return new_total_evaluated_count, new_global_best_perf, new_global_best_index 155 | 156 | def pad_trajectory( 157 | performance_trajectory: list, 158 | total_evaluated_count: int, 159 | evaluation_budget: int, 160 | method_name: str = "Search" 161 | ) -> None: 162 | """Pads the performance trajectory if the search ended before using the full budget.""" 163 | if total_evaluated_count < evaluation_budget: 164 | if performance_trajectory: 165 | final_best_perf_to_pad = performance_trajectory[-1][1] 166 | if np.isfinite(final_best_perf_to_pad): 167 | print(f"Padding {method_name} trajectory from {total_evaluated_count + 1} to {evaluation_budget} with final best perf: {final_best_perf_to_pad:.4f}") 168 | for i in range(total_evaluated_count + 1, evaluation_budget + 1): 169 | performance_trajectory.append((i, final_best_perf_to_pad)) 170 | else: 171 | print(f"Warning: {method_name} Search ended early ({total_evaluated_count}/{evaluation_budget}) but last best performance is not finite ({final_best_perf_to_pad}). Cannot pad trajectory.") 172 | else: 173 | print(f"Warning: {method_name} Search ended early ({total_evaluated_count}/{evaluation_budget}) but performance trajectory is empty. Cannot pad trajectory.") 174 | 175 | def calculate_evaluation_time( 176 | index: int, 177 | dataset: PerformancePredictionDataset, 178 | time_cache: dict 179 | ) -> float: 180 | """ 181 | Calculates the evaluation time: eval_time = train_time * 20 + valid_time + test_time. 182 | 183 | Args: 184 | index (int): Graph index to calculate evaluation time for 185 | dataset (PerformancePredictionDataset): Dataset containing timing data 186 | time_cache (dict): Cache for storing calculated evaluation times 187 | 188 | Returns: 189 | Optional[float]: Total evaluation time in seconds, or None if timing data 190 | is not available or invalid. 191 | """ 192 | epochs = 20 193 | 194 | if index in time_cache: 195 | return time_cache[index] 196 | 197 | try: 198 | row = dataset.df_result_group[dataset.df_result_group['idx'] == index] 199 | if len(row) == 0: 200 | print(f"Warning: No time data found for index {index}") 201 | return None 202 | 203 | train_time = row.iloc[0]['train_time'] 204 | 205 | if train_time is None or (isinstance(train_time, float) and np.isnan(train_time)): 206 | print(f"Warning: NaN or None train_time found for index {index}") 207 | return None 208 | 209 | valid_time = row.iloc[0]['valid_time'] 210 | test_time = row.iloc[0]['test_time'] 211 | 212 | if (valid_time is None or (isinstance(valid_time, float) and np.isnan(valid_time)) or 213 | test_time is None or (isinstance(test_time, float) and np.isnan(test_time))): 214 | print(f"Warning: NaN or None valid_time or test_time found for index {index}") 215 | eval_time = train_time * epochs + train_time + train_time 216 | else: 217 | eval_time = (train_time + valid_time) * epochs + test_time 218 | 219 | time_cache[index] = eval_time 220 | 221 | return eval_time 222 | 223 | except Exception as e: 224 | print(f"Error calculating evaluation time for index {index}: {str(e)}") 225 | return None -------------------------------------------------------------------------------- /rdb2g_bench/common/search_space/gnn_search_space.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import copy 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch_geometric.data import HeteroData 8 | 9 | class GNNSearchSpace(): 10 | def __init__( 11 | self, 12 | dataset: str, 13 | task: str, 14 | num_layers: int, 15 | node_types: List[str], 16 | edge_types: List[Tuple[str, str, str]], 17 | src_entity_table: str, 18 | dst_entity_table: str = None, 19 | ): 20 | self.dataset = dataset 21 | self.task = task 22 | self.src_entity_table = src_entity_table 23 | self.dst_entity_table = dst_entity_table 24 | self.num_layers = num_layers 25 | self.node_types = node_types 26 | self.edge_types = edge_types 27 | self.reachable_nodes = set() 28 | 29 | def filter_unreachable(self, edges): 30 | reachable_nodes = self.reachable_nodes.copy() 31 | layer_reachable = [reachable_nodes.copy()] 32 | 33 | for _ in range(self.num_layers): 34 | new_reachable = set() 35 | for i, is_selected in enumerate(edges): 36 | if is_selected == 1: 37 | src, _, dst = self.edge_types[i] 38 | if src in reachable_nodes: 39 | new_reachable.add(dst) 40 | if dst in reachable_nodes: 41 | new_reachable.add(src) 42 | 43 | reachable_nodes.update(new_reachable) 44 | layer_reachable.append(new_reachable) 45 | 46 | if not new_reachable: 47 | break 48 | 49 | filtered_edges = np.copy(edges) 50 | for i, is_selected in enumerate(edges): 51 | if is_selected == 1: 52 | src, _, dst = self.edge_types[i] 53 | # Check connection between consecutive layers 54 | for layer in range(len(layer_reachable)-1): 55 | if (src in layer_reachable[layer] and dst in layer_reachable[layer+1]) or \ 56 | (dst in layer_reachable[layer] and src in layer_reachable[layer+1]): 57 | break 58 | else: 59 | filtered_edges[i] = 0 60 | 61 | return filtered_edges 62 | 63 | def is_valid(self, edges) -> bool: 64 | raise NotImplementedError 65 | 66 | def is_possible(self, edges) -> bool: 67 | for i, is_selected in enumerate(edges): 68 | if is_selected == 1: 69 | _, rel, _ = self.edge_types[i] 70 | if rel.startswith('r2e'): 71 | rel = rel[4:] 72 | for j, edge in enumerate(self.edge_types): 73 | if edge[0] == rel or edge[2] == rel: 74 | if edges[j] == 1: 75 | return False 76 | return True 77 | 78 | def generate_all_graphs(self): 79 | num_edge_types = len(self.edge_types) 80 | possible_graphs = set() 81 | for mask in range(2 ** num_edge_types): 82 | # Create binary array of selected edges (1 = included, 0 = excluded) 83 | selected_edges = np.zeros(num_edge_types, dtype=int) 84 | for i in range(num_edge_types): 85 | if mask & (1 << i): 86 | selected_edges[i] = 1 87 | if self.is_possible(selected_edges): 88 | filtered_edges = self.filter_unreachable(selected_edges) 89 | key = tuple(filtered_edges) 90 | if key not in possible_graphs and self.is_valid(filtered_edges): 91 | possible_graphs.add(key) 92 | 93 | return sorted(list(possible_graphs)) 94 | 95 | def get_data(self, edges: np.ndarray, data: HeteroData) -> HeteroData: 96 | new_data = HeteroData() 97 | 98 | involved_nodes = set() 99 | for i, is_selected in enumerate(edges): 100 | if is_selected: 101 | src, rel, dst = self.edge_types[i] 102 | involved_nodes.add(src) 103 | involved_nodes.add(dst) 104 | if rel.startswith("r2e"): 105 | involved_nodes.add(rel[4:]) 106 | 107 | for node in involved_nodes: 108 | if node in data.node_types: 109 | for attr_name, attr_value in data[node].items(): 110 | new_data[node][attr_name] = attr_value 111 | 112 | for i, is_selected in enumerate(edges): 113 | if is_selected: 114 | edge_type = self.edge_types[i] 115 | if edge_type in data.edge_types: 116 | for attr_name, attr_value in data[edge_type].items(): 117 | new_data[edge_type][attr_name] = attr_value 118 | rev_edge_type = (edge_type[2], 'rev_' + edge_type[1], edge_type[0]) 119 | assert rev_edge_type in data.edge_types 120 | for attr_name, attr_value in data[rev_edge_type].items(): 121 | new_data[rev_edge_type][attr_name] = attr_value 122 | 123 | return new_data 124 | 125 | def __repr__(self) -> str: 126 | return (f"BasicSearchSpace(dataset={self.dataset}, " 127 | f"task={self.task}, " 128 | f"src_entity_table={self.src_entity_table}, " 129 | f"matrix_shape={self.matrix}, " 130 | f"node_types={self.node_types})") 131 | 132 | 133 | class GNNNodeSearchSpace(GNNSearchSpace): 134 | def __init__( 135 | self, 136 | dataset: str, 137 | task: str, 138 | num_layers: int, 139 | node_types: List[str], 140 | edge_types: List[Tuple[str, str, str]], 141 | src_entity_table: str, 142 | dst_entity_table: str = None, 143 | ): 144 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table) 145 | self.reachable_nodes = {self.src_entity_table} 146 | if dst_entity_table is not None: 147 | raise ValueError("dst_entity_table must be None for GNNNodeSearchSpace") 148 | 149 | 150 | def filter_unreachable(self, edges: torch.Tensor) -> torch.Tensor: 151 | return super().filter_unreachable(edges) 152 | 153 | def is_valid(self, edges: torch.Tensor) -> bool: 154 | # if target entity has no incoming/outcoming edges, the graph is invalid 155 | src_has_edges = False 156 | 157 | for i, is_selected in enumerate(edges): 158 | if is_selected == 1: 159 | src, _, dst = self.edge_types[i] 160 | if src == self.src_entity_table or dst == self.src_entity_table: 161 | src_has_edges = True 162 | break 163 | 164 | if not src_has_edges: 165 | return False 166 | 167 | return True 168 | 169 | def __repr__(self) -> str: 170 | return (f"GNNNodeSearchSpace(dataset={self.dataset}, " 171 | f"task={self.task}, " 172 | f"src_entity_table={self.src_entity_table}, " 173 | f"node_types={self.node_types})") 174 | 175 | 176 | class GNNLinkSearchSpace(GNNSearchSpace): 177 | def __init__( 178 | self, 179 | dataset: str, 180 | task: str, 181 | num_layers: int, 182 | node_types: List[str], 183 | edge_types: List[Tuple[str, str, str]], 184 | src_entity_table: str, 185 | dst_entity_table: str, 186 | ): 187 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table) 188 | self.reachable_nodes = {self.src_entity_table, self.dst_entity_table} 189 | 190 | def filter_unreachable(self, edges: np.ndarray) -> np.ndarray: 191 | return super().filter_unreachable(edges) 192 | 193 | def is_valid(self, edges: np.ndarray) -> bool: 194 | # Check if source and destination entity tables have any edges 195 | src_has_edges = False 196 | dst_has_edges = False 197 | 198 | for i, is_selected in enumerate(edges): 199 | if is_selected == 1: 200 | src, _, dst = self.edge_types[i] 201 | if src == self.src_entity_table or dst == self.src_entity_table: 202 | src_has_edges = True 203 | if src == self.dst_entity_table or dst == self.dst_entity_table: 204 | dst_has_edges = True 205 | if src_has_edges and dst_has_edges: 206 | break 207 | 208 | # If either source or destination entity has no edges, the graph is invalid 209 | if not src_has_edges or not dst_has_edges: 210 | return False 211 | 212 | return True 213 | 214 | def __repr__(self) -> str: 215 | return (f"GNNLinkSearchSpace(dataset={self.dataset}, " 216 | f"task={self.task}, " 217 | f"src_entity_table={self.src_entity_table}, " 218 | f"dst_entity_table={self.dst_entity_table}, " 219 | f"node_types={self.node_types})") 220 | 221 | 222 | class IDGNNLinkSearchSpace(GNNSearchSpace): 223 | def __init__( 224 | self, 225 | dataset: str, 226 | task: str, 227 | num_layers: int, 228 | node_types: List[str], 229 | edge_types: List[Tuple[str, str, str]], 230 | src_entity_table: str, 231 | dst_entity_table: str, 232 | ): 233 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table) 234 | self.reachable_nodes = {self.src_entity_table} 235 | 236 | def filter_unreachable(self, edges: np.ndarray) -> np.ndarray: 237 | return super().filter_unreachable(edges) 238 | 239 | def is_valid(self, edges: np.ndarray) -> bool: 240 | # Check if source and destination entity tables have any edges 241 | src_has_edges = False 242 | dst_has_edges = False 243 | 244 | for i, is_selected in enumerate(edges): 245 | if is_selected == 1: 246 | src, _, dst = self.edge_types[i] 247 | if src == self.src_entity_table or dst == self.src_entity_table: 248 | src_has_edges = True 249 | if src == self.dst_entity_table or dst == self.dst_entity_table: 250 | dst_has_edges = True 251 | if src_has_edges and dst_has_edges: 252 | break 253 | 254 | # If source or destination entity has no edges, the graph is invalid 255 | if not src_has_edges or not dst_has_edges: 256 | return False 257 | 258 | reachable_from_src = {self.src_entity_table} 259 | 260 | for _ in range(len(self.node_types)): 261 | if self.dst_entity_table in reachable_from_src: 262 | break 263 | 264 | new_reachable = set() 265 | for i, is_selected in enumerate(edges): 266 | if is_selected == 1: 267 | src, _, dst = self.edge_types[i] 268 | if src in reachable_from_src: 269 | new_reachable.add(dst) 270 | if dst in reachable_from_src: 271 | new_reachable.add(src) 272 | 273 | reachable_from_src.update(new_reachable) 274 | 275 | if not new_reachable: 276 | break 277 | 278 | # If dst_entity_table is not reachable from src_entity_table, the graph is invalid 279 | if self.dst_entity_table not in reachable_from_src: 280 | return False 281 | 282 | return True 283 | 284 | def __repr__(self) -> str: 285 | return (f"IDGNNLinkSearchSpace(dataset={self.dataset}, " 286 | f"task={self.task}, " 287 | f"src_entity_table={self.src_entity_table}, " 288 | f"dst_entity_table={self.dst_entity_table}, " 289 | f"node_types={self.node_types})") 290 | -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/llm_micro_action.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Type, Union, List 2 | from torch_geometric.data import HeteroData 3 | 4 | from ...common.search_space.search_space import TotalSearchSpace 5 | from ...common.search_space.gnn_search_space import GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace 6 | from ..micro_action import MicroActionSet 7 | 8 | class LLMMicroActionSet(MicroActionSet): 9 | def __init__(self, 10 | dataset: str, 11 | task: str, 12 | hetero_data: HeteroData, 13 | GNNSpaceClass: Type[Union[GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace]], 14 | num_layers: int, 15 | src_entity_table: str, 16 | dst_entity_table: Optional[str] = None): 17 | super().__init__(dataset, task, hetero_data, GNNSpaceClass, num_layers, src_entity_table, dst_entity_table) 18 | 19 | def add_fk_pk_edge(self, 20 | current_edge_set: Tuple[int, ...], 21 | from_table_name: str, 22 | from_col_name: str, 23 | to_table_name: str, 24 | ) -> Tuple[Tuple[int, ...], int, str]: 25 | graph_idx, error_msg = -1, "" 26 | edge_to_add = (from_table_name, f"f2p_{from_col_name}", to_table_name) 27 | if edge_to_add in self.full_edges: 28 | edge_index = self.full_edges.index(edge_to_add) 29 | if current_edge_set[edge_index] == 0: 30 | new_edge_set_list = list(current_edge_set) 31 | new_edge_set_list[edge_index] = 1 32 | new_edge_set = tuple(new_edge_set_list) 33 | if new_edge_set in self.valid_edge_sets: 34 | graph_idx = self.valid_edge_sets_list.index(new_edge_set) 35 | error_msg = "" 36 | else: 37 | new_edge_set = current_edge_set 38 | error_msg = f"Given add_fk_pk_edge action is not valid." 39 | else: 40 | new_edge_set = current_edge_set 41 | error_msg = f"Given edge type({edge_to_add}) between {from_table_name} and {to_table_name} is already connected." 42 | else: 43 | new_edge_set = current_edge_set 44 | error_msg = f"Given edge type({edge_to_add}) between {from_table_name} and {to_table_name} is an invalid edge type." 45 | 46 | return new_edge_set, graph_idx, error_msg 47 | 48 | def remove_fk_pk_edge(self, 49 | current_edge_set: Tuple[int, ...], 50 | from_table_name: str, 51 | from_col_name: str, 52 | to_table_name: str, 53 | ) -> Tuple[Tuple[int, ...], int, str]: 54 | graph_idx, error_msg = -1, "" 55 | edge_to_remove = (from_table_name, f"f2p_{from_col_name}", to_table_name) 56 | if edge_to_remove in self.full_edges: 57 | edge_index = self.full_edges.index(edge_to_remove) 58 | if current_edge_set[edge_index] == 1: 59 | new_edge_set_list = list(current_edge_set) 60 | new_edge_set_list[edge_index] = 0 61 | new_edge_set = tuple(new_edge_set_list) 62 | if new_edge_set in self.valid_edge_sets: 63 | graph_idx = self.valid_edge_sets_list.index(new_edge_set) 64 | error_msg = "" 65 | else: 66 | new_edge_set = current_edge_set 67 | error_msg = f"Given remove_fk_pk_edge action is not valid." 68 | else: 69 | new_edge_set = current_edge_set 70 | error_msg = f"Given edge type({edge_to_remove}) between {from_table_name} and {to_table_name} is not connected." 71 | else: 72 | new_edge_set = current_edge_set 73 | error_msg = f"Given edge type({edge_to_remove}) between {from_table_name} and {to_table_name} is an invalid edge type." 74 | 75 | return new_edge_set, graph_idx, error_msg 76 | 77 | def convert_row_to_edge(self, 78 | current_edge_set: Tuple[int, ...], 79 | table_1_name: str, 80 | table_2_name: str, 81 | edge_table_name: str 82 | ) -> Tuple[Tuple[int, ...], int, str]: 83 | graph_idx, error_msg = -1, "" 84 | if (table_1_name, f"r2e_{edge_table_name}", table_2_name) in self.full_edges: 85 | convert_edge = (table_1_name, f"r2e_{edge_table_name}", table_2_name) 86 | convert_edge_index = self.full_edges.index(convert_edge) 87 | elif (table_2_name, f"r2e_{edge_table_name}", table_1_name) in self.full_edges: 88 | convert_edge = (table_2_name, f"r2e_{edge_table_name}", table_1_name) 89 | convert_edge_index = self.full_edges.index(convert_edge) 90 | else: 91 | return current_edge_set, -1, f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is an invalid edge type." 92 | 93 | if current_edge_set[convert_edge_index] == 0: 94 | f2p_indices = self.r2e_to_f2p_map.get(convert_edge_index) 95 | if f2p_indices and f2p_indices[0] is not None and f2p_indices[1] is not None: 96 | f2p_idx1, f2p_idx2 = f2p_indices 97 | if current_edge_set[f2p_idx1] == 1 and current_edge_set[f2p_idx2] == 1: 98 | new_edge_set_list = list(current_edge_set) 99 | new_edge_set_list[convert_edge_index] = 1 100 | new_edge_set_list[f2p_idx1] = 0 101 | new_edge_set_list[f2p_idx2] = 0 102 | new_edge_set = tuple(new_edge_set_list) 103 | if new_edge_set in self.valid_edge_sets: 104 | graph_idx = self.valid_edge_sets_list.index(new_edge_set) 105 | error_msg = "" 106 | else: 107 | new_edge_set = current_edge_set 108 | error_msg = f"Given convert_edge_to_row action is not valid." 109 | else: 110 | new_edge_set = current_edge_set 111 | error_msg = f"Given convert_edge_to_row action is not valid." 112 | else: 113 | new_edge_set = current_edge_set 114 | error_msg = f"Given convert_edge_to_row action is not valid." 115 | else: 116 | new_edge_set = current_edge_set 117 | error_msg = f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is already converted to edge." 118 | 119 | return new_edge_set, graph_idx, error_msg 120 | 121 | def convert_edge_to_row(self, 122 | current_edge_set: Tuple[int, ...], 123 | table_1_name: str, 124 | table_2_name: str, 125 | edge_table_name: str 126 | ) -> Tuple[Tuple[int, ...], int, str]: 127 | graph_idx, error_msg = -1, "" 128 | if (table_1_name, f"r2e_{edge_table_name}", table_2_name) in self.full_edges: 129 | convert_edge = (table_1_name, f"r2e_{edge_table_name}", table_2_name) 130 | convert_edge_index = self.full_edges.index(convert_edge) 131 | elif (table_2_name, f"r2e_{edge_table_name}", table_1_name) in self.full_edges: 132 | convert_edge = (table_2_name, f"r2e_{edge_table_name}", table_1_name) 133 | convert_edge_index = self.full_edges.index(convert_edge) 134 | else: 135 | return current_edge_set, -1, f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is an invalid edge type." 136 | 137 | if current_edge_set[convert_edge_index] == 1: 138 | f2p_indices = self.r2e_to_f2p_map.get(convert_edge_index) 139 | if f2p_indices and f2p_indices[0] is not None and f2p_indices[1] is not None: 140 | f2p_idx1, f2p_idx2 = f2p_indices 141 | if current_edge_set[f2p_idx1] == 0 and current_edge_set[f2p_idx2] == 0: 142 | new_edge_set_list = list(current_edge_set) 143 | new_edge_set_list[convert_edge_index] = 0 144 | new_edge_set_list[f2p_idx1] = 1 145 | new_edge_set_list[f2p_idx2] = 1 146 | new_edge_set = tuple(new_edge_set_list) 147 | if new_edge_set in self.valid_edge_sets: 148 | graph_idx = self.valid_edge_sets_list.index(new_edge_set) 149 | error_msg = "" 150 | else: 151 | new_edge_set = current_edge_set 152 | error_msg = f"Given convert_edge_to_row action is not valid." 153 | else: 154 | new_edge_set = current_edge_set 155 | error_msg = f"Given convert_edge_to_row action is not valid." 156 | else: 157 | new_edge_set = current_edge_set 158 | error_msg = f"Given convert_edge_to_row action is not valid." 159 | else: 160 | new_edge_set = current_edge_set 161 | error_msg = f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is not converted to edge." 162 | 163 | return new_edge_set, graph_idx, error_msg 164 | 165 | def get_possible_add_fk_pk_edge(self, 166 | current_edge_set: Tuple[int, ...] 167 | ) -> List[Tuple[Tuple[int, ...], int]]: 168 | possible_next_sets_with_indices = [] 169 | for edge_index_to_add in self.fk_pk_indices: 170 | if current_edge_set[edge_index_to_add] == 0: 171 | new_edge_set_list = list(current_edge_set) 172 | new_edge_set_list[edge_index_to_add] = 1 173 | # new_edge_set_list = self.search_space.gnn_search_space.filter_unreachable(new_edge_set_list) 174 | new_edge_set = tuple(new_edge_set_list) 175 | if new_edge_set in self.valid_edge_sets: 176 | assert new_edge_set in self.valid_edge_sets_list, f"Edge set {new_edge_set} in set but not in list!" 177 | # index = self.valid_edge_sets_list.index(new_edge_set) 178 | possible_next_sets_with_indices.append(new_edge_set) 179 | return possible_next_sets_with_indices 180 | 181 | def get_possible_remove_fk_pk_edge(self, 182 | current_edge_set: Tuple[int, ...] 183 | ) -> List[Tuple[Tuple[int, ...], int]]: 184 | possible_next_sets_with_indices = [] 185 | for edge_index_to_remove in self.fk_pk_indices: 186 | if current_edge_set[edge_index_to_remove] == 1: 187 | new_edge_set_list = list(current_edge_set) 188 | new_edge_set_list[edge_index_to_remove] = 0 189 | new_edge_set = tuple(new_edge_set_list) 190 | if new_edge_set in self.valid_edge_sets: 191 | assert new_edge_set in self.valid_edge_sets_list, f"Edge set {new_edge_set} in set but not in list!" 192 | possible_next_sets_with_indices.append(new_edge_set) 193 | return possible_next_sets_with_indices 194 | 195 | def get_possible_convert_row_to_edge(self, 196 | current_edge_set: Tuple[int, ...] 197 | ) -> List[Tuple[Tuple[int, ...], int]]: 198 | possible_next_sets_with_indices = [] 199 | for f2p_pair, r2e_idx in self.f2p_pair_to_r2e_map.items(): 200 | f2p_idx1, f2p_idx2 = f2p_pair 201 | if current_edge_set[r2e_idx] == 0 and current_edge_set[f2p_idx1] == 1 and current_edge_set[f2p_idx2] == 1: 202 | new_edge_set_list = list(current_edge_set) 203 | new_edge_set_list[r2e_idx] = 1 204 | new_edge_set_list[f2p_idx1] = 0 205 | new_edge_set_list[f2p_idx2] = 0 206 | new_edge_set = tuple(new_edge_set_list) 207 | if new_edge_set in self.valid_edge_sets: 208 | possible_next_sets_with_indices.append(new_edge_set) 209 | return possible_next_sets_with_indices 210 | 211 | def get_possible_convert_edge_to_row(self, 212 | current_edge_set: Tuple[int, ...] 213 | ) -> List[Tuple[Tuple[int, ...], int]]: 214 | possible_next_sets_with_indices = [] 215 | for f2p_pair, r2e_idx in self.f2p_pair_to_r2e_map.items(): 216 | f2p_idx1, f2p_idx2 = f2p_pair 217 | if current_edge_set[f2p_idx1] == 0 and current_edge_set[f2p_idx2] == 0 and current_edge_set[r2e_idx] == 1: 218 | new_edge_set_list = list(current_edge_set) 219 | new_edge_set_list[f2p_idx1] = 1 220 | new_edge_set_list[f2p_idx2] = 1 221 | new_edge_set_list[r2e_idx] = 0 222 | new_edge_set = tuple(new_edge_set_list) 223 | if new_edge_set in self.valid_edge_sets: 224 | possible_next_sets_with_indices.append(new_edge_set) 225 | return possible_next_sets_with_indices -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/baselines/ea.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | from typing import Dict, Union, List, Optional 6 | 7 | from ..dataset import PerformancePredictionDataset 8 | from ..micro_action import MicroActionSet 9 | from .utils import calculate_overall_rank, get_performance_for_index, update_trajectory_and_best, pad_trajectory, calculate_evaluation_time 10 | 11 | def evolutionary_heuristic_analysis( 12 | dataset: PerformancePredictionDataset, 13 | micro_action_set: MicroActionSet, 14 | overall_actual_y: torch.Tensor, 15 | higher_is_better: bool, 16 | termination_threshold_ratio: float, 17 | method_name: str = "Evolutionary Heuristic", 18 | population_size: int = 10, 19 | tournament_size: int = 10, 20 | max_iterations: int = 1000, 21 | ): 22 | """ 23 | Perform Neural Architecture Search using Evolutionary Algorithm. 24 | 25 | This function implements a complete evolutionary algorithm for finding optimal 26 | graph neural network architectures. It maintains a population of architectures, 27 | applies micro action-based mutations, and uses tournament selection for evolution. 28 | 29 | Args: 30 | dataset (PerformancePredictionDataset): Dataset containing architecture 31 | performance data 32 | micro_action_set (MicroActionSet): Set of micro actions for architecture 33 | space exploration 34 | overall_actual_y (torch.Tensor): Complete performance tensor for 35 | ranking calculations 36 | higher_is_better (bool): Whether higher performance values are better 37 | termination_threshold_ratio (float): Fraction of total architectures to 38 | evaluate as budget 39 | method_name (str): Name identifier for this method. 40 | Defaults to "Evolutionary Heuristic". 41 | population_size (int): Number of individuals in the population. 42 | Defaults to 10. 43 | tournament_size (int): Number of individuals selected for tournament. 44 | Defaults to 10. 45 | max_iterations (int): Maximum number of evolutionary generations. 46 | Defaults to 1000. 47 | 48 | Returns: 49 | Dict[str, Union[str, int, float, List, Optional[int]]]: Dictionary containing search results and performance metrics. 50 | 51 | - method (str): Method name 52 | - selected_graph_id (Optional[int]): Index of best found architecture 53 | - actual_y_perf_of_selected (float): Performance of selected architecture 54 | - selection_metric_value (float): Metric value used for selection 55 | - selected_graph_origin (str): Origin method name 56 | - discovered_count (int): Number of architectures evaluated 57 | - total_iterations_run (int): Number of generations completed 58 | - rank_position_overall (float): Rank among all architectures 59 | - percentile_overall (float): Percentile ranking 60 | - total_samples_overall (int): Total available architectures 61 | - performance_trajectory (List): Performance over time 62 | - total_evaluation_time (float): Time spent on evaluations 63 | - total_run_time (float): Total algorithm runtime 64 | 65 | Example: 66 | >>> results = evolutionary_heuristic_analysis( 67 | ... dataset=dataset, 68 | ... micro_action_set=micro_actions, 69 | ... overall_actual_y=y_tensor, 70 | ... higher_is_better=True, 71 | ... termination_threshold_ratio=0.05, 72 | ... population_size=20, 73 | ... tournament_size=5, 74 | ... max_iterations=100 75 | ... ) 76 | >>> print(f"Best architecture: {results['selected_graph_id']}") 77 | >>> print(f"Performance: {results['actual_y_perf_of_selected']:.4f}") 78 | """ 79 | population = [] 80 | seen_indices = set() 81 | performance_cache = {} 82 | time_cache = {} 83 | performance_trajectory = [] 84 | num_total_valid_graphs = len(micro_action_set.valid_edge_sets_list) 85 | discovered_count = 0 86 | best_perf_so_far = float('-inf') if higher_is_better else float('inf') 87 | best_index_so_far = -1 88 | final_iteration_count = 0 89 | total_evaluated_count = 0 90 | total_evaluation_time = 0.0 91 | 92 | start_time = time.time() 93 | 94 | if num_total_valid_graphs == 0: 95 | print("Error: No valid graphs found in MicroActionSet. Cannot initialize population.") 96 | initial_indices = [] 97 | else: 98 | pop_size_to_init = min(population_size, num_total_valid_graphs) 99 | if pop_size_to_init <= 0: 100 | print("Warning: Cannot initialize population with size <= 0.") 101 | initial_indices = [] 102 | elif pop_size_to_init > num_total_valid_graphs: 103 | print(f"Warning: Requested population size ({population_size}) exceeds number of valid graphs ({num_total_valid_graphs}). Using {num_total_valid_graphs}.") 104 | pop_size_to_init = num_total_valid_graphs 105 | initial_indices = np.random.choice( 106 | num_total_valid_graphs, 107 | size=pop_size_to_init, 108 | replace=False 109 | ) 110 | else: 111 | initial_indices = np.random.choice( 112 | num_total_valid_graphs, 113 | size=pop_size_to_init, 114 | replace=False 115 | ) 116 | 117 | print(f"Initializing population with {len(initial_indices)} individuals...") 118 | for index in initial_indices: 119 | population.append((index, 0)) 120 | if index not in seen_indices: 121 | seen_indices.add(index) 122 | discovered_count += 1 123 | 124 | initial_cache_size = len(performance_cache) 125 | perf = get_performance_for_index(index, dataset, performance_cache) 126 | 127 | if perf is not None: 128 | performance_cache[index] = perf 129 | 130 | eval_time = calculate_evaluation_time(index, dataset, time_cache) 131 | if eval_time is not None: 132 | if len(performance_cache) > initial_cache_size: 133 | total_evaluation_time += eval_time 134 | 135 | total_evaluated_count, best_perf_so_far, best_index_so_far = \ 136 | update_trajectory_and_best( 137 | index, perf, performance_cache, initial_cache_size, 138 | total_evaluated_count, performance_trajectory, 139 | best_perf_so_far, best_index_so_far, higher_is_better 140 | ) 141 | 142 | print(f"After initialization ({total_evaluated_count} evals): Best perf={best_perf_so_far:.4f}") 143 | 144 | termination_count_threshold = int(termination_threshold_ratio * num_total_valid_graphs) 145 | print(f"Evolution budget set to {termination_count_threshold} unique evaluations.") 146 | 147 | for iteration in range(max_iterations): 148 | final_iteration_count = iteration + 1 149 | if not population: 150 | break 151 | 152 | mutated_indices_this_step = [] 153 | 154 | current_pop_size = len(population) 155 | actual_tournament_size = min(tournament_size, current_pop_size) 156 | if actual_tournament_size <= 0: 157 | break 158 | 159 | tournament_candidate_indices = np.random.choice( 160 | current_pop_size, 161 | size=actual_tournament_size, 162 | replace=False 163 | ) 164 | 165 | for pop_list_idx in tournament_candidate_indices: 166 | current_index, _ = population[pop_list_idx] 167 | current_edge_set = micro_action_set.valid_edge_sets_list[current_index] 168 | 169 | action_fns = [ 170 | micro_action_set.add_fk_pk_edge, 171 | micro_action_set.remove_fk_pk_edge, 172 | micro_action_set.convert_row_to_edge, 173 | micro_action_set.convert_edge_to_row, 174 | ] 175 | random.shuffle(action_fns) 176 | 177 | mutated_index = current_index 178 | for chosen_action_fn in action_fns: 179 | possible_next_states = chosen_action_fn(current_edge_set) 180 | 181 | if possible_next_states: 182 | _, next_index = random.choice(possible_next_states) 183 | mutated_index = next_index 184 | 185 | if mutated_index not in seen_indices: 186 | seen_indices.add(mutated_index) 187 | discovered_count += 1 188 | break 189 | 190 | mutated_indices_this_step.append(mutated_index) 191 | 192 | tournament_performances = [] 193 | tournament_indices = mutated_indices_this_step 194 | valid_tournament_results = [] 195 | 196 | terminated_early = False 197 | for index in tournament_indices: 198 | initial_cache_size = len(performance_cache) 199 | perf = get_performance_for_index(index, dataset, performance_cache) 200 | 201 | if perf is not None: 202 | tournament_performances.append(perf) 203 | valid_tournament_results.append((perf, index)) 204 | 205 | performance_cache[index] = perf 206 | 207 | eval_time = calculate_evaluation_time(index, dataset, time_cache) 208 | if eval_time is not None: 209 | if len(performance_cache) > initial_cache_size: 210 | total_evaluation_time += eval_time 211 | 212 | new_total_evaluated_count, new_best_perf_so_far, new_best_index_so_far = \ 213 | update_trajectory_and_best( 214 | index, perf, performance_cache, initial_cache_size, 215 | total_evaluated_count, performance_trajectory, 216 | best_perf_so_far, best_index_so_far, higher_is_better 217 | ) 218 | 219 | if len(performance_cache) > initial_cache_size: 220 | total_evaluated_count = new_total_evaluated_count 221 | best_perf_so_far = new_best_perf_so_far 222 | best_index_so_far = new_best_index_so_far 223 | 224 | if total_evaluated_count >= termination_count_threshold: 225 | print(f"Termination condition met at iteration {iteration+1}: Evaluated {total_evaluated_count} unique graphs >= threshold {termination_count_threshold}") 226 | terminated_early = True 227 | break 228 | else: 229 | best_perf_so_far = new_best_perf_so_far 230 | best_index_so_far = new_best_index_so_far 231 | 232 | if terminated_early: 233 | break 234 | 235 | if not valid_tournament_results: 236 | print(f"Warning: Iteration {iteration+1} - No valid performances in tournament. Skipping replacement.") 237 | continue 238 | 239 | if higher_is_better: 240 | best_tournament_perf, best_tournament_index = max(valid_tournament_results, key=lambda item: item[0]) 241 | else: 242 | best_tournament_perf, best_tournament_index = min(valid_tournament_results, key=lambda item: item[0]) 243 | 244 | oldest_pop_list_idx = -1 245 | max_age = -1 246 | for i, (_, age) in enumerate(population): 247 | if age > max_age: 248 | max_age = age 249 | oldest_pop_list_idx = i 250 | 251 | if oldest_pop_list_idx != -1: 252 | population[oldest_pop_list_idx] = (best_tournament_index, 0) 253 | 254 | for i in range(len(population)): 255 | if i != oldest_pop_list_idx: 256 | idx, age = population[i] 257 | population[i] = (idx, age + 1) 258 | 259 | final_selected_index = best_index_so_far 260 | final_selected_perf = best_perf_so_far 261 | 262 | pad_trajectory(performance_trajectory, total_evaluated_count, termination_count_threshold, method_name) 263 | 264 | total_run_time = time.time() - start_time 265 | 266 | results = { 267 | "method": method_name, 268 | "selected_graph_id": final_selected_index, 269 | "actual_y_perf_of_selected": final_selected_perf, 270 | "selection_metric_value": final_selected_perf, 271 | "selected_graph_origin": "Evolutionary", 272 | "discovered_count": discovered_count, 273 | "total_iterations_run": final_iteration_count, 274 | "rank_position_overall": np.nan, 275 | "percentile_overall": np.nan, 276 | "total_samples_overall": overall_actual_y.numel() if overall_actual_y is not None else 0, 277 | "performance_trajectory": performance_trajectory, 278 | "total_evaluation_time": total_evaluation_time, 279 | "total_run_time": total_run_time 280 | } 281 | 282 | rank_info = calculate_overall_rank( 283 | final_selected_perf, 284 | overall_actual_y, 285 | higher_is_better 286 | ) 287 | if rank_info: 288 | results["rank_position_overall"] = rank_info["rank_position_overall"] 289 | results["percentile_overall"] = rank_info["percentile_overall"] 290 | 291 | return results 292 | -------------------------------------------------------------------------------- /rdb2g_bench/benchmark/llm/llm_utils.py: -------------------------------------------------------------------------------- 1 | from .llm_micro_action import LLMMicroActionSet 2 | 3 | import pandas as pd 4 | import json 5 | 6 | 7 | def get_budget(dataset, task, budget, gnn="GraphSAGE", tag="final", path='../results/tables'): 8 | csv_path = f"{path}/{dataset}/{task}/{tag}/{gnn}/0.csv" 9 | df = pd.read_csv(csv_path) 10 | max_idx = df['idx'].max() + 1 11 | budget = int(max_idx * budget) 12 | return budget 13 | 14 | 15 | def get_micro_action_result( 16 | action: dict, 17 | llm_micro_action_set: LLMMicroActionSet, 18 | current_edge_set: list 19 | ): 20 | 21 | if action['action'] == 'add_fk_pk_edge': 22 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.add_fk_pk_edge( 23 | current_edge_set=current_edge_set, 24 | from_table_name=action['parameters']['from_table_name'], 25 | from_col_name=action['parameters']['from_col_name'], 26 | to_table_name=action['parameters']['to_table_name'] 27 | ) 28 | elif action['action'] == 'remove_fk_pk_edge': 29 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.remove_fk_pk_edge( 30 | current_edge_set=current_edge_set, 31 | from_table_name=action['parameters']['from_table_name'], 32 | from_col_name=action['parameters']['from_col_name'], 33 | to_table_name=action['parameters']['to_table_name'] 34 | ) 35 | 36 | elif action['action'] == 'convert_row_to_edge': 37 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.convert_row_to_edge( 38 | current_edge_set=current_edge_set, 39 | table_1_name=action['parameters']['table_1_name'], 40 | table_2_name=action['parameters']['table_2_name'], 41 | edge_table_name=action['parameters']['edge_table_name'] 42 | ) 43 | elif action['action'] == 'convert_edge_to_row': 44 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.convert_edge_to_row( 45 | current_edge_set=current_edge_set, 46 | table_1_name=action['parameters']['table_1_name'], 47 | table_2_name=action['parameters']['table_2_name'], 48 | edge_table_name=action['parameters']['edge_table_name'] 49 | ) 50 | 51 | else: 52 | # raise ValueError(f"Action {action['action']} is unsupported.") 53 | return None, None, None 54 | 55 | return new_edge_set, graph_idx, error_msg 56 | 57 | def check_all_actions( 58 | actions: list [dict], 59 | llm_micro_action_set: LLMMicroActionSet, 60 | current_edge_set: tuple, 61 | perf_pred_dataset: dict 62 | ): 63 | new_edge_sets, graph_idxs, error_msgs, scores = [], [], [], [] 64 | for action in actions : 65 | action['parameters'] = action['parameters'][0] if type(action['parameters']) == list else action['parameters'] 66 | new_edge_set, graph_idx, error_msg = get_micro_action_result( 67 | action=action, 68 | llm_micro_action_set=llm_micro_action_set, 69 | current_edge_set=current_edge_set 70 | ) 71 | if new_edge_set is None: 72 | continue 73 | # update results 74 | new_edge_sets.append(new_edge_set) 75 | graph_idxs.append(graph_idx) 76 | error_msgs.append(error_msg) 77 | score = perf_pred_dataset.get(graph_idx).y.item() if graph_idx != -1 else 0 78 | scores.append(score) 79 | 80 | if all(graph_idx == -1 for graph_idx in graph_idxs): 81 | return actions[0], new_edge_sets[0], graph_idxs[0], error_msgs[0] 82 | else: 83 | idx = scores.index(max(scores)) 84 | return actions[idx], new_edge_sets[idx], graph_idxs[idx], error_msgs[idx] 85 | 86 | 87 | 88 | def get_available_edges(full_edges: list): 89 | r2e_edges = [] 90 | f2p_edges = [] 91 | for edge in full_edges: 92 | if edge[1].startswith("r2e"): 93 | r2e_edges.append(edge) 94 | elif edge[1].startswith("f2p"): 95 | f2p_edges.append(edge) 96 | 97 | print(f"R2E edges: {r2e_edges}") 98 | print(f"F2P edges: {f2p_edges}") 99 | return r2e_edges, f2p_edges 100 | 101 | def check_none_action(parsed_response_text): 102 | if not parsed_response_text: 103 | return True 104 | if len(parsed_response_text) == 0 : 105 | return True 106 | 107 | parsed_response_text = [parsed_response_text] if type(parsed_response_text) == dict else parsed_response_text 108 | 109 | if 'action' not in parsed_response_text[0]: 110 | return True 111 | elif not parsed_response_text[0]['action']: 112 | return True 113 | elif parsed_response_text[0]['action'].lower() == "none": 114 | return True 115 | return False 116 | 117 | 118 | def get_changed_edge(before_edge_set, after_edge_set, full_edges): 119 | 120 | changed_edge_idx = [i for i, (a, b) in enumerate(zip(before_edge_set, after_edge_set)) if a != b] 121 | 122 | return full_edges[changed_edge_idx[0]] 123 | 124 | def get_edge_info(full_edges, current_edge_set, llm_micro_action_set): 125 | 126 | add_f2p_edge = llm_micro_action_set.get_possible_add_fk_pk_edge(current_edge_set) 127 | add_f2p_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in add_f2p_edge] if len(add_f2p_edge) > 0 else [] 128 | 129 | remove_f2p_edge = llm_micro_action_set.get_possible_remove_fk_pk_edge(current_edge_set) 130 | remove_f2p_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in remove_f2p_edge] if len(remove_f2p_edge) > 0 else [] 131 | 132 | convert_row2edge = llm_micro_action_set.get_possible_convert_row_to_edge(current_edge_set) 133 | convert_row2edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in convert_row2edge] if len(convert_row2edge) > 0 else [] 134 | 135 | convert_edge2row = llm_micro_action_set.get_possible_convert_edge_to_row(current_edge_set) 136 | convert_edge2row = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in convert_edge2row] if len(convert_edge2row) > 0 else [] 137 | 138 | add_row2edge_edge = llm_micro_action_set.get_possible_add_row2edge_edge(current_edge_set) 139 | add_row2edge_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in add_row2edge_edge] if len(add_row2edge_edge) > 0 else [] 140 | 141 | remove_row2edge_edge = llm_micro_action_set.get_possible_remove_row2edge_edge(current_edge_set) 142 | remove_row2edge_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in remove_row2edge_edge] if len(remove_row2edge_edge) > 0 else [] 143 | 144 | 145 | add_f2p_edge_info = "Note: ONLY the following set of fk_pk_edge can be added:" if len(add_f2p_edge) > 0 else "Note: There are **NO** fk_pk_edge that can be added in current schema." 146 | remove_f2p_edge_info = "Note: ONLY the following set of fk_pk_edge can be removed:" if len(remove_f2p_edge) > 0 else "Note: There are **NO** fk_pk_edge that can be removed in current schema." 147 | convert_row2edge_info = "Note: ONLY the following set of edges can be converted from row to edge:" if len(convert_row2edge) > 0 else "Note: There are **NO** edges that can be converted from row to edge in current schema." 148 | convert_edge2row_info = "Note: ONLY the following set of edges can be converted from edge to row:" if len(convert_edge2row) > 0 else "Note: There are **NO** edges that can be converted from edge to row in current schema." 149 | add_row2edge_edge_info = "Note: ONLY the following set of edges can be added:" if len(add_row2edge_edge) > 0 else "Note: There are **NO** row2edge edges that can be added in current schema." 150 | remove_row2edge_edge_info = "Note: ONLY the following set of edges can be removed:" if len(remove_row2edge_edge) > 0 else "Note: There are **NO** row2edge edges that can be removed in current schema." 151 | 152 | for edge in add_f2p_edge: 153 | edge_info = {"from_table_name": edge[0], "from_col_name": edge[1].replace('f2p_',''), "to_table_name": edge[2]} 154 | add_f2p_edge_info += f'\n{json.dumps(edge_info)}' 155 | for edge in remove_f2p_edge: 156 | edge_info = {"from_table_name": edge[0], "from_col_name": edge[1].replace('f2p_',''), "to_table_name": edge[2]} 157 | remove_f2p_edge_info += f'\n{json.dumps(edge_info)}' 158 | for edge in convert_row2edge: 159 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')} 160 | convert_row2edge_info += f'\n{json.dumps(edge_info)}' 161 | for edge in convert_edge2row: 162 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')} 163 | convert_edge2row_info += f'\n{json.dumps(edge_info)}' 164 | for edge in add_row2edge_edge: 165 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')} 166 | add_row2edge_edge_info += f'\n{json.dumps(edge_info)}' 167 | for edge in remove_row2edge_edge: 168 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')} 169 | remove_row2edge_edge_info += f'\n{json.dumps(edge_info)}' 170 | 171 | return {"add_fk_pk_edge": add_f2p_edge_info, 172 | "remove_fk_pk_edge": remove_f2p_edge_info, 173 | "convert_row_to_edge": convert_row2edge_info, 174 | "convert_edge_to_row": convert_edge2row_info, 175 | "add_row2edge_edge": add_row2edge_edge_info, 176 | "remove_row2edge_edge": remove_row2edge_edge_info} 177 | 178 | 179 | def conduct_multiple_actions( 180 | actions: list [dict], 181 | llm_micro_action_set: LLMMicroActionSet, 182 | current_edge_set: tuple 183 | ): 184 | valid_actions = [] 185 | invalid_actions = [] 186 | edge_set = current_edge_set 187 | graph_idx = -1 188 | error_msg = "" 189 | for action in actions : 190 | action['parameters'] = action['parameters'][0] if type(action['parameters']) == list else action['parameters'] 191 | new_edge_set, new_graph_idx, new_error_msg = get_micro_action_result( 192 | action=action, 193 | llm_micro_action_set=llm_micro_action_set, 194 | current_edge_set=current_edge_set 195 | ) 196 | # if the action is not valid, skip it 197 | if new_edge_set is None: 198 | continue 199 | # if the action is invalid, add it to the invalid actions 200 | if new_graph_idx == -1: 201 | error_msg += f"Action: {action['action']} \nParameters: {json.dumps(action['parameters'])} \nError: {new_error_msg}\n" 202 | invalid_actions.append(action) 203 | continue 204 | else: 205 | valid_actions.append(action) 206 | edge_set = new_edge_set 207 | graph_idx = new_graph_idx 208 | # print(f"New edge set: {edge_set}") 209 | # print(f"New graph idx: {graph_idx}") 210 | # print(f"Valid actions: {len(valid_actions)} Invalid actions: {len(invalid_actions)}") 211 | # print(f"Error msg: {error_msg}") 212 | return valid_actions, invalid_actions, edge_set, graph_idx, error_msg 213 | 214 | def update_edge_set( 215 | current_edge_set, 216 | new_edge_set, 217 | best_edge_set, 218 | update_best, 219 | ): 220 | # update best edge set 221 | current_edge_set = new_edge_set 222 | # print(f"Edge set changed from {current_edge_set} to {new_edge_set}") 223 | if update_best: 224 | best_edge_set = new_edge_set 225 | print(f"\033[93mBest edge set updated to {best_edge_set}\033[0m") 226 | return current_edge_set, best_edge_set 227 | 228 | def update_score( 229 | current_score, 230 | new_score, 231 | best_score, 232 | score_result, 233 | update_best, 234 | ): 235 | # if graph_idx = -1, do not update the score ( i.e., new_score = current_score ) 236 | if new_score == current_score: 237 | score_result.append(best_score) 238 | return current_score, current_score, best_score, score_result 239 | 240 | past_score = current_score 241 | current_score = new_score 242 | if update_best: 243 | best_score = current_score 244 | print(f"\033[93mBest score updated to {best_score:.4f}\033[0m") 245 | score_result.append(best_score) 246 | return past_score, current_score, best_score, score_result 247 | 248 | def update_action( 249 | parsed_all_actions, # ALL actions 250 | valid_actions, # ONLY Valid actions 251 | action_result, # history actions 252 | valid_action_result, # history valid actions 253 | best_valid_action_result, 254 | last_action_num, 255 | update_best, 256 | ): 257 | action_result.extend(parsed_all_actions) 258 | valid_action_result.extend(valid_actions) 259 | 260 | if update_best: 261 | best_valid_action_result = valid_action_result.copy() 262 | print(f"\033[93mBest valid action result updated to {len(best_valid_action_result)} actions \033[0m") 263 | else: 264 | pass # use the best_valid_action_result 265 | 266 | if len(valid_actions) > 0: 267 | last_action_num = len(valid_actions) 268 | else: 269 | pass # use the last_action_num 270 | return action_result, valid_action_result, best_valid_action_result, last_action_num 271 | 272 | 273 | def remove_invalid_history_actions( 274 | action_result, 275 | invalid_actions, 276 | ): 277 | print(f"Total history actions: {len(action_result)}") 278 | print(f"InValid history actions: {len(invalid_actions)}") 279 | for action in action_result: 280 | if action in invalid_actions: 281 | action_result.remove(action) 282 | 283 | 284 | print(f"Total history actions: {len(action_result)}") 285 | return action_result 286 | 287 | def get_history_actions( 288 | valid_action_result, 289 | # parsed_all_actions, 290 | max_history_actions, 291 | ): 292 | # valid_action_result.extend(parsed_all_actions) 293 | latest_actions = valid_action_result[-max_history_actions:] 294 | history_actions = json.dumps(latest_actions, indent=2).strip() if len(latest_actions) > 0 else "" 295 | return history_actions 296 | --------------------------------------------------------------------------------