├── toolbox ├── .gitignore ├── utils │ ├── Optimizer.py │ ├── __init__.py │ ├── SpeedUp.py │ ├── RandomSeeds.py │ ├── Time.py │ ├── Embed.py │ ├── DefaultDict.py │ ├── VisualizeStore.py │ ├── LaTeX.py │ ├── Download.py │ ├── Log.py │ ├── AutoML.py │ ├── Progbar.py │ └── Framework.py ├── web │ ├── log_app │ │ ├── server │ │ │ ├── __init__.py │ │ │ ├── app_utils.py │ │ │ └── data_container.py │ │ ├── README.md │ │ ├── static │ │ │ ├── img │ │ │ │ ├── chart.ico │ │ │ │ ├── table.ico │ │ │ │ └── loading.gif │ │ │ ├── fonts │ │ │ │ ├── glyphicons-halflings-regular.eot │ │ │ │ ├── glyphicons-halflings-regular.ttf │ │ │ │ ├── glyphicons-halflings-regular.woff │ │ │ │ └── glyphicons-halflings-regular.woff2 │ │ │ ├── css │ │ │ │ ├── bootstrap-table-filter-control.css │ │ │ │ ├── bootstrap-table-reorder-rows.css │ │ │ │ └── table.css │ │ │ └── js │ │ │ │ ├── chart.js │ │ │ │ ├── bootstrap-table-reorder-rows.js │ │ │ │ └── utils.js │ │ ├── __init__.py │ │ ├── templates │ │ │ ├── folder_img.html │ │ │ ├── folder.html │ │ │ └── multi_chart.html │ │ ├── main.py │ │ ├── output │ │ │ └── default.cfg │ │ ├── line_app.py │ │ ├── multi_char_app.py │ │ ├── app.py │ │ └── folder_app.py │ ├── __init__.py │ └── log_server │ │ ├── README.md │ │ ├── __init__.py │ │ └── main.py ├── nn │ ├── functional │ │ ├── __init__.py │ │ └── complex.py │ ├── __init__.py │ ├── Flatten.py │ ├── README.md │ ├── Highway.py │ ├── DistMult.py │ ├── RESCAL.py │ ├── GCN.py │ ├── GAT.py │ ├── Complex.py │ ├── LorentzE.py │ ├── TuckERT.py │ ├── MobiusEmbedding.py │ ├── ConvE.py │ ├── TransE.py │ ├── TuckER.py │ ├── TuckERTNT.py │ ├── TuckERTTR.py │ ├── Rotate3D.py │ ├── TuckERTTT.py │ ├── CoPER.py │ ├── TuckERCPD.py │ ├── BlaschkE.py │ ├── Regularizer.py │ ├── ParamE.py │ ├── ComplexTuckER.py │ ├── MobiusE.py │ └── TuckerMobiusE.py ├── README_en.md ├── exp │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── JsonConfig.py │ ├── Notebook.py │ ├── classic │ │ ├── experiment_demo.py │ │ └── train_CartPole.py │ ├── DistributeSchema.py │ ├── Experiment.py │ └── OutputSchema.py ├── evaluate │ ├── __init__.py │ ├── EntityAlignment.py │ ├── GatherMetric.py │ └── Leaderboard.py ├── data │ ├── __init__.py │ ├── README.md │ ├── TripleDataset.py │ ├── ComplementaryDataset.py │ ├── FixWindowNegSamplingDataset.py │ ├── LinkPredictDataset.py │ └── ScoringAllDataset.py ├── __init__.py ├── game │ ├── __init__.py │ ├── SinglePlayerGame.py │ ├── Player.py │ └── MultiPlayerGame.py ├── optim │ ├── __init__.py │ └── lr_scheduler.py ├── requirements-web.txt ├── requirements.txt ├── CheetSheet.md ├── README.md ├── cli │ └── clean_output.py └── README_template_en.md ├── vis.png ├── requirements.txt └── .gitignore /toolbox/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /toolbox/utils/Optimizer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /toolbox/web/log_app/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /toolbox/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["complex"] 2 | -------------------------------------------------------------------------------- /toolbox/README_en.md: -------------------------------------------------------------------------------- 1 | # KGE Toolbox 2 | 3 | 4 | 5 | ## CheatSheet -------------------------------------------------------------------------------- /toolbox/exp/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Experiment", "OutputSchema"] 2 | -------------------------------------------------------------------------------- /vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/vis.png -------------------------------------------------------------------------------- /toolbox/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["EntityAlignment", "LinkPredict", "Evaluate"] 2 | -------------------------------------------------------------------------------- /toolbox/data/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DataSchema", "DatasetSchema", "FixWindowNegSamplingDataset"] 2 | -------------------------------------------------------------------------------- /toolbox/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Log", "Progbar", "RandomSeeds", "ModelParamStore", "DefaultDict", "VisualizeStore.py"] 2 | -------------------------------------------------------------------------------- /toolbox/web/log_app/README.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | ```shell 4 | python -m toolbox.web.log_app.main --log_dir="output" 5 | ``` 6 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/web/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/web/log_server/README.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | ```shell 4 | python -m toolbox.web.log_server.main --log_dir="output" 5 | ``` 6 | -------------------------------------------------------------------------------- /toolbox/game/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/11/7 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/requirements-web.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pytorch 3 | numpy 4 | scikit-learn 5 | scipy 6 | flask 7 | Flask-Cors 8 | jieba 9 | rouge 10 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/chart.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/chart.ico -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/table.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/table.ico -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/loading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/loading.gif -------------------------------------------------------------------------------- /toolbox/web/log_server/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/exp/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/27 5 | @description: 使用配置文件来管理脚本的各种超参数 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 请自己根据cuda版本安装 2 | tqdm 3 | numpy 4 | scikit-learn 5 | scipy 6 | click 7 | tensorboardX 8 | pandas 9 | pathlib 10 | pyyaml -------------------------------------------------------------------------------- /toolbox/data/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## TODO 3 | 1. Maybe we can create data dir and hold some classic datasets to prevent downloading from the source to accelerate coding process. 4 | -------------------------------------------------------------------------------- /toolbox/web/log_app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 通过 web api 获取远程的实验日志服务器的数据,呈现到前端 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/nn/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["CapsE", "TransE", "ConvE", "GAT", "GCN", "Flatten", "CoPER", "OctonionE", "RotatE", "EchoE", "QuatE", "HAKE", "TuckER", "Complex", "Highway", "DistMult"] 2 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /toolbox/nn/Flatten.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Flatten(nn.Module): 5 | def forward(self, x): 6 | n = x.size(0) 7 | x = x.view(n, -1) 8 | return x 9 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff2 -------------------------------------------------------------------------------- /toolbox/evaluate/EntityAlignment.py: -------------------------------------------------------------------------------- 1 | 2 | def entity_alignment(predictions): 3 | """ 4 | predictions : torch.Tensor, similarity matrix of shape (entity_count, entity_count) 5 | """ 6 | pass -------------------------------------------------------------------------------- /toolbox/CheetSheet.md: -------------------------------------------------------------------------------- 1 | # 指定 gpu 2 | ```shell 3 | CUDA_VISIBLE_DEVICES=3 4 | ``` 5 | 6 | # 内存分析,打印每一行代码执行前后的内存变化 7 | ```shell 8 | pip install memory_profiler psutil 9 | python -m memory_profiler main.py 10 | ``` -------------------------------------------------------------------------------- /toolbox/utils/SpeedUp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def speed_up(): 5 | torch.cuda.emptyCache() 6 | 7 | def how_to_speed_up(): 8 | info = """ 9 | 1. for your dataloader: pin_memory == True, num_worker >= 8 10 | 2. choose faster optimizer: AdamW 11 | """ 12 | print(info) 13 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/css/bootstrap-table-filter-control.css: -------------------------------------------------------------------------------- 1 | /** 2 | * @author: Dennis Hernández 3 | * @webSite: http://djhvscf.github.io/Blog 4 | * @version: v2.1.1 5 | */ 6 | 7 | .no-filter-control { 8 | height: 34px; 9 | } 10 | 11 | .filter-control { 12 | margin: 0 2px 2px 2px; 13 | } -------------------------------------------------------------------------------- /toolbox/web/log_app/templates/folder_img.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Image 6 | 7 | 8 | 9 |

Image for {{ img_path }}

