├── .pyup.yml ├── docs ├── readme.rst ├── logo_torchKGE_small.png ├── tutorials │ ├── evaluation.rst │ ├── training.rst │ ├── linkprediction.rst │ ├── tripletclassification.rst │ ├── transe_wrappers.rst │ ├── transe.rst │ └── transe_early_stopping.rst ├── authors.rst ├── reference │ ├── inference.rst │ ├── data.rst │ ├── sampling.rst │ ├── evaluation.rst │ ├── models.rst │ └── utils.rst ├── Makefile ├── index.rst ├── _static │ └── css │ │ └── custom.css ├── make.bat ├── installation.rst ├── contributing.rst ├── conf.py └── history.rst ├── tests ├── __init__.py ├── test_evaluation.py ├── test_data.py └── test_utils.py ├── tox.ini ├── .editorconfig ├── requirements_dev.txt ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── ci_checks.yml │ └── release.yml ├── .readthedocs.yaml ├── setup.cfg ├── torchkge ├── models │ ├── __init__.py │ ├── deep.py │ └── interfaces.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── dissimilarities.py │ ├── modeling.py │ ├── losses.py │ ├── data.py │ ├── pretrained_models.py │ ├── operations.py │ ├── training.py │ ├── data_redundancy.py │ └── datasets.py ├── exceptions.py ├── inference.py ├── data_structures.py └── evaluation.py ├── setup.py ├── LICENSE ├── .gitignore └── README.rst /.pyup.yml: -------------------------------------------------------------------------------- 1 | update: insecure -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Unit test package for torchkge.""" 4 | -------------------------------------------------------------------------------- /docs/logo_torchKGE_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchkge-team/torchkge/HEAD/docs/logo_torchKGE_small.png -------------------------------------------------------------------------------- /docs/tutorials/evaluation.rst: -------------------------------------------------------------------------------- 1 | Model Evaluation 2 | **************** 3 | 4 | .. include:: linkprediction.rst 5 | 6 | .. include:: tripletclassification.rst 7 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Armand Boschin 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38, flake8 3 | 4 | [testenv:flake8] 5 | basepython = python 6 | deps = flake8 7 | commands = flake8 torchkge 8 | 9 | [testenv] 10 | setenv = 11 | PYTHONPATH = {toxinidir} 12 | 13 | commands = python setup.py test 14 | -------------------------------------------------------------------------------- /docs/tutorials/training.rst: -------------------------------------------------------------------------------- 1 | Model Training 2 | ************** 3 | 4 | Here are two examples of models being trained on FB15k. 5 | 6 | .. include:: transe.rst 7 | 8 | .. include:: transe_wrappers.rst 9 | 10 | .. include:: transe_early_stopping.rst 11 | -------------------------------------------------------------------------------- /docs/reference/inference.rst: -------------------------------------------------------------------------------- 1 | .. _inference: 2 | 3 | 4 | Inference 5 | ********* 6 | 7 | Entity Inference 8 | ---------------- 9 | 10 | .. autoclass:: torchkge.inference.EntityInference 11 | :members: 12 | 13 | Relation Inference 14 | ------------------ 15 | 16 | .. autoclass:: torchkge.inference.RelationInference 17 | :members: 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.22 2 | pandas>=1.4 3 | torch>=1.2 4 | tqdm>=4.64 5 | 6 | # Documentation 7 | sphinx==5.0.2 8 | sphinx_rtd_theme==1.0 9 | numpydoc==1.4.0 10 | 11 | # Tests 12 | flake8==4.0.1 13 | tox==3.25.1 14 | pytest-runner==6.0.0 15 | pytest==7.1.2 16 | 17 | # Deployement 18 | pip==23.3.1 19 | bumpversion==0.6 20 | wheel==0.38.2 21 | -------------------------------------------------------------------------------- /docs/reference/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | 4 | Data Structure 5 | ******************* 6 | 7 | .. currentmodule:: torchkge.data_structures 8 | 9 | Knowledge Graph 10 | --------------- 11 | .. autoclass:: torchkge.data_structures.KnowledgeGraph 12 | :members: 13 | 14 | Small KG 15 | -------- 16 | .. autoclass:: torchkge.data_structures.SmallKG 17 | :members: 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * TorchKGE version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-20.04 10 | tools: 11 | python: "3.8" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | python: 18 | install: 19 | - requirements: requirements_dev.txt -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.17.7 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:torchkge/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | [aliases] 21 | test = pytest 22 | 23 | [tool:pytest] 24 | collect_ignore = ['setup.py'] 25 | -------------------------------------------------------------------------------- /torchkge/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .interfaces import Model, TranslationModel, BilinearModel 2 | 3 | from .translation import TransEModel 4 | from .translation import TransHModel 5 | from .translation import TransRModel 6 | from .translation import TransDModel 7 | from .translation import TorusEModel 8 | 9 | from .bilinear import RESCALModel 10 | from .bilinear import DistMultModel 11 | from .bilinear import HolEModel 12 | from .bilinear import ComplExModel 13 | from .bilinear import AnalogyModel 14 | 15 | from .deep import ConvKBModel 16 | -------------------------------------------------------------------------------- /docs/reference/sampling.rst: -------------------------------------------------------------------------------- 1 | .. _sampling: 2 | 3 | .. currentmodule:: torchkge.sampling 4 | 5 | Negative Sampling 6 | ***************** 7 | 8 | Uniform negative sampler 9 | ------------------------ 10 | .. autoclass:: torchkge.sampling.UniformNegativeSampler 11 | :members: 12 | 13 | Bernoulli negative sampler 14 | -------------------------- 15 | .. autoclass:: torchkge.sampling.BernoulliNegativeSampler 16 | :members: 17 | 18 | Positional negative sampler 19 | --------------------------- 20 | .. autoclass:: torchkge.sampling.PositionalNegativeSampler 21 | :members: 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = torchkge 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to TorchKGE' s documentation! 2 | ====================================== 3 | 4 | .. include:: readme.rst 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: Tutorials: 9 | 10 | tutorials/training 11 | tutorials/evaluation 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Reference: 16 | 17 | reference/models 18 | reference/evaluation 19 | reference/inference 20 | reference/sampling 21 | reference/data 22 | reference/utils 23 | 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Installation: 28 | 29 | installation 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: About: 34 | 35 | contributing 36 | authors 37 | history 38 | -------------------------------------------------------------------------------- /docs/tutorials/linkprediction.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Link Prediction 3 | =============== 4 | 5 | To evaluate a model on link prediction:: 6 | 7 | from torch import cuda 8 | from torchkge.utils.pretrained_models import load_pretrained_transe 9 | from torchkge.utils.datasets import load_fb15k 10 | from torchkge.evaluation import LinkPredictionEvaluator 11 | 12 | _, _, kg_test = load_fb15k() 13 | 14 | model = load_pretrained_transe('fb15k', 100) 15 | if cuda.is_available(): 16 | model.cuda() 17 | 18 | # Link prediction evaluation on test set. 19 | evaluator = LinkPredictionEvaluator(model, kg_test) 20 | evaluator.evaluate(b_size=32) 21 | evaluator.print_results() 22 | 23 | -------------------------------------------------------------------------------- /.github/workflows/ci_checks.yml: -------------------------------------------------------------------------------- 1 | name: CI Checks 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | fail-fast: false 10 | matrix: 11 | python-version: ['3.8', '3.9', '3.10'] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | python -m pip install -r requirements_dev.txt 23 | python setup.py develop 24 | - name: Test with pytest 25 | run: | 26 | py.test --doctest-modules -------------------------------------------------------------------------------- /torchkge/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for TorchKGE.""" 4 | 5 | __author__ = """Armand Boschin""" 6 | __email__ = 'aboschin@enst.fr' 7 | __version__ = '0.17.7' 8 | 9 | from torchkge.exceptions import NotYetEvaluatedError 10 | from torchkge.utils import MarginLoss, LogisticLoss 11 | from torchkge.utils import l1_dissimilarity, l2_dissimilarity 12 | from .data_structures import KnowledgeGraph 13 | from .evaluation import LinkPredictionEvaluator 14 | from .evaluation import TripletClassificationEvaluator 15 | from .models import ConvKBModel 16 | from .models import RESCALModel, DistMultModel, HolEModel, ComplExModel, AnalogyModel 17 | from .models import TransEModel, TransHModel, TransRModel, TransDModel, TorusEModel 18 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | html { 2 | background-color: #e6e6e6; 3 | } 4 | 5 | body.wy-body-for-nav { 6 | line-height: 1.5em; 7 | color: #333; 8 | } 9 | 10 | div.wy-nav-side { 11 | background-color: #333; 12 | } 13 | 14 | div.wy-side-nav-search { 15 | background-color: #777777; 16 | } 17 | 18 | 19 | div.wy-menu.wy-menu-vertical>p { 20 | color: #c5113b /* section titles */ 21 | } 22 | 23 | .wy-nav-top { 24 | background-color: #777777; 25 | } 26 | 27 | .wy-side-nav-search>a:hover, .wy-side-nav-search .wy-dropdown>a:hover { 28 | background: None; /*background for logo when hovered*/ 29 | } 30 | 31 | .wy-side-nav-search>div.version { 32 | color: white; 33 | } 34 | 35 | .wy-side-nav-search input[type=text] { 36 | border-color: #d9d9d9; 37 | } 38 | 39 | a { 40 | color: #c5113b; 41 | } 42 | -------------------------------------------------------------------------------- /docs/tutorials/tripletclassification.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | Triplet Classification 3 | ====================== 4 | 5 | To evaluate a model on triplet classification:: 6 | 7 | from torch import cuda 8 | from torchkge.evaluation import TripletClassificationEvaluator 9 | from torchkge.utils.pretrained_models import load_pretrained_transe 10 | from torchkge.utils.datasets import load_fb15k 11 | 12 | _, kg_val, kg_test = load_fb15k() 13 | 14 | model = load_pretrained_transe('fb15k', 100): 15 | if cuda.is_available(): 16 | model.cuda() 17 | 18 | # Triplet classification evaluation on test set by learning thresholds on validation set 19 | evaluator = TripletClassificationEvaluator(model, kg_val, kg_test) 20 | evaluator.evaluate(b_size=128) 21 | 22 | print('Accuracy on test set: {}'.format(evaluator.accuracy(b_size=128))) 23 | 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=torchkge 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /torchkge/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import DataLoader, get_data_home, clear_data_home, safe_extract 2 | 3 | from .data_redundancy import count_triplets, duplicates 4 | from .data_redundancy import cartesian_product_relations 5 | 6 | from .datasets import load_fb15k, load_fb13, load_fb15k237 7 | from .datasets import load_wn18, load_wn18rr 8 | from .datasets import load_yago3_10, load_wikidatasets, load_wikidata_vitals 9 | 10 | from .dissimilarities import l1_dissimilarity, l2_dissimilarity 11 | from .dissimilarities import l1_torus_dissimilarity, l2_torus_dissimilarity, \ 12 | el2_torus_dissimilarity 13 | 14 | from .losses import MarginLoss, LogisticLoss, BinaryCrossEntropyLoss 15 | from .modeling import init_embedding, get_true_targets, load_embeddings, filter_scores 16 | from .operations import get_rank, get_mask, get_bernoulli_probs 17 | from .pretrained_models import load_pretrained_transe, load_pretrained_rescal, load_pretrained_complex 18 | from .training import Trainer, TrainDataLoader 19 | -------------------------------------------------------------------------------- /torchkge/exceptions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | 8 | class NotYetEvaluatedError(Exception): 9 | def __init__(self, message): 10 | super().__init__(message) 11 | 12 | 13 | class SizeMismatchError(Exception): 14 | def __init__(self, message): 15 | super().__init__(message) 16 | 17 | 18 | class WrongDimensionError(Exception): 19 | def __init__(self, message): 20 | super().__init__(message) 21 | 22 | 23 | class NotYetImplementedError(Exception): 24 | def __init__(self, message): 25 | super().__init__(message) 26 | 27 | 28 | class WrongArgumentsError(Exception): 29 | def __init__(self, message): 30 | super().__init__(message) 31 | 32 | 33 | class SanityError(Exception): 34 | def __init__(self, message): 35 | super().__init__(message) 36 | 37 | 38 | class SplitabilityError(Exception): 39 | def __init__(self, message): 40 | super().__init__(message) 41 | 42 | 43 | class NoPreTrainedVersionError(Exception): 44 | def __init__(self, message): 45 | super().__init__(message) 46 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.8' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_PASSWORD }} 37 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | Stable release 4 | -------------- 5 | 6 | To install TorchKGE, run this command in your terminal: 7 | 8 | .. code-block:: console 9 | 10 | $ pip install torchkge 11 | 12 | This is the preferred method to install TorchKGE, as it will always install the most recent stable release. 13 | 14 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 15 | you through the process. 16 | 17 | .. _pip: https://pip.pypa.io 18 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 19 | 20 | 21 | From sources 22 | ------------ 23 | 24 | The sources for TorchKGE can be downloaded from the `Github repo`_. 25 | 26 | You can either clone the public repository: 27 | 28 | .. code-block:: console 29 | 30 | $ git clone git://github.com/torchkge/torchkge 31 | 32 | Or download the `tarball`_: 33 | 34 | .. code-block:: console 35 | 36 | $ curl -OL https://github.com/torchkge/torchkge/tarball/master 37 | 38 | Once you have a copy of the source, you can install it with: 39 | 40 | .. code-block:: console 41 | 42 | $ python setup.py install 43 | 44 | 45 | .. _Github repo: https://github.com/torchkge/torchkge 46 | .. _tarball: https://github.com/torchkge/torchkge/tarball/master 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import setup, find_packages 7 | 8 | with open('README.rst') as readme_file: 9 | readme = readme_file.read() 10 | 11 | requirements = ['torch>=1.2.0', 'tqdm>=4.64', 'pandas>=1.4', 'numpy>=1.22'] 12 | 13 | setup_requirements = ['pytest-runner'] 14 | 15 | test_requirements = ['pytest'] 16 | 17 | setup( 18 | author="TorchKGE Developers", 19 | author_email='aboschin@enst.fr', 20 | classifiers=[ 21 | 'Development Status :: 2 - Pre-Alpha', 22 | 'Intended Audience :: Developers', 23 | 'License :: OSI Approved :: BSD License', 24 | 'Natural Language :: English', 25 | 'Programming Language :: Python :: 3.8', 26 | 'Programming Language :: Python :: 3.9', 27 | 'Programming Language :: Python :: 3.10' 28 | ], 29 | description="Knowledge Graph embedding in Python and PyTorch.", 30 | license="BSD license", 31 | long_description=readme, 32 | include_package_data=True, 33 | keywords='torchkge', 34 | name='torchkge', 35 | url='https://github.com/torchkge-team/torchkge', 36 | packages=find_packages(), 37 | install_requires=requirements, 38 | setup_requires=setup_requirements, 39 | tests_require=test_requirements, 40 | test_suite='tests', 41 | version='0.17.7', 42 | zip_safe=False, 43 | ) 44 | -------------------------------------------------------------------------------- /docs/tutorials/transe_wrappers.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Shortest training 3 | ================= 4 | 5 | TorchKGE also provides simple utility wrappers for model training. Here is an example on how to use them:: 6 | 7 | from torch.optim import Adam 8 | 9 | from torchkge.evaluation import LinkPredictionEvaluator 10 | from torchkge.models import TransEModel 11 | from torchkge.utils.datasets import load_fb15k 12 | from torchkge.utils import Trainer, MarginLoss 13 | 14 | 15 | def main(): 16 | # Define some hyper-parameters for training 17 | emb_dim = 100 18 | lr = 0.0004 19 | margin = 0.5 20 | n_epochs = 1000 21 | batch_size = 32768 22 | 23 | # Load dataset 24 | kg_train, kg_val, kg_test = load_fb15k() 25 | 26 | # Define the model and criterion 27 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, 28 | dissimilarity_type='L2') 29 | criterion = MarginLoss(margin) 30 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 31 | 32 | trainer = Trainer(model, criterion, kg_train, n_epochs, batch_size, 33 | optimizer=optimizer, sampling_type='bern', use_cuda='all',) 34 | 35 | trainer.run() 36 | 37 | evaluator = LinkPredictionEvaluator(model, kg_test) 38 | evaluator.evaluate(200) 39 | evaluator.print_results() 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | BSD License 4 | 5 | Copyright (c) 2022, TorchKGE developers 6 | Armand Boschin 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without modification, 10 | are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, this 16 | list of conditions and the following disclaimer in the documentation and/or 17 | other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from this 21 | software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 24 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 25 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 26 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 27 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 28 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 30 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 31 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 32 | OF THE POSSIBILITY OF SUCH DAMAGE. 33 | 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | .DS_Store 4 | test/ 5 | .travis.yml 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /torchkge/utils/dissimilarities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from math import pi 8 | from torch import abs, cos, min 9 | 10 | 11 | def l1_dissimilarity(a, b): 12 | """Compute dissimilarity between rows of `a` and `b` as :math:`||a-b||_1`. 13 | 14 | """ 15 | assert len(a.shape) == len(b.shape) 16 | return (a-b).norm(p=1, dim=-1) 17 | 18 | 19 | def l2_dissimilarity(a, b): 20 | """Compute dissimilarity between rows of `a` and `b` as 21 | :math:`||a-b||_2^2`. 22 | 23 | """ 24 | assert len(a.shape) == len(b.shape) 25 | return (a-b).norm(p=2, dim=-1)**2 26 | 27 | 28 | def l1_torus_dissimilarity(a, b): 29 | """See `paper by Ebisu et al. `_ 30 | for details about the definition of this dissimilarity function. 31 | 32 | """ 33 | assert len(a.shape) == len(b.shape) 34 | return 2 * min(abs(a - b), 1 - abs(a - b)).sum(dim=-1) 35 | 36 | 37 | def l2_torus_dissimilarity(a, b): 38 | """See `paper by Ebisu et al. `_ 39 | for details about the definition of this dissimilarity function. 40 | 41 | """ 42 | assert len(a.shape) == len(b.shape) 43 | return 4 * min((a - b) ** 2, 1 - (a - b) ** 2).sum(dim=-1) 44 | 45 | 46 | def el2_torus_dissimilarity(a, b): 47 | """See `paper by Ebisu et al. `_ 48 | for details about the definition of this dissimilarity function. 49 | 50 | """ 51 | assert len(a.shape) == len(b.shape) 52 | tmp = min(a - b, 1 - (a - b)) 53 | tmp = 2 * (1 - cos(2 * pi * tmp)) 54 | return tmp.sum(dim=-1) / 4 55 | -------------------------------------------------------------------------------- /tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | 4 | from torch import long 5 | 6 | from torchkge.data_structures import KnowledgeGraph 7 | from torchkge.evaluation import LinkPredictionEvaluator, TripletClassificationEvaluator 8 | from torchkge.models import TransEModel 9 | 10 | 11 | class TestUtils(unittest.TestCase): 12 | 13 | def setUp(self): 14 | df = pd.DataFrame([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], 15 | [5, 4, 0]], columns=['from', 'to', 'rel']) 16 | self.kg = KnowledgeGraph(df) 17 | 18 | def checkSanityLinkPrediction(self, evaluator): 19 | assert evaluator.rank_true_heads.dtype == long 20 | assert evaluator.rank_true_tails.dtype == long 21 | assert evaluator.filt_rank_true_heads.dtype == long 22 | assert evaluator.filt_rank_true_tails.dtype == long 23 | 24 | assert evaluator.rank_true_heads.shape[0] == len(self.kg) 25 | assert evaluator.rank_true_tails.shape[0] == len(self.kg) 26 | assert evaluator.filt_rank_true_heads.shape[0] == len(self.kg) 27 | assert evaluator.filt_rank_true_tails.shape[0] == len(self.kg) 28 | 29 | def test_LinkPredictionEvaluator(self): 30 | model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1') 31 | 32 | evaluator = LinkPredictionEvaluator(model, self.kg) 33 | self.checkSanityLinkPrediction(evaluator) 34 | 35 | evaluator.evaluate(b_size=len(self.kg)) 36 | self.checkSanityLinkPrediction(evaluator) 37 | 38 | def test_TripletClassificationEvaluator(self): 39 | model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1') 40 | kg1, kg2 = self.kg.split_kg(sizes=(4, 5)) 41 | # kg2 contains all relations so it will be used as validation 42 | evaluator = TripletClassificationEvaluator(model, kg2, kg1) 43 | assert evaluator.thresholds is None 44 | assert not evaluator.evaluated 45 | 46 | evaluator.evaluate(b_size=len(self.kg)) 47 | assert evaluator.evaluated 48 | assert evaluator.thresholds is not None 49 | assert (len(evaluator.thresholds.shape) == 1) & (evaluator.thresholds.shape[0] == self.kg.n_rel) 50 | -------------------------------------------------------------------------------- /docs/tutorials/transe.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Simplest training 3 | ================= 4 | 5 | This is the python code to train TransE without any wrapper. This script shows how all parts of TorchKGE should be used 6 | together:: 7 | 8 | from torch import cuda 9 | from torch.optim import Adam 10 | 11 | from torchkge.models import TransEModel 12 | from torchkge.sampling import BernoulliNegativeSampler 13 | from torchkge.utils import MarginLoss, DataLoader 14 | from torchkge.utils.datasets import load_fb15k 15 | 16 | from tqdm.autonotebook import tqdm 17 | 18 | # Load dataset 19 | kg_train, _, _ = load_fb15k() 20 | 21 | # Define some hyper-parameters for training 22 | emb_dim = 100 23 | lr = 0.0004 24 | n_epochs = 1000 25 | b_size = 32768 26 | margin = 0.5 27 | 28 | # Define the model and criterion 29 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2') 30 | criterion = MarginLoss(margin) 31 | 32 | # Move everything to CUDA if available 33 | if cuda.is_available(): 34 | cuda.empty_cache() 35 | model.cuda() 36 | criterion.cuda() 37 | 38 | # Define the torch optimizer to be used 39 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 40 | 41 | sampler = BernoulliNegativeSampler(kg_train) 42 | dataloader = DataLoader(kg_train, batch_size=b_size, use_cuda='all') 43 | 44 | iterator = tqdm(range(n_epochs), unit='epoch') 45 | for epoch in iterator: 46 | running_loss = 0.0 47 | for i, batch in enumerate(dataloader): 48 | h, t, r = batch[0], batch[1], batch[2] 49 | n_h, n_t = sampler.corrupt_batch(h, t, r) 50 | 51 | optimizer.zero_grad() 52 | 53 | # forward + backward + optimize 54 | pos, neg = model(h, t, r, n_h, n_t) 55 | loss = criterion(pos, neg) 56 | loss.backward() 57 | optimizer.step() 58 | 59 | running_loss += loss.item() 60 | iterator.set_description( 61 | 'Epoch {} | mean loss: {:.5f}'.format(epoch + 1, 62 | running_loss / len(dataloader))) 63 | 64 | model.normalize_parameters() 65 | 66 | -------------------------------------------------------------------------------- /docs/reference/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _evaluation: 2 | 3 | 4 | Evaluation 5 | ********** 6 | 7 | Link Prediction 8 | --------------- 9 | To assess the performance of the link prediction evaluation module of TorchKGE, it was compared with the ones of 10 | `AmpliGraph `_ (v1.3.1) and `OpenKE `_ (version of 11 | April, 9). The computation times (in seconds) reported in the following table are averaged over 5 independent evaluation 12 | processes. Experiments were done using PyTorch 1.5, TensorFlow 1.15 and a Tesla K80 GPU. Missing values for AmpliGraph 13 | are due to missing models in the library. 14 | 15 | .. tabularcolumns:: p{2cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 16 | 17 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 18 | | Model | TransE | TransD | RESCAL | ComplEx | 19 | +===========+===========+===========+===========+===========+===========+===========+===========+===========+ 20 | | Dataset |FB15k | WN18 | FB15k | WN18 | FB15k | WN18 | FB15k | WN18 | 21 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 22 | |AmpliGraph | 354.8 | 39.8 | | | | | 537.2 | 94.9 | 23 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 24 | |OpenKE | 235.6 | 42.2 | 258.5 | 43.7 | 789.1 | 178.4 | 354.7 | 63.9 | 25 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 26 | |TorchKGE | 76.1 | 13.8 | 60.8 | 11.1 | 46.9 | 7.1 | 96.4 | 18.6 | 27 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 28 | 29 | .. autoclass:: torchkge.evaluation.LinkPredictionEvaluator 30 | :members: 31 | 32 | Relation Prediction 33 | ------------------- 34 | .. autoclass:: torchkge.evaluation.RelationPredictionEvaluator 35 | 36 | Triplet Classification 37 | ---------------------- 38 | .. autoclass:: torchkge.evaluation.TripletClassificationEvaluator 39 | :members: 40 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | TorchKGE 3 | ======== 4 | 5 | .. image:: https://graphs.telecom-paristech.fr/images/logo_torchKGE_small.png 6 | :align: right 7 | :width: 100px 8 | :alt: logo torchkge 9 | 10 | .. image:: https://img.shields.io/pypi/v/torchkge.svg 11 | :target: https://pypi.python.org/pypi/torchkge 12 | 13 | .. image:: https://github.com/torchkge-team/torchkge/actions/workflows/ci_checks.yml/badge.svg 14 | :target: https://github.com/torchkge-team/torchkge/actions/workflows/ci_checks.yml 15 | 16 | .. image:: https://readthedocs.org/projects/torchkge/badge/?version=latest 17 | :target: https://torchkge.readthedocs.io/en/latest/?badge=latest 18 | :alt: Documentation Status 19 | 20 | .. image:: https://img.shields.io/pypi/pyversions/torchkge.svg 21 | :target: https://pypi.org/project/torchkge/ 22 | 23 | TorchKGE: Knowledge Graph embedding in Python and Pytorch. 24 | 25 | TorchKGE is a Python module for knowledge graph (KG) embedding relying solely on Pytorch. This package provides 26 | researchers and engineers with a clean and efficient API to design and test new models. It features a KG data structure, 27 | simple model interfaces and modules for negative sampling and model evaluation. Its main strength is a highly efficient 28 | evaluation module for the link prediction task, a central application of KG embedding. It has been `observed `_ to be up 29 | to five times faster than `AmpliGraph `_ and twenty-four times faster than 30 | `OpenKE `_. Various KG embedding models are also already implemented. Special 31 | attention has been paid to code efficiency and simplicity, documentation and API consistency. It is distributed using 32 | PyPI under BSD license. 33 | 34 | Citations 35 | --------- 36 | If you find this code useful in your research, please consider citing our `paper `_ (presented at `IWKG-KDD `_ 2020): 37 | 38 | .. code:: 39 | 40 | @inproceedings{arm2020torchkge, 41 | title={TorchKGE: Knowledge Graph Embedding in Python and PyTorch}, 42 | author={Armand Boschin}, 43 | year={2020}, 44 | month={Aug}, 45 | booktitle={International Workshop on Knowledge Graph: Mining Knowledge Graph for Deep Insights}, 46 | } 47 | 48 | * Free software: BSD license 49 | * Documentation: https://torchkge.readthedocs.io. 50 | -------------------------------------------------------------------------------- /docs/reference/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Models 4 | ****** 5 | 6 | Interfaces 7 | ========== 8 | 9 | Model 10 | ----- 11 | .. autoclass:: torchkge.models.interfaces.Model 12 | :members: 13 | 14 | TranslationalModels 15 | ------------------- 16 | .. autoclass:: torchkge.models.interfaces.TranslationModel 17 | :members: 18 | 19 | Translational Models 20 | ==================== 21 | 22 | Parameters used to train models available in pre-trained version : 23 | 24 | .. tabularcolumns:: p{2cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 25 | 26 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 27 | | | Dataset | Dimension | Optimizer | Learning Rate | Batch Size | Loss | Margin | L2 penalization | 28 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 29 | |TransE | FB15k | 100 | Adam | 2.1e-5 | 32768 | Margin | .651 | 1e-5 | 30 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 31 | |TransE | FB15k237 | 100 | Adam | 2.1e-5 | 32768 | Margin | .651 | 1e-5 | 32 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 33 | |TransE | FB15k237 | 150 | Adam | 2.7e-5 | 32768 | Margin | .648 | 1e-5 | 34 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 35 | 36 | TransE 37 | ------ 38 | .. autoclass:: torchkge.models.translation.TransEModel 39 | :members: 40 | 41 | TransH 42 | ------ 43 | .. autoclass:: torchkge.models.translation.TransHModel 44 | :members: 45 | 46 | TransR 47 | ------ 48 | .. autoclass:: torchkge.models.translation.TransRModel 49 | :members: 50 | 51 | TransD 52 | ------ 53 | .. autoclass:: torchkge.models.translation.TransDModel 54 | :members: 55 | 56 | TorusE 57 | ------ 58 | .. autoclass:: torchkge.models.translation.TorusEModel 59 | :members: 60 | 61 | Bilinear Models 62 | =============== 63 | 64 | RESCAL 65 | ------ 66 | .. autoclass:: torchkge.models.bilinear.RESCALModel 67 | :members: 68 | 69 | DistMult 70 | -------- 71 | .. autoclass:: torchkge.models.bilinear.DistMultModel 72 | :members: 73 | 74 | HolE 75 | ---- 76 | .. autoclass:: torchkge.models.bilinear.HolEModel 77 | :members: 78 | 79 | ComplEx 80 | ------- 81 | .. autoclass:: torchkge.models.bilinear.ComplExModel 82 | :members: 83 | 84 | ANALOGY 85 | ------- 86 | .. autoclass:: torchkge.models.bilinear.AnalogyModel 87 | :members: 88 | 89 | Deep Models 90 | =========== 91 | 92 | ConvKB 93 | ------ 94 | .. autoclass:: torchkge.models.deep.ConvKBModel 95 | :members: 96 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | 4 | from torch import Tensor, int64 5 | 6 | from torchkge.data_structures import KnowledgeGraph 7 | from torchkge.exceptions import WrongArgumentsError, SanityError, SizeMismatchError 8 | 9 | 10 | class TestUtils(unittest.TestCase): 11 | """Tests for `torchkge.utils`.""" 12 | 13 | def setUp(self): 14 | self.df = pd.DataFrame([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], 15 | [5, 4, 0]], columns=['from', 'to', 'rel']) 16 | self.kg = KnowledgeGraph(self.df) 17 | 18 | def test_KnowledgeGraph_Builder(self): 19 | assert len(self.kg) == 9 20 | assert self.kg.n_ent == 6 21 | assert self.kg.n_rel == 4 22 | assert (type(self.kg.rel2ix) == dict) & (type(self.kg.rel2ix) == dict) 23 | assert (type(self.kg.head_idx) == Tensor) & (type(self.kg.tail_idx) == Tensor) & \ 24 | (type(self.kg.relations) == Tensor) 25 | assert (self.kg.head_idx.dtype == int64) & (self.kg.tail_idx.dtype == int64) & \ 26 | (self.kg.relations.dtype == int64) 27 | assert (len(self.kg.head_idx) == len(self.kg.tail_idx) == len(self.kg.relations)) 28 | 29 | kg_dict = {'heads': self.kg.head_idx, 'tails': self.kg.tail_idx, 'relations': self.kg.relations} 30 | with self.assertRaises(WrongArgumentsError): 31 | KnowledgeGraph() 32 | with self.assertRaises(WrongArgumentsError): 33 | KnowledgeGraph(kg=kg_dict, df=self.df) 34 | with self.assertRaises(WrongArgumentsError): 35 | KnowledgeGraph(kg=kg_dict) 36 | with self.assertRaises(WrongArgumentsError): 37 | KnowledgeGraph(kg={'heads': self.kg.head_idx, 'tails': self.kg.tail_idx}, 38 | ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) 39 | with self.assertRaises(SanityError): 40 | KnowledgeGraph(kg={'heads': self.kg.head_idx[:-1], 41 | 'tails': self.kg.tail_idx, 42 | 'relations': self.kg.relations}, 43 | ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) 44 | with self.assertRaises(SanityError): 45 | KnowledgeGraph(kg={'heads': self.kg.head_idx.int(), 46 | 'tails': self.kg.tail_idx, 47 | 'relations': self.kg.relations}, 48 | ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) 49 | 50 | def test_split_kg(self): 51 | assert (len(self.kg.split_kg()) == 2) & (len(self.kg.split_kg(validation=True)) == 3) 52 | with self.assertRaises(SizeMismatchError): 53 | self.kg.split_kg(sizes=(1, 2, 3, 4)) 54 | with self.assertRaises(WrongArgumentsError): 55 | self.kg.split_kg(sizes=(9, 9, 9)) 56 | with self.assertRaises(WrongArgumentsError): 57 | self.kg.split_kg(sizes=(9, 9)) 58 | 59 | -------------------------------------------------------------------------------- /torchkge/utils/modeling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import tensor 8 | from torch.nn import Embedding 9 | from torch.nn.init import xavier_uniform_ 10 | 11 | import pickle 12 | import tarfile 13 | 14 | from torchkge.utils import get_data_home, safe_extract 15 | 16 | from os import makedirs, remove 17 | from os.path import exists 18 | from urllib.request import urlretrieve 19 | 20 | 21 | def init_embedding(n_vectors, dim): 22 | """Create a torch.nn.Embedding object with `n_vectors` samples and `dim` 23 | dimensions. It is then initialized with Xavier uniform distribution. 24 | """ 25 | entity_embeddings = Embedding(n_vectors, dim) 26 | xavier_uniform_(entity_embeddings.weight.data) 27 | 28 | return entity_embeddings 29 | 30 | 31 | def load_embeddings(model, dim, dataset, data_home=None): 32 | 33 | if data_home is None: 34 | data_home = get_data_home() 35 | data_path = data_home + '/models/' 36 | targz_file = data_path + '{}_{}_{}.tar.gz'.format(model, dataset, dim) 37 | pkl_file = data_path + '{}_{}_{}.pkl'.format(model, dataset, dim) 38 | if not exists(pkl_file): 39 | if not exists(data_path): 40 | makedirs(data_path, exist_ok=True) 41 | urlretrieve("https://graphs.telecom-paris.fr/data/torchkge/models/{}_{}_{}.tar.gz".format(model, dataset, dim), 42 | targz_file) 43 | with tarfile.open(targz_file, 'r') as tf: 44 | safe_extract(tf, data_path) 45 | remove(targz_file) 46 | 47 | with open(pkl_file, 'rb') as f: 48 | state_dict = pickle.load(f) 49 | 50 | return state_dict 51 | 52 | 53 | def get_true_targets(dictionary, key1, key2, true_idx, i): 54 | """For a current index `i` of the batch, returns a tensor containing the 55 | indices of entities for which the triplet is an existing one (i.e. a true 56 | one under CWA). 57 | 58 | Parameters 59 | ---------- 60 | dictionary: default dict 61 | Dictionary of keys (int, int) and values list of ints giving all 62 | possible entities for the (entity, relation) pair. 63 | key1: torch.Tensor, shape: (batch_size), dtype: torch.long 64 | key2: torch.Tensor, shape: (batch_size), dtype: torch.long 65 | true_idx: torch.Tensor, shape: (batch_size), dtype: torch.long 66 | Tensor containing the true entity for each sample. 67 | i: int 68 | Indicates which index of the batch is currently treated. 69 | 70 | Returns 71 | ------- 72 | true_targets: torch.Tensor, shape: (batch_size), dtype: torch.long 73 | Tensor containing the indices of entities such that 74 | (e_idx[i], r_idx[i], true_target[any]) is a true fact. 75 | 76 | """ 77 | try: 78 | true_targets = dictionary[key1[i].item(), key2[i].item()].copy() 79 | if true_idx is not None: 80 | true_targets.remove(true_idx[i].item()) 81 | if len(true_targets) > 0: 82 | return tensor(list(true_targets)).long() 83 | else: 84 | return None 85 | else: 86 | return tensor(list(true_targets)).long() 87 | except KeyError: 88 | return None 89 | 90 | 91 | def filter_scores(scores, dictionary, key1, key2, true_idx): 92 | # filter out the true negative samples by assigning - inf score. 93 | b_size = scores.shape[0] 94 | filt_scores = scores.clone() 95 | 96 | for i in range(b_size): 97 | true_targets = get_true_targets(dictionary, key1, key2, true_idx, i) 98 | if true_targets is None: 99 | continue 100 | filt_scores[i][true_targets] = - float('Inf') 101 | 102 | return filt_scores 103 | -------------------------------------------------------------------------------- /docs/tutorials/transe_early_stopping.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Training with Ignite 3 | ==================== 4 | 5 | TorchKGE can be used along with the `PyTorch ignite `_ library. It makes it easy to include 6 | early stopping in the training process. Here is an example script of training a TransE model on FB15k on GPU with early 7 | stopping on evaluation MRR:: 8 | 9 | import torch 10 | from ignite.engine import Engine, Events 11 | from ignite.handlers import EarlyStopping 12 | from ignite.metrics import RunningAverage 13 | from torch.optim import Adam 14 | 15 | from torchkge.evaluation import LinkPredictionEvaluator 16 | from torchkge.models import TransEModel 17 | from torchkge.sampling import BernoulliNegativeSampler 18 | from torchkge.utils import MarginLoss, DataLoader 19 | from torchkge.utils.datasets import load_fb15k 20 | 21 | 22 | def process_batch(engine, batch): 23 | h, t, r = batch[0], batch[1], batch[2] 24 | n_h, n_t = sampler.corrupt_batch(h, t, r) 25 | 26 | optimizer.zero_grad() 27 | 28 | pos, neg = model(h, t, r, n_h, n_t) 29 | loss = criterion(pos, neg) 30 | loss.backward() 31 | optimizer.step() 32 | 33 | return loss.item() 34 | 35 | 36 | def linkprediction_evaluation(engine): 37 | model.normalize_parameters() 38 | 39 | loss = engine.state.output 40 | 41 | # validation MRR measure 42 | if engine.state.epoch % eval_epoch == 0: 43 | evaluator = LinkPredictionEvaluator(model, kg_val) 44 | evaluator.evaluate(b_size=256, verbose=False) 45 | val_mrr = evaluator.mrr()[1] 46 | else: 47 | val_mrr = 0 48 | 49 | print('Epoch {} | Train loss: {}, Validation MRR: {}'.format( 50 | engine.state.epoch, loss, val_mrr)) 51 | 52 | try: 53 | if engine.state.best_mrr < val_mrr: 54 | engine.state.best_mrr = val_mrr 55 | return val_mrr 56 | 57 | except AttributeError as e: 58 | if engine.state.epoch == 1: 59 | engine.state.best_mrr = val_mrr 60 | return val_mrr 61 | else: 62 | raise e 63 | 64 | device = torch.device('cuda') 65 | 66 | eval_epoch = 20 # do link prediction evaluation each 20 epochs 67 | max_epochs = 1000 68 | patience = 40 69 | batch_size = 32768 70 | emb_dim = 100 71 | lr = 0.0004 72 | margin = 0.5 73 | 74 | kg_train, kg_val, kg_test = load_fb15k() 75 | 76 | # Define the model, optimizer and criterion 77 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, 78 | dissimilarity_type='L2') 79 | model.to(device) 80 | 81 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 82 | criterion = MarginLoss(margin) 83 | sampler = BernoulliNegativeSampler(kg_train, kg_val=kg_val, kg_test=kg_test) 84 | 85 | # Define the engine 86 | trainer = Engine(process_batch) 87 | 88 | # Define the moving average 89 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'margin') 90 | 91 | # Add early stopping 92 | handler = EarlyStopping(patience=patience, 93 | score_function=linkprediction_evaluation, 94 | trainer=trainer) 95 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) 96 | 97 | # Training 98 | train_iterator = DataLoader(kg_train, batch_size, use_cuda='all') 99 | trainer.run(train_iterator, 100 | epoch_length=len(train_iterator), 101 | max_epochs=max_epochs) 102 | 103 | print('Best score {:.3f} at epoch {}'.format(handler.best_score, 104 | trainer.state.epoch - handler.patience)) 105 | 106 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. 8 | 9 | You can contribute in many ways: 10 | 11 | Types of Contributions 12 | ---------------------- 13 | 14 | Report Bugs 15 | ~~~~~~~~~~~ 16 | 17 | Report bugs at https://github.com/torchkge-team/torchkge/issues. 18 | 19 | If you are reporting a bug, please include: 20 | 21 | * Your operating system name and version. 22 | * Any details about your local setup that might be helpful in troubleshooting. 23 | * Detailed steps to reproduce the bug. 24 | 25 | Fix Bugs 26 | ~~~~~~~~ 27 | 28 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants 29 | to implement it. 30 | 31 | Implement Features 32 | ~~~~~~~~~~~~~~~~~~ 33 | 34 | Look through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is opento whoever 35 | wants to implement it. 36 | 37 | Write Documentation 38 | ~~~~~~~~~~~~~~~~~~~ 39 | 40 | TorchKGE could always use more documentation, whether as part of the official TorchKGE docs, in docstrings, or even 41 | on the web in blog posts, articles, and such. 42 | 43 | Submit Feedback 44 | ~~~~~~~~~~~~~~~ 45 | 46 | The best way to send feedback is to file an issue at https://github.com/torchkge-team/torchkge/issues. 47 | 48 | If you are proposing a feature: 49 | 50 | * Explain in detail how it would work. 51 | * Keep the scope as narrow as possible, to make it easier to implement. 52 | * Remember that this is a volunteer-driven project, and that contributions 53 | are welcome :) 54 | 55 | Get Started! 56 | ------------ 57 | 58 | Ready to contribute? Here's how to set up `torchkge` for local development. 59 | 60 | 1. Fork the `torchkge` repo on GitHub. 61 | 2. Clone your fork locally:: 62 | 63 | $ git clone git@github.com:your_name_here/torchkge.git 64 | 65 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 66 | 67 | $ mkvirtualenv torchkge 68 | $ cd torchkge/ 69 | $ python setup.py develop 70 | 71 | 4. Create a branch for local development:: 72 | 73 | $ git checkout -b dev/name-of-your-bugfix-or-feature 74 | 75 | Now you can make your changes locally. 76 | 77 | 5. When you're done making changes, check that your changes pass tests, including testing other 78 | Python versions with tox:: 79 | 80 | $ flake8 torchkge tests 81 | $ python setup.py test or py.test 82 | $ tox 83 | 84 | To get tox, just pip install it into your virtualenv. 85 | 86 | 6. Commit your changes and push your branch to GitHub:: 87 | 88 | $ git add . 89 | $ git commit -m "Your detailed description of your changes." 90 | $ git push origin dev/name-of-your-bugfix-or-feature 91 | 92 | 7. Submit a pull request through the GitHub website. 93 | 94 | Pull Request Guidelines 95 | ----------------------- 96 | 97 | Before you submit a pull request, check that it meets these guidelines: 98 | 99 | 1. The pull request should include tests. 100 | 2. If the pull request adds functionality, the docs should be updated. Put 101 | your new functionality into a function with a docstring, and add the 102 | feature to the list in README.rst. 103 | 3. The pull request should work for Python 3,7, 3.8, 3.9 and for PyPi. Check 104 | https://github.com/torchkge-team/torchkge/actions 105 | and make sure that the tests pass for all supported Python versions. 106 | 107 | Deploying 108 | --------- 109 | 110 | A reminder for the maintainers on how to deploy. 111 | Make sure all your changes are committed (including an entry in HISTORY.rst). 112 | Then run:: 113 | 114 | $ bumpversion patch # possible: major / minor / patch 115 | $ git push 116 | $ git push --tags 117 | 118 | Github Actions will then deploy to PyPI if tests pass. 119 | -------------------------------------------------------------------------------- /docs/reference/utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | 4 | Utils 5 | ***** 6 | 7 | .. currentmodule:: torchkge.utils 8 | 9 | Datasets loaders 10 | ---------------- 11 | 12 | .. autofunction:: torchkge.utils.datasets.load_fb13 13 | .. autofunction:: torchkge.utils.datasets.load_fb15k 14 | .. autofunction:: torchkge.utils.datasets.load_fb15k237 15 | .. autofunction:: torchkge.utils.datasets.load_wn18 16 | .. autofunction:: torchkge.utils.datasets.load_wn18rr 17 | .. autofunction:: torchkge.utils.datasets.load_yago3_10 18 | .. autofunction:: torchkge.utils.datasets.load_wikidatasets 19 | .. autofunction:: torchkge.utils.datasets.load_wikidata_vitals 20 | 21 | 22 | Pre-trained models 23 | ------------------ 24 | 25 | TransE model 26 | ============ 27 | .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 28 | 29 | +-----------+-----------+-----------+----------+--------------------+ 30 | | Model | Dataset | Dimension | Test MRR | Filtered Test MRR | 31 | +===========+===========+===========+==========+====================+ 32 | | TransE | FB15k | 100 | 0.250 | 0.420 | 33 | +-----------+-----------+-----------+----------+--------------------+ 34 | | TransE | FB15k237 | 150 | 0.187 | 0.287 | 35 | +-----------+-----------+-----------+----------+--------------------+ 36 | | TransE | WDV5 | 150 | 0.258 | 0.305 | 37 | +-----------+-----------+-----------+----------+--------------------+ 38 | | TransE | WN18RR | 100 | 0.201 | 0.236 | 39 | +-----------+-----------+-----------+----------+--------------------+ 40 | | TransE | Yago3-10 | 200 | 0.143 | 0.261 | 41 | +-----------+-----------+-----------+----------+--------------------+ 42 | 43 | .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_transe 44 | 45 | RESCAL Model 46 | ============= 47 | .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm} 48 | 49 | +-----------+-----------+-----------+----------+--------------------+ 50 | | Model | Dataset | Dimension | Test MRR | Filtered Test MRR | 51 | +===========+===========+===========+==========+====================+ 52 | | RESCAL | FB15k237 | 200 | 0.180 | 0.307 | 53 | +-----------+-----------+-----------+----------+--------------------+ 54 | | RESCAL | WN18RR | 150 | 0.273 | 0.424 | 55 | +-----------+-----------+-----------+----------+--------------------+ 56 | | RESCAL | Yago3-10 | 200 | 0.127 | 0.334 | 57 | +-----------+-----------+-----------+----------+--------------------+ 58 | 59 | .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_rescal 60 | 61 | ComplEx Model 62 | ============= 63 | .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm} 64 | 65 | +-----------+-----------+-----------+----------+--------------------+ 66 | | Model | Dataset | Dimension | Test MRR | Filtered Test MRR | 67 | +===========+===========+===========+==========+====================+ 68 | | ComplEx | FB15k237 | 200 | 0.180 | 0.308 | 69 | +-----------+-----------+-----------+----------+--------------------+ 70 | | ComplEx | WN18RR | 200 | 0.290 | 0.455 | 71 | +-----------+-----------+-----------+----------+--------------------+ 72 | | ComplEx | WDV5 | 200 | 0.283 | 0.371 | 73 | +-----------+-----------+-----------+----------+--------------------+ 74 | | ComplEx | Yago3-10 | 200 | 0.164 | 0.421 | 75 | +-----------+-----------+-----------+----------+--------------------+ 76 | 77 | .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_complex 78 | 79 | Data redundancy 80 | --------------- 81 | .. autofunction:: torchkge.utils.data_redundancy.duplicates 82 | .. autofunction:: torchkge.utils.data_redundancy.count_triplets 83 | .. autofunction:: torchkge.utils.data_redundancy.cartesian_product_relations 84 | 85 | Dissimilarities 86 | --------------- 87 | .. autofunction:: torchkge.utils.dissimilarities.l1_dissimilarity 88 | .. autofunction:: torchkge.utils.dissimilarities.l2_dissimilarity 89 | .. autofunction:: torchkge.utils.dissimilarities.l1_torus_dissimilarity 90 | .. autofunction:: torchkge.utils.dissimilarities.l2_torus_dissimilarity 91 | .. autofunction:: torchkge.utils.dissimilarities.el2_torus_dissimilarity 92 | 93 | Losses 94 | ------ 95 | .. autoclass:: torchkge.utils.losses.MarginLoss 96 | :members: 97 | .. autoclass:: torchkge.utils.losses.LogisticLoss 98 | :members: 99 | .. autoclass:: torchkge.utils.losses.BinaryCrossEntropyLoss 100 | :members: 101 | 102 | Training wrappers 103 | ----------------- 104 | .. autoclass:: torchkge.utils.training.TrainDataLoader 105 | :members: 106 | .. autoclass:: torchkge.utils.training.Trainer 107 | :members: 108 | -------------------------------------------------------------------------------- /torchkge/utils/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import ones_like, zeros_like 8 | from torch.nn import Module, Sigmoid 9 | from torch.nn import MarginRankingLoss, SoftMarginLoss, BCELoss 10 | 11 | 12 | class MarginLoss(Module): 13 | """Margin loss as it was defined in `TransE paper 14 | `_ 15 | by Bordes et al. in 2013. This class implements :class:`torch.nn.Module` 16 | interface. 17 | 18 | """ 19 | def __init__(self, margin): 20 | super().__init__() 21 | self.loss = MarginRankingLoss(margin=margin, reduction='sum') 22 | 23 | def forward(self, positive_triplets, negative_triplets): 24 | """ 25 | Parameters 26 | ---------- 27 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 28 | Scores of the true triplets as returned by the `forward` methods of 29 | the models. 30 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 31 | Scores of the negative triplets as returned by the `forward` 32 | methods of the models. 33 | 34 | Returns 35 | ------- 36 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 37 | Loss of the form 38 | :math:`\\max\\{0, \\gamma - f(h,r,t) + f(h',r',t')\\}` where 39 | :math:`\\gamma` is the margin (defined at initialization), 40 | :math:`f(h,r,t)` is the score of a true fact and 41 | :math:`f(h',r',t')` is the score of the associated negative fact. 42 | """ 43 | return self.loss(positive_triplets, negative_triplets, 44 | target=ones_like(positive_triplets)) 45 | 46 | 47 | class LogisticLoss(Module): 48 | """Logistic loss as it was defined in `TransE paper 49 | `_ 50 | by Bordes et al. in 2013. This class implements :class:`torch.nn.Module` 51 | interface. 52 | 53 | """ 54 | def __init__(self): 55 | super().__init__() 56 | self.loss = SoftMarginLoss(reduction='sum') 57 | 58 | def forward(self, positive_triplets, negative_triplets): 59 | """ 60 | Parameters 61 | ---------- 62 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 63 | Scores of the true triplets as returned by the `forward` methods 64 | of the models. 65 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 66 | Scores of the negative triplets as returned by the `forward` 67 | methods of the models. 68 | Returns 69 | ------- 70 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 71 | Loss of the form :math:`\\log(1+ \\exp(\\eta \\times f(h,r,t))` 72 | where :math:`f(h,r,t)` is the score of the fact and :math:`\\eta` 73 | is either 1 or -1 if the fact is true or false. 74 | """ 75 | targets = ones_like(positive_triplets) 76 | return self.loss(positive_triplets, targets) + \ 77 | self.loss(negative_triplets, -targets) 78 | 79 | 80 | class BinaryCrossEntropyLoss(Module): 81 | """This class implements :class:`torch.nn.Module` interface. 82 | 83 | """ 84 | 85 | def __init__(self): 86 | super().__init__() 87 | self.sig = Sigmoid() 88 | self.loss = BCELoss(reduction='sum') 89 | 90 | def forward(self, positive_triplets, negative_triplets): 91 | """ 92 | 93 | Parameters 94 | ---------- 95 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 96 | Scores of the true triplets as returned by the `forward` methods 97 | of the models. 98 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 99 | Scores of the negative triplets as returned by the `forward` 100 | methods of the models. 101 | Returns 102 | ------- 103 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 104 | Loss of the form :math:`-\\eta \\cdot \\log(f(h,r,t)) + 105 | (1-\\eta) \\cdot \\log(1 - f(h,r,t))` where :math:`f(h,r,t)` 106 | is the score of the fact and :math:`\\eta` is either 1 or 107 | 0 if the fact is true or false. 108 | """ 109 | return self.loss(self.sig(positive_triplets), 110 | ones_like(positive_triplets)) + \ 111 | self.loss(self.sig(negative_triplets), 112 | zeros_like(negative_triplets)) 113 | -------------------------------------------------------------------------------- /torchkge/utils/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | import shutil 8 | 9 | from os import environ, makedirs 10 | from os.path import exists, expanduser, join, abspath, commonprefix 11 | 12 | def is_within_directory(directory, target): 13 | abs_directory = abspath(directory) 14 | abs_target = abspath(target) 15 | 16 | prefix = commonprefix([abs_directory, abs_target]) 17 | 18 | return prefix == abs_directory 19 | 20 | 21 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 22 | for member in tar.getmembers(): 23 | member_path = join(path, member.name) 24 | if not is_within_directory(path, member_path): 25 | raise Exception("Attempted Path Traversal in Tar File") 26 | 27 | tar.extractall(path, members, numeric_owner=numeric_owner) 28 | 29 | 30 | def get_data_home(data_home=None): 31 | """Returns the path to the data directory. The path is created if 32 | it does not exist. 33 | 34 | If data_home is none, the data is downloaded into the home directory of 35 | of the user. 36 | 37 | Parameters 38 | ---------- 39 | data_home: string 40 | The path to the data set. 41 | """ 42 | if data_home is None: 43 | data_home = environ.get('TORCHKGE_DATA', 44 | join('~', 'torchkge_data')) 45 | data_home = expanduser(data_home) 46 | if not exists(data_home): 47 | makedirs(data_home) 48 | return data_home 49 | 50 | 51 | def clear_data_home(data_home=None): 52 | """Deletes the directory data_home 53 | 54 | Parameters 55 | ---------- 56 | data_home: string 57 | The path to the directory that should be removed. 58 | """ 59 | data_home = get_data_home(data_home) 60 | shutil.rmtree(data_home) 61 | 62 | 63 | def get_n_batches(n, b_size): 64 | """Returns the number of bachtes. Let n be the number of samples in the data set, 65 | let batch_size be the number of samples per batch, then the number of batches is given by 66 | n 67 | n_batches = --------- 68 | batch_size 69 | 70 | Parameters 71 | ---------- 72 | n: int 73 | Size of the data set. 74 | b_size: int 75 | Number of samples per batch. 76 | """ 77 | n_batch = n // b_size 78 | if n % b_size > 0: 79 | n_batch += 1 80 | return n_batch 81 | 82 | 83 | class DataLoader: 84 | """This class is inspired from :class:`torch.utils.dataloader.DataLoader`. 85 | It is however way simpler. 86 | 87 | """ 88 | def __init__(self, kg, batch_size, use_cuda=None): 89 | """ 90 | 91 | Parameters 92 | ---------- 93 | kg: torchkge.data_structures.KnowledgeGraph or torchkge.data_structures.SmallKG 94 | Knowledge graph in the form of an object implemented in 95 | torchkge.data_structures. 96 | batch_size: int 97 | Size of the required batches. 98 | use_cuda: str (opt, default = None) 99 | Can be either None (no use of cuda at all), 'all' to move all the 100 | dataset to cuda and then split in batches or 'batch' to simply move 101 | the batches to cuda before they are returned. 102 | """ 103 | self.h = kg.head_idx 104 | self.t = kg.tail_idx 105 | self.r = kg.relations 106 | 107 | self.use_cuda = use_cuda 108 | self.batch_size = batch_size 109 | 110 | if use_cuda is not None and use_cuda == 'all': 111 | self.h = self.h.cuda() 112 | self.t = self.t.cuda() 113 | self.r = self.r.cuda() 114 | 115 | def __len__(self): 116 | return get_n_batches(len(self.h), self.batch_size) 117 | 118 | def __iter__(self): 119 | return _DataLoaderIter(self) 120 | 121 | 122 | class _DataLoaderIter: 123 | def __init__(self, loader): 124 | self.h = loader.h 125 | self.t = loader.t 126 | self.r = loader.r 127 | 128 | self.use_cuda = loader.use_cuda 129 | self.batch_size = loader.batch_size 130 | 131 | self.n_batches = get_n_batches(len(self.h), self.batch_size) 132 | self.current_batch = 0 133 | 134 | def __next__(self): 135 | if self.current_batch == self.n_batches: 136 | raise StopIteration 137 | else: 138 | i = self.current_batch 139 | self.current_batch += 1 140 | 141 | tmp_h = self.h[i * self.batch_size: (i + 1) * self.batch_size] 142 | tmp_t = self.t[i * self.batch_size: (i + 1) * self.batch_size] 143 | tmp_r = self.r[i * self.batch_size: (i + 1) * self.batch_size] 144 | 145 | if self.use_cuda is not None and self.use_cuda == 'batch': 146 | return tmp_h.cuda(), tmp_t.cuda(), tmp_r.cuda() 147 | else: 148 | return tmp_h, tmp_t, tmp_r 149 | 150 | def __iter__(self): 151 | return self 152 | -------------------------------------------------------------------------------- /torchkge/utils/pretrained_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from ..exceptions import NoPreTrainedVersionError 8 | from ..models import TransEModel, ComplExModel, RESCALModel 9 | from ..utils import load_embeddings 10 | 11 | 12 | def load_pretrained_transe(dataset, emb_dim=None, data_home=None): 13 | """Load a pretrained version of TransE model. 14 | 15 | Parameters 16 | ---------- 17 | dataset: str 18 | emb_dim: int (opt, default None) 19 | Embedding dimension 20 | data_home: str (opt, default None) 21 | Path to the `torchkge_data` directory (containing data folders). Useful 22 | for pre-trained model loading. 23 | 24 | Returns 25 | ------- 26 | model: `TorchKGE.model.translation.TransEModel` 27 | Pretrained version of TransE model. 28 | """ 29 | dims = {'fb15k': 100, 'wn18rr': 100, 'fb15k237': 150, 'wdv5': 150, 'yago310': 200} 30 | try: 31 | if emb_dim is None: 32 | emb_dim = dims[dataset] 33 | else: 34 | try: 35 | assert dims[dataset] == emb_dim 36 | except AssertionError: 37 | raise NoPreTrainedVersionError('No pre-trained version of TransE for ' 38 | '{} in dimension {}'.format(dataset, emb_dim)) 39 | except KeyError: 40 | raise NoPreTrainedVersionError('No pre-trained version of TransE for {}'.format(dataset)) 41 | 42 | state_dict = load_embeddings('transe', emb_dim, dataset, data_home) 43 | model = TransEModel(emb_dim, 44 | n_entities=state_dict['ent_emb.weight'].shape[0], 45 | n_relations=state_dict['rel_emb.weight'].shape[0], 46 | dissimilarity_type='L2') 47 | model.load_state_dict(state_dict) 48 | 49 | return model 50 | 51 | 52 | def load_pretrained_complex(dataset, emb_dim=None, data_home=None): 53 | """Load a pretrained version of ComplEx model. 54 | 55 | Parameters 56 | ---------- 57 | dataset: str 58 | emb_dim: int (opt, default None) 59 | Embedding dimension 60 | data_home: str (opt, default None) 61 | Path to the `torchkge_data` directory (containing data folders). Useful 62 | for pre-trained model loading. 63 | 64 | Returns 65 | ------- 66 | model: `TorchKGE.model.translation.ComplExModel` 67 | Pretrained version of ComplEx model. 68 | """ 69 | dims = {'wn18rr': 200, 'fb15k237': 200, 'wdv5': 200, 'yago310': 200} 70 | try: 71 | if emb_dim is None: 72 | emb_dim = dims[dataset] 73 | else: 74 | try: 75 | assert dims[dataset] == emb_dim 76 | except AssertionError: 77 | raise NoPreTrainedVersionError('No pre-trained version of ComplEx for ' 78 | '{} in dimension {}'.format(dataset, emb_dim)) 79 | except KeyError: 80 | raise NoPreTrainedVersionError('No pre-trained version of ComplEx for {}'.format(dataset)) 81 | 82 | state_dict = load_embeddings('complex', emb_dim, dataset, data_home) 83 | model = ComplExModel(emb_dim, 84 | n_entities=state_dict['re_ent_emb.weight'].shape[0], 85 | n_relations=state_dict['re_rel_emb.weight'].shape[0]) 86 | model.load_state_dict(state_dict) 87 | 88 | return model 89 | 90 | 91 | def load_pretrained_rescal(dataset, emb_dim=None, data_home=None): 92 | """Load a pretrained version of RESCAL model. 93 | 94 | Parameters 95 | ---------- 96 | dataset: str 97 | emb_dim: int (opt, default None) 98 | Embedding dimension 99 | data_home: str (opt, default None) 100 | Path to the `torchkge_data` directory (containing data folders). Useful 101 | for pre-trained model loading. 102 | 103 | Returns 104 | ------- 105 | model: `TorchKGE.model.translation.RESCALModel` 106 | Pretrained version of RESCAL model. 107 | """ 108 | dims = {'wn18rr': 200, 'fb15k237': 200, 'yago310': 200} 109 | try: 110 | if emb_dim is None: 111 | emb_dim = dims[dataset] 112 | else: 113 | try: 114 | assert dims[dataset] == emb_dim 115 | except AssertionError: 116 | raise NoPreTrainedVersionError('No pre-trained version of RESCAL for ' 117 | '{} in dimension {}'.format(dataset, emb_dim)) 118 | except KeyError: 119 | raise NoPreTrainedVersionError('No pre-trained version of RESCAL for {}'.format(dataset)) 120 | 121 | state_dict = load_embeddings('rescal', emb_dim, dataset, data_home) 122 | model = RESCALModel(emb_dim, 123 | n_entities=state_dict['ent_emb.weight'].shape[0], 124 | n_relations=state_dict['rel_mat.weight'].shape[0]) 125 | model.load_state_dict(state_dict) 126 | 127 | return model 128 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # torchkge documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another 17 | # directory, add these directories to sys.path here. If the directory is 18 | # relative to the documentation root, use os.path.abspath to make it 19 | # absolute, like shown here. 20 | # 21 | import os 22 | import sys 23 | sys.path.insert(0, os.path.abspath('..')) 24 | 25 | import torchkge 26 | 27 | # -- General configuration --------------------------------------------- 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 35 | extensions = ['sphinx.ext.autodoc', 36 | 'sphinx.ext.viewcode', 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.coverage', 39 | 'sphinx.ext.doctest', 40 | 'sphinx.ext.intersphinx', 41 | 'sphinx.ext.mathjax', 42 | 'sphinx.ext.napoleon', 43 | 'sphinx.ext.todo', 44 | 'sphinx.ext.viewcode'] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = '.rst' 54 | 55 | # The master toctree document. 56 | master_doc = 'index' 57 | 58 | # General information about the project. 59 | project = u'TorchKGE' 60 | copyright = u"2022, TorchKGE developers" 61 | author = u"Armand Boschin" 62 | 63 | # The version info for the project you're documenting, acts as replacement 64 | # for |version| and |release|, also used in various other places throughout 65 | # the built documents. 66 | # 67 | # The short X.Y version. 68 | version = torchkge.__version__ 69 | # The full version, including alpha/beta/rc tags. 70 | release = torchkge.__version__ 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = 'en' 78 | 79 | # List of patterns, relative to source directory, that match files and 80 | # directories to ignore when looking for source files. 81 | # This patterns also effect to html_static_path and html_extra_path 82 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 83 | 84 | # The name of the Pygments (syntax highlighting) style to use. 85 | pygments_style = 'sphinx' 86 | 87 | # If true, `todo` and `todoList` produce output, else they produce nothing. 88 | todo_include_todos = False 89 | 90 | 91 | # -- Options for HTML output ------------------------------------------- 92 | 93 | # The theme to use for HTML and HTML Help pages. See the documentation for 94 | # a list of builtin themes. 95 | # 96 | html_theme = 'sphinx_rtd_theme' 97 | html_logo = 'logo_torchKGE_small.png' 98 | html_favicon = "logo_torchKGE_small.png" 99 | html_theme_options = { 100 | 'logo_only': True 101 | } 102 | 103 | # Theme options are theme-specific and customize the look and feel of a 104 | # theme further. For a list of options available for each theme, see the 105 | # documentation. 106 | # 107 | # html_theme_options = {} 108 | 109 | # Add any paths that contain custom static files (such as style sheets) here, 110 | # relative to this directory. They are copied after the builtin static files, 111 | # so a file named "default.css" will overwrite the builtin "default.css". 112 | html_static_path = ['_static'] 113 | 114 | 115 | # -- Options for HTMLHelp output --------------------------------------- 116 | 117 | # Output file base name for HTML help builder. 118 | htmlhelp_basename = 'torchkgedoc' 119 | 120 | 121 | # -- Options for LaTeX output ------------------------------------------ 122 | 123 | latex_elements = { 124 | # The paper size ('letterpaper' or 'a4paper'). 125 | # 126 | # 'papersize': 'letterpaper', 127 | 128 | # The font size ('10pt', '11pt' or '12pt'). 129 | # 130 | # 'pointsize': '10pt', 131 | 132 | # Additional stuff for the LaTeX preamble. 133 | # 134 | # 'preamble': '', 135 | 136 | # Latex figure (float) alignment 137 | # 138 | # 'figure_align': 'htbp', 139 | } 140 | 141 | # Grouping the document tree into LaTeX files. List of tuples 142 | # (source start file, target name, title, author, documentclass 143 | # [howto, manual, or own class]). 144 | latex_documents = [ 145 | (master_doc, 'torchkge.tex', 146 | u'TorchKGE Documentation', 147 | u'Armand Boschin', 'manual'), 148 | ] 149 | 150 | 151 | # -- Options for manual page output ------------------------------------ 152 | 153 | # One entry per manual page. List of tuples 154 | # (source start file, name, description, authors, manual section). 155 | man_pages = [ 156 | (master_doc, 'torchkge', 157 | u'TorchKGE Documentation', 158 | [author], 1) 159 | ] 160 | 161 | 162 | # -- Options for Texinfo output ---------------------------------------- 163 | 164 | # Grouping the document tree into Texinfo files. List of tuples 165 | # (source start file, target name, title, author, 166 | # dir menu entry, description, category) 167 | texinfo_documents = [ 168 | (master_doc, 'torchkge', 169 | u'TorchKGE Documentation', 170 | author, 171 | 'torchkge', 172 | 'One line description of project.', 173 | 'Miscellaneous'), 174 | ] 175 | 176 | 177 | def setup(app): 178 | app.add_css_file('css/custom.css') 179 | 180 | 181 | nbsphinx_kernel_name = 'python3' 182 | -------------------------------------------------------------------------------- /torchkge/models/deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import nn, cat 8 | 9 | from ..models.interfaces import Model 10 | from ..utils import init_embedding 11 | 12 | 13 | class ConvKBModel(Model): 14 | """Implementation of ConvKB model detailed in 2018 paper by Nguyen et al.. 15 | This class inherits from the :class:`torchkge.models.interfaces.Model` 16 | interface. It then has its attributes as well. 17 | 18 | 19 | References 20 | ---------- 21 | * Nguyen, D. Q., Nguyen, T. D., Nguyen, D. Q., and Phung, D. 22 | `A Novel Embed- ding Model for Knowledge Base Completion Based on 23 | Convolutional Neural Network. 24 | `_ 25 | In Proceedings of the 2018 Conference of the North American Chapter of 26 | the Association for Computational Linguistics: Human Language 27 | Technologies (2018), vol. 2, pp. 327–333. 28 | 29 | Parameters 30 | ---------- 31 | emb_dim: int 32 | Dimension of embedding space. 33 | n_filters: int 34 | Number of filters used for convolution. 35 | n_entities: int 36 | Number of entities in the current data set. 37 | n_relations: int 38 | Number of relations in the current data set. 39 | 40 | Attributes 41 | ---------- 42 | ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) 43 | Embeddings of the entities, initialized with Xavier uniform 44 | distribution and then normalized. 45 | rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) 46 | Embeddings of the relations, initialized with Xavier uniform 47 | distribution. 48 | 49 | """ 50 | 51 | def __init__(self, emb_dim, n_filters, n_entities, n_relations): 52 | super().__init__(n_entities, n_relations) 53 | self.emb_dim = emb_dim 54 | 55 | self.ent_emb = init_embedding(self.n_ent, self.emb_dim) 56 | self.rel_emb = init_embedding(self.n_rel, self.emb_dim) 57 | 58 | self.convlayer = nn.Sequential(nn.Conv1d(3, n_filters, 1, stride=1), 59 | nn.ReLU()) 60 | self.output = nn.Sequential(nn.Linear(emb_dim * n_filters, 2), 61 | nn.Softmax(dim=1)) 62 | 63 | def scoring_function(self, h_idx, t_idx, r_idx): 64 | """Compute the scoring function for the triplets given as argument: 65 | by applying convolutions to the concatenation of the embeddings. See 66 | referenced paper for more details on the score. See 67 | torchkge.models.interfaces.Models for more details on the API. 68 | 69 | """ 70 | b_size = h_idx.shape[0] 71 | 72 | h = self.ent_emb(h_idx).view(b_size, 1, -1) 73 | t = self.ent_emb(t_idx).view(b_size, 1, -1) 74 | r = self.rel_emb(r_idx).view(b_size, 1, -1) 75 | concat = cat((h, r, t), dim=1) 76 | 77 | return self.output(self.convlayer(concat).reshape(b_size, -1))[:, 1] 78 | 79 | def normalize_parameters(self): 80 | """Normalize the entity embeddings, as explained in original paper. 81 | This methods should be called at the end of each training epoch and at 82 | the end of training as well. 83 | 84 | """ 85 | pass 86 | 87 | def get_embeddings(self): 88 | """Return the embeddings of entities and relations. 89 | 90 | Returns 91 | ------- 92 | ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float 93 | Embeddings of entities. 94 | rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float 95 | Embeddings of relations. 96 | 97 | """ 98 | self.normalize_parameters() 99 | return self.ent_emb.weight.data, self.rel_emb.weight.data 100 | 101 | def inference_scoring_function(self, h, t, r): 102 | """Link prediction evaluation helper function. See 103 | torchkge.models.interfaces.Models for more details on the API. 104 | 105 | """ 106 | b_size = h.shape[0] 107 | 108 | if (len(h.shape) == 2) & (len(t.shape) == 4) & (len(r.shape) == 2): 109 | concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 110 | r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 111 | t), dim=2) 112 | concat = concat.reshape(-1, 3, self.emb_dim) 113 | 114 | elif (len(h.shape) == 4) & (len(t.shape) == 2) & (len(r.shape) == 2): 115 | concat = cat((h, 116 | r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 117 | t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim)), dim=2) 118 | concat = concat.reshape(-1, 3, self.emb_dim) 119 | 120 | else: 121 | assert (len(h.shape) == 2) & (len(t.shape) == 2) & (len(r.shape) == 4) 122 | concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim), 123 | r, 124 | t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim)), dim=2) 125 | concat = concat.reshape(-1, 3, self.emb_dim) 126 | 127 | scores = self.output(self.convlayer(concat).reshape(concat.shape[0], -1)) 128 | scores = scores.reshape(b_size, -1, 2) 129 | 130 | return scores[:, :, 1] 131 | 132 | def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): 133 | """Link prediction evaluation helper function. Get entities embeddings 134 | and relations embeddings. The output will be fed to the 135 | `inference_scoring_function` method. See torchkge.models.interfaces.Models for 136 | more details on the API. 137 | 138 | """ 139 | b_size = h_idx.shape[0] 140 | 141 | h = self.ent_emb(h_idx) 142 | t = self.ent_emb(t_idx) 143 | r = self.rel_emb(r_idx) 144 | 145 | if entities: 146 | candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) 147 | candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) 148 | candidates = candidates.view(b_size, self.n_ent, 1, self.emb_dim) 149 | else: 150 | candidates = self.rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) 151 | candidates = candidates.expand(b_size, self.n_rel, self.emb_dim) 152 | candidates = candidates.view(b_size, self.n_rel, 1, self.emb_dim) 153 | 154 | return h, t, r, candidates 155 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | 4 | from collections import defaultdict 5 | from torch import tensor, cat, eq, bool 6 | from torch.nn import Embedding 7 | 8 | from torchkge.data_structures import KnowledgeGraph 9 | from torchkge.utils.dissimilarities import l1_dissimilarity, l2_dissimilarity, \ 10 | l1_torus_dissimilarity, l2_torus_dissimilarity, el2_torus_dissimilarity 11 | from torchkge.utils.modeling import init_embedding, get_true_targets 12 | from torchkge.sampling import get_possible_heads_tails 13 | from torchkge.utils.operations import get_mask, get_rank 14 | from torchkge.utils.operations import get_dictionaries, get_tph, get_hpt, \ 15 | get_bernoulli_probs 16 | 17 | 18 | class TestUtils(unittest.TestCase): 19 | """Tests for `torchkge.utils`.""" 20 | 21 | def setUp(self): 22 | self.df = pd.DataFrame([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], 23 | [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], 24 | [5, 4, 0]], columns=['from', 'to', 'rel']) 25 | self.heads = tensor([0, 0, 0, 0, 1, 1, 2, 3, 5]).long() 26 | self.tails = tensor([1, 2, 3, 4, 2, 3, 4, 4, 4]).long() 27 | self.rels = tensor([0, 0, 0, 0, 1, 2, 0, 4, 0]).long() 28 | 29 | # get_dictionaries 30 | self.d1 = {'a': 0, 'b': 1, 'c': 2, 'd': 3} 31 | self.d2 = {'R1': 0, 'R2': 1} 32 | 33 | # dissimilarities 34 | self.a = tensor([[1.4, 2, 3, 4], [5.4, 6, 7, 8]]).float() 35 | self.b = tensor([[1.3, 4, 2, 10], [5.9, 8, 6, 7]]).float() 36 | 37 | # bernoulli 38 | self.t1 = tensor([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], 39 | [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], 40 | [5, 4, 0]]) 41 | self.t2 = tensor([[1, 44], [2, 33], [3, 44], [4, 33]]) 42 | self.r1_mean = {tensor([1, 0]): 0.0, tensor([2, 0]): 0.0, 43 | tensor([2, 1]): 1.0, tensor([3, 0]): 0.0, 44 | tensor([3, 2]): 1.0, tensor([4, 0]): 2.3333333, 45 | tensor([4, 4]): 3.0} 46 | self.r1_count = {tensor([1, 0]): 1, tensor([2, 0]): 1, 47 | tensor([2, 1]): 1, tensor([3, 0]): 1, 48 | tensor([3, 2]): 1, tensor([4, 0]): 3, 49 | tensor([4, 4]): 1} 50 | self.r2_mean = {33: 3., 44: 2.} 51 | self.r2_count = {33: 2, 44: 2} 52 | 53 | # get_true_targets 54 | self.e_idx = tensor([0, 0, 0]).long() 55 | self.r_idx = tensor([0, 0, 1]).long() 56 | self.true_idx = tensor([1, 2, 1]).long() 57 | self.dictionary = {(0, 0): [0, 1, 2], (0, 1): [1]} 58 | 59 | @staticmethod 60 | def compare_dicts_tensorkeys(d1, d2): 61 | for k in d1.keys(): 62 | found = False 63 | for kk in d2.keys(): 64 | if eq(k, kk).all(): 65 | found = True 66 | assert (d1[k] - d2[kk]) < 1e-03 67 | continue 68 | if not found: 69 | raise AssertionError 70 | 71 | def test_get_dictionaries(self): 72 | df = pd.DataFrame([['a', 'R1', 'b'], ['c', 'R2', 'd']], 73 | columns=['from', 'rel', 'to']) 74 | assert get_dictionaries(df, ent=True) == self.d1 75 | assert get_dictionaries(df, ent=False) == self.d2 76 | 77 | def test_get_tph(self): 78 | kg = KnowledgeGraph(df=self.df) 79 | t = cat((kg.head_idx.view(-1, 1), kg.tail_idx.view(-1, 1), 80 | kg.relations.view(-1, 1)), dim=1) 81 | assert get_tph(t) == {0: 2., 1: 1., 2: 1., 3: 1.} 82 | 83 | def test_get_hpt(self): 84 | kg = KnowledgeGraph(df=self.df) 85 | t = cat((kg.head_idx.view(-1, 1), kg.tail_idx.view(-1, 1), 86 | kg.relations.view(-1, 1)), dim=1) 87 | assert get_hpt(t) == {0: 1.5, 1: 1., 2: 1., 3: 1.} 88 | 89 | def test_get_bernoulli_probs(self): 90 | kg = KnowledgeGraph(df=self.df) 91 | probs = get_bernoulli_probs(kg) 92 | res = {0: 0.5714, 1: 0.5, 2: 0.5, 3: 0.5} 93 | 94 | for k in probs.keys(): 95 | assert (res[k] - probs[k]) < 1e-03 96 | 97 | def test_dissimilarities(self): 98 | assert ((l1_dissimilarity(self.a, self.b) == 99 | tensor([9.1000, 4.5000])).all() == 1) 100 | assert ((l2_dissimilarity(self.a, self.b) - 101 | tensor([41.0100, 6.2500])).sum() < 1e-03) 102 | assert ((l1_torus_dissimilarity(self.a, self.b) - 103 | 2 * tensor([0.1000, 0.5000])).sum() < 1e-03) 104 | assert ((l2_torus_dissimilarity(self.a, self.b) - 105 | 4 * tensor([0.1000, 0.5000])**2).sum() < 1e-03) 106 | assert ((el2_torus_dissimilarity(self.a, self.b) - 107 | tensor([0.6180, 2.0000])).sum() < 1e-03) 108 | 109 | def test_init_embedding(self): 110 | n = 10 111 | dim = 100 112 | 113 | p = init_embedding(n, dim) 114 | 115 | assert type(p) == Embedding 116 | assert p.weight.requires_grad 117 | assert p.weight.shape == (10, 100) 118 | 119 | def test_get_true_targets(self): 120 | assert eq(get_true_targets(self.dictionary, self.e_idx, 121 | self.r_idx, self.true_idx, 0), 122 | tensor([0, 2]).long()).all().item() 123 | assert eq(get_true_targets(self.dictionary, self.e_idx, 124 | self.r_idx, self.true_idx, 1), 125 | tensor([0, 1]).long()).all().item() 126 | assert get_true_targets(self.dictionary, self.e_idx, 127 | self.r_idx, self.true_idx, 2) is None 128 | 129 | def test_get_possible_heads_tails(self): 130 | kg = KnowledgeGraph(self.df) 131 | h, t = get_possible_heads_tails(kg) 132 | 133 | assert (type(h) == dict) & (type(t) == dict) 134 | 135 | assert h == {0: {0, 2, 5}, 1: {1}, 2: {1}, 3: {3}} 136 | assert t == {0: {1, 2, 3, 4}, 1: {2}, 2: {3}, 3: {4}} 137 | 138 | p_h, p_t = defaultdict(set), defaultdict(set) 139 | p_h[0].add(40) 140 | p_h[10].add(50) 141 | p_t[0].add(41) 142 | p_t[10].add(51) 143 | 144 | h, t = get_possible_heads_tails(kg, possible_heads=dict(p_h), 145 | possible_tails=dict(p_t)) 146 | 147 | assert h == {0: {0, 2, 5, 40}, 1: {1}, 2: {1}, 3: {3}, 10: {50}} 148 | assert t == {0: {1, 2, 3, 4, 41}, 1: {2}, 2: {3}, 3: {4}, 10: {51}} 149 | 150 | def test_get_mask(self): 151 | m = get_mask(10, 1, 2) 152 | assert m.dtype == bool 153 | assert len(m.shape) == 1 154 | assert m.shape[0] == 10 155 | 156 | def test_get_rank(self): 157 | data = tensor([[1, 2, 3, 4, 0], [1, 2, 1, 3, 0]]).float() 158 | true = tensor([4, 2]) 159 | r1 = get_rank(data, true) 160 | r2 = get_rank(data, true, low_values=True) 161 | 162 | assert eq(r1, tensor([5, 4])).all() 163 | assert eq(r2, tensor([1, 3])).all() 164 | -------------------------------------------------------------------------------- /torchkge/utils/operations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from collections import defaultdict 8 | from pandas import DataFrame 9 | from torch import zeros, cat 10 | from numpy import unique 11 | 12 | 13 | def get_mask(length, start, end): 14 | """Create a mask of length `length` filled with 0s except between indices 15 | `start` (included) and `end` (excluded). 16 | 17 | Parameters 18 | ---------- 19 | length: int 20 | Length of the mask to be created. 21 | start: int 22 | First index (included) where the mask will be filled with 0s. 23 | end: int 24 | Last index (excluded) where the mask will be filled with 0s. 25 | 26 | Returns 27 | ------- 28 | mask: `torch.Tensor`, shape: (length), dtype: `torch.bool` 29 | Mask of length `length` filled with 0s except between indices `start` 30 | (included) and `end` (excluded). 31 | """ 32 | mask = zeros(length) 33 | mask[[i for i in range(start, end)]] = 1 34 | return mask.bool() 35 | 36 | 37 | def get_rank(data, true, low_values=False): 38 | """Computes the rank of entity at index true[i]. If the rank is k then 39 | there are k-1 entities with better (higher or lower) value in data. 40 | 41 | Parameters 42 | ---------- 43 | data: `torch.Tensor`, dtype: `torch.float`, shape: (n_facts, dimensions) 44 | Scores for each entity. 45 | true: `torch.Tensor`, dtype: `torch.int`, shape: (n_facts) 46 | true[i] is the index of the true entity for test i of the batch. 47 | low_values: bool, optional (default=False) 48 | if True, best rank is the lowest score else it is the highest. 49 | 50 | Returns 51 | ------- 52 | ranks: `torch.Tensor`, dtype: `torch.int`, shape: (n_facts) 53 | ranks[i] - 1 is the number of entities which have better (or same) 54 | scores in data than the one and index true[i] 55 | """ 56 | true_data = data.gather(1, true.long().view(-1, 1)) 57 | 58 | if low_values: 59 | return (data <= true_data).sum(dim=1) 60 | else: 61 | return (data >= true_data).sum(dim=1) 62 | 63 | 64 | def get_dictionaries(df, ent=True): 65 | """Build entities or relations dictionaries. 66 | 67 | Parameters 68 | ---------- 69 | df: `pandas.DataFrame` 70 | Data frame containing three columns [from, to, rel]. 71 | ent: bool 72 | if True then ent2ix is returned, if False then rel2ix is returned. 73 | 74 | Returns 75 | ------- 76 | dict: dictionary 77 | Either ent2ix or rel2ix. 78 | 79 | """ 80 | if ent: 81 | tmp = list(set(df['from'].unique()).union(set(df['to'].unique()))) 82 | return {ent: i for i, ent in enumerate(sorted(tmp))} 83 | else: 84 | tmp = list(df['rel'].unique()) 85 | return {rel: i for i, rel in enumerate(sorted(tmp))} 86 | 87 | 88 | def extend_dicts(kg, attributes): 89 | ent2ix = {k: v for k, v in kg.ent2ix.items()} 90 | rel2ix = {k: v for k, v in kg.rel2ix.items()} 91 | 92 | assert len(ent2ix) == len(unique(list(ent2ix.values()))) 93 | assert len(rel2ix) == len(unique(list(rel2ix.values()))) 94 | 95 | tmp = list(set(attributes['from'].unique()).union(set(attributes['to'].unique()))) 96 | for ent in sorted(tmp): 97 | if ent in ent2ix.keys(): 98 | continue 99 | else: 100 | ent2ix[ent] = len(ent2ix) 101 | 102 | tmp = list(attributes['rel'].unique()) 103 | for rel in sorted(tmp): 104 | if rel in rel2ix.keys(): 105 | continue 106 | else: 107 | rel2ix[rel] = len(rel2ix) 108 | assert len(ent2ix) == len(unique(list(ent2ix.values()))) 109 | assert len(rel2ix) == len(unique(list(rel2ix.values()))) 110 | assert len(ent2ix) == (max(list(ent2ix.values())) + 1) 111 | assert len(rel2ix) == (max(list(rel2ix.values())) + 1) 112 | 113 | return ent2ix, rel2ix 114 | 115 | 116 | def get_tph(t): 117 | """Get the average number of tail per heads for each relation. 118 | 119 | Parameters 120 | ---------- 121 | t: `torch.Tensor`, dtype: `torch.long`, shape: (b_size, 3) 122 | First column contains head indices, second tails and third relations. 123 | Returns 124 | ------- 125 | d: dict 126 | keys: relation indices, values: average number of tail per heads. 127 | """ 128 | df = DataFrame(t.numpy(), columns=['from', 'to', 'rel']) 129 | df = df.groupby(['from', 'rel']).count().groupby('rel').mean() 130 | df.reset_index(inplace=True) 131 | return {df.loc[i].values[0]: df.loc[i].values[1] for i in df.index} 132 | 133 | 134 | def get_hpt(t): 135 | """Get the average number of head per tails for each relation. 136 | 137 | Parameters 138 | ---------- 139 | t: `torch.Tensor`, dtype: `torch.long`, shape: (b_size, 3) 140 | First column contains head indices, second tails and third relations. 141 | Returns 142 | ------- 143 | d: dict 144 | keys: relation indices, values: average number of head per tails. 145 | """ 146 | df = DataFrame(t.numpy(), columns=['from', 'to', 'rel']) 147 | df = df.groupby(['rel', 'to']).count().groupby('rel').mean() 148 | df.reset_index(inplace=True) 149 | return {df.loc[i].values[0]: df.loc[i].values[1] for i in df.index} 150 | 151 | 152 | def get_bernoulli_probs(kg): 153 | """Evaluate the Bernoulli probabilities for negative sampling as in the 154 | TransH original paper by Wang et al. (2014). 155 | 156 | Parameters 157 | ---------- 158 | kg: `torchkge.data_structures.KnowledgeGraph` 159 | 160 | Returns 161 | ------- 162 | tph: dict 163 | keys: relations , values: sampling probabilities as described by 164 | Wang et al. in their paper. 165 | 166 | """ 167 | t = cat((kg.head_idx.view(-1, 1), 168 | kg.tail_idx.view(-1, 1), 169 | kg.relations.view(-1, 1)), dim=1) 170 | 171 | hpt = get_hpt(t) 172 | tph = get_tph(t) 173 | 174 | assert hpt.keys() == tph.keys() 175 | 176 | for k in tph.keys(): 177 | tph[k] = tph[k] / (tph[k] + hpt[k]) 178 | 179 | return tph 180 | 181 | 182 | def get_fitlering_dictionaries(kg, kg_te=None): 183 | dict_of_heads = defaultdict(set) 184 | dict_of_tails = defaultdict(set) 185 | dict_of_rels = defaultdict(set) 186 | for i in range(kg.n_facts): 187 | dict_of_heads[(kg.tail_idx[i].item(), kg.relations[i].item())].add(kg.head_idx[i].item()) 188 | dict_of_tails[(kg.head_idx[i].item(), kg.relations[i].item())].add(kg.tail_idx[i].item()) 189 | dict_of_rels[(kg.head_idx[i].item(), kg.tail_idx[i].item())].add(kg.relations[i].item()) 190 | if kg_te is not None: 191 | for i in range(kg_te.n_facts): 192 | dict_of_rels[(kg_te.tail_idx[i].item(), kg_te.relations[i].item())].add(kg_te.head_idx[i].item()) 193 | dict_of_rels[(kg_te.head_idx[i].item(), kg_te.relations[i].item())].add(kg_te.tail_idx[i].item()) 194 | dict_of_rels[(kg_te.head_idx[i].item(), kg_te.tail_idx[i].item())].add(kg_te.relations[i].item()) 195 | return dict_of_heads, dict_of_tails, dict_of_rels -------------------------------------------------------------------------------- /torchkge/utils/training.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | from typing import Optional 7 | 8 | from ..data_structures import SmallKG 9 | from ..sampling import BernoulliNegativeSampler, UniformNegativeSampler 10 | from ..utils.data import get_n_batches 11 | 12 | from tqdm.autonotebook import tqdm 13 | 14 | 15 | class TrainDataLoader: 16 | """Dataloader providing the training process with batches of true and 17 | negatively sampled facts. 18 | 19 | Parameters 20 | ---------- 21 | kg: torchkge.data_structures.KnowledgeGraph 22 | Dataset to be divided in batches. 23 | batch_size: int 24 | Size of the batches. 25 | sampling_type: str 26 | Either 'unif' (uniform negative sampling) or 'bern' (Bernoulli negative 27 | sampling). 28 | use_cuda: str (opt, default = None) 29 | Can be either None (no use of cuda at all), 'all' to move all the 30 | dataset to cuda and then split in batches or 'batch' to simply move 31 | the batches to cuda before they are returned. 32 | 33 | """ 34 | 35 | def __init__(self, kg, batch_size, sampling_type, use_cuda=None): 36 | self.h = kg.head_idx 37 | self.t = kg.tail_idx 38 | self.r = kg.relations 39 | 40 | self.use_cuda = use_cuda 41 | self.b_size = batch_size 42 | self.iterator = None 43 | 44 | if sampling_type == 'unif': 45 | self.sampler = UniformNegativeSampler(kg) 46 | elif sampling_type == 'bern': 47 | self.sampler = BernoulliNegativeSampler(kg) 48 | 49 | self.tmp_cuda = use_cuda in ['batch', 'all'] 50 | 51 | if use_cuda is not None and use_cuda == 'all': 52 | self.h = self.h.cuda() 53 | self.t = self.t.cuda() 54 | self.r = self.r.cuda() 55 | 56 | def __len__(self): 57 | return get_n_batches(len(self.h), self.b_size) 58 | 59 | def __iter__(self): 60 | self.iterator = TrainDataLoaderIter(self) 61 | return self.iterator 62 | 63 | def get_counter_examples(self) -> SmallKG: 64 | return SmallKG(self.iterator.nh, self.iterator.nt, self.iterator.r) 65 | 66 | 67 | class TrainDataLoaderIter: 68 | def __init__(self, loader): 69 | self.h = loader.h 70 | self.t = loader.t 71 | self.r = loader.r 72 | 73 | self.nh, self.nt = loader.sampler.corrupt_kg(loader.b_size, 74 | loader.tmp_cuda) 75 | if loader.use_cuda: 76 | self.nh = self.nh.cuda() 77 | self.nt = self.nt.cuda() 78 | 79 | self.use_cuda = loader.use_cuda 80 | self.b_size = loader.b_size 81 | 82 | self.n_batches = get_n_batches(len(self.h), self.b_size) 83 | self.current_batch = 0 84 | 85 | def __next__(self): 86 | if self.current_batch == self.n_batches: 87 | raise StopIteration 88 | else: 89 | i = self.current_batch 90 | self.current_batch += 1 91 | 92 | batch = dict() 93 | batch['h'] = self.h[i * self.b_size: (i + 1) * self.b_size] 94 | batch['t'] = self.t[i * self.b_size: (i + 1) * self.b_size] 95 | batch['r'] = self.r[i * self.b_size: (i + 1) * self.b_size] 96 | batch['nh'] = self.nh[i * self.b_size: (i + 1) * self.b_size] 97 | batch['nt'] = self.nt[i * self.b_size: (i + 1) * self.b_size] 98 | 99 | if self.use_cuda == 'batch': 100 | batch['h'] = batch['h'].cuda() 101 | batch['t'] = batch['t'].cuda() 102 | batch['r'] = batch['r'].cuda() 103 | batch['nh'] = batch['nh'].cuda() 104 | batch['nt'] = batch['nt'].cuda() 105 | 106 | return batch 107 | 108 | def __iter__(self): 109 | return self 110 | 111 | 112 | class Trainer: 113 | """This class simply wraps a simple training procedure. 114 | 115 | Parameters 116 | ---------- 117 | model: torchkge.models.interfaces.Model 118 | Model to be trained. 119 | criterion: 120 | Criteria which should differentiate positive and negative scores. Can 121 | be an elements of torchkge.utils.losses 122 | kg_train: torchkge.data_structures.KnowledgeGraph 123 | KG used for training. 124 | n_epochs: int 125 | Number of epochs in the training procedure. 126 | batch_size: int 127 | Number of batches to use. 128 | sampling_type: str 129 | Either 'unif' (uniform negative sampling) or 'bern' (Bernoulli negative 130 | sampling). 131 | use_cuda: str (opt, default = None) 132 | Can be either None (no use of cuda at all), 'all' to move all the 133 | dataset to cuda and then split in batches or 'batch' to simply move 134 | the batches to cuda before they are returned. 135 | 136 | 137 | Attributes 138 | ---------- 139 | 140 | """ 141 | def __init__(self, model, criterion, kg_train, n_epochs, batch_size, 142 | optimizer, sampling_type='bern', use_cuda=None): 143 | 144 | self.model = model 145 | self.criterion = criterion 146 | self.kg_train = kg_train 147 | self.use_cuda = use_cuda 148 | self.n_epochs = n_epochs 149 | self.optimizer = optimizer 150 | self.sampling_type = sampling_type 151 | 152 | self.batch_size = batch_size 153 | self.n_triples = len(kg_train) 154 | self.counter_examples: Optional[SmallKG] = None 155 | 156 | def process_batch(self, current_batch): 157 | self.optimizer.zero_grad() 158 | 159 | h, t, r = current_batch['h'], current_batch['t'], current_batch['r'] 160 | nh, nt = current_batch['nh'], current_batch['nt'] 161 | 162 | p, n = self.model(h, t, r, nh, nt) 163 | loss = self.criterion(p, n) 164 | loss.backward() 165 | self.optimizer.step() 166 | 167 | return loss.detach().item() 168 | 169 | def run(self): 170 | if self.use_cuda in ['all', 'batch']: 171 | self.model.cuda() 172 | self.criterion.cuda() 173 | 174 | iterator = tqdm(range(self.n_epochs), unit='epoch') 175 | data_loader = TrainDataLoader(self.kg_train, 176 | batch_size=self.batch_size, 177 | sampling_type=self.sampling_type, 178 | use_cuda=self.use_cuda) 179 | self.counter_examples = data_loader.get_counter_examples() 180 | for epoch in iterator: 181 | sum_ = 0 182 | for i, batch in enumerate(data_loader): 183 | loss = self.process_batch(batch) 184 | sum_ += loss 185 | 186 | iterator.set_description( 187 | 'Epoch {} | mean loss: {:.5f}'.format(epoch + 1, sum_ / len(data_loader))) 188 | self.model.normalize_parameters() 189 | 190 | def get_counter_examples(self) -> Optional[SmallKG]: 191 | """ 192 | Retrieve the counter-examples generated while training the model. 193 | 194 | If the model has not been trained yet, return None 195 | 196 | Returns 197 | ------- 198 | A simple knowledge graph containing the triplets that were used as counter-examples during the training phase. 199 | """ 200 | return self.counter_examples 201 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 0.17.7 (2023-04-05) 5 | ------------------- 6 | * Adding additional pretrained models 7 | 8 | 0.17.6 (2023-03-31) 9 | ------------------- 10 | * Fix embedding dimension mixup in translation models 11 | * Fix implementation error in some bilinear models 12 | * Fix various docstring typos 13 | 14 | 0.17.5 (2022-09-18) 15 | ------------------- 16 | * Fix bug in TransH implementation 17 | 18 | 0.17.4 (2022-07-04) 19 | ------------------- 20 | * Upgrade dependencies 21 | * Fix normalization in translation models 22 | * Improve loading function of wikidatavitals datasets. 23 | 24 | 0.17.3 (2022-04-21) 25 | ------------------- 26 | * Fix ConvKB scoring function and normalization step 27 | 28 | 0.17.2 (2022-03-02) 29 | ------------------- 30 | * Fix the documentation in evaluation and inference modules 31 | * Fix a typo in the sampling module's documentation 32 | 33 | 0.17.1 (2022-02-25) 34 | ------------------- 35 | * Add support of Python 3.7 back 36 | 37 | 0.17.0 (2022-02-25) 38 | ------------------- 39 | * Add relation prediction evaluation 40 | * Add relation negative sampling module 41 | * Add inference module 42 | * Update models' API accordingly to the previous new features 43 | * Switch from TravisCI to GitHub Actions 44 | 45 | 0.16.25 (2021-03-01) 46 | -------------------- 47 | * Update in available pretrained models 48 | 49 | 0.16.24 (2021-02-16) 50 | -------------------- 51 | * Fix deployment 52 | 53 | 0.16.23 (2021-02-16) 54 | -------------------- 55 | * Removed useless k_max parameter in link-prediction evaluation method 56 | 57 | 0.16.22 (2021-02-05) 58 | -------------------- 59 | * Add pretrained version of TransE for yago310 and ComplEx for fb15k237 and wdv5. 60 | 61 | 0.16.21 (2021-02-02) 62 | -------------------- 63 | * Add pretrained version of TransE for Wikidata-Vitals level 5 64 | 65 | 0.16.20 (2021-01-22) 66 | -------------------- 67 | * Add support for Python 3.8 68 | * Clean up loading process for kgs 69 | * Fix deprecation warning 70 | 71 | 0.16.19 (2021-01-20) 72 | -------------------- 73 | * Fix release 74 | 75 | 0.16.18 (2021-01-20) 76 | -------------------- 77 | * Add data loader for wikidata vitals knowledge graphs 78 | 79 | 0.16.17 (2020-11-03) 80 | -------------------- 81 | * Bug fix get_ranks method 82 | 83 | 0.16.16 (2020-10-07) 84 | -------------------- 85 | * Bug fix in KG split method 86 | 87 | 0.16.15 (2020-10-07) 88 | -------------------- 89 | * Fix WikiDataSets loader (again) 90 | 91 | 0.16.14 (2020-09-21) 92 | -------------------- 93 | * Fix WikiDataSets loader 94 | 95 | 0.16.13 (2020-08-06) 96 | -------------------- 97 | * Fix reduction in BCE loss 98 | * Add pretrained models 99 | 100 | 0.16.12 (2020-07-07) 101 | -------------------- 102 | * Release patch 103 | 104 | 0.16.11 (2020-07-07) 105 | -------------------- 106 | * Fix bug in pre-trained models loading that made all models being redownloaded every time 107 | 108 | 0.16.10 (2020-07-02) 109 | -------------------- 110 | * Minor bug patch 111 | 112 | 0.16.9 (2020-07-02) 113 | ------------------- 114 | * Update urls to retrieve datasets and pre-trained models. 115 | 116 | 0.16.8 (2020-07-01) 117 | ------------------- 118 | * Add binary cross-entropy loss 119 | 120 | 0.16.7 (2020-06-23) 121 | ------------------- 122 | * Change API for pre-trained models 123 | 124 | 0.16.6 (2020-06-09) 125 | ------------------- 126 | * Patch in pre-trained model loading 127 | * Added pre-trained loading for TransE on FB15k237 in dimension 100. 128 | 129 | 0.16.5 (2020-06-02) 130 | ------------------- 131 | * Release patch 132 | 133 | 0.16.4 (2020-06-02) 134 | ------------------- 135 | * Add parameter in data redundancy to exclude know reverse triplets from 136 | duplicate search. 137 | 138 | 0.16.3 (2020-05-29) 139 | ------------------- 140 | * Release patch 141 | 142 | 0.16.2 (2020-05-29) 143 | ------------------- 144 | * Add methods to compute data redundancy in knowledge graphs as in 2020 145 | `paper `__ by Akrami et al 146 | (see references in concerned methods). 147 | 148 | 0.16.1 (2020-05-28) 149 | ------------------- 150 | * Patch an awkward import 151 | * Add dataset loaders for WN18RR and YAGO3-10 152 | 153 | 0.16.0 (2020-04-27) 154 | ------------------- 155 | * Redefinition of the models' API (simplified interfaces, renamed LP 156 | methods and added get_embeddings method) 157 | * Implementation of the new API for all models 158 | * TorusE implementation fixed 159 | * TransD reimplementation to avoid matmul usage (costly in 160 | back-propagation) 161 | * Added feature to negative samplers to generate several negative 162 | samples from each fact. Those can be fed directly to the models. 163 | * Added some wrappers for training to utils module. 164 | * Progress bars now make the most of tqdm's possibilities 165 | * Code reformatting 166 | * Docstrings update 167 | 168 | 0.15.5 (2020-04-23) 169 | ------------------- 170 | * Defined a new homemade and simpler DataLoader class. 171 | 172 | 0.15.4 (2020-04-22) 173 | ------------------- 174 | * Removed the use of torch DataLoader object. 175 | 176 | 0.15.3 (2020-04-02) 177 | ------------------- 178 | * Added a method to print results in link prediction evaluator 179 | 180 | 0.15.2 (2020-04-01) 181 | ------------------- 182 | * Fixed a misfit test 183 | 184 | 0.15.1 (2020-04-01) 185 | ------------------- 186 | * Cleared the definition of rank in link prediction 187 | 188 | 0.15.0 (2020-04-01) 189 | ------------------- 190 | * Improved use of tqdm progress bars 191 | 192 | 0.14.0 (2020-04-01) 193 | ------------------- 194 | * Change in the API of loss functions (margin and logistic loss) 195 | * Documentation update 196 | 197 | 0.13.0 (2020-02-10) 198 | ------------------- 199 | * Added ConvKB model 200 | 201 | 0.12.1 (2020-01-10) 202 | ------------------- 203 | * Minor patch in interfaces 204 | * Comment additions 205 | 206 | 0.12.0 (2019-12-05) 207 | ------------------- 208 | * Various bug fixes 209 | * New KG splitting method enforcing all entities and relations to appear at least once in the training set. 210 | 211 | 0.11.3 (2019-11-15) 212 | ------------------- 213 | * Minor bug fixes 214 | 215 | 0.11.2 (2019-11-11) 216 | ------------------- 217 | * Minor bug fixes 218 | 219 | 0.11.1 (2019-10-21) 220 | ------------------- 221 | * Fixed requirements conflicts 222 | 223 | 0.11.0 (2019-10-21) 224 | ------------------- 225 | * Added TorusE model 226 | * Added dataloaders 227 | * Fixed some bugs 228 | 229 | 0.10.4 (2019-10-07) 230 | ------------------- 231 | * Fixed error in bilinear models. 232 | 233 | 0.10.3 (2019-07-23) 234 | ------------------- 235 | * Added intermediate function for hit@k metric in link prediction. 236 | 237 | 0.10.2 (2019-07-22) 238 | ------------------- 239 | * Fixed assertion error in Analogy model 240 | 241 | 0.10.0 (2019-07-19) 242 | ------------------- 243 | * Implemented Triplet Classification evaluation method 244 | * Added Negative Sampler objects to standardize negative sampling methods. 245 | 246 | 247 | 0.9.0 (2019-07-17) 248 | ------------------ 249 | * Implemented HolE model (Nickel et al.) 250 | * Implemented ComplEx model (Trouillon et al.) 251 | * Implemented ANALOGY model (Liu et al.) 252 | * Added knowledge graph splitting into train, validation and test instead of just train and test. 253 | 254 | 0.8.0 (2019-07-09) 255 | ------------------ 256 | * Implemented Bernoulli negative sampling as in Wang et al. paper on TransH (2014). 257 | 258 | 0.7.0 (2019-07-01) 259 | ------------------ 260 | * Implemented Mean Reciprocal Rank measure of performance. 261 | * Implemented Logistic Loss. 262 | * Changed implementation of margin loss to use torch methods. 263 | 264 | 0.6.0 (2019-06-25) 265 | ------------------ 266 | * Implemented DistMult 267 | 268 | 0.5.0 (2019-06-24) 269 | ------------------ 270 | * Changed implementation of LinkPrediction ranks by moving functions to model methods. 271 | * Implemented RESCAL. 272 | 273 | 0.4.0 (2019-05-15) 274 | ------------------ 275 | * Fixed a major bug/problem in the Evaluation protocol of LinkPrediction. 276 | 277 | 0.3.1 (2019-05-10) 278 | ------------------ 279 | * Minor bug fixes in the various normalization functions. 280 | 281 | 0.3.0 (2019-05-09) 282 | ------------------ 283 | * Fixed CUDA support. 284 | 285 | 0.2.0 (2019-05-07) 286 | ------------------ 287 | * Added support for filtered performance measures. 288 | 289 | 0.1.7 (2019-04-03) 290 | ------------------ 291 | * First real release on PyPI. 292 | 293 | 0.1.0 (2019-04-01) 294 | ------------------ 295 | * First release on PyPI. 296 | -------------------------------------------------------------------------------- /torchkge/utils/data_redundancy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | 6 | This module contains functions implementing methods explained in `this 7 | paper`__ by Akrami et al. 8 | """ 9 | from itertools import combinations 10 | from torch import cat 11 | from tqdm.autonotebook import tqdm 12 | 13 | 14 | def concat_kgs(kg_tr, kg_val, kg_te): 15 | h = cat((kg_tr.head_idx, kg_val.head_idx, kg_te.head_idx)) 16 | t = cat((kg_tr.tail_idx, kg_val.tail_idx, kg_te.tail_idx)) 17 | r = cat((kg_tr.relations, kg_val.relations, kg_te.relations)) 18 | return h, t, r 19 | 20 | 21 | def get_pairs(kg, r, type='ht'): 22 | mask = (kg.relations == r) 23 | 24 | if type == 'ht': 25 | return set((i.item(), j.item()) for i, j in cat( 26 | (kg.head_idx[mask].view(-1, 1), 27 | kg.tail_idx[mask].view(-1, 1)), dim=1)) 28 | else: 29 | assert type == 'th' 30 | return set((j.item(), i.item()) for i, j in cat( 31 | (kg.head_idx[mask].view(-1, 1), 32 | kg.tail_idx[mask].view(-1, 1)), dim=1)) 33 | 34 | 35 | def count_triplets(kg1, kg2, duplicates, rev_duplicates): 36 | """ 37 | Parameters 38 | ---------- 39 | kg1: torchkge.data_structures.KnowledgeGraph 40 | kg2: torchkge.data_structures.KnowledgeGraph 41 | duplicates: list 42 | List returned by torchkge.utils.data_redundancy.duplicates. 43 | rev_duplicates: list 44 | List returned by torchkge.utils.data_redundancy.duplicates. 45 | 46 | Returns 47 | ------- 48 | n_duplicates: int 49 | Number of triplets in kg2 that have their duplicate triplet 50 | in kg1 51 | n_rev_duplicates: int 52 | Number of triplets in kg2 that have their reverse duplicate 53 | triplet in kg1. 54 | """ 55 | n_duplicates = 0 56 | for r1, r2 in duplicates: 57 | ht_tr = get_pairs(kg1, r2, type='ht') 58 | ht_te = get_pairs(kg2, r1, type='ht') 59 | 60 | n_duplicates += len(ht_te.intersection(ht_tr)) 61 | 62 | ht_tr = get_pairs(kg1, r1, type='ht') 63 | ht_te = get_pairs(kg2, r2, type='ht') 64 | 65 | n_duplicates += len(ht_te.intersection(ht_tr)) 66 | 67 | n_rev_duplicates = 0 68 | for r1, r2 in rev_duplicates: 69 | th_tr = get_pairs(kg1, r2, type='th') 70 | ht_te = get_pairs(kg2, r1, type='ht') 71 | 72 | n_rev_duplicates += len(ht_te.intersection(th_tr)) 73 | 74 | th_tr = get_pairs(kg1, r1, type='th') 75 | ht_te = get_pairs(kg2, r2, type='ht') 76 | 77 | n_rev_duplicates += len(ht_te.intersection(th_tr)) 78 | 79 | return n_duplicates, n_rev_duplicates 80 | 81 | 82 | def duplicates(kg_tr, kg_val, kg_te, theta1=0.8, theta2=0.8, 83 | verbose=False, counts=False, reverses=None): 84 | """Return the duplicate and reverse duplicate relations as explained 85 | in paper by Akrami et al. 86 | 87 | References 88 | ---------- 89 | * Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang. 90 | `Realistic Re-evaluation of Knowledge Graph Completion Methods: 91 | An Experimental Study. `_ 92 | SIGMOD’20, June 14–19, 2020, Portland, OR, USA 93 | 94 | Parameters 95 | ---------- 96 | kg_tr: torchkge.data_structures.KnowledgeGraph 97 | Train set 98 | kg_val: torchkge.data_structures.KnowledgeGraph 99 | Validation set 100 | kg_te: torchkge.data_structures.KnowledgeGraph 101 | Test set 102 | theta1: float 103 | First threshold (see paper). 104 | theta2: float 105 | Second threshold (see paper). 106 | verbose: bool 107 | counts: bool 108 | Should the triplets involving (reverse) duplicate relations be 109 | counted in all sets. 110 | reverses: list 111 | List of known reverse relations. 112 | 113 | Returns 114 | ------- 115 | duplicates: list 116 | List of pairs giving duplicate relations. 117 | rev_duplicates: list 118 | List of pairs giving reverse duplicate relations. 119 | """ 120 | if verbose: 121 | print('Computing Ts') 122 | 123 | if reverses is None: 124 | reverses = [] 125 | 126 | T = dict() 127 | T_inv = dict() 128 | lengths = dict() 129 | 130 | h, t, r = concat_kgs(kg_tr, kg_val, kg_te) 131 | 132 | for r_ in tqdm(range(kg_tr.n_rel)): 133 | mask = (r == r_) 134 | lengths[r_] = mask.sum().item() 135 | 136 | pairs = cat((h[mask].view(-1, 1), t[mask].view(-1, 1)), dim=1) 137 | 138 | T[r_] = set([(h_.item(), t_.item()) for h_, t_ in pairs]) 139 | T_inv[r_] = set([(t_.item(), h_.item()) for h_, t_ in pairs]) 140 | 141 | if verbose: 142 | print('Finding duplicate relations') 143 | 144 | duplicates = [] 145 | rev_duplicates = [] 146 | 147 | iter_ = list(combinations(range(1345), 2)) 148 | 149 | for r1, r2 in tqdm(iter_): 150 | a = len(T[r1].intersection(T[r2])) / lengths[r1] 151 | b = len(T[r1].intersection(T[r2])) / lengths[r2] 152 | 153 | if a > theta1 and b > theta2: 154 | duplicates.append((r1, r2)) 155 | 156 | if (r1, r2) not in reverses: 157 | a = len(T[r1].intersection(T_inv[r2])) / lengths[r1] 158 | b = len(T[r1].intersection(T_inv[r2])) / lengths[r2] 159 | 160 | if a > theta1 and b > theta2: 161 | rev_duplicates.append((r1, r2)) 162 | 163 | if verbose: 164 | print('Duplicate relations: {}'.format(len(duplicates))) 165 | print('Reverse duplicate relations: ' 166 | '{}\n'.format(len(rev_duplicates))) 167 | 168 | if counts: 169 | dupl, rev = count_triplets(kg_tr, kg_tr, duplicates, rev_duplicates) 170 | print('{} train triplets have duplicate in train set ' 171 | '({}%)'.format(dupl, int(dupl / len(kg_tr)))) 172 | print('{} train triplets have reverse duplicate in train set ' 173 | '({}%)\n'.format(rev, int(rev / len(kg_tr) * 100))) 174 | 175 | dupl, rev = count_triplets(kg_tr, kg_te, duplicates, rev_duplicates) 176 | print('{} test triplets have duplicate in train set ' 177 | '({}%)'.format(dupl, int(dupl / len(kg_te)))) 178 | print('{} test triplets have reverse duplicate in train set ' 179 | '({}%)\n'.format(rev, int(rev / len(kg_te) * 100))) 180 | 181 | dupl, rev = count_triplets(kg_te, kg_te, duplicates, rev_duplicates) 182 | print('{} test triplets have duplicate in test set ' 183 | '({}%)'.format(dupl, int(dupl / len(kg_te)))) 184 | print('{} test triplets have reverse duplicate in test set ' 185 | '({}%)\n'.format(rev, int(rev / len(kg_te) * 100))) 186 | 187 | return duplicates, rev_duplicates 188 | 189 | 190 | def cartesian_product_relations(kg_tr, kg_val, kg_te, theta=0.8): 191 | """Return the cartesian product relations as explained in paper by 192 | Akrami et al. 193 | 194 | References 195 | ---------- 196 | * Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang. 197 | `Realistic Re-evaluation of Knowledge Graph Completion Methods: An 198 | Experimental Study. `_ 199 | SIGMOD’20, June 14–19, 2020, Portland, OR, USA 200 | 201 | Parameters 202 | ---------- 203 | kg_tr: torchkge.data_structures.KnowledgeGraph 204 | Train set 205 | kg_val: torchkge.data_structures.KnowledgeGraph 206 | Validation set 207 | kg_te: torchkge.data_structures.KnowledgeGraph 208 | Test set 209 | theta: float 210 | Threshold used to compute the cartesian product relations. 211 | 212 | Returns 213 | ------- 214 | selected_relations: list 215 | List of relations index that are cartesian product relations 216 | (see paper for details). 217 | 218 | """ 219 | selected_relations = [] 220 | 221 | h, t, r = concat_kgs(kg_tr, kg_val, kg_te) 222 | 223 | S = dict() 224 | O = dict() 225 | lengths = dict() 226 | 227 | for r_ in tqdm(range(kg_tr.n_rel)): 228 | mask = (r == r_) 229 | lengths[r_] = mask.sum().item() 230 | 231 | S[r_] = set(h_.item() for h_ in h[mask]) 232 | O[r_] = set(t_.item() for t_ in t[mask]) 233 | 234 | if lengths[r_] / (len(S[r_]) * len(O[r_])) > theta: 235 | selected_relations.append(r_) 236 | 237 | return selected_relations 238 | -------------------------------------------------------------------------------- /torchkge/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | from torch import empty, tensor 7 | from tqdm.autonotebook import tqdm 8 | 9 | from .exceptions import WrongArgumentsError 10 | from .utils import filter_scores 11 | from .utils.data import get_n_batches 12 | 13 | 14 | class DataLoader_: 15 | """This class is inspired from :class:`torch.utils.dataloader.DataLoader`. 16 | It is however way simpler. 17 | 18 | """ 19 | def __init__(self, a, b, batch_size, use_cuda=None): 20 | """ 21 | 22 | Parameters 23 | ---------- 24 | batch_size: int 25 | Size of the required batches. 26 | use_cuda: str (opt, default = None) 27 | Can be either None (no use of cuda at all), 'all' to move all the 28 | dataset to cuda and then split in batches or 'batch' to simply move 29 | the batches to cuda before they are returned. 30 | """ 31 | self.a = a 32 | self.b = b 33 | 34 | self.use_cuda = use_cuda 35 | self.batch_size = batch_size 36 | 37 | if use_cuda is not None and use_cuda == 'all': 38 | self.a = self.a.cuda() 39 | self.b = self.b.cuda() 40 | 41 | def __len__(self): 42 | return get_n_batches(len(self.a), self.batch_size) 43 | 44 | def __iter__(self): 45 | return _DataLoaderIter(self) 46 | 47 | 48 | class _DataLoaderIter: 49 | def __init__(self, loader): 50 | self.a = loader.a 51 | self.b = loader.b 52 | 53 | self.use_cuda = loader.use_cuda 54 | self.batch_size = loader.batch_size 55 | 56 | self.n_batches = get_n_batches(len(self.a), self.batch_size) 57 | self.current_batch = 0 58 | 59 | def __next__(self): 60 | if self.current_batch == self.n_batches: 61 | raise StopIteration 62 | else: 63 | i = self.current_batch 64 | self.current_batch += 1 65 | 66 | tmp_a = self.a[i * self.batch_size: (i + 1) * self.batch_size] 67 | tmp_b = self.b[i * self.batch_size: (i + 1) * self.batch_size] 68 | 69 | if self.use_cuda is not None and self.use_cuda == 'batch': 70 | return tmp_a.cuda(), tmp_b.cuda() 71 | else: 72 | return tmp_a, tmp_b 73 | 74 | def __iter__(self): 75 | return self 76 | 77 | 78 | class RelationInference(object): 79 | """Use trained embedding model to infer missing relations in triples. 80 | 81 | Parameters 82 | ---------- 83 | model: torchkge.models.interfaces.Model 84 | Embedding model inheriting from the right interface. 85 | entities1: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 86 | List of the indices of known entities 1. 87 | entities2: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 88 | List of the indices of known entities 2. 89 | top_k: int 90 | Indicates the number of top predictions to return. 91 | dictionary: dict, optional (default=None) 92 | Dictionary of possible relations. It is used to filter predictions 93 | that are known to be True in the training set in order to return 94 | only new facts. 95 | 96 | Attributes 97 | ---------- 98 | model: torchkge.models.interfaces.Model 99 | Embedding model inheriting from the right interface. 100 | entities1: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 101 | List of the indices of known entities 1. 102 | entities2: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 103 | List of the indices of known entities 2. 104 | top_k: int 105 | Indicates the number of top predictions to return. 106 | dictionary: dict, optional (default=None) 107 | Dictionary of possible relations. It is used to filter predictions 108 | that are known to be True in the training set in order to return 109 | only new facts. 110 | predictions: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.long` 111 | List of the indices of predicted relations for each test fact. 112 | scores: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.float` 113 | List of the scores of resulting triples for each test fact. 114 | """ 115 | # TODO: add the possibility to infer link orientation as well. 116 | 117 | def __init__(self, model, entities1, entities2, top_k=1, dictionary=None): 118 | 119 | self.model = model 120 | self.entities1 = entities1 121 | self.entities2 = entities2 122 | self.topk = top_k 123 | self.dictionary = dictionary 124 | 125 | self.predictions = empty(size=(len(entities1), top_k)).long() 126 | self.scores = empty(size=(len(entities2), top_k)) 127 | 128 | def evaluate(self, b_size, verbose=True): 129 | use_cuda = next(self.model.parameters()).is_cuda 130 | 131 | if use_cuda: 132 | dataloader = DataLoader_(self.entities1, self.entities2, batch_size=b_size, use_cuda='batch') 133 | self.predictions = self.predictions.cuda() 134 | else: 135 | dataloader = DataLoader_(self.entities1, self.entities2, batch_size=b_size) 136 | 137 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 138 | unit='batch', disable=(not verbose), 139 | desc='Inference'): 140 | ents1, ents2 = batch[0], batch[1] 141 | h_emb, t_emb, _, candidates = self.model.inference_prepare_candidates(ents1, ents2, tensor([]).long(), 142 | entities=False) 143 | scores = self.model.inference_scoring_function(h_emb, t_emb, candidates) 144 | 145 | if self.dictionary is not None: 146 | scores = filter_scores(scores, self.dictionary, ents1, ents2, None) 147 | 148 | scores, indices = scores.sort(descending=True) 149 | 150 | self.predictions[i * b_size: (i + 1) * b_size] = indices[:, :self.topk] 151 | self.scores[i * b_size, (i + 1) * b_size] = scores[:, :self.topk] 152 | 153 | if use_cuda: 154 | self.predictions = self.predictions.cpu() 155 | self.scores = self.scores.cpu() 156 | 157 | 158 | class EntityInference(object): 159 | """Use trained embedding model to infer missing entities in triples. 160 | 161 | Parameters 162 | ---------- 163 | model: torchkge.models.interfaces.Model 164 | Embedding model inheriting from the right interface. 165 | known_entities: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 166 | List of the indices of known entities. 167 | known_relations: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 168 | List of the indices of known relations. 169 | top_k: int 170 | Indicates the number of top predictions to return. 171 | missing: str 172 | String indicating if the missing entities are the heads or the tails. 173 | dictionary: dict, optional (default=None) 174 | Dictionary of possible heads or tails (depending on the value of `missing`). 175 | It is used to filter predictions that are known to be True in the training set 176 | in order to return only new facts. 177 | 178 | Attributes 179 | ---------- 180 | model: torchkge.models.interfaces.Model 181 | Embedding model inheriting from the right interface. 182 | known_entities: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 183 | List of the indices of known entities. 184 | known_relations: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 185 | List of the indices of known relations. 186 | top_k: int 187 | Indicates the number of top predictions to return. 188 | missing: str 189 | String indicating if the missing entities are the heads or the tails. 190 | dictionary: dict, optional (default=None) 191 | Dictionary of possible heads or tails (depending on the value of `missing`). 192 | It is used to filter predictions that are known to be True in the training set 193 | in order to return only new facts. 194 | predictions: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.long` 195 | List of the indices of predicted entities for each test fact. 196 | scores: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.float` 197 | List of the scores of resulting triples for each test fact. 198 | 199 | """ 200 | def __init__(self, model, known_entities, known_relations, top_k=1, missing='tails', dictionary=None): 201 | try: 202 | assert missing in ['heads', 'tails'] 203 | self.missing = missing 204 | except AssertionError: 205 | raise WrongArgumentsError("missing entity should either be 'heads' or 'tails'") 206 | self.model = model 207 | self.known_entities = known_entities 208 | self.known_relations = known_relations 209 | self.missing = missing 210 | self.top_k = top_k 211 | self.dictionary = dictionary 212 | 213 | self.predictions = empty(size=(len(known_entities), top_k)).long() 214 | self.scores = empty(size=(len(known_entities), top_k)) 215 | 216 | def evaluate(self, b_size, verbose=True): 217 | use_cuda = next(self.model.parameters()).is_cuda 218 | 219 | if use_cuda: 220 | dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size, use_cuda='batch') 221 | self.predictions = self.predictions.cuda() 222 | else: 223 | dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size) 224 | 225 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 226 | unit='batch', disable=(not verbose), 227 | desc='Inference'): 228 | known_ents, known_rels = batch[0], batch[1] 229 | if self.missing == 'heads': 230 | _, t_emb, rel_emb, candidates = self.model.inference_prepare_candidates(tensor([]).long(), known_ents, 231 | known_rels, 232 | entities=True) 233 | scores = self.model.inference_scoring_function(candidates, t_emb, rel_emb) 234 | else: 235 | h_emb, _, rel_emb, candidates = self.model.inference_prepare_candidates(known_ents, tensor([]).long(), 236 | known_rels, 237 | entities=True) 238 | scores = self.model.inference_scoring_function(h_emb, candidates, rel_emb) 239 | 240 | if self.dictionary is not None: 241 | scores = filter_scores(scores, self.dictionary, known_ents, known_rels, None) 242 | 243 | scores, indices = scores.sort(descending=True) 244 | 245 | self.predictions[i * b_size: (i+1)*b_size] = indices[:, :self.top_k] 246 | self.scores[i*b_size, (i+1)*b_size] = scores[:, :self.top_k] 247 | 248 | if use_cuda: 249 | self.predictions = self.predictions.cpu() 250 | self.scores = self.scores.cpu() 251 | -------------------------------------------------------------------------------- /torchkge/models/interfaces.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch.nn import Module 8 | 9 | from ..utils.dissimilarities import l1_dissimilarity, l2_dissimilarity, \ 10 | l1_torus_dissimilarity, l2_torus_dissimilarity, el2_torus_dissimilarity 11 | 12 | 13 | class Model(Module): 14 | """Model interface to be used by any other class implementing a knowledge 15 | graph embedding model. It is only 16 | required to implement the methods `scoring_function`, 17 | `normalize_parameters`, `inference_prepare_candidates` and `inference_scoring_function`. 18 | 19 | Parameters 20 | ---------- 21 | n_entities: int 22 | Number of entities to be embedded. 23 | n_relations: int 24 | Number of relations to be embedded. 25 | 26 | Attributes 27 | ---------- 28 | n_ent: int 29 | Number of entities to be embedded. 30 | n_rel: int 31 | Number of relations to be embedded. 32 | 33 | """ 34 | def __init__(self, n_entities, n_relations): 35 | super().__init__() 36 | self.n_ent = n_entities 37 | self.n_rel = n_relations 38 | 39 | def forward(self, heads, tails, relations, negative_heads, negative_tails, negative_relations=None): 40 | """ 41 | 42 | Parameters 43 | ---------- 44 | heads: torch.Tensor, dtype: torch.long, shape: (batch_size) 45 | Integer keys of the current batch's heads 46 | tails: torch.Tensor, dtype: torch.long, shape: (batch_size) 47 | Integer keys of the current batch's tails. 48 | relations: torch.Tensor, dtype: torch.long, shape: (batch_size) 49 | Integer keys of the current batch's relations. 50 | negative_heads: torch.Tensor, dtype: torch.long, shape: (batch_size) 51 | Integer keys of the current batch's negatively sampled heads. 52 | negative_tails: torch.Tensor, dtype: torch.long, shape: (batch_size) 53 | Integer keys of the current batch's negatively sampled tails.ze) 54 | negative_relations: torch.Tensor, dtype: torch.long, shape: (batch_size) 55 | Integer keys of the current batch's negatively sampled relations. 56 | 57 | Returns 58 | ------- 59 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 60 | Scoring function evaluated on true triples. 61 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 62 | Scoring function evaluated on negatively sampled triples. 63 | 64 | """ 65 | pos = self.scoring_function(heads, tails, relations) 66 | 67 | if negative_relations is None: 68 | negative_relations = relations 69 | 70 | if negative_heads.shape[0] > negative_relations.shape[0]: 71 | # in that case, several negative samples are sampled from each fact 72 | n_neg = int(negative_heads.shape[0] / negative_relations.shape[0]) 73 | pos = pos.repeat(n_neg) 74 | neg = self.scoring_function(negative_heads, 75 | negative_tails, 76 | negative_relations.repeat(n_neg)) 77 | else: 78 | neg = self.scoring_function(negative_heads, 79 | negative_tails, 80 | negative_relations) 81 | 82 | return pos, neg 83 | 84 | def scoring_function(self, h_idx, t_idx, r_idx): 85 | """Compute the scoring function for the triplets given as argument. 86 | 87 | Parameters 88 | ---------- 89 | h_idx: torch.Tensor, dtype: torch.long, shape: (b_size) 90 | Integer keys of the current batch's heads 91 | t_idx: torch.Tensor, dtype: torch.long, shape: (b_size) 92 | Integer keys of the current batch's tails. 93 | r_idx: torch.Tensor, dtype: torch.long, shape: (b_size) 94 | Integer keys of the current batch's relations. 95 | 96 | Returns 97 | ------- 98 | score: torch.Tensor, dtype: torch.float, shape: (b_size) 99 | Score of each triplet. 100 | 101 | """ 102 | raise NotImplementedError 103 | 104 | def normalize_parameters(self): 105 | """Normalize some parameters. This methods should be end at the end of 106 | each training epoch and at the end of training as well. 107 | 108 | """ 109 | raise NotImplementedError 110 | 111 | def get_embeddings(self): 112 | """Return the tensors representing entities and relations in current 113 | model. 114 | 115 | """ 116 | raise NotImplementedError 117 | 118 | def inference_scoring_function(self, h, t, r): 119 | """ Link prediction evaluation helper function. Compute the scores of 120 | (h, r, c) or (c, r, t) for any candidate c. The arguments should 121 | match the ones of `inference_prepare_candidates`. 122 | 123 | Parameters 124 | ---------- 125 | h: torch.Tensor, shape: (b_size, ent_emb_dim) or (b_size, n_ent, 126 | ent_emb_dim), dtype: torch.float 127 | t: torch.Tensor, shape: (b_size, ent_emb_dim) or (b_size, n_ent, 128 | ent_emb_dim), dtype: torch.float 129 | r: torch.Tensor, shape: (b_size, ent_emb_dim) or (b_size, n_rel, 130 | ent_emb_dim), dtype: torch.float 131 | 132 | Returns 133 | ------- 134 | scores: torch.Tensor, shape: (b_size, n_ent), dtype: torch.float 135 | Scores of each candidate for each triple. 136 | """ 137 | raise NotImplementedError 138 | 139 | def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): 140 | """Link prediction evaluation helper function. Get entities and 141 | relations embeddings, along with entity candidates ready for (projected 142 | if needed). The output will be fed to the `inference_scoring_function` 143 | method of the model at hand. 144 | 145 | Parameters 146 | ---------- 147 | h_idx: torch.Tensor, shape: (b_size), dtype: torch.long 148 | List of heads indices. 149 | t_idx: torch.Tensor, shape: (b_size), dtype: torch.long 150 | List of tails indices. 151 | r_idx: torch.Tensor, shape: (b_size), dtype: torch.long 152 | List of relations indices. 153 | entities: bool 154 | Boolean indicating if candidates are entities or not. 155 | 156 | Returns 157 | ------- 158 | h: torch.Tensor, shape: (b_size, rel_emb_dim), dtype: torch.float 159 | Head vectors fed to `inference_scoring_function`. For translation 160 | models it is the entities embeddings projected in relation space, 161 | for example. 162 | t: torch.Tensor, shape: (b_size, rel_emb_dim), dtype: torch.float 163 | Tail vectors fed to `inference_scoring_function`. For translation 164 | models it is the entities embeddings projected in relation space, 165 | for example. 166 | candidates: torch.Tensor, shape: (b_size, rel_emb_dim, n_ent), 167 | dtype: torch.float 168 | All entities embeddings prepared from batch evaluation. Axis 0 is 169 | simply duplication. 170 | r: torch.Tensor, shape: (b_size, rel_emb_dim), dtype: torch.float 171 | Relations embeddings or matrices. 172 | 173 | """ 174 | raise NotImplementedError 175 | 176 | 177 | class TranslationModel(Model): 178 | """Model interface to be used by any other class implementing a 179 | translation knowledge graph embedding model. This interface inherits from 180 | the interface :class:`torchkge.models.interfaces.Model`. It is only 181 | required to implement the methods `scoring_function`, 182 | `normalize_parameters` and `inference_prepare_candidates`. 183 | 184 | Parameters 185 | ---------- 186 | n_entities: int 187 | Number of entities to be embedded. 188 | n_relations: int 189 | Number of relations to be embedded. 190 | dissimilarity_type: str 191 | One of 'L1', 'L2', 'toruse_L1', 'toruse_L2' and 'toruse_eL2'. 192 | 193 | Attributes 194 | ---------- 195 | dissimilarity: function 196 | Dissimilarity function. 197 | 198 | """ 199 | def __init__(self, n_entities, n_relations, dissimilarity_type): 200 | super().__init__(n_entities, n_relations) 201 | 202 | assert dissimilarity_type in ['L1', 'L2', 'torus_L1', 'torus_L2', 203 | 'torus_eL2'] 204 | 205 | if dissimilarity_type == 'L1': 206 | self.dissimilarity = l1_dissimilarity 207 | elif dissimilarity_type == 'L2': 208 | self.dissimilarity = l2_dissimilarity 209 | elif dissimilarity_type == 'torus_L1': 210 | self.dissimilarity = l1_torus_dissimilarity 211 | elif dissimilarity_type == 'torus_L2': 212 | self.dissimilarity = l2_torus_dissimilarity 213 | else: 214 | self.dissimilarity = el2_torus_dissimilarity 215 | 216 | def scoring_function(self, h_idx, t_idx, r_idx): 217 | """See torchkge.models.interfaces.Models. 218 | 219 | """ 220 | raise NotImplementedError 221 | 222 | def normalize_parameters(self): 223 | """See torchkge.models.interfaces.Models. 224 | 225 | """ 226 | raise NotImplementedError 227 | 228 | def get_embeddings(self): 229 | """See torchkge.models.interfaces.Models. 230 | 231 | """ 232 | raise NotImplementedError 233 | 234 | def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): 235 | """See torchkge.models.interfaces.Models. 236 | 237 | """ 238 | raise NotImplementedError 239 | 240 | def inference_scoring_function(self, proj_h, proj_t, r): 241 | """This overwrites the method declared in 242 | torchkge.models.interfaces.Models. For translation models, the computed 243 | score is the dissimilarity of between projected heads + relations and 244 | projected tails. Projections are done in relation-specific subspaces. 245 | 246 | """ 247 | b_size = proj_h.shape[0] 248 | 249 | if len(r.shape) == 2: 250 | if len(proj_t.shape) == 3: 251 | assert (len(proj_h.shape) == 2) 252 | # this is the tail completion case in link prediction 253 | hr = (proj_h + r).view(b_size, 1, r.shape[1]) 254 | return - self.dissimilarity(hr, proj_t) 255 | else: 256 | assert (len(proj_h.shape) == 3) & (len(proj_t.shape) == 2) 257 | # this is the head completion case in link prediction 258 | r_ = r.view(b_size, 1, r.shape[1]) 259 | t_ = proj_t.view(b_size, 1, r.shape[1]) 260 | return - self.dissimilarity(proj_h + r_, t_) 261 | elif len(r.shape) == 3: 262 | # this is the relation prediction case 263 | # Two cases possible: 264 | # * proj_ent.shape == (b_size, self.n_rel, self.emb_dim) -> projection depending on relations 265 | # * proj_ent.shape == (b_size, self.emb_dim) -> no projection 266 | try: 267 | proj_h = proj_h.view(b_size, -1, self.emb_dim) 268 | proj_t = proj_t.view(b_size, -1, self.emb_dim) 269 | except AttributeError: 270 | proj_h = proj_h.view(b_size, -1, self.rel_emb_dim) 271 | proj_t = proj_t.view(b_size, -1, self.rel_emb_dim) 272 | return - self.dissimilarity(proj_h + r, proj_t) 273 | 274 | 275 | class BilinearModel(Model): 276 | """Model interface to be used by any other class implementing a 277 | bilinear knowledge graph embedding model. This interface inherits from 278 | the interface :class:`torchkge.models.interfaces.Model`. It is only 279 | required to implement the methods `scoring_function`, 280 | `normalize_parameters`, `inference_prepare_candidates` and `inference_scoring_function`. 281 | 282 | Parameters 283 | ---------- 284 | n_entities: int 285 | Number of entities to be embedded. 286 | n_relations: int 287 | Number of relations to be embedded. 288 | emb_dim: int 289 | Dimension of the embedding space. 290 | 291 | Attributes 292 | ---------- 293 | emb_dim: int 294 | Dimension of the embedding space. 295 | 296 | """ 297 | 298 | def __init__(self, emb_dim, n_entities, n_relations): 299 | super().__init__(n_entities, n_relations) 300 | self.emb_dim = emb_dim 301 | 302 | def scoring_function(self, h_idx, t_idx, r_idx): 303 | """See torchkge.models.interfaces.Models. 304 | 305 | """ 306 | raise NotImplementedError 307 | 308 | def normalize_parameters(self): 309 | """See torchkge.models.interfaces.Models. 310 | 311 | """ 312 | raise NotImplementedError 313 | 314 | def get_embeddings(self): 315 | """See torchkge.models.interfaces.Models. 316 | 317 | """ 318 | raise NotImplementedError 319 | 320 | def inference_scoring_function(self, h, t, r): 321 | """See torchkge.models.interfaces.Models. 322 | 323 | """ 324 | raise NotImplementedError 325 | 326 | def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): 327 | """See torchkge.models.interfaces.Models. 328 | 329 | """ 330 | raise NotImplementedError 331 | -------------------------------------------------------------------------------- /torchkge/utils/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | 6 | This module's code is freely adapted from Scikit-Learn's 7 | sklearn.datasets.base.py code. 8 | """ 9 | 10 | import pickle 11 | import tarfile 12 | import zipfile 13 | 14 | from os import makedirs, remove 15 | from os.path import exists 16 | from pandas import concat, DataFrame, merge, read_csv 17 | from urllib.request import urlretrieve 18 | 19 | from torchkge.data_structures import KnowledgeGraph 20 | 21 | from torchkge.utils import get_data_home, safe_extract 22 | from torchkge.utils.operations import extend_dicts 23 | 24 | 25 | def load_fb13(data_home=None): 26 | """Load FB13 dataset. 27 | 28 | Parameters 29 | ---------- 30 | data_home: str, optional 31 | Path to the `torchkge_data` directory (containing data folders). If 32 | files are not present on disk in this directory, they are downloaded 33 | and then placed in the right place. 34 | 35 | Returns 36 | ------- 37 | kg_train: torchkge.data_structures.KnowledgeGraph 38 | kg_val: torchkge.data_structures.KnowledgeGraph 39 | kg_test: torchkge.data_structures.KnowledgeGraph 40 | 41 | """ 42 | if data_home is None: 43 | data_home = get_data_home() 44 | data_path = data_home + '/FB13' 45 | if not exists(data_path): 46 | makedirs(data_path, exist_ok=True) 47 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB13.zip", 48 | data_home + '/FB13.zip') 49 | with zipfile.ZipFile(data_home + '/FB13.zip', 'r') as zip_ref: 50 | zip_ref.extractall(data_home) 51 | remove(data_home + '/FB13.zip') 52 | 53 | df1 = read_csv(data_path + '/train2id.txt', 54 | sep='\t', header=None, names=['from', 'rel', 'to']) 55 | df2 = read_csv(data_path + '/valid2id.txt', 56 | sep='\t', header=None, names=['from', 'rel', 'to']) 57 | df3 = read_csv(data_path + '/test2id.txt', 58 | sep='\t', header=None, names=['from', 'rel', 'to']) 59 | df = concat([df1, df2, df3]) 60 | kg = KnowledgeGraph(df) 61 | 62 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 63 | 64 | 65 | def load_fb15k(data_home=None): 66 | """Load FB15k dataset. See `here 67 | `__ 68 | for paper by Bordes et al. originally presenting the dataset. 69 | 70 | Parameters 71 | ---------- 72 | data_home: str, optional 73 | Path to the `torchkge_data` directory (containing data folders). If 74 | files are not present on disk in this directory, they are downloaded 75 | and then placed in the right place. 76 | 77 | Returns 78 | ------- 79 | kg_train: torchkge.data_structures.KnowledgeGraph 80 | kg_val: torchkge.data_structures.KnowledgeGraph 81 | kg_test: torchkge.data_structures.KnowledgeGraph 82 | 83 | """ 84 | if data_home is None: 85 | data_home = get_data_home() 86 | data_path = data_home + '/FB15k' 87 | if not exists(data_path): 88 | makedirs(data_path, exist_ok=True) 89 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB15k.zip", 90 | data_home + '/FB15k.zip') 91 | with zipfile.ZipFile(data_home + '/FB15k.zip', 'r') as zip_ref: 92 | zip_ref.extractall(data_home) 93 | remove(data_home + '/FB15k.zip') 94 | 95 | df1 = read_csv(data_path + '/freebase_mtr100_mte100-train.txt', 96 | sep='\t', header=None, names=['from', 'rel', 'to']) 97 | df2 = read_csv(data_path + '/freebase_mtr100_mte100-valid.txt', 98 | sep='\t', header=None, names=['from', 'rel', 'to']) 99 | df3 = read_csv(data_path + '/freebase_mtr100_mte100-test.txt', 100 | sep='\t', header=None, names=['from', 'rel', 'to']) 101 | df = concat([df1, df2, df3]) 102 | kg = KnowledgeGraph(df) 103 | 104 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 105 | 106 | 107 | def load_fb15k237(data_home=None): 108 | """Load FB15k237 dataset. See `here 109 | `__ for paper by Toutanova et 110 | al. originally presenting the dataset. 111 | 112 | Parameters 113 | ---------- 114 | data_home: str, optional 115 | Path to the `torchkge_data` directory (containing data folders). If 116 | files are not present on disk in this directory, they are downloaded 117 | and then placed in the right place. 118 | 119 | Returns 120 | ------- 121 | kg_train: torchkge.data_structures.KnowledgeGraph 122 | kg_val: torchkge.data_structures.KnowledgeGraph 123 | kg_test: torchkge.data_structures.KnowledgeGraph 124 | 125 | """ 126 | if data_home is None: 127 | data_home = get_data_home() 128 | data_path = data_home + '/FB15k237' 129 | if not exists(data_path): 130 | makedirs(data_path, exist_ok=True) 131 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB15k237.zip", 132 | data_home + '/FB15k237.zip') 133 | with zipfile.ZipFile(data_home + '/FB15k237.zip', 'r') as zip_ref: 134 | zip_ref.extractall(data_home) 135 | remove(data_home + '/FB15k237.zip') 136 | 137 | df1 = read_csv(data_path + '/train.txt', 138 | sep='\t', header=None, names=['from', 'rel', 'to']) 139 | df2 = read_csv(data_path + '/valid.txt', 140 | sep='\t', header=None, names=['from', 'rel', 'to']) 141 | df3 = read_csv(data_path + '/test.txt', 142 | sep='\t', header=None, names=['from', 'rel', 'to']) 143 | df = concat([df1, df2, df3]) 144 | kg = KnowledgeGraph(df) 145 | 146 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 147 | 148 | 149 | def load_wn18(data_home=None): 150 | """Load WN18 dataset. 151 | 152 | Parameters 153 | ---------- 154 | data_home: str, optional 155 | Path to the `torchkge_data` directory (containing data folders). If 156 | files are not present on disk in this directory, they are downloaded 157 | and then placed in the right place. 158 | 159 | Returns 160 | ------- 161 | kg_train: torchkge.data_structures.KnowledgeGraph 162 | kg_val: torchkge.data_structures.KnowledgeGraph 163 | kg_test: torchkge.data_structures.KnowledgeGraph 164 | 165 | """ 166 | if data_home is None: 167 | data_home = get_data_home() 168 | data_path = data_home + '/WN18' 169 | if not exists(data_path): 170 | makedirs(data_path, exist_ok=True) 171 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18.zip", 172 | data_home + '/WN18.zip') 173 | with zipfile.ZipFile(data_home + '/WN18.zip', 'r') as zip_ref: 174 | zip_ref.extractall(data_home) 175 | remove(data_home + '/WN18.zip') 176 | 177 | df1 = read_csv(data_path + '/wordnet-mlj12-train.txt', 178 | sep='\t', header=None, names=['from', 'rel', 'to']) 179 | df2 = read_csv(data_path + '/wordnet-mlj12-valid.txt', 180 | sep='\t', header=None, names=['from', 'rel', 'to']) 181 | df3 = read_csv(data_path + '/wordnet-mlj12-test.txt', 182 | sep='\t', header=None, names=['from', 'rel', 'to']) 183 | df = concat([df1, df2, df3]) 184 | kg = KnowledgeGraph(df) 185 | 186 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 187 | 188 | 189 | def load_wn18rr(data_home=None): 190 | """Load WN18RR dataset. See `here 191 | `__ for paper by Dettmers et 192 | al. originally presenting the dataset. 193 | 194 | Parameters 195 | ---------- 196 | data_home: str, optional 197 | Path to the `torchkge_data` directory (containing data folders). If 198 | files are not present on disk in this directory, they are downloaded 199 | and then placed in the right place. 200 | 201 | Returns 202 | ------- 203 | kg_train: torchkge.data_structures.KnowledgeGraph 204 | kg_val: torchkge.data_structures.KnowledgeGraph 205 | kg_test: torchkge.data_structures.KnowledgeGraph 206 | 207 | """ 208 | if data_home is None: 209 | data_home = get_data_home() 210 | data_path = data_home + '/WN18RR' 211 | if not exists(data_path): 212 | makedirs(data_path, exist_ok=True) 213 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18RR.zip", 214 | data_home + '/WN18RR.zip') 215 | with zipfile.ZipFile(data_home + '/WN18RR.zip', 'r') as zip_ref: 216 | zip_ref.extractall(data_home) 217 | remove(data_home + '/WN18RR.zip') 218 | 219 | df1 = read_csv(data_path + '/train.txt', 220 | sep='\t', header=None, names=['from', 'rel', 'to']) 221 | df2 = read_csv(data_path + '/valid.txt', 222 | sep='\t', header=None, names=['from', 'rel', 'to']) 223 | df3 = read_csv(data_path + '/test.txt', 224 | sep='\t', header=None, names=['from', 'rel', 'to']) 225 | df = concat([df1, df2, df3]) 226 | kg = KnowledgeGraph(df) 227 | 228 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 229 | 230 | 231 | def load_yago3_10(data_home=None): 232 | """Load YAGO3-10 dataset. See `here 233 | `__ for paper by Dettmers et 234 | al. originally presenting the dataset. 235 | 236 | Parameters 237 | ---------- 238 | data_home: str, optional 239 | Path to the `torchkge_data` directory (containing data folders). If 240 | files are not present on disk in this directory, they are downloaded 241 | and then placed in the right place. 242 | 243 | Returns 244 | ------- 245 | kg_train: torchkge.data_structures.KnowledgeGraph 246 | kg_val: torchkge.data_structures.KnowledgeGraph 247 | kg_test: torchkge.data_structures.KnowledgeGraph 248 | 249 | """ 250 | if data_home is None: 251 | data_home = get_data_home() 252 | data_path = data_home + '/YAGO3-10' 253 | if not exists(data_path): 254 | makedirs(data_path, exist_ok=True) 255 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/YAGO3-10.zip", 256 | data_home + '/YAGO3-10.zip') 257 | with zipfile.ZipFile(data_home + '/YAGO3-10.zip', 'r') as zip_ref: 258 | zip_ref.extractall(data_home) 259 | remove(data_home + '/YAGO3-10.zip') 260 | 261 | df1 = read_csv(data_path + '/train.txt', 262 | sep='\t', header=None, names=['from', 'rel', 'to']) 263 | df2 = read_csv(data_path + '/valid.txt', 264 | sep='\t', header=None, names=['from', 'rel', 'to']) 265 | df3 = read_csv(data_path + '/test.txt', 266 | sep='\t', header=None, names=['from', 'rel', 'to']) 267 | df = concat([df1, df2, df3]) 268 | kg = KnowledgeGraph(df) 269 | 270 | return kg.split_kg(sizes=(len(df1), len(df2), len(df3))) 271 | 272 | 273 | def load_wikidatasets(which, limit_=0, data_home=None): 274 | """Load WikiDataSets dataset. See `here 275 | `__ for paper by Boschin et al. 276 | originally presenting the dataset. 277 | 278 | Parameters 279 | ---------- 280 | which: str 281 | String indicating which subset of Wikidata should be loaded. 282 | Available ones are `humans`, `companies`, `animals`, `countries` and 283 | `films`. 284 | limit_: int, optional (default=0) 285 | This indicates a lower limit on the number of neighbors an entity 286 | should have in the graph to be kept. 287 | data_home: str, optional 288 | Path to the `torchkge_data` directory (containing data folders). If 289 | files are not present on disk in this directory, they are downloaded 290 | and then placed in the right place. 291 | 292 | Returns 293 | ------- 294 | kg: torchkge.data_structures.KnowledgeGraph 295 | 296 | """ 297 | assert which in ['humans', 'companies', 'animals', 'countries', 'films'] 298 | 299 | if data_home is None: 300 | data_home = get_data_home() 301 | 302 | data_home = data_home + '/WikiDataSets' 303 | data_path = data_home + '/' + which 304 | if not exists(data_path): 305 | makedirs(data_path, exist_ok=True) 306 | urlretrieve("https://graphs.telecom-paristech.fr/data/WikiDataSets/{}.tar.gz".format(which), 307 | data_home + '/{}.tar.gz'.format(which)) 308 | 309 | with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf: 310 | safe_extract(tf, data_home) 311 | remove(data_home + '/{}.tar.gz'.format(which)) 312 | 313 | df = read_csv(data_path + '/edges.tsv', sep='\t', 314 | names=['from', 'to', 'rel'], skiprows=1) 315 | 316 | if limit_ > 0: 317 | a = df.groupby('from').count()['rel'] 318 | b = df.groupby('to').count()['rel'] 319 | 320 | # Filter out nodes with too few facts 321 | tmp = merge(right=DataFrame(a).reset_index(), 322 | left=DataFrame(b).reset_index(), 323 | how='outer', right_on='from', left_on='to', ).fillna(0) 324 | 325 | tmp['rel'] = tmp['rel_x'] + tmp['rel_y'] 326 | tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1) 327 | 328 | tmp = tmp.loc[tmp['rel'] >= limit_] 329 | df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])] 330 | 331 | kg = KnowledgeGraph(df_bis) 332 | else: 333 | kg = KnowledgeGraph(df) 334 | 335 | return kg 336 | 337 | 338 | def load_wikidata_vitals(level=5, data_home=None): 339 | """Load knowledge graph extracted from Wikidata using the entities 340 | corresponding to Wikipedia pages contained in Wikivitals. See `here 341 | `__ for details on Wikivitals and 342 | Wikivitals+ datasets. 343 | 344 | Parameters 345 | ---------- 346 | level: int (default=5) 347 | Either 4 or 5. 348 | data_home: str, optional 349 | Path to the `torchkge_data` directory (containing data folders). If 350 | files are not present on disk in this directory, they are downloaded 351 | and then placed in the right place. 352 | 353 | Returns 354 | ------- 355 | kg: torchkge.data_structures.KnowledgeGraph 356 | kg_attr: torchkge.data_structures.KnowledgeGraph 357 | """ 358 | assert level in [4, 5] 359 | 360 | if data_home is None: 361 | data_home = get_data_home() 362 | 363 | data_path = data_home + '/wikidatavitals-level{}'.format(level) 364 | 365 | if not exists(data_path): 366 | makedirs(data_path, exist_ok=True) 367 | print('Downloading archive') 368 | urlretrieve("https://graphs.telecom-paristech.fr/data/torchkge/kgs/wikidatavitals-level{}.zip".format(level), 369 | data_home + '/wikidatavitals-level{}.zip'.format(level)) 370 | 371 | with zipfile.ZipFile(data_home + '/wikidatavitals-level{}.zip'.format(level), 'r') as zip_ref: 372 | zip_ref.extractall(data_home) 373 | remove(data_home + '/wikidatavitals-level{}.zip'.format(level)) 374 | 375 | if not exists(data_path+'/kgs.pkl'): 376 | print('Building torchkge.KnowledgeGraph objects from the archive.') 377 | df = read_csv(data_path + '/edges.tsv', sep='\t', 378 | names=['from', 'to', 'rel'], skiprows=1) 379 | attributes = read_csv(data_path + '/attributes.tsv', sep='\t', 380 | names=['from', 'to', 'rel'], skiprows=1) 381 | 382 | entities = read_csv(data_path + '/entities.tsv', sep='\t') 383 | relations = read_csv(data_path + '/relations.tsv', sep='\t') 384 | nodes = read_csv(data_path + '/nodes.tsv', sep='\t') 385 | 386 | df = enrich(df, entities, relations) 387 | attributes = enrich(attributes, entities, relations) 388 | 389 | relid2label = {relations.loc[i, 'wikidataID']: relations.loc[i, 'label'] 390 | for i in relations.index} 391 | entid2label = {entities.loc[i, 'wikidataID']: entities.loc[i, 'label'] for 392 | i in entities.index} 393 | entid2pagename = {nodes.loc[i, 'wikidataID']: nodes.loc[i, 'pageName'] for 394 | i in nodes.index} 395 | 396 | kg = KnowledgeGraph(df) 397 | ent2ix, rel2ix = extend_dicts(kg, attributes) 398 | kg_attr = KnowledgeGraph(attributes, ent2ix=ent2ix, rel2ix=rel2ix) 399 | 400 | kg.relid2label = relid2label 401 | kg_attr.relid2label = relid2label 402 | kg.entid2label = entid2label 403 | kg_attr.entid2label = entid2label 404 | kg.entid2pagename = entid2pagename 405 | kg_attr.entid2pagename = entid2pagename 406 | 407 | with open(data_path + '/kgs.pkl', 'wb') as f: 408 | pickle.dump((kg, kg_attr), f) 409 | 410 | else: 411 | print('Loading torchkge.KnowledgeGraph objects from disk.') 412 | with open(data_path + '/kgs.pkl', 'rb') as f: 413 | kg, kg_attr = pickle.load(f) 414 | 415 | return kg, kg_attr 416 | 417 | 418 | def enrich(df, entities, relations): 419 | df = merge(left=df, right=entities[['entityID', 'wikidataID']], 420 | left_on='from', right_on='entityID')[ 421 | ['to', 'rel', 'wikidataID']] 422 | df.columns = ['to', 'rel', 'from'] 423 | 424 | df = merge(left=df, right=entities[['entityID', 'wikidataID']], 425 | left_on='to', right_on='entityID')[ 426 | ['from', 'rel', 'wikidataID']] 427 | 428 | df.columns = ['from', 'rel', 'to'] 429 | 430 | df = merge(left=df, right=relations[['relationID', 'wikidataID']], 431 | left_on='rel', right_on='relationID')[ 432 | ['from', 'to', 'wikidataID']] 433 | 434 | df.columns = ['from', 'to', 'rel'] 435 | return df 436 | -------------------------------------------------------------------------------- /torchkge/data_structures.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from collections import defaultdict 8 | 9 | from pandas import DataFrame 10 | from torch import cat, eq, int64, long, randperm, tensor, Tensor, zeros_like 11 | from torch.utils.data import Dataset 12 | 13 | from torchkge.exceptions import SizeMismatchError, WrongArgumentsError, SanityError 14 | from torchkge.utils.operations import get_dictionaries 15 | 16 | 17 | class KnowledgeGraph(Dataset): 18 | """Knowledge graph representation. At least one of `df` and `kg` 19 | parameters should be passed. 20 | 21 | Parameters 22 | ---------- 23 | df: pandas.DataFrame, optional 24 | Data frame containing three columns [from, to, rel]. 25 | kg: dict, optional 26 | Dictionary with keys ('heads', 'tails', 'relations') and values 27 | the corresponding torch long tensors. 28 | ent2ix: dict, optional 29 | Dictionary mapping entity labels to their integer key. This is 30 | computed if not passed as argument. 31 | rel2ix: dict, optional 32 | Dictionary mapping relation labels to their integer key. This is 33 | computed if not passed as argument. 34 | dict_of_heads: dict, optional 35 | Dictionary of possible heads :math:`h` so that the triple 36 | :math:`(h,r,t)` gives a true fact. The keys are tuples (t, r). 37 | This is computed if not passed as argument. 38 | dict_of_tails: dict, optional 39 | Dictionary of possible tails :math:`t` so that the triple 40 | :math:`(h,r,t)` gives a true fact. The keys are tuples (h, r). 41 | This is computed if not passed as argument. 42 | dict_of_rels: dict, optional 43 | Dictionary of possible relations :math:`r` so that the triple 44 | :math:`(h,r,t)` gives a true fact. The keys are tuples (h, t). 45 | This is computed if not passed as argument. 46 | 47 | 48 | Attributes 49 | ---------- 50 | ent2ix: dict 51 | Dictionary mapping entity labels to their integer key. 52 | rel2ix: dict 53 | Dictionary mapping relation labels to their integer key. 54 | n_ent: int 55 | Number of distinct entities in the data set. 56 | n_rel: int 57 | Number of distinct entities in the data set. 58 | n_facts: int 59 | Number of samples in the data set. A sample is a fact: a triplet 60 | (h, r, l). 61 | head_idx: torch.Tensor, dtype = torch.long, shape: (n_facts) 62 | List of the int key of heads for each fact. 63 | tail_idx: torch.Tensor, dtype = torch.long, shape: (n_facts) 64 | List of the int key of tails for each fact. 65 | relations: torch.Tensor, dtype = torch.long, shape: (n_facts) 66 | List of the int key of relations for each fact. 67 | 68 | """ 69 | 70 | def __init__(self, df=None, kg=None, ent2ix=None, rel2ix=None, 71 | dict_of_heads=None, dict_of_tails=None, dict_of_rels=None): 72 | 73 | if df is None: 74 | if kg is None: 75 | raise WrongArgumentsError("Please provide at least one " 76 | "argument of `df` and kg`") 77 | else: 78 | try: 79 | assert (type(kg) == dict) & ('heads' in kg.keys()) & \ 80 | ('tails' in kg.keys()) & \ 81 | ('relations' in kg.keys()) 82 | except AssertionError: 83 | raise WrongArgumentsError("Keys in the `kg` dict should " 84 | "contain `heads`, `tails`, " 85 | "`relations`.") 86 | try: 87 | assert (rel2ix is not None) & (ent2ix is not None) 88 | except AssertionError: 89 | raise WrongArgumentsError("Please provide the two " 90 | "dictionaries ent2ix and rel2ix " 91 | "if building from `kg`.") 92 | else: 93 | if kg is not None: 94 | raise WrongArgumentsError("`df` and kg` arguments should not " 95 | "both be provided.") 96 | 97 | if ent2ix is None: 98 | self.ent2ix = get_dictionaries(df, ent=True) 99 | else: 100 | self.ent2ix = ent2ix 101 | 102 | if rel2ix is None: 103 | self.rel2ix = get_dictionaries(df, ent=False) 104 | else: 105 | self.rel2ix = rel2ix 106 | 107 | self.n_ent = max(self.ent2ix.values()) + 1 108 | self.n_rel = max(self.rel2ix.values()) + 1 109 | 110 | if df is not None: 111 | # build kg from a pandas dataframe 112 | self.n_facts = len(df) 113 | self.head_idx = tensor(df['from'].map(self.ent2ix).values).long() 114 | self.tail_idx = tensor(df['to'].map(self.ent2ix).values).long() 115 | self.relations = tensor(df['rel'].map(self.rel2ix).values).long() 116 | else: 117 | # build kg from another kg 118 | self.n_facts = kg['heads'].shape[0] 119 | self.head_idx = kg['heads'] 120 | self.tail_idx = kg['tails'] 121 | self.relations = kg['relations'] 122 | 123 | if dict_of_heads is None or dict_of_tails is None or dict_of_rels is None: 124 | self.dict_of_heads = defaultdict(set) 125 | self.dict_of_tails = defaultdict(set) 126 | self.dict_of_rels = defaultdict(set) 127 | self.evaluate_dicts() 128 | 129 | else: 130 | self.dict_of_heads = dict_of_heads 131 | self.dict_of_tails = dict_of_tails 132 | self.dict_of_rels = dict_of_rels 133 | try: 134 | self.sanity_check() 135 | except AssertionError: 136 | raise SanityError("Please check the sanity of arguments.") 137 | 138 | def __len__(self): 139 | return self.n_facts 140 | 141 | def __getitem__(self, item): 142 | return (self.head_idx[item].item(), 143 | self.tail_idx[item].item(), 144 | self.relations[item].item()) 145 | 146 | def sanity_check(self): 147 | assert (type(self.dict_of_heads) == defaultdict) & \ 148 | (type(self.dict_of_tails) == defaultdict) & \ 149 | (type(self.dict_of_rels) == defaultdict) 150 | assert (type(self.ent2ix) == dict) & (type(self.rel2ix) == dict) 151 | assert (len(self.ent2ix) == self.n_ent) & \ 152 | (len(self.rel2ix) == self.n_rel) 153 | assert (type(self.head_idx) == Tensor) & \ 154 | (type(self.tail_idx) == Tensor) & \ 155 | (type(self.relations) == Tensor) 156 | assert (self.head_idx.dtype == int64) & \ 157 | (self.tail_idx.dtype == int64) & (self.relations.dtype == int64) 158 | assert (len(self.head_idx) == len(self.tail_idx) == len(self.relations)) 159 | 160 | def split_kg(self, share=0.8, sizes=None, validation=False): 161 | """Split the knowledge graph into train and test. If `sizes` is 162 | provided then it is used to split the samples as explained below. If 163 | only `share` is provided, the split is done at random but it assures 164 | to keep at least one fact involving each type of entity and relation 165 | in the training subset. 166 | 167 | Parameters 168 | ---------- 169 | share: float 170 | Percentage to allocate to train set. 171 | sizes: tuple 172 | Tuple of ints of length 2 or 3. 173 | 174 | * If len(sizes) == 2, then the first sizes[0] values of the 175 | knowledge graph will be used as training set and the rest as 176 | test set. 177 | 178 | * If len(sizes) == 3, then the first sizes[0] values of the 179 | knowledge graph will be used as training set, the following 180 | sizes[1] as validation set and the last sizes[2] as testing set. 181 | validation: bool 182 | Indicate if a validation set should be produced along with train 183 | and test sets. 184 | 185 | Returns 186 | ------- 187 | train_kg: torchkge.data_structures.KnowledgeGraph 188 | val_kg: torchkge.data_structures.KnowledgeGraph, optional 189 | test_kg: torchkge.data_structures.KnowledgeGraph 190 | 191 | """ 192 | if sizes is not None: 193 | try: 194 | if len(sizes) == 3: 195 | try: 196 | assert (sizes[0] + sizes[1] + sizes[2] == self.n_facts) 197 | except AssertionError: 198 | raise WrongArgumentsError('Sizes should sum to the ' 199 | 'number of facts.') 200 | elif len(sizes) == 2: 201 | try: 202 | assert (sizes[0] + sizes[1] == self.n_facts) 203 | except AssertionError: 204 | raise WrongArgumentsError('Sizes should sum to the ' 205 | 'number of facts.') 206 | else: 207 | raise SizeMismatchError('Tuple `sizes` should be of ' 208 | 'length 2 or 3.') 209 | except AssertionError: 210 | raise SizeMismatchError('Tuple `sizes` should sum up to the ' 211 | 'number of facts in the knowledge ' 212 | 'graph.') 213 | else: 214 | assert share < 1 215 | 216 | if ((sizes is not None) and (len(sizes) == 3)) or \ 217 | ((sizes is None) and validation): 218 | # return training, validation and a testing graphs 219 | 220 | if (sizes is None) and validation: 221 | mask_tr, mask_val, mask_te = self.get_mask(share, 222 | validation=True) 223 | else: 224 | mask_tr = cat([tensor([1 for _ in range(sizes[0])]), 225 | tensor([0 for _ in range(sizes[1] + sizes[2])])]).bool() 226 | mask_val = cat([tensor([0 for _ in range(sizes[0])]), 227 | tensor([1 for _ in range(sizes[1])]), 228 | tensor([0 for _ in range(sizes[2])])]).bool() 229 | mask_te = ~(mask_tr | mask_val) 230 | 231 | return (KnowledgeGraph( 232 | kg={'heads': self.head_idx[mask_tr], 233 | 'tails': self.tail_idx[mask_tr], 234 | 'relations': self.relations[mask_tr]}, 235 | ent2ix=self.ent2ix, rel2ix=self.rel2ix, 236 | dict_of_heads=self.dict_of_heads, 237 | dict_of_tails=self.dict_of_tails, 238 | dict_of_rels=self.dict_of_rels), 239 | KnowledgeGraph( 240 | kg={'heads': self.head_idx[mask_val], 241 | 'tails': self.tail_idx[mask_val], 242 | 'relations': self.relations[mask_val]}, 243 | ent2ix=self.ent2ix, rel2ix=self.rel2ix, 244 | dict_of_heads=self.dict_of_heads, 245 | dict_of_tails=self.dict_of_tails, 246 | dict_of_rels=self.dict_of_rels), 247 | KnowledgeGraph( 248 | kg={'heads': self.head_idx[mask_te], 249 | 'tails': self.tail_idx[mask_te], 250 | 'relations': self.relations[mask_te]}, 251 | ent2ix=self.ent2ix, rel2ix=self.rel2ix, 252 | dict_of_heads=self.dict_of_heads, 253 | dict_of_tails=self.dict_of_tails, 254 | dict_of_rels=self.dict_of_rels)) 255 | else: 256 | # return training and testing graphs 257 | 258 | assert (((sizes is not None) and len(sizes) == 2) or 259 | ((sizes is None) and not validation)) 260 | if sizes is None: 261 | mask_tr, mask_te = self.get_mask(share, validation=False) 262 | else: 263 | mask_tr = cat([tensor([1 for _ in range(sizes[0])]), 264 | tensor([0 for _ in range(sizes[1])])]).bool() 265 | mask_te = ~mask_tr 266 | return (KnowledgeGraph( 267 | kg={'heads': self.head_idx[mask_tr], 268 | 'tails': self.tail_idx[mask_tr], 269 | 'relations': self.relations[mask_tr]}, 270 | ent2ix=self.ent2ix, rel2ix=self.rel2ix, 271 | dict_of_heads=self.dict_of_heads, 272 | dict_of_tails=self.dict_of_tails, 273 | dict_of_rels=self.dict_of_rels), 274 | KnowledgeGraph( 275 | kg={'heads': self.head_idx[mask_te], 276 | 'tails': self.tail_idx[mask_te], 277 | 'relations': self.relations[mask_te]}, 278 | ent2ix=self.ent2ix, rel2ix=self.rel2ix, 279 | dict_of_heads=self.dict_of_heads, 280 | dict_of_tails=self.dict_of_tails, 281 | dict_of_rels=self.dict_of_rels)) 282 | 283 | def get_mask(self, share, validation=False): 284 | """Returns masks to split knowledge graph into train, test and 285 | optionally validation sets. The mask is first created by dividing 286 | samples between subsets based on relation equilibrium. Then if any 287 | entity is not present in the training subset it is manually added by 288 | assigning a share of the sample involving the missing entity either 289 | as head or tail. 290 | 291 | Parameters 292 | ---------- 293 | share: float 294 | validation: bool 295 | 296 | Returns 297 | ------- 298 | mask: torch.Tensor, shape: (n), dtype: torch.bool 299 | mask_val: torch.Tensor, shape: (n), dtype: torch.bool (optional) 300 | mask_te: torch.Tensor, shape: (n), dtype: torch.bool 301 | """ 302 | 303 | uniques_r, counts_r = self.relations.unique(return_counts=True) 304 | uniques_e, _ = cat((self.head_idx, 305 | self.tail_idx)).unique(return_counts=True) 306 | 307 | mask = zeros_like(self.relations).bool() 308 | if validation: 309 | mask_val = zeros_like(self.relations).bool() 310 | 311 | # splitting relations among subsets 312 | for i, r in enumerate(uniques_r): 313 | rand = randperm(counts_r[i].item()) 314 | 315 | # list of indices k such that relations[k] == r 316 | sub_mask = eq(self.relations, r).nonzero(as_tuple=False)[:, 0] 317 | 318 | assert len(sub_mask) == counts_r[i].item() 319 | 320 | if validation: 321 | train_size, val_size, test_size = self.get_sizes(counts_r[i].item(), 322 | share=share, 323 | validation=True) 324 | mask[sub_mask[rand[:train_size]]] = True 325 | mask_val[sub_mask[rand[train_size:train_size + val_size]]] = True 326 | 327 | else: 328 | train_size, test_size = self.get_sizes(counts_r[i].item(), 329 | share=share, 330 | validation=False) 331 | mask[sub_mask[rand[:train_size]]] = True 332 | 333 | # adding missing entities to the train set 334 | u = cat((self.head_idx[mask], self.tail_idx[mask])).unique() 335 | if len(u) < self.n_ent: 336 | missing_entities = tensor(list(set(uniques_e.tolist()) - 337 | set(u.tolist())), dtype=long) 338 | for e in missing_entities: 339 | sub_mask = ((self.head_idx == e) | 340 | (self.tail_idx == e)).nonzero(as_tuple=False)[:, 0] 341 | rand = randperm(len(sub_mask)) 342 | sizes = self.get_sizes(mask.shape[0], 343 | share=share, 344 | validation=validation) 345 | mask[sub_mask[rand[:sizes[0]]]] = True 346 | if validation: 347 | mask_val[sub_mask[rand[:sizes[0]]]] = False 348 | 349 | if validation: 350 | assert not (mask & mask_val).any().item() 351 | return mask, mask_val, ~(mask | mask_val) 352 | else: 353 | return mask, ~mask 354 | 355 | @staticmethod 356 | def get_sizes(count, share, validation=False): 357 | """With `count` samples, returns how many should go to train and test 358 | 359 | """ 360 | if count == 1: 361 | if validation: 362 | return 1, 0, 0 363 | else: 364 | return 1, 0 365 | if count == 2: 366 | if validation: 367 | return 1, 1, 0 368 | else: 369 | return 1, 1 370 | 371 | n_train = int(count * share) 372 | assert n_train < count 373 | if n_train == 0: 374 | n_train += 1 375 | 376 | if not validation: 377 | return n_train, count - n_train 378 | else: 379 | if count - n_train == 1: 380 | n_train -= 1 381 | return n_train, 1, 1 382 | else: 383 | n_val = int(int(count - n_train) / 2) 384 | return n_train, n_val, count - n_train - n_val 385 | 386 | def evaluate_dicts(self): 387 | """Evaluates dicts of possible alternatives to an entity in a fact 388 | that still gives a true fact in the entire knowledge graph. 389 | 390 | """ 391 | for i in range(self.n_facts): 392 | self.dict_of_heads[(self.tail_idx[i].item(), 393 | self.relations[i].item())].add(self.head_idx[i].item()) 394 | self.dict_of_tails[(self.head_idx[i].item(), 395 | self.relations[i].item())].add(self.tail_idx[i].item()) 396 | self.dict_of_rels[(self.head_idx[i].item(), 397 | self.tail_idx[i].item())].add(self.relations[i].item()) 398 | 399 | def get_df(self): 400 | """ 401 | Returns a Pandas DataFrame with columns ['from', 'to', 'rel']. 402 | """ 403 | ix2ent = {v: k for k, v in self.ent2ix.items()} 404 | ix2rel = {v: k for k, v in self.rel2ix.items()} 405 | 406 | df = DataFrame(cat((self.head_idx.view(1, -1), 407 | self.tail_idx.view(1, -1), 408 | self.relations.view(1, -1))).transpose(0, 1).numpy(), 409 | columns=['from', 'to', 'rel']) 410 | 411 | df['from'] = df['from'].apply(lambda x: ix2ent[x]) 412 | df['to'] = df['to'].apply(lambda x: ix2ent[x]) 413 | df['rel'] = df['rel'].apply(lambda x: ix2rel[x]) 414 | 415 | return df 416 | 417 | 418 | class SmallKG(Dataset): 419 | """Minimalist version of a knowledge graph. Built with tensors of heads, 420 | tails and relations. 421 | 422 | """ 423 | def __init__(self, heads, tails, relations): 424 | assert heads.shape == tails.shape == relations.shape 425 | self.head_idx = heads 426 | self.tail_idx = tails 427 | self.relations = relations 428 | self.length = heads.shape[0] 429 | 430 | def __len__(self): 431 | return self.length 432 | 433 | def __getitem__(self, item): 434 | return self.head_idx[item].item(), self.tail_idx[item].item(), self.relations[item].item() 435 | -------------------------------------------------------------------------------- /torchkge/evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import empty, zeros, cat 8 | from tqdm.autonotebook import tqdm 9 | 10 | from .data_structures import SmallKG 11 | from .exceptions import NotYetEvaluatedError 12 | from .sampling import PositionalNegativeSampler 13 | from .utils import DataLoader, get_rank, filter_scores 14 | 15 | 16 | class RelationPredictionEvaluator(object): 17 | """Evaluate performance of given embedding using relation prediction method. 18 | 19 | References 20 | ---------- 21 | * Armand Boschin, Thomas Bonald. 22 | Enriching Wikidata with Semantified Wikipedia Hyperlinks 23 | In proceedings of the Wikidata workshop, ISWC2021, 2021. 24 | http://ceur-ws.org/Vol-2982/paper-6.pdf 25 | 26 | Parameters 27 | ---------- 28 | model: torchkge.models.interfaces.Model 29 | Embedding model inheriting from the right interface. 30 | knowledge_graph: torchkge.data_structures.KnowledgeGraph 31 | Knowledge graph on which the evaluation will be done. 32 | directed: bool, optional (default=True) 33 | Indicates whether the orientation head to tail is known when 34 | predicting missing relations. If False, then both tests (h, _, t) 35 | and (t, _, r) are done to find the best scoring triples. 36 | 37 | Attributes 38 | ---------- 39 | model: torchkge.models.interfaces.Model 40 | Embedding model inheriting from the right interface. 41 | kg: torchkge.data_structures.KnowledgeGraph 42 | Knowledge graph on which the evaluation will be done. 43 | rank_true_rels: torch.Tensor, shape: (n_facts), dtype: `torch.int` 44 | For each fact, this is the rank of the true relation when all relations 45 | are ranked. They are ranked in decreasing order of scoring function 46 | :math:`f_r(h,t)`. 47 | filt_rank_true_rels: torch.Tensor, shape: (n_facts), dtype: `torch.int` 48 | This is the same as the `rank_true_rels` when in the filtered 49 | case. See referenced paper by Bordes et al. for more information. 50 | evaluated: bool 51 | Indicates if the method LinkPredictionEvaluator.evaluate has already 52 | been called. 53 | directed: bool, optional (default=True) 54 | Indicates whether the orientation head to tail is known when 55 | predicting missing relations. If False, then both tests (h, _, t) 56 | and (t, _, r) are done to find the best scoring triples. 57 | """ 58 | 59 | def __init__(self, model, knowledge_graph, directed=True): 60 | self.model = model 61 | self.kg = knowledge_graph 62 | self.directed = directed 63 | 64 | self.rank_true_rels = empty(size=(knowledge_graph.n_facts,)).long() 65 | self.filt_rank_true_rels = empty(size=(knowledge_graph.n_facts,)).long() 66 | 67 | self.evaluated = False 68 | 69 | def evaluate(self, b_size, verbose=True): 70 | """ 71 | 72 | Parameters 73 | ---------- 74 | b_size: int 75 | Size of the current batch. 76 | verbose: bool 77 | Indicates whether a progress bar should be displayed during 78 | evaluation. 79 | 80 | """ 81 | use_cuda = next(self.model.parameters()).is_cuda 82 | 83 | if use_cuda: 84 | dataloader = DataLoader(self.kg, batch_size=b_size, use_cuda='batch') 85 | self.rank_true_rels = self.rank_true_rels.cuda() 86 | self.filt_rank_true_rels = self.filt_rank_true_rels.cuda() 87 | else: 88 | dataloader = DataLoader(self.kg, batch_size=b_size) 89 | 90 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 91 | unit='batch', disable=(not verbose), 92 | desc='Relation prediction evaluation'): 93 | h_idx, t_idx, r_idx = batch[0], batch[1], batch[2] 94 | h_emb, t_emb, r_emb, candidates = self.model.inference_prepare_candidates(h_idx, t_idx, r_idx, entities=False) 95 | 96 | scores = self.model.inference_scoring_function(h_emb, t_emb, candidates) 97 | filt_scores = filter_scores(scores, self.kg.dict_of_rels, h_idx, t_idx, r_idx) 98 | 99 | if not self.directed: 100 | scores_bis = self.model.inference_scoring_function(t_emb, h_emb, candidates) 101 | filt_scores_bis = filter_scores(scores_bis, self.kg.dict_of_rels, h_idx, t_idx, r_idx) 102 | 103 | scores = cat((scores, scores_bis), dim=1) 104 | filt_scores = cat((filt_scores, filt_scores_bis), dim=1) 105 | 106 | self.rank_true_rels[i * b_size: (i + 1) * b_size] = get_rank(scores, r_idx).detach() 107 | self.filt_rank_true_rels[i * b_size: (i + 1) * b_size] = get_rank(filt_scores, r_idx).detach() 108 | 109 | self.evaluated = True 110 | 111 | if use_cuda: 112 | self.rank_true_rels = self.rank_true_rels.cpu() 113 | self.filt_rank_true_rels = self.filt_rank_true_rels.cpu() 114 | 115 | def mean_rank(self): 116 | """ 117 | 118 | Returns 119 | ------- 120 | mean_rank: float 121 | Mean rank of the true entity when replacing alternatively head 122 | and tail in any fact of the dataset. 123 | filt_mean_rank: float 124 | Filtered mean rank of the true entity when replacing 125 | alternatively head and tail in any fact of the dataset. 126 | 127 | """ 128 | if not self.evaluated: 129 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 130 | 'LinkPredictionEvaluator.evaluate') 131 | sum_ = self.rank_true_rels.float().mean().item() 132 | filt_sum = self.filt_rank_true_rels.float().mean().item() 133 | return sum_, filt_sum 134 | 135 | def hit_at_k(self, k=10): 136 | """ 137 | 138 | Parameters 139 | ---------- 140 | k: int 141 | Hit@k is the number of entities that show up in the top k that 142 | give facts present in the dataset. 143 | 144 | Returns 145 | ------- 146 | avg_hitatk: float 147 | Average of hit@k for head and tail replacement. 148 | filt_avg_hitatk: float 149 | Filtered average of hit@k for head and tail replacement. 150 | 151 | """ 152 | if not self.evaluated: 153 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 154 | 'LinkPredictionEvaluator.evaluate') 155 | 156 | return (self.rank_true_rels <= k).float().mean().item(), (self.filt_rank_true_rels <= k).float().mean().item() 157 | 158 | def mrr(self): 159 | """ 160 | 161 | Returns 162 | ------- 163 | avg_mrr: float 164 | Average of mean recovery rank for head and tail replacement. 165 | filt_avg_mrr: float 166 | Filtered average of mean recovery rank for head and tail 167 | replacement. 168 | 169 | """ 170 | if not self.evaluated: 171 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 172 | 'LinkPredictionEvaluator.evaluate') 173 | mrr = (self.rank_true_rels.float()**(-1)).mean() 174 | filt_mrr = (self.filt_rank_true_rels.float()**(-1)).mean() 175 | 176 | return mrr.item(), filt_mrr.item() 177 | 178 | def print_results(self, k=None, n_digits=3): 179 | """ 180 | 181 | Parameters 182 | ---------- 183 | k: int or list 184 | k (or list of k) such that hit@k will be printed. 185 | n_digits: int 186 | Number of digits to be printed for hit@k and MRR. 187 | """ 188 | if k is None: 189 | k = 10 190 | 191 | if k is not None and type(k) == int: 192 | print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format( 193 | k, round(self.hit_at_k(k=k)[0], n_digits), 194 | k, round(self.hit_at_k(k=k)[1], n_digits))) 195 | if k is not None and type(k) == list: 196 | for i in k: 197 | print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format( 198 | i, round(self.hit_at_k(k=i)[0], n_digits), 199 | i, round(self.hit_at_k(k=i)[1], n_digits))) 200 | 201 | print('Mean Rank : {} \t Filt. Mean Rank : {}'.format( 202 | int(self.mean_rank()[0]), int(self.mean_rank()[1]))) 203 | print('MRR : {} \t\t Filt. MRR : {}'.format( 204 | round(self.mrr()[0], n_digits), round(self.mrr()[1], n_digits))) 205 | 206 | 207 | class LinkPredictionEvaluator(object): 208 | """Evaluate performance of given embedding using link prediction method. 209 | 210 | References 211 | ---------- 212 | * Antoine Bordes, Nicolas Usunier, Alberto Garcia-Duran, Jason Weston, 213 | and Oksana Yakhnenko. 214 | Translating Embeddings for Modeling Multi-relational Data. 215 | In Advances in Neural Information Processing Systems 26, pages 2787–2795, 216 | 2013. 217 | https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data 218 | 219 | Parameters 220 | ---------- 221 | model: torchkge.models.interfaces.Model 222 | Embedding model inheriting from the right interface. 223 | knowledge_graph: torchkge.data_structures.KnowledgeGraph 224 | Knowledge graph on which the evaluation will be done. 225 | 226 | Attributes 227 | ---------- 228 | model: torchkge.models.interfaces.Model 229 | Embedding model inheriting from the right interface. 230 | kg: torchkge.data_structures.KnowledgeGraph 231 | Knowledge graph on which the evaluation will be done. 232 | rank_true_heads: torch.Tensor, shape: (n_facts), dtype: `torch.int` 233 | For each fact, this is the rank of the true head when all entities 234 | are ranked as possible replacement of the head entity. They are 235 | ranked in decreasing order of scoring function :math:`f_r(h,t)`. 236 | rank_true_tails: torch.Tensor, shape: (n_facts), dtype: `torch.int` 237 | For each fact, this is the rank of the true tail when all entities 238 | are ranked as possible replacement of the tail entity. They are 239 | ranked in decreasing order of scoring function :math:`f_r(h,t)`. 240 | filt_rank_true_heads: torch.Tensor, shape: (n_facts), dtype: `torch.int` 241 | This is the same as the `rank_of_true_heads` when in the filtered 242 | case. See referenced paper by Bordes et al. for more information. 243 | filt_rank_true_tails: torch.Tensor, shape: (n_facts), dtype: `torch.int` 244 | This is the same as the `rank_of_true_tails` when in the filtered 245 | case. See referenced paper by Bordes et al. for more information. 246 | evaluated: bool 247 | Indicates if the method LinkPredictionEvaluator.evaluate has already 248 | been called. 249 | 250 | """ 251 | 252 | def __init__(self, model, knowledge_graph): 253 | self.model = model 254 | self.kg = knowledge_graph 255 | 256 | self.rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long() 257 | self.rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long() 258 | self.filt_rank_true_heads = empty(size=(knowledge_graph.n_facts,)).long() 259 | self.filt_rank_true_tails = empty(size=(knowledge_graph.n_facts,)).long() 260 | 261 | self.evaluated = False 262 | 263 | def evaluate(self, b_size, verbose=True): 264 | """ 265 | 266 | Parameters 267 | ---------- 268 | b_size: int 269 | Size of the current batch. 270 | verbose: bool 271 | Indicates whether a progress bar should be displayed during 272 | evaluation. 273 | 274 | """ 275 | use_cuda = next(self.model.parameters()).is_cuda 276 | 277 | if use_cuda: 278 | dataloader = DataLoader(self.kg, batch_size=b_size, use_cuda='batch') 279 | self.rank_true_heads = self.rank_true_heads.cuda() 280 | self.rank_true_tails = self.rank_true_tails.cuda() 281 | self.filt_rank_true_heads = self.filt_rank_true_heads.cuda() 282 | self.filt_rank_true_tails = self.filt_rank_true_tails.cuda() 283 | else: 284 | dataloader = DataLoader(self.kg, batch_size=b_size) 285 | 286 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 287 | unit='batch', disable=(not verbose), 288 | desc='Link prediction evaluation'): 289 | h_idx, t_idx, r_idx = batch[0], batch[1], batch[2] 290 | h_emb, t_emb, r_emb, candidates = self.model.inference_prepare_candidates(h_idx, t_idx, r_idx, entities=True) 291 | 292 | scores = self.model.inference_scoring_function(h_emb, candidates, r_emb) 293 | filt_scores = filter_scores(scores, self.kg.dict_of_tails, h_idx, r_idx, t_idx) 294 | self.rank_true_tails[i * b_size: (i + 1) * b_size] = get_rank(scores, t_idx).detach() 295 | self.filt_rank_true_tails[i * b_size: (i + 1) * b_size] = get_rank(filt_scores, t_idx).detach() 296 | 297 | scores = self.model.inference_scoring_function(candidates, t_emb, r_emb) 298 | filt_scores = filter_scores(scores, self.kg.dict_of_heads, t_idx, r_idx, h_idx) 299 | self.rank_true_heads[i * b_size: (i + 1) * b_size] = get_rank(scores, h_idx).detach() 300 | self.filt_rank_true_heads[i * b_size: (i + 1) * b_size] = get_rank(filt_scores, h_idx).detach() 301 | 302 | self.evaluated = True 303 | 304 | if use_cuda: 305 | self.rank_true_heads = self.rank_true_heads.cpu() 306 | self.rank_true_tails = self.rank_true_tails.cpu() 307 | self.filt_rank_true_heads = self.filt_rank_true_heads.cpu() 308 | self.filt_rank_true_tails = self.filt_rank_true_tails.cpu() 309 | 310 | def mean_rank(self): 311 | """ 312 | 313 | Returns 314 | ------- 315 | mean_rank: float 316 | Mean rank of the true entity when replacing alternatively head 317 | and tail in any fact of the dataset. 318 | filt_mean_rank: float 319 | Filtered mean rank of the true entity when replacing 320 | alternatively head and tail in any fact of the dataset. 321 | 322 | """ 323 | if not self.evaluated: 324 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 325 | 'LinkPredictionEvaluator.evaluate') 326 | sum_ = (self.rank_true_heads.float().mean() + 327 | self.rank_true_tails.float().mean()).item() 328 | filt_sum = (self.filt_rank_true_heads.float().mean() + 329 | self.filt_rank_true_tails.float().mean()).item() 330 | return sum_ / 2, filt_sum / 2 331 | 332 | def hit_at_k_heads(self, k=10): 333 | if not self.evaluated: 334 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 335 | 'LinkPredictionEvaluator.evaluate') 336 | head_hit = (self.rank_true_heads <= k).float().mean() 337 | filt_head_hit = (self.filt_rank_true_heads <= k).float().mean() 338 | 339 | return head_hit.item(), filt_head_hit.item() 340 | 341 | def hit_at_k_tails(self, k=10): 342 | if not self.evaluated: 343 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 344 | 'LinkPredictionEvaluator.evaluate') 345 | tail_hit = (self.rank_true_tails <= k).float().mean() 346 | filt_tail_hit = (self.filt_rank_true_tails <= k).float().mean() 347 | 348 | return tail_hit.item(), filt_tail_hit.item() 349 | 350 | def hit_at_k(self, k=10): 351 | """ 352 | 353 | Parameters 354 | ---------- 355 | k: int 356 | Hit@k is the number of entities that show up in the top k that 357 | give facts present in the dataset. 358 | 359 | Returns 360 | ------- 361 | avg_hitatk: float 362 | Average of hit@k for head and tail replacement. 363 | filt_avg_hitatk: float 364 | Filtered average of hit@k for head and tail replacement. 365 | 366 | """ 367 | if not self.evaluated: 368 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 369 | 'LinkPredictionEvaluator.evaluate') 370 | 371 | head_hit, filt_head_hit = self.hit_at_k_heads(k=k) 372 | tail_hit, filt_tail_hit = self.hit_at_k_tails(k=k) 373 | 374 | return (head_hit + tail_hit) / 2, (filt_head_hit + filt_tail_hit) / 2 375 | 376 | def mrr(self): 377 | """ 378 | 379 | Returns 380 | ------- 381 | avg_mrr: float 382 | Average of mean recovery rank for head and tail replacement. 383 | filt_avg_mrr: float 384 | Filtered average of mean recovery rank for head and tail 385 | replacement. 386 | 387 | """ 388 | if not self.evaluated: 389 | raise NotYetEvaluatedError('Evaluator not evaluated call ' 390 | 'LinkPredictionEvaluator.evaluate') 391 | head_mrr = (self.rank_true_heads.float()**(-1)).mean() 392 | tail_mrr = (self.rank_true_tails.float()**(-1)).mean() 393 | filt_head_mrr = (self.filt_rank_true_heads.float()**(-1)).mean() 394 | filt_tail_mrr = (self.filt_rank_true_tails.float()**(-1)).mean() 395 | 396 | return ((head_mrr + tail_mrr).item() / 2, 397 | (filt_head_mrr + filt_tail_mrr).item() / 2) 398 | 399 | def print_results(self, k=None, n_digits=3): 400 | """ 401 | 402 | Parameters 403 | ---------- 404 | k: int or list 405 | k (or list of k) such that hit@k will be printed. 406 | n_digits: int 407 | Number of digits to be printed for hit@k and MRR. 408 | """ 409 | if k is None: 410 | k = 10 411 | 412 | if k is not None and type(k) == int: 413 | print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format( 414 | k, round(self.hit_at_k(k=k)[0], n_digits), 415 | k, round(self.hit_at_k(k=k)[1], n_digits))) 416 | if k is not None and type(k) == list: 417 | for i in k: 418 | print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format( 419 | i, round(self.hit_at_k(k=i)[0], n_digits), 420 | i, round(self.hit_at_k(k=i)[1], n_digits))) 421 | 422 | print('Mean Rank : {} \t Filt. Mean Rank : {}'.format( 423 | int(self.mean_rank()[0]), int(self.mean_rank()[1]))) 424 | print('MRR : {} \t\t Filt. MRR : {}'.format( 425 | round(self.mrr()[0], n_digits), round(self.mrr()[1], n_digits))) 426 | 427 | 428 | class TripletClassificationEvaluator(object): 429 | """Evaluate performance of given embedding using triplet classification 430 | method. 431 | 432 | References 433 | ---------- 434 | * Richard Socher, Danqi Chen, Christopher D Manning, and Andrew Ng. 435 | Reasoning With Neural Tensor Networks for Knowledge Base Completion. 436 | In Advances in Neural Information Processing Systems 26, pages 926–934. 437 | 2013. 438 | https://nlp.stanford.edu/pubs/SocherChenManningNg_NIPS2013.pdf 439 | 440 | Parameters 441 | ---------- 442 | model: torchkge.models.interfaces.Model 443 | Embedding model inheriting from the right interface. 444 | kg_val: torchkge.data_structures.KnowledgeGraph 445 | Knowledge graph on which the validation thresholds will be computed. 446 | kg_test: torchkge.data_structures.KnowledgeGraph 447 | Knowledge graph on which the testing evaluation will be done. 448 | 449 | Attributes 450 | ---------- 451 | model: torchkge.models.interfaces.Model 452 | Embedding model inheriting from the right interface. 453 | kg_val: torchkge.data_structures.KnowledgeGraph 454 | Knowledge graph on which the validation thresholds will be computed. 455 | kg_test: torchkge.data_structures.KnowledgeGraph 456 | Knowledge graph on which the evaluation will be done. 457 | evaluated: bool 458 | Indicate whether the `evaluate` function has been called. 459 | thresholds: float 460 | Value of the thresholds for the scoring function to consider a 461 | triplet as true. It is defined by calling the `evaluate` method. 462 | sampler: torchkge.sampling.NegativeSampler 463 | Negative sampler. 464 | 465 | """ 466 | 467 | def __init__(self, model, kg_val, kg_test): 468 | self.model = model 469 | self.kg_val = kg_val 470 | self.kg_test = kg_test 471 | self.is_cuda = next(self.model.parameters()).is_cuda 472 | 473 | self.evaluated = False 474 | self.thresholds = None 475 | 476 | self.sampler = PositionalNegativeSampler(self.kg_val, 477 | kg_test=self.kg_test) 478 | 479 | def get_scores(self, heads, tails, relations, batch_size): 480 | """With head, tail and relation indexes, compute the value of the 481 | scoring function of the model. 482 | 483 | Parameters 484 | ---------- 485 | heads: torch.Tensor, dtype: torch.long, shape: n_facts 486 | List of heads indices. 487 | tails: torch.Tensor, dtype: torch.long, shape: n_facts 488 | List of tails indices. 489 | relations: torch.Tensor, dtype: torch.long, shape: n_facts 490 | List of relation indices. 491 | batch_size: int 492 | 493 | Returns 494 | ------- 495 | scores: torch.Tensor, dtype: torch.float, shape: n_facts 496 | List of scores of each triplet. 497 | """ 498 | scores = [] 499 | 500 | small_kg = SmallKG(heads, tails, relations) 501 | if self.is_cuda: 502 | dataloader = DataLoader(small_kg, batch_size=batch_size, 503 | use_cuda='batch') 504 | else: 505 | dataloader = DataLoader(small_kg, batch_size=batch_size) 506 | 507 | for i, batch in enumerate(dataloader): 508 | h_idx, t_idx, r_idx = batch[0], batch[1], batch[2] 509 | scores.append(self.model.scoring_function(h_idx, t_idx, r_idx)) 510 | 511 | return cat(scores, dim=0) 512 | 513 | def evaluate(self, b_size): 514 | """Find relation thresholds using the validation set. As described in 515 | the paper by Socher et al., for a relation, the threshold is a value t 516 | such that if the score of a triplet is larger than t, the fact is true. 517 | If a relation is not present in any fact of the validation set, then 518 | the largest value score of all negative samples is used as threshold. 519 | 520 | Parameters 521 | ---------- 522 | b_size: int 523 | Batch size. 524 | """ 525 | r_idx = self.kg_val.relations 526 | 527 | neg_heads, neg_tails = self.sampler.corrupt_kg(b_size, self.is_cuda, 528 | which='main') 529 | neg_scores = self.get_scores(neg_heads, neg_tails, r_idx, b_size) 530 | 531 | self.thresholds = zeros(self.kg_val.n_rel) 532 | 533 | for i in range(self.kg_val.n_rel): 534 | mask = (r_idx == i).bool() 535 | if mask.sum() > 0: 536 | self.thresholds[i] = neg_scores[mask].max() 537 | else: 538 | self.thresholds[i] = neg_scores.max() 539 | 540 | self.evaluated = True 541 | self.thresholds.detach_() 542 | 543 | def accuracy(self, b_size): 544 | """ 545 | 546 | Parameters 547 | ---------- 548 | b_size: int 549 | Batch size. 550 | 551 | Returns 552 | ------- 553 | acc: float 554 | Share of all triplets (true and negatively sampled ones) that where 555 | correctly classified using the thresholds learned from the 556 | validation set. 557 | 558 | """ 559 | if not self.evaluated: 560 | self.evaluate(b_size) 561 | 562 | r_idx = self.kg_test.relations 563 | 564 | neg_heads, neg_tails = self.sampler.corrupt_kg(b_size, 565 | self.is_cuda, 566 | which='test') 567 | scores = self.get_scores(self.kg_test.head_idx, 568 | self.kg_test.tail_idx, 569 | r_idx, 570 | b_size) 571 | neg_scores = self.get_scores(neg_heads, neg_tails, r_idx, b_size) 572 | 573 | if self.is_cuda: 574 | self.thresholds = self.thresholds.cuda() 575 | 576 | scores = (scores > self.thresholds[r_idx]) 577 | neg_scores = (neg_scores < self.thresholds[r_idx]) 578 | 579 | return (scores.sum().item() + 580 | neg_scores.sum().item()) / (2 * self.kg_test.n_facts) 581 | --------------------------------------------------------------------------------