├── toolbox ├── .gitignore ├── utils │ ├── Optimizer.py │ ├── __init__.py │ ├── SpeedUp.py │ ├── RandomSeeds.py │ ├── Time.py │ ├── Embed.py │ ├── DefaultDict.py │ ├── VisualizeStore.py │ ├── LaTeX.py │ ├── Download.py │ ├── Log.py │ ├── AutoML.py │ ├── Progbar.py │ └── Framework.py ├── web │ ├── log_app │ │ ├── server │ │ │ ├── __init__.py │ │ │ ├── app_utils.py │ │ │ └── data_container.py │ │ ├── README.md │ │ ├── static │ │ │ ├── img │ │ │ │ ├── chart.ico │ │ │ │ ├── table.ico │ │ │ │ └── loading.gif │ │ │ ├── fonts │ │ │ │ ├── glyphicons-halflings-regular.eot │ │ │ │ ├── glyphicons-halflings-regular.ttf │ │ │ │ ├── glyphicons-halflings-regular.woff │ │ │ │ └── glyphicons-halflings-regular.woff2 │ │ │ ├── css │ │ │ │ ├── bootstrap-table-filter-control.css │ │ │ │ ├── bootstrap-table-reorder-rows.css │ │ │ │ └── table.css │ │ │ └── js │ │ │ │ ├── chart.js │ │ │ │ ├── bootstrap-table-reorder-rows.js │ │ │ │ └── utils.js │ │ ├── __init__.py │ │ ├── templates │ │ │ ├── folder_img.html │ │ │ ├── folder.html │ │ │ └── multi_chart.html │ │ ├── main.py │ │ ├── output │ │ │ └── default.cfg │ │ ├── line_app.py │ │ ├── multi_char_app.py │ │ ├── app.py │ │ └── folder_app.py │ ├── __init__.py │ └── log_server │ │ ├── README.md │ │ ├── __init__.py │ │ └── main.py ├── nn │ ├── functional │ │ ├── __init__.py │ │ └── complex.py │ ├── __init__.py │ ├── Flatten.py │ ├── README.md │ ├── Highway.py │ ├── DistMult.py │ ├── RESCAL.py │ ├── GCN.py │ ├── GAT.py │ ├── Complex.py │ ├── LorentzE.py │ ├── TuckERT.py │ ├── MobiusEmbedding.py │ ├── ConvE.py │ ├── TransE.py │ ├── TuckER.py │ ├── TuckERTNT.py │ ├── TuckERTTR.py │ ├── Rotate3D.py │ ├── TuckERTTT.py │ ├── CoPER.py │ ├── TuckERCPD.py │ ├── BlaschkE.py │ ├── Regularizer.py │ ├── ParamE.py │ ├── ComplexTuckER.py │ ├── MobiusE.py │ └── TuckerMobiusE.py ├── README_en.md ├── exp │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── JsonConfig.py │ ├── Notebook.py │ ├── classic │ │ ├── experiment_demo.py │ │ └── train_CartPole.py │ ├── DistributeSchema.py │ ├── Experiment.py │ └── OutputSchema.py ├── evaluate │ ├── __init__.py │ ├── EntityAlignment.py │ ├── GatherMetric.py │ └── Leaderboard.py ├── data │ ├── __init__.py │ ├── README.md │ ├── TripleDataset.py │ ├── ComplementaryDataset.py │ ├── FixWindowNegSamplingDataset.py │ ├── LinkPredictDataset.py │ └── ScoringAllDataset.py ├── __init__.py ├── game │ ├── __init__.py │ ├── SinglePlayerGame.py │ ├── Player.py │ └── MultiPlayerGame.py ├── optim │ ├── __init__.py │ └── lr_scheduler.py ├── requirements-web.txt ├── requirements.txt ├── CheetSheet.md ├── README.md ├── cli │ └── clean_output.py └── README_template_en.md ├── vis.png ├── requirements.txt └── .gitignore /toolbox/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /toolbox/utils/Optimizer.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /toolbox/web/log_app/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /toolbox/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["complex"] 2 | -------------------------------------------------------------------------------- /toolbox/README_en.md: -------------------------------------------------------------------------------- 1 | # KGE Toolbox 2 | 3 | 4 | 5 | ## CheatSheet -------------------------------------------------------------------------------- /toolbox/exp/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Experiment", "OutputSchema"] 2 | -------------------------------------------------------------------------------- /vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/vis.png -------------------------------------------------------------------------------- /toolbox/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["EntityAlignment", "LinkPredict", "Evaluate"] 2 | -------------------------------------------------------------------------------- /toolbox/data/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DataSchema", "DatasetSchema", "FixWindowNegSamplingDataset"] 2 | -------------------------------------------------------------------------------- /toolbox/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Log", "Progbar", "RandomSeeds", "ModelParamStore", "DefaultDict", "VisualizeStore.py"] 2 | -------------------------------------------------------------------------------- /toolbox/web/log_app/README.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | ```shell 4 | python -m toolbox.web.log_app.main --log_dir="output" 5 | ``` 6 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/web/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/web/log_server/README.md: -------------------------------------------------------------------------------- 1 | # Get Started 2 | 3 | ```shell 4 | python -m toolbox.web.log_server.main --log_dir="output" 5 | ``` 6 | -------------------------------------------------------------------------------- /toolbox/game/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/11/7 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/requirements-web.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pytorch 3 | numpy 4 | scikit-learn 5 | scipy 6 | flask 7 | Flask-Cors 8 | jieba 9 | rouge 10 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/chart.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/chart.ico -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/table.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/table.ico -------------------------------------------------------------------------------- /toolbox/web/log_app/static/img/loading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/img/loading.gif -------------------------------------------------------------------------------- /toolbox/web/log_server/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/12/9 5 | @description: null 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/exp/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/27 5 | @description: 使用配置文件来管理脚本的各种超参数 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 请自己根据cuda版本安装 2 | tqdm 3 | numpy 4 | scikit-learn 5 | scipy 6 | click 7 | tensorboardX 8 | pandas 9 | pathlib 10 | pyyaml -------------------------------------------------------------------------------- /toolbox/data/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## TODO 3 | 1. Maybe we can create data dir and hold some classic datasets to prevent downloading from the source to accelerate coding process. 4 | -------------------------------------------------------------------------------- /toolbox/web/log_app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 通过 web api 获取远程的实验日志服务器的数据,呈现到前端 6 | """ 7 | -------------------------------------------------------------------------------- /toolbox/nn/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["CapsE", "TransE", "ConvE", "GAT", "GCN", "Flatten", "CoPER", "OctonionE", "RotatE", "EchoE", "QuatE", "HAKE", "TuckER", "Complex", "Highway", "DistMult"] 2 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /toolbox/nn/Flatten.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Flatten(nn.Module): 5 | def forward(self, x): 6 | n = x.size(0) 7 | x = x.view(n, -1) 8 | return x 9 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinXueyuanStdio/KGE-toolbox/HEAD/toolbox/web/log_app/static/fonts/glyphicons-halflings-regular.woff2 -------------------------------------------------------------------------------- /toolbox/evaluate/EntityAlignment.py: -------------------------------------------------------------------------------- 1 | 2 | def entity_alignment(predictions): 3 | """ 4 | predictions : torch.Tensor, similarity matrix of shape (entity_count, entity_count) 5 | """ 6 | pass -------------------------------------------------------------------------------- /toolbox/CheetSheet.md: -------------------------------------------------------------------------------- 1 | # 指定 gpu 2 | ```shell 3 | CUDA_VISIBLE_DEVICES=3 4 | ``` 5 | 6 | # 内存分析,打印每一行代码执行前后的内存变化 7 | ```shell 8 | pip install memory_profiler psutil 9 | python -m memory_profiler main.py 10 | ``` -------------------------------------------------------------------------------- /toolbox/utils/SpeedUp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def speed_up(): 5 | torch.cuda.emptyCache() 6 | 7 | def how_to_speed_up(): 8 | info = """ 9 | 1. for your dataloader: pin_memory == True, num_worker >= 8 10 | 2. choose faster optimizer: AdamW 11 | """ 12 | print(info) 13 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/css/bootstrap-table-filter-control.css: -------------------------------------------------------------------------------- 1 | /** 2 | * @author: Dennis Hernández 3 | * @webSite: http://djhvscf.github.io/Blog 4 | * @version: v2.1.1 5 | */ 6 | 7 | .no-filter-control { 8 | height: 34px; 9 | } 10 | 11 | .filter-control { 12 | margin: 0 2px 2px 2px; 13 | } -------------------------------------------------------------------------------- /toolbox/web/log_app/templates/folder_img.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |Image for {{ img_path }}
10 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | sphinx>=2.1.2 3 | networkx>=2.2 4 | setuptools>=40.8.0 5 | matplotlib>=3.0.3 6 | numpy>=1.16.2 7 | seaborn>=0.9.0 8 | scikit_learn>=0.20.3 9 | hyperopt>=0.2.1 10 | pathlib>=1.0.1 11 | numpydoc>=0.9.1 12 | sphinx-gallery>=0.3.1 13 | sphinx-rtd-theme>=0.4.3 14 | pytest>=3.6 15 | pyyaml>=5.3.1 -------------------------------------------------------------------------------- /toolbox/nn/README.md: -------------------------------------------------------------------------------- 1 | # nn 2 | 3 | This package contains classic knowledge graph embedding models. 4 | 5 | ## Symbols 6 | 7 | The comments in the source code use these symbols: 8 | 9 | ``` 10 | T: number of triples 11 | B: batch size 12 | d: dimension 13 | d_e: dimension of entity embedding 14 | d_r: dimension of relation embedding 15 | ``` 16 | -------------------------------------------------------------------------------- /toolbox/nn/Highway.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Highway(nn.Module): 6 | def __init__(self, x_hidden): 7 | super(Highway, self).__init__() 8 | self.lin = nn.Linear(x_hidden, x_hidden) 9 | 10 | def forward(self, x1, x2): 11 | gate = torch.sigmoid(self.lin(x1)) 12 | x = torch.mul(gate, x2) + torch.mul(1 - gate, x1) 13 | return x 14 | -------------------------------------------------------------------------------- /toolbox/game/SinglePlayerGame.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | 8 | 9 | class SinglePlayerGameEnv: 10 | def __init__(self): 11 | pass 12 | 13 | def reset(self): 14 | # return state 15 | pass 16 | 17 | def step(self, action): 18 | # return state, reward, done, info 19 | pass 20 | 21 | def state(self): 22 | # return state 23 | pass 24 | 25 | def render(self, show=False): 26 | pass 27 | -------------------------------------------------------------------------------- /toolbox/utils/RandomSeeds.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 随机种子 6 | """ 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def set_seeds(seed=1234): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | if torch.cuda.is_available(): 18 | torch.cuda.manual_seed_all(seed) 19 | if seed == 0: 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | -------------------------------------------------------------------------------- /toolbox/README.md: -------------------------------------------------------------------------------- 1 | # KGE Toolbox 2 | 3 | 知识图谱嵌入工具箱,助力快速实验出成果 4 | 5 | 1. 通用 6 | 1. 命令行参数解析 7 | 2. 日志 8 | 3. 进度条 9 | 4. 随机种子 10 | 5. TensorBoard监控 11 | 6. 超参数自动搜索AutoML 12 | 7. 梯度累加(应对小内存gpu对batch_size的限制) 13 | 8. 中断训练、恢复训练 14 | 2. 知识图谱嵌入领域专用工具 15 | 1. 嵌入降维可视化 16 | 2. 数据集、常用数据预处理 17 | 3. 链接预测任务、实体对齐任务(自动生成对应的数据集并训练) 18 | 4. 测试指标(Hit@k、MR、MMR、AUC) 19 | 5. 经典KGE模型的PyTorch版复现 20 | 21 | 22 | ## CheatSheet 23 | 24 | CUDA_VISIBLE_DEVICES=3 25 | 26 | 内存分析,打印每一行代码执行前后的内存变化 27 | ```shell 28 | pip install memory_profiler psutil 29 | python -m memory_profiler main.py 30 | ``` -------------------------------------------------------------------------------- /toolbox/game/Player.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | 8 | 9 | class Player: 10 | def __init__(self): 11 | self.player_index = None 12 | 13 | def get_player_index(self): 14 | # for multi player game 15 | return self.player_index 16 | 17 | def set_player_index(self, player_index): 18 | # for multi player game 19 | self.player_index = player_index 20 | 21 | def get_action(self, state, **kwargs): 22 | raise NotImplementedError 23 | 24 | def reset(self): 25 | # for self-play player 26 | pass 27 | -------------------------------------------------------------------------------- /toolbox/cli/clean_output.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/4/25 5 | @description: null 6 | """ 7 | import os 8 | 9 | import click 10 | 11 | from toolbox.exp.OutputSchema import OutputSchema, Cleaner 12 | 13 | 14 | @click.command() 15 | @click.option("--output_dir", type=str, default="output", help="Which dir to clean") 16 | def main(output_dir): 17 | dirs = os.listdir(output_dir) 18 | for d in dirs: 19 | output = OutputSchema(str(d)) 20 | cleaner = Cleaner(output.pathSchema) 21 | cleaner.remove_non_best_checkpoint_and_model() 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /toolbox/data/TripleDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/30 5 | @description: null 6 | """ 7 | from typing import List, Tuple 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class TripleDataset(Dataset): 14 | def __init__(self, triples_ids: List[Tuple[int, int, int]]): 15 | self.triples_ids = triples_ids 16 | 17 | def __len__(self): 18 | return len(self.triples_ids) 19 | 20 | def __getitem__(self, idx): 21 | h, r, t = self.triples_ids[idx] 22 | h = torch.LongTensor([h]) 23 | r = torch.LongTensor([r]) 24 | t = torch.LongTensor([t]) 25 | return h, r, t 26 | -------------------------------------------------------------------------------- /toolbox/exp/Notebook.py: -------------------------------------------------------------------------------- 1 | from toolbox.exp.OutputSchema import OutputSchema 2 | from toolbox.utils.ModelParamStore import ModelParamStoreSchema 3 | from toolbox.utils.VisualizeStore import VisualizeStoreSchema 4 | 5 | 6 | class Self: 7 | pass 8 | 9 | 10 | def init_by_output(self: object, output: OutputSchema): 11 | self.debug = output.logger.debug 12 | self.log = output.logger.info 13 | self.warn = output.logger.warn 14 | self.error = output.logger.error 15 | self.critical = output.logger.critical 16 | self.success = output.logger.success 17 | self.fail = output.logger.failed 18 | self.vis = VisualizeStoreSchema(str(output.pathSchema.dir_path_visualize)) 19 | self.store = ModelParamStoreSchema(output.pathSchema) 20 | -------------------------------------------------------------------------------- /toolbox/utils/Time.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | from functools import wraps 3 | from time import time 4 | 5 | 6 | class benchmark(object): 7 | 8 | def __init__(self, msg, fmt="%0.3g"): 9 | self.msg = msg 10 | self.fmt = fmt 11 | 12 | def __enter__(self): 13 | self.start = timer() 14 | return self 15 | 16 | def __exit__(self, *args): 17 | t = timer() - self.start 18 | print(("%s : " + self.fmt + " seconds") % (self.msg, t)) 19 | self.time = t 20 | 21 | 22 | def timing(f): 23 | @wraps(f) 24 | def wrap(*args, **kw): 25 | ts = time() 26 | result = f(*args, **kw) 27 | te = time() 28 | print(f'func:{f.__name__!r} args:[{args!r}, {kw!r}] took: {te - ts:2.4f} sec') 29 | return result 30 | return wrap 31 | -------------------------------------------------------------------------------- /toolbox/nn/functional/complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mobius_mul_with_unit_norm(Q_1, Q_2): 5 | a_h = Q_1 # = {a_h + b_h i + c_h j + d_h k : a_r, b_r, c_r, d_r \in R^k} 6 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 7 | 8 | # Normalize the relation to eliminate the scaling effect 9 | denominator = torch.sqrt(a_r ** 2 + b_r ** 2 + c_r ** 2 + d_r ** 2) 10 | p = a_r / denominator 11 | q = b_r / denominator 12 | u = c_r / denominator 13 | v = d_r / denominator 14 | # Q'=E Hamilton product R 15 | h_r = (a_h * p + q) / (a_h * u + v) 16 | return h_r 17 | 18 | 19 | def mobius_mul(Q_1, Q_2): 20 | a_h = Q_1 # = {a_h : a_r, b_r, c_r, d_r \in R^k} 21 | a_r, b_r, c_r, d_r = Q_2 # = {a_r + b_r i + c_r j + d_r k : a_r, b_r, c_r, d_r \in R^k} 22 | h_r = (a_h * a_r + b_r) / (a_h * c_r + d_r) 23 | return h_r 24 | -------------------------------------------------------------------------------- /toolbox/web/log_app/static/css/bootstrap-table-reorder-rows.css: -------------------------------------------------------------------------------- 1 | .reorder_rows_onDragClass td { 2 | background-color: #eee; 3 | -webkit-box-shadow: 11px 5px 12px 2px #333, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 4 | -webkit-box-shadow: 6px 3px 5px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 5 | -moz-box-shadow: 6px 4px 5px 1px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 6 | -box-shadow: 6px 4px 5px 1px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 7 | } 8 | 9 | .reorder_rows_onDragClass td:last-child { 10 | -webkit-box-shadow: 8px 7px 12px 0 #333, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 11 | -webkit-box-shadow: 1px 8px 6px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset; 12 | -moz-box-shadow: 0 9px 4px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset, -1px 0 0 #ccc inset; 13 | -box-shadow: 0 9px 4px -4px #555, 0 1px 0 #ccc inset, 0 -1px 0 #ccc inset, -1px 0 0 #ccc inset; 14 | } -------------------------------------------------------------------------------- /toolbox/evaluate/GatherMetric.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/23 5 | @description: null 6 | """ 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | def __str__(self): 32 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 33 | return fmtstr.format(**self.__dict__) 34 | -------------------------------------------------------------------------------- /toolbox/utils/Embed.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def get_vec(entities_embedding, id_list: List[int], embedding_dim=200, device="cuda"): 7 | tensor = torch.LongTensor(id_list).view(-1, 1).to(device) 8 | return entities_embedding(tensor).view(-1, embedding_dim).cpu().detach().numpy() 9 | 10 | 11 | def get_vec2(entities_embedding, id_list: List[int], embedding_dim=200, device="cuda"): 12 | all_entity_ids = torch.LongTensor(id_list).view(-1).to(device) 13 | all_entity_vec = torch.index_select( 14 | entities_embedding, 15 | dim=0, 16 | index=all_entity_ids 17 | ).view(-1, embedding_dim).cpu().detach().numpy() 18 | return all_entity_vec 19 | 20 | 21 | def get_vec3(entities_embedding, orth: torch.Tensor, id_list: List[int], device="cuda"): 22 | all_entity_ids = torch.LongTensor(id_list).view(-1).to(device) 23 | all_entity_vec = torch.index_select( 24 | entities_embedding, 25 | dim=0, 26 | index=all_entity_ids 27 | ).view(-1, 200) 28 | all_entity_vec = all_entity_vec.matmul(orth.transpose(0, 1)) 29 | return all_entity_vec.cpu().detach().numpy() 30 | 31 | -------------------------------------------------------------------------------- /toolbox/exp/classic/experiment_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/20 5 | @description: 实验的基础能力 demo 6 | """ 7 | import random 8 | 9 | from toolbox.exp.Experiment import Experiment 10 | from toolbox.exp.OutputSchema import OutputSchema 11 | 12 | 13 | class MyExperiment(Experiment): 14 | 15 | def __init__(self, output: OutputSchema): 16 | super(MyExperiment, self).__init__(output) 17 | self.log(f"{locals()}") 18 | 19 | self.model_param_store.save_scripts([__file__]) 20 | seed = 10 21 | max_steps = 10 22 | learning_rate = 0.0001 23 | self.metric_log_store.set_rng_seed(seed) 24 | self.metric_log_store.add_hyper(learning_rate, "learning_rate") 25 | self.metric_log_store.add_progress(max_steps) 26 | for step in range(max_steps): 27 | acc = random.randint(10, 100) 28 | self.metric_log_store.add_loss(acc, step, name="loss") 29 | self.metric_log_store.add_metric({"acc": acc}, step, "Test") 30 | if acc > 50: 31 | self.metric_log_store.add_best_metric({"acc": acc}, "Test") 32 | self.metric_log_store.finish() 33 | 34 | 35 | output = OutputSchema("demo") 36 | MyExperiment(output) 37 | -------------------------------------------------------------------------------- /toolbox/web/log_server/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/19 5 | @description: 监听指定文件夹下的日志,以 web api 的形式提供服务给 log_app 6 | ``` 7 | python toolbox/web/log_server/main.py 8 | python -m toolbox.web.log_server.main 9 | ``` 10 | """ 11 | from typing import Optional, List 12 | 13 | import click 14 | import uvicorn 15 | from fastapi import FastAPI 16 | 17 | from .log_server import LogServer 18 | 19 | app = FastAPI() 20 | 21 | log_server = LogServer() 22 | 23 | 24 | @app.get("/") 25 | def read_root(): 26 | return {"你好": "请到该链接下查看 api 文档 /docs"} 27 | 28 | 29 | @app.get("/logs") 30 | def read_logs(ignore_log_names: Optional[dict] = None) -> List[dict]: 31 | return log_server.read_logs(ignore_log_names) 32 | 33 | 34 | @app.get("/certain_logs") 35 | def read_certain_logs(log_dir_names: List[str]) -> List[dict]: 36 | return log_server.read_certain_logs(log_dir_names) 37 | 38 | 39 | @click.command() 40 | @click.option("--log_dir", type=str, default="./output", help="日志所在文件夹") 41 | @click.option("--ip", type=str, default="0.0.0.0", help="IP 地址") 42 | @click.option("--port", type=int, default=45666, help="端口。如果该端口不可用,会自动选一个可用的") 43 | def main(log_dir: str, ip, port): 44 | app.root_path = log_dir 45 | log_server.set_log_dir(log_dir) 46 | uvicorn.run(app, host=ip, port=port) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /toolbox/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/11/7 5 | @description: null 6 | """ 7 | import torch 8 | 9 | 10 | def get_scheduler(optimizer, lr_policy="exp", epoch_count=5, lr_decay_iters=25, niter=100, niter_decay=100, ): 11 | """Return a learning rate scheduler 12 | Parameters: 13 | optimizer -- 网络优化器 14 | lr_policy -- 学习率scheduler的名称: linear | step | plateau | cosine 15 | """ 16 | if lr_policy == 'linear': 17 | def lambda_rule(epoch): 18 | lr_l = 1.0 - max(0, epoch + epoch_count - niter) / float(niter_decay + 1) 19 | return lr_l 20 | 21 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 22 | elif lr_policy == 'step': 23 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.5) 24 | elif lr_policy == 'plateau': 25 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 26 | elif lr_policy == 'cosine': 27 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) 28 | elif lr_policy == 'exp': 29 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) 30 | else: 31 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 32 | return scheduler 33 | -------------------------------------------------------------------------------- /toolbox/nn/DistMult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CoreDistMult(nn.Module): 7 | def __init__(self, input_dropout_rate=0.2): 8 | super(CoreDistMult, self).__init__() 9 | self.dropout = nn.Dropout(input_dropout_rate) 10 | 11 | def forward(self, h, r): 12 | h = self.dropout(h) 13 | r = self.dropout(r) 14 | 15 | x = h * r 16 | x = F.relu(x) 17 | return x 18 | 19 | 20 | class DistMult(nn.Module): 21 | def __init__(self, num_entities, num_relations, embedding_dim, input_dropout_rate=0.2): 22 | super(DistMult, self).__init__() 23 | self.E = nn.Embedding(num_entities, embedding_dim, padding_idx=0) 24 | self.R = nn.Embedding(num_relations, embedding_dim, padding_idx=0) 25 | self.core = CoreDistMult(input_dropout_rate) 26 | self.loss = nn.BCELoss() 27 | self.b = nn.Parameter(torch.zeros(num_entities)) 28 | 29 | def init(self): 30 | nn.init.xavier_normal_(self.E.weight.data) 31 | nn.init.xavier_normal_(self.R.weight.data) 32 | 33 | def forward(self, h_idx, r_idx): 34 | h = self.E(h_idx) 35 | r = self.R(r_idx) 36 | 37 | t = self.core(h, r) 38 | t = t.view(-1, self.embedding_dim) 39 | 40 | x = torch.mm(t, self.E.weight.transpose(1, 0)) 41 | x = x + self.b.expand_as(x) 42 | x = torch.sigmoid(x) 43 | return x 44 | -------------------------------------------------------------------------------- /toolbox/game/MultiPlayerGame.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/9/26 5 | @description: null 6 | """ 7 | from typing import Tuple 8 | 9 | from toolbox.game.Player import Player 10 | 11 | 12 | class DoublePlayerGameEnv: 13 | def __init__(self): 14 | self.players = [1, 2] # player1 and player2 15 | self.current_player: int = 1 16 | 17 | def reset(self, start_player=0): 18 | self.current_player = self.players[start_player] # start player 19 | # return state 20 | 21 | def step(self, action): 22 | # return state, reward, done, info 23 | pass 24 | 25 | def game_end(self) -> Tuple[bool, int]: 26 | # return bool, player_id 27 | # if player 1 win, then return True, 1 28 | # if no one win, then return False, -1 29 | return False, -1 30 | 31 | def render(self, player1, player2): 32 | pass 33 | 34 | def get_current_player(self): 35 | return self.current_player 36 | 37 | 38 | class GameManager: 39 | """game manager""" 40 | 41 | def start_play(self, player1: Player, player2: Player, start_player=0, is_shown=True): 42 | """start a game between two players""" 43 | pass 44 | 45 | def start_self_play(self, player: Player, is_shown=False): 46 | """ start a self-play game using a MCTS player, reuse the search tree, 47 | and store the self-play data: (state, mcts_probs, z) for training 48 | """ 49 | pass 50 | -------------------------------------------------------------------------------- /toolbox/utils/DefaultDict.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/3/11 5 | @description: null 6 | """ 7 | class DefaultDict(dict): 8 | def __init__(self, default_value): 9 | super(DefaultDict, self).__init__() 10 | self._default_value = default_value 11 | 12 | def __missing__(self, key): 13 | value = self._default_value(key) if callable(self._default_value) else self._default_value 14 | self[key] = value 15 | return value 16 | 17 | 18 | class LambdaSet(object): 19 | def __call__(self, key): 20 | return set() 21 | 22 | 23 | class LambdaDefaultDict(object): 24 | def __init__(self, value): 25 | self._value = value 26 | 27 | def __call__(self, key): 28 | return self._value() 29 | 30 | 31 | if __name__ == '__main__': 32 | import pickle 33 | import random 34 | 35 | sro_t = DefaultDict(DefaultDict(DefaultDict(LambdaSet()))) 36 | 37 | print(sro_t[0][0][0]) 38 | for i in range(10): 39 | for j in range(10): 40 | for k in range(10): 41 | sro_t[i][j][k] = {m for m in range(random.randint(3, 9))} 42 | print(sro_t[0][0][0]) 43 | 44 | with open("output/test.pkl", "wb") as f: 45 | pickle.dump(sro_t, f) 46 | 47 | with open("output/test.pkl", "rb") as f: 48 | another = pickle.load(f) 49 | for i in another: 50 | for j in another[i]: 51 | for k in another[i][j]: 52 | print(another[i][j][k]) 53 | for i in range(10, 12): 54 | for j in range(10, 12): 55 | for k in range(10, 12): 56 | print(another[i][j][k]) 57 | -------------------------------------------------------------------------------- /toolbox/nn/RESCAL.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2021/10/30 5 | @description: null 6 | """ 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class CoreRESCAL(nn.Module): 12 | def __init__(self, entity_dim): 13 | super(CoreRESCAL, self).__init__() 14 | self.entity_dim = entity_dim 15 | 16 | def forward(self, h, r): 17 | h = h.view(-1, 1, self.entity_dim) 18 | r = r.view(-1, self.entity_dim, self.entity_dim) 19 | 20 | t = torch.bmm(h, r).view(-1, self.entity_dim) 21 | return t 22 | 23 | 24 | class RESCAL(nn.Module): 25 | def __init__(self, num_entities, num_relations, entity_dim, input_dropout=0.3): 26 | super(RESCAL, self).__init__() 27 | self.entity_dim = entity_dim 28 | 29 | self.E = nn.Embedding(num_entities, entity_dim) 30 | self.R = nn.Embedding(num_relations, entity_dim * entity_dim) 31 | 32 | self.core = CoreRESCAL(entity_dim) 33 | self.input_dropout = nn.Dropout(input_dropout) 34 | 35 | self.loss = nn.BCELoss() 36 | self.b = nn.Parameter(torch.zeros(num_entities)) 37 | 38 | def init(self): 39 | nn.init.kaiming_uniform_(self.E.weight.data) 40 | nn.init.kaiming_uniform_(self.R.weight.data) 41 | 42 | def forward(self, h_idx, r_idx): 43 | h = self.input_dropout(self.E(h_idx)) 44 | r = self.R(r_idx) 45 | 46 | t = self.core(h, r) 47 | t = t.view(-1, self.entity_dim) 48 | 49 | x = torch.mm(t, self.input_dropout(self.E.weight).transpose(1, 0)) 50 | x = x + self.b.expand_as(x) 51 | x = torch.sigmoid(x) 52 | return x 53 | -------------------------------------------------------------------------------- /toolbox/web/log_app/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: lxy 3 | @email: linxy59@mail2.sysu.edu.cn 4 | @date: 2022/2/20 5 | @description: 监听远程日志服务器,显示前端 6 | ``` 7 | python toolbox/web/log_app/main.py 8 | python -m toolbox.web.log_app.main 9 | ``` 10 | """ 11 | import os 12 | import shutil 13 | 14 | import click 15 | 16 | from .app import start_app 17 | 18 | 19 | def create_settings(preference_dir_of_log_app="output"): 20 | pj_path = os.path.realpath('.') # user project path 21 | tools_path = os.path.realpath(__file__)[:-len("main.py")] # installed pkg path 22 | if not os.path.isdir(os.path.join(pj_path, preference_dir_of_log_app)): 23 | shutil.copytree(os.path.join(tools_path, "output"), os.path.join(pj_path, preference_dir_of_log_app)) 24 | elif not os.path.exists(os.path.join(pj_path, preference_dir_of_log_app, "default.cfg")): 25 | shutil.copy(os.path.join(tools_path, "output", "default.cfg"), os.path.join(pj_path, preference_dir_of_log_app)) 26 | 27 | 28 | @click.command() 29 | @click.option("--log_dir", type=str, default="output", help="app 配置文件所在文件夹.") 30 | @click.option("--log_config_name", type=str, default="default.cfg", help="启动 app 的配置文件。app 停机后会把运行期间被修改的配置保存到该文件,方便下次运行(这个设置在配置文件中关闭)。") 31 | @click.option("--ip", type=str, default="0.0.0.0", help="IP 地址") 32 | @click.option("--port", type=int, default=44666, help="端口。如果该端口不可用,会自动选一个可用的") 33 | @click.option("--standby_hours", type=int, default=24, help="空转小时数。如果超过这个时间没有任何操作,会自动停止运行,防止资源浪费") 34 | def main(log_dir, log_config_name, ip, port, standby_hours): 35 | log_dir = os.path.abspath(log_dir) 36 | create_settings(log_dir) 37 | start_app(log_dir, log_config_name, standby_hours, port, ip) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /toolbox/web/log_app/output/default.cfg: -------------------------------------------------------------------------------- 1 | [frontend_settings] 2 | # 以下的几个设置主要是用于控制前端的显示 3 | Ignore_null_value_when_filter=True 4 | Wrap_display=False 5 | Pagination=True 6 | Hide_hidden_columns_when_reorder=False 7 | # 前端的任何变动都不会尝试更新到服务器,即所有改动不会保存 8 | Offline=False 9 | # 是否保存本次前端页面的改动(包括删除,增加,column排序等)。在server关闭时和更改config时会判断 10 | Save_settings=True 11 | # row是否是可以通过拖拽交换的,如果可以交换则无法进行复制 12 | Reorderable_rows=False 13 | # 当选择revert代码时 revert到的路径: ../| File or Directory | 42 |Modify Time | 43 |Size | 44 |
|---|---|---|
| ..{{ ossep }} | 50 |51 | | 52 | |
| 57 | {% if i.isfile %} 58 | {{ i.filename }} 60 | 61 | {% else %} 62 | {{ i.filename }} 64 | 65 | {% endif %} 66 | | 67 |{{ i.mtime }} | 68 |{{ i.size }} | 69 |