10 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | sphinx>=2.1.2 3 | networkx>=2.2 4 | setuptools>=40.8.0 5 | matplotlib>=3.0.3 6 | numpy>=1.16.2 7 | seaborn>=0.9.0 8 | scikit_learn>=0.20.3 9 | hyperopt>=0.2.1 10 | pathlib>=1.0.1 11 | numpydoc>=0.9.1 12 | sphinx-gallery>=0.3.1 13 | sphinx-rtd-theme>=0.4.3 14 | pytest>=3.6 15 | pyyaml>=5.3.1 -------------------------------------------------------------------------------- /toolbox/nn/README.md: -------------------------------------------------------------------------------- 1 | # nn 2 | 3 | This package contains classic knowledge graph embedding models. 4 | 5 | ## Symbols 6 | 7 | The comments in the source code use these symbols: 8 | 9 | ``` 10 | T: number of triples 11 | B: batch size 12 | d: dimension 13 | d_e: dimension of entity embedding 14 | d_r: dimension of relation embedding 15 | ``` 16 | -------------------------------------------------------------------------------- /toolbox/nn/Highway.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Highway(nn.Module): 6 | def __init__(self, x_hidden): 7 | super(Highway, self).__init__() 8 | self.lin = nn.Linear(x_hidden, x_hidden) 9 | 10 | def forward(self, x1, x2): 11 | gate = torch.sigmoid(self.lin(x1)) 12 | x = torch.mul(gate, x2) + torch.mul(1 - gate, x1) 13 | return x 14 | -------------------------------------------------------------------------------- /toolbox/game/SinglePlayerGame.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | 8 | 9 | class SinglePlayerGameEnv: 10 | def __init__(self): 11 | pass 12 | 13 | def reset(self): 14 | # return state 15 | pass 16 | 17 | def step(self, action): 18 | # return state, reward, done, info 19 | pass 20 | 21 | def state(self): 22 | # return state 23 | pass 24 | 25 | def render(self, show=False): 26 | pass 27 | -------------------------------------------------------------------------------- /toolbox/utils/RandomSeeds.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 随机种子 6 | """ 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def set_seeds(seed=1234): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | if torch.cuda.is_available(): 18 | torch.cuda.manual_seed_all(seed) 19 | if seed == 0: 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | -------------------------------------------------------------------------------- /toolbox/README.md: -------------------------------------------------------------------------------- 1 | # KGE Toolbox 2 | 3 | 知识图谱嵌入工具箱,助力快速实验出成果 4 | 5 | 1. 通用 6 | 1. 命令行参数解析 7 | 2. 日志 8 | 3. 进度条 9 | 4. 随机种子 10 | 5. TensorBoard监控 11 | 6. 超参数自动搜索AutoML 12 | 7. 梯度累加(应对小内存gpu对batch_size的限制) 13 | 8. 中断训练、恢复训练 14 | 2. 知识图谱嵌入领域专用工具 15 | 1. 嵌入降维可视化 16 | 2. 数据集、常用数据预处理 17 | 3. 链接预测任务、实体对齐任务(自动生成对应的数据集并训练) 18 | 4. 测试指标(Hit@k、MR、MMR、AUC) 19 | 5. 经典KGE模型的PyTorch版复现 20 | 21 | 22 | ## CheatSheet 23 | 24 | CUDA_VISIBLE_DEVICES=3 25 | 26 | 内存分析,打印每一行代码执行前后的内存变化 27 | ```shell 28 | pip install memory_profiler psutil 29 | python -m memory_profiler main.py 30 | ``` -------------------------------------------------------------------------------- /toolbox/game/Player.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | 8 | 9 | class Player: 10 | def __init__(self): 11 | self.player_index = None 12 | 13 | def get_player_index(self): 14 | # for multi player game 15 | return self.player_index 16 | 17 | def set_player_index(self, player_index): 18 | # for multi player game 19 | self.player_index = player_index 20 | 21 | def get_action(self, state, **kwargs): 22 | raise NotImplementedError 23 | 24 | def reset(self): 25 | # for self-play player 26 | pass 27 | -------------------------------------------------------------------------------- /toolbox/cli/clean_output.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/4/25 5 | @description: null 6 | """ 7 | import os 8 | 9 | import click 10 | 11 | from toolbox.exp.OutputSchema import OutputSchema, Cleaner 12 | 13 | 14 | @click.command() 15 | @click.option("--output_dir", type=str, default="output", help="Which dir to clean") 16 | def main(output_dir): 17 | dirs = os.listdir(output_dir) 18 | for d in dirs: 19 | output = OutputSchema(str(d)) 20 | cleaner = Cleaner(output.pathSchema) 21 | cleaner.remove_non_best_checkpoint_and_model() 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /toolbox/data/TripleDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/30 5 | @description: null 6 | """ 7 | from typing import List, Tuple 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class TripleDataset(Dataset): 14 | def __init__(self, triples_ids: List[Tuple[int, int, int]]): 15 | self.triples_ids = triples_ids 16 | 17 | def __len__(self): 18 | return len(self.triples_ids) 19 | 20 | def __getitem__(self, idx): 21 | h, r, t = self.triples_ids[idx] 22 | h = torch.LongTensor([h]) 23 | r = torch.LongTensor([r]) 24 | t = torch.LongTensor([t]) 25 | return h, r, t 26 | -------------------------------------------------------------------------------- /toolbox/exp/Notebook.py: -------------------------------------------------------------------------------- 1 | from toolbox.exp.OutputSchema import OutputSchema 2 | from toolbox.utils.ModelParamStore import ModelParamStoreSchema 3 | from toolbox.utils.VisualizeStore import VisualizeStoreSchema 4 | 5 | 6 | class Self: 7 | pass 8 | 9 | 10 | def init_by_output(self: object, output: OutputSchema): 11 | self.debug = output.logger.debug 12 | self.log = output.logger.info 13 | self.warn = output.logger.warn 14 | self.error = output.logger.error 15 | self.critical = output.logger.critical 16 | self.success = output.logger.success 17 | self.fail = output.logger.failed 18 | self.vis = VisualizeStoreSchema(str(output.pathSchema.dir_path_visualize)) 19 | self.store = ModelParamStoreSchema(output.pathSchema) 20 | -------------------------------------------------------------------------------- /toolbox/utils/Time.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | from functools import wraps 3 | from time import time 4 | 5 | 6 | class benchmark(object): 7 | 8 | def __init__(self, msg, fmt="%0.3g"): 9 | self.msg = msg 10 | self.fmt = fmt 11 | 12 | def __enter__(self): 13 | self.start = timer() 14 | return self 15 | 16 | def __exit__(self, *args): 17 | t = timer() - self.start 18 | print(("%s : " + self.fmt + " seconds") % (self.msg, t)) 19 | self.time = t 20 | 21 | 22 | def timing(f): 23 | @wraps(f) 24 | def wrap(*args, **kw): 25 | ts = time() 26 | result = f(*args, **kw) 27 | te = time() 28 | print(f'func:{f.__name__!r} args:[{args!r}, {kw!r}] took: {te - ts:2.4f} sec') 29 | return result 30 | return wrap 31 | -------------------------------------------------------------------------------- /toolbox/nn/functional/complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mobius_mul_with_unit_norm(Q_1, Q_2): 5 | a_h = Q_1 # = {a_h + b_h i + c_h j + d_h k : a_r, b_r, c_r, d_r \in R^k} 6 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 7 | 8 | # Normalize the relation to eliminate the scaling effect 9 | denominator = torch.sqrt(a_r ** 2 + b_r ** 2 + c_r ** 2 + d_r ** 2) 10 | p = a_r / denominator 11 | q = b_r / denominator 12 | u = c_r / denominator 13 | v = d_r / denominator 14 | # Q'=E Hamilton product R 15 | h_r = (a_h * p + q) / (a_h * u + v) 16 | return h_r 17 | 18 | 19 | def mobius_mul(Q_1, Q_2): 20 | a_h = Q_1 # = {a_h : a_r, b_r, c_r, d_r \in R^k} 21 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 22 | h_r = (a_h * a_r + b_r) / (a_h * c_r + d_r) 23 | return h_r 24 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/css/bootstrap-table-reorder-rows.css: -------------------------------------------------------------------------------- 1 | .reorder_rows_onDragClass td { 2 | background-color: #eee; 3 | -webkit-box-shadow: 11px 5px 12px 2px #333, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 4 | -webkit-box-shadow: 6px 3px 5px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 5 | -moz-box-shadow: 6px 4px 5px 1px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 6 | -box-shadow: 6px 4px 5px 1px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 7 | } 8 | 9 | .reorder_rows_onDragClass td:last-child { 10 | -webkit-box-shadow: 8px 7px 12px 0 #333, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 11 | -webkit-box-shadow: 1px 8px 6px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 12 | -moz-box-shadow: 0 9px 4px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset, -1px 0 0 #ccc inset; 13 | -box-shadow: 0 9px 4px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset, -1px 0 0 #ccc inset; 14 | } -------------------------------------------------------------------------------- /toolbox/evaluate/GatherMetric.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/23 5 | @description: null 6 | """ 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | def __str__(self): 32 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 33 | return fmtstr.format(**self.__dict__) 34 | -------------------------------------------------------------------------------- /toolbox/utils/Embed.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_vec(entities_embedding, id_list: List[int], embedding_dim=200, device="cuda"): 7 | tensor = torch.LongTensor(id_list).view(-1, 1).to(device) 8 | return entities_embedding(tensor).view(-1, embedding_dim).cpu().detach().numpy() 9 | 10 | 11 | def get_vec2(entities_embedding, id_list: List[int], embedding_dim=200, device="cuda"): 12 | all_entity_ids = torch.LongTensor(id_list).view(-1).to(device) 13 | all_entity_vec = torch.index_select( 14 | entities_embedding, 15 | dim=0, 16 | index=all_entity_ids 17 | ).view(-1, embedding_dim).cpu().detach().numpy() 18 | return all_entity_vec 19 | 20 | 21 | def get_vec3(entities_embedding, orth: torch.Tensor, id_list: List[int], device="cuda"): 22 | all_entity_ids = torch.LongTensor(id_list).view(-1).to(device) 23 | all_entity_vec = torch.index_select( 24 | entities_embedding, 25 | dim=0, 26 | index=all_entity_ids 27 | ).view(-1, 200) 28 | all_entity_vec = all_entity_vec.matmul(orth.transpose(0, 1)) 29 | return all_entity_vec.cpu().detach().numpy() 30 | 31 | -------------------------------------------------------------------------------- /toolbox/exp/classic/experiment_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/20 5 | @description: 实验的基础能力 demo 6 | """ 7 | import random 8 | 9 | from toolbox.exp.Experiment import Experiment 10 | from toolbox.exp.OutputSchema import OutputSchema 11 | 12 | 13 | class MyExperiment(Experiment): 14 | 15 | def __init__(self, output: OutputSchema): 16 | super(MyExperiment, self).__init__(output) 17 | self.log(f"{locals()}") 18 | 19 | self.model_param_store.save_scripts([__file__]) 20 | seed = 10 21 | max_steps = 10 22 | learning_rate = 0.0001 23 | self.metric_log_store.set_rng_seed(seed) 24 | self.metric_log_store.add_hyper(learning_rate, "learning_rate") 25 | self.metric_log_store.add_progress(max_steps) 26 | for step in range(max_steps): 27 | acc = random.randint(10, 100) 28 | self.metric_log_store.add_loss(acc, step, name="loss") 29 | self.metric_log_store.add_metric({"acc": acc}, step, "Test") 30 | if acc > 50: 31 | self.metric_log_store.add_best_metric({"acc": acc}, "Test") 32 | self.metric_log_store.finish() 33 | 34 | 35 | output = OutputSchema("demo") 36 | MyExperiment(output) 37 | -------------------------------------------------------------------------------- /toolbox/web/log_server/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 监听指定文件夹下的日志,以 web api 的形式提供服务给 log_app 6 | ``` 7 | python toolbox/web/log_server/main.py 8 | python -m toolbox.web.log_server.main 9 | ``` 10 | """ 11 | from typing import Optional, List 12 | 13 | import click 14 | import uvicorn 15 | from fastapi import FastAPI 16 | 17 | from .log_server import LogServer 18 | 19 | app = FastAPI() 20 | 21 | log_server = LogServer() 22 | 23 | 24 | @app.get("/") 25 | def read_root(): 26 | return {"你好": "请到该链接下查看 api 文档 /docs"} 27 | 28 | 29 | @app.get("/logs") 30 | def read_logs(ignore_log_names: Optional[dict] = None) -> List[dict]: 31 | return log_server.read_logs(ignore_log_names) 32 | 33 | 34 | @app.get("/certain_logs") 35 | def read_certain_logs(log_dir_names: List[str]) -> List[dict]: 36 | return log_server.read_certain_logs(log_dir_names) 37 | 38 | 39 | @click.command() 40 | @click.option("--log_dir", type=str, default="./output", help="日志所在文件夹") 41 | @click.option("--ip", type=str, default="0.0.0.0", help="IP 地址") 42 | @click.option("--port", type=int, default=45666, help="端口。如果该端口不可用,会自动选一个可用的") 43 | def main(log_dir: str, ip, port): 44 | app.root_path = log_dir 45 | log_server.set_log_dir(log_dir) 46 | uvicorn.run(app, host=ip, port=port) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /toolbox/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/11/7 5 | @description: null 6 | """ 7 | import torch 8 | 9 | 10 | def get_scheduler(optimizer, lr_policy="exp", epoch_count=5, lr_decay_iters=25, niter=100, niter_decay=100, ): 11 | """Return a learning rate scheduler 12 | Parameters: 13 | optimizer -- 网络优化器 14 | lr_policy -- 学习率scheduler的名称: linear | step | plateau | cosine 15 | """ 16 | if lr_policy == 'linear': 17 | def lambda_rule(epoch): 18 | lr_l = 1.0 - max(0, epoch + epoch_count - niter) / float(niter_decay + 1) 19 | return lr_l 20 | 21 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 22 | elif lr_policy == 'step': 23 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.5) 24 | elif lr_policy == 'plateau': 25 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 26 | elif lr_policy == 'cosine': 27 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) 28 | elif lr_policy == 'exp': 29 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) 30 | else: 31 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 32 | return scheduler 33 | -------------------------------------------------------------------------------- /toolbox/nn/DistMult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CoreDistMult(nn.Module): 7 | def __init__(self, input_dropout_rate=0.2): 8 | super(CoreDistMult, self).__init__() 9 | self.dropout = nn.Dropout(input_dropout_rate) 10 | 11 | def forward(self, h, r): 12 | h = self.dropout(h) 13 | r = self.dropout(r) 14 | 15 | x = h * r 16 | x = F.relu(x) 17 | return x 18 | 19 | 20 | class DistMult(nn.Module): 21 | def __init__(self, num_entities, num_relations, embedding_dim, input_dropout_rate=0.2): 22 | super(DistMult, self).__init__() 23 | self.E = nn.Embedding(num_entities, embedding_dim, padding_idx=0) 24 | self.R = nn.Embedding(num_relations, embedding_dim, padding_idx=0) 25 | self.core = CoreDistMult(input_dropout_rate) 26 | self.loss = nn.BCELoss() 27 | self.b = nn.Parameter(torch.zeros(num_entities)) 28 | 29 | def init(self): 30 | nn.init.xavier_normal_(self.E.weight.data) 31 | nn.init.xavier_normal_(self.R.weight.data) 32 | 33 | def forward(self, h_idx, r_idx): 34 | h = self.E(h_idx) 35 | r = self.R(r_idx) 36 | 37 | t = self.core(h, r) 38 | t = t.view(-1, self.embedding_dim) 39 | 40 | x = torch.mm(t, self.E.weight.transpose(1, 0)) 41 | x = x + self.b.expand_as(x) 42 | x = torch.sigmoid(x) 43 | return x 44 | -------------------------------------------------------------------------------- /toolbox/game/MultiPlayerGame.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | from typing import Tuple 8 | 9 | from toolbox.game.Player import Player 10 | 11 | 12 | class DoublePlayerGameEnv: 13 | def __init__(self): 14 | self.players = [1, 2] # player1 and player2 15 | self.current_player: int = 1 16 | 17 | def reset(self, start_player=0): 18 | self.current_player = self.players[start_player] # start player 19 | # return state 20 | 21 | def step(self, action): 22 | # return state, reward, done, info 23 | pass 24 | 25 | def game_end(self) -> Tuple[bool, int]: 26 | # return bool, player_id 27 | # if player 1 win, then return True, 1 28 | # if no one win, then return False, -1 29 | return False, -1 30 | 31 | def render(self, player1, player2): 32 | pass 33 | 34 | def get_current_player(self): 35 | return self.current_player 36 | 37 | 38 | class GameManager: 39 | """game manager""" 40 | 41 | def start_play(self, player1: Player, player2: Player, start_player=0, is_shown=True): 42 | """start a game between two players""" 43 | pass 44 | 45 | def start_self_play(self, player: Player, is_shown=False): 46 | """ start a self-play game using a MCTS player, reuse the search tree, 47 | and store the self-play data: (state, mcts_probs, z) for training 48 | """ 49 | pass 50 | -------------------------------------------------------------------------------- /toolbox/utils/DefaultDict.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/11 5 | @description: null 6 | """ 7 | class DefaultDict(dict): 8 | def __init__(self, default_value): 9 | super(DefaultDict, self).__init__() 10 | self._default_value = default_value 11 | 12 | def __missing__(self, key): 13 | value = self._default_value(key) if callable(self._default_value) else self._default_value 14 | self[key] = value 15 | return value 16 | 17 | 18 | class LambdaSet(object): 19 | def __call__(self, key): 20 | return set() 21 | 22 | 23 | class LambdaDefaultDict(object): 24 | def __init__(self, value): 25 | self._value = value 26 | 27 | def __call__(self, key): 28 | return self._value() 29 | 30 | 31 | if __name__ == '__main__': 32 | import pickle 33 | import random 34 | 35 | sro_t = DefaultDict(DefaultDict(DefaultDict(LambdaSet()))) 36 | 37 | print(sro_t[0][0][0]) 38 | for i in range(10): 39 | for j in range(10): 40 | for k in range(10): 41 | sro_t[i][j][k] = {m for m in range(random.randint(3, 9))} 42 | print(sro_t[0][0][0]) 43 | 44 | with open("output/test.pkl", "wb") as f: 45 | pickle.dump(sro_t, f) 46 | 47 | with open("output/test.pkl", "rb") as f: 48 | another = pickle.load(f) 49 | for i in another: 50 | for j in another[i]: 51 | for k in another[i][j]: 52 | print(another[i][j][k]) 53 | for i in range(10, 12): 54 | for j in range(10, 12): 55 | for k in range(10, 12): 56 | print(another[i][j][k]) 57 | -------------------------------------------------------------------------------- /toolbox/nn/RESCAL.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/30 5 | @description: null 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class CoreRESCAL(nn.Module): 12 | def __init__(self, entity_dim): 13 | super(CoreRESCAL, self).__init__() 14 | self.entity_dim = entity_dim 15 | 16 | def forward(self, h, r): 17 | h = h.view(-1, 1, self.entity_dim) 18 | r = r.view(-1, self.entity_dim, self.entity_dim) 19 | 20 | t = torch.bmm(h, r).view(-1, self.entity_dim) 21 | return t 22 | 23 | 24 | class RESCAL(nn.Module): 25 | def __init__(self, num_entities, num_relations, entity_dim, input_dropout=0.3): 26 | super(RESCAL, self).__init__() 27 | self.entity_dim = entity_dim 28 | 29 | self.E = nn.Embedding(num_entities, entity_dim) 30 | self.R = nn.Embedding(num_relations, entity_dim * entity_dim) 31 | 32 | self.core = CoreRESCAL(entity_dim) 33 | self.input_dropout = nn.Dropout(input_dropout) 34 | 35 | self.loss = nn.BCELoss() 36 | self.b = nn.Parameter(torch.zeros(num_entities)) 37 | 38 | def init(self): 39 | nn.init.kaiming_uniform_(self.E.weight.data) 40 | nn.init.kaiming_uniform_(self.R.weight.data) 41 | 42 | def forward(self, h_idx, r_idx): 43 | h = self.input_dropout(self.E(h_idx)) 44 | r = self.R(r_idx) 45 | 46 | t = self.core(h, r) 47 | t = t.view(-1, self.entity_dim) 48 | 49 | x = torch.mm(t, self.input_dropout(self.E.weight).transpose(1, 0)) 50 | x = x + self.b.expand_as(x) 51 | x = torch.sigmoid(x) 52 | return x 53 | -------------------------------------------------------------------------------- /toolbox/web/log_app/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/20 5 | @description: 监听远程日志服务器,显示前端 6 | ``` 7 | python toolbox/web/log_app/main.py 8 | python -m toolbox.web.log_app.main 9 | ``` 10 | """ 11 | import os 12 | import shutil 13 | 14 | import click 15 | 16 | from .app import start_app 17 | 18 | 19 | def create_settings(preference_dir_of_log_app="output"): 20 | pj_path = os.path.realpath('.') # user project path 21 | tools_path = os.path.realpath(__file__)[:-len("main.py")] # installed pkg path 22 | if not os.path.isdir(os.path.join(pj_path, preference_dir_of_log_app)): 23 | shutil.copytree(os.path.join(tools_path, "output"), os.path.join(pj_path, preference_dir_of_log_app)) 24 | elif not os.path.exists(os.path.join(pj_path, preference_dir_of_log_app, "default.cfg")): 25 | shutil.copy(os.path.join(tools_path, "output", "default.cfg"), os.path.join(pj_path, preference_dir_of_log_app)) 26 | 27 | 28 | @click.command() 29 | @click.option("--log_dir", type=str, default="output", help="app 配置文件所在文件夹.") 30 | @click.option("--log_config_name", type=str, default="default.cfg", help="启动 app 的配置文件。app 停机后会把运行期间被修改的配置保存到该文件,方便下次运行(这个设置在配置文件中关闭)。") 31 | @click.option("--ip", type=str, default="0.0.0.0", help="IP 地址") 32 | @click.option("--port", type=int, default=44666, help="端口。如果该端口不可用,会自动选一个可用的") 33 | @click.option("--standby_hours", type=int, default=24, help="空转小时数。如果超过这个时间没有任何操作,会自动停止运行,防止资源浪费") 34 | def main(log_dir, log_config_name, ip, port, standby_hours): 35 | log_dir = os.path.abspath(log_dir) 36 | create_settings(log_dir) 37 | start_app(log_dir, log_config_name, standby_hours, port, ip) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /toolbox/web/log_app/output/default.cfg: -------------------------------------------------------------------------------- 1 | [frontend_settings] 2 | # 以下的几个设置主要是用于控制前端的显示 3 | Ignore_null_value_when_filter=True 4 | Wrap_display=False 5 | Pagination=True 6 | Hide_hidden_columns_when_reorder=False 7 | # 前端的任何变动都不会尝试更新到服务器,即所有改动不会保存 8 | Offline=False 9 | # 是否保存本次前端页面的改动(包括删除,增加,column排序等)。在server关闭时和更改config时会判断 10 | Save_settings=True 11 | # row是否是可以通过拖拽交换的,如果可以交换则无法进行复制 12 | Reorderable_rows=False 13 | # 当选择revert代码时 revert到的路径: ../-revert 或 ../-revert- 14 | No_suffix_when_reset=True 15 | # 是否忽略掉filter_condition中的不存在对应key的log 16 | Ignore_filter_condition_not_exist_log=True 17 | 18 | [basic_settings] 19 | # 如果有内容长度超过这个值,在前端就会被用...替代。 20 | str_max_length=20 21 | # float的值保留几位小数 22 | round_to=6 23 | # 是否在表格中忽略不改变的column 24 | ignore_unchanged_columns=True 25 | 26 | [data_settings] 27 | # 在这里的log将不在前端显示出来,但是可以通过display点击出来。建议通过前端选择 28 | hidden_logs= 29 | # 在这里的log将在前端删除。建议通过前端选择 30 | deleted_logs= 31 | # 可以设置条件,只有满足以下条件的field才会被显示,请通过前端增加filter条件。 32 | filter_condition= 33 | # 默认加载 远程日志服务器 api 列表。 34 | # api 应该来自toolbox/web/log_server/serve_the_logs.py。 35 | # 必须带有 "http://" 36 | remote_log_servers=http://127.0.0.1:45666 37 | 38 | [column_settings] 39 | # 隐藏的column,建议通过前端选择 40 | hidden_columns= 41 | # 不需要显示的column,用逗号隔开,不要使用引号。需要将其从父节点一直写到它本身,比如排除meta中的fit_id, 写为meta-fit_id 42 | exclude_columns= 43 | # 允许编辑的column 44 | editable_columns=memo 45 | # column的显示顺序,强烈推荐不要手动更改 46 | column_order= 47 | 48 | [chart_settings] 49 | # 在走势图中,每个对象最多显示的点的数量,不要太大,否则前端可能会卡住 50 | max_points=200 51 | # 不需要在走势图中显示的column名称 52 | chart_exclude_columns= 53 | # 前端间隔秒多久尝试更新一次走势图,不要设置为太小。 54 | update_every=4 55 | # 如果前端超过max_no_updates次更新都没有获取到更新的数据,就停止刷新。如果evaluation的时间特别长,可能需要调大这个选项。 56 | max_no_updates=40 57 | 58 | [multi_chart_settings] 59 | # 最多支持可对比的log 60 | max_compare_metrics = 10 -------------------------------------------------------------------------------- /toolbox/data/ComplementaryDataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Set 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ComplementaryTrainDataset(Dataset): 8 | """ 9 | 生成 补集划分 的数据集 10 | head0: Bx(T-1) 11 | rel0: Bx(T-1) 12 | tail0: Bx(T-1) 13 | head: Bx1 14 | rel: Bx1 15 | tail: Bx1 16 | """ 17 | 18 | def __init__(self, triples_ids: List[Tuple[int, int, int]]): 19 | self.triples_ids: List[Tuple[int, int, int]] = triples_ids 20 | self.triples_idx_set: Set[int] = set([i for i in range(len(triples_ids))]) 21 | self.triples = torch.LongTensor(triples_ids) 22 | 23 | def __len__(self): 24 | return len(self.triples_ids) 25 | 26 | def __getitem__(self, idx): 27 | context_idx = list(self.triples_idx_set.difference({idx})) 28 | sample_triples = self.triples[idx] 29 | context_triples = self.triples[context_idx] 30 | head0, rel0, tail0 = context_triples[:, 0], context_triples[:, 1], context_triples[:, 2] 31 | head, rel, tail = sample_triples[0], sample_triples[1], sample_triples[2] 32 | return head0, rel0, tail0, head, rel, tail 33 | 34 | 35 | class ComplementaryTestDataset(Dataset): 36 | """ 37 | 生成 补集划分 的数据集 38 | head: Bx1 39 | rel: Bx1 40 | tail: Bx1 41 | """ 42 | 43 | def __init__(self, triples_ids: List[Tuple[int, int, int]]): 44 | self.triples_ids: List[Tuple[int, int, int]] = triples_ids 45 | self.triples_idx_set: Set[int] = set([i for i in range(len(triples_ids))]) 46 | self.triples = torch.LongTensor(triples_ids) 47 | 48 | def __len__(self): 49 | return len(self.triples_ids) 50 | 51 | def __getitem__(self, idx): 52 | sample_triples = self.triples[idx] 53 | head, rel, tail = sample_triples[0], sample_triples[1], sample_triples[2] 54 | return head, rel, tail 55 | -------------------------------------------------------------------------------- /toolbox/README_template_en.md: -------------------------------------------------------------------------------- 1 | # KGE Toolbox 2 | 3 | ## Environment 4 | 5 | create a conda environment with `pytorch` `cython` and `scikit-learn` : 6 | ```shell 7 | conda create --name toolbox_env python=3.7 8 | source activate toolbox_env 9 | conda install --file requirements.txt -c pytorch 10 | ``` 11 | ## How to run 12 | 13 | ```shell 14 | python train.py --batch_size=512 --name=TryMyModel 15 | ``` 16 | 17 | ## Contributing to 18 | 19 | To contribute to , follow these steps: 20 | 21 | 1. Fork this repository. 22 | 2. Create a branch: `git checkout -b `. 23 | 3. Make your changes and commit them: `git commit -m ''` 24 | 4. Push to the original branch: `git push origin /` 25 | 5. Create the pull request. 26 | 27 | Alternatively see the GitHub documentation on [creating a pull request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request). 28 | 29 | ## Contributors 30 | 31 | Thanks to the following people who have contributed to this project: 32 | 33 | * [@LinxueyuanStdio](https://github.com/LinxueyuanStdio) 📖 34 | * [@scottydocs](https://github.com/scottydocs) 📖 35 | * [@cainwatson](https://github.com/cainwatson) 🐛 36 | * [@calchuchesta](https://github.com/calchuchesta) 🐛 37 | 38 | You might want to consider using something like the [All Contributors](https://github.com/all-contributors/all-contributors) specification and its [emoji key](https://allcontributors.org/docs/en/emoji-key). 39 | 40 | ## Contact 41 | 42 | If you want to contact me you can reach me at . 43 | 44 | ## License 45 | 46 | 47 | This project uses the following license: [](). 48 | -------------------------------------------------------------------------------- /toolbox/nn/GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch_geometric.utils import degree 5 | from torch_sparse import spmm 6 | 7 | 8 | class GCN(nn.Module): 9 | def __init__(self, embedding_dim, out_dim): 10 | super(GCN, self).__init__() 11 | self.W = nn.Linear(embedding_dim, out_dim, bias=False) 12 | 13 | def forward(self, x, h, t): 14 | """ 15 | x: matrix of shape Exd 16 | h: vector of length T or matrix of shape Tx1 17 | t: vector of length T or matrix of shape Tx1 18 | 19 | PlainGCN(X) = sigma(D^(1/2) M D^(1/2) X) 20 | sigma is activation function 21 | M is adjacency matrix of shape ExE 22 | D is degree matrix of shape ExE 23 | X is embedding matrix of shape Exd 24 | """ 25 | deg = degree(h, x.size(0), dtype=x.dtype) 26 | deg_inv_sqrt = deg.pow(-0.5) 27 | norm = deg_inv_sqrt[t] * deg_inv_sqrt[h] 28 | 29 | # x = F.selu(spmm(torch.cat([t, h], dim= 0), norm, x.size(0), x.size(0), self.W(x))) 30 | x = F.relu(spmm(torch.cat([t, h], dim=0), norm, x.size(0), x.size(0), self.W(x))) 31 | return x 32 | 33 | 34 | class PlainGCN(nn.Module): 35 | def __init__(self): 36 | super(PlainGCN, self).__init__() 37 | 38 | def forward(self, x, h, t): 39 | """ 40 | x: matrix of shape Exd 41 | h: vector of length T or matrix of shape Tx1 42 | t: vector of length T or matrix of shape Tx1 43 | 44 | PlainGCN(X) = sigma(D^(1/2) M D^(1/2) X) 45 | sigma is activation function 46 | M is adjacency matrix of shape ExE 47 | D is degree matrix of shape ExE 48 | X is embedding matrix of shape Exd 49 | """ 50 | deg = degree(h, x.size(0), dtype=x.dtype) 51 | deg_inv_sqrt = deg.pow(-0.5) 52 | norm = deg_inv_sqrt[t] * deg_inv_sqrt[h] 53 | 54 | # x = F.selu(spmm(torch.cat([t, h], dim= 0), norm, x.size(0), x.size(0), x)) 55 | x = F.relu(spmm(torch.cat([t, h], dim=0), norm, x.size(0), x.size(0), x)) 56 | return x 57 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/js/chart.js: -------------------------------------------------------------------------------- 1 | function generate_range_modal(current_step, charts, ele, range_checked, ranges) { 2 | // 根据charts中的内容生成几个range 3 | var sliders = {}; 4 | 5 | if (!jQuery.isEmptyObject(charts)) { 6 | for (var key in charts) { 7 | if (!(key in range_checked)) { 8 | range_checked[key] = ''; 9 | } 10 | var checked = range_checked[key]; 11 | var _range = []; 12 | if (!(key in ranges)) { 13 | _range = [0, current_step]; 14 | } else { 15 | _range = ranges[key]; 16 | } 17 | var name = key; 18 | var id = key + '_range_bar'; 19 | var html = '
'; 20 | html += "
\n" + 21 | " \n" + 28 | "
"; 29 | html += ''; 30 | html += '
'; 31 | ele.append(html); 32 | var enabled = checked !== ''; 33 | sliders[key] = new Slider('#' + id, { 34 | max: current_step, 35 | value: _range, 36 | enabled: enabled, 37 | step: 50 38 | }); 39 | } 40 | } 41 | return sliders; 42 | } -------------------------------------------------------------------------------- /toolbox/nn/GAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch_geometric.utils import softmax 5 | from torch_sparse import spmm 6 | 7 | 8 | class GAT(nn.Module): 9 | """ 10 | 第一种GAT,可以不用关系,只用实体对 11 | """ 12 | 13 | def __init__(self, hidden): 14 | super(GAT, self).__init__() 15 | self.a_i = nn.Linear(hidden, 1, bias=False) 16 | self.a_j = nn.Linear(hidden, 1, bias=False) 17 | 18 | def forward(self, x, h, t): 19 | """ 20 | x: matrix of shape Exd 21 | h: vector of length T or matrix of shape Tx1 22 | t: vector of length T or matrix of shape Tx1 23 | 24 | GAT(X) = AX 25 | A is self attention from X 26 | A is matrix of shape ExE 27 | X is embedding matrix of shape Exd 28 | """ 29 | e_i = self.a_i(x)[h].view(-1) 30 | e_j = self.a_j(x)[t].view(-1) 31 | e = e_i + e_j 32 | alpha = softmax(F.leaky_relu(e).float(), h) 33 | sparse_index = torch.cat([h.view(1, -1), t.view(1, -1)], dim=0) 34 | x = F.relu(spmm(sparse_index, alpha, x.size(0), x.size(0), x)) 35 | return x 36 | 37 | 38 | class GAT2(nn.Module): 39 | def __init__(self, hidden_dim): 40 | super(GAT2, self).__init__() 41 | self.a_i = nn.Linear(hidden_dim, 1, bias=False) 42 | self.a_j = nn.Linear(hidden_dim, 1, bias=False) 43 | self.a_k = nn.Linear(hidden_dim, 1, bias=False) 44 | 45 | def forward(self, E, R, T): 46 | """ 47 | E: 矩阵 |E| x d_e ,即实体数 x 嵌入维度 48 | R: 矩阵 |R| x d_r ,即关系数 x 嵌入维度 49 | T: 矩阵 |T| x 3, 即三元组数 x 3,每一行是[头实体索引,关系索引,尾实体索引] 50 | """ 51 | h = T[:, 0] 52 | r = T[:, 1] 53 | t = T[:, 2] 54 | e_i = self.a_i(E)[h].view(-1) 55 | e_j = self.a_j(E)[t].view(-1) 56 | r_k = self.a_j(R)[r].view(-1) 57 | e = e_i + r_k + e_j 58 | alpha = softmax(F.leaky_relu(e).float(), h) 59 | sparse_index = torch.cat([h.view(1, -1), t.view(1, -1)], dim=0) 60 | E = F.relu(spmm(sparse_index, alpha, E.size(0), E.size(0), E)) 61 | return E 62 | -------------------------------------------------------------------------------- /toolbox/evaluate/Leaderboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/23 5 | @description: null 6 | """ 7 | from typing import List 8 | 9 | 10 | def QueryEmbeddingLeaderboard(): 11 | header = ["Model", "1p", "2p", "3p", "2i", "3i", "pi", "ip", "2u", "up", "AVG"] 12 | FB15k = [ 13 | ["GQE", 53.9, 15.5, 11.1, 40.2, 52.4, 27.5, 19.4, 22.3, 11.7, 28.2], 14 | ["Q2B", 70.5, 23.0, 15.1, 61.2, 71.8, 41.8, 28.7, 37.7, 19.0, 40.1], 15 | ["BetaE", 65.1, 25.7, 24.7, 55.8, 66.5, 43.9, 28.1, 40.1, 25.2, 41.6], 16 | ["LogicE", 72.3, 29.8, 26.2, 56.1, 66.3, 42.7, 32.6, 43.4, 27.5, 44.1], 17 | ["ConE", 73.3, 33.8, 29.2, 64.4, 73.7, 50.9, 35.7, 55.7, 31.4, 49.8], 18 | ] 19 | FB237 = [ 20 | ["GQE", 35.2, 7.4, 5.5, 23.6, 35.7, 16.7, 10.9, 8.4, 5.8, 16.6], 21 | ["Q2B", 41.3, 9.9, 7.2, 31.1, 45.4, 21.9, 13.3, 11.9, 8.1, 21.1], 22 | ["BetaE", 39.0, 10.9, 10.0, 28.8, 42.5, 22.4, 12.6, 12.4, 9.7, 20.9], 23 | ["LogicE", 41.3, 11.8, 10.4, 31.4, 43.9, 23.8, 14.0, 13.4, 10.2, 22.3], 24 | ["ConE", 41.8, 12.8, 11.0, 32.6, 47.3, 25.5, 14.0, 14.5, 10.8, 23.4], 25 | ["BoolE", 43.3, 13.0, 11.0, 34.5, 48.0, 27.0, 16.7, 15.1, 11.2, 24.4], 26 | ] 27 | NELL = [ 28 | ["GQE", 33.1, 12.1, 9.9, 27.3, 35.1, 18.5, 14.5, 8.5, 9.0, 18.7], 29 | ["Q2B", 42.7, 14.5, 11.7, 34.7, 45.8, 23.2, 17.4, 12.0, 10.7, 23.6], 30 | ["BetaE", 53.0, 13.0, 11.4, 37.6, 47.5, 24.1, 14.3, 12.2, 8.5, 24.6], 31 | ["LogicE", 58.3, 17.7, 15.4, 40.5, 50.4, 27.3, 19.2, 15.9, 12.7, 28.6], 32 | ["ConE", 53.1, 16.1, 13.9, 40.0, 50.8, 26.3, 17.5, 15.3, 11.3, 27.2], 33 | ] 34 | return header, FB15k, FB237, NELL 35 | 36 | 37 | def append_to_QueryEmbeddingLeaderboard(name: str, FB15k_result: List[float], FB237_result: List[float], NELL_result: List[float]): 38 | header, FB15k, FB237, NELL = QueryEmbeddingLeaderboard() 39 | FB15k.append([name] + FB15k_result) 40 | FB237.append([name] + FB237_result) 41 | NELL.append([name] + NELL_result) 42 | return header, FB15k, FB237, NELL 43 | 44 | -------------------------------------------------------------------------------- /toolbox/exp/DistributeSchema.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List 3 | 4 | import torch 5 | from torch import distributed as dist 6 | 7 | 8 | class DistributeSchema: 9 | 10 | def __init__(self, local_rank=-1, gpus: Optional[List[int]] = None): 11 | self.local_rank = local_rank 12 | self.gpus = gpus 13 | self.world_size = self.get_world_size() 14 | if self.world_size > 1 and not dist.is_initialized(): 15 | dist.init_process_group("nccl", init_method="env://") 16 | self.device = self.get_device() 17 | 18 | def reduce_sum(self, tensor): 19 | if self.get_world_size() > 1: 20 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 21 | 22 | def running_in_main_node(self) -> bool: 23 | return self.get_local_rank() == 0 24 | 25 | def get_local_rank(self): 26 | if self.local_rank == -1: 27 | if dist.is_initialized(): 28 | return dist.get_rank() 29 | if "RANK" in os.environ: 30 | return int(os.environ["RANK"]) 31 | return 0 32 | return self.local_rank 33 | 34 | def get_world_size(self): 35 | if dist.is_initialized(): 36 | return dist.get_world_size() 37 | if "WORLD_SIZE" in os.environ: 38 | return int(os.environ["WORLD_SIZE"]) 39 | return 1 40 | 41 | def synchronize(self): 42 | if self.get_world_size() > 1: 43 | dist.barrier() 44 | 45 | def get_device(self): 46 | if self.gpus: 47 | device = torch.device(self.gpus[self.get_local_rank()]) 48 | else: 49 | if torch.cuda.is_available(): 50 | device = torch.device("cuda") 51 | else: 52 | device = torch.device("cpu") 53 | return device 54 | 55 | def wrap_to_parallel_model(self, model): 56 | if self.world_size > 1: 57 | return torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.device]) 58 | return model 59 | 60 | def __repr__(self): 61 | return f"{self.__class__.__name__}(gpu={self.local_rank} in {self.gpus})" 62 | 63 | -------------------------------------------------------------------------------- /toolbox/nn/Complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Complex(nn.Module): 6 | def __init__(self, num_entities, num_relations, embedding_dim, input_dropout_rate=0.2): 7 | super(Complex, self).__init__() 8 | self.num_entities = num_entities 9 | self.embedding_dim = embedding_dim 10 | self.E_real = nn.Embedding(num_entities, embedding_dim) 11 | self.E_img = nn.Embedding(num_entities, embedding_dim) 12 | self.R_real = nn.Embedding(num_relations, embedding_dim) 13 | self.R_img = nn.Embedding(num_relations, embedding_dim) 14 | self.dropout = nn.Dropout(input_dropout_rate) 15 | self.loss = nn.BCELoss() 16 | 17 | def init(self): 18 | nn.init.xavier_normal_(self.E_real.weight.data) 19 | nn.init.xavier_normal_(self.E_img.weight.data) 20 | nn.init.xavier_normal_(self.R_real.weight.data) 21 | nn.init.xavier_normal_(self.R_img.weight.data) 22 | 23 | def forward(self, e1, rel): 24 | e1_embedded_real = self.E_real(e1).view(-1, self.embedding_dim) 25 | rel_embedded_real = self.R_real(rel).view(-1, self.embedding_dim) 26 | e1_embedded_img = self.E_img(e1).view(-1, self.embedding_dim) 27 | rel_embedded_img = self.R_img(rel).view(-1, self.embedding_dim) 28 | 29 | e1_embedded_real = self.dropout(e1_embedded_real) 30 | rel_embedded_real = self.dropout(rel_embedded_real) 31 | e1_embedded_img = self.dropout(e1_embedded_img) 32 | rel_embedded_img = self.dropout(rel_embedded_img) 33 | 34 | # complex space bilinear product (equivalent to HolE) 35 | realrealreal = torch.mm(e1_embedded_real * rel_embedded_real, self.E_real.weight.transpose(1, 0)) 36 | realimgimg = torch.mm(e1_embedded_real * rel_embedded_img, self.E_img.weight.transpose(1, 0)) 37 | imgrealimg = torch.mm(e1_embedded_img * rel_embedded_real, self.E_img.weight.transpose(1, 0)) 38 | imgimgreal = torch.mm(e1_embedded_img * rel_embedded_img, self.E_real.weight.transpose(1, 0)) 39 | pred = realrealreal + realimgimg + imgrealimg - imgimgreal 40 | pred = torch.sigmoid(pred) 41 | 42 | return pred 43 | -------------------------------------------------------------------------------- /toolbox/nn/LorentzE.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/17 5 | @description: null 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class LorentzE(nn.Module): 12 | def __init__(self, num_entities, num_relations, embedding_dim, input_dropout_rate=0.2): 13 | super(LorentzE, self).__init__() 14 | self.num_entities = num_entities 15 | self.embedding_dim = embedding_dim 16 | 17 | self.E_ct = nn.Embedding(num_entities, embedding_dim) 18 | self.E_x = nn.Embedding(num_entities, embedding_dim) 19 | self.E_y = nn.Embedding(num_entities, embedding_dim) 20 | self.E_z = nn.Embedding(num_entities, embedding_dim) # E x d 21 | 22 | self.R_v1 = nn.Embedding(num_relations, embedding_dim) # R x d 23 | self.R_v2 = nn.Embedding(num_relations, embedding_dim) 24 | self.R_v3 = nn.Embedding(num_relations, embedding_dim) 25 | 26 | self.dropout = nn.Dropout(input_dropout_rate) 27 | self.loss = nn.BCELoss() 28 | 29 | def forward(self, h_idx, r_idx): 30 | # h_idx Bx1 31 | # r_idx Bx1 32 | h_ct = self.E_ct(h_idx).view(-1, self.embedding_dim) # Bxd 33 | h_x = self.E_x(h_idx).view(-1, self.embedding_dim) # Bxd 34 | h_y = self.E_y(h_idx).view(-1, self.embedding_dim) # Bxd 35 | h_z = self.E_z(h_idx).view(-1, self.embedding_dim) # Bxd 36 | r_v1 = self.R_v1(h_idx).view(-1, self.embedding_dim) # Bxd 37 | r_v2 = self.R_v2(h_idx).view(-1, self.embedding_dim) # Bxd 38 | r_v3 = self.R_v3(h_idx).view(-1, self.embedding_dim) # Bxd 39 | # f : x -> [0, c] 40 | t_ct = 1 * h_ct + r_v1 * h_x + r_v2 * h_y + r_v3 * h_z 41 | t_x = h_x 42 | t_y = h_y 43 | t_z = h_z 44 | # 1 vs. N , N=E 45 | score_ct = torch.mm(t_ct, self.E_ct.weight.transpose(1, 0)) # Bxd, Exd -> BxE, 46 | score_x = torch.mm(t_x, self.E_x.weight.transpose(1, 0)) # Bxd, Exd -> BxE, 47 | score_y = torch.mm(t_y, self.E_y.weight.transpose(1, 0)) # Bxd, Exd -> BxE, 48 | score_z = torch.mm(t_z, self.E_z.weight.transpose(1, 0)) # Bxd, Exd -> BxE, 49 | score = (score_ct + score_x + score_y + score_z) / 4 50 | score = score.sigmoid() 51 | return score 52 | -------------------------------------------------------------------------------- /toolbox/web/log_app/line_app.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from flask import Blueprint 4 | from flask import render_template 5 | from flask import request 6 | 7 | from .server.data_container import all_data 8 | from .server.table_utils import expand_dict 9 | from .server.table_utils import generate_columns 10 | 11 | line_page = Blueprint('line_page', __name__, template_folder='templates') 12 | 13 | 14 | @line_page.route('/line', methods=['POST']) 15 | def line_index(): 16 | ids = request.values['ids'] 17 | # 取出所有的logs 18 | flat_logs = [all_data['data'][id].copy() for id in ids.split(',')] 19 | hidden_columns = all_data['hidden_columns'].copy() 20 | 21 | # 删除不是共有的部分 22 | value_dict_count = defaultdict(list) # 每个key有多少个 23 | for log in flat_logs: 24 | for key in log.keys(): 25 | value_dict_count[key] += [log[key]] 26 | for key, _lst in list(value_dict_count.items()): 27 | if len(_lst) != len(flat_logs): # 有些没有这个值 28 | for log in flat_logs: 29 | log.pop(key, None) 30 | if len(set(_lst)) == 1: # 只有一个值 31 | hidden_columns[key] = 1 32 | 33 | logs = [expand_dict([log])[0] for log in flat_logs] 34 | # column_order, column_dict, hidden_columns, settings, logs 35 | hidden_columns['id'] = 1 36 | hidden_columns['memo'] = 1 37 | hidden_columns['meta'] = 1 38 | res = generate_columns(logs, hidden_columns=hidden_columns, column_order=all_data['column_order'], editable_columns={}, 39 | exclude_columns={}, ignore_unchanged_columns=False, 40 | str_max_length=20, round_to=6, num_extra_log=0) 41 | 42 | column_order = res['column_order'] 43 | column_order.pop('id') 44 | column_order['OrderKeys'].remove('id') 45 | if 'metric' in column_order: # 将metric放在第一的位置 46 | column_order['OrderKeys'].remove('metric') 47 | column_order['OrderKeys'].insert(0, 'metric') 48 | column_dict = res['column_dict'] 49 | column_dict.pop('id') 50 | hidden_columns = res['hidden_columns'] 51 | data = res['data'] 52 | for key, log in data.items(): 53 | log.pop('id') 54 | 55 | return render_template('line.html', data=data, column_order=column_order, column_dict=column_dict, 56 | hidden_columns=hidden_columns) 57 | -------------------------------------------------------------------------------- /toolbox/nn/TuckERT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TuckERT(nn.Module): 7 | def __init__(self, d, de, dr, dt, device="cpu", input_dropout=0., hidden_dropout1=0., hidden_dropout2=0., **kwargs): 8 | super(TuckERT, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Embeddings dimensionality 13 | self.de = de 14 | self.dr = dr 15 | self.dt = dt 16 | 17 | # Data dimensionality 18 | self.ne = len(d.entities) 19 | self.nr = len(d.relations) 20 | self.nt = len(d.time) 21 | 22 | # Embedding matrices 23 | self.E = nn.Embedding(self.ne, de) 24 | self.R = nn.Embedding(self.nr, dr) 25 | self.T = nn.Embedding(self.nt, dt) 26 | 27 | # Core tensor 28 | self.W = nn.Parameter(torch.tensor(np.random.uniform(-0.1, 0.1, (dr, de, de, dt)), dtype=torch.float, device=self.device, requires_grad=True)) 29 | 30 | # "Specia"l layers 31 | self.input_dropout = nn.Dropout(input_dropout) 32 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 33 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 34 | self.loss = nn.BCELoss() 35 | 36 | self.bne = nn.BatchNorm1d(de) 37 | 38 | def init(self): 39 | nn.init.xavier_normal_(self.E.weight.data) 40 | nn.init.xavier_normal_(self.R.weight.data) 41 | nn.init.xavier_normal_(self.T.weight.data) 42 | 43 | def forward(self, e1_idx, r_idx, t_idx): 44 | # Mode 1 product with entity vector 45 | e1 = self.E(e1_idx) 46 | x = self.bne(e1) 47 | x = self.input_dropout(x) 48 | x = e1 49 | x = x.view(-1, 1, self.de) 50 | 51 | # Mode 2 product with relation vector 52 | r = self.R(r_idx) 53 | W_mat = torch.mm(r, self.W.view(r.size(1), -1)) 54 | W_mat = W_mat.view(-1, self.de, self.de * self.dt) 55 | x = torch.bmm(x, W_mat) 56 | 57 | # Mode 3 product with time vector 58 | t = self.T(t_idx) 59 | x = x.view(-1, self.de, self.dt) 60 | x = torch.bmm(x, t.view(*t.shape, -1)) 61 | 62 | # Mode 4 product with entity matrix 63 | x = x.view(-1, self.de) 64 | x = torch.mm(x, self.E.weight.transpose(1, 0)) 65 | 66 | pred = torch.sigmoid(x) 67 | return pred 68 | -------------------------------------------------------------------------------- /toolbox/data/FixWindowNegSamplingDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple, Set, Dict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from toolbox.data.functional import build_map_hr_t 9 | 10 | 11 | def get_neg_sampling_batch(entity_ids: Set[int], 12 | hr_t: Dict[Tuple[int, int], Set[int]], 13 | hr_pairs: List[Tuple[int, int]], 14 | idx: int, 15 | sampling_window_size=200): 16 | assert sampling_window_size in range(len(entity_ids)) 17 | assert idx in range(len(hr_pairs)) 18 | max_positive_count = sampling_window_size // 10 19 | batch = hr_pairs[idx] 20 | target_sim = np.zeros(sampling_window_size) 21 | target_ids = [] 22 | positive_ids: List[int] = list(hr_t[batch]) 23 | if len(positive_ids) > max_positive_count: 24 | positive_ids = random.choices(positive_ids, k=max_positive_count) 25 | positive_count = len(positive_ids) 26 | if positive_count >= sampling_window_size: 27 | ids = random.choices(positive_ids, k=sampling_window_size) 28 | target_ids.append(ids) 29 | target_sim[:] = 1. 30 | else: 31 | negative_count = sampling_window_size - positive_count 32 | negative_ids = list(entity_ids.difference(set(positive_ids))) 33 | negative_ids = random.choices(negative_ids, k=negative_count) 34 | ids = positive_ids + negative_ids 35 | target_ids.append(ids) 36 | target_sim[:positive_count] = 1. 37 | target_ids = torch.LongTensor(target_ids).view(-1) 38 | target_sim = torch.FloatTensor(target_sim).view(-1) 39 | return np.array(batch), target_ids, target_sim 40 | 41 | 42 | class FixedWindowNegSamplingDataset(Dataset): 43 | def __init__(self, train_triples_ids: List[Tuple[int, int, int]], entity_ids: List[int], sampling_window_size=200): 44 | self.hr_t = build_map_hr_t(train_triples_ids) 45 | self.hr_pairs = list(self.hr_t.keys()) 46 | self.sampling_window_size = sampling_window_size 47 | self.entity_ids: Set[int] = set(entity_ids) 48 | 49 | def __len__(self): 50 | return len(self.hr_pairs) 51 | 52 | def __getitem__(self, idx): 53 | return get_neg_sampling_batch(self.entity_ids, self.hr_t, self.hr_pairs, idx, self.sampling_window_size) 54 | -------------------------------------------------------------------------------- /toolbox/web/log_app/templates/folder.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Log Folder 6 | 7 | 8 | 9 | 26 | 31 | 32 | 33 |
34 |

