├── .gitignore
├── LICENSE
├── README.md
├── examples
├── DistMult
│ └── DistMult-FB15k.sh
├── RESCAL
│ └── RESCAL-FB15k.sh
├── TransE
│ └── TransE-FB15k.sh
├── TransH
│ └── TransH-FB15k.sh
└── TransR
│ └── TransR-FB15k.sh
├── images
└── structure.png
├── krl
├── __init__.py
├── base_model.py
├── config.py
├── dataset
│ ├── __init__.py
│ ├── instance
│ │ ├── __init__.py
│ │ ├── huggingface_krl_datasets_conf.json
│ │ ├── index.py
│ │ └── utils.py
│ └── krl_dataset.py
├── evaluator.py
├── lit_model
│ ├── TransXLitModel.py
│ └── __init__.py
├── metric.py
├── metric_fomatter.py
├── models
│ ├── DistMult.py
│ ├── RESCAL.py
│ ├── TransD.py
│ ├── TransE.py
│ ├── TransH.py
│ ├── TransR.py
│ └── __init__.py
├── negative_sampler.py
├── serializer.py
├── storage.py
├── trainer.py
├── typer_apps
│ ├── __init__.py
│ ├── distmult.py
│ ├── rescal.py
│ ├── transe.py
│ ├── transh.py
│ └── transr.py
└── utils
│ ├── __init__.py
│ ├── data.py
│ ├── device.py
│ ├── logs_dir.py
│ ├── optim.py
│ └── seed.py
├── requirements.txt
├── test.ipynb
├── transe.ipynb
└── typer_app.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94 | __pypackages__/
95 |
96 | # Celery stuff
97 | celerybeat-schedule
98 | celerybeat.pid
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | # Run my scripts
131 | my-scripts/
132 | nohup.out
133 |
134 | # pytorch lightning
135 | lightning_logs/
136 | ./test.ipynb
137 |
138 | # model checkpoint
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 yubin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KRL
2 |
3 | This is a framework that uses PyTorch to reproduce Knowledge Representation Learning (KRL) models. With it, you can easily run models of knowledge representation learning, while quickly implementing your ideas after making simple changes or extensions.
4 |
5 | Currently, we have implemented several knowledge representation learning models including **TransE**, **RESCAL**, and simple abstraction and decoupling of publicly reusable code.
6 |
7 | Don't use this repo because I am developing it. So it must suffer broken changes in the future.
8 |
9 | The overview structure of this library is below:
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## How to use it?
17 |
18 | Don't use it now! The document has not been updated. I will write it soon.
19 |
20 | If you just want to use the models directly, then **you just need to run the sample scripts located in the directory** `./examples` and all you need to do is change a few parameters that you want to change. The call to the training code has been wrapped through [Typer](https://typer.tiangolo.com/), which is a great tool for building CLIs. So we can call the training code by CLIs. The examples can be found in the directory `./examples`.
21 |
22 | Example:
23 |
24 | ```shell
25 | cd ./examples
26 | sh transe.sh
27 | ```
28 |
29 | If you are trying to make some changes or innovations, then you may need to briefly understand the logic of how this library works. Fortunately, the code is cleanly wrapped and decoupled, which makes it easy for you to understand the logic of the entire program and make changes to parts of it.
30 |
31 | First, run through the `transe.ipynb` notebook in the project root directory and you will get to know the core logic of the library. Therefore, successfully running through this notebook is the first step to using the project.
32 | This notebook replicates the operation of the **TransE** model.
33 |
34 | In this real program library, there is still a difference between the running logic of a model and the logic in a notebook. To realize the reuse of module code, the program library again abstracts and encapsulates part of the program code, but this does not change the core idea of the model operation.
35 |
36 |
37 | + Notice: Before you run this script, you should download the dataset, such as FB15k, and modify the script for choosing the path of the dataset and checkpoints.
38 |
39 | The example `./transe.ipynb` is a good tutorial for reproducing the TransE if you want to know the structure of this repo. This tutorial can be run without any dependencies, except for the use of common third-party libraries like PyTorch and Numpy.
40 |
41 | ## Dataset
42 |
43 | In this paper, the different types of datasets are uniformly encapsulated in the `KRLDataset` class. But the data in the dataset needs to be downloaded by you and the disk path to the dataset needs to be specified when generating the configuration instance.
44 |
45 | You can download the dataset here: [KGDataset](https://github.com/ZhenfengLei/KGDatasets).
46 |
47 |
48 | ## Plan
49 |
50 | | Status | Model | Year | Paper | Rewarks |
51 | | :----: | :----: | :----: | :--- | --- |
52 | | :heavy_check_mark: | [RESCAL](/krl/models/RESCAL.py) | 2011 | ICML'11, [OpenReview](https://openreview.net/forum?id=H14QEiZ_WS) | |
53 | | :heavy_check_mark: | [TransE](/krl/models/TransE.py) | 2013 | NIPS'13, [ACM](http://dl.acm.org/doi/10.5555/2999792.2999923) | |
54 | | :heavy_check_mark: | [TransH](/krl/models/TransH.py) | 2014 | AAAI'14, [ReasearchGate](https://www.researchgate.net/publication/319207032_Knowledge_Graph_Embedding_by_Translating_on_Hyperplanes) | |
55 | | :heavy_check_mark: | [DistMult](/krl/models/DistMult.py) | 2014 | ICLR'15, [arXiv](http://arxiv.org/abs/1412.6575) | |
56 | | :heavy_check_mark: | [TransR](/krl/models/TransR.py) | 2015 | AAAI'15, [AAAI](https://ojs.aaai.org/index.php/AAAI/article/view/9491) | a low performance, but I don't know why. |
57 | | :white_circle: | TransD | 2015 | ACL-IJCNLP 2015, [Aclanthology](https://aclanthology.org/P15-1067) | |
58 | | :white_circle: | TransA | 2015 | [arXiv](https://arxiv.org/abs/1509.05490) | |
59 | | :white_circle: | TransG | 2015 | [arXiv](https://arxiv.org/abs/1509.05488) | |
60 | | :white_circle: | KG2E | 2015 | CIKM'15, [ACM](https://dl.acm.org/doi/10.1145/2806416.2806502) | |
61 | | :white_circle: | TranSparse | 2016 | AAAI'16, [AAAI](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/11982) | |
62 | | :white_circle: | TransF | 2016 | AAAI'16, [AAAI](https://www.aaai.org/ocs/index.php/KR/KR16/paper/view/12887) | |
63 | | :white_circle: | ComplEx | 2016 | ICML'16, [arXiv](http://arxiv.org/abs/1606.06357) | |
64 | | :white_circle: | HolE | 2016 | AAAI'16, [arXiv](http://arxiv.org/abs/1510.04935) | |
65 | | :white_circle: | R-GCN | 2017 | ESWC'18, [arXiv](http://arxiv.org/abs/1703.06103) | |
66 | | :white_circle: | ConvKB | 2018 | NAACL-HLT 2018, [arXiv](http://arxiv.org/abs/1712.02121) | |
67 | | :white_circle: | ConvE | 2018 | AAAI'18, [arXiv](http://arxiv.org/abs/1707.01476) | |
68 | | :white_circle: | SimplE | 2018 | NIPS'18, [arXiv](http://arxiv.org/abs/1802.04868) | |
69 | | :white_circle: | RotatE | 2019 | ICLR'19, [arXiv](http://arxiv.org/abs/1902.10197) | |
70 | | :white_circle: | QuatE | 2019 | NeurIPS'19, [arXiv](http://arxiv.org/abs/1904.10281) | |
71 | | :white_circle: | ConvR | 2019 | NAACL-HLT 2019, [Aclanthology](https://aclanthology.org/N19-1103) | |
72 | | :white_circle: | KG-BERT | 2019 | [arXiv](http://arxiv.org/abs/1909.03193) | |
73 | | :white_circle: | PairRE | 2021 | ACL-IJCNLP 2021, [Aclanthology](https://aclanthology.org/2021.acl-long.336) | |
74 |
75 |
76 | ## How to contribute it?
77 |
78 | If you have read the code of this project, then you should have understood the code style of this library. All you need to do is to extend it as you see fit and submit pull requests after successful testing.
79 |
--------------------------------------------------------------------------------
/examples/DistMult/DistMult-FB15k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd .. # /examples/
4 | cd ../krl # /krl
5 | python typer_app.py DistMult train --dataset-name "FB15k"\
6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \
7 | --batch-size 128 \
8 | --valid-batch-size 32 \
9 | --valid-freq 3 \
10 | --lr 0.001 \
11 | --epoch-size 500 \
12 | --embed-dim 50 \
13 | --alpha 0.001 \
14 | --regul-type F2 \
15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/distmult_fb15k.ckpt \
16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/distmult_fb15k_metrics.txt
17 |
--------------------------------------------------------------------------------
/examples/RESCAL/RESCAL-FB15k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd .. # /examples/
4 | cd ../krl # /krl
5 | python typer_app.py RESCAL train --dataset-name "FB15k"\
6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \
7 | --batch-size 128 \
8 | --valid-batch-size 32 \
9 | --valid-freq 3 \
10 | --lr 0.001 \
11 | --epoch-size 500 \
12 | --embed-dim 50 \
13 | --alpha 0.001 \
14 | --regul-type F2 \
15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k.ckpt \
16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k_metrics.txt
17 |
--------------------------------------------------------------------------------
/examples/TransE/TransE-FB15k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd .. # /examples
4 | cd ../krl # /krl
5 | python typer_app.py TransE train --dataset-name "FB15k"\
6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \
7 | --batch-size 128 \
8 | --valid-batch-size 64 \
9 | --valid-freq 5 \
10 | --lr 0.001 \
11 | --epoch-size 500 \
12 | --embed-dim 50 \
13 | --norm 1 \
14 | --margin 2.0 \
15 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k.ckpt \
16 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k_metrics.txt
17 |
--------------------------------------------------------------------------------
/examples/TransH/TransH-FB15k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd .. # /examples/
4 | cd ../krl # /krl
5 | python typer_app.py TransH train --dataset-name "FB15k"\
6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \
7 | --batch-size 4800 \
8 | --valid-batch-size 64 \
9 | --valid-freq 5 \
10 | --lr 0.001 \
11 | --epoch-size 10 \
12 | --embed-dim 100 \
13 | --norm 1 \
14 | --margin 2.0 \
15 | --c 1.0 \
16 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transh_fb15k.ckpt \
17 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transh_fb15k_metrics.txt
18 |
--------------------------------------------------------------------------------
/examples/TransR/TransR-FB15k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd .. # /examples/
4 | cd ../krl # /krl
5 | python typer_app.py TransR train --dataset-name "FB15k"\
6 | --base-dir /root/yubin/dataset/KRL/master/FB15k \
7 | --batch-size 4800 \
8 | --valid-batch-size 32 \
9 | --valid-freq 25 \
10 | --lr 0.001 \
11 | --epoch-size 400 \
12 | --embed-dim 50 \
13 | --norm 1 \
14 | --margin 1.0 \
15 | --c 0.1 \
16 | --ckpt-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transr_fb15k.ckpt \
17 | --metric-result-path /root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transr_fb15k_metrics.txt
18 |
--------------------------------------------------------------------------------
/images/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/images/structure.png
--------------------------------------------------------------------------------
/krl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/__init__.py
--------------------------------------------------------------------------------
/krl/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for various KRL models.
3 | """
4 |
5 | import torch.nn as nn
6 | import torch
7 | from abc import ABC, abstractmethod
8 |
9 | from .dataset import KRLDatasetDict
10 | from .config import TrainConf, HyperParam
11 |
12 |
13 | class KRLModel(nn.Module, ABC):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | @abstractmethod
18 | def embed(self, triples):
19 | """get the embeddings of the triples
20 |
21 | :param triples: a batch of triples, which consists of (heads, rels, tails) id.
22 | :return: the embedding of (heads, rels, tails)
23 | """
24 | pass
25 |
26 | @abstractmethod
27 | def loss(self):
28 | """Return model losses based on the input.
29 | """
30 | pass
31 |
32 | @abstractmethod
33 | def forward(self):
34 | """Return model losses based on the input.
35 | """
36 | pass
37 |
38 | @abstractmethod
39 | def predict(self, triples: torch.Tensor):
40 | """Calculated dissimilarity score for given triplets.
41 |
42 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
43 | :return: dissimilarity score for given triplets
44 | """
45 | pass
46 |
47 |
48 | class TransXBaseModel(KRLModel):
49 | pass
50 |
51 |
52 | class ModelMain(ABC):
53 | """Act as a main function for a KRL model.
54 | """
55 | def __init__(self) -> None:
56 | super().__init__()
57 |
58 | @abstractmethod
59 | def __call__(self):
60 | pass
61 |
62 |
63 | class LitModelMain(ModelMain):
64 | def __init__(
65 | self,
66 | dataset: KRLDatasetDict,
67 | train_conf: TrainConf,
68 | seed: int = None,
69 | ) -> None:
70 | super().__init__()
71 | self.datasets = dataset
72 | self.dataset_conf = dataset.dataset_conf
73 | self.train_conf = train_conf
74 | self.seed = seed
75 | self.params = None
76 |
77 | @abstractmethod
78 | def __call__(self):
79 | pass
--------------------------------------------------------------------------------
/krl/config.py:
--------------------------------------------------------------------------------
1 | """
2 | 用于填写配置信息的 Settings
3 | """
4 |
5 | from pydantic import BaseModel, Field
6 | from abc import ABC
7 | from pathlib import Path
8 | from typing import Optional, Mapping
9 |
10 |
11 | ######## Dataset config ########
12 |
13 |
14 | class DatasetConf(BaseModel, ABC):
15 | dataset_name: str = Field(title='数据集的名称,方便打印时查看')
16 |
17 |
18 | class BuiletinDatasetConf(DatasetConf):
19 | """
20 | 内置的数据集的相关配置信息
21 | """
22 | source: str = Field(title='数据集的来源')
23 |
24 |
25 | class BuiletinHuggingfaceDatasetConf(BuiletinDatasetConf):
26 | """
27 | 存放于 Hugging Face datasets 上的数据集
28 | :param BuiletinDatasetConf: _description_
29 | """
30 | huggingface_repo: str = Field(title='Hugging Face 中存放该数据集的 repo')
31 |
32 |
33 | class LocalDatasetConf(DatasetConf):
34 | """
35 | 存放于本地的数据集的相关配置信息
36 | """
37 | base_dir: Optional[Path] = Field(title='数据集的目录')
38 | entity2id_path: Optional[str] = Field(default='entity2id.txt', title='entity2id 的文件名')
39 | relation2id_path: Optional[str] = Field(default='relation2id.txt', title='relation2id 的文件名')
40 | train_path: Optional[str] = Field(default='train.txt', title='training set 的文件')
41 | valid_path: Optional[str] = Field(default='valid.txt', title='valid set 的文件')
42 | test_path: Optional[str] = Field(default='test.txt', title='testing set 的文件')
43 |
44 |
45 | ######## Dataset meta-data ########
46 |
47 | class KRLDatasetMeta(BaseModel):
48 | entity2id: Mapping[str, int]
49 | rel2id: Mapping[str, int]
50 |
51 |
52 | ######## Hyper-parameters config ########
53 |
54 |
55 | class HyperParam(BaseModel, ABC):
56 | """
57 | 超参数,所有超参数的 Config 类都应该继承于它
58 | """
59 | batch_size: int = 128
60 | valid_batch_size: int = 64
61 | learning_rate: float = 0.001
62 | optimizer: str = Field(defualt='adam', title='optimizer name')
63 | epoch_size: int = 500
64 | valid_freq: int = Field(defualt=5, title='训练过程中,每隔多少次就做一次 valid 来验证是否保存模型')
65 | num_works: int = Field(default=64, title='`num_works` in dataloader')
66 | early_stoping_patience: int = Field(default=5, title='the patience of EarlyStoping')
67 |
68 |
69 |
70 | ######## Training config ########
71 |
72 |
73 | class TrainConf(BaseModel):
74 | """
75 | 训练的一些配置
76 | """
77 | logs_dir: Path = Field(title='The directory used to keep the log.')
78 |
79 |
80 | ######## Other HyperParam class for various models ##########
81 |
82 | class TransHyperParam(HyperParam):
83 | """
84 | Trans 系列模型的超参数类
85 | """
86 | embed_dim: int = 50
87 | norm: int = 1
88 | margin: int = 2.0
89 |
--------------------------------------------------------------------------------
/krl/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .krl_dataset import KRLDataset, LocalKRLDataset, create_mapping, BuiletinHuggingfaceDataset, KRLDatasetDict
2 | from .instance import load_krl_dataset
--------------------------------------------------------------------------------
/krl/dataset/instance/__init__.py:
--------------------------------------------------------------------------------
1 | from .index import load_krl_dataset
2 |
--------------------------------------------------------------------------------
/krl/dataset/instance/huggingface_krl_datasets_conf.json:
--------------------------------------------------------------------------------
1 | {
2 | "fb15k": {
3 | "name": "FB15k",
4 | "repo": "VLyb/FB15k"
5 | },
6 | "wn18": {
7 | "name": "WN18",
8 | "repo": "VLyb/WN18"
9 | },
10 | "fb15k-237": {
11 | "name": "FB15k-237",
12 | "repo": "VLyb/FB15k-237"
13 | },
14 | "wn18rr": {
15 | "name": "WN18RR",
16 | "repo": "VLyb/WN18RR"
17 | },
18 | "yago3-10": {
19 | "name": "YAGO3-10",
20 | "repo": "VLyb/YAGO3-10"
21 | },
22 | "nations": {
23 | "name": "Nations",
24 | "repo": "VLyb/Nations"
25 | },
26 | "dbpedia50": {
27 | "name": "DBpedia50",
28 | "repo": "VLyb/DBpedia50"
29 | },
30 | "dbpedia500": {
31 | "name": "DBpedia500",
32 | "repo": "VLyb/DBpedia500"
33 | },
34 | "kinship": {
35 | "name": "Kinship",
36 | "repo": "VLyb/Kinship"
37 | },
38 | "umls": {
39 | "name": "UMLS",
40 | "repo": "VLyb/UMLS"
41 | }
42 | }
--------------------------------------------------------------------------------
/krl/dataset/instance/index.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import json
3 | from pathlib import Path
4 |
5 | from .utils import convert_huggingface_dataset
6 | from ..krl_dataset import BuiletinHuggingfaceDataset, KRLDatasetDict
7 | from krl.config import BuiletinHuggingfaceDatasetConf
8 |
9 |
10 |
11 | def __load_huggingface_krl_dataset(dataset_name: str) -> KRLDatasetDict:
12 | json_path = Path(__file__).parent / 'huggingface_krl_datasets_conf.json'
13 | with json_path.open() as f:
14 | name_to_confs = json.load(f)
15 | conf_dict = name_to_confs.get(dataset_name.lower())
16 | if conf_dict is None:
17 | raise NotImplemented(f"dataset {dataset_name} hasn't been implemented.")
18 | conf = BuiletinHuggingfaceDatasetConf(
19 | dataset_name=conf_dict['name'],
20 | source='HuggingFace',
21 | huggingface_repo=conf_dict['repo']
22 | )
23 | dataset_dict = load_dataset(conf.huggingface_repo)
24 | train_triples, valid_triples, test_triples, dataset_meta = convert_huggingface_dataset(dataset_dict)
25 | return KRLDatasetDict(
26 | train=BuiletinHuggingfaceDataset(conf, 'train', dataset_meta, train_triples),
27 | valid=BuiletinHuggingfaceDataset(conf, 'valid', dataset_meta, valid_triples),
28 | test=BuiletinHuggingfaceDataset(conf, 'test', dataset_meta, test_triples),
29 | meta=dataset_meta,
30 | dataset_conf=conf
31 | )
32 |
33 |
34 | def load_krl_dataset(dataset_name: str):
35 | return __load_huggingface_krl_dataset(dataset_name)
36 |
--------------------------------------------------------------------------------
/krl/dataset/instance/utils.py:
--------------------------------------------------------------------------------
1 | from datasets.dataset_dict import DatasetDict as HuggingfaceDatasetDict
2 | from datasets.arrow_dataset import Dataset as HuggingFaceDataset
3 | from typing import Tuple, List
4 |
5 | from ..krl_dataset import BuiletinHuggingfaceDataset
6 | from ...config import KRLDatasetMeta
7 |
8 |
9 | Triple = Tuple[int, int, int]
10 | Triples = List[Triple]
11 |
12 | def convert_huggingface_dataset(
13 | dataset_dict: HuggingfaceDatasetDict
14 | ) -> Tuple[Triples, Triples, Triples, KRLDatasetMeta]:
15 | entity2id = {}
16 | rel2id = {}
17 |
18 | def read_dataset(dataset: HuggingFaceDataset):
19 | triples = []
20 | for triple in dataset:
21 | head, rel, tail = str(triple['head']), str(triple['relation']), str(triple['tail'])
22 | if head not in entity2id:
23 | entity2id[head] = len(entity2id)
24 | head_id = entity2id[head]
25 | if tail not in entity2id:
26 | entity2id[tail] = len(entity2id)
27 | tail_id = entity2id[tail]
28 | if rel not in rel2id:
29 | rel2id[rel] = len(rel2id)
30 | rel_id = rel2id[rel]
31 | triples.append((head_id, rel_id, tail_id))
32 | return triples
33 |
34 | train_triples = read_dataset(dataset_dict['train'])
35 | valid_triples = read_dataset(dataset_dict['validation'])
36 | test_triples = read_dataset(dataset_dict['test'])
37 |
38 | dataset_meta = KRLDatasetMeta(
39 | entity2id=entity2id,
40 | rel2id=rel2id
41 | )
42 |
43 | return train_triples, valid_triples, test_triples, dataset_meta
44 |
45 |
--------------------------------------------------------------------------------
/krl/dataset/krl_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | The dataset class used to read KRL data, such as FB15k
3 | """
4 |
5 | from typing import Literal, Tuple, Dict, List, Mapping
6 | from torch.utils.data import Dataset
7 | from abc import ABC
8 | from pydantic import BaseModel
9 |
10 | from ..config import DatasetConf, LocalDatasetConf, BuiletinHuggingfaceDatasetConf, KRLDatasetMeta
11 |
12 |
13 | EntityMapping = Dict[str, int]
14 | RelMapping = Dict[str, int]
15 | Triple = List[int]
16 |
17 | def create_mapping(dataset_conf: DatasetConf) -> KRLDatasetMeta:
18 | """
19 | create mapping of `entity2id` and `relation2id`
20 | """
21 | # 读取 entity2id
22 | entity2id = dict()
23 | entity2id_path = dataset_conf.base_dir / dataset_conf.entity2id_path
24 | if not entity2id_path.exists():
25 | raise FileNotFoundError(f'{entity2id_path} not found.')
26 | with entity2id_path.open() as f:
27 | for line in f:
28 | entity, entity_id = line.split()
29 | entity = entity.strip()
30 | entity_id = int(entity_id.strip())
31 | entity2id[entity] = entity_id
32 | # 读取 relation2id
33 | rel2id = dict()
34 | rel2id_path = dataset_conf.base_dir / dataset_conf.relation2id_path
35 | if not rel2id_path.exists():
36 | raise FileNotFoundError(f'{rel2id_path} not found.')
37 | with rel2id_path.open() as f:
38 | for line in f:
39 | rel, rel_id = line.split()
40 | rel = rel.strip()
41 | rel_id = int(rel_id.strip())
42 | rel2id[rel] = rel_id
43 | return KRLDatasetMeta(
44 | entity2id=entity2id,
45 | rel2id=rel2id
46 | )
47 |
48 |
49 | class KRLDataset(Dataset, ABC):
50 | """
51 | KRL 数据集的抽象类
52 | :param Dataset: _description_
53 | :param ABC: _description_
54 | """
55 | def __init__(self,
56 | dataset_conf: DatasetConf,
57 | mode: Literal['train', 'valid', 'test'],
58 | dataset_meta: KRLDatasetMeta) -> None:
59 | super().__init__()
60 | self.dataset_name = dataset_conf.dataset_name
61 | self.conf = dataset_conf
62 | if mode not in {'train', 'valid', 'test'}:
63 | raise ValueError(f'dataset mode not support:{mode} mode')
64 | self.mode = mode
65 | self.triples
66 | self.meta = dataset_meta
67 |
68 |
69 | class LocalKRLDataset(Dataset):
70 | """
71 | 存放于本地的数据集
72 | """
73 | def __init__(self,
74 | dataset_conf: LocalDatasetConf,
75 | mode: Literal['train', 'valid', 'test'],
76 | dataset_meta: KRLDatasetMeta) -> None:
77 | super().__init__()
78 | self.conf = dataset_conf
79 | if mode not in {'train', 'valid', 'test'}:
80 | raise ValueError(f'dataset mode not support:{mode} mode')
81 | self.mode = mode
82 | self.triples = []
83 | self.meta = dataset_meta
84 | self._read_triples() # 读取数据集,并获得所有的 triples
85 |
86 | def _split_and_to_id(self, line: str) -> Triple:
87 | """将数据集文件中的一行数据进行切分,并将 entity 和 rel 转换成 id
88 |
89 | :param line: 数据集的一行数据
90 | :return: [head_id, rel_id, tail_id]
91 | """
92 | head, tail, rel = line.split()
93 | head_id = self.meta.entity2id[head.strip()]
94 | tail_id = self.meta.entity2id[tail.strip()]
95 | rel_id = self.meta.rel2id[rel.strip()]
96 | return (head_id, rel_id, tail_id)
97 |
98 | def _read_triples(self):
99 | data_path = {
100 | 'train': self.conf.train_path,
101 | 'valid': self.conf.valid_path,
102 | 'test': self.conf.test_path
103 | }.get(self.mode)
104 | p = self.conf.base_dir / data_path
105 | if not p.exists():
106 | raise FileNotFoundError(f'{p} not found.')
107 | with p.open() as f:
108 | self.triples = [self._split_and_to_id(line) for line in f]
109 |
110 | def __len__(self):
111 | """Denotes the total number of samples."""
112 | return len(self.triples)
113 |
114 | def __getitem__(self, index) -> Triple:
115 | """Returns (head id, relation id, tail id)."""
116 | triple = self.triples[index]
117 | return triple[0], triple[1], triple[2]
118 |
119 |
120 | class BuiletinHuggingfaceDataset(KRLDataset):
121 | def __init__(
122 | self,
123 | dataset_conf: BuiletinHuggingfaceDatasetConf,
124 | mode: Literal['train', 'valid', 'test'],
125 | dataset_mata: KRLDatasetMeta,
126 | triples: List[Tuple[int, int, int]],
127 | ) -> None:
128 | self.triples = triples
129 | super().__init__(dataset_conf, mode, dataset_mata)
130 |
131 | def __len__(self):
132 | """Denotes the total number of samples."""
133 | return len(self.triples)
134 |
135 | def __getitem__(self, index) -> Triple:
136 | """Returns (head id, relation id, tail id)."""
137 | triple = self.triples[index]
138 | return triple[0], triple[1], triple[2]
139 |
140 |
141 | class KRLDatasetDict(BaseModel):
142 | train: KRLDataset
143 | valid: KRLDataset
144 | test: KRLDataset
145 |
146 | meta: KRLDatasetMeta
147 |
148 | dataset_conf: DatasetConf
149 |
150 | class Config:
151 | arbitrary_types_allowed = True
152 |
--------------------------------------------------------------------------------
/krl/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import ABC, abstractmethod
3 | from typing import List
4 |
5 | from .metric import KRLMetricBase, RankMetric, MetricEnum
6 |
7 |
8 | def cal_hits_at_k(predictions: torch.Tensor,
9 | ground_truth_idx: torch.Tensor,
10 | device: torch.device,
11 | k: int) -> float:
12 | """Calculates number of hits@k.
13 |
14 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
15 | must be sorted in class ids order
16 | :param ground_truth_idx: Bx1 tensor with index of ground truth class
17 | :param k: number of top K results to be considered as hits
18 | :return: Hits@K scoreH
19 | """
20 | assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size
21 |
22 | zero_tensor = torch.tensor([0], device=device)
23 | one_tensor = torch.tensor([1], device=device)
24 | _, indices = predictions.topk(k, largest=False) # indices: [batch_size, k]
25 | where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool
26 | hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item()
27 | return hits
28 |
29 | def cal_mrr(predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float:
30 | """Calculates mean reciprocal rank (MRR) for given predictions and ground truth values.
31 |
32 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
33 | must be sorted in class ids order
34 | :param ground_truth_idx: Bx1 tensor with index of ground truth class
35 | :return: Mean reciprocal rank score
36 | """
37 | assert predictions.size(0) == ground_truth_idx.size(0)
38 |
39 | indices = predictions.argsort()
40 | return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item()
41 |
42 |
43 |
44 | class Evaluator(ABC):
45 | """
46 | Every evaluator should derive this base class.
47 | """
48 |
49 | def __init__(self, device: torch.device) -> None:
50 | super().__init__()
51 | self.device = device
52 |
53 | @abstractmethod
54 | def evaluate(
55 | self,
56 | predictions: torch.Tensor,
57 | ground_truth_idx: torch.Tensor
58 | ) -> KRLMetricBase:
59 | """
60 | Calculate the metrics of model's prediction.
61 |
62 | :param predictions: _description_
63 | :param ground_truth_idx: _description_
64 | :return: _description_
65 | """
66 | pass
67 |
68 | @abstractmethod
69 | def clear(self):
70 | """
71 | Clear this evaluator for reusing.
72 | """
73 | pass
74 |
75 | @abstractmethod
76 | def reset_metrics(self, metrics: List[MetricEnum]):
77 | """
78 | reset the metrics that you want to calculate.
79 |
80 | :param metrics: the metric list.
81 | :type metrics: List[MetricEnum]
82 | """
83 | pass
84 |
85 | def export_metrics(self) -> KRLMetricBase:
86 | """
87 | export the metric result stored in evaluator.
88 | """
89 | pass
90 |
91 |
92 | class RankEvaluator(Evaluator):
93 |
94 | _SUPPORT_METRICS = {
95 | MetricEnum.MRR,
96 | MetricEnum.HITS_AT_1,
97 | MetricEnum.HITS_AT_3,
98 | MetricEnum.HITS_AT_10
99 | }
100 |
101 | def __init__(
102 | self,
103 | device: torch.device,
104 | metrics: List[MetricEnum]
105 | ) -> None:
106 | """
107 | :param metrics: The metrics that you want to calcualte.
108 | """
109 | super().__init__(device)
110 | self.example_cnt = 0
111 | self.metrics = None
112 | self._mrr_sum = None
113 | self._hits_at_1_sum = None
114 | self._hits_at_3_sum = None
115 | self._hits_at_10_sum = None
116 | # checks the metrics that you want to calcualte
117 | self.reset_metrics(metrics)
118 | # set to 0 if you want to calcualte this metric
119 | self.clear()
120 |
121 | def clear(self):
122 | self.example_cnt = 0
123 | self._mrr_sum = None if MetricEnum.MRR not in self.metrics else 0
124 | self._hits_at_1_sum = None if MetricEnum.HITS_AT_1 not in self.metrics else 0
125 | self._hits_at_3_sum = None if MetricEnum.HITS_AT_3 not in self.metrics else 0
126 | self._hits_at_10_sum = None if MetricEnum.HITS_AT_10 not in self.metrics else 0
127 |
128 | def reset_metrics(self, metrics: List[MetricEnum]):
129 | for m in metrics:
130 | if m not in RankEvaluator._SUPPORT_METRICS:
131 | raise NotImplementedError(f"Evaluator don't support metric: {m.value}")
132 | self.metrics = set(metrics)
133 |
134 | def evaluate(
135 | self,
136 | predictions: torch.Tensor,
137 | ground_truth_idx: torch.Tensor
138 | ):
139 | self.example_cnt += predictions.size(0)
140 | if MetricEnum.MRR in self.metrics:
141 | self._mrr_sum += cal_mrr(predictions, ground_truth_idx)
142 | if MetricEnum.HITS_AT_1 in self.metrics:
143 | self._hits_at_1_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 1)
144 | if MetricEnum.HITS_AT_3 in self.metrics:
145 | self._hits_at_3_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 3)
146 | if MetricEnum.HITS_AT_10 in self.metrics:
147 | self._hits_at_10_sum += cal_hits_at_k(predictions, ground_truth_idx, self.device, 10)
148 |
149 | def export_metrics(self) -> RankMetric:
150 | """
151 | Export the metric result stored in evaluator
152 | """
153 | result = RankMetric(
154 | mrr=None if MetricEnum.MRR not in self.metrics else self._percentage(self._mrr_sum),
155 | hits_at_1=None if MetricEnum.HITS_AT_1 not in self.metrics else self._percentage(self._hits_at_1_sum),
156 | hits_at_3=None if MetricEnum.HITS_AT_3 not in self.metrics else self._percentage(self._hits_at_3_sum),
157 | hits_at_10=None if MetricEnum.HITS_AT_10 not in self.metrics else self._percentage(self._hits_at_10_sum)
158 | )
159 | return result
160 |
161 | def _percentage(self, sum):
162 | return sum / self.example_cnt * 100
--------------------------------------------------------------------------------
/krl/lit_model/TransXLitModel.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from typing import List, Any
5 |
6 | from ..base_model import TransXBaseModel
7 | from ..config import HyperParam
8 | from ..negative_sampler import NegativeSampler
9 | from ..metric import HitsAtK, MRR
10 | from ..dataset import KRLDatasetDict
11 | from .. import utils
12 |
13 |
14 | class TransXLitModel(pl.LightningModule):
15 | def __init__(
16 | self,
17 | model: TransXBaseModel,
18 | dataset_dict: KRLDatasetDict,
19 | train_neg_sampler: NegativeSampler,
20 | hyper_params: HyperParam
21 | ) -> None:
22 | super().__init__()
23 | self.model = model
24 | self.model_device = next(model.parameters()).device
25 | self.dataset_dict = dataset_dict
26 | self.train_neg_sampler = train_neg_sampler
27 | self.params = hyper_params
28 | self.val_hits10 = HitsAtK(10)
29 | self.test_hits1 = HitsAtK(1)
30 | self.test_hits10 = HitsAtK(10)
31 | self.test_mrr = MRR()
32 |
33 | def training_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor):
34 | """
35 | training step
36 | :param batch: [3, batch_size]
37 | :param batch_idx: [batch_size]
38 | """
39 | pos_heads, pos_rels, pos_tails = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
40 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3]
41 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3]
42 | # calculte loss
43 | loss, _, _ = self.model.forward(pos_triples, neg_triples)
44 | return loss
45 |
46 | def validation_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor):
47 | preds, target = self._get_preds_and_target(batch)
48 | self.val_hits10.update(preds, target)
49 |
50 | def validation_epoch_end(self, outputs: List[Any]) -> None:
51 | val_hits_at_10 = self.val_hits10.compute()
52 | self.val_hits10.reset()
53 | self.log('val_hits@10', val_hits_at_10)
54 |
55 | def test_step(self, batch: List[torch.Tensor], batch_idx: torch.Tensor):
56 | preds, target = self._get_preds_and_target(batch)
57 | self.test_hits1.update(preds, target)
58 | self.test_hits10.update(preds, target)
59 | self.test_mrr.update(preds, target)
60 |
61 | def test_epoch_end(self, outputs: List[Any]) -> None:
62 | result = {
63 | 'hits_at_1': self.test_hits1.compute(),
64 | 'hits_at_10': self.test_hits10.compute(),
65 | 'mrr': self.test_mrr.compute()
66 | }
67 |
68 | self.test_hits1.reset()
69 | self.test_hits10.reset()
70 | self.test_mrr.reset()
71 |
72 | self.log_dict(result)
73 |
74 | def configure_optimizers(self):
75 | optimizer = utils.create_optimizer(self.params.optimizer, self.model, self.params.learning_rate)
76 | milestones = int(self.params.epoch_size / 2)
77 | stepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[milestones], gamma=0.1)
78 | return {
79 | 'optimizer': optimizer,
80 | 'lr_scheduler': stepLR
81 | }
82 |
83 | def train_dataloader(self) -> DataLoader:
84 | return DataLoader(
85 | self.dataset_dict.train,
86 | batch_size=self.params.batch_size,
87 | num_workers=64
88 | )
89 |
90 | def val_dataloader(self) -> DataLoader:
91 | return DataLoader(
92 | self.dataset_dict.valid,
93 | batch_size=self.params.valid_batch_size,
94 | num_workers=64
95 | )
96 |
97 | def test_dataloader(self) -> DataLoader:
98 | return DataLoader(
99 | self.dataset_dict.test,
100 | batch_size=self.params.valid_batch_size,
101 | num_workers=64
102 | )
103 |
104 |
105 | def _get_preds_and_target(self, batch: List[torch.Tensor]):
106 | ent_num = len(self.dataset_dict.meta.entity2id)
107 | entity_ids = torch.arange(0, ent_num, device=self.device)
108 | entity_ids.unsqueeze_(0)
109 | heads, rels, tails = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
110 | batch_size = heads.size(0)
111 | all_entities = entity_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num]
112 | # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num]
113 | heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num]
114 | rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num)
115 | tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num)
116 | # check all possible tails
117 | triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3]
118 | tails_predictions = self.model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num]
119 | # check all possible heads
120 | triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3)
121 | heads_predictions = self.model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num]
122 |
123 | # Concept preditions
124 | predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num]
125 | ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1]
126 |
127 | return predictions, ground_truth_entity_id
128 |
--------------------------------------------------------------------------------
/krl/lit_model/__init__.py:
--------------------------------------------------------------------------------
1 | from .TransXLitModel import TransXLitModel
--------------------------------------------------------------------------------
/krl/metric.py:
--------------------------------------------------------------------------------
1 | """
2 | Calculate the metrics of KRL models.
3 | """
4 | from pydantic import BaseModel
5 | from typing import Optional
6 | from abc import ABC
7 | from enum import Enum
8 | import torchmetrics
9 | import torch
10 |
11 |
12 | class MetricEnum(Enum):
13 | """
14 | Enumerate all metric name. This name is the attribute of Metric Model which derived from `KRLMetricBase`.
15 | """
16 | MRR = 'mrr'
17 | HITS_AT_1 = 'hits_at_1'
18 | HITS_AT_3 = 'hits_at_3'
19 | HITS_AT_10 = 'hits_at_10'
20 |
21 |
22 | class KRLMetricBase(BaseModel, ABC):
23 | """All metric model class should derive this class.
24 | """
25 | pass
26 |
27 |
28 | class RankMetric(KRLMetricBase):
29 | mrr: Optional[float]
30 | hits_at_1: Optional[float]
31 | hits_at_3: Optional[float]
32 | hits_at_10: Optional[float]
33 |
34 |
35 | class MRR(torchmetrics.Metric):
36 | def __init__(self) -> None:
37 | super().__init__()
38 | self.add_state("mrr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
39 | self.add_state("example_cnt", default=torch.tensor(0), dist_reduce_fx="sum")
40 |
41 | def _cal_mrr(self, predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float:
42 | """Calculates mean reciprocal rank (MRR) for given predictions and ground truth values.
43 |
44 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
45 | must be sorted in class ids order
46 | :param ground_truth_idx: Bx1 tensor with index of ground truth class
47 | :return: Mean reciprocal rank score
48 | """
49 | assert predictions.size(0) == ground_truth_idx.size(0)
50 |
51 | indices = predictions.argsort()
52 | return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item()
53 |
54 | def update(self, preds: torch.Tensor, target: torch.Tensor):
55 | self.mrr_sum += self._cal_mrr(preds, target)
56 | self.example_cnt += preds.size(0)
57 |
58 | def compute(self):
59 | return self.mrr_sum.float() / self.example_cnt * 100
60 |
61 |
62 | class HitsAtK(torchmetrics.Metric):
63 | def __init__(self, k: int) -> None:
64 | super().__init__()
65 | self.k = k
66 | self.add_state("hits_sum", default=torch.tensor(0), dist_reduce_fx="sum")
67 | self.add_state("example_cnt", default=torch.tensor(0), dist_reduce_fx="sum")
68 |
69 | def _cal_hits_at_k(
70 | self,
71 | predictions: torch.Tensor,
72 | ground_truth_idx: torch.Tensor
73 | ) -> float:
74 | """Calculates number of hits@k.
75 |
76 | :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions
77 | must be sorted in class ids order
78 | :param ground_truth_idx: Bx1 tensor with index of ground truth class
79 | :param k: number of top K results to be considered as hits
80 | :return: Hits@K scoreH
81 | """
82 | assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size
83 |
84 | device = predictions.device
85 |
86 | zero_tensor = torch.tensor([0], device=device)
87 | one_tensor = torch.tensor([1], device=device)
88 | _, indices = predictions.topk(self.k, largest=False) # indices: [batch_size, k]
89 | where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool
90 | hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item()
91 | return hits
92 |
93 | def update(self, preds: torch.Tensor, target: torch.Tensor):
94 | self.hits_sum += self._cal_hits_at_k(preds, target)
95 | self.example_cnt += preds.size(0)
96 |
97 | def compute(self) -> float:
98 | return self.hits_sum.float() / self.example_cnt * 100
99 |
--------------------------------------------------------------------------------
/krl/metric_fomatter.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | from .metric import KRLMetricBase, RankMetric
5 | from .config import DatasetConf
6 |
7 |
8 | class MetricFormatter(ABC):
9 | """
10 | Every formatter class should derive this class.
11 | """
12 | @abstractmethod
13 | def convert(self, metric: KRLMetricBase) -> Any:
14 | """
15 | Convert metric model into a specific format
16 |
17 | :param metric: The metric instance that we want to convert.
18 | """
19 | pass
20 |
21 |
22 | _STRING_TEMPLATE = """dataset: {dataset_name},
23 | Hits@1: {hits_at_1},
24 | Hits@3: {hits_at_3},
25 | Hist@10: {hits_at_10},
26 | MRR: {mrr}.
27 | """
28 |
29 | class StringFormatter(MetricFormatter):
30 | """
31 | Convert the metric into string.
32 | """
33 | def convert(self, metric: RankMetric, dataset_conf: DatasetConf) -> str:
34 | return _STRING_TEMPLATE.format(
35 | dataset_name=dataset_conf.dataset_name,
36 | hits_at_1=metric.hits_at_1,
37 | hits_at_3=metric.hits_at_3,
38 | hits_at_10=metric.hits_at_10,
39 | mrr=metric.mrr
40 | )
41 |
--------------------------------------------------------------------------------
/krl/models/DistMult.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 |
4 | - https://github.com/Sujit-O/pykg2vec/blob/master/pykg2vec/models/pointwise.py
5 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/DistMult.py
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from pydantic import Field
12 | from typing import Literal
13 |
14 | from ..base_model import KRLModel, ModelMain
15 | from ..config import HyperParam, LocalDatasetConf, TrainConf
16 | from ..dataset import create_mapping, LocalKRLDataset
17 | from ..negative_sampler import BernNegSampler
18 | from .. import utils
19 | from ..trainer import RescalTrainer
20 | from ..metric import MetricEnum
21 | from ..evaluator import RankEvaluator
22 | from .. import storage
23 | from ..metric_fomatter import StringFormatter
24 | from ..serializer import FileSerializer
25 |
26 |
27 | class DistMultHyperParam(HyperParam):
28 | """Hyper-parameters of DistMult
29 | """
30 | embed_dim: int
31 | alpha: float = Field(0.001, title='regularization parameter')
32 | regul_type: Literal['F2', 'N3'] = Field('F2', title='regularization type')
33 |
34 |
35 | class DistMult(KRLModel):
36 | def __init__(
37 | self,
38 | ent_num: int,
39 | rel_num: int,
40 | device: torch.device,
41 | hyper_params: DistMultHyperParam
42 | ):
43 | super().__init__()
44 | self.ent_num = ent_num
45 | self.rel_num = rel_num
46 | self.device = device
47 | self.embed_dim = hyper_params.embed_dim
48 | self.alpha = hyper_params.alpha
49 | self.regul_type = hyper_params.regul_type.upper()
50 |
51 | # initialize entity embedding
52 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
53 | nn.init.xavier_uniform_(self.ent_embedding.weight.data)
54 |
55 | # initialize relation embedding
56 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)
57 | nn.init.xavier_uniform_(self.rel_embedding.weight.data)
58 |
59 | self.criterion = nn.MSELoss() # 当数据的 neg_sample label 为 0 时用这个
60 | # self.criterion = lambda preds, labels: F.softplus(preds * labels).mean() # neg_sample label 为 -1 时用这个
61 |
62 |
63 | def embed(self, triples):
64 | """get the embedding of triples
65 |
66 | :param triples: [heads, rels, tails]
67 | """
68 | assert triples.shape[1] == 3
69 | # get entity ids and relation ids
70 | heads = triples[:, 0]
71 | rels = triples[:, 1]
72 | tails = triples[:, 2]
73 | # id -> embedding
74 | h_embs = self.ent_embedding(heads) # [batch, emb]
75 | t_embs = self.ent_embedding(tails)
76 | r_embs = self.rel_embedding(rels) # [batch, emb]
77 | return h_embs, r_embs, t_embs
78 |
79 | def _get_reg(self, h_embs, r_embs, t_embs):
80 | """Calculate regularization term
81 |
82 | :param h_embs: heads embedding, size: [batch, embed]
83 | :param r_embs: rels embedding
84 | :param t_embs: tails embeddings
85 | :return: _description_
86 | """
87 | if self.regul_type == 'F2':
88 | regul = (torch.mean(h_embs ** 2) + torch.mean(t_embs ** 2) + torch.mean(r_embs ** 2)) / 3
89 | else:
90 | regul = torch.mean(torch.sum(h_embs ** 3, -1) + torch.sum(r_embs ** 3, -1) + torch.sum(t_embs ** 3, -1))
91 | return regul
92 |
93 | def _scoring(self, h_embs, r_embs, t_embs):
94 | """计算一个 batch 的三元组的 scores
95 | score 越大越好,正例接近 1,负例接近 0
96 | This score can also be regard as the `pred`
97 |
98 | :param h_embs: embedding of a batch heads, size: [batch, emb]
99 | :return: size: [batch,]
100 | """
101 | return torch.sum(h_embs * r_embs * t_embs, dim=1)
102 |
103 | def loss(self, triples: torch.Tensor, labels: torch.Tensor):
104 | """Calculate the loss
105 |
106 | :param triples: a batch of triples. size: [batch, 3]
107 | :param labels: the label of each triple, label = 1 if the triple is positive, label = 0 if the triple is negative. size: [batch,]
108 | """
109 | assert triples.shape[1] == 3
110 | assert triples.shape[0] == labels.shape[0]
111 |
112 | h_embs, r_embs, t_embs = self.embed(triples)
113 |
114 | scores = self._scoring(h_embs, r_embs, t_embs)
115 | regul = self._get_reg(h_embs, r_embs, t_embs)
116 | loss = self.criterion(scores, labels.float()) + self.alpha * regul
117 |
118 | return loss, scores
119 |
120 | def forward(self, triples, labels):
121 | loss, scores = self.loss(triples, labels)
122 | return loss, scores
123 |
124 | def predict(self, triples):
125 | """Calculated dissimilarity score for given triplets.
126 |
127 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
128 | :return: dissimilarity score for given triplets
129 | """
130 | h_embs, r_embs, t_embs = self.embed(triples)
131 | scores = self._scoring(h_embs, r_embs, t_embs)
132 | return -scores
133 |
134 |
135 | class DistMultMain(ModelMain):
136 | def __init__(
137 | self,
138 | dataset_conf: LocalDatasetConf,
139 | train_conf: TrainConf,
140 | hyper_params: DistMultHyperParam,
141 | device: torch.device
142 | ) -> None:
143 | super().__init__()
144 | self.dataset_conf = dataset_conf
145 | self.train_conf = train_conf
146 | self.hyper_params = hyper_params
147 | self.device = device
148 |
149 | def __call__(self):
150 | # create mapping
151 | entity2id, rel2id = create_mapping(self.dataset_conf)
152 | ent_num = len(entity2id)
153 | rel_num = len(rel2id)
154 |
155 | # create dataset and dataloader
156 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id)
157 |
158 | # create negative-sampler
159 | neg_sampler = BernNegSampler(train_dataset, self.device)
160 |
161 | # create model
162 | model = DistMult(ent_num, rel_num, self.device, self.hyper_params)
163 | model = model.to(self.device)
164 |
165 | # create optimizer
166 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate)
167 |
168 | # create trainer
169 | trainer = RescalTrainer(
170 | model=model,
171 | train_conf=self.train_conf,
172 | params=self.hyper_params,
173 | dataset_conf=self.dataset_conf,
174 | entity2id=entity2id,
175 | rel2id=rel2id,
176 | device=self.device,
177 | train_dataloder=train_dataloader,
178 | valid_dataloder=valid_dataloader,
179 | train_neg_sampler=neg_sampler,
180 | valid_neg_sampler=neg_sampler,
181 | optimzer=optimizer
182 | )
183 |
184 | # training process
185 | trainer.run_training()
186 |
187 | # create evaluator
188 | metrics = [
189 | MetricEnum.MRR,
190 | MetricEnum.HITS_AT_1,
191 | MetricEnum.HITS_AT_3,
192 | MetricEnum.HITS_AT_10
193 | ]
194 | evaluator = RankEvaluator(self.device, metrics)
195 |
196 | # Testing the best checkpoint on test dataset
197 | # load best model
198 | ckpt = storage.load_checkpoint(self.train_conf)
199 | model.load_state_dict(ckpt.model_state_dict)
200 | model = model.to(self.device)
201 | # create test-dataset
202 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id)
203 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size)
204 | # run inference on test-dataset
205 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator)
206 |
207 | # choice metric formatter
208 | metric_formatter = StringFormatter()
209 |
210 | # choice the way of serialize
211 | serilizer = FileSerializer(self.train_conf, self.dataset_conf)
212 | # serialize the metric
213 | serilizer.serialize(metric, metric_formatter)
214 |
--------------------------------------------------------------------------------
/krl/models/RESCAL.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 |
4 | - https://yubincloud.github.io/notebook-paper/KG/KRL/1101.RESCAL-and-extensions.html
5 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/RESCAL.py
6 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/RESCAL.py
7 | """
8 |
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 | from pydantic import Field
15 | from typing import Literal
16 |
17 | from ..base_model import KRLModel, ModelMain
18 | from ..config import HyperParam, TrainConf
19 | from ..dataset import KRLDatasetDict
20 | from ..negative_sampler import BernNegSampler
21 | from .. import utils
22 | from ..trainer import RescalTrainer
23 | from ..metric import MetricEnum
24 | from ..evaluator import RankEvaluator
25 | from .. import storage
26 | from ..metric_fomatter import StringFormatter
27 | from ..serializer import FileSerializer
28 |
29 |
30 | class RescalHyperParam(HyperParam):
31 | """Hyper-parameters of RESCAL
32 | """
33 | embed_dim: int
34 | alpha: float = Field(0.001, title='regularization parameter')
35 | regul_type: Literal['F2', 'N3'] = Field('F2', title='regularization type')
36 |
37 |
38 | class RESCAL(KRLModel):
39 | def __init__(
40 | self,
41 | ent_num: int,
42 | rel_num: int,
43 | device: torch.device,
44 | hyper_params: RescalHyperParam,
45 | ):
46 | super().__init__()
47 | self.ent_num = ent_num
48 | self.rel_num = rel_num
49 | self.device = device
50 | self.embed_dim = hyper_params.embed_dim
51 | self.alpha = hyper_params.alpha
52 | self.regul_type = hyper_params.regul_type.upper()
53 |
54 | # initialize entity embedding
55 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
56 | nn.init.xavier_uniform_(self.ent_embedding.weight.data)
57 | self.ent_embedding.weight.data = F.normalize(self.ent_embedding.weight.data, 2, 1) # 在许多实现中,这一行可以去掉
58 |
59 | # initialize relation embedding
60 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim * self.embed_dim)
61 | nn.init.xavier_uniform_(self.rel_embedding.weight.data)
62 | self.rel_embedding.weight.data = F.normalize(self.rel_embedding.weight.data, 2, 1)
63 |
64 | self.criterion = nn.MSELoss()
65 |
66 | def embed(self, triples):
67 | """Get the embeddings of a batch of triples
68 |
69 | :param triples: _description_
70 | """
71 | assert triples.shape[1] == 3
72 | # get entity ids and relation ids
73 | heads = triples[:, 0]
74 | rels = triples[:, 1]
75 | tails = triples[:, 2]
76 | # id -> embedding
77 | h_embs = self.ent_embedding(heads) # [batch, emb]
78 | t_embs = self.ent_embedding(tails)
79 | r_embs = self.rel_embedding(rels) # [batch, emb * emb]
80 | return h_embs, r_embs, t_embs
81 |
82 |
83 | def _scoring(self, h_embs, r_embs, t_embs):
84 | """计算一个 batch 的三元组的 scores
85 | score 越大越好,正例接近 1,负例接近 0
86 | This score can also be regard as the `pred`
87 |
88 | :param h_embs: heads embedding,size: [batch, embed]
89 | :param r_embs: rels embedding,size: [batch, embed * embed]
90 | :return: size: [batch,]
91 | """
92 | # calcate scores
93 | r_embs = r_embs.view(-1, self.embed_dim, self.embed_dim) # [batch, emb, emb]
94 | t_embs = t_embs.view(-1, self.embed_dim, 1) # [batch, emb, 1]
95 |
96 | tr = torch.matmul(r_embs, t_embs) # [batch, emb, 1]
97 | tr = tr.view(-1, self.embed_dim) # [batch, emb]
98 |
99 | return torch.sum(h_embs * tr, dim=1)
100 |
101 | def _get_reg(self, h_embs, r_embs, t_embs):
102 | """Calculate regularization term
103 |
104 | :param h_embs: heads embedding, size: [batch, embed]
105 | :param r_embs: rels embedding
106 | :param t_embs: tails embeddings
107 | :return: _description_
108 | """
109 | if self.regul_type == 'F2':
110 | regul = (torch.mean(h_embs ** 2) + torch.mean(t_embs ** 2) + torch.mean(r_embs ** 2)) / 3
111 | else:
112 | regul = torch.mean(torch.sum(h_embs ** 3, -1) + torch.sum(r_embs ** 3, -1) + torch.sum(t_embs ** 3, -1))
113 | return regul
114 |
115 | def loss(self, triples: torch.Tensor, labels: torch.Tensor):
116 | """Calculate the loss
117 |
118 | :param triples: a batch of triples. size: [batch, 3]
119 | :param labels: the label of each triple, label = 1 if the triple is positive, label = 0 if the triple is negative. size: [batch,]
120 | """
121 | assert triples.shape[1] == 3
122 | assert triples.shape[0] == labels.shape[0]
123 |
124 | h_embs, r_embs, t_embs = self.embed(triples)
125 |
126 | scores = self._scoring(h_embs, r_embs, t_embs)
127 | regul = self._get_reg(h_embs, r_embs, t_embs)
128 | loss = self.criterion(scores, labels.float()) + self.alpha * regul
129 |
130 | return loss, scores
131 |
132 | def forward(self, triples, labels):
133 | loss, scores = self.loss(triples, labels)
134 | return loss, scores
135 |
136 | def predict(self, triples):
137 | """Calculated dissimilarity score for given triplets.
138 |
139 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
140 | :return: dissimilarity score for given triplets
141 | """
142 | h_embs, r_embs, t_embs = self.embed(triples)
143 | return -self._scoring(h_embs, r_embs, t_embs)
144 |
145 |
146 |
147 | class RescalMain(ModelMain):
148 | def __init__(
149 | self,
150 | dataset: KRLDatasetDict,
151 | train_conf: TrainConf,
152 | hyper_params: RescalHyperParam,
153 | device: torch.device
154 | ) -> None:
155 | super().__init__()
156 | self.datasets = dataset
157 | self.dataset_conf = dataset.dataset_conf
158 | self.train_conf = train_conf
159 | self.params = hyper_params
160 | self.device = device
161 |
162 | def __call__(self):
163 | # create mapping
164 | entity2id = self.datasets.entity2id
165 | rel2id = self.datasets.rel2id
166 | ent_num = len(entity2id)
167 | rel_num = len(rel2id)
168 |
169 | # create dataset and dataloader
170 | train_dataset = self.datasets.train
171 | train_dataloader = DataLoader(train_dataset, self.params.batch_size)
172 | valid_dataset = self.datasets.valid
173 | valid_dataloader = DataLoader(valid_dataset, self.params.batch_size)
174 |
175 | # create negative-sampler
176 | neg_sampler = BernNegSampler(train_dataset, self.device)
177 |
178 | # create model
179 | model = RESCAL(ent_num, rel_num, self.device, self.params)
180 | model = model.to(self.device)
181 |
182 | # create optimizer
183 | optimizer = utils.create_optimizer(self.params.optimizer, model, self.params.learning_rate)
184 |
185 | # create trainer
186 | trainer = RescalTrainer(
187 | model=model,
188 | train_conf=self.train_conf,
189 | params=self.params,
190 | dataset_conf=self.dataset_conf,
191 | entity2id=entity2id,
192 | rel2id=rel2id,
193 | device=self.device,
194 | train_dataloder=train_dataloader,
195 | valid_dataloder=valid_dataloader,
196 | train_neg_sampler=neg_sampler,
197 | valid_neg_sampler=neg_sampler,
198 | optimzer=optimizer
199 | )
200 |
201 | # training process
202 | trainer.run_training()
203 |
204 | # create evaluator
205 | metrics = [
206 | MetricEnum.MRR,
207 | MetricEnum.HITS_AT_1,
208 | MetricEnum.HITS_AT_3,
209 | MetricEnum.HITS_AT_10
210 | ]
211 | evaluator = RankEvaluator(self.device, metrics)
212 |
213 | # Testing the best checkpoint on test dataset
214 | # load best model
215 | ckpt = storage.load_checkpoint(self.train_conf)
216 | model.load_state_dict(ckpt.model_state_dict)
217 | model = model.to(self.device)
218 | # create test-dataset
219 | test_dataset = self.datasets.test
220 | test_dataloder = DataLoader(test_dataset, self.params.valid_batch_size)
221 | # run inference on test-dataset
222 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator)
223 |
224 | # choice metric formatter
225 | metric_formatter = StringFormatter()
226 |
227 | # choice the way of serialize
228 | serilizer = FileSerializer(self.train_conf, self.dataset_conf)
229 | # serialize the metric
230 | serilizer.serialize(metric, metric_formatter)
231 |
--------------------------------------------------------------------------------
/krl/models/TransD.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 |
4 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/TransD.py
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | from ..config import HyperParam
10 | from ..base_model import KRLModel
11 |
12 |
13 | class TransDHyperParam(HyperParam):
14 | ent_dim: int
15 | rel_dim: int
16 | norm: int
17 | margin: float
18 |
19 |
20 | class TransD(KRLModel):
21 | def __init__(
22 | self,
23 | ent_num: int,
24 | rel_num: int,
25 | device: torch.device,
26 | hyper_params: TransDHyperParam
27 | ):
28 | super().__init__()
29 | self.ent_num = ent_num
30 | self.rel_num = rel_num
31 | self.device = device
32 | self.norm = hyper_params.norm
33 | self.ent_dim = hyper_params.ent_dim
34 | self.rel_dim = hyper_params.rel_dim
35 | self.margin = hyper_params.margin
36 |
37 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin)
38 |
39 | # 初始化 ent_embedding
40 | self.ent_embedding = nn.Embedding(self.ent_num, self.ent_dim)
41 | nn.init.xavier_uniform_(self.ent_embedding.weight.data)
42 |
43 | # 初始化 rel_embedding
44 | self.rel_embedding = nn.Embedding(self.rel_num, self.rel_dim)
45 | nn.init.xavier_uniform_(self.rel_embedding.weight.data)
46 |
47 | # 初始化 transfer embedding
48 | self.ent_transfer = nn.Embedding(self.ent_num, self.ent_dim)
49 | nn.init.xavier_uniform_(self.ent_transfer.weight.data)
50 | self.rel_transfer = nn.Embedding(self.rel_num, self.rel_dim)
51 | nn.init.xavier_uniform_(self.rel_transfer.weight.data)
52 |
53 | self.dist_fn = nn.PairwiseDistance(p=self.norm)
54 |
--------------------------------------------------------------------------------
/krl/models/TransE.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.loggers import CSVLogger
6 | from lightning.pytorch.callbacks import EarlyStopping
7 | from datetime import datetime
8 |
9 | from ..base_model import TransXBaseModel, LitModelMain
10 | from ..config import TransHyperParam, TrainConf
11 | from ..dataset import KRLDatasetDict
12 | from ..negative_sampler import BernNegSampler
13 | from ..lit_model import TransXLitModel
14 | from .. import utils
15 |
16 |
17 | class TransEHyperParam(TransHyperParam):
18 | """Hyper-paramters of TransE
19 | """
20 | pass
21 |
22 |
23 | class TransE(TransXBaseModel):
24 | def __init__(self,
25 | ent_num: int,
26 | rel_num: int,
27 | hyper_params: TransEHyperParam
28 | ):
29 | super().__init__()
30 | self.ent_num = ent_num
31 | self.rel_num = rel_num
32 | self.norm = hyper_params.norm
33 | self.embed_dim = hyper_params.embed_dim
34 | self.margin = hyper_params.margin
35 |
36 | # 初始化 ent_embedding,按照原论文的方法来初始化
37 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
38 | torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data)
39 | #uniform_range = 6 / np.sqrt(self.embed_dim)
40 | #self.ent_embedding.weight.data.uniform_(-uniform_range, uniform_range)
41 |
42 | # 初始化 rel_embedding
43 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)
44 | torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data)
45 | #uniform_range = 6 / np.sqrt(self.embed_dim)
46 | #self.rel_embedding.weight.data.uniform_(-uniform_range, uniform_range)
47 |
48 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance
49 | self.criterion = nn.MarginRankingLoss(margin=self.margin)
50 |
51 | def embed(self, triples):
52 | """get the embedding of triples
53 |
54 | :param triples: [heads, rels, tails]
55 | :return: embedding of triples.
56 | """
57 | assert triples.shape[1] == 3
58 | heads = triples[:, 0]
59 | rels = triples[:, 1]
60 | tails = triples[:, 2]
61 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim]
62 | r_embs = self.rel_embedding(rels)
63 | t_embs = self.ent_embedding(tails)
64 | return h_embs, r_embs, t_embs
65 |
66 | def _distance(self, triples):
67 | """计算一个 batch 的三元组的 distance
68 |
69 | :param triples: 一个 batch 的 triple,size: [batch, 3]
70 | :return: size: [batch,]
71 | """
72 | h_embs, r_embs, t_embs = self.embed(triples)
73 | return self.dist_fn(h_embs + r_embs, t_embs)
74 |
75 | def loss(self, pos_distances: torch.Tensor, neg_distances: torch.Tensor):
76 | """Calculate the loss
77 |
78 | :param pos_distances: [batch, ]
79 | :param neg_distances: [batch, ]
80 | :return: loss
81 | """
82 | ones = torch.tensor([-1], dtype=torch.long, device=pos_distances.device)
83 | return self.criterion(pos_distances, neg_distances, ones)
84 |
85 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):
86 | """Return model losses based on the input.
87 |
88 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)
89 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)
90 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component
91 | """
92 | assert pos_triples.size()[1] == 3
93 | assert neg_triples.size()[1] == 3
94 |
95 | pos_distances = self._distance(pos_triples)
96 | neg_distances = self._distance(neg_triples)
97 | loss = self.loss(pos_distances, neg_distances)
98 | return loss, pos_distances, neg_distances
99 |
100 | def predict(self, triples: torch.Tensor):
101 | """Calculated dissimilarity score for given triplets.
102 |
103 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
104 | :return: dissimilarity score for given triplets
105 | """
106 | return self._distance(triples)
107 |
108 |
109 | class TransELitMain(LitModelMain):
110 | def __init__(
111 | self,
112 | dataset: KRLDatasetDict,
113 | train_conf: TrainConf,
114 | hyper_params: TransEHyperParam,
115 | seed: int = None,
116 | ) -> None:
117 | super().__init__(
118 | dataset,
119 | train_conf,
120 | seed
121 | )
122 | self.params = hyper_params
123 |
124 | def __call__(self):
125 | # seed everything
126 | pl.seed_everything(self.seed)
127 |
128 | # create mapping
129 | ent_num = len(self.datasets.meta.entity2id)
130 | rel_num = len(self.datasets.meta.rel2id)
131 |
132 | # create negative-sampler
133 | train_dataset = self.datasets.train
134 | neg_sampler = BernNegSampler(train_dataset)
135 |
136 | # create model
137 | model = TransE(ent_num, rel_num, self.params)
138 | model_wrapped = TransXLitModel(model, self.datasets, neg_sampler, self.params)
139 |
140 | # callbacks
141 | early_stopping = EarlyStopping('val_hits@10', mode="max", patience=self.params.early_stoping_patience, check_on_train_epoch_end=False)
142 |
143 | # create trainer
144 | trainer = pl.Trainer(
145 | gpus="0,",
146 | max_epochs=self.params.epoch_size,
147 | logger=CSVLogger(self.train_conf.logs_dir, name=f'{model.__class__.__name__}-{self.dataset_conf.dataset_name}'),
148 | callbacks=[early_stopping]
149 | )
150 | trainer.fit(model=model_wrapped)
151 |
152 | trainer.test(model_wrapped)
153 |
--------------------------------------------------------------------------------
/krl/models/TransH.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 |
4 | - https://github.com/LYuhang/Trans-Implementation/blob/master/code/models/TransH.py
5 | - https://github.com/zqhead/TransH/blob/master/TransH_torch.py
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | from pydantic import Field
13 |
14 | from ..base_model import TransXBaseModel, ModelMain
15 | from ..config import TransHyperParam, LocalDatasetConf, TrainConf
16 | from ..dataset import create_mapping, LocalKRLDataset
17 | from ..negative_sampler import BernNegSampler
18 | from .. import utils
19 | from ..trainer import TransETrainer
20 | from ..metric import MetricEnum
21 | from ..evaluator import RankEvaluator
22 | from .. import storage
23 | from ..metric_fomatter import StringFormatter
24 | from ..serializer import FileSerializer
25 |
26 |
27 | class TransHHyperParam(TransHyperParam):
28 | """Hyper-paramters of TransH
29 |
30 | :param HyperParam: base hyper params model
31 | """
32 | C: float = Field(default=0.1, description='a hyper-parameter weighting the importance of soft constraints.')
33 | eps: float = Field(default=1e-3, description='the $\episilon$ in loss function')
34 |
35 |
36 | class TransH(TransXBaseModel):
37 | def __init__(
38 | self,
39 | ent_num: int,
40 | rel_num: int,
41 | device: torch.device,
42 | hyper_params: TransHHyperParam
43 | ):
44 | super().__init__()
45 | self.ent_num = ent_num
46 | self.rel_num = rel_num
47 | self.device = device
48 | self.norm = hyper_params.norm
49 | self.embed_dim = hyper_params.embed_dim
50 | self.margin = hyper_params.margin
51 | self.C = hyper_params.C # a hyper-parameter weighting the importance of soft constraints
52 | self.eps = hyper_params.eps # the $\episilon$ in loss function
53 |
54 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin)
55 |
56 | # 初始化 ent_embedding
57 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
58 | nn.init.xavier_uniform_(self.ent_embedding.weight.data)
59 |
60 | # 初始化 rel_embedding,Embedding for the relation-specific translation vector $d_r$
61 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)
62 | nn.init.xavier_uniform_(self.rel_embedding.weight.data)
63 |
64 | # 初始化 rel_hyper_embedding,Embedding for the relation-specific hyperplane $w_r$
65 | self.rel_hyper_embedding = nn.Embedding(self.rel_num, self.embed_dim)
66 | nn.init.xavier_uniform_(self.rel_hyper_embedding.weight.data)
67 |
68 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance
69 |
70 | def embed(self, triples):
71 | """get the embedding of triples
72 |
73 | :param triples: [heads, rels, tails]
74 | :return: embedding of triples.
75 | """
76 | assert triples.shape[1] == 3
77 | heads = triples[:, 0]
78 | rels = triples[:, 1]
79 | tails = triples[:, 2]
80 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim]
81 | t_embs = self.ent_embedding(tails)
82 | r_embs = self.rel_embedding(rels)
83 | r_hyper_embs = self.rel_hyper_embedding(rels) # relation hyperplane, size: [batch_size, embed_dim]
84 | return h_embs, r_embs, t_embs, r_hyper_embs
85 |
86 | def _project(self, ent_embeds, rel_hyper_embeds):
87 | """Project entity embedding into relation hyperplane
88 | computational process: $h - w_r^T h w_r$
89 |
90 | :param ent_embeds: entity embedding, size: [batch_size, embed_dim]
91 | :param rel_hyper_embeds: relation hyperplane, size: [batch_size, embed_dim]
92 | """
93 | return ent_embeds - rel_hyper_embeds * torch.sum(ent_embeds * rel_hyper_embeds, dim=1, keepdim=True)
94 |
95 | def _distance(self, triples):
96 | """计算一个 batch 的三元组的 distance
97 |
98 | :param triples: 一个 batch 的 triple,size: [batch, 3]
99 | :return: size: [batch,]
100 | """
101 | assert triples.shape[1] == 3
102 | # step 1: Transform index tensor to embedding tensor.
103 | h_embs, r_embs, t_embs, r_hyper_embs = self.embed(triples)
104 | # step 2: Project entity head and tail embedding to relation hyperplane
105 | h_embs = self._project(h_embs, r_hyper_embs)
106 | t_embs = self._project(t_embs, r_hyper_embs)
107 | # step 3: Calculate similarity score in relation hyperplane
108 | return self.dist_fn(h_embs + r_embs, t_embs)
109 |
110 | def _cal_margin_based_loss(self, pos_distances, neg_distances):
111 | """Calculate the margin-based loss
112 |
113 | :param pos_distances: [batch, ]
114 | :param neg_distances: [batch, ]
115 | :return: margin_based loss, size: [1,]
116 | """
117 | ones = torch.tensor([-1], dtype=torch.long, device=self.device)
118 | return self.margin_loss_fn(pos_distances, neg_distances, ones)
119 |
120 | def _cal_scale_loss(self):
121 | """Calculate the scale loss.
122 | F.relu(x) is equal to max(x, 0).
123 | """
124 | ent_norm = torch.norm(self.ent_embedding.weight, p=2, dim=1) # the L2 norm of entity embedding, size: [ent_num, ]
125 | scale_loss = torch.sum(F.relu(ent_norm - 1))
126 | return scale_loss
127 |
128 | def _cal_orthogonal_loss(self):
129 | """Calculate the orthogonal loss.
130 | """
131 | orth_loss = torch.sum(F.relu(torch.sum(self.rel_hyper_embedding.weight * self.rel_embedding.weight, dim=1, keepdim=False) / torch.norm(self.rel_embedding.weight, p=2, dim=1, keepdim=False) - self.eps ** 2))
132 | return orth_loss
133 |
134 | def loss(self, pos_distances, neg_distances):
135 | """Calculate the loss
136 |
137 | :param pos_distances: [batch, ]
138 | :param neg_distances: [batch, ]
139 | :return: loss
140 | """
141 | margin_based_loss = self._cal_margin_based_loss(pos_distances, neg_distances)
142 | scale_loss = self._cal_scale_loss()
143 | orth_loss = self._cal_orthogonal_loss()
144 | ent_num = self.ent_num
145 | return margin_based_loss + self.C * (scale_loss / ent_num + orth_loss / ent_num)
146 |
147 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):
148 | """Return model losses based on the input.
149 |
150 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)
151 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)
152 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component
153 | """
154 | assert pos_triples.size()[1] == 3
155 | assert neg_triples.size()[1] == 3
156 |
157 | pos_distances = self._distance(pos_triples)
158 | neg_distances = self._distance(neg_triples)
159 | loss = self.loss(pos_distances, neg_distances)
160 | return loss, pos_distances, neg_distances
161 |
162 | def predict(self, triples: torch.Tensor):
163 | """Calculated dissimilarity score for given triplets.
164 |
165 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
166 | :return: dissimilarity score for given triplets
167 | """
168 | return self._distance(triples)
169 |
170 |
171 | class TransHMain(ModelMain):
172 | def __init__(
173 | self,
174 | dataset_conf: LocalDatasetConf,
175 | train_conf: TrainConf,
176 | hyper_params: TransHHyperParam,
177 | device: torch.device
178 | ) -> None:
179 | super().__init__()
180 | self.dataset_conf = dataset_conf
181 | self.train_conf = train_conf
182 | self.hyper_params = hyper_params
183 | self.device = device
184 |
185 | def __call__(self):
186 | # create mapping
187 | entity2id, rel2id = create_mapping(self.dataset_conf)
188 | ent_num = len(entity2id)
189 | rel_num = len(rel2id)
190 |
191 | # create dataset and dataloader
192 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id)
193 |
194 | # create negative-sampler
195 | neg_sampler = BernNegSampler(train_dataset, self.device)
196 |
197 | # create model
198 | model = TransH(ent_num, rel_num, self.device, self.hyper_params)
199 | model = model.to(self.device)
200 |
201 | # create optimizer
202 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate)
203 |
204 | # create trainer
205 | trainer = TransETrainer(
206 | model=model,
207 | train_conf=self.train_conf,
208 | params=self.hyper_params,
209 | dataset_conf=self.dataset_conf,
210 | entity2id=entity2id,
211 | rel2id=rel2id,
212 | device=self.device,
213 | train_dataloder=train_dataloader,
214 | valid_dataloder=valid_dataloader,
215 | train_neg_sampler=neg_sampler,
216 | valid_neg_sampler=neg_sampler,
217 | optimzer=optimizer
218 | )
219 |
220 | # training process
221 | trainer.run_training()
222 |
223 | # create evaluator
224 | metrics = [
225 | MetricEnum.MRR,
226 | MetricEnum.HITS_AT_1,
227 | MetricEnum.HITS_AT_3,
228 | MetricEnum.HITS_AT_10
229 | ]
230 | evaluator = RankEvaluator(self.device, metrics)
231 |
232 | # Testing the best checkpoint on test dataset
233 | # load best model
234 | ckpt = storage.load_checkpoint(self.train_conf)
235 | model.load_state_dict(ckpt.model_state_dict)
236 | model = model.to(self.device)
237 | # create test-dataset
238 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id)
239 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size)
240 | # run inference on test-dataset
241 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator)
242 |
243 | # choice metric formatter
244 | metric_formatter = StringFormatter()
245 |
246 | # choice the way of serialize
247 | serilizer = FileSerializer(self.train_conf, self.dataset_conf)
248 | # serialize the metric
249 | serilizer.serialize(metric, metric_formatter)
250 |
--------------------------------------------------------------------------------
/krl/models/TransR.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference:
3 |
4 | - https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/TransR.py
5 | - https://github.com/zqhead/TransR
6 | - https://github.com/Sujit-O/pykg2vec/blob/master/pykg2vec/models/pairwise.py
7 | - https://github.com/nju-websoft/muKG/blob/main/src/torch/kge_models/TransR.py
8 |
9 | Note: Although the TransE can be run, I don't know why it have very low hits@10 metric.
10 | """
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from torch.utils.data import DataLoader
16 |
17 | from ..base_model import TransXBaseModel, ModelMain
18 | from ..config import HyperParam, LocalDatasetConf, TrainConf
19 | from ..dataset import create_mapping, LocalKRLDataset
20 | from ..negative_sampler import BernNegSampler
21 | from .. import utils
22 | from ..trainer import TransETrainer
23 | from ..metric import MetricEnum
24 | from ..evaluator import RankEvaluator
25 | from .. import storage
26 | from ..metric_fomatter import StringFormatter
27 | from ..serializer import FileSerializer
28 |
29 | class TransRHyperParam(HyperParam):
30 | """Hyper-parameters of TransR
31 | """
32 | embed_dim: int
33 | norm: int
34 | margin: float
35 | C: float
36 |
37 |
38 | class TransR(TransXBaseModel):
39 | def __init__(self,
40 | ent_num: int,
41 | rel_num: int,
42 | device: torch.device,
43 | hyper_params: TransRHyperParam
44 | ):
45 | super().__init__()
46 | self.ent_num = ent_num
47 | self.rel_num = rel_num
48 | self.device = device
49 | self.norm = hyper_params.norm
50 | self.embed_dim = hyper_params.embed_dim
51 | self.C = hyper_params.C
52 |
53 | self.margin = hyper_params.margin
54 | self.epsilon = 2.0
55 | self.embedding_range = (self.margin + self.epsilon) / self.embed_dim
56 |
57 | # 初始化 ent_embedding,按照原论文的方法来初始化
58 | self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)
59 | nn.init.uniform_(
60 | tensor=self.ent_embedding.weight.data,
61 | a=-self.embedding_range,
62 | b=self.embedding_range
63 | )
64 | # torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data)
65 |
66 | # 初始化 rel_embedding
67 | self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)
68 | nn.init.uniform_(
69 | tensor=self.rel_embedding.weight.data,
70 | a=-self.embedding_range,
71 | b=self.embedding_range
72 | )
73 | # torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data)
74 |
75 | # initialize trasfer matrix
76 | self.transfer_matrix = nn.Embedding(self.rel_num, self.embed_dim * self.embed_dim)
77 | nn.init.uniform_(
78 | tensor=self.transfer_matrix.weight.data,
79 | a=-self.embedding_range,
80 | b=self.embedding_range
81 | )
82 | # nn.init.xavier_uniform_(self.tranfer_matrix.weight.data)
83 |
84 | self.dist_fn = nn.PairwiseDistance(p=self.norm) # the function for calculating the distance
85 | self.margin_loss_fn = nn.MarginRankingLoss(margin=self.margin)
86 |
87 | def _transfer(self, ent_embs: torch.Tensor, rels_tranfer: torch.Tensor):
88 | """Tranfer the entity space into the relation-specfic space
89 |
90 | :param ent_embs: [batch, ent_dim]
91 | :param rels_tranfer: [batch, ent_dim * rel_dim]
92 | """
93 | assert ent_embs.size(0) == rels_tranfer.size(0)
94 | assert ent_embs.size(1) == self.embed_dim
95 | assert rels_tranfer.size(1) == self.embed_dim * self.embed_dim
96 |
97 | rels_tranfer = rels_tranfer.view(-1, self.embed_dim, self.embed_dim) # [batch, ent_dim, rel_dim]
98 | ent_embs = ent_embs.view(-1, 1, self.embed_dim) # [batch, 1, ent_dim]
99 |
100 | ent_proj = torch.matmul(ent_embs, rels_tranfer) # [batch, 1, rel_dim]
101 | return ent_proj.view(-1, self.embed_dim) # [batch, rel_dim]
102 |
103 | def embed(self, triples):
104 | """get the embedding of triples
105 |
106 | :param triples: [heads, rels, tails]
107 | :return: embedding of triples.
108 | """
109 | assert triples.shape[1] == 3
110 | heads = triples[:, 0]
111 | rels = triples[:, 1]
112 | tails = triples[:, 2]
113 | # id -> embedding
114 | h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim]
115 | r_embs = self.rel_embedding(rels)
116 | t_embs = self.ent_embedding(tails)
117 | rels_transfer = self.transfer_matrix(rels)
118 | # tranfer the entity embedding from entity space into relation-specific space
119 | h_embs = self._transfer(h_embs, rels_transfer)
120 | t_embs = self._transfer(t_embs, rels_transfer)
121 | return h_embs, r_embs, t_embs
122 |
123 | def _distance(self, triples):
124 | """计算一个 batch 的三元组的 distance
125 |
126 | :param triples: 一个 batch 的 triple,size: [batch, 3]
127 | :return: size: [batch,]
128 | """
129 | h_embs, r_embs, t_embs = self.embed(triples)
130 | return self.dist_fn(h_embs + r_embs, t_embs)
131 |
132 | def _cal_margin_base_loss(self, pos_distances, neg_distances):
133 | """Calculate the margin-based loss
134 |
135 | :param pos_distances: [batch, ]
136 | :param neg_distances: [batch, ]
137 | :return: margin_based loss, size: [1,]
138 | """
139 | ones = torch.tensor([-1], dtype=torch.long, device=self.device)
140 | return self.margin_loss_fn(pos_distances, neg_distances, ones)
141 |
142 | def _cal_scale_loss(self, embedding: nn.Embedding):
143 | """Calculate the scale loss.
144 | F.relu(x) is equal to max(x, 0).
145 | """
146 | norm = torch.norm(embedding.weight, p=2, dim=1) # the L2 norm of entity embedding, size: [ent_num, ]
147 | scale_loss = torch.sum(F.relu(norm - 1))
148 | return scale_loss
149 |
150 | def loss(self, pos_distances, neg_distances):
151 | """Calculate the loss
152 |
153 | :param pos_distances: [batch, ]
154 | :param neg_distances: [batch, ]
155 | :return: loss
156 | """
157 | margin_based_loss = self._cal_margin_base_loss(pos_distances, neg_distances)
158 | ent_scale_loss = self._cal_scale_loss(self.ent_embedding)
159 | rel_scale_loss = self._cal_scale_loss(self.rel_embedding)
160 | return margin_based_loss + self.C * ((ent_scale_loss + rel_scale_loss) / (self.ent_num + self.rel_num))
161 |
162 | def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):
163 | """Return model losses based on the input.
164 |
165 | :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)
166 | :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)
167 | :return: tuple of the model loss, positive triplets loss component, negative triples loss component
168 | """
169 | assert pos_triples.size()[1] == 3
170 | assert neg_triples.size()[1] == 3
171 |
172 | pos_distances = self._distance(pos_triples)
173 | neg_distances = self._distance(neg_triples)
174 | loss = self.loss(pos_distances, neg_distances)
175 | return loss, pos_distances, neg_distances
176 |
177 | def predict(self, triples: torch.Tensor):
178 | """Calculated dissimilarity score for given triplets.
179 |
180 | :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
181 | :return: dissimilarity score for given triplets
182 | """
183 | return self._distance(triples)
184 |
185 |
186 | class TransRMain(ModelMain):
187 |
188 | def __init__(
189 | self,
190 | dataset_conf: LocalDatasetConf,
191 | train_conf: TrainConf,
192 | hyper_params: TransRHyperParam,
193 | device: torch.device
194 | ) -> None:
195 | super().__init__()
196 | self.dataset_conf = dataset_conf
197 | self.train_conf = train_conf
198 | self.hyper_params = hyper_params
199 | self.device = device
200 |
201 | def __call__(self):
202 | # create mapping
203 | entity2id, rel2id = create_mapping(self.dataset_conf)
204 | ent_num = len(entity2id)
205 | rel_num = len(rel2id)
206 |
207 | # create dataset and dataloader
208 | train_dataset, train_dataloader, valid_dataset, valid_dataloader = utils.create_local_dataloader(self.dataset_conf, self.hyper_params, entity2id, rel2id)
209 |
210 | # create negative-sampler
211 | neg_sampler = BernNegSampler(train_dataset, self.device)
212 |
213 | # create model
214 | model = TransR(ent_num, rel_num, self.device, self.hyper_params)
215 | model = model.to(self.device)
216 |
217 | # create optimizer
218 | optimizer = utils.create_optimizer(self.hyper_params.optimizer, model, self.hyper_params.learning_rate)
219 |
220 | # create trainer
221 | trainer = TransETrainer(
222 | model=model,
223 | train_conf=self.train_conf,
224 | params=self.hyper_params,
225 | dataset_conf=self.dataset_conf,
226 | entity2id=entity2id,
227 | rel2id=rel2id,
228 | device=self.device,
229 | train_dataloder=train_dataloader,
230 | valid_dataloder=valid_dataloader,
231 | train_neg_sampler=neg_sampler,
232 | valid_neg_sampler=neg_sampler,
233 | optimzer=optimizer
234 | )
235 |
236 | # training process
237 | trainer.run_training()
238 |
239 | # create evaluator
240 | metrics = [
241 | MetricEnum.MRR,
242 | MetricEnum.HITS_AT_1,
243 | MetricEnum.HITS_AT_3,
244 | MetricEnum.HITS_AT_10
245 | ]
246 | evaluator = RankEvaluator(self.device, metrics)
247 |
248 | # Testing the best checkpoint on test dataset
249 | # load best model
250 | ckpt = storage.load_checkpoint(self.train_conf)
251 | model.load_state_dict(ckpt.model_state_dict)
252 | model = model.to(self.device)
253 | # create test-dataset
254 | test_dataset = LocalKRLDataset(self.dataset_conf, 'test', entity2id, rel2id)
255 | test_dataloder = DataLoader(test_dataset, self.hyper_params.valid_batch_size)
256 | # run inference on test-dataset
257 | metric = trainer.run_inference(test_dataloder, ent_num, evaluator)
258 |
259 | # choice metric formatter
260 | metric_formatter = StringFormatter()
261 |
262 | # choice the way of serialize
263 | serilizer = FileSerializer(self.train_conf, self.dataset_conf)
264 | # serialize the metric
265 | serilizer.serialize(metric, metric_formatter)
266 |
--------------------------------------------------------------------------------
/krl/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/models/__init__.py
--------------------------------------------------------------------------------
/krl/negative_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | The sampler used to obtain negative samples for KRL.
3 | """
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from abc import ABC, abstractmethod
7 |
8 | from .dataset import KRLDataset
9 |
10 |
11 | class NegativeSampler(ABC):
12 | def __init__(self, dataset: KRLDataset):
13 | self.dataset = dataset
14 |
15 | @abstractmethod
16 | def neg_sample(self, heads, rels, tails):
17 | """执行负采样
18 |
19 | :param heads: 由 batch_size 个 head idx 组成的 tensor,size: [batch_size]
20 | :param rels: size [batch_size]
21 | :param tails: size [batch_size]
22 | """
23 | pass
24 |
25 |
26 | class RandomNegativeSampler(NegativeSampler):
27 | """
28 | 随机替换 head 或者 tail 来实现采样
29 | """
30 | def __init__(self, dataset: KRLDataset):
31 | super().__init__(dataset)
32 |
33 | def neg_sample(self, heads, rels, tails):
34 | device = heads.device
35 | ent_num = len(self.dataset.entity2id)
36 | head_or_tail = torch.randint(high=2, size=heads.size(), device=device)
37 | random_entities = torch.randint(high=ent_num, size=heads.size(), device=device)
38 | corupted_heads = torch.where(head_or_tail == 1, random_entities, heads)
39 | corupted_tails = torch.where(head_or_tail == 0, random_entities, tails)
40 | return torch.stack([corupted_heads, rels, corupted_tails], dim=1)
41 |
42 |
43 | class BernNegSampler(NegativeSampler):
44 | """
45 | Using bernoulli distribution to select whether to replace the head entity or tail entity.
46 | Specific sample process can refer to TransH paper or this implementation.
47 | """
48 | def __init__(self,
49 | dataset: KRLDataset):
50 | """init function
51 |
52 | :param dataset: KRLDataset for negative sample
53 | :param device: device
54 | """
55 | super().__init__(dataset)
56 | self.entity2id = dataset.meta.entity2id
57 | self.rel2id = dataset.meta.rel2id
58 | self.ent_num = len(self.entity2id)
59 | self.rel_num = len(self.rel2id)
60 |
61 | self.probs_of_replace_head = self._cal_tph_and_hpt() # 采样时替换 head 的概率
62 | assert self.probs_of_replace_head.shape[0] == self.rel_num
63 |
64 | def _cal_tph_and_hpt(self):
65 | dataloder = DataLoader(self.dataset, batch_size=1)
66 | r_h_matrix = torch.zeros([self.rel_num, self.ent_num]) # [i, j] 表示 r_i 与 h_j 有多少种尾实体
67 | r_t_matrxi = torch.zeros([self.rel_num, self.ent_num]) # [i, j] 表示 r_i 与 t_j 有多少种头实体
68 | for batch in iter(dataloder):
69 | h, r, t = batch[0], batch[1], batch[2]
70 | h = h.item()
71 | r = r.item()
72 | t = t.item()
73 | r_h_matrix[r, h] += 1
74 | r_t_matrxi[r, t] += 1
75 | tph = torch.sum(r_h_matrix, dim=1) / torch.sum(r_h_matrix != 0, dim=1)
76 | tph.nan_to_num_(1) # 将 nan 填充为 1
77 | hpt = torch.sum(r_t_matrxi, dim=1) / torch.sum(r_t_matrxi != 0, dim=1)
78 | hpt.nan_to_num_(1)
79 | probs_of_replace_head = tph / (tph + hpt)
80 | probs_of_replace_head.nan_to_num_(0.5)
81 | return probs_of_replace_head
82 |
83 | def neg_sample(self, heads, rels, tails):
84 | device = heads.device
85 | batch_size = heads.shape[0]
86 | rands = torch.rand([batch_size])
87 | probs = self.probs_of_replace_head[rels.cpu()]
88 | head_or_tail = (rands < probs).to(device) # True 的代表选择 head, False 的代表选择 tail
89 | random_entities = torch.randint(high=self.ent_num, size=heads.size(), device=device)
90 | corupted_heads = torch.where(head_or_tail == True, random_entities, heads)
91 | corupted_tails = torch.where(head_or_tail == False, random_entities, tails)
92 | return torch.stack([corupted_heads, rels, corupted_tails], dim=1)
93 |
--------------------------------------------------------------------------------
/krl/serializer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from .config import TrainConf, DatasetConf
4 | from .metric import KRLMetricBase, RankMetric
5 | from .metric_fomatter import MetricFormatter, StringFormatter
6 |
7 |
8 | class BaseSerializer(ABC):
9 | """
10 | Serialize the metrics.
11 | Every serializer class should derive this class.
12 | """
13 | def __init__(self, train_conf: TrainConf, dataset_conf: DatasetConf) -> None:
14 | super().__init__()
15 | self.train_conf = train_conf
16 | self.dataset_conf = dataset_conf
17 |
18 | @abstractmethod
19 | def serialize(self, metric: KRLMetricBase, formatter: MetricFormatter) -> bool:
20 | pass
21 |
22 |
23 | class FileSerializer(BaseSerializer):
24 | """
25 | Serilize the metric into local file.
26 | """
27 | def serialize(self, metric: RankMetric, formatter: StringFormatter) -> bool:
28 | """
29 | Serialize the string metric into file.
30 |
31 | :param metric: the metric instance that you want to serilize.
32 | :param formatter: We will use this formatter to convert metric instance into string.
33 | :return: success or not.
34 | """
35 | result = formatter.convert(metric, self.dataset_conf)
36 | with open(self.train_conf.metric_result_path, 'w') as f:
37 | f.write(result)
38 | return True
39 |
--------------------------------------------------------------------------------
/krl/storage.py:
--------------------------------------------------------------------------------
1 |
2 | from pydantic import BaseModel
3 | import torch
4 |
5 | from .config import TrainConf, HyperParam
6 | from .base_model import KRLModel
7 |
8 | class CheckpointFormat(BaseModel):
9 | model_state_dict: dict
10 | optim_state_dict: dict
11 | epoch_id: int
12 | best_score: float
13 | hyper_params: dict
14 |
15 |
16 | def save_checkpoint(model: KRLModel,
17 | optimzer: torch.optim.Optimizer,
18 | epoch_id: int,
19 | best_score: float,
20 | hyper_params: HyperParam,
21 | train_conf: TrainConf):
22 | ckpt = CheckpointFormat(
23 | model_state_dict=model.state_dict(),
24 | optim_state_dict=optimzer.state_dict(),
25 | epoch_id=epoch_id,
26 | best_score=best_score,
27 | hyper_params=hyper_params.dict()
28 | )
29 | torch.save(ckpt.dict(), train_conf.checkpoint_path)
30 |
31 |
32 | def load_checkpoint(train_conf: TrainConf) -> CheckpointFormat:
33 | ckpt = torch.load(train_conf.checkpoint_path)
34 | return CheckpointFormat.parse_obj(ckpt)
35 |
--------------------------------------------------------------------------------
/krl/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | KRLTrainer for training and testing models.
3 | """
4 | from typing import Mapping
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from rich.progress import Progress as RichProgress
8 | from abc import ABC, abstractmethod
9 |
10 | from .base_model import KRLModel
11 | from .config import TrainConf, HyperParam, DatasetConf
12 | from .negative_sampler import NegativeSampler
13 | from . import storage
14 | from .evaluator import RankEvaluator
15 | from .metric import RankMetric, MetricEnum
16 |
17 |
18 |
19 | class KRLTrainer(ABC):
20 | def __init__(self,
21 | model: KRLModel,
22 | train_conf: TrainConf,
23 | params: HyperParam,
24 | dataset_conf: DatasetConf,
25 | entity2id: Mapping[str ,int],
26 | rel2id: Mapping[str, int],
27 | device: torch.device,
28 | train_dataloader: DataLoader,
29 | valid_dataloader: DataLoader,
30 | train_neg_sampler: NegativeSampler,
31 | valid_neg_sampler: NegativeSampler,
32 | optimzer: torch.optim.Optimizer
33 | ) -> None:
34 | self.model = model
35 | self.train_conf = train_conf
36 | self.params = params
37 | self.dataset_conf = dataset_conf
38 | self.entity2id = entity2id
39 | self.rel2id = rel2id
40 | self.device = device
41 | self.train_dataloader = train_dataloader
42 | self.valid_dataloader = valid_dataloader
43 | self.train_neg_sampler = train_neg_sampler
44 | self.valid_neg_sampler = valid_neg_sampler
45 | self.optimzer = optimzer
46 |
47 | def run_inference(self,
48 | dataloder: DataLoader,
49 | ent_num: int,
50 | evaluator: RankEvaluator,
51 | ) -> RankMetric:
52 | """
53 | Run the inference process on the KRL model.
54 |
55 | Rewrite this function if you need more logic for the training model.
56 | The implementation here just provides an example of training TransE.
57 | """
58 | model = self.model
59 | device = self.device
60 |
61 | # entity_ids = [[0, 1, 2, ..., ent_num]], shape: [1, ent_num]
62 | entitiy_ids = torch.arange(0, ent_num, device=device).unsqueeze(0)
63 | for i, batch in enumerate(dataloder):
64 | # batch: [3, batch_size]
65 | heads, rels, tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)
66 | batch_size = heads.size()[0]
67 | all_entities = entitiy_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num]
68 | # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num]
69 | heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num]
70 | rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num)
71 | tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num)
72 | # check all possible tails
73 | triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3]
74 | tails_predictions = model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num]
75 | # check all possible heads
76 | triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3)
77 | heads_predictions = model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num]
78 |
79 | # Concept preditions
80 | predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num]
81 | ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1]
82 | # calculate metrics
83 | evaluator.evaluate(predictions, ground_truth_entity_id)
84 |
85 | metric_result = evaluator.export_metrics() # get result from evaluator
86 | return metric_result
87 |
88 | @abstractmethod
89 | def run_training(self):
90 | """
91 | Run the training process on the KRL model.
92 |
93 | Rewrite this function if you need more logic for the training model.
94 | The implementation here just provides an example of training TransE.
95 | """
96 | pass
97 |
98 |
99 | class TransETrainer(KRLTrainer):
100 | """
101 | Trainer for training TransE and other similar models.
102 | """
103 | def __init__(
104 | self,
105 | model: KRLModel,
106 | train_conf: TrainConf,
107 | params: HyperParam,
108 | dataset_conf: DatasetConf,
109 | entity2id: Mapping[str, int],
110 | rel2id: Mapping[str, int],
111 | device: torch.device,
112 | train_dataloder: DataLoader,
113 | valid_dataloder: DataLoader,
114 | train_neg_sampler: NegativeSampler,
115 | valid_neg_sampler: NegativeSampler,
116 | optimzer: torch.optim.Optimizer
117 | ) -> None:
118 | super().__init__(model, train_conf, params, dataset_conf, entity2id, rel2id, device, train_dataloder, valid_dataloder, train_neg_sampler, valid_neg_sampler, optimzer)
119 |
120 | def run_training(self):
121 | """
122 | Run the training process on the TransE model and other similar models, for example, TransH.
123 |
124 | Rewrite this function if you need more logic for the training model.
125 | The implementation here just provides an example of training TransE.
126 | """
127 | device = self.device
128 | optimzer = self.optimzer
129 | model = self.model
130 | DATASET_LEN = len(self.train_dataloader.dataset)
131 | # prepare the tools for tarining
132 | best_score = 0.0
133 | evaluator = RankEvaluator(self.device, [MetricEnum.HITS_AT_10])
134 | # training loop
135 | with RichProgress() as rich_progress:
136 | train_task = rich_progress.add_task('[green]Total training...', total=self.params.epoch_size)
137 | for epoch_id in range(1, self.params.epoch_size + 1):
138 | epoch_task = rich_progress.add_task(f'[cyan]Epoch {epoch_id}', total=DATASET_LEN)
139 | model.train()
140 | for i, batch in enumerate(self.train_dataloader):
141 | # get a batch of training data
142 | pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)
143 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3]
144 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3]
145 | optimzer.zero_grad()
146 | # calculte loss
147 | loss, _, _ = model(pos_triples, neg_triples)
148 | loss.backward()
149 | # update model
150 | optimzer.step()
151 | rich_progress.update(epoch_task, advance=pos_triples.size(0))
152 | rich_progress.remove_task(epoch_task)
153 |
154 | if epoch_id % self.params.valid_freq == 0:
155 | model.eval()
156 | with torch.no_grad():
157 | ent_num = len(self.entity2id)
158 | evaluator.clear() # clear the evaluator
159 | metric = self.run_inference(self.valid_dataloader, ent_num, evaluator)
160 | hits_at_10 = metric.hits_at_10
161 | if hits_at_10 > best_score:
162 | best_score = hits_at_10
163 | print('best score of valid: ', best_score)
164 | storage.save_checkpoint(model, optimzer, epoch_id, best_score, self.params, self.train_conf)
165 |
166 | rich_progress.update(train_task, advance=1)
167 |
168 |
169 | class RescalTrainer(KRLTrainer):
170 | """
171 | Trainer for tarining RESCAL and other similar models.
172 | """
173 | def __init__(self, model: KRLModel, train_conf: TrainConf, params: HyperParam, dataset_conf: DatasetConf, entity2id: Mapping[str, int], rel2id: Mapping[str, int], device: torch.device, train_dataloder: DataLoader, valid_dataloder: DataLoader, train_neg_sampler: NegativeSampler, valid_neg_sampler: NegativeSampler, optimzer: torch.optim.Optimizer) -> None:
174 | super().__init__(model, train_conf, params, dataset_conf, entity2id, rel2id, device, train_dataloder, valid_dataloder, train_neg_sampler, valid_neg_sampler, optimzer)
175 |
176 | def run_training(self):
177 | device = self.device
178 | optimzer = self.optimzer
179 | model = self.model
180 | DATASET_LEN = len(self.train_dataloader.dataset)
181 | # prepare tools for training
182 | best_score = 0.0
183 | evaluator = RankEvaluator(self.device, [MetricEnum.HITS_AT_10])
184 | # training loop
185 | with RichProgress() as rich_progress:
186 | train_task = rich_progress.add_task('[green]Total training...', total=self.params.epoch_size)
187 | for epoch_id in range(1, self.params.epoch_size + 1):
188 | epoch_task = rich_progress.add_task(f'[cyan]Epoch {epoch_id}', total=DATASET_LEN)
189 | loss_sum = 0
190 | model.train()
191 | for batch in iter(self.train_dataloader):
192 | pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)
193 | pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3]
194 | neg_triples = self.train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3]
195 | triples = torch.cat([pos_triples, neg_triples])
196 | pos_num = pos_triples.size(0)
197 | total_num = triples.size(0)
198 | labels = torch.zeros([total_num], device=device)
199 | labels[0: pos_num] = 1 # the pos_triple label is equal to 1.
200 | shuffle_index = torch.randperm(total_num, device=device) # index sequence for shuffling data
201 | triples = triples[shuffle_index]
202 | labels = labels[shuffle_index]
203 | # calculate loss
204 | optimzer.zero_grad()
205 | loss, _ = model(triples, labels)
206 | loss.backward()
207 | loss_sum += loss.cpu().item()
208 | # update model
209 | optimzer.step()
210 | rich_progress.update(epoch_task, advance=pos_triples.size(0))
211 | rich_progress.remove_task(epoch_task)
212 |
213 | if epoch_id % self.params.valid_freq == 0:
214 | model.eval()
215 | with torch.no_grad():
216 | ent_num = len(self.entity2id)
217 | evaluator.clear()
218 | metric = self.run_inference(self.valid_dataloader, ent_num, evaluator)
219 | hits_at_10 = metric.hits_at_10
220 | if hits_at_10 > best_score:
221 | best_score = hits_at_10
222 | print('best score of valid: ', best_score)
223 | storage.save_checkpoint(model, optimzer, epoch_id, best_score, self.params, self.train_conf)
224 |
225 | rich_progress.update(train_task, advance=1)
--------------------------------------------------------------------------------
/krl/typer_apps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yubinCloud/kg2vec/c1e291031530933e5a381c2bdbf1f17e655e041e/krl/typer_apps/__init__.py
--------------------------------------------------------------------------------
/krl/typer_apps/distmult.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import typer
3 |
4 | from ..config import LocalDatasetConf, TrainConf
5 | from ..models.DistMult import DistMultHyperParam, DistMultMain
6 | from .. import utils
7 |
8 |
9 |
10 | app = typer.Typer()
11 |
12 |
13 | @app.command(name='train')
14 | def train_distmult(
15 | dataset_name: str = typer.Option(...),
16 | base_dir: Path = typer.Option(...),
17 | batch_size: int = typer.Option(...),
18 | valid_batch_size: int = typer.Option(...),
19 | valid_freq: int = typer.Option(50),
20 | lr: float = typer.Option(0.001),
21 | epoch_size: int = typer.Option(...),
22 | optimizer: str = typer.Option('adam'),
23 | embed_dim: int = typer.Option(...),
24 | alpha: float = typer.Option(0.001, help='regularization parameter'),
25 | regul_type: str = typer.Option('F2', help='regularization type, F2 or N3', case_sensitive=False),
26 | ckpt_path: Path = typer.Option(...),
27 | metric_result_path: Path = typer.Option(...)
28 | ):
29 | if not base_dir.exists():
30 | print("base_dir doesn't exists")
31 | raise typer.Exit()
32 | dataset_conf = LocalDatasetConf(
33 | dataset_name=dataset_name,
34 | base_dir=base_dir
35 | )
36 | train_conf = TrainConf(
37 | checkpoint_path=ckpt_path.absolute().as_posix(),
38 | metric_result_path=metric_result_path.absolute().as_posix()
39 | )
40 | hyper_params = DistMultHyperParam(
41 | batch_size=batch_size,
42 | valid_batch_size=valid_batch_size,
43 | learning_rate=lr,
44 | optimizer=optimizer,
45 | epoch_size=epoch_size,
46 | embed_dim=embed_dim,
47 | valid_freq=valid_freq,
48 | alpha=alpha,
49 | regul_type=regul_type
50 | )
51 | device = utils.get_device()
52 |
53 | main = DistMultMain(dataset_conf, train_conf, hyper_params, device)
54 |
55 | main()
56 |
--------------------------------------------------------------------------------
/krl/typer_apps/rescal.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import typer
3 |
4 | from ..config import TrainConf
5 | from ..models.RESCAL import RescalHyperParam, RescalMain
6 | from .. import utils
7 | from ..dataset import load_krl_dataset
8 |
9 |
10 | app = typer.Typer()
11 |
12 |
13 | @app.command(name='train')
14 | def train_rescal(
15 | dataset_name: str = typer.Option('FB15k'),
16 | batch_size: int = typer.Option(256),
17 | valid_batch_size: int = typer.Option(16),
18 | valid_freq: int = typer.Option(5),
19 | lr: float = typer.Option(0.01),
20 | epoch_size: int = typer.Option(200),
21 | optimizer: str = typer.Option('adam'),
22 | embed_dim: int = typer.Option(50),
23 | alpha: float = typer.Option(0.001, help='regularization parameter'),
24 | regul_type: str = typer.Option('F2', help='regularization type, F2 or N3', case_sensitive=False),
25 | ckpt_path: Path = typer.Option(Path('/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k.ckpt')),
26 | metric_result_path: Path = typer.Option(Path('/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/rescal_fb15k_metrics.txt'))
27 | ):
28 | dataset_dict = load_krl_dataset(dataset_name)
29 | train_conf = TrainConf(
30 | checkpoint_path=ckpt_path.absolute().as_posix(),
31 | metric_result_path=metric_result_path.absolute().as_posix()
32 | )
33 | hyper_params = RescalHyperParam(
34 | batch_size=batch_size,
35 | valid_batch_size=valid_batch_size,
36 | learning_rate=lr,
37 | optimizer=optimizer,
38 | epoch_size=epoch_size,
39 | embed_dim=embed_dim,
40 | valid_freq=valid_freq,
41 | alpha=alpha,
42 | regul_type=regul_type
43 | )
44 | device = utils.get_device()
45 |
46 | main = RescalMain(dataset_dict, train_conf, hyper_params, device)
47 |
48 | main()
49 |
--------------------------------------------------------------------------------
/krl/typer_apps/transe.py:
--------------------------------------------------------------------------------
1 | import typer
2 | from pathlib import Path
3 |
4 | from ..config import TrainConf
5 | from ..models.TransE import TransEHyperParam, TransELitMain
6 | from ..dataset import load_krl_dataset
7 |
8 |
9 | app = typer.Typer()
10 |
11 |
12 | @app.command(name='train')
13 | def train_transe(
14 | dataset_name: str = typer.Option('FB15k'),
15 | batch_size: int = typer.Option(256),
16 | valid_batch_size: int = typer.Option(16),
17 | valid_freq: int = typer.Option(5),
18 | lr: float = typer.Option(0.01),
19 | optimizer: str = typer.Option('adam'),
20 | epoch_size: int = typer.Option(500),
21 | embed_dim: int = typer.Option(100),
22 | norm: int = typer.Option(1),
23 | margin: float = typer.Option(2.0),
24 | logs_dir: Path = typer.Option('lightning_logs/'),
25 | early_stoping_patience: int = typer.Option(5)
26 | ):
27 | dataset_dict = load_krl_dataset(dataset_name)
28 | train_conf = TrainConf(
29 | logs_dir=logs_dir
30 | )
31 | hyper_params = TransEHyperParam(
32 | batch_size=batch_size,
33 | valid_batch_size=valid_batch_size,
34 | learning_rate=lr,
35 | optimizer=optimizer,
36 | epoch_size=epoch_size,
37 | early_stoping_patience=early_stoping_patience,
38 | embed_dim=embed_dim,
39 | norm=norm,
40 | margin=margin,
41 | valid_freq=valid_freq
42 | )
43 |
44 | main = TransELitMain(dataset_dict, train_conf, hyper_params)
45 |
46 | main()
--------------------------------------------------------------------------------
/krl/typer_apps/transh.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import typer
3 |
4 | from ..config import LocalDatasetConf, TrainConf
5 | from ..models.TransH import TransHHyperParam, TransHMain
6 | from .. import utils
7 |
8 |
9 | app = typer.Typer()
10 |
11 |
12 | @app.command(name='train')
13 | def train_tranh(
14 | dataset_name: str = typer.Option(...),
15 | base_dir: Path = typer.Option(...),
16 | batch_size: int = typer.Option(128),
17 | valid_batch_size: int = typer.Option(64),
18 | valid_freq: int = typer.Option(5),
19 | lr: float = typer.Option(0.001),
20 | epoch_size: int = typer.Option(500),
21 | optimizer: str = typer.Option('adam'),
22 | embed_dim: int = typer.Option(50),
23 | norm: int = typer.Option(2),
24 | margin: float = typer.Option(1.0),
25 | C: float = typer.Option(1.0, help='a hyper-parameter weighting the importance of soft constraints.'),
26 | eps: float = typer.Option(1e-3, help='the $\episilon$ in loss function'),
27 | ckpt_path: Path = typer.Option(...),
28 | metric_result_path: Path = typer.Option(...)
29 | ):
30 | if not base_dir.exists():
31 | print("base_dir doesn't exists.")
32 | raise typer.Exit()
33 | # initialize all configurations
34 | dataset_conf = LocalDatasetConf(
35 | dataset_name=dataset_name,
36 | base_dir=base_dir
37 | )
38 | train_conf = TrainConf(
39 | checkpoint_path=ckpt_path.absolute().as_posix(),
40 | metric_result_path=metric_result_path.absolute().as_posix()
41 | )
42 | hyper_params = TransHHyperParam(
43 | batch_size=batch_size,
44 | valid_batch_size=valid_batch_size,
45 | learning_rate=lr,
46 | optimizer=optimizer,
47 | epoch_size=epoch_size,
48 | embed_dim=embed_dim,
49 | norm=norm,
50 | margin=margin,
51 | valid_freq=valid_freq,
52 | C=C,
53 | eps=eps
54 | )
55 | device = utils.get_device()
56 |
57 | main = TransHMain(dataset_conf, train_conf, hyper_params, device)
58 |
59 | main()
60 |
--------------------------------------------------------------------------------
/krl/typer_apps/transr.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import typer
3 |
4 | from ..config import LocalDatasetConf, TrainConf
5 | from ..models.TransR import TransRHyperParam, TransRMain
6 | from .. import utils
7 |
8 |
9 | app = typer.Typer()
10 |
11 |
12 | @app.command(name='train')
13 | def train_transr(
14 | dataset_name: str = typer.Option(...),
15 | base_dir: Path = typer.Option(...),
16 | batch_size: int = typer.Option(128),
17 | valid_batch_size: int = typer.Option(64),
18 | valid_freq: int = typer.Option(5),
19 | lr: float = typer.Option(0.001),
20 | epoch_size: int = typer.Option(500),
21 | optimizer: str = typer.Option('adam'),
22 | embed_dim: int = typer.Option(50),
23 | norm: int = typer.Option(2),
24 | margin: float = typer.Option(1.0),
25 | C: float = typer.Option(1.0),
26 | ckpt_path: Path = typer.Option(...),
27 | metric_result_path: Path = typer.Option(...)
28 | ):
29 | if not base_dir.exists():
30 | print("base_dir doesn't exists.")
31 | raise typer.Exit()
32 | dataset_conf = LocalDatasetConf(
33 | dataset_name=dataset_name,
34 | base_dir=base_dir
35 | )
36 | train_conf = TrainConf(
37 | checkpoint_path=ckpt_path.absolute().as_posix(),
38 | metric_result_path=metric_result_path.absolute().as_posix()
39 | )
40 | hyper_params = TransRHyperParam(
41 | batch_size=batch_size,
42 | valid_batch_size=valid_batch_size,
43 | learning_rate=lr,
44 | optimizer=optimizer,
45 | epoch_size=epoch_size,
46 | embed_dim=embed_dim,
47 | norm=norm,
48 | margin=margin,
49 | C=C,
50 | valid_freq=valid_freq
51 | )
52 | device = utils.get_device()
53 |
54 | main = TransRMain(dataset_conf, train_conf, hyper_params, device)
55 |
56 | main()
--------------------------------------------------------------------------------
/krl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .device import get_device
2 | from .seed import set_seed
3 | from .data import create_local_dataloader
4 | from .optim import create_optimizer
5 | from .logs_dir import create_logs_dir
--------------------------------------------------------------------------------
/krl/utils/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from typing import Mapping, Tuple
3 |
4 | from ..config import HyperParam, LocalDatasetConf
5 | from ..dataset import LocalKRLDataset
6 |
7 |
8 |
9 | def create_local_dataloader(
10 | dataset_conf: LocalDatasetConf,
11 | params: HyperParam,
12 | entity2id: Mapping[str, int],
13 | rel2id: Mapping[str, int]
14 | ) -> Tuple[LocalKRLDataset, DataLoader, LocalKRLDataset, DataLoader]:
15 | train_dataset = LocalKRLDataset(dataset_conf, 'train', entity2id, rel2id)
16 | train_dataloader = DataLoader(train_dataset, params.batch_size)
17 | valid_dataset = LocalKRLDataset(dataset_conf, 'valid', entity2id, rel2id)
18 | valid_dataloader = DataLoader(valid_dataset, params.valid_batch_size)
19 | return train_dataset, train_dataloader, valid_dataset, valid_dataloader
20 |
--------------------------------------------------------------------------------
/krl/utils/device.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def get_device() -> torch.device:
4 | return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
5 |
--------------------------------------------------------------------------------
/krl/utils/logs_dir.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from ..config import TrainConf, DatasetConf
4 | from ..base_model import KRLModel
5 |
6 |
7 | def create_logs_dir(
8 | model: KRLModel,
9 | train_conf: TrainConf,
10 | dataset_conf: DatasetConf
11 | ) -> str:
12 | dir = train_conf.logs_dir / model.__class__.__name__ / dataset_conf.dataset_name / datetime.now().strftime(r'%y%m%d-%H%M%S')
13 | dir.mkdir(parents=True, exist_ok=True)
14 | return str(dir)
15 |
--------------------------------------------------------------------------------
/krl/utils/optim.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | import torch
3 | from torch.optim import Optimizer
4 | from typing import Union
5 |
6 | from ..base_model import KRLModel
7 |
8 |
9 | class Lion(Optimizer):
10 | r"""Implements Lion algorithm."""
11 |
12 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
13 | """Initialize the hyperparameters.
14 | Args:
15 | params (iterable): iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr (float, optional): learning rate (default: 1e-4)
18 | betas (Tuple[float, float], optional): coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.99))
20 | weight_decay (float, optional): weight decay coefficient (default: 0)
21 | """
22 |
23 | if not 0.0 <= lr:
24 | raise ValueError('Invalid learning rate: {}'.format(lr))
25 | if not 0.0 <= betas[0] < 1.0:
26 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
27 | if not 0.0 <= betas[1] < 1.0:
28 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
29 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
30 | super().__init__(params, defaults)
31 |
32 | @torch.no_grad()
33 | def step(self, closure=None):
34 | """Performs a single optimization step.
35 | Args:
36 | closure (callable, optional): A closure that reevaluates the model
37 | and returns the loss.
38 | Returns:
39 | the loss.
40 | """
41 | loss = None
42 | if closure is not None:
43 | with torch.enable_grad():
44 | loss = closure()
45 |
46 | for group in self.param_groups:
47 | for p in group['params']:
48 | if p.grad is None:
49 | continue
50 |
51 | # Perform stepweight decay
52 | p.data.mul_(1 - group['lr'] * group['weight_decay'])
53 |
54 | grad = p.grad
55 | state = self.state[p]
56 | # State initialization
57 | if len(state) == 0:
58 | # Exponential moving average of gradient values
59 | state['exp_avg'] = torch.zeros_like(p)
60 |
61 | exp_avg = state['exp_avg']
62 | beta1, beta2 = group['betas']
63 |
64 | # Weight update
65 | update = exp_avg * beta1 + grad * (1 - beta1)
66 | p.add_(torch.sign(update), alpha=-group['lr'])
67 | # Decay the momentum running average coefficient
68 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
69 |
70 | return loss
71 |
72 |
73 | optimizer_map = {
74 | 'adam': optim.Adam,
75 | 'adamw': optim.AdamW,
76 | 'radam': optim.RAdam,
77 | 'sgd': optim.SGD,
78 | 'adagrad': optim.Adagrad,
79 | 'rms': optim.RMSprop,
80 | 'lion': Lion
81 | }
82 |
83 |
84 | def create_optimizer(optimizer_name: Union[str, Optimizer], model: KRLModel, lr: float) -> Optimizer:
85 | """create a optimizer from a optimizer name
86 | """
87 | # if the param `optimizer_name` is a Optimizer, return it now.
88 | if isinstance(optimizer_name, Optimizer):
89 | return optimizer_name
90 |
91 | optim_klass = optimizer_map.get(optimizer_name.lower())
92 | if optim_klass is None:
93 | raise NotImplementedError(f'No support for {optimizer_name} optimizer')
94 | return optim_klass(model.parameters(), lr)
95 |
--------------------------------------------------------------------------------
/krl/utils/seed.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def set_seed(seed: int):
7 | random.seed(seed)
8 | torch.manual_seed(seed)
9 | if torch.cuda.is_available():
10 | torch.cuda.manual_seed(seed)
11 | torch.cuda.manual_seed_all(seed)
12 | np.random.seed(seed)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch
2 | typer[all]
3 | pydantic
4 | loguru
5 |
--------------------------------------------------------------------------------
/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/root/miniconda3/envs/ais/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13 | " from .autonotebook import tqdm as notebook_tqdm\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "from krl.dataset.instance import load_krl_dataset"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "Using custom data configuration VLyb--WN18-18d07f2b5874d7b1\n",
31 | "Found cached dataset csv (/root/.cache/huggingface/datasets/VLyb___csv/VLyb--WN18-18d07f2b5874d7b1/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n",
32 | "100%|██████████| 3/3 [00:00<00:00, 466.55it/s]\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "wn18 = load_krl_dataset('WN18')"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 6,
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/plain": [
48 | "(10, 4, 11)"
49 | ]
50 | },
51 | "execution_count": 6,
52 | "metadata": {},
53 | "output_type": "execute_result"
54 | }
55 | ],
56 | "source": [
57 | "wn18.train[5]"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": []
66 | }
67 | ],
68 | "metadata": {
69 | "kernelspec": {
70 | "display_name": "ais",
71 | "language": "python",
72 | "name": "python3"
73 | },
74 | "language_info": {
75 | "codemirror_mode": {
76 | "name": "ipython",
77 | "version": 3
78 | },
79 | "file_extension": ".py",
80 | "mimetype": "text/x-python",
81 | "name": "python",
82 | "nbconvert_exporter": "python",
83 | "pygments_lexer": "ipython3",
84 | "version": "3.9.16"
85 | },
86 | "orig_nbformat": 4,
87 | "vscode": {
88 | "interpreter": {
89 | "hash": "e79d845ce09a116dcacebcdcf6a8ee439742a7b8247d9dc7720dd3aefde3f822"
90 | }
91 | }
92 | },
93 | "nbformat": 4,
94 | "nbformat_minor": 2
95 | }
96 |
--------------------------------------------------------------------------------
/transe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TransE\n",
8 | "\n",
9 | "Here we will show how to reproduce the TransE model."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from pydantic import BaseSettings, BaseModel, Field\n",
19 | "from typing import Optional, Literal, Tuple, Dict, List\n",
20 | "from torch.utils.data import Dataset, DataLoader\n",
21 | "import torch.nn as nn\n",
22 | "import torch\n",
23 | "import numpy as np\n",
24 | "from abc import ABC, abstractmethod"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "## Config Class\n",
32 | "\n",
33 | "Define out configuration classes that can instantiate configuration items that meet your requirements based on your runtime environment or your own ideas.\n",
34 | "\n",
35 | "With the help of configuration classes, you can change the dataset or adjust the hyperparameters of the model without changing the model logic."
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "class DatasetConf(BaseSettings):\n",
45 | " \"\"\"\n",
46 | " 数据集的相关配置信息\n",
47 | " \"\"\"\n",
48 | " dataset_name: str = Field(title='数据集的名称,方便打印时查看')\n",
49 | " base_dir: str = Field(title='数据集的目录')\n",
50 | " entity2id_path: str = Field(default='entity2id.txt', title='entity2id 的文件名')\n",
51 | " relation2id_path: str = Field(default='relation2id.txt', title='relation2id 的文件名')\n",
52 | " train_path: str = Field(default='train.txt', title='training set 的文件')\n",
53 | " valid_path: str = Field(default='valid.txt', title='valid set 的文件')\n",
54 | " test_path: str = Field(default='test.txt', title='testing set 的目录')\n",
55 | "\n",
56 | "\n",
57 | "class HyperParam(BaseModel):\n",
58 | " \"\"\"\n",
59 | " 超参数\n",
60 | " \"\"\"\n",
61 | " batch_size: int = 128\n",
62 | " valid_batch_size: int = 64\n",
63 | " learning_rate: float = 0.001\n",
64 | " epoch_size: int = 500\n",
65 | " embed_dim: int = 50\n",
66 | " norm: int = 1\n",
67 | " margin: int = 2.0\n",
68 | " valid_freq: int = Field(title='训练过程中,每隔多少次就做一次 valid 来验证是否保存模型')\n",
69 | "\n",
70 | "\n",
71 | "class TrainConf(BaseModel):\n",
72 | " \"\"\"\n",
73 | " 训练的一些配置\n",
74 | " \"\"\"\n",
75 | " checkpoint_path: str = Field(title='保存模型的路径')\n",
76 | " metric_result_path: str = Field(title='运行 test 的 metric 输出位置')"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "## Dataset\n",
84 | "\n",
85 | "Defines the classes used to read datasets, including reading entity-to-ID mappings, relationship-to-ID mappings, and triplet collections.\n",
86 | "\n",
87 | "+ The `create_mapping` function is used to generate the entity-to-ID mapping dictionary and the relationship-to-ID mapping dictionary.\n",
88 | "+ `KRLDataset` is a further wrapper around the `Dataset` class in PyTorch, and is similar in usage."
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 4,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "EntityMapping = Dict[str, int]\n",
98 | "RelMapping = Dict[str, int]\n",
99 | "Triple = List[int]\n",
100 | "\n",
101 | "def create_mapping(dataset_conf: DatasetConf) -> Tuple[EntityMapping, RelMapping]:\n",
102 | " \"\"\"\n",
103 | " create mapping of `entity2id` and `relation2id`\n",
104 | " \"\"\"\n",
105 | " # 读取 entity2id\n",
106 | " entity2id = dict()\n",
107 | " with open(dataset_conf.base_dir + dataset_conf.entity2id_path) as f:\n",
108 | " for line in f:\n",
109 | " entity, entity_id = line.split()\n",
110 | " entity = entity.strip()\n",
111 | " entity_id = int(entity_id.strip())\n",
112 | " entity2id[entity] = entity_id\n",
113 | " # 读取 relation2id\n",
114 | " rel2id = dict()\n",
115 | " with open(dataset_conf.base_dir + dataset_conf.relation2id_path) as f:\n",
116 | " for line in f:\n",
117 | " rel, rel_id = line.split()\n",
118 | " rel = rel.strip()\n",
119 | " rel_id = int(rel_id.strip())\n",
120 | " rel2id[rel] = rel_id\n",
121 | " return entity2id, rel2id\n",
122 | "\n",
123 | "\n",
124 | "class KRLDataset(Dataset):\n",
125 | " def __init__(self,\n",
126 | " dataset_conf: DatasetConf,\n",
127 | " mode: Literal['train', 'valid', 'test'],\n",
128 | " entity2id: Dict[str, int],\n",
129 | " rel2id: Dict[str, int]) -> None:\n",
130 | " super().__init__()\n",
131 | " self.conf = dataset_conf\n",
132 | " self.mode = mode\n",
133 | " self.triples = []\n",
134 | " self.entity2id = entity2id\n",
135 | " self.rel2id = rel2id\n",
136 | " self._read_triples() # 读取数据集,并获得所有的 triples\n",
137 | " \n",
138 | " def _split_and_to_id(self, line: str) -> Triple:\n",
139 | " \"\"\"将数据集文件中的一行数据进行切分,并将 entity 和 rel 转换成 id\n",
140 | "\n",
141 | " :param line: 数据集的一行数据\n",
142 | " :return: [head_id, rel_id, tail_id]\n",
143 | " \"\"\"\n",
144 | " head, tail, rel = line.split()\n",
145 | " head_id = self.entity2id[head.strip()]\n",
146 | " rel_id = self.rel2id[rel.strip()]\n",
147 | " tail_id = self.entity2id[tail.strip()]\n",
148 | " return (head_id, rel_id, tail_id)\n",
149 | " \n",
150 | " def _read_triples(self):\n",
151 | " data_path = {\n",
152 | " 'train': self.conf.train_path,\n",
153 | " 'valid': self.conf.valid_path,\n",
154 | " 'test': self.conf.test_path\n",
155 | " }.get(self.mode)\n",
156 | " with open(self.conf.base_dir + data_path) as f:\n",
157 | " self.triples = [self._split_and_to_id(line) for line in f]\n",
158 | " \n",
159 | " def __len__(self):\n",
160 | " \"\"\"Denotes the total number of samples.\"\"\"\n",
161 | " return len(self.triples)\n",
162 | " \n",
163 | " def __getitem__(self, index) -> Triple:\n",
164 | " \"\"\"Returns (head id, relation id, tail id).\"\"\"\n",
165 | " triple = self.triples[index]\n",
166 | " return triple[0], triple[1], triple[2]"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {},
172 | "source": [
173 | "## Negative Sampler\n",
174 | "\n",
175 | "In order to train the model, we need not only positive samples, but also negative samples. The goal of the negative sampler is to generate negative samples based on the positive samples in the dataset.\n",
176 | "\n",
177 | "Since there are multiple negative sampling strategies, we abstract a common abstract class `NegativeSampler`, and all negative samplers that implement different negative sampling strategies should inherit from this abstract base class."
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 5,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "class NegativeSampler(ABC):\n",
187 | " def __init__(self, dataset: KRLDataset, device: torch.device):\n",
188 | " self.dataset = dataset\n",
189 | " self.device = device\n",
190 | " \n",
191 | " @abstractmethod\n",
192 | " def neg_sample(self, heads, rels, tails):\n",
193 | " \"\"\"执行负采样\n",
194 | "\n",
195 | " :param heads: 由 batch_size 个 head idx 组成的 tensor,size: [batch_size]\n",
196 | " :param rels: size [batch_size]\n",
197 | " :param tails: size [batch_size]\n",
198 | " \"\"\"\n",
199 | " pass\n"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "The simplest negative sampling strategy is to randomly replace the head entity or tail entity in a triplet to obtain a negative sample."
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": null,
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "class RandomNegativeSampler(NegativeSampler):\n",
216 | " \"\"\"\n",
217 | " 随机替换 head 或者 tail 来实现采样\n",
218 | " \"\"\"\n",
219 | " def __init__(self, dataset: KRLDataset, device: torch.device):\n",
220 | " super().__init__(dataset, device)\n",
221 | " \n",
222 | " def neg_sample(self, heads, rels, tails):\n",
223 | " ent_num = len(self.dataset.entity2id)\n",
224 | " head_or_tail = torch.randint(high=2, size=heads.size(), device=self.device)\n",
225 | " random_entities = torch.randint(high=ent_num, size=heads.size(), device=self.device)\n",
226 | " corupted_heads = torch.where(head_or_tail == 1, random_entities, heads)\n",
227 | " corupted_tails = torch.where(head_or_tail == 0, random_entities, tails)\n",
228 | " return torch.stack([corupted_heads, rels, corupted_tails], dim=1)"
229 | ]
230 | },
231 | {
232 | "cell_type": "markdown",
233 | "metadata": {},
234 | "source": [
235 | "## Model\n",
236 | "\n",
237 | "Defining the TransE model."
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 6,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "class TransE(nn.Module):\n",
247 | " def __init__(\n",
248 | " self,\n",
249 | " ent_num: int,\n",
250 | " rel_num: int,\n",
251 | " device: torch.device,\n",
252 | " norm: int,\n",
253 | " embed_dim: int,\n",
254 | " margin: float\n",
255 | " ):\n",
256 | " super().__init__()\n",
257 | " self.ent_num = ent_num\n",
258 | " self.rel_num = rel_num\n",
259 | " self.device = device\n",
260 | " self.norm = norm\n",
261 | " self.embed_dim = embed_dim\n",
262 | " self.margin = margin\n",
263 | "\n",
264 | " # Initialize ent_embedding\n",
265 | " self.ent_embedding = nn.Embedding(self.ent_num, self.embed_dim)\n",
266 | " torch.nn.init.xavier_uniform_(self.ent_embedding.weight.data)\n",
267 | " #uniform_range = 6 / np.sqrt(self.embed_dim)\n",
268 | " #self.ent_embedding.weight.data.uniform_(-uniform_range, uniform_range)\n",
269 | " \n",
270 | " # Initialize rel_embedding\n",
271 | " self.rel_embedding = nn.Embedding(self.rel_num, self.embed_dim)\n",
272 | " torch.nn.init.xavier_uniform_(self.rel_embedding.weight.data)\n",
273 | " #uniform_range = 6 / np.sqrt(self.embed_dim)\n",
274 | " #self.rel_embedding.weight.data.uniform_(-uniform_range, uniform_range)\n",
275 | "\n",
276 | " self.criterion = nn.MarginRankingLoss(margin=self.margin)\n",
277 | " \n",
278 | " def _distance(self, triples):\n",
279 | " \"\"\"Calculate the distance of a batch's triplet\n",
280 | "\n",
281 | " :param triples: triples of a batch,size: [batch, 3]\n",
282 | " :return: size: [batch,]\n",
283 | " \"\"\"\n",
284 | " heads = triples[:, 0]\n",
285 | " rels = triples[:, 1]\n",
286 | " tails = triples[:, 2]\n",
287 | " h_embs = self.ent_embedding(heads) # h_embs: [batch, embed_dim]\n",
288 | " r_embs = self.rel_embedding(rels)\n",
289 | " t_embs = self.ent_embedding(tails)\n",
290 | " dist = h_embs + r_embs - t_embs # [batch, embed_dim]\n",
291 | " return torch.norm(dist, p=self.norm, dim=1)\n",
292 | " \n",
293 | " def loss(self, pos_distances, neg_distances):\n",
294 | " \"\"\"Calculate the loss of TransE training\n",
295 | "\n",
296 | " :param pos_distances: [batch, ]\n",
297 | " :param neg_distances: [batch, ]\n",
298 | " :return: loss\n",
299 | " \"\"\"\n",
300 | " ones = torch.tensor([-1], dtype=torch.long, device=self.device)\n",
301 | " return self.criterion(pos_distances, neg_distances, ones)\n",
302 | " \n",
303 | " def forward(self, pos_triples: torch.Tensor, neg_triples: torch.Tensor):\n",
304 | " \"\"\"Return model losses based on the input.\n",
305 | "\n",
306 | " :param pos_triples: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)\n",
307 | " :param neg_triples: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)\n",
308 | " :return: tuple of the model loss, positive triplets loss component, negative triples loss component\n",
309 | " \"\"\"\n",
310 | " assert pos_triples.size()[1] == 3\n",
311 | " assert neg_triples.size()[1] == 3\n",
312 | " \n",
313 | " pos_distances = self._distance(pos_triples)\n",
314 | " neg_distances = self._distance(neg_triples)\n",
315 | " loss = self.loss(pos_distances, neg_distances)\n",
316 | " return loss, pos_distances, neg_distances\n",
317 | " \n",
318 | " def predict(self, triples: torch.Tensor):\n",
319 | " \"\"\"Calculated dissimilarity score for given triplets.\n",
320 | "\n",
321 | " :param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)\n",
322 | " :return: dissimilarity score for given triplets\n",
323 | " \"\"\"\n",
324 | " return self._distance(triples)"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "## Metric\n",
332 | "\n",
333 | "Calculate the metric for measuring the effect of link prediction, i.e. MRR and hits@10."
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 7,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "# metric\n",
343 | "\n",
344 | "def cal_hits_at_k(\n",
345 | " predictions: torch.Tensor,\n",
346 | " ground_truth_idx: torch.Tensor,\n",
347 | " device: torch.device,\n",
348 | " k: int\n",
349 | ") -> float:\n",
350 | " \"\"\"Calculates number of hits@k.\n",
351 | "\n",
352 | " :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions\n",
353 | " must be sorted in class ids order\n",
354 | " :param ground_truth_idx: Bx1 tensor with index of ground truth class\n",
355 | " :param k: number of top K results to be considered as hits\n",
356 | " :return: Hits@K scoreH\n",
357 | " \"\"\"\n",
358 | " assert predictions.size()[0] == ground_truth_idx.size()[0] # has the same batch_size\n",
359 | " \n",
360 | " zero_tensor = torch.tensor([0], device=device)\n",
361 | " one_tensor = torch.tensor([1], device=device)\n",
362 | " _, indices = predictions.topk(k, largest=False) # indices: [batch_size, k]\n",
363 | " where_flags = indices == ground_truth_idx # where_flags: [batch_size, k], type: bool\n",
364 | " hits = torch.where(where_flags, one_tensor, zero_tensor).sum().item()\n",
365 | " return hits\n",
366 | "\n",
367 | "def cal_mrr(predictions: torch.Tensor, ground_truth_idx: torch.Tensor) -> float:\n",
368 | " \"\"\"Calculates mean reciprocal rank (MRR) for given predictions and ground truth values.\n",
369 | "\n",
370 | " :param predictions: BxN tensor of prediction values where B is batch size and N number of classes. Predictions\n",
371 | " must be sorted in class ids order\n",
372 | " :param ground_truth_idx: Bx1 tensor with index of ground truth class\n",
373 | " :return: Mean reciprocal rank score\n",
374 | " \"\"\"\n",
375 | " assert predictions.size(0) == ground_truth_idx.size(0)\n",
376 | "\n",
377 | " indices = predictions.argsort()\n",
378 | " return (1.0 / (indices == ground_truth_idx).nonzero()[:, 1].float().add(1.0)).sum().item()\n"
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "## Inference Operation\n",
386 | "\n",
387 | "Run the inference process for the model, i.e., iterate through the validation or test set and compute the metric."
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 8,
393 | "metadata": {},
394 | "outputs": [],
395 | "source": [
396 | "def run_testing(\n",
397 | " model: TransE,\n",
398 | " dataloader: DataLoader,\n",
399 | " ent_num: int,\n",
400 | " device: torch.device,\n",
401 | ") -> Tuple[float, float, float, float]:\n",
402 | " \"\"\"Run test programs against Trans models\n",
403 | "\n",
404 | " :param model: TransE model\n",
405 | " :param ent_num: Number of entities in the dataset\n",
406 | " :return: _description_\n",
407 | " \"\"\"\n",
408 | " hits_at_1 = 0.0\n",
409 | " hits_at_3 = 0.0\n",
410 | " hits_at_10 = 0.0\n",
411 | " mrr = 0.0\n",
412 | " examples_count = 0\n",
413 | " \n",
414 | " # entity_ids = [[0, 1, 2, ..., ent_num]], shape: [1, ent_num]\n",
415 | " entitiy_ids = torch.arange(0, ent_num, device=device).unsqueeze(0)\n",
416 | " for i, batch in enumerate(dataloader):\n",
417 | " # batch: [3, batch_size]\n",
418 | " heads, rels, tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n",
419 | " batch_size = heads.size()[0]\n",
420 | " all_entities = entitiy_ids.repeat(batch_size, 1) # all_entities: [batch_size, ent_num]\n",
421 | " # heads: [batch_size,] -> [batch_size, 1] -> [batch_size, ent_num]\n",
422 | " heads_expanded = heads.reshape(-1, 1).repeat(1, ent_num) # _expanded: [batch_size, ent_num]\n",
423 | " rels_expanded = rels.reshape(-1, 1).repeat(1, ent_num)\n",
424 | " tails_expanded = tails.reshape(-1, 1).repeat(1, ent_num)\n",
425 | " # check all possible tails\n",
426 | " triplets = torch.stack([heads_expanded, rels_expanded, all_entities], dim=2).reshape(-1, 3) # triplets: [batch_size * ent_num, 3]\n",
427 | " tails_predictions = model.predict(triplets).reshape(batch_size, -1) # tails_prediction: [batch_size, ent_num]\n",
428 | " # check all possible heads\n",
429 | " triplets = torch.stack([all_entities, rels_expanded, tails_expanded], dim=2).reshape(-1, 3)\n",
430 | " heads_predictions = model.predict(triplets).reshape(batch_size, -1) # heads_prediction: [batch_size, ent_num]\n",
431 | " \n",
432 | " # Concept preditions\n",
433 | " predictions = torch.cat([tails_predictions, heads_predictions], dim=0) # predictions: [batch_size * 2, ent_num]\n",
434 | " ground_truth_entity_id = torch.cat([tails.reshape(-1, 1), heads.reshape(-1, 1)], dim=0) # [batch_size * 2, 1]\n",
435 | " # calculate metrics\n",
436 | " hits_at_1 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=1)\n",
437 | " hits_at_3 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=3)\n",
438 | " hits_at_10 += cal_hits_at_k(predictions, ground_truth_entity_id, device=device, k=10)\n",
439 | " mrr += cal_mrr(predictions, ground_truth_entity_id)\n",
440 | " \n",
441 | " examples_count += predictions.size()[0]\n",
442 | " \n",
443 | " hits_at_1_score = hits_at_1 / examples_count * 100\n",
444 | " hits_at_3_score = hits_at_3 / examples_count * 100\n",
445 | " hits_at_10_score = hits_at_10 / examples_count * 100\n",
446 | " mrr_score = mrr / examples_count * 100\n",
447 | " \n",
448 | " return hits_at_1_score, hits_at_3_score, hits_at_10_score, mrr_score"
449 | ]
450 | },
451 | {
452 | "cell_type": "markdown",
453 | "metadata": {},
454 | "source": [
455 | "## Checkpoint\n",
456 | "\n",
457 | "During the training process, if the model outperforms the best score on the validation set, the model state at that time should be transformed into a checkpoint and saved to disk.\n",
458 | "\n",
459 | "The process of storing and loading checkpoints is simply encapsulated here."
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 9,
465 | "metadata": {},
466 | "outputs": [],
467 | "source": [
468 | "class CheckpointFormat(BaseModel):\n",
469 | " model_state_dict: dict\n",
470 | " optim_state_dict: dict\n",
471 | " epoch_id: int\n",
472 | " best_score: float\n",
473 | "\n",
474 | "\n",
475 | "def save_checkpoint(model: TransE,\n",
476 | " optimzer: torch.optim.Optimizer,\n",
477 | " epoch_id: int,\n",
478 | " best_score: float,\n",
479 | " train_conf: TrainConf):\n",
480 | " ckpt = CheckpointFormat(\n",
481 | " model_state_dict=model.state_dict(),\n",
482 | " optim_state_dict=optimzer.state_dict(),\n",
483 | " epoch_id=epoch_id,\n",
484 | " best_score=best_score\n",
485 | " )\n",
486 | " torch.save(ckpt.dict(), train_conf.checkpoint_path)\n",
487 | "\n",
488 | "\n",
489 | "def load_checkpoint(train_conf: TrainConf) -> CheckpointFormat:\n",
490 | " ckpt = torch.load(train_conf.checkpoint_path)\n",
491 | " return CheckpointFormat.parse_obj(ckpt)\n",
492 | " "
493 | ]
494 | },
495 | {
496 | "cell_type": "markdown",
497 | "metadata": {},
498 | "source": [
499 | "## Training Operation\n",
500 | "\n",
501 | "The process of training a model using a dataset.\n",
502 | "\n",
503 | "In the real library, this part of the functionality is encapsulated in a `Trainer` class."
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 10,
509 | "metadata": {},
510 | "outputs": [],
511 | "source": [
512 | "def run_training(model: TransE,\n",
513 | " train_conf: TrainConf,\n",
514 | " params: HyperParam,\n",
515 | " device: torch.device,\n",
516 | " dataset_conf: DatasetConf,\n",
517 | " entity2id: Dict[str, int],\n",
518 | " rel2id: Dict[str, int]):\n",
519 | " # 准备数据集\n",
520 | " train_dataset = KRLDataset(dataset_conf, 'train', entity2id, rel2id)\n",
521 | " valid_dataset = KRLDataset(dataset_conf, 'valid', entity2id, rel2id)\n",
522 | " # dataset -> dataloader\n",
523 | " train_dataloder = DataLoader(train_dataset, params.batch_size)\n",
524 | " valid_dataloder = DataLoader(valid_dataset, params.valid_batch_size)\n",
525 | " # 负采样器\n",
526 | " train_neg_sampler = RandomNegativeSampler(train_dataset, device)\n",
527 | " valid_neg_sampler = RandomNegativeSampler(valid_dataset, device)\n",
528 | " # 准备训练的工具\n",
529 | " optimzer = torch.optim.Adam(model.parameters(), lr=params.learning_rate)\n",
530 | " min_valid_loss = 10000.0\n",
531 | " best_score = 0.0\n",
532 | " # training loop\n",
533 | " for epoch_id in range(1, params.epoch_size + 1):\n",
534 | " print(\"Starting epoch: \", epoch_id)\n",
535 | " loss_sum = 0\n",
536 | " model.train()\n",
537 | " for i, batch in enumerate(train_dataloder):\n",
538 | " # 获取一个 batch 的训练资料\n",
539 | " pos_heads, pos_rels, pos_tails = batch[0].to(device), batch[1].to(device), batch[2].to(device)\n",
540 | " pos_triples = torch.stack([pos_heads, pos_rels, pos_tails], dim=1) # pos_triples: [batch_size, 3]\n",
541 | " neg_triples = train_neg_sampler.neg_sample(pos_heads, pos_rels, pos_tails) # neg_triples: [batch_size, 3]\n",
542 | " optimzer.zero_grad()\n",
543 | " # 计算 loss\n",
544 | " loss, pos_dist, neg_dist = model(pos_triples, neg_triples)\n",
545 | " loss.backward()\n",
546 | " loss_sum += loss.cpu().item()\n",
547 | " # update model\n",
548 | " optimzer.step()\n",
549 | " \n",
550 | " if epoch_id % params.valid_freq == 0:\n",
551 | " model.eval()\n",
552 | " _, _, hits_at_10, _ = run_testing(model, valid_dataloder, len(valid_dataset.entity2id), device)\n",
553 | " score = hits_at_10\n",
554 | " print('valid hits@10:', score)\n",
555 | " if score > best_score:\n",
556 | " best_score = score\n",
557 | " print('best score of valid: ', best_score)\n",
558 | " save_checkpoint(model, optimzer, epoch_id, best_score, train_conf)\n",
559 | " "
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": 11,
565 | "metadata": {},
566 | "outputs": [],
567 | "source": [
568 | "def get_device() -> torch.device:\n",
569 | " return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
570 | "\n",
571 | "\n",
572 | "def main(dataset_conf: DatasetConf,\n",
573 | " params: HyperParam,\n",
574 | " train_conf: TrainConf,\n",
575 | " device: torch.device):\n",
576 | " entity2id, rel2id = create_mapping(dataset_conf)\n",
577 | " device = get_device()\n",
578 | " ent_num = len(entity2id)\n",
579 | " rel_num = len(rel2id)\n",
580 | " model = TransE(ent_num, rel_num, device,\n",
581 | " norm=params.norm,\n",
582 | " embed_dim=params.embed_dim,\n",
583 | " margin=params.margin)\n",
584 | " model = model.to(device)\n",
585 | " run_training(model, train_conf, params, device, dataset_conf, entity2id, rel2id)\n",
586 | " \n",
587 | " # Testing the best checkpoint on test dataset\n",
588 | " ckpt = load_checkpoint(train_conf)\n",
589 | " model.load_state_dict(ckpt.model_state_dict)\n",
590 | " model = model.to(device)\n",
591 | " test_dataset = KRLDataset(dataset_conf, 'test', entity2id, rel2id)\n",
592 | " test_dataloder = DataLoader(test_dataset, params.valid_batch_size)\n",
593 | " hits_at_1, hits_at_3, hits_at_10, mrr = run_testing(model, test_dataloder, ent_num, device)\n",
594 | " \n",
595 | " # write results\n",
596 | " with open(train_conf.metric_result_path, 'w') as f:\n",
597 | " f.write(f'dataset: {dataset_conf.dataset_name}\\n')\n",
598 | " f.write(f'Hits@1: {hits_at_1}\\n')\n",
599 | " f.write(f'Hits@3: {hits_at_3}\\n')\n",
600 | " f.write(f'Hits@10: {hits_at_10}\\n')\n",
601 | " f.write(f'MRR: {mrr}\\n')"
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "metadata": {},
607 | "source": [
608 | "## Begin\n",
609 | "\n",
610 | "Instantiate the configuration and call the main function.\n",
611 | "\n",
612 | "Next, you need the path of the dataset and where to save the checkpoints."
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "execution_count": 12,
618 | "metadata": {},
619 | "outputs": [],
620 | "source": [
621 | "fb15k_dataset_conf = DatasetConf(\n",
622 | " dataset_name='FB15K',\n",
623 | " base_dir='/root/yubin/dataset/KRL/master/FB15k/' # TODO: change it!\n",
624 | ")\n",
625 | "\n",
626 | "fb15k_hyper_params = HyperParam(\n",
627 | " valid_freq=5,\n",
628 | " batch_size=128,\n",
629 | " valid_batch_size=64,\n",
630 | " learning_rate=0.001,\n",
631 | " epoch_size=500,\n",
632 | " embed_dim=50,\n",
633 | " norm=1,\n",
634 | " margin=2.0\n",
635 | ")\n",
636 | "\n",
637 | "fb15k_train_conf = TrainConf(\n",
638 | " checkpoint_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k.ckpt', # TODO: change it!\n",
639 | " metric_result_path='/root/sharespace/yubin/papers/KRL/scratch/TransX/tmp/transe_fb15k_metrics.txt' # TODO: change it!\n",
640 | ")\n",
641 | "\n",
642 | "device = get_device()"
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": null,
648 | "metadata": {},
649 | "outputs": [],
650 | "source": [
651 | "main(fb15k_dataset_conf, fb15k_hyper_params, fb15k_train_conf, device)"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": 14,
657 | "metadata": {},
658 | "outputs": [],
659 | "source": [
660 | "ckpt = torch.load(fb15k_train_conf.checkpoint_path)"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": 15,
666 | "metadata": {},
667 | "outputs": [
668 | {
669 | "data": {
670 | "text/plain": [
671 | "{'model_state_dict': {'ent_embedding.weight': tensor([[ 0.0126, 1.1495, 0.7716, ..., -0.8967, 0.3144, -0.2118],\n",
672 | " [ 0.6201, -1.0493, 0.7837, ..., 0.2150, 0.2489, 0.0075],\n",
673 | " [-0.2746, -0.2592, 0.1011, ..., -1.1602, 0.1287, 0.4214],\n",
674 | " ...,\n",
675 | " [ 0.8894, 0.2856, 0.2504, ..., 0.8878, 0.9393, -0.2801],\n",
676 | " [ 0.8036, 0.7737, 0.1246, ..., 0.2968, -0.0092, -0.2363],\n",
677 | " [-1.0176, -0.0581, -0.5224, ..., -0.3004, -1.3833, -0.6132]],\n",
678 | " device='cuda:0'),\n",
679 | " 'rel_embedding.weight': tensor([[-0.0909, -0.4988, -0.7200, ..., -0.0086, 0.0742, -0.1348],\n",
680 | " [-0.9219, 0.5674, 0.4704, ..., -1.6722, -1.1977, -0.0505],\n",
681 | " [-0.6558, -0.0297, -0.1181, ..., -0.2469, -0.2535, -0.6061],\n",
682 | " ...,\n",
683 | " [ 0.3357, 0.2950, -0.2752, ..., 0.5626, 0.3279, -0.5322],\n",
684 | " [-0.3469, -0.0041, -0.7309, ..., 0.1166, 0.0848, 0.3135],\n",
685 | " [ 0.0369, -0.0189, -0.0132, ..., 0.1579, -0.0358, -0.0636]],\n",
686 | " device='cuda:0')},\n",
687 | " 'optim_state_dict': {'state': {0: {'step': tensor(1887500.),\n",
688 | " 'exp_avg': tensor([[-5.6052e-45, 5.6052e-45, -5.6052e-45, ..., -5.6052e-45,\n",
689 | " -5.6052e-45, -5.6052e-45],\n",
690 | " [ 5.6052e-45, -5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
691 | " -5.6052e-45, -5.6052e-45],\n",
692 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
693 | " -5.6052e-45, 5.6052e-45],\n",
694 | " ...,\n",
695 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
696 | " 5.6052e-45, 5.6052e-45],\n",
697 | " [-5.6052e-45, 5.6052e-45, -5.6052e-45, ..., 5.6052e-45,\n",
698 | " 5.6052e-45, 5.6052e-45],\n",
699 | " [-3.9882e-41, 3.9814e-41, 3.9882e-41, ..., -3.9882e-41,\n",
700 | " -3.9882e-41, 3.9882e-41]], device='cuda:0'),\n",
701 | " 'exp_avg_sq': tensor([[5.3320e-10, 5.3320e-10, 5.3320e-10, ..., 5.3320e-10, 5.3320e-10,\n",
702 | " 5.3320e-10],\n",
703 | " [1.1443e-09, 1.1443e-09, 1.1443e-09, ..., 1.1443e-09, 1.8754e-08,\n",
704 | " 1.8754e-08],\n",
705 | " [3.3255e-08, 3.3255e-08, 2.2694e-08, ..., 3.3255e-08, 3.3255e-08,\n",
706 | " 2.2694e-08],\n",
707 | " ...,\n",
708 | " [3.6491e-13, 3.6491e-13, 3.6491e-13, ..., 3.6491e-13, 3.6491e-13,\n",
709 | " 3.6491e-13],\n",
710 | " [2.1251e-13, 2.1251e-13, 2.1251e-13, ..., 2.1251e-13, 2.1251e-13,\n",
711 | " 2.1251e-13],\n",
712 | " [5.5608e-08, 5.5608e-08, 5.5608e-08, ..., 5.5608e-08, 5.5608e-08,\n",
713 | " 5.5608e-08]], device='cuda:0')},\n",
714 | " 1: {'step': tensor(1887500.),\n",
715 | " 'exp_avg': tensor([[-5.6052e-45, -5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
716 | " 5.6052e-45, 5.6052e-45],\n",
717 | " [ 5.6052e-45, -5.6052e-45, 5.6052e-45, ..., 5.6052e-45,\n",
718 | " 5.6052e-45, 5.6052e-45],\n",
719 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
720 | " 5.6052e-45, -5.6052e-45],\n",
721 | " ...,\n",
722 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., 5.6052e-45,\n",
723 | " 5.6052e-45, 5.6052e-45],\n",
724 | " [ 5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
725 | " -5.6052e-45, 5.6052e-45],\n",
726 | " [-5.6052e-45, 5.6052e-45, 5.6052e-45, ..., -5.6052e-45,\n",
727 | " -5.6052e-45, 5.6052e-45]], device='cuda:0'),\n",
728 | " 'exp_avg_sq': tensor([[7.0065e-43, 3.1465e-32, 3.1465e-32, ..., 4.8395e-37, 7.0065e-43,\n",
729 | " 7.0065e-43],\n",
730 | " [2.6102e-08, 2.8141e-08, 1.2741e-07, ..., 2.5118e-10, 2.0500e-08,\n",
731 | " 1.0558e-07],\n",
732 | " [2.6078e-24, 7.9196e-29, 7.0065e-43, ..., 2.6077e-24, 7.9196e-29,\n",
733 | " 7.0065e-43],\n",
734 | " ...,\n",
735 | " [6.1684e-23, 1.7540e-34, 7.0065e-43, ..., 7.0065e-43, 1.7540e-34,\n",
736 | " 1.7540e-34],\n",
737 | " [7.0065e-43, 4.4173e-41, 4.4173e-41, ..., 7.0065e-43, 7.0065e-43,\n",
738 | " 7.0065e-43],\n",
739 | " [7.0065e-43, 7.0065e-43, 7.0065e-43, ..., 7.0065e-43, 7.0065e-43,\n",
740 | " 7.0065e-43]], device='cuda:0')}},\n",
741 | " 'param_groups': [{'lr': 0.001,\n",
742 | " 'betas': (0.9, 0.999),\n",
743 | " 'eps': 1e-08,\n",
744 | " 'weight_decay': 0,\n",
745 | " 'amsgrad': False,\n",
746 | " 'maximize': False,\n",
747 | " 'foreach': None,\n",
748 | " 'capturable': False,\n",
749 | " 'differentiable': False,\n",
750 | " 'fused': False,\n",
751 | " 'params': [0, 1]}]},\n",
752 | " 'epoch_id': 500,\n",
753 | " 'best_score': 39.291}"
754 | ]
755 | },
756 | "execution_count": 15,
757 | "metadata": {},
758 | "output_type": "execute_result"
759 | }
760 | ],
761 | "source": [
762 | "ckpt"
763 | ]
764 | }
765 | ],
766 | "metadata": {
767 | "kernelspec": {
768 | "display_name": "Python 3 (ipykernel)",
769 | "language": "python",
770 | "name": "python3"
771 | },
772 | "language_info": {
773 | "codemirror_mode": {
774 | "name": "ipython",
775 | "version": 3
776 | },
777 | "file_extension": ".py",
778 | "mimetype": "text/x-python",
779 | "name": "python",
780 | "nbconvert_exporter": "python",
781 | "pygments_lexer": "ipython3",
782 | "version": "3.9.15"
783 | },
784 | "vscode": {
785 | "interpreter": {
786 | "hash": "0418effca45178467ac68c18e34d93809a092be692e0a4443d8690099b71f4bc"
787 | }
788 | }
789 | },
790 | "nbformat": 4,
791 | "nbformat_minor": 4
792 | }
793 |
--------------------------------------------------------------------------------
/typer_app.py:
--------------------------------------------------------------------------------
1 | import typer
2 |
3 | from krl.typer_apps.rescal import app as rescal_app
4 | from krl.typer_apps.transe import app as transe_app
5 | from krl.typer_apps.transh import app as transh_app
6 | from krl.typer_apps.distmult import app as distmult_app
7 | from krl.typer_apps.transr import app as transr_app
8 |
9 |
10 |
11 | app = typer.Typer()
12 |
13 | app.add_typer(rescal_app, name='RESCAL')
14 | app.add_typer(transe_app, name='TransE')
15 | app.add_typer(transh_app, name='TransH')
16 | app.add_typer(distmult_app, name='DistMult')
17 | app.add_typer(transr_app, name='TransR')
18 |
19 |
20 | if __name__ == '__main__':
21 | app()
22 |
--------------------------------------------------------------------------------