├── docs
├── requirements.txt
├── logo.png
├── Makefile
├── make.bat
└── source
│ ├── 010
│ ├── install_and_setup.rst
│ └── tutorial.rst
│ ├── 030
│ ├── llm_baseline.rst
│ ├── random_search.rst
│ ├── bayesian_optimization.rst
│ ├── evolutionary_algorithm.rst
│ ├── reinforcement_learning.rst
│ ├── benchmark_dataset.rst
│ ├── greedy_search.rst
│ ├── run_benchmark.rst
│ └── micro_action.rst
│ ├── index.rst
│ ├── conf.py
│ └── 020
│ ├── download_precomputed.rst
│ └── reproduce_datasets.rst
├── rdb2g_bench
├── common
│ ├── __init__.py
│ ├── text_embedder.py
│ └── search_space
│ │ ├── search_space.py
│ │ ├── row2ne_search_space.py
│ │ └── gnn_search_space.py
├── benchmark
│ ├── llm
│ │ ├── __init__.py
│ │ ├── prompts
│ │ │ ├── task.json
│ │ │ ├── action.json
│ │ │ └── prompt.py
│ │ ├── llm_micro_action.py
│ │ └── llm_utils.py
│ ├── __init__.py
│ ├── bench_runner.py
│ └── baselines
│ │ ├── random.py
│ │ ├── utils.py
│ │ └── ea.py
├── dataset
│ ├── __init__.py
│ ├── utils.py
│ ├── models
│ │ ├── modules
│ │ │ ├── sage_conv_edge.py
│ │ │ ├── nn.py
│ │ │ ├── gps_conv.py
│ │ │ └── gin_conv.py
│ │ └── model.py
│ └── dataset.py
└── __init__.py
├── .gitignore
├── .readthedocs.yaml
├── MANIFEST.in
├── requirements.txt
├── examples
├── run_benchmark.py
├── run_gnn_node.py
├── run_idgnn_link.py
├── load_dataset.py
└── download_dataset.py
├── LICENSE
├── pyproject.toml
└── README.md
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==7.1.2
2 | sphinx-rtd-theme==1.3.0rc1
3 |
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chlehdwon/RDB2G-Bench/HEAD/docs/logo.png
--------------------------------------------------------------------------------
/rdb2g_bench/common/__init__.py:
--------------------------------------------------------------------------------
1 | from . import search_space
2 |
3 | __all__ = ["search_space"]
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | __pycache__
3 | logs/
4 | results/
5 | rdb2g_bench.egg-info
6 | build/
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from .llm_runner import run_llm_baseline
2 |
3 | __all__ = [
4 | "run_llm_baseline"
5 | ]
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | from . import baselines
2 | from . import llm
3 | from .bench_runner import run_benchmark
4 |
5 | __all__ = [
6 | "baselines",
7 | "llm",
8 | "run_benchmark"
9 | ]
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: "2"
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.10"
7 |
8 | python:
9 | install:
10 | - requirements: docs/requirements.txt
11 | - method: pip
12 | path: .
13 |
14 | sphinx:
15 | configuration: docs/source/conf.py
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from . import models
2 | from .node_worker import run_gnn_node_worker
3 | from .link_worker import run_idgnn_link_worker
4 | from .utils import integrate_edge_tf
5 |
6 | __all__ = [
7 | "models",
8 | "run_gnn_node_worker",
9 | "run_idgnn_link_worker",
10 | "integrate_edge_tf"
11 | ]
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements.txt
4 | include pyproject.toml
5 | recursive-include rdb2g_bench *.py
6 | recursive-include rdb2g_bench *.txt
7 | recursive-include rdb2g_bench *.md
8 | recursive-include rdb2g_bench *.json
9 | recursive-exclude rdb2g_bench __pycache__
10 | recursive-exclude rdb2g_bench *.pyc
11 | recursive-exclude rdb2g_bench *.pyo
--------------------------------------------------------------------------------
/rdb2g_bench/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 | __author__ = "Dongwon Choi"
3 | __email__ = "chlehdwon@kaist.ac.kr"
4 |
5 | # Import main modules for easy access
6 | from . import benchmark
7 | from . import common
8 | from . import dataset
9 |
10 | __all__ = [
11 | "__version__",
12 | "__author__",
13 | "__email__",
14 | "benchmark",
15 | "common",
16 | "dataset",
17 | ]
--------------------------------------------------------------------------------
/rdb2g_bench/common/text_embedder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 |
5 | # Please run `pip install -U sentence-transformers`
6 | from sentence_transformers import SentenceTransformer
7 | from torch import Tensor
8 |
9 |
10 | class GloveTextEmbedding:
11 | def __init__(self, device: Optional[torch.device] = None):
12 | self.model = SentenceTransformer(
13 | "sentence-transformers/average_word_embeddings_glove.6B.300d",
14 | device=device,
15 | )
16 |
17 | def __call__(self, sentences: List[str]) -> Tensor:
18 | return self.model.encode(sentences, convert_to_tensor=True)
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core dependencies
2 | torch>=2.1.0
3 | torch_geometric==2.6.1
4 | numpy>=1.26.0
5 | pandas>=2.2.0
6 | scikit-learn>=1.6.0
7 | sentence-transformers>=3.3.1
8 |
9 | # RelBench
10 | relbench==1.1.0
11 | pytorch-frame>=0.2.5
12 |
13 | # Hugging Face datasets
14 | datasets>=2.15.0
15 |
16 | # Plotting and visualization
17 | matplotlib>=3.10.0
18 | seaborn>=0.13.0
19 | networkx>=3.1
20 |
21 | # Utilities
22 | tqdm>=4.65.0
23 |
24 | # LLM dependencies
25 | anthropic>=0.25.0
26 |
27 | # PyTorch Geometric dependencies
28 | # Install separately: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
29 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/examples/run_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["ANTHROPIC_API_KEY"] = "YOUR_API_KEY"
3 |
4 | from rdb2g_bench.benchmark.bench_runner import run_benchmark
5 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline
6 |
7 | # Example 1: Basic run
8 | run_benchmark(
9 | dataset="rel-f1",
10 | task="driver-top3",
11 | gnn="GraphSAGE",
12 | budget_percentage=0.05,
13 | method="all",
14 | num_runs=1,
15 | seed=0,
16 | )
17 |
18 | # Example 2: Run only for specific methods (Currently, Greedy, BO, RL, and EA are supported)
19 | run_benchmark(
20 | dataset="rel-f1",
21 | task="driver-top3",
22 | gnn="GraphSAGE",
23 | budget_percentage=0.05,
24 | method=["rl", "bo"],
25 | num_runs=1,
26 | seed=0,
27 | )
28 |
29 | # Example 3: Run LLM-based baseline
30 | run_llm_baseline(
31 | dataset="rel-f1",
32 | task="driver-top3",
33 | gnn="GraphSAGE",
34 | budget_percentage=0.05,
35 | seed=0,
36 | )
--------------------------------------------------------------------------------
/docs/source/010/install_and_setup.rst:
--------------------------------------------------------------------------------
1 | Install and Setup
2 | ==================
3 |
4 | Installation
5 | ------------
6 |
7 | Clone the repository and install the package:
8 |
9 | .. code-block:: bash
10 |
11 | git clone https://github.com/chlehdwon/RDB2G-Bench.git
12 | cd RDB2G-Bench
13 | pip install -e .
14 |
15 | PyTorch Geometric Dependencies
16 | ------------------------------
17 |
18 | Install additional PyG dependencies. The example below shows installation for torch 2.1.0 + cuda 12.1:
19 |
20 | .. code-block:: bash
21 |
22 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
23 |
24 | .. note::
25 | You can skip this step if you don't want to reproduce the datasets.
26 |
27 | Environment Setup for LLM usage
28 | --------------------------------
29 |
30 | For LLM-based baselines, set up your API key:
31 |
32 | .. code-block:: bash
33 |
34 | export ANTHROPIC_API_KEY="YOUR_API_KEY"
35 |
--------------------------------------------------------------------------------
/examples/run_gnn_node.py:
--------------------------------------------------------------------------------
1 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker
2 |
3 | # Example 1: Basic run
4 | results = run_gnn_node_worker(
5 | dataset_name="rel-f1",
6 | task_name="driver-top3",
7 | gnn="GraphSAGE",
8 | )
9 |
10 | print(results['processed_graphs']) # [0, ..., 721]
11 | print(results['total_processed']) # 722
12 | print(results['csv_file']) # ./results/tables/rel-f1/driver-top3/{tag}/GraphSAGE/42.csv
13 |
14 | # Example 2: Run parallelly on multiple GPUs
15 | results_even = run_gnn_node_worker(
16 | dataset_name="rel-f1",
17 | task_name="driver-top3",
18 | gnn="GraphSAGE",
19 | idx=0,
20 | workers=2,
21 | device="cuda:0"
22 | )
23 |
24 | results_odd = run_gnn_node_worker(
25 | dataset_name="rel-f1",
26 | task_name="driver-top3",
27 | gnn="GraphSAGE",
28 | idx=1,
29 | workers=2,
30 | device="cuda:1"
31 | )
32 |
33 | # Example 3: Run worker on specific GNN and target indices
34 | results_0 = run_gnn_node_worker(
35 | dataset_name="rel-f1",
36 | task_name="driver-top3",
37 | gnn="GPS",
38 | target_indices=[0]
39 | )
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Dongwon Choi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/run_idgnn_link.py:
--------------------------------------------------------------------------------
1 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker
2 |
3 | # Example 1: Basic run
4 | results = run_idgnn_link_worker(
5 | dataset_name="rel-avito",
6 | task_name="user-ad-visit",
7 | gnn="GraphSAGE",
8 | )
9 |
10 | print(results['processed_graphs']) # [0, ..., 908]
11 | print(results['total_processed']) # 909
12 | print(results['csv_file']) # ./results/tables/rel-avito/user-ad-visit/{tag}/GraphSAGE/42.csv
13 |
14 | # Example 2: Run parallelly on multiple GPUs
15 | results_even = run_idgnn_link_worker(
16 | dataset_name="rel-avito",
17 | task_name="user-ad-visit",
18 | gnn="GraphSAGE",
19 | idx=0,
20 | workers=2,
21 | device="cuda:0"
22 | )
23 |
24 | results_odd = run_idgnn_link_worker(
25 | dataset_name="rel-avito",
26 | task_name="user-ad-visit",
27 | gnn="GraphSAGE",
28 | idx=1,
29 | workers=2,
30 | device="cuda:1"
31 | )
32 |
33 | # Example 3: Run worker on specific target indices
34 | results_0 = run_idgnn_link_worker(
35 | dataset_name="rel-avito",
36 | task_name="user-ad-visit",
37 | gnn="GraphSAGE",
38 | target_indices=[0]
39 | )
--------------------------------------------------------------------------------
/docs/source/030/llm_baseline.rst:
--------------------------------------------------------------------------------
1 | LLM-based Baseline
2 | ==================
3 |
4 | This module implements Large Language Model (LLM) based baseline.
5 |
6 | How it Works
7 | ------------
8 |
9 | The LLM-based baseline process:
10 |
11 | 1. Provide the LLM with problem description and search space
12 | 2. LLM suggests promising micro actions for graph models
13 | 3. Evaluate suggested micro actions on the actual task
14 | 4. Provide feedback to the LLM about performance
15 | 5. LLM refines its suggestions based on feedback
16 | 6. Repeat until the budget is exhausted
17 |
18 | LLM-Based Baseline
19 | ------------------
20 |
21 | .. automodule:: rdb2g_bench.benchmark.llm.llm_runner
22 | :members:
23 | :undoc-members:
24 | :show-inheritance:
25 |
26 | Example Usage
27 | ~~~~~~~~~~~~~
28 |
29 | .. code-block:: python
30 |
31 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline
32 |
33 | # Run LLM-based baseline
34 | results = run_llm_baseline(
35 | dataset="rel-f1",
36 | task="driver-top3",
37 | gnn="GraphSAGE",
38 | budget_percentage=0.05,
39 | seed=42,
40 | model="claude-3-5-sonnet-latest",
41 | temperature=0.8,
42 | )
43 |
44 |
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/prompts/task.json:
--------------------------------------------------------------------------------
1 | {
2 | "driver-dnf": "For each driver predict the if they will DNF (did not finish) a race in the next 1 month.",
3 | "driver-top3": "For each driver predict if they will qualify in the top-3 for a race in the next 1 month.",
4 | "driver-position": "Predict the average finishing position of each driver all races in the next 2 months.",
5 | "user-repeat": "Predict whether a user will attend an event(by responding yes or maybe) in the next 7 days if they have already attended an event in the last 14 days.",
6 | "user-ignore": "Predict whether a user will ignore more than 2 event invitations in the next 7 days",
7 | "user-attendance": "Predict how many events each user will respond yes or maybe in the next seven days.",
8 | "user-clicks": "Predict whether each customer will click on more than one Ads in the next 4 days.",
9 | "user-visits": "Predict whether each customer will visit more than one Ad in the next 4 days.",
10 | "ad-ctr": "Assuming the Ad will be clicked in the next 4 days, predict the Click-Through-Rate (CTR) for each Ad.",
11 | "user-ad-visit":"Predict the list of ads a user will visit in the next 4 days.",
12 | "post-post-related" :"Predict a list of existing posts that users will link a given post to in the next two years.",
13 | "study-outcome":"Predict if the trials will achieve its primary outcome (defined as p-value < 0.05)."
14 | }
15 |
16 |
--------------------------------------------------------------------------------
/docs/source/030/random_search.rst:
--------------------------------------------------------------------------------
1 | Random Search
2 | =============
3 |
4 | This module implements random search baseline for RDB-to-Graph modeling search.
5 | Random search provides a simple yet effective baseline by uniformly sampling graph models
6 | from the search space without any heuristic guidance, serving as an important comparison
7 | point for evaluating more sophisticated optimization algorithms.
8 |
9 | How it Works
10 | ------------
11 |
12 | The random search algorithm process:
13 |
14 | 1. Uniformly sample graph models from the entire search space
15 | 2. Evaluate the performance of each sampled graph model
16 | 3. Keep track of the best graph model found so far
17 | 4. Repeat until the budget is exhausted
18 |
19 | Random Search Baseline
20 | ----------------------
21 |
22 | .. automodule:: rdb2g_bench.benchmark.baselines.random
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Example Usage
28 | ~~~~~~~~~~~~~
29 |
30 | .. code-block:: python
31 |
32 | from rdb2g_bench.benchmark.runner import run_benchmark
33 |
34 | # Run random search baseline
35 | results = run_benchmark(
36 | dataset="rel-f1",
37 | task="driver-top3",
38 | gnn="GraphSAGE",
39 | budget_percentage=0.05,
40 | method=["random"],
41 | num_runs=10,
42 | seed=42
43 | )
44 |
45 | # Access results
46 | print(f"Best architecture found: {results['random']['selected_graph_id']}")
47 | print(f"Performance: {results['random']['actual_y_perf_of_selected']:.4f}")
48 | print(f"Architectures evaluated: {results['random']['discovered_count']}")
--------------------------------------------------------------------------------
/docs/source/010/tutorial.rst:
--------------------------------------------------------------------------------
1 | Tutorial
2 | ========
3 |
4 | Basic Usage
5 | -----------
6 |
7 | This tutorial provides a quick overview of RDB2G-Bench.
8 |
9 | Quick Start Example
10 | -------------------
11 |
12 | Here's a simple example to get you started:
13 |
14 | .. code-block:: python
15 |
16 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench
17 |
18 | # Load pre-computed benchmark results
19 | bench = load_rdb2g_bench("./results")
20 |
21 | # Access specific results
22 | result = bench['rel-f1']['driver-top3'][0]
23 | test_metric = result['test_metric']
24 | params = result['params']
25 | train_time = result['train_time']
26 |
27 | print(f"Test Metric: {test_metric}")
28 | print(f"Training Time: {train_time}")
29 |
30 | Run First Benchmark
31 | -----------------------------
32 |
33 | .. code-block:: python
34 |
35 | from rdb2g_bench.benchmark.runner import run_benchmark
36 |
37 | # Run a simple benchmark
38 | results = run_benchmark(
39 | dataset="rel-f1",
40 | task="driver-top3",
41 | budget_percentage=0.05,
42 | method="all",
43 | num_runs=3,
44 | seed=42
45 | )
46 |
47 | For more detailed examples, check the ``examples/`` directory in the repository.
48 |
49 | Package Structure Overview
50 | --------------------------
51 |
52 | .. code-block:: text
53 |
54 | rdb2g_bench/
55 | ├── benchmark/ # Core benchmarking functionality
56 | ├── common/ # Shared utilities and search spaces
57 | ├── dataset/ # Dataset loading and processing
58 | └── __init__.py # Package initialization
--------------------------------------------------------------------------------
/docs/source/030/bayesian_optimization.rst:
--------------------------------------------------------------------------------
1 | Bayesian Optimization
2 | =====================
3 |
4 | This module implements Bayesian optimization baseline for RDB-to-Graph modeling search.
5 | Bayesian optimization is a sequential model-based optimization technique particularly
6 | effective for expensive black-box optimization problems like finding optimal graph models.
7 |
8 | How it Works
9 | ------------
10 |
11 | The Bayesian optimization algorithm process:
12 |
13 | 1. Build a probabilistic model (surrogate) of the objective function
14 | 2. Use acquisition function to determine next graph model to evaluate
15 | 3. Evaluate the objective function at the selected graph model
16 | 4. Update the surrogate model with new observation
17 | 5. Repeat until budget is exhausted
18 |
19 | Bayesian Optimization Baseline
20 | ------------------------------
21 |
22 | .. automodule:: rdb2g_bench.benchmark.baselines.bo
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Example Usage
28 | ~~~~~~~~~~~~~
29 |
30 | .. code-block:: python
31 |
32 | from rdb2g_bench.benchmark.runner import run_benchmark
33 |
34 | # Run Bayesian optimization with default parameters
35 | results = run_benchmark(
36 | dataset="rel-f1",
37 | task="driver-top3",
38 | gnn="GraphSAGE",
39 | budget_percentage=0.05,
40 | method=["bo"],
41 | num_runs=10,
42 | seed=42
43 | )
44 |
45 | # Access results
46 | print(f"Best architecture found: {results['bo']['selected_graph_id']}")
47 | print(f"Performance: {results['bo']['actual_y_perf_of_selected']:.4f}")
48 | print(f"Rank: {results['bo']['rank_position_overall']}")
49 |
--------------------------------------------------------------------------------
/docs/source/030/evolutionary_algorithm.rst:
--------------------------------------------------------------------------------
1 | Evolutionary Algorithm
2 | ======================
3 |
4 | This module implements evolutionary algorithm baseline for RDB-to-Graph modeling search.
5 | Evolutionary algorithms are population-based metaheuristics that use biological evolution
6 | mechanisms such as reproduction, mutation, and selection to find optimal solutions.
7 |
8 | How it Works
9 | ------------
10 |
11 | The evolutionary algorithm process:
12 |
13 | 1. Initialize a population of random graph models
14 | 2. Evaluate the performance of each graph model
15 | 3. Select parents for reproduction
16 | 4. Apply crossover and mutation operators to the selected graph models
17 | 5. Replace old population with new generation
18 | 6. Repeat until budget is exhausted
19 |
20 | Evolutionary Algorithm Baseline
21 | -------------------------------
22 |
23 | .. automodule:: rdb2g_bench.benchmark.baselines.ea
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Example Usage
29 | ~~~~~~~~~~~~~
30 |
31 | .. code-block:: python
32 |
33 | from rdb2g_bench.benchmark.runner import run_benchmark
34 |
35 | # Run evolutionary algorithm with default parameters
36 | results = run_benchmark(
37 | dataset="rel-f1",
38 | task="driver-top3",
39 | gnn="GraphSAGE",
40 | budget_percentage=0.05,
41 | method=["ea"],
42 | num_runs=10,
43 | seed=42
44 | )
45 |
46 | # Access results
47 | print(f"Best architecture found: {results['ea']['selected_graph_id']}")
48 | print(f"Performance: {results['ea']['actual_y_perf_of_selected']:.4f}")
49 | print(f"Generations completed: {results['ea']['total_iterations_run']}")
50 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to RDB2G-Bench Documentation!
2 | ========================================
3 |
4 | This is the official documentation of the paper **RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases.**
5 |
6 | **RDB2G-Bench** is an *easy-to-use framework* for benchmarking graph-based analysis and prediction tasks by converting relational database data into graphs.
7 |
8 | Contents
9 | --------
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: Get Started
14 |
15 | 010/install_and_setup
16 | 010/tutorial
17 |
18 | .. toctree::
19 | :maxdepth: 2
20 | :caption: Datasets
21 |
22 | 020/download_precomputed
23 | 020/reproduce_datasets
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 | :caption: Benchmarks
28 |
29 | 030/benchmark_dataset
30 | 030/micro_action
31 | 030/run_benchmark
32 | 030/random_search
33 | 030/greedy_search
34 | 030/evolutionary_algorithm
35 | 030/bayesian_optimization
36 | 030/reinforcement_learning
37 | 030/llm_baseline
38 |
39 | Citation
40 | --------
41 |
42 | If you use RDB2G-Bench in your research, please cite:
43 |
44 | .. code-block:: bibtex
45 |
46 | @article{choi2025rdb2gbench,
47 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases},
48 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin},
49 | year={2025},
50 | url={https://arxiv.org/abs/2506.01360},
51 | }
52 |
53 | License
54 | -------
55 |
56 | This project is distributed under the MIT License as specified in the LICENSE file.
57 |
58 |
59 |
--------------------------------------------------------------------------------
/docs/source/030/reinforcement_learning.rst:
--------------------------------------------------------------------------------
1 | Reinforcement Learning
2 | ======================
3 |
4 | This module implements reinforcement learning baseline for RDB-to-Graph modeling search.
5 | The approach uses deep reinforcement learning with policy gradients to train
6 | a recurrent neural network controller that learns to generate sequences of micro actions
7 | for constructing high-performing graph models.
8 |
9 | How it Works
10 | ------------
11 |
12 | The reinforcement learning algorithm process:
13 |
14 | 1. Initialize RNN-based controller to learn micro actions for graph model optimization
15 | 2. Start from random graph model and convert to state embedding
16 | 3. Controller selects actions
17 | 4. Evaluate the performance of new graph model and compute reward
18 | 5. Update controller policy using Policy Gradient based on discounted rewards
19 | 6. Repeat episodes until the budget is exhausted
20 |
21 | Reinforcement Learning Baseline
22 | -------------------------------
23 |
24 | .. automodule:: rdb2g_bench.benchmark.baselines.rl
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 | Example Usage
30 | ~~~~~~~~~~~~~
31 |
32 | .. code-block:: python
33 |
34 | from rdb2g_bench.benchmark.runner import run_benchmark
35 |
36 | # Run reinforcement learning with default parameters
37 | results = run_benchmark(
38 | dataset="rel-f1",
39 | task="driver-top3",
40 | gnn="GraphSAGE",
41 | budget_percentage=0.05,
42 | method=["rl"],
43 | num_runs=10,
44 | seed=42
45 | )
46 |
47 | # Access results
48 | print(f"Best architecture found: {results['rl']['selected_graph_id']}")
49 | print(f"Performance: {results['rl']['actual_y_perf_of_selected']:.4f}")
50 | print(f"Episodes completed: {results['rl']['total_iterations_run']}")
51 |
--------------------------------------------------------------------------------
/docs/source/030/benchmark_dataset.rst:
--------------------------------------------------------------------------------
1 | Performance Prediction Dataset
2 | ===============================
3 |
4 | This module implements the ``PerformancePredictionDataset`` class, which is used to load performance data for benchmarking RDB-to-Graph search algorithms on RDB2G-Bench.
5 |
6 | The ``PerformancePredictionDataset`` class is the core data interface for RDB2G-Bench. It loads pre-computed performance results from CSV files, processes graph configurations, and provides a unified interface for benchmark algorithms to access performance data.
7 |
8 | How it Works
9 | ------------
10 |
11 | The dataset loading process involves the following steps:
12 |
13 | 1. **RelBench Integration**: Connects to the specified RelBench dataset and task
14 | 2. **Graph Materialization**: Creates heterogeneous graph data with proper embeddings
15 | 3. **Results Loading**: Reads performance data from CSV files for the specified GNN backbone architectures
16 | 4. **Data Aggregation**: Groups results by graph models and aggregates across seeds
17 | 5. **Graph Indexing**: Creates mappings between graph models and search space
18 |
19 | Performance Prediction Dataset
20 | -------------------------------
21 |
22 | .. automodule:: rdb2g_bench.benchmark.dataset
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Example Usage
28 | -------------
29 |
30 | .. code-block:: python
31 |
32 | from rdb2g_bench.benchmark.dataset import PerformancePredictionDataset
33 |
34 | # Initialize dataset
35 | dataset = PerformancePredictionDataset(
36 | dataset_name="rel-f1",
37 | task_name="driver-top3",
38 | gnn="GraphSAGE",
39 | tag="hf",
40 | result_dir="./results"
41 | )
42 |
43 | # Access basic information
44 | print(f"Number of configurations: {len(dataset)}")
45 | print(f"Target metric: {dataset.target_col}")
46 | print(f"Task type: {dataset.task.task_type}")
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/docs/source/030/greedy_search.rst:
--------------------------------------------------------------------------------
1 | Greedy Search
2 | =============
3 |
4 | This module implements multiple greedy search strategies for RDB-to-Graph modeling search.
5 | Greedy algorithms make locally optimal choices at each step, providing fast and deterministic
6 | approaches for finding good graph models with minimal computational overhead.
7 |
8 | How it Works
9 | ------------
10 |
11 | The greedy search algorithm implements three different greedy strategies for graph model optimization:
12 |
13 | 1. **Forward Greedy**: Starts with a graph model only with target table(s) and iteratively moves to the best graph model
14 | 2. **Backward Greedy**: Starts with the All-Rows-to-Nodes(AR2N) graph model and iteratively moves to the best graph model
15 | 3. **Local Greedy**: Combines forward and backward greedy strategies with randomly initialized graph model
16 |
17 | The greedy search algorithm process:
18 |
19 | 1. Starts with a random graph model
20 | 2. Evaluates the performance of the graph model
21 | 3. Selects the best local improvement based on the chosen greedy strategy
22 | 4. Repeats until the budget is exhausted
23 |
24 |
25 | Greedy Search Baseline
26 | ----------------------
27 |
28 | .. automodule:: rdb2g_bench.benchmark.baselines.greedy
29 | :members:
30 | :undoc-members:
31 | :show-inheritance:
32 |
33 | Example Usage
34 | ~~~~~~~~~~~~~
35 |
36 | .. code-block:: python
37 |
38 | from rdb2g_bench.benchmark.runner import run_benchmark
39 |
40 | # Run greedy search by default
41 | results = run_benchmark(
42 | dataset="rel-f1",
43 | task="driver-top3",
44 | gnn="GraphSAGE",
45 | budget_percentage=0.05,
46 | method=["greedy"],
47 | num_runs=10,
48 | seed=42
49 | )
50 |
51 | # Access results
52 | print(f"Best architecture found: {results['greedy']['selected_graph_id']}")
53 | print(f"Performance: {results['greedy']['actual_y_perf_of_selected']:.4f}")
54 | print(f"Greedy steps completed: {results['greedy']['total_iterations_run']}")
55 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "rdb2g-bench"
7 | version = "0.1.1"
8 | description = "A benchmark framework for automatic graph modeling of relational databases"
9 | readme = "README.md"
10 | license = {file = "LICENSE"}
11 | authors = [
12 | {name = "RDB2G-Bench Team", email = "chlehdwon@kaist.ac.kr"}
13 | ]
14 | maintainers = [
15 | {name = "RDB2G-Bench Team", email = "chlehdwon@kaist.ac.kr"}
16 | ]
17 | keywords = ["graph neural networks", "relational databases", "benchmark", "machine learning"]
18 | classifiers = [
19 | "Development Status :: 3 - Alpha",
20 | "Intended Audience :: Developers",
21 | "Intended Audience :: Science/Research",
22 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
23 | "License :: OSI Approved :: MIT License",
24 | "Programming Language :: Python :: 3",
25 | "Programming Language :: Python :: 3.7",
26 | "Programming Language :: Python :: 3.8",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Operating System :: OS Independent",
31 | ]
32 | requires-python = ">=3.7"
33 | dependencies = [
34 | "torch>=2.1.0",
35 | "numpy>=1.26.0",
36 | "pandas>=2.2.0",
37 | "scikit-learn>=1.6.0",
38 | "pytorch-frame>=0.2.5",
39 | "torch_geometric==2.6.1",
40 | "relbench==1.1.0",
41 | "sentence-transformers>=3.3.1",
42 | "matplotlib>=3.10.0",
43 | "seaborn>=0.13.0",
44 | "networkx>=3.1",
45 | "tqdm>=4.65.0",
46 | "anthropic>=0.25.0",
47 | "datasets>=3.6.0"
48 | ]
49 |
50 | [project.urls]
51 | Homepage = "https://github.com/chlehdwon/RDB2G-Bench"
52 | Repository = "https://github.com/chlehdwon/RDB2G-Bench"
53 | Issues = "https://github.com/chlehdwon/RDB2G-Bench/issues"
54 |
55 | [tool.setuptools.packages.find]
56 | include = ["rdb2g_bench*"]
57 |
58 | [tool.setuptools.package-data]
59 | rdb2g_bench = ["*.txt", "*.md"]
--------------------------------------------------------------------------------
/rdb2g_bench/common/search_space/search_space.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from torch_geometric.data import HeteroData
7 | from .gnn_search_space import GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace
8 | from .row2ne_search_space import Row2NESearchSpace
9 |
10 | class TotalSearchSpace():
11 | def __init__(self,
12 | dataset: str,
13 | task: str,
14 | hetero_data: HeteroData,
15 | GNNSearchSpace: Union[GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace],
16 | num_layers: int,
17 | src_entity_table: str,
18 | dst_entity_table: str = None
19 | ):
20 | self.dataset = dataset
21 | self.task = task
22 | self.data = hetero_data
23 | self.row2graph = Row2NESearchSpace(dataset, task, hetero_data)
24 | self.num_layers = num_layers
25 | self.src_entity_table = src_entity_table
26 | self.dst_entity_table = dst_entity_table
27 | self.full_edges = self.row2graph.find_full_edges()
28 | self.gnn_search_space = GNNSearchSpace(
29 | dataset=self.dataset,
30 | task=self.task,
31 | num_layers=self.num_layers,
32 | node_types=self.row2graph.node_types,
33 | edge_types=self.full_edges,
34 | src_entity_table=self.src_entity_table,
35 | dst_entity_table=self.dst_entity_table
36 | )
37 |
38 | def get_possible_edges(self):
39 | return self.row2graph.find_possible_edges()
40 |
41 | def get_full_edges(self):
42 | return self.full_edges
43 |
44 | def generate_all_graphs(self):
45 | return self.gnn_search_space.generate_all_graphs()
46 |
47 | def get_full_graph_idx(self, graphs: List[tuple]) -> int:
48 | graphs = np.array(graphs)
49 | r2n_idx = np.where(np.sum(graphs[:, :len(self.get_possible_edges())], axis=1) == 0)[0]
50 | return -1 if len(r2n_idx) == 0 else r2n_idx[-1]
51 |
52 | def get_data(self, edges: Union[torch.Tensor, tuple]) -> HeteroData:
53 | if type(edges) != torch.Tensor:
54 | edges = torch.tensor(edges)
55 | converted_data = self.row2graph.convert_row_to_edge(edges)
56 | return self.gnn_search_space.get_data(edges, converted_data)
--------------------------------------------------------------------------------
/examples/load_dataset.py:
--------------------------------------------------------------------------------
1 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench
2 |
3 | bench = load_rdb2g_bench(result_dir="./results")
4 |
5 | available = bench.get_available()
6 | print(available)
7 | """
8 | {
9 | 'rel-stack': {
10 | 'post-post-related': ['GraphSAGE', 'GIN', 'GPS']
11 | },
12 | 'rel-event': {
13 | 'user-attendance': ['GraphSAGE', 'GIN', 'GPS'],
14 | 'user-repeat': ['GraphSAGE', 'GIN', 'GPS'],
15 | 'user-ignore': ['GraphSAGE', 'GIN', 'GPS']
16 | },
17 | 'rel-f1': {
18 | 'driver-top3': ['GraphSAGE', 'GIN', 'GPS'],
19 | 'driver-position': ['GraphSAGE', 'GIN', 'GPS'],
20 | 'driver-dnf': ['GraphSAGE', 'GIN', 'GPS']
21 | },
22 | 'rel-trial': {
23 | 'study-outcome': ['GraphSAGE', 'GIN', 'GPS']
24 | },
25 | 'rel-avito': {
26 | 'user-visits': ['GraphSAGE', 'GIN', 'GPS'],
27 | 'ad-ctr': ['GraphSAGE', 'GIN', 'GPS'],
28 | 'user-clicks': ['GraphSAGE', 'GIN', 'GPS'],
29 | 'user-ad-visit': ['GraphSAGE', 'GIN', 'GPS']
30 | }
31 | }
32 | """
33 |
34 | # Access specific task and GNN model
35 | task = bench['rel-f1']['driver-top3']
36 | available_gnns = task.get_available_gnns()
37 | print(f"Available GNNs: {available_gnns}") # ['GraphSAGE', 'GIN', 'GPS']
38 |
39 | # Access specific GNN model
40 | gnn = task['GraphSAGE']
41 | indices = gnn.get_available_indices()
42 | print(f"Available indices for GraphSAGE: {len(indices)}") # 722
43 |
44 | # Get results for specific graph configuration
45 | result = gnn[0]
46 | print(f"Index 0 results: {result.stats}")
47 | """
48 | Index 0: {'test_metric_mean': 0.805233155657748,
49 | 'test_metric_std': 0.034900872955993194,
50 | 'params': 954241,
51 | 'train_time': 0.30655655304590856,
52 | 'valid_time': 0.06576369603474932,
53 | 'test_time': 0.05956562360127762}
54 | """
55 |
56 | # Access aggregated statistics
57 | stats = result.stats
58 | test_metric_mean = stats['test_metric_mean']
59 | test_metric_std = stats['test_metric_std']
60 | params = stats['params']
61 | train_time = stats['train_time']
62 |
63 | print(f"Test metric: {test_metric_mean:.4f} ± {test_metric_std:.4f}") # 0.8052 ± 0.0347
64 | print(f"Parameters: {params:.0f}") # 954241
65 | print(f"Train time: {train_time:.4f} seconds") # 0.3066
66 |
67 | # Compare different GNN models
68 | print("\nComparing GNN models:")
69 | for gnn_name in available_gnns:
70 | gnn_model = task[gnn_name]
71 | result_0 = gnn_model[0]
72 | perf = result_0.stats['test_metric_mean']
73 | print(f"{gnn_name}: {perf:.4f}")
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/prompts/action.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_fk_pk_edge": "Here is the introduction of add_fk_pk_edge:\nDescription:\nCreates a directed edge from one table to another by adding a foreign key (FK) to primary key (PK) relationship. \nUse when you need to represent an important directional relationship between two tables in your graph schema.\nParameters:\n from_table_name: the name of the table containing the foreign key\n from_col_name: the name of the foreign key column in to_table\n to_table_name: the name of the table containing the primary key",
3 | "remove_fk_pk_edge": "Here is the introduction of remove_fk_pk_edge:\nDescription:\nEliminates a directed edge between tables by removing a FK-PK relationship. \nUse when a previously modeled relationship doesn't add meaningful context to your graph structure and should be excluded.\nParameters:\n from_table_name: the name of the table containing the foreign key\n from_col_name: the name of the primary key column in to_table\n to_table_name: the name of the table containing the primary key",
4 | "convert_row_to_edge": "Here is the introduction of convert_row_to_edge:\nDescription:\nTransforms what was originally modeled as an entity table into a relationship edge in your graph. \nUse when an intermediate table (denoted as edge_table_name) better represents a relationship property between two tables (denoted as table_1_name and table_2_name) rather than being an independent entity.\nNote that table_1_name and table_2_name can be equal when the edge_table_name has 2 foreign keys which refer to the same primary key.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the table to convert to edge between table_1_name and table_2_name",
5 | "convert_edge_to_row": "Here is the introduction of convert_edge_to_row:\nDescription:\nTransforms what was modeled as a relationship edge into a proper entity table in your graph. \nUse when an edge contains sufficient attributes and identity to justify becoming an entity table with its own properties.\nNote that table_1_name and table_2_name can be equal when the edge_table_name has 2 foreign keys which refer to the same primary key.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name",
6 | "add_row2edge_edge": "Here is the introduction of add_row2edge_edge:\nDescription:\nCreate a new edge (relationship) between two tables by introducing an edge (join) table.\nUse this when a many-to-many relationship between two tables is needed and meaningful for your graph schema.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name",
7 | "remove_row2edge_edge": "Here is the introduction of remove_row2edge_edge:\nDescription:\nRemove an edge (relationship) between two tables by deleting the edge (join) table.\nUse this when the many-to-many relationship is not necessary for your graph schema.\nParameters:\n table_1_name: the name of the first row table\n table_2_name: the name of the second row table\n edge_table_name: the name of the edge table to convert to rows between table_1_name and table_2_name"
8 | }
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | import os
4 | import sys
5 | sys.path.append(os.path.abspath('.'))
6 | sys.path.append(os.path.abspath('..'))
7 | sys.path.append(os.path.abspath('../..'))
8 |
9 | # ReadTheDocs-specific configuration
10 | on_rtd = os.environ.get('READTHEDOCS') == 'True'
11 | if on_rtd:
12 | # ReadTheDocs automatically installs the package via .readthedocs.yaml
13 | # But we need to ensure the path is correct
14 | sys.path.insert(0, os.path.abspath('../../'))
15 |
16 | # -- Project information
17 |
18 | project = 'RDB2G-Bench'
19 | copyright = 'KAIST Data Mining Lab'
20 | author = 'Dongwon Choi'
21 |
22 | release = '0.1'
23 | version = '0.1.1'
24 |
25 | # -- General configuration
26 |
27 | extensions = [
28 | 'sphinx.ext.duration',
29 | 'sphinx.ext.autodoc',
30 | 'sphinx.ext.autosummary',
31 | 'sphinx.ext.doctest',
32 | 'sphinx.ext.intersphinx',
33 | 'sphinx.ext.todo',
34 | 'sphinx.ext.coverage',
35 | 'sphinx.ext.mathjax',
36 | 'sphinx.ext.ifconfig',
37 | 'sphinx.ext.napoleon',
38 | 'sphinx.ext.ifconfig',
39 | 'sphinx.ext.viewcode',
40 | 'sphinx.ext.githubpages',
41 | 'sphinx.ext.todo',
42 | 'sphinx.ext.napoleon',
43 | ]
44 |
45 | intersphinx_mapping = {
46 | 'python': ('https://docs.python.org/3/', None),
47 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None),
48 | }
49 | intersphinx_disabled_domains = ['std']
50 |
51 | templates_path = ['_templates']
52 |
53 | # -- Options for HTML output
54 |
55 | html_theme = 'sphinx_rtd_theme'
56 |
57 | # Show project logo at the top-left of the docs
58 | html_logo = '../logo.png'
59 |
60 | # HTML theme options for sphinx_rtd_theme
61 | html_theme_options = {
62 | 'prev_next_buttons_location': 'bottom',
63 | 'style_external_links': False,
64 | 'vcs_pageview_mode': '',
65 | 'style_nav_header_background': '#FFFFFF',
66 | 'logo_only': True,
67 | # These options are for the theme but we'll add GitHub link via html_context
68 | }
69 |
70 | # Add GitHub link in the HTML context
71 | html_context = {
72 | 'display_github': True,
73 | 'github_user': 'chlehdwon',
74 | 'github_repo': 'RDB2G-Bench',
75 | 'github_version': 'main',
76 | 'conf_py_path': '/docs/source/',
77 | }
78 |
79 | # -- Autodoc configuration for type hints
80 | autodoc_typehints = 'signature'
81 |
82 | autodoc_mock_imports = [
83 | "torch", "numpy", "pandas", "typing_extensions",
84 | "torch_frame", "torch_geometric", "torch_scatter", "torch_sparse",
85 | "torch_cluster", "torch_spline_conv", "pyg_lib",
86 | "dgl", "sklearn", "scipy", "networkx",
87 | "ogb", "tqdm", "qpth", "quadprog", "cvxpy", "rdkit", "dgllife",
88 | "relbench", "anthropic", "openai",
89 | "typin", "pathlib", "json", "ast", "copy", "itertools", "pickle",
90 | "matplotlib", "seaborn", "plotly", "wandb", "tensorboard",
91 | "transformers", "datasets", "tokenizers", "accelerate",
92 | "optuna", "ray", "hyperopt", "ax-platform",
93 | "psutil", "memory_profiler", "line_profiler",
94 | # ReadTheDocs specific mocks
95 | "torch.nn", "torch.optim", "torch.utils", "torch.cuda"
96 | ]
97 |
98 | if on_rtd:
99 | autodoc_mock_imports.extend([
100 | "torch.nn.functional", "torch.distributions",
101 | "torch_geometric.nn", "torch_geometric.data", "torch_geometric.utils",
102 | "relbench.base", "relbench.datasets", "relbench.modeling", "relbench.tasks"
103 | ])
104 |
105 | # -- Options for EPUB output
106 | epub_show_urls = 'footnote'
107 |
108 | add_module_names = False
--------------------------------------------------------------------------------
/examples/download_dataset.py:
--------------------------------------------------------------------------------
1 | from rdb2g_bench.dataset.dataset import download_rdb2g_bench, get_dataset_stats
2 |
3 | # List all available datasets and tasks with GNN models
4 | df_stats = get_dataset_stats(cache_dir="~/.cache")
5 | print(df_stats)
6 | """
7 | dataset task gnn idx seed test_metric_mean test_metric_std test_metric_min test_metric_max
8 | 0 rel-avito ad-ctr GraphSAGE 1304 5 0.0423 0.0010 0.0374 0.0458
9 | 1 rel-avito ad-ctr GIN 1304 5 0.0419 0.0012 0.0365 0.0461
10 | 2 rel-avito ad-ctr GPS 1304 5 0.0425 0.0009 0.0378 0.0459
11 | 3 rel-avito user-ad-visit GraphSAGE 909 5 0.0165 0.0076 0.0016 0.0372
12 | 4 rel-avito user-ad-visit GIN 909 5 0.0163 0.0074 0.0018 0.0369
13 | 5 rel-avito user-ad-visit GPS 909 5 0.0168 0.0078 0.0014 0.0375
14 | 6 rel-f1 driver-top3 GraphSAGE 722 15 0.7847 0.0150 0.6554 0.8732
15 | 7 rel-f1 driver-top3 GIN 722 15 0.7823 0.0148 0.6521 0.8698
16 | 8 rel-f1 driver-top3 GPS 722 15 0.7891 0.0152 0.6598 0.8756
17 | ...
18 | """
19 |
20 | # Filter statistics by specific GNN model
21 | graphsage_stats = df_stats[df_stats['gnn'] == 'GraphSAGE']
22 | print("\nGraphSAGE performance across datasets:")
23 | print(graphsage_stats[['dataset', 'task', 'test_metric_mean', 'test_metric_std']])
24 |
25 | # Download entire RDB2G-Bench dataset (includes all GNN models)
26 | saved_files = download_rdb2g_bench(
27 | result_dir="./results",
28 | cache_dir="~/.cache",
29 | tag="hf"
30 | )
31 | print(f"\nDownloaded {len(saved_files)} dataset/task combinations")
32 | for combo, files in saved_files.items():
33 | print(f"{combo}: {len(files)} files")
34 |
35 | # Download specific RDB2G-Bench dataset with specific GNN model
36 | saved_files = download_rdb2g_bench(
37 | result_dir="./results",
38 | cache_dir="~/.cache",
39 | dataset_names=["rel-f1"],
40 | task_names=["driver-top3"],
41 | gnn_names=["GIN"],
42 | tag="hf_gin"
43 | )
44 | print(f"\nSpecific GNN model download saved files:")
45 | for combo, files in saved_files.items():
46 | print(f"{combo}: {files}")
--------------------------------------------------------------------------------
/docs/source/020/download_precomputed.rst:
--------------------------------------------------------------------------------
1 | Download Pre-computed Datasets
2 | ===============================
3 |
4 | This module provides functionality to load and access pre-computed benchmark results without running experiments.
5 | The RDB2G-Bench dataset can be downloaded from Hugging Face Hub and accessed through a hierarchical interface.
6 |
7 | The RDB2G-Bench dataset is also available from the HuggingFace: https://huggingface.co/datasets/kaistdata/RDB2G-Bench
8 |
9 | Dataset Module
10 | --------------
11 |
12 | The dataset module provides functions for downloading and loading benchmark data.
13 |
14 | .. automodule:: rdb2g_bench.dataset.dataset
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Dataloader Module
20 | -----------------
21 |
22 | The dataloader module provides classes for hierarchical access to benchmark results.
23 |
24 | .. automodule:: rdb2g_bench.dataset.dataloader
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 | Data Access Pattern
30 | -------------------
31 |
32 | The benchmark data follows a hierarchical access pattern:
33 |
34 | .. code-block:: text
35 |
36 | RDB2GBench[dataset_name][task_name][gnn_name][idx] -> IndexAccessor
37 |
38 | Example:
39 | bench['rel-f1']['driver-top3']['GraphSAGE'][0] -> Results for graph configuration 0 with GraphSAGE
40 |
41 | Directory Structure
42 | ~~~~~~~~~~~~~~~~~~~
43 |
44 | The downloaded data is organized in the following directory structure:
45 |
46 | .. code-block:: text
47 |
48 | results/
49 | tables/
50 | dataset_name/ # e.g., rel-f1, rel-avito
51 | task_name/ # e.g., driver-top3, user-ad-visit
52 | tag/ # e.g., hf
53 | gnn_name/ # e.g., GraphSAGE, GIN, GPS
54 | 0.csv # Results for seed 0
55 | 1.csv # Results for seed 1
56 | ...
57 |
58 | Example Usage
59 | ~~~~~~~~~~~~~
60 |
61 | .. code-block:: python
62 |
63 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench
64 |
65 | # Load benchmark results (downloads if missing)
66 | bench = load_rdb2g_bench("./results")
67 |
68 | # Check available datasets, tasks, and GNN models
69 | available = bench.get_available()
70 | print(available)
71 | # {'rel-f1': {'driver-top3': ['GIN', 'GPS', 'GraphSAGE'], ...}, ...}
72 |
73 | # Access results for rel-f1 dataset, driver-top3 task, GraphSAGE model
74 | result = bench['rel-f1']['driver-top3']['GraphSAGE'][0] # First graph configuration
75 |
76 | # Extract performance metrics (aggregated across seeds)
77 | test_metric = result['test_metric'] # Test performance mean
78 | test_std = result['test_metric_std'] # Test performance std
79 | params = result['params'] # Model parameters
80 | train_time = result['train_time'] # Training time
81 |
82 | CSV File Format
83 | ~~~~~~~~~~~~~~~
84 |
85 | Each CSV file contains the following columns:
86 |
87 | .. code-block:: text
88 |
89 | idx : Graph configuration index
90 | graph : Graph structure representation (e.g., "graph_00000000000010")
91 | train_metric : Training performance metric
92 | valid_metric : Validation performance metric
93 | test_metric : Test performance metric
94 | params : Number of trainable parameters
95 | train_time : Average training time per epoch
96 | valid_time : Validation time
97 | test_time : Test time
98 | dataset : Dataset name
99 | task : Task name
100 | seed : Random seed
101 | gnn : GNN model name (GraphSAGE, GIN, GPS)
102 |
103 | Accessing GNN-specific Results
104 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105 |
106 | The benchmark now supports GNN-specific access patterns:
107 |
108 | .. code-block:: python
109 |
110 | # Access specific GNN models
111 | graphsage_result = bench['rel-f1']['driver-top3']['GraphSAGE'][0]
112 | gin_result = bench['rel-f1']['driver-top3']['GIN'][0]
113 |
114 | # Get available GNN models for a task
115 | available_gnns = bench['rel-f1']['driver-top3'].get_available_gnns()
116 | print(available_gnns) # ['GIN', 'GPS', 'GraphSAGE']
117 |
118 | # Compare performance across GNN models
119 | for gnn in available_gnns:
120 | result = bench['rel-f1']['driver-top3'][gnn][0]
121 | print(f"{gnn}: {result['test_metric']}")
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/docs/source/030/run_benchmark.rst:
--------------------------------------------------------------------------------
1 | Run Benchmark
2 | =============
3 |
4 | This module provides the main interface for running comprehensive benchmarks on RDB2G-Bench.
5 | The benchmark runner executes multiple RDB-to-Graph modeling search algorithms and provides
6 | statistical analysis and comparison across different methods, datasets, and tasks.
7 |
8 | The benchmark system supports all available RDB-to-Graph modeling strategies and automatically handles
9 | data preparation, caching, multi-run execution, and results aggregation.
10 |
11 | Benchmark Runner Interface
12 | --------------------------
13 |
14 | .. automodule:: rdb2g_bench.benchmark.bench_runner
15 | :members:
16 | :undoc-members:
17 | :show-inheritance:
18 |
19 | Core Benchmark Engine
20 | ---------------------
21 |
22 | .. automodule:: rdb2g_bench.benchmark.benchmark
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | Example Usage
28 | ~~~~~~~~~~~~~
29 |
30 | Basic Benchmark Execution
31 | ^^^^^^^^^^^^^^^^^^^^^^^^^^
32 |
33 | .. code-block:: python
34 |
35 | from rdb2g_bench.benchmark.bench_runner import run_benchmark
36 |
37 | # Run all methods on default dataset and task with specific GNN
38 | results = run_benchmark(
39 | dataset="rel-f1",
40 | task="driver-top3",
41 | gnn="GraphSAGE",
42 | budget_percentage=0.05,
43 | method="all",
44 | num_runs=10,
45 | seed=42
46 | )
47 |
48 | # Print summary of results
49 | for method, stats in results.items():
50 | if 'avg_actual_y_perf_of_selected' in stats:
51 | perf = stats['avg_actual_y_perf_of_selected']
52 | rank = stats.get('avg_rank_position_overall', 'N/A')
53 | print(f"{method}: Performance={perf:.4f}, Rank={rank}")
54 |
55 | Specific Methods Comparison
56 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
57 |
58 | .. code-block:: python
59 |
60 | # Compare evolutionary algorithm vs greedy search
61 | results = run_benchmark(
62 | dataset="rel-avito",
63 | task="user-ad-visit",
64 | gnn="GIN",
65 | budget_percentage=0.05,
66 | method=["ea", "greedy", "random"],
67 | num_runs=15,
68 | seed=123
69 | )
70 |
71 | # Extract key metrics
72 | for method in ["ea", "greedy", "random"]:
73 | if method in results:
74 | stats = results[method]
75 | print(f"\n{method.upper()} Results:")
76 | print(f" Average Performance: {stats['avg_actual_y_perf_of_selected']:.4f}")
77 | print(f" Average Rank: {stats['avg_rank_position_overall']:.1f}")
78 | print(f" Average Runtime: {stats['avg_run_time']:.2f}s")
79 |
80 | Available Methods
81 | -----------------
82 |
83 | The benchmark runner supports the following search methods:
84 |
85 | - **"all"**: Run all available methods
86 | - **"ea"**: Evolutionary Algorithm baseline
87 | - **"greedy"**: Greedy Search strategies (forward, backward, random)
88 | - **"rl"**: Reinforcement Learning with policy gradients
89 | - **"bo"**: Bayesian Optimization with surrogate models
90 |
91 | You can specify a single method as a string or multiple methods as a list.
92 |
93 | Results Structure
94 | -----------------
95 |
96 | The returned results dictionary contains method-wise statistics:
97 |
98 | .. code-block:: python
99 |
100 | {
101 | "Method Name": {
102 | "avg_actual_y_perf_of_selected": float, # Average performance
103 | "avg_rank_position_overall": float, # Average ranking
104 | "avg_percentile_overall": float, # Average percentile
105 | "total_samples_overall": int, # Total architectures
106 | "selected_graph_ids_runs": list, # Selected graph IDs
107 | "avg_selection_metric_value": float, # Average selection metric
108 | "selected_graph_origins": list, # Method origins
109 | "avg_evaluation_time": float, # Average evaluation time
110 | "avg_run_time": float # Average total runtime
111 | }
112 | }
113 |
114 | Output Files
115 | ------------
116 |
117 | The benchmark automatically generates several output files:
118 |
119 | - **Individual Trajectories**: ``avg_trajectory_{method}_{gnn}_{num_runs}runs.csv``
120 | - **Combined Trajectories**: ``all_methods_trajectories_{gnn}_{num_runs}runs.csv``
121 | - **Performance Summary**: ``performance_summary_{gnn}_{num_runs}runs.csv``
122 |
123 | These files are saved in: ``{result_dir}/benchmark/{dataset}/{task}/{tag}/{gnn}/``
124 |
--------------------------------------------------------------------------------
/docs/source/020/reproduce_datasets.rst:
--------------------------------------------------------------------------------
1 | Reproduce Datasets
2 | ==================
3 |
4 | This module provides functionality to reproduce the datasets and benchmark results from the RDB2G-Bench paper.
5 | The workers support both node-level tasks (classification/regression) and link prediction tasks (recommendation)
6 | with various GNN architectures and comprehensive evaluation metrics.
7 |
8 | Node Worker
9 | -----------
10 |
11 | The node worker handles node classification and regression tasks using Graph Neural Networks.
12 |
13 | .. automodule:: rdb2g_bench.dataset.node_worker
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 | Link Worker
19 | -----------
20 |
21 | The link worker handles link prediction tasks using ID-aware Graph Neural Networks (IDGNN).
22 |
23 | .. automodule:: rdb2g_bench.dataset.link_worker
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Task Types
29 | ----------
30 |
31 | The benchmark supports different types of tasks:
32 |
33 | Node-Level Tasks
34 | ~~~~~~~~~~~~~~~~
35 |
36 | - **Binary Classification**: Predicting binary labels for nodes (e.g., driver performance)
37 | - **Regression**: Predicting continuous values for nodes (e.g., ratings, scores)
38 | - **Multilabel Classification**: Predicting multiple binary labels per node
39 |
40 | Link Prediction Tasks
41 | ~~~~~~~~~~~~~~~~~~~~~
42 |
43 | - **Recommendation**: Predicting user-item interactions using ranking metrics
44 | - **Link Prediction**: General link prediction between entities
45 |
46 | Supported Models
47 | ----------------
48 |
49 | Both workers support multiple GNN architectures:
50 |
51 | - **GraphSAGE**: Inductive graph representation learning
52 | - **GIN**: Graph Isomorphism Network for graph-level tasks
53 | - **GPS**: Graph transformer with positional encodings
54 |
55 | Example Usage
56 | ----------------
57 |
58 | Node Classification/Regression
59 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60 |
61 | .. code-block:: python
62 |
63 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker
64 |
65 | # Run binary classification experiment
66 | results = run_gnn_node_worker(
67 | dataset_name="rel-f1",
68 | task_name="driver-top3",
69 | gnn="GraphSAGE",
70 | epochs=20,
71 | lr=0.005,
72 | batch_size=512,
73 | channels=128,
74 | )
75 | print(f"Processed {results['total_processed']} graph configurations")
76 |
77 | # Run regression experiment
78 | results = run_gnn_node_worker(
79 | dataset_name="rel-f1",
80 | task_name="driver-position",
81 | gnn="GIN",
82 | epochs=50,
83 | lr=0.001,
84 | weight_decay=1e-4
85 | )
86 |
87 | # Parallel processing with multiple workers
88 | results = run_gnn_node_worker(
89 | dataset_name="rel-f1",
90 | task_name="driver-top3",
91 | idx=0, # Worker 0 ~ 3
92 | workers=4, # Total 4 workers
93 | epochs=20
94 | )
95 |
96 | Link Prediction
97 | ~~~~~~~~~~~~~~~
98 |
99 | .. code-block:: python
100 |
101 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker
102 |
103 | # Run recommendation experiment
104 | results = run_idgnn_link_worker(
105 | dataset_name="rel-avito",
106 | task_name="user-ad-visit",
107 | gnn="GraphSAGE",
108 | epochs=20,
109 | lr=0.001,
110 | batch_size=512,
111 | channels=128,
112 | temporal_strategy="last",
113 | )
114 |
115 | # Run with specific graph configurations
116 | results = run_idgnn_link_worker(
117 | dataset_name="rel-avito",
118 | task_name="user-ad-visit",
119 | target_indices=[0, 5, 10, 15], # Run only these configurations
120 | epochs=30,
121 | patience=10
122 | )
123 |
124 |
125 | Parallel Processing
126 | ~~~~~~~~~~~~~~~~~~~
127 |
128 | .. code-block:: python
129 |
130 | import multiprocessing as mp
131 | from concurrent.futures import ProcessPoolExecutor
132 |
133 | def run_worker(worker_id, total_workers):
134 | """Run a single worker process."""
135 | return run_gnn_node_worker(
136 | dataset_name="rel-f1",
137 | task_name="driver-top3",
138 | idx=worker_id,
139 | workers=total_workers,
140 | epochs=20,
141 | save_csv=True
142 | )
143 |
144 | # Run multiple workers in parallel
145 | num_workers = 4
146 | with ProcessPoolExecutor(max_workers=num_workers) as executor:
147 | futures = [
148 | executor.submit(run_worker, i, num_workers)
149 | for i in range(num_workers)
150 | ]
151 |
152 | # Collect results
153 | all_results = [future.result() for future in futures]
154 | total_processed = sum(r['total_processed'] for r in all_results)
155 | print(f"Total graphs processed: {total_processed}")
156 |
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def integrate_edge_tf(batch, edge_tf_dict):
5 | r2e_edge_types = []
6 | rev_r2e_edge_types = []
7 |
8 | for edge_type in batch.edge_types:
9 | src, rel, dst = edge_type
10 | if rel.startswith('r2e'):
11 | r2e_edge_types.append(edge_type)
12 | elif rel.startswith('rev_r2e'):
13 | rev_r2e_edge_types.append(edge_type)
14 |
15 | for edge_type in edge_tf_dict:
16 | del batch[edge_type]
17 |
18 | table_to_edge_types = {}
19 | for edge_type in r2e_edge_types:
20 | src, rel, dst = edge_type
21 | table_name = rel[4:]
22 | if table_name not in table_to_edge_types:
23 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []}
24 | table_to_edge_types[table_name]['r2e'].append(edge_type)
25 |
26 | for edge_type in rev_r2e_edge_types:
27 | src, rel, dst = edge_type
28 | table_name = rel[8:]
29 | if table_name not in table_to_edge_types:
30 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []}
31 | table_to_edge_types[table_name]['rev_r2e'].append(edge_type)
32 |
33 | for table_name, edge_types in table_to_edge_types.items():
34 | if table_name not in edge_tf_dict:
35 | continue
36 |
37 | all_mapped_ids = set()
38 | r2e_mapped_ids_dict = {}
39 | rev_r2e_mapped_ids_dict = {}
40 |
41 | for edge_type in edge_types['r2e']:
42 | if 'mapped_node_ids' in batch[edge_type] and len(batch[edge_type]['mapped_node_ids']) > 0:
43 | mapped_ids = batch[edge_type]['mapped_node_ids'].cpu().numpy()
44 | all_mapped_ids.update(mapped_ids)
45 | r2e_mapped_ids_dict[edge_type] = mapped_ids
46 |
47 | for edge_type in edge_types['rev_r2e']:
48 | if 'mapped_node_ids' in batch[edge_type] and len(batch[edge_type]['mapped_node_ids']) > 0:
49 | mapped_ids = batch[edge_type]['mapped_node_ids'].cpu().numpy()
50 | all_mapped_ids.update(mapped_ids)
51 | rev_r2e_mapped_ids_dict[edge_type] = mapped_ids
52 |
53 | if not all_mapped_ids:
54 | continue
55 |
56 | all_mapped_ids = sorted(list(all_mapped_ids))
57 | mapped_tensor = torch.tensor(all_mapped_ids, dtype=torch.long)
58 |
59 | batch[table_name]['tf'] = edge_tf_dict[table_name][mapped_tensor]
60 |
61 | max_id = max(all_mapped_ids) + 1
62 | id_to_idx_array = np.full(max_id, -1)
63 | for idx, id_val in enumerate(all_mapped_ids):
64 | id_to_idx_array[id_val] = idx
65 |
66 | for edge_type in edge_types['r2e']:
67 | if edge_type in r2e_mapped_ids_dict:
68 | old_mapped_ids = r2e_mapped_ids_dict[edge_type]
69 | new_indices = torch.tensor(id_to_idx_array[old_mapped_ids], dtype=torch.long)
70 | batch[edge_type]['mapped_node_ids'] = new_indices
71 |
72 | for edge_type in edge_types['rev_r2e']:
73 | if edge_type in rev_r2e_mapped_ids_dict:
74 | old_mapped_ids = rev_r2e_mapped_ids_dict[edge_type]
75 | new_indices = torch.tensor(id_to_idx_array[old_mapped_ids], dtype=torch.long)
76 | batch[edge_type]['mapped_node_ids'] = new_indices
77 |
78 | return batch
79 |
80 | def divide_node_edge_dict(batch, x_dict):
81 | table_to_edge_types = {}
82 | for edge_type in batch.edge_types:
83 | src, rel, dst = edge_type
84 | if rel.startswith('r2e'):
85 | table_name = rel[4:]
86 | if table_name not in table_to_edge_types:
87 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []}
88 | table_to_edge_types[table_name]['r2e'].append(edge_type)
89 | elif rel.startswith('rev_r2e'):
90 | table_name = rel[8:]
91 | if table_name not in table_to_edge_types:
92 | table_to_edge_types[table_name] = {'r2e': [], 'rev_r2e': []}
93 | table_to_edge_types[table_name]['rev_r2e'].append(edge_type)
94 | edge_dict = {}
95 | tables_to_remove = set()
96 | for table_name, edge_types in table_to_edge_types.items():
97 | if table_name in x_dict:
98 | table_features = x_dict[table_name]
99 | for edge_type in edge_types['r2e'] + edge_types['rev_r2e']:
100 | assert hasattr(batch[edge_type], 'mapped_node_ids')
101 | mapped_ids = batch[edge_type]['mapped_node_ids']
102 | edge_dict[edge_type] = table_features[mapped_ids]
103 | tables_to_remove.add(table_name)
104 |
105 | for table_name in tables_to_remove:
106 | if table_name in x_dict:
107 | del x_dict[table_name]
108 |
109 | return x_dict, edge_dict
--------------------------------------------------------------------------------
/docs/source/030/micro_action.rst:
--------------------------------------------------------------------------------
1 | Micro Actions
2 | =============
3 |
4 | This module defines the core micro actions used by all optimization algorithms for RDB-to-Graph modeling search.
5 | These atomic operations enable systematic exploration of the graph model space by transforming
6 | one valid graph model into another.
7 |
8 | How it Works
9 | ------------
10 |
11 | Micro actions represent atomic operations for transforming graph models:
12 |
13 | 1. **add_fk_pk_edge**: Add foreign key-primary key edge between tables
14 | 2. **remove_fk_pk_edge**: Remove foreign key-primary key edge between tables
15 | 3. **convert_row_to_edge**: Convert row representation to edge representation
16 | 4. **convert_edge_to_row**: Convert edge representation to row representation
17 |
18 | Each action transforms the current graph model (edge set) to a new valid graph model, enabling systematic exploration of the graph model space.
19 |
20 | Micro Action Set
21 | ----------------
22 |
23 | .. automodule:: rdb2g_bench.benchmark.micro_action
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Example Usage
29 | ~~~~~~~~~~~~~
30 |
31 | Example with Data Preparation
32 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
33 |
34 | .. code-block:: python
35 |
36 | import os
37 | import json
38 | from pathlib import Path
39 | from torch_frame import stype
40 | from torch_frame.config.text_embedder import TextEmbedderConfig
41 | from relbench.datasets import get_dataset
42 | from relbench.tasks import get_task
43 | from relbench.modeling.graph import make_pkey_fkey_graph
44 | from relbench.modeling.utils import get_stype_proposal
45 |
46 | from rdb2g_bench.benchmark.micro_action import MicroActionSet
47 | from rdb2g_bench.common.search_space.gnn_search_space import GNNNodeSearchSpace
48 | from rdb2g_bench.common.text_embedder import GloveTextEmbedding
49 |
50 | # Step 1: Load dataset and task
51 | dataset_name = "rel-f1"
52 | task_name = "driver-top3"
53 |
54 | dataset = get_dataset(dataset_name, download=True)
55 | task = get_task(dataset_name, task_name, download=True)
56 |
57 | # Step 2: Prepare column type information
58 | cache_dir = os.path.expanduser("~/.cache/relbench_examples")
59 | stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json")
60 |
61 | try:
62 | with open(stypes_cache_path, "r") as f:
63 | col_to_stype_dict = json.load(f)
64 | for table, col_to_stype in col_to_stype_dict.items():
65 | for col, stype_str in col_to_stype.items():
66 | col_to_stype[col] = stype(stype_str)
67 | except FileNotFoundError:
68 | col_to_stype_dict = get_stype_proposal(dataset.get_db())
69 | Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
70 | with open(stypes_cache_path, "w") as f:
71 | json.dump(col_to_stype_dict, f, indent=2, default=str)
72 |
73 | # Step 3: Create heterogeneous graph data
74 | device = "cuda" if torch.cuda.is_available() else "cpu"
75 | hetero_data, col_stats_dict = make_pkey_fkey_graph(
76 | dataset.get_db(),
77 | col_to_stype_dict=col_to_stype_dict,
78 | text_embedder_cfg=TextEmbedderConfig(
79 | text_embedder=GloveTextEmbedding(device=device),
80 | batch_size=256
81 | ),
82 | cache_dir=f"{cache_dir}/{dataset_name}/materialized",
83 | )
84 |
85 | # Step 4: Initialize micro action set
86 | micro_actions = MicroActionSet(
87 | dataset=dataset_name,
88 | task=task_name,
89 | hetero_data=hetero_data,
90 | GNNSpaceClass=GNNNodeSearchSpace,
91 | num_layers=2,
92 | src_entity_table=task.entity_table
93 | )
94 |
95 | print(f"Total number of valid graph models: {len(micro_actions.valid_edge_sets_list)}")
96 | print(f"Number of FK-PK edges: {len(micro_actions.fk_pk_indices)}")
97 | print(f"Number of R2E edges: {len(micro_actions.r2e_indices)}")
98 |
99 | Basic Micro Action Operations
100 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
101 |
102 | .. code-block:: python
103 |
104 | # Get the first valid edge set as starting point
105 | current_edge_set = micro_actions.valid_edge_sets_list[0]
106 | print(f"Starting edge set: {current_edge_set}")
107 |
108 | # Explore all possible FK-PK edge additions
109 | add_fk_actions = micro_actions.add_fk_pk_edge(current_edge_set)
110 | print(f"Possible FK-PK additions: {len(add_fk_actions)}")
111 |
112 | for new_set, index in add_fk_actions[:3]: # Show first 3
113 | print(f" Add action: {current_edge_set} -> {new_set} (index: {index})")
114 |
115 | # Explore FK-PK edge removals
116 | remove_fk_actions = micro_actions.remove_fk_pk_edge(current_edge_set)
117 | print(f"Possible FK-PK removals: {len(remove_fk_actions)}")
118 |
119 | # Explore row-to-edge conversions
120 | row_to_edge_actions = micro_actions.convert_row_to_edge(current_edge_set)
121 | print(f"Possible row-to-edge conversions: {len(row_to_edge_actions)}")
122 |
123 | # Explore edge-to-row conversions
124 | edge_to_row_actions = micro_actions.convert_edge_to_row(current_edge_set)
125 | print(f"Possible edge-to-row conversions: {len(edge_to_row_actions)}")
126 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ----
6 |
7 | [](https://github.com/chlehdwon/RDB2G-Bench/releases)
8 | [](https://rdb2g-bench.readthedocs.io/en/latest/)
9 | [](https://huggingface.co/datasets/kaistdata/RDB2G-Bench)
10 | [](https://arxiv.org/abs/2506.01360)
11 | [](https://opensource.org/licenses/MIT)
12 |
13 | This is the official implementation of the paper **RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases.**
14 |
15 | **RDB2G-Bench** is an **easy-to-use framework** for benchmarking graph-based analysis and prediction tasks by converting relational database data into graphs.
16 |
17 | ## 🚀 Installation
18 |
19 | ```bash
20 | git clone https://github.com/chlehdwon/RDB2G-Bench.git
21 | cd RDB2G-Bench
22 | pip install -e .
23 | ```
24 |
25 | Also, please install additional PyG dependencies. The below shows an example when you use torch 2.1.0 + cuda 12.1.
26 |
27 | ```bash
28 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
29 | ```
30 | You can skip this part if you don't want to reproduce our dataset.
31 |
32 | ## ⚡ Package Usage
33 |
34 | Comprehensive documentation and detailed guides are available at [our documentation site](https://rdb2g-bench.readthedocs.io/en/latest/).
35 |
36 | You can also check the `examples/` directory for complete usage examples and tutorials.
37 |
38 | ### Download Pre-computed Datasets
39 |
40 | ```python
41 | from rdb2g_bench.dataset.dataset import load_rdb2g_bench
42 |
43 | bench = load_rdb2g_bench("./results")
44 |
45 | result = bench['rel-f1']['driver-top3'][0] # Access by graph index
46 | test_metric = result['test_metric'] # Test performance
47 | params = result['params'] # Model parameters
48 | train_time = result['train_time'] # Train time
49 | ```
50 |
51 | ### Reproduce Datasets for Classification & Regression Tasks
52 |
53 | ```python
54 | from rdb2g_bench.dataset.node_worker import run_gnn_node_worker
55 |
56 | results = run_gnn_node_worker(
57 | dataset_name="rel-f1",
58 | task_name="driver-top3",
59 | gnn="GraphSAGE",
60 | epochs=20,
61 | lr=0.005
62 | )
63 | ```
64 |
65 | ### Reproduce Datasets for Recommendation Tasks
66 |
67 | ```python
68 | from rdb2g_bench.dataset.link_worker import run_idgnn_link_worker
69 |
70 | results = run_idgnn_link_worker(
71 | dataset_name="rel-avito",
72 | task_name="user-ad-visit",
73 | gnn="GraphSAGE",
74 | epochs=20,
75 | lr=0.001
76 | )
77 | ```
78 |
79 | ### Run Benchmarks
80 |
81 | ```python
82 | from rdb2g_bench.benchmark.bench_runner import run_benchmark
83 |
84 | results = run_benchmark(
85 | dataset="rel-f1",
86 | task="driver-top3",
87 | gnn="GraphSAGE",
88 | budget_percentage=0.05,
89 | method="all",
90 | num_runs=10,
91 | seed=0
92 | )
93 | ```
94 |
95 | ### Run LLM-based baseline
96 |
97 | Before using LLM-based baseline, you need to set up your API key:
98 |
99 | ```bash
100 | export ANTHROPIC_API_KEY="YOUR_API_KEY"
101 | ```
102 |
103 | ```python
104 | from rdb2g_bench.benchmark.llm.llm_runner import run_llm_baseline
105 |
106 | results = run_llm_baseline(
107 | dataset="rel-f1",
108 | task="driver-top3",
109 | gnn="GraphSAGE",
110 | budget_percentage=0.05,
111 | model="claude-3-5-sonnet-latest",
112 | temperature=0.8,
113 | seed=42
114 | )
115 | ```
116 |
117 | ## 📁 Package Structure
118 |
119 | ```
120 | rdb2g_bench/
121 | ├── benchmark/ # Core benchmarking functionality
122 | │ ├── llm/ # LLM-based baseline methods
123 | │ └── baselines/ # Other baseline methods
124 | ├── common/ # Shared utilities and search spaces
125 | ├── dataset/ # Dataset loading and processing
126 | └── __init__.py # Package initialization
127 | ```
128 |
129 | ## 📖 Reference
130 |
131 | The dataset construction and implementation of RDB2G-Bench is based on the [RelBench](https://github.com/snap-stanford/relbench) framework.
132 |
133 | ## 📝 Citation
134 |
135 | If you use RDB2G-Bench in your research, please cite:
136 |
137 | ```bibtex
138 | @inproceedings{choi2025rdb2gbench,
139 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases},
140 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin},
141 | year={2025},
142 | booktitle={NeurIPS},
143 | }
144 | ```
145 | or
146 | ```bibtex
147 | @article{choi2025rdb2gbench,
148 | title={RDB2G-Bench: A Comprehensive Benchmark for Automatic Graph Modeling of Relational Databases},
149 | author={Dongwon Choi and Sunwoo Kim and Juyeon Kim and Kyungho Kim and Geon Lee and Shinhwan Kang and Myunghwan Kim and Kijung Shin},
150 | year={2025},
151 | url={https://arxiv.org/abs/2506.01360},
152 | }
153 | ```
154 |
155 | ## ⚖️ License
156 |
157 | This project is distributed under the MIT License as specified in the LICENSE file.
158 |
159 |
160 |
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/models/modules/sage_conv_edge.py:
--------------------------------------------------------------------------------
1 | # Custom implementation of SAGEConv with edge features based on torch_geometric.nn.conv.SAGEConv
2 |
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 |
9 | from torch_geometric.nn.aggr import Aggregation, MultiAggregation
10 | from torch_geometric.nn.conv import MessagePassing
11 | from torch_geometric.nn.dense.linear import Linear
12 | from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor, PairTensor
13 | from torch_geometric.utils import spmm
14 |
15 |
16 | class SAGEConvEdge(MessagePassing):
17 | r"""The GraphSAGE operator with edge features support.
18 |
19 | This implementation extends the original GraphSAGE to handle edge features by:
20 | 1. Transforming edge features separately
21 | 2. Combining node and edge features in the message passing phase
22 | 3. Supporting both directed and undirected edges
23 |
24 | Args:
25 | in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
26 | derive the size from the first input(s) to the forward method.
27 | A tuple corresponds to the sizes of source and target
28 | dimensionalities.
29 | out_channels (int): Size of each output sample.
30 | normalize (bool, optional): If set to :obj:`False`, output features
31 | will not be :math:`\ell_2`-normalized. (default: :obj:`True`)
32 | root_weight (bool, optional): If set to :obj:`False`, the layer will
33 | not add the transformed root node features to the output.
34 | (default: :obj:`True`)
35 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
36 | an additive bias. (default: :obj:`True`)
37 | edge_feature_channels (int, optional): Size of edge features. If not provided,
38 | will use the same size as node features. (default: :obj:`None`)
39 | **kwargs (optional): Additional arguments of
40 | :class:`torch_geometric.nn.conv.MessagePassing`.
41 | """
42 |
43 | def __init__(
44 | self,
45 | in_channels: Union[int, Tuple[int, int]],
46 | out_channels: int,
47 | aggr: Optional[Union[str, List[str], Aggregation]] = "mean",
48 | normalize: bool = False,
49 | root_weight: bool = True,
50 | project: bool = False,
51 | bias: bool = True,
52 | **kwargs,
53 | ):
54 | self.in_channels = in_channels
55 | self.out_channels = out_channels
56 | self.normalize = normalize
57 | self.root_weight = root_weight
58 | self.project = project
59 |
60 | if isinstance(in_channels, int):
61 | in_channels = (in_channels, in_channels)
62 |
63 | super().__init__(aggr, **kwargs)
64 |
65 | if self.project:
66 | if in_channels[0] <= 0:
67 | raise ValueError(f"'{self.__class__.__name__}' does not "
68 | f"support lazy initialization with "
69 | f"`project=True`")
70 | self.lin = Linear(in_channels[0], in_channels[0], bias=True)
71 |
72 | if isinstance(self.aggr_module, MultiAggregation):
73 | aggr_out_channels = self.aggr_module.get_out_channels(
74 | in_channels[0])
75 | else:
76 | # node features + edge features
77 | aggr_out_channels = in_channels[0] * 2
78 |
79 | # Linear layer for node features
80 | self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)
81 | if self.root_weight:
82 | self.lin_r = Linear(in_channels[1], out_channels, bias=False)
83 |
84 | self.reset_parameters()
85 |
86 | def reset_parameters(self):
87 | super().reset_parameters()
88 | if self.project:
89 | self.lin.reset_parameters()
90 | self.lin_l.reset_parameters()
91 | if self.root_weight:
92 | self.lin_r.reset_parameters()
93 |
94 | def forward(
95 | self,
96 | x: Union[Tensor, PairTensor],
97 | edge_index: Adj,
98 | edge_attr: Optional[Tensor] = None,
99 | size: Size = None,
100 | ) -> Tensor:
101 | if isinstance(x, Tensor):
102 | x = (x, x)
103 |
104 | if self.project and hasattr(self, 'lin'):
105 | x = (self.lin(x[0]).relu(), x[1])
106 |
107 | # propagate_type: (x: PairTensor, edge_attr: OptTensor)
108 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
109 | out = self.lin_l(out)
110 |
111 | x_r = x[1]
112 | if self.root_weight and x_r is not None:
113 | out += self.lin_r(x_r)
114 |
115 | if self.normalize:
116 | out = F.normalize(out, p=2., dim=-1)
117 |
118 | return out
119 |
120 | def message(
121 | self,
122 | x_j: Tensor,
123 | edge_attr: Optional[Tensor] = None,
124 | ) -> Tensor:
125 | if edge_attr is None:
126 | edge_attr = torch.zeros_like(x_j)
127 | return torch.cat([x_j, edge_attr], dim=-1)
128 |
129 | def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
130 | if isinstance(adj_t, SparseTensor):
131 | adj_t = adj_t.set_value(None, layout=None)
132 | return spmm(adj_t, x[0], reduce=self.aggr)
133 |
134 | def __repr__(self) -> str:
135 | return (f'{self.__class__.__name__}({self.in_channels}, '
136 | f'{self.out_channels}, aggr={self.aggr})')
137 |
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/models/model.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/snap-stanford/relbench/blob/main/examples/model.py
2 |
3 | from typing import Any, Dict, List
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import Embedding, ModuleDict
8 | from torch_frame.data.stats import StatType
9 | from torch_geometric.data import HeteroData
10 | from torch_geometric.nn import MLP
11 | from torch_geometric.typing import NodeType
12 |
13 | from relbench.modeling.nn import HeteroEncoder, HeteroTemporalEncoder, HeteroGraphSAGE
14 | from .modules.nn import HeteroGraphSAGE, HeteroGIN, HeteroGPS
15 | from ..utils import divide_node_edge_dict
16 |
17 | class Model(torch.nn.Module):
18 | def __init__(
19 | self,
20 | data: HeteroData,
21 | col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
22 | num_layers: int,
23 | channels: int,
24 | out_channels: int,
25 | aggr: str,
26 | norm: str,
27 | gnn: str = 'GraphSAGE',
28 | # List of node types to add shallow embeddings to input
29 | shallow_list: List[NodeType] = [],
30 | # ID awareness
31 | id_awareness: bool = False,
32 | ):
33 | super().__init__()
34 |
35 | self.encoder = HeteroEncoder(
36 | channels=channels,
37 | node_to_col_names_dict={
38 | node_type: data[node_type].tf.col_names_dict
39 | for node_type in data.node_types
40 | },
41 | node_to_col_stats=col_stats_dict,
42 | )
43 |
44 | # Create temporal encoder if time attributes exist
45 | self.temporal_encoder = HeteroTemporalEncoder(
46 | node_types=[
47 | node_type for node_type in data.node_types if "time" in data[node_type]
48 | ],
49 | channels=channels,
50 | )
51 |
52 | # Initialize GNN based on specified type
53 | if gnn == "GraphSAGE":
54 | self.gnn = HeteroGraphSAGE(
55 | node_types=data.node_types,
56 | edge_types=data.edge_types,
57 | channels=channels,
58 | aggr=aggr,
59 | num_layers=num_layers,
60 | )
61 | elif gnn == "GIN":
62 | self.gnn = HeteroGIN(
63 | node_types=data.node_types,
64 | edge_types=data.edge_types,
65 | channels=channels,
66 | )
67 | elif gnn == "GPS":
68 | self.gnn = HeteroGPS(
69 | node_types=data.node_types,
70 | edge_types=data.edge_types,
71 | channels=channels,
72 | )
73 | else:
74 | raise ValueError(f"Unknown GNN type: {gnn}")
75 |
76 | self.head = MLP(
77 | channels,
78 | out_channels=out_channels,
79 | norm=norm,
80 | num_layers=1,
81 | )
82 |
83 | self.embedding_dict = ModuleDict(
84 | {
85 | node: Embedding(data.num_nodes_dict[node], channels)
86 | for node in shallow_list
87 | }
88 | )
89 | self.id_awareness_emb = None
90 | if id_awareness:
91 | self.id_awareness_emb = torch.nn.Embedding(1, channels)
92 | self.reset_parameters()
93 |
94 | def reset_parameters(self):
95 | self.encoder.reset_parameters()
96 | self.temporal_encoder.reset_parameters()
97 | self.gnn.reset_parameters()
98 | self.head.reset_parameters()
99 | for embedding in self.embedding_dict.values():
100 | torch.nn.init.normal_(embedding.weight, std=0.1)
101 | if self.id_awareness_emb is not None:
102 | self.id_awareness_emb.reset_parameters()
103 |
104 | def forward(
105 | self,
106 | batch: HeteroData,
107 | entity_table: NodeType,
108 | ) -> Tensor:
109 | seed_time = batch[entity_table].seed_time if hasattr(batch[entity_table], "seed_time") else None
110 | x_dict = self.encoder(batch.tf_dict)
111 | if seed_time is not None:
112 | rel_time_dict = self.temporal_encoder(
113 | seed_time, batch.time_dict, batch.batch_dict
114 | )
115 | for node_type, rel_time in rel_time_dict.items():
116 | x_dict[node_type] = x_dict[node_type] + rel_time
117 |
118 | for node_type, embedding in self.embedding_dict.items():
119 | x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)
120 |
121 | x_dict, edge_dict = divide_node_edge_dict(batch, x_dict)
122 |
123 | x_dict = self.gnn(
124 | x_dict,
125 | edge_dict,
126 | batch.edge_index_dict,
127 | batch.num_sampled_nodes_dict,
128 | batch.num_sampled_edges_dict,
129 | )
130 |
131 | if seed_time is not None:
132 | return self.head(x_dict[entity_table][: seed_time.size(0)])
133 | else:
134 | return self.head(x_dict[entity_table][: batch[entity_table]['batch_size']])
135 |
136 | def forward_dst_readout(
137 | self,
138 | batch: HeteroData,
139 | entity_table: NodeType,
140 | dst_table: NodeType,
141 | ) -> Tensor:
142 | if self.id_awareness_emb is None:
143 | raise RuntimeError(
144 | "id_awareness must be set True to use forward_dst_readout"
145 | )
146 | seed_time = batch[entity_table].seed_time if hasattr(batch[entity_table], "seed_time") else None
147 | x_dict = self.encoder(batch.tf_dict)
148 | if seed_time is not None:
149 | x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight
150 | rel_time_dict = self.temporal_encoder(
151 | seed_time, batch.time_dict, batch.batch_dict
152 | )
153 | for node_type, rel_time in rel_time_dict.items():
154 | x_dict[node_type] = x_dict[node_type] + rel_time
155 |
156 | for node_type, embedding in self.embedding_dict.items():
157 | x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)
158 |
159 | x_dict, edge_dict = divide_node_edge_dict(batch, x_dict)
160 |
161 | x_dict = self.gnn(
162 | x_dict,
163 | edge_dict,
164 | batch.edge_index_dict,
165 | )
166 |
167 | return self.head(x_dict[dst_table])
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/bench_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, List, Union, Any
3 |
4 | from .benchmark import main as benchmark_main
5 |
6 |
7 | def run_benchmark(
8 | dataset: str = "rel-f1",
9 | task: str = "driver-top3",
10 | gnn: str = "GraphSAGE",
11 | budget_percentage: float = 0.05,
12 | method: Union[str, List[str]] = "all",
13 | num_runs: int = 10,
14 | seed: int = 0,
15 | **kwargs
16 | ) -> Dict[str, Any]:
17 | """
18 | Execute comprehensive benchmark analysis on specified dataset and task.
19 |
20 | This function provides a high-level interface for running neural architecture search
21 | benchmarks on RDB2G-Bench. It supports multiple optimization methods and automatically
22 | handles data preparation, model training, evaluation, and results aggregation across
23 | multiple runs for statistical robustness.
24 |
25 | The benchmark process includes:
26 |
27 | 1. Dataset and task preparation with proper caching
28 | 2. Search space initialization with micro actions
29 | 3. Performance prediction dataset setup
30 | 4. Multiple independent runs with different random seeds
31 | 5. Statistical analysis and results aggregation
32 | 6. Trajectory analysis and CSV export for visualization
33 |
34 | Args:
35 | dataset (str): Name of the RelBench dataset to benchmark on.
36 | Available datasets include "rel-f1", "rel-avito", "rel-amazon", etc.
37 | Defaults to "rel-f1".
38 | task (str): Name of the RelBench task to evaluate.
39 | Task names depend on the dataset (e.g., "driver-top3", "user-ad-visit").
40 | Defaults to "driver-top3".
41 | gnn (str): Name of the GNN model to benchmark (e.g., "GraphSAGE", "GIN", "GPS").
42 | Defaults to "GraphSAGE".
43 | budget_percentage (float): Budget percentage for search algorithms as fraction
44 | of total search space (0.0-1.0). Higher values allow more thorough search
45 | but increase computational cost. Defaults to 0.05 (5%).
46 | method (Union[str, List[str]]): Search method(s) to benchmark.
47 | Options: "all", "ea", "greedy", "rl", "bo", "random", or list of methods.
48 | "all" runs all available methods. Defaults to "all".
49 | num_runs (int): Number of independent runs for statistical analysis.
50 | More runs provide better statistical confidence but increase runtime.
51 | Defaults to 10.
52 | seed (int): Base random seed for reproducibility. Each run uses seed + run_index.
53 | Defaults to 0.
54 | **kwargs: Additional configuration parameters:
55 |
56 | - tag (str): Experiment tag for result organization. Defaults to "hf".
57 | - cache_dir (str): Directory for caching datasets and models.
58 | Defaults to "~/.cache/relbench_examples".
59 | - result_dir (str): Root directory for saving results and trajectories.
60 | Defaults to "./results".
61 |
62 | Returns:
63 | Dict[str, Any]: Dictionary containing comprehensive benchmark results.
64 |
65 | - avg_actual_y_perf_of_selected (float): Average performance across runs
66 | - avg_rank_position_overall (float): Average ranking position
67 | - avg_percentile_overall (float): Average percentile ranking
68 | - selected_graph_ids_runs (List[int]): Graph IDs selected in each run
69 | - avg_evaluation_time (float): Average time spent on evaluations
70 | - avg_run_time (float): Average total runtime per run
71 | - method_wise_statistics (Dict): Detailed statistics for each method
72 | - performance_trajectories (Dict): Performance over time for each method
73 | - statistical_comparisons (Dict): Rankings and comparisons between methods
74 |
75 | Example:
76 | >>> # Run all methods on default dataset/task with specific GNN
77 | >>> results = run_benchmark(
78 | ... dataset="rel-f1",
79 | ... task="driver-top3",
80 | ... gnn="GraphSAGE",
81 | ... budget_percentage=0.05,
82 | ... method="all",
83 | ... num_runs=10,
84 | ... seed=42
85 | ... )
86 | >>>
87 | >>> # Print aggregated results
88 | >>> for method, stats in results.items():
89 | ... if 'avg_actual_y_perf_of_selected' in stats:
90 | ... print(f"{method}: {stats['avg_actual_y_perf_of_selected']:.4f}")
91 |
92 | >>> # Run specific methods with custom configuration
93 | >>> results = run_benchmark(
94 | ... dataset="rel-avito",
95 | ... task="user-ad-visit",
96 | ... gnn="GIN",
97 | ... budget_percentage=0.10,
98 | ... method=["ea", "greedy", "rl"],
99 | ... num_runs=5,
100 | ... tag="custom_experiment",
101 | ... cache_dir="/custom/cache",
102 | ... result_dir="/custom/results"
103 | ... )
104 |
105 | >>> # Run quick test with single method
106 | >>> results = run_benchmark(
107 | ... gnn="GPS",
108 | ... method="all",
109 | ... num_runs=10,
110 | ... budget_percentage=0.05
111 | ... )
112 |
113 | Note:
114 | - Results are automatically saved to CSV files for further analysis
115 | - Performance trajectories are exported for visualization
116 | - All intermediate data is cached to speed up repeated experiments
117 | - Large datasets may require significant memory and storage
118 | """
119 |
120 | if isinstance(method, str):
121 | method = [method]
122 |
123 | default_params = {
124 | "tag": "hf",
125 | "cache_dir": "~/.cache/relbench_examples",
126 | "result_dir": "./results",
127 | }
128 |
129 | params = {**default_params, **kwargs}
130 |
131 | class Args:
132 | def __init__(self, **kwargs):
133 | for key, value in kwargs.items():
134 | setattr(self, key, value)
135 |
136 | args = Args(
137 | dataset=dataset,
138 | task=task,
139 | gnn=gnn,
140 | budget_percentage=budget_percentage,
141 | method=method,
142 | num_runs=num_runs,
143 | seed=seed,
144 | **params
145 | )
146 |
147 | results = benchmark_main(args)
148 |
149 | return results
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/baselines/random.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import time
5 | from typing import Dict, Union, List, Optional
6 |
7 | from ..dataset import PerformancePredictionDataset
8 | from ..micro_action import MicroActionSet
9 | from .utils import calculate_overall_rank, get_performance_for_index, update_trajectory_and_best, pad_trajectory, calculate_evaluation_time
10 |
11 | def random_heuristic_analysis(
12 | dataset: PerformancePredictionDataset,
13 | micro_action_set: MicroActionSet,
14 | overall_actual_y: torch.Tensor,
15 | higher_is_better: bool,
16 | termination_threshold_ratio: float,
17 | method_name: str = "Random Heuristic",
18 | ):
19 | """
20 | Perform Neural Architecture Search using Random Sampling Strategy.
21 |
22 | This function implements pure random search that uniformly samples graph neural
23 | network architectures from the entire valid search space. Each architecture is
24 | evaluated independently without any guidance from previous evaluations, providing
25 | an unbiased baseline for comparison with other optimization methods.
26 |
27 |
28 | Args:
29 | dataset (PerformancePredictionDataset): Dataset containing architecture
30 | performance data
31 | micro_action_set (MicroActionSet): Set of micro actions defining the
32 | architecture search space
33 | overall_actual_y (torch.Tensor): Complete performance tensor for
34 | ranking calculations
35 | higher_is_better (bool): Whether higher performance values are better
36 | termination_threshold_ratio (float): Fraction of total architectures to
37 | evaluate as budget
38 | method_name (str): Name identifier for this method.
39 | Defaults to "Random Heuristic".
40 |
41 | Returns:
42 | Dict[str, Union[str, int, float, List, Optional[int]]]: Dictionary containing search results and performance metrics.
43 |
44 | - method (str): Method name
45 | - selected_graph_id (Optional[int]): Index of best found architecture
46 | - actual_y_perf_of_selected (float): Performance of selected architecture
47 | - selection_metric_value (float): Metric value used for selection
48 | - selected_graph_origin (str): Origin method name
49 | - discovered_count (int): Number of architectures evaluated
50 | - total_iterations_run (int): Number of random samples drawn
51 | - rank_position_overall (float): Rank among all architectures
52 | - percentile_overall (float): Percentile ranking
53 | - total_samples_overall (int): Total available architectures
54 | - performance_trajectory (List): Performance over time
55 | - total_evaluation_time (float): Time spent on evaluations
56 | - total_run_time (float): Total algorithm runtime
57 |
58 | Example:
59 | >>> results = random_heuristic_analysis(
60 | ... dataset=dataset,
61 | ... micro_action_set=micro_actions,
62 | ... overall_actual_y=y_tensor,
63 | ... higher_is_better=True,
64 | ... termination_threshold_ratio=0.05
65 | ... )
66 | >>> print(f"Best architecture: {results['selected_graph_id']}")
67 | >>> print(f"Performance: {results['actual_y_perf_of_selected']:.4f}")
68 | >>> print(f"Evaluated: {results['discovered_count']} architectures")
69 | """
70 | performance_cache = {}
71 | time_cache = {}
72 | performance_trajectory = []
73 | total_evaluated_count = 0
74 | best_perf_so_far = float('-inf') if higher_is_better else float('inf')
75 | best_index_so_far = -1
76 | method_origin = "Random Heuristic"
77 | total_samples_overall = overall_actual_y.numel() if overall_actual_y is not None else 0
78 | total_evaluation_time = 0.0
79 |
80 | start_time = time.time()
81 |
82 | num_total_valid_graphs = len(micro_action_set.valid_edge_sets_list)
83 | if num_total_valid_graphs == 0:
84 | print(f"Error: No valid graphs available in MicroActionSet. Cannot proceed.")
85 | return {
86 | "method": method_name, "selected_graph_id": None, "actual_y_perf_of_selected": np.nan,
87 | "selection_metric_value": np.nan, "selected_graph_origin": method_origin,
88 | "discovered_count": 0, "total_iterations_run": 0,
89 | "rank_position_overall": np.nan, "percentile_overall": np.nan,
90 | "total_samples_overall": total_samples_overall, "performance_trajectory": [],
91 | "total_evaluation_time": 0.0, "total_run_time": 0.0
92 | }
93 |
94 | evaluation_budget = max(1, int(termination_threshold_ratio * num_total_valid_graphs))
95 | print(f"{method_name}: Total valid graphs: {num_total_valid_graphs}. Budget set to {evaluation_budget} unique evaluations.")
96 |
97 | all_valid_indices = list(range(num_total_valid_graphs))
98 | shuffled_indices = random.sample(all_valid_indices, k=num_total_valid_graphs)
99 |
100 | for index in shuffled_indices:
101 | if total_evaluated_count >= evaluation_budget:
102 | print(f"{method_name}: Budget reached ({total_evaluated_count}/{evaluation_budget}).")
103 | break
104 |
105 | initial_cache_size = len(performance_cache)
106 | perf = get_performance_for_index(index, dataset, performance_cache)
107 |
108 | if perf is not None:
109 | performance_cache[index] = perf
110 |
111 | eval_time = calculate_evaluation_time(index, dataset, time_cache)
112 | if eval_time is not None:
113 | if len(performance_cache) > initial_cache_size:
114 | total_evaluation_time += eval_time
115 |
116 | total_evaluated_count, best_perf_so_far, best_index_so_far = \
117 | update_trajectory_and_best(
118 | index, perf, performance_cache, initial_cache_size,
119 | total_evaluated_count, performance_trajectory,
120 | best_perf_so_far, best_index_so_far, higher_is_better
121 | )
122 |
123 | print(f"{method_name}: Finished evaluating {total_evaluated_count} graphs.")
124 |
125 | total_run_time = time.time() - start_time
126 |
127 | pad_trajectory(performance_trajectory, total_evaluated_count, evaluation_budget, method_name)
128 |
129 | final_selected_perf = best_perf_so_far if best_index_so_far != -1 else np.nan
130 | final_selected_index = best_index_so_far if best_index_so_far != -1 else None
131 |
132 | results = {
133 | "method": method_name,
134 | "selected_graph_id": final_selected_index,
135 | "actual_y_perf_of_selected": final_selected_perf,
136 | "selection_metric_value": final_selected_perf,
137 | "selected_graph_origin": method_origin,
138 | "discovered_count": total_evaluated_count,
139 | "total_iterations_run": total_evaluated_count,
140 | "rank_position_overall": np.nan,
141 | "percentile_overall": np.nan,
142 | "total_samples_overall": total_samples_overall,
143 | "performance_trajectory": performance_trajectory,
144 | "total_evaluation_time": total_evaluation_time,
145 | "total_run_time": total_run_time
146 | }
147 |
148 | if final_selected_index is not None and not np.isnan(final_selected_perf) and overall_actual_y is not None:
149 | rank_info = calculate_overall_rank(
150 | final_selected_perf,
151 | overall_actual_y,
152 | higher_is_better
153 | )
154 | if rank_info:
155 | results["rank_position_overall"] = rank_info["rank_position_overall"]
156 | results["percentile_overall"] = rank_info["percentile_overall"]
157 |
158 | return results
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/models/modules/nn.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv
6 | from torch_geometric.typing import EdgeType, NodeType
7 | from .sage_conv_edge import SAGEConvEdge
8 | from .gin_conv import GINConv, GINEConv
9 | from .gps_conv import GPSConv
10 |
11 |
12 | class HeteroGraphSAGE(torch.nn.Module):
13 | def __init__(
14 | self,
15 | node_types: List[NodeType],
16 | edge_types: List[EdgeType],
17 | channels: int,
18 | aggr: str = "mean",
19 | num_layers: int = 2,
20 | ):
21 | super().__init__()
22 |
23 | self.convs = torch.nn.ModuleList()
24 | for _ in range(num_layers):
25 | conv = HeteroConv(
26 | {
27 | edge_type: SAGEConvEdge((channels, channels), channels, aggr=aggr)
28 | if edge_type[1].startswith('r2e_') or edge_type[1].startswith('rev_r2e_')
29 | else SAGEConv((channels, channels), channels, aggr=aggr)
30 | for edge_type in edge_types
31 | },
32 | aggr="sum",
33 | )
34 | self.convs.append(conv)
35 |
36 | self.norms = torch.nn.ModuleList()
37 | for _ in range(num_layers):
38 | norm_dict = torch.nn.ModuleDict()
39 | for node_type in node_types:
40 | norm_dict[node_type] = LayerNorm(channels, mode="node")
41 | self.norms.append(norm_dict)
42 |
43 | def reset_parameters(self):
44 | for conv in self.convs:
45 | conv.reset_parameters()
46 | for norm_dict in self.norms:
47 | for norm in norm_dict.values():
48 | norm.reset_parameters()
49 |
50 | def forward(
51 | self,
52 | x_dict: Dict[NodeType, Tensor],
53 | edge_dict: Dict[str, Tensor],
54 | edge_index_dict: Dict[EdgeType, Tensor],
55 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
56 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
57 | ) -> Dict[NodeType, Tensor]:
58 | for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
59 | edge_attr_dict = {}
60 |
61 | for edge_type in edge_index_dict.keys():
62 | if edge_type in edge_dict:
63 | edge_attr_dict[edge_type] = edge_dict[edge_type]
64 |
65 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
66 |
67 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
68 | x_dict = {key: x.relu() for key, x in x_dict.items()}
69 |
70 | return x_dict
71 |
72 | class HeteroGIN(torch.nn.Module):
73 | def __init__(
74 | self,
75 | node_types: List[NodeType],
76 | edge_types: List[EdgeType],
77 | channels: int,
78 | aggr: str = "mean",
79 | num_layers: int = 2,
80 | ):
81 | super().__init__()
82 |
83 | self.convs = torch.nn.ModuleList()
84 | for _ in range(num_layers):
85 | nn_gin = torch.nn.Sequential(
86 | torch.nn.Linear(channels, channels * 2),
87 | torch.nn.ReLU(),
88 | torch.nn.Linear(channels * 2, channels),
89 | )
90 | conv = HeteroConv(
91 | {
92 | edge_type: GINEConv(nn=nn_gin, train_eps=True)
93 | if edge_type[1].startswith('r2e_') or edge_type[1].startswith('rev_r2e_')
94 | else GINConv(nn=nn_gin, train_eps=True)
95 | for edge_type in edge_types
96 | },
97 | aggr="sum",
98 | )
99 | self.convs.append(conv)
100 |
101 | self.norms = torch.nn.ModuleList()
102 | for _ in range(num_layers):
103 | norm_dict = torch.nn.ModuleDict()
104 | for node_type in node_types:
105 | norm_dict[node_type] = LayerNorm(channels, mode="node")
106 | self.norms.append(norm_dict)
107 |
108 | def reset_parameters(self):
109 | for conv in self.convs:
110 | conv.reset_parameters()
111 | for norm_dict in self.norms:
112 | for norm in norm_dict.values():
113 | norm.reset_parameters()
114 |
115 | def forward(
116 | self,
117 | x_dict: Dict[NodeType, Tensor],
118 | edge_dict: Dict[str, Tensor],
119 | edge_index_dict: Dict[EdgeType, Tensor],
120 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
121 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
122 | ) -> Dict[NodeType, Tensor]:
123 | for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
124 | edge_attr_dict = {}
125 |
126 | for edge_type in edge_index_dict.keys():
127 | if edge_type in edge_dict:
128 | edge_attr_dict[edge_type] = edge_dict[edge_type]
129 |
130 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
131 |
132 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
133 | x_dict = {key: x.relu() for key, x in x_dict.items()}
134 |
135 | return x_dict
136 |
137 | class HeteroGPS(torch.nn.Module):
138 | def __init__(
139 | self,
140 | node_types: List[NodeType],
141 | edge_types: List[EdgeType],
142 | channels: int,
143 | num_layers: int = 2,
144 | heads: int = 1,
145 | dropout: float = 0.0,
146 | attn_type: str = 'multihead',
147 | aggr: str = "mean",
148 | ):
149 | super().__init__()
150 |
151 | self.convs = torch.nn.ModuleList()
152 | for _ in range(num_layers):
153 | conv_dict = {}
154 | for edge_type in edge_types:
155 | local_mpnn = SAGEConvEdge(
156 | in_channels=(channels, channels),
157 | out_channels=channels,
158 | aggr=aggr
159 | )
160 | gps_layer = GPSConv(
161 | channels=channels,
162 | conv=local_mpnn,
163 | heads=heads,
164 | dropout=dropout,
165 | attn_type=attn_type,
166 | norm=None,
167 | )
168 | conv_dict[edge_type] = gps_layer
169 |
170 | conv = HeteroConv(conv_dict, aggr="sum")
171 | self.convs.append(conv)
172 |
173 | self.norms = torch.nn.ModuleList()
174 | for _ in range(num_layers):
175 | norm_dict = torch.nn.ModuleDict()
176 | for node_type in node_types:
177 | norm_dict[node_type] = LayerNorm(channels, mode="node")
178 | self.norms.append(norm_dict)
179 |
180 | def reset_parameters(self):
181 | for conv in self.convs:
182 | conv.reset_parameters()
183 | for norm_dict in self.norms:
184 | for norm in norm_dict.values():
185 | norm.reset_parameters()
186 |
187 | def forward(
188 | self,
189 | x_dict: Dict[NodeType, Tensor],
190 | edge_dict: Dict[str, Tensor],
191 | edge_index_dict: Dict[EdgeType, Tensor],
192 | num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
193 | num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
194 | ) -> Dict[NodeType, Tensor]:
195 | for conv, norm_dict in zip(self.convs, self.norms):
196 | edge_attr_dict = {}
197 | for edge_type in edge_index_dict.keys():
198 | if edge_type in edge_dict:
199 | edge_attr_dict[edge_type] = edge_dict[edge_type]
200 |
201 | x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
202 |
203 | x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
204 | x_dict = {key: x.relu() for key, x in x_dict.items()}
205 |
206 | return x_dict
207 |
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/models/modules/gps_conv.py:
--------------------------------------------------------------------------------
1 | # Custom implementation of GPSConv based on torch_geometric.nn.conv.GPSConv
2 |
3 | import inspect
4 | from typing import Any, Dict, Optional, Union
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from torch.nn import Dropout, Linear, Sequential
10 |
11 | from torch_geometric.nn.attention import PerformerAttention
12 | from torch_geometric.nn.conv import MessagePassing
13 | from torch_geometric.nn.inits import reset
14 | from torch_geometric.nn.resolver import (
15 | activation_resolver,
16 | normalization_resolver,
17 | )
18 | from torch_geometric.typing import Adj, PairTensor
19 | from torch_geometric.utils import to_dense_batch
20 |
21 |
22 | class GPSConv(torch.nn.Module):
23 | r"""The general, powerful, scalable (GPS) graph transformer layer from the
24 | `"Recipe for a General, Powerful, Scalable Graph Transformer"
25 | `_ paper.
26 |
27 | The GPS layer is based on a 3-part recipe:
28 |
29 | 1. Inclusion of positional (PE) and structural encodings (SE) to the input
30 | features (done in a pre-processing step via
31 | :class:`torch_geometric.transforms`).
32 | 2. A local message passing layer (MPNN) that operates on the input graph.
33 | 3. A global attention layer that operates on the entire graph.
34 |
35 | .. note::
36 |
37 | For an example of using :class:`GPSConv`, see
38 | `examples/graph_gps.py
39 | `_.
41 |
42 | Args:
43 | channels (int): Size of each input sample.
44 | conv (MessagePassing, optional): The local message passing layer.
45 | heads (int, optional): Number of multi-head-attentions.
46 | (default: :obj:`1`)
47 | dropout (float, optional): Dropout probability of intermediate
48 | embeddings. (default: :obj:`0.`)
49 | act (str or Callable, optional): The non-linear activation function to
50 | use. (default: :obj:`"relu"`)
51 | act_kwargs (Dict[str, Any], optional): Arguments passed to the
52 | respective activation function defined by :obj:`act`.
53 | (default: :obj:`None`)
54 | norm (str or Callable, optional): The normalization function to
55 | use. (default: :obj:`"batch_norm"`)
56 | norm_kwargs (Dict[str, Any], optional): Arguments passed to the
57 | respective normalization function defined by :obj:`norm`.
58 | (default: :obj:`None`)
59 | attn_type (str): Global attention type, :obj:`multihead` or
60 | :obj:`performer`. (default: :obj:`multihead`)
61 | attn_kwargs (Dict[str, Any], optional): Arguments passed to the
62 | attention layer. (default: :obj:`None`)
63 | """
64 | def __init__(
65 | self,
66 | channels: int,
67 | conv: Optional[MessagePassing],
68 | heads: int = 1,
69 | dropout: float = 0.0,
70 | act: str = 'relu',
71 | act_kwargs: Optional[Dict[str, Any]] = None,
72 | norm: Optional[str] = 'batch_norm',
73 | norm_kwargs: Optional[Dict[str, Any]] = None,
74 | attn_type: str = 'multihead',
75 | attn_kwargs: Optional[Dict[str, Any]] = None,
76 | ):
77 | super().__init__()
78 |
79 | self.channels = channels
80 | self.conv = conv
81 | self.heads = heads
82 | self.dropout = dropout
83 | self.attn_type = attn_type
84 |
85 | attn_kwargs = attn_kwargs or {}
86 | if attn_type == 'multihead':
87 | self.attn = torch.nn.MultiheadAttention(
88 | channels,
89 | heads,
90 | batch_first=True,
91 | **attn_kwargs,
92 | )
93 | elif attn_type == 'performer':
94 | self.attn = PerformerAttention(
95 | channels=channels,
96 | heads=heads,
97 | **attn_kwargs,
98 | )
99 | else:
100 | # TODO: Support BigBird
101 | raise ValueError(f'{attn_type} is not supported')
102 |
103 | self.mlp = Sequential(
104 | Linear(channels, channels * 2),
105 | activation_resolver(act, **(act_kwargs or {})),
106 | Dropout(dropout),
107 | Linear(channels * 2, channels),
108 | Dropout(dropout),
109 | )
110 |
111 | norm_kwargs = norm_kwargs or {}
112 | self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
113 | self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
114 | self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)
115 |
116 | self.norm_with_batch = False
117 | if self.norm1 is not None:
118 | signature = inspect.signature(self.norm1.forward)
119 | self.norm_with_batch = 'batch' in signature.parameters
120 |
121 | def reset_parameters(self):
122 | r"""Resets all learnable parameters of the module."""
123 | if self.conv is not None:
124 | self.conv.reset_parameters()
125 | self.attn._reset_parameters()
126 | reset(self.mlp)
127 | if self.norm1 is not None:
128 | self.norm1.reset_parameters()
129 | if self.norm2 is not None:
130 | self.norm2.reset_parameters()
131 | if self.norm3 is not None:
132 | self.norm3.reset_parameters()
133 |
134 | def forward(
135 | self,
136 | x: Union[Tensor, PairTensor],
137 | edge_index: Adj,
138 | batch: Optional[torch.Tensor] = None,
139 | edge_attr: Optional[Tensor] = None,
140 | **kwargs,
141 | ) -> Tensor:
142 | r"""Runs the forward pass of the module."""
143 | hs = []
144 |
145 | # Determine input for residual connections and global attention based on x type
146 | x_for_residual_and_global = x[1] if isinstance(x, tuple) else x
147 |
148 | if self.conv is not None: # Local MPNN.
149 | # Pass edge_attr to the local message passing layer
150 | h_local = self.conv(x, edge_index, edge_attr=edge_attr, **kwargs)
151 | h_local = F.dropout(h_local, p=self.dropout, training=self.training)
152 | h_local = h_local + x_for_residual_and_global
153 | if self.norm1 is not None:
154 | if self.norm_with_batch:
155 | h_local = self.norm1(h_local, batch=batch)
156 | else:
157 | h_local = self.norm1(h_local)
158 | hs.append(h_local)
159 |
160 | # Global attention transformer-style model.
161 | # Use x_for_residual_and_global for global attention input
162 | h_global_in, mask = to_dense_batch(x_for_residual_and_global, batch)
163 |
164 | if isinstance(self.attn, torch.nn.MultiheadAttention):
165 | h_global, _ = self.attn(h_global_in, h_global_in, h_global_in, key_padding_mask=~mask,
166 | need_weights=False)
167 | elif isinstance(self.attn, PerformerAttention):
168 | h_global = self.attn(h_global_in, mask=mask)
169 |
170 | h_global = h_global[mask]
171 | h_global = F.dropout(h_global, p=self.dropout, training=self.training)
172 | h_global = h_global + x_for_residual_and_global # Residual connection.
173 | if self.norm2 is not None:
174 | if self.norm_with_batch:
175 | h_global = self.norm2(h_global, batch=batch)
176 | else:
177 | h_global = self.norm2(h_global)
178 | hs.append(h_global)
179 |
180 | out = sum(hs) # Combine local and global outputs.
181 |
182 | out = out + self.mlp(out)
183 | if self.norm3 is not None:
184 | if self.norm_with_batch:
185 | out = self.norm3(out, batch=batch)
186 | else:
187 | out = self.norm3(out)
188 |
189 | return out
190 |
191 | def __repr__(self) -> str:
192 | return (f'{self.__class__.__name__}({self.channels}, '
193 | f'conv={self.conv}, heads={self.heads}, '
194 | f'attn_type={self.attn_type})')
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/prompts/prompt.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def get_score_feedback(initial_score, past_score, current_score, higher_is_better, last_action_num):
4 | score_feedback = f"In history actions, after the last {last_action_num} actions, the score has changed from {past_score:.4f} to {current_score:.4f}.\n"
5 | score_feedback += (
6 | "Since a **higher** score is better, " if higher_is_better else "Since a **lower** score is better, "
7 | )
8 | if (current_score > past_score and higher_is_better) or (current_score < past_score and not higher_is_better):
9 | score_feedback += f"the performance has **improved**."
10 | # if score has improved but still lower than initial score
11 | if (initial_score == current_score) or (initial_score > current_score and higher_is_better) or (initial_score < current_score and not higher_is_better):
12 | score_feedback += f"\nHowever, the initial score was {initial_score:.4f}, Please try other actions to improve the performance."
13 | else:
14 | score_feedback += "the performance has **decreased**. Please consider either reversing the previous action or exploring alternative actions to improve the schema."
15 | return score_feedback
16 |
17 |
18 | def get_action_info(actions= [ 'add_fk_pk_edge','remove_fk_pk_edge', 'convert_edge_to_row', 'convert_row_to_edge' ], #'add_row2edge_edge', 'remove_row2edge_edge',
19 | edge_info={"add_fk_pk_edge": "", "remove_fk_pk_edge": "", "convert_row_to_edge": "", "convert_edge_to_row": ""}): #, 'add_row2edge_edge', 'remove_row2edge_edge'
20 | with open("./prompts/action.json", "r") as f:
21 | action_info = json.load(f)
22 | action = ""
23 | for i, action_name in enumerate(actions):
24 | if edge_info[action_name] != "":
25 | action += f"{action_info[action_name]}\n{edge_info[action_name]}\n\n"
26 | return action
27 |
28 | def augmentation_prompt(dataset_name, task_name, edge_info, history_actions="", error_msg="", past_score=0, current_score=0, higher_is_better=True, initial_score=0, budget=10, initial_attempt=True, last_action_num=1):
29 | with open("./prompts/data_stats.json", "r") as f:
30 | stats = json.load(f)
31 | stats = stats[dataset_name]
32 | with open("./prompts/schema.json", "r") as f:
33 | schema = json.load(f)
34 | schema = schema[dataset_name]
35 | with open("./prompts/task.json", "r") as f:
36 | task = json.load(f)
37 | task = task[task_name]
38 |
39 | action_info = get_action_info(edge_info=edge_info)
40 | score_feedback = ""
41 |
42 |
43 | if error_msg != "":
44 | error_msg = f"Warning: The following actions will cause errors: \n{error_msg}"
45 |
46 |
47 | if initial_attempt:
48 | return f"""You are expected to construct graph schema based on the original inputs.
49 | You will be given an original schema represented in the dictionary format:
50 |
51 | 1. dataset_name: name of the dataset
52 | 2. tables: meta data for list of tables, each one will present following attributes
53 | 1. name: table name
54 | 2. columns: list of columns, each column will have following attributes
55 | 1. name: column name
56 | 2. dtype: column type, can be either text, categorical, float, primary_key, foreign_key, or multi_category. primary_key and foreign_key are two special types of categorical columns, which presents a structural relationship with other tables. Multi_category means this column is of list type, and each cell main contains a list of categorical values. After a column is set as primary_key or foreign_key, it should not be changed to other types.
57 | 3. link_to (optional): if this column is a foreign key, point to which primary key from which table
58 | 3. statistics of the table: statistics of the column value of tables. These statistics can be used to help you determine the characteristics of the columns.
59 |
60 |
61 | Here are the documents of the actions:
62 | {action_info}
63 | {error_msg}
64 |
65 | Now, you need to:
66 |
67 | 1. Actively think about which actions (from the list below) should be conducted to improve the schema.
68 | 2. Output all actions you can think of from the above list to make the schema better, and output your selections in the following format:
69 | If multiple actions are needed, please list **all** of them.
70 |
71 |
72 | [
73 | {{
74 | "explanation": ,
75 | "action": ,
76 | "parameters":
77 | }},
78 | {{
79 | "explanation": ,
80 | "action": ,
81 | "parameters":
82 | }},
83 | {{
84 | "explanation": ,
85 | "action": ,
86 | "parameters":
87 | }},
88 | ...
89 | ]
90 |
91 |
92 |
93 |
94 | {stats}
95 |
96 |
97 | {task}
98 |
99 |
100 | {schema}
101 |
102 |
103 |
104 |
105 | Return your output in the json format inside ."""
106 |
107 | else:
108 | # Add history actions and score feedback only if there are history actions
109 | if history_actions.strip() != "":
110 | history_actions = f"History Actions: \n{history_actions}"
111 | score_feedback = get_score_feedback(initial_score, past_score, current_score, higher_is_better, error_msg, last_action_num)
112 | return f"""You are expected to construct graph schema based on the original inputs.
113 | You will be given an original schema represented in the dictionary format:
114 |
115 | 1. dataset_name: name of the dataset
116 | 2. tables: meta data for list of tables, each one will present following attributes
117 | 1. name: table name
118 | 2. columns: list of columns, each column will have following attributes
119 | 1. name: column name
120 | 2. dtype: column type, can be either text, categorical, float, primary_key, foreign_key, or multi_category. primary_key and foreign_key are two special types of categorical columns, which presents a structural relationship with other tables. Multi_category means this column is of list type, and each cell main contains a list of categorical values. After a column is set as primary_key or foreign_key, it should not be changed to other types.
121 | 3. link_to (optional): if this column is a foreign key, point to which primary key from which table
122 | 3. statistics of the table: statistics of the column value of tables. These statistics can be used to help you determine the characteristics of the columns.
123 |
124 |
125 | Here are the documents of the actions:
126 | {action_info}
127 | {error_msg}
128 |
129 | Now, you need to
130 | 1. Actively think about whether any one of the 4 actions should be conducted
131 | 2. Output all actions you can think of from the above list to make the schema better, and output your selections in the following format:
132 | If multiple actions are needed, please list **all** of them.
133 |
134 |
135 | [
136 | {{
137 | "explanation": ,
138 | "action": ,
139 | "parameters":
140 | }},
141 | {{
142 | "explanation": ,
143 | "action": ,
144 | "parameters":
145 | }},
146 | {{
147 | "explanation": ,
148 | "action": ,
149 | "parameters":
150 | }},
151 | ...
152 | ]
153 |
154 |
155 | 3. If you think there's no more action, you can output
156 |
157 | None
158 |
159 |
160 | {history_actions}
161 |
162 |
163 |
164 | {stats}
165 |
166 |
167 | {task}
168 |
169 |
170 | {schema}
171 |
172 |
173 |
174 | {score_feedback}
175 |
176 | Note that the current schema may **not be optimal**, so other actions may yield better results.
177 | Please **only** halt the program with `None` if you believe no further actions are worth trying.
178 | You can try {budget} more times to improve the performance.
179 | Return your output in the json format inside ."""
180 |
181 |
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/models/modules/gin_conv.py:
--------------------------------------------------------------------------------
1 | # Custom implementation of GINConv based on torch_geometric.nn.conv.GINConv
2 |
3 | from typing import Callable, Optional, Union
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from torch_geometric.nn.conv import MessagePassing
9 | from torch_geometric.nn.dense.linear import Linear
10 | from torch_geometric.nn.inits import reset
11 | from torch_geometric.typing import (
12 | Adj,
13 | OptPairTensor,
14 | OptTensor,
15 | Size,
16 | SparseTensor,
17 | )
18 | from torch_geometric.utils import spmm
19 |
20 |
21 | class GINConv(MessagePassing):
22 | r"""The graph isomorphism operator from the `"How Powerful are
23 | Graph Neural Networks?" `_ paper.
24 |
25 | .. math::
26 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot
27 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)
28 |
29 | or
30 |
31 | .. math::
32 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} +
33 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),
34 |
35 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP.
36 |
37 | Args:
38 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
39 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to
40 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by
41 | :class:`torch.nn.Sequential`.
42 | eps (float, optional): (Initial) :math:`\epsilon`-value.
43 | (default: :obj:`0.`)
44 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon`
45 | will be a trainable parameter. (default: :obj:`False`)
46 | **kwargs (optional): Additional arguments of
47 | :class:`torch_geometric.nn.conv.MessagePassing`.
48 |
49 | Shapes:
50 | - **input:**
51 | node features :math:`(|\mathcal{V}|, F_{in})` or
52 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
53 | if bipartite,
54 | edge indices :math:`(2, |\mathcal{E}|)`
55 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
56 | :math:`(|\mathcal{V}_t|, F_{out})` if bipartite
57 | """
58 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
59 | **kwargs):
60 | kwargs.setdefault('aggr', 'add')
61 | super().__init__(**kwargs)
62 | self.nn = nn
63 | self.initial_eps = eps
64 | if train_eps:
65 | self.eps = torch.nn.Parameter(torch.empty(1))
66 | else:
67 | self.register_buffer('eps', torch.empty(1))
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | super().reset_parameters()
72 | reset(self.nn)
73 | self.eps.data.fill_(self.initial_eps)
74 |
75 | def forward(
76 | self,
77 | x: Union[Tensor, OptPairTensor],
78 | edge_index: Adj,
79 | size: Size = None,
80 | ) -> Tensor:
81 |
82 | if isinstance(x, Tensor):
83 | x = (x, x)
84 |
85 | # propagate_type: (x: OptPairTensor)
86 | out = self.propagate(edge_index, x=x, size=size)
87 |
88 | x_r = x[1]
89 | if x_r is not None:
90 | out = out + (1 + self.eps) * x_r
91 |
92 | return self.nn(out)
93 |
94 | def message(self, x_j: Tensor) -> Tensor:
95 | return x_j
96 |
97 | def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
98 | if isinstance(adj_t, SparseTensor):
99 | adj_t = adj_t.set_value(None, layout=None)
100 | return spmm(adj_t, x[0], reduce=self.aggr)
101 |
102 | def __repr__(self) -> str:
103 | return f'{self.__class__.__name__}(nn={self.nn})'
104 |
105 |
106 | class GINEConv(MessagePassing):
107 | r"""The modified :class:`GINConv` operator from the `"Strategies for
108 | Pre-training Graph Neural Networks" `_
109 | paper.
110 |
111 | .. math::
112 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot
113 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU}
114 | ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right)
115 |
116 | that is able to incorporate edge features :math:`\mathbf{e}_{j,i}` into
117 | the aggregation procedure.
118 |
119 | Args:
120 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
121 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to
122 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by
123 | :class:`torch.nn.Sequential`.
124 | eps (float, optional): (Initial) :math:`\epsilon`-value.
125 | (default: :obj:`0.`)
126 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon`
127 | will be a trainable parameter. (default: :obj:`False`)
128 | edge_dim (int, optional): Edge feature dimensionality. If set to
129 | :obj:`None`, node and edge feature dimensionality is expected to
130 | match. Other-wise, edge features are linearly transformed to match
131 | node feature dimensionality. (default: :obj:`None`)
132 | **kwargs (optional): Additional arguments of
133 | :class:`torch_geometric.nn.conv.MessagePassing`.
134 |
135 | Shapes:
136 | - **input:**
137 | node features :math:`(|\mathcal{V}|, F_{in})` or
138 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
139 | if bipartite,
140 | edge indices :math:`(2, |\mathcal{E}|)`,
141 | edge features :math:`(|\mathcal{E}|, D)` *(optional)*
142 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
143 | :math:`(|\mathcal{V}_t|, F_{out})` if bipartite
144 | """
145 | def __init__(self, nn: torch.nn.Module, eps: float = 0.,
146 | train_eps: bool = False, edge_dim: Optional[int] = None,
147 | **kwargs):
148 | kwargs.setdefault('aggr', 'add')
149 | super().__init__(**kwargs)
150 | self.nn = nn
151 | self.initial_eps = eps
152 | if train_eps:
153 | self.eps = torch.nn.Parameter(torch.empty(1))
154 | else:
155 | self.register_buffer('eps', torch.empty(1))
156 | if edge_dim is not None:
157 | if isinstance(self.nn, torch.nn.Sequential):
158 | nn = self.nn[0]
159 | if hasattr(nn, 'in_features'):
160 | in_channels = nn.in_features
161 | elif hasattr(nn, 'in_channels'):
162 | in_channels = nn.in_channels
163 | else:
164 | raise ValueError("Could not infer input channels from `nn`.")
165 | self.lin = Linear(edge_dim, in_channels)
166 |
167 | else:
168 | self.lin = None
169 | self.reset_parameters()
170 |
171 | def reset_parameters(self):
172 | reset(self.nn)
173 | self.eps.data.fill_(self.initial_eps)
174 | if self.lin is not None:
175 | self.lin.reset_parameters()
176 |
177 | def forward(
178 | self,
179 | x: Union[Tensor, OptPairTensor],
180 | edge_index: Adj,
181 | edge_attr: OptTensor = None,
182 | size: Size = None,
183 | ) -> Tensor:
184 | if isinstance(x, Tensor):
185 | x = (x, x)
186 |
187 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
188 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
189 |
190 | x_r = x[1]
191 | if x_r is not None:
192 | out = out + (1 + self.eps) * x_r
193 |
194 | return self.nn(out)
195 |
196 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
197 | if edge_attr is None:
198 | edge_attr = torch.zeros_like(x_j)
199 | if self.lin is None and x_j.size(-1) != edge_attr.size(-1):
200 | raise ValueError("Node and edge feature dimensionalities do not "
201 | "match. Consider setting the 'edge_dim' "
202 | "attribute of 'GINEConv'")
203 | if self.lin is not None:
204 | edge_attr = self.lin(edge_attr)
205 |
206 | return (x_j + edge_attr).relu()
207 |
208 | def __repr__(self) -> str:
209 | return f'{self.__class__.__name__}(nn={self.nn})'
210 |
--------------------------------------------------------------------------------
/rdb2g_bench/common/search_space/row2ne_search_space.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple, Union
2 |
3 | import copy
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from torch_geometric.data import HeteroData
8 |
9 | class Row2NESearchSpace():
10 | def __init__(self, dataset: str, task: str, hetero_data: HeteroData):
11 | self.dataset = dataset
12 | self.task = task
13 | self.data = hetero_data
14 | self.node_types = sorted(list(hetero_data.node_types))
15 | self.edge_types = sorted(list(hetero_data.edge_types))
16 | self.possible_edge_types = self.find_possible_edges()
17 | self.full_edges = self.find_full_edges()
18 | self._cached_r2e_key = None
19 | self._cached_converted_data = None
20 |
21 | def find_possible_edges(self):
22 | possible_edges = []
23 | for node_type in self.node_types:
24 | outgoing_edges = [(src, rel, dst) for src, rel, dst in self.edge_types
25 | if src == node_type and rel.startswith('f2p')]
26 | incoming_edges = [(src, rel, dst) for src, rel, dst in self.edge_types
27 | if dst == node_type and rel.startswith('f2p')]
28 | if len(outgoing_edges) == 2 and len(incoming_edges) == 0:
29 | possible_edges.append(node_type)
30 |
31 | return sorted(possible_edges)
32 |
33 | def find_full_edges(self):
34 | full_edges = []
35 |
36 | for node_type in self.possible_edge_types:
37 | connected_nodes = []
38 | for src, rel, dst in self.edge_types:
39 | if src == node_type and rel.startswith('f2p'):
40 | connected_nodes.append(dst)
41 |
42 | if len(connected_nodes) == 1:
43 | src_node, dst_node = connected_nodes[0], connected_nodes[0]
44 | elif len(connected_nodes) == 2:
45 | src_node, dst_node = connected_nodes[0], connected_nodes[1]
46 | else:
47 | raise ValueError(f"Invalid number of connected nodes: {len(connected_nodes)}")
48 |
49 | src_edge_type, dst_edge_type = None, None
50 | for edge_type in self.edge_types:
51 | if edge_type[0] == node_type and edge_type[2] == src_node and edge_type != dst_edge_type and src_edge_type is None:
52 | src_edge_type = edge_type
53 | if edge_type[0] == node_type and edge_type[2] == dst_node and edge_type != src_edge_type and dst_edge_type is None:
54 | dst_edge_type = edge_type
55 |
56 | if src_edge_type and dst_edge_type:
57 | full_edges.append((src_node, f"r2e_{node_type}", dst_node))
58 |
59 | return full_edges + [edge for edge in sorted(self.data.edge_types) if edge[1].startswith("f2p")]
60 |
61 | def convert_row_to_edge(self, edges: torch.Tensor):
62 | if not isinstance(edges, torch.Tensor):
63 | edges = torch.tensor(edges)
64 |
65 | # Determine the cache key based on the row-to-edge part of the edges tensor
66 | r2e_part = edges[:len(self.possible_edge_types)]
67 | current_r2e_key = tuple(r2e_part.tolist())
68 | if current_r2e_key == self._cached_r2e_key:
69 | print("Using cached converted data")
70 | return copy.deepcopy(self._cached_converted_data)
71 |
72 | converted_data = HeteroData()
73 | # Copy node data and ensure tensors are contiguous
74 | nodes_to_convert = set()
75 | for i, is_selected in enumerate(r2e_part):
76 | if is_selected == 1:
77 | _, rel, _ = self.full_edges[i]
78 | nodes_to_convert.add(rel[4:])
79 |
80 | for node_type in self.node_types:
81 | for attr_name, attr_value in self.data[node_type].items():
82 | if node_type in nodes_to_convert and attr_name == 'time':
83 | continue
84 | if isinstance(attr_value, torch.Tensor):
85 | converted_data[node_type][attr_name] = attr_value.clone().detach().contiguous()
86 | else:
87 | converted_data[node_type][attr_name] = attr_value
88 |
89 | # Copy existing edges and ensure tensors are contiguous
90 | for edge_type in self.edge_types:
91 | src_node, edge_name, dst_node = edge_type
92 | if src_node not in nodes_to_convert and dst_node not in nodes_to_convert:
93 | converted_data[edge_type].edge_index = self.data[edge_type].edge_index.clone().detach().contiguous()
94 |
95 | # Convert nodes to edges
96 | for node_type in nodes_to_convert:
97 | connected_nodes = []
98 | for src, rel, dst in self.edge_types:
99 | if src == node_type and rel.startswith('f2p'):
100 | connected_nodes.append(dst)
101 |
102 | if len(connected_nodes) == 1:
103 | src_node, dst_node = connected_nodes[0], connected_nodes[0]
104 | elif len(connected_nodes) == 2:
105 | src_node, dst_node = connected_nodes[0], connected_nodes[1]
106 | else:
107 | raise ValueError(f"Invalid number of connected nodes: {len(connected_nodes)}")
108 |
109 | # Find relevant edge types
110 | src_edge_type, dst_edge_type = None, None
111 | for edge_type in self.edge_types:
112 | if edge_type[0] == node_type and edge_type[2] == src_node and edge_type != dst_edge_type and src_edge_type is None:
113 | src_edge_type = edge_type
114 | if edge_type[0] == node_type and edge_type[2] == dst_node and edge_type != src_edge_type and dst_edge_type is None:
115 | dst_edge_type = edge_type
116 |
117 | if src_edge_type and dst_edge_type:
118 | # Create new edge (from src_node to dst_node)
119 | new_edge_type = (src_node, f"r2e_{node_type}", dst_node)
120 | reverse_edge_type = (dst_node, f"rev_r2e_{node_type}", src_node)
121 |
122 | src_edges = self.data[src_edge_type].edge_index.clone().detach().contiguous()
123 | dst_edges = self.data[dst_edge_type].edge_index.clone().detach().contiguous()
124 |
125 | # Create mappings using numpy operations
126 | src_node_ids = src_edges[0].numpy()
127 | src_table_ids = src_edges[1].numpy()
128 | src_mapping = dict(zip(src_node_ids, src_table_ids))
129 |
130 | dst_node_ids = dst_edges[0].numpy()
131 | dst_table_ids = dst_edges[1].numpy()
132 | dst_mapping = dict(zip(dst_node_ids, dst_table_ids))
133 |
134 | # Get common node ids that exist in both mappings
135 | node_ids = np.arange(len(self.data[node_type].tf))
136 | mask = np.isin(node_ids, list(src_mapping.keys())) & np.isin(node_ids, list(dst_mapping.keys()))
137 | valid_node_ids = node_ids[mask]
138 |
139 | # Construct edge pairs
140 | if len(valid_node_ids) > 0:
141 | # Use vectorized operations to get src and dst ids
142 | max_node_id = max(max(src_mapping.keys()), max(dst_mapping.keys())) + 1
143 | src_lookup = np.full(max_node_id, -1)
144 | dst_lookup = np.full(max_node_id, -1)
145 |
146 | src_keys = np.array(list(src_mapping.keys()), dtype=np.int64)
147 | src_values = np.array(list(src_mapping.values()), dtype=np.int64)
148 | src_lookup[src_keys] = src_values
149 |
150 | dst_keys = np.array(list(dst_mapping.keys()), dtype=np.int64)
151 | dst_values = np.array(list(dst_mapping.values()), dtype=np.int64)
152 | dst_lookup[dst_keys] = dst_values
153 |
154 | # Apply vectorized lookups
155 | src_ids = src_lookup[valid_node_ids]
156 | dst_ids = dst_lookup[valid_node_ids]
157 | new_edges = np.stack([src_ids, dst_ids], axis=1)
158 |
159 | if len(new_edges) > 0:
160 | new_edge_index = torch.tensor(new_edges, dtype=torch.long).t().contiguous()
161 | converted_data[new_edge_type].edge_index = new_edge_index
162 | converted_data[new_edge_type].mapped_node_ids = torch.tensor(valid_node_ids, dtype=torch.long).clone().detach().contiguous()
163 |
164 | reverse_edge_index = torch.stack([new_edge_index[1], new_edge_index[0]], dim=0).contiguous()
165 | converted_data[reverse_edge_type].edge_index = reverse_edge_index
166 | converted_data[reverse_edge_type].mapped_node_ids = torch.tensor(valid_node_ids, dtype=torch.long).clone().detach().contiguous()
167 |
168 | self._cached_converted_data = converted_data.to('cpu')
169 | self._cached_r2e_key = current_r2e_key
170 |
171 | return converted_data
--------------------------------------------------------------------------------
/rdb2g_bench/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from pathlib import Path
4 | from typing import Optional, List, Dict, Union
5 | from datasets import load_dataset
6 |
7 | from .dataloader import RDB2GBench
8 |
9 |
10 | def load_rdb2g_bench(
11 | result_dir: str = "./results",
12 | download: bool = True,
13 | cache_dir: Optional[str] = None,
14 | tag: str = "hf"
15 | ) -> RDB2GBench:
16 | """
17 | Load RDB2G-Bench results from local directory, download if missing.
18 |
19 | This function serves as the main entry point for accessing RDB2G-Bench data.
20 | It first checks for existing data in the specified directory, and if no data
21 | is found and download is enabled, it automatically downloads the dataset
22 | from Hugging Face Hub.
23 |
24 | Args:
25 | result_dir (str): Directory containing the benchmark results.
26 | Defaults to "./results".
27 | download (bool): Whether to download data if results are missing.
28 | Defaults to True.
29 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets.
30 | If None, uses default HF cache location.
31 | tag (str): Tag to use for downloaded data organization.
32 | Defaults to "hf".
33 |
34 | Returns:
35 | RDB2GBench: Benchmark object for accessing organized results with
36 | hierarchical access pattern: bench[dataset][task][idx].
37 |
38 | Raises:
39 | RuntimeError: If no valid data is found in the result directory
40 | after download attempts.
41 |
42 | Example:
43 | >>> bench = load_rdb2g_bench("./my_results")
44 | >>> available = bench.get_available()
45 | >>> result = bench['rel-f1']['driver-top3'][0]
46 | """
47 | result_path = Path(result_dir)
48 | tables_path = result_path / "tables"
49 |
50 | has_data = False
51 | if tables_path.exists():
52 | for dataset_dir in tables_path.iterdir():
53 | if dataset_dir.is_dir():
54 | for task_dir in dataset_dir.iterdir():
55 | if task_dir.is_dir():
56 | for tag_dir in task_dir.iterdir():
57 | if tag_dir.is_dir():
58 | # Check for GNN subdirectories with CSV files
59 | for gnn_dir in tag_dir.iterdir():
60 | if gnn_dir.is_dir() and list(gnn_dir.glob("*.csv")):
61 | has_data = True
62 | break
63 | if has_data:
64 | break
65 | if has_data:
66 | break
67 | if has_data:
68 | break
69 |
70 | if not has_data and download:
71 | print(f"No data found in {result_dir}, downloading from Hugging Face...")
72 | download_rdb2g_bench(
73 | result_dir=result_dir,
74 | cache_dir=cache_dir,
75 | tag=tag
76 | )
77 |
78 | bench = RDB2GBench(result_dir)
79 |
80 | if not bench.get_available():
81 | raise RuntimeError(f"No valid data found in {result_dir}")
82 |
83 | return bench
84 |
85 |
86 | def download_rdb2g_bench(
87 | result_dir: str = "./results",
88 | cache_dir: Optional[str] = None,
89 | dataset_names: Optional[List[str]] = None,
90 | task_names: Optional[List[str]] = None,
91 | gnn_names: Optional[List[str]] = None,
92 | tag: str = "hf",
93 | ) -> Dict[str, List[str]]:
94 | """
95 | Download RDB2G-Bench dataset from Hugging Face and organize it by dataset/task.
96 |
97 | This function downloads the complete or filtered RDB2G-Bench dataset from
98 | Hugging Face Hub and organizes it into a structured directory format.
99 | The data is grouped by dataset and task, with separate CSV files for each
100 | random seed.
101 |
102 | Args:
103 | result_dir (str): Directory to save the organized results.
104 | Will be created if it doesn't exist. Defaults to "./results".
105 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets.
106 | If None, uses default HF cache location.
107 | dataset_names (Optional[List[str]]): List of specific datasets to download.
108 | If None, downloads all available datasets.
109 | task_names (Optional[List[str]]): List of specific tasks to download.
110 | If None, downloads all available tasks.
111 | gnn_names (Optional[List[str]]): List of specific GNN models to download.
112 | If None, downloads all available GNN models.
113 | tag (str): Tag to identify the download and organize files.
114 | Defaults to "hf".
115 |
116 | Returns:
117 | Dict[str, List[str]]: Dictionary mapping dataset/task combinations to lists of saved file paths.
118 |
119 | Keys are in format "dataset/task".
120 |
121 | Example:
122 | >>> saved_files = download_rdb2g_bench(
123 | ... dataset_names=['rel-f1'],
124 | ... task_names=['driver-top3'],
125 | ... gnn_names=['GraphSAGE']
126 | ... )
127 | >>> print(saved_files)
128 | {'rel-f1/driver-top3': ['./results/tables/rel-f1/driver-top3/hf/GraphSAGE/0.csv', ...]}
129 | """
130 | result_dir = Path(result_dir)
131 | result_dir.mkdir(parents=True, exist_ok=True)
132 |
133 | dataset = load_dataset(
134 | "kaistdata/RDB2G-Bench",
135 | cache_dir=cache_dir,
136 | split="train"
137 | )
138 |
139 | df = dataset.to_pandas()
140 |
141 | if dataset_names is not None:
142 | df = df[df['dataset'].isin(dataset_names)]
143 |
144 | if task_names is not None:
145 | df = df[df['task'].isin(task_names)]
146 |
147 | if gnn_names is not None:
148 | df = df[df['gnn'].isin(gnn_names)]
149 |
150 | if df.empty:
151 | return {}
152 |
153 | saved_files = {}
154 | grouped = df.groupby(['dataset', 'task'])
155 |
156 | for (dataset_name, task_name), group_df in grouped:
157 | group_df = group_df.sort_values(['seed', 'idx'])
158 |
159 | combination_key = f"{dataset_name}/{task_name}"
160 | saved_files[combination_key] = []
161 |
162 | # Group by GNN to create separate directories
163 | for gnn_name, gnn_group_df in group_df.groupby('gnn'):
164 | gnn_dir = result_dir / "tables" / dataset_name / task_name / tag / gnn_name
165 | gnn_dir.mkdir(parents=True, exist_ok=True)
166 |
167 | for seed, seed_df in gnn_group_df.groupby('seed'):
168 | output_df = seed_df[[
169 | 'idx', 'graph', 'train_metric', 'valid_metric', 'test_metric', 'params',
170 | 'train_time', 'valid_time', 'test_time', 'gnn'
171 | ]].copy()
172 |
173 | filename = f"{seed}.csv"
174 | filepath = gnn_dir / filename
175 | output_df.to_csv(filepath, index=False)
176 | saved_files[combination_key].append(str(filepath))
177 |
178 | return saved_files
179 |
180 |
181 | def get_dataset_stats(cache_dir: Optional[str] = None) -> pd.DataFrame:
182 | """
183 | Get comprehensive statistics about all available datasets and tasks in RDB2G-Bench.
184 |
185 | This function loads the complete dataset from Hugging Face and computes
186 | aggregate statistics for each dataset/task combination, including counts
187 | of unique indices and seeds, as well as performance metrics statistics.
188 |
189 | Args:
190 | cache_dir (Optional[str]): Cache directory for Hugging Face datasets.
191 | If None, uses default HF cache location.
192 |
193 | Returns:
194 | pd.DataFrame: DataFrame with statistical information about all datasets and tasks.
195 |
196 | - dataset: Dataset name
197 | - task: Task name
198 | - gnn: GNN model name
199 | - idx: Number of unique graph configurations
200 | - seed: Number of random seeds
201 | - test_metric_mean: Mean test performance
202 | - test_metric_std: Standard deviation of test performance
203 | - test_metric_min: Minimum test performance
204 | - test_metric_max: Maximum test performance
205 |
206 | Example:
207 | >>> stats = get_dataset_stats()
208 | >>> print(stats.head())
209 | dataset task gnn idx seed test_metric_mean test_metric_std ...
210 | rel-f1 driver-top3 GraphSAGE 50 10 0.8542 0.0123 ...
211 | """
212 | dataset = load_dataset(
213 | "kaistdata/RDB2G-Bench",
214 | cache_dir=cache_dir,
215 | split="train"
216 | )
217 |
218 | df = dataset.to_pandas()
219 |
220 | stats = df.groupby(['dataset', 'task', 'gnn']).agg({
221 | 'idx': 'nunique',
222 | 'seed': 'nunique',
223 | 'test_metric': ['mean', 'std', 'min', 'max']
224 | }).round(4)
225 |
226 | stats.columns = ['_'.join(col).strip() for col in stats.columns]
227 |
228 | stats = stats.rename(columns={
229 | 'idx_nunique': 'idx',
230 | 'seed_nunique': 'seed',
231 | })
232 | stats = stats.reset_index()
233 |
234 | return stats
235 |
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/baselines/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Tuple
3 | from ..dataset import PerformancePredictionDataset
4 | import numpy as np
5 |
6 | def calculate_overall_rank(selected_actual_y: float,
7 | overall_actual_y: torch.Tensor,
8 | higher_is_better: bool):
9 | """
10 | Calculates the rank (position and percentile) of a selected performance.
11 |
12 | Args:
13 | selected_actual_y (float): The performance score to rank
14 | overall_actual_y (torch.Tensor): Tensor containing all performance values
15 | higher_is_better (bool): Whether higher values indicate better performance
16 |
17 | Returns:
18 | Optional[Dict]: Dictionary containing rank information with keys:
19 |
20 | rank_position_overall (int): Position in the overall ranking
21 | percentile_overall (float): Percentile ranking (0-100)
22 | total_samples_overall (int): Total number of samples in ranking
23 |
24 | Returns None if overall_actual_y is empty.
25 | """
26 | if overall_actual_y.numel() == 0:
27 | print("Warning: Overall actual y tensor is empty, cannot calculate overall rank.")
28 | return None
29 |
30 | overall_actual_y = overall_actual_y.squeeze()
31 | if overall_actual_y.ndim == 0:
32 | overall_actual_y = overall_actual_y.unsqueeze(0)
33 |
34 | num_total_overall = overall_actual_y.numel()
35 | if higher_is_better:
36 | num_better_or_equal = torch.sum(overall_actual_y >= selected_actual_y).item()
37 | else:
38 | num_better_or_equal = torch.sum(overall_actual_y <= selected_actual_y).item()
39 |
40 | rank_position = num_better_or_equal
41 | percentile = (rank_position / num_total_overall) * 100 if num_total_overall > 0 else 0
42 |
43 | return {
44 | "rank_position_overall": rank_position,
45 | "percentile_overall": percentile,
46 | "total_samples_overall": num_total_overall
47 | }
48 |
49 | def get_performance_for_index(
50 | index: int,
51 | dataset: PerformancePredictionDataset,
52 | performance_cache: dict
53 | ) -> float:
54 | """
55 | Retrieves the performance for a given graph index, using a cache if available.
56 |
57 | Args:
58 | index (int): Graph index to retrieve performance for
59 | dataset (PerformancePredictionDataset): Dataset containing performance data
60 | performance_cache (dict): Cache for storing retrieved performance values
61 |
62 | Returns:
63 | Optional[float]: Performance value for the given index, or None if not found
64 | or if performance data is invalid (NaN/None).
65 | """
66 | if index in performance_cache:
67 | return performance_cache[index]
68 |
69 | try:
70 | row = dataset.df_result_group[dataset.df_result_group['idx'] == index]
71 | if len(row) == 0:
72 | print(f"Warning: No performance data found for index {index}")
73 | return None
74 |
75 | performance = row.iloc[0][dataset.target_col]
76 |
77 | if performance is None or (isinstance(performance, float) and np.isnan(performance)):
78 | print(f"Warning: NaN or None performance found for index {index}")
79 | return None
80 |
81 | return performance
82 |
83 | except Exception as e:
84 | print(f"Error retrieving performance for index {index}: {str(e)}")
85 | return None
86 |
87 | def update_trajectory_and_best(
88 | index: int,
89 | perf: Optional[float],
90 | performance_cache: dict,
91 | initial_cache_size: int,
92 | total_evaluated_count: int,
93 | performance_trajectory: list,
94 | global_best_perf: float,
95 | global_best_index: int,
96 | higher_is_better: bool,
97 | ) -> Tuple[int, float, int]:
98 | """
99 | Updates trajectory and global best after a performance evaluation.
100 |
101 | Args:
102 | index (int): Index of the evaluated architecture
103 | perf (Optional[float]): Performance value obtained
104 | performance_cache (dict): Cache storing performance values
105 | initial_cache_size (int): Initial size of performance cache
106 | total_evaluated_count (int): Current count of evaluated architectures
107 | performance_trajectory (list): List tracking performance over time
108 | global_best_perf (float): Current best performance found
109 | global_best_index (int): Index of current best architecture
110 | higher_is_better (bool): Whether higher values indicate better performance
111 |
112 | Returns:
113 | Tuple[int, float, int]: Updated values:
114 |
115 | new_total_evaluated_count (int): Updated evaluation count
116 | new_global_best_perf (float): Updated best performance
117 | new_global_best_index (int): Updated best architecture index
118 | """
119 | new_total_evaluated_count = total_evaluated_count
120 | new_global_best_perf = global_best_perf
121 | new_global_best_index = global_best_index
122 |
123 | if perf is not None and np.isfinite(perf):
124 | if len(performance_cache) > initial_cache_size:
125 | new_total_evaluated_count += 1
126 |
127 | is_better = False
128 | if new_global_best_index == -1:
129 | is_better = True
130 | elif higher_is_better and perf > new_global_best_perf:
131 | is_better = True
132 | elif not higher_is_better and perf < new_global_best_perf:
133 | is_better = True
134 |
135 | if is_better:
136 | new_global_best_perf = perf
137 | new_global_best_index = index
138 |
139 | performance_trajectory.append((new_total_evaluated_count, new_global_best_perf))
140 |
141 | else:
142 | is_better_than_current_global = False
143 | if new_global_best_index == -1:
144 | is_better_than_current_global = True
145 | elif higher_is_better and perf > new_global_best_perf:
146 | is_better_than_current_global = True
147 | elif not higher_is_better and perf < new_global_best_perf:
148 | is_better_than_current_global = True
149 |
150 | if is_better_than_current_global:
151 | new_global_best_perf = perf
152 | new_global_best_index = index
153 |
154 | return new_total_evaluated_count, new_global_best_perf, new_global_best_index
155 |
156 | def pad_trajectory(
157 | performance_trajectory: list,
158 | total_evaluated_count: int,
159 | evaluation_budget: int,
160 | method_name: str = "Search"
161 | ) -> None:
162 | """Pads the performance trajectory if the search ended before using the full budget."""
163 | if total_evaluated_count < evaluation_budget:
164 | if performance_trajectory:
165 | final_best_perf_to_pad = performance_trajectory[-1][1]
166 | if np.isfinite(final_best_perf_to_pad):
167 | print(f"Padding {method_name} trajectory from {total_evaluated_count + 1} to {evaluation_budget} with final best perf: {final_best_perf_to_pad:.4f}")
168 | for i in range(total_evaluated_count + 1, evaluation_budget + 1):
169 | performance_trajectory.append((i, final_best_perf_to_pad))
170 | else:
171 | print(f"Warning: {method_name} Search ended early ({total_evaluated_count}/{evaluation_budget}) but last best performance is not finite ({final_best_perf_to_pad}). Cannot pad trajectory.")
172 | else:
173 | print(f"Warning: {method_name} Search ended early ({total_evaluated_count}/{evaluation_budget}) but performance trajectory is empty. Cannot pad trajectory.")
174 |
175 | def calculate_evaluation_time(
176 | index: int,
177 | dataset: PerformancePredictionDataset,
178 | time_cache: dict
179 | ) -> float:
180 | """
181 | Calculates the evaluation time: eval_time = train_time * 20 + valid_time + test_time.
182 |
183 | Args:
184 | index (int): Graph index to calculate evaluation time for
185 | dataset (PerformancePredictionDataset): Dataset containing timing data
186 | time_cache (dict): Cache for storing calculated evaluation times
187 |
188 | Returns:
189 | Optional[float]: Total evaluation time in seconds, or None if timing data
190 | is not available or invalid.
191 | """
192 | epochs = 20
193 |
194 | if index in time_cache:
195 | return time_cache[index]
196 |
197 | try:
198 | row = dataset.df_result_group[dataset.df_result_group['idx'] == index]
199 | if len(row) == 0:
200 | print(f"Warning: No time data found for index {index}")
201 | return None
202 |
203 | train_time = row.iloc[0]['train_time']
204 |
205 | if train_time is None or (isinstance(train_time, float) and np.isnan(train_time)):
206 | print(f"Warning: NaN or None train_time found for index {index}")
207 | return None
208 |
209 | valid_time = row.iloc[0]['valid_time']
210 | test_time = row.iloc[0]['test_time']
211 |
212 | if (valid_time is None or (isinstance(valid_time, float) and np.isnan(valid_time)) or
213 | test_time is None or (isinstance(test_time, float) and np.isnan(test_time))):
214 | print(f"Warning: NaN or None valid_time or test_time found for index {index}")
215 | eval_time = train_time * epochs + train_time + train_time
216 | else:
217 | eval_time = (train_time + valid_time) * epochs + test_time
218 |
219 | time_cache[index] = eval_time
220 |
221 | return eval_time
222 |
223 | except Exception as e:
224 | print(f"Error calculating evaluation time for index {index}: {str(e)}")
225 | return None
--------------------------------------------------------------------------------
/rdb2g_bench/common/search_space/gnn_search_space.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple
2 |
3 | import copy
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from torch_geometric.data import HeteroData
8 |
9 | class GNNSearchSpace():
10 | def __init__(
11 | self,
12 | dataset: str,
13 | task: str,
14 | num_layers: int,
15 | node_types: List[str],
16 | edge_types: List[Tuple[str, str, str]],
17 | src_entity_table: str,
18 | dst_entity_table: str = None,
19 | ):
20 | self.dataset = dataset
21 | self.task = task
22 | self.src_entity_table = src_entity_table
23 | self.dst_entity_table = dst_entity_table
24 | self.num_layers = num_layers
25 | self.node_types = node_types
26 | self.edge_types = edge_types
27 | self.reachable_nodes = set()
28 |
29 | def filter_unreachable(self, edges):
30 | reachable_nodes = self.reachable_nodes.copy()
31 | layer_reachable = [reachable_nodes.copy()]
32 |
33 | for _ in range(self.num_layers):
34 | new_reachable = set()
35 | for i, is_selected in enumerate(edges):
36 | if is_selected == 1:
37 | src, _, dst = self.edge_types[i]
38 | if src in reachable_nodes:
39 | new_reachable.add(dst)
40 | if dst in reachable_nodes:
41 | new_reachable.add(src)
42 |
43 | reachable_nodes.update(new_reachable)
44 | layer_reachable.append(new_reachable)
45 |
46 | if not new_reachable:
47 | break
48 |
49 | filtered_edges = np.copy(edges)
50 | for i, is_selected in enumerate(edges):
51 | if is_selected == 1:
52 | src, _, dst = self.edge_types[i]
53 | # Check connection between consecutive layers
54 | for layer in range(len(layer_reachable)-1):
55 | if (src in layer_reachable[layer] and dst in layer_reachable[layer+1]) or \
56 | (dst in layer_reachable[layer] and src in layer_reachable[layer+1]):
57 | break
58 | else:
59 | filtered_edges[i] = 0
60 |
61 | return filtered_edges
62 |
63 | def is_valid(self, edges) -> bool:
64 | raise NotImplementedError
65 |
66 | def is_possible(self, edges) -> bool:
67 | for i, is_selected in enumerate(edges):
68 | if is_selected == 1:
69 | _, rel, _ = self.edge_types[i]
70 | if rel.startswith('r2e'):
71 | rel = rel[4:]
72 | for j, edge in enumerate(self.edge_types):
73 | if edge[0] == rel or edge[2] == rel:
74 | if edges[j] == 1:
75 | return False
76 | return True
77 |
78 | def generate_all_graphs(self):
79 | num_edge_types = len(self.edge_types)
80 | possible_graphs = set()
81 | for mask in range(2 ** num_edge_types):
82 | # Create binary array of selected edges (1 = included, 0 = excluded)
83 | selected_edges = np.zeros(num_edge_types, dtype=int)
84 | for i in range(num_edge_types):
85 | if mask & (1 << i):
86 | selected_edges[i] = 1
87 | if self.is_possible(selected_edges):
88 | filtered_edges = self.filter_unreachable(selected_edges)
89 | key = tuple(filtered_edges)
90 | if key not in possible_graphs and self.is_valid(filtered_edges):
91 | possible_graphs.add(key)
92 |
93 | return sorted(list(possible_graphs))
94 |
95 | def get_data(self, edges: np.ndarray, data: HeteroData) -> HeteroData:
96 | new_data = HeteroData()
97 |
98 | involved_nodes = set()
99 | for i, is_selected in enumerate(edges):
100 | if is_selected:
101 | src, rel, dst = self.edge_types[i]
102 | involved_nodes.add(src)
103 | involved_nodes.add(dst)
104 | if rel.startswith("r2e"):
105 | involved_nodes.add(rel[4:])
106 |
107 | for node in involved_nodes:
108 | if node in data.node_types:
109 | for attr_name, attr_value in data[node].items():
110 | new_data[node][attr_name] = attr_value
111 |
112 | for i, is_selected in enumerate(edges):
113 | if is_selected:
114 | edge_type = self.edge_types[i]
115 | if edge_type in data.edge_types:
116 | for attr_name, attr_value in data[edge_type].items():
117 | new_data[edge_type][attr_name] = attr_value
118 | rev_edge_type = (edge_type[2], 'rev_' + edge_type[1], edge_type[0])
119 | assert rev_edge_type in data.edge_types
120 | for attr_name, attr_value in data[rev_edge_type].items():
121 | new_data[rev_edge_type][attr_name] = attr_value
122 |
123 | return new_data
124 |
125 | def __repr__(self) -> str:
126 | return (f"BasicSearchSpace(dataset={self.dataset}, "
127 | f"task={self.task}, "
128 | f"src_entity_table={self.src_entity_table}, "
129 | f"matrix_shape={self.matrix}, "
130 | f"node_types={self.node_types})")
131 |
132 |
133 | class GNNNodeSearchSpace(GNNSearchSpace):
134 | def __init__(
135 | self,
136 | dataset: str,
137 | task: str,
138 | num_layers: int,
139 | node_types: List[str],
140 | edge_types: List[Tuple[str, str, str]],
141 | src_entity_table: str,
142 | dst_entity_table: str = None,
143 | ):
144 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table)
145 | self.reachable_nodes = {self.src_entity_table}
146 | if dst_entity_table is not None:
147 | raise ValueError("dst_entity_table must be None for GNNNodeSearchSpace")
148 |
149 |
150 | def filter_unreachable(self, edges: torch.Tensor) -> torch.Tensor:
151 | return super().filter_unreachable(edges)
152 |
153 | def is_valid(self, edges: torch.Tensor) -> bool:
154 | # if target entity has no incoming/outcoming edges, the graph is invalid
155 | src_has_edges = False
156 |
157 | for i, is_selected in enumerate(edges):
158 | if is_selected == 1:
159 | src, _, dst = self.edge_types[i]
160 | if src == self.src_entity_table or dst == self.src_entity_table:
161 | src_has_edges = True
162 | break
163 |
164 | if not src_has_edges:
165 | return False
166 |
167 | return True
168 |
169 | def __repr__(self) -> str:
170 | return (f"GNNNodeSearchSpace(dataset={self.dataset}, "
171 | f"task={self.task}, "
172 | f"src_entity_table={self.src_entity_table}, "
173 | f"node_types={self.node_types})")
174 |
175 |
176 | class GNNLinkSearchSpace(GNNSearchSpace):
177 | def __init__(
178 | self,
179 | dataset: str,
180 | task: str,
181 | num_layers: int,
182 | node_types: List[str],
183 | edge_types: List[Tuple[str, str, str]],
184 | src_entity_table: str,
185 | dst_entity_table: str,
186 | ):
187 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table)
188 | self.reachable_nodes = {self.src_entity_table, self.dst_entity_table}
189 |
190 | def filter_unreachable(self, edges: np.ndarray) -> np.ndarray:
191 | return super().filter_unreachable(edges)
192 |
193 | def is_valid(self, edges: np.ndarray) -> bool:
194 | # Check if source and destination entity tables have any edges
195 | src_has_edges = False
196 | dst_has_edges = False
197 |
198 | for i, is_selected in enumerate(edges):
199 | if is_selected == 1:
200 | src, _, dst = self.edge_types[i]
201 | if src == self.src_entity_table or dst == self.src_entity_table:
202 | src_has_edges = True
203 | if src == self.dst_entity_table or dst == self.dst_entity_table:
204 | dst_has_edges = True
205 | if src_has_edges and dst_has_edges:
206 | break
207 |
208 | # If either source or destination entity has no edges, the graph is invalid
209 | if not src_has_edges or not dst_has_edges:
210 | return False
211 |
212 | return True
213 |
214 | def __repr__(self) -> str:
215 | return (f"GNNLinkSearchSpace(dataset={self.dataset}, "
216 | f"task={self.task}, "
217 | f"src_entity_table={self.src_entity_table}, "
218 | f"dst_entity_table={self.dst_entity_table}, "
219 | f"node_types={self.node_types})")
220 |
221 |
222 | class IDGNNLinkSearchSpace(GNNSearchSpace):
223 | def __init__(
224 | self,
225 | dataset: str,
226 | task: str,
227 | num_layers: int,
228 | node_types: List[str],
229 | edge_types: List[Tuple[str, str, str]],
230 | src_entity_table: str,
231 | dst_entity_table: str,
232 | ):
233 | super().__init__(dataset, task, num_layers, node_types, edge_types, src_entity_table, dst_entity_table)
234 | self.reachable_nodes = {self.src_entity_table}
235 |
236 | def filter_unreachable(self, edges: np.ndarray) -> np.ndarray:
237 | return super().filter_unreachable(edges)
238 |
239 | def is_valid(self, edges: np.ndarray) -> bool:
240 | # Check if source and destination entity tables have any edges
241 | src_has_edges = False
242 | dst_has_edges = False
243 |
244 | for i, is_selected in enumerate(edges):
245 | if is_selected == 1:
246 | src, _, dst = self.edge_types[i]
247 | if src == self.src_entity_table or dst == self.src_entity_table:
248 | src_has_edges = True
249 | if src == self.dst_entity_table or dst == self.dst_entity_table:
250 | dst_has_edges = True
251 | if src_has_edges and dst_has_edges:
252 | break
253 |
254 | # If source or destination entity has no edges, the graph is invalid
255 | if not src_has_edges or not dst_has_edges:
256 | return False
257 |
258 | reachable_from_src = {self.src_entity_table}
259 |
260 | for _ in range(len(self.node_types)):
261 | if self.dst_entity_table in reachable_from_src:
262 | break
263 |
264 | new_reachable = set()
265 | for i, is_selected in enumerate(edges):
266 | if is_selected == 1:
267 | src, _, dst = self.edge_types[i]
268 | if src in reachable_from_src:
269 | new_reachable.add(dst)
270 | if dst in reachable_from_src:
271 | new_reachable.add(src)
272 |
273 | reachable_from_src.update(new_reachable)
274 |
275 | if not new_reachable:
276 | break
277 |
278 | # If dst_entity_table is not reachable from src_entity_table, the graph is invalid
279 | if self.dst_entity_table not in reachable_from_src:
280 | return False
281 |
282 | return True
283 |
284 | def __repr__(self) -> str:
285 | return (f"IDGNNLinkSearchSpace(dataset={self.dataset}, "
286 | f"task={self.task}, "
287 | f"src_entity_table={self.src_entity_table}, "
288 | f"dst_entity_table={self.dst_entity_table}, "
289 | f"node_types={self.node_types})")
290 |
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/llm_micro_action.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Optional, Type, Union, List
2 | from torch_geometric.data import HeteroData
3 |
4 | from ...common.search_space.search_space import TotalSearchSpace
5 | from ...common.search_space.gnn_search_space import GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace
6 | from ..micro_action import MicroActionSet
7 |
8 | class LLMMicroActionSet(MicroActionSet):
9 | def __init__(self,
10 | dataset: str,
11 | task: str,
12 | hetero_data: HeteroData,
13 | GNNSpaceClass: Type[Union[GNNNodeSearchSpace, GNNLinkSearchSpace, IDGNNLinkSearchSpace]],
14 | num_layers: int,
15 | src_entity_table: str,
16 | dst_entity_table: Optional[str] = None):
17 | super().__init__(dataset, task, hetero_data, GNNSpaceClass, num_layers, src_entity_table, dst_entity_table)
18 |
19 | def add_fk_pk_edge(self,
20 | current_edge_set: Tuple[int, ...],
21 | from_table_name: str,
22 | from_col_name: str,
23 | to_table_name: str,
24 | ) -> Tuple[Tuple[int, ...], int, str]:
25 | graph_idx, error_msg = -1, ""
26 | edge_to_add = (from_table_name, f"f2p_{from_col_name}", to_table_name)
27 | if edge_to_add in self.full_edges:
28 | edge_index = self.full_edges.index(edge_to_add)
29 | if current_edge_set[edge_index] == 0:
30 | new_edge_set_list = list(current_edge_set)
31 | new_edge_set_list[edge_index] = 1
32 | new_edge_set = tuple(new_edge_set_list)
33 | if new_edge_set in self.valid_edge_sets:
34 | graph_idx = self.valid_edge_sets_list.index(new_edge_set)
35 | error_msg = ""
36 | else:
37 | new_edge_set = current_edge_set
38 | error_msg = f"Given add_fk_pk_edge action is not valid."
39 | else:
40 | new_edge_set = current_edge_set
41 | error_msg = f"Given edge type({edge_to_add}) between {from_table_name} and {to_table_name} is already connected."
42 | else:
43 | new_edge_set = current_edge_set
44 | error_msg = f"Given edge type({edge_to_add}) between {from_table_name} and {to_table_name} is an invalid edge type."
45 |
46 | return new_edge_set, graph_idx, error_msg
47 |
48 | def remove_fk_pk_edge(self,
49 | current_edge_set: Tuple[int, ...],
50 | from_table_name: str,
51 | from_col_name: str,
52 | to_table_name: str,
53 | ) -> Tuple[Tuple[int, ...], int, str]:
54 | graph_idx, error_msg = -1, ""
55 | edge_to_remove = (from_table_name, f"f2p_{from_col_name}", to_table_name)
56 | if edge_to_remove in self.full_edges:
57 | edge_index = self.full_edges.index(edge_to_remove)
58 | if current_edge_set[edge_index] == 1:
59 | new_edge_set_list = list(current_edge_set)
60 | new_edge_set_list[edge_index] = 0
61 | new_edge_set = tuple(new_edge_set_list)
62 | if new_edge_set in self.valid_edge_sets:
63 | graph_idx = self.valid_edge_sets_list.index(new_edge_set)
64 | error_msg = ""
65 | else:
66 | new_edge_set = current_edge_set
67 | error_msg = f"Given remove_fk_pk_edge action is not valid."
68 | else:
69 | new_edge_set = current_edge_set
70 | error_msg = f"Given edge type({edge_to_remove}) between {from_table_name} and {to_table_name} is not connected."
71 | else:
72 | new_edge_set = current_edge_set
73 | error_msg = f"Given edge type({edge_to_remove}) between {from_table_name} and {to_table_name} is an invalid edge type."
74 |
75 | return new_edge_set, graph_idx, error_msg
76 |
77 | def convert_row_to_edge(self,
78 | current_edge_set: Tuple[int, ...],
79 | table_1_name: str,
80 | table_2_name: str,
81 | edge_table_name: str
82 | ) -> Tuple[Tuple[int, ...], int, str]:
83 | graph_idx, error_msg = -1, ""
84 | if (table_1_name, f"r2e_{edge_table_name}", table_2_name) in self.full_edges:
85 | convert_edge = (table_1_name, f"r2e_{edge_table_name}", table_2_name)
86 | convert_edge_index = self.full_edges.index(convert_edge)
87 | elif (table_2_name, f"r2e_{edge_table_name}", table_1_name) in self.full_edges:
88 | convert_edge = (table_2_name, f"r2e_{edge_table_name}", table_1_name)
89 | convert_edge_index = self.full_edges.index(convert_edge)
90 | else:
91 | return current_edge_set, -1, f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is an invalid edge type."
92 |
93 | if current_edge_set[convert_edge_index] == 0:
94 | f2p_indices = self.r2e_to_f2p_map.get(convert_edge_index)
95 | if f2p_indices and f2p_indices[0] is not None and f2p_indices[1] is not None:
96 | f2p_idx1, f2p_idx2 = f2p_indices
97 | if current_edge_set[f2p_idx1] == 1 and current_edge_set[f2p_idx2] == 1:
98 | new_edge_set_list = list(current_edge_set)
99 | new_edge_set_list[convert_edge_index] = 1
100 | new_edge_set_list[f2p_idx1] = 0
101 | new_edge_set_list[f2p_idx2] = 0
102 | new_edge_set = tuple(new_edge_set_list)
103 | if new_edge_set in self.valid_edge_sets:
104 | graph_idx = self.valid_edge_sets_list.index(new_edge_set)
105 | error_msg = ""
106 | else:
107 | new_edge_set = current_edge_set
108 | error_msg = f"Given convert_edge_to_row action is not valid."
109 | else:
110 | new_edge_set = current_edge_set
111 | error_msg = f"Given convert_edge_to_row action is not valid."
112 | else:
113 | new_edge_set = current_edge_set
114 | error_msg = f"Given convert_edge_to_row action is not valid."
115 | else:
116 | new_edge_set = current_edge_set
117 | error_msg = f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is already converted to edge."
118 |
119 | return new_edge_set, graph_idx, error_msg
120 |
121 | def convert_edge_to_row(self,
122 | current_edge_set: Tuple[int, ...],
123 | table_1_name: str,
124 | table_2_name: str,
125 | edge_table_name: str
126 | ) -> Tuple[Tuple[int, ...], int, str]:
127 | graph_idx, error_msg = -1, ""
128 | if (table_1_name, f"r2e_{edge_table_name}", table_2_name) in self.full_edges:
129 | convert_edge = (table_1_name, f"r2e_{edge_table_name}", table_2_name)
130 | convert_edge_index = self.full_edges.index(convert_edge)
131 | elif (table_2_name, f"r2e_{edge_table_name}", table_1_name) in self.full_edges:
132 | convert_edge = (table_2_name, f"r2e_{edge_table_name}", table_1_name)
133 | convert_edge_index = self.full_edges.index(convert_edge)
134 | else:
135 | return current_edge_set, -1, f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is an invalid edge type."
136 |
137 | if current_edge_set[convert_edge_index] == 1:
138 | f2p_indices = self.r2e_to_f2p_map.get(convert_edge_index)
139 | if f2p_indices and f2p_indices[0] is not None and f2p_indices[1] is not None:
140 | f2p_idx1, f2p_idx2 = f2p_indices
141 | if current_edge_set[f2p_idx1] == 0 and current_edge_set[f2p_idx2] == 0:
142 | new_edge_set_list = list(current_edge_set)
143 | new_edge_set_list[convert_edge_index] = 0
144 | new_edge_set_list[f2p_idx1] = 1
145 | new_edge_set_list[f2p_idx2] = 1
146 | new_edge_set = tuple(new_edge_set_list)
147 | if new_edge_set in self.valid_edge_sets:
148 | graph_idx = self.valid_edge_sets_list.index(new_edge_set)
149 | error_msg = ""
150 | else:
151 | new_edge_set = current_edge_set
152 | error_msg = f"Given convert_edge_to_row action is not valid."
153 | else:
154 | new_edge_set = current_edge_set
155 | error_msg = f"Given convert_edge_to_row action is not valid."
156 | else:
157 | new_edge_set = current_edge_set
158 | error_msg = f"Given convert_edge_to_row action is not valid."
159 | else:
160 | new_edge_set = current_edge_set
161 | error_msg = f"Given edge type({edge_table_name}) between {table_1_name} and {table_2_name} is not converted to edge."
162 |
163 | return new_edge_set, graph_idx, error_msg
164 |
165 | def get_possible_add_fk_pk_edge(self,
166 | current_edge_set: Tuple[int, ...]
167 | ) -> List[Tuple[Tuple[int, ...], int]]:
168 | possible_next_sets_with_indices = []
169 | for edge_index_to_add in self.fk_pk_indices:
170 | if current_edge_set[edge_index_to_add] == 0:
171 | new_edge_set_list = list(current_edge_set)
172 | new_edge_set_list[edge_index_to_add] = 1
173 | # new_edge_set_list = self.search_space.gnn_search_space.filter_unreachable(new_edge_set_list)
174 | new_edge_set = tuple(new_edge_set_list)
175 | if new_edge_set in self.valid_edge_sets:
176 | assert new_edge_set in self.valid_edge_sets_list, f"Edge set {new_edge_set} in set but not in list!"
177 | # index = self.valid_edge_sets_list.index(new_edge_set)
178 | possible_next_sets_with_indices.append(new_edge_set)
179 | return possible_next_sets_with_indices
180 |
181 | def get_possible_remove_fk_pk_edge(self,
182 | current_edge_set: Tuple[int, ...]
183 | ) -> List[Tuple[Tuple[int, ...], int]]:
184 | possible_next_sets_with_indices = []
185 | for edge_index_to_remove in self.fk_pk_indices:
186 | if current_edge_set[edge_index_to_remove] == 1:
187 | new_edge_set_list = list(current_edge_set)
188 | new_edge_set_list[edge_index_to_remove] = 0
189 | new_edge_set = tuple(new_edge_set_list)
190 | if new_edge_set in self.valid_edge_sets:
191 | assert new_edge_set in self.valid_edge_sets_list, f"Edge set {new_edge_set} in set but not in list!"
192 | possible_next_sets_with_indices.append(new_edge_set)
193 | return possible_next_sets_with_indices
194 |
195 | def get_possible_convert_row_to_edge(self,
196 | current_edge_set: Tuple[int, ...]
197 | ) -> List[Tuple[Tuple[int, ...], int]]:
198 | possible_next_sets_with_indices = []
199 | for f2p_pair, r2e_idx in self.f2p_pair_to_r2e_map.items():
200 | f2p_idx1, f2p_idx2 = f2p_pair
201 | if current_edge_set[r2e_idx] == 0 and current_edge_set[f2p_idx1] == 1 and current_edge_set[f2p_idx2] == 1:
202 | new_edge_set_list = list(current_edge_set)
203 | new_edge_set_list[r2e_idx] = 1
204 | new_edge_set_list[f2p_idx1] = 0
205 | new_edge_set_list[f2p_idx2] = 0
206 | new_edge_set = tuple(new_edge_set_list)
207 | if new_edge_set in self.valid_edge_sets:
208 | possible_next_sets_with_indices.append(new_edge_set)
209 | return possible_next_sets_with_indices
210 |
211 | def get_possible_convert_edge_to_row(self,
212 | current_edge_set: Tuple[int, ...]
213 | ) -> List[Tuple[Tuple[int, ...], int]]:
214 | possible_next_sets_with_indices = []
215 | for f2p_pair, r2e_idx in self.f2p_pair_to_r2e_map.items():
216 | f2p_idx1, f2p_idx2 = f2p_pair
217 | if current_edge_set[f2p_idx1] == 0 and current_edge_set[f2p_idx2] == 0 and current_edge_set[r2e_idx] == 1:
218 | new_edge_set_list = list(current_edge_set)
219 | new_edge_set_list[f2p_idx1] = 1
220 | new_edge_set_list[f2p_idx2] = 1
221 | new_edge_set_list[r2e_idx] = 0
222 | new_edge_set = tuple(new_edge_set_list)
223 | if new_edge_set in self.valid_edge_sets:
224 | possible_next_sets_with_indices.append(new_edge_set)
225 | return possible_next_sets_with_indices
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/baselines/ea.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import time
5 | from typing import Dict, Union, List, Optional
6 |
7 | from ..dataset import PerformancePredictionDataset
8 | from ..micro_action import MicroActionSet
9 | from .utils import calculate_overall_rank, get_performance_for_index, update_trajectory_and_best, pad_trajectory, calculate_evaluation_time
10 |
11 | def evolutionary_heuristic_analysis(
12 | dataset: PerformancePredictionDataset,
13 | micro_action_set: MicroActionSet,
14 | overall_actual_y: torch.Tensor,
15 | higher_is_better: bool,
16 | termination_threshold_ratio: float,
17 | method_name: str = "Evolutionary Heuristic",
18 | population_size: int = 10,
19 | tournament_size: int = 10,
20 | max_iterations: int = 1000,
21 | ):
22 | """
23 | Perform Neural Architecture Search using Evolutionary Algorithm.
24 |
25 | This function implements a complete evolutionary algorithm for finding optimal
26 | graph neural network architectures. It maintains a population of architectures,
27 | applies micro action-based mutations, and uses tournament selection for evolution.
28 |
29 | Args:
30 | dataset (PerformancePredictionDataset): Dataset containing architecture
31 | performance data
32 | micro_action_set (MicroActionSet): Set of micro actions for architecture
33 | space exploration
34 | overall_actual_y (torch.Tensor): Complete performance tensor for
35 | ranking calculations
36 | higher_is_better (bool): Whether higher performance values are better
37 | termination_threshold_ratio (float): Fraction of total architectures to
38 | evaluate as budget
39 | method_name (str): Name identifier for this method.
40 | Defaults to "Evolutionary Heuristic".
41 | population_size (int): Number of individuals in the population.
42 | Defaults to 10.
43 | tournament_size (int): Number of individuals selected for tournament.
44 | Defaults to 10.
45 | max_iterations (int): Maximum number of evolutionary generations.
46 | Defaults to 1000.
47 |
48 | Returns:
49 | Dict[str, Union[str, int, float, List, Optional[int]]]: Dictionary containing search results and performance metrics.
50 |
51 | - method (str): Method name
52 | - selected_graph_id (Optional[int]): Index of best found architecture
53 | - actual_y_perf_of_selected (float): Performance of selected architecture
54 | - selection_metric_value (float): Metric value used for selection
55 | - selected_graph_origin (str): Origin method name
56 | - discovered_count (int): Number of architectures evaluated
57 | - total_iterations_run (int): Number of generations completed
58 | - rank_position_overall (float): Rank among all architectures
59 | - percentile_overall (float): Percentile ranking
60 | - total_samples_overall (int): Total available architectures
61 | - performance_trajectory (List): Performance over time
62 | - total_evaluation_time (float): Time spent on evaluations
63 | - total_run_time (float): Total algorithm runtime
64 |
65 | Example:
66 | >>> results = evolutionary_heuristic_analysis(
67 | ... dataset=dataset,
68 | ... micro_action_set=micro_actions,
69 | ... overall_actual_y=y_tensor,
70 | ... higher_is_better=True,
71 | ... termination_threshold_ratio=0.05,
72 | ... population_size=20,
73 | ... tournament_size=5,
74 | ... max_iterations=100
75 | ... )
76 | >>> print(f"Best architecture: {results['selected_graph_id']}")
77 | >>> print(f"Performance: {results['actual_y_perf_of_selected']:.4f}")
78 | """
79 | population = []
80 | seen_indices = set()
81 | performance_cache = {}
82 | time_cache = {}
83 | performance_trajectory = []
84 | num_total_valid_graphs = len(micro_action_set.valid_edge_sets_list)
85 | discovered_count = 0
86 | best_perf_so_far = float('-inf') if higher_is_better else float('inf')
87 | best_index_so_far = -1
88 | final_iteration_count = 0
89 | total_evaluated_count = 0
90 | total_evaluation_time = 0.0
91 |
92 | start_time = time.time()
93 |
94 | if num_total_valid_graphs == 0:
95 | print("Error: No valid graphs found in MicroActionSet. Cannot initialize population.")
96 | initial_indices = []
97 | else:
98 | pop_size_to_init = min(population_size, num_total_valid_graphs)
99 | if pop_size_to_init <= 0:
100 | print("Warning: Cannot initialize population with size <= 0.")
101 | initial_indices = []
102 | elif pop_size_to_init > num_total_valid_graphs:
103 | print(f"Warning: Requested population size ({population_size}) exceeds number of valid graphs ({num_total_valid_graphs}). Using {num_total_valid_graphs}.")
104 | pop_size_to_init = num_total_valid_graphs
105 | initial_indices = np.random.choice(
106 | num_total_valid_graphs,
107 | size=pop_size_to_init,
108 | replace=False
109 | )
110 | else:
111 | initial_indices = np.random.choice(
112 | num_total_valid_graphs,
113 | size=pop_size_to_init,
114 | replace=False
115 | )
116 |
117 | print(f"Initializing population with {len(initial_indices)} individuals...")
118 | for index in initial_indices:
119 | population.append((index, 0))
120 | if index not in seen_indices:
121 | seen_indices.add(index)
122 | discovered_count += 1
123 |
124 | initial_cache_size = len(performance_cache)
125 | perf = get_performance_for_index(index, dataset, performance_cache)
126 |
127 | if perf is not None:
128 | performance_cache[index] = perf
129 |
130 | eval_time = calculate_evaluation_time(index, dataset, time_cache)
131 | if eval_time is not None:
132 | if len(performance_cache) > initial_cache_size:
133 | total_evaluation_time += eval_time
134 |
135 | total_evaluated_count, best_perf_so_far, best_index_so_far = \
136 | update_trajectory_and_best(
137 | index, perf, performance_cache, initial_cache_size,
138 | total_evaluated_count, performance_trajectory,
139 | best_perf_so_far, best_index_so_far, higher_is_better
140 | )
141 |
142 | print(f"After initialization ({total_evaluated_count} evals): Best perf={best_perf_so_far:.4f}")
143 |
144 | termination_count_threshold = int(termination_threshold_ratio * num_total_valid_graphs)
145 | print(f"Evolution budget set to {termination_count_threshold} unique evaluations.")
146 |
147 | for iteration in range(max_iterations):
148 | final_iteration_count = iteration + 1
149 | if not population:
150 | break
151 |
152 | mutated_indices_this_step = []
153 |
154 | current_pop_size = len(population)
155 | actual_tournament_size = min(tournament_size, current_pop_size)
156 | if actual_tournament_size <= 0:
157 | break
158 |
159 | tournament_candidate_indices = np.random.choice(
160 | current_pop_size,
161 | size=actual_tournament_size,
162 | replace=False
163 | )
164 |
165 | for pop_list_idx in tournament_candidate_indices:
166 | current_index, _ = population[pop_list_idx]
167 | current_edge_set = micro_action_set.valid_edge_sets_list[current_index]
168 |
169 | action_fns = [
170 | micro_action_set.add_fk_pk_edge,
171 | micro_action_set.remove_fk_pk_edge,
172 | micro_action_set.convert_row_to_edge,
173 | micro_action_set.convert_edge_to_row,
174 | ]
175 | random.shuffle(action_fns)
176 |
177 | mutated_index = current_index
178 | for chosen_action_fn in action_fns:
179 | possible_next_states = chosen_action_fn(current_edge_set)
180 |
181 | if possible_next_states:
182 | _, next_index = random.choice(possible_next_states)
183 | mutated_index = next_index
184 |
185 | if mutated_index not in seen_indices:
186 | seen_indices.add(mutated_index)
187 | discovered_count += 1
188 | break
189 |
190 | mutated_indices_this_step.append(mutated_index)
191 |
192 | tournament_performances = []
193 | tournament_indices = mutated_indices_this_step
194 | valid_tournament_results = []
195 |
196 | terminated_early = False
197 | for index in tournament_indices:
198 | initial_cache_size = len(performance_cache)
199 | perf = get_performance_for_index(index, dataset, performance_cache)
200 |
201 | if perf is not None:
202 | tournament_performances.append(perf)
203 | valid_tournament_results.append((perf, index))
204 |
205 | performance_cache[index] = perf
206 |
207 | eval_time = calculate_evaluation_time(index, dataset, time_cache)
208 | if eval_time is not None:
209 | if len(performance_cache) > initial_cache_size:
210 | total_evaluation_time += eval_time
211 |
212 | new_total_evaluated_count, new_best_perf_so_far, new_best_index_so_far = \
213 | update_trajectory_and_best(
214 | index, perf, performance_cache, initial_cache_size,
215 | total_evaluated_count, performance_trajectory,
216 | best_perf_so_far, best_index_so_far, higher_is_better
217 | )
218 |
219 | if len(performance_cache) > initial_cache_size:
220 | total_evaluated_count = new_total_evaluated_count
221 | best_perf_so_far = new_best_perf_so_far
222 | best_index_so_far = new_best_index_so_far
223 |
224 | if total_evaluated_count >= termination_count_threshold:
225 | print(f"Termination condition met at iteration {iteration+1}: Evaluated {total_evaluated_count} unique graphs >= threshold {termination_count_threshold}")
226 | terminated_early = True
227 | break
228 | else:
229 | best_perf_so_far = new_best_perf_so_far
230 | best_index_so_far = new_best_index_so_far
231 |
232 | if terminated_early:
233 | break
234 |
235 | if not valid_tournament_results:
236 | print(f"Warning: Iteration {iteration+1} - No valid performances in tournament. Skipping replacement.")
237 | continue
238 |
239 | if higher_is_better:
240 | best_tournament_perf, best_tournament_index = max(valid_tournament_results, key=lambda item: item[0])
241 | else:
242 | best_tournament_perf, best_tournament_index = min(valid_tournament_results, key=lambda item: item[0])
243 |
244 | oldest_pop_list_idx = -1
245 | max_age = -1
246 | for i, (_, age) in enumerate(population):
247 | if age > max_age:
248 | max_age = age
249 | oldest_pop_list_idx = i
250 |
251 | if oldest_pop_list_idx != -1:
252 | population[oldest_pop_list_idx] = (best_tournament_index, 0)
253 |
254 | for i in range(len(population)):
255 | if i != oldest_pop_list_idx:
256 | idx, age = population[i]
257 | population[i] = (idx, age + 1)
258 |
259 | final_selected_index = best_index_so_far
260 | final_selected_perf = best_perf_so_far
261 |
262 | pad_trajectory(performance_trajectory, total_evaluated_count, termination_count_threshold, method_name)
263 |
264 | total_run_time = time.time() - start_time
265 |
266 | results = {
267 | "method": method_name,
268 | "selected_graph_id": final_selected_index,
269 | "actual_y_perf_of_selected": final_selected_perf,
270 | "selection_metric_value": final_selected_perf,
271 | "selected_graph_origin": "Evolutionary",
272 | "discovered_count": discovered_count,
273 | "total_iterations_run": final_iteration_count,
274 | "rank_position_overall": np.nan,
275 | "percentile_overall": np.nan,
276 | "total_samples_overall": overall_actual_y.numel() if overall_actual_y is not None else 0,
277 | "performance_trajectory": performance_trajectory,
278 | "total_evaluation_time": total_evaluation_time,
279 | "total_run_time": total_run_time
280 | }
281 |
282 | rank_info = calculate_overall_rank(
283 | final_selected_perf,
284 | overall_actual_y,
285 | higher_is_better
286 | )
287 | if rank_info:
288 | results["rank_position_overall"] = rank_info["rank_position_overall"]
289 | results["percentile_overall"] = rank_info["percentile_overall"]
290 |
291 | return results
292 |
--------------------------------------------------------------------------------
/rdb2g_bench/benchmark/llm/llm_utils.py:
--------------------------------------------------------------------------------
1 | from .llm_micro_action import LLMMicroActionSet
2 |
3 | import pandas as pd
4 | import json
5 |
6 |
7 | def get_budget(dataset, task, budget, gnn="GraphSAGE", tag="final", path='../results/tables'):
8 | csv_path = f"{path}/{dataset}/{task}/{tag}/{gnn}/0.csv"
9 | df = pd.read_csv(csv_path)
10 | max_idx = df['idx'].max() + 1
11 | budget = int(max_idx * budget)
12 | return budget
13 |
14 |
15 | def get_micro_action_result(
16 | action: dict,
17 | llm_micro_action_set: LLMMicroActionSet,
18 | current_edge_set: list
19 | ):
20 |
21 | if action['action'] == 'add_fk_pk_edge':
22 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.add_fk_pk_edge(
23 | current_edge_set=current_edge_set,
24 | from_table_name=action['parameters']['from_table_name'],
25 | from_col_name=action['parameters']['from_col_name'],
26 | to_table_name=action['parameters']['to_table_name']
27 | )
28 | elif action['action'] == 'remove_fk_pk_edge':
29 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.remove_fk_pk_edge(
30 | current_edge_set=current_edge_set,
31 | from_table_name=action['parameters']['from_table_name'],
32 | from_col_name=action['parameters']['from_col_name'],
33 | to_table_name=action['parameters']['to_table_name']
34 | )
35 |
36 | elif action['action'] == 'convert_row_to_edge':
37 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.convert_row_to_edge(
38 | current_edge_set=current_edge_set,
39 | table_1_name=action['parameters']['table_1_name'],
40 | table_2_name=action['parameters']['table_2_name'],
41 | edge_table_name=action['parameters']['edge_table_name']
42 | )
43 | elif action['action'] == 'convert_edge_to_row':
44 | new_edge_set, graph_idx, error_msg = llm_micro_action_set.convert_edge_to_row(
45 | current_edge_set=current_edge_set,
46 | table_1_name=action['parameters']['table_1_name'],
47 | table_2_name=action['parameters']['table_2_name'],
48 | edge_table_name=action['parameters']['edge_table_name']
49 | )
50 |
51 | else:
52 | # raise ValueError(f"Action {action['action']} is unsupported.")
53 | return None, None, None
54 |
55 | return new_edge_set, graph_idx, error_msg
56 |
57 | def check_all_actions(
58 | actions: list [dict],
59 | llm_micro_action_set: LLMMicroActionSet,
60 | current_edge_set: tuple,
61 | perf_pred_dataset: dict
62 | ):
63 | new_edge_sets, graph_idxs, error_msgs, scores = [], [], [], []
64 | for action in actions :
65 | action['parameters'] = action['parameters'][0] if type(action['parameters']) == list else action['parameters']
66 | new_edge_set, graph_idx, error_msg = get_micro_action_result(
67 | action=action,
68 | llm_micro_action_set=llm_micro_action_set,
69 | current_edge_set=current_edge_set
70 | )
71 | if new_edge_set is None:
72 | continue
73 | # update results
74 | new_edge_sets.append(new_edge_set)
75 | graph_idxs.append(graph_idx)
76 | error_msgs.append(error_msg)
77 | score = perf_pred_dataset.get(graph_idx).y.item() if graph_idx != -1 else 0
78 | scores.append(score)
79 |
80 | if all(graph_idx == -1 for graph_idx in graph_idxs):
81 | return actions[0], new_edge_sets[0], graph_idxs[0], error_msgs[0]
82 | else:
83 | idx = scores.index(max(scores))
84 | return actions[idx], new_edge_sets[idx], graph_idxs[idx], error_msgs[idx]
85 |
86 |
87 |
88 | def get_available_edges(full_edges: list):
89 | r2e_edges = []
90 | f2p_edges = []
91 | for edge in full_edges:
92 | if edge[1].startswith("r2e"):
93 | r2e_edges.append(edge)
94 | elif edge[1].startswith("f2p"):
95 | f2p_edges.append(edge)
96 |
97 | print(f"R2E edges: {r2e_edges}")
98 | print(f"F2P edges: {f2p_edges}")
99 | return r2e_edges, f2p_edges
100 |
101 | def check_none_action(parsed_response_text):
102 | if not parsed_response_text:
103 | return True
104 | if len(parsed_response_text) == 0 :
105 | return True
106 |
107 | parsed_response_text = [parsed_response_text] if type(parsed_response_text) == dict else parsed_response_text
108 |
109 | if 'action' not in parsed_response_text[0]:
110 | return True
111 | elif not parsed_response_text[0]['action']:
112 | return True
113 | elif parsed_response_text[0]['action'].lower() == "none":
114 | return True
115 | return False
116 |
117 |
118 | def get_changed_edge(before_edge_set, after_edge_set, full_edges):
119 |
120 | changed_edge_idx = [i for i, (a, b) in enumerate(zip(before_edge_set, after_edge_set)) if a != b]
121 |
122 | return full_edges[changed_edge_idx[0]]
123 |
124 | def get_edge_info(full_edges, current_edge_set, llm_micro_action_set):
125 |
126 | add_f2p_edge = llm_micro_action_set.get_possible_add_fk_pk_edge(current_edge_set)
127 | add_f2p_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in add_f2p_edge] if len(add_f2p_edge) > 0 else []
128 |
129 | remove_f2p_edge = llm_micro_action_set.get_possible_remove_fk_pk_edge(current_edge_set)
130 | remove_f2p_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in remove_f2p_edge] if len(remove_f2p_edge) > 0 else []
131 |
132 | convert_row2edge = llm_micro_action_set.get_possible_convert_row_to_edge(current_edge_set)
133 | convert_row2edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in convert_row2edge] if len(convert_row2edge) > 0 else []
134 |
135 | convert_edge2row = llm_micro_action_set.get_possible_convert_edge_to_row(current_edge_set)
136 | convert_edge2row = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in convert_edge2row] if len(convert_edge2row) > 0 else []
137 |
138 | add_row2edge_edge = llm_micro_action_set.get_possible_add_row2edge_edge(current_edge_set)
139 | add_row2edge_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in add_row2edge_edge] if len(add_row2edge_edge) > 0 else []
140 |
141 | remove_row2edge_edge = llm_micro_action_set.get_possible_remove_row2edge_edge(current_edge_set)
142 | remove_row2edge_edge = [get_changed_edge(current_edge_set, new_edge_set, full_edges) for new_edge_set in remove_row2edge_edge] if len(remove_row2edge_edge) > 0 else []
143 |
144 |
145 | add_f2p_edge_info = "Note: ONLY the following set of fk_pk_edge can be added:" if len(add_f2p_edge) > 0 else "Note: There are **NO** fk_pk_edge that can be added in current schema."
146 | remove_f2p_edge_info = "Note: ONLY the following set of fk_pk_edge can be removed:" if len(remove_f2p_edge) > 0 else "Note: There are **NO** fk_pk_edge that can be removed in current schema."
147 | convert_row2edge_info = "Note: ONLY the following set of edges can be converted from row to edge:" if len(convert_row2edge) > 0 else "Note: There are **NO** edges that can be converted from row to edge in current schema."
148 | convert_edge2row_info = "Note: ONLY the following set of edges can be converted from edge to row:" if len(convert_edge2row) > 0 else "Note: There are **NO** edges that can be converted from edge to row in current schema."
149 | add_row2edge_edge_info = "Note: ONLY the following set of edges can be added:" if len(add_row2edge_edge) > 0 else "Note: There are **NO** row2edge edges that can be added in current schema."
150 | remove_row2edge_edge_info = "Note: ONLY the following set of edges can be removed:" if len(remove_row2edge_edge) > 0 else "Note: There are **NO** row2edge edges that can be removed in current schema."
151 |
152 | for edge in add_f2p_edge:
153 | edge_info = {"from_table_name": edge[0], "from_col_name": edge[1].replace('f2p_',''), "to_table_name": edge[2]}
154 | add_f2p_edge_info += f'\n{json.dumps(edge_info)}'
155 | for edge in remove_f2p_edge:
156 | edge_info = {"from_table_name": edge[0], "from_col_name": edge[1].replace('f2p_',''), "to_table_name": edge[2]}
157 | remove_f2p_edge_info += f'\n{json.dumps(edge_info)}'
158 | for edge in convert_row2edge:
159 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')}
160 | convert_row2edge_info += f'\n{json.dumps(edge_info)}'
161 | for edge in convert_edge2row:
162 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')}
163 | convert_edge2row_info += f'\n{json.dumps(edge_info)}'
164 | for edge in add_row2edge_edge:
165 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')}
166 | add_row2edge_edge_info += f'\n{json.dumps(edge_info)}'
167 | for edge in remove_row2edge_edge:
168 | edge_info = {"table_1_name": edge[0], "table_2_name": edge[2], "edge_table_name": edge[1].replace('r2e_','')}
169 | remove_row2edge_edge_info += f'\n{json.dumps(edge_info)}'
170 |
171 | return {"add_fk_pk_edge": add_f2p_edge_info,
172 | "remove_fk_pk_edge": remove_f2p_edge_info,
173 | "convert_row_to_edge": convert_row2edge_info,
174 | "convert_edge_to_row": convert_edge2row_info,
175 | "add_row2edge_edge": add_row2edge_edge_info,
176 | "remove_row2edge_edge": remove_row2edge_edge_info}
177 |
178 |
179 | def conduct_multiple_actions(
180 | actions: list [dict],
181 | llm_micro_action_set: LLMMicroActionSet,
182 | current_edge_set: tuple
183 | ):
184 | valid_actions = []
185 | invalid_actions = []
186 | edge_set = current_edge_set
187 | graph_idx = -1
188 | error_msg = ""
189 | for action in actions :
190 | action['parameters'] = action['parameters'][0] if type(action['parameters']) == list else action['parameters']
191 | new_edge_set, new_graph_idx, new_error_msg = get_micro_action_result(
192 | action=action,
193 | llm_micro_action_set=llm_micro_action_set,
194 | current_edge_set=current_edge_set
195 | )
196 | # if the action is not valid, skip it
197 | if new_edge_set is None:
198 | continue
199 | # if the action is invalid, add it to the invalid actions
200 | if new_graph_idx == -1:
201 | error_msg += f"Action: {action['action']} \nParameters: {json.dumps(action['parameters'])} \nError: {new_error_msg}\n"
202 | invalid_actions.append(action)
203 | continue
204 | else:
205 | valid_actions.append(action)
206 | edge_set = new_edge_set
207 | graph_idx = new_graph_idx
208 | # print(f"New edge set: {edge_set}")
209 | # print(f"New graph idx: {graph_idx}")
210 | # print(f"Valid actions: {len(valid_actions)} Invalid actions: {len(invalid_actions)}")
211 | # print(f"Error msg: {error_msg}")
212 | return valid_actions, invalid_actions, edge_set, graph_idx, error_msg
213 |
214 | def update_edge_set(
215 | current_edge_set,
216 | new_edge_set,
217 | best_edge_set,
218 | update_best,
219 | ):
220 | # update best edge set
221 | current_edge_set = new_edge_set
222 | # print(f"Edge set changed from {current_edge_set} to {new_edge_set}")
223 | if update_best:
224 | best_edge_set = new_edge_set
225 | print(f"\033[93mBest edge set updated to {best_edge_set}\033[0m")
226 | return current_edge_set, best_edge_set
227 |
228 | def update_score(
229 | current_score,
230 | new_score,
231 | best_score,
232 | score_result,
233 | update_best,
234 | ):
235 | # if graph_idx = -1, do not update the score ( i.e., new_score = current_score )
236 | if new_score == current_score:
237 | score_result.append(best_score)
238 | return current_score, current_score, best_score, score_result
239 |
240 | past_score = current_score
241 | current_score = new_score
242 | if update_best:
243 | best_score = current_score
244 | print(f"\033[93mBest score updated to {best_score:.4f}\033[0m")
245 | score_result.append(best_score)
246 | return past_score, current_score, best_score, score_result
247 |
248 | def update_action(
249 | parsed_all_actions, # ALL actions
250 | valid_actions, # ONLY Valid actions
251 | action_result, # history actions
252 | valid_action_result, # history valid actions
253 | best_valid_action_result,
254 | last_action_num,
255 | update_best,
256 | ):
257 | action_result.extend(parsed_all_actions)
258 | valid_action_result.extend(valid_actions)
259 |
260 | if update_best:
261 | best_valid_action_result = valid_action_result.copy()
262 | print(f"\033[93mBest valid action result updated to {len(best_valid_action_result)} actions \033[0m")
263 | else:
264 | pass # use the best_valid_action_result
265 |
266 | if len(valid_actions) > 0:
267 | last_action_num = len(valid_actions)
268 | else:
269 | pass # use the last_action_num
270 | return action_result, valid_action_result, best_valid_action_result, last_action_num
271 |
272 |
273 | def remove_invalid_history_actions(
274 | action_result,
275 | invalid_actions,
276 | ):
277 | print(f"Total history actions: {len(action_result)}")
278 | print(f"InValid history actions: {len(invalid_actions)}")
279 | for action in action_result:
280 | if action in invalid_actions:
281 | action_result.remove(action)
282 |
283 |
284 | print(f"Total history actions: {len(action_result)}")
285 | return action_result
286 |
287 | def get_history_actions(
288 | valid_action_result,
289 | # parsed_all_actions,
290 | max_history_actions,
291 | ):
292 | # valid_action_result.extend(parsed_all_actions)
293 | latest_actions = valid_action_result[-max_history_actions:]
294 | history_actions = json.dumps(latest_actions, indent=2).strip() if len(latest_actions) > 0 else ""
295 | return history_actions
296 |
--------------------------------------------------------------------------------