Folder

35 |
36 |

Current Folder {{ id+ossep+subdir }}

37 |
38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | {% if subdir %} 48 | 49 | 50 | 51 | 52 | 53 | {% endif %} 54 | {% for i in contents %} 55 | 56 | 67 | 68 | 69 | 70 | {% endfor %} 71 | 72 |
File or DirectoryModify TimeSize
..{{ ossep }}
57 | {% if i.isfile %} 58 | {{ i.filename }} 60 | 61 | {% else %} 62 | {{ i.filename }} 64 | 65 | {% endif %} 66 | {{ i.mtime }}{{ i.size }}
73 |
74 |
75 | 76 | -------------------------------------------------------------------------------- /toolbox/exp/config/JsonConfig.py: -------------------------------------------------------------------------------- 1 | import json 2 | from shutil import copyfile 3 | from typing import Any, Dict 4 | 5 | 6 | class Config(object): 7 | """Config load from json file 8 | 9 | Examples: 10 | >>> from toolbox.exp.config.JsonConfig import Config 11 | >>> config = Config(config={"name":"hyperparams", "dataset":"FB15k"}, config_file="xxx.json") 12 | >>> # use it like a json object 13 | >>> # config.name 14 | >>> # config.dataset 15 | """ 16 | 17 | def __init__(self, config=None, config_file=None): 18 | if config: 19 | self._update(config) 20 | if config_file: 21 | with open(config_file, 'r') as fin: 22 | config = json.load(fin) 23 | self._update(config) 24 | 25 | def __getitem__(self, key): 26 | return self.__dict__[key] 27 | 28 | def __contains__(self, item): 29 | return item in self.__dict__ 30 | 31 | def items(self): 32 | return self.__dict__.items() 33 | 34 | def add(self, key, value): 35 | """Add key value pair 36 | """ 37 | self.__dict__[key] = value 38 | 39 | def _update(self, config: Dict[str, Any]): 40 | if not isinstance(config, dict): 41 | return 42 | 43 | for key in config: 44 | if isinstance(config[key], dict): 45 | config[key] = Config(config[key]) 46 | elif isinstance(config[key], list): 47 | config[key] = [Config(x) if isinstance(x, dict) else x for x in config[key]] 48 | 49 | self.__dict__.update(config) 50 | 51 | def save_as_json(self, dir_name: str, filename: str): 52 | if type(self.source) is list: 53 | for s in self.source: 54 | c = Config(s) 55 | c.save(dir_name) 56 | elif type(self.source) is dict: 57 | json.dumps(self.source, indent=4) 58 | else: 59 | copyfile(self.source, dir_name + filename) 60 | 61 | def show(self, fun=print): 62 | if type(self.source) is list: 63 | for s in self.source: 64 | c = Config(s) 65 | c.show(fun) 66 | elif type(self.source) is dict: 67 | fun(json.dumps(self.source)) 68 | else: 69 | with open(self.source) as f: 70 | fun(json.dumps(json.load(f), indent=4)) 71 | -------------------------------------------------------------------------------- /toolbox/nn/MobiusEmbedding.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from toolbox.nn.ComplexEmbedding import ComplexEmbedding, ComplexDropout, ComplexBatchNorm1d 7 | 8 | 9 | class MobiusEmbedding(nn.Module): 10 | def __init__(self, num_entities, embedding_dim, num_channels, norm_num_channels=2): 11 | super(MobiusEmbedding, self).__init__() 12 | self.num_entities = num_entities 13 | self.embedding_dim = embedding_dim 14 | self.num_channels = num_channels 15 | self.embeddings = nn.ModuleList([ComplexEmbedding(num_entities, embedding_dim, norm_num_channels) for _ in range(num_channels)]) 16 | 17 | def forward(self, idx): 18 | embedings = [] 19 | for embedding in self.embeddings: 20 | embedings.append(embedding(idx)) 21 | return tuple(embedings) 22 | 23 | def init(self): 24 | for embedding in self.embeddings: 25 | embedding.init() 26 | 27 | def get_embeddings(self): 28 | return [embedding.get_embeddings() for embedding in self.embeddings] 29 | 30 | def get_cat_embedding(self): 31 | return torch.cat([embedding.get_cat_embedding() for embedding in self.embeddings], 1) 32 | 33 | 34 | class MobiusDropout(nn.Module): 35 | def __init__(self, dropout_rate_list: List[List[float]]): 36 | super(MobiusDropout, self).__init__() 37 | self.dropout_rate_list = dropout_rate_list 38 | self.dropouts = nn.ModuleList([ComplexDropout(dropout_rate) for dropout_rate in dropout_rate_list]) 39 | 40 | def forward(self, complex_numbers): 41 | out = [] 42 | for idx, complex_number in enumerate(list(complex_numbers)): 43 | out.append(self.dropouts[idx](complex_number)) 44 | return tuple(out) 45 | 46 | 47 | class MobiusBatchNorm1d(nn.Module): 48 | def __init__(self, embedding_dim, num_channels, norm_num_channels=2): 49 | super(MobiusBatchNorm1d, self).__init__() 50 | self.embedding_dim = embedding_dim 51 | self.num_channels = num_channels 52 | self.batch_norms = nn.ModuleList([ComplexBatchNorm1d(embedding_dim, norm_num_channels) for _ in range(num_channels)]) 53 | 54 | def forward(self, complex_numbers): 55 | out = [] 56 | for idx, complex_number in enumerate(list(complex_numbers)): 57 | out.append(self.batch_norms[idx](complex_number)) 58 | return tuple(out) 59 | -------------------------------------------------------------------------------- /toolbox/utils/VisualizeStore.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 可视化 6 | run the command below to open tensorbard 7 | ```shell 8 | tensorboard --logdir . 9 | ``` 10 | """ 11 | 12 | 13 | def get_writer(log_dir: str, comments=""): 14 | from torch.utils.tensorboard import SummaryWriter 15 | return SummaryWriter(log_dir, comments) 16 | 17 | 18 | def add_scalar(writer, name: str, value, step_num: int): 19 | writer.add_scalar(name, value, step_num) 20 | 21 | 22 | def add_result(writer, result, step_num: int): 23 | left2right = result["left2right"] 24 | right2left = result["right2left"] 25 | using_time = result["time"] 26 | sorted(left2right) 27 | sorted(right2left) 28 | for i in left2right: 29 | add_scalar(writer, i, left2right[i], step_num) 30 | for i in right2left: 31 | add_scalar(writer, i, right2left[i], step_num) 32 | add_scalar(writer, "using time (s)", using_time, step_num) 33 | 34 | 35 | class VisualizeStoreSchema: 36 | def __init__(self, log_dir: str, comments=""): 37 | self.writer = get_writer(log_dir, comments) 38 | print() 39 | print("Tensorboard is activated on dir " + log_dir) 40 | print("You can open tensorboard with:") 41 | print(" tensorboard --logdir " + log_dir + " --host=your_ip --port=6006") 42 | print() 43 | 44 | def add_scalar(self, name: str, value, step_num: int): 45 | add_scalar(self.writer, name, value, step_num) 46 | 47 | def add_model(self, model): 48 | self.writer.add_graph(model) 49 | 50 | def add_embedding(self, embedding, labels=None, step_num: int = 0): 51 | # fix: 52 | # module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem' 53 | # see: 54 | # https://github.com/pytorch/pytorch/issues/30966 55 | import tensorflow as tf 56 | import tensorboard as tb 57 | tf.io.gfile = tb.compat.tensorflow_stub.io.gfile 58 | self.writer.add_embedding(embedding, metadata=labels, global_step=step_num) 59 | 60 | def add_result(self, result, step_num: int): 61 | add_result(self.writer, result, step_num) 62 | 63 | def add_link_prediction_result(self, result, step_num: int, scope: str): 64 | for key in result: 65 | for i in result[key]: 66 | self.add_scalar(f"{scope}_{key}_{i}", result[key][i], step_num) 67 | -------------------------------------------------------------------------------- /toolbox/nn/ConvE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class CoreConvE(nn.Module): 7 | def __init__(self, embedding_dim, img_h=10, input_dropout=0.2, hidden_dropout1=0.3, hidden_dropout2=0.2): 8 | super(CoreConvE, self).__init__() 9 | self.inp_drop = nn.Dropout(input_dropout) 10 | self.feature_map_drop = nn.Dropout2d(hidden_dropout1) 11 | self.hidden_drop = nn.Dropout(hidden_dropout2) 12 | 13 | self.img_h = img_h 14 | self.img_w = embedding_dim // self.img_h 15 | 16 | self.conv1 = nn.Conv2d(1, 32, (3, 3), 1, 0, bias=True) 17 | self.bn0 = nn.BatchNorm2d(1) 18 | self.bn1 = nn.BatchNorm2d(32) 19 | self.bn2 = nn.BatchNorm1d(embedding_dim) 20 | 21 | hidden_size = (self.img_h * 2 - 3 + 1) * (self.img_w - 3 + 1) * 32 22 | self.fc = nn.Linear(hidden_size, embedding_dim) 23 | 24 | def forward(self, h, r): 25 | h = h.view(-1, 1, self.img_h, self.img_w) 26 | r = r.view(-1, 1, self.img_h, self.img_w) 27 | 28 | x = torch.cat([h, r], 2) 29 | x = self.bn0(x) 30 | x = self.inp_drop(x) 31 | x = self.conv1(x) 32 | x = self.bn1(x) 33 | x = F.relu(x) 34 | x = self.feature_map_drop(x) 35 | x = x.view(x.shape[0], -1) 36 | x = self.fc(x) 37 | x = self.hidden_drop(x) 38 | x = self.bn2(x) 39 | x = F.relu(x) 40 | return x 41 | 42 | 43 | class ConvE(nn.Module): 44 | def __init__(self, num_entities, num_relations, embedding_dim, hidden_dropout=0.3): 45 | super(ConvE, self).__init__() 46 | self.E = nn.Embedding(num_entities, embedding_dim) 47 | self.R = nn.Embedding(num_relations, embedding_dim) 48 | 49 | self.core = CoreConvE(embedding_dim) 50 | self.dropout = nn.Dropout(hidden_dropout) 51 | self.b = nn.Parameter(torch.zeros(num_entities)) 52 | self.m = nn.PReLU() 53 | 54 | self.loss = nn.BCELoss() 55 | 56 | def init(self): 57 | nn.init.xavier_normal_(self.E.weight.data) 58 | nn.init.xavier_normal_(self.R.weight.data) 59 | 60 | def forward(self, h, r): 61 | h = self.E(h) # Bxd 62 | r = self.R(r) # Bxd 63 | t = self.core(h, r) 64 | 65 | x = torch.mm(t, self.dropout(self.E.weight).transpose(1, 0)) 66 | x = x + self.b.expand_as(x) 67 | x = torch.sigmoid(x) 68 | return x # batch_size x E 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ -------------------------------------------------------------------------------- /toolbox/web/log_app/server/app_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import socket 3 | import threading 4 | import time 5 | from urllib import request as urequest 6 | 7 | 8 | def cmd_parser(): 9 | # 返回为app.py准备的command line parser 10 | parser = argparse.ArgumentParser(description="To display your experiment logs in html.") 11 | 12 | parser.add_argument('-d', '--log_dir', help='Where to read logs. This directory should include a lot of logs.', 13 | required=True, type=str) 14 | parser.add_argument('-l', '--log_config_name', 15 | help="Log config name. Will try to find it in {log_dir}/{log_config_name}. Default is " 16 | "default.cfg", 17 | required=False, 18 | type=str, default='default.cfg') 19 | parser.add_argument('-p', '--port', help='What port to use. Default 5000, but when it is blocked, pick 5001 ...', 20 | required=False, type=int, default=5000) 21 | 22 | return parser 23 | 24 | 25 | def get_usage_port(start_port): 26 | # 给定一个start_port, 依次累加直到找到一个可用的port 27 | while start_port < 65535: 28 | if net_is_used(start_port): 29 | start_port += 1 30 | else: 31 | return start_port 32 | 33 | 34 | def net_is_used(port, ip='0.0.0.0'): 35 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 36 | try: 37 | s.connect((ip, port)) 38 | s.shutdown(2) 39 | return True 40 | except: 41 | return False 42 | 43 | 44 | class ServerWatcher(threading.Thread): 45 | def __init__(self, LEAST_REQUEST_TIMESTAMP, port=5000): 46 | super().__init__() 47 | self.deque = LEAST_REQUEST_TIMESTAMP 48 | self._stop_flag = False 49 | self.port = port 50 | 51 | def set_server_wait_seconds(self, server_wait_seconds): 52 | self.server_wait_seconds = server_wait_seconds 53 | 54 | def run(self): 55 | while (time.time() - self.deque[0]) < self.server_wait_seconds and not self._stop_flag: 56 | time.sleep(1) 57 | print("This server is going to shut down.") 58 | try: 59 | if not self._stop_flag: # 不是手动关闭的 60 | req = urequest.Request(f'http://127.0.0.1:{self.port}/kill', headers={}, data=''.encode('utf-8')) 61 | page = urequest.urlopen(req).read().decode('utf-8') 62 | except Exception as e: 63 | import traceback 64 | traceback.print_exc() 65 | raise RuntimeError("Error occurred when try to automatically shut down server.") 66 | 67 | def stop(self): 68 | self._stop_flag = True 69 | -------------------------------------------------------------------------------- /toolbox/nn/TransE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class CoreTransE(nn.Module): 7 | def __init__(self): 8 | super(CoreTransE, self).__init__() 9 | 10 | def forward(self, h, r): 11 | x = h + r 12 | x = F.relu(x) 13 | return x 14 | 15 | 16 | class TransE(nn.Module): 17 | def __init__(self, num_entities, num_relations, embedding_dim, hidden_dropout=0.2): 18 | super(TransE, self).__init__() 19 | self.embedding_dim = embedding_dim 20 | self.E = nn.Embedding(num_entities, embedding_dim) 21 | self.R = nn.Embedding(num_relations, embedding_dim) 22 | 23 | self.core = CoreTransE() 24 | self.dropout = nn.Dropout(hidden_dropout) 25 | self.b = nn.Parameter(torch.zeros(num_entities)) 26 | self.m = nn.PReLU() 27 | 28 | self.loss = nn.BCELoss() 29 | 30 | def init(self): 31 | nn.init.xavier_normal_(self.E.weight.data) 32 | nn.init.xavier_normal_(self.R.weight.data) 33 | 34 | def forward(self, h_idx, r_idx): 35 | h = self.E(h_idx) # Bxd 36 | r = self.R(r_idx) # Bxd 37 | 38 | t = self.core(h, r) 39 | t = t.view(-1, self.embedding_dim) 40 | 41 | x = torch.mm(t, self.dropout(self.E.weight).transpose(1, 0)) 42 | x = x + self.b.expand_as(x) 43 | x = torch.sigmoid(x) 44 | return x # batch_size x E 45 | 46 | 47 | class ReverseTransE(nn.Module): 48 | def __init__(self, num_entities, num_relations, embedding_dim, hidden_dropout=0.2): 49 | super(ReverseTransE, self).__init__() 50 | self.embedding_dim = embedding_dim 51 | self.E = nn.Embedding(num_entities, embedding_dim) 52 | self.R = nn.Embedding(num_relations, embedding_dim) 53 | 54 | self.core = CoreTransE() 55 | self.dropout = nn.Dropout(hidden_dropout) 56 | self.b = nn.Parameter(torch.zeros(num_entities)) 57 | self.m = nn.PReLU() 58 | 59 | self.loss = nn.BCELoss() 60 | 61 | def init(self): 62 | nn.init.xavier_normal_(self.E.weight.data) 63 | nn.init.xavier_normal_(self.R.weight.data) 64 | 65 | def forward(self, h_idx, r_idx): 66 | h = self.E(h_idx.view(-1)) # Bxd 67 | 68 | R = torch.cat([self.R.weight, -self.R.weight], dim=0) 69 | r = torch.index_select(R, index=r_idx.view(-1), dim=0) # Bxd 70 | 71 | t = self.core(h, r) 72 | t = t.view(-1, self.embedding_dim) 73 | 74 | x = torch.mm(t, self.dropout(self.E.weight).transpose(1, 0)) 75 | x = x + self.b.expand_as(x) 76 | x = torch.sigmoid(x) 77 | return x # batch_size x E 78 | -------------------------------------------------------------------------------- /toolbox/nn/TuckER.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class CoreTuckER(nn.Module): 7 | def __init__(self, entity_dim, relation_dim, hidden_dropout1=0.4, hidden_dropout2=0.5): 8 | super(CoreTuckER, self).__init__() 9 | self.entity_dim = entity_dim 10 | self.relation_dim = relation_dim 11 | 12 | self.W = nn.Parameter(torch.FloatTensor(np.random.uniform(-0.01, 0.01, (relation_dim, entity_dim, entity_dim)))) 13 | 14 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 15 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 16 | 17 | self.bn0 = nn.BatchNorm1d(entity_dim) 18 | self.bn1 = nn.BatchNorm1d(entity_dim) 19 | 20 | self.m = nn.PReLU() 21 | 22 | def forward(self, h, r): 23 | h = self.bn0(h.view(-1, self.entity_dim)).view(-1, 1, self.entity_dim) 24 | 25 | W = self.W.view(self.relation_dim, -1) 26 | W = torch.mm(r.view(-1, self.relation_dim), W) 27 | W = W.view(-1, self.entity_dim, self.entity_dim) 28 | W = self.hidden_dropout1(W) 29 | 30 | t = torch.bmm(h, W) 31 | t = t.view(-1, self.entity_dim) 32 | t = self.bn1(t) 33 | t = self.hidden_dropout2(t) 34 | t = self.m(t) 35 | return t 36 | 37 | def w(self, h, r): 38 | h = torch.cat([h.transpose(1, 0).unsqueeze(dim=0)] * r.size(0), dim=0) # BxdxE 39 | 40 | W = self.W.view(self.relation_dim, -1) 41 | W = torch.mm(r.view(-1, self.relation_dim), W) 42 | W = W.view(-1, self.entity_dim, self.entity_dim) # Bxdxd 43 | W = self.hidden_dropout1(W) 44 | t = torch.bmm(W, h) # BxdxE 45 | return t 46 | 47 | 48 | class TuckER(nn.Module): 49 | def __init__(self, num_entities, num_relations, entity_dim, relation_dim, input_dropout=0.3, hidden_dropout=0.3, hidden_dropout2=0.3): 50 | super(TuckER, self).__init__() 51 | self.entity_dim = entity_dim 52 | self.relation_dim = relation_dim 53 | 54 | self.E = nn.Embedding(num_entities, entity_dim) 55 | self.R = nn.Embedding(num_relations, relation_dim) 56 | 57 | self.core = CoreTuckER(entity_dim, relation_dim, hidden_dropout, hidden_dropout2) 58 | self.input_dropout = nn.Dropout(input_dropout) 59 | 60 | self.loss = nn.BCELoss() 61 | self.b = nn.Parameter(torch.zeros(num_entities)) 62 | 63 | def init(self): 64 | nn.init.kaiming_uniform_(self.E.weight.data) 65 | nn.init.kaiming_uniform_(self.R.weight.data) 66 | 67 | def forward(self, h_idx, r_idx): 68 | h = self.input_dropout(self.E(h_idx)) 69 | r = self.R(r_idx) 70 | 71 | t = self.core(h, r) 72 | t = t.view(-1, self.entity_dim) 73 | 74 | x = torch.mm(t, self.input_dropout(self.E.weight).transpose(1, 0)) 75 | x = x + self.b.expand_as(x) 76 | x = torch.sigmoid(x) 77 | return x 78 | -------------------------------------------------------------------------------- /toolbox/nn/TuckERTNT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TuckERTNT(nn.Module): 7 | def __init__(self, d, de, dr, dt, device="cpu", input_dropout=0., hidden_dropout1=0., hidden_dropout2=0., **kwargs): 8 | super(TuckERTNT, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Embeddings dimensionality 13 | self.de = de 14 | self.dr = dr 15 | self.dt = dt 16 | 17 | # Data dimensionality 18 | self.ne = len(d.entities) 19 | self.nr = len(d.relations) 20 | self.nt = len(d.time) 21 | 22 | # Embedding matrices 23 | self.E = nn.Embedding(self.ne, de) 24 | self.R = nn.Embedding(self.nr, dr) 25 | self.T = nn.Embedding(self.nt, dt) 26 | 27 | # Core tensor 28 | self.W = nn.Parameter(torch.tensor(np.random.uniform(-0.1, 0.1, (dr, de, dt, de)), dtype=torch.float, device=self.device, requires_grad=True)) 29 | 30 | # "Special" Layers 31 | self.input_dropout = nn.Dropout(input_dropout) 32 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 33 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 34 | self.loss = nn.BCELoss() 35 | 36 | self.bne = nn.BatchNorm1d(de) 37 | 38 | def init(self): 39 | nn.init.xavier_normal_(self.E.weight.data) 40 | nn.init.xavier_normal_(self.R.weight.data) 41 | nn.init.xavier_normal_(self.T.weight.data) 42 | 43 | def forward(self, e1_idx, r_idx, t_idx): 44 | ### Temporal part 45 | # Mode 1 product with entity vector 46 | e1 = self.E(e1_idx) 47 | x = self.bne(e1) 48 | x = self.input_dropout(x) 49 | x = x.view(-1, 1, self.de) # (B, 1, de) 50 | 51 | # Mode 2 product with relation vector 52 | r = self.R(r_idx) # (B, dr) 53 | W_mat = torch.mm(r, self.W.view(r.size(1), -1)) # (B, dr) * (dr, de*de*dt) = (B, de*de*dt) 54 | W_mat = W_mat.view(-1, self.de, self.de * self.dt) # (B, de, de*dt) 55 | x = torch.bmm(x, W_mat) # (B, 1, de) * (B, de, de*dt) = (B, 1, de*dt) 56 | 57 | # Mode 4 product with entity matrix 58 | x = x.view(-1, self.de) # (B, de*dt) -> (B*dt, de) 59 | x = torch.mm(x, self.E.weight.transpose(1, 0)) # (B*dt, de) * (E, de)^T = (B*dt, E) 60 | 61 | # Mode 3 product with time vector 62 | t = self.T(t_idx).view(-1, 1, self.dt) # (B, 1, dt) 63 | xt = x.view(-1, self.dt, self.ne) # (B, dt, E) 64 | xt = torch.bmm(t, xt) # (B, 1, dt) * (B, dt, E) -> (B, 1, E) 65 | xt = xt.view(-1, self.ne) # (B, E) 66 | 67 | ### Non temporal part 68 | # mode 3 product with identity matrix 69 | x = x.view(-1, self.dt) # (B*E, dt) 70 | x = torch.mm(x, torch.ones(self.dt).to(self.device).view(self.dt, 1)) # (B*E, dt) * (dt, 1) = (B*E, 1) 71 | x = x.view(-1, self.ne) # (B, E) 72 | 73 | # Sum of the 2 models 74 | x = x + xt 75 | 76 | # Turn results into "probabilities" 77 | pred = torch.sigmoid(x) 78 | return pred 79 | -------------------------------------------------------------------------------- /toolbox/data/LinkPredictDataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Set 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class LinkPredictDataset(Dataset): 8 | def __init__(self, test_triples_ids: List[Tuple[int, int, int]], hr_t: Dict[Tuple[int, int], Set[int]], max_relation_id: int, entity_count: int): 9 | """ 10 | test_triples_ids: without reverse r 11 | hr_t: all hr->t, MUST with reverse r 12 | """ 13 | self.test_triples_ids = test_triples_ids 14 | self.hr_t = hr_t 15 | self.entity_count = entity_count 16 | self.max_relation_id = max_relation_id 17 | 18 | def __len__(self): 19 | return len(self.test_triples_ids) 20 | 21 | def __getitem__(self, idx): 22 | h, r, t = self.test_triples_ids[idx] 23 | reverse_r = r + self.max_relation_id 24 | 25 | mask_for_hr = torch.zeros(self.entity_count).long() 26 | mask_for_hr[list(self.hr_t[(h, r)])] = 1 27 | mask_for_hr[t] = 0 28 | 29 | mask_for_tReverser = torch.zeros(self.entity_count).long() 30 | mask_for_tReverser[list(self.hr_t[(t, reverse_r)])] = 1 31 | mask_for_tReverser[h] = 0 32 | 33 | h = torch.LongTensor([h]) 34 | r = torch.LongTensor([r]) 35 | t = torch.LongTensor([t]) 36 | reverse_r = torch.LongTensor([reverse_r]) 37 | 38 | return h, r, mask_for_hr, t, reverse_r, mask_for_tReverser 39 | 40 | 41 | class LinkPredictTypeConstraintDataset(Dataset): 42 | def __init__(self, 43 | test_triples_ids: List[Tuple[int, int, int]], 44 | r_t: Dict[int, Set[int]], 45 | hr_t: Dict[Tuple[int, int], Set[int]], 46 | max_relation_id: int, entity_count: int): 47 | """ 48 | test_triples_ids: without reverse r 49 | hr_t: all hr->t, MUST with reverse r 50 | r_t: all r->t, MUST with reverse r 51 | """ 52 | self.test_triples_ids = test_triples_ids 53 | self.r_t = r_t 54 | self.hr_t = hr_t 55 | self.entity_count = entity_count 56 | self.max_relation_id = max_relation_id 57 | self.all_entity_ids = set(range(self.entity_count)) 58 | 59 | def __len__(self): 60 | return len(self.test_triples_ids) 61 | 62 | def __getitem__(self, idx): 63 | h, r, t = self.test_triples_ids[idx] 64 | reverse_r = r + self.max_relation_id 65 | 66 | tail_type_constraint_mask_for_hr = torch.zeros(self.entity_count).long() 67 | valid = list({t} | (self.r_t[r] - self.hr_t[(h, r)])) 68 | tail_type_constraint_mask_for_hr[valid] = 1 69 | 70 | tail_type_constraint_for_tReverser = torch.zeros(self.entity_count).long() 71 | valid = list({h} | (self.r_t[reverse_r] - self.hr_t[(t, reverse_r)])) 72 | tail_type_constraint_for_tReverser[valid] = 1 73 | 74 | h = torch.LongTensor([h]) 75 | r = torch.LongTensor([r]) 76 | t = torch.LongTensor([t]) 77 | reverse_r = torch.LongTensor([reverse_r]) 78 | 79 | return h, r, tail_type_constraint_mask_for_hr, t, reverse_r, tail_type_constraint_for_tReverser 80 | -------------------------------------------------------------------------------- /toolbox/nn/TuckERTTR.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TuckERTTR(nn.Module): 7 | def __init__(self, d, de, dr, dt, ranks, device='cpu', input_dropout=0., hidden_dropout1=0., hidden_dropout2=0., **kwargs): 8 | super(TuckERTTR, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Embeddings dimensionality 13 | self.de = de 14 | self.dr = dr 15 | self.dt = dt 16 | 17 | # Data dimensionality 18 | self.ne = len(d.entities) 19 | self.nr = len(d.relations) 20 | self.nt = len(d.time) 21 | 22 | # Embedding matrices 23 | self.E = nn.Embedding(self.ne, de).to(self.device) 24 | self.R = nn.Embedding(self.nr, dr).to(self.device) 25 | self.T = nn.Embedding(self.nt, dt).to(self.device) 26 | 27 | # Size of Tensor Ring decompostion tensors 28 | ni = [self.dr, self.de, self.de, self.dt] 29 | if isinstance(ranks, int) or isinstance(ranks, np.int64): 30 | ranks = [ranks for _ in range(5)] 31 | elif isinstance(ranks, list) and len(ranks) == 5: 32 | pass 33 | else: 34 | raise TypeError('ranks must be int or list of len 5') 35 | 36 | # List of tensors of the TR 37 | self.Zlist = nn.ParameterList([ 38 | nn.Parameter(torch.tensor(np.random.uniform(-1e-1, 1e-1, (ranks[i], ni[i], ranks[i + 1])), dtype=torch.float, requires_grad=True).to(self.device)) 39 | for i in range(4) 40 | ]) 41 | 42 | # dropout Layers 43 | self.input_dropout = nn.Dropout(input_dropout) 44 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 45 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 46 | 47 | # batchnorm layers 48 | self.bne = nn.BatchNorm1d(de) 49 | 50 | # loss 51 | self.loss = nn.BCELoss() 52 | 53 | def init(self): 54 | nn.init.xavier_normal_(self.E.weight.data) 55 | nn.init.xavier_normal_(self.R.weight.data) 56 | nn.init.xavier_normal_(self.T.weight.data) 57 | 58 | def forward(self, e1_idx, r_idx, t_idx): 59 | 60 | e1 = self.E(e1_idx) 61 | r = self.R(r_idx) 62 | t = self.T(t_idx) 63 | 64 | # Recover core tensor from TR (compute the trace of hadamart (element wise) product of all tensors) 65 | W = torch.einsum('aib,bjc,ckd,dla->ijkl', list(self.Zlist)) 66 | W = W.view(self.dr, self.de, self.de, self.dt) 67 | 68 | # Mode 1 product with entity vector 69 | x = e1 70 | x = x.view(-1, 1, self.de) 71 | 72 | # Mode 2 product with relation vector 73 | W_mat = torch.mm(r, W.view(self.dr, -1)) 74 | W_mat = W_mat.view(-1, self.de, self.de * self.dt) 75 | x = torch.bmm(x, W_mat) 76 | 77 | # Mode 3 product with temporal vector 78 | x = x.view(-1, self.de, self.dt) 79 | x = torch.bmm(x, t.view(*t.shape, -1)) 80 | 81 | # Mode 4 product with entity matrix 82 | x = x.view(-1, self.de) 83 | x = torch.mm(x, self.E.weight.transpose(1, 0)) 84 | 85 | # Turn results into "probabilities" 86 | pred = torch.sigmoid(x) 87 | return pred 88 | -------------------------------------------------------------------------------- /toolbox/nn/Rotate3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/30 5 | @description: null 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | pi = 3.14159262358979323846 11 | 12 | class CoreRotate3D(nn.Module): 13 | def __init__(self, entity_dim): 14 | super(CoreRotate3D, self).__init__() 15 | self.entity_dim = entity_dim 16 | 17 | def forward(self, h, r): 18 | h = h.view(-1, 1, self.entity_dim) 19 | r = r.view(-1, self.entity_dim, self.entity_dim) 20 | 21 | t = torch.bmm(h, r).view(-1, self.entity_dim) 22 | return t 23 | 24 | head_i, head_j, head_k = torch.chunk(head, 3, dim=2) 25 | beta_1, beta_2, theta, bias = torch.chunk(rel, 4, dim=2) 26 | tail_i, tail_j, tail_k = torch.chunk(tail, 3, dim=2) 27 | 28 | bias = torch.abs(bias) 29 | 30 | # Make phases of relations uniformly distributed in [-pi, pi] 31 | beta_1 = beta_1 / (self.embedding_range.item() / self.pi) 32 | beta_2 = beta_2 / (self.embedding_range.item() / self.pi) 33 | theta = theta / (self.embedding_range.item() / self.pi) 34 | cos_theta = torch.cos(theta) 35 | sin_theta = torch.sin(theta) 36 | 37 | # Obtain representation of the rotation axis 38 | rel_i = torch.cos(beta_1) 39 | rel_j = torch.sin(beta_1)*torch.cos(beta_2) 40 | rel_k = torch.sin(beta_1)*torch.sin(beta_2) 41 | 42 | C = rel_i*head_i + rel_j*head_j + rel_k*head_k 43 | C = C*(1-cos_theta) 44 | 45 | # Rotate the head entity 46 | new_head_i = head_i*cos_theta + C*rel_i + sin_theta*(rel_j*head_k-head_j*rel_k) 47 | new_head_j = head_j*cos_theta + C*rel_j - sin_theta*(rel_i*head_k-head_i*rel_k) 48 | new_head_k = head_k*cos_theta + C*rel_k + sin_theta*(rel_i*head_j-head_i*rel_j) 49 | 50 | score_i = new_head_i*bias - tail_i 51 | score_j = new_head_j*bias - tail_j 52 | score_k = new_head_k*bias - tail_k 53 | 54 | score = torch.stack([score_i, score_j, score_k], dim=0) 55 | score = score.norm(dim=0, p=self.p) 56 | score = self.gamma.item() - score.sum(dim=2) 57 | return score 58 | 59 | class Rotate3D(nn.Module): 60 | def __init__(self, num_entities, num_relations, entity_dim, input_dropout=0.3): 61 | super(Rotate3D, self).__init__() 62 | self.entity_dim = entity_dim 63 | 64 | self.E = nn.Embedding(num_entities, entity_dim) 65 | self.R = nn.Embedding(num_relations, entity_dim * entity_dim) 66 | 67 | self.core = CoreRotate3D(entity_dim) 68 | self.input_dropout = nn.Dropout(input_dropout) 69 | 70 | self.loss = nn.BCELoss() 71 | self.b = nn.Parameter(torch.zeros(num_entities)) 72 | 73 | def init(self): 74 | nn.init.kaiming_uniform_(self.E.weight.data) 75 | nn.init.kaiming_uniform_(self.R.weight.data) 76 | 77 | def forward(self, h_idx, r_idx): 78 | h = self.input_dropout(self.E(h_idx)) 79 | r = self.R(r_idx) 80 | 81 | t = self.core(h, r) 82 | t = t.view(-1, self.entity_dim) 83 | 84 | x = torch.mm(t, self.input_dropout(self.E.weight).transpose(1, 0)) 85 | x = x + self.b.expand_as(x) 86 | x = torch.sigmoid(x) 87 | return x 88 | -------------------------------------------------------------------------------- /toolbox/web/log_app/server/data_container.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | 5 | class HandlerWatcher(threading.Thread): 6 | """ 7 | 一个用于监控reader的类。 8 | 9 | """ 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.all_handlers = all_handlers 14 | self._stop_flag = False 15 | self._quit = True 16 | self._start = False 17 | 18 | def run(self): 19 | self._quit = False 20 | self._start = True 21 | while not self._stop_flag: 22 | if len(self.all_handlers) > 0: 23 | for _uuid in list(self.all_handlers.keys()): 24 | handler = self.all_handlers[_uuid] 25 | if handler.reader._quit: 26 | handler.reader.stop() 27 | handler = self.all_handlers.pop(_uuid) 28 | print(f"Delete handler {_uuid}") 29 | del handler 30 | time.sleep(0.5) 31 | # 删除所有的handler 32 | for _uuid in list(self.all_handlers.keys()): 33 | handler = self.all_handlers.pop(_uuid) 34 | if handler.reader._quit: 35 | handler.reader.stop() 36 | print(f"Delete handler {_uuid}") 37 | del handler 38 | 39 | self._quit = True 40 | 41 | def stop(self): 42 | self._stop_flag = True 43 | count = 0 44 | while not self._quit: 45 | time.sleep(0.6) 46 | if count > 3: 47 | raise RuntimeError("Some bug happens.") 48 | count += 1 49 | 50 | 51 | # singleton 52 | """ 53 | all_data包含以下的key: 54 | settings: {} 一级dict,包含了所有的frontend_settings中的内容,value全部是bool值 55 | basic_settings: {} 一级dict包含了config中basic_settings中的setting. 56 | hidden_rows: {} 一级dict key为隐藏的row的id 57 | deleted_rows: {} 一级dict key为删除的row的id 58 | filter_condition: {} 一级dict,expanded的key以及它等于的value. value可以为str或者list[str], list[str]表示满足任意条件 59 | 即可 60 | hidden_columns: {} 一级dict key为隐藏的column. 展平后的值 61 | exclude_columns: {} 一级dict,需要排除的column 62 | editable_columns: {} 一级dict,支持编辑的column名 63 | column_order: {} nested dict. 表示column的顺序的. 类似于{"meta":{"fit_id": xxx, ...}, "metric": {...}, "OrderKeys:["meta", "metric"]"} 64 | field_columns: {}, 一级dict,key是expanded后的且会显示在前端table的名称,比如hyper-lr, id 65 | column_dict: {}, 二级dict,第一级是展开的key, 比如hyper-lr; 第二级是{'title':, 'field':}等用于生成前端header的内容。 66 | chart_settings: {} 保存chart相关的设置,包含以下的内容 67 | chart_exclude_columns:{} 一级dict,需要排除的column名称 68 | max_points:int 前端每条线最多显示多少个点 69 | update_every: int, 隔多少秒update一次 70 | max_no_updates: int 多少次没有得到更新就认为已经停止了 71 | config: 读取的ConfigParser对象 72 | extra_data: {}, 第一层的key为前端增加的记录获取用户修改某条记录留下的修改记录; 第一层的value对应的是一个一级dict。 73 | root_log_dir: str, log文件夹的路径 74 | log_config_name: str, 75 | log_agent: LogAgent()对象 76 | port: int, port 77 | uuid: str, 这个server的uuid 78 | token: str,None 这个server的token,放访问路径上的 79 | data: {id1:{'id':id1, 'field1':xxx,}, id2:{}}, 所有的数据都在这个里面,这是一个一级dict. extra_data已经替换了里面的值 80 | """ 81 | all_data = {} 82 | all_handlers = {} # 存放的是key是一个特有的uuid, value是一个ChartStepLogHandler 83 | handler_watcher = HandlerWatcher() # 对all_handlers中的内容进行监控,如果发现太长时间没有update就把它移除,防止线程爆炸了 84 | -------------------------------------------------------------------------------- /toolbox/utils/LaTeX.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: 这里维护各种模型的复现结果。用户进行实验后,可以拉取实验结果,一键生成对应的 latex 表格,就不用手动抄写到论文中了 6 | https://pandas.pydata.org/docs/reference/api/pandas.io.formats.style.Styler.to_latex.html 7 | """ 8 | from typing import * 9 | 10 | import pandas as pd 11 | 12 | 13 | def QueryEmbeddingLeaderboard(): 14 | header = ["Model", "1p", "2p", "3p", "2i", "3i", "pi", "ip", "2u", "up", "AVG"] 15 | FB15k = [ 16 | ["GQE", 53.9, 15.5, 11.1, 40.2, 52.4, 27.5, 19.4, 22.3, 11.7, 28.2], 17 | ["Q2B", 70.5, 23.0, 15.1, 61.2, 71.8, 41.8, 28.7, 37.7, 19.0, 40.1], 18 | ["BetaE", 65.1, 25.7, 24.7, 55.8, 66.5, 43.9, 28.1, 40.1, 25.2, 41.6], 19 | ["LogicE", 72.3, 29.8, 26.2, 56.1, 66.3, 42.7, 32.6, 43.4, 27.5, 44.1], 20 | ["ConE", 73.3, 33.8, 29.2, 64.4, 73.7, 50.9, 35.7, 55.7, 31.4, 49.8], 21 | ] 22 | FB237 = [ 23 | ["GQE", 35.2, 7.4, 5.5, 23.6, 35.7, 16.7, 10.9, 8.4, 5.8, 16.6], 24 | ["Q2B", 41.3, 9.9, 7.2, 31.1, 45.4, 21.9, 13.3, 11.9, 8.1, 21.1], 25 | ["BetaE", 39.0, 10.9, 10.0, 28.8, 42.5, 22.4, 12.6, 12.4, 9.7, 20.9], 26 | ["LogicE", 41.3, 11.8, 10.4, 31.4, 43.9, 23.8, 14.0, 13.4, 10.2, 22.3], 27 | ["ConE", 41.8, 12.8, 11.0, 32.6, 47.3, 25.5, 14.0, 14.5, 10.8, 23.4], 28 | ["BoolE", 43.3, 13.0, 11.0, 34.5, 48.0, 27.0, 16.7, 15.1, 11.2, 24.4], 29 | ] 30 | NELL = [ 31 | ["GQE", 33.1, 12.1, 9.9, 27.3, 35.1, 18.5, 14.5, 8.5, 9.0, 18.7], 32 | ["Q2B", 42.7, 14.5, 11.7, 34.7, 45.8, 23.2, 17.4, 12.0, 10.7, 23.6], 33 | ["BetaE", 53.0, 13.0, 11.4, 37.6, 47.5, 24.1, 14.3, 12.2, 8.5, 24.6], 34 | ["LogicE", 58.3, 17.7, 15.4, 40.5, 50.4, 27.3, 19.2, 15.9, 12.7, 28.6], 35 | ["ConE", 53.1, 16.1, 13.9, 40.0, 50.8, 26.3, 17.5, 15.3, 11.3, 27.2], 36 | ] 37 | return header, FB15k, FB237, NELL 38 | 39 | 40 | def append_to_QueryEmbeddingLeaderboard(name: str, FB15k_result: List[float], FB237_result: List[float], NELL_result: List[float]): 41 | header, FB15k, FB237, NELL = QueryEmbeddingLeaderboard() 42 | FB15k.append([name] + FB15k_result) 43 | FB237.append([name] + FB237_result) 44 | NELL.append([name] + NELL_result) 45 | return header, FB15k, FB237, NELL 46 | 47 | 48 | def QueryEmbeddingLeaderboard_to_latex_table(name: str, 49 | FB15k_result: List[float], 50 | FB237_result: List[float], 51 | NELL_result: List[float], 52 | output_filename: str = "table.tex"): 53 | "Dataset" 54 | 55 | 56 | def dataframe_to_latex_table(df: pd.DataFrame, output_filename: str = "table.tex"): 57 | columns = list(df.columns) 58 | df.columns = pd.MultiIndex.from_tuples([ 59 | ("Numeric", "Integers"), 60 | ("Numeric", "Floats"), 61 | ("Non-Numeric", "Strings") 62 | ]) 63 | df.index = pd.MultiIndex.from_tuples([ 64 | ("L0", "ix1"), ("L0", "ix2"), ("L1", "ix3") 65 | ]) 66 | s = df.style.highlight_max( 67 | props='cellcolor:[HTML]{FFFF00}; color:{red}; itshape:; bfseries:;' 68 | ) 69 | s.to_latex( 70 | column_format="rrrrr", position="h", position_float="centering", 71 | hrules=True, label="table:5", caption="Styled LaTeX Table", 72 | multirow_align="t", multicol_align="r" 73 | ) 74 | pass 75 | -------------------------------------------------------------------------------- /toolbox/utils/Download.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/17 5 | @description: null 6 | A rudimentary URL downloader (like wget or curl) to demonstrate Rich progress bars. 7 | """ 8 | import os.path 9 | import signal 10 | import sys 11 | from concurrent.futures import ThreadPoolExecutor 12 | from functools import partial 13 | from threading import Event 14 | from typing import Iterable 15 | from urllib.request import urlopen 16 | 17 | from rich.progress import ( 18 | BarColumn, 19 | DownloadColumn, 20 | Progress, 21 | TaskID, 22 | TextColumn, 23 | TimeRemainingColumn, 24 | TransferSpeedColumn, 25 | ) 26 | 27 | 28 | class DownloadManager: 29 | def __init__(self): 30 | self.progress = Progress( 31 | TextColumn("[bold blue]{task.fields[filename]}", justify="right"), 32 | BarColumn(bar_width=None), 33 | "[progress.percentage]{task.percentage:>3.1f}%", 34 | "•", 35 | DownloadColumn(), 36 | "•", 37 | TransferSpeedColumn(), 38 | "•", 39 | TimeRemainingColumn(), 40 | ) 41 | 42 | self.done_event = Event() 43 | 44 | def handle_sigint(signum, frame): 45 | self.done_event.set() 46 | 47 | signal.signal(signal.SIGINT, handle_sigint) 48 | 49 | def copy_url(self, task_id: TaskID, url: str, path: str) -> None: 50 | """Copy data from a url to a local file.""" 51 | self.progress.console.log(f"Requesting {url}") 52 | response = urlopen(url) 53 | # This will break if the response doesn't contain content length 54 | self.progress.update(task_id, total=int(response.info()["Content-length"])) 55 | with open(path, "wb") as dest_file: 56 | self.progress.start_task(task_id) 57 | for data in iter(partial(response.read, 32768), b""): 58 | dest_file.write(data) 59 | self.progress.update(task_id, advance=len(data)) 60 | if self.done_event.is_set(): 61 | return 62 | self.progress.console.log(f"Downloaded {path}") 63 | 64 | def download(self, urls: Iterable[str], dest_dir: str): 65 | """Download multuple files to the given directory.""" 66 | 67 | with self.progress: 68 | with ThreadPoolExecutor(max_workers=4) as pool: 69 | for url in urls: 70 | filename = url.split("/")[-1] 71 | dest_path = os.path.join(dest_dir, filename) 72 | task_id = self.progress.add_task("download", filename=filename, start=False) 73 | pool.submit(self.copy_url, task_id, url, dest_path) 74 | 75 | def download_to_path(self, urls: Iterable[str], dest_path: str): 76 | """Download multuple files to the given directory.""" 77 | 78 | with self.progress: 79 | with ThreadPoolExecutor(max_workers=4) as pool: 80 | for url in urls: 81 | filename = url.split("/")[-1] 82 | task_id = self.progress.add_task("download", filename=filename, start=False) 83 | pool.submit(self.copy_url, task_id, url, dest_path) 84 | 85 | 86 | if __name__ == "__main__": 87 | # Try with https://releases.ubuntu.com/20.04/ubuntu-20.04.3-desktop-amd64.iso 88 | if sys.argv[1:]: 89 | DownloadManager().download(sys.argv[1:], "./") 90 | else: 91 | print("Usage:\n\tpython downloader.py URL1 URL2 URL3 (etc)") 92 | -------------------------------------------------------------------------------- /toolbox/web/log_app/multi_char_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | from flask import Blueprint 5 | from flask import render_template 6 | from flask import request, jsonify 7 | 8 | from .log_read import is_log_dir_has_step, is_log_record_finish 9 | from .server.chart_utils import MultiChartStepLogHandler 10 | from .server.data_container import all_data, all_handlers, handler_watcher 11 | from .server.utils import check_uuid 12 | 13 | multi_chart_page = Blueprint("multi_chart_page", __name__, template_folder='templates') 14 | 15 | 16 | @multi_chart_page.route('/multi_chart', methods=['POST']) 17 | def chart(): 18 | res = check_uuid(all_data['uuid'], request.values['uuid']) 19 | if res is not None: 20 | return jsonify(res) 21 | 22 | logs = request.values['ids'].split(',') # [] 23 | titles = request.values['titles'].split(',') # TODO 如何对比loss呢 24 | root_log_dir = all_data['root_log_dir'] 25 | 26 | # check是否具有metric的记录 27 | has_step_logs = [] 28 | finish_logs = [] 29 | for log in logs: 30 | full_log_path = os.path.join(root_log_dir, log) 31 | if is_log_dir_has_step(full_log_path, ('metric.log',)): 32 | has_step_logs.append(log) 33 | if is_log_record_finish(full_log_path): 34 | finish_logs.append(log) 35 | 36 | msg = '' 37 | results = {} 38 | multi_chart_uuid = str(uuid.uuid1()) 39 | find_titles = [] 40 | if len(has_step_logs) > 1: # 至少得有两个吧 41 | # 需要读取数据 42 | update_every = all_data['chart_settings']['update_every'] 43 | wait_seconds = update_every * 5 # 如果本来应该收到三次更新,但是却没有收到,则自动关闭 44 | max_no_updates = all_data['chart_settings']['max_no_updates'] 45 | 46 | handler = MultiChartStepLogHandler(root_log_dir, has_step_logs, multi_chart_uuid, 47 | titles=titles, round_to=all_data['basic_settings']['round_to'], 48 | wait_seconds=wait_seconds, max_no_updates=max_no_updates) 49 | results = handler.update_logs(handler_names=('metric',)) 50 | all_handlers[multi_chart_uuid] = handler 51 | if not handler_watcher._start: 52 | handler_watcher.start() 53 | 54 | for title in titles: 55 | if title in results: 56 | find_titles.append(title) 57 | if len(find_titles) == 0: 58 | msg = 'No log has step information.' 59 | 60 | results['update_every'] = update_every 61 | results['max_no_updates'] = max_no_updates 62 | results['multi_chart_uuid'] = multi_chart_uuid 63 | 64 | else: 65 | # 没有的话怎么办 66 | msg = 'Less than 2 logs have step information.' 67 | 68 | return render_template('multi_chart.html', data=results, message=msg, multi_chart_uuid=multi_chart_uuid, 69 | titles=','.join(find_titles), logs=request.values['ids']) 70 | 71 | 72 | @multi_chart_page.route('/multi_chart/new_step', methods=['POST']) 73 | def chart_new_step(): 74 | # 获取某个log_dir的更新 75 | multi_chart_uuid = request.json['multi_chart_uuid'] 76 | 77 | try: 78 | results = {} 79 | if multi_chart_uuid in all_handlers: 80 | handler = all_handlers[multi_chart_uuid] 81 | results = handler.update_logs() 82 | print(results) 83 | return jsonify(data=results, status='success') 84 | except Exception as e: 85 | import traceback 86 | traceback.print_exc() 87 | return jsonify({'status': 'fail', 'message': f"Exception occurred in the server: {str(e)}."}) 88 | -------------------------------------------------------------------------------- /toolbox/utils/Log.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 日志 6 | """ 7 | import logging 8 | 9 | # logging.basicConfig(format='%(message)s', level=logging.INFO) 10 | DEBUG_SUCCESS_NUM = 1001 11 | DEBUG_FAILED_NUM = 1002 12 | logging.addLevelName(DEBUG_SUCCESS_NUM, "SUCCESS") 13 | logging.addLevelName(DEBUG_FAILED_NUM, "FAILED") 14 | 15 | 16 | def debug_success(self, message, *args, **kws): 17 | if self.isEnabledFor(DEBUG_SUCCESS_NUM): 18 | self._log(DEBUG_SUCCESS_NUM, message, args, **kws) 19 | 20 | 21 | def debug_failed(self, message, *args, **kws): 22 | if self.isEnabledFor(DEBUG_FAILED_NUM): 23 | self._log(DEBUG_FAILED_NUM, message, args, **kws) 24 | 25 | 26 | logging.Logger.success = debug_success 27 | logging.Logger.failed = debug_failed 28 | 29 | 30 | class ColorFormatter(logging.Formatter): 31 | """Logging Formatter to add colors and count warning / errors""" 32 | 33 | blue = "\x1b[34m" 34 | cyan = "\x1b[36;1m" 35 | green = "\x1b[32;1m" 36 | orange = "\x1b[33;21m" 37 | grey = "\x1b[38;21m" 38 | yellow = "\x1b[33;21m" 39 | red = "\x1b[31;21m" 40 | bold_red = "\x1b[31;1m" 41 | reset = "\x1b[0m" 42 | 43 | time_prefix = "[%(asctime)s]" 44 | filename_prefix = " (%(filename)s:%(lineno)d) " 45 | msg = "%(message)s" 46 | 47 | prefix = orange + time_prefix + reset + grey + filename_prefix + reset 48 | 49 | FORMATS = { 50 | logging.DEBUG: prefix + blue + msg + reset, 51 | logging.INFO: prefix + cyan + msg + reset, 52 | logging.WARNING: prefix + yellow + msg + reset, 53 | logging.ERROR: prefix + red + msg + reset, 54 | logging.CRITICAL: prefix + bold_red + msg + reset, 55 | DEBUG_SUCCESS_NUM: prefix + green + msg + reset, 56 | DEBUG_FAILED_NUM: prefix + bold_red + msg + reset, 57 | } 58 | 59 | def format(self, record): 60 | log_fmt = self.FORMATS.get(record.levelno) 61 | formatter = logging.Formatter(log_fmt) 62 | return formatter.format(record) 63 | 64 | 65 | def Log(filename: str, name_scope="0", write_to_console=True): 66 | """Return instance of logger 统一的日志样式 67 | 68 | Examples: 69 | >>> from toolbox.utils.Log import Log 70 | >>> log = Log("./train.log") 71 | >>> log.debug("debug message") 72 | >>> log.info("info message") 73 | >>> log.warning("warning message") 74 | >>> log.error("error message") 75 | >>> log.critical("critical message") 76 | """ 77 | logger = logging.getLogger('log-%s' % name_scope) 78 | logger.setLevel(logging.DEBUG) 79 | 80 | file_handler = logging.FileHandler(filename) 81 | file_handler.setLevel(logging.DEBUG) 82 | file_handler.setFormatter(logging.Formatter('[%(asctime)s] p%(process)s (%(filename)s:%(lineno)d) - %(message)s', '%m-%d %H:%M:%S')) 83 | logger.addHandler(file_handler) 84 | 85 | if write_to_console: 86 | console_handler = logging.StreamHandler() 87 | console_handler.setLevel(logging.DEBUG) 88 | console_handler.setFormatter(ColorFormatter()) 89 | logger.addHandler(console_handler) 90 | 91 | return logger 92 | 93 | 94 | def log_result(logger, result): 95 | """ 96 | :param logger: from toolbox.utils.Log() 97 | :param result: from toolbox.Evaluate.evaluate() 98 | """ 99 | from toolbox.evaluate.Evaluate import pretty_print 100 | pretty_print(result, logger.info) 101 | -------------------------------------------------------------------------------- /toolbox/nn/TuckERTTT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TuckERTTT(nn.Module): 7 | def __init__(self, d, de, dr, dt, ranks, device='cpu', input_dropout=0., hidden_dropout1=0., hidden_dropout2=0., **kwargs): 8 | super(TuckERTTT, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Embeddings dimensionality 13 | self.de = de 14 | self.dr = dr 15 | self.dt = dt 16 | 17 | # Data dimensionality 18 | self.ne = len(d.entities) 19 | self.nr = len(d.relations) 20 | self.nt = len(d.time) 21 | 22 | # Embedding matrices 23 | self.E = nn.Embedding(self.ne, de).to(self.device) 24 | self.R = nn.Embedding(self.nr, dr).to(self.device) 25 | self.T = nn.Embedding(self.nt, dt).to(self.device) 26 | 27 | ## Core Tensor 28 | # Size of Tensor Ring decompostion tensors 29 | ni = [self.dr, self.de, self.de, self.dt] 30 | if isinstance(ranks, int) or isinstance(ranks, np.int64): 31 | ranks = [ranks for _ in range(3)] 32 | elif isinstance(ranks, list) and len(ranks) == 3: 33 | pass 34 | else: 35 | raise TypeError('ranks must be int or list of len 3') 36 | 37 | list_tmp = [1] 38 | list_tmp.extend(ranks) 39 | list_tmp.append(1) 40 | ranks = list_tmp 41 | 42 | # List of tensors of the tensor train 43 | self.Glist = nn.ParameterList([ 44 | nn.Parameter(torch.tensor(np.random.uniform(-1e-1, 1e-1, (ranks[i], ni[i], ranks[i + 1])), dtype=torch.float, requires_grad=True).to(self.device)) 45 | for i in range(4) 46 | ]) 47 | 48 | # "Special" Layers 49 | self.input_dropout = nn.Dropout(input_dropout) 50 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 51 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 52 | self.loss = nn.BCELoss() 53 | 54 | self.bne = nn.BatchNorm1d(de) 55 | self.bnr = nn.BatchNorm1d(dr) 56 | self.bnt = nn.BatchNorm1d(dt) 57 | 58 | def init(self): 59 | nn.init.xavier_normal_(self.E.weight.data) 60 | nn.init.xavier_normal_(self.R.weight.data) 61 | nn.init.xavier_normal_(self.T.weight.data) 62 | 63 | def forward(self, e1_idx, r_idx, t_idx): 64 | 65 | e1 = self.E(e1_idx) 66 | r = self.R(r_idx) 67 | t = self.T(t_idx) 68 | 69 | # Product between embedding matrices and TT-cores 70 | # RG1 = torch.mm(r,self.Glist[0].view(-1,self.dr)) 71 | # G2E = torch.einsum('aib,ic->acb',(self.Glist[1],e1)) 72 | # G3E = torch.einsum('aib,ic->acb',(self.Glist[2],self.E)) 73 | # TG4 = torch.mm(t,self.Glist[3].view(-1,self.dr).transpose()) 74 | 75 | W = torch.einsum('i1,1j2,3k4,4l->ijkl', list(self.Glist)) 76 | W = W.view(self.dr, self.de, self.de, self.dt) 77 | 78 | # Mode 1 product with entity vector 79 | x = e1 80 | x = x.view(-1, 1, self.de) 81 | 82 | # Mode 2 product with relation vector 83 | W_mat = torch.mm(r, W.view(self.dr, -1)) 84 | W_mat = W_mat.view(-1, self.de, self.de * self.dt) 85 | x = torch.bmm(x, W_mat) 86 | 87 | # Mode 3 product with temporal vector 88 | t = self.T(t_idx) 89 | x = x.view(-1, self.de, self.dt) 90 | x = torch.bmm(x, t.view(*t.shape, -1)) 91 | 92 | # Mode 4 product with entity matrix 93 | x = x.view(-1, self.de) 94 | x = torch.mm(x, self.E.weight.transpose(1, 0)) 95 | 96 | # Turn results into "probabilities" 97 | pred = torch.sigmoid(x) 98 | return pred 99 | -------------------------------------------------------------------------------- /toolbox/nn/CoPER.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | from typing import List 4 | 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class ContextualParameterGenerator(nn.Module): 10 | def __init__(self, feature_in_dim: int, shape: List[int]): 11 | super(ContextualParameterGenerator, self).__init__() 12 | self.feature_in_dim = feature_in_dim 13 | self.shape = shape 14 | self.feature_out_dim = reduce(mul, shape, 1) 15 | self.hidden_dim = 200 16 | self.generate = nn.Sequential( 17 | nn.Linear(self.feature_in_dim, self.hidden_dim), 18 | nn.Dropout(0.1), 19 | nn.ReLU(), 20 | nn.Linear(self.hidden_dim, self.feature_out_dim), 21 | nn.Dropout(0.1), 22 | nn.ReLU(), 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.generate(x) 27 | return x.view(-1, *self.shape) 28 | 29 | 30 | class CoreCoPER(nn.Module): 31 | def __init__(self, 32 | feature_in_dim: int, 33 | feature_out_dim: int, 34 | generate_dim: int, 35 | hidden_dropout1: float = 0.2, 36 | hidden_dropout2: float = 0.2): 37 | super(CoreCoPER, self).__init__() 38 | self.feature_in_dim = feature_in_dim 39 | self.feature_out_dim = feature_out_dim 40 | self.generate_dim = generate_dim 41 | 42 | self.conv_in_height = 10 43 | self.conv_in_width = self.feature_in_dim // 10 44 | 45 | self.conv_filter_height = 3 46 | self.conv_filter_width = 3 47 | self.conv_num_channels = 32 48 | 49 | self.conv_out_height = self.conv_in_height - self.conv_filter_height + 1 50 | self.conv_out_width = self.conv_in_width - self.conv_filter_width + 1 51 | 52 | self.fc_input_dim = self.conv_out_height * self.conv_out_width * self.conv_num_channels 53 | 54 | self.generate_conv_weight = ContextualParameterGenerator(self.generate_dim, [self.conv_num_channels, 1, self.conv_filter_height, self.conv_filter_width]) # conv weight 55 | self.generate_conv_bias = ContextualParameterGenerator(self.generate_dim, [self.conv_num_channels]) # conv bias 56 | self.generate_fc_weight = ContextualParameterGenerator(self.generate_dim, [self.fc_input_dim, self.feature_out_dim]) # fc weight 57 | self.generate_fc_bias = ContextualParameterGenerator(self.generate_dim, [1, self.feature_out_dim]) # fc bias 58 | 59 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 60 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 61 | 62 | self.bn1 = nn.BatchNorm1d(self.feature_out_dim) 63 | 64 | self.m = nn.PReLU() 65 | 66 | def forward(self, input_embedding, generate_embedding): 67 | """ 68 | input_embedding: batch_size x self.feature_in_dim 69 | generate_embedding: batch_size x self.generate_dim 70 | """ 71 | img = input_embedding.view(1, -1, self.conv_in_height, self.conv_in_width) 72 | 73 | r = generate_embedding 74 | batch_size = generate_embedding.size(0) 75 | 76 | conv_weight = self.generate_conv_weight(r).view(-1, 1, self.conv_filter_height, self.conv_filter_width) 77 | conv_bias = self.generate_conv_bias(r).view(-1) 78 | fc_weight = self.generate_fc_weight(r) 79 | fc_bias = self.generate_fc_bias(r) 80 | 81 | x = F.conv2d(img, conv_weight, bias=conv_bias, groups=batch_size) 82 | x = F.relu(x) 83 | x = self.hidden_dropout1(x) 84 | x = x.view(-1, 1, self.fc_input_dim) 85 | 86 | x = x.bmm(fc_weight) + fc_bias 87 | x = x.view(-1, self.feature_out_dim) 88 | x = self.bn1(x) 89 | x = self.hidden_dropout2(x) 90 | x = self.m(x) 91 | return x 92 | -------------------------------------------------------------------------------- /toolbox/nn/TuckERCPD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TuckERCPD(torch.nn.Module): 7 | def __init__(self, d, de, dr, dt, device="cpu", **kwargs): 8 | super(TuckERCPD, self).__init__() 9 | 10 | self.device = device 11 | 12 | # Embeddings dimensionality 13 | self.de = de 14 | self.dr = dr 15 | self.dt = dt 16 | 17 | # Data dimensionality 18 | self.ne = len(d.entities) 19 | self.nr = len(d.relations) 20 | self.nt = len(d.time) 21 | 22 | # CPD rank 23 | self.p = min(self.de, self.dt, self.dr) 24 | 25 | # Embedding matrices 26 | self.E = torch.nn.Embedding(self.ne, de).to(self.device) 27 | self.R = torch.nn.Embedding(self.nr, dr).to(self.device) 28 | self.T = torch.nn.Embedding(self.nt, dt).to(self.device) 29 | 30 | ### CPD Decomp of core tensor of Tucker 31 | # Core identity tensor of CPD 32 | self.I = torch.zeros(*[self.p for _ in range(4)]).to(self.device) 33 | for i in range(self.p): 34 | self.I[i, i, i, i] = 1 35 | 36 | # Factors of CPD 37 | self.Flist = torch.nn.ParameterList([ 38 | torch.nn.Parameter(torch.tensor(np.random.uniform(-1e-1, 1e-1, (d, self.p)), dtype=torch.float, requires_grad=True).to(self.device)) 39 | for d in [self.dr, self.de, self.de, self.dt] 40 | ]) 41 | 42 | # Dropout Layers 43 | self.input_dropout = torch.nn.Dropout(kwargs["input_dropout"]) 44 | self.hidden_dropout1 = torch.nn.Dropout(kwargs["hidden_dropout1"]) 45 | self.hidden_dropout2 = torch.nn.Dropout(kwargs["hidden_dropout2"]) 46 | self.hidden_dropout3 = torch.nn.Dropout(kwargs["hidden_dropout3"]) 47 | 48 | # batchnorm layers 49 | self.bne = torch.nn.BatchNorm1d(self.de) 50 | self.bnp3 = torch.nn.BatchNorm1d(self.p ** 3) 51 | self.bnp2 = torch.nn.BatchNorm1d(self.p ** 2) 52 | self.bnp = torch.nn.BatchNorm1d(self.p) 53 | 54 | # Loss 55 | self.loss = torch.nn.BCELoss() 56 | 57 | def init(self): 58 | nn.init.xavier_normal_(self.E.weight.data) 59 | nn.init.xavier_normal_(self.R.weight.data) 60 | nn.init.xavier_normal_(self.T.weight.data) 61 | 62 | for i in range(len(self.Flist)): 63 | nn.init.xavier_normal_(self.Flist[i]) 64 | 65 | def forward(self, e1_idx, r_idx, t_idx): 66 | 67 | # Select corresponding embeddings 68 | e1 = self.E(e1_idx) 69 | e1 = self.bne(e1) 70 | e1 = self.input_dropout(e1) 71 | 72 | r = self.R(r_idx) 73 | t = self.T(t_idx) 74 | 75 | # Compute intermediate factor matrices from embeddings and factor of core tensor 76 | 77 | fr = torch.mm(r, self.Flist[0]) 78 | fe1 = torch.mm(e1, self.Flist[1]) 79 | FE = torch.mm(self.E.weight, self.Flist[2]) 80 | ft = torch.mm(t, self.Flist[3]) 81 | 82 | ### Recover tensor 83 | 84 | # Mode 1 product with intermediate relaton vecot 85 | x = torch.mm(fr, self.I.view(self.p, -1)) 86 | x = self.bnp3(x) 87 | x = self.hidden_dropout1(x) 88 | 89 | x = x.view(-1, self.p, self.p ** 2) 90 | x = torch.bmm(fe1.view(-1, 1, self.p), x).view(-1, self.p ** 2) 91 | x = self.bnp2(x) 92 | x = self.hidden_dropout2(x) 93 | 94 | x = x.view(-1, self.p, self.p) 95 | x = torch.bmm(x, ft.view(*ft.shape, -1)).view(-1, self.p) 96 | x = self.bnp(x) 97 | x = self.hidden_dropout3(x) 98 | 99 | x = torch.mm(x, FE.transpose(1, 0)) 100 | 101 | # Turn results into "probabilities" 102 | pred = torch.sigmoid(x) 103 | return pred 104 | -------------------------------------------------------------------------------- /toolbox/exp/Experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional, List 3 | 4 | from toolbox.exp.DistributeSchema import DistributeSchema 5 | from toolbox.exp.OutputSchema import OutputSchema 6 | from toolbox.utils.LaTeXSotre import EvaluateLaTeXStoreSchema 7 | from toolbox.utils.MetricLogStore import MetricLogStoreSchema 8 | from toolbox.utils.ModelParamStore import ModelParamStoreSchema 9 | from toolbox.utils.VisualizeStore import VisualizeStoreSchema 10 | 11 | 12 | class Experiment: 13 | 14 | def __init__(self, output: OutputSchema, local_rank: int = -1, gpus: Optional[List[int]] = None): 15 | self.output = output 16 | self.distribute = DistributeSchema(local_rank, gpus) 17 | self.debug = self.log_in_main_node(output.logger.debug) 18 | self.log = self.log_in_main_node(output.logger.info) 19 | self.warn = self.log_in_main_node(output.logger.warn) 20 | self.error = self.log_in_main_node(output.logger.error) 21 | self.critical = self.log_in_main_node(output.logger.critical) 22 | self.success = self.log_in_main_node(output.logger.success) 23 | self.fail = self.log_in_main_node(output.logger.failed) 24 | self.visualize_store = VisualizeStoreSchema(str(output.pathSchema.dir_path_visualize)) 25 | self.model_param_store = ModelParamStoreSchema(output.pathSchema) 26 | self.metric_log_store = MetricLogStoreSchema(str(output.pathSchema.dir_path_log)) 27 | self.latex_store = EvaluateLaTeXStoreSchema(output.pathSchema) 28 | 29 | def re_init(self, output: OutputSchema, local_rank: int = -1, gpus: Optional[List[int]] = None): 30 | self.output = output 31 | self.distribute = DistributeSchema(local_rank, gpus) 32 | self.debug = self.log_in_main_node(output.logger.debug) 33 | self.log = self.log_in_main_node(output.logger.info) 34 | self.warn = self.log_in_main_node(output.logger.warn) 35 | self.error = self.log_in_main_node(output.logger.error) 36 | self.critical = self.log_in_main_node(output.logger.critical) 37 | self.success = self.log_in_main_node(output.logger.success) 38 | self.fail = self.log_in_main_node(output.logger.failed) 39 | self.visualize_store = VisualizeStoreSchema(str(output.pathSchema.dir_path_visualize)) 40 | self.model_param_store = ModelParamStoreSchema(output.pathSchema) 41 | self.metric_log_store = MetricLogStoreSchema(str(output.pathSchema.dir_path_log)) 42 | self.latex_store = EvaluateLaTeXStoreSchema(output.pathSchema) 43 | 44 | def log_in_main_node(self, log_func): 45 | if self.distribute.running_in_main_node(): 46 | return log_func 47 | return lambda x: [x] 48 | 49 | def dump_model(self, model): 50 | self.debug("Model Structure".center(50, "-")) 51 | self.debug(model) 52 | if hasattr(model, "print_hyperparams"): 53 | self.debug("Model Hyperparams".center(50, "-")) 54 | model.print_hyperparams(self.debug) 55 | self.debug("Trainable parameters".center(50, "-")) 56 | num_params = 0 57 | for name, param in model.named_parameters(): 58 | if param.requires_grad: 59 | ps = np.prod(param.size()) 60 | num_params += ps 61 | self.debug(f"{name}: {sizeof_fmt(ps)}") 62 | self.log('Total Parameters: %s (%d)' % (sizeof_fmt(num_params), num_params)) 63 | self.debug("-" * 50) 64 | 65 | def __repr__(self): 66 | return f"{self.__class__.__name__}({self.output}, {self.distribute})" 67 | 68 | 69 | def sizeof_fmt(num, suffix='B'): 70 | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: 71 | if abs(num) < 1024.0: 72 | return "%3.1f%s%s" % (num, unit, suffix) 73 | num /= 1024.0 74 | return "%.1f%s%s" % (num, 'Yi', suffix) 75 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/js/bootstrap-table-reorder-rows.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @author: Dennis Hernández 3 | * @webSite: http://djhvscf.github.io/Blog 4 | * @version: v1.0.1 5 | */ 6 | 7 | (function ($) { 8 | 9 | 'use strict'; 10 | 11 | var isSearch = false; 12 | 13 | var rowAttr = function (row, index) { 14 | return { 15 | id: 'customId_' + index 16 | }; 17 | }; 18 | 19 | $.extend($.fn.bootstrapTable.defaults, { 20 | reorderableRows: false, 21 | onDragStyle: null, 22 | onDropStyle: null, 23 | onDragClass: "reorder_rows_onDragClass", 24 | dragHandle: null, 25 | useRowAttrFunc: false, 26 | onReorderRowsDrag: function (table, row) { 27 | return false; 28 | }, 29 | onReorderRowsDrop: function (table, row) { 30 | return false; 31 | }, 32 | onReorderRow: function (newData) { 33 | return false; 34 | } 35 | }); 36 | 37 | $.extend($.fn.bootstrapTable.Constructor.EVENTS, { 38 | 'reorder-row.bs.table': 'onReorderRow' 39 | }); 40 | 41 | var BootstrapTable = $.fn.bootstrapTable.Constructor, 42 | _init = BootstrapTable.prototype.init, 43 | _initSearch = BootstrapTable.prototype.initSearch; 44 | 45 | BootstrapTable.prototype.init = function () { 46 | 47 | if (!this.options.reorderableRows) { 48 | _init.apply(this, Array.prototype.slice.apply(arguments)); 49 | return; 50 | } 51 | 52 | var that = this; 53 | if (this.options.useRowAttrFunc) { 54 | this.options.rowAttributes = rowAttr; 55 | } 56 | 57 | var onPostBody = this.options.onPostBody; 58 | this.options.onPostBody = function () { 59 | setTimeout(function () { 60 | that.makeRowsReorderable(); 61 | onPostBody.apply(); 62 | }, 1); 63 | }; 64 | 65 | _init.apply(this, Array.prototype.slice.apply(arguments)); 66 | }; 67 | 68 | BootstrapTable.prototype.initSearch = function () { 69 | _initSearch.apply(this, Array.prototype.slice.apply(arguments)); 70 | 71 | if (!this.options.reorderableRows) { 72 | return; 73 | } 74 | 75 | //Known issue after search if you reorder the rows the data is not display properly 76 | //isSearch = true; 77 | }; 78 | 79 | BootstrapTable.prototype.makeRowsReorderable = function () { 80 | if (this.options.cardView) { 81 | return; 82 | } 83 | 84 | var that = this; 85 | this.$el.tableDnD({ 86 | onDragStyle: that.options.onDragStyle, 87 | onDropStyle: that.options.onDropStyle, 88 | onDragClass: that.options.onDragClass, 89 | onDrop: that.onDrop, 90 | onDragStart: that.options.onReorderRowsDrag, 91 | dragHandle: that.options.dragHandle 92 | }); 93 | }; 94 | 95 | BootstrapTable.prototype.onDrop = function (table, droppedRow) { 96 | var tableBs = $(table), 97 | tableBsData = tableBs.data('bootstrap.table'), 98 | tableBsOptions = tableBs.data('bootstrap.table').options, 99 | row = null, 100 | newData = []; 101 | 102 | for (var i = 0; i < table.tBodies[0].rows.length; i++) { 103 | row = $(table.tBodies[0].rows[i]); 104 | newData.push(tableBsOptions.data[row.data('index')]); 105 | row.data('index', i).attr('data-index', i); 106 | } 107 | 108 | tableBsOptions.data = tableBsOptions.data.slice(0, tableBsData.pageFrom - 1) 109 | .concat(newData) 110 | .concat(tableBsOptions.data.slice(tableBsData.pageTo)); 111 | 112 | //Call the user defined function 113 | tableBsOptions.onReorderRowsDrop.apply(table, [table, droppedRow]); 114 | 115 | //Call the event reorder-row 116 | tableBsData.trigger('reorder-row', newData); 117 | }; 118 | })(jQuery); -------------------------------------------------------------------------------- /toolbox/nn/BlaschkE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from toolbox.nn.ComplexEmbedding import ComplexEmbedding, ComplexDropout, ComplexScoringAll, ComplexBatchNorm1d, ComplexMult, ComplexAdd, ComplexDiv 6 | 7 | 8 | class BlaschkeMult(nn.Module): 9 | """ 10 | h - r 11 | h * r = -------------- 12 | r_conj * h - 1 13 | h in C^d 14 | r in C^d 15 | """ 16 | 17 | def __init__(self, norm_flag=False): 18 | super(BlaschkeMult, self).__init__() 19 | self.flag_hamilton_mul_norm = norm_flag 20 | self.complex_mul = ComplexMult(norm_flag) 21 | self.complex_add = ComplexAdd() 22 | self.complex_div = ComplexDiv() 23 | 24 | def forward(self, h, r): 25 | r_a, r_b = r 26 | # h_a, h_b = h 27 | # r = F.normalize(torch.cat([r_a.unsqueeze(dim=1), r_b.unsqueeze(dim=1)], dim=1), dim=1, p=2) 28 | # r_a = r[:, 0, :] 29 | # r_b = r[:, 1, :] 30 | # h = F.normalize(torch.cat([h_a.unsqueeze(dim=1), h_b.unsqueeze(dim=1)], dim=1), dim=1, p=2) 31 | # h_a = h[:, 0, :] 32 | # h_b = h[:, 1, :] 33 | # print(r_a.size(), r_b.size(), h_a.size(), h_b.size()) 34 | # r_norm = torch.sqrt(r_a ** 2 + r_b ** 2) 35 | # r_a = r_a / r_norm 36 | # r_b = r_b / r_norm 37 | # h_norm = torch.sqrt(h_a ** 2 + h_b ** 2) 38 | # h_a = h_a / h_norm 39 | # h_b = h_b / h_norm 40 | # h = (h_a, h_b) 41 | 42 | neg_r = (-r_a, -r_b) 43 | hr_top = self.complex_add(h, neg_r) 44 | 45 | neg_one = (-torch.ones_like(r_a), torch.zeros_like(r_b)) # -1 = -1 + 0 i = (-1, 0) 46 | conjugate_r = (r_a, -r_b) 47 | hr_bottom = self.complex_add(self.complex_mul(h, conjugate_r), neg_one) 48 | 49 | h_r = self.complex_div(hr_top, hr_bottom) 50 | return h_r 51 | 52 | 53 | class BlaschkE(nn.Module): 54 | 55 | def __init__(self, 56 | num_entities, num_relations, 57 | embedding_dim, 58 | norm_flag=False, input_dropout=0.2, hidden_dropout=0.3): 59 | super(BlaschkE, self).__init__() 60 | self.embedding_dim = embedding_dim 61 | self.num_entities = num_entities 62 | self.num_relations = num_relations 63 | self.loss = nn.BCELoss() 64 | self.flag_hamilton_mul_norm = norm_flag 65 | self.E = ComplexEmbedding(self.num_entities, self.embedding_dim, 2) # a + bi 66 | self.R = ComplexEmbedding(self.num_relations, self.embedding_dim, 2) # a + bi 67 | self.E_dropout = ComplexDropout([input_dropout, input_dropout]) 68 | self.R_dropout = ComplexDropout([input_dropout, input_dropout]) 69 | self.hidden_dp = ComplexDropout([hidden_dropout, hidden_dropout]) 70 | self.E_bn = ComplexBatchNorm1d(self.embedding_dim, 2) 71 | self.R_bn = ComplexBatchNorm1d(self.embedding_dim, 4) 72 | 73 | self.mul = BlaschkeMult(norm_flag) 74 | self.scoring_all = ComplexScoringAll() 75 | 76 | def forward(self, h_idx, r_idx): 77 | return self.forward_head_batch(h_idx.view(-1), r_idx.view(-1)) 78 | 79 | def forward_head_batch(self, h_idx, r_idx): 80 | """ 81 | Completed. 82 | Given a head entity and a relation (h,r), we compute scores for all possible triples,i.e., 83 | [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) 84 | Given a batch of head entities and relations => shape (size of batch,| Entities|) 85 | """ 86 | h = self.E(h_idx) 87 | r = self.R(r_idx) 88 | 89 | t = self.mul(h, r) 90 | if self.flag_hamilton_mul_norm: 91 | score_a, score_b = self.scoring_all(t, self.E.get_embeddings()) # a + b i 92 | else: 93 | score_a, score_b = self.scoring_all(self.E_dropout(t), self.E_dropout(self.E_bn(self.E.get_embeddings()))) 94 | score = score_a + score_b 95 | return torch.sigmoid(score) 96 | 97 | def init(self): 98 | self.E.init() 99 | self.R.init() 100 | -------------------------------------------------------------------------------- /toolbox/nn/Regularizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Fro(nn.Module): 8 | def __init__(self, weight: float): 9 | super(Fro, self).__init__() 10 | self.weight = weight 11 | 12 | def forward(self, factors: Tuple[torch.Tensor]): 13 | norm = 0 14 | for factor in factors: 15 | for f in factor: 16 | norm += self.weight * torch.sum(torch.norm(f, 2) ** 2) 17 | return norm / factors[0][0].shape[0] 18 | 19 | 20 | class N3(nn.Module): 21 | def __init__(self, weight: float): 22 | super(N3, self).__init__() 23 | self.weight = weight 24 | 25 | def forward(self, factors: Tuple[torch.Tensor]): 26 | norm = 0 27 | for factor in factors: 28 | for f in factor: 29 | norm += self.weight * torch.sum(torch.abs(f) ** 3) / f.shape[0] 30 | return norm 31 | 32 | 33 | class L1(nn.Module): 34 | def __init__(self, weight: float): 35 | super(L1, self).__init__() 36 | self.weight = weight 37 | 38 | def forward(self, factors: Tuple[torch.Tensor]): 39 | norm = 0 40 | for factor in factors: 41 | for f in factor: 42 | norm += self.weight * torch.sum(torch.abs(f) ** 1) 43 | return norm / factors[0][0].shape[0] 44 | 45 | 46 | class L2(nn.Module): 47 | def __init__(self, weight: float): 48 | super(L2, self).__init__() 49 | self.weight = weight 50 | 51 | def forward(self, factors: Tuple[torch.Tensor]): 52 | norm = 0 53 | for factor in factors: 54 | for f in factor: 55 | norm += self.weight * torch.sum(torch.abs(f) ** 2) 56 | return norm / factors[0][0].shape[0] 57 | 58 | 59 | class NA(nn.Module): 60 | def __init__(self, weight: float): 61 | super(NA, self).__init__() 62 | self.weight = weight 63 | 64 | def forward(self, factors: Tuple[torch.Tensor]): 65 | return torch.Tensor([0.0]).cuda() 66 | 67 | 68 | class DURA(nn.Module): 69 | def __init__(self, weight: float): 70 | super(DURA, self).__init__() 71 | self.weight = weight 72 | 73 | def forward(self, factors: Tuple[torch.Tensor]): 74 | norm = 0 75 | 76 | for factor in factors: 77 | h, r, t = factor 78 | 79 | norm += torch.sum(t ** 2 + h ** 2) 80 | norm += torch.sum(h ** 2 * r ** 2 + t ** 2 * r ** 2) 81 | 82 | return self.weight * norm / h.shape[0] 83 | 84 | 85 | class DURA_RESCAL(nn.Module): 86 | def __init__(self, weight: float): 87 | super(DURA_RESCAL, self).__init__() 88 | self.weight = weight 89 | 90 | def forward(self, factors: Tuple[torch.Tensor]): 91 | norm = 0 92 | for factor in factors: 93 | h, r, t = factor 94 | norm += torch.sum(h ** 2 + t ** 2) 95 | norm += torch.sum( 96 | torch.bmm(r.transpose(1, 2), h.unsqueeze(-1)) ** 2 + torch.bmm(r, t.unsqueeze(-1)) ** 2) 97 | return self.weight * norm / h.shape[0] 98 | 99 | 100 | class DURA_RESCAL_W(nn.Module): 101 | def __init__(self, weight: float): 102 | super(DURA_RESCAL_W, self).__init__() 103 | self.weight = weight 104 | 105 | def forward(self, factors: Tuple[torch.Tensor]): 106 | norm = 0 107 | for factor in factors: 108 | h, r, t = factor 109 | norm += 2.0 * torch.sum(h ** 2 + t ** 2) 110 | norm += 0.5 * torch.sum( 111 | torch.bmm(r.transpose(1, 2), h.unsqueeze(-1)) ** 2 + torch.bmm(r, t.unsqueeze(-1)) ** 2) 112 | return self.weight * norm / h.shape[0] 113 | 114 | 115 | class DURA_W(nn.Module): 116 | def __init__(self, weight: float): 117 | super(DURA_W, self).__init__() 118 | self.weight = weight 119 | 120 | def forward(self, factors: Tuple[torch.Tensor]): 121 | norm = 0 122 | for factor in factors: 123 | h, r, t = factor 124 | 125 | norm += 0.5 * torch.sum(t ** 2 + h ** 2) 126 | norm += 1.5 * torch.sum(h ** 2 * r ** 2 + t ** 2 * r ** 2) 127 | 128 | return self.weight * norm / h.shape[0] 129 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/js/utils.js: -------------------------------------------------------------------------------- 1 | // 处理json数据. 2 | /* 3 | value是一个数组 4 | [ 5 | {}, //每一条是一次实验记录 6 | {}, 7 | {} 8 | ] 9 | 10 | */ 11 | function processData(column_dict) { 12 | // 将数据设置为居中,将一些内容设置为json类型 13 | for (var key1 in column_dict) { 14 | var v1 = column_dict[key1]; 15 | v1['valign'] = 'middle'; 16 | v1['align'] = 'center'; 17 | if (('field' in v1) && window.settings['Wrap display']) { 18 | v1['class'] = 'word-wrap'; 19 | } 20 | for (var key in v1) { 21 | if (v1[key] === 'true') 22 | v1[key] = true; 23 | if (v1[key] === 'false') 24 | v1[key] = false; 25 | } 26 | v1['escape'] = true; 27 | 28 | } 29 | return column_dict; 30 | } 31 | 32 | function change_field_class(columns, class_name) { 33 | columns.forEach(function (v1, i) { 34 | v1.forEach(function (v2, i) { 35 | if (('field' in v2)) { 36 | v2['class'] = class_name; 37 | } 38 | }) 39 | }); 40 | return columns; 41 | } 42 | 43 | 44 | function convert_to_columns(column_order, column_dict, hidden_columns) { 45 | // 根据column_order, column_dict, hidden_columns生成columns, columns可以用于生成table对象 46 | const max_depth = get_max_col_ord_depth(column_order); 47 | const columns = []; 48 | for (let i = 0; i < max_depth; i++) { 49 | columns[i] = []; 50 | } 51 | 52 | generate_columns(column_order, column_dict, hidden_columns, '', columns, 0, max_depth); 53 | 54 | return columns; 55 | } 56 | 57 | function generate_columns(column_order, column_dict, hidden_columns, prefix, columns, depth, max_depth) { 58 | var total_colspan = 0; 59 | 60 | const keys = get_order_keys(column_order); 61 | 62 | keys.forEach(function (key) { 63 | var field; 64 | if (prefix === '') 65 | field = key; 66 | else 67 | field = prefix + '-' + key; 68 | 69 | if (!(field in hidden_columns)) //没有隐藏 70 | { 71 | var item = column_dict[field]; 72 | 73 | if (!(column_order[key] === 'EndOfOrder')) // 说明还有下一层 74 | { 75 | var colspan = generate_columns(column_order[key], column_dict, hidden_columns, field, columns, 76 | depth + 1, max_depth); 77 | item['colspan'] = colspan; 78 | item['rowspan'] = 1; 79 | total_colspan += colspan; 80 | } else { 81 | item['rowspan'] = max_depth - depth; 82 | item['colspan'] = 1; 83 | total_colspan += 1; 84 | } 85 | if (item['colspan'] !== 0) //只有当下面的内容没有全被隐藏的时候才显示 86 | columns[depth].push(item); 87 | } 88 | }); 89 | 90 | return total_colspan; 91 | } 92 | 93 | function get_order_keys(column_order) { 94 | // 给定order_columns返回他的key顺序. 按照这个顺序访问内容即可,已经删掉了OrderKeys关键字 95 | var keys = []; 96 | var key; 97 | if (column_order.hasOwnProperty('OrderKeys')) { 98 | column_order['OrderKeys'].forEach(function (key, i) { 99 | keys.push(key); 100 | }); 101 | } else { 102 | for (key in column_order) { 103 | keys.push(key); 104 | } 105 | } 106 | return keys; 107 | } 108 | 109 | function get_max_col_ord_depth(value) { 110 | // 根据column_order获取最大depth 111 | 112 | var depth = 0; 113 | var keys = get_order_keys(value); 114 | keys.forEach(function (key) { 115 | if (value[key] === 'EndOfOrder') 116 | depth = Math.max(depth, 1); 117 | else 118 | depth = Math.max(depth, get_max_col_ord_depth(value[key]) + 1); 119 | }); 120 | return depth; 121 | } 122 | 123 | const prompt = function (message, style, time) { 124 | $('.alert').remove(); 125 | style = (style === undefined) ? 'alert-success' : style; 126 | time = (time === undefined) ? 1200 : time; 127 | $('
') 128 | .appendTo('body') 129 | .addClass('alert ' + style) 130 | .html(message) 131 | .show() 132 | .delay(time) 133 | .fadeOut(); 134 | }; 135 | 136 | // 成功提示 137 | const success_prompt = function (message, time) { 138 | prompt(message, 'alert-success', time); 139 | }; 140 | 141 | // 警告提示 142 | const warning_prompt = function (message, time) { 143 | prompt(message, 'alert-warning', time); 144 | }; 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /toolbox/exp/classic/train_CartPole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # Hyper Parameters 8 | BATCH_SIZE = 32 9 | LR = 0.01 # learning rate 10 | EPSILON = 0.9 # greedy policy 11 | GAMMA = 0.9 # reward discount 12 | TARGET_REPLACE_ITER = 100 # target update frequency 13 | MEMORY_CAPACITY = 2000 14 | env = gym.make('CartPole-v0') 15 | env = env.unwrapped 16 | N_ACTIONS = env.action_space.n 17 | N_STATES = env.observation_space.shape[0] 18 | ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape 19 | 20 | 21 | class Net(nn.Module): 22 | def __init__(self, ): 23 | super(Net, self).__init__() 24 | self.fc1 = nn.Linear(N_STATES, 50) 25 | self.fc1.weight.data.normal_(0, 0.1) # initialization 26 | self.out = nn.Linear(50, N_ACTIONS) 27 | self.out.weight.data.normal_(0, 0.1) # initialization 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = F.relu(x) 32 | actions_value = self.out(x) 33 | return actions_value 34 | 35 | 36 | class DQN(object): 37 | def __init__(self): 38 | self.eval_net, self.target_net = Net(), Net() 39 | 40 | self.learn_step_counter = 0 # for target updating 41 | self.memory_counter = 0 # for storing memory 42 | self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory 43 | self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) 44 | self.loss_func = nn.MSELoss() 45 | 46 | def choose_action(self, x): 47 | x = torch.unsqueeze(torch.FloatTensor(x), 0) 48 | # input only one sample 49 | if np.random.uniform() < EPSILON: # greedy 50 | actions_value = self.eval_net.forward(x) 51 | action = torch.max(actions_value, 1)[1].data.numpy() 52 | action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index 53 | else: # random 54 | action = np.random.randint(0, N_ACTIONS) 55 | action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) 56 | return action 57 | 58 | def store_transition(self, s, a, r, s_): 59 | transition = np.hstack((s, [a, r], s_)) 60 | # replace the old memory with new memory 61 | index = self.memory_counter % MEMORY_CAPACITY 62 | self.memory[index, :] = transition 63 | self.memory_counter += 1 64 | 65 | def learn(self): 66 | # target parameter update 67 | if self.learn_step_counter % TARGET_REPLACE_ITER == 0: 68 | self.target_net.load_state_dict(self.eval_net.state_dict()) 69 | self.learn_step_counter += 1 70 | 71 | # sample batch transitions 72 | sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE) 73 | b_memory = self.memory[sample_index, :] 74 | b_s = torch.FloatTensor(b_memory[:, :N_STATES]) 75 | b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES + 1].astype(int)) 76 | b_r = torch.FloatTensor(b_memory[:, N_STATES + 1:N_STATES + 2]) 77 | b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:]) 78 | 79 | # q_eval w.r.t the action in experience 80 | q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1) 81 | q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate 82 | q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1) 83 | loss = self.loss_func(q_eval, q_target) 84 | 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | self.optimizer.step() 88 | 89 | 90 | dqn = DQN() 91 | 92 | print('\nCollecting experience...') 93 | for i_episode in range(400): 94 | s = env.reset() 95 | ep_r = 0 96 | while True: 97 | env.render() 98 | a = dqn.choose_action(s) 99 | 100 | # take action 101 | s_, r, done, info = env.step(a) 102 | 103 | # modify the reward 104 | x, x_dot, theta, theta_dot = s_ 105 | r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8 106 | r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5 107 | r = r1 + r2 108 | 109 | dqn.store_transition(s, a, r, s_) 110 | 111 | ep_r += r 112 | if dqn.memory_counter > MEMORY_CAPACITY: 113 | dqn.learn() 114 | if done: 115 | print('Ep: ', i_episode, 116 | '| Ep_r: ', round(ep_r, 2)) 117 | 118 | if done: 119 | break 120 | s = s_ 121 | -------------------------------------------------------------------------------- /toolbox/web/log_app/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import uuid 4 | from collections import deque 5 | from threading import Timer 6 | from urllib import request as urequest 7 | 8 | from flask import Flask, url_for, redirect 9 | from flask import jsonify 10 | from flask import request 11 | from flask import send_from_directory 12 | 13 | from .chart_app import chart_page 14 | from .folder_app import folder_page 15 | from .line_app import line_page 16 | from .log_read import log_agent 17 | from .multi_char_app import multi_chart_page 18 | from .server.app_utils import ServerWatcher 19 | from .server.app_utils import get_usage_port 20 | from .server.data_container import all_data 21 | from .server.data_container import handler_watcher 22 | from .server.table_utils import prepare_data 23 | from .server.table_utils import save_all_data 24 | from .server.utils import check_uuid, colored_string 25 | from .summary_app import summary_page 26 | from .table_app import table_page 27 | 28 | app = Flask(__name__) 29 | 30 | app.register_blueprint(chart_page) 31 | app.register_blueprint(table_page) 32 | app.register_blueprint(summary_page) 33 | app.register_blueprint(line_page) 34 | app.register_blueprint(multi_chart_page) 35 | app.register_blueprint(folder_page) 36 | 37 | LEAST_REQUEST_TIMESTAMP = deque(maxlen=1) 38 | LEAST_REQUEST_TIMESTAMP.append(time.time()) 39 | 40 | 41 | @app.route('/') 42 | def index(): 43 | return redirect(url_for('table_page.table')) 44 | 45 | 46 | @app.before_request 47 | def update_last_request_ms(): 48 | global LEAST_REQUEST_TIMESTAMP 49 | LEAST_REQUEST_TIMESTAMP.append(time.time()) 50 | 51 | 52 | @app.route('/kill', methods=['POST']) 53 | def seriouslykill(): 54 | time.sleep(1) 55 | func = request.environ.get('werkzeug.server.shutdown') 56 | if func is None: 57 | raise RuntimeError('Not running with the Werkzeug Server') 58 | func() 59 | return "stopping" 60 | 61 | 62 | @app.route('/arange_kill', methods=['POST']) 63 | def arange_kill(): 64 | res = check_uuid(all_data['uuid'], request.json['uuid']) 65 | if res is not None: 66 | return jsonify(res) 67 | 68 | def shutdown(): 69 | req = urequest.Request('http://127.0.0.1:{}/kill'.format(all_data['port']), headers={}, data=''.encode('utf-8')) 70 | page = urequest.urlopen(req).read().decode('utf-8') 71 | 72 | print("Shutting down from the frontend...") 73 | Timer(1.0, shutdown).start() 74 | return jsonify(status='success', msg='') 75 | 76 | 77 | @app.route('/table.ico') 78 | def get_table_ico(): 79 | return send_from_directory(os.path.join('.', 'static', 'img'), 'table.ico') 80 | 81 | 82 | @app.route('/chart.ico') 83 | def get_chart_ico(): 84 | return send_from_directory(os.path.join('.', 'static', 'img'), 'chart.ico') 85 | 86 | 87 | def start_app(log_dir, log_config_name, standby_hours, start_port, ip='0.0.0.0', token=None): 88 | """ 89 | log_dir app日志目录 90 | log_config_name app的配置文件名 91 | start_port 端口。如果该端口不可用,会自动选一个可用的 92 | standby_hours 空转小时数。如果超过这个时间没有任何操作,会自动停止运行,防止资源浪费 93 | """ 94 | os.chdir(os.path.dirname(os.path.abspath(__file__))) # 可能需要把运行路径移动到这里 95 | all_data['root_log_dir'] = log_dir # will be used by chart_app 96 | server_wait_seconds = int(standby_hours * 3600) 97 | print("This server will automatically shutdown if no api access for {} hours.".format(standby_hours)) 98 | all_data['log_config_name'] = log_config_name 99 | all_data['log_agent'] = log_agent 100 | if token is None: 101 | all_data['token'] = None 102 | else: 103 | all_data['token'] = str(token) 104 | print(colored_string(f"You specify token:{all_data['token']}, remember to add this token when access your table.", color='red')) 105 | 106 | # 准备数据 107 | all_data.update(prepare_data(log_agent, all_data['root_log_dir'], all_data['log_config_name'])) 108 | print(f"Finish preparing data. Found {len(all_data['data'])} records in {log_dir}.") 109 | all_data['uuid'] = str(uuid.uuid1()) 110 | 111 | port = get_usage_port(start_port=start_port) 112 | all_data['port'] = port 113 | 114 | server_watcher = ServerWatcher(LEAST_REQUEST_TIMESTAMP, port) 115 | server_watcher.set_server_wait_seconds(server_wait_seconds) 116 | server_watcher.start() 117 | app.run(host=ip, port=port, debug=False, threaded=True) 118 | 119 | # TODO 输出访问的ip地址 120 | print("Shutting down server...") 121 | save_all_data(all_data, all_data['root_log_dir'], all_data['log_config_name']) 122 | handler_watcher.stop() 123 | server_watcher.stop() 124 | 125 | 126 | if __name__ == '__main__': 127 | from .server.app_utils import cmd_parser 128 | 129 | parser = cmd_parser() 130 | args = parser.parse_args() 131 | start_app(args.log_dir, args.log_config_name, args.port, 1, '123') 132 | -------------------------------------------------------------------------------- /toolbox/web/log_app/folder_app.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 主要用于显示log folder下的文件 3 | 4 | 5 | """ 6 | import base64 7 | import os 8 | import time 9 | 10 | from flask import Blueprint 11 | from flask import make_response, send_file 12 | from flask import render_template, redirect, url_for 13 | from flask import request, jsonify 14 | 15 | from .server.data_container import all_data 16 | from .server.folder_utils import get_image_size 17 | from .server.utils import check_uuid 18 | 19 | folder_page = Blueprint('folder_page', __name__, template_folder='templates') 20 | 21 | 22 | @folder_page.route('/folder', methods=['POST', 'GET']) 23 | def show_folder(): 24 | if request.method == 'POST': 25 | uuid = request.values['uuid'] 26 | id = request.values['id'] if 'id' in request.values else '' 27 | subdir = request.values['subdir'] if 'subdir' in request.values else '' 28 | else: 29 | uuid = request.args.get('uuid') 30 | id = request.args.get('id') 31 | subdir = request.args.get('subdir') 32 | res = check_uuid(all_data['uuid'], uuid) 33 | if res is not None: 34 | return jsonify(res) 35 | if id: 36 | # TODO 远程日志服务器 37 | log_dir = all_data['root_log_dir'] 38 | folder = os.path.join(log_dir, id) 39 | if os.path.relpath(folder, log_dir).startswith('.'): 40 | return jsonify(status='fail', msg='Permission denied.') 41 | 42 | if subdir == '': # 如果为空,说明还是需要访问folder 43 | pass 44 | elif os.path.isfile(os.path.join(folder, subdir)): # 文件直接发送 45 | if os.path.splitext(subdir)[1][1:] in ('jpg', 'png', 'jpeg', 'fig'): 46 | return redirect(url_for('folder_page.show_image', uuid=uuid, id=id, subdir=subdir), code=301) 47 | resp = make_response(send_file(os.path.join(folder, subdir))) 48 | resp.headers["Content-type"] = "text/plan;charset=UTF-8" 49 | return resp 50 | elif os.path.isdir(os.path.join(folder, subdir)): # 如果是directory 51 | folder = os.path.join(folder, subdir) 52 | else: 53 | return jsonify(status='fail', msg="Invalid file.") 54 | 55 | if os.path.relpath(folder, log_dir).startswith('.'): 56 | return jsonify(status='fail', msg='Permission denied.') 57 | 58 | current_list = os.listdir(folder) 59 | contents = [] 60 | for i in sorted(current_list): 61 | fullpath = folder + os.sep + i 62 | # 如果是目录,在后面添加一个sep 63 | if os.path.isdir(fullpath): 64 | extra = os.sep 65 | else: 66 | extra = '' 67 | content = {} 68 | content['filename'] = i + extra 69 | content['mtime'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(os.stat(fullpath).st_mtime)) 70 | content['size'] = str(round(os.path.getsize(fullpath) / 1024)) + 'k' 71 | content['isfile'] = os.path.isfile(fullpath) 72 | if extra: 73 | contents.insert(0, content) 74 | else: 75 | contents.append(content) 76 | subdir = os.path.relpath(os.path.abspath(folder), start=os.path.abspath(os.path.join(log_dir, id))) 77 | if subdir.startswith('.'): 78 | subdir = '' 79 | else: 80 | if not subdir.endswith(os.sep): 81 | subdir += os.sep 82 | return render_template('folder.html', contents=contents, subdir=subdir, ossep=os.sep, 83 | uuid=all_data['uuid'], id=id) 84 | else: 85 | return jsonify(status='fail', msg="The request lacks id or filename.") 86 | 87 | 88 | @folder_page.route('/folder/show_image', methods=['GET']) 89 | def show_image(): 90 | uuid = request.args.get('uuid') 91 | id = request.args.get('id') 92 | subdir = request.args.get('subdir') 93 | res = check_uuid(all_data['uuid'], uuid) 94 | if res is not None: 95 | return jsonify(res) 96 | if id: 97 | log_dir = all_data['root_log_dir'] 98 | folder = os.path.join(log_dir, id) 99 | if os.path.splitext(subdir)[1][1:] in ('jpg', 'png', 'jpeg', 'fig'): 100 | img_stream = '' 101 | with open(os.path.join(folder, subdir), 'rb') as img_f: 102 | img_stream = img_f.read() 103 | img_stream = base64.b64encode(img_stream).decode('ascii') 104 | try: 105 | width = get_image_size(os.path.join(folder, subdir))[0] 106 | except: 107 | width = -1 108 | if width == -1: 109 | width = 1000 110 | return render_template('folder_img.html', img_stream=img_stream, img_path=subdir, width=width) 111 | return jsonify(status='fail', msg=f"Fail to show {os.path.relpath(os.path.join(folder, subdir)), log_dir}") 112 | -------------------------------------------------------------------------------- /toolbox/nn/ParamE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class CoreParamEGate(nn.Module): 7 | def __init__(self, entity_dim, hidden_dim=100): 8 | super(CoreParamEGate, self).__init__() 9 | self.entity_dim = entity_dim 10 | self.hidden_dim = hidden_dim 11 | self.linear = nn.Linear(hidden_dim, entity_dim) 12 | 13 | def forward(self, h, r1, r2): 14 | h = h.view(-1, 1, self.entity_dim) 15 | r1 = r1.view(-1, self.entity_dim, self.hidden_dim) 16 | r2 = r2.view(-1, self.entity_dim, self.hidden_dim) 17 | gate = torch.sigmoid(torch.bmm(h, r1).squeeze(dim=1)) 18 | value = torch.tanh(torch.bmm(h, r2).squeeze(dim=1)) 19 | x = (1 - gate) * value 20 | x = self.linear(x) 21 | x = F.relu(x) 22 | return x 23 | 24 | 25 | class CoreParamECNN(nn.Module): 26 | def __init__(self, entity_dim, 27 | conv_num_channels=32, 28 | conv_num_channels2=64, 29 | conv_filter_height=3, conv_filter_width=3, 30 | hidden_dropout1: float = 0.2, 31 | hidden_dropout2: float = 0.2): 32 | super(CoreParamECNN, self).__init__() 33 | self.entity_dim = entity_dim 34 | 35 | self.conv_in_height = 10 36 | self.conv_in_width = self.entity_dim // 10 37 | 38 | self.conv_filter_height = conv_filter_height 39 | self.conv_filter_width = conv_filter_width 40 | self.conv_num_channels = conv_num_channels 41 | self.conv_num_channels2 = conv_num_channels2 42 | 43 | self.conv_out_height = self.conv_in_height - self.conv_filter_height + 1 44 | self.conv_out_width = self.conv_in_width - self.conv_filter_width + 1 45 | 46 | self.conv_out_height2 = self.conv_out_height - self.conv_filter_height + 1 47 | self.conv_out_width2 = self.conv_out_width - self.conv_filter_width + 1 48 | 49 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 50 | self.hidden_dropout2 = nn.Dropout(hidden_dropout2) 51 | 52 | hidden_dim = self.conv_num_channels * self.conv_num_channels2 * self.conv_out_height2 * self.conv_out_width2 53 | self.hidden_dim = hidden_dim 54 | self.linear = nn.Linear(hidden_dim, entity_dim) 55 | 56 | def forward(self, h, r1, r2): 57 | img = h.view(1, -1, self.conv_in_height, self.conv_in_width) 58 | batch_size = img.size(1) 59 | conv_weight1 = r1.view(batch_size * self.conv_num_channels, 1, self.conv_filter_height, self.conv_filter_width) 60 | conv_weight2 = r2.view(batch_size * self.conv_num_channels2, 1, self.conv_filter_height, self.conv_filter_width) 61 | 62 | x = F.conv2d(img, weight=conv_weight1, groups=batch_size) 63 | x = F.relu(x) 64 | x = self.hidden_dropout1(x) 65 | x = x.view(-1, batch_size, self.conv_out_height, self.conv_out_width) 66 | 67 | x = F.conv2d(x, weight=conv_weight2, groups=batch_size) 68 | x = F.relu(x) 69 | x = self.hidden_dropout2(x) 70 | x = x.view(batch_size, self.hidden_dim) 71 | 72 | x = self.linear(x) 73 | x = F.relu(x) 74 | return x 75 | 76 | 77 | class ParamE(nn.Module): 78 | def __init__(self, num_entities, num_relations, entity_dim, relation_dim, hidden_dropout=0.2): 79 | super(ParamE, self).__init__() 80 | self.entity_dim = entity_dim 81 | self.relation_dim = relation_dim 82 | self.E = nn.Embedding(num_entities, entity_dim) 83 | 84 | conv_num_channels = 32 85 | conv_num_channels2 = 64 86 | conv_filter_height = 3 87 | conv_filter_width = 3 88 | self.R1 = nn.Embedding(num_relations, conv_num_channels * conv_filter_height * conv_filter_width) 89 | self.R2 = nn.Embedding(num_relations, conv_num_channels2 * conv_filter_height * conv_filter_width) 90 | 91 | self.core = CoreParamECNN(entity_dim, 92 | conv_num_channels, conv_num_channels2, 93 | conv_filter_height, conv_filter_width) 94 | self.dropout = nn.Dropout(hidden_dropout) 95 | self.b = nn.Parameter(torch.zeros(num_entities)) 96 | self.m = nn.PReLU() 97 | 98 | self.loss = nn.BCELoss() 99 | 100 | def init(self): 101 | nn.init.xavier_normal_(self.E.weight.data) 102 | nn.init.xavier_normal_(self.R1.weight.data) 103 | nn.init.xavier_normal_(self.R2.weight.data) 104 | 105 | def forward(self, h_idx, r_idx): 106 | h = self.E(h_idx) # Bxd 107 | r1 = self.R1(r_idx) # Bxd 108 | r2 = self.R2(r_idx) # Bxd 109 | 110 | t = self.core(h, r1, r2) 111 | t = t.view(-1, self.entity_dim) 112 | 113 | x = torch.mm(t, self.dropout(self.E.weight).transpose(1, 0)) 114 | x = x + self.b.expand_as(x) 115 | x = torch.sigmoid(x) 116 | return x # batch_size x E 117 | -------------------------------------------------------------------------------- /toolbox/nn/ComplexTuckER.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from toolbox.nn.ComplexEmbedding import ComplexEmbedding, ComplexDropout, ComplexScoringAll, ComplexBatchNorm1d 6 | 7 | 8 | class CoreTuckER(nn.Module): 9 | def __init__(self, entity_dim, relation_dim, hidden_dropout1=0.4): 10 | super(CoreTuckER, self).__init__() 11 | self.entity_dim = entity_dim 12 | self.relation_dim = relation_dim 13 | 14 | self.W = nn.Parameter(torch.FloatTensor(np.random.uniform(-0.01, 0.01, (relation_dim, entity_dim, entity_dim)))) 15 | 16 | self.hidden_dropout1 = nn.Dropout(hidden_dropout1) 17 | 18 | def forward(self, h, r): 19 | h = h.view(-1, 1, self.entity_dim) 20 | 21 | W = self.W.view(self.relation_dim, -1) 22 | W = torch.mm(r.view(-1, self.relation_dim), W) 23 | W = W.view(-1, self.entity_dim, self.entity_dim) 24 | W = self.hidden_dropout1(W) 25 | 26 | t = torch.bmm(h, W) 27 | t = t.view(-1, self.entity_dim) 28 | return t 29 | 30 | def w(self, h, r): 31 | h = torch.cat([h.transpose(1, 0).unsqueeze(dim=0)] * r.size(0), dim=0) # BxdxE 32 | 33 | W = self.W.view(self.relation_dim, -1) 34 | W = torch.mm(r.view(-1, self.relation_dim), W) 35 | W = W.view(-1, self.entity_dim, self.entity_dim) # Bxdxd 36 | W = self.hidden_dropout1(W) 37 | t = torch.bmm(W, h) # BxdxE 38 | return t 39 | 40 | 41 | class ComplexTuckER(nn.Module): 42 | def __init__(self, entity_dim, relation_dim, hidden_dropout=0.4): 43 | super(ComplexTuckER, self).__init__() 44 | self.entity_dim = entity_dim 45 | self.relation_dim = relation_dim 46 | 47 | self.Wa = CoreTuckER(entity_dim, relation_dim, hidden_dropout) 48 | self.Wb = CoreTuckER(entity_dim, relation_dim, hidden_dropout) 49 | 50 | self.bn0 = nn.BatchNorm1d(entity_dim) 51 | self.bn1 = nn.BatchNorm1d(entity_dim) 52 | self.bn3 = nn.BatchNorm1d(relation_dim) 53 | self.bn4 = nn.BatchNorm1d(relation_dim) 54 | self.bn5 = nn.BatchNorm1d(entity_dim) 55 | self.bn6 = nn.BatchNorm1d(entity_dim) 56 | self.ma = nn.PReLU() 57 | self.mb = nn.PReLU() 58 | 59 | def forward(self, h, r): 60 | h_a, h_b = h 61 | h_a = self.bn0(h_a) 62 | h_b = self.bn1(h_b) 63 | r_a, r_b = r 64 | r_a = self.bn3(r_a) 65 | r_b = self.bn4(r_b) 66 | t_a = self.Wa(h_a, r_a) - self.Wb(h_a, r_b) - self.Wb(h_b, r_a) - self.Wa(h_b, r_b) 67 | t_b = self.Wb(h_a, r_a) + self.Wa(h_a, r_b) + self.Wa(h_b, r_a) - self.Wb(h_b, r_b) 68 | t_a = self.bn5(t_a) 69 | t_b = self.bn6(t_b) 70 | t_a = self.ma(t_a) 71 | t_b = self.mb(t_b) 72 | return t_a, t_b 73 | 74 | class TuckER(nn.Module): 75 | def __init__(self, num_entities, num_relations, entity_dim, relation_dim, input_dropout=0.3, hidden_dropout=0.3, hidden_dropout2=0.3): 76 | super(TuckER, self).__init__() 77 | self.entity_dim = entity_dim 78 | self.relation_dim = relation_dim 79 | self.flag_hamilton_mul_norm = True 80 | 81 | self.E = ComplexEmbedding(num_entities, entity_dim) 82 | self.R = ComplexEmbedding(num_relations, relation_dim) 83 | 84 | self.core = ComplexTuckER(entity_dim, relation_dim, hidden_dropout) 85 | self.E_bn = ComplexBatchNorm1d(entity_dim, 2) 86 | self.E_dropout = ComplexDropout([input_dropout, input_dropout]) 87 | self.R_dropout = ComplexDropout([input_dropout, input_dropout]) 88 | self.hidden_dp = ComplexDropout([hidden_dropout, hidden_dropout]) 89 | 90 | self.scoring_all = ComplexScoringAll() 91 | self.bce = nn.BCELoss() 92 | self.b1 = nn.Parameter(torch.zeros(num_entities)) 93 | self.b2 = nn.Parameter(torch.zeros(num_entities)) 94 | 95 | def init(self): 96 | self.E.init() 97 | self.R.init() 98 | 99 | def forward(self, h_idx, r_idx): 100 | return self.forward_head_batch(h_idx.view(-1), r_idx.view(-1)) 101 | 102 | def forward_head_batch(self, h_idx, r_idx): 103 | """ 104 | Completed. 105 | Given a head entity and a relation (h,r), we compute scores for all possible triples,i.e., 106 | [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) 107 | Given a batch of head entities and relations => shape (size of batch,| Entities|) 108 | """ 109 | h = self.E(h_idx) 110 | r = self.R(r_idx) 111 | 112 | t = self.core(h, r) 113 | 114 | if self.flag_hamilton_mul_norm: 115 | score_a, score_b = self.scoring_all(t, self.E.get_embeddings()) # a + b i 116 | else: 117 | score_a, score_b = self.scoring_all(self.E_dropout(t), self.E_dropout(self.E_bn(self.E.get_embeddings()))) 118 | score_a = score_a + self.b1.expand_as(score_a) 119 | score_b = score_b + self.b2.expand_as(score_b) 120 | 121 | y_a = torch.sigmoid(score_a) 122 | y_b = torch.sigmoid(score_b) 123 | 124 | return y_a, y_b 125 | 126 | def loss(self, target, y): 127 | y_a, y_b = target 128 | return self.bce(y_a, y) + self.bce(y_b, y) 129 | -------------------------------------------------------------------------------- /toolbox/utils/AutoML.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def config_tuning_space(tuning_space_raw): 5 | from hyperopt import hp 6 | from hyperopt.pyll.base import scope 7 | if tuning_space_raw is None: 8 | return None 9 | 10 | hyper_obj = {} 11 | if "learning_rate" in tuning_space_raw: 12 | hyper_obj = {**hyper_obj, **{ 13 | "learning_rate": hp.loguniform('learning_rate', np.log(tuning_space_raw['learning_rate']['min']), 14 | np.log(tuning_space_raw['learning_rate']['max']))}} 15 | if "hidden_size" in tuning_space_raw: 16 | hyper_obj = {**hyper_obj, **{"hidden_size": scope.int( 17 | hp.qloguniform('hidden_size', np.log(tuning_space_raw['hidden_size']['min']), 18 | np.log(tuning_space_raw['hidden_size']['max']), 1))}} 19 | if "ent_hidden_size" in tuning_space_raw: 20 | hyper_obj = {**hyper_obj, **{"ent_hidden_size": scope.int( 21 | hp.qloguniform("ent_hidden_size", np.log(tuning_space_raw['ent_hidden_size']['min']), 22 | np.log(tuning_space_raw['ent_hidden_size']['max']), 1))}} 23 | if "rel_hidden_size" in tuning_space_raw: 24 | hyper_obj = {**hyper_obj, **{"rel_hidden_size": scope.int( 25 | hp.qloguniform("rel_hidden_size", np.log(tuning_space_raw['rel_hidden_size']['min']), 26 | np.log(tuning_space_raw['rel_hidden_size']['max']), 1))}} 27 | if "batch_size" in tuning_space_raw: 28 | hyper_obj = {**hyper_obj, **{"batch_size": scope.int( 29 | hp.qloguniform("batch_size", np.log(tuning_space_raw['batch_size']['min']), 30 | np.log(tuning_space_raw['batch_size']['max']), 1))}} 31 | if "margin" in tuning_space_raw: 32 | hyper_obj = {**hyper_obj, **{ 33 | "margin": hp.uniform("margin", tuning_space_raw["margin"]["min"], tuning_space_raw["margin"]["max"])}} 34 | if "lmbda" in tuning_space_raw: 35 | hyper_obj = {**hyper_obj, **{"lmbda": hp.loguniform('lmbda', np.log(tuning_space_raw["lmbda"]["min"]), 36 | np.log(tuning_space_raw["lmbda"]["max"]))}} 37 | if "distance_measure" in tuning_space_raw: 38 | hyper_obj = {**hyper_obj, 39 | **{"distance_measure": hp.choice('distance_measure', tuning_space_raw["distance_measure"])}} 40 | if "cmax" in tuning_space_raw: 41 | hyper_obj = {**hyper_obj, **{"cmax": hp.loguniform('cmax', np.log(tuning_space_raw["cmax"]["min"]), 42 | np.log(tuning_space_raw["cmax"]["max"]))}} 43 | if "cmin" in tuning_space_raw: 44 | hyper_obj = {**hyper_obj, **{"cmin": hp.loguniform('cmin', np.log(tuning_space_raw["cmin"]["min"]), 45 | np.log(tuning_space_raw["cmin"]["max"]))}} 46 | if "optimizer" in tuning_space_raw: 47 | hyper_obj = {**hyper_obj, **{"optimizer": hp.choice("optimizer", tuning_space_raw["optimizer"])}} 48 | if "bilinear" in tuning_space_raw: 49 | hyper_obj = {**hyper_obj, **{"bilinear": hp.choice('bilinear', tuning_space_raw["bilinear"])}} 50 | if "epochs" in tuning_space_raw: 51 | hyper_obj = {**hyper_obj, **{"epochs": hp.choice("epochs", tuning_space_raw["epochs"])}} 52 | if "feature_map_dropout" in tuning_space_raw: 53 | hyper_obj = {**hyper_obj, **{ 54 | "feature_map_dropout": hp.choice('feature_map_dropout', tuning_space_raw["feature_map_dropout"])}} 55 | if "input_dropout" in tuning_space_raw: 56 | hyper_obj = {**hyper_obj, 57 | **{"input_dropout": hp.choice('input_dropout', tuning_space_raw["input_dropout"])}} 58 | if "hidden_dropout" in tuning_space_raw: 59 | hyper_obj = {**hyper_obj, 60 | **{"hidden_dropout": hp.choice('hidden_dropout', tuning_space_raw["hidden_dropout"])}} 61 | if "use_bias" in tuning_space_raw: 62 | hyper_obj = {**hyper_obj, **{"use_bias": hp.choice('use_bias', tuning_space_raw["use_bias"])}} 63 | if "label_smoothing" in tuning_space_raw: 64 | hyper_obj = {**hyper_obj, 65 | **{"label_smoothing": hp.choice('label_smoothing', tuning_space_raw["label_smoothing"])}} 66 | if "lr_decay" in tuning_space_raw: 67 | hyper_obj = {**hyper_obj, **{"lr_decay": hp.choice('lr_decay', tuning_space_raw["lr_decay"])}} 68 | if "l1_flag" in tuning_space_raw: 69 | hyper_obj = {**hyper_obj, **{"l1_flag": hp.choice('l1_flag', tuning_space_raw["l1_flag"])}} 70 | if "sampling" in tuning_space_raw: 71 | hyper_obj = {**hyper_obj, **{"sampling": hp.choice('sampling', tuning_space_raw["sampling"])}} 72 | 73 | return hyper_obj 74 | 75 | 76 | def grid_search(config, run): 77 | # config = { 78 | # "embedding_dim": [100 * i for i in range(1, 6)], 79 | # "input_dropout": [0.1 * i for i in range(2)], 80 | # "hidden_dropout": [0.1 * i for i in range(2)], 81 | # } 82 | from sklearn.model_selection import ParameterGrid 83 | for i, setting in enumerate(ParameterGrid(config)): 84 | # input_dropout = setting["input_dropout"] 85 | # hidden_dropout = setting["hidden_dropout"] 86 | # embedding_dim = setting["embedding_dim"] 87 | run(i, setting) 88 | -------------------------------------------------------------------------------- /toolbox/web/log_app/templates/multi_chart.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | chart-{{ log_dir }} 8 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 33 | 55 | 75 | 76 | 77 |
78 | 82 | 95 |
96 |
97 | 99 | 100 | 101 | 115 | 116 | 117 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /toolbox/data/ScoringAllDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple, Dict, Set 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from toolbox.data.functional import build_map_hr_t, with_inverse_relations 8 | 9 | 10 | class ScoringAllDataset(Dataset): 11 | def __init__(self, train_triples_ids: List[Tuple[int, int, int]], entity_count: int): 12 | self.hr_t = build_map_hr_t(train_triples_ids) 13 | self.hr_pairs = list(self.hr_t.keys()) 14 | self.entity_count = entity_count 15 | 16 | def __len__(self): 17 | return len(self.hr_pairs) 18 | 19 | def __getitem__(self, idx): 20 | h, r = self.hr_pairs[idx] 21 | data = torch.zeros(self.entity_count).float() 22 | data[list(self.hr_t[(h, r)])] = 1. 23 | h = torch.LongTensor([h]) 24 | r = torch.LongTensor([r]) 25 | return h, r, data 26 | 27 | 28 | class ScoringNegativeDataset(Dataset): 29 | def __init__(self, train_triples_ids: List[Tuple[int, int, int]], entity_count: int, negative_sample_size: int = 32): 30 | self.hr_t = build_map_hr_t(train_triples_ids) 31 | self.hr_pairs = list(self.hr_t.keys()) 32 | self.entity_count = entity_count 33 | self.negative_sample_size = negative_sample_size 34 | 35 | def __len__(self): 36 | return len(self.hr_pairs) 37 | 38 | def __getitem__(self, idx): 39 | h, r = self.hr_pairs[idx] 40 | data = torch.zeros(self.entity_count).float() 41 | data[list(self.hr_t[(h, r)])] = 1. 42 | idx = torch.randperm(self.entity_count)[:self.negative_sample_size] 43 | targets = data[idx] 44 | h = torch.LongTensor([h]) 45 | r = torch.LongTensor([r]) 46 | return h, r, idx, targets 47 | 48 | 49 | class ScoringOneVsNegativeDataset(Dataset): 50 | def __init__(self, train_triples_ids: List[Tuple[int, int, int]], max_relation_id: int, entity_count: int, negative_sample_size: int = 32): 51 | self.train_triples_ids = train_triples_ids 52 | train_triples, _, _ = with_inverse_relations(train_triples_ids, max_relation_id) 53 | self.hr_t = build_map_hr_t(train_triples) 54 | self.entity_count = entity_count 55 | self.max_relation_id = max_relation_id 56 | self.negative_sample_size = negative_sample_size 57 | 58 | def __len__(self): 59 | return len(self.train_triples_ids) 60 | 61 | def __getitem__(self, idx): 62 | h, r, t = self.train_triples_ids[idx] 63 | reverse_r = self.max_relation_id + r 64 | t_sample, t_target = self.sampling(h, r, t) 65 | h_sample, h_target = self.sampling(t, reverse_r, h) 66 | h = torch.LongTensor([h]) 67 | r = torch.LongTensor([r]) 68 | t = torch.LongTensor([t]) 69 | reverse_r = torch.LongTensor([reverse_r]) 70 | return h, r, t_sample, t_target, t, reverse_r, h_sample, h_target 71 | 72 | def sampling(self, h, r, answer_idx): 73 | valid_negative = list(set(range(self.entity_count)) - self.hr_t[(h, r)]) 74 | sample_negative = random.choices(valid_negative, k=self.negative_sample_size) 75 | sample_idx = torch.LongTensor([answer_idx] + sample_negative) # (1, 1 + num_negative) 76 | target = torch.zeros(1 + self.negative_sample_size).float() # (1, 1 + num_negative) 77 | target[0] = 1 78 | return sample_idx, target 79 | 80 | 81 | class ComplementaryScoringAllDataset(Dataset): 82 | def __init__(self, hr_t: Dict[Tuple[int, int], Set[int]], all_keys: List[Tuple[int, int]], entity_count: int): 83 | self.hr_t = hr_t 84 | self.hr_pairs = all_keys 85 | self.entity_count = entity_count 86 | 87 | def __len__(self): 88 | return len(self.hr_pairs) 89 | 90 | def __getitem__(self, idx): 91 | h, r = self.hr_pairs[idx] 92 | data = torch.ones(self.entity_count).float() 93 | value = list(self.hr_t[(h, r)]) 94 | if len(value) > 0: 95 | data[value] = 0. 96 | h = torch.LongTensor([h]) 97 | r = torch.LongTensor([r]) 98 | return h, r, data 99 | 100 | 101 | class BidirectionalScoringAllDataset(Dataset): 102 | def __init__(self, test_triples_ids: List[Tuple[int, int, int]], hr_t: Dict[Tuple[int, int], Set[int]], max_relation_id: int, entity_count: int): 103 | """ 104 | test_triples_ids: without reverse r 105 | hr_t: all hr->t, MUST with reverse r 106 | """ 107 | self.test_triples_ids = test_triples_ids 108 | self.hr_t = hr_t 109 | self.entity_count = entity_count 110 | self.max_relation_id = max_relation_id 111 | 112 | def __len__(self): 113 | return len(self.test_triples_ids) 114 | 115 | def __getitem__(self, idx): 116 | h, r, t = self.test_triples_ids[idx] 117 | reverse_r = r + self.max_relation_id 118 | 119 | predict_for_hr = torch.zeros(self.entity_count).float() 120 | predict_for_hr[list(self.hr_t[(h, r)])] = 1. 121 | 122 | predict_for_tReverser = torch.zeros(self.entity_count).float() 123 | predict_for_tReverser[list(self.hr_t[(t, reverse_r)])] = 1. 124 | 125 | h = torch.LongTensor([h]) 126 | r = torch.LongTensor([r]) 127 | t = torch.LongTensor([t]) 128 | reverse_r = torch.LongTensor([reverse_r]) 129 | 130 | return h, r, predict_for_hr, t, reverse_r, predict_for_tReverser 131 | -------------------------------------------------------------------------------- /toolbox/exp/OutputSchema.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/26 5 | @description: 输出目录管理 6 | """ 7 | from pathlib import Path 8 | from typing import Union 9 | 10 | from toolbox.utils.Log import Log 11 | 12 | 13 | class OutputPathSchema: 14 | """ 15 | 输出目录 下的路径 16 | """ 17 | 18 | def __init__(self, output_path: Union[Path, str]): 19 | self.output_path: Path = output_path if output_path is Path else Path(output_path) 20 | 21 | self.dir_path_log = self.output_path / 'logs' 22 | self.dir_path_visualize = self.output_path / 'visualize' 23 | self.dir_path_checkpoint = self.output_path / 'checkpoint' 24 | self.dir_path_latex = self.output_path / 'latex' 25 | self.dir_path_deploy = self.output_path / 'deploy' 26 | self.dir_path_scripts = self.output_path / 'scripts' 27 | 28 | self.build_dir_structure() 29 | 30 | def log_path(self, filename) -> Path: 31 | return self.dir_path_log / filename 32 | 33 | def visualize_path(self, filename) -> Path: 34 | return self.dir_path_visualize / filename 35 | 36 | def checkpoint_path(self, filename="checkpoint.tar") -> Path: 37 | return self.dir_path_checkpoint / filename 38 | 39 | def latex_path(self, filename="best.tex") -> Path: 40 | return self.dir_path_latex / filename 41 | 42 | def deploy_path(self, filename="model.tar") -> Path: 43 | return self.dir_path_deploy / filename 44 | 45 | def scripts_path(self, filename) -> Path: 46 | return self.dir_path_scripts / filename 47 | 48 | def build_dir_structure(self): 49 | self.output_path.mkdir(parents=True, exist_ok=True) 50 | self.dir_path_log.mkdir(parents=True, exist_ok=True) 51 | self.dir_path_visualize.mkdir(parents=True, exist_ok=True) 52 | self.dir_path_checkpoint.mkdir(parents=True, exist_ok=True) 53 | self.dir_path_latex.mkdir(parents=True, exist_ok=True) 54 | self.dir_path_deploy.mkdir(parents=True, exist_ok=True) 55 | self.dir_path_scripts.mkdir(parents=True, exist_ok=True) 56 | 57 | def clean(self): 58 | # clean the dir, and recreate dir structure 59 | import shutil 60 | shutil.rmtree(self.output_path) 61 | self.build_dir_structure() 62 | 63 | 64 | class OutputSchema: 65 | """ 66 | 输出目录 67 | ./output 68 | - experiment name 69 | - visualize Tensorboard 可视化 70 | - events... 71 | - logs log 日志,包含超参数、指标、最佳指标等日志,配合 toolbox.web.log_app 使用 72 | - config.log 73 | - loss.log 74 | - checkpoint 检查点,用于恢复训练 75 | - checkpoint_score_xx.tar 76 | - deploy 部署模型,基于 checkpoint ,不同的是这里的 tar 文件内只包含模型,不包含优化器梯度等信息 77 | - model_score_xx.tar 78 | - config.yaml 79 | - output.log 打印到命令行的日志 80 | 81 | Args: 82 | experiment_name (str): Name of your experiment 83 | overwrite (bool): If True, it will delete the folder and create new one. 84 | 85 | Examples: 86 | >>> from toolbox.exp.OutputSchema import OutputSchema 87 | >>> output = OutputSchema("output_name") 88 | >>> output.dump() 89 | 90 | """ 91 | 92 | def __init__(self, experiment_name: str, overwrite=False): 93 | self.name = experiment_name 94 | self.home_path = self.output_home_path() 95 | self.pathSchema = OutputPathSchema(self.home_path) 96 | if overwrite: 97 | self.pathSchema.clean() 98 | self.logger = Log(str(self.home_path / "output.log"), name_scope=experiment_name + "output") 99 | 100 | def output_home_path(self) -> Path: 101 | data_home_path: Path = Path('.') / 'output' 102 | data_home_path.mkdir(parents=True, exist_ok=True) 103 | data_home_path = data_home_path.resolve() 104 | return data_home_path / self.name 105 | 106 | def output_path_child(self, child_dir_name: str) -> Path: 107 | return self.home_path / child_dir_name 108 | 109 | def child_log(self, name: str, write_to_console=False) -> Log: 110 | return Log(str(self.pathSchema.log_path(name)), name_scope=self.name + "output-" + name, write_to_console=write_to_console) 111 | 112 | def dump(self): 113 | """ Displays all the metadata of the knowledge graph""" 114 | for key, value in self.__dict__.items(): 115 | self.logger.info("%s %s" % (key, value)) 116 | 117 | def __repr__(self): 118 | return f"{self.__class__.__name__}({self.home_path})" 119 | 120 | 121 | class Cleaner: 122 | def __init__(self, pathSchema: OutputPathSchema): 123 | self.pathSchema: OutputPathSchema = pathSchema 124 | 125 | def remove_non_best_checkpoint_and_model(self): 126 | def remove_non_best(dir_path: Path): 127 | dir_name = str(dir_path) 128 | import os 129 | print("In", dir_name) 130 | filenames = os.listdir(dir_name) 131 | to_delete_files = set([f for f in filenames if "best" not in f]) 132 | for filename in to_delete_files: 133 | print(" remove", filename) 134 | os.remove(str(dir_path / filename)) 135 | 136 | remove_non_best(self.pathSchema.dir_path_checkpoint) 137 | remove_non_best(self.pathSchema.dir_path_deploy) 138 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/css/table.css: -------------------------------------------------------------------------------- 1 | 2 | .word-wrap { 3 | word-break: break-all; 4 | max-width: 80px; 5 | word-wrap: break-word; 6 | overflow-wrap: break-word; 7 | min-width: 30px; 8 | } 9 | 10 | #columns_dialogue div { 11 | margin-top: 2px; 12 | padding: 1px 0px 1px 10px; 13 | margin-bottom: 2px; 14 | } 15 | 16 | #add_row_dialogue div { 17 | margin-top: 2px; 18 | padding: 1px 0px 1px 10px; 19 | margin-bottom: 2px; 20 | } 21 | 22 | 23 | 24 | /* 以下内容用于checkbox*/ 25 | 26 | /* 27 | P.S: if you like my content maybe you will become a donator and donate some money? That helps me to create new awesome materials. https://www.paypal.me/melnik909 28 | */ 29 | 30 | /* 31 | I've used nested span elements for creating an animation of square turn and creating an arrow animation. But if you know other a solution please email me 32 | melnik909@ya.ru 33 | */ 34 | 35 | /* 36 | ===== 37 | LEVEL 1. CORE STYLES 38 | ===== 39 | */ 40 | 41 | .toggle{ 42 | --uiToggleSize: var(--toggleSize, 20px); 43 | --uiToggleIndent: var(--toggleIndent, .4em); 44 | --uiToggleBorderWidth: var(--toggleBorderWidth, 2px); 45 | --uiToggleColor: var(--toggleColor, #000); 46 | --uiToggleDisabledColor: var(--toggleDisabledColor, #868e96); 47 | --uiToggleBgColor: var(--toggleBgColor, #fff); 48 | --uiToggleArrowWidth: var(--toggleArrowWidth, 2px); 49 | --uiToggleArrowColor: var(--toggleArrowColor, #fff); 50 | 51 | display: inline-block; 52 | position: relative; 53 | } 54 | 55 | .toggle__input{ 56 | position: absolute; 57 | left: -99999px; 58 | } 59 | 60 | .toggle__label{ 61 | display: inline-flex; 62 | cursor: pointer; 63 | min-height: var(--uiToggleSize); 64 | padding-left: calc(var(--uiToggleSize) + var(--uiToggleIndent)); 65 | } 66 | 67 | .toggle__label:before, .toggle__label:after{ 68 | content: ""; 69 | box-sizing: border-box; 70 | width: 1em; 71 | height: 1em; 72 | font-size: var(--uiToggleSize); 73 | 74 | position: absolute; 75 | left: 0; 76 | top: 0; 77 | } 78 | 79 | .toggle__label:before{ 80 | border: var(--uiToggleBorderWidth) solid var(--uiToggleColor); 81 | z-index: 2; 82 | } 83 | 84 | .toggle__input:disabled ~ .toggle__label:before{ 85 | border-color: var(--uiToggleDisabledColor); 86 | } 87 | 88 | /*.toggle__input:focus ~ .toggle__label:before{*/ 89 | /*box-shadow: 0 0 0 2px var(--uiToggleBgColor), 0 0 0px 4px var(--uiToggleColor);*/ 90 | /*}*/ 91 | 92 | /*.toggle__input:not(:disabled):checked:focus ~ .toggle__label:after{*/ 93 | /*box-shadow: 0 0 0 2px var(--uiToggleBgColor), 0 0 0px 4px var(--uiToggleColor);*/ 94 | /*}*/ 95 | 96 | .toggle__input:not(:disabled) ~ .toggle__label:after{ 97 | background-color: var(--uiToggleColor); 98 | opacity: 0; 99 | } 100 | 101 | .toggle__input:not(:disabled):checked ~ .toggle__label:after{ 102 | opacity: 1; 103 | } 104 | 105 | .toggle__text{ 106 | margin-top: auto; 107 | margin-bottom: auto; 108 | } 109 | 110 | /* 111 | The arrow size and position depends from sizes of square because I needed an arrow correct positioning from the top left corner of the element toggle 112 | */ 113 | 114 | .toggle__text:before{ 115 | content: ""; 116 | box-sizing: border-box; 117 | width: 0; 118 | height: 0; 119 | font-size: var(--uiToggleSize); 120 | 121 | border-left-width: 0; 122 | border-bottom-width: 0; 123 | border-left-style: solid; 124 | border-bottom-style: solid; 125 | border-color: var(--uiToggleArrowColor); 126 | 127 | position: absolute; 128 | top: .5428em; 129 | left: .2em; 130 | z-index: 3; 131 | 132 | transform-origin: left top; 133 | transform: rotate(-40deg) skew(10deg); 134 | } 135 | 136 | .toggle__input:not(:disabled):checked ~ .toggle__label .toggle__text:before{ 137 | width: .5em; 138 | height: .25em; 139 | border-left-width: var(--uiToggleArrowWidth); 140 | border-bottom-width: var(--uiToggleArrowWidth); 141 | will-change: width, height; 142 | transition: width .1s ease-out .2s, height .2s ease-out; 143 | } 144 | 145 | /* 146 | ===== 147 | LEVEL 2. PRESENTATION STYLES 148 | ===== 149 | */ 150 | 151 | /* 152 | The demo skin 153 | */ 154 | 155 | .toggle__label:before, .toggle__label:after{ 156 | border-radius: 2px; 157 | } 158 | 159 | /* 160 | The animation of switching states 161 | */ 162 | 163 | .toggle__input:not(:disabled) ~ .toggle__label:before, 164 | .toggle__input:not(:disabled) ~ .toggle__label:after{ 165 | opacity: 1; 166 | transform-origin: center center; 167 | will-change: transform; 168 | transition: transform .2s ease-out; 169 | } 170 | 171 | .toggle__input:not(:disabled) ~ .toggle__label:before{ 172 | transform: rotateY(0deg); 173 | transition-delay: .2s; 174 | } 175 | 176 | .toggle__input:not(:disabled) ~ .toggle__label:after{ 177 | transform: rotateY(90deg); 178 | } 179 | 180 | .toggle__input:not(:disabled):checked ~ .toggle__label:before{ 181 | transform: rotateY(-90deg); 182 | transition-delay: 0s; 183 | } 184 | 185 | .toggle__input:not(:disabled):checked ~ .toggle__label:after{ 186 | transform: rotateY(0deg); 187 | transition-delay: .2s; 188 | } 189 | 190 | .toggle__text:before{ 191 | opacity: 0; 192 | } 193 | 194 | .toggle__input:not(:disabled):checked ~ .toggle__label .toggle__text:before{ 195 | opacity: 1; 196 | transition: opacity .1s ease-out .3s, width .1s ease-out .5s, height .2s ease-out .3s; 197 | } 198 | 199 | /* 200 | ===== 201 | LEVEL 3. SETTINGS 202 | ===== 203 | */ 204 | 205 | .toggle{ 206 | --toggleColor: #008490; 207 | --toggleBgColor: #0d49b6; 208 | /*--toggleSize: 15px;*/ 209 | } 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /toolbox/utils/Progbar.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 进度条 6 | """ 7 | import sys 8 | import time 9 | from typing import Dict, Any, Union, Tuple, List 10 | import datetime 11 | import numpy as np 12 | 13 | 14 | class Progbar(object): 15 | """Progress bar class inspired by keras 进度条 16 | 17 | Examples: 18 | >>> from toolbox.utils.Progbar import Progbar 19 | >>> progbar = Progbar(max_step=100) 20 | >>> for i in range(100): 21 | >>> progbar.update(i, [("step", i), ("next", i+1)]) 22 | """ 23 | 24 | def __init__(self, max_step: int, width: int = 15, mode: str = "instant"): 25 | self.max_step: int = max_step 26 | self.width: int = width 27 | self.mode: str = mode 28 | self.last_width: int = 0 29 | 30 | self.sum_values: Dict[str, Any] = {} 31 | 32 | self.start: float = time.time() 33 | self.last_step: int = 0 34 | 35 | self.info: str = "" 36 | self.bar: str = "" 37 | 38 | def _update_values(self, curr_step: int, values: List[Tuple[str, Union[float, str, int]]]): 39 | for k, v in values: 40 | if k not in self.sum_values: 41 | if isinstance(v, float) or isinstance(v, int): 42 | if self.mode == "instant": 43 | self.sum_values[k] = v 44 | else: 45 | self.sum_values[k] = [v * (curr_step - self.last_step), curr_step - self.last_step] 46 | elif isinstance(v, str): 47 | self.sum_values[k] = (v + " ")[:20] 48 | else: 49 | self.sum_values[k] = (str(v) + " ")[:20] 50 | else: 51 | if isinstance(v, float) or isinstance(v, int): 52 | if self.mode == "instant": 53 | self.sum_values[k] = v 54 | else: 55 | self.sum_values[k][0] += v * (curr_step - self.last_step) 56 | self.sum_values[k][1] += (curr_step - self.last_step) 57 | elif isinstance(v, str): 58 | self.sum_values[k] = (v + " ")[:20] 59 | else: 60 | self.sum_values[k] = (str(v) + " ")[:20] 61 | 62 | def _write_bar(self, curr_step: int): 63 | last_width = self.last_width 64 | sys.stdout.write("\b" * last_width) 65 | sys.stdout.write("\r") 66 | 67 | numdigits = int(np.floor(np.log10(self.max_step))) + 1 68 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 69 | bar = barstr % (curr_step, self.max_step) 70 | prog = float(curr_step) / self.max_step 71 | prog_width = int(self.width * prog) 72 | if prog_width > 0: 73 | bar += ('=' * (prog_width - 1)) 74 | if curr_step < self.max_step: 75 | bar += '>' 76 | else: 77 | bar += '=' 78 | bar += ('.' * (self.width - prog_width)) 79 | bar += ']' 80 | sys.stdout.write(bar) 81 | 82 | return bar 83 | 84 | def _get_eta(self, curr_step: int): 85 | now = time.time() 86 | if curr_step: 87 | time_per_step = (now - self.start) / curr_step 88 | else: 89 | time_per_step = 0 90 | eta = time_per_step * (self.max_step - curr_step) 91 | 92 | if curr_step < self.max_step: 93 | info = ' - ETA: %s' % str(datetime.timedelta(seconds=eta)) 94 | else: 95 | info = ' - %s' % str(datetime.timedelta(seconds=now - self.start)) 96 | 97 | return info 98 | 99 | def _get_values_sum(self): 100 | info = "" 101 | for name, value in self.sum_values.items(): 102 | if isinstance(value, str): 103 | info += ' - %s: %s' % (name, value) 104 | else: 105 | if self.mode == "instant": 106 | if isinstance(value, int): 107 | info += ' - %s: %d' % (name, value) 108 | else: 109 | info += ' - %s: %.6f' % (name, value) 110 | else: 111 | info += ' - %s: %.6f' % (name, value[0] / max(1, value[1])) 112 | return info 113 | 114 | def _write_info(self, curr_step: int): 115 | info = "" 116 | info += self._get_eta(curr_step) 117 | info += self._get_values_sum() 118 | 119 | sys.stdout.write(info) 120 | 121 | return info 122 | 123 | def _update_width(self, curr_step: int): 124 | curr_width = len(self.bar) + len(self.info) 125 | if curr_width < self.last_width: 126 | sys.stdout.write(" " * (self.last_width - curr_width)) 127 | 128 | if curr_step >= self.max_step: 129 | sys.stdout.write("\n") 130 | 131 | sys.stdout.flush() 132 | 133 | self.last_width = curr_width 134 | 135 | def update(self, curr_step: int, values: Union[Dict[str, Any], List[Tuple[str, Any]]]): 136 | """Updates the progress bar. 137 | The progress bar will display averages for these values. 138 | 139 | Args: 140 | values: Dict or List of tuples (name, value_for_last_step). 141 | """ 142 | if isinstance(values, dict): 143 | values = [(k, v) for k, v in values.items()] 144 | self._update_values(curr_step, values) 145 | self.bar = self._write_bar(curr_step) 146 | self.info = self._write_info(curr_step) 147 | self._update_width(curr_step) 148 | self.last_step = curr_step 149 | return values 150 | -------------------------------------------------------------------------------- /toolbox/utils/Framework.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import sys 5 | from functools import wraps 6 | 7 | import numpy as np 8 | from packaging import version 9 | 10 | if sys.version_info < (3, 8): 11 | import importlib_metadata 12 | else: 13 | import importlib.metadata as importlib_metadata 14 | 15 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 16 | 17 | ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} 18 | ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) 19 | 20 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 21 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 22 | USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() 23 | 24 | if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: 25 | _torch_available = importlib.util.find_spec("torch") is not None 26 | if _torch_available: 27 | try: 28 | _torch_version = importlib_metadata.version("torch") 29 | logger.info(f"PyTorch version {_torch_version} available.") 30 | except importlib_metadata.PackageNotFoundError: 31 | _torch_available = False 32 | else: 33 | logger.info("Disabling PyTorch because USE_TF is set") 34 | _torch_available = False 35 | 36 | if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: 37 | _tf_available = importlib.util.find_spec("tensorflow") is not None 38 | if _tf_available: 39 | candidates = ( 40 | "tensorflow", 41 | "tensorflow-cpu", 42 | "tensorflow-gpu", 43 | "tf-nightly", 44 | "tf-nightly-cpu", 45 | "tf-nightly-gpu", 46 | "intel-tensorflow", 47 | ) 48 | _tf_version = None 49 | # For the metadata, we have to look for both tensorflow and tensorflow-cpu 50 | for pkg in candidates: 51 | try: 52 | _tf_version = importlib_metadata.version(pkg) 53 | break 54 | except importlib_metadata.PackageNotFoundError: 55 | pass 56 | _tf_available = _tf_version is not None 57 | if _tf_available: 58 | if version.parse(_tf_version) < version.parse("2"): 59 | logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") 60 | _tf_available = False 61 | else: 62 | logger.info(f"TensorFlow version {_tf_version} available.") 63 | else: 64 | logger.info("Disabling Tensorflow because USE_TORCH is set") 65 | _tf_available = False 66 | 67 | _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False 68 | 69 | 70 | def is_offline_mode(): 71 | return _is_offline_mode 72 | 73 | 74 | def is_torch_available(): 75 | return _torch_available 76 | 77 | 78 | def is_torch_cuda_available(): 79 | if is_torch_available(): 80 | import torch 81 | 82 | return torch.cuda.is_available() 83 | else: 84 | return False 85 | 86 | 87 | def is_tf_available(): 88 | return _tf_available 89 | 90 | 91 | def is_deepspeed_available(): 92 | return importlib.util.find_spec("deepspeed") is not None 93 | 94 | 95 | def is_psutil_available(): 96 | return importlib.util.find_spec("psutil") is not None 97 | 98 | 99 | class cached_property(property): 100 | """ 101 | Descriptor that mimics @property but caches output in member variable. 102 | 103 | From tensorflow_datasets 104 | 105 | Built-in in functools from Python 3.8. 106 | """ 107 | 108 | def __get__(self, obj, objtype=None): 109 | # See docs.python.org/3/howto/descriptor.html#properties 110 | if obj is None: 111 | return self 112 | if self.fget is None: 113 | raise AttributeError("unreadable attribute") 114 | attr = "__cached_" + self.fget.__name__ 115 | cached = getattr(obj, attr, None) 116 | if cached is None: 117 | cached = self.fget(obj) 118 | setattr(obj, attr, cached) 119 | return cached 120 | 121 | 122 | def torch_required(func): 123 | # Chose a different decorator name than in tests so it's clear they are not the same. 124 | @wraps(func) 125 | def wrapper(*args, **kwargs): 126 | if is_torch_available(): 127 | return func(*args, **kwargs) 128 | else: 129 | raise ImportError(f"Method `{func.__name__}` requires PyTorch.") 130 | 131 | return wrapper 132 | 133 | 134 | def tf_required(func): 135 | # Chose a different decorator name than in tests so it's clear they are not the same. 136 | @wraps(func) 137 | def wrapper(*args, **kwargs): 138 | if is_tf_available(): 139 | return func(*args, **kwargs) 140 | else: 141 | raise ImportError(f"Method `{func.__name__}` requires TF.") 142 | 143 | return wrapper 144 | 145 | 146 | def is_tensor(x): 147 | """ Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """ 148 | if is_torch_available(): 149 | import torch 150 | 151 | if isinstance(x, torch.Tensor): 152 | return True 153 | if is_tf_available(): 154 | import tensorflow as tf 155 | 156 | if isinstance(x, tf.Tensor): 157 | return True 158 | return isinstance(x, np.ndarray) 159 | 160 | 161 | def _is_numpy(x): 162 | return isinstance(x, np.ndarray) 163 | 164 | 165 | def _is_torch(x): 166 | import torch 167 | 168 | return isinstance(x, torch.Tensor) 169 | 170 | 171 | def _is_torch_device(x): 172 | import torch 173 | 174 | return isinstance(x, torch.device) 175 | 176 | 177 | def _is_tensorflow(x): 178 | import tensorflow as tf 179 | 180 | return isinstance(x, tf.Tensor) 181 | -------------------------------------------------------------------------------- /toolbox/nn/MobiusE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from toolbox.nn.ComplexEmbedding import ComplexEmbedding, ComplexDropout, ComplexScoringAll, ComplexBatchNorm1d, ComplexMult, ComplexAdd, ComplexDiv, ComplexAlign 6 | from toolbox.nn.MobiusEmbedding import MobiusEmbedding, MobiusDropout, MobiusBatchNorm1d 7 | from toolbox.nn.Regularizer import N3 8 | 9 | 10 | def mobius_mul_with_unit_norm(Q_1, Q_2): 11 | a_h = Q_1 # = {a_h + b_h i + c_h j + d_h k : a_r, b_r, c_r, d_r \in R^k} 12 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 13 | 14 | # Normalize the relation to eliminate the scaling effect 15 | denominator = torch.sqrt(a_r ** 2 + b_r ** 2 + c_r ** 2 + d_r ** 2) 16 | p = a_r / denominator 17 | q = b_r / denominator 18 | u = c_r / denominator 19 | v = d_r / denominator 20 | # Q'=E Hamilton product R 21 | h_r = (a_h * p + q) / (a_h * u + v) 22 | return h_r 23 | 24 | 25 | def mobius_mul(Q_1, Q_2): 26 | a_h = Q_1 # = {a_h : a_r, b_r, c_r, d_r \in R^k} 27 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 28 | h_r = (a_h * a_r + b_r) / (a_h * c_r + d_r) 29 | return h_r 30 | 31 | 32 | class MobiusMult(nn.Module): 33 | """ 34 | a_r * h + b_r 35 | h * r = ------------- 36 | c_r * h + d_r 37 | h in CP^d 38 | r_a, r_b, r_c, r_d in C^d 39 | """ 40 | 41 | def __init__(self, norm_flag=False): 42 | super(MobiusMult, self).__init__() 43 | self.flag_hamilton_mul_norm = norm_flag 44 | self.complex_mul = ComplexMult(norm_flag) 45 | self.complex_add = ComplexAdd() 46 | self.complex_div = ComplexDiv() 47 | 48 | def forward(self, h, r): 49 | r_a, r_b, r_c, r_d = r 50 | hr_top = self.complex_add(self.complex_mul(h, r_a), r_b) 51 | hr_bottom = self.complex_add(self.complex_mul(h, r_c), r_d) 52 | h_r = self.complex_div(hr_top, hr_bottom) 53 | return h_r 54 | 55 | 56 | class MobiusE(nn.Module): 57 | 58 | def __init__(self, 59 | num_entities, num_relations, 60 | embedding_dim, 61 | norm_flag=False, input_dropout=0.2, hidden_dropout=0.3, regularization_weight=0.1): 62 | super(MobiusE, self).__init__() 63 | self.embedding_dim = embedding_dim 64 | self.num_entities = num_entities 65 | self.num_relations = num_relations 66 | self.loss = nn.BCELoss() 67 | self.flag_hamilton_mul_norm = norm_flag 68 | self.E = ComplexEmbedding(self.num_entities, self.embedding_dim, 2) # a + bi 69 | self.R = MobiusEmbedding(self.num_relations, self.embedding_dim, 4) # 4 numbers: a + bi 70 | self.E_dropout = ComplexDropout([input_dropout, input_dropout]) 71 | self.R_dropout = MobiusDropout([[input_dropout, input_dropout]] * 4) 72 | self.hidden_dp = ComplexDropout([hidden_dropout, hidden_dropout]) 73 | self.E_bn = ComplexBatchNorm1d(self.embedding_dim, 2) 74 | self.R_bn = MobiusBatchNorm1d(self.embedding_dim, 4) 75 | self.b = nn.Parameter(torch.zeros(num_entities)) 76 | 77 | self.mul = MobiusMult(norm_flag) 78 | self.scoring_all = ComplexScoringAll() 79 | self.align = ComplexAlign() 80 | self.regularizer = N3(regularization_weight) 81 | 82 | def forward(self, h_idx, r_idx): 83 | h_idx = h_idx.view(-1) 84 | r_idx = r_idx.view(-1) 85 | return self.forward_head_batch(h_idx, r_idx) 86 | 87 | def forward_head_batch(self, h_idx, r_idx): 88 | """ 89 | Completed. 90 | Given a head entity and a relation (h,r), we compute scores for all possible triples,i.e., 91 | [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) 92 | Given a batch of head entities and relations => shape (size of batch,| Entities|) 93 | """ 94 | h = self.E(h_idx) 95 | r = self.R(r_idx) 96 | 97 | t = self.mul(h, r) 98 | 99 | if self.flag_hamilton_mul_norm: 100 | score_a, score_b = self.scoring_all(t, self.E.get_embeddings()) # a + b i 101 | else: 102 | score_a, score_b = self.scoring_all(self.E_dropout(t), self.E_dropout(self.E_bn(self.E.get_embeddings()))) 103 | x = score_a + score_b 104 | x = x + self.b.expand_as(x) 105 | x = torch.sigmoid(x) 106 | 107 | return x 108 | 109 | def regular_loss(self, h_idx, r_idx): 110 | h = self.E(h_idx) 111 | r = self.R(r_idx) 112 | h_a, h_a_i = h 113 | (r_a, r_a_i), (r_b, r_b_i), (r_c, r_c_i), (r_d, r_d_i) = r 114 | factors = ( 115 | torch.sqrt(h_a ** 2 + h_a_i ** 2), 116 | torch.sqrt(r_a ** 2 + r_a_i ** 2 + r_c ** 2 + r_c_i ** 2 + r_b ** 2 + r_b_i ** 2 + r_d ** 2 + r_d_i ** 2), 117 | ) 118 | regular_loss = self.regularizer(factors) 119 | return regular_loss 120 | 121 | def reverse_loss(self, h_idx, r_idx, max_relation_idx): 122 | h = self.E(h_idx) 123 | h_a, h_b = h 124 | h = (h_a.detach(), h_b.detach()) 125 | 126 | r = self.R(r_idx) 127 | reverse_rel_idx = (r_idx + max_relation_idx) % (2 * max_relation_idx) 128 | 129 | t = self.mul(h, r) 130 | reverse_r = self.R(reverse_rel_idx) 131 | reverse_t = self.mul(t, reverse_r) 132 | reverse_a, reverse_b = self.align(reverse_t, h) # a + b i 133 | reverse_score = reverse_a + reverse_b 134 | reverse_score = torch.mean(F.relu(reverse_score)) 135 | 136 | return reverse_score 137 | 138 | def init(self): 139 | self.E.init() 140 | self.R.init() 141 | 142 | # 143 | # TRAIN: {'MRR': 0.3367987126111984, 'hits@[1,3,10]': tensor([0.2347, 0.3808, 0.5365])} 144 | # VALID : {'MRR': 0.28631188720464706, 'hits@[1,3,10]': tensor([0.2089, 0.3107, 0.4430])} 145 | # 146 | # 147 | # TEST : ({'rhs': 0.38042595982551575, 'lhs': 0.1826966404914856}, {'rhs': tensor([0.2927, 0.4182, 0.5544]), 'lhs': tensor([0.1148, 0.1959, 0.3186])}) 148 | -------------------------------------------------------------------------------- /toolbox/nn/TuckerMobiusE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from toolbox.nn.ComplexEmbedding import ComplexEmbedding, ComplexDropout, ComplexBatchNorm1d, ComplexMult, ComplexAdd, ComplexDiv, ComplexAlign 6 | from toolbox.nn.MobiusEmbedding import MobiusEmbedding, MobiusDropout, MobiusBatchNorm1d 7 | from toolbox.nn.Regularizer import N3 8 | from toolbox.nn.TuckER import CoreTuckER 9 | 10 | 11 | def mobius_mul_with_unit_norm(Q_1, Q_2): 12 | a_h = Q_1 # = {a_h + b_h i + c_h j + d_h k : a_r, b_r, c_r, d_r \in R^k} 13 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 14 | 15 | # Normalize the relation to eliminate the scaling effect 16 | denominator = torch.sqrt(a_r ** 2 + b_r ** 2 + c_r ** 2 + d_r ** 2) 17 | p = a_r / denominator 18 | q = b_r / denominator 19 | u = c_r / denominator 20 | v = d_r / denominator 21 | # Q'=E Hamilton product R 22 | h_r = (a_h * p + q) / (a_h * u + v) 23 | return h_r 24 | 25 | 26 | def mobius_mul(Q_1, Q_2): 27 | a_h = Q_1 # = {a_h : a_r, b_r, c_r, d_r \in R^k} 28 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 29 | h_r = (a_h * a_r + b_r) / (a_h * c_r + d_r) 30 | return h_r 31 | 32 | 33 | class MobiusMult(nn.Module): 34 | """ 35 | a_r * h + b_r 36 | h * r = ------------- 37 | c_r * h + d_r 38 | h in CP^d 39 | r_a, r_b, r_c, r_d in C^d 40 | """ 41 | 42 | def __init__(self, norm_flag=False): 43 | super(MobiusMult, self).__init__() 44 | self.flag_hamilton_mul_norm = norm_flag 45 | self.complex_mul = ComplexMult(norm_flag) 46 | self.complex_add = ComplexAdd() 47 | self.complex_div = ComplexDiv() 48 | 49 | def forward(self, h, r): 50 | r_a, r_b, r_c, r_d = r 51 | hr_top = self.complex_add(self.complex_mul(h, r_a), r_b) 52 | hr_bottom = self.complex_add(self.complex_mul(h, r_c), r_d) 53 | h_r = self.complex_div(hr_top, hr_bottom) 54 | return h_r 55 | 56 | 57 | class BatchComplexScoringAll(nn.Module): 58 | def forward(self, complex_numbers, embeddings): 59 | out = [] 60 | for idx, complex_number in enumerate(list(complex_numbers)): 61 | ans = torch.bmm(complex_number.unsqueeze(dim=1), embeddings[idx]).squeeze(dim=1) 62 | out.append(ans) 63 | return tuple(out) 64 | 65 | 66 | class TuckerMobiusE(nn.Module): 67 | 68 | def __init__(self, 69 | num_entities, num_relations, 70 | embedding_dim, 71 | norm_flag=False, input_dropout=0.2, hidden_dropout=0.3, regularization_weight=0.1): 72 | super(TuckerMobiusE, self).__init__() 73 | self.embedding_dim = embedding_dim 74 | self.num_entities = num_entities 75 | self.num_relations = num_relations 76 | self.loss = nn.BCELoss() 77 | self.flag_hamilton_mul_norm = norm_flag 78 | self.E = ComplexEmbedding(self.num_entities, self.embedding_dim, 2) # a + bi 79 | self.R = MobiusEmbedding(self.num_relations, self.embedding_dim, 4) # 4 numbers: a + bi 80 | self.R2 = nn.Embedding(num_relations, embedding_dim) 81 | self.real_tucker = CoreTuckER(embedding_dim, embedding_dim, hidden_dropout) 82 | self.img_tucker = CoreTuckER(embedding_dim, embedding_dim, hidden_dropout) 83 | self.E_dropout = ComplexDropout([input_dropout, input_dropout]) 84 | self.R_dropout = MobiusDropout([[input_dropout, input_dropout]] * 4) 85 | self.hidden_dp = ComplexDropout([hidden_dropout, hidden_dropout]) 86 | self.E_bn = ComplexBatchNorm1d(self.embedding_dim, 2) 87 | self.R_bn = MobiusBatchNorm1d(self.embedding_dim, 4) 88 | 89 | self.mul = MobiusMult(norm_flag) 90 | self.scoring_all = BatchComplexScoringAll() 91 | self.align = ComplexAlign() 92 | self.regularizer = N3(regularization_weight) 93 | 94 | def forward(self, h_idx, r_idx): 95 | h_idx = h_idx.view(-1) 96 | r_idx = r_idx.view(-1) 97 | return self.forward_head_batch(h_idx, r_idx) 98 | 99 | def forward_head_batch(self, h_idx, r_idx): 100 | """ 101 | Completed. 102 | Given a head entity and a relation (h,r), we compute scores for all possible triples,i.e., 103 | [score(h,r,x)|x \in Entities] => [0.0,0.1,...,0.8], shape=> (1, |Entities|) 104 | Given a batch of head entities and relations => shape (size of batch,| Entities|) 105 | """ 106 | h = self.E(h_idx) 107 | r = self.R(r_idx) 108 | r2 = self.R2(r_idx) 109 | 110 | h_a, h_b = h 111 | t_a = self.real_tucker(h_a, r2) 112 | t_b = self.img_tucker(h_b, r2) 113 | h = (t_a, t_b) 114 | 115 | t = self.mul(h, r) 116 | # (re_relation_a, im_relation_a), (re_relation_c, im_relation_c), (re_relation_b, im_relation_b), (re_relation_d, im_relation_d) = r 117 | 118 | E_a, E_b = self.E.get_embeddings() 119 | E_a = self.real_tucker.w(E_a, r2) 120 | E_b = self.img_tucker.w(E_b, r2) 121 | E = (E_a, E_b) 122 | # E = self.E.get_embeddings() 123 | if self.flag_hamilton_mul_norm: 124 | score_a, score_b = self.scoring_all(t, E) # a + b i 125 | else: 126 | score_a, score_b = self.scoring_all(self.E_dropout(t), self.E_dropout(self.E_bn(E))) 127 | score = score_a + score_b 128 | score = torch.sigmoid(score) 129 | 130 | return score 131 | 132 | def reverse_loss(self, h_idx, r_idx, max_relation_idx): 133 | h = self.E(h_idx) 134 | h_a, h_b = h 135 | h = (h_a.detach(), h_b.detach()) 136 | 137 | r = self.R(r_idx) 138 | reverse_rel_idx = (r_idx + max_relation_idx) % (2 * max_relation_idx) 139 | 140 | t = self.mul(h, r) 141 | reverse_r = self.R(reverse_rel_idx) 142 | reverse_t = self.mul(t, reverse_r) 143 | reverse_a, reverse_b = self.align(reverse_t, h) # a + b i 144 | reverse_score = reverse_a + reverse_b 145 | reverse_score = torch.mean(F.relu(reverse_score)) 146 | 147 | return reverse_score 148 | 149 | def init(self): 150 | self.E.init() 151 | self.R.init() 152 | --------------------------------------------------------------------------------