├── __init__.py ├── data └── .gitkeep ├── benchmarks ├── src │ ├── __init__.py │ └── ckg_benchmarks │ │ ├── __init__.py │ │ ├── gcl │ │ ├── __init__.py │ │ ├── base.py │ │ ├── grace.py │ │ ├── mvgrl.py │ │ └── train.py │ │ ├── egraphmae │ │ ├── __init__.py │ │ ├── model.py │ │ ├── train.py │ │ └── egat.py │ │ ├── graphmae │ │ ├── __init__.py │ │ ├── model.py │ │ ├── gat.py │ │ └── train.py │ │ ├── graphsage │ │ ├── __init__.py │ │ ├── model.py │ │ └── train.py │ │ ├── utils.py │ │ └── base.py ├── setup.py └── README.md ├── picture ├── logo_212x212.png ├── logo_400x113.png └── companykg_illustration.png ├── src └── companykg │ ├── __init__.py │ ├── settings.py │ ├── utils.py │ ├── eval.py │ └── kg.py ├── tutorials ├── grace_train_example.sh ├── mvgrl_train_example.sh ├── egraphmae_train_example.sh ├── graphmae_train_example.sh ├── graphsage_train_example.sh ├── gcl_train.ipynb └── tutorial.ipynb ├── setup.py ├── LICENSE ├── .gitignore └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/gcl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/egraphmae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphmae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphsage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /picture/logo_212x212.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/CompanyKG/HEAD/picture/logo_212x212.png -------------------------------------------------------------------------------- /picture/logo_400x113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/CompanyKG/HEAD/picture/logo_400x113.png -------------------------------------------------------------------------------- /picture/companykg_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/CompanyKG/HEAD/picture/companykg_illustration.png -------------------------------------------------------------------------------- /src/companykg/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | from .kg import CompanyKG -------------------------------------------------------------------------------- /tutorials/grace_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (C) eqtgroup.com Ltd 2023 4 | # https://github.com/EQTPartners/CompanyKG 5 | # License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 6 | # 7 | # This is an example of how to call the GCL training interface 8 | # to train a GNN with GRACE 9 | python -m ckg_benchmarks.gcl.train \ 10 | --device -1 \ 11 | --method grace \ 12 | --n-layer 1 \ 13 | --embedding-dim 8 \ 14 | --epochs 1 \ 15 | --sampler-edges 2 \ 16 | --batch-size 128 17 | -------------------------------------------------------------------------------- /tutorials/mvgrl_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (C) eqtgroup.com Ltd 2023 4 | # https://github.com/EQTPartners/CompanyKG 5 | # License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 6 | # 7 | # This is an example of how to call the GCL training interface 8 | # to train a GNN with GRACE 9 | python -m ckg_benchmarks.gcl.train \ 10 | --device -1 \ 11 | --method mvgrl \ 12 | --n-layer 1 \ 13 | --embedding-dim 8 \ 14 | --epochs 1 \ 15 | --sampler-edges 2 \ 16 | --batch-size 128 17 | -------------------------------------------------------------------------------- /tutorials/egraphmae_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (C) eqtgroup.com Ltd 2023 4 | # https://github.com/EQTPartners/CompanyKG 5 | # License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 6 | # 7 | # This is an example of how to call the GraphMAE training interface 8 | # Note that the hyperparameters used are extremely limited, so the 9 | # resulting model will not be good, but can be trained with small memory 10 | python -m ckg_benchmarks.egraphmae.train \ 11 | --epochs 1 \ 12 | --n-layer 2 \ 13 | --embedding-dim 8 \ 14 | --data-root-folder ./data \ 15 | --device -1 \ 16 | --disable-metis 17 | -------------------------------------------------------------------------------- /tutorials/graphmae_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (C) eqtgroup.com Ltd 2023 4 | # https://github.com/EQTPartners/CompanyKG 5 | # License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 6 | # 7 | # This is an example of how to call the GraphMAE training interface 8 | # Note that the hyperparameters used are extremely limited, so the 9 | # resulting model will not be good, but can be trained with small memory 10 | python -m ckg_benchmarks.graphmae.train \ 11 | --epochs 1 \ 12 | --n-layer 2 \ 13 | --embedding-dim 8 \ 14 | --data-root-folder ./data \ 15 | --device -1 \ 16 | --disable-metis 17 | -------------------------------------------------------------------------------- /src/companykg/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | CHUNK_SIZE = 1024 * 1024 * 64 # 64 MiB chunk sizes 8 | # This should point to the record number of the latest version 9 | ZENODO_RECORD_NUMBER = "8010239" 10 | ZENODO_DATASET_BASE_URI = f"https://zenodo.org/record/{ZENODO_RECORD_NUMBER}/files/" 11 | EDGES_FILENAME = "edges.pt" 12 | EDGES_WEIGHTS_FILENAME = "edges_weight.pt" 13 | NODES_FEATURES_FILENAME_TEMPLATE = "nodes_feature_.pt" 14 | EVAL_TASK_FILENAME_TEMPLATE = "eval_task_.parquet.gz" 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | from setuptools import setup 8 | 9 | setup( 10 | name="CompanyKG", 11 | version="1.0", 12 | package_dir={"": "src"}, 13 | include_package_data=True, 14 | description="Company Knowledge Graph data loading and evaluation utilities", 15 | author="EQT Motherbrain", 16 | install_requires=[ 17 | "numpy", 18 | "scipy", 19 | "pandas", 20 | "torch", 21 | "scikit-learn", 22 | "pyarrow", 23 | "fastparquet", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /tutorials/graphsage_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (C) eqtgroup.com Ltd 2023 4 | # https://github.com/EQTPartners/CompanyKG 5 | # License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 6 | # 7 | # This is an example of how to call the GraphSAGE training interface 8 | # Note that the hyperparameters used are extremely limited, so the 9 | # resulting model will not be good, but can be trained with small memory 10 | python -m ckg_benchmarks.graphsage.train \ 11 | --epochs 1 \ 12 | --n-layer 2 \ 13 | --embedding-dim 8 \ 14 | --data-root-folder ./data \ 15 | --device -1 \ 16 | --train-batch-size 256 \ 17 | --inference-batch-size 256 \ 18 | --n-sample-neighbor 2 19 | -------------------------------------------------------------------------------- /src/companykg/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | 9 | import requests 10 | 11 | from companykg.settings import CHUNK_SIZE 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def download_zenodo_file(uri: str, dest_path: str) -> None: 17 | """Zenodo file downloader that maintains O(1) memory consumption""" 18 | logger.info(f"Downloading {uri} to {dest_path}") 19 | with requests.get(uri, stream=True) as r: 20 | r.raise_for_status() 21 | with open(dest_path, "wb") as f: 22 | for chunk in r.iter_content(chunk_size=CHUNK_SIZE): 23 | f.write(chunk) 24 | logger.info("...[DONE]") 25 | -------------------------------------------------------------------------------- /benchmarks/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | from setuptools import setup 8 | 9 | setup( 10 | name="companykg-benchmarks", 11 | version="1.0", 12 | package_dir={"": "src"}, 13 | include_package_data=True, 14 | description="Company Knowledge Graph benchmarking utilities", 15 | author="EQT Motherbrain", 16 | install_requires=[ 17 | "companykg", 18 | "numpy", 19 | "scipy", 20 | "pandas", 21 | "torch", 22 | "dgl", 23 | "scikit-learn", 24 | "igraph", 25 | "click", 26 | "tqdm", 27 | "torch", 28 | "torch-scatter", 29 | "torch-sparse", 30 | "torch-cluster", 31 | "torch-spline-conv", 32 | "torch-geometric", 33 | "deepsnap", 34 | "DGL", 35 | "PyGCL @ git+https://github.com/ivanustyuzhEQT/PyGCL", 36 | "numba", 37 | ], 38 | dependency_links=["https://data.dgl.ai/wheels/repo.html"], 39 | ) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EQT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/companykg/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from companykg import CompanyKG 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "embeddings_path", 17 | help="Path to a pyTorch tensor file containing embeddings to be evaluated" 18 | ) 19 | parser.add_argument( 20 | "--data-root-folder", 21 | default="./data", 22 | type=str, 23 | help="The root folder where the CompanyKG data is downloaded to", 24 | ) 25 | parser.add_argument( 26 | "--output", 27 | default="./eval_results.json", 28 | type=str, 29 | help="File path to output evaluation results to as JSON", 30 | ) 31 | opts = parser.parse_args() 32 | 33 | root_logger = logging.getLogger() 34 | root_logger.setLevel(logging.INFO) 35 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 36 | console_handler = logging.StreamHandler() 37 | console_handler.setFormatter(log_formatter) 38 | root_logger.addHandler(console_handler) 39 | 40 | # Load the CKG data 41 | # Node feature type and edge weights, as we only use it for evaluation 42 | logger.info(f"Loading CompanyKG from {opts.data_root_folder}") 43 | ckg = CompanyKG( 44 | data_root_folder=opts.data_root_folder, 45 | ) 46 | logger.info(f"Loading embeddings for evaluation from {opts.embeddings_path}") 47 | embed = torch.load(opts.embeddings_path) 48 | 49 | # Check we've got an embedding for every node 50 | if embed.shape[0] != ckg.n_nodes: 51 | raise ValueError(f"number of embeddings ({embed.shape[0]}) does not match number of " 52 | f"nodes in KG ({ckg.n_nodes})") 53 | 54 | logger.info("Running evaluation") 55 | # This will output the results to stdout 56 | results = ckg.evaluate(embed=embed, silent=False) 57 | 58 | eval_results_path = Path(opts.output) 59 | with eval_results_path.open("w") as f: 60 | json.dump(results, f) 61 | logger.info(f"Evaluation results are exported to {eval_results_path}") 62 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/gcl/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | 9 | import torch 10 | from torch_geometric.loader import NeighborLoader 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class BaseEncoder(torch.nn.Module): 16 | def __init__(self, hparams): 17 | super().__init__() 18 | self.hparams = hparams 19 | # All encoders should have a num_layers param 20 | self.num_layers = self.hparams["num_layers"] 21 | 22 | def encode_batch(self, node_embeddings, edges, batch_size): 23 | """ 24 | Pass node embeddings through the GNN to transform to a new 25 | set of embeddings. 26 | 27 | This processes a single batch. The node embeddings need 28 | to include all nodes in the subgraph relevant to encoding 29 | this batch and the edges define the subgraph for the batch. 30 | The batch_size is given to specify which nodes (at the 31 | beginning of node_embeddings) we want to get an encoding for. 32 | 33 | :param node_embeddings: Input embeddings for each node in the graph 34 | :param edges: Edges in the graph 35 | :param batch_size: Number of nodes we're getting embeddings for 36 | :return: New embeddings for each node as a PyTorch tensor 37 | """ 38 | raise NotImplementedError("subclass should implement encode()") 39 | 40 | def encode(self, pyg_graph, batch_size=64, sample_edges=50, device=None): 41 | # We can't typically load a full graph into memory, so we fetch 42 | # batches, making sure we provide the necessary level of neighbours 43 | # for each included sample 44 | data_loader = NeighborLoader( 45 | pyg_graph, 46 | num_neighbors=[sample_edges] * self.num_layers, 47 | batch_size=batch_size, 48 | ) 49 | encodings = [] 50 | for batch in data_loader: 51 | if device is not None: 52 | batch.to(device) 53 | encodings.append( 54 | self.encode_batch(batch.x, batch.edge_index, batch_size).detach().cpu() 55 | ) 56 | all_encodings = torch.cat(encodings) 57 | return all_encodings 58 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import argparse 8 | import json 9 | 10 | import click 11 | import random 12 | from typing import Any, Callable 13 | 14 | import numpy as np 15 | import torch 16 | 17 | 18 | def set_random_seed(seed: int) -> None: 19 | """Set all relevant random seeds 20 | 21 | Args: 22 | seed (int): the seed to be used 23 | """ 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.backends.cudnn.determinstic = True 30 | 31 | 32 | def ranged_type(value_type: type, min_value: Any, max_value: Any) -> Callable: 33 | """Return function handle of an argument type function for ArgumentParser checking a range: 34 | min_value <= arg <= max_value 35 | 36 | Args: 37 | value_type (type): value-type to convert arg to 38 | min_value (Any): minimum acceptable argument 39 | max_value (Any): maximum acceptable argument 40 | 41 | Returns: 42 | function: function handle of an argument type function for ArgumentParser 43 | """ 44 | 45 | def range_checker(arg: str): 46 | try: 47 | f = value_type(arg) 48 | except ValueError: 49 | raise argparse.ArgumentTypeError(f"must be a valid {value_type}") 50 | if f < min_value or f > max_value: 51 | raise argparse.ArgumentTypeError( 52 | f"must be within [{min_value}, {max_value}]" 53 | ) 54 | return f 55 | 56 | # Return function handle to checking function 57 | return range_checker 58 | 59 | 60 | class JsonDictParamType(click.ParamType): 61 | name = "json" 62 | 63 | def convert(self, value, param, ctx): 64 | if isinstance(value, str): 65 | try: 66 | proc_val = json.loads(value) 67 | except json.decoder.JSONDecodeError as e: 68 | self.fail(f"{value!r} is not valid JSON: {e}", param, ctx) 69 | else: 70 | if not isinstance(proc_val, dict): 71 | self.fail(f"JSON parameter '{value!r}' is not a dict", param, ctx) 72 | else: 73 | return proc_val 74 | else: 75 | return value 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # mac 132 | .DS_Store 133 | 134 | # data 135 | *.parquet.gz 136 | *.pt 137 | 138 | # experiment related 139 | experiments -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/gcl/grace.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | 9 | import GCL.augmentors as A 10 | import torch 11 | import torch.nn.functional as F 12 | from torch_geometric.nn import GCNConv 13 | 14 | from ckg_benchmarks.gcl.base import BaseEncoder 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class GConv(torch.nn.Module): 20 | def __init__(self, input_dim, hidden_dim, activation, num_layers): 21 | super(GConv, self).__init__() 22 | self.activation = activation() 23 | self.layers = torch.nn.ModuleList() 24 | self.layers.append(GCNConv(input_dim, hidden_dim, cached=False)) 25 | for _ in range(num_layers - 1): 26 | self.layers.append(GCNConv(hidden_dim, hidden_dim, cached=False)) 27 | 28 | def forward(self, x, edge_index, edge_weight=None): 29 | z = x 30 | for i, conv in enumerate(self.layers): 31 | z = conv(z, edge_index, edge_weight) 32 | z = self.activation(z) 33 | return z 34 | 35 | 36 | class GraceEncoder(BaseEncoder): 37 | def __init__(self, input_dim, hidden_dim, num_layers, proj_dim): 38 | super(GraceEncoder, self).__init__( 39 | { 40 | "input_dim": input_dim, 41 | "hidden_dim": hidden_dim, 42 | "num_layers": num_layers, 43 | "proj_dim": proj_dim, 44 | } 45 | ) 46 | 47 | aug1 = A.Compose([A.EdgeRemoving(pe=0.3), A.FeatureMasking(pf=0.3)]) 48 | aug2 = A.Compose([A.EdgeRemoving(pe=0.3), A.FeatureMasking(pf=0.3)]) 49 | 50 | gconv = GConv( 51 | input_dim=input_dim, 52 | hidden_dim=hidden_dim, 53 | activation=torch.nn.ReLU, 54 | num_layers=num_layers, 55 | ) 56 | self.encoder = gconv 57 | self.augmentor = (aug1, aug2) 58 | 59 | self.fc1 = torch.nn.Linear(hidden_dim, proj_dim) 60 | self.fc2 = torch.nn.Linear(proj_dim, hidden_dim) 61 | 62 | @staticmethod 63 | def from_hparams(hparams): 64 | return GraceEncoder( 65 | hparams["input_dim"], 66 | hparams["hidden_dim"], 67 | hparams["num_layers"], 68 | hparams["proj_dim"], 69 | ) 70 | 71 | def forward(self, x, edge_index, edge_weight=None): 72 | aug1, aug2 = self.augmentor 73 | x1, edge_index1, edge_weight1 = aug1(x, edge_index, edge_weight) 74 | x2, edge_index2, edge_weight2 = aug2(x, edge_index, edge_weight) 75 | z = self.encoder(x, edge_index, edge_weight) 76 | z1 = self.encoder(x1, edge_index1, edge_weight1) 77 | z2 = self.encoder(x2, edge_index2, edge_weight2) 78 | return z, z1, z2 79 | 80 | def project(self, z: torch.Tensor) -> torch.Tensor: 81 | z = F.elu(self.fc1(z)) 82 | return self.fc2(z) 83 | 84 | def encode_batch(self, node_embeddings, edges, batch_size): 85 | with torch.no_grad(): 86 | vectors = self.encoder(node_embeddings, edges)[:batch_size] 87 | return vectors.detach() 88 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphsage/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import dgl 8 | import torch 9 | 10 | 11 | class GraphSAGE(torch.nn.Module): 12 | """The GraphSAGE model class""" 13 | 14 | def __init__( 15 | self, n_layer: int, in_feats: int, h_feats: int, dropout: float 16 | ) -> None: 17 | """Initializer of GraphSAGE model. 18 | 19 | Args: 20 | n_layer (int): the number of GNN layers. 21 | in_feats (int): the dimension of input feature. 22 | h_feats (int): the dimension of the graph embedding to be trained. 23 | dropout (float): the drop out rate of GNN layers. 24 | """ 25 | super(GraphSAGE, self).__init__() 26 | self.n_layer = n_layer 27 | self.h_feats = h_feats 28 | self.gcn_layers = torch.nn.ModuleList() 29 | for i in range(n_layer): 30 | if i == 0: 31 | self.gcn_layers.append( 32 | dgl.nn.SAGEConv( 33 | in_feats, h_feats, aggregator_type="gcn", feat_drop=dropout 34 | ) 35 | ) 36 | elif i == n_layer - 1: 37 | self.gcn_layers.append( 38 | dgl.nn.SAGEConv(h_feats, h_feats, aggregator_type="gcn") 39 | ) 40 | else: 41 | self.gcn_layers.append( 42 | dgl.nn.SAGEConv( 43 | h_feats, h_feats, aggregator_type="gcn", feat_drop=dropout 44 | ) 45 | ) 46 | 47 | def forward(self, mfgs: list, x: torch.Tensor) -> torch.Tensor: 48 | """The forward propagation function of the model. 49 | 50 | Args: 51 | mfgs (list): the Message-passing Flow Graphs (MFGs). 52 | x (torch.Tensor): the input feature tensor. 53 | 54 | Returns: 55 | torch.Tensor: the output tensor of the forward pass. 56 | """ 57 | for i in range(self.n_layer): 58 | if i == 0: 59 | h_dst = x[: mfgs[i].num_dst_nodes()] 60 | h = self.gcn_layers[i](mfgs[i], (x, h_dst)) 61 | h = torch.nn.functional.leaky_relu(h) 62 | elif i == self.n_layer - 1: 63 | h_dst = h[: mfgs[i].num_dst_nodes()] 64 | h = self.gcn_layers[i](mfgs[i], (h, h_dst)) 65 | else: 66 | h_dst = h[: mfgs[i].num_dst_nodes()] 67 | h = self.gcn_layers[i](mfgs[i], (h, h_dst)) 68 | h = torch.nn.functional.leaky_relu(h) 69 | return h 70 | 71 | 72 | class DotPredictor(torch.nn.Module): 73 | """A pairwise predictor implemented with dot product""" 74 | 75 | def forward(self, g: dgl.DGLGraph, h: torch.Tensor) -> torch.Tensor: 76 | """The forward pass of DotPredictor. 77 | 78 | Args: 79 | g (dgl.DGLGraph): the input graph. 80 | h (torch.Tensor): the input node feature. 81 | 82 | Returns: 83 | torch.Tensor: the output scores. 84 | """ 85 | with g.local_scope(): 86 | g.ndata["h"] = h 87 | g.apply_edges(dgl.function.u_dot_v("h", "h", "score")) 88 | return g.edata["score"][:, 0] 89 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/gcl/mvgrl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | 9 | import GCL.augmentors as A 10 | import torch 11 | from torch import nn 12 | from torch_geometric.nn import GCNConv 13 | from torch_geometric.nn.inits import uniform 14 | 15 | from ckg_benchmarks.gcl.base import BaseEncoder 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class GConv(nn.Module): 21 | def __init__(self, input_dim, hidden_dim, num_layers): 22 | super(GConv, self).__init__() 23 | self.layers = torch.nn.ModuleList() 24 | self.activation = nn.PReLU(hidden_dim) 25 | for i in range(num_layers): 26 | if i == 0: 27 | self.layers.append(GCNConv(input_dim, hidden_dim)) 28 | else: 29 | self.layers.append(GCNConv(hidden_dim, hidden_dim)) 30 | 31 | def forward(self, x, edge_index, edge_weight=None): 32 | z = x 33 | for conv in self.layers: 34 | z = conv(z, edge_index, edge_weight) 35 | z = self.activation(z) 36 | return z 37 | 38 | 39 | class MvgrlEncoder(BaseEncoder): 40 | def __init__(self, input_dim, hidden_dim, num_layers): 41 | super(MvgrlEncoder, self).__init__( 42 | { 43 | "input_dim": input_dim, 44 | "hidden_dim": hidden_dim, 45 | "num_layers": num_layers, 46 | } 47 | ) 48 | 49 | aug1 = A.Identity() 50 | aug2 = A.PPRDiffusion(alpha=0.2) 51 | gconv1 = GConv( 52 | input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers 53 | ) 54 | gconv2 = GConv( 55 | input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers 56 | ) 57 | 58 | self.encoder1 = gconv1 59 | self.encoder2 = gconv2 60 | self.augmentor = (aug1, aug2) 61 | self.project = torch.nn.Linear(hidden_dim, hidden_dim) 62 | uniform(hidden_dim, self.project.weight) 63 | 64 | @staticmethod 65 | def from_hparams(hparams): 66 | return MvgrlEncoder( 67 | hparams["input_dim"], 68 | hparams["hidden_dim"], 69 | hparams["num_layers"], 70 | ) 71 | 72 | @staticmethod 73 | def corruption(x, edge_index, edge_weight): 74 | return x[torch.randperm(x.size(0))], edge_index, edge_weight 75 | 76 | def forward(self, x, edge_index, batch_size, edge_weight=None): 77 | aug1, aug2 = self.augmentor 78 | x1, edge_index1, edge_weight1 = aug1(x, edge_index, edge_weight) 79 | x2, edge_index2, edge_weight2 = aug2(x, edge_index, edge_weight) 80 | z1 = self.encoder1(x1, edge_index1, edge_weight1)[:batch_size] 81 | z2 = self.encoder2(x2, edge_index2, edge_weight2)[:batch_size] 82 | g1 = self.project(torch.sigmoid(z1.mean(dim=0, keepdim=True))) 83 | g2 = self.project(torch.sigmoid(z2.mean(dim=0, keepdim=True))) 84 | z1n = self.encoder1(*self.corruption(x1, edge_index1, edge_weight1))[ 85 | :batch_size 86 | ] 87 | z2n = self.encoder2(*self.corruption(x2, edge_index2, edge_weight2))[ 88 | :batch_size 89 | ] 90 | return z1, z2, g1, g2, z1n, z2n 91 | 92 | def encode_batch(self, node_embeddings, edges, batch_size): 93 | # TODO Add edge weight here 94 | with torch.no_grad(): 95 | vectors = self.encoder1(node_embeddings, edges)[:batch_size] 96 | return vectors.detach() 97 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | import os 9 | from abc import ABC, abstractmethod 10 | from pathlib import Path 11 | from typing import Union, Optional 12 | 13 | import torch 14 | 15 | from ckg_benchmarks.utils import set_random_seed 16 | from companykg import CompanyKG 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class BaseTrainer(ABC): 22 | training_method_name = "unknown" 23 | 24 | def __init__( 25 | self, 26 | nodes_feature_type: str = "msbert", 27 | data_root_folder: str = "./data", 28 | seed: int = 42, 29 | device: Union[int, str] = 0, 30 | work_folder: str = "./experiments", 31 | finetune_from: Optional[str] = None, 32 | allow_retrain: bool = False, 33 | ): 34 | self.allow_retrain = allow_retrain 35 | self.nodes_feature_type = nodes_feature_type 36 | self.data_root_folder = data_root_folder 37 | self.seed = seed 38 | self.work_folder = Path(work_folder) / self.training_method_name 39 | self.device = device if device >= 0 else "cpu" 40 | self.finetune_from = finetune_from 41 | 42 | # Create CompanyKG object 43 | self.comkg = self.load_companykg() 44 | self.comkg.describe() 45 | 46 | # Create DGL graph 47 | self.graph = self.build_graph() 48 | logger.info(self.graph) 49 | 50 | # File path to which the final model will be saved 51 | self.eval_results_path = self.work_folder / f"{self.hparams_str}.pkl" 52 | self.embedding_save_path = self.work_folder / f"{self.hparams_str}.pt" 53 | 54 | self.model, self.optimizer, self.scheduler = self.init_model( 55 | load=self.finetune_from 56 | ) 57 | 58 | # Set at the end of training based on the final model 59 | self.embeddings = None 60 | 61 | def load_companykg(self): 62 | """ 63 | Load CompanyKG dataset in preparation for training 64 | 65 | """ 66 | return CompanyKG( 67 | nodes_feature_type=self.nodes_feature_type, 68 | load_edges_weights=False, 69 | data_root_folder=self.data_root_folder, 70 | ) 71 | 72 | @property 73 | @abstractmethod 74 | def hparams_str(self): 75 | ... 76 | 77 | @abstractmethod 78 | def build_graph(self): 79 | ... 80 | 81 | @abstractmethod 82 | def init_model(self, load=None): 83 | ... 84 | 85 | @abstractmethod 86 | def _train_model(self): 87 | ... 88 | 89 | @abstractmethod 90 | def inference(self) -> torch.Tensor: 91 | ... 92 | 93 | def train_model(self): 94 | set_random_seed(self.seed) 95 | self.model.to(self.device) 96 | 97 | # Check if the current trial has been run already 98 | # Allow this check to be overridden 99 | if not self.allow_retrain and self.eval_results_path.exists(): 100 | logger.info(f"Skip trial {self.hparams_str}: has already been run") 101 | return 102 | 103 | # Logging: send a copy to the output folder 104 | os.makedirs(self.work_folder, exist_ok=True) 105 | log_path = os.path.join(self.work_folder, f"{self.hparams_str}.log") 106 | logger.info(f"Sending training logs to {log_path}") 107 | file_handler = logging.FileHandler(log_path, mode="a") 108 | log_formatter = logging.Formatter( 109 | "%(asctime)s [%(levelname)-5.5s] %(message)s" 110 | ) 111 | file_handler.setFormatter(log_formatter) 112 | # Add handler to the root logger 113 | logging.getLogger().addHandler(file_handler) 114 | 115 | logger.info("Strating model training") 116 | self._train_model() 117 | logger.info("Model training complete") 118 | 119 | # Inference with best model 120 | logger.info("Projecting full KG using final model") 121 | embed = self.inference() 122 | torch.save(embed, self.embedding_save_path) 123 | logger.info(f"Best embeddings saved to {self.embedding_save_path}") 124 | # Keep the embeddings for later use 125 | self.embeddings = embed 126 | # You can evaluate these using: 127 | # results = trainer.comkg.evaluate(embed=trainer.embeddings, silent=True) 128 | 129 | def evaluate(self, silent=False): 130 | if self.embeddings is None: 131 | raise RuntimeError( 132 | "projected embeddings are not available for evaluation: training must be run first" 133 | ) 134 | results = self.comkg.evaluate(embed=self.embeddings, silent=silent) 135 | return results 136 | 137 | @classmethod 138 | def train(cls, **kwargs): 139 | """ 140 | Initialize a trainer and run the training routine, returning the trainer. 141 | This also provides the trained model as `trainer.model`. 142 | 143 | :return: trainer instance 144 | """ 145 | logger.info("Initializing model and trainer") 146 | trainer = cls(**kwargs) 147 | logger.info("Starting model training") 148 | trainer.train_model() 149 | logger.info("Model training complete") 150 | return trainer 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

⚠️ Repository Upgraded and Migrated to Version 2.x ⚠️

3 |

This repository corresponds to CompanyKG Version 1.x. We have extended this work to Version 2.x, hosted in a new repository. Since 2.x is backward compatible, we recommend submitting issues and pull requests for both Version 1.x and 2.x to the CompanyKG2 repository.

4 |
5 | 6 | --- 7 | 8 | ![CompanyKG Logo](./picture/logo_400x113.png) 9 | 10 |
11 |

12 | Dataset • 13 | Tutorial • 14 | External Results • 15 | Arxiv • 16 | IEEE Transactions on Big Data • 17 | Citation 18 |

19 | 20 |
21 | 22 | [![version](https://img.shields.io/badge/Version-v1.1-green)](https://github.com/EQTPartners/CompanyKG/releases/tag/1.1) 23 | [![python](https://img.shields.io/badge/Python-3.8-yellow)](https://www.python.org/doc/versions/) 24 | [![python](https://img.shields.io/badge/Motherbrain-Research-orange)](https://motherbrain.ai/) 25 | 26 | This repository contains all code released to accompany the release of the CompanyKG 27 | knowledge graph illustrated in Figure 1 below. 28 | For details of the dataset and benchmark experiments, see the official release of the [paper](https://arxiv.org/abs/2306.10649) and [dataset](https://zenodo.org/record/8010239). 29 | 30 | ![CompanyKG Illustration](./picture/companykg_illustration.png) 31 | 32 | There are two main parts to the code release: 33 | * CompanyKG dataset access and task evaluations (see below) 34 | * [Benchmark model training and experiments](./benchmarks/README.md) 35 | 36 | 37 | 38 | ## Pre-Requisites 39 | 40 | * Python 3.8 41 | 42 | There are also optional dependencies, if you want to be able to convert the KG 43 | to one of the data structures used by these packages: 44 | 45 | * [DGL](https://pypi.org/project/dgl/): `dgl` 46 | * [iGraph](https://pypi.org/project/python-igraph/): `python-igraph` 47 | * [PyTorch Geometric (PyG)](https://pypi.org/project/torch-geometric/): `torch-geometric` 48 | 49 | 50 | ## Setup 51 | 52 | The `companykg` Python package provides a data structure to load CompanyKG into memory, 53 | convert between different graph representations and run evaluation of trained embeddings 54 | or other company-ranking predictors on three evaluation tasks. 55 | 56 | To install the `comapnykg` package and its Python dependencies, activate a virtual 57 | environment (such as [Virtualenv](https://github.com/pyenv/pyenv-virtualenv) or Conda) and run: 58 | 59 | ```bash 60 | pip install -e . 61 | ``` 62 | 63 | The first time you instantiate the CompanyKG class, if the dataset is not already available 64 | (in the default subdirectory or another location you specify), the latest version will be automatically 65 | downloaded from Zenodo. 66 | 67 | 68 | ## Basic usage 69 | 70 | By default, the CompanyKG dataset will be loaded from (and, if necessary, downloaded to) 71 | a `data` subdirectory of the working directory. To load the dataset from this default location, 72 | simply instantiate the `CompanyKG` class: 73 | ```python 74 | from companykg import CompanyKG 75 | 76 | ckg = CompanyKG() 77 | ``` 78 | 79 | If you have already downloaded the dataset and want to 80 | load it from its current location, specify the path: 81 | ```python 82 | ckg = CompanyKG(data_root_folder="/path/to/stored/companykg/directory") 83 | ``` 84 | 85 | The graph can be loaded with different vector representations (embeddings) of 86 | company description data associated with the nodes: `msbert` (mSBERT), `simcse`(SimCSE), 87 | `ada2` (ADA2) or `pause` (PAUSE). 88 | 89 | ```python 90 | ckg = CompanyKG(nodes_feature_type="pause") 91 | ``` 92 | 93 | If you want to experiment with different embedding types, you can also load embeddings 94 | of a different type into an already-loaded graph: 95 | 96 | ```python 97 | ckg.change_feature_type("simcse") 98 | ``` 99 | 100 | By default, edge weights are not loaded into the graph. To change this use: 101 | ```python 102 | ckg = CompanyKG(load_edges_weights=True) 103 | ``` 104 | 105 | 106 | A tutorial showing further ways to use CompanyKG is [here](./tutorials/tutorial.ipynb). 107 | 108 | 109 | ## Training benchmark models 110 | 111 | Implementations of various benchmark graph-based learning models are provided in this repository. 112 | 113 | To use them, install the `ckg_benchmarks` Python package, along with its dependencies, from the 114 | `benchmarks` subdirectory. First install `companykg` as above and then: 115 | 116 | ```bash 117 | cd benchmarks 118 | pip install -e . 119 | ``` 120 | 121 | Further instructions for using the benchmarks package for model training and provided in 122 | the [benchmarks README file](./benchmarks/README.md). 123 | 124 | 125 | ## External Results 126 | We collect all benchmarking results on this dataset here. Welcome to reach out to us (via github issue or [email shown in our paper](https://arxiv.org/pdf/2306.10649.pdf)) if you wish to include your experimental results. 127 | - [Knorreman](https://github.com/Knorreman/fastRP) reported results using [fastRP algorithm](https://arxiv.org/pdf/1908.11512.pdf) achieving [competitive results](https://github.com/EQTPartners/CompanyKG/issues/1#issuecomment-1749707045) (i.e., `sp_auc=85.7%`, `sr_test_acc=69.2%`, `R@50=0.353`, and `R@100=0.430` obtained on different hyper-parameters and initial node embeddings). 128 | 129 | 130 | ## Cite This Work 131 | 132 | Cite the [paper](https://arxiv.org/abs/2306.10649): 133 | ```bibtex 134 | @article{cao2023companykg, 135 | author = {Lele Cao and 136 | Vilhelm von Ehrenheim and 137 | Mark Granroth-Wilding and 138 | Richard Anselmo Stahl and 139 | Drew McCornack and 140 | Armin Catovic and 141 | Dhiana Deva Cavacanti Rocha}, 142 | title = {{CompanyKG: A Large-Scale Heterogeneous Graph for Company Similarity Quantification}}, 143 | journal = {IEEE Transactions on Big Data}, 144 | year = {2024}, 145 | doi = {10.1109/TBDATA.2024.3407573} 146 | } 147 | ``` 148 | 149 | Cite the [official release of the CompanyKG dataset on Zenodo](https://zenodo.org/record/8010239): 150 | ```bibtex 151 | @article{companykg_2023_8010239, 152 | author = {Lele Cao and 153 | Vilhelm von Ehrenheim and 154 | Mark Granroth-Wilding and 155 | Richard Anselmo Stahl and 156 | Drew McCornack and 157 | Armin Catovic and 158 | Dhiana Deva Cavacanti Rocha}, 159 | title = {{CompanyKG Dataset: A Large-Scale Heterogeneous Graph for Company Similarity Quantification}}, 160 | month = June, 161 | year = 2023, 162 | publisher = {Zenodo}, 163 | version = {1.1}, 164 | doi = {10.5281/zenodo.8010239}, 165 | url = {https://doi.org/10.5281/zenodo.8010239} 166 | } 167 | ``` 168 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphmae/model.py: -------------------------------------------------------------------------------- 1 | # This implementation is adapted from GraphMAE 2 | # https://github.com/THUDM/GraphMAE 3 | 4 | from functools import partial 5 | from itertools import chain 6 | 7 | import dgl 8 | import numpy as np 9 | import torch 10 | from ckg_benchmarks.graphmae.gat import GAT 11 | 12 | 13 | def sce_loss(x, y, alpha=3): 14 | x = torch.nn.functional.normalize(x, p=2, dim=-1) 15 | y = torch.nn.functional.normalize(y, p=2, dim=-1) 16 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 17 | return loss.mean() 18 | 19 | 20 | def mask_edge(graph, mask_prob): 21 | E = graph.num_edges() 22 | mask_rates = torch.FloatTensor(np.ones(E) * mask_prob) 23 | masks = torch.bernoulli(1 - mask_rates) 24 | mask_idx = masks.nonzero().squeeze(1) 25 | return mask_idx 26 | 27 | 28 | def drop_edge(graph, drop_rate, return_edges=False): 29 | if drop_rate <= 0: 30 | return graph 31 | 32 | n_node = graph.num_nodes() 33 | edge_mask = mask_edge(graph, drop_rate) 34 | src = graph.edges()[0] 35 | dst = graph.edges()[1] 36 | 37 | nsrc = src[edge_mask] 38 | ndst = dst[edge_mask] 39 | 40 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 41 | ng = ng.add_self_loop() 42 | 43 | dsrc = src[~edge_mask] 44 | ddst = dst[~edge_mask] 45 | 46 | if return_edges: 47 | return ng, (dsrc, ddst) 48 | return ng 49 | 50 | 51 | class GraphMAE(torch.nn.Module): 52 | def __init__( 53 | self, 54 | in_dim: int, 55 | num_hidden: int, 56 | num_layers: int, 57 | feat_drop: float, 58 | attn_drop: float, 59 | nhead: int, 60 | nhead_out: int = 1, 61 | mask_rate: float = 0.5, 62 | drop_edge_rate: float = 0.5, 63 | replace_rate: float = 0.15, 64 | alpha_l: float = 3, 65 | concat_hidden: bool = False, 66 | ): 67 | super(GraphMAE, self).__init__() 68 | self._mask_rate = mask_rate 69 | self._drop_edge_rate = drop_edge_rate 70 | self._output_hidden_size = num_hidden 71 | self._concat_hidden = concat_hidden 72 | self._norm = torch.nn.LayerNorm # fix to LayerNorm 73 | 74 | self._replace_rate = replace_rate 75 | self._mask_token_rate = 1 - self._replace_rate 76 | 77 | assert num_hidden % nhead == 0 78 | assert num_hidden % nhead_out == 0 79 | 80 | enc_num_hidden = num_hidden // nhead 81 | enc_nhead = nhead 82 | 83 | dec_in_dim = num_hidden 84 | dec_num_hidden = num_hidden // nhead_out 85 | 86 | # build encoder 87 | self.encoder = GAT( 88 | in_dim=in_dim, 89 | num_hidden=enc_num_hidden, 90 | out_dim=enc_num_hidden, 91 | num_layers=num_layers, 92 | nhead=enc_nhead, 93 | nhead_out=enc_nhead, 94 | concat_out=True, 95 | feat_drop=feat_drop, 96 | attn_drop=attn_drop, 97 | negative_slope=0.2, 98 | residual=True, 99 | norm=self._norm, 100 | encoding=True, 101 | ) 102 | 103 | # build decoder for attribute prediction 104 | self.decoder = GAT( 105 | in_dim=dec_in_dim, 106 | num_hidden=dec_num_hidden, 107 | out_dim=in_dim, 108 | num_layers=1, 109 | nhead=nhead, 110 | nhead_out=nhead_out, 111 | concat_out=True, 112 | feat_drop=feat_drop, 113 | attn_drop=attn_drop, 114 | negative_slope=0.2, 115 | residual=True, 116 | norm=self._norm, 117 | encoding=False, 118 | ) 119 | 120 | self.enc_mask_token = torch.nn.Parameter(torch.zeros(1, in_dim)) 121 | if concat_hidden: 122 | self.encoder_to_decoder = torch.nn.Linear( 123 | dec_in_dim * num_layers, dec_in_dim, bias=False 124 | ) 125 | else: 126 | self.encoder_to_decoder = torch.nn.Linear( 127 | dec_in_dim, dec_in_dim, bias=False 128 | ) 129 | 130 | # * setup loss function 131 | self.criterion = self.setup_loss_fn(alpha_l) 132 | 133 | @property 134 | def output_hidden_dim(self): 135 | return self._output_hidden_size 136 | 137 | def setup_loss_fn(self, alpha_l): 138 | return partial(sce_loss, alpha=alpha_l) 139 | 140 | def encoding_mask_noise(self, g, x, mask_rate=0.3): 141 | num_nodes = g.num_nodes() 142 | perm = torch.randperm(num_nodes, device=x.device) 143 | num_mask_nodes = int(mask_rate * num_nodes) 144 | 145 | # random masking 146 | num_mask_nodes = int(mask_rate * num_nodes) 147 | mask_nodes = perm[:num_mask_nodes] 148 | keep_nodes = perm[num_mask_nodes:] 149 | 150 | if self._replace_rate > 0: 151 | num_noise_nodes = int(self._replace_rate * num_mask_nodes) 152 | perm_mask = torch.randperm(num_mask_nodes, device=x.device) 153 | token_nodes = mask_nodes[ 154 | perm_mask[: int(self._mask_token_rate * num_mask_nodes)] 155 | ] 156 | noise_nodes = mask_nodes[ 157 | perm_mask[-int(self._replace_rate * num_mask_nodes) :] 158 | ] 159 | noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[ 160 | :num_noise_nodes 161 | ] 162 | 163 | out_x = x.clone() 164 | out_x[token_nodes] = 0.0 165 | out_x[noise_nodes] = x[noise_to_be_chosen] 166 | else: 167 | out_x = x.clone() 168 | token_nodes = mask_nodes 169 | out_x[mask_nodes] = 0.0 170 | 171 | out_x[token_nodes] += self.enc_mask_token 172 | use_g = g.clone() 173 | 174 | return use_g, out_x, (mask_nodes, keep_nodes) 175 | 176 | def forward(self, g, x): 177 | # ---- attribute reconstruction ---- 178 | loss = self.mask_attr_prediction(g, x) 179 | loss_item = {"loss": loss.item()} 180 | return loss, loss_item 181 | 182 | def mask_attr_prediction(self, g, x): 183 | pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise( 184 | g, x, self._mask_rate 185 | ) 186 | 187 | if self._drop_edge_rate > 0: 188 | use_g, masked_edges = drop_edge( 189 | pre_use_g, self._drop_edge_rate, return_edges=True 190 | ) 191 | else: 192 | use_g = pre_use_g 193 | 194 | enc_rep, all_hidden = self.encoder(use_g, use_x, return_hidden=True) 195 | if self._concat_hidden: 196 | enc_rep = torch.cat(all_hidden, dim=1) 197 | 198 | # ---- attribute reconstruction ---- 199 | rep = self.encoder_to_decoder(enc_rep) 200 | 201 | recon = self.decoder(pre_use_g, rep) 202 | 203 | x_init = x[mask_nodes] 204 | x_rec = recon[mask_nodes] 205 | 206 | loss = self.criterion(x_rec, x_init) 207 | return loss 208 | 209 | def embed(self, g, x): 210 | rep = self.encoder(g, x) 211 | return rep 212 | 213 | @property 214 | def enc_params(self): 215 | return self.encoder.parameters() 216 | 217 | @property 218 | def dec_params(self): 219 | return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) 220 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/egraphmae/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | from functools import partial 8 | from itertools import chain 9 | 10 | import dgl 11 | import torch 12 | from ckg_benchmarks.egraphmae.egat import EGAT 13 | from ckg_benchmarks.graphmae.model import mask_edge, sce_loss 14 | 15 | 16 | def drop_edge(graph, drop_rate, e=None, return_edges=False): 17 | 18 | if drop_rate <= 0: 19 | return graph 20 | 21 | n_node = graph.num_nodes() 22 | edge_mask = mask_edge(graph, drop_rate) 23 | src = graph.edges()[0] 24 | dst = graph.edges()[1] 25 | 26 | nsrc = src[edge_mask] 27 | ndst = dst[edge_mask] 28 | 29 | ng = dgl.graph((nsrc, ndst), num_nodes=n_node) 30 | 31 | if e is not None: 32 | ex = e[edge_mask] 33 | ng.edata["weight"] = ex 34 | else: 35 | ex = None 36 | 37 | ng = ng.add_self_loop() 38 | 39 | dsrc = src[~edge_mask] 40 | ddst = dst[~edge_mask] 41 | 42 | if e is not None: 43 | if return_edges: 44 | return ng, ng.edata["weight"], (dsrc, ddst) 45 | return ng, ng.edata["weight"] 46 | else: 47 | if return_edges: 48 | return ng, None, (dsrc, ddst) 49 | return ng, None 50 | 51 | 52 | class EGraphMAE(torch.nn.Module): 53 | def __init__( 54 | self, 55 | in_dim: int, 56 | num_hidden: int, 57 | num_layers: int, 58 | feat_drop: float, 59 | attn_drop: float, 60 | nhead: int, 61 | num_edge_features: int, 62 | num_edge_hidden: int, 63 | nhead_out: int = 1, 64 | mask_rate: float = 0.5, 65 | drop_edge_rate: float = 0.5, 66 | replace_rate: float = 0.15, 67 | alpha_l: float = 3, 68 | concat_hidden: bool = False, 69 | ): 70 | super(EGraphMAE, self).__init__() 71 | self._mask_rate = mask_rate 72 | self._drop_edge_rate = drop_edge_rate 73 | self._output_hidden_size = num_hidden 74 | self._concat_hidden = concat_hidden 75 | self._norm = torch.nn.LayerNorm # fix to LayerNorm 76 | 77 | self._replace_rate = replace_rate 78 | self._mask_token_rate = 1 - self._replace_rate 79 | 80 | assert num_hidden % nhead == 0 81 | assert num_hidden % nhead_out == 0 82 | 83 | enc_num_hidden = num_hidden // nhead 84 | enc_nhead = nhead 85 | enc_num_hidden_e = num_edge_hidden // nhead 86 | # enc_nhead_e = nhead 87 | 88 | dec_in_dim = num_hidden 89 | dec_num_hidden = num_hidden // nhead_out 90 | 91 | dec_in_dim_e = num_edge_hidden 92 | dec_num_hidden_e = num_edge_hidden // nhead_out 93 | 94 | # build encoder 95 | self.encoder = EGAT( 96 | in_dim=in_dim, 97 | num_hidden=enc_num_hidden, 98 | out_dim=enc_num_hidden, 99 | num_layers=num_layers, 100 | nhead=enc_nhead, 101 | nhead_out=enc_nhead, 102 | concat_out=True, 103 | feat_drop=feat_drop, 104 | attn_drop=attn_drop, 105 | negative_slope=0.2, 106 | residual=True, 107 | norm=self._norm, 108 | in_dim_e=num_edge_features, 109 | num_hidden_e=enc_num_hidden_e, 110 | out_dim_e=enc_num_hidden_e, 111 | encoding=True, 112 | ) 113 | 114 | # build decoder for attribute prediction 115 | self.decoder = EGAT( 116 | in_dim=dec_in_dim, 117 | num_hidden=dec_num_hidden, 118 | out_dim=in_dim, 119 | num_layers=1, 120 | nhead=nhead, 121 | nhead_out=nhead_out, 122 | feat_drop=feat_drop, 123 | attn_drop=attn_drop, 124 | negative_slope=0.2, 125 | residual=True, 126 | norm=self._norm, 127 | concat_out=True, 128 | in_dim_e=dec_in_dim_e, 129 | num_hidden_e=dec_num_hidden_e, 130 | out_dim_e=num_edge_features, 131 | encoding=False, 132 | ) 133 | 134 | self.enc_mask_token = torch.nn.Parameter(torch.zeros(1, in_dim)) 135 | if concat_hidden: 136 | self.encoder_to_decoder = torch.nn.Linear( 137 | dec_in_dim * num_layers, dec_in_dim, bias=False 138 | ) 139 | ## Do not bother to create self.encoder_to_decoder_e 140 | else: 141 | self.encoder_to_decoder = torch.nn.Linear( 142 | dec_in_dim, dec_in_dim, bias=False 143 | ) 144 | self.encoder_to_decoder_e = torch.nn.Linear( 145 | dec_in_dim_e, dec_in_dim_e, bias=False 146 | ) 147 | 148 | # * setup loss function 149 | self.criterion = self.setup_loss_fn(alpha_l) 150 | 151 | @property 152 | def output_hidden_dim(self): 153 | return self._output_hidden_size 154 | 155 | def setup_loss_fn(self, alpha_l): 156 | return partial(sce_loss, alpha=alpha_l) 157 | 158 | def encoding_mask_noise(self, g, x, mask_rate=0.3): 159 | num_nodes = g.num_nodes() 160 | perm = torch.randperm(num_nodes, device=x.device) 161 | num_mask_nodes = int(mask_rate * num_nodes) 162 | 163 | # random masking 164 | num_mask_nodes = int(mask_rate * num_nodes) 165 | mask_nodes = perm[:num_mask_nodes] 166 | keep_nodes = perm[num_mask_nodes:] 167 | 168 | if self._replace_rate > 0: 169 | num_noise_nodes = int(self._replace_rate * num_mask_nodes) 170 | perm_mask = torch.randperm(num_mask_nodes, device=x.device) 171 | token_nodes = mask_nodes[ 172 | perm_mask[: int(self._mask_token_rate * num_mask_nodes)] 173 | ] 174 | noise_nodes = mask_nodes[ 175 | perm_mask[-int(self._replace_rate * num_mask_nodes) :] 176 | ] 177 | noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[ 178 | :num_noise_nodes 179 | ] 180 | 181 | out_x = x.clone() 182 | out_x[token_nodes] = 0.0 183 | out_x[noise_nodes] = x[noise_to_be_chosen] 184 | else: 185 | out_x = x.clone() 186 | token_nodes = mask_nodes 187 | out_x[mask_nodes] = 0.0 188 | 189 | out_x[token_nodes] += self.enc_mask_token 190 | use_g = g.clone() 191 | 192 | return use_g, out_x, (mask_nodes, keep_nodes) 193 | 194 | def forward(self, g, x, e): 195 | # ---- attribute reconstruction ---- 196 | loss = self.mask_attr_prediction(g, x, e) 197 | loss_item = {"loss": loss.item()} 198 | return loss, loss_item 199 | 200 | def mask_attr_prediction(self, g, x, e): 201 | pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise( 202 | g, x, self._mask_rate 203 | ) 204 | 205 | if self._drop_edge_rate > 0: 206 | use_g, use_e, masked_edges = drop_edge( 207 | pre_use_g, self._drop_edge_rate, e, return_edges=True 208 | ) 209 | else: 210 | use_g = pre_use_g 211 | use_e = e 212 | 213 | enc_rep, enc_rep_e, all_hidden = self.encoder( 214 | use_g, use_x, use_e, return_hidden=True 215 | ) 216 | 217 | if self._concat_hidden: ## Should be false! 218 | enc_rep = torch.cat(all_hidden, dim=1) 219 | 220 | # ---- attribute reconstruction ---- 221 | rep = self.encoder_to_decoder(enc_rep) 222 | rep_e = self.encoder_to_decoder_e(enc_rep_e) 223 | 224 | rep[mask_nodes] = 0 225 | 226 | recon, recon_e = self.decoder(use_g, rep, rep_e) 227 | 228 | x_init = x[mask_nodes] 229 | x_rec = recon[mask_nodes] 230 | 231 | loss = self.criterion(x_rec, x_init) 232 | return loss 233 | 234 | def embed(self, g, x, e): 235 | rep = self.encoder(g, x, e) 236 | return rep 237 | 238 | @property 239 | def enc_params(self): 240 | return self.encoder.parameters() 241 | 242 | @property 243 | def dec_params(self): 244 | return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) 245 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # CompanyKG benchmarks 2 | 3 | The code in this directory, released as a PyPi package `companykg-benchmarks`, implements all model 4 | training are required code for experimentation that constitutes the initial benchmarking of the 5 | CompanyKG knowledge graph. 6 | 7 | 8 | ## Setup 9 | 10 | You will need Python>=3.8. 11 | 12 | All other dependencies needed to run the full model training and evaluation are 13 | covered by `setup.py`, so just follow instructions to install the `ckg_benchmarks` 14 | package in your virtual environment. 15 | 16 | Note that we also depend on the CompanyKG package `companykg`, the main package 17 | provided at the top level in this same repository. 18 | See [the main README file](../README.md) for instructions. 19 | 20 | 21 | Activate a Python virtual environment (such as 22 | [Virtualenv](https://github.com/pyenv/pyenv-virtualenv) or Conda). 23 | `cd` to the directory containing this `README` (`/benchmarks/` in the repository). 24 | Install the `companykg-benchmarks` package and its dependencies using: 25 | ```bash 26 | pip install -e . 27 | ``` 28 | 29 | Note that this package has a lot more dependencies than `companykg`, since it 30 | needs all the machine learning libraries used to train models. 31 | 32 | 33 | ## Benchmark models 34 | 35 | `companykg-benchmarks` provides training routines for the following graph learning-based 36 | graph node encoding models: 37 | * GRACE (GCL) 38 | * MVGRL (GCL) 39 | * GraphSAGE 40 | * GraphMAE 41 | * eGraphMAE 42 | 43 | Each model, once trained on the CompanyKG graph, can be used to produce a new embedding 44 | to represent each node of the graph. These can be used to measure the similarity of companies 45 | and thus applied to the three evaluation tasks: 46 | * Similarity Prediction (SP) 47 | * Similarity Ranking (SR) 48 | * Competitor Retrieval (CR) 49 | 50 | Each model is trained on the CompanyKG graph, loaded using the `companykg` package. Once 51 | this is downloaded, the training routines can be pointed to the local dataset using the 52 | `--data-root-folder` option. 53 | 54 | The models all have the same training interface and training can be run from the command line 55 | or programmatically, e.g. in a Jupyter notebook. We provide examples of both below. 56 | 57 | 58 | ### Command-line training 59 | 60 | You can run training from the command line in the following form (from within a 61 | virtual environment with the `companykg-benchmarks` package installed): 62 | ```bash 63 | python -m ckg_benchmarks..train ... 64 | ``` 65 | where `` is `gcl` (GRACE and MVGRL), `graphsage`, `graphmae` or `egraphmae`. 66 | 67 | The remaining options control data location, training options and model hyperparameters. 68 | See the `--help` for each command for more details. 69 | 70 | Examples of training commands for each model type are provided below and in the `tutorials` 71 | directory. 72 | 73 | ### Python training 74 | 75 | Each model's `train` module contains a `train_model` function that can be called to train 76 | and return a trainer instance, which includes the trained model. 77 | The function's keyword arguments match the options of the command-line 78 | interface. 79 | 80 | ```python 81 | from ckg_benchmarks..train import train_model 82 | 83 | trainer = train_model( 84 | data_root_folder="path/to/data", 85 | n_layer=3, 86 | ... 87 | ) 88 | # Final trained model is available as: 89 | trainer.model 90 | ``` 91 | 92 | See below for examples of training with specific methods. 93 | 94 | The notebook [gcl_train](../tutorials/gcl_train.ipynb) provides a full example of 95 | training and evaluating with one method. 96 | 97 | 98 | 99 | ### Examples 100 | 101 | For each of the training methods, we show below an example of 102 | how to run training from the command line and an equivalent 103 | example in Python code. 104 | 105 | Note that all examples use extremely limited hyperparameters, so the 106 | resulting model will not be good, but can be trained with small memory 107 | in a short time. In practice, you would want to adjust the hyperparameters 108 | to something more like the published model selection results. 109 | 110 | 111 | #### GCL (GRACE and MVGRL) 112 | 113 | Train with GRACE from the command line: 114 | ```bash 115 | python -m ckg_benchmarks.gcl.train \ 116 | --device -1 \ 117 | --method grace \ 118 | --n-layer 1 \ 119 | --embedding-dim 8 \ 120 | --epochs 1 \ 121 | --sampler-edges 2 \ 122 | --batch-size 128 123 | ``` 124 | 125 | `device=-1` forces use of CPU. Select a different device number 126 | to use a GPU. 127 | 128 | To train MVGRL, simply change the `method` parameter: 129 | ```bash 130 | python -m ckg_benchmarks.gcl.train \ 131 | --device -1 \ 132 | --method mvgrl \ 133 | --n-layer 1 \ 134 | --embedding-dim 8 \ 135 | --epochs 1 \ 136 | --sampler-edges 2 \ 137 | --batch-size 128 138 | ``` 139 | 140 | To do the same thing from Python code: 141 | ```python 142 | from ckg_benchmarks.gcl.train import train_model 143 | trainer = train_model( 144 | # Use CPU 145 | device=-1, 146 | # Train with GRACE; you can also use 'mvgrl' here 147 | method="grace", 148 | # Typically we use 2 or 3 149 | n_layer=1, 150 | # Minimum value we usually consider is 8 151 | embedding_dim=8, 152 | # For our experiments we trained for 100 epochs, here just 1 for testing 153 | epochs=1, 154 | # We usually sample 5 or 10 edges for training 155 | sampler_edges=2, 156 | # For GPU you'll want to set your batch size bigger if you can, as it makes it faster 157 | batch_size=128, 158 | ) 159 | ``` 160 | 161 | 162 | #### GraphSAGE 163 | 164 | Train from command line: 165 | ```bash 166 | python -m ckg_benchmarks.graphsage.train \ 167 | --device -1 \ 168 | --n-layer 2 \ 169 | --embedding_dim 8 \ 170 | --epochs 1 \ 171 | --train-batch-size 256 \ 172 | --inference-batch-size 256 \ 173 | --n-sample-neighbor 2 \ 174 | ``` 175 | 176 | The same thing from Python code: 177 | ```python 178 | from ckg_benchmarks.graphsage.train import train_model 179 | trainer = train_model( 180 | device=-1, 181 | n_layer=2, 182 | embedding_dim=8, 183 | epochs=1, 184 | train_batch_size=256, 185 | inference_batch_size=256, 186 | n_sample_neighbor=2, 187 | ) 188 | ``` 189 | 190 | 191 | 192 | #### GraphMAE 193 | 194 | Train from command line: 195 | ```bash 196 | python -m ckg_benchmarks.graphmae.train \ 197 | --device -1 \ 198 | --n-layer 2 \ 199 | --embedding_dim 8 \ 200 | --epochs 1 \ 201 | --disable-metis 202 | ``` 203 | 204 | The same thing from Python code: 205 | ```python 206 | from ckg_benchmarks.graphmae.train import train_model 207 | trainer = train_model( 208 | device=-1, 209 | n_layer=2, 210 | embedding_dim=8, 211 | epochs=1, 212 | disable_metis=True, 213 | ) 214 | ``` 215 | 216 | `disable_metis` is useful for training with small memory, but you may want to 217 | drop this option for more efficient training. 218 | 219 | 220 | 221 | #### eGraphMAE 222 | 223 | Train from command line: 224 | ```bash 225 | python -m ckg_benchmarks.egraphmae.train \ 226 | --device -1 \ 227 | --n-layer 2 \ 228 | --embedding-dim 8 \ 229 | --epochs 1 \ 230 | --disable-metis 231 | ``` 232 | 233 | The same thing from Python code: 234 | ```python 235 | from ckg_benchmarks.egraphmae.train import train_model 236 | trainer = train_model( 237 | device=-1, 238 | n_layer=2, 239 | embedding_dim=8, 240 | epochs=1, 241 | disable_metis=True, 242 | ) 243 | ``` 244 | 245 | 246 | 247 | ### Evaluating trained models 248 | 249 | Evaluation on all three tasks is implemented in the `companykg` package. 250 | As part of the training routine, the full graph's node embeddings are projected 251 | into the GNN's embedding space and the resulting embeddings are output to a file. 252 | It is therefore easy to evaluate these embeddings using CompanyKG. 253 | 254 | When training is run from the command line, the path to which the embeddings 255 | are output is printed at the end of training. You can use the `companykg.eval` 256 | tool to run all evaluation tasks on the embeddings: 257 | ```bash 258 | python -m companykg.eval 259 | ``` 260 | 261 | When training is run from Python, the returned trainer instance provides a method 262 | `evaluate()` that runs the CompanyKG evaluation method on the projected embeddings. 263 | The evaluation results will be output to stdout (unless you specify `silent=True) 264 | and the returned dictionary contains the results for the tasks. 265 | ```python 266 | from ckg_benchmarks.gcl.train import train_model 267 | 268 | trainer = train_model(...) 269 | results_dict = trainer.evaluate(embed=trainer.embeddings) 270 | ``` 271 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/gcl/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import argparse 8 | import logging 9 | 10 | import GCL.losses as L 11 | import torch 12 | from GCL.models import DualBranchContrast 13 | from torch.optim import Adam 14 | from torch_geometric.loader import NeighborLoader 15 | from tqdm.auto import tqdm 16 | 17 | from ckg_benchmarks.base import BaseTrainer 18 | from ckg_benchmarks.gcl.grace import GraceEncoder 19 | from ckg_benchmarks.gcl.mvgrl import MvgrlEncoder 20 | from ckg_benchmarks.utils import ranged_type 21 | from companykg import CompanyKG 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def grace_train_step(encoder_model, contrast_model, batch, optimizer, batch_size): 27 | optimizer.zero_grad() 28 | z, z1, z2 = encoder_model(batch.x, batch.edge_index) 29 | h1, h2 = [encoder_model.project(x) for x in [z1, z2]] 30 | h1 = h1[:batch_size] 31 | h2 = h2[:batch_size] 32 | loss = contrast_model(h1, h2) 33 | loss.backward() 34 | optimizer.step() 35 | return loss.item() 36 | 37 | 38 | def mvgrl_train_step(encoder_model, contrast_model, data, optimizer, batch_size): 39 | optimizer.zero_grad() 40 | z1, z2, g1, g2, z1n, z2n = encoder_model(data.x, data.edge_index, batch_size) 41 | loss = contrast_model(h1=z1, h2=z2, g1=g1, g2=g2, h3=z1n, h4=z2n) 42 | loss.backward() 43 | optimizer.step() 44 | return loss.item() 45 | 46 | 47 | class GclTrainer(BaseTrainer): 48 | def __init__( 49 | self, 50 | method: str = "grace", 51 | edge_weights: bool = False, 52 | embedding_dim: int = 64, 53 | n_layer: int = 2, 54 | sampler_edges: int = 5, 55 | batch_size: int = 16, 56 | learning_rate: float = None, 57 | epochs: int = 500, 58 | **kwargs, 59 | ): 60 | self.method = method 61 | if self.method not in ["grace", "mvgrl"]: 62 | raise ValueError("training method should be 'grace' or 'mvgrl'") 63 | self.training_method_name = self.method 64 | 65 | if "finetune_from" in kwargs: 66 | raise ValueError( 67 | "'finetune_from' is not currently supported for GCL methods" 68 | ) 69 | self.edge_weights = edge_weights 70 | self.batch_size = batch_size 71 | self.sampler_edges = sampler_edges 72 | self.embedding_dim = embedding_dim 73 | self.n_layer = n_layer 74 | self.epochs = epochs 75 | 76 | # This will be set by the model init 77 | self.contrast_model = None 78 | 79 | if learning_rate is None: 80 | # Use different defaults for GRACE and MVGRL 81 | learning_rate = 0.01 if self.method == "mvgrl" else 0.001 82 | self.learning_rate = learning_rate 83 | 84 | if self.method == "grace": 85 | self.train_step = grace_train_step 86 | else: 87 | self.train_step = mvgrl_train_step 88 | 89 | super().__init__(**kwargs) 90 | 91 | def load_companykg(self): 92 | """ 93 | Load CompanyKG dataset in preparation for training 94 | 95 | """ 96 | return CompanyKG( 97 | nodes_feature_type=self.nodes_feature_type, 98 | load_edges_weights=self.edge_weights, 99 | data_root_folder=self.data_root_folder, 100 | ) 101 | 102 | @property 103 | def hparams_str(self): 104 | return "_".join( 105 | str(x) 106 | for x in [ 107 | self.comkg.nodes_feature_type, 108 | self.epochs, 109 | self.n_layer, 110 | self.embedding_dim, 111 | self.sampler_edges, 112 | self.batch_size, 113 | self.seed, 114 | ] 115 | ) 116 | 117 | def build_graph(self): 118 | # Convert to PyG graph 119 | return self.comkg.to_pyg() 120 | 121 | def init_model(self, load=None): 122 | if self.method == "mvgrl": 123 | model = MvgrlEncoder( 124 | self.comkg.nodes_feature_dim, self.embedding_dim, self.n_layer 125 | ) 126 | self.contrast_model = DualBranchContrast(loss=L.JSD(), mode="G2L") 127 | else: 128 | model = GraceEncoder( 129 | self.comkg.nodes_feature_dim, 130 | self.embedding_dim, 131 | self.n_layer, 132 | self.embedding_dim, 133 | ) 134 | self.contrast_model = DualBranchContrast( 135 | loss=L.InfoNCE(tau=0.2), mode="L2L", intraview_negs=True 136 | ) 137 | 138 | optimizer = Adam(model.parameters(), lr=self.learning_rate) 139 | return model, optimizer, None 140 | 141 | def _train_model(self): 142 | data_loader = NeighborLoader( 143 | self.graph, 144 | num_neighbors=[self.sampler_edges] * self.n_layer, 145 | batch_size=self.batch_size, 146 | shuffle=True, 147 | ) 148 | 149 | self.model.train() 150 | for epoch in range(self.epochs): 151 | logger.info(f"Starting epoch {epoch + 1}") 152 | with tqdm( 153 | total=len(data_loader), 154 | desc=f"Epoch {epoch + 1}/{self.epochs}", 155 | disable=None, 156 | ) as pbar: 157 | epoch_loss = 0 158 | for bnum, batch in enumerate(data_loader): 159 | batch.to(self.device) 160 | loss = self.train_step( 161 | self.model, 162 | self.contrast_model, 163 | batch, 164 | self.optimizer, 165 | self.batch_size, 166 | ) 167 | epoch_loss += loss 168 | 169 | pbar.set_postfix({"loss": epoch_loss / (bnum + 1)}) 170 | pbar.update() 171 | logger.info(f"Epoch {epoch+1} loss: {epoch_loss}") 172 | 173 | def inference(self) -> torch.Tensor: 174 | self.model.to(self.device) 175 | return self.model.encode( 176 | self.graph, batch_size=self.batch_size, device=self.device 177 | ) 178 | 179 | 180 | # Alias for initializing and training a model 181 | train_model = GclTrainer.train 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument( 187 | "--method", 188 | type=str, 189 | default="grace", 190 | choices=["grace", "mvgrl"], 191 | help="The training method to use: grace or mvgrl", 192 | ) 193 | parser.add_argument( 194 | "--epochs", 195 | default=1000, 196 | type=ranged_type(int, 1, 2000), 197 | help="The max number of training epochs", 198 | ) 199 | parser.add_argument( 200 | "--nodes-feature-type", 201 | type=str, 202 | default="msbert", 203 | choices=["msbert", "ada2", "simcse", "pause"], 204 | help="The type of nodes feature: msbert, ada2, simcse or pause", 205 | ) 206 | parser.add_argument( 207 | "--device", 208 | default=0, 209 | type=int, 210 | help="The device used to carry out the training: use cpu when less than 0", 211 | ) 212 | parser.add_argument( 213 | "--data-root-folder", 214 | default="./data", 215 | type=str, 216 | help="The root folder where the CompanyKG data is downloaded to", 217 | ) 218 | parser.add_argument( 219 | "--work-folder", 220 | default="./experiments", 221 | type=str, 222 | help="The working folder where models and logs are saved to", 223 | ) 224 | parser.add_argument( 225 | "--n-layer", 226 | default=2, 227 | type=ranged_type(int, 2, 4), 228 | help="The number of GNN layers", 229 | ) 230 | parser.add_argument( 231 | "--learning-rate", 232 | type=ranged_type(float, 0.00001, 0.01), 233 | help="The training learning rate. Defaults: 0.01 (MVGRL), 0.001 (GRACE)", 234 | ) 235 | parser.add_argument( 236 | "--embedding-dim", 237 | default=64, 238 | type=ranged_type(int, 8, 512), 239 | help="The dimension of the node embedding to be learned", 240 | ) 241 | parser.add_argument( 242 | "--seed", 243 | default=42, 244 | type=int, 245 | help="The seed used to run the experiment", 246 | ) 247 | parser.add_argument( 248 | "--batch-size", 249 | default=16, 250 | type=int, 251 | help="Batch size to use for training and inference", 252 | ) 253 | parser.add_argument( 254 | "--sampler-edges", 255 | default=5, 256 | type=int, 257 | help="Number of neighbors to sample to build training batches", 258 | ) 259 | parser.add_argument( 260 | "--edge-weights", 261 | action="store_true", 262 | help="Load edge weights for training", 263 | ) 264 | opts = parser.parse_args() 265 | 266 | root_logger = logging.getLogger() 267 | root_logger.setLevel(logging.INFO) 268 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 269 | console_handler = logging.StreamHandler() 270 | console_handler.setFormatter(log_formatter) 271 | root_logger.addHandler(console_handler) 272 | 273 | # Create trainer and run 274 | logger.info("Initializing trainer") 275 | trainer = GclTrainer( 276 | nodes_feature_type=opts.nodes_feature_type, 277 | data_root_folder=opts.data_root_folder, 278 | method=opts.method, 279 | edge_weights=opts.edge_weights, 280 | embedding_dim=opts.embedding_dim, 281 | n_layer=opts.n_layer, 282 | sampler_edges=opts.sampler_edges, 283 | batch_size=opts.batch_size, 284 | learning_rate=opts.learning_rate, 285 | epochs=opts.epochs, 286 | seed=opts.seed, 287 | device=opts.device, 288 | work_folder=opts.work_folder, 289 | ) 290 | 291 | logger.info("Starting model training") 292 | trainer.train_model() 293 | logger.info("Training complete") 294 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphsage/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import argparse 8 | import logging 9 | import os 10 | 11 | import dgl 12 | import torch 13 | 14 | from ckg_benchmarks.base import BaseTrainer 15 | from ckg_benchmarks.graphsage.model import GraphSAGE, DotPredictor 16 | from ckg_benchmarks.utils import ranged_type 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class GraphSageTrainer(BaseTrainer): 22 | training_method_name = "graphsage" 23 | 24 | def __init__( 25 | self, 26 | embedding_dim: int = 64, 27 | n_layer: int = 2, 28 | dropout_rate: float = 0.1, 29 | learning_rate: float = 0.001, 30 | epochs: int = 2, 31 | train_batch_size: int = 2048, 32 | inference_batch_size: int = 2048, 33 | n_sample_neighbor: int = 8, 34 | **kwargs 35 | ): 36 | if "finetune_from" in kwargs: 37 | raise ValueError("'finetune_from' is not currently supported for GCL methods") 38 | self.embedding_dim = embedding_dim 39 | self.n_layer = n_layer 40 | self.dropout_rate = dropout_rate 41 | self.learning_rate = learning_rate 42 | self.epochs = epochs 43 | self.train_batch_size = train_batch_size 44 | self.inference_batch_size = inference_batch_size 45 | self.n_sample_neighbor = n_sample_neighbor 46 | # This will get set by the model init 47 | self.predictor = None 48 | 49 | super().__init__(**kwargs) 50 | 51 | # Create training sampler, also used for inference 52 | self.negative_sampler = dgl.dataloading.negative_sampler.Uniform(1) 53 | self.sampler = dgl.dataloading.NeighborSampler( 54 | [self.n_sample_neighbor for _ in range(self.n_layer)] 55 | ) 56 | 57 | @property 58 | def hparams_str(self): 59 | return "{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}".format( 60 | self.nodes_feature_type, 61 | self.epochs, 62 | self.n_sample_neighbor, 63 | self.train_batch_size, 64 | self.n_layer, 65 | self.dropout_rate, 66 | self.embedding_dim, 67 | self.learning_rate, 68 | self.seed, 69 | ) 70 | 71 | def build_graph(self): 72 | # Create DGL graph 73 | graph = self.comkg.get_dgl_graph(self.work_folder)[0] 74 | graph = dgl.add_reverse_edges(graph) 75 | return graph 76 | 77 | def init_model(self, load=None): 78 | # Create the model 79 | model = GraphSAGE( 80 | n_layer=self.n_layer, 81 | in_feats=self.comkg.nodes_feature_dim, 82 | h_feats=self.embedding_dim, 83 | dropout=self.dropout_rate, 84 | ) 85 | logger.info(model) 86 | predictor = DotPredictor() 87 | optimizer = torch.optim.Adam( 88 | list(model.parameters()) + list(predictor.parameters()), lr=self.learning_rate 89 | ) 90 | self.predictor = predictor 91 | # We don't use a scheduler, so just return None 92 | return model, optimizer, None 93 | 94 | def _train_model(self): 95 | train_dataloader = dgl.dataloading.DataLoader( 96 | self.graph, 97 | torch.arange(self.graph.number_of_edges()), 98 | dgl.dataloading.as_edge_prediction_sampler( 99 | self.sampler, negative_sampler=self.negative_sampler 100 | ), 101 | device=self.device, 102 | batch_size=self.train_batch_size, 103 | shuffle=True, 104 | drop_last=False, 105 | num_workers=2, 106 | ) 107 | total_steps = len(train_dataloader) 108 | 109 | self.predictor.to(self.device) 110 | 111 | # Start training loop 112 | for epoch in range(self.epochs): 113 | logger.info(f"Starting epoch {epoch+1}/{self.epochs}") 114 | self.model.train() 115 | for step, (_, pos_graph, neg_graph, mfgs) in enumerate(train_dataloader): 116 | inputs = mfgs[0].srcdata["feat"] 117 | outputs = self.model(mfgs, inputs) 118 | pos_score = self.predictor(pos_graph, outputs) 119 | neg_score = self.predictor(neg_graph, outputs) 120 | score = torch.cat([pos_score, neg_score]) 121 | label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)]) 122 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, label) 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | self.optimizer.step() 126 | logger.info(f"{step}/{total_steps} of epoch {epoch}: loss={loss.item()}") 127 | 128 | # Save the trained model 129 | epoch_snapshot_name = f"{self.hparams_str}_e{epoch}" 130 | model_save_path = os.path.join(self.work_folder, f"{epoch_snapshot_name}.pth") 131 | torch.save(self.model, model_save_path) 132 | logger.info(f"Model saved to {model_save_path}") 133 | 134 | def inference(self) -> torch.Tensor: 135 | """The inference (prediction) function for GraphSAGE model. 136 | 137 | Args: 138 | model (GraphSAGE): a GraphSAGE model. 139 | g (dgl.DGLGraph): the input graph to be predicted. 140 | sampler (dgl.dataloading.neighbor_sampler.NeighborSampler): a sampler that 141 | should be exactly the same as the one used in training. 142 | batch_size (int): the batch size for inference mini-batches. 143 | device (str): the device (cuda or cpu) to run the inference. 144 | 145 | Returns: 146 | torch.Tensor: the predicted node embeddings. 147 | """ 148 | self.model.eval() 149 | with torch.no_grad(): 150 | _dataloader = dgl.dataloading.DataLoader( 151 | self.graph, 152 | torch.arange(self.graph.number_of_nodes()), 153 | self.sampler, 154 | batch_size=self.inference_batch_size, 155 | shuffle=False, 156 | drop_last=False, 157 | num_workers=0, # num_workers > 0 will cause evaluation randomness 158 | device=self.device, 159 | ) 160 | result = [] 161 | for _, _, mfgs in _dataloader: 162 | inputs = mfgs[0].srcdata["feat"] 163 | result.append(self.model(mfgs, inputs)) 164 | return torch.cat(result) 165 | 166 | 167 | train_model = GraphSageTrainer.train 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument( 173 | "--epochs", 174 | default=2, 175 | type=ranged_type(int, 1, 100), 176 | help="The max number of training epochs", 177 | ) 178 | parser.add_argument( 179 | "--nodes-feature-type", 180 | type=str, 181 | default="msbert", 182 | choices=["msbert", "ada2", "simcse", "pause"], 183 | help="The type of nodes feature: msbert, ada2, simcse or pause", 184 | ) 185 | parser.add_argument( 186 | "--device", 187 | default=0, 188 | type=int, 189 | help="The device used to carry out the training: use cpu when less than 0", 190 | ) 191 | parser.add_argument( 192 | "--data-root-folder", 193 | default="./data", 194 | type=str, 195 | help="The root folder where the CompanyKG data is downloaded to", 196 | ) 197 | parser.add_argument( 198 | "--work-folder", 199 | default="./experiments", 200 | type=str, 201 | help="The working folder where models and logs are saved to", 202 | ) 203 | parser.add_argument( 204 | "--n-sample-neighbor", 205 | default=8, 206 | type=ranged_type(int, 2, 64), 207 | help="The number of neighbor to be sampled", 208 | ) 209 | parser.add_argument( 210 | "--train-batch-size", 211 | default=2048, 212 | type=ranged_type(int, 16, 2**15), 213 | help="The number of samples in each training mini-batch", 214 | ) 215 | parser.add_argument( 216 | "--inference-batch-size", 217 | default=2048, 218 | type=ranged_type(int, 16, 2**15), 219 | help="The number of samples in each inference mini-batch", 220 | ) 221 | parser.add_argument( 222 | "--n-layer", 223 | default=2, 224 | type=ranged_type(int, 2, 4), 225 | choices=range(2, 4), 226 | help="The number of GNN layers", 227 | ) 228 | parser.add_argument( 229 | "--dropout-rate", 230 | default=0.1, 231 | type=ranged_type(float, 0.0, 0.3), 232 | help="The feature dropout rate of GNN layers", 233 | ) 234 | parser.add_argument( 235 | "--learning-rate", 236 | default=0.001, 237 | type=ranged_type(float, 0.00001, 0.01), 238 | help="The training learning rate", 239 | ) 240 | parser.add_argument( 241 | "--embedding-dim", 242 | default=128, 243 | type=ranged_type(int, 8, 1024), 244 | help="The dimension of the embedding to be learned", 245 | ) 246 | parser.add_argument( 247 | "--seed", 248 | default=42, 249 | type=int, 250 | help="The seed used to run the experiment", 251 | ) 252 | opts = parser.parse_args() 253 | 254 | root_logger = logging.getLogger() 255 | root_logger.setLevel(logging.INFO) 256 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 257 | console_handler = logging.StreamHandler() 258 | console_handler.setFormatter(log_formatter) 259 | root_logger.addHandler(console_handler) 260 | 261 | logger.info("Initializing trainer") 262 | trainer = GraphSageTrainer( 263 | nodes_feature_type=opts.nodes_feature_type, 264 | data_root_folder=opts.data_root_folder, 265 | embedding_dim=opts.embedding_dim, 266 | n_layer=opts.n_layer, 267 | dropout_rate=opts.dropout_rate, 268 | learning_rate=opts.learning_rate, 269 | epochs=opts.epochs, 270 | seed=opts.seed, 271 | device=opts.device, 272 | work_folder=opts.work_folder, 273 | train_batch_size=opts.train_batch_size, 274 | inference_batch_size=opts.inference_batch_size, 275 | n_sample_neighbor=opts.n_sample_neighbor, 276 | ) 277 | 278 | logger.info("Starting model training") 279 | trainer.train_model() 280 | logger.info("Training complete") 281 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/egraphmae/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import argparse 8 | import logging 9 | import os 10 | 11 | import dgl 12 | import numpy as np 13 | import torch 14 | 15 | from ckg_benchmarks.egraphmae.model import EGraphMAE 16 | from ckg_benchmarks.graphmae.train import GraphMAETrainer 17 | from ckg_benchmarks.utils import ranged_type 18 | from companykg.kg import CompanyKG 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class EGraphMAETrainer(GraphMAETrainer): 24 | """ 25 | Subclass the GraphMAETrainer to inherit most of the training 26 | procedure from it. We override certain parts here that are 27 | specific to eGraphMAE. 28 | 29 | """ 30 | training_method_name = "egraphmae" 31 | 32 | def __init__(self, edge_hidden_dim=32, **kwargs): 33 | self.edge_hidden_dim = edge_hidden_dim 34 | super().__init__(**kwargs) 35 | 36 | def load_companykg(self): 37 | # Override to load with edge weights 38 | return CompanyKG( 39 | nodes_feature_type=self.nodes_feature_type, 40 | load_edges_weights=True, 41 | data_root_folder=self.data_root_folder, 42 | ) 43 | 44 | def build_graph(self): 45 | graph = self.comkg.get_dgl_graph(self.work_folder)[0] 46 | graph.edata["weight"] = graph.edata["weight"] / np.linalg.norm( 47 | graph.edata["weight"], axis=1, keepdims=True 48 | ) 49 | graph = dgl.add_reverse_edges(graph, copy_edata=True) 50 | graph = dgl.remove_self_loop(graph) 51 | graph = dgl.add_self_loop(graph, fill_data=1.0) 52 | # The following code for creating a subgraph for SR evaluation that 53 | # can be used in early stopping was in the original implementation, 54 | # but wasn't being used. 55 | # I'm leaving it here so we can revive this efficiency trick if we 56 | # need to later 57 | # But note that it happened before the add_reverse_edges, etc (which 58 | # are then applied separately to the subgraph) 59 | """ 60 | # Create a subgraph for evaluation: EGraphMAE requires significantly more memory 61 | # hence can only inference on a subgraph on CPU during training. 62 | sr_df = self.comkg.eval_tasks["sr"][1] 63 | sr_nids = list( 64 | set(sr_df.target_node_id.unique()) 65 | .union(set(sr_df.candidate0_node_id.unique())) 66 | .union(set(sr_df.candidate1_node_id.unique())) 67 | ) 68 | graph = self.comkg.get_dgl_graph(self.work_folder)[0] 69 | eval_g = dgl.merge( 70 | [ 71 | dgl.khop_in_subgraph(graph, sr_nids, k=opts.n_layer)[0], 72 | dgl.khop_out_subgraph(graph, sr_nids, k=opts.n_layer)[0], 73 | ] 74 | ) 75 | eval_g = dgl.add_reverse_edges(eval_g, copy_edata=True) 76 | eval_g = dgl.remove_self_loop(eval_g) 77 | eval_g = dgl.add_self_loop(eval_g, fill_data=1.0) 78 | """ 79 | return graph 80 | 81 | @property 82 | def hparams_str(self): 83 | # We have one extra parameter 84 | return f"{super().hparams_str}_{self.edge_hidden_dim}" 85 | 86 | def inference(self) -> torch.Tensor: 87 | """ 88 | eGraphMAE inference at the moment only supports CPU inference, so doesn't 89 | pay attention to self.device. 90 | 91 | """ 92 | # Due to memory limitation we have to save and load to cpu for inference 93 | tmp_model_path = os.path.join(self.work_folder, "tmp_model.pth") 94 | torch.save(self.model, tmp_model_path) 95 | tmp_model = torch.load(tmp_model_path, map_location="cpu") 96 | tmp_model.eval() 97 | # Note that embed will return both node and edge embedding 98 | embed_dense = ( 99 | tmp_model.embed( 100 | self.graph, self.graph.ndata["feat"], self.graph.edata["weight"] 101 | )[0].cpu().detach() 102 | ) 103 | embed_dense / np.linalg.norm(embed_dense, axis=1, keepdims=True) 104 | if embed_dense.shape[0] == self.comkg.n_nodes: 105 | return embed_dense.numpy() 106 | else: 107 | # The evaluation logic requires full embedding matrix 108 | embed = torch.zeros(self.comkg.n_nodes, self.embedding_dim) 109 | for idx, x in enumerate(self.graph.ndata["_ID"].tolist()): 110 | embed[x, :] = embed_dense[idx, :] 111 | return embed.numpy() 112 | 113 | def init_model(self, load=None): 114 | # Initialize model and training 115 | if load is None: 116 | # Initialize a new model 117 | model = EGraphMAE( 118 | in_dim=self.comkg.nodes_feature_dim, 119 | num_hidden=self.embedding_dim, 120 | num_layers=self.n_layer, 121 | feat_drop=self.dropout_rate + 0.1, 122 | attn_drop=self.dropout_rate, 123 | nhead=self.n_heads, 124 | mask_rate=self.mask_rate, 125 | drop_edge_rate=self.drop_edge_rate, 126 | num_edge_features=self.comkg.edges_weight_dim, 127 | num_edge_hidden=self.edge_hidden_dim, 128 | ) 129 | else: 130 | # Load a pre-trained model 131 | model = torch.load(load) 132 | 133 | optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) 134 | lr_schedule = lambda epoch: (1 + np.cos(epoch * np.pi / self.epochs)) * 0.5 135 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) 136 | return model, optimizer, scheduler 137 | 138 | def training_loss(self, subgraph): 139 | loss, _ = self.model(subgraph, subgraph.ndata["feat"], subgraph.edata["weight"]) 140 | return loss 141 | 142 | 143 | train_model = EGraphMAETrainer.train 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument( 149 | "--epochs", 150 | default=1000, 151 | type=ranged_type(int, 1, 2000), 152 | help="The max number of training epochs", 153 | ) 154 | parser.add_argument( 155 | "--nodes-feature-type", 156 | type=str, 157 | default="msbert", 158 | choices=["msbert", "ada2", "simcse", "pause"], 159 | help="The type of nodes feature: msbert, ada2, simcse or pause", 160 | ) 161 | parser.add_argument( 162 | "--device", 163 | default=0, 164 | type=int, 165 | help="The device used to carry out the training: use cpu when less than 0", 166 | ) 167 | parser.add_argument( 168 | "--data-root-folder", 169 | default="./data", 170 | type=str, 171 | help="The root folder where the CompanyKG data is downloaded to", 172 | ) 173 | parser.add_argument( 174 | "--finetune-from", 175 | type=str, 176 | help="The saved model to be finetuned from, ex. ./msbert_1000_2_0.01_0.001_256_0.1_0.1_8_100_42.pth", 177 | ) 178 | parser.add_argument( 179 | "--work-folder", 180 | default="./experiments", 181 | type=str, 182 | help="The working folder where models and logs are saved to", 183 | ) 184 | parser.add_argument( 185 | "--n-layer", 186 | default=2, 187 | type=ranged_type(int, 2, 4), 188 | help="The number of GNN layers", 189 | ) 190 | parser.add_argument( 191 | "--dropout-rate", 192 | default=0.1, 193 | type=ranged_type(float, 0.0, 0.3), 194 | help="The feature dropout rate of GNN layers", 195 | ) 196 | parser.add_argument( 197 | "--learning-rate", 198 | default=0.001, 199 | type=ranged_type(float, 0.00001, 0.01), 200 | help="The training learning rate", 201 | ) 202 | parser.add_argument( 203 | "--embedding-dim", 204 | default=64, 205 | type=ranged_type(int, 8, 512), 206 | help="The dimension of the node embedding to be learned", 207 | ) 208 | parser.add_argument( 209 | "--drop-edge-rate", 210 | default=0.5, 211 | type=ranged_type(float, 0.1, 0.8), 212 | help="The rate of edges to be dropped during training", 213 | ) 214 | parser.add_argument( 215 | "--mask-rate", 216 | default=0.5, 217 | type=ranged_type(float, 0.1, 0.8), 218 | help="The rate of nodes feature to be masked during training", 219 | ) 220 | parser.add_argument( 221 | "--n-heads", 222 | default=8, 223 | type=ranged_type(int, 1, 8), 224 | help="The number of attention heads", 225 | ) 226 | parser.add_argument( 227 | "--n-lives", 228 | default=100, 229 | type=ranged_type(int, 10, 200), 230 | help="The number of training epochs allowed when evaluation metrics are not improving", 231 | ) 232 | parser.add_argument( 233 | "--edge-hidden-dim", 234 | default=32, 235 | type=ranged_type(int, 4, 64), 236 | help="The dimension of the edge embedding", 237 | ) 238 | parser.add_argument( 239 | "--seed", 240 | default=42, 241 | type=int, 242 | help="The seed used to run the experiment", 243 | ) 244 | parser.add_argument( 245 | "--disable-metis", 246 | action="store_true", 247 | help="Force trainer not to use Metis partitions, even when the embedding size " 248 | "is large" 249 | ) 250 | opts = parser.parse_args() 251 | 252 | root_logger = logging.getLogger() 253 | root_logger.setLevel(logging.INFO) 254 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 255 | console_handler = logging.StreamHandler() 256 | console_handler.setFormatter(log_formatter) 257 | root_logger.addHandler(console_handler) 258 | 259 | # Create trainer and run 260 | logger.info("Initializing trainer") 261 | trainer = EGraphMAETrainer( 262 | edge_hidden_dim=opts.edge_hidden_dim, 263 | nodes_feature_type=opts.nodes_feature_type, 264 | data_root_folder=opts.data_root_folder, 265 | embedding_dim=opts.embedding_dim, 266 | n_layer=opts.n_layer, 267 | dropout_rate=opts.dropout_rate, 268 | n_heads=opts.n_heads, 269 | mask_rate=opts.mask_rate, 270 | drop_edge_rate=opts.drop_edge_rate, 271 | learning_rate=opts.learning_rate, 272 | epochs=opts.epochs, 273 | n_lives=opts.n_lives, 274 | seed=opts.seed, 275 | device=opts.device, 276 | finetune_from=opts.finetune_from, 277 | work_folder=opts.work_folder, 278 | disable_metis=opts.disable_metis, 279 | ) 280 | 281 | logger.info("Starting model training") 282 | trainer.train_model() 283 | logger.info("Training complete") 284 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/egraphmae/egat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import dgl.function as fn 8 | import torch 9 | import torch.nn as nn 10 | from dgl.ops import edge_softmax 11 | from dgl.utils import expand_as_pair 12 | 13 | 14 | class EGAT(nn.Module): 15 | def __init__( 16 | self, 17 | in_dim, 18 | num_hidden, 19 | out_dim, 20 | num_layers, 21 | nhead, 22 | nhead_out, 23 | feat_drop, 24 | attn_drop, 25 | negative_slope, 26 | residual, 27 | norm, 28 | in_dim_e, 29 | num_hidden_e, 30 | out_dim_e, 31 | concat_out=False, 32 | encoding=False, 33 | ): 34 | super(EGAT, self).__init__() 35 | self.out_dim = out_dim 36 | self.num_heads = nhead 37 | self.num_layers = num_layers 38 | self.gat_layers = nn.ModuleList() 39 | self.activation = nn.PReLU() # fix to PReLU 40 | self.concat_out = concat_out 41 | self.in_dim_e = in_dim_e 42 | self.num_hidden_e = num_hidden_e 43 | self.out_dim_e = out_dim_e 44 | 45 | last_activation = self.activation if encoding else None 46 | last_residual = encoding and residual 47 | last_norm = norm if encoding else None 48 | 49 | if num_layers == 1: 50 | self.gat_layers.append( 51 | EGATConv( 52 | in_dim, 53 | in_dim_e, 54 | out_dim, 55 | out_dim_e, 56 | nhead_out, 57 | feat_drop, 58 | attn_drop, 59 | negative_slope, 60 | last_residual, 61 | norm=last_norm, 62 | ) 63 | ) 64 | else: 65 | # input projection (no residual) 66 | self.gat_layers.append( 67 | EGATConv( 68 | in_dim, 69 | in_dim_e, 70 | num_hidden, 71 | num_hidden_e, 72 | nhead, 73 | feat_drop, 74 | attn_drop, 75 | negative_slope, 76 | residual, 77 | self.activation, 78 | norm=norm, 79 | ) 80 | ) 81 | # hidden layers 82 | for l in range(1, num_layers - 1): 83 | # due to multi-head, the in_dim = num_hidden * num_heads 84 | self.gat_layers.append( 85 | EGATConv( 86 | num_hidden * nhead, 87 | num_hidden_e * nhead, 88 | num_hidden, 89 | num_hidden_e, 90 | nhead, 91 | feat_drop, 92 | attn_drop, 93 | negative_slope, 94 | residual, 95 | self.activation, 96 | norm=norm, 97 | ) 98 | ) 99 | # output projection 100 | self.gat_layers.append( 101 | EGATConv( 102 | num_hidden * nhead, 103 | num_hidden_e * nhead, 104 | out_dim, 105 | out_dim_e, 106 | nhead_out, 107 | feat_drop, 108 | attn_drop, 109 | negative_slope, 110 | last_residual, 111 | activation=last_activation, 112 | norm=last_norm, 113 | ) 114 | ) 115 | 116 | self.head = nn.Identity() 117 | 118 | def forward(self, g, inputs, eputs, return_hidden=False): 119 | h = inputs 120 | e = eputs 121 | hidden_list = [] 122 | 123 | for l in range(self.num_layers): 124 | h, e = self.gat_layers[l](g, h, e) 125 | hidden_list.append(h) 126 | # h = h.flatten(1) 127 | # output projection 128 | if return_hidden: 129 | return self.head(h), self.head(e), hidden_list 130 | else: 131 | return self.head(h), self.head(e) 132 | 133 | def reset_classifier(self, num_classes): 134 | self.head = nn.Linear(self.num_heads * self.out_dim, num_classes) 135 | 136 | 137 | class EGATConv(nn.Module): 138 | def __init__( 139 | self, 140 | in_node_feats, 141 | in_edge_feats, 142 | out_node_feats, 143 | out_edge_feats, 144 | num_heads, 145 | feat_drop=0.0, 146 | attn_drop=0.0, 147 | negative_slope=0.2, 148 | residual=False, 149 | activation=None, 150 | bias=True, 151 | norm=None, 152 | ): 153 | 154 | super().__init__() 155 | self._num_heads = num_heads 156 | self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats) 157 | self._out_node_feats = out_node_feats 158 | self._out_edge_feats = out_edge_feats 159 | 160 | if isinstance(in_node_feats, tuple): 161 | self.fc_node_src = nn.Linear( 162 | self._in_src_node_feats, out_node_feats * num_heads, bias=False 163 | ) 164 | self.fc_ni = nn.Linear( 165 | self._in_src_node_feats, out_edge_feats * num_heads, bias=False 166 | ) 167 | self.fc_nj = nn.Linear( 168 | self._in_dst_node_feats, out_edge_feats * num_heads, bias=False 169 | ) 170 | else: 171 | self.fc_node_src = nn.Linear( 172 | self._in_src_node_feats, out_node_feats * num_heads, bias=False 173 | ) 174 | self.fc_ni = nn.Linear( 175 | self._in_src_node_feats, out_edge_feats * num_heads, bias=False 176 | ) 177 | self.fc_nj = nn.Linear( 178 | self._in_src_node_feats, out_edge_feats * num_heads, bias=False 179 | ) 180 | 181 | self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats * num_heads, bias=False) 182 | self.attn = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_edge_feats))) 183 | self.feat_drop = nn.Dropout(feat_drop) 184 | self.attn_drop = nn.Dropout(attn_drop) 185 | self.leaky_relu = nn.LeakyReLU(negative_slope) 186 | if bias: 187 | self.bias = nn.Parameter( 188 | torch.FloatTensor(size=(num_heads * out_edge_feats,)) 189 | ) 190 | else: 191 | self.register_buffer("bias", None) 192 | 193 | if residual: 194 | if self._in_dst_node_feats != out_node_feats * num_heads: 195 | self.res_fc = nn.Linear( 196 | self._in_dst_node_feats, num_heads * out_node_feats, bias=False 197 | ) 198 | else: 199 | self.res_fc = nn.Identity() 200 | else: 201 | self.register_buffer("res_fc", None) 202 | 203 | self.reset_parameters() 204 | 205 | self.activation = activation 206 | 207 | self.norm = norm 208 | if norm is not None: 209 | self.norm = norm(num_heads * out_node_feats) 210 | self.norm_e = norm(num_heads * out_edge_feats) 211 | 212 | def reset_parameters(self): 213 | """ 214 | Reinitialize learnable parameters. 215 | """ 216 | gain = nn.init.calculate_gain("relu") 217 | nn.init.xavier_normal_(self.fc_node_src.weight, gain=gain) 218 | nn.init.xavier_normal_(self.fc_ni.weight, gain=gain) 219 | nn.init.xavier_normal_(self.fc_fij.weight, gain=gain) 220 | nn.init.xavier_normal_(self.fc_nj.weight, gain=gain) 221 | nn.init.xavier_normal_(self.attn, gain=gain) 222 | nn.init.constant_(self.bias, 0) 223 | if isinstance(self.res_fc, nn.Linear): 224 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 225 | 226 | def forward(self, graph, nfeats, efeats, get_attention=False): 227 | 228 | with graph.local_scope(): 229 | if (graph.in_degrees() == 0).any(): 230 | raise RuntimeError( 231 | "There are 0-in-degree nodes in the graph, " 232 | "output for those nodes will be invalid. " 233 | "This is harmful for some applications, " 234 | "causing silent performance regression. " 235 | "Adding self-loop on the input graph by " 236 | "calling `g = dgl.add_self_loop(g)` will resolve " 237 | "the issue." 238 | ) 239 | 240 | # calc edge attention 241 | # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats 242 | # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py 243 | if isinstance(nfeats, tuple): 244 | nfeats_src, nfeats_dst = nfeats 245 | nfeats_src = self.feat_drop(nfeats_src) 246 | nfeats_dst = self.feat_drop(nfeats_dst) 247 | dst_prefix_shape = nfeats_dst.shape[:-1] 248 | else: 249 | nfeats_src = nfeats_dst = self.feat_drop(nfeats) 250 | dst_prefix_shape = nfeats.shape[:-1] 251 | 252 | f_ni = self.fc_ni(nfeats_src) 253 | f_nj = self.fc_nj(nfeats_dst) 254 | f_fij = self.fc_fij(efeats) 255 | 256 | graph.srcdata.update({"f_ni": f_ni}) 257 | graph.dstdata.update({"f_nj": f_nj}) 258 | # add ni, nj factors 259 | graph.apply_edges(fn.u_add_v("f_ni", "f_nj", "f_tmp")) 260 | # add fij to node factor 261 | f_out = graph.edata.pop("f_tmp") + f_fij 262 | # f_out = f_fij 263 | if self.bias is not None: 264 | f_out = f_out + self.bias 265 | f_out = self.leaky_relu(f_out) 266 | f_out = f_out.view(-1, self._num_heads, self._out_edge_feats) 267 | # compute attention factor 268 | e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1) 269 | graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) 270 | # graph.edata['a'] = e 271 | graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view( 272 | -1, self._num_heads, self._out_node_feats 273 | ) 274 | # calc weighted sum 275 | graph.update_all(fn.u_mul_e("h_out", "a", "m"), fn.sum("m", "h_out")) 276 | 277 | h_out = graph.dstdata["h_out"].view( 278 | -1, self._num_heads, self._out_node_feats 279 | ) 280 | 281 | # residual 282 | if self.res_fc is not None: 283 | # Use -1 rather than self._num_heads to handle broadcasting 284 | resval = self.res_fc(nfeats_dst).view( 285 | *dst_prefix_shape, -1, self._out_node_feats 286 | ) 287 | h_out = h_out + resval 288 | 289 | rst_h_out = h_out.flatten(1) 290 | rst_f_out = f_out.flatten(1) 291 | 292 | if self.norm is not None: 293 | rst_h_out = self.norm(rst_h_out) 294 | rst_f_out = self.norm_e(rst_f_out) 295 | 296 | # activation 297 | if self.activation: 298 | rst_h_out = self.activation(rst_h_out) 299 | rst_f_out = self.activation(rst_f_out) 300 | 301 | if get_attention: 302 | return rst_h_out, rst_f_out, graph.edata.pop("a") 303 | else: 304 | return rst_h_out, rst_f_out 305 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphmae/gat.py: -------------------------------------------------------------------------------- 1 | # This implementation is adapted from GraphMAE 2 | # https://github.com/THUDM/GraphMAE 3 | 4 | import dgl.function as fn 5 | import torch 6 | import torch.nn as nn 7 | from dgl.ops import edge_softmax 8 | from dgl.utils import expand_as_pair 9 | 10 | 11 | class GAT(nn.Module): 12 | def __init__( 13 | self, 14 | in_dim, 15 | num_hidden, 16 | out_dim, 17 | num_layers, 18 | nhead, 19 | nhead_out, 20 | feat_drop, 21 | attn_drop, 22 | negative_slope, 23 | residual, 24 | norm, 25 | concat_out=False, 26 | encoding=False, 27 | ): 28 | super(GAT, self).__init__() 29 | self.out_dim = out_dim 30 | self.num_heads = nhead 31 | self.num_layers = num_layers 32 | self.gat_layers = nn.ModuleList() 33 | self.activation = nn.PReLU() # fix to PReLU 34 | self.concat_out = concat_out 35 | 36 | last_activation = self.activation if encoding else None 37 | last_residual = encoding and residual 38 | last_norm = norm if encoding else None 39 | 40 | if num_layers == 1: 41 | self.gat_layers.append( 42 | GATConv( 43 | in_dim, 44 | out_dim, 45 | nhead_out, 46 | feat_drop, 47 | attn_drop, 48 | negative_slope, 49 | last_residual, 50 | norm=last_norm, 51 | concat_out=concat_out, 52 | ) 53 | ) 54 | else: 55 | # input projection (no residual) 56 | self.gat_layers.append( 57 | GATConv( 58 | in_dim, 59 | num_hidden, 60 | nhead, 61 | feat_drop, 62 | attn_drop, 63 | negative_slope, 64 | residual, 65 | self.activation, 66 | norm=norm, 67 | concat_out=concat_out, 68 | ) 69 | ) 70 | # hidden layers 71 | for l in range(1, num_layers - 1): 72 | # due to multi-head, the in_dim = num_hidden * num_heads 73 | self.gat_layers.append( 74 | GATConv( 75 | num_hidden * nhead, 76 | num_hidden, 77 | nhead, 78 | feat_drop, 79 | attn_drop, 80 | negative_slope, 81 | residual, 82 | self.activation, 83 | norm=norm, 84 | concat_out=concat_out, 85 | ) 86 | ) 87 | # output projection 88 | self.gat_layers.append( 89 | GATConv( 90 | num_hidden * nhead, 91 | out_dim, 92 | nhead_out, 93 | feat_drop, 94 | attn_drop, 95 | negative_slope, 96 | last_residual, 97 | activation=last_activation, 98 | norm=last_norm, 99 | concat_out=concat_out, 100 | ) 101 | ) 102 | 103 | self.head = nn.Identity() 104 | 105 | def forward(self, g, inputs, return_hidden=False): 106 | h = inputs 107 | hidden_list = [] 108 | for l in range(self.num_layers): 109 | h = self.gat_layers[l](g, h) 110 | hidden_list.append(h) 111 | # h = h.flatten(1) 112 | # output projection 113 | if return_hidden: 114 | return self.head(h), hidden_list 115 | else: 116 | return self.head(h) 117 | 118 | def reset_classifier(self, num_classes): 119 | self.head = nn.Linear(self.num_heads * self.out_dim, num_classes) 120 | 121 | 122 | class GATConv(nn.Module): 123 | def __init__( 124 | self, 125 | in_feats, 126 | out_feats, 127 | num_heads, 128 | feat_drop=0.0, 129 | attn_drop=0.0, 130 | negative_slope=0.2, 131 | residual=False, 132 | activation=None, 133 | allow_zero_in_degree=False, 134 | bias=True, 135 | norm=None, 136 | concat_out=True, 137 | ): 138 | super(GATConv, self).__init__() 139 | self._num_heads = num_heads 140 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 141 | self._out_feats = out_feats 142 | self._allow_zero_in_degree = allow_zero_in_degree 143 | self._concat_out = concat_out 144 | 145 | if isinstance(in_feats, tuple): 146 | self.fc_src = nn.Linear( 147 | self._in_src_feats, out_feats * num_heads, bias=False 148 | ) 149 | self.fc_dst = nn.Linear( 150 | self._in_dst_feats, out_feats * num_heads, bias=False 151 | ) 152 | else: 153 | self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False) 154 | self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 155 | self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats))) 156 | self.feat_drop = nn.Dropout(feat_drop) 157 | self.attn_drop = nn.Dropout(attn_drop) 158 | self.leaky_relu = nn.LeakyReLU(negative_slope) 159 | if bias: 160 | self.bias = nn.Parameter(torch.FloatTensor(size=(num_heads * out_feats,))) 161 | else: 162 | self.register_buffer("bias", None) 163 | if residual: 164 | if self._in_dst_feats != out_feats * num_heads: 165 | self.res_fc = nn.Linear( 166 | self._in_dst_feats, num_heads * out_feats, bias=False 167 | ) 168 | else: 169 | self.res_fc = nn.Identity() 170 | else: 171 | self.register_buffer("res_fc", None) 172 | self.reset_parameters() 173 | self.activation = activation 174 | 175 | self.norm = norm 176 | if norm is not None: 177 | self.norm = norm(num_heads * out_feats) 178 | 179 | def reset_parameters(self): 180 | """ 181 | 182 | Description 183 | ----------- 184 | Reinitialize learnable parameters. 185 | 186 | Note 187 | ---- 188 | The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. 189 | The attention weights are using xavier initialization method. 190 | """ 191 | gain = nn.init.calculate_gain("relu") 192 | if hasattr(self, "fc"): 193 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 194 | else: 195 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 196 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 197 | nn.init.xavier_normal_(self.attn_l, gain=gain) 198 | nn.init.xavier_normal_(self.attn_r, gain=gain) 199 | if self.bias is not None: 200 | nn.init.constant_(self.bias, 0) 201 | if isinstance(self.res_fc, nn.Linear): 202 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 203 | 204 | def set_allow_zero_in_degree(self, set_value): 205 | self._allow_zero_in_degree = set_value 206 | 207 | def forward(self, graph, feat, get_attention=False): 208 | with graph.local_scope(): 209 | if not self._allow_zero_in_degree: 210 | if (graph.in_degrees() == 0).any(): 211 | raise RuntimeError( 212 | "There are 0-in-degree nodes in the graph, " 213 | "output for those nodes will be invalid. " 214 | "This is harmful for some applications, " 215 | "causing silent performance regression. " 216 | "Adding self-loop on the input graph by " 217 | "calling `g = dgl.add_self_loop(g)` will resolve " 218 | "the issue. Setting ``allow_zero_in_degree`` " 219 | "to be `True` when constructing this module will " 220 | "suppress the check and let the code run." 221 | ) 222 | 223 | if isinstance(feat, tuple): 224 | src_prefix_shape = feat[0].shape[:-1] 225 | dst_prefix_shape = feat[1].shape[:-1] 226 | h_src = self.feat_drop(feat[0]) 227 | h_dst = self.feat_drop(feat[1]) 228 | if not hasattr(self, "fc_src"): 229 | feat_src = self.fc(h_src).view( 230 | *src_prefix_shape, self._num_heads, self._out_feats 231 | ) 232 | feat_dst = self.fc(h_dst).view( 233 | *dst_prefix_shape, self._num_heads, self._out_feats 234 | ) 235 | else: 236 | feat_src = self.fc_src(h_src).view( 237 | *src_prefix_shape, self._num_heads, self._out_feats 238 | ) 239 | feat_dst = self.fc_dst(h_dst).view( 240 | *dst_prefix_shape, self._num_heads, self._out_feats 241 | ) 242 | else: 243 | src_prefix_shape = dst_prefix_shape = feat.shape[:-1] 244 | h_src = h_dst = self.feat_drop(feat) 245 | feat_src = feat_dst = self.fc(h_src).view( 246 | *src_prefix_shape, self._num_heads, self._out_feats 247 | ) 248 | if graph.is_block: 249 | feat_dst = feat_src[: graph.number_of_dst_nodes()] 250 | h_dst = h_dst[: graph.number_of_dst_nodes()] 251 | dst_prefix_shape = ( 252 | graph.number_of_dst_nodes(), 253 | ) + dst_prefix_shape[1:] 254 | # NOTE: GAT paper uses "first concatenation then linear projection" 255 | # to compute attention scores, while ours is "first projection then 256 | # addition", the two approaches are mathematically equivalent: 257 | # We decompose the weight vector a mentioned in the paper into 258 | # [a_l || a_r], then 259 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 260 | # Our implementation is much efficient because we do not need to 261 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 262 | # addition could be optimized with DGL's built-in function u_add_v, 263 | # which further speeds up computation and saves memory footprint. 264 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 265 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 266 | graph.srcdata.update({"ft": feat_src, "el": el}) 267 | graph.dstdata.update({"er": er}) 268 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 269 | graph.apply_edges(fn.u_add_v("el", "er", "e")) 270 | e = self.leaky_relu(graph.edata.pop("e")) 271 | # compute softmax 272 | graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) 273 | # message passing 274 | graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) 275 | rst = graph.dstdata["ft"] 276 | 277 | # bias 278 | if self.bias is not None: 279 | rst = rst + self.bias.view( 280 | *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats 281 | ) 282 | 283 | # residual 284 | if self.res_fc is not None: 285 | # Use -1 rather than self._num_heads to handle broadcasting 286 | resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats) 287 | rst = rst + resval 288 | 289 | if self._concat_out: 290 | rst = rst.flatten(1) 291 | else: 292 | rst = torch.mean(rst, dim=1) 293 | 294 | if self.norm is not None: 295 | rst = self.norm(rst) 296 | 297 | # activation 298 | if self.activation: 299 | rst = self.activation(rst) 300 | 301 | if get_attention: 302 | return rst, graph.edata["a"] 303 | else: 304 | return rst 305 | -------------------------------------------------------------------------------- /benchmarks/src/ckg_benchmarks/graphmae/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import argparse 8 | import logging 9 | import math 10 | import os 11 | from random import randrange 12 | from typing import Union 13 | 14 | import dgl 15 | import numpy as np 16 | import torch 17 | 18 | from ckg_benchmarks.base import BaseTrainer 19 | from ckg_benchmarks.graphmae.model import GraphMAE 20 | from ckg_benchmarks.utils import ranged_type 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class GraphMAETrainer(BaseTrainer): 26 | training_method_name = "graphmae" 27 | 28 | def __init__( 29 | self, 30 | embedding_dim: int = 64, 31 | n_layer: int = 2, 32 | dropout_rate: float = 0.1, 33 | n_heads: int = 8, 34 | mask_rate: float = 0.5, 35 | drop_edge_rate: float = 0.5, 36 | learning_rate: float = 0.001, 37 | epochs: int = 500, 38 | n_lives: int = 100, 39 | disable_metis: bool = False, 40 | **kwargs 41 | ): 42 | self.disable_metis = disable_metis 43 | self.embedding_dim = embedding_dim 44 | self.n_layer = n_layer 45 | self.dropout_rate = dropout_rate 46 | self.n_heads = n_heads 47 | self.mask_rate = mask_rate 48 | self.drop_edge_rate = drop_edge_rate 49 | self.learning_rate = learning_rate 50 | self.epochs = epochs 51 | self.n_lives = n_lives 52 | 53 | super().__init__(**kwargs) 54 | 55 | def build_graph(self): 56 | """ 57 | Build a DGL graph from the loaded CompanyKG graph 58 | and prepare it for training. 59 | 60 | """ 61 | graph = self.comkg.get_dgl_graph(self.work_folder)[0] 62 | graph = dgl.add_reverse_edges(graph) 63 | graph = graph.remove_self_loop() 64 | graph = graph.add_self_loop() 65 | return graph 66 | 67 | @property 68 | def hparams_str(self): 69 | return "{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}".format( 70 | self.comkg.nodes_feature_type, 71 | self.epochs, 72 | self.n_layer, 73 | self.dropout_rate, 74 | self.learning_rate, 75 | self.embedding_dim, 76 | self.drop_edge_rate, 77 | self.mask_rate, 78 | self.n_heads, 79 | self.n_lives, 80 | self.seed, 81 | ) 82 | 83 | def init_model(self, load=None): 84 | # Initialize model and training 85 | if load is None: 86 | # Initialize a new model 87 | model = GraphMAE( 88 | in_dim=self.comkg.nodes_feature_dim, 89 | num_hidden=self.embedding_dim, 90 | num_layers=self.n_layer, 91 | feat_drop=self.dropout_rate + 0.1, 92 | attn_drop=self.dropout_rate, 93 | nhead=self.n_heads, 94 | mask_rate=self.mask_rate, 95 | drop_edge_rate=self.drop_edge_rate, 96 | ) 97 | else: 98 | # Load a pre-trained model 99 | model = torch.load(load) 100 | 101 | optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) 102 | lr_schedule = lambda epoch: (1 + np.cos(epoch * np.pi / self.epochs)) * 0.5 103 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule) 104 | return model, optimizer, scheduler 105 | 106 | def inference_gpu(self) -> torch.Tensor: 107 | g = self.graph.to(self.device) 108 | self.model.eval() 109 | embed = self.model.embed(g, g.ndata["feat"]).cpu().detach().numpy() 110 | self.model.train() 111 | return embed / np.linalg.norm(embed, axis=1, keepdims=True) 112 | 113 | def inference_cpu(self) -> torch.Tensor: 114 | tmp_model_path = os.path.join(self.work_folder, "tmp_model.pth") 115 | torch.save(self.model, tmp_model_path) 116 | tmp_model = torch.load(tmp_model_path, map_location="cpu") 117 | tmp_model.eval() 118 | g = self.graph.to("cpu") 119 | embed = tmp_model.embed(g, g.ndata["feat"]).cpu().detach().numpy() 120 | return embed / np.linalg.norm(embed, axis=1, keepdims=True) 121 | 122 | def inference(self) -> torch.Tensor: 123 | try: 124 | return self.inference_gpu() 125 | except: 126 | logger.warning( 127 | f"GPU inference failed, fall back to CPU inference. Please be patient ..." 128 | ) 129 | return self.inference_cpu() 130 | 131 | def training_loss(self, subgraph): 132 | loss, _ = self.model(subgraph, subgraph.ndata["feat"]) 133 | return loss 134 | 135 | def train_epoch( 136 | self, 137 | epoch: int, 138 | train_dataloader: Union[dgl.dataloading.dataloader.DataLoader, list], 139 | metis: int = 1, 140 | ): 141 | self.model.train() 142 | n_step = len(train_dataloader) 143 | for step, subgraph in enumerate(train_dataloader): 144 | subgraph = subgraph.to(self.device) 145 | loss = self.training_loss(subgraph) 146 | self.optimizer.zero_grad() 147 | loss.backward() 148 | self.optimizer.step() 149 | self.scheduler.step() 150 | if metis <= 1: 151 | logger.info(f"epoch {epoch}: loss={loss.item()}") 152 | else: 153 | logger.info( 154 | f"epoch {epoch} | metis {metis} | step {step}/{n_step}: loss={loss.item()}" 155 | ) 156 | 157 | def _train_model(self): 158 | # For high-dim nodes feature, we need to create Metis partitions 159 | if self.comkg.nodes_feature_type != "pause" and not self.disable_metis: 160 | logger.info("Using Metis partitions") 161 | n_metis = [5, 10, 15, 20, 30, 40, 50, 100, 300, 400, 600, 800, 1000] 162 | train_dataloaders = [] 163 | 164 | for p in n_metis: 165 | sampler = dgl.dataloading.ClusterGCNSampler( 166 | self.graph, 167 | p, 168 | cache_path=os.path.join( 169 | self.work_folder, f"{self.comkg.nodes_feature_type}_metis_{p}.pkl" 170 | ), 171 | ) 172 | train_dataloader = dgl.dataloading.DataLoader( 173 | graph=self.graph, 174 | indices=torch.arange(p), 175 | graph_sampler=sampler, 176 | batch_size=math.ceil(p / 10), 177 | shuffle=True, 178 | drop_last=False, 179 | num_workers=2, 180 | ) 181 | train_dataloaders.append(train_dataloader) 182 | use_metis = True 183 | else: 184 | use_metis = False 185 | 186 | # Training Procedure 187 | best_sr_acc = 0 188 | n_lives = self.n_lives 189 | model_save_path = os.path.join(self.work_folder, f"{self.hparams_str}.pth") 190 | 191 | for epoch in range(self.epochs): 192 | logger.info(f"Starting epoch {epoch+1}/{self.epochs}") 193 | eval_gap = 1 194 | if use_metis: 195 | train_dataloader_idx = randrange(len(n_metis)) 196 | current_n_metis = n_metis[train_dataloader_idx] 197 | train_dataloader = train_dataloaders[train_dataloader_idx] 198 | with train_dataloader.enable_cpu_affinity(): 199 | self.train_epoch( 200 | epoch=epoch, 201 | train_dataloader=train_dataloader, 202 | metis=current_n_metis, 203 | ) 204 | else: 205 | eval_gap = 3 206 | self.train_epoch( 207 | epoch=epoch, 208 | train_dataloader=[self.graph], 209 | ) 210 | 211 | # Training SR and SP evaluation 212 | if epoch % eval_gap == 0: 213 | # Inference 214 | embed = self.inference() 215 | 216 | # SR Eval on the validation set 217 | sr_acc = self.comkg.evaluate_sr(embed=embed, split="validation") 218 | logger.info(f"SR Accuracy: {sr_acc}, Lives left: {n_lives}") 219 | 220 | # SR task is prioritized here due to it being more challenging 221 | # and having a defined validation set 222 | if sr_acc > best_sr_acc: 223 | best_sr_acc = sr_acc 224 | n_lives = self.n_lives 225 | 226 | # Save the best-so-far trained model 227 | torch.save(self.model, model_save_path) 228 | logger.info(f"Model saved to {model_save_path}") 229 | else: 230 | n_lives -= 1 231 | if n_lives < 0: 232 | break 233 | 234 | # Load the best model 235 | self.model = torch.load(model_save_path) 236 | logger.info(f"Best model loaded from {model_save_path}") 237 | 238 | 239 | train_model = GraphMAETrainer.train 240 | 241 | 242 | if __name__ == "__main__": 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument( 245 | "--epochs", 246 | default=500, 247 | type=ranged_type(int, 1, 2000), 248 | help="The max number of training epochs", 249 | ) 250 | parser.add_argument( 251 | "--nodes-feature-type", 252 | type=str, 253 | default="msbert", 254 | choices=["msbert", "ada2", "simcse", "pause"], 255 | help="The type of nodes feature: msbert, ada2, simcse or pause", 256 | ) 257 | parser.add_argument( 258 | "--device", 259 | default=0, 260 | type=int, 261 | help="The device used to carry out the training: use cpu when less than 0", 262 | ) 263 | parser.add_argument( 264 | "--data-root-folder", 265 | default="./data", 266 | type=str, 267 | help="The root folder where the CompanyKG data is downloaded to", 268 | ) 269 | parser.add_argument( 270 | "--finetune-from", 271 | type=str, 272 | help="The saved model to be finetuned from, ex. ./msbert_1000_2_0.01_0.001_256_0.1_0.1_8_100_42.pth", 273 | ) 274 | parser.add_argument( 275 | "--work-folder", 276 | default="./experiments", 277 | type=str, 278 | help="The working folder where models and logs are saved to", 279 | ) 280 | parser.add_argument( 281 | "--n-layer", 282 | default=2, 283 | type=ranged_type(int, 2, 4), 284 | help="The number of GNN layers", 285 | ) 286 | parser.add_argument( 287 | "--dropout-rate", 288 | default=0.1, 289 | type=ranged_type(float, 0.0, 0.3), 290 | help="The feature dropout rate of GNN layers", 291 | ) 292 | parser.add_argument( 293 | "--learning-rate", 294 | default=0.001, 295 | type=ranged_type(float, 0.00001, 0.01), 296 | help="The training learning rate", 297 | ) 298 | parser.add_argument( 299 | "--embedding-dim", 300 | default=64, 301 | type=ranged_type(int, 8, 512), 302 | help="The dimension of the embedding to be learned", 303 | ) 304 | parser.add_argument( 305 | "--drop-edge-rate", 306 | default=0.5, 307 | type=ranged_type(float, 0.1, 0.9), 308 | help="The rate of edges to be dropped during training", 309 | ) 310 | parser.add_argument( 311 | "--mask-rate", 312 | default=0.5, 313 | type=ranged_type(float, 0.1, 0.9), 314 | help="The rate of nodes feature to be masked during training", 315 | ) 316 | parser.add_argument( 317 | "--n-heads", 318 | default=8, 319 | type=ranged_type(int, 1, 8), 320 | help="The number of attention heads", 321 | ) 322 | parser.add_argument( 323 | "--n-lives", 324 | default=100, 325 | type=ranged_type(int, 10, 200), 326 | help="The number of training epochs allowed when evaluation metrics are not improving", 327 | ) 328 | parser.add_argument( 329 | "--seed", 330 | default=42, 331 | type=int, 332 | help="The seed used to run the experiment", 333 | ) 334 | parser.add_argument( 335 | "--disable-metis", 336 | action="store_true", 337 | help="Force trainer not to use Metis partitions, even when the embedding size " 338 | "is large" 339 | ) 340 | opts = parser.parse_args() 341 | 342 | root_logger = logging.getLogger() 343 | root_logger.setLevel(logging.INFO) 344 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") 345 | console_handler = logging.StreamHandler() 346 | console_handler.setFormatter(log_formatter) 347 | root_logger.addHandler(console_handler) 348 | 349 | # Create trainer and run 350 | logger.info("Initializing trainer") 351 | trainer = GraphMAETrainer( 352 | nodes_feature_type=opts.nodes_feature_type, 353 | data_root_folder=opts.data_root_folder, 354 | embedding_dim=opts.embedding_dim, 355 | n_layer=opts.n_layer, 356 | dropout_rate=opts.dropout_rate, 357 | n_heads=opts.n_heads, 358 | mask_rate=opts.mask_rate, 359 | drop_edge_rate=opts.drop_edge_rate, 360 | learning_rate=opts.learning_rate, 361 | epochs=opts.epochs, 362 | n_lives=opts.n_lives, 363 | seed=opts.seed, 364 | device=opts.device, 365 | finetune_from=opts.finetune_from, 366 | work_folder=opts.work_folder, 367 | disable_metis=opts.disable_metis, 368 | ) 369 | 370 | logger.info("Starting model training") 371 | trainer.train_model() 372 | logger.info("Training complete") 373 | -------------------------------------------------------------------------------- /src/companykg/kg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) eqtgroup.com Ltd 2023 3 | https://github.com/EQTPartners/CompanyKG 4 | License: MIT, https://github.com/EQTPartners/CompanyKG/LICENSE.md 5 | """ 6 | 7 | import logging 8 | import os 9 | from typing import Tuple, List 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from scipy import spatial 15 | from sklearn.metrics import roc_auc_score, accuracy_score 16 | 17 | from companykg.settings import ( 18 | EDGES_FILENAME, 19 | ZENODO_DATASET_BASE_URI, 20 | EDGES_WEIGHTS_FILENAME, 21 | NODES_FEATURES_FILENAME_TEMPLATE, 22 | EVAL_TASK_FILENAME_TEMPLATE, 23 | ) 24 | from companykg.utils import download_zenodo_file 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class CompanyKG: 30 | """The CompanyKG class that provides utility functions 31 | to load data and carry out evaluations. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | nodes_feature_type: str = "msbert", 37 | load_edges_weights: bool = False, 38 | data_root_folder: str = "./data", 39 | ) -> None: 40 | """Initialize a CompanyKG object. 41 | 42 | Args: 43 | nodes_feature_type (str, optional): the desired note feature type. 44 | Viable values include "msbert", "pause", "simcse", "ada2". Defaults to "msbert". 45 | load_edges_weights (bool, optional): load edge weights or not. Defaults to False. 46 | data_root_folder (str, optional): root folder of downloaded data. Defaults to "./data". 47 | If the folder does not exist, the latest version of the dataset will be downloaded from 48 | Zenodo. 49 | """ 50 | 51 | self.data_root_folder = data_root_folder 52 | 53 | # Load nodes feature: only load one type 54 | self.nodes_feature_type = nodes_feature_type 55 | 56 | # Create a local data directory - NOP if directory already exists 57 | os.makedirs(data_root_folder, exist_ok=True) 58 | 59 | # Load edges 60 | # First check if edges file exists - download if it doesn't 61 | self.edges_file = os.path.join(data_root_folder, EDGES_FILENAME) 62 | if not os.path.exists(self.edges_file): 63 | download_zenodo_file( 64 | os.path.join(ZENODO_DATASET_BASE_URI, EDGES_FILENAME), 65 | self.edges_file, 66 | ) 67 | self.edges = torch.load(self.edges_file) 68 | logger.info(f"[DONE] Loaded {self.edges_file}") 69 | 70 | # Load edge weights [Optional] 71 | # First check if edge weights file exists - download if it doesn't 72 | self.load_edges_weights = load_edges_weights 73 | if load_edges_weights: 74 | self.edges_weight_file = os.path.join( 75 | data_root_folder, EDGES_WEIGHTS_FILENAME 76 | ) 77 | if not os.path.exists(self.edges_weight_file): 78 | download_zenodo_file( 79 | os.path.join(ZENODO_DATASET_BASE_URI, EDGES_WEIGHTS_FILENAME), 80 | self.edges_weight_file, 81 | ) 82 | self.edges_weight = torch.load(self.edges_weight_file).to_dense() 83 | logger.info(f"[DONE] Loaded {self.edges_weight_file}") 84 | 85 | # Load nodes feaures file 86 | # Check for nodes features file - download if it doesn't exist 87 | _nodes_feature_filename = NODES_FEATURES_FILENAME_TEMPLATE.replace( 88 | "", 89 | nodes_feature_type, 90 | ) 91 | self.nodes_feature_file = os.path.join( 92 | data_root_folder, _nodes_feature_filename 93 | ) 94 | if not os.path.exists(self.nodes_feature_file): 95 | download_zenodo_file( 96 | os.path.join(ZENODO_DATASET_BASE_URI, _nodes_feature_filename), 97 | self.nodes_feature_file, 98 | ) 99 | self._load_node_features() 100 | logger.info(f"[DONE] Loaded {self.nodes_feature_file}") 101 | 102 | # Load evaluation test data 103 | self.eval_task_types = ("sp", "sr", "cr") 104 | self.eval_tasks = dict() 105 | for task_type in self.eval_task_types: 106 | # Check if evaluation test data exists - otherwise download it 107 | _eval_task_filename = EVAL_TASK_FILENAME_TEMPLATE.replace( 108 | "", task_type 109 | ) 110 | _eval_task_file = os.path.join(data_root_folder, _eval_task_filename) 111 | if not os.path.exists(_eval_task_file): 112 | download_zenodo_file( 113 | os.path.join(ZENODO_DATASET_BASE_URI, _eval_task_filename), 114 | _eval_task_file, 115 | ) 116 | self.eval_tasks[task_type] = ( 117 | _eval_task_file, 118 | pd.read_parquet(_eval_task_file), 119 | ) 120 | logger.info(f"[DONE] Loaded {_eval_task_file}") 121 | 122 | self.n_edges = self.edges.shape[0] 123 | if self.load_edges_weights: 124 | self.edges_weight_dim = self.edges_weight.shape[1] 125 | 126 | # Default Top-K for CR task 127 | self.eval_cr_top_ks = [50, 100, 200, 500, 1000, 2000, 5000, 10000] 128 | 129 | def _load_node_features(self): 130 | self.nodes_feature = torch.load(self.nodes_feature_file) 131 | if self.nodes_feature.dtype is not torch.float32: 132 | self.nodes_feature = self.nodes_feature.to(dtype=torch.float32) 133 | # Set Vars 134 | self.n_nodes = self.nodes_feature.shape[0] 135 | self.nodes_feature_dim = self.nodes_feature.shape[1] 136 | 137 | def change_feature_type(self, feature_type: str): 138 | if feature_type != self.nodes_feature_type: 139 | self.nodes_feature_type = feature_type 140 | self._load_node_features() 141 | 142 | @property 143 | def nodes_id(self) -> list: 144 | """Get an ordered list of node IDs. 145 | 146 | Returns: 147 | list: an ordered (ascending) list of node IDs. 148 | """ 149 | return [i for i in range(self.n_nodes)] 150 | 151 | def describe(self) -> None: 152 | """Print key statistics of loaded data.""" 153 | print(f"data_root_folder={self.data_root_folder}") 154 | print(f"n_nodes={self.n_nodes}, n_edges={self.n_edges}") 155 | print(f"nodes_feature_type={self.nodes_feature_type}") 156 | print(f"nodes_feature_dimension={self.nodes_feature_dim}") 157 | if self.load_edges_weights: 158 | print(f"edges_weight_dimension={self.edges_weight_dim}") 159 | for task_type in self.eval_task_types: 160 | print(f"{task_type}: {len(self.eval_tasks[task_type][1])} samples") 161 | 162 | def to_pyg(self): 163 | """ 164 | Build a PyTorch-geometric graph from the loaded CompanyKG. 165 | 166 | """ 167 | try: 168 | from torch_geometric.data import Data 169 | except ImportError as e: 170 | raise ImportError( 171 | "pytorch-geometric is not installed: please install to produce PyG graph" 172 | ) from e 173 | 174 | # Incxlude edges going in both directions, since PyG uses directed graphs 175 | edge_index = torch.concat([self.edges.T, self.edges[:, [1, 0]].T], dim=1) 176 | return Data(x=self.nodes_feature, edge_index=edge_index) 177 | 178 | def to_igraph(self): 179 | """ 180 | Build an iGraph graph from the loaded CompanyKG. 181 | Requires iGraph to be installed. 182 | 183 | """ 184 | try: 185 | import igraph as ig 186 | except ImportError as e: 187 | raise ImportError( 188 | "python-igraph is not installed: please install to produce iGraph graph" 189 | ) from e 190 | 191 | g = ig.Graph() 192 | g.add_vertices(self.n_nodes) 193 | # Names should be strings 194 | g.vs["name"] = [str(i) for i in self.nodes_id] 195 | 196 | logger.info("Building iGraph graph from edges") 197 | if self.load_edges_weights: 198 | # Convert tensors to Np arrays 199 | edge_weights = self.edges_weight.numpy() 200 | edges = self.edges.numpy() 201 | 202 | # Flatten the non-zero weights for each edge so we have a separate edge for each weight type 203 | nonzeros = np.nonzero(edge_weights) 204 | # These are just the column indices of the nonzeros 205 | types = nonzeros[1] 206 | # The weights for these separated edges are the flattened non-zero values 207 | weights = edge_weights[nonzeros] 208 | # The edges themselves are indexed by the row indices of the non-zero values 209 | # This repeats edges where there are multiple non-zero weight types 210 | edges = edges[nonzeros[0]] 211 | 212 | # Flatten the non-zero weights for each edge so we have a separate edge for each weight type 213 | attrs = { 214 | "type": types, 215 | "weight": weights, 216 | } 217 | g.add_edges( 218 | edges, 219 | attributes=attrs, 220 | ) 221 | else: 222 | g.add_edges((i, j) for (i, j) in self.edges) 223 | return g 224 | 225 | def evaluate_sp(self, embed: torch.Tensor) -> float: 226 | """Evaluate the specified node embeddings on SP task. 227 | 228 | Args: 229 | embed (torch.Tensor): the node embeddings to be evaluated. 230 | 231 | Returns: 232 | float: AUC score on SP task. 233 | """ 234 | test_data = self.eval_tasks["sp"][1] 235 | gt = test_data["label"].tolist() 236 | pred = [] 237 | for _, row in test_data.iterrows(): 238 | node_embeds = (embed[row["node_id0"]], embed[row["node_id1"]]) 239 | try: 240 | with np.errstate(invalid="ignore"): 241 | pred.append( 242 | 1 - 0.5 * spatial.distance.cosine(node_embeds[0], node_embeds[1]) 243 | ) 244 | except: 245 | print(row) 246 | raise 247 | return roc_auc_score(gt, pred) 248 | 249 | def evaluate_sr(self, embed: torch.Tensor, split: str = "validation") -> float: 250 | """Evaluate the specified node embeddings on SR task. 251 | 252 | Args: 253 | embed (torch.Tensor): the node embeddings to be evaluated. 254 | split (str): the split (validation/test) on which the evaluation will be run. 255 | 256 | Returns: 257 | float: Accuracy on SR task. 258 | """ 259 | test_data = self.eval_tasks["sr"][1] 260 | test_data = test_data[test_data["split"] == split] 261 | gt = test_data["label"].tolist() 262 | pred = [] 263 | for _, row in test_data.iterrows(): 264 | query_embed = embed[row["target_node_id"]] 265 | candidate0_embed = embed[row["candidate0_node_id"]] 266 | candidate1_embed = embed[row["candidate1_node_id"]] 267 | with np.errstate(invalid="ignore"): 268 | _p1 = 1 - 0.5 * spatial.distance.cosine(query_embed, candidate0_embed) 269 | _p2 = 1 - 0.5 * spatial.distance.cosine(query_embed, candidate1_embed) 270 | pred.append(0) if _p1 >= _p2 else pred.append(1) 271 | return accuracy_score(gt, pred) 272 | 273 | @staticmethod 274 | def search_most_similar( 275 | target_embed: torch.Tensor, embed: torch.Tensor 276 | ) -> Tuple[np.array, np.array]: 277 | """Search top-K most similar nodes to a target node. 278 | 279 | Args: 280 | target_embed (torch.Tensor): the embedding of the target node. 281 | embed (torch.Tensor): the node embeddings to be searched from, i.e. candidate nodes. 282 | K (int, optional): the number of nodes to be returned as search result. Defaults to 50. 283 | 284 | Returns: 285 | Tuple[np.array, np.array]: the node IDs and the cosine similarity scores. 286 | """ 287 | with np.errstate(invalid="ignore"): 288 | sims = np.dot(embed, target_embed) / ( 289 | np.linalg.norm(embed, axis=1) * np.linalg.norm(target_embed) 290 | ) 291 | # Reverse so the most similar is first 292 | max_ids = np.argsort(sims)[ 293 | -2::-1 294 | ] # remove target company (first element in the reversed array) 295 | return max_ids, sims[max_ids] 296 | 297 | def cr_top_ks(self, embed: torch.Tensor, ks: List[int]): 298 | """ 299 | Evaluate CR (Competitor retrieval) as the average recall @ k for a number of 300 | different values of k. 301 | 302 | :param embed: the node embeddings to be evaluated 303 | :param ks: list of k-values to evaluate at 304 | :return: 305 | """ 306 | test_data = self.eval_tasks["cr"][1] 307 | target_nids = list(test_data["target_node_id"].unique()) 308 | target_nids.sort() 309 | # Collect recalls for different ks from each sample 310 | k_recalls = dict((k, []) for k in ks) 311 | 312 | for nid in target_nids: 313 | competitor_df = test_data[test_data["target_node_id"] == nid] 314 | competitor_nids = set(list(competitor_df["competitor_node_id"].unique())) 315 | # Get as many predictions as we'll need for the highest k 316 | res_nids, res_dists = CompanyKG.search_most_similar(embed[nid], embed) 317 | for k in sorted(ks): 318 | k_res_nids = set(res_nids[:k]) 319 | common_set = k_res_nids & competitor_nids 320 | recall = len(common_set) / len(competitor_nids) 321 | k_recalls[k].append(recall) 322 | 323 | # Average the recalls over samples for each k 324 | recalls = [np.mean(k_recalls[k]) for k in sorted(ks)] 325 | return recalls 326 | 327 | def cr_top_k(self, embed: torch.Tensor, k: int = 50) -> float: 328 | """Evaluate CR (Competitor Retrieval) performance using top-K hit rate. 329 | This function will evaluate each target company in CR test set. 330 | 331 | Args: 332 | embed (torch.Tensor): the node embeddings to be evaluated 333 | k (int, optional): the number of nodes to be returned as search result. Defaults to 50. 334 | 335 | Returns: 336 | Tuple[float, list]: the overall hit rate and the per-target hit rate. 337 | """ 338 | return self.cr_top_ks(embed, [k])[0] 339 | 340 | def evaluate_cr(self, embed: torch.Tensor) -> list: 341 | """Evaluate the specified node embeddings on CR task. 342 | 343 | Args: 344 | embed (torch.Tensor): the node embeddings to be evaluated. 345 | 346 | Returns: 347 | float: the list of tuples containing the CR results. 348 | The first element in each tuple is the overall hit rate for top-K. 349 | """ 350 | return self.cr_top_ks(embed, self.eval_cr_top_ks) 351 | 352 | def evaluate( 353 | self, 354 | embeddings_file: str = None, 355 | embed: torch.Tensor = None, 356 | silent: bool = False, 357 | ) -> dict: 358 | """Evaluate the specified embedding on all evaluation tasks: SP, SR and CR. 359 | When none parameters provided, it will evaluate the embodied nodes feature. 360 | 361 | Args: 362 | embeddings_file (str, optional): the path to the embedding file; 363 | it has highest priority. Defaults to None. 364 | embed (torch.Tensor, optional): the embedding to be evaluated; 365 | it has second highest priority. Defaults to None. 366 | silent (bool): by default, evaluation results are printed to stdout; 367 | if True, nothing is output, you just get the results in the 368 | returned dict 369 | 370 | Returns: 371 | dict: a dictionary of evaluation results. 372 | """ 373 | if embeddings_file is not None: 374 | try: 375 | embed = torch.load(embeddings_file) 376 | except: 377 | embed = torch.load(embeddings_file, map_location="cpu") 378 | result_dict = {"source": embeddings_file} 379 | if not silent: 380 | print(f"Evaluate Node Embeddings {embeddings_file}:") 381 | elif embed is not None: 382 | result_dict = {"source": f"embed {embed.shape}"} 383 | if not silent: 384 | print(f"Evaluate Custom Embeddings:") 385 | else: 386 | embed = self.nodes_feature 387 | result_dict = {"source": self.nodes_feature_type} 388 | if not silent: 389 | print(f"Evaluate Node Features {self.nodes_feature_type}:") 390 | # SP Task 391 | if not silent: 392 | print("Evaluate SP ...") 393 | result_dict["sp_auc"] = self.evaluate_sp(embed) 394 | if not silent: 395 | print("SP AUC:", result_dict["sp_auc"]) 396 | # SR Task 397 | if not silent: 398 | print("Evaluate SR ...") 399 | result_dict["sr_validation_acc"] = self.evaluate_sr(embed) 400 | result_dict["sr_test_acc"] = self.evaluate_sr(embed, split="test") 401 | if not silent: 402 | print( 403 | "SR Validation ACC:", 404 | result_dict["sr_validation_acc"], 405 | "SR Test ACC:", 406 | result_dict["sr_test_acc"], 407 | ) 408 | 409 | # CR Task 410 | if not silent: 411 | print(f"Evaluate CR with top-K hit rate (K={self.eval_cr_top_ks}) ...") 412 | result_dict["cr_topk_hit_rate"] = self.evaluate_cr(embed) 413 | if not silent: 414 | print("CR Hit Rates:", result_dict["cr_topk_hit_rate"]) 415 | 416 | return result_dict 417 | 418 | def get_dgl_graph(self, work_folder: str) -> list: 419 | """Obtain a DGL graph. If it has not been built before, a new graph will be constructed, 420 | otherwise it will simply load from file in the specified working directory. 421 | 422 | Args: 423 | work_folder (str): the working directory of graph building. 424 | 425 | Returns: 426 | list: the built graph(s). 427 | """ 428 | try: 429 | import dgl 430 | except ImportError as e: 431 | raise ImportError( 432 | "DGL is not installed. Please install to produce DGL graph" 433 | ) from e 434 | 435 | dgl_file = os.path.join(work_folder, f"dgl_{self.nodes_feature_type}.bin") 436 | if os.path.isfile(dgl_file): 437 | return dgl.data.utils.load_graphs(dgl_file)[0] 438 | else: 439 | graph_data = { 440 | ("_N", "_E", "_N"): self.edges.tolist(), 441 | } 442 | g = dgl.heterograph(graph_data) 443 | g.ndata["feat"] = self.nodes_feature 444 | if self.load_edges_weights: 445 | g.edata["weight"] = self.edges_weight 446 | dgl.data.utils.save_graphs(dgl_file, [g]) 447 | return [g] 448 | -------------------------------------------------------------------------------- /tutorials/gcl_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4bc68038-aef2-4c70-99bc-9bab76607231", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# Tutorial: training a GCL benchmark model\n", 11 | "\n", 12 | "This tutorial uses the `ckg_benchmarks` package to train a GCL model and evaluate it on CompanyKG.\n", 13 | "\n", 14 | "The trained model is not likely to produce good results, since the hyperparameters are set for minimal computation (e.g. only one training epoch). But this code can serve as an example for training better models and a test of the GCL model training code.\n", 15 | "\n", 16 | "We demonstrate how to train with GRACE and MVGRL here. (The only difference is the `method` argument to `train_model`.)\n", 17 | "\n", 18 | "You can apply an almost identical training procedure to other GNN training methods by using their `train_model` functions and adjusting the parameters." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "34c3bc77-9eb8-4d05-84dd-f6c597c20760", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "%load_ext autoreload\n", 29 | "%autoreload 2" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "b42e0eec-ec4f-4828-9cda-1a68aea6bcac", 35 | "metadata": {}, 36 | "source": [ 37 | "We initialize logging so that we see model training progress." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "40d9db24-ef37-4e6c-a96f-7340e80c330e", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import logging\n", 48 | "logger = logging.getLogger()\n", 49 | "handler = logging.StreamHandler()\n", 50 | "handler.setFormatter(\n", 51 | " logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')\n", 52 | ")\n", 53 | "logger.addHandler(handler)\n", 54 | "logger.setLevel(logging.INFO)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "9764cf6b-2572-420d-822e-2b8488f8e7da", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from companykg import CompanyKG\n", 65 | "from ckg_benchmarks.gcl.train import train_model" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "5f74dcb9-aa91-403b-a671-38a98677c26b", 71 | "metadata": {}, 72 | "source": [ 73 | "## Load data\n", 74 | "\n", 75 | "Prepare the CompanyKG dataset. \n", 76 | "\n", 77 | "The first time you run this, the data will be downloaded from Zenodo to the `data` subdirectory, which could take some time. After that, it will be quicker to load.\n", 78 | "\n", 79 | "The dataset is then loaded into memory using the mSBERT node feature type.\n", 80 | "\n", 81 | "This step is not strictly necessary, as we don't use the loaded data for training: the training routine takes care of loading it itself. But loading it here causes it to be downloaded if necessary." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "id": "aeeabe85-74ea-47ea-8617-a14495938625", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stderr", 92 | "output_type": "stream", 93 | "text": [ 94 | "2023-06-14 12:39:58,723 companykg.kg INFO [DONE] Loaded ./data/edges.pt\n", 95 | "2023-06-14 12:40:10,900 companykg.kg INFO [DONE] Loaded ./data/edges_weight.pt\n", 96 | "2023-06-14 12:40:31,937 companykg.kg INFO [DONE] Loaded ./data/nodes_feature_msbert.pt\n", 97 | "2023-06-14 12:40:33,453 companykg.kg INFO [DONE] Loaded ./data/eval_task_sp.parquet.gz\n", 98 | "2023-06-14 12:40:33,521 companykg.kg INFO [DONE] Loaded ./data/eval_task_sr.parquet.gz\n", 99 | "2023-06-14 12:40:33,526 companykg.kg INFO [DONE] Loaded ./data/eval_task_cr.parquet.gz\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "ckg = CompanyKG(\n", 105 | " nodes_feature_type=\"msbert\", \n", 106 | " load_edges_weights=True,\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "8f3e7947-598b-4570-94e8-7516e8bab756", 113 | "metadata": {}, 114 | "source": [ 115 | "## GRACE training" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "a3fb668d-8510-42d9-934d-38988053dd08", 121 | "metadata": {}, 122 | "source": [ 123 | "### Train model\n", 124 | "Now we set a minimal GCL model training using GRACE.\n", 125 | "\n", 126 | "Training uses a GPU if it's available.\n", 127 | "\n", 128 | "To train a better model, adjust the parameters set here, in particular `epochs`.\n", 129 | "\n", 130 | "Calling this training method is equivalent to running the following command:\n", 131 | "```\n", 132 | "python -m ckg_benchmarks.gcl.train \\\n", 133 | " --device -1 \\\n", 134 | " --method grace \\\n", 135 | " --n-layer 1 \\\n", 136 | " --embedding-dim 8 \\\n", 137 | " --epochs 1 \\\n", 138 | " --sampler-edges 2 \\\n", 139 | " --batch-size 128\n", 140 | "```" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "id": "a28865e3-db96-43af-9964-d3c6ca615d73", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "2023-06-14 12:41:41,929 ckg_benchmarks.base INFO Initializing model and trainer\n", 154 | "2023-06-14 12:41:42,184 companykg.kg INFO [DONE] Loaded ./data/edges.pt\n", 155 | "2023-06-14 12:41:42,912 companykg.kg INFO [DONE] Loaded ./data/nodes_feature_msbert.pt\n", 156 | "2023-06-14 12:41:42,917 companykg.kg INFO [DONE] Loaded ./data/eval_task_sp.parquet.gz\n", 157 | "2023-06-14 12:41:42,921 companykg.kg INFO [DONE] Loaded ./data/eval_task_sr.parquet.gz\n", 158 | "2023-06-14 12:41:42,925 companykg.kg INFO [DONE] Loaded ./data/eval_task_cr.parquet.gz\n" 159 | ] 160 | }, 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "data_root_folder=./data\n", 166 | "n_nodes=1169931, n_edges=50815503\n", 167 | "nodes_feature_type=msbert\n", 168 | "nodes_feature_dimension=512\n", 169 | "sp: 3219 samples\n", 170 | "sr: 1856 samples\n", 171 | "cr: 400 samples\n" 172 | ] 173 | }, 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "2023-06-14 12:41:43,293 ckg_benchmarks.base INFO Data(x=[1169931, 512], edge_index=[2, 101631006])\n", 179 | "2023-06-14 12:41:43,296 ckg_benchmarks.base INFO Starting model training\n", 180 | "2023-06-14 12:42:02,491 ckg_benchmarks.base INFO Sending training logs to experiments/grace/msbert_1_1_8_2_128_42.log\n", 181 | "2023-06-14 12:42:02,492 ckg_benchmarks.base INFO Strating model training\n", 182 | "2023-06-14 12:42:04,991 ckg_benchmarks.gcl.train INFO Starting epoch 1\n" 183 | ] 184 | }, 185 | { 186 | "data": { 187 | "application/vnd.jupyter.widget-view+json": { 188 | "model_id": "42ea2ec415e24d46b65f004458da7a35", 189 | "version_major": 2, 190 | "version_minor": 0 191 | }, 192 | "text/plain": [ 193 | "Epoch 1/1: 0%| | 0/9141 [00:00\n", 184 | "\n", 197 | "\n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | "
node_id0node_id1label
01217691366310
11511073368520
23336013638220
324814195710
4372534371440
............
32143612228832351
32155361548832351
321633005810884871
321749459410884871
3218162410884871
\n", 275 | "

3219 rows × 3 columns

\n", 276 | "" 277 | ], 278 | "text/plain": [ 279 | " node_id0 node_id1 label\n", 280 | "0 121769 136631 0\n", 281 | "1 151107 336852 0\n", 282 | "2 333601 363822 0\n", 283 | "3 2481 419571 0\n", 284 | "4 37253 437144 0\n", 285 | "... ... ... ...\n", 286 | "3214 361222 883235 1\n", 287 | "3215 536154 883235 1\n", 288 | "3216 330058 1088487 1\n", 289 | "3217 494594 1088487 1\n", 290 | "3218 1624 1088487 1\n", 291 | "\n", 292 | "[3219 rows x 3 columns]" 293 | ] 294 | }, 295 | "execution_count": 8, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "# SP samples\n", 302 | "comkg.eval_tasks['sp'][1]" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 9, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "data": { 312 | "text/html": [ 313 | "
\n", 314 | "\n", 327 | "\n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | "
target_node_idcandidate0_node_idcandidate1_node_idlabelsplit
02013891984357976520test
14507036184866243840test
210974152979783865840validation
38100024441010165340test
4861572115565811152080test
..................
18515222576690896079811test
185210836622030704824781validation
18533542765518878659950test
185483067250488210468821test
1855804707991173917441test
\n", 429 | "

1856 rows × 5 columns

\n", 430 | "
" 431 | ], 432 | "text/plain": [ 433 | " target_node_id candidate0_node_id candidate1_node_id label \\\n", 434 | "0 201389 198435 797652 0 \n", 435 | "1 450703 618486 624384 0 \n", 436 | "2 1097415 297978 386584 0 \n", 437 | "3 81000 244410 1016534 0 \n", 438 | "4 861572 1155658 1115208 0 \n", 439 | "... ... ... ... ... \n", 440 | "1851 522257 669089 607981 1 \n", 441 | "1852 1083662 203070 482478 1 \n", 442 | "1853 354276 551887 865995 0 \n", 443 | "1854 830672 504882 1046882 1 \n", 444 | "1855 804707 991173 91744 1 \n", 445 | "\n", 446 | " split \n", 447 | "0 test \n", 448 | "1 test \n", 449 | "2 validation \n", 450 | "3 test \n", 451 | "4 test \n", 452 | "... ... \n", 453 | "1851 test \n", 454 | "1852 validation \n", 455 | "1853 test \n", 456 | "1854 test \n", 457 | "1855 test \n", 458 | "\n", 459 | "[1856 rows x 5 columns]" 460 | ] 461 | }, 462 | "execution_count": 9, 463 | "metadata": {}, 464 | "output_type": "execute_result" 465 | } 466 | ], 467 | "source": [ 468 | "# SR samples\n", 469 | "comkg.eval_tasks['sr'][1]" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 10, 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "data": { 479 | "text/html": [ 480 | "
\n", 481 | "\n", 494 | "\n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | "
target_node_idcompetitor_node_id
0384334994
13843263332
238431034500
3498145823
44981288480
.........
39511446341004440
39611446341077443
3971163522172921
3981163522268689
39911635221149354
\n", 560 | "

400 rows × 2 columns

\n", 561 | "
" 562 | ], 563 | "text/plain": [ 564 | " target_node_id competitor_node_id\n", 565 | "0 3843 34994\n", 566 | "1 3843 263332\n", 567 | "2 3843 1034500\n", 568 | "3 4981 45823\n", 569 | "4 4981 288480\n", 570 | ".. ... ...\n", 571 | "395 1144634 1004440\n", 572 | "396 1144634 1077443\n", 573 | "397 1163522 172921\n", 574 | "398 1163522 268689\n", 575 | "399 1163522 1149354\n", 576 | "\n", 577 | "[400 rows x 2 columns]" 578 | ] 579 | }, 580 | "execution_count": 10, 581 | "metadata": {}, 582 | "output_type": "execute_result" 583 | } 584 | ], 585 | "source": [ 586 | "# CR samples\n", 587 | "comkg.eval_tasks['cr'][1]" 588 | ] 589 | }, 590 | { 591 | "attachments": {}, 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "## Evaluate Node Feature" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": 11, 601 | "metadata": {}, 602 | "outputs": [ 603 | { 604 | "name": "stdout", 605 | "output_type": "stream", 606 | "text": [ 607 | "Evaluate Node Features msbert:\n", 608 | "Evaluate SP ...\n", 609 | "SP AUC: 0.8059550091101482\n", 610 | "Evaluate SR ...\n", 611 | "SR Validation ACC: 0.6956521739130435 SR Test ACC: 0.6713709677419355\n", 612 | "Evaluate CR with top-K hit rate (K=[50, 100, 200, 500, 1000, 2000, 5000, 10000]) ...\n", 613 | "CR Hit Rates: [0.12955922001974632, 0.18240535049745576, 0.23030967570441258, 0.31102329687856, 0.4143004291030607, 0.47711466165413524, 0.5583993126756285, 0.6349049707602339]\n" 614 | ] 615 | } 616 | ], 617 | "source": [ 618 | "# Run all evaluation tasks on the loaded node feature\n", 619 | "eval_results = comkg.evaluate()" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 12, 625 | "metadata": {}, 626 | "outputs": [ 627 | { 628 | "data": { 629 | "text/plain": [ 630 | "0.8059550091101482" 631 | ] 632 | }, 633 | "execution_count": 12, 634 | "metadata": {}, 635 | "output_type": "execute_result" 636 | } 637 | ], 638 | "source": [ 639 | "# Show AUC score for SP task\n", 640 | "eval_results[\"sp_auc\"]" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 13, 646 | "metadata": {}, 647 | "outputs": [ 648 | { 649 | "data": { 650 | "text/plain": [ 651 | "0.6713709677419355" 652 | ] 653 | }, 654 | "execution_count": 13, 655 | "metadata": {}, 656 | "output_type": "execute_result" 657 | } 658 | ], 659 | "source": [ 660 | "# Show test accuracy for SR task\n", 661 | "eval_results[\"sr_test_acc\"]" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "execution_count": 14, 667 | "metadata": {}, 668 | "outputs": [ 669 | { 670 | "data": { 671 | "text/plain": [ 672 | "0.6956521739130435" 673 | ] 674 | }, 675 | "execution_count": 14, 676 | "metadata": {}, 677 | "output_type": "execute_result" 678 | } 679 | ], 680 | "source": [ 681 | "# Show validation accuracy for SR task\n", 682 | "eval_results[\"sr_validation_acc\"]" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 15, 688 | "metadata": {}, 689 | "outputs": [ 690 | { 691 | "data": { 692 | "text/plain": [ 693 | "[0.12955922001974632,\n", 694 | " 0.18240535049745576,\n", 695 | " 0.23030967570441258,\n", 696 | " 0.31102329687856,\n", 697 | " 0.4143004291030607,\n", 698 | " 0.47711466165413524,\n", 699 | " 0.5583993126756285,\n", 700 | " 0.6349049707602339]" 701 | ] 702 | }, 703 | "execution_count": 15, 704 | "metadata": {}, 705 | "output_type": "execute_result" 706 | } 707 | ], 708 | "source": [ 709 | "# Show Top-K Hit Rate for CR task\n", 710 | "eval_results[\"cr_topk_hit_rate\"]" 711 | ] 712 | }, 713 | { 714 | "attachments": {}, 715 | "cell_type": "markdown", 716 | "metadata": {}, 717 | "source": [ 718 | "## Evaluate Saved Embedding" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 16, 724 | "metadata": {}, 725 | "outputs": [ 726 | { 727 | "name": "stdout", 728 | "output_type": "stream", 729 | "text": [ 730 | "Evaluate Node Embeddings ./data/nodes_feature_msbert.pt:\n", 731 | "Evaluate SP ...\n", 732 | "SP AUC: 0.8059550091101482\n", 733 | "Evaluate SR ...\n", 734 | "SR Validation ACC: 0.6956521739130435 SR Test ACC: 0.6713709677419355\n", 735 | "Evaluate CR with top-K hit rate (K=[50, 100, 200, 500, 1000, 2000, 5000, 10000]) ...\n", 736 | "CR Hit Rates: [0.12955922001974632, 0.18240535049745576, 0.23030967570441258, 0.31102329687856, 0.4143004291030607, 0.47711466165413524, 0.5583993126756285, 0.6349049707602339]\n" 737 | ] 738 | } 739 | ], 740 | "source": [ 741 | "# Run all evaluation tasks on the specified embeddings saved in torch.Tensor format\n", 742 | "\n", 743 | "EMBEDDINGS_FILE = \"./data/nodes_feature_msbert.pt\"\n", 744 | "\n", 745 | "eval_results = comkg.evaluate(embeddings_file=EMBEDDINGS_FILE)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 17, 751 | "metadata": {}, 752 | "outputs": [ 753 | { 754 | "data": { 755 | "text/plain": [ 756 | "0.8059550091101482" 757 | ] 758 | }, 759 | "execution_count": 17, 760 | "metadata": {}, 761 | "output_type": "execute_result" 762 | } 763 | ], 764 | "source": [ 765 | "# Show AUC score for SP task\n", 766 | "eval_results[\"sp_auc\"]" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": 18, 772 | "metadata": {}, 773 | "outputs": [ 774 | { 775 | "data": { 776 | "text/plain": [ 777 | "0.6713709677419355" 778 | ] 779 | }, 780 | "execution_count": 18, 781 | "metadata": {}, 782 | "output_type": "execute_result" 783 | } 784 | ], 785 | "source": [ 786 | "# Show test accuracy for SR task\n", 787 | "eval_results[\"sr_test_acc\"]" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": 19, 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "0.6956521739130435" 799 | ] 800 | }, 801 | "execution_count": 19, 802 | "metadata": {}, 803 | "output_type": "execute_result" 804 | } 805 | ], 806 | "source": [ 807 | "# Show validation accuracy for SR task\n", 808 | "eval_results[\"sr_validation_acc\"]" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 20, 814 | "metadata": {}, 815 | "outputs": [ 816 | { 817 | "data": { 818 | "text/plain": [ 819 | "[0.12955922001974632,\n", 820 | " 0.18240535049745576,\n", 821 | " 0.23030967570441258,\n", 822 | " 0.31102329687856,\n", 823 | " 0.4143004291030607,\n", 824 | " 0.47711466165413524,\n", 825 | " 0.5583993126756285,\n", 826 | " 0.6349049707602339]" 827 | ] 828 | }, 829 | "execution_count": 20, 830 | "metadata": {}, 831 | "output_type": "execute_result" 832 | } 833 | ], 834 | "source": [ 835 | "# Show Top-K Hit Rate for CR task\n", 836 | "eval_results[\"cr_topk_hit_rate\"]" 837 | ] 838 | }, 839 | { 840 | "attachments": {}, 841 | "cell_type": "markdown", 842 | "metadata": {}, 843 | "source": [ 844 | "## Create DGL Graph" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": 21, 850 | "metadata": {}, 851 | "outputs": [ 852 | { 853 | "data": { 854 | "text/plain": [ 855 | "[Graph(num_nodes=1169931, num_edges=50815503,\n", 856 | " ndata_schemes={'feat': Scheme(shape=(512,), dtype=torch.float32)}\n", 857 | " edata_schemes={'weight': Scheme(shape=(15,), dtype=torch.float32)})]" 858 | ] 859 | }, 860 | "execution_count": 21, 861 | "metadata": {}, 862 | "output_type": "execute_result" 863 | } 864 | ], 865 | "source": [ 866 | "# Takes about 15 mins, the graph will be saved to work_folder\n", 867 | "g = comkg.get_dgl_graph(work_folder=\"./experiments\")\n", 868 | "g" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": 22, 874 | "metadata": {}, 875 | "outputs": [ 876 | { 877 | "data": { 878 | "text/plain": [ 879 | "[Graph(num_nodes=1169931, num_edges=50815503,\n", 880 | " ndata_schemes={'feat': Scheme(shape=(512,), dtype=torch.float32)}\n", 881 | " edata_schemes={'weight': Scheme(shape=(15,), dtype=torch.float32)})]" 882 | ] 883 | }, 884 | "execution_count": 22, 885 | "metadata": {}, 886 | "output_type": "execute_result" 887 | } 888 | ], 889 | "source": [ 890 | "# When call the same function again, it will load from file directly\n", 891 | "g = comkg.get_dgl_graph(work_folder=\"./experiments\")\n", 892 | "g" 893 | ] 894 | }, 895 | { 896 | "attachments": {}, 897 | "cell_type": "markdown", 898 | "metadata": {}, 899 | "source": [ 900 | "## Create iGraph" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 24, 906 | "metadata": {}, 907 | "outputs": [ 908 | { 909 | "data": { 910 | "text/plain": [ 911 | "" 912 | ] 913 | }, 914 | "execution_count": 24, 915 | "metadata": {}, 916 | "output_type": "execute_result" 917 | } 918 | ], 919 | "source": [ 920 | "g = comkg.to_igraph()\n", 921 | "g" 922 | ] 923 | } 924 | ], 925 | "metadata": { 926 | "environment": { 927 | "kernel": "python3", 928 | "name": "pytorch-gpu.1-13.m107", 929 | "type": "gcloud", 930 | "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m107" 931 | }, 932 | "kernelspec": { 933 | "display_name": "Python 3 (ipykernel)", 934 | "language": "python", 935 | "name": "python3" 936 | }, 937 | "language_info": { 938 | "codemirror_mode": { 939 | "name": "ipython", 940 | "version": 3 941 | }, 942 | "file_extension": ".py", 943 | "mimetype": "text/x-python", 944 | "name": "python", 945 | "nbconvert_exporter": "python", 946 | "pygments_lexer": "ipython3", 947 | "version": "3.9.18" 948 | } 949 | }, 950 | "nbformat": 4, 951 | "nbformat_minor": 4 952 | } 953 | --------------------------------------------------------------------------------