├── .gitignore ├── LICENSE ├── README.md ├── examples ├── DistMult │ └── DistMult-FB15k.sh ├── RESCAL │ └── RESCAL-FB15k.sh ├── TransE │ └── TransE-FB15k.sh ├── TransH │ └── TransH-FB15k.sh └── TransR │ └── TransR-FB15k.sh ├── images └── structure.png ├── krl ├── __init__.py ├── base_model.py ├── config.py ├── dataset │ ├── __init__.py │ ├── instance │ │ ├── __init__.py │ │ ├── huggingface_krl_datasets_conf.json │ │ ├── index.py │ │ └── utils.py │ └── krl_dataset.py ├── evaluator.py ├── lit_model │ ├── TransXLitModel.py │ └── __init__.py ├── metric.py ├── metric_fomatter.py ├── models │ ├── DistMult.py │ ├── RESCAL.py │ ├── TransD.py │ ├── TransE.py │ ├── TransH.py │ ├── TransR.py │ └── __init__.py ├── negative_sampler.py ├── serializer.py ├── storage.py ├── trainer.py ├── typer_apps │ ├── __init__.py │ ├── distmult.py │ ├── rescal.py │ ├── transe.py │ ├── transh.py │ └── transr.py └── utils │ ├── __init__.py │ ├── data.py │ ├── device.py │ ├── logs_dir.py │ ├── optim.py │ └── seed.py ├── requirements.txt ├── test.ipynb ├── transe.ipynb └── typer_app.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Run my scripts 131 | my-scripts/ 132 | nohup.out 133 | 134 | # pytorch lightning 135 | lightning_logs/ 136 | ./test.ipynb 137 | 138 | # model checkpoint 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 yubin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KRL 2 | 3 | This is a framework that uses PyTorch to reproduce Knowledge Representation Learning (KRL) models. With it, you can easily run models of knowledge representation learning, while quickly implementing your ideas after making simple changes or extensions. 4 | 5 | Currently, we have implemented several knowledge representation learning models including **TransE**, **RESCAL**, and simple abstraction and decoupling of publicly reusable code. 6 | 7 | Don't use this repo because I am developing it. So it must suffer broken changes in the future. 8 | 9 | The overview structure of this library is below: 10 | 11 |

12 | 13 |

14 | 15 | 16 | ## How to use it? 17 | 18 | Don't use it now! The document has not been updated. I will write it soon. 19 | 20 | If you just want to use the models directly, then **you just need to run the sample scripts located in the directory** `./examples` and all you need to do is change a few parameters that you want to change. The call to the training code has been wrapped through [Typer](https://typer.tiangolo.com/), which is a great tool for building CLIs. So we can call the training code by CLIs. The examples can be found in the directory `./examples`. 21 | 22 | Example: 23 | 24 | ```shell 25 | cd ./examples 26 | sh transe.sh 27 | ``` 28 | 29 | If you are trying to make some changes or innovations, then you may need to briefly understand the logic of how this library works. Fortunately, the code is cleanly wrapped and decoupled, which makes it easy for you to understand the logic of the entire program and make changes to parts of it. 30 | 31 | First, run through the `transe.ipynb` notebook in the project root directory and you will get to know the core logic of the library. Therefore, successfully running through this notebook is the first step to using the project. 32 | This notebook replicates the operation of the **TransE** model. 33 | 34 | In this real program library, there is still a difference between the running logic of a model and the logic in a notebook. To realize the reuse of module code, the program library again abstracts and encapsulates part of the program code, but this does not change the core idea of the model operation. 35 | 36 | 37 | + Notice: Before you run this script, you should download the dataset, such as FB15k, and modify the script for choosing the path of the dataset and checkpoints. 38 | 39 | The example `./transe.ipynb` is a good tutorial for reproducing the TransE if you want to know the structure of this repo. This tutorial can be run without any dependencies, except for the use of common third-party libraries like PyTorch and Numpy. 40 | 41 | ## Dataset 42 | 43 | In this paper, the different types of datasets are uniformly encapsulated in the `KRLDataset` class. But the data in the dataset needs to be downloaded by you and the disk path to the dataset needs to be specified when generating the configuration instance. 44 | 45 | You can download the dataset here: [KGDataset](https://github.com/ZhenfengLei/KGDatasets). 46 | 47 | 48 | ## Plan 49 | 50 | | Status | Model | Year | Paper | Rewarks | 51 | | :----: | :----: | :----: | :--- | --- | 52 | | :heavy_check_mark: | [RESCAL](/krl/models/RESCAL.py) | 2011 | ICML'11, [OpenReview](https://openreview.net/forum?id=H14QEiZ_WS) | | 53 | | :heavy_check_mark: | [TransE](/krl/models/TransE.py) | 2013 | NIPS'13, [ACM](http://dl.acm.org/doi/10.5555/2999792.2999923) | | 54 | | :heavy_check_mark: | [TransH](/krl/models/TransH.py) | 2014 | AAAI'14, [ReasearchGate](https://www.researchgate.net/publication/319207032_Knowledge_Graph_Embedding_by_Translating_on_Hyperplanes) | | 55 | | :heavy_check_mark: | [DistMult](/krl/models/DistMult.py) | 2014 | ICLR'15, [arXiv](http://arxiv.org/abs/1412.6575) | | 56 | | :heavy_check_mark: | [TransR](/krl/models/TransR.py) | 2015 | AAAI'15, [AAAI](https://ojs.aaai.org/index.php/AAAI/article/view/9491) | a low performance, but I don't know why. | 57 | | :white_circle: | TransD | 2015 | ACL-IJCNLP 2015, [Aclanthology](https://aclanthology.org/P15-1067) | | 58 | | :white_circle: | TransA | 2015 | [arXiv](https://arxiv.org/abs/1509.05490) | | 59 | | :white_circle: | TransG | 2015 | [arXiv](https://arxiv.org/abs/1509.05488) | | 60 | | :white_circle: | KG2E | 2015 | CIKM'15, [ACM](https://dl.acm.org/doi/10.1145/2806416.2806502) | | 61 | | :white_circle: | TranSparse | 2016 | AAAI'16, [AAAI](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/11982) | | 62 | | :white_circle: | TransF | 2016 | AAAI'16, [AAAI](https://www.aaai.org/ocs/index.php/KR/KR16/paper/view/12887) | | 63 | | :white_circle: | ComplEx | 2016 | ICML'16, [arXiv](http://arxiv.org/abs/1606.06357) | | 64 | | :white_circle: | HolE | 2016 | AAAI'16, [arXiv](http://arxiv.org/abs/1510.04935) | | 65 | | :white_circle: | R-GCN | 2017 | ESWC'18, [arXiv](http://arxiv.org/abs/1703.06103) | | 66 | | :white_circle: | ConvKB | 2018 | NAACL-HLT 2018, [arXiv](http://arxiv.org/abs/1712.02121) | | 67 | | :white_circle: | ConvE | 2018 | AAAI'18, [arXiv](http://arxiv.org/abs/1707.01476) | | 68 | | :white_circle: | SimplE | 2018 | NIPS'18, [arXiv](http://arxiv.org/abs/1802.04868) | | 69 | | :white_circle: | RotatE | 2019 | ICLR'19, [arXiv](http://arxiv.org/abs/1902.10197) | | 70 | | :white_circle: | QuatE | 2019 | NeurIPS'19, [arXiv](http://arxiv.org/abs/1904.10281) | | 71 | | :white_circle: | ConvR | 2019 | NAACL-HLT 2019, [Aclanthology](https://aclanthology.org/N19-1103) | | 72 | | :white_circle: | KG-BERT | 2019 | [arXiv](http://arxiv.org/abs/1909.03193) | | 73 | | :white_circle: | PairRE | 2021 | ACL-IJCNLP 2021, [Aclanthology](https://aclanthology.org/2021.acl-long.336) | | 74 | 75 | 76 | ## How to contribute it? 77 | 78 | If you have read the code of this project, then you should have understood the code style of this library. All you need to do is to extend it as you see fit and submit pull requests after successful testing. 79 | -------------------------------------------------------------------------------- /examples/DistMult/DistMult-FB15k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. # /examples/ 4 | cd ../krl # /krl 5 | python typer_app.py DistMult train --dataset-name "FB15k"\ 6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \ 7 | --batch-size 128 \ 8 | --valid-batch-size 32 \ 9 | --valid-freq 3 \ 10 | --lr 0.001 \ 11 | --epoch-size 500 \ 12 | --embed-dim 50 \ 13 | --alpha 0.001 \ 14 | --regul-type F2 \ 15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/distmult_fb15k.ckpt \ 16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/distmult_fb15k_metrics.txt 17 | -------------------------------------------------------------------------------- /examples/RESCAL/RESCAL-FB15k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. # /examples/ 4 | cd ../krl # /krl 5 | python typer_app.py RESCAL train --dataset-name "FB15k"\ 6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \ 7 | --batch-size 128 \ 8 | --valid-batch-size 32 \ 9 | --valid-freq 3 \ 10 | --lr 0.001 \ 11 | --epoch-size 500 \ 12 | --embed-dim 50 \ 13 | --alpha 0.001 \ 14 | --regul-type F2 \ 15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k.ckpt \ 16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k_metrics.txt 17 | -------------------------------------------------------------------------------- /examples/TransE/TransE-FB15k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. # /examples 4 | cd ../krl # /krl 5 | python typer_app.py TransE train --dataset-name "FB15k"\ 6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \ 7 | --batch-size 128 \ 8 | --valid-batch-size 64 \ 9 | --valid-freq 5 \ 10 | --lr 0.001 \ 11 | --epoch-size 500 \ 12 | --embed-dim 50 \ 13 | --norm 1 \ 14 | --margin 2.0 \ 15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k.ckpt \ 16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k_metrics.txt 17 | -------------------------------------------------------------------------------- /examples/TransH/TransH-FB15k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. # /examples/ 4 | cd ../krl # /krl 5 | python typer_app.py TransH train --dataset-name "FB15k"\ 6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \ 7 | --batch-size 4800 \ 8 | --valid-batch-size 64 \ 9 | --valid-freq 5 \ 10 | --lr 0.001 \ 11 | --epoch-size 10 \ 12 | --embed-dim 100 \ 13 | --norm 1 \ 14 | --margin 2.0 \ 15 | --c 1.0 \ 16 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transh_fb15k.ckpt \ 17 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transh_fb15k_metrics.txt 18 | -------------------------------------------------------------------------------- /examples/TransR/TransR-FB15k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. # /examples/ 4 | cd ../krl # /krl 5 | python typer_app.py TransR train --dataset-name "FB15k"\ 6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \ 7 | --batch-size 4800 \ 8 | --valid-batch-size 32 \ 9 | --valid-freq 25 \ 10 | --lr 0.001 \ 11 | --epoch-size 400 \ 12 | --embed-dim 50 \ 13 | --norm 1 \ 14 | --margin 1.0 \ 15 | --c 0.1 \ 16 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transr_fb15k.ckpt \ 17 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transr_fb15k_metrics.txt 18 | -------------------------------------------------------------------------------- /images/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/images/structure.png -------------------------------------------------------------------------------- /krl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/__init__.py -------------------------------------------------------------------------------- /krl/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for various KRL models. 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch 7 | from abc import ABC, abstractmethod 8 | 9 | from .dataset import KRLDatasetDict 10 | from .config import TrainConf, HyperParam 11 | 12 | 13 | class KRLModel(nn.Module, ABC): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | @abstractmethod 18 | def embed(self, triples): 19 | """get the embeddings of the triples 20 | 21 | :param triples: a batch of triples, which consists of (heads, rels, tails) id. 22 | :return: the embedding of (heads, rels, tails) 23 | """ 24 | pass 25 | 26 | @abstractmethod 27 | def loss(self): 28 | """Return model losses based on the input. 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def forward(self): 34 | """Return model losses based on the input. 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def predict(self, triples: torch.Tensor): 40 | """Calculated dissimilarity score for given triplets. 41 | 42 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 43 | :return: dissimilarity score for given triplets 44 | """ 45 | pass 46 | 47 | 48 | class TransXBaseModel(KRLModel): 49 | pass 50 | 51 | 52 | class ModelMain(ABC): 53 | """Act as a main function for a KRL model. 54 | """ 55 | def __init__(self) -> None: 56 | super().__init__() 57 | 58 | @abstractmethod 59 | def __call__(self): 60 | pass 61 | 62 | 63 | class LitModelMain(ModelMain): 64 | def __init__( 65 | self, 66 | dataset: KRLDatasetDict, 67 | train_conf: TrainConf, 68 | seed: int = None, 69 | ) -> None: 70 | super().__init__() 71 | self.datasets = dataset 72 | self.dataset_conf = dataset.dataset_conf 73 | self.train_conf = train_conf 74 | self.seed = seed 75 | self.params = None 76 | 77 | @abstractmethod 78 | def __call__(self): 79 | pass -------------------------------------------------------------------------------- /krl/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | 用于填写配置信息的 Settings 3 | """ 4 | 5 | from pydantic import BaseModel, Field 6 | from abc import ABC 7 | from pathlib import Path 8 | from typing import Optional, Mapping 9 | 10 | 11 | ######## Dataset config ######## 12 | 13 | 14 | class DatasetConf(BaseModel, ABC): 15 | dataset_name: str = Field(title='数据集的名称,方便打印时查看') 16 | 17 | 18 | class BuiletinDatasetConf(DatasetConf): 19 | """ 20 | 内置的数据集的相关配置信息 21 | """ 22 | source: str = Field(title='数据集的来源') 23 | 24 | 25 | class BuiletinHuggingfaceDatasetConf(BuiletinDatasetConf): 26 | """ 27 | 存放于 Hugging Face datasets 上的数据集 28 | :param BuiletinDatasetConf: _description_ 29 | """ 30 | huggingface_repo: str = Field(title='Hugging Face 中存放该数据集的 repo') 31 | 32 | 33 | class LocalDatasetConf(DatasetConf): 34 | """ 35 | 存放于本地的数据集的相关配置信息 36 | """ 37 | base_dir: Optional[Path] = Field(title='数据集的目录') 38 | entity2id_path: Optional[str] = Field(default='entity2id.txt', title='entity2id 的文件名') 39 | relation2id_path: Optional[str] = Field(default='relation2id.txt', title='relation2id 的文件名') 40 | train_path: Optional[str] = Field(default='train.txt', title='training set 的文件') 41 | valid_path: Optional[str] = Field(default='valid.txt', title='valid set 的文件') 42 | test_path: Optional[str] = Field(default='test.txt', title='testing set 的文件') 43 | 44 | 45 | ######## Dataset meta-data ######## 46 | 47 | class KRLDatasetMeta(BaseModel): 48 | entity2id: Mapping[str, int] 49 | rel2id: Mapping[str, int] 50 | 51 | 52 | ######## Hyper-parameters config ######## 53 | 54 | 55 | class HyperParam(BaseModel, ABC): 56 | """ 57 | 超参数,所有超参数的 Config 类都应该继承于它 58 | """ 59 | batch_size: int = 128 60 | valid_batch_size: int = 64 61 | learning_rate: float = 0.001 62 | optimizer: str = Field(defualt='adam', title='optimizer name') 63 | epoch_size: int = 500 64 | valid_freq: int = Field(defualt=5, title='训练过程中,每隔多少次就做一次 valid 来验证是否保存模型') 65 | num_works: int = Field(default=64, title='`num_works` in dataloader') 66 | early_stoping_patience: int = Field(default=5, title='the patience of EarlyStoping') 67 | 68 | 69 | 70 | ######## Training config ######## 71 | 72 | 73 | class TrainConf(BaseModel): 74 | """ 75 | 训练的一些配置 76 | """ 77 | logs_dir: Path = Field(title='The directory used to keep the log.') 78 | 79 | 80 | ######## Other HyperParam class for various models ########## 81 | 82 | class TransHyperParam(HyperParam): 83 | """ 84 | Trans 系列模型的超参数类 85 | """ 86 | embed_dim: int = 50 87 | norm: int = 1 88 | margin: int = 2.0 89 | -------------------------------------------------------------------------------- /krl/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .krl_dataset import KRLDataset, LocalKRLDataset, create_mapping, BuiletinHuggingfaceDataset, KRLDatasetDict 2 | from .instance import load_krl_dataset -------------------------------------------------------------------------------- /krl/dataset/instance/__init__.py: -------------------------------------------------------------------------------- 1 | from .index import load_krl_dataset 2 | -------------------------------------------------------------------------------- /krl/dataset/instance/huggingface_krl_datasets_conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "fb15k": { 3 | "name": "FB15k", 4 | "repo": "VLyb/FB15k" 5 | }, 6 | "wn18": { 7 | "name": "WN18", 8 | "repo": "VLyb/WN18" 9 | }, 10 | "fb15k-237": { 11 | "name": "FB15k-237", 12 | "repo": "VLyb/FB15k-237" 13 | }, 14 | "wn18rr": { 15 | "name": "WN18RR", 16 | "repo": "VLyb/WN18RR" 17 | }, 18 | "yago3-10": { 19 | "name": "YAGO3-10", 20 | "repo": "VLyb/YAGO3-10" 21 | }, 22 | "nations": { 23 | "name": "Nations", 24 | "repo": "VLyb/Nations" 25 | }, 26 | "dbpedia50": { 27 | "name": "DBpedia50", 28 | "repo": "VLyb/DBpedia50" 29 | }, 30 | "dbpedia500": { 31 | "name": "DBpedia500", 32 | "repo": "VLyb/DBpedia500" 33 | }, 34 | "kinship": { 35 | "name": "Kinship", 36 | "repo": "VLyb/Kinship" 37 | }, 38 | "umls": { 39 | "name": "UMLS", 40 | "repo": "VLyb/UMLS" 41 | } 42 | } -------------------------------------------------------------------------------- /krl/dataset/instance/index.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | from pathlib import Path 4 | 5 | from .utils import convert_huggingface_dataset 6 | from ..krl_dataset import BuiletinHuggingfaceDataset, KRLDatasetDict 7 | from krl.config import BuiletinHuggingfaceDatasetConf 8 | 9 | 10 | 11 | def __load_huggingface_krl_dataset(dataset_name: str) -> KRLDatasetDict: 12 | json_path = Path(__file__).parent / 'huggingface_krl_datasets_conf.json' 13 | with json_path.open() as f: 14 | name_to_confs = json.load(f) 15 | conf_dict = name_to_confs.get(dataset_name.lower()) 16 | if conf_dict is None: 17 | raise NotImplemented(f"dataset {dataset_name} hasn't been implemented.") 18 | conf = BuiletinHuggingfaceDatasetConf( 19 | dataset_name=conf_dict['name'], 20 | source='HuggingFace', 21 | huggingface_repo=conf_dict['repo'] 22 | ) 23 | dataset_dict = load_dataset(conf.huggingface_repo) 24 | train_triples, valid_triples, test_triples, dataset_meta = convert_huggingface_dataset(dataset_dict) 25 | return KRLDatasetDict( 26 | train=BuiletinHuggingfaceDataset(conf, 'train', dataset_meta, train_triples), 27 | valid=BuiletinHuggingfaceDataset(conf, 'valid', dataset_meta, valid_triples), 28 | test=BuiletinHuggingfaceDataset(conf, 'test', dataset_meta, test_triples), 29 | meta=dataset_meta, 30 | dataset_conf=conf 31 | ) 32 | 33 | 34 | def load_krl_dataset(dataset_name: str): 35 | return __load_huggingface_krl_dataset(dataset_name) 36 | -------------------------------------------------------------------------------- /krl/dataset/instance/utils.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset_dict import DatasetDict as HuggingfaceDatasetDict 2 | from datasets.arrow_dataset import Dataset as HuggingFaceDataset 3 | from typing import Tuple, List 4 | 5 | from ..krl_dataset import BuiletinHuggingfaceDataset 6 | from ...config import KRLDatasetMeta 7 | 8 | 9 | Triple = Tuple[int, int, int] 10 | Triples = List[Triple] 11 | 12 | def convert_huggingface_dataset( 13 | dataset_dict: HuggingfaceDatasetDict 14 | ) -> Tuple[Triples, Triples, Triples, KRLDatasetMeta]: 15 | entity2id = {} 16 | rel2id = {} 17 | 18 | def read_dataset(dataset: HuggingFaceDataset): 19 | triples = [] 20 | for triple in dataset: 21 | head, rel, tail = str(triple['head']), str(triple['relation']), str(triple['tail']) 22 | if head not in entity2id: 23 | entity2id[head] = len(entity2id) 24 | head_id = entity2id[head] 25 | if tail not in entity2id: 26 | entity2id[tail] = len(entity2id) 27 | tail_id = entity2id[tail] 28 | if rel not in rel2id: 29 | rel2id[rel] = len(rel2id) 30 | rel_id = rel2id[rel] 31 | triples.append((head_id, rel_id, tail_id)) 32 | return triples 33 | 34 | train_triples = read_dataset(dataset_dict['train']) 35 | valid_triples = read_dataset(dataset_dict['validation']) 36 | test_triples = read_dataset(dataset_dict['test']) 37 | 38 | dataset_meta = KRLDatasetMeta( 39 | entity2id=entity2id, 40 | rel2id=rel2id 41 | ) 42 | 43 | return train_triples, valid_triples, test_triples, dataset_meta 44 | 45 | -------------------------------------------------------------------------------- /krl/dataset/krl_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dataset class used to read KRL data, such as FB15k 3 | """ 4 | 5 | from typing import Literal, Tuple, Dict, List, Mapping 6 | from torch.utils.data import Dataset 7 | from abc import ABC 8 | from pydantic import BaseModel 9 | 10 | from ..config import DatasetConf, LocalDatasetConf, BuiletinHuggingfaceDatasetConf, KRLDatasetMeta 11 | 12 | 13 | EntityMapping = Dict[str, int] 14 | RelMapping = Dict[str, int] 15 | Triple = List[int] 16 | 17 | def create_mapping(dataset_conf: DatasetConf) -> KRLDatasetMeta: 18 | """ 19 | create mapping of `entity2id` and `relation2id` 20 | """ 21 | # 读取 entity2id 22 | entity2id = dict() 23 | entity2id_path = dataset_conf.base_dir / dataset_conf.entity2id_path 24 | if not entity2id_path.exists(): 25 | raise FileNotFoundError(f'{entity2id_path} not found.') 26 | with entity2id_path.open() as f: 27 | for line in f: 28 | entity, entity_id = line.split() 29 | entity = entity.strip() 30 | entity_id = int(entity_id.strip()) 31 | entity2id[entity] = entity_id 32 | # 读取 relation2id 33 | rel2id = dict() 34 | rel2id_path = dataset_conf.base_dir / dataset_conf.relation2id_path 35 | if not rel2id_path.exists(): 36 | raise FileNotFoundError(f'{rel2id_path} not found.') 37 | with rel2id_path.open() as f: 38 | for line in f: 39 | rel, rel_id = line.split() 40 | rel = rel.strip() 41 | rel_id = int(rel_id.strip()) 42 | rel2id[rel] = rel_id 43 | return KRLDatasetMeta( 44 | entity2id=entity2id, 45 | rel2id=rel2id 46 | ) 47 | 48 | 49 | class KRLDataset(Dataset, ABC): 50 | """ 51 | KRL 数据集的抽象类 52 | :param Dataset: _description_ 53 | :param ABC: _description_ 54 | """ 55 | def __init__(self, 56 | dataset_conf: DatasetConf, 57 | mode: Literal['train', 'valid', 'test'], 58 | dataset_meta: KRLDatasetMeta) -> None: 59 | super().__init__() 60 | self.dataset_name = dataset_conf.dataset_name 61 | self.conf = dataset_conf 62 | if mode not in {'train', 'valid', 'test'}: 63 | raise ValueError(f'dataset mode not support:{mode} mode') 64 | self.mode = mode 65 | self.triples 66 | self.meta = dataset_meta 67 | 68 | 69 | class LocalKRLDataset(Dataset): 70 | """ 71 | 存放于本地的数据集 72 | """ 73 | def __init__(self, 74 | dataset_conf: LocalDatasetConf, 75 | mode: Literal['train', 'valid', 'test'], 76 | dataset_meta: KRLDatasetMeta) -> None: 77 | super().__init__() 78 | self.conf = dataset_conf 79 | if mode not in {'train', 'valid', 'test'}: 80 | raise ValueError(f'dataset mode not support:{mode} mode') 81 | self.mode = mode 82 | self.triples = [] 83 | self.meta = dataset_meta 84 | self._read_triples() # 读取数据集,并获得所有的 triples 85 | 86 | def _split_and_to_id(self, line: str) -> Triple: 87 | """将数据集文件中的一行数据进行切分,并将 entity 和 rel 转换成 id 88 | 89 | :param line: 数据集的一行数据 90 | :return: [head_id, rel_id, tail_id] 91 | """ 92 | head, tail, rel = line.split() 93 | head_id = self.meta.entity2id[head.strip()] 94 | tail_id = self.meta.entity2id[tail.strip()] 95 | rel_id = self.meta.rel2id[rel.strip()] 96 | return (head_id, rel_id, tail_id) 97 | 98 | def _read_triples(self): 99 | data_path = { 100 | 'train': self.conf.train_path, 101 | 'valid': self.conf.valid_path, 102 | 'test': self.conf.test_path 103 | }.get(self.mode) 104 | p = self.conf.base_dir / data_path 105 | if not p.exists(): 106 | raise FileNotFoundError(f'{p} not found.') 107 | with p.open() as f: 108 | self.triples = [self._split_and_to_id(line) for line in f] 109 | 110 | def __len__(self): 111 | """Denotes the total number of samples.""" 112 | return len(self.triples) 113 | 114 | def __getitem__(self, index) -> Triple: 115 | """Returns (head id, relation id, tail id).""" 116 | triple = self.triples[index] 117 | return triple[0], triple[1], triple[2] 118 | 119 | 120 | class BuiletinHuggingfaceDataset(KRLDataset): 121 | def __init__( 122 | self, 123 | dataset_conf: BuiletinHuggingfaceDatasetConf, 124 | mode: Literal['train', 'valid', 'test'], 125 | dataset_mata: KRLDatasetMeta, 126 | triples: List[Tuple[int, int, int]], 127 | ) -> None: 128 | self.triples = triples 129 | super().__init__(dataset_conf, mode, dataset_mata) 130 | 131 | def __len__(self): 132 | """Denotes the total number of samples.""" 133 | return len(self.triples) 134 | 135 | def __getitem__(self, index) -> Triple: 136 | """Returns (head id, relation id, tail id).""" 137 | triple = self.triples[index] 138 | return triple[0], triple[1], triple[2] 139 | 140 | 141 | class KRLDatasetDict(BaseModel): 142 | train: KRLDataset 143 | valid: KRLDataset 144 | test: KRLDataset 145 | 146 | meta: KRLDatasetMeta 147 | 148 | dataset_conf: DatasetConf 149 | 150 | class Config: 151 | arbitrary_types_allowed = True 152 | -------------------------------------------------------------------------------- /krl/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from typing import List 4 | 5 | from .metric import KRLMetricBase, RankMetric, MetricEnum 6 | 7 | 8 | def cal_hits_at_k(predictions: torch.Tensor, 9 | ground_truth_idx: torch.Tensor, 10 | device: torch.device, 11 | k: int) -> float: 12 | """Calculates number of hits@k. 13 | 14 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions 15 | must be sorted in class ids order 16 | :param ground_truth_idx: Bx1 tensor with index of ground truth class 17 | :param k: number of top K results to be considered as hits 18 | :return: Hits@K scoreH 19 | """ 20 | assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size 21 | 22 | zero_tensor = torch.tensor([0], device=device) 23 | one_tensor = torch.tensor([1], device=device) 24 | _, indices = predictions.topk(k, largest=False) # indices: [batch_size, k] 25 | where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool 26 | hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item() 27 | return hits 28 | 29 | def cal_mrr(predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float: 30 | """Calculates mean reciprocal rank (MRR) for given predictions and ground truth values. 31 | 32 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions 33 | must be sorted in class ids order 34 | :param ground_truth_idx: Bx1 tensor with index of ground truth class 35 | :return: Mean reciprocal rank score 36 | """ 37 | assert predictions.size(0) == ground_truth_idx.size(0) 38 | 39 | indices = predictions.argsort() 40 | return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item() 41 | 42 | 43 | 44 | class Evaluator(ABC): 45 | """ 46 | Every evaluator should derive this base class. 47 | """ 48 | 49 | def __init__(self, device: torch.device) -> None: 50 | super().__init__() 51 | self.device = device 52 | 53 | @abstractmethod 54 | def evaluate( 55 | self, 56 | predictions: torch.Tensor, 57 | ground_truth_idx: torch.Tensor 58 | ) -> KRLMetricBase: 59 | """ 60 | Calculate the metrics of model's prediction. 61 | 62 | :param predictions: _description_ 63 | :param ground_truth_idx: _description_ 64 | :return: _description_ 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def clear(self): 70 | """ 71 | Clear this evaluator for reusing. 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def reset_metrics(self, metrics: List[MetricEnum]): 77 | """ 78 | reset the metrics that you want to calculate. 79 | 80 | :param metrics: the metric list. 81 | :type metrics: List[MetricEnum] 82 | """ 83 | pass 84 | 85 | def export_metrics(self) -> KRLMetricBase: 86 | """ 87 | export the metric result stored in evaluator. 88 | """ 89 | pass 90 | 91 | 92 | class RankEvaluator(Evaluator): 93 | 94 | _SUPPORT_METRICS = { 95 | MetricEnum.MRR, 96 | MetricEnum.HITS_AT_1, 97 | MetricEnum.HITS_AT_3, 98 | MetricEnum.HITS_AT_10 99 | } 100 | 101 | def __init__( 102 | self, 103 | device: torch.device, 104 | metrics: List[MetricEnum] 105 | ) -> None: 106 | """ 107 | :param metrics: The metrics that you want to calcualte. 108 | """ 109 | super().__init__(device) 110 | self.example_cnt = 0 111 | self.metrics = None 112 | self._mrr_sum = None 113 | self._hits_at_1_sum = None 114 | self._hits_at_3_sum = None 115 | self._hits_at_10_sum = None 116 | # checks the metrics that you want to calcualte 117 | self.reset_metrics(metrics) 118 | # set to 0 if you want to calcualte this metric 119 | self.clear() 120 | 121 | def clear(self): 122 | self.example_cnt = 0 123 | self._mrr_sum = None if MetricEnum.MRR not in self.metrics else 0 124 | self._hits_at_1_sum = None if MetricEnum.HITS_AT_1 not in self.metrics else 0 125 | self._hits_at_3_sum = None if MetricEnum.HITS_AT_3 not in self.metrics else 0 126 | self._hits_at_10_sum = None if MetricEnum.HITS_AT_10 not in self.metrics else 0 127 | 128 | def reset_metrics(self, metrics: List[MetricEnum]): 129 | for m in metrics: 130 | if m not in RankEvaluator._SUPPORT_METRICS: 131 | raise NotImplementedError(f"Evaluator don't support metric: {m.value}") 132 | self.metrics = set(metrics) 133 | 134 | def evaluate( 135 | self, 136 | predictions: torch.Tensor, 137 | ground_truth_idx: torch.Tensor 138 | ): 139 | self.example_cnt += predictions.size(0) 140 | if MetricEnum.MRR in self.metrics: 141 | self._mrr_sum += cal_mrr(predictions, ground_truth_idx) 142 | if MetricEnum.HITS_AT_1 in self.metrics: 143 | self._hits_at_1_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 1) 144 | if MetricEnum.HITS_AT_3 in self.metrics: 145 | self._hits_at_3_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 3) 146 | if MetricEnum.HITS_AT_10 in self.metrics: 147 | self._hits_at_10_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 10) 148 | 149 | def export_metrics(self) -> RankMetric: 150 | """ 151 | Export the metric result stored in evaluator 152 | """ 153 | result = RankMetric( 154 | mrr=None if MetricEnum.MRR not in self.metrics else self._percentage(self._mrr_sum), 155 | hits_at_1=None if MetricEnum.HITS_AT_1 not in self.metrics else self._percentage(self._hits_at_1_sum), 156 | hits_at_3=None if MetricEnum.HITS_AT_3 not in self.metrics else self._percentage(self._hits_at_3_sum), 157 | hits_at_10=None if MetricEnum.HITS_AT_10 not in self.metrics else self._percentage(self._hits_at_10_sum) 158 | ) 159 | return result 160 | 161 | def _percentage(self, sum): 162 | return sum / self.example_cnt * 100 -------------------------------------------------------------------------------- /krl/lit_model/TransXLitModel.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from typing import List, Any 5 | 6 | from ..base_model import TransXBaseModel 7 | from ..config import HyperParam 8 | from ..negative_sampler import NegativeSampler 9 | from ..metric import HitsAtK, MRR 10 | from ..dataset import KRLDatasetDict 11 | from .. import utils 12 | 13 | 14 | class TransXLitModel(pl.LightningModule): 15 | def __init__( 16 | self, 17 | model: TransXBaseModel, 18 | dataset_dict: KRLDatasetDict, 19 | train_neg_sampler: NegativeSampler, 20 | hyper_params: HyperParam 21 | ) -> None: 22 | super().__init__() 23 | self.model = model 24 | self.model_device = next(model.parameters()).device 25 | self.dataset_dict = dataset_dict 26 | self.train_neg_sampler = train_neg_sampler 27 | self.params = hyper_params 28 | self.val_hits10 = HitsAtK(10) 29 | self.test_hits1 = HitsAtK(1) 30 | self.test_hits10 = HitsAtK(10) 31 | self.test_mrr = MRR() 32 | 33 | def training_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor): 34 | """ 35 | training step 36 | :param batch: [3, batch_size] 37 | :param batch_idx: [batch_size] 38 | """ 39 | pos_heads, pos_rels, pos_tails = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device) 40 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3] 41 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3] 42 | # calculte loss 43 | loss, _, _ = self.model.forward(pos_triples, neg_triples) 44 | return loss 45 | 46 | def validation_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor): 47 | preds, target = self._get_preds_and_target(batch) 48 | self.val_hits10.update(preds, target) 49 | 50 | def validation_epoch_end(self, outputs: List[Any]) -> None: 51 | val_hits_at_10 = self.val_hits10.compute() 52 | self.val_hits10.reset() 53 | self.log('val_hits@10', val_hits_at_10) 54 | 55 | def test_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor): 56 | preds, target = self._get_preds_and_target(batch) 57 | self.test_hits1.update(preds, target) 58 | self.test_hits10.update(preds, target) 59 | self.test_mrr.update(preds, target) 60 | 61 | def test_epoch_end(self, outputs: List[Any]) -> None: 62 | result = { 63 | 'hits_at_1': self.test_hits1.compute(), 64 | 'hits_at_10': self.test_hits10.compute(), 65 | 'mrr': self.test_mrr.compute() 66 | } 67 | 68 | self.test_hits1.reset() 69 | self.test_hits10.reset() 70 | self.test_mrr.reset() 71 | 72 | self.log_dict(result) 73 | 74 | def configure_optimizers(self): 75 | optimizer = utils.create_optimizer(self.params.optimizer, self.model, self.params.learning_rate) 76 | milestones = int(self.params.epoch_size / 2) 77 | stepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[milestones], gamma=0.1) 78 | return { 79 | 'optimizer': optimizer, 80 | 'lr_scheduler': stepLR 81 | } 82 | 83 | def train_dataloader(self) -> DataLoader: 84 | return DataLoader( 85 | self.dataset_dict.train, 86 | batch_size=self.params.batch_size, 87 | num_workers=64 88 | ) 89 | 90 | def val_dataloader(self) -> DataLoader: 91 | return DataLoader( 92 | self.dataset_dict.valid, 93 | batch_size=self.params.valid_batch_size, 94 | num_workers=64 95 | ) 96 | 97 | def test_dataloader(self) -> DataLoader: 98 | return DataLoader( 99 | self.dataset_dict.test, 100 | batch_size=self.params.valid_batch_size, 101 | num_workers=64 102 | ) 103 | 104 | 105 | def _get_preds_and_target(self, batch: List[torch.Tensor]): 106 | ent_num = len(self.dataset_dict.meta.entity2id) 107 | entity_ids = torch.arange(0, ent_num, device=self.device) 108 | entity_ids.unsqueeze_(0) 109 | heads, rels, tails = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device) 110 | batch_size = heads.size(0) 111 | all_entities = entity_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num] 112 | # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num] 113 | heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num] 114 | rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num) 115 | tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num) 116 | # check all possible tails 117 | triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3] 118 | tails_predictions = self.model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num] 119 | # check all possible heads 120 | triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3) 121 | heads_predictions = self.model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num] 122 | 123 | # Concept preditions 124 | predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num] 125 | ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1] 126 | 127 | return predictions, ground_truth_entity_id 128 | -------------------------------------------------------------------------------- /krl/lit_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .TransXLitModel import TransXLitModel -------------------------------------------------------------------------------- /krl/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate the metrics of KRL models. 3 | """ 4 | from pydantic import BaseModel 5 | from typing import Optional 6 | from abc import ABC 7 | from enum import Enum 8 | import torchmetrics 9 | import torch 10 | 11 | 12 | class MetricEnum(Enum): 13 | """ 14 | Enumerate all metric name. This name is the attribute of Metric Model which derived from `KRLMetricBase`. 15 | """ 16 | MRR = 'mrr' 17 | HITS_AT_1 = 'hits_at_1' 18 | HITS_AT_3 = 'hits_at_3' 19 | HITS_AT_10 = 'hits_at_10' 20 | 21 | 22 | class KRLMetricBase(BaseModel, ABC): 23 | """All metric model class should derive this class. 24 | """ 25 | pass 26 | 27 | 28 | class RankMetric(KRLMetricBase): 29 | mrr: Optional[float] 30 | hits_at_1: Optional[float] 31 | hits_at_3: Optional[float] 32 | hits_at_10: Optional[float] 33 | 34 | 35 | class MRR(torchmetrics.Metric): 36 | def __init__(self) -> None: 37 | super().__init__() 38 | self.add_state("mrr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 39 | self.add_state("example_cnt", default=torch.tensor(0), dist_reduce_fx="sum") 40 | 41 | def _cal_mrr(self, predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float: 42 | """Calculates mean reciprocal rank (MRR) for given predictions and ground truth values. 43 | 44 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions 45 | must be sorted in class ids order 46 | :param ground_truth_idx: Bx1 tensor with index of ground truth class 47 | :return: Mean reciprocal rank score 48 | """ 49 | assert predictions.size(0) == ground_truth_idx.size(0) 50 | 51 | indices = predictions.argsort() 52 | return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item() 53 | 54 | def update(self, preds: torch.Tensor, target: torch.Tensor): 55 | self.mrr_sum += self._cal_mrr(preds, target) 56 | self.example_cnt += preds.size(0) 57 | 58 | def compute(self): 59 | return self.mrr_sum.float() / self.example_cnt * 100 60 | 61 | 62 | class HitsAtK(torchmetrics.Metric): 63 | def __init__(self, k: int) -> None: 64 | super().__init__() 65 | self.k = k 66 | self.add_state("hits_sum", default=torch.tensor(0), dist_reduce_fx="sum") 67 | self.add_state("example_cnt", default=torch.tensor(0), dist_reduce_fx="sum") 68 | 69 | def _cal_hits_at_k( 70 | self, 71 | predictions: torch.Tensor, 72 | ground_truth_idx: torch.Tensor 73 | ) -> float: 74 | """Calculates number of hits@k. 75 | 76 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions 77 | must be sorted in class ids order 78 | :param ground_truth_idx: Bx1 tensor with index of ground truth class 79 | :param k: number of top K results to be considered as hits 80 | :return: Hits@K scoreH 81 | """ 82 | assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size 83 | 84 | device = predictions.device 85 | 86 | zero_tensor = torch.tensor([0], device=device) 87 | one_tensor = torch.tensor([1], device=device) 88 | _, indices = predictions.topk(self.k, largest=False) # indices: [batch_size, k] 89 | where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool 90 | hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item() 91 | return hits 92 | 93 | def update(self, preds: torch.Tensor, target: torch.Tensor): 94 | self.hits_sum += self._cal_hits_at_k(preds, target) 95 | self.example_cnt += preds.size(0) 96 | 97 | def compute(self) -> float: 98 | return self.hits_sum.float() / self.example_cnt * 100 99 | -------------------------------------------------------------------------------- /krl/metric_fomatter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from .metric import KRLMetricBase, RankMetric 5 | from .config import DatasetConf 6 | 7 | 8 | class MetricFormatter(ABC): 9 | """ 10 | Every formatter class should derive this class. 11 | """ 12 | @abstractmethod 13 | def convert(self, metric: KRLMetricBase) -> Any: 14 | """ 15 | Convert metric model into a specific format 16 | 17 | :param metric: The metric instance that we want to convert. 18 | """ 19 | pass 20 | 21 | 22 | _STRING_TEMPLATE = """dataset: {dataset_name}, 23 | Hits@1: {hits_at_1}, 24 | Hits@3: {hits_at_3}, 25 | Hist@10: {hits_at_10}, 26 | MRR: {mrr}. 27 | """ 28 | 29 | class StringFormatter(MetricFormatter): 30 | """ 31 | Convert the metric into string. 32 | """ 33 | def convert(self, metric: RankMetric, dataset_conf: DatasetConf) -> str: 34 | return _STRING_TEMPLATE.format( 35 | dataset_name=dataset_conf.dataset_name, 36 | hits_at_1=metric.hits_at_1, 37 | hits_at_3=metric.hits_at_3, 38 | hits_at_10=metric.hits_at_10, 39 | mrr=metric.mrr 40 | ) 41 | -------------------------------------------------------------------------------- /krl/models/DistMult.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | 4 | - https://github.com/Sujit-O/pykg2vec/blob/master/pykg2vec/models/pointwise.py 5 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/DistMult.py 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from pydantic import Field 12 | from typing import Literal 13 | 14 | from ..base_model import KRLModel, ModelMain 15 | from ..config import HyperParam, LocalDatasetConf, TrainConf 16 | from ..dataset import create_mapping, LocalKRLDataset 17 | from ..negative_sampler import BernNegSampler 18 | from .. import utils 19 | from ..trainer import RescalTrainer 20 | from ..metric import MetricEnum 21 | from ..evaluator import RankEvaluator 22 | from .. import storage 23 | from ..metric_fomatter import StringFormatter 24 | from ..serializer import FileSerializer 25 | 26 | 27 | class DistMultHyperParam(HyperParam): 28 | """Hyper-parameters of DistMult 29 | """ 30 | embed_dim: int 31 | alpha: float = Field(0.001, title='regularization parameter') 32 | regul_type: Literal['F2', 'N3'] = Field('F2', title='regularization type') 33 | 34 | 35 | class DistMult(KRLModel): 36 | def __init__( 37 | self, 38 | ent_num: int, 39 | rel_num: int, 40 | device: torch.device, 41 | hyper_params: DistMultHyperParam 42 | ): 43 | super().__init__() 44 | self.ent_num = ent_num 45 | self.rel_num = rel_num 46 | self.device = device 47 | self.embed_dim = hyper_params.embed_dim 48 | self.alpha = hyper_params.alpha 49 | self.regul_type = hyper_params.regul_type.upper() 50 | 51 | # initialize entity embedding 52 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim) 53 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 54 | 55 | # initialize relation embedding 56 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim) 57 | nn.init.xavier_uniform_(self.rel_embedding.weight.data) 58 | 59 | self.criterion = nn.MSELoss() # 当数据的 neg_sample label 为 0 时用这个 60 | # self.criterion = lambda preds, labels: F.softplus(preds * labels).mean() # neg_sample label 为 -1 时用这个 61 | 62 | 63 | def embed(self, triples): 64 | """get the embedding of triples 65 | 66 | :param triples: [heads, rels, tails] 67 | """ 68 | assert triples.shape[1] == 3 69 | # get entity ids and relation ids 70 | heads = triples[:, 0] 71 | rels = triples[:, 1] 72 | tails = triples[:, 2] 73 | # id -> embedding 74 | h_embs = self.ent_embedding(heads) # [batch, emb] 75 | t_embs = self.ent_embedding(tails) 76 | r_embs = self.rel_embedding(rels) # [batch, emb] 77 | return h_embs, r_embs, t_embs 78 | 79 | def _get_reg(self, h_embs, r_embs, t_embs): 80 | """Calculate regularization term 81 | 82 | :param h_embs: heads embedding, size: [batch, embed] 83 | :param r_embs: rels embedding 84 | :param t_embs: tails embeddings 85 | :return: _description_ 86 | """ 87 | if self.regul_type == 'F2': 88 | regul = (torch.mean(h_embs ** 2) + torch.mean(t_embs ** 2) + torch.mean(r_embs ** 2)) / 3 89 | else: 90 | regul = torch.mean(torch.sum(h_embs ** 3, -1) + torch.sum(r_embs ** 3, -1) + torch.sum(t_embs ** 3, -1)) 91 | return regul 92 | 93 | def _scoring(self, h_embs, r_embs, t_embs): 94 | """计算一个 batch 的三元组的 scores 95 | score 越大越好,正例接近 1,负例接近 0 96 | This score can also be regard as the `pred` 97 | 98 | :param h_embs: embedding of a batch heads, size: [batch, emb] 99 | :return: size: [batch,] 100 | """ 101 | return torch.sum(h_embs * r_embs * t_embs, dim=1) 102 | 103 | def loss(self, triples: torch.Tensor, labels: torch.Tensor): 104 | """Calculate the loss 105 | 106 | :param triples: a batch of triples. size: [batch, 3] 107 | :param labels: the label of each triple, label = 1 if the triple is positive, label = 0 if the triple is negative. size: [batch,] 108 | """ 109 | assert triples.shape[1] == 3 110 | assert triples.shape[0] == labels.shape[0] 111 | 112 | h_embs, r_embs, t_embs = self.embed(triples) 113 | 114 | scores = self._scoring(h_embs, r_embs, t_embs) 115 | regul = self._get_reg(h_embs, r_embs, t_embs) 116 | loss = self.criterion(scores, labels.float()) + self.alpha * regul 117 | 118 | return loss, scores 119 | 120 | def forward(self, triples, labels): 121 | loss, scores = self.loss(triples, labels) 122 | return loss, scores 123 | 124 | def predict(self, triples): 125 | """Calculated dissimilarity score for given triplets. 126 | 127 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 128 | :return: dissimilarity score for given triplets 129 | """ 130 | h_embs, r_embs, t_embs = self.embed(triples) 131 | scores = self._scoring(h_embs, r_embs, t_embs) 132 | return -scores 133 | 134 | 135 | class DistMultMain(ModelMain): 136 | def __init__( 137 | self, 138 | dataset_conf: LocalDatasetConf, 139 | train_conf: TrainConf, 140 | hyper_params: DistMultHyperParam, 141 | device: torch.device 142 | ) -> None: 143 | super().__init__() 144 | self.dataset_conf = dataset_conf 145 | self.train_conf = train_conf 146 | self.hyper_params = hyper_params 147 | self.device = device 148 | 149 | def __call__(self): 150 | # create mapping 151 | entity2id, rel2id = create_mapping(self.dataset_conf) 152 | ent_num = len(entity2id) 153 | rel_num = len(rel2id) 154 | 155 | # create dataset and dataloader 156 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id) 157 | 158 | # create negative-sampler 159 | neg_sampler = BernNegSampler(train_dataset, self.device) 160 | 161 | # create model 162 | model = DistMult(ent_num, rel_num, self.device, self.hyper_params) 163 | model = model.to(self.device) 164 | 165 | # create optimizer 166 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate) 167 | 168 | # create trainer 169 | trainer = RescalTrainer( 170 | model=model, 171 | train_conf=self.train_conf, 172 | params=self.hyper_params, 173 | dataset_conf=self.dataset_conf, 174 | entity2id=entity2id, 175 | rel2id=rel2id, 176 | device=self.device, 177 | train_dataloder=train_dataloader, 178 | valid_dataloder=valid_dataloader, 179 | train_neg_sampler=neg_sampler, 180 | valid_neg_sampler=neg_sampler, 181 | optimzer=optimizer 182 | ) 183 | 184 | # training process 185 | trainer.run_training() 186 | 187 | # create evaluator 188 | metrics = [ 189 | MetricEnum.MRR, 190 | MetricEnum.HITS_AT_1, 191 | MetricEnum.HITS_AT_3, 192 | MetricEnum.HITS_AT_10 193 | ] 194 | evaluator = RankEvaluator(self.device, metrics) 195 | 196 | # Testing the best checkpoint on test dataset 197 | # load best model 198 | ckpt = storage.load_checkpoint(self.train_conf) 199 | model.load_state_dict(ckpt.model_state_dict) 200 | model = model.to(self.device) 201 | # create test-dataset 202 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id) 203 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size) 204 | # run inference on test-dataset 205 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator) 206 | 207 | # choice metric formatter 208 | metric_formatter = StringFormatter() 209 | 210 | # choice the way of serialize 211 | serilizer = FileSerializer(self.train_conf, self.dataset_conf) 212 | # serialize the metric 213 | serilizer.serialize(metric, metric_formatter) 214 | -------------------------------------------------------------------------------- /krl/models/RESCAL.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | 4 | - https://yubincloud.github.io/notebook-paper/KG/KRL/1101.RESCAL-and-extensions.html 5 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/RESCAL.py 6 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/RESCAL.py 7 | """ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | from pydantic import Field 15 | from typing import Literal 16 | 17 | from ..base_model import KRLModel, ModelMain 18 | from ..config import HyperParam, TrainConf 19 | from ..dataset import KRLDatasetDict 20 | from ..negative_sampler import BernNegSampler 21 | from .. import utils 22 | from ..trainer import RescalTrainer 23 | from ..metric import MetricEnum 24 | from ..evaluator import RankEvaluator 25 | from .. import storage 26 | from ..metric_fomatter import StringFormatter 27 | from ..serializer import FileSerializer 28 | 29 | 30 | class RescalHyperParam(HyperParam): 31 | """Hyper-parameters of RESCAL 32 | """ 33 | embed_dim: int 34 | alpha: float = Field(0.001, title='regularization parameter') 35 | regul_type: Literal['F2', 'N3'] = Field('F2', title='regularization type') 36 | 37 | 38 | class RESCAL(KRLModel): 39 | def __init__( 40 | self, 41 | ent_num: int, 42 | rel_num: int, 43 | device: torch.device, 44 | hyper_params: RescalHyperParam, 45 | ): 46 | super().__init__() 47 | self.ent_num = ent_num 48 | self.rel_num = rel_num 49 | self.device = device 50 | self.embed_dim = hyper_params.embed_dim 51 | self.alpha = hyper_params.alpha 52 | self.regul_type = hyper_params.regul_type.upper() 53 | 54 | # initialize entity embedding 55 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim) 56 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 57 | self.ent_embedding.weight.data = F.normalize(self.ent_embedding.weight.data, 2, 1) # 在许多实现中,这一行可以去掉 58 | 59 | # initialize relation embedding 60 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim * self.embed_dim) 61 | nn.init.xavier_uniform_(self.rel_embedding.weight.data) 62 | self.rel_embedding.weight.data = F.normalize(self.rel_embedding.weight.data, 2, 1) 63 | 64 | self.criterion = nn.MSELoss() 65 | 66 | def embed(self, triples): 67 | """Get the embeddings of a batch of triples 68 | 69 | :param triples: _description_ 70 | """ 71 | assert triples.shape[1] == 3 72 | # get entity ids and relation ids 73 | heads = triples[:, 0] 74 | rels = triples[:, 1] 75 | tails = triples[:, 2] 76 | # id -> embedding 77 | h_embs = self.ent_embedding(heads) # [batch, emb] 78 | t_embs = self.ent_embedding(tails) 79 | r_embs = self.rel_embedding(rels) # [batch, emb * emb] 80 | return h_embs, r_embs, t_embs 81 | 82 | 83 | def _scoring(self, h_embs, r_embs, t_embs): 84 | """计算一个 batch 的三元组的 scores 85 | score 越大越好,正例接近 1,负例接近 0 86 | This score can also be regard as the `pred` 87 | 88 | :param h_embs: heads embedding,size: [batch, embed] 89 | :param r_embs: rels embedding,size: [batch, embed * embed] 90 | :return: size: [batch,] 91 | """ 92 | # calcate scores 93 | r_embs = r_embs.view(-1, self.embed_dim, self.embed_dim) # [batch, emb, emb] 94 | t_embs = t_embs.view(-1, self.embed_dim, 1) # [batch, emb, 1] 95 | 96 | tr = torch.matmul(r_embs, t_embs) # [batch, emb, 1] 97 | tr = tr.view(-1, self.embed_dim) # [batch, emb] 98 | 99 | return torch.sum(h_embs * tr, dim=1) 100 | 101 | def _get_reg(self, h_embs, r_embs, t_embs): 102 | """Calculate regularization term 103 | 104 | :param h_embs: heads embedding, size: [batch, embed] 105 | :param r_embs: rels embedding 106 | :param t_embs: tails embeddings 107 | :return: _description_ 108 | """ 109 | if self.regul_type == 'F2': 110 | regul = (torch.mean(h_embs ** 2) + torch.mean(t_embs ** 2) + torch.mean(r_embs ** 2)) / 3 111 | else: 112 | regul = torch.mean(torch.sum(h_embs ** 3, -1) + torch.sum(r_embs ** 3, -1) + torch.sum(t_embs ** 3, -1)) 113 | return regul 114 | 115 | def loss(self, triples: torch.Tensor, labels: torch.Tensor): 116 | """Calculate the loss 117 | 118 | :param triples: a batch of triples. size: [batch, 3] 119 | :param labels: the label of each triple, label = 1 if the triple is positive, label = 0 if the triple is negative. size: [batch,] 120 | """ 121 | assert triples.shape[1] == 3 122 | assert triples.shape[0] == labels.shape[0] 123 | 124 | h_embs, r_embs, t_embs = self.embed(triples) 125 | 126 | scores = self._scoring(h_embs, r_embs, t_embs) 127 | regul = self._get_reg(h_embs, r_embs, t_embs) 128 | loss = self.criterion(scores, labels.float()) + self.alpha * regul 129 | 130 | return loss, scores 131 | 132 | def forward(self, triples, labels): 133 | loss, scores = self.loss(triples, labels) 134 | return loss, scores 135 | 136 | def predict(self, triples): 137 | """Calculated dissimilarity score for given triplets. 138 | 139 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 140 | :return: dissimilarity score for given triplets 141 | """ 142 | h_embs, r_embs, t_embs = self.embed(triples) 143 | return -self._scoring(h_embs, r_embs, t_embs) 144 | 145 | 146 | 147 | class RescalMain(ModelMain): 148 | def __init__( 149 | self, 150 | dataset: KRLDatasetDict, 151 | train_conf: TrainConf, 152 | hyper_params: RescalHyperParam, 153 | device: torch.device 154 | ) -> None: 155 | super().__init__() 156 | self.datasets = dataset 157 | self.dataset_conf = dataset.dataset_conf 158 | self.train_conf = train_conf 159 | self.params = hyper_params 160 | self.device = device 161 | 162 | def __call__(self): 163 | # create mapping 164 | entity2id = self.datasets.entity2id 165 | rel2id = self.datasets.rel2id 166 | ent_num = len(entity2id) 167 | rel_num = len(rel2id) 168 | 169 | # create dataset and dataloader 170 | train_dataset = self.datasets.train 171 | train_dataloader = DataLoader(train_dataset, self.params.batch_size) 172 | valid_dataset = self.datasets.valid 173 | valid_dataloader = DataLoader(valid_dataset, self.params.batch_size) 174 | 175 | # create negative-sampler 176 | neg_sampler = BernNegSampler(train_dataset, self.device) 177 | 178 | # create model 179 | model = RESCAL(ent_num, rel_num, self.device, self.params) 180 | model = model.to(self.device) 181 | 182 | # create optimizer 183 | optimizer = utils.create_optimizer(self.params.optimizer, model, self.params.learning_rate) 184 | 185 | # create trainer 186 | trainer = RescalTrainer( 187 | model=model, 188 | train_conf=self.train_conf, 189 | params=self.params, 190 | dataset_conf=self.dataset_conf, 191 | entity2id=entity2id, 192 | rel2id=rel2id, 193 | device=self.device, 194 | train_dataloder=train_dataloader, 195 | valid_dataloder=valid_dataloader, 196 | train_neg_sampler=neg_sampler, 197 | valid_neg_sampler=neg_sampler, 198 | optimzer=optimizer 199 | ) 200 | 201 | # training process 202 | trainer.run_training() 203 | 204 | # create evaluator 205 | metrics = [ 206 | MetricEnum.MRR, 207 | MetricEnum.HITS_AT_1, 208 | MetricEnum.HITS_AT_3, 209 | MetricEnum.HITS_AT_10 210 | ] 211 | evaluator = RankEvaluator(self.device, metrics) 212 | 213 | # Testing the best checkpoint on test dataset 214 | # load best model 215 | ckpt = storage.load_checkpoint(self.train_conf) 216 | model.load_state_dict(ckpt.model_state_dict) 217 | model = model.to(self.device) 218 | # create test-dataset 219 | test_dataset = self.datasets.test 220 | test_dataloder = DataLoader(test_dataset, self.params.valid_batch_size) 221 | # run inference on test-dataset 222 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator) 223 | 224 | # choice metric formatter 225 | metric_formatter = StringFormatter() 226 | 227 | # choice the way of serialize 228 | serilizer = FileSerializer(self.train_conf, self.dataset_conf) 229 | # serialize the metric 230 | serilizer.serialize(metric, metric_formatter) 231 | -------------------------------------------------------------------------------- /krl/models/TransD.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | 4 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/TransD.py 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ..config import HyperParam 10 | from ..base_model import KRLModel 11 | 12 | 13 | class TransDHyperParam(HyperParam): 14 | ent_dim: int 15 | rel_dim: int 16 | norm: int 17 | margin: float 18 | 19 | 20 | class TransD(KRLModel): 21 | def __init__( 22 | self, 23 | ent_num: int, 24 | rel_num: int, 25 | device: torch.device, 26 | hyper_params: TransDHyperParam 27 | ): 28 | super().__init__() 29 | self.ent_num = ent_num 30 | self.rel_num = rel_num 31 | self.device = device 32 | self.norm = hyper_params.norm 33 | self.ent_dim = hyper_params.ent_dim 34 | self.rel_dim = hyper_params.rel_dim 35 | self.margin = hyper_params.margin 36 | 37 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin) 38 | 39 | # 初始化 ent_embedding 40 | self.ent_embedding = nn.Embedding(self.ent_num, self.ent_dim) 41 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 42 | 43 | # 初始化 rel_embedding 44 | self.rel_embedding = nn.Embedding(self.rel_num, self.rel_dim) 45 | nn.init.xavier_uniform_(self.rel_embedding.weight.data) 46 | 47 | # 初始化 transfer embedding 48 | self.ent_transfer = nn.Embedding(self.ent_num, self.ent_dim) 49 | nn.init.xavier_uniform_(self.ent_transfer.weight.data) 50 | self.rel_transfer = nn.Embedding(self.rel_num, self.rel_dim) 51 | nn.init.xavier_uniform_(self.rel_transfer.weight.data) 52 | 53 | self.dist_fn = nn.PairwiseDistance(p=self.norm) 54 | -------------------------------------------------------------------------------- /krl/models/TransE.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.loggers import CSVLogger 6 | from lightning.pytorch.callbacks import EarlyStopping 7 | from datetime import datetime 8 | 9 | from ..base_model import TransXBaseModel, LitModelMain 10 | from ..config import TransHyperParam, TrainConf 11 | from ..dataset import KRLDatasetDict 12 | from ..negative_sampler import BernNegSampler 13 | from ..lit_model import TransXLitModel 14 | from .. import utils 15 | 16 | 17 | class TransEHyperParam(TransHyperParam): 18 | """Hyper-paramters of TransE 19 | """ 20 | pass 21 | 22 | 23 | class TransE(TransXBaseModel): 24 | def __init__(self, 25 | ent_num: int, 26 | rel_num: int, 27 | hyper_params: TransEHyperParam 28 | ): 29 | super().__init__() 30 | self.ent_num = ent_num 31 | self.rel_num = rel_num 32 | self.norm = hyper_params.norm 33 | self.embed_dim = hyper_params.embed_dim 34 | self.margin = hyper_params.margin 35 | 36 | # 初始化 ent_embedding,按照原论文的方法来初始化 37 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim) 38 | torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data) 39 | #uniform_range = 6 / np.sqrt(self.embed_dim) 40 | #self.ent_embedding.weight.data.uniform_(-uniform_range, uniform_range) 41 | 42 | # 初始化 rel_embedding 43 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim) 44 | torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data) 45 | #uniform_range = 6 / np.sqrt(self.embed_dim) 46 | #self.rel_embedding.weight.data.uniform_(-uniform_range, uniform_range) 47 | 48 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance 49 | self.criterion = nn.MarginRankingLoss(margin=self.margin) 50 | 51 | def embed(self, triples): 52 | """get the embedding of triples 53 | 54 | :param triples: [heads, rels, tails] 55 | :return: embedding of triples. 56 | """ 57 | assert triples.shape[1] == 3 58 | heads = triples[:, 0] 59 | rels = triples[:, 1] 60 | tails = triples[:, 2] 61 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim] 62 | r_embs = self.rel_embedding(rels) 63 | t_embs = self.ent_embedding(tails) 64 | return h_embs, r_embs, t_embs 65 | 66 | def _distance(self, triples): 67 | """计算一个 batch 的三元组的 distance 68 | 69 | :param triples: 一个 batch 的 triple,size: [batch, 3] 70 | :return: size: [batch,] 71 | """ 72 | h_embs, r_embs, t_embs = self.embed(triples) 73 | return self.dist_fn(h_embs + r_embs, t_embs) 74 | 75 | def loss(self, pos_distances: torch.Tensor, neg_distances: torch.Tensor): 76 | """Calculate the loss 77 | 78 | :param pos_distances: [batch, ] 79 | :param neg_distances: [batch, ] 80 | :return: loss 81 | """ 82 | ones = torch.tensor([-1], dtype=torch.long, device=pos_distances.device) 83 | return self.criterion(pos_distances, neg_distances, ones) 84 | 85 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor): 86 | """Return model losses based on the input. 87 | 88 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail) 89 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail) 90 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component 91 | """ 92 | assert pos_triples.size()[1] == 3 93 | assert neg_triples.size()[1] == 3 94 | 95 | pos_distances = self._distance(pos_triples) 96 | neg_distances = self._distance(neg_triples) 97 | loss = self.loss(pos_distances, neg_distances) 98 | return loss, pos_distances, neg_distances 99 | 100 | def predict(self, triples: torch.Tensor): 101 | """Calculated dissimilarity score for given triplets. 102 | 103 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 104 | :return: dissimilarity score for given triplets 105 | """ 106 | return self._distance(triples) 107 | 108 | 109 | class TransELitMain(LitModelMain): 110 | def __init__( 111 | self, 112 | dataset: KRLDatasetDict, 113 | train_conf: TrainConf, 114 | hyper_params: TransEHyperParam, 115 | seed: int = None, 116 | ) -> None: 117 | super().__init__( 118 | dataset, 119 | train_conf, 120 | seed 121 | ) 122 | self.params = hyper_params 123 | 124 | def __call__(self): 125 | # seed everything 126 | pl.seed_everything(self.seed) 127 | 128 | # create mapping 129 | ent_num = len(self.datasets.meta.entity2id) 130 | rel_num = len(self.datasets.meta.rel2id) 131 | 132 | # create negative-sampler 133 | train_dataset = self.datasets.train 134 | neg_sampler = BernNegSampler(train_dataset) 135 | 136 | # create model 137 | model = TransE(ent_num, rel_num, self.params) 138 | model_wrapped = TransXLitModel(model, self.datasets, neg_sampler, self.params) 139 | 140 | # callbacks 141 | early_stopping = EarlyStopping('val_hits@10', mode="max", patience=self.params.early_stoping_patience, check_on_train_epoch_end=False) 142 | 143 | # create trainer 144 | trainer = pl.Trainer( 145 | gpus="0,", 146 | max_epochs=self.params.epoch_size, 147 | logger=CSVLogger(self.train_conf.logs_dir, name=f'{model.__class__.__name__}-{self.dataset_conf.dataset_name}'), 148 | callbacks=[early_stopping] 149 | ) 150 | trainer.fit(model=model_wrapped) 151 | 152 | trainer.test(model_wrapped) 153 | -------------------------------------------------------------------------------- /krl/models/TransH.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | 4 | - https://github.com/LYuhang/Trans-Implementation/blob/master/code/models/TransH.py 5 | - https://github.com/zqhead/TransH/blob/master/TransH_torch.py 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from pydantic import Field 13 | 14 | from ..base_model import TransXBaseModel, ModelMain 15 | from ..config import TransHyperParam, LocalDatasetConf, TrainConf 16 | from ..dataset import create_mapping, LocalKRLDataset 17 | from ..negative_sampler import BernNegSampler 18 | from .. import utils 19 | from ..trainer import TransETrainer 20 | from ..metric import MetricEnum 21 | from ..evaluator import RankEvaluator 22 | from .. import storage 23 | from ..metric_fomatter import StringFormatter 24 | from ..serializer import FileSerializer 25 | 26 | 27 | class TransHHyperParam(TransHyperParam): 28 | """Hyper-paramters of TransH 29 | 30 | :param HyperParam: base hyper params model 31 | """ 32 | C: float = Field(default=0.1, description='a hyper-parameter weighting the importance of soft constraints.') 33 | eps: float = Field(default=1e-3, description='the $\episilon$ in loss function') 34 | 35 | 36 | class TransH(TransXBaseModel): 37 | def __init__( 38 | self, 39 | ent_num: int, 40 | rel_num: int, 41 | device: torch.device, 42 | hyper_params: TransHHyperParam 43 | ): 44 | super().__init__() 45 | self.ent_num = ent_num 46 | self.rel_num = rel_num 47 | self.device = device 48 | self.norm = hyper_params.norm 49 | self.embed_dim = hyper_params.embed_dim 50 | self.margin = hyper_params.margin 51 | self.C = hyper_params.C # a hyper-parameter weighting the importance of soft constraints 52 | self.eps = hyper_params.eps # the $\episilon$ in loss function 53 | 54 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin) 55 | 56 | # 初始化 ent_embedding 57 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim) 58 | nn.init.xavier_uniform_(self.ent_embedding.weight.data) 59 | 60 | # 初始化 rel_embedding,Embedding for the relation-specific translation vector $d_r$ 61 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim) 62 | nn.init.xavier_uniform_(self.rel_embedding.weight.data) 63 | 64 | # 初始化 rel_hyper_embedding,Embedding for the relation-specific hyperplane $w_r$ 65 | self.rel_hyper_embedding = nn.Embedding(self.rel_num, self.embed_dim) 66 | nn.init.xavier_uniform_(self.rel_hyper_embedding.weight.data) 67 | 68 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance 69 | 70 | def embed(self, triples): 71 | """get the embedding of triples 72 | 73 | :param triples: [heads, rels, tails] 74 | :return: embedding of triples. 75 | """ 76 | assert triples.shape[1] == 3 77 | heads = triples[:, 0] 78 | rels = triples[:, 1] 79 | tails = triples[:, 2] 80 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim] 81 | t_embs = self.ent_embedding(tails) 82 | r_embs = self.rel_embedding(rels) 83 | r_hyper_embs = self.rel_hyper_embedding(rels) # relation hyperplane, size: [batch_size, embed_dim] 84 | return h_embs, r_embs, t_embs, r_hyper_embs 85 | 86 | def _project(self, ent_embeds, rel_hyper_embeds): 87 | """Project entity embedding into relation hyperplane 88 | computational process: $h - w_r^T h w_r$ 89 | 90 | :param ent_embeds: entity embedding, size: [batch_size, embed_dim] 91 | :param rel_hyper_embeds: relation hyperplane, size: [batch_size, embed_dim] 92 | """ 93 | return ent_embeds - rel_hyper_embeds * torch.sum(ent_embeds * rel_hyper_embeds, dim=1, keepdim=True) 94 | 95 | def _distance(self, triples): 96 | """计算一个 batch 的三元组的 distance 97 | 98 | :param triples: 一个 batch 的 triple,size: [batch, 3] 99 | :return: size: [batch,] 100 | """ 101 | assert triples.shape[1] == 3 102 | # step 1: Transform index tensor to embedding tensor. 103 | h_embs, r_embs, t_embs, r_hyper_embs = self.embed(triples) 104 | # step 2: Project entity head and tail embedding to relation hyperplane 105 | h_embs = self._project(h_embs, r_hyper_embs) 106 | t_embs = self._project(t_embs, r_hyper_embs) 107 | # step 3: Calculate similarity score in relation hyperplane 108 | return self.dist_fn(h_embs + r_embs, t_embs) 109 | 110 | def _cal_margin_based_loss(self, pos_distances, neg_distances): 111 | """Calculate the margin-based loss 112 | 113 | :param pos_distances: [batch, ] 114 | :param neg_distances: [batch, ] 115 | :return: margin_based loss, size: [1,] 116 | """ 117 | ones = torch.tensor([-1], dtype=torch.long, device=self.device) 118 | return self.margin_loss_fn(pos_distances, neg_distances, ones) 119 | 120 | def _cal_scale_loss(self): 121 | """Calculate the scale loss. 122 | F.relu(x) is equal to max(x, 0). 123 | """ 124 | ent_norm = torch.norm(self.ent_embedding.weight, p=2, dim=1) # the L2 norm of entity embedding, size: [ent_num, ] 125 | scale_loss = torch.sum(F.relu(ent_norm - 1)) 126 | return scale_loss 127 | 128 | def _cal_orthogonal_loss(self): 129 | """Calculate the orthogonal loss. 130 | """ 131 | orth_loss = torch.sum(F.relu(torch.sum(self.rel_hyper_embedding.weight * self.rel_embedding.weight, dim=1, keepdim=False) / torch.norm(self.rel_embedding.weight, p=2, dim=1, keepdim=False) - self.eps ** 2)) 132 | return orth_loss 133 | 134 | def loss(self, pos_distances, neg_distances): 135 | """Calculate the loss 136 | 137 | :param pos_distances: [batch, ] 138 | :param neg_distances: [batch, ] 139 | :return: loss 140 | """ 141 | margin_based_loss = self._cal_margin_based_loss(pos_distances, neg_distances) 142 | scale_loss = self._cal_scale_loss() 143 | orth_loss = self._cal_orthogonal_loss() 144 | ent_num = self.ent_num 145 | return margin_based_loss + self.C * (scale_loss / ent_num + orth_loss / ent_num) 146 | 147 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor): 148 | """Return model losses based on the input. 149 | 150 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail) 151 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail) 152 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component 153 | """ 154 | assert pos_triples.size()[1] == 3 155 | assert neg_triples.size()[1] == 3 156 | 157 | pos_distances = self._distance(pos_triples) 158 | neg_distances = self._distance(neg_triples) 159 | loss = self.loss(pos_distances, neg_distances) 160 | return loss, pos_distances, neg_distances 161 | 162 | def predict(self, triples: torch.Tensor): 163 | """Calculated dissimilarity score for given triplets. 164 | 165 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 166 | :return: dissimilarity score for given triplets 167 | """ 168 | return self._distance(triples) 169 | 170 | 171 | class TransHMain(ModelMain): 172 | def __init__( 173 | self, 174 | dataset_conf: LocalDatasetConf, 175 | train_conf: TrainConf, 176 | hyper_params: TransHHyperParam, 177 | device: torch.device 178 | ) -> None: 179 | super().__init__() 180 | self.dataset_conf = dataset_conf 181 | self.train_conf = train_conf 182 | self.hyper_params = hyper_params 183 | self.device = device 184 | 185 | def __call__(self): 186 | # create mapping 187 | entity2id, rel2id = create_mapping(self.dataset_conf) 188 | ent_num = len(entity2id) 189 | rel_num = len(rel2id) 190 | 191 | # create dataset and dataloader 192 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id) 193 | 194 | # create negative-sampler 195 | neg_sampler = BernNegSampler(train_dataset, self.device) 196 | 197 | # create model 198 | model = TransH(ent_num, rel_num, self.device, self.hyper_params) 199 | model = model.to(self.device) 200 | 201 | # create optimizer 202 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate) 203 | 204 | # create trainer 205 | trainer = TransETrainer( 206 | model=model, 207 | train_conf=self.train_conf, 208 | params=self.hyper_params, 209 | dataset_conf=self.dataset_conf, 210 | entity2id=entity2id, 211 | rel2id=rel2id, 212 | device=self.device, 213 | train_dataloder=train_dataloader, 214 | valid_dataloder=valid_dataloader, 215 | train_neg_sampler=neg_sampler, 216 | valid_neg_sampler=neg_sampler, 217 | optimzer=optimizer 218 | ) 219 | 220 | # training process 221 | trainer.run_training() 222 | 223 | # create evaluator 224 | metrics = [ 225 | MetricEnum.MRR, 226 | MetricEnum.HITS_AT_1, 227 | MetricEnum.HITS_AT_3, 228 | MetricEnum.HITS_AT_10 229 | ] 230 | evaluator = RankEvaluator(self.device, metrics) 231 | 232 | # Testing the best checkpoint on test dataset 233 | # load best model 234 | ckpt = storage.load_checkpoint(self.train_conf) 235 | model.load_state_dict(ckpt.model_state_dict) 236 | model = model.to(self.device) 237 | # create test-dataset 238 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id) 239 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size) 240 | # run inference on test-dataset 241 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator) 242 | 243 | # choice metric formatter 244 | metric_formatter = StringFormatter() 245 | 246 | # choice the way of serialize 247 | serilizer = FileSerializer(self.train_conf, self.dataset_conf) 248 | # serialize the metric 249 | serilizer.serialize(metric, metric_formatter) 250 | -------------------------------------------------------------------------------- /krl/models/TransR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | 4 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/TransR.py 5 | - https://github.com/zqhead/TransR 6 | - https://github.com/Sujit-O/pykg2vec/blob/master/pykg2vec/models/pairwise.py 7 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/TransR.py 8 | 9 | Note: Although the TransE can be run, I don't know why it have very low hits@10 metric. 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | 17 | from ..base_model import TransXBaseModel, ModelMain 18 | from ..config import HyperParam, LocalDatasetConf, TrainConf 19 | from ..dataset import create_mapping, LocalKRLDataset 20 | from ..negative_sampler import BernNegSampler 21 | from .. import utils 22 | from ..trainer import TransETrainer 23 | from ..metric import MetricEnum 24 | from ..evaluator import RankEvaluator 25 | from .. import storage 26 | from ..metric_fomatter import StringFormatter 27 | from ..serializer import FileSerializer 28 | 29 | class TransRHyperParam(HyperParam): 30 | """Hyper-parameters of TransR 31 | """ 32 | embed_dim: int 33 | norm: int 34 | margin: float 35 | C: float 36 | 37 | 38 | class TransR(TransXBaseModel): 39 | def __init__(self, 40 | ent_num: int, 41 | rel_num: int, 42 | device: torch.device, 43 | hyper_params: TransRHyperParam 44 | ): 45 | super().__init__() 46 | self.ent_num = ent_num 47 | self.rel_num = rel_num 48 | self.device = device 49 | self.norm = hyper_params.norm 50 | self.embed_dim = hyper_params.embed_dim 51 | self.C = hyper_params.C 52 | 53 | self.margin = hyper_params.margin 54 | self.epsilon = 2.0 55 | self.embedding_range = (self.margin + self.epsilon) / self.embed_dim 56 | 57 | # 初始化 ent_embedding,按照原论文的方法来初始化 58 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim) 59 | nn.init.uniform_( 60 | tensor=self.ent_embedding.weight.data, 61 | a=-self.embedding_range, 62 | b=self.embedding_range 63 | ) 64 | # torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data) 65 | 66 | # 初始化 rel_embedding 67 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim) 68 | nn.init.uniform_( 69 | tensor=self.rel_embedding.weight.data, 70 | a=-self.embedding_range, 71 | b=self.embedding_range 72 | ) 73 | # torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data) 74 | 75 | # initialize trasfer matrix 76 | self.transfer_matrix = nn.Embedding(self.rel_num, self.embed_dim * self.embed_dim) 77 | nn.init.uniform_( 78 | tensor=self.transfer_matrix.weight.data, 79 | a=-self.embedding_range, 80 | b=self.embedding_range 81 | ) 82 | # nn.init.xavier_uniform_(self.tranfer_matrix.weight.data) 83 | 84 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance 85 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin) 86 | 87 | def _transfer(self, ent_embs: torch.Tensor, rels_tranfer: torch.Tensor): 88 | """Tranfer the entity space into the relation-specfic space 89 | 90 | :param ent_embs: [batch, ent_dim] 91 | :param rels_tranfer: [batch, ent_dim * rel_dim] 92 | """ 93 | assert ent_embs.size(0) == rels_tranfer.size(0) 94 | assert ent_embs.size(1) == self.embed_dim 95 | assert rels_tranfer.size(1) == self.embed_dim * self.embed_dim 96 | 97 | rels_tranfer = rels_tranfer.view(-1, self.embed_dim, self.embed_dim) # [batch, ent_dim, rel_dim] 98 | ent_embs = ent_embs.view(-1, 1, self.embed_dim) # [batch, 1, ent_dim] 99 | 100 | ent_proj = torch.matmul(ent_embs, rels_tranfer) # [batch, 1, rel_dim] 101 | return ent_proj.view(-1, self.embed_dim) # [batch, rel_dim] 102 | 103 | def embed(self, triples): 104 | """get the embedding of triples 105 | 106 | :param triples: [heads, rels, tails] 107 | :return: embedding of triples. 108 | """ 109 | assert triples.shape[1] == 3 110 | heads = triples[:, 0] 111 | rels = triples[:, 1] 112 | tails = triples[:, 2] 113 | # id -> embedding 114 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim] 115 | r_embs = self.rel_embedding(rels) 116 | t_embs = self.ent_embedding(tails) 117 | rels_transfer = self.transfer_matrix(rels) 118 | # tranfer the entity embedding from entity space into relation-specific space 119 | h_embs = self._transfer(h_embs, rels_transfer) 120 | t_embs = self._transfer(t_embs, rels_transfer) 121 | return h_embs, r_embs, t_embs 122 | 123 | def _distance(self, triples): 124 | """计算一个 batch 的三元组的 distance 125 | 126 | :param triples: 一个 batch 的 triple,size: [batch, 3] 127 | :return: size: [batch,] 128 | """ 129 | h_embs, r_embs, t_embs = self.embed(triples) 130 | return self.dist_fn(h_embs + r_embs, t_embs) 131 | 132 | def _cal_margin_base_loss(self, pos_distances, neg_distances): 133 | """Calculate the margin-based loss 134 | 135 | :param pos_distances: [batch, ] 136 | :param neg_distances: [batch, ] 137 | :return: margin_based loss, size: [1,] 138 | """ 139 | ones = torch.tensor([-1], dtype=torch.long, device=self.device) 140 | return self.margin_loss_fn(pos_distances, neg_distances, ones) 141 | 142 | def _cal_scale_loss(self, embedding: nn.Embedding): 143 | """Calculate the scale loss. 144 | F.relu(x) is equal to max(x, 0). 145 | """ 146 | norm = torch.norm(embedding.weight, p=2, dim=1) # the L2 norm of entity embedding, size: [ent_num, ] 147 | scale_loss = torch.sum(F.relu(norm - 1)) 148 | return scale_loss 149 | 150 | def loss(self, pos_distances, neg_distances): 151 | """Calculate the loss 152 | 153 | :param pos_distances: [batch, ] 154 | :param neg_distances: [batch, ] 155 | :return: loss 156 | """ 157 | margin_based_loss = self._cal_margin_base_loss(pos_distances, neg_distances) 158 | ent_scale_loss = self._cal_scale_loss(self.ent_embedding) 159 | rel_scale_loss = self._cal_scale_loss(self.rel_embedding) 160 | return margin_based_loss + self.C * ((ent_scale_loss + rel_scale_loss) / (self.ent_num + self.rel_num)) 161 | 162 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor): 163 | """Return model losses based on the input. 164 | 165 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail) 166 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail) 167 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component 168 | """ 169 | assert pos_triples.size()[1] == 3 170 | assert neg_triples.size()[1] == 3 171 | 172 | pos_distances = self._distance(pos_triples) 173 | neg_distances = self._distance(neg_triples) 174 | loss = self.loss(pos_distances, neg_distances) 175 | return loss, pos_distances, neg_distances 176 | 177 | def predict(self, triples: torch.Tensor): 178 | """Calculated dissimilarity score for given triplets. 179 | 180 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail) 181 | :return: dissimilarity score for given triplets 182 | """ 183 | return self._distance(triples) 184 | 185 | 186 | class TransRMain(ModelMain): 187 | 188 | def __init__( 189 | self, 190 | dataset_conf: LocalDatasetConf, 191 | train_conf: TrainConf, 192 | hyper_params: TransRHyperParam, 193 | device: torch.device 194 | ) -> None: 195 | super().__init__() 196 | self.dataset_conf = dataset_conf 197 | self.train_conf = train_conf 198 | self.hyper_params = hyper_params 199 | self.device = device 200 | 201 | def __call__(self): 202 | # create mapping 203 | entity2id, rel2id = create_mapping(self.dataset_conf) 204 | ent_num = len(entity2id) 205 | rel_num = len(rel2id) 206 | 207 | # create dataset and dataloader 208 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id) 209 | 210 | # create negative-sampler 211 | neg_sampler = BernNegSampler(train_dataset, self.device) 212 | 213 | # create model 214 | model = TransR(ent_num, rel_num, self.device, self.hyper_params) 215 | model = model.to(self.device) 216 | 217 | # create optimizer 218 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate) 219 | 220 | # create trainer 221 | trainer = TransETrainer( 222 | model=model, 223 | train_conf=self.train_conf, 224 | params=self.hyper_params, 225 | dataset_conf=self.dataset_conf, 226 | entity2id=entity2id, 227 | rel2id=rel2id, 228 | device=self.device, 229 | train_dataloder=train_dataloader, 230 | valid_dataloder=valid_dataloader, 231 | train_neg_sampler=neg_sampler, 232 | valid_neg_sampler=neg_sampler, 233 | optimzer=optimizer 234 | ) 235 | 236 | # training process 237 | trainer.run_training() 238 | 239 | # create evaluator 240 | metrics = [ 241 | MetricEnum.MRR, 242 | MetricEnum.HITS_AT_1, 243 | MetricEnum.HITS_AT_3, 244 | MetricEnum.HITS_AT_10 245 | ] 246 | evaluator = RankEvaluator(self.device, metrics) 247 | 248 | # Testing the best checkpoint on test dataset 249 | # load best model 250 | ckpt = storage.load_checkpoint(self.train_conf) 251 | model.load_state_dict(ckpt.model_state_dict) 252 | model = model.to(self.device) 253 | # create test-dataset 254 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id) 255 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size) 256 | # run inference on test-dataset 257 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator) 258 | 259 | # choice metric formatter 260 | metric_formatter = StringFormatter() 261 | 262 | # choice the way of serialize 263 | serilizer = FileSerializer(self.train_conf, self.dataset_conf) 264 | # serialize the metric 265 | serilizer.serialize(metric, metric_formatter) 266 | -------------------------------------------------------------------------------- /krl/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/models/__init__.py -------------------------------------------------------------------------------- /krl/negative_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | The sampler used to obtain negative samples for KRL. 3 | """ 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from abc import ABC, abstractmethod 7 | 8 | from .dataset import KRLDataset 9 | 10 | 11 | class NegativeSampler(ABC): 12 | def __init__(self, dataset: KRLDataset): 13 | self.dataset = dataset 14 | 15 | @abstractmethod 16 | def neg_sample(self, heads, rels, tails): 17 | """执行负采样 18 | 19 | :param heads: 由 batch_size 个 head idx 组成的 tensor,size: [batch_size] 20 | :param rels: size [batch_size] 21 | :param tails: size [batch_size] 22 | """ 23 | pass 24 | 25 | 26 | class RandomNegativeSampler(NegativeSampler): 27 | """ 28 | 随机替换 head 或者 tail 来实现采样 29 | """ 30 | def __init__(self, dataset: KRLDataset): 31 | super().__init__(dataset) 32 | 33 | def neg_sample(self, heads, rels, tails): 34 | device = heads.device 35 | ent_num = len(self.dataset.entity2id) 36 | head_or_tail = torch.randint(high=2, size=heads.size(), device=device) 37 | random_entities = torch.randint(high=ent_num, size=heads.size(), device=device) 38 | corupted_heads = torch.where(head_or_tail == 1, random_entities, heads) 39 | corupted_tails = torch.where(head_or_tail == 0, random_entities, tails) 40 | return torch.stack([corupted_heads, rels, corupted_tails], dim=1) 41 | 42 | 43 | class BernNegSampler(NegativeSampler): 44 | """ 45 | Using bernoulli distribution to select whether to replace the head entity or tail entity. 46 | Specific sample process can refer to TransH paper or this implementation. 47 | """ 48 | def __init__(self, 49 | dataset: KRLDataset): 50 | """init function 51 | 52 | :param dataset: KRLDataset for negative sample 53 | :param device: device 54 | """ 55 | super().__init__(dataset) 56 | self.entity2id = dataset.meta.entity2id 57 | self.rel2id = dataset.meta.rel2id 58 | self.ent_num = len(self.entity2id) 59 | self.rel_num = len(self.rel2id) 60 | 61 | self.probs_of_replace_head = self._cal_tph_and_hpt() # 采样时替换 head 的概率 62 | assert self.probs_of_replace_head.shape[0] == self.rel_num 63 | 64 | def _cal_tph_and_hpt(self): 65 | dataloder = DataLoader(self.dataset, batch_size=1) 66 | r_h_matrix = torch.zeros([self.rel_num, self.ent_num]) # [i, j] 表示 r_i 与 h_j 有多少种尾实体 67 | r_t_matrxi = torch.zeros([self.rel_num, self.ent_num]) # [i, j] 表示 r_i 与 t_j 有多少种头实体 68 | for batch in iter(dataloder): 69 | h, r, t = batch[0], batch[1], batch[2] 70 | h = h.item() 71 | r = r.item() 72 | t = t.item() 73 | r_h_matrix[r, h] += 1 74 | r_t_matrxi[r, t] += 1 75 | tph = torch.sum(r_h_matrix, dim=1) / torch.sum(r_h_matrix != 0, dim=1) 76 | tph.nan_to_num_(1) # 将 nan 填充为 1 77 | hpt = torch.sum(r_t_matrxi, dim=1) / torch.sum(r_t_matrxi != 0, dim=1) 78 | hpt.nan_to_num_(1) 79 | probs_of_replace_head = tph / (tph + hpt) 80 | probs_of_replace_head.nan_to_num_(0.5) 81 | return probs_of_replace_head 82 | 83 | def neg_sample(self, heads, rels, tails): 84 | device = heads.device 85 | batch_size = heads.shape[0] 86 | rands = torch.rand([batch_size]) 87 | probs = self.probs_of_replace_head[rels.cpu()] 88 | head_or_tail = (rands < probs).to(device) # True 的代表选择 head, False 的代表选择 tail 89 | random_entities = torch.randint(high=self.ent_num, size=heads.size(), device=device) 90 | corupted_heads = torch.where(head_or_tail == True, random_entities, heads) 91 | corupted_tails = torch.where(head_or_tail == False, random_entities, tails) 92 | return torch.stack([corupted_heads, rels, corupted_tails], dim=1) 93 | -------------------------------------------------------------------------------- /krl/serializer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from .config import TrainConf, DatasetConf 4 | from .metric import KRLMetricBase, RankMetric 5 | from .metric_fomatter import MetricFormatter, StringFormatter 6 | 7 | 8 | class BaseSerializer(ABC): 9 | """ 10 | Serialize the metrics. 11 | Every serializer class should derive this class. 12 | """ 13 | def __init__(self, train_conf: TrainConf, dataset_conf: DatasetConf) -> None: 14 | super().__init__() 15 | self.train_conf = train_conf 16 | self.dataset_conf = dataset_conf 17 | 18 | @abstractmethod 19 | def serialize(self, metric: KRLMetricBase, formatter: MetricFormatter) -> bool: 20 | pass 21 | 22 | 23 | class FileSerializer(BaseSerializer): 24 | """ 25 | Serilize the metric into local file. 26 | """ 27 | def serialize(self, metric: RankMetric, formatter: StringFormatter) -> bool: 28 | """ 29 | Serialize the string metric into file. 30 | 31 | :param metric: the metric instance that you want to serilize. 32 | :param formatter: We will use this formatter to convert metric instance into string. 33 | :return: success or not. 34 | """ 35 | result = formatter.convert(metric, self.dataset_conf) 36 | with open(self.train_conf.metric_result_path, 'w') as f: 37 | f.write(result) 38 | return True 39 | -------------------------------------------------------------------------------- /krl/storage.py: -------------------------------------------------------------------------------- 1 | 2 | from pydantic import BaseModel 3 | import torch 4 | 5 | from .config import TrainConf, HyperParam 6 | from .base_model import KRLModel 7 | 8 | class CheckpointFormat(BaseModel): 9 | model_state_dict: dict 10 | optim_state_dict: dict 11 | epoch_id: int 12 | best_score: float 13 | hyper_params: dict 14 | 15 | 16 | def save_checkpoint(model: KRLModel, 17 | optimzer: torch.optim.Optimizer, 18 | epoch_id: int, 19 | best_score: float, 20 | hyper_params: HyperParam, 21 | train_conf: TrainConf): 22 | ckpt = CheckpointFormat( 23 | model_state_dict=model.state_dict(), 24 | optim_state_dict=optimzer.state_dict(), 25 | epoch_id=epoch_id, 26 | best_score=best_score, 27 | hyper_params=hyper_params.dict() 28 | ) 29 | torch.save(ckpt.dict(), train_conf.checkpoint_path) 30 | 31 | 32 | def load_checkpoint(train_conf: TrainConf) -> CheckpointFormat: 33 | ckpt = torch.load(train_conf.checkpoint_path) 34 | return CheckpointFormat.parse_obj(ckpt) 35 | -------------------------------------------------------------------------------- /krl/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | KRLTrainer for training and testing models. 3 | """ 4 | from typing import Mapping 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from rich.progress import Progress as RichProgress 8 | from abc import ABC, abstractmethod 9 | 10 | from .base_model import KRLModel 11 | from .config import TrainConf, HyperParam, DatasetConf 12 | from .negative_sampler import NegativeSampler 13 | from . import storage 14 | from .evaluator import RankEvaluator 15 | from .metric import RankMetric, MetricEnum 16 | 17 | 18 | 19 | class KRLTrainer(ABC): 20 | def __init__(self, 21 | model: KRLModel, 22 | train_conf: TrainConf, 23 | params: HyperParam, 24 | dataset_conf: DatasetConf, 25 | entity2id: Mapping[str ,int], 26 | rel2id: Mapping[str, int], 27 | device: torch.device, 28 | train_dataloader: DataLoader, 29 | valid_dataloader: DataLoader, 30 | train_neg_sampler: NegativeSampler, 31 | valid_neg_sampler: NegativeSampler, 32 | optimzer: torch.optim.Optimizer 33 | ) -> None: 34 | self.model = model 35 | self.train_conf = train_conf 36 | self.params = params 37 | self.dataset_conf = dataset_conf 38 | self.entity2id = entity2id 39 | self.rel2id = rel2id 40 | self.device = device 41 | self.train_dataloader = train_dataloader 42 | self.valid_dataloader = valid_dataloader 43 | self.train_neg_sampler = train_neg_sampler 44 | self.valid_neg_sampler = valid_neg_sampler 45 | self.optimzer = optimzer 46 | 47 | def run_inference(self, 48 | dataloder: DataLoader, 49 | ent_num: int, 50 | evaluator: RankEvaluator, 51 | ) -> RankMetric: 52 | """ 53 | Run the inference process on the KRL model. 54 | 55 | Rewrite this function if you need more logic for the training model. 56 | The implementation here just provides an example of training TransE. 57 | """ 58 | model = self.model 59 | device = self.device 60 | 61 | # entity_ids = [[0, 1, 2, ..., ent_num]], shape: [1, ent_num] 62 | entitiy_ids = torch.arange(0, ent_num, device=device).unsqueeze(0) 63 | for i, batch in enumerate(dataloder): 64 | # batch: [3, batch_size] 65 | heads, rels, tails = batch[0].to(device), batch[1].to(device), batch[2].to(device) 66 | batch_size = heads.size()[0] 67 | all_entities = entitiy_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num] 68 | # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num] 69 | heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num] 70 | rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num) 71 | tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num) 72 | # check all possible tails 73 | triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3] 74 | tails_predictions = model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num] 75 | # check all possible heads 76 | triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3) 77 | heads_predictions = model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num] 78 | 79 | # Concept preditions 80 | predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num] 81 | ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1] 82 | # calculate metrics 83 | evaluator.evaluate(predictions, ground_truth_entity_id) 84 | 85 | metric_result = evaluator.export_metrics() # get result from evaluator 86 | return metric_result 87 | 88 | @abstractmethod 89 | def run_training(self): 90 | """ 91 | Run the training process on the KRL model. 92 | 93 | Rewrite this function if you need more logic for the training model. 94 | The implementation here just provides an example of training TransE. 95 | """ 96 | pass 97 | 98 | 99 | class TransETrainer(KRLTrainer): 100 | """ 101 | Trainer for training TransE and other similar models. 102 | """ 103 | def __init__( 104 | self, 105 | model: KRLModel, 106 | train_conf: TrainConf, 107 | params: HyperParam, 108 | dataset_conf: DatasetConf, 109 | entity2id: Mapping[str, int], 110 | rel2id: Mapping[str, int], 111 | device: torch.device, 112 | train_dataloder: DataLoader, 113 | valid_dataloder: DataLoader, 114 | train_neg_sampler: NegativeSampler, 115 | valid_neg_sampler: NegativeSampler, 116 | optimzer: torch.optim.Optimizer 117 | ) -> None: 118 | super().__init__(model, train_conf, params, dataset_conf, entity2id, rel2id, device, train_dataloder, valid_dataloder, train_neg_sampler, valid_neg_sampler, optimzer) 119 | 120 | def run_training(self): 121 | """ 122 | Run the training process on the TransE model and other similar models, for example, TransH. 123 | 124 | Rewrite this function if you need more logic for the training model. 125 | The implementation here just provides an example of training TransE. 126 | """ 127 | device = self.device 128 | optimzer = self.optimzer 129 | model = self.model 130 | DATASET_LEN = len(self.train_dataloader.dataset) 131 | # prepare the tools for tarining 132 | best_score = 0.0 133 | evaluator = RankEvaluator(self.device, [MetricEnum.HITS_AT_10]) 134 | # training loop 135 | with RichProgress() as rich_progress: 136 | train_task = rich_progress.add_task('[green]Total training...', total=self.params.epoch_size) 137 | for epoch_id in range(1, self.params.epoch_size + 1): 138 | epoch_task = rich_progress.add_task(f'[cyan]Epoch {epoch_id}', total=DATASET_LEN) 139 | model.train() 140 | for i, batch in enumerate(self.train_dataloader): 141 | # get a batch of training data 142 | pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device) 143 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3] 144 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3] 145 | optimzer.zero_grad() 146 | # calculte loss 147 | loss, _, _ = model(pos_triples, neg_triples) 148 | loss.backward() 149 | # update model 150 | optimzer.step() 151 | rich_progress.update(epoch_task, advance=pos_triples.size(0)) 152 | rich_progress.remove_task(epoch_task) 153 | 154 | if epoch_id % self.params.valid_freq == 0: 155 | model.eval() 156 | with torch.no_grad(): 157 | ent_num = len(self.entity2id) 158 | evaluator.clear() # clear the evaluator 159 | metric = self.run_inference(self.valid_dataloader, ent_num, evaluator) 160 | hits_at_10 = metric.hits_at_10 161 | if hits_at_10 > best_score: 162 | best_score = hits_at_10 163 | print('best score of valid: ', best_score) 164 | storage.save_checkpoint(model, optimzer, epoch_id, best_score, self.params, self.train_conf) 165 | 166 | rich_progress.update(train_task, advance=1) 167 | 168 | 169 | class RescalTrainer(KRLTrainer): 170 | """ 171 | Trainer for tarining RESCAL and other similar models. 172 | """ 173 | def __init__(self, model: KRLModel, train_conf: TrainConf, params: HyperParam, dataset_conf: DatasetConf, entity2id: Mapping[str, int], rel2id: Mapping[str, int], device: torch.device, train_dataloder: DataLoader, valid_dataloder: DataLoader, train_neg_sampler: NegativeSampler, valid_neg_sampler: NegativeSampler, optimzer: torch.optim.Optimizer) -> None: 174 | super().__init__(model, train_conf, params, dataset_conf, entity2id, rel2id, device, train_dataloder, valid_dataloder, train_neg_sampler, valid_neg_sampler, optimzer) 175 | 176 | def run_training(self): 177 | device = self.device 178 | optimzer = self.optimzer 179 | model = self.model 180 | DATASET_LEN = len(self.train_dataloader.dataset) 181 | # prepare tools for training 182 | best_score = 0.0 183 | evaluator = RankEvaluator(self.device, [MetricEnum.HITS_AT_10]) 184 | # training loop 185 | with RichProgress() as rich_progress: 186 | train_task = rich_progress.add_task('[green]Total training...', total=self.params.epoch_size) 187 | for epoch_id in range(1, self.params.epoch_size + 1): 188 | epoch_task = rich_progress.add_task(f'[cyan]Epoch {epoch_id}', total=DATASET_LEN) 189 | loss_sum = 0 190 | model.train() 191 | for batch in iter(self.train_dataloader): 192 | pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device) 193 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3] 194 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3] 195 | triples = torch.cat([pos_triples, neg_triples]) 196 | pos_num = pos_triples.size(0) 197 | total_num = triples.size(0) 198 | labels = torch.zeros([total_num], device=device) 199 | labels[0: pos_num] = 1 # the pos_triple label is equal to 1. 200 | shuffle_index = torch.randperm(total_num, device=device) # index sequence for shuffling data 201 | triples = triples[shuffle_index] 202 | labels = labels[shuffle_index] 203 | # calculate loss 204 | optimzer.zero_grad() 205 | loss, _ = model(triples, labels) 206 | loss.backward() 207 | loss_sum += loss.cpu().item() 208 | # update model 209 | optimzer.step() 210 | rich_progress.update(epoch_task, advance=pos_triples.size(0)) 211 | rich_progress.remove_task(epoch_task) 212 | 213 | if epoch_id % self.params.valid_freq == 0: 214 | model.eval() 215 | with torch.no_grad(): 216 | ent_num = len(self.entity2id) 217 | evaluator.clear() 218 | metric = self.run_inference(self.valid_dataloader, ent_num, evaluator) 219 | hits_at_10 = metric.hits_at_10 220 | if hits_at_10 > best_score: 221 | best_score = hits_at_10 222 | print('best score of valid: ', best_score) 223 | storage.save_checkpoint(model, optimzer, epoch_id, best_score, self.params, self.train_conf) 224 | 225 | rich_progress.update(train_task, advance=1) -------------------------------------------------------------------------------- /krl/typer_apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/typer_apps/__init__.py -------------------------------------------------------------------------------- /krl/typer_apps/distmult.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import typer 3 | 4 | from ..config import LocalDatasetConf, TrainConf 5 | from ..models.DistMult import DistMultHyperParam, DistMultMain 6 | from .. import utils 7 | 8 | 9 | 10 | app = typer.Typer() 11 | 12 | 13 | @app.command(name='train') 14 | def train_distmult( 15 | dataset_name: str = typer.Option(...), 16 | base_dir: Path = typer.Option(...), 17 | batch_size: int = typer.Option(...), 18 | valid_batch_size: int = typer.Option(...), 19 | valid_freq: int = typer.Option(50), 20 | lr: float = typer.Option(0.001), 21 | epoch_size: int = typer.Option(...), 22 | optimizer: str = typer.Option('adam'), 23 | embed_dim: int = typer.Option(...), 24 | alpha: float = typer.Option(0.001, help='regularization parameter'), 25 | regul_type: str = typer.Option('F2', help='regularization type, F2 or N3', case_sensitive=False), 26 | ckpt_path: Path = typer.Option(...), 27 | metric_result_path: Path = typer.Option(...) 28 | ): 29 | if not base_dir.exists(): 30 | print("base_dir doesn't exists") 31 | raise typer.Exit() 32 | dataset_conf = LocalDatasetConf( 33 | dataset_name=dataset_name, 34 | base_dir=base_dir 35 | ) 36 | train_conf = TrainConf( 37 | checkpoint_path=ckpt_path.absolute().as_posix(), 38 | metric_result_path=metric_result_path.absolute().as_posix() 39 | ) 40 | hyper_params = DistMultHyperParam( 41 | batch_size=batch_size, 42 | valid_batch_size=valid_batch_size, 43 | learning_rate=lr, 44 | optimizer=optimizer, 45 | epoch_size=epoch_size, 46 | embed_dim=embed_dim, 47 | valid_freq=valid_freq, 48 | alpha=alpha, 49 | regul_type=regul_type 50 | ) 51 | device = utils.get_device() 52 | 53 | main = DistMultMain(dataset_conf, train_conf, hyper_params, device) 54 | 55 | main() 56 | -------------------------------------------------------------------------------- /krl/typer_apps/rescal.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import typer 3 | 4 | from ..config import TrainConf 5 | from ..models.RESCAL import RescalHyperParam, RescalMain 6 | from .. import utils 7 | from ..dataset import load_krl_dataset 8 | 9 | 10 | app = typer.Typer() 11 | 12 | 13 | @app.command(name='train') 14 | def train_rescal( 15 | dataset_name: str = typer.Option('FB15k'), 16 | batch_size: int = typer.Option(256), 17 | valid_batch_size: int = typer.Option(16), 18 | valid_freq: int = typer.Option(5), 19 | lr: float = typer.Option(0.01), 20 | epoch_size: int = typer.Option(200), 21 | optimizer: str = typer.Option('adam'), 22 | embed_dim: int = typer.Option(50), 23 | alpha: float = typer.Option(0.001, help='regularization parameter'), 24 | regul_type: str = typer.Option('F2', help='regularization type, F2 or N3', case_sensitive=False), 25 | ckpt_path: Path = typer.Option(Path('/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k.ckpt')), 26 | metric_result_path: Path = typer.Option(Path('/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k_metrics.txt')) 27 | ): 28 | dataset_dict = load_krl_dataset(dataset_name) 29 | train_conf = TrainConf( 30 | checkpoint_path=ckpt_path.absolute().as_posix(), 31 | metric_result_path=metric_result_path.absolute().as_posix() 32 | ) 33 | hyper_params = RescalHyperParam( 34 | batch_size=batch_size, 35 | valid_batch_size=valid_batch_size, 36 | learning_rate=lr, 37 | optimizer=optimizer, 38 | epoch_size=epoch_size, 39 | embed_dim=embed_dim, 40 | valid_freq=valid_freq, 41 | alpha=alpha, 42 | regul_type=regul_type 43 | ) 44 | device = utils.get_device() 45 | 46 | main = RescalMain(dataset_dict, train_conf, hyper_params, device) 47 | 48 | main() 49 | -------------------------------------------------------------------------------- /krl/typer_apps/transe.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from pathlib import Path 3 | 4 | from ..config import TrainConf 5 | from ..models.TransE import TransEHyperParam, TransELitMain 6 | from ..dataset import load_krl_dataset 7 | 8 | 9 | app = typer.Typer() 10 | 11 | 12 | @app.command(name='train') 13 | def train_transe( 14 | dataset_name: str = typer.Option('FB15k'), 15 | batch_size: int = typer.Option(256), 16 | valid_batch_size: int = typer.Option(16), 17 | valid_freq: int = typer.Option(5), 18 | lr: float = typer.Option(0.01), 19 | optimizer: str = typer.Option('adam'), 20 | epoch_size: int = typer.Option(500), 21 | embed_dim: int = typer.Option(100), 22 | norm: int = typer.Option(1), 23 | margin: float = typer.Option(2.0), 24 | logs_dir: Path = typer.Option('lightning_logs/'), 25 | early_stoping_patience: int = typer.Option(5) 26 | ): 27 | dataset_dict = load_krl_dataset(dataset_name) 28 | train_conf = TrainConf( 29 | logs_dir=logs_dir 30 | ) 31 | hyper_params = TransEHyperParam( 32 | batch_size=batch_size, 33 | valid_batch_size=valid_batch_size, 34 | learning_rate=lr, 35 | optimizer=optimizer, 36 | epoch_size=epoch_size, 37 | early_stoping_patience=early_stoping_patience, 38 | embed_dim=embed_dim, 39 | norm=norm, 40 | margin=margin, 41 | valid_freq=valid_freq 42 | ) 43 | 44 | main = TransELitMain(dataset_dict, train_conf, hyper_params) 45 | 46 | main() -------------------------------------------------------------------------------- /krl/typer_apps/transh.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import typer 3 | 4 | from ..config import LocalDatasetConf, TrainConf 5 | from ..models.TransH import TransHHyperParam, TransHMain 6 | from .. import utils 7 | 8 | 9 | app = typer.Typer() 10 | 11 | 12 | @app.command(name='train') 13 | def train_tranh( 14 | dataset_name: str = typer.Option(...), 15 | base_dir: Path = typer.Option(...), 16 | batch_size: int = typer.Option(128), 17 | valid_batch_size: int = typer.Option(64), 18 | valid_freq: int = typer.Option(5), 19 | lr: float = typer.Option(0.001), 20 | epoch_size: int = typer.Option(500), 21 | optimizer: str = typer.Option('adam'), 22 | embed_dim: int = typer.Option(50), 23 | norm: int = typer.Option(2), 24 | margin: float = typer.Option(1.0), 25 | C: float = typer.Option(1.0, help='a hyper-parameter weighting the importance of soft constraints.'), 26 | eps: float = typer.Option(1e-3, help='the $\episilon$ in loss function'), 27 | ckpt_path: Path = typer.Option(...), 28 | metric_result_path: Path = typer.Option(...) 29 | ): 30 | if not base_dir.exists(): 31 | print("base_dir doesn't exists.") 32 | raise typer.Exit() 33 | # initialize all configurations 34 | dataset_conf = LocalDatasetConf( 35 | dataset_name=dataset_name, 36 | base_dir=base_dir 37 | ) 38 | train_conf = TrainConf( 39 | checkpoint_path=ckpt_path.absolute().as_posix(), 40 | metric_result_path=metric_result_path.absolute().as_posix() 41 | ) 42 | hyper_params = TransHHyperParam( 43 | batch_size=batch_size, 44 | valid_batch_size=valid_batch_size, 45 | learning_rate=lr, 46 | optimizer=optimizer, 47 | epoch_size=epoch_size, 48 | embed_dim=embed_dim, 49 | norm=norm, 50 | margin=margin, 51 | valid_freq=valid_freq, 52 | C=C, 53 | eps=eps 54 | ) 55 | device = utils.get_device() 56 | 57 | main = TransHMain(dataset_conf, train_conf, hyper_params, device) 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /krl/typer_apps/transr.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import typer 3 | 4 | from ..config import LocalDatasetConf, TrainConf 5 | from ..models.TransR import TransRHyperParam, TransRMain 6 | from .. import utils 7 | 8 | 9 | app = typer.Typer() 10 | 11 | 12 | @app.command(name='train') 13 | def train_transr( 14 | dataset_name: str = typer.Option(...), 15 | base_dir: Path = typer.Option(...), 16 | batch_size: int = typer.Option(128), 17 | valid_batch_size: int = typer.Option(64), 18 | valid_freq: int = typer.Option(5), 19 | lr: float = typer.Option(0.001), 20 | epoch_size: int = typer.Option(500), 21 | optimizer: str = typer.Option('adam'), 22 | embed_dim: int = typer.Option(50), 23 | norm: int = typer.Option(2), 24 | margin: float = typer.Option(1.0), 25 | C: float = typer.Option(1.0), 26 | ckpt_path: Path = typer.Option(...), 27 | metric_result_path: Path = typer.Option(...) 28 | ): 29 | if not base_dir.exists(): 30 | print("base_dir doesn't exists.") 31 | raise typer.Exit() 32 | dataset_conf = LocalDatasetConf( 33 | dataset_name=dataset_name, 34 | base_dir=base_dir 35 | ) 36 | train_conf = TrainConf( 37 | checkpoint_path=ckpt_path.absolute().as_posix(), 38 | metric_result_path=metric_result_path.absolute().as_posix() 39 | ) 40 | hyper_params = TransRHyperParam( 41 | batch_size=batch_size, 42 | valid_batch_size=valid_batch_size, 43 | learning_rate=lr, 44 | optimizer=optimizer, 45 | epoch_size=epoch_size, 46 | embed_dim=embed_dim, 47 | norm=norm, 48 | margin=margin, 49 | C=C, 50 | valid_freq=valid_freq 51 | ) 52 | device = utils.get_device() 53 | 54 | main = TransRMain(dataset_conf, train_conf, hyper_params, device) 55 | 56 | main() -------------------------------------------------------------------------------- /krl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .device import get_device 2 | from .seed import set_seed 3 | from .data import create_local_dataloader 4 | from .optim import create_optimizer 5 | from .logs_dir import create_logs_dir -------------------------------------------------------------------------------- /krl/utils/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from typing import Mapping, Tuple 3 | 4 | from ..config import HyperParam, LocalDatasetConf 5 | from ..dataset import LocalKRLDataset 6 | 7 | 8 | 9 | def create_local_dataloader( 10 | dataset_conf: LocalDatasetConf, 11 | params: HyperParam, 12 | entity2id: Mapping[str, int], 13 | rel2id: Mapping[str, int] 14 | ) -> Tuple[LocalKRLDataset, DataLoader, LocalKRLDataset, DataLoader]: 15 | train_dataset = LocalKRLDataset(dataset_conf, 'train', entity2id, rel2id) 16 | train_dataloader = DataLoader(train_dataset, params.batch_size) 17 | valid_dataset = LocalKRLDataset(dataset_conf, 'valid', entity2id, rel2id) 18 | valid_dataloader = DataLoader(valid_dataset, params.valid_batch_size) 19 | return train_dataset, train_dataloader, valid_dataset, valid_dataloader 20 | -------------------------------------------------------------------------------- /krl/utils/device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_device() -> torch.device: 4 | return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 5 | -------------------------------------------------------------------------------- /krl/utils/logs_dir.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from ..config import TrainConf, DatasetConf 4 | from ..base_model import KRLModel 5 | 6 | 7 | def create_logs_dir( 8 | model: KRLModel, 9 | train_conf: TrainConf, 10 | dataset_conf: DatasetConf 11 | ) -> str: 12 | dir = train_conf.logs_dir / model.__class__.__name__ / dataset_conf.dataset_name / datetime.now().strftime(r'%y%m%d-%H%M%S') 13 | dir.mkdir(parents=True, exist_ok=True) 14 | return str(dir) 15 | -------------------------------------------------------------------------------- /krl/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import torch 3 | from torch.optim import Optimizer 4 | from typing import Union 5 | 6 | from ..base_model import KRLModel 7 | 8 | 9 | class Lion(Optimizer): 10 | r"""Implements Lion algorithm.""" 11 | 12 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 13 | """Initialize the hyperparameters. 14 | Args: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-4) 18 | betas (Tuple[float, float], optional): coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.99)) 20 | weight_decay (float, optional): weight decay coefficient (default: 0) 21 | """ 22 | 23 | if not 0.0 <= lr: 24 | raise ValueError('Invalid learning rate: {}'.format(lr)) 25 | if not 0.0 <= betas[0] < 1.0: 26 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) 27 | if not 0.0 <= betas[1] < 1.0: 28 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) 29 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 30 | super().__init__(params, defaults) 31 | 32 | @torch.no_grad() 33 | def step(self, closure=None): 34 | """Performs a single optimization step. 35 | Args: 36 | closure (callable, optional): A closure that reevaluates the model 37 | and returns the loss. 38 | Returns: 39 | the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | with torch.enable_grad(): 44 | loss = closure() 45 | 46 | for group in self.param_groups: 47 | for p in group['params']: 48 | if p.grad is None: 49 | continue 50 | 51 | # Perform stepweight decay 52 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 53 | 54 | grad = p.grad 55 | state = self.state[p] 56 | # State initialization 57 | if len(state) == 0: 58 | # Exponential moving average of gradient values 59 | state['exp_avg'] = torch.zeros_like(p) 60 | 61 | exp_avg = state['exp_avg'] 62 | beta1, beta2 = group['betas'] 63 | 64 | # Weight update 65 | update = exp_avg * beta1 + grad * (1 - beta1) 66 | p.add_(torch.sign(update), alpha=-group['lr']) 67 | # Decay the momentum running average coefficient 68 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 69 | 70 | return loss 71 | 72 | 73 | optimizer_map = { 74 | 'adam': optim.Adam, 75 | 'adamw': optim.AdamW, 76 | 'radam': optim.RAdam, 77 | 'sgd': optim.SGD, 78 | 'adagrad': optim.Adagrad, 79 | 'rms': optim.RMSprop, 80 | 'lion': Lion 81 | } 82 | 83 | 84 | def create_optimizer(optimizer_name: Union[str, Optimizer], model: KRLModel, lr: float) -> Optimizer: 85 | """create a optimizer from a optimizer name 86 | """ 87 | # if the param `optimizer_name` is a Optimizer, return it now. 88 | if isinstance(optimizer_name, Optimizer): 89 | return optimizer_name 90 | 91 | optim_klass = optimizer_map.get(optimizer_name.lower()) 92 | if optim_klass is None: 93 | raise NotImplementedError(f'No support for {optimizer_name} optimizer') 94 | return optim_klass(model.parameters(), lr) 95 | -------------------------------------------------------------------------------- /krl/utils/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def set_seed(seed: int): 7 | random.seed(seed) 8 | torch.manual_seed(seed) 9 | if torch.cuda.is_available(): 10 | torch.cuda.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | np.random.seed(seed) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | typer[all] 3 | pydantic 4 | loguru 5 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/root/miniconda3/envs/ais/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "from krl.dataset.instance import load_krl_dataset" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stderr", 28 | "output_type": "stream", 29 | "text": [ 30 | "Using custom data configuration VLyb--WN18-18d07f2b5874d7b1\n", 31 | "Found cached dataset csv (/root/.cache/huggingface/datasets/VLyb___csv/VLyb--WN18-18d07f2b5874d7b1/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n", 32 | "100%|██████████| 3/3 [00:00<00:00, 466.55it/s]\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "wn18 = load_krl_dataset('WN18')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 6, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "(10, 4, 11)" 49 | ] 50 | }, 51 | "execution_count": 6, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "wn18.train[5]" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": "ais", 71 | "language": "python", 72 | "name": "python3" 73 | }, 74 | "language_info": { 75 | "codemirror_mode": { 76 | "name": "ipython", 77 | "version": 3 78 | }, 79 | "file_extension": ".py", 80 | "mimetype": "text/x-python", 81 | "name": "python", 82 | "nbconvert_exporter": "python", 83 | "pygments_lexer": "ipython3", 84 | "version": "3.9.16" 85 | }, 86 | "orig_nbformat": 4, 87 | "vscode": { 88 | "interpreter": { 89 | "hash": "e79d845ce09a116dcacebcdcf6a8ee439742a7b8247d9dc7720dd3aefde3f822" 90 | } 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 2 95 | } 96 | -------------------------------------------------------------------------------- /transe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TransE\n", 8 | "\n", 9 | "Here we will show how to reproduce the TransE model." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from pydantic import BaseSettings, BaseModel, Field\n", 19 | "from typing import Optional, Literal, Tuple, Dict, List\n", 20 | "from torch.utils.data import Dataset, DataLoader\n", 21 | "import torch.nn as nn\n", 22 | "import torch\n", 23 | "import numpy as np\n", 24 | "from abc import ABC, abstractmethod" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Config Class\n", 32 | "\n", 33 | "Define out configuration classes that can instantiate configuration items that meet your requirements based on your runtime environment or your own ideas.\n", 34 | "\n", 35 | "With the help of configuration classes, you can change the dataset or adjust the hyperparameters of the model without changing the model logic." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "class DatasetConf(BaseSettings):\n", 45 | " \"\"\"\n", 46 | " 数据集的相关配置信息\n", 47 | " \"\"\"\n", 48 | " dataset_name: str = Field(title='数据集的名称,方便打印时查看')\n", 49 | " base_dir: str = Field(title='数据集的目录')\n", 50 | " entity2id_path: str = Field(default='entity2id.txt', title='entity2id 的文件名')\n", 51 | " relation2id_path: str = Field(default='relation2id.txt', title='relation2id 的文件名')\n", 52 | " train_path: str = Field(default='train.txt', title='training set 的文件')\n", 53 | " valid_path: str = Field(default='valid.txt', title='valid set 的文件')\n", 54 | " test_path: str = Field(default='test.txt', title='testing set 的目录')\n", 55 | "\n", 56 | "\n", 57 | "class HyperParam(BaseModel):\n", 58 | " \"\"\"\n", 59 | " 超参数\n", 60 | " \"\"\"\n", 61 | " batch_size: int = 128\n", 62 | " valid_batch_size: int = 64\n", 63 | " learning_rate: float = 0.001\n", 64 | " epoch_size: int = 500\n", 65 | " embed_dim: int = 50\n", 66 | " norm: int = 1\n", 67 | " margin: int = 2.0\n", 68 | " valid_freq: int = Field(title='训练过程中,每隔多少次就做一次 valid 来验证是否保存模型')\n", 69 | "\n", 70 | "\n", 71 | "class TrainConf(BaseModel):\n", 72 | " \"\"\"\n", 73 | " 训练的一些配置\n", 74 | " \"\"\"\n", 75 | " checkpoint_path: str = Field(title='保存模型的路径')\n", 76 | " metric_result_path: str = Field(title='运行 test 的 metric 输出位置')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Dataset\n", 84 | "\n", 85 | "Defines the classes used to read datasets, including reading entity-to-ID mappings, relationship-to-ID mappings, and triplet collections.\n", 86 | "\n", 87 | "+ The `create_mapping` function is used to generate the entity-to-ID mapping dictionary and the relationship-to-ID mapping dictionary.\n", 88 | "+ `KRLDataset` is a further wrapper around the `Dataset` class in PyTorch, and is similar in usage." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "EntityMapping = Dict[str, int]\n", 98 | "RelMapping = Dict[str, int]\n", 99 | "Triple = List[int]\n", 100 | "\n", 101 | "def create_mapping(dataset_conf: DatasetConf) -> Tuple[EntityMapping, RelMapping]:\n", 102 | " \"\"\"\n", 103 | " create mapping of `entity2id` and `relation2id`\n", 104 | " \"\"\"\n", 105 | " # 读取 entity2id\n", 106 | " entity2id = dict()\n", 107 | " with open(dataset_conf.base_dir + dataset_conf.entity2id_path) as f:\n", 108 | " for line in f:\n", 109 | " entity, entity_id = line.split()\n", 110 | " entity = entity.strip()\n", 111 | " entity_id = int(entity_id.strip())\n", 112 | " entity2id[entity] = entity_id\n", 113 | " # 读取 relation2id\n", 114 | " rel2id = dict()\n", 115 | " with open(dataset_conf.base_dir + dataset_conf.relation2id_path) as f:\n", 116 | " for line in f:\n", 117 | " rel, rel_id = line.split()\n", 118 | " rel = rel.strip()\n", 119 | " rel_id = int(rel_id.strip())\n", 120 | " rel2id[rel] = rel_id\n", 121 | " return entity2id, rel2id\n", 122 | "\n", 123 | "\n", 124 | "class KRLDataset(Dataset):\n", 125 | " def __init__(self,\n", 126 | " dataset_conf: DatasetConf,\n", 127 | " mode: Literal['train', 'valid', 'test'],\n", 128 | " entity2id: Dict[str, int],\n", 129 | " rel2id: Dict[str, int]) -> None:\n", 130 | " super().__init__()\n", 131 | " self.conf = dataset_conf\n", 132 | " self.mode = mode\n", 133 | " self.triples = []\n", 134 | " self.entity2id = entity2id\n", 135 | " self.rel2id = rel2id\n", 136 | " self._read_triples() # 读取数据集,并获得所有的 triples\n", 137 | " \n", 138 | " def _split_and_to_id(self, line: str) -> Triple:\n", 139 | " \"\"\"将数据集文件中的一行数据进行切分,并将 entity 和 rel 转换成 id\n", 140 | "\n", 141 | " :param line: 数据集的一行数据\n", 142 | " :return: [head_id, rel_id, tail_id]\n", 143 | " \"\"\"\n", 144 | " head, tail, rel = line.split()\n", 145 | " head_id = self.entity2id[head.strip()]\n", 146 | " rel_id = self.rel2id[rel.strip()]\n", 147 | " tail_id = self.entity2id[tail.strip()]\n", 148 | " return (head_id, rel_id, tail_id)\n", 149 | " \n", 150 | " def _read_triples(self):\n", 151 | " data_path = {\n", 152 | " 'train': self.conf.train_path,\n", 153 | " 'valid': self.conf.valid_path,\n", 154 | " 'test': self.conf.test_path\n", 155 | " }.get(self.mode)\n", 156 | " with open(self.conf.base_dir + data_path) as f:\n", 157 | " self.triples = [self._split_and_to_id(line) for line in f]\n", 158 | " \n", 159 | " def __len__(self):\n", 160 | " \"\"\"Denotes the total number of samples.\"\"\"\n", 161 | " return len(self.triples)\n", 162 | " \n", 163 | " def __getitem__(self, index) -> Triple:\n", 164 | " \"\"\"Returns (head id, relation id, tail id).\"\"\"\n", 165 | " triple = self.triples[index]\n", 166 | " return triple[0], triple[1], triple[2]" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "## Negative Sampler\n", 174 | "\n", 175 | "In order to train the model, we need not only positive samples, but also negative samples. The goal of the negative sampler is to generate negative samples based on the positive samples in the dataset.\n", 176 | "\n", 177 | "Since there are multiple negative sampling strategies, we abstract a common abstract class `NegativeSampler`, and all negative samplers that implement different negative sampling strategies should inherit from this abstract base class." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "class NegativeSampler(ABC):\n", 187 | " def __init__(self, dataset: KRLDataset, device: torch.device):\n", 188 | " self.dataset = dataset\n", 189 | " self.device = device\n", 190 | " \n", 191 | " @abstractmethod\n", 192 | " def neg_sample(self, heads, rels, tails):\n", 193 | " \"\"\"执行负采样\n", 194 | "\n", 195 | " :param heads: 由 batch_size 个 head idx 组成的 tensor,size: [batch_size]\n", 196 | " :param rels: size [batch_size]\n", 197 | " :param tails: size [batch_size]\n", 198 | " \"\"\"\n", 199 | " pass\n" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "The simplest negative sampling strategy is to randomly replace the head entity or tail entity in a triplet to obtain a negative sample." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "class RandomNegativeSampler(NegativeSampler):\n", 216 | " \"\"\"\n", 217 | " 随机替换 head 或者 tail 来实现采样\n", 218 | " \"\"\"\n", 219 | " def __init__(self, dataset: KRLDataset, device: torch.device):\n", 220 | " super().__init__(dataset, device)\n", 221 | " \n", 222 | " def neg_sample(self, heads, rels, tails):\n", 223 | " ent_num = len(self.dataset.entity2id)\n", 224 | " head_or_tail = torch.randint(high=2, size=heads.size(), device=self.device)\n", 225 | " random_entities = torch.randint(high=ent_num, size=heads.size(), device=self.device)\n", 226 | " corupted_heads = torch.where(head_or_tail == 1, random_entities, heads)\n", 227 | " corupted_tails = torch.where(head_or_tail == 0, random_entities, tails)\n", 228 | " return torch.stack([corupted_heads, rels, corupted_tails], dim=1)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## Model\n", 236 | "\n", 237 | "Defining the TransE model." 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 6, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "class TransE(nn.Module):\n", 247 | " def __init__(\n", 248 | " self,\n", 249 | " ent_num: int,\n", 250 | " rel_num: int,\n", 251 | " device: torch.device,\n", 252 | " norm: int,\n", 253 | " embed_dim: int,\n", 254 | " margin: float\n", 255 | " ):\n", 256 | " super().__init__()\n", 257 | " self.ent_num = ent_num\n", 258 | " self.rel_num = rel_num\n", 259 | " self.device = device\n", 260 | " self.norm = norm\n", 261 | " self.embed_dim = embed_dim\n", 262 | " self.margin = margin\n", 263 | "\n", 264 | " # Initialize ent_embedding\n", 265 | " self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)\n", 266 | " torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data)\n", 267 | " #uniform_range = 6 / np.sqrt(self.embed_dim)\n", 268 | " #self.ent_embedding.weight.data.uniform_(-uniform_range, uniform_range)\n", 269 | " \n", 270 | " # Initialize rel_embedding\n", 271 | " self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)\n", 272 | " torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data)\n", 273 | " #uniform_range = 6 / np.sqrt(self.embed_dim)\n", 274 | " #self.rel_embedding.weight.data.uniform_(-uniform_range, uniform_range)\n", 275 | "\n", 276 | " self.criterion = nn.MarginRankingLoss(margin=self.margin)\n", 277 | " \n", 278 | " def _distance(self, triples):\n", 279 | " \"\"\"Calculate the distance of a batch's triplet\n", 280 | "\n", 281 | " :param triples: triples of a batch,size: [batch, 3]\n", 282 | " :return: size: [batch,]\n", 283 | " \"\"\"\n", 284 | " heads = triples[:, 0]\n", 285 | " rels = triples[:, 1]\n", 286 | " tails = triples[:, 2]\n", 287 | " h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim]\n", 288 | " r_embs = self.rel_embedding(rels)\n", 289 | " t_embs = self.ent_embedding(tails)\n", 290 | " dist = h_embs + r_embs - t_embs # [batch, embed_dim]\n", 291 | " return torch.norm(dist, p=self.norm, dim=1)\n", 292 | " \n", 293 | " def loss(self, pos_distances, neg_distances):\n", 294 | " \"\"\"Calculate the loss of TransE training\n", 295 | "\n", 296 | " :param pos_distances: [batch, ]\n", 297 | " :param neg_distances: [batch, ]\n", 298 | " :return: loss\n", 299 | " \"\"\"\n", 300 | " ones = torch.tensor([-1], dtype=torch.long, device=self.device)\n", 301 | " return self.criterion(pos_distances, neg_distances, ones)\n", 302 | " \n", 303 | " def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):\n", 304 | " \"\"\"Return model losses based on the input.\n", 305 | "\n", 306 | " :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)\n", 307 | " :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)\n", 308 | " :return: tuple of the model loss, positive triplets loss component, negative triples loss component\n", 309 | " \"\"\"\n", 310 | " assert pos_triples.size()[1] == 3\n", 311 | " assert neg_triples.size()[1] == 3\n", 312 | " \n", 313 | " pos_distances = self._distance(pos_triples)\n", 314 | " neg_distances = self._distance(neg_triples)\n", 315 | " loss = self.loss(pos_distances, neg_distances)\n", 316 | " return loss, pos_distances, neg_distances\n", 317 | " \n", 318 | " def predict(self, triples: torch.Tensor):\n", 319 | " \"\"\"Calculated dissimilarity score for given triplets.\n", 320 | "\n", 321 | " :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)\n", 322 | " :return: dissimilarity score for given triplets\n", 323 | " \"\"\"\n", 324 | " return self._distance(triples)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "## Metric\n", 332 | "\n", 333 | "Calculate the metric for measuring the effect of link prediction, i.e. MRR and hits@10." 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 7, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "# metric\n", 343 | "\n", 344 | "def cal_hits_at_k(\n", 345 | " predictions: torch.Tensor,\n", 346 | " ground_truth_idx: torch.Tensor,\n", 347 | " device: torch.device,\n", 348 | " k: int\n", 349 | ") -> float:\n", 350 | " \"\"\"Calculates number of hits@k.\n", 351 | "\n", 352 | " :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions\n", 353 | " must be sorted in class ids order\n", 354 | " :param ground_truth_idx: Bx1 tensor with index of ground truth class\n", 355 | " :param k: number of top K results to be considered as hits\n", 356 | " :return: Hits@K scoreH\n", 357 | " \"\"\"\n", 358 | " assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size\n", 359 | " \n", 360 | " zero_tensor = torch.tensor([0], device=device)\n", 361 | " one_tensor = torch.tensor([1], device=device)\n", 362 | " _, indices = predictions.topk(k, largest=False) # indices: [batch_size, k]\n", 363 | " where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool\n", 364 | " hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item()\n", 365 | " return hits\n", 366 | "\n", 367 | "def cal_mrr(predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float:\n", 368 | " \"\"\"Calculates mean reciprocal rank (MRR) for given predictions and ground truth values.\n", 369 | "\n", 370 | " :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions\n", 371 | " must be sorted in class ids order\n", 372 | " :param ground_truth_idx: Bx1 tensor with index of ground truth class\n", 373 | " :return: Mean reciprocal rank score\n", 374 | " \"\"\"\n", 375 | " assert predictions.size(0) == ground_truth_idx.size(0)\n", 376 | "\n", 377 | " indices = predictions.argsort()\n", 378 | " return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item()\n" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "## Inference Operation\n", 386 | "\n", 387 | "Run the inference process for the model, i.e., iterate through the validation or test set and compute the metric." 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 8, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "def run_testing(\n", 397 | " model: TransE,\n", 398 | " dataloader: DataLoader,\n", 399 | " ent_num: int,\n", 400 | " device: torch.device,\n", 401 | ") -> Tuple[float, float, float, float]:\n", 402 | " \"\"\"Run test programs against Trans models\n", 403 | "\n", 404 | " :param model: TransE model\n", 405 | " :param ent_num: Number of entities in the dataset\n", 406 | " :return: _description_\n", 407 | " \"\"\"\n", 408 | " hits_at_1 = 0.0\n", 409 | " hits_at_3 = 0.0\n", 410 | " hits_at_10 = 0.0\n", 411 | " mrr = 0.0\n", 412 | " examples_count = 0\n", 413 | " \n", 414 | " # entity_ids = [[0, 1, 2, ..., ent_num]], shape: [1, ent_num]\n", 415 | " entitiy_ids = torch.arange(0, ent_num, device=device).unsqueeze(0)\n", 416 | " for i, batch in enumerate(dataloader):\n", 417 | " # batch: [3, batch_size]\n", 418 | " heads, rels, tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n", 419 | " batch_size = heads.size()[0]\n", 420 | " all_entities = entitiy_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num]\n", 421 | " # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num]\n", 422 | " heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num]\n", 423 | " rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num)\n", 424 | " tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num)\n", 425 | " # check all possible tails\n", 426 | " triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3]\n", 427 | " tails_predictions = model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num]\n", 428 | " # check all possible heads\n", 429 | " triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3)\n", 430 | " heads_predictions = model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num]\n", 431 | " \n", 432 | " # Concept preditions\n", 433 | " predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num]\n", 434 | " ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1]\n", 435 | " # calculate metrics\n", 436 | " hits_at_1 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=1)\n", 437 | " hits_at_3 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=3)\n", 438 | " hits_at_10 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=10)\n", 439 | " mrr += cal_mrr(predictions, ground_truth_entity_id)\n", 440 | " \n", 441 | " examples_count += predictions.size()[0]\n", 442 | " \n", 443 | " hits_at_1_score = hits_at_1 / examples_count * 100\n", 444 | " hits_at_3_score = hits_at_3 / examples_count * 100\n", 445 | " hits_at_10_score = hits_at_10 / examples_count * 100\n", 446 | " mrr_score = mrr / examples_count * 100\n", 447 | " \n", 448 | " return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "## Checkpoint\n", 456 | "\n", 457 | "During the training process, if the model outperforms the best score on the validation set, the model state at that time should be transformed into a checkpoint and saved to disk.\n", 458 | "\n", 459 | "The process of storing and loading checkpoints is simply encapsulated here." 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 9, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "class CheckpointFormat(BaseModel):\n", 469 | " model_state_dict: dict\n", 470 | " optim_state_dict: dict\n", 471 | " epoch_id: int\n", 472 | " best_score: float\n", 473 | "\n", 474 | "\n", 475 | "def save_checkpoint(model: TransE,\n", 476 | " optimzer: torch.optim.Optimizer,\n", 477 | " epoch_id: int,\n", 478 | " best_score: float,\n", 479 | " train_conf: TrainConf):\n", 480 | " ckpt = CheckpointFormat(\n", 481 | " model_state_dict=model.state_dict(),\n", 482 | " optim_state_dict=optimzer.state_dict(),\n", 483 | " epoch_id=epoch_id,\n", 484 | " best_score=best_score\n", 485 | " )\n", 486 | " torch.save(ckpt.dict(), train_conf.checkpoint_path)\n", 487 | "\n", 488 | "\n", 489 | "def load_checkpoint(train_conf: TrainConf) -> CheckpointFormat:\n", 490 | " ckpt = torch.load(train_conf.checkpoint_path)\n", 491 | " return CheckpointFormat.parse_obj(ckpt)\n", 492 | " " 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": {}, 498 | "source": [ 499 | "## Training Operation\n", 500 | "\n", 501 | "The process of training a model using a dataset.\n", 502 | "\n", 503 | "In the real library, this part of the functionality is encapsulated in a `Trainer` class." 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 10, 509 | "metadata": {}, 510 | "outputs": [], 511 | "source": [ 512 | "def run_training(model: TransE,\n", 513 | " train_conf: TrainConf,\n", 514 | " params: HyperParam,\n", 515 | " device: torch.device,\n", 516 | " dataset_conf: DatasetConf,\n", 517 | " entity2id: Dict[str, int],\n", 518 | " rel2id: Dict[str, int]):\n", 519 | " # 准备数据集\n", 520 | " train_dataset = KRLDataset(dataset_conf, 'train', entity2id, rel2id)\n", 521 | " valid_dataset = KRLDataset(dataset_conf, 'valid', entity2id, rel2id)\n", 522 | " # dataset -> dataloader\n", 523 | " train_dataloder = DataLoader(train_dataset, params.batch_size)\n", 524 | " valid_dataloder = DataLoader(valid_dataset, params.valid_batch_size)\n", 525 | " # 负采样器\n", 526 | " train_neg_sampler = RandomNegativeSampler(train_dataset, device)\n", 527 | " valid_neg_sampler = RandomNegativeSampler(valid_dataset, device)\n", 528 | " # 准备训练的工具\n", 529 | " optimzer = torch.optim.Adam(model.parameters(), lr=params.learning_rate)\n", 530 | " min_valid_loss = 10000.0\n", 531 | " best_score = 0.0\n", 532 | " # training loop\n", 533 | " for epoch_id in range(1, params.epoch_size + 1):\n", 534 | " print(\"Starting epoch: \", epoch_id)\n", 535 | " loss_sum = 0\n", 536 | " model.train()\n", 537 | " for i, batch in enumerate(train_dataloder):\n", 538 | " # 获取一个 batch 的训练资料\n", 539 | " pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n", 540 | " pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3]\n", 541 | " neg_triples = train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3]\n", 542 | " optimzer.zero_grad()\n", 543 | " # 计算 loss\n", 544 | " loss, pos_dist, neg_dist = model(pos_triples, neg_triples)\n", 545 | " loss.backward()\n", 546 | " loss_sum += loss.cpu().item()\n", 547 | " # update model\n", 548 | " optimzer.step()\n", 549 | " \n", 550 | " if epoch_id % params.valid_freq == 0:\n", 551 | " model.eval()\n", 552 | " _, _, hits_at_10, _ = run_testing(model, valid_dataloder, len(valid_dataset.entity2id), device)\n", 553 | " score = hits_at_10\n", 554 | " print('valid hits@10:', score)\n", 555 | " if score > best_score:\n", 556 | " best_score = score\n", 557 | " print('best score of valid: ', best_score)\n", 558 | " save_checkpoint(model, optimzer, epoch_id, best_score, train_conf)\n", 559 | " " 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 11, 565 | "metadata": {}, 566 | "outputs": [], 567 | "source": [ 568 | "def get_device() -> torch.device:\n", 569 | " return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", 570 | "\n", 571 | "\n", 572 | "def main(dataset_conf: DatasetConf,\n", 573 | " params: HyperParam,\n", 574 | " train_conf: TrainConf,\n", 575 | " device: torch.device):\n", 576 | " entity2id, rel2id = create_mapping(dataset_conf)\n", 577 | " device = get_device()\n", 578 | " ent_num = len(entity2id)\n", 579 | " rel_num = len(rel2id)\n", 580 | " model = TransE(ent_num, rel_num, device,\n", 581 | " norm=params.norm,\n", 582 | " embed_dim=params.embed_dim,\n", 583 | " margin=params.margin)\n", 584 | " model = model.to(device)\n", 585 | " run_training(model, train_conf, params, device, dataset_conf, entity2id, rel2id)\n", 586 | " \n", 587 | " # Testing the best checkpoint on test dataset\n", 588 | " ckpt = load_checkpoint(train_conf)\n", 589 | " model.load_state_dict(ckpt.model_state_dict)\n", 590 | " model = model.to(device)\n", 591 | " test_dataset = KRLDataset(dataset_conf, 'test', entity2id, rel2id)\n", 592 | " test_dataloder = DataLoader(test_dataset, params.valid_batch_size)\n", 593 | " hits_at_1, hits_at_3, hits_at_10, mrr = run_testing(model, test_dataloder, ent_num, device)\n", 594 | " \n", 595 | " # write results\n", 596 | " with open(train_conf.metric_result_path, 'w') as f:\n", 597 | " f.write(f'dataset: {dataset_conf.dataset_name}\\n')\n", 598 | " f.write(f'Hits@1: {hits_at_1}\\n')\n", 599 | " f.write(f'Hits@3: {hits_at_3}\\n')\n", 600 | " f.write(f'Hits@10: {hits_at_10}\\n')\n", 601 | " f.write(f'MRR: {mrr}\\n')" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": {}, 607 | "source": [ 608 | "## Begin\n", 609 | "\n", 610 | "Instantiate the configuration and call the main function.\n", 611 | "\n", 612 | "Next, you need the path of the dataset and where to save the checkpoints." 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 12, 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [ 621 | "fb15k_dataset_conf = DatasetConf(\n", 622 | " dataset_name='FB15K',\n", 623 | " base_dir='/root/yubin/dataset/KRL/master/FB15k/' # TODO: change it!\n", 624 | ")\n", 625 | "\n", 626 | "fb15k_hyper_params = HyperParam(\n", 627 | " valid_freq=5,\n", 628 | " batch_size=128,\n", 629 | " valid_batch_size=64,\n", 630 | " learning_rate=0.001,\n", 631 | " epoch_size=500,\n", 632 | " embed_dim=50,\n", 633 | " norm=1,\n", 634 | " margin=2.0\n", 635 | ")\n", 636 | "\n", 637 | "fb15k_train_conf = TrainConf(\n", 638 | " checkpoint_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k.ckpt', # TODO: change it!\n", 639 | " metric_result_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k_metrics.txt' # TODO: change it!\n", 640 | ")\n", 641 | "\n", 642 | "device = get_device()" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "metadata": {}, 649 | "outputs": [], 650 | "source": [ 651 | "main(fb15k_dataset_conf, fb15k_hyper_params, fb15k_train_conf, device)" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 14, 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [ 660 | "ckpt = torch.load(fb15k_train_conf.checkpoint_path)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 15, 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "data": { 670 | "text/plain": [ 671 | "{'model_state_dict': {'ent_embedding.weight': tensor([[ 0.0126, 1.1495, 0.7716, ..., -0.8967, 0.3144, -0.2118],\n", 672 | " [ 0.6201, -1.0493, 0.7837, ..., 0.2150, 0.2489, 0.0075],\n", 673 | " [-0.2746, -0.2592, 0.1011, ..., -1.1602, 0.1287, 0.4214],\n", 674 | " ...,\n", 675 | " [ 0.8894, 0.2856, 0.2504, ..., 0.8878, 0.9393, -0.2801],\n", 676 | " [ 0.8036, 0.7737, 0.1246, ..., 0.2968, -0.0092, -0.2363],\n", 677 | " [-1.0176, -0.0581, -0.5224, ..., -0.3004, -1.3833, -0.6132]],\n", 678 | " device='cuda:0'),\n", 679 | " 'rel_embedding.weight': tensor([[-0.0909, -0.4988, -0.7200, ..., -0.0086, 0.0742, -0.1348],\n", 680 | " [-0.9219, 0.5674, 0.4704, ..., -1.6722, -1.1977, -0.0505],\n", 681 | " [-0.6558, -0.0297, -0.1181, ..., -0.2469, -0.2535, -0.6061],\n", 682 | " ...,\n", 683 | " [ 0.3357, 0.2950, -0.2752, ..., 0.5626, 0.3279, -0.5322],\n", 684 | " [-0.3469, -0.0041, -0.7309, ..., 0.1166, 0.0848, 0.3135],\n", 685 | " [ 0.0369, -0.0189, -0.0132, ..., 0.1579, -0.0358, -0.0636]],\n", 686 | " device='cuda:0')},\n", 687 | " 'optim_state_dict': {'state': {0: {'step': tensor(1887500.),\n", 688 | " 'exp_avg': tensor([[-5.6052e-45, 5.6052e-45, -5.6052e-45, ..., -5.6052e-45,\n", 689 | " -5.6052e-45, -5.6052e-45],\n", 690 | " [ 5.6052e-45, -5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 691 | " -5.6052e-45, -5.6052e-45],\n", 692 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 693 | " -5.6052e-45, 5.6052e-45],\n", 694 | " ...,\n", 695 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 696 | " 5.6052e-45, 5.6052e-45],\n", 697 | " [-5.6052e-45, 5.6052e-45, -5.6052e-45, ..., 5.6052e-45,\n", 698 | " 5.6052e-45, 5.6052e-45],\n", 699 | " [-3.9882e-41, 3.9814e-41, 3.9882e-41, ..., -3.9882e-41,\n", 700 | " -3.9882e-41, 3.9882e-41]], device='cuda:0'),\n", 701 | " 'exp_avg_sq': tensor([[5.3320e-10, 5.3320e-10, 5.3320e-10, ..., 5.3320e-10, 5.3320e-10,\n", 702 | " 5.3320e-10],\n", 703 | " [1.1443e-09, 1.1443e-09, 1.1443e-09, ..., 1.1443e-09, 1.8754e-08,\n", 704 | " 1.8754e-08],\n", 705 | " [3.3255e-08, 3.3255e-08, 2.2694e-08, ..., 3.3255e-08, 3.3255e-08,\n", 706 | " 2.2694e-08],\n", 707 | " ...,\n", 708 | " [3.6491e-13, 3.6491e-13, 3.6491e-13, ..., 3.6491e-13, 3.6491e-13,\n", 709 | " 3.6491e-13],\n", 710 | " [2.1251e-13, 2.1251e-13, 2.1251e-13, ..., 2.1251e-13, 2.1251e-13,\n", 711 | " 2.1251e-13],\n", 712 | " [5.5608e-08, 5.5608e-08, 5.5608e-08, ..., 5.5608e-08, 5.5608e-08,\n", 713 | " 5.5608e-08]], device='cuda:0')},\n", 714 | " 1: {'step': tensor(1887500.),\n", 715 | " 'exp_avg': tensor([[-5.6052e-45, -5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 716 | " 5.6052e-45, 5.6052e-45],\n", 717 | " [ 5.6052e-45, -5.6052e-45, 5.6052e-45, ..., 5.6052e-45,\n", 718 | " 5.6052e-45, 5.6052e-45],\n", 719 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 720 | " 5.6052e-45, -5.6052e-45],\n", 721 | " ...,\n", 722 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., 5.6052e-45,\n", 723 | " 5.6052e-45, 5.6052e-45],\n", 724 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 725 | " -5.6052e-45, 5.6052e-45],\n", 726 | " [-5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n", 727 | " -5.6052e-45, 5.6052e-45]], device='cuda:0'),\n", 728 | " 'exp_avg_sq': tensor([[7.0065e-43, 3.1465e-32, 3.1465e-32, ..., 4.8395e-37, 7.0065e-43,\n", 729 | " 7.0065e-43],\n", 730 | " [2.6102e-08, 2.8141e-08, 1.2741e-07, ..., 2.5118e-10, 2.0500e-08,\n", 731 | " 1.0558e-07],\n", 732 | " [2.6078e-24, 7.9196e-29, 7.0065e-43, ..., 2.6077e-24, 7.9196e-29,\n", 733 | " 7.0065e-43],\n", 734 | " ...,\n", 735 | " [6.1684e-23, 1.7540e-34, 7.0065e-43, ..., 7.0065e-43, 1.7540e-34,\n", 736 | " 1.7540e-34],\n", 737 | " [7.0065e-43, 4.4173e-41, 4.4173e-41, ..., 7.0065e-43, 7.0065e-43,\n", 738 | " 7.0065e-43],\n", 739 | " [7.0065e-43, 7.0065e-43, 7.0065e-43, ..., 7.0065e-43, 7.0065e-43,\n", 740 | " 7.0065e-43]], device='cuda:0')}},\n", 741 | " 'param_groups': [{'lr': 0.001,\n", 742 | " 'betas': (0.9, 0.999),\n", 743 | " 'eps': 1e-08,\n", 744 | " 'weight_decay': 0,\n", 745 | " 'amsgrad': False,\n", 746 | " 'maximize': False,\n", 747 | " 'foreach': None,\n", 748 | " 'capturable': False,\n", 749 | " 'differentiable': False,\n", 750 | " 'fused': False,\n", 751 | " 'params': [0, 1]}]},\n", 752 | " 'epoch_id': 500,\n", 753 | " 'best_score': 39.291}" 754 | ] 755 | }, 756 | "execution_count": 15, 757 | "metadata": {}, 758 | "output_type": "execute_result" 759 | } 760 | ], 761 | "source": [ 762 | "ckpt" 763 | ] 764 | } 765 | ], 766 | "metadata": { 767 | "kernelspec": { 768 | "display_name": "Python 3 (ipykernel)", 769 | "language": "python", 770 | "name": "python3" 771 | }, 772 | "language_info": { 773 | "codemirror_mode": { 774 | "name": "ipython", 775 | "version": 3 776 | }, 777 | "file_extension": ".py", 778 | "mimetype": "text/x-python", 779 | "name": "python", 780 | "nbconvert_exporter": "python", 781 | "pygments_lexer": "ipython3", 782 | "version": "3.9.15" 783 | }, 784 | "vscode": { 785 | "interpreter": { 786 | "hash": "0418effca45178467ac68c18e34d93809a092be692e0a4443d8690099b71f4bc" 787 | } 788 | } 789 | }, 790 | "nbformat": 4, 791 | "nbformat_minor": 4 792 | } 793 | -------------------------------------------------------------------------------- /typer_app.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from krl.typer_apps.rescal import app as rescal_app 4 | from krl.typer_apps.transe import app as transe_app 5 | from krl.typer_apps.transh import app as transh_app 6 | from krl.typer_apps.distmult import app as distmult_app 7 | from krl.typer_apps.transr import app as transr_app 8 | 9 | 10 | 11 | app = typer.Typer() 12 | 13 | app.add_typer(rescal_app, name='RESCAL') 14 | app.add_typer(transe_app, name='TransE') 15 | app.add_typer(transh_app, name='TransH') 16 | app.add_typer(distmult_app, name='DistMult') 17 | app.add_typer(transr_app, name='TransR') 18 | 19 | 20 | if __name__ == '__main__': 21 | app() 22 | --------------------------------------------------------------------------------