├── README.md ├── bert_pretrain.py ├── coca_pretrain.py ├── data_prepare.py ├── finetune_bert.py ├── finetune_graph.py ├── finetune_image.py ├── finetune_multimodal.py ├── finetune_text.py ├── model_ensemble.py ├── model_soup_multimodal.py ├── model_soup_text.py ├── notebook └── ccks2022.ipynb ├── pics ├── 164923009848338131649230098175.png ├── coca.png ├── k3m.png ├── model_soup_1.png ├── model_soup_2.png ├── pkgm_roberta.png └── vit.png ├── pkgm_pretrain.py ├── pred_bert.py ├── pred_text.py ├── predict.sh ├── requirements.txt ├── run_coca_pretrain.sh ├── run_data_prepare.sh ├── run_finetune_graph.sh ├── run_finetune_image.sh ├── run_finetune_multimodal.sh ├── run_finetune_text.sh ├── run_model_soup_multimodal.sh ├── run_model_soup_text.sh ├── run_pkgm_pretrain.sh ├── run_pred_image.sh ├── run_pred_multimodal.sh ├── run_pred_text.sh ├── src ├── __init__.py ├── bert │ ├── __init__.py │ ├── data_utils.py │ ├── log.py │ └── model.py ├── config │ ├── coca_base.json │ ├── coca_large.json │ ├── eca_nfnet_l0.json │ ├── gcn.json │ ├── pkgm_base.json │ ├── pkgm_large.json │ ├── resnetv2_50.json │ ├── roberta_base.json │ ├── roberta_image_base.json │ ├── roberta_image_large.json │ ├── roberta_large.json │ ├── vit_base_patch16_384.json │ └── vit_large_patch16_384.json ├── data │ ├── __init__.py │ └── data.py ├── models │ ├── __init__.py │ ├── base.py │ ├── graph.py │ ├── image.py │ ├── loss.py │ ├── multimodal.py │ └── text.py └── utils │ ├── __init__.py │ ├── config.py │ └── logger.py ├── submit ├── Dockerfile ├── deepAI_result.jsonl ├── push.sh ├── requirements.txt ├── result.zip ├── run.sh └── similarity.py ├── torchkge ├── README.rst ├── __init__.py ├── docs │ ├── Makefile │ ├── __init__.py │ ├── _static │ │ └── css │ │ │ └── custom.css │ ├── authors.rst │ ├── conf.py │ ├── contributing.rst │ ├── history.rst │ ├── index.rst │ ├── installation.rst │ ├── logo_torchKGE_small.png │ ├── make.bat │ ├── readme.rst │ ├── reference │ │ ├── data.rst │ │ ├── evaluation.rst │ │ ├── inference.rst │ │ ├── models.rst │ │ ├── sampling.rst │ │ └── utils.rst │ └── tutorials │ │ ├── evaluation.rst │ │ ├── linkprediction.rst │ │ ├── training.rst │ │ ├── transe.rst │ │ ├── transe_early_stopping.rst │ │ ├── transe_wrappers.rst │ │ └── tripletclassification.rst ├── requirements_dev.txt └── torchkge │ ├── __init__.py │ ├── data_structures.py │ ├── evaluation.py │ ├── exceptions.py │ ├── inference.py │ ├── models │ ├── __init__.py │ ├── bilinear.py │ ├── deep.py │ ├── interfaces.py │ └── translation.py │ ├── sampling.py │ └── utils │ ├── __init__.py │ ├── data.py │ ├── data_redundancy.py │ ├── datasets.py │ ├── dissimilarities.py │ ├── losses.py │ ├── modeling.py │ ├── operations.py │ ├── pretrained_models.py │ └── training.py └── train.sh /model_ensemble.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import argparse 5 | import copy 6 | 7 | from src.utils import logger 8 | 9 | # 只在验证集出现,未在训练集出现的品类 10 | ONLY_VALID_CATES = ['投资贵金属', '客厅吸顶灯', '衬衫', '电热水壶', '养生壶/煎药壶', '鞋柜', '脱毛膏', '自热火锅', '洗烘套装', '椰棕床垫', '足浴器', '茶壶', '电动自行车'] 11 | # 只在测试集出现,未在训练集出现的品类 12 | ONLY_TEST_CATES = ['鞋柜', '洗衣机', '衬衫', '茶壶', '电动自行车', '脱毛膏', '投资贵金属', '椰棕床垫', '身体乳液', '客厅吸顶灯', '电热水壶', '足浴器', '养生壶/煎药壶', '洗烘套装', '自热火锅'] 13 | 14 | # 模型和阈值 15 | models_and_thresholds = [ 16 | # ("roberta_base-v3.4-one_tower-cls-ce", 0.4), 17 | ("roberta_large-v3.4-one_tower-cls-NA-ce", 0.3, 0.8610), 18 | # ("roberta_large-v3.6-one_tower-cls-NA-ce", 0.4, 0.8478), 19 | ("roberta_large-v3.4-one_tower-cls_1,2,3,4_cat-NA-ce", 0.4, 0.8600), 20 | ("roberta_large-v4-one_tower-cls-NA-ce", 0.4, 0.8612), 21 | ("roberta_image_large-v5-one_tower-cls-begin-ce", 0.4, 0.8582), 22 | # ("roberta_image_large-v5.1-one_tower-cls-begin-ce", 0.4, 0.8446), 23 | ("eca_nfnet_l0-v6", 0.5, 0.7777), 24 | ("pkgm_large-v3.4-one_tower-cls-NA-ce", 0.4, 0.8096), 25 | # ("item_alignment-k3m_base", 0.6, 0.7635), 26 | ("bert_base-one_tower-cls-NA-ce", 0.3, 0.8510), 27 | # ("bert_adversarial-two_tower-cls-ce", 0.3, 0.8477), 28 | # ("fasttext", 0.5, 0.7024), 29 | ("textcnn-v3.4-two_tower-cls-NA-ce", 0.6, 0.7703), 30 | # ("coca_base-v5.2-two_tower-cls-sum-ce", 0.5, 0.7875), 31 | # ("coca_large-v5.2-two_tower-cls-sum-ce", 0.5, 0.7784), 32 | # ("vit_base_patch16_384-v6.2", 0.5, 0.7685) 33 | ] 34 | 35 | models_and_thresholds_in = [ 36 | # ("roberta_base-v3.4-one_tower-cls-ce", 0.4), 37 | ("roberta_large-v3.4-one_tower-cls-NA-ce", 0.3, 0.8610), 38 | # ("roberta_large-v3.6-one_tower-cls-NA-ce", 0.4, 0.8478), 39 | ("roberta_large-v3.4-one_tower-cls_1,2,3,4_cat-NA-ce", 0.4, 0.8600), 40 | ("roberta_large-v4-one_tower-cls-NA-ce", 0.3, 0.8612), 41 | ("roberta_image_large-v5-one_tower-cls-begin-ce", 0.4, 0.8582), 42 | # ("roberta_image_large-v5.1-one_tower-cls-begin-ce", 0.4, 0.8446), 43 | ("eca_nfnet_l0-v6", 0.4, 0.7777), 44 | ("pkgm_large-v3.4-one_tower-cls-NA-ce", 0.4, 0.8096), 45 | # ("item_alignment-k3m_base", 0.6, 0.7635), 46 | ("bert_base-one_tower-cls-NA-ce", 0.3, 0.8510), 47 | # ("bert_adversarial-two_tower-cls-ce", 0.3, 0.8477), 48 | # ("fasttext", 0.5, 0.7024), 49 | ("textcnn-v3.4-two_tower-cls-NA-ce", 0.6, 0.7703), 50 | # ("coca_base-v5.2-two_tower-cls-sum-ce", 0.5, 0.7875), 51 | # ("coca_large-v5.2-two_tower-cls-sum-ce", 0.5, 0.7784), 52 | # ("vit_base_patch16_384-v6.2", 0.5, 0.7685) 53 | ] 54 | 55 | models_and_thresholds_not_in = [ 56 | ("roberta_large-v3.4-one_tower-cls-NA-ce", 0.4, 0.8610), 57 | # ("roberta_large-v3.6-one_tower-cls-NA-ce", 0.4, 0.8583), 58 | ("roberta_large-v3.4-one_tower-cls_1,2,3,4_cat-NA-ce", 0.4, 0.8600), 59 | ("roberta_large-v4-one_tower-cls-NA-ce", 0.5, 0.8612), 60 | ("roberta_image_large-v5-one_tower-cls-begin-ce", 0.4, 0.8582), 61 | # ("roberta_image_large-v5.1-one_tower-cls-begin-ce", 0.4, 0.8446), 62 | # ("eca_nfnet_l0-v6", 0.5, 0.7783), 63 | ("pkgm_large-v3.4-one_tower-cls-NA-ce", 0.5, 0.8096), 64 | # ("item_alignment-k3m_base", 0.6, 0.7635), 65 | ("bert_base-one_tower-cls-NA-ce", 0.4, 0.8510), 66 | # ("bert_adversarial-two_tower-cls-ce", 0.3, 0.8477), 67 | # ("fasttext", 0.5, 0.7024), 68 | ("textcnn-v3.4-two_tower-cls-NA-ce", 0.6, 0.7703), 69 | # ("coca_base-v5.2-two_tower-cls-sum-ce", 0.5, 0.7875), 70 | # ("coca_large-v5.2-two_tower-cls-sum-ce", 0.5, 0.7882), 71 | # ("vit_base_patch16_384-v6.2", 0.5, 0.7685) 72 | ] 73 | 74 | 75 | def get_parser(): 76 | parser = argparse.ArgumentParser() 77 | 78 | # Required parameters 79 | parser.add_argument("--data_dir", required=True, type=str, help="数据地址") 80 | parser.add_argument("--ensemble_strategy", required=True, type=str, help="ensemble strategy: threshold, f1") 81 | 82 | parser.add_argument("--input_file", default="deepAI_result_threshold=0.4.jsonl", type=str, 83 | help="input file name") 84 | parser.add_argument("--split_by_valid_or_test", action="store_true", help="whether to use different models and thresholds based" 85 | "on whether the categories appeared in training data") 86 | 87 | return parser.parse_args() 88 | 89 | 90 | def ensemble(args, id_dict): 91 | if args.split_by_valid_or_test: 92 | # 处理在训练集里出现过的品类 93 | lines = dict() 94 | for model, threshold, f1 in models_and_thresholds_in: 95 | f = os.path.join(args.data_dir, "output", model, args.input_file) 96 | ct = 0 97 | total = 0 98 | with open(f, "r", encoding="utf-8") as r: 99 | while True: 100 | line = r.readline() 101 | if not line: 102 | break 103 | d = json.loads(line.strip()) 104 | src_item_id = d['src_item_id'] 105 | src_cate_name = id_dict[src_item_id]['cate_name'] 106 | tgt_item_id = d['tgt_item_id'] 107 | tgt_cate_name = id_dict[tgt_item_id]['cate_name'] 108 | # if src_cate_name in ONLY_VALID_CATES or tgt_cate_name in ONLY_VALID_CATES: 109 | if src_cate_name in ONLY_TEST_CATES or tgt_cate_name in ONLY_TEST_CATES: 110 | continue 111 | key = src_item_id + "-" + tgt_item_id 112 | prob = eval(d['tgt_item_emb'])[0] 113 | if key not in lines: 114 | dd = copy.deepcopy(d) 115 | dd['tgt_item_emb'] = prob - threshold 116 | dd['0'] = 0.0 117 | dd['1'] = 0.0 118 | lines[key] = dd 119 | else: 120 | lines[key]['tgt_item_emb'] += prob - threshold 121 | if prob >= threshold: 122 | ct += 1 123 | lines[key]['1'] += f1 124 | else: 125 | lines[key]['0'] += f1 126 | total += 1 127 | logger.info(f"In Train: {model}-{threshold} p: {ct}, total: {total}") 128 | 129 | # 处理未在训练集里出现过的品类 130 | for model, threshold, f1 in models_and_thresholds_not_in: 131 | f = os.path.join(args.data_dir, "output", model, args.input_file) 132 | ct = 0 133 | total = 0 134 | with open(f, "r", encoding="utf-8") as r: 135 | while True: 136 | line = r.readline() 137 | if not line: 138 | break 139 | d = json.loads(line.strip()) 140 | src_item_id = d['src_item_id'] 141 | src_cate_name = id_dict[src_item_id]['cate_name'] 142 | tgt_item_id = d['tgt_item_id'] 143 | tgt_cate_name = id_dict[tgt_item_id]['cate_name'] 144 | # if src_cate_name in ONLY_VALID_CATES or tgt_cate_name in ONLY_VALID_CATES: 145 | if src_cate_name in ONLY_TEST_CATES or tgt_cate_name in ONLY_TEST_CATES: 146 | key = src_item_id + "-" + tgt_item_id 147 | prob = eval(d['tgt_item_emb'])[0] 148 | if key not in lines: 149 | dd = copy.deepcopy(d) 150 | dd['tgt_item_emb'] = prob - threshold 151 | dd['0'] = 0.0 152 | dd['1'] = 0.0 153 | lines[key] = dd 154 | else: 155 | lines[key]['tgt_item_emb'] += prob - threshold 156 | if prob >= threshold: 157 | ct += 1 158 | lines[key]['1'] += f1 159 | else: 160 | lines[key]['0'] += f1 161 | total += 1 162 | logger.info(f"Not In Train: {model}-{threshold} p: {ct}, total: {total}") 163 | else: 164 | lines = dict() 165 | for model, threshold, f1 in models_and_thresholds: 166 | f = os.path.join(args.data_dir, "output", model, args.input_file) 167 | ct = 0 168 | total = 0 169 | with open(f, "r", encoding="utf-8") as r: 170 | while True: 171 | line = r.readline() 172 | if not line: 173 | break 174 | d = json.loads(line.strip()) 175 | src_item_id = d['src_item_id'] 176 | src_cate_name = id_dict[src_item_id]['cate_name'] 177 | tgt_item_id = d['tgt_item_id'] 178 | tgt_cate_name = id_dict[tgt_item_id]['cate_name'] 179 | key = src_item_id + "-" + tgt_item_id 180 | prob = eval(d['tgt_item_emb'])[0] 181 | if key not in lines: 182 | dd = copy.deepcopy(d) 183 | dd['tgt_item_emb'] = prob - threshold 184 | dd['0'] = 0.0 185 | dd['1'] = 0.0 186 | lines[key] = dd 187 | else: 188 | lines[key]['tgt_item_emb'] += prob - threshold 189 | if prob >= threshold: 190 | ct += 1 191 | lines[key]['1'] += f1 192 | else: 193 | lines[key]['0'] += f1 194 | total += 1 195 | logger.info(f"{model}-{threshold} p: {ct}, total: {total}") 196 | 197 | return lines 198 | 199 | 200 | def main(): 201 | args = get_parser() 202 | 203 | # 加载item信息文件 204 | id_dict = dict() 205 | with open(os.path.join(args.data_dir, "raw", "item_info.jsonl"), "r", encoding="utf-8") as r: 206 | while True: 207 | line = r.readline() 208 | if not line: 209 | break 210 | d = json.loads(line.strip()) 211 | item_id = d['item_id'] 212 | id_dict[item_id] = d 213 | logger.info(f"id dict length: {len(id_dict)}") 214 | 215 | # 加载各个模型的结果 216 | lines = ensemble(args, id_dict) 217 | 218 | # 模型结果融合 219 | lines_ensemble = [] 220 | threshold = 0.0 221 | total = 0 222 | ct = 0 223 | cv_included_ct = 0 224 | for _, d in lines.items(): 225 | dd = copy.deepcopy(d) 226 | if args.ensemble_strategy == "f1": 227 | if dd['1'] >= dd['0']: 228 | ct += 1 229 | p = 1.0 230 | else: 231 | p = -1.0 232 | elif args.ensemble_strategy == "threshold": 233 | if dd['tgt_item_emb'] >= threshold: 234 | ct += 1 235 | p = dd['tgt_item_emb'] 236 | else: 237 | raise ValueError(f"unsupported ensemble strategy: {args.ensemble_strategy}") 238 | dd['tgt_item_emb'] = f"[{p}]" 239 | dd['threshold'] = threshold 240 | total += 1 241 | lines_ensemble.append(dd) 242 | logger.info(f"cv included p: {cv_included_ct}, p: {ct}, total: {total}") 243 | 244 | # 模型结果保存 245 | # model = "ensemble_f1-rl_v3.4_0.3-rl_v3.6_0.4-ril_v5_0.4-el0_v6_0.4-cl_v5.2_0.5-vb_v6_0.4" 246 | # model = "ensemble-rl_v3.4_0.3-rlcat_v3.4_0.4-rl_v4_0.4-ril_v5_0.4-el0_v6_0.5-pl_v3.4_0.4-bb_0.3-tc_v3.4_0.6-cl_v5.2_0.5" 247 | model = "ensemble" 248 | model_dir = os.path.join(args.data_dir, "output", model) 249 | if not os.path.isdir(model_dir): 250 | os.mkdir(model_dir) 251 | f = os.path.join(model_dir, f"deepAI_result.jsonl") 252 | with open(f, "w", encoding="utf-8") as w: 253 | for dd in lines_ensemble: 254 | w.write(json.dumps(dd)+"\n") 255 | 256 | 257 | if __name__ == "__main__": 258 | main() 259 | -------------------------------------------------------------------------------- /pics/164923009848338131649230098175.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/164923009848338131649230098175.png -------------------------------------------------------------------------------- /pics/coca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/coca.png -------------------------------------------------------------------------------- /pics/k3m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/k3m.png -------------------------------------------------------------------------------- /pics/model_soup_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/model_soup_1.png -------------------------------------------------------------------------------- /pics/model_soup_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/model_soup_2.png -------------------------------------------------------------------------------- /pics/pkgm_roberta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/pkgm_roberta.png -------------------------------------------------------------------------------- /pics/vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/pics/vit.png -------------------------------------------------------------------------------- /pkgm_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | 5 | from torch.optim import Adam 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torchkge import LinkPredictionEvaluator 8 | from torchkge import TransEModel, PKGMModel 9 | from torchkge import load_ccks 10 | from torchkge import Trainer, MarginLoss 11 | from src.utils import logger 12 | 13 | 14 | def get_parser(): 15 | parser = argparse.ArgumentParser() 16 | 17 | # Required parameters 18 | parser.add_argument("--data_dir", required=True, type=str, help="模型训练数据地址") 19 | parser.add_argument("--output_dir", required=True, type=str, help="The output directory where the model checkpoints will be written.") 20 | parser.add_argument("--model_name", default="transe_epoch-{}.bin", type=str, help="model saving name",) 21 | # training 22 | parser.add_argument("--do_eval", action="store_true", help="是否进行模型验证") 23 | parser.add_argument("--do_test", action="store_true", help="是否进行模型测试") 24 | parser.add_argument("--cuda_mode", default="all", help="cuda mode, all or batch") 25 | parser.add_argument("--train_batch_size", default=2048, type=int, help="Total batch size for training.") 26 | parser.add_argument("--eval_batch_size", default=2048, type=int, help="Total batch size for training.") 27 | parser.add_argument("--learning_rate", default=1e-3, type=float, help="The initial learning rate for Adam.") 28 | parser.add_argument("--start_epoch", default=0, type=int, help="starting training epoch") 29 | parser.add_argument("--num_train_epochs", default=1000, type=int, help="Total number of training epochs to perform.") 30 | parser.add_argument("--log_steps", default=None, type=int, help="every n steps, log training process") 31 | parser.add_argument("--save_epochs", default=1000, type=int, help="every n epochs, save model") 32 | parser.add_argument("--pretrained_model_path", default=None, type=str, help="pretrained model path") 33 | # optimization 34 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 35 | parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit") 36 | parser.add_argument("--weight_decay", default=1e-5, type=float, help="weight decay") 37 | parser.add_argument("--warmup_proportion", default=0.2, type=float, help="warmup proportion in learning rate scheduler") 38 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Number of updates steps to accumualte before performing a backward/update pass.") 39 | # Graph Embedding 40 | parser.add_argument("--dim", default=768, type=int, help="dimension of graph embedding") 41 | parser.add_argument("--margin", default=1.0, type=float, help="maring loss") 42 | parser.add_argument("--n_neg", default=3, type=int, help="number of negative samples") 43 | # parser.add_argument("--negative_entities", default=3, type=int, help="number of negative entities") 44 | # parser.add_argument("--negative_relations", default=3, type=int, help="number of negative relations") 45 | parser.add_argument("--norm", default="L2", type=str, help="vector norm: L1, L2, torus_L1, torus_L2") 46 | parser.add_argument("--sampling_type", default="bern", type=str, help="sampling type, Either 'unif' (uniform negative sampling) or " 47 | "'bern' (Bernoulli negative sampling)") 48 | 49 | return parser.parse_args() 50 | 51 | 52 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 53 | """ 54 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 55 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 56 | 57 | Args: 58 | optimizer ([`~torch.optim.Optimizer`]): 59 | The optimizer for which to schedule the learning rate. 60 | num_warmup_steps (`int`): 61 | The number of steps for the warmup phase. 62 | num_training_steps (`int`): 63 | The total number of training steps. 64 | last_epoch (`int`, *optional*, defaults to -1): 65 | The index of the last epoch when resuming training. 66 | 67 | Return: 68 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 69 | """ 70 | 71 | def lr_lambda(current_step: int): 72 | if current_step < num_warmup_steps: 73 | return float(current_step) / float(max(1, num_warmup_steps)) 74 | return max( 75 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 76 | ) 77 | 78 | return LambdaLR(optimizer, lr_lambda, last_epoch) 79 | 80 | 81 | def main(): 82 | args = get_parser() 83 | 84 | # Load dataset 85 | kgs = load_ccks(args.data_dir, args.do_eval, args.do_test) 86 | kg_train = kgs[0] 87 | logger.info(f"finished loading data") 88 | 89 | # Define the model and criterion 90 | if "transe" in args.model_name: 91 | model = TransEModel(args.dim, kg_train.n_ent, kg_train.n_rel, 92 | dissimilarity_type=args.norm) 93 | elif "pkgm" in args.model_name: 94 | model = PKGMModel(args.dim, kg_train.n_ent, kg_train.n_rel, 95 | dissimilarity_type=args.norm) 96 | else: 97 | raise ValueError(f"Unsuported model name: {args.model_name}") 98 | if args.pretrained_model_path is not None: 99 | state_dict = torch.load(args.pretrained_model_path, map_location="cpu") 100 | model.load_state_dict(state_dict) 101 | criterion = MarginLoss(args.margin) 102 | optimizer = Adam(model.parameters(), lr=args.learning_rate, 103 | weight_decay=args.weight_decay, eps=args.adam_epsilon) 104 | num_train_optimization_steps = int( 105 | len(kg_train) 106 | / args.train_batch_size 107 | / args.gradient_accumulation_steps 108 | ) * (args.num_train_epochs - args.start_epoch) 109 | num_warmup_steps = int(num_train_optimization_steps * args.warmup_proportion) 110 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_train_optimization_steps) 111 | 112 | # Start Training 113 | if args.fp16: 114 | scaler = torch.cuda.amp.GradScaler() 115 | else: 116 | scaler = None 117 | use_cuda = args.cuda_mode if torch.cuda.is_available() else None 118 | model_save_path = os.path.join(args.output_dir, args.model_name) 119 | trainer = Trainer(model, criterion, kg_train, args.num_train_epochs, 120 | args.train_batch_size, optimizer=optimizer, scheduler=scheduler, 121 | model_save_path=model_save_path, sampling_type=args.sampling_type, 122 | n_neg=args.n_neg, use_cuda=use_cuda, fp16=args.fp16, scaler=scaler, 123 | log_steps=args.log_steps, start_epoch=args.start_epoch, 124 | save_epochs=args.save_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps) 125 | trainer.run() 126 | 127 | # Evaluation 128 | if args.do_test: 129 | if args.do_eval: 130 | kg_test = kgs[2] 131 | else: 132 | kg_test = kgs[1] 133 | evaluator = LinkPredictionEvaluator(model, kg_test) 134 | evaluator.evaluate(args.eval_batch_size) 135 | evaluator.print_results() 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /pred_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import json 5 | import numpy as np 6 | import torch 7 | import jieba 8 | 9 | from torch.utils.data import DataLoader 10 | from transformers import ( 11 | BertTokenizer, 12 | BertConfig 13 | ) 14 | from src.models import RobertaModel 15 | from src.data import RobertaDataset, collate 16 | from src.utils import logger, BOS_TOKEN 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser() 21 | 22 | # Required parameters 23 | parser.add_argument("--data_dir", required=True, type=str, help="模型训练数据地址") 24 | parser.add_argument("--output_dir", required=True, type=str, help="The output directory where the model checkpoints will be written.") 25 | parser.add_argument("--config_file", required=True, type=str, help="The config file which specified the model details.") 26 | 27 | # training 28 | parser.add_argument("--seed", default=2345, type=int, help="random seed") 29 | parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") 30 | parser.add_argument("--eval_batch_size", default=64, type=int, help="Total batch size for training.") 31 | parser.add_argument("--learning_rate", default=1e-3, type=float, help="The initial learning rate for Adam.") 32 | parser.add_argument("--start_epoch", default=0, type=int, help="starting training epoch") 33 | parser.add_argument("--num_train_epochs", default=1000, type=int, help="Total number of training epochs to perform.") 34 | parser.add_argument("--weight_decay", default=1e-5, type=float, help="weight decay") 35 | parser.add_argument("--log_steps", default=None, type=int, help="every n steps, log training process") 36 | parser.add_argument("--pretrained_model_path", default=None, type=str, help="pretrained model path, including roberta and pkgm") 37 | parser.add_argument("--file_state_dict", default=None, type=str, help="finetuned model path") 38 | parser.add_argument("--type_vocab_size", default=2, type=int, help="Number of unique segment ids") 39 | parser.add_argument("--parameters_to_freeze", default=None, type=str, help="file that contains parameters that do not require gradient descend") 40 | parser.add_argument("--threshold", default=0.5, type=float, help="default threshold for item embedding score for prediction") 41 | # optimization 42 | parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " 43 | "E.g., 0.1 = 10%% of training.") 44 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Number of updates steps to accumualte before performing a backward/update pass.") 45 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 46 | parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit") 47 | parser.add_argument("--margin", default=1.0, type=float, help="margin in loss function") 48 | # NLP 49 | parser.add_argument("--do_lower_case", default=True, type=bool, help="Whether to lower case the input text. True for uncased models, False for cased models.") 50 | parser.add_argument("--max_seq_len", default=None, type=int, help="max length for one item title") 51 | parser.add_argument("--max_seq_len_pv", default=None, type=int, help="max length of pvs, 'None' - do not add pvs as text") 52 | parser.add_argument("--max_position_embeddings", default=512, type=int, help="max position embedding length") 53 | parser.add_argument("--max_pvs", default=20, type=int, help="max number of pairs for one item") 54 | parser.add_argument("--cls_layers", default="1", type=str, help="which layers of cls used for classification") 55 | parser.add_argument("--cls_pool", default="cat", type=str, help="ways to pool multiple layers of cls used for classification") 56 | parser.add_argument("--auxiliary_task", action="store_true", help="whether to include auxiliary task. The task is additionally comparing pv pairs of src and tgt item." 57 | "for pv keys that are shared by two items, compute whether the pv value is the same") 58 | # TextCNN 59 | parser.add_argument("--filter_sizes", default="1,2,3,5", type=str, help="filter sizes") 60 | parser.add_argument("--num_filters", default=36, type=int, help="number of filters") 61 | 62 | return parser.parse_args() 63 | 64 | 65 | def load_raw_data(args): 66 | id2image_name = dict() 67 | with open(os.path.join(args.data_dir, "raw", "item_info.jsonl"), "r", encoding="utf-8") as r: 68 | while True: 69 | line = r.readline() 70 | if not line: 71 | break 72 | d = json.loads(line.strip()) 73 | id2image_name[d['item_id']] = d#['item_image_name'] 74 | logger.info(f"Finished loading item info, size: {len(id2image_name)}") 75 | 76 | test_data = [] 77 | with open(os.path.join(args.data_dir, "processed", "entity2id.txt"), "r", encoding="utf-8") as r: 78 | while True: 79 | line = r.readline() 80 | if not line: 81 | break 82 | item, idx = line.strip("\n").split("\t") 83 | if "/item/" in item: 84 | item = item.replace("/item/", "") 85 | text = id2image_name[item]['title'] 86 | elif "/value/" in item: 87 | text = item.replace("/value/", "") 88 | else: 89 | logger.warning(f"wrong format data: {item}") 90 | text = " ".join(jieba.cut(text)) 91 | test_data.append((idx, text)) 92 | 93 | return test_data 94 | 95 | 96 | def main(): 97 | args = get_parser() 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | n_gpu = torch.cuda.device_count() 100 | logger.info(f"device: {device}, n_gpu: {n_gpu}, 16-bits training: {args.fp16}") 101 | # 设定随机数种子 102 | random.seed(args.seed) 103 | np.random.seed(args.seed) 104 | torch.manual_seed(args.seed) 105 | if n_gpu > 0: 106 | torch.cuda.manual_seed_all(args.seed) 107 | # load tokenizer 108 | tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=args.do_lower_case) 109 | tokenizer.do_basic_tokenize = False 110 | tokenizer.bos_token = BOS_TOKEN 111 | logger.info(f"vocab size: {tokenizer.vocab_size}") 112 | # load model 113 | config = BertConfig.from_json_file(os.path.join(args.output_dir, args.config_file)) 114 | # config.interaction_type = args.interaction_type 115 | # config.type_vocab_size = args.type_vocab_size 116 | # config.classification_method = args.classification_method 117 | # config.similarity_measure = args.similarity_measure 118 | # config.loss_type = args.loss_type 119 | config.max_seq_len = args.max_seq_len 120 | # config.max_seq_len_pv = args.max_seq_len_pv 121 | # config.max_pvs = args.max_pvs 122 | # config.max_position_embeddings = args.max_position_embeddings 123 | # config.loss_margin = args.margin 124 | # config.cls_layers = args.cls_layers 125 | # config.cls_pool = args.cls_pool 126 | # config.filter_sizes = args.filter_sizes 127 | # config.num_filters = args.num_filters 128 | # config.auxiliary_task = args.auxiliary_task 129 | config.ensemble = None 130 | if args.max_seq_len_pv is None: 131 | max_seq_len = args.max_seq_len 132 | elif args.max_seq_len is None: 133 | max_seq_len = args.max_seq_len_pv 134 | else: 135 | max_seq_len = args.max_seq_len + args.max_seq_len_pv 136 | assert args.max_position_embeddings >= 2 * max_seq_len + 2 137 | 138 | model = RobertaModel.from_pretrained(args.pretrained_model_path, config=config, 139 | ignore_mismatched_sizes=True) 140 | # load previous model weights (if exists) 141 | if args.file_state_dict is not None: 142 | state_dict = torch.load(args.file_state_dict, map_location="cpu") 143 | model.load_state_dict(state_dict) 144 | # load raw data 145 | test_data = load_raw_data(args) 146 | logger.info(f"# test samples: {len(test_data)}") 147 | 148 | if device == "cuda": 149 | model.cuda() 150 | 151 | test_dataset = RobertaDataset(test_data, tokenizer, max_seq_len=args.max_seq_len) 152 | test_data_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, 153 | collate_fn=collate) 154 | 155 | model.eval() 156 | torch.set_grad_enabled(False) 157 | 158 | f = os.path.join(args.data_dir, "processed", f"feature_matrix.pt") 159 | feature_matrix = None 160 | for step, batch in enumerate(test_data_loader): 161 | ids = batch[0] 162 | if args.fp16: 163 | with torch.cuda.amp.autocast(): 164 | batch = tuple(t.to(device=device, non_blocking=True) if t is not None else t for t in batch[1:]) 165 | input_ids, segment_ids, input_mask, position_ids = batch 166 | output = model( 167 | input_ids=input_ids, 168 | token_type_ids=segment_ids, 169 | attention_mask=input_mask, 170 | position_ids=position_ids, 171 | output_hidden_states=True 172 | ) 173 | else: 174 | batch = tuple(t.to(device=device, non_blocking=True) if t is not None else t for t in batch[1:]) 175 | input_ids, segment_ids, input_mask, position_ids = batch 176 | output = model( 177 | input_ids=input_ids, 178 | token_type_ids=segment_ids, 179 | attention_mask=input_mask, 180 | position_ids=position_ids, 181 | output_hidden_states=True 182 | ) 183 | sequence_outputs = output.pooler_output 184 | if feature_matrix is None: 185 | feature_matrix = sequence_outputs 186 | else: 187 | feature_matrix = torch.cat((feature_matrix, sequence_outputs), dim=0) 188 | 189 | if args.log_steps is not None and step % args.log_steps == 0: 190 | logger.info(f"[Prediction] {step} samples processed") 191 | 192 | torch.save(feature_matrix, f) 193 | 194 | logger.info(f"[Prediction] Finished processing") 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | ROOT_DIR="${HOME}/Data" 5 | DATA_DIR=${ROOT_DIR} 6 | OUTPUT_DIR=${ROOT_DIR}/output 7 | PRETRAINED_MODEL_PATH="${ROOT_DIR}/bert/roberta_large" 8 | 9 | 10 | # Roberta_large-v3.4 11 | FILE_STATE_DICT="${HOME}/Data/output/roberta_large-v3.4-one_tower-cls-NA-ce/text_finetune_epoch-9.bin" 12 | python finetune_text.py \ 13 | --data_dir $DATA_DIR \ 14 | --output_dir $OUTPUT_DIR \ 15 | --model_name "roberta_large" \ 16 | --data_version "v3.4" \ 17 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 18 | --file_state_dict $FILE_STATE_DICT \ 19 | --config_file "src/config/roberta_large.json" \ 20 | --do_pred \ 21 | --interaction_type "one_tower" \ 22 | --classification_method "cls" \ 23 | --similarity_measure "NA" \ 24 | --loss_type "ce" \ 25 | --max_seq_len 50 \ 26 | --max_seq_len_pv 205 \ 27 | --eval_batch_size 108 \ 28 | --threshold 0.4 \ 29 | --fp16 30 | 31 | # Roberta_large-v3.4-cls_cat_1,2,3,4 32 | FILE_STATE_DICT="${HOME}/Data/output/roberta_large-v3.4-one_tower-cls_1,2,3,4_cat-NA-ce/text_finetune_epoch-9.bin" 33 | python finetune_text.py \ 34 | --data_dir $DATA_DIR \ 35 | --output_dir $OUTPUT_DIR \ 36 | --model_name "roberta_large" \ 37 | --data_version "v3.4" \ 38 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 39 | --file_state_dict $FILE_STATE_DICT \ 40 | --config_file "src/config/roberta_large.json" \ 41 | --do_pred \ 42 | --interaction_type "one_tower" \ 43 | --classification_method "cls_1,2,3,4_cat" \ 44 | --similarity_measure "NA" \ 45 | --loss_type "ce" \ 46 | --cls_layers "1,2,3,4" \ 47 | --cls_pool "cat" \ 48 | --max_seq_len 50 \ 49 | --max_seq_len_pv 205 \ 50 | --eval_batch_size 108 \ 51 | --threshold 0.4 \ 52 | --fp16 53 | 54 | # Roberta_large-v4 55 | FILE_STATE_DICT="${HOME}/Data/output/roberta_large-v4-one_tower-cls-NA-ce/text_finetune_epoch-9.bin" 56 | python finetune_text.py \ 57 | --data_dir $DATA_DIR \ 58 | --output_dir $OUTPUT_DIR \ 59 | --model_name "roberta_large" \ 60 | --data_version "v4" \ 61 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 62 | --file_state_dict $FILE_STATE_DICT \ 63 | --config_file "src/config/roberta_large.json" \ 64 | --do_pred \ 65 | --interaction_type "one_tower" \ 66 | --classification_method "cls" \ 67 | --similarity_measure "NA" \ 68 | --loss_type "ce" \ 69 | --max_seq_len 50 \ 70 | --max_seq_len_pv 205 \ 71 | --eval_batch_size 108 \ 72 | --threshold 0.4 \ 73 | --fp16 74 | 75 | # pkgm_large-v3.4 76 | FILE_STATE_DICT="${HOME}/Data/output/pkgm_large-v3.4-one_tower-cls-NA-ce/text_finetune_epoch-9.bin" 77 | python finetune_text.py \ 78 | --data_dir $DATA_DIR \ 79 | --output_dir $OUTPUT_DIR \ 80 | --model_name "pkgm_large" \ 81 | --data_version "v3.4" \ 82 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 83 | --file_state_dict $FILE_STATE_DICT \ 84 | --config_file "src/config/pkgm_large.json" \ 85 | --do_pred \ 86 | --interaction_type "one_tower" \ 87 | --classification_method "cls" \ 88 | --similarity_measure "NA" \ 89 | --loss_type "ce" \ 90 | --max_seq_len 64 \ 91 | --max_pvs 30 \ 92 | --eval_batch_size 512 \ 93 | --threshold 0.4 \ 94 | --fp16 95 | 96 | # textcnn-v3.4 97 | FILE_STATE_DICT="${HOME}/Data/output/textcnn-v3.4-two_tower-cls-NA-ce/text_finetune_epoch-9.bin" 98 | python finetune_text.py \ 99 | --data_dir $DATA_DIR \ 100 | --output_dir $OUTPUT_DIR \ 101 | --model_name "textcnn" \ 102 | --data_version "v3.4" \ 103 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 104 | --file_state_dict $FILE_STATE_DICT \ 105 | --config_file "src/config/roberta_large.json" \ 106 | --do_pred \ 107 | --interaction_type "two_tower" \ 108 | --classification_method "cls" \ 109 | --similarity_measure "NA" \ 110 | --loss_type "ce" \ 111 | --max_seq_len 50 \ 112 | --max_seq_len_pv 205 \ 113 | --eval_batch_size 512 \ 114 | --threshold 0.4 \ 115 | --fp16 116 | 117 | # bert_base 118 | python pred_bert.py 119 | 120 | # roberta_image_large-v5 121 | FILE_STATE_DICT="${HOME}/Data/output/roberta_image_large-v5-one_tower-cls-begin-ce/multimodal_finetune_epoch-9.bin" 122 | python finetune_multimodal.py \ 123 | --data_dir $DATA_DIR \ 124 | --output_dir $OUTPUT_DIR \ 125 | --model_name "roberta_image_large" \ 126 | --data_version "v5" \ 127 | --config_file "src/config/roberta_image_large.json" \ 128 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 129 | --file_state_dict $FILE_STATE_DICT \ 130 | --do_pred \ 131 | --interaction_type "one_tower" \ 132 | --classification_method "cls" \ 133 | --ensemble "begin" \ 134 | --loss_type "ce" \ 135 | --max_seq_len 50 \ 136 | --max_seq_len_pv 205 \ 137 | --eval_batch_size 108 \ 138 | --threshold 0.4 \ 139 | --fp16 140 | 141 | # eca_nfnet_l0-v6 142 | FILE_STATE_DICT="${HOME}/Data/output/eca_nfnet_l0-v6/image_finetune_epoch-9.bin" 143 | python finetune_image.py \ 144 | --data_dir $DATA_DIR/raw \ 145 | --output_dir $OUTPUT_DIR \ 146 | --model_name "eca_nfnet_l0" \ 147 | --data_version "v6" \ 148 | --config_file "src/config/eca_nfnet_l0.json" \ 149 | --file_state_dict $FILE_STATE_DICT \ 150 | --do_pred \ 151 | --image_size 1000 \ 152 | --eval_batch_size 128 \ 153 | --threshold 0.4 \ 154 | --fp16 155 | 156 | # model ensemble 157 | python model_ensemble.py \ 158 | --data_dir $DATA_DIR \ 159 | --ensemble_strategy "threshold" \ 160 | --split_by_valid_or_test -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | jieba 3 | pytorch_transformers 4 | scikit-learn==1.1.1 5 | tensorpack==0.11 6 | tensorflow==2.2.0 7 | torch==1.11.0 8 | torchvision==0.12.0 9 | torch-geometric==2.0.4 10 | transformers==4.20.1 11 | timm==0.6.5 12 | Pillow==9.1.1 13 | matplotlib>=3.2.2 14 | seaborn>=0.11.0 15 | tqdm==4.64.0 16 | tensorboardX -------------------------------------------------------------------------------- /run_coca_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # data processing 4 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 5 | MAIN="/root/Code/commodity-alignment/pkgm/coca_pretrain.py" 6 | DATA_DIR=${ROOT_DIR}/raw 7 | OUTPUT_DIR=${ROOT_DIR}/output 8 | MODEL_NAME="coca_base" 9 | PRETRAINED_TEXT_MODEL_PATH="/root/autodl-tmp/Data/bert/roberta_base" 10 | PRETRAINED_IMAGE_MODEL_PATH="/root/autodl-tmp/Data/cv/vit_base_patch16_384/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz" 11 | #PRETRAINED_IMAGE_MODEL_PATH="/root/autodl-tmp/Data/cv/vit_large_patch16_384/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz" 12 | 13 | IMAGE_MODEL_NAME="vit_base_patch16_384" 14 | TRAIN_BATCH_SIZE=108 15 | LEARNING_RATE=1e-4 16 | NUM_EPOCHS=10 17 | 18 | python $MAIN \ 19 | --data_dir $DATA_DIR \ 20 | --output_dir $OUTPUT_DIR \ 21 | --model_name $MODEL_NAME \ 22 | --config_file ${MODEL_NAME}.json \ 23 | --pretrained_text_model_path $PRETRAINED_TEXT_MODEL_PATH \ 24 | --pretrained_image_model_path $PRETRAINED_IMAGE_MODEL_PATH \ 25 | --image_model_name $IMAGE_MODEL_NAME \ 26 | --image_size 384 \ 27 | --max_seq_len 64 \ 28 | --train_batch_size $TRAIN_BATCH_SIZE \ 29 | --warmup_proportion 0.3 \ 30 | --learning_rate $LEARNING_RATE \ 31 | --num_train_epochs $NUM_EPOCHS \ 32 | --log_steps 10 \ 33 | --fp16 34 | -------------------------------------------------------------------------------- /run_data_prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 8 | MAIN="/root/Code/commodity-alignment/pkgm/data_prepare.py" 9 | DATA_DIR=${ROOT_DIR}/raw 10 | OUTPUT_DIR=${ROOT_DIR}/processed 11 | NUM_TRAIN_AUGMENT=0 12 | #NUM_TRAIN_AUGMENT=22000 13 | #NUM_TRAIN_AUGMENT=49000 14 | 15 | # ONLY IMAGE DATA 16 | #python $MAIN \ 17 | # --data_dir $DATA_DIR \ 18 | # --output_dir $OUTPUT_DIR \ 19 | # --dtypes "train,valid" \ 20 | # --only_image \ 21 | # --image_size 1000 22 | 23 | # TEXT DATA 24 | #python $MAIN \ 25 | # --data_dir $DATA_DIR \ 26 | # --output_dir $OUTPUT_DIR \ 27 | # --dtypes "train,valid" \ 28 | # --filter_method "freq" \ 29 | # --min_freq 10 \ 30 | # --min_prop 0.5 \ 31 | # --num_train_augment $NUM_TRAIN_AUGMENT \ 32 | # --num_neg 1 \ 33 | # --split_on_train \ 34 | # --prev_valid $OUTPUT_DIR/finetune_train_valid_orig.tsv \ 35 | # --valid_proportion 0.25 \ 36 | # --valid_pos_proportion 0.4 37 | ## --with_image \ 38 | ## --cv_model_name "eca_nfnet_l0" \ 39 | ## --finetuned \ 40 | ## --pretrained_model_path "/root/autodl-tmp/Data/ccks2022/task9/output/eca_nfnet_l0-v6-full/image_finetune_epoch-4.bin" \ 41 | ## --image_size 1000 \ 42 | ## --batch_size 256 43 | 44 | # IMAGE OBJECT DETECTION 45 | python $MAIN \ 46 | --data_dir $DATA_DIR \ 47 | --output_dir $OUTPUT_DIR \ 48 | --dtypes "train,valid" \ 49 | --only_image \ 50 | --object_detection \ 51 | --cv_model_name "yolov5x6" \ 52 | --code_path "/root/Code/yolov5" \ 53 | --pretrained_model_path "/root/autodl-tmp/Data/cv/yolov5x6.pt" \ 54 | --min_crop_ratio 0.1 55 | -------------------------------------------------------------------------------- /run_finetune_graph.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 6 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_graph.py" 7 | DATA_DIR=${ROOT_DIR} 8 | OUTPUT_DIR=${ROOT_DIR}/output 9 | MODEL_NAME="gcn" 10 | DATA_VERSION="v1" 11 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/roberta_large" 12 | #PARAMETERS_TO_FREEZE="/root/autodl-tmp/Data/ccks2022/task9/output/parameters_pkgm.json" 13 | #PARAMETERS_TO_FREEZE="/root/autodl-tmp/Data/ccks2022/task9/output/textcnn_parameters_to_freeze.json" 14 | 15 | INTERACTION_TYPE="two_tower" 16 | CLASSIFICATION_METHOD="cls" 17 | SIMILARITY_MEASURE="NA" 18 | LOSS_TYPE="ce" 19 | BATCH_SIZE=512 20 | LEARNING_RATE=1e-4 21 | NUM_EPOCHS=500 22 | 23 | python $MAIN \ 24 | --data_dir $DATA_DIR \ 25 | --output_dir $OUTPUT_DIR \ 26 | --model_name ${MODEL_NAME} \ 27 | --data_version ${DATA_VERSION} \ 28 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 29 | --config_file "${MODEL_NAME}.json" \ 30 | --do_train \ 31 | --do_eval \ 32 | --interaction_type $INTERACTION_TYPE \ 33 | --classification_method $CLASSIFICATION_METHOD \ 34 | --similarity_measure $SIMILARITY_MEASURE \ 35 | --loss_type $LOSS_TYPE \ 36 | --warmup_proportion 0.3 \ 37 | --train_batch_size $BATCH_SIZE \ 38 | --eval_batch_size $BATCH_SIZE \ 39 | --learning_rate $LEARNING_RATE \ 40 | --num_train_epochs $NUM_EPOCHS \ 41 | --num_layers 4 \ 42 | --hidden_size 128 \ 43 | --log_steps 10 \ 44 | --save_epochs 10 45 | # --fp16 46 | # --auxiliary_task 47 | # --parameters_to_freeze $PARAMETERS_TO_FREEZE 48 | -------------------------------------------------------------------------------- /run_finetune_image.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 6 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_image.py" 7 | DATA_DIR=${ROOT_DIR}/raw 8 | OUTPUT_DIR=${ROOT_DIR}/output 9 | MODEL_NAME="vit_base_patch16_384" 10 | DATA_VERSION="v6" 11 | #PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/cv/${MODEL_NAME}.pth" 12 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/cv/vit_base_patch16_384/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz" 13 | #PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/cv/vit_large_patch16_384/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz" 14 | 15 | TRAIN_BATCH_SIZE=56 16 | EVAL_BATCH_SIZE=112 17 | LEARNING_RATE=5e-5 18 | NUM_EPOCHS=10 19 | 20 | python $MAIN \ 21 | --data_dir $DATA_DIR \ 22 | --output_dir $OUTPUT_DIR \ 23 | --model_name ${MODEL_NAME} \ 24 | --data_version ${DATA_VERSION} \ 25 | --config_file ${MODEL_NAME}.json \ 26 | --do_train \ 27 | --do_eval \ 28 | --warmup_proportion 0.3 \ 29 | --image_size 384 \ 30 | --train_batch_size $TRAIN_BATCH_SIZE \ 31 | --eval_batch_size $EVAL_BATCH_SIZE \ 32 | --learning_rate $LEARNING_RATE \ 33 | --num_train_epochs $NUM_EPOCHS \ 34 | --log_steps 10 \ 35 | --fp16 36 | # --pretrained_model_path $PRETRAINED_MODEL_PATH \ 37 | -------------------------------------------------------------------------------- /run_finetune_multimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | #PRETRAINED_MODEL_PATH="/Users/zeyesun/Documents/Data/bert/chinese_roberta_wwm_ext_pytorch" 8 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 9 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_multimodal.py" 10 | DATA_DIR=${ROOT_DIR} 11 | OUTPUT_DIR=${ROOT_DIR}/output 12 | MODEL_NAME="roberta_image_large" 13 | DATA_VERSION="v5-full" 14 | PRETRAINED_MODEL_PATH="${OUTPUT_DIR}/${MODEL_NAME}" 15 | #PARAMETERS_TO_FREEZE="/root/autodl-tmp/Data/ccks2022/task9/output/parameters_pkgm.json" 16 | IMAGE_MODEL_NAME="vit_base_patch16_384" 17 | 18 | INTERACTION_TYPE="two_tower" 19 | CLASSIFICATION_METHOD="cls" 20 | ENSEMBLE="begin" 21 | LOSS_TYPE="ce" 22 | TRAIN_BATCH_SIZE=16 23 | EVAL_BATCH_SIZE=64 24 | LEARNING_RATE=5e-5 25 | NUM_EPOCHS=10 26 | 27 | python $MAIN \ 28 | --data_dir $DATA_DIR \ 29 | --output_dir $OUTPUT_DIR \ 30 | --model_name ${MODEL_NAME} \ 31 | --data_version ${DATA_VERSION} \ 32 | --config_file "${MODEL_NAME}.json" \ 33 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 34 | --do_train \ 35 | --interaction_type $INTERACTION_TYPE \ 36 | --classification_method $CLASSIFICATION_METHOD \ 37 | --ensemble $ENSEMBLE \ 38 | --loss_type $LOSS_TYPE \ 39 | --max_seq_len 50 \ 40 | --max_seq_len_pv 205 \ 41 | --image_hidden_size 3072 \ 42 | --warmup_proportion 0.3 \ 43 | --train_batch_size $TRAIN_BATCH_SIZE \ 44 | --eval_batch_size $EVAL_BATCH_SIZE \ 45 | --learning_rate $LEARNING_RATE \ 46 | --num_train_epochs $NUM_EPOCHS \ 47 | --log_steps 10 \ 48 | --fp16 49 | # --parameters_to_freeze $PARAMETERS_TO_FREEZE 50 | -------------------------------------------------------------------------------- /run_finetune_text.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | #PRETRAINED_MODEL_PATH="/Users/zeyesun/Documents/Data/bert/chinese_roberta_wwm_ext_pytorch" 8 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 9 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_text.py" 10 | DATA_DIR=${ROOT_DIR} 11 | OUTPUT_DIR=${ROOT_DIR}/output 12 | MODEL_NAME="roberta_large" 13 | DATA_VERSION="v3.4-full" 14 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/${MODEL_NAME}" 15 | PARAMETERS_TO_FREEZE="/root/autodl-tmp/Data/ccks2022/task9/output/parameters_pkgm.json" 16 | #PARAMETERS_TO_FREEZE="/root/autodl-tmp/Data/ccks2022/task9/output/textcnn_parameters_to_freeze.json" 17 | 18 | INTERACTION_TYPE="two_tower" 19 | CLASSIFICATION_METHOD="cls" 20 | SIMILARITY_MEASURE="NA" 21 | LOSS_TYPE="ce" 22 | BATCH_SIZE=40 23 | LEARNING_RATE=5e-5 24 | NUM_EPOCHS=10 25 | 26 | python $MAIN \ 27 | --data_dir $DATA_DIR \ 28 | --output_dir $OUTPUT_DIR \ 29 | --model_name ${MODEL_NAME} \ 30 | --data_version ${DATA_VERSION} \ 31 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 32 | --config_file "${MODEL_NAME}.json" \ 33 | --do_train \ 34 | --interaction_type $INTERACTION_TYPE \ 35 | --classification_method $CLASSIFICATION_METHOD \ 36 | --similarity_measure $SIMILARITY_MEASURE \ 37 | --loss_type $LOSS_TYPE \ 38 | --max_seq_len 50 \ 39 | --max_seq_len_pv 205 \ 40 | --max_pvs 30 \ 41 | --warmup_proportion 0.3 \ 42 | --train_batch_size $BATCH_SIZE \ 43 | --eval_batch_size $BATCH_SIZE \ 44 | --learning_rate $LEARNING_RATE \ 45 | --num_train_epochs $NUM_EPOCHS \ 46 | --log_steps 10 \ 47 | --fp16 48 | # --auxiliary_task 49 | # --parameters_to_freeze $PARAMETERS_TO_FREEZE 50 | -------------------------------------------------------------------------------- /run_model_soup_multimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 6 | MAIN="/root/Code/commodity-alignment/pkgm/model_soup_multimodal.py" 7 | DATA_DIR=${ROOT_DIR}/processed 8 | OUTPUT_DIR=${ROOT_DIR}/output 9 | MODEL_NAME="roberta_image_large" 10 | DATA_VERSION="v5-full" 11 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/roberta_large" 12 | 13 | INTERACTION_TYPE="one_tower" 14 | CLASSIFICATION_METHOD="cls" 15 | ENSEMBLE="begin" 16 | LOSS_TYPE="ce" 17 | BATCH_SIZE=96 18 | 19 | FILE_STATE_DICT="/root/autodl-tmp/Data/ccks2022/task9/output/${MODEL_NAME}-${DATA_VERSION}-${INTERACTION_TYPE}-${CLASSIFICATION_METHOD}-${ENSEMBLE}-${LOSS_TYPE}/multimodal_finetune_epoch-{}.bin" 20 | EPOCHS="6,7,8,9" 21 | 22 | 23 | python $MAIN \ 24 | --data_dir $DATA_DIR \ 25 | --output_dir $OUTPUT_DIR \ 26 | --model_name ${MODEL_NAME} \ 27 | --data_version ${DATA_VERSION} \ 28 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 29 | --config_file "${MODEL_NAME}.json" \ 30 | --interaction_type $INTERACTION_TYPE \ 31 | --classification_method $CLASSIFICATION_METHOD \ 32 | --ensemble $ENSEMBLE \ 33 | --loss_type $LOSS_TYPE \ 34 | --file_state_dict $FILE_STATE_DICT \ 35 | --epochs $EPOCHS \ 36 | --max_seq_len 50 \ 37 | --max_seq_len_pv 205 \ 38 | --max_pvs 30 \ 39 | --eval_batch_size $BATCH_SIZE \ 40 | --threshold 0.5 \ 41 | --log_steps 10 \ 42 | --fp16 \ 43 | -------------------------------------------------------------------------------- /run_model_soup_text.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 6 | MAIN="/root/Code/commodity-alignment/pkgm/model_soup_text.py" 7 | DATA_DIR=${ROOT_DIR}/processed 8 | OUTPUT_DIR=${ROOT_DIR}/output 9 | MODEL_NAME="roberta_large" 10 | DATA_VERSION="v3.4-full" 11 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/${MODEL_NAME}" 12 | 13 | INTERACTION_TYPE="one_tower" 14 | CLASSIFICATION_METHOD="cls" 15 | SIMILARITY_MEASURE="NA" 16 | LOSS_TYPE="ce" 17 | BATCH_SIZE=40 18 | 19 | FILE_STATE_DICT="/root/autodl-tmp/Data/ccks2022/task9/output/${MODEL_NAME}-${DATA_VERSION}-${INTERACTION_TYPE}-${CLASSIFICATION_METHOD}-${SIMILARITY_MEASURE}-${LOSS_TYPE}/pkgm_finetune_epoch-{}.bin" 20 | EPOCHS="6,7,8,9" 21 | 22 | 23 | python $MAIN \ 24 | --data_dir $DATA_DIR \ 25 | --output_dir $OUTPUT_DIR \ 26 | --model_name ${MODEL_NAME} \ 27 | --data_version ${DATA_VERSION} \ 28 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 29 | --config_file "${MODEL_NAME}.json" \ 30 | --interaction_type $INTERACTION_TYPE \ 31 | --classification_method $CLASSIFICATION_METHOD \ 32 | --similarity_measure $SIMILARITY_MEASURE \ 33 | --loss_type $LOSS_TYPE \ 34 | --file_state_dict $FILE_STATE_DICT \ 35 | --epochs $EPOCHS \ 36 | --max_seq_len 50 \ 37 | --max_seq_len_pv 205 \ 38 | --max_pvs 30 \ 39 | --eval_batch_size $BATCH_SIZE \ 40 | --threshold 0.5 \ 41 | --log_steps 10 \ 42 | --fp16 \ 43 | -------------------------------------------------------------------------------- /run_pkgm_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 8 | MAIN="/root/Code/commodity-alignment/pkgm/pkgm_pretrain.py" 9 | DATA_DIR=${ROOT_DIR}/processed 10 | OUTPUT_DIR=${ROOT_DIR}/output 11 | #MODEL_NAME="transe_epoch-{}.bin" 12 | MODEL_NAME="pkgm_epoch-{}.bin" 13 | TRAIN_BATCH_SIZE=32768 14 | EVAL_BATCH_SIZE=32768 15 | LEARNING_RATE=1e-4 16 | NUM_EPOCHS=2000 17 | SAVE_EPOCHS=1000 18 | EMBEDDING_DIM=768 19 | MARGIN=1.0 20 | 21 | python $MAIN \ 22 | --data_dir $DATA_DIR \ 23 | --output_dir $OUTPUT_DIR \ 24 | --model_name $MODEL_NAME \ 25 | --n_neg 3 \ 26 | --train_batch_size $TRAIN_BATCH_SIZE \ 27 | --eval_batch_size $EVAL_BATCH_SIZE \ 28 | --learning_rate $LEARNING_RATE \ 29 | --num_train_epochs $NUM_EPOCHS \ 30 | --dim $EMBEDDING_DIM \ 31 | --margin $MARGIN \ 32 | --save_epochs $SAVE_EPOCHS 33 | # --fp16 34 | -------------------------------------------------------------------------------- /run_pred_image.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | #PRETRAINED_MODEL_PATH="/Users/zeyesun/Documents/Data/bert/chinese_roberta_wwm_ext_pytorch" 8 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 9 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_image.py" 10 | DATA_DIR=${ROOT_DIR}/raw 11 | OUTPUT_DIR=${ROOT_DIR}/output 12 | MODEL_NAME="eca_nfnet_l0" 13 | DATA_VERSION="v6.2-full" 14 | #PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/cv/${MODEL_NAME}.pth" 15 | 16 | THRESHOLD=0.4 17 | EPOCH=9 18 | EVAL_BATCH_SIZE=120 19 | 20 | FILE_STATE_DICT="/root/autodl-tmp/Data/ccks2022/task9/output/${MODEL_NAME}-${DATA_VERSION}/image_finetune_epoch-${EPOCH}.bin" 21 | 22 | python $MAIN \ 23 | --data_dir $DATA_DIR \ 24 | --output_dir $OUTPUT_DIR \ 25 | --model_name ${MODEL_NAME} \ 26 | --data_version ${DATA_VERSION} \ 27 | --config_file ${MODEL_NAME}.json \ 28 | --file_state_dict $FILE_STATE_DICT \ 29 | --do_pred \ 30 | --eval_batch_size $EVAL_BATCH_SIZE \ 31 | --threshold $THRESHOLD \ 32 | --image_size 800 \ 33 | --log_steps 10 \ 34 | --fp16 35 | -------------------------------------------------------------------------------- /run_pred_multimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | #PRETRAINED_MODEL_PATH="/Users/zeyesun/Documents/Data/bert/chinese_roberta_wwm_ext_pytorch" 8 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 9 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_multimodal.py" 10 | DATA_DIR=${ROOT_DIR}/processed 11 | OUTPUT_DIR=${ROOT_DIR}/output 12 | MODEL_NAME="roberta_image_large" 13 | DATA_VERSION="v5-full" 14 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/${MODEL_NAME}" 15 | 16 | INTERACTION_TYPE="one_tower" 17 | CLASSIFICATION_METHOD="cls" 18 | ENSEMBLE="begin" 19 | LOSS_TYPE="ce" 20 | THRESHOLD=0.4 21 | EPOCH=9 22 | EVAL_BATCH_SIZE=108 23 | 24 | FILE_STATE_DICT="/root/autodl-tmp/Data/ccks2022/task9/output/${MODEL_NAME}-${DATA_VERSION}-${INTERACTION_TYPE}-${CLASSIFICATION_METHOD}-${ENSEMBLE}-${LOSS_TYPE}/multimodal_finetune_epoch-${EPOCH}.bin" 25 | 26 | python $MAIN \ 27 | --data_dir $DATA_DIR \ 28 | --output_dir $OUTPUT_DIR \ 29 | --model_name ${MODEL_NAME} \ 30 | --data_version ${DATA_VERSION} \ 31 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 32 | --file_state_dict $FILE_STATE_DICT \ 33 | --config_file "${MODEL_NAME}.json" \ 34 | --do_pred \ 35 | --interaction_type $INTERACTION_TYPE \ 36 | --classification_method $CLASSIFICATION_METHOD \ 37 | --ensemble $ENSEMBLE \ 38 | --loss_type $LOSS_TYPE \ 39 | --type_vocab_size 2 \ 40 | --max_seq_len 50 \ 41 | --max_seq_len_pv 205 \ 42 | --max_pvs 30 \ 43 | --image_hidden_size 3072 \ 44 | --eval_batch_size $EVAL_BATCH_SIZE \ 45 | --threshold $THRESHOLD \ 46 | --log_steps 100 \ 47 | --fp16 48 | -------------------------------------------------------------------------------- /run_pred_text.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # data processing 5 | #ROOT_DIR="/Users/zeyesun/Documents/Data/ccks2022/task9_商品同款" 6 | #MAIN="/Users/zeyesun/Documents/Code/torchkge/examples/train.py" 7 | #PRETRAINED_MODEL_PATH="/Users/zeyesun/Documents/Data/bert/chinese_roberta_wwm_ext_pytorch" 8 | ROOT_DIR="/root/autodl-tmp/Data/ccks2022/task9" 9 | MAIN="/root/Code/commodity-alignment/pkgm/finetune_text.py" 10 | DATA_DIR=${ROOT_DIR} 11 | OUTPUT_DIR=${ROOT_DIR}/output 12 | MODEL_NAME="roberta_large" 13 | DATA_VERSION="v3.4-full" 14 | PRETRAINED_MODEL_PATH="/root/autodl-tmp/Data/bert/${MODEL_NAME}" 15 | 16 | INTERACTION_TYPE="one_tower" 17 | CLASSIFICATION_METHOD="cls" 18 | SIMILARITY_MEASURE="NA" 19 | LOSS_TYPE="ce" 20 | THRESHOLD=0.4 21 | EPOCH=9 22 | EVAL_BATCH_SIZE=256 23 | 24 | FILE_STATE_DICT="/root/autodl-tmp/Data/ccks2022/task9/output/${MODEL_NAME}-${DATA_VERSION}-${INTERACTION_TYPE}-${CLASSIFICATION_METHOD}-${SIMILARITY_MEASURE}-${LOSS_TYPE}/pkgm_finetune_epoch-${EPOCH}.bin" 25 | 26 | python $MAIN \ 27 | --data_dir $DATA_DIR \ 28 | --output_dir $OUTPUT_DIR \ 29 | --model_name ${MODEL_NAME} \ 30 | --data_version ${DATA_VERSION} \ 31 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 32 | --file_state_dict $FILE_STATE_DICT \ 33 | --config_file "${MODEL_NAME}.json" \ 34 | --do_pred \ 35 | --interaction_type $INTERACTION_TYPE \ 36 | --classification_method $CLASSIFICATION_METHOD \ 37 | --similarity_measure $SIMILARITY_MEASURE \ 38 | --loss_type $LOSS_TYPE \ 39 | --type_vocab_size 2 \ 40 | --max_seq_len 50 \ 41 | --max_seq_len_pv 205 \ 42 | --max_pvs 30 \ 43 | --max_position_embeddings 512 \ 44 | --eval_batch_size $EVAL_BATCH_SIZE \ 45 | --threshold $THRESHOLD \ 46 | --log_steps 10 \ 47 | --fp16 48 | 49 | 50 | #MAIN="/root/Code/commodity-alignment/pkgm/pred_text.py" 51 | #python $MAIN \ 52 | # --data_dir $DATA_DIR \ 53 | # --output_dir $OUTPUT_DIR \ 54 | # --model_name ${MODEL_NAME}-${DATA_VERSION} \ 55 | # --pretrained_model_path $PRETRAINED_MODEL_PATH \ 56 | # --config_file "${MODEL_NAME}.json" \ 57 | # --max_seq_len 64 \ 58 | # --eval_batch_size $EVAL_BATCH_SIZE \ 59 | # --log_steps 10 \ 60 | # --fp16 61 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/src/__init__.py -------------------------------------------------------------------------------- /src/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ------------------------------------------------- 4 | # @Project :commodity-aligment 5 | # @File :__init__.py 6 | # @Date :2022/5/5 09:13 7 | # @Author :mengqingyang 8 | # @Email :mengqingyang0102@163.com 9 | ------------------------------------------------- 10 | """ 11 | from .log import * 12 | from .data_utils import * 13 | from .model import * 14 | -------------------------------------------------------------------------------- /src/bert/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ------------------------------------------------- 4 | # @Project :EntityAlignNet 5 | # @File :data_bert_utils 6 | # @Date :2022/6/29 11:40 7 | # @Author :mengqingyang 8 | # @Email :mengqingyang0102@163.com 9 | ------------------------------------------------- 10 | """ 11 | 12 | import os 13 | import json 14 | 15 | import torch 16 | import numpy as np 17 | from random import shuffle 18 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 19 | 20 | from .log import LOGGER 21 | 22 | 23 | def read_data(data_dir, file, pair_names): 24 | records = [] 25 | file_name = os.path.join(data_dir, file) 26 | with open(file_name, 'r', encoding="utf8") as f: 27 | data = json.loads(f.read()) 28 | if "data" in data: 29 | for rec in data["data"]: 30 | records.append([rec.get(pair_names[0], ""), rec.get(pair_names[1], ""), int(rec["label"])]) 31 | else: 32 | for rec in data: 33 | records.append([rec.get(pair_names[0], ""), rec.get(pair_names[1], ""), int(rec["label"])]) 34 | return records 35 | 36 | 37 | def shuffle_pvs_pairs(pvs_pairs): 38 | shuffled_pvs = [] 39 | for pvs_a, pvs_b, label in pvs_pairs: 40 | pvs_a_ = pvs_a.split(";") 41 | shuffle(pvs_a_) 42 | 43 | pvs_b_ = pvs_b.split(";") 44 | shuffle(pvs_b_) 45 | 46 | rand = np.random.random() 47 | if rand < 0.5: 48 | shuffled_pvs.append([';'.join(pvs_b_), ";".join(pvs_a_), label]) 49 | else: 50 | shuffled_pvs.append([';'.join(pvs_a_), ";".join(pvs_b_), label]) 51 | return shuffled_pvs 52 | 53 | 54 | def join_data(data_dir, filename, do_shuffle=True): 55 | pvs_pairs = read_data(data_dir, filename, ["pvs_a", "pvs_b"]) 56 | if do_shuffle: 57 | pvs_pairs = shuffle_pvs_pairs(pvs_pairs) 58 | title_pairs = read_data(data_dir, filename, ["title_a", "title_b"]) 59 | industry_name_pairs = read_data(data_dir, filename, ["industry_name_a", "industry_name_b"]) 60 | cate_pairs = read_data(data_dir, filename, ["cate_a", "cate_b"]) 61 | cate_path_pairs = read_data(data_dir, filename, ["cate_path_a", "cate_path_b"]) 62 | return pvs_pairs, title_pairs, industry_name_pairs, cate_pairs, cate_path_pairs 63 | 64 | 65 | def show(data, mode, name, show_num=3): 66 | LOGGER.info("") 67 | LOGGER.info(f"======= {mode} / {name} ==========") 68 | for pv in data[:show_num]: 69 | LOGGER.info(pv) 70 | LOGGER.info(f"======= *******^_^******* ==========") 71 | 72 | 73 | def get_examples(pvs, titles, cates, cate_paths, industry_names): 74 | pvs_src, pvs_tgt, labels = zip(*pvs) 75 | titles_src, titles_tgt, labels = zip(*titles) 76 | cate_src, cate_tgt, labels = zip(*cates) 77 | cate_path_src, cate_path_tgt, labels = zip(*cate_paths) 78 | industry_name_src, industry_name_tgt, labels = zip(*industry_names) 79 | return pvs_src, pvs_tgt, titles_src, titles_tgt, cate_src, cate_tgt, cate_path_src, cate_path_tgt, industry_name_src, industry_name_tgt, labels 80 | 81 | 82 | def show_pairs(src_data, tgt_data, labels, name, mode="train", show_num=3): 83 | LOGGER.info("") 84 | LOGGER.info(f"======= {mode} / {name} ==========") 85 | for a, b, label in zip(src_data[:show_num], tgt_data[:show_num], labels[:show_num]): 86 | LOGGER.info(f"src_{name}:" + str(a)) 87 | LOGGER.info(f"tgt_{name}:" + str(b)) 88 | LOGGER.info("label:" + str(label)) 89 | LOGGER.info(f"======= *******^_^******* ==========") 90 | 91 | 92 | def encode(tokenizer, pvs_src, pvs_tgt, title_src, title_tgt, cate_src, cate_tgt, cate_path_src, cate_path_tgt, 93 | industry_name_src, 94 | industry_name_tgt, pvs_len=512, title_len=150, cate_len=20, cate_path_len=50, industry_name_len=20): 95 | pvs = tokenizer(pvs_src, pvs_tgt, 96 | padding='max_length', truncation=True, max_length=pvs_len) 97 | 98 | LOGGER.info("pvs encoded ^_^") 99 | title = tokenizer(title_src, title_tgt, 100 | padding='max_length', truncation=True, max_length=title_len) 101 | LOGGER.info("title encoded ^_^") 102 | cate = tokenizer(cate_src, cate_tgt, 103 | padding='max_length', truncation=True, max_length=cate_len) 104 | LOGGER.info("cate encoded ^_^") 105 | cate_path = tokenizer(cate_path_src, cate_path_tgt, 106 | padding='max_length', truncation=True, max_length=cate_path_len) 107 | LOGGER.info("cate_path encoded ^_^") 108 | industry_name = tokenizer(industry_name_src, industry_name_tgt, padding='max_length', 109 | truncation=True, max_length=industry_name_len) 110 | LOGGER.info("industry_name encoded ^_^") 111 | return pvs, title, cate, cate_path, industry_name 112 | 113 | 114 | def convert_examples_to_features(example): 115 | input_ids = torch.tensor(example['input_ids']) 116 | attention_mask = torch.tensor(example['attention_mask']) 117 | token_type_ids = torch.tensor(example['token_type_ids']) 118 | return input_ids, attention_mask, token_type_ids 119 | 120 | 121 | def get_dataloader(pvs, titles, cates, cate_paths, industry_names, labels, batch_size=4, mode="train"): 122 | pvs_input_ids, pvs_attention_mask, pvs_token_type_ids = convert_examples_to_features(pvs) 123 | title_input_ids, title_attention_mask, title_token_type_ids = convert_examples_to_features(titles) 124 | cate_input_ids, cate_attention_mask, cate_token_type_ids = convert_examples_to_features(cates) 125 | cate_path_input_ids, cate_path_attention_mask, cate_path_token_type_ids = convert_examples_to_features(cate_paths) 126 | industry_name_input_ids, industry_name_attention_mask, industry_name_token_type_ids = convert_examples_to_features(industry_names) 127 | tensor_labels = torch.tensor(labels) 128 | 129 | # Create the DataLoader. 130 | data = TensorDataset(pvs_input_ids, pvs_attention_mask, pvs_token_type_ids, 131 | title_input_ids, title_attention_mask, title_token_type_ids, 132 | cate_input_ids, cate_attention_mask, cate_token_type_ids, 133 | cate_path_input_ids, cate_path_attention_mask, cate_path_token_type_ids, 134 | industry_name_input_ids, industry_name_attention_mask, industry_name_token_type_ids, 135 | tensor_labels) 136 | if mode == "train": 137 | sampler = RandomSampler(data) 138 | else: 139 | sampler = SequentialSampler(data) 140 | dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) 141 | return dataloader 142 | 143 | -------------------------------------------------------------------------------- /src/bert/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ------------------------------------------------- 4 | # @Project :commodity-alignment 5 | # @File :log 6 | # @Date :2022/8/7 15:09 7 | # @Author :mengqingyang 8 | # @Email :mengqingyang0102@163.com 9 | ------------------------------------------------- 10 | """ 11 | import logging 12 | 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt='%m/%d/%Y %H:%M:%S', 15 | level=logging.INFO) 16 | 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | -------------------------------------------------------------------------------- /src/bert/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ------------------------------------------------- 4 | # @Project :EntityAlignNet 5 | # @File :bert 6 | # @Date :2022/6/29 11:10 7 | # @Author :mengqingyang 8 | # @Email :mengqingyang0102@163.com 9 | ------------------------------------------------- 10 | """ 11 | import torch 12 | from torch import nn 13 | from pytorch_transformers import BertPreTrainedModel 14 | from pytorch_transformers.modeling_bert import BertEncoder, BertPooler, BertOnlyNSPHead, BertLayerNorm 15 | from torch.nn import CrossEntropyLoss 16 | 17 | 18 | class BertEmbeddings(nn.Module): 19 | """Construct the embeddings from word, position and token_type embeddings. 20 | """ 21 | def __init__(self, config): 22 | super(BertEmbeddings, self).__init__() 23 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 24 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 25 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 26 | 27 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 28 | # any TensorFlow checkpoint file 29 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 30 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 31 | 32 | def forward(self, input_ids, token_type_ids=None, position_ids=None): 33 | seq_length = input_ids.size(1) 34 | if position_ids is None: 35 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 36 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 37 | if token_type_ids is None: 38 | token_type_ids = torch.zeros_like(input_ids) 39 | 40 | 41 | words_embeddings = self.word_embeddings(input_ids) 42 | position_embeddings = self.position_embeddings(position_ids) 43 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 44 | 45 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 46 | 47 | embeddings = self.LayerNorm(embeddings) 48 | embeddings = self.dropout(embeddings) 49 | return embeddings 50 | 51 | 52 | class BertModel(BertPreTrainedModel): 53 | r""" 54 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 55 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 56 | Sequence of hidden-states at the output of the last layer of the model. 57 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 58 | Last layer hidden-state of the first token of the sequence (classification token) 59 | further processed by a Linear layer and a Tanh activation function. The Linear 60 | layer weights are trained from the next sentence prediction (classification) 61 | objective during Bert pretraining. This output is usually *not* a good summary 62 | of the semantic content of the input, you're often better with averaging or pooling 63 | the sequence of hidden-states for the whole input sequence. 64 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 65 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 66 | of shape ``(batch_size, sequence_length, hidden_size)``: 67 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 68 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 69 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 70 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 71 | 72 | Examples:: 73 | 74 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 75 | model = BertModel.from_pretrained('bert-base-uncased') 76 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 77 | outputs = model(input_ids) 78 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 79 | 80 | """ 81 | def __init__(self, config): 82 | super(BertModel, self).__init__(config) 83 | 84 | self.embeddings = BertEmbeddings(config) 85 | self.encoder = BertEncoder(config) 86 | self.pooler = BertPooler(config) 87 | 88 | self.init_weights() 89 | 90 | def _resize_token_embeddings(self, new_num_tokens): 91 | old_embeddings = self.embeddings.word_embeddings 92 | new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) 93 | self.embeddings.word_embeddings = new_embeddings 94 | return self.embeddings.word_embeddings 95 | 96 | def _prune_heads(self, heads_to_prune): 97 | """ Prunes heads of the model. 98 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 99 | See base class PreTrainedModel 100 | """ 101 | for layer, heads in heads_to_prune.items(): 102 | self.encoder.layer[layer].attention.prune_heads(heads) 103 | 104 | def forward(self, input_ids, 105 | token_type_ids=None, 106 | attention_mask=None, 107 | position_ids=None, 108 | head_mask=None, 109 | noise=None): 110 | if attention_mask is None: 111 | attention_mask = torch.ones_like(input_ids) 112 | if token_type_ids is None: 113 | token_type_ids = torch.zeros_like(input_ids) 114 | 115 | # We create a 3D attention mask from a 2D tensor mask. 116 | # Sizes are [batch_size, 1, 1, to_seq_length] 117 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 118 | # this attention mask is more simple than the triangular masking of causal attention 119 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 120 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 121 | 122 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 123 | # masked positions, this operation will create a tensor which is 0.0 for 124 | # positions we want to attend and -10000.0 for masked positions. 125 | # Since we are adding it to the raw scores before the softmax, this is 126 | # effectively the same as removing these entirely. 127 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 128 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 129 | 130 | # Prepare head mask if needed 131 | # 1.0 in head_mask indicate we keep the head 132 | # attention_probs has shape bsz x n_heads x N x N 133 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 134 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 135 | if head_mask is not None: 136 | if head_mask.dim() == 1: 137 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 138 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 139 | elif head_mask.dim() == 2: 140 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 141 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 142 | else: 143 | head_mask = [None] * self.config.num_hidden_layers 144 | 145 | embedding_output = self.embeddings(input_ids, 146 | position_ids=position_ids, 147 | token_type_ids=token_type_ids) 148 | if noise is not None: 149 | embedding_output = embedding_output + noise 150 | encoder_outputs = self.encoder(embedding_output, 151 | extended_attention_mask, 152 | head_mask=head_mask) 153 | sequence_output = encoder_outputs[0] 154 | pooled_output = self.pooler(sequence_output) 155 | 156 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here 157 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 158 | 159 | 160 | class BertAlignModel(BertPreTrainedModel): 161 | def __init__(self, config=None): 162 | """title, cate, cate_path, pvs, industry_name""" 163 | super(BertAlignModel, self).__init__(config) 164 | self.loss_fct = CrossEntropyLoss() 165 | self.num_labels = 2 166 | self.bert = BertModel(config) 167 | self.cls = BertOnlyNSPHead(config) 168 | self.init_weights() 169 | 170 | def get_sim_eval_weight(self): 171 | w = self.cls.seq_relationship.weight.detach().cpu() 172 | b = self.cls.seq_relationship.bias.detach().cpu() 173 | weight = w[1] - w[0] 174 | return weight, b[1] - b[0] 175 | 176 | def forward(self, 177 | pvs_input_ids=None, pvs_token_type_ids=None, pvs_attention_mask=None, 178 | title_input_ids=None, title_token_type_ids=None, title_attention_mask=None, 179 | cate_input_ids=None, cate_token_type_ids=None, cate_attention_mask=None, 180 | cate_path_input_ids=None, cate_path_token_type_ids=None, cate_path_attention_mask=None, 181 | industry_name_input_ids=None, industry_name_token_type_ids=None, industry_name_attention_mask=None, 182 | next_sentence_label=None, pvs_noise=None, title_noise=None): 183 | title_out = self.bert(title_input_ids, 184 | token_type_ids=title_token_type_ids, 185 | attention_mask=title_attention_mask, 186 | noise=title_noise) 187 | title_pool_out = title_out[1] 188 | 189 | cate_out = self.bert(cate_input_ids, 190 | token_type_ids=cate_token_type_ids, 191 | attention_mask=cate_attention_mask) 192 | cate_pool_out = cate_out[1] 193 | 194 | cate_path_out = self.bert(cate_path_input_ids, 195 | token_type_ids=cate_path_token_type_ids, 196 | attention_mask=cate_path_attention_mask) 197 | cate_path_pool_out = cate_path_out[1] 198 | 199 | pvs_out = self.bert(pvs_input_ids, 200 | token_type_ids=pvs_token_type_ids, 201 | attention_mask=pvs_attention_mask, 202 | noise=pvs_noise) 203 | pvs_pool_out = pvs_out[1] 204 | 205 | industry_name_out = self.bert(industry_name_input_ids, 206 | token_type_ids=industry_name_token_type_ids, 207 | attention_mask=industry_name_attention_mask) 208 | industry_name_pool_out = industry_name_out[1] 209 | 210 | pool_out = title_pool_out + cate_pool_out + cate_path_pool_out + pvs_pool_out + industry_name_pool_out 211 | seq_relationship_score = self.cls(pool_out) 212 | output = [pool_out, seq_relationship_score] 213 | if next_sentence_label is not None: 214 | next_sentence_loss = self.loss_fct(seq_relationship_score.view(-1, self.num_labels), 215 | next_sentence_label.view(-1)) 216 | output.append(next_sentence_loss) 217 | return output 218 | 219 | -------------------------------------------------------------------------------- /src/config/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "num_hidden_layers_multimodal": 12, 20 | "num_attention_heads_multimodal": 12, 21 | "feedforward_multiplication_multimodal": 6, 22 | } 23 | -------------------------------------------------------------------------------- /src/config/coca_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "num_hidden_layers_multimodal": 24, 20 | "num_attention_heads_multimodal": 16, 21 | "feedforward_multiplication_multimodal": 12, 22 | "image_size": 384, 23 | "patch_size": 16 24 | } 25 | -------------------------------------------------------------------------------- /src/config/eca_nfnet_l0.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dropout_prob": 0.1, 3 | "num_labels": 2, 4 | "image_size": 800 5 | } 6 | -------------------------------------------------------------------------------- /src/config/gcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dropout_prob": 0.1, 3 | "hidden_size": 1024, 4 | "intermediate_size": 128, 5 | "num_hidden_layers": 2, 6 | "num_entities": 230023, 7 | "num_labels": 2, 8 | "alpha": 0.1, 9 | "theta": 0.5 10 | } 11 | -------------------------------------------------------------------------------- /src/config/pkgm_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "cate_size": 179, 20 | "num_entities": 231619, 21 | "num_relations": 987, 22 | "kg_embedding_dim": 768, 23 | "entity_projection_bias": false 24 | } 25 | -------------------------------------------------------------------------------- /src/config/pkgm_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "num_entities": 258211, 20 | "num_relations": 1379, 21 | "kg_embedding_dim": 1024, 22 | "entity_projection_bias": false, 23 | "cate_size": 179 24 | } 25 | -------------------------------------------------------------------------------- /src/config/resnetv2_50.json: -------------------------------------------------------------------------------- 1 | { 2 | "hidden_dropout_prob": 0.1, 3 | "num_hidden_layers": 50, 4 | "num_labels": 2 5 | } 6 | -------------------------------------------------------------------------------- /src/config/roberta_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "cate_size": 179 20 | } 21 | -------------------------------------------------------------------------------- /src/config/roberta_image_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "interaction_type": "one_tower", 20 | "classification_method": "cls", 21 | "similarity_measure": "inner_product", 22 | "loss_type": "ce", 23 | "num_entities": 231619, 24 | "num_relations": 987, 25 | "kg_embedding_dim":768, 26 | "max_seq_len": 52, 27 | "max_pvs": 30, 28 | "entity_projection_bias": false, 29 | "image_hidden_size": 3072 30 | } 31 | -------------------------------------------------------------------------------- /src/config/roberta_image_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "classification_method": "cls", 20 | "loss_type": "ce", 21 | "num_entities": 231619, 22 | "num_relations": 987, 23 | "kg_embedding_dim": 1024, 24 | "max_seq_len": 52, 25 | "max_pvs": 30, 26 | "entity_projection_bias": false, 27 | "image_hidden_size": 3072 28 | } 29 | -------------------------------------------------------------------------------- /src/config/roberta_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "cate_size": 179 20 | } 21 | -------------------------------------------------------------------------------- /src/config/vit_base_patch16_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "patch_size": 16, 20 | "image_size": 384 21 | } 22 | -------------------------------------------------------------------------------- /src/config/vit_large_patch16_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 21128, 19 | "patch_size": 16, 20 | "image_size": 384 21 | } 22 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from .base import * 3 | from .text import * 4 | from .image import * 5 | from .multimodal import * 6 | from .graph import * -------------------------------------------------------------------------------- /src/models/graph.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from .base import SequenceClassifierOutput, TwoTowerClassificationHead 7 | from torch_geometric.nn import GCN2Conv 8 | from ..utils import logger 9 | 10 | 11 | # GCN 12 | class GCN(torch.nn.Module): 13 | def __init__(self, config): 14 | # intermediate_size, num_hidden_layers, hidden_dropout_prob=0.1, 15 | # hidden_size=768, alpha=0.1, theta=0.5, num_labels=2): 16 | super().__init__() 17 | 18 | self.linear = torch.nn.Linear(config.hidden_size, config.intermediate_size) 19 | # self.lins = torch.nn.ModuleList() 20 | # self.lins.append(torch.nn.Linear(hidden_size, intermediate_size)) 21 | # self.lins.append(torch.nn.Linear(intermediate_size*2, num_labels)) 22 | 23 | self.convs = torch.nn.ModuleList() 24 | for layer in range(config.num_hidden_layers): 25 | self.convs.append( 26 | GCN2Conv(config.intermediate_size, config.alpha, config.theta, layer + 1, 27 | shared_weights=True, normalize=False)) 28 | 29 | self.dropout = config.hidden_dropout_prob 30 | 31 | def forward(self, x, adj_t): 32 | x = F.dropout(x, self.dropout, training=self.training) 33 | x = x_0 = self.linear(x).relu() 34 | # x = x_0 = self.lins[0](x).relu() 35 | 36 | for conv in self.convs: 37 | x = F.dropout(x, self.dropout, training=self.training) 38 | x = conv(x, x_0, adj_t) 39 | x = x.relu() 40 | 41 | x = F.dropout(x, self.dropout, training=self.training) 42 | # x = self.lins[1](x) 43 | 44 | return x # x.log_softmax(dim=-1) 45 | 46 | 47 | class GCNTwoTower(nn.Module): 48 | def __init__(self, config): 49 | # intermediate_size, 50 | # num_hidden_layers, 51 | # hidden_dropout_prob=0.1, 52 | # hidden_size=768, 53 | # num_labels=2, 54 | # loss_margin=0.0 55 | # ): 56 | super().__init__() 57 | self.config = config 58 | self.num_labels = config.num_labels 59 | 60 | # GCN encoder 61 | self.encoder = GCN(config) 62 | # intermediate_size, 63 | # num_hidden_layers, 64 | # hidden_dropout_prob=hidden_dropout_prob, 65 | # hidden_size=hidden_size) 66 | 67 | # self.classifier = RobertaClassificationHead(config) 68 | self.classifier = TwoTowerClassificationHead(config.intermediate_size, 69 | dropout=config.hidden_dropout_prob, 70 | num_labels=config.num_labels) 71 | 72 | # if config.loss_type == "cosine": 73 | # self.loss_fct = nn.CosineEmbeddingLoss(margin=loss_margin) 74 | # elif config.loss_type == "bce": 75 | # self.loss_fct = nn.BCEWithLogitsLoss() 76 | # elif config.loss_type == "euclidean": 77 | # self.loss_fct = EuclideanDistanceLoss() 78 | # elif config.loss_type == "hinge": 79 | # self.loss_fct = HingeLoss(margin=loss_margin) 80 | # else: 81 | self.loss_fct = nn.CrossEntropyLoss() 82 | 83 | def forward(self, feature_matrix, adjacency_matrix, pairs): 84 | node_embeddings = self.encoder(feature_matrix, adjacency_matrix) 85 | loss = None 86 | logits, probs, src_embeds, tgt_embeds = None, None, None, None 87 | for pair in pairs: 88 | i = pair['src_idx'] 89 | j = pair['tgt_idx'] 90 | src_node_embeddings = node_embeddings[i].unsqueeze(0) 91 | tgt_node_embeddings = node_embeddings[j].unsqueeze(0) 92 | 93 | # logits 94 | src_e, tgt_e, lgt, prob = self.classifier(src_node_embeddings, tgt_node_embeddings) 95 | if logits is None: 96 | logits = lgt 97 | probs = prob[:, 1] 98 | src_embeds = prob[:, 0] 99 | tgt_embeds = prob[:, 1] 100 | else: 101 | src_embeds = torch.cat((src_embeds, prob[:, 0]), dim=0) 102 | tgt_embeds = torch.cat((tgt_embeds, prob[:, 1]), dim=0) 103 | probs = torch.cat((probs, prob[:, 1]), dim=0) 104 | 105 | labels = pair.get('item_label', None) 106 | if labels is not None: 107 | if loss is None: 108 | loss = 0 109 | labels = torch.tensor(int(pair['item_label']), dtype=torch.long, device=feature_matrix.device) 110 | if self.config.loss_type == "cosine": 111 | loss += self.loss_fct(src_embeds, tgt_embeds, (labels*2-1).view(-1)) 112 | elif self.config.loss_type == "ce": 113 | loss += self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 114 | elif self.config.loss_type == "hinge" or self.config.loss_type == "euclidean": 115 | loss += self.loss_fct(logits.view(-1), (labels*2-1).view(-1)) 116 | else: 117 | loss += self.loss_fct(logits.view(-1), labels.view(-1)) 118 | 119 | # if not return_dict: 120 | # output = (logits,) + outputs[2:] 121 | # return ((loss,) + output) if loss is not None else output 122 | 123 | if loss is not None: 124 | loss /= len(pairs) 125 | 126 | return SequenceClassifierOutput( 127 | loss=loss, 128 | logits=logits, 129 | probs=probs, 130 | src_embeds=src_embeds, 131 | tgt_embeds=tgt_embeds 132 | ) 133 | -------------------------------------------------------------------------------- /src/models/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.nn.modules.loss import _Loss 5 | 6 | 7 | class EuclideanDistanceLoss(_Loss): 8 | r"""Measures the loss given an input tensor :math:`x` (euclidean distance) and a labels tensor :math:`y` 9 | (containing 1 or -1). 10 | This is usually used for measuring whether two inputs are similar or 11 | dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically 12 | used for learning nonlinear embeddings or semi-supervised learning. 13 | 14 | The loss function for :math:`n`-th sample in the mini-batch is 15 | 16 | .. math:: 17 | l_n = \begin{cases} 18 | x_n^{y_n} 19 | \end{cases} 20 | 21 | and the total loss functions is 22 | 23 | .. math:: 24 | \ell(x, y) = \begin{cases} 25 | \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ 26 | \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} 27 | \end{cases} 28 | 29 | where :math:`L = \{l_1,\dots,l_N\}^\top`. 30 | 31 | Args: 32 | margin (float, optional): Has a default value of `1`. 33 | size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 34 | the losses are averaged over each loss element in the batch. Note that for 35 | some losses, there are multiple elements per sample. If the field :attr:`size_average` 36 | is set to ``False``, the losses are instead summed for each minibatch. Ignored 37 | when :attr:`reduce` is ``False``. Default: ``True`` 38 | reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 39 | losses are averaged or summed over observations for each minibatch depending 40 | on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 41 | batch element instead and ignores :attr:`size_average`. Default: ``True`` 42 | reduction (string, optional): Specifies the reduction to apply to the output: 43 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 44 | ``'mean'``: the sum of the output will be divided by the number of 45 | elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 46 | and :attr:`reduce` are in the process of being deprecated, and in the meantime, 47 | specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 48 | 49 | Shape: 50 | - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation 51 | operates over all the elements. 52 | - Target: :math:`(*)`, same shape as the input 53 | - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input 54 | """ 55 | __constants__ = ['reduction'] 56 | # margin: float 57 | 58 | def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: 59 | super(EuclideanDistanceLoss, self).__init__(size_average, reduce, reduction) 60 | 61 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 62 | loss = torch.pow(input, target) 63 | if self.reduction == "sum": 64 | return loss.sum() 65 | elif self.reduction == "mean": 66 | return loss.mean() 67 | else: 68 | return loss 69 | 70 | 71 | class HingeLoss(_Loss): 72 | r"""Measures the loss given an input tensor :math:`x` (inner product) and a labels tensor :math:`y` 73 | (containing 1 or -1). 74 | This is usually used for measuring whether two inputs are similar or 75 | dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically 76 | used for learning nonlinear embeddings or semi-supervised learning. 77 | 78 | The loss function for :math:`n`-th sample in the mini-batch is 79 | 80 | .. math:: 81 | l_n = \begin{cases} 82 | \max \{0, \Delta - y_n * x_n\} 83 | \end{cases} 84 | 85 | and the total loss functions is 86 | 87 | .. math:: 88 | \ell(x, y) = \begin{cases} 89 | \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ 90 | \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} 91 | \end{cases} 92 | 93 | where :math:`L = \{l_1,\dots,l_N\}^\top`. 94 | 95 | Args: 96 | margin (float, optional): Has a default value of `1`. 97 | size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 98 | the losses are averaged over each loss element in the batch. Note that for 99 | some losses, there are multiple elements per sample. If the field :attr:`size_average` 100 | is set to ``False``, the losses are instead summed for each minibatch. Ignored 101 | when :attr:`reduce` is ``False``. Default: ``True`` 102 | reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 103 | losses are averaged or summed over observations for each minibatch depending 104 | on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 105 | batch element instead and ignores :attr:`size_average`. Default: ``True`` 106 | reduction (string, optional): Specifies the reduction to apply to the output: 107 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 108 | ``'mean'``: the sum of the output will be divided by the number of 109 | elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 110 | and :attr:`reduce` are in the process of being deprecated, and in the meantime, 111 | specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 112 | 113 | Shape: 114 | - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation 115 | operates over all the elements. 116 | - Target: :math:`(*)`, same shape as the input 117 | - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input 118 | """ 119 | __constants__ = ['margin', 'reduction'] 120 | margin: float 121 | 122 | def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean') -> None: 123 | super(HingeLoss, self).__init__(size_average, reduce, reduction) 124 | self.margin = margin 125 | 126 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 127 | zero = torch.zeros(1, device=input.device) 128 | loss = torch.max(zero, self.margin - input * target) 129 | if self.reduction == "sum": 130 | return loss.sum() 131 | elif self.reduction == "mean": 132 | return loss.mean() 133 | else: 134 | return loss 135 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import logger 2 | from .config import * -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | 2 | ROBERTA_WEIGHTS_NAME = "pytorch_model.bin" 3 | KG_WEIGHTS_NAME = "pkgm_model.bin" 4 | COCA_WEIGHTS_NAME = "coca_model.bin" 5 | VIT_WEIGHTS_NAME = "image_encoder.bin" 6 | 7 | BOS_TOKEN = "" 8 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | 5 | logging.basicConfig( 6 | format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)s] %(message)s", 7 | datefmt="%Y/%m/%d %H:%M:%S", 8 | level=logging.INFO 9 | ) 10 | 11 | logger = logging.getLogger(__name__) 12 | -------------------------------------------------------------------------------- /submit/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base Images 2 | ## 从天池基础镜像构建 3 | FROM registry.cn-shanghai.aliyuncs.com/ccks2022_task9_subtask2/submit:py3.7 4 | 5 | ## 把当前文件夹里的文件构建到镜像的根目录下 6 | ADD result.zip requirements.txt run.sh / 7 | 8 | ## 指定默认工作目录为根目录(需要把run.sh和生成的结果文件都放在该文件夹下,提交后才能运行) 9 | WORKDIR / 10 | 11 | ## 镜像启动后统一执行 sh run.sh 12 | CMD ["sh", "run.sh"] 13 | -------------------------------------------------------------------------------- /submit/push.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #Compress-Archive -Update -Path .\similarity.py,.\deepAI_result.jsonl -DestinationPath result.zip 4 | 5 | TAG="ensemble-threshold_0.0" 6 | 7 | docker build -t registry.cn-shanghai.aliyuncs.com/ccks2022_task9_subtask2/submit:$TAG . 8 | docker push registry.cn-shanghai.aliyuncs.com/ccks2022_task9_subtask2/submit:$TAG 9 | -------------------------------------------------------------------------------- /submit/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/submit/requirements.txt -------------------------------------------------------------------------------- /submit/result.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/submit/result.zip -------------------------------------------------------------------------------- /submit/run.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/submit/run.sh -------------------------------------------------------------------------------- /torchkge/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | TorchKGE 3 | ======== 4 | 5 | .. image:: https://graphs.telecom-paristech.fr/images/logo_torchKGE_small.png 6 | :align: right 7 | :width: 100px 8 | :alt: logo torchkge 9 | 10 | .. image:: https://img.shields.io/pypi/v/torchkge.svg 11 | :target: https://pypi.python.org/pypi/torchkge 12 | 13 | .. image:: https://github.com/torchkge-team/torchkge/actions/workflows/ci_checks.yml/badge.svg 14 | :target: https://github.com/torchkge-team/torchkge/actions/workflows/ci_checks.yml 15 | 16 | .. image:: https://readthedocs.org/projects/torchkge/badge/?version=latest 17 | :target: https://torchkge.readthedocs.io/en/latest/?badge=latest 18 | :alt: Documentation Status 19 | 20 | .. image:: https://pyup.io/repos/github/torchkge-team/torchkge/shield.svg 21 | :target: https://pyup.io/repos/github/torchkge-team/torchkge/ 22 | :alt: Updates 23 | 24 | .. image:: https://img.shields.io/pypi/pyversions/torchkge.svg 25 | :target: https://pypi.org/project/torchkge/ 26 | 27 | TorchKGE: Knowledge Graph embedding in Python and Pytorch. 28 | 29 | TorchKGE is a Python module for knowledge graph (KG) embedding relying solely on Pytorch. This package provides 30 | researchers and engineers with a clean and efficient API to design and test new models. It features a KG data structure, 31 | simple model interfaces and modules for negative sampling and model evaluation. Its main strength is a highly efficient 32 | evaluation module for the link prediction task, a central application of KG embedding. It has been `observed `_ to be up 33 | to five times faster than `AmpliGraph `_ and twenty-four times faster than 34 | `OpenKE `_. Various KG embedding models are also already implemented. Special 35 | attention has been paid to code efficiency and simplicity, documentation and API consistency. It is distributed using 36 | PyPI under BSD license. 37 | 38 | Citations 39 | --------- 40 | If you find this code useful in your research, please consider citing our `paper `_ (presented at `IWKG-KDD `_ 2020): 41 | 42 | .. code:: 43 | 44 | @inproceedings{arm2020torchkge, 45 | title={TorchKGE: Knowledge Graph Embedding in Python and PyTorch}, 46 | author={Armand Boschin}, 47 | year={2020}, 48 | month={Aug}, 49 | booktitle={International Workshop on Knowledge Graph: Mining Knowledge Graph for Deep Insights}, 50 | } 51 | 52 | * Free software: BSD license 53 | * Documentation: https://torchkge.readthedocs.io. 54 | -------------------------------------------------------------------------------- /torchkge/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchkge.models import * 2 | from .torchkge.utils import * 3 | from .torchkge.evaluation import * -------------------------------------------------------------------------------- /torchkge/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = torchkge 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /torchkge/docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/torchkge/docs/__init__.py -------------------------------------------------------------------------------- /torchkge/docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | html { 2 | background-color: #e6e6e6; 3 | } 4 | 5 | body.wy-body-for-nav { 6 | line-height: 1.5em; 7 | color: #333; 8 | } 9 | 10 | div.wy-nav-side { 11 | background-color: #333; 12 | } 13 | 14 | div.wy-side-nav-search { 15 | background-color: #777777; 16 | } 17 | 18 | 19 | div.wy-menu.wy-menu-vertical>p { 20 | color: #c5113b /* section titles */ 21 | } 22 | 23 | .wy-nav-top { 24 | background-color: #777777; 25 | } 26 | 27 | .wy-side-nav-search>a:hover, .wy-side-nav-search .wy-dropdown>a:hover { 28 | background: None; /*background for logo when hovered*/ 29 | } 30 | 31 | .wy-side-nav-search>div.version { 32 | color: white; 33 | } 34 | 35 | .wy-side-nav-search input[type=text] { 36 | border-color: #d9d9d9; 37 | } 38 | 39 | a { 40 | color: #c5113b; 41 | } 42 | -------------------------------------------------------------------------------- /torchkge/docs/authors.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Armand Boschin 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /torchkge/docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # torchkge documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another 17 | # directory, add these directories to sys.path here. If the directory is 18 | # relative to the documentation root, use os.path.abspath to make it 19 | # absolute, like shown here. 20 | # 21 | import os 22 | import sys 23 | sys.path.insert(0, os.path.abspath('..')) 24 | 25 | import torchkge 26 | 27 | # -- General configuration --------------------------------------------- 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 35 | extensions = ['sphinx.ext.autodoc', 36 | 'sphinx.ext.viewcode', 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.coverage', 39 | 'sphinx.ext.doctest', 40 | 'sphinx.ext.intersphinx', 41 | 'sphinx.ext.mathjax', 42 | 'sphinx.ext.napoleon', 43 | 'sphinx.ext.todo', 44 | 'sphinx.ext.viewcode'] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = '.rst' 54 | 55 | # The master toctree document. 56 | master_doc = 'index' 57 | 58 | # General information about the project. 59 | project = u'TorchKGE' 60 | copyright = u"2022, TorchKGE developers" 61 | author = u"Armand Boschin" 62 | 63 | # The version info for the project you're documenting, acts as replacement 64 | # for |version| and |release|, also used in various other places throughout 65 | # the built documents. 66 | # 67 | # The short X.Y version. 68 | version = torchkge.__version__ 69 | # The full version, including alpha/beta/rc tags. 70 | release = torchkge.__version__ 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = None 78 | 79 | # List of patterns, relative to source directory, that match files and 80 | # directories to ignore when looking for source files. 81 | # This patterns also effect to html_static_path and html_extra_path 82 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 83 | 84 | # The name of the Pygments (syntax highlighting) style to use. 85 | pygments_style = 'sphinx' 86 | 87 | # If true, `todo` and `todoList` produce output, else they produce nothing. 88 | todo_include_todos = False 89 | 90 | 91 | # -- Options for HTML output ------------------------------------------- 92 | 93 | # The theme to use for HTML and HTML Help pages. See the documentation for 94 | # a list of builtin themes. 95 | # 96 | html_theme = 'sphinx_rtd_theme' 97 | html_logo = 'logo_torchKGE_small.png' 98 | html_favicon = "logo_torchKGE_small.png" 99 | html_theme_options = { 100 | 'logo_only': True 101 | } 102 | 103 | # Theme options are theme-specific and customize the look and feel of a 104 | # theme further. For a list of options available for each theme, see the 105 | # documentation. 106 | # 107 | # html_theme_options = {} 108 | 109 | # Add any paths that contain custom static files (such as style sheets) here, 110 | # relative to this directory. They are copied after the builtin static files, 111 | # so a file named "default.css" will overwrite the builtin "default.css". 112 | html_static_path = ['_static'] 113 | 114 | 115 | # -- Options for HTMLHelp output --------------------------------------- 116 | 117 | # Output file base name for HTML help builder. 118 | htmlhelp_basename = 'torchkgedoc' 119 | 120 | 121 | # -- Options for LaTeX output ------------------------------------------ 122 | 123 | latex_elements = { 124 | # The paper size ('letterpaper' or 'a4paper'). 125 | # 126 | # 'papersize': 'letterpaper', 127 | 128 | # The font size ('10pt', '11pt' or '12pt'). 129 | # 130 | # 'pointsize': '10pt', 131 | 132 | # Additional stuff for the LaTeX preamble. 133 | # 134 | # 'preamble': '', 135 | 136 | # Latex figure (float) alignment 137 | # 138 | # 'figure_align': 'htbp', 139 | } 140 | 141 | # Grouping the document tree into LaTeX files. List of tuples 142 | # (source start file, target name, title, author, documentclass 143 | # [howto, manual, or own class]). 144 | latex_documents = [ 145 | (master_doc, 'torchkge.tex', 146 | u'TorchKGE Documentation', 147 | u'Armand Boschin', 'manual'), 148 | ] 149 | 150 | 151 | # -- Options for manual page output ------------------------------------ 152 | 153 | # One entry per manual page. List of tuples 154 | # (source start file, name, description, authors, manual section). 155 | man_pages = [ 156 | (master_doc, 'torchkge', 157 | u'TorchKGE Documentation', 158 | [author], 1) 159 | ] 160 | 161 | 162 | # -- Options for Texinfo output ---------------------------------------- 163 | 164 | # Grouping the document tree into Texinfo files. List of tuples 165 | # (source start file, target name, title, author, 166 | # dir menu entry, description, category) 167 | texinfo_documents = [ 168 | (master_doc, 'torchkge', 169 | u'TorchKGE Documentation', 170 | author, 171 | 'torchkge', 172 | 'One line description of project.', 173 | 'Miscellaneous'), 174 | ] 175 | 176 | 177 | def setup(app): 178 | app.add_css_file('css/custom.css') 179 | 180 | 181 | nbsphinx_kernel_name = 'python3' 182 | -------------------------------------------------------------------------------- /torchkge/docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. 8 | 9 | You can contribute in many ways: 10 | 11 | Types of Contributions 12 | ---------------------- 13 | 14 | Report Bugs 15 | ~~~~~~~~~~~ 16 | 17 | Report bugs at https://github.com/torchkge-team/torchkge/issues. 18 | 19 | If you are reporting a bug, please include: 20 | 21 | * Your operating system name and version. 22 | * Any details about your local setup that might be helpful in troubleshooting. 23 | * Detailed steps to reproduce the bug. 24 | 25 | Fix Bugs 26 | ~~~~~~~~ 27 | 28 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants 29 | to implement it. 30 | 31 | Implement Features 32 | ~~~~~~~~~~~~~~~~~~ 33 | 34 | Look through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is opento whoever 35 | wants to implement it. 36 | 37 | Write Documentation 38 | ~~~~~~~~~~~~~~~~~~~ 39 | 40 | TorchKGE could always use more documentation, whether as part of the official TorchKGE docs, in docstrings, or even 41 | on the web in blog posts, articles, and such. 42 | 43 | Submit Feedback 44 | ~~~~~~~~~~~~~~~ 45 | 46 | The best way to send feedback is to file an issue at https://github.com/torchkge-team/torchkge/issues. 47 | 48 | If you are proposing a feature: 49 | 50 | * Explain in detail how it would work. 51 | * Keep the scope as narrow as possible, to make it easier to implement. 52 | * Remember that this is a volunteer-driven project, and that contributions 53 | are welcome :) 54 | 55 | Get Started! 56 | ------------ 57 | 58 | Ready to contribute? Here's how to set up `torchkge` for local development. 59 | 60 | 1. Fork the `torchkge` repo on GitHub. 61 | 2. Clone your fork locally:: 62 | 63 | $ git clone git@github.com:your_name_here/torchkge.git 64 | 65 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 66 | 67 | $ mkvirtualenv torchkge 68 | $ cd torchkge/ 69 | $ python setup.py develop 70 | 71 | 4. Create a branch for local development:: 72 | 73 | $ git checkout -b dev/name-of-your-bugfix-or-feature 74 | 75 | Now you can make your changes locally. 76 | 77 | 5. When you're done making changes, check that your changes pass tests, including testing other 78 | Python versions with tox:: 79 | 80 | $ flake8 torchkge tests 81 | $ python setup.py test or py.test 82 | $ tox 83 | 84 | To get tox, just pip install it into your virtualenv. 85 | 86 | 6. Commit your changes and push your branch to GitHub:: 87 | 88 | $ git add . 89 | $ git commit -m "Your detailed description of your changes." 90 | $ git push origin dev/name-of-your-bugfix-or-feature 91 | 92 | 7. Submit a pull request through the GitHub website. 93 | 94 | Pull Request Guidelines 95 | ----------------------- 96 | 97 | Before you submit a pull request, check that it meets these guidelines: 98 | 99 | 1. The pull request should include tests. 100 | 2. If the pull request adds functionality, the docs should be updated. Put 101 | your new functionality into a function with a docstring, and add the 102 | feature to the list in README.rst. 103 | 3. The pull request should work for Python 3,7, 3.8, 3.9 and for PyPi. Check 104 | https://github.com/torchkge-team/torchkge/actions 105 | and make sure that the tests pass for all supported Python versions. 106 | 107 | Deploying 108 | --------- 109 | 110 | A reminder for the maintainers on how to deploy. 111 | Make sure all your changes are committed (including an entry in HISTORY.rst). 112 | Then run:: 113 | 114 | $ bumpversion patch # possible: major / minor / patch 115 | $ git push 116 | $ git push --tags 117 | 118 | Github Actions will then deploy to PyPI if tests pass. 119 | -------------------------------------------------------------------------------- /torchkge/docs/history.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.17.3 (2022-04-21) 6 | ------------------- 7 | * Fix ConvKB scoring function and normalization step 8 | 9 | 0.17.2 (2022-03-02) 10 | ------------------- 11 | * Fix the documentation in evaluation and inference modules 12 | * Fix a typo in the sampling module's documentation 13 | 14 | 0.17.1 (2022-02-25) 15 | ------------------- 16 | * Add support of Python 3.7 back 17 | 18 | 0.17.0 (2022-02-25) 19 | ------------------- 20 | * Add relation prediction evaluation 21 | * Add relation negative sampling module 22 | * Add inference module 23 | * Update models' API accordingly to the previous new features 24 | * Switch from TravisCI to GitHub Actions 25 | 26 | 0.16.25 (2021-03-01) 27 | -------------------- 28 | * Update in available pretrained models 29 | 30 | 0.16.24 (2021-02-16) 31 | -------------------- 32 | * Fix deployment 33 | 34 | 0.16.23 (2021-02-16) 35 | -------------------- 36 | * Removed useless k_max parameter in link-prediction evaluation method 37 | 38 | 0.16.22 (2021-02-05) 39 | -------------------- 40 | * Add pretrained version of TransE for yago310 and ComplEx for fb15k237 and wdv5. 41 | 42 | 0.16.21 (2021-02-02) 43 | -------------------- 44 | * Add pretrained version of TransE for Wikidata-Vitals level 5 45 | 46 | 0.16.20 (2021-01-22) 47 | -------------------- 48 | * Add support for Python 3.8 49 | * Clean up loading process for kgs 50 | * Fix deprecation warning 51 | 52 | 0.16.19 (2021-01-20) 53 | -------------------- 54 | * Fix release 55 | 56 | 0.16.18 (2021-01-20) 57 | -------------------- 58 | * Add data loader for wikidata vitals knowledge graphs 59 | 60 | 0.16.17 (2020-11-03) 61 | -------------------- 62 | * Bug fix get_ranks method 63 | 64 | 0.16.16 (2020-10-07) 65 | -------------------- 66 | * Bug fix in KG split method 67 | 68 | 0.16.15 (2020-10-07) 69 | -------------------- 70 | * Fix WikiDataSets loader (again) 71 | 72 | 0.16.14 (2020-09-21) 73 | -------------------- 74 | * Fix WikiDataSets loader 75 | 76 | 0.16.13 (2020-08-06) 77 | -------------------- 78 | * Fix reduction in BCE loss 79 | * Add pretrained models 80 | 81 | 0.16.12 (2020-07-07) 82 | -------------------- 83 | * Release patch 84 | 85 | 0.16.11 (2020-07-07) 86 | -------------------- 87 | * Fix bug in pre-trained models loading that made all models being redownloaded every time 88 | 89 | 0.16.10 (2020-07-02) 90 | -------------------- 91 | * Minor bug patch 92 | 93 | 0.16.9 (2020-07-02) 94 | ------------------- 95 | * Update urls to retrieve datasets and pre-trained models. 96 | 97 | 0.16.8 (2020-07-01) 98 | ------------------- 99 | * Add binary cross-entropy loss 100 | 101 | 0.16.7 (2020-06-23) 102 | ------------------- 103 | * Change API for pre-trained models 104 | 105 | 0.16.6 (2020-06-09) 106 | ------------------- 107 | * Patch in pre-trained model loading 108 | * Added pre-trained loading for TransE on FB15k237 in dimension 100. 109 | 110 | 0.16.5 (2020-06-02) 111 | ------------------- 112 | * Release patch 113 | 114 | 0.16.4 (2020-06-02) 115 | ------------------- 116 | * Add parameter in data redundancy to exclude know reverse triplets from 117 | duplicate search. 118 | 119 | 0.16.3 (2020-05-29) 120 | ------------------- 121 | * Release patch 122 | 123 | 0.16.2 (2020-05-29) 124 | ------------------- 125 | * Add methods to compute data redundancy in knowledge graphs as in 2020 126 | `paper `__ by Akrami et al 127 | (see references in concerned methods). 128 | 129 | 0.16.1 (2020-05-28) 130 | ------------------- 131 | * Patch an awkward import 132 | * Add dataset loaders for WN18RR and YAGO3-10 133 | 134 | 0.16.0 (2020-04-27) 135 | ------------------- 136 | * Redefinition of the models' API (simplified interfaces, renamed LP 137 | methods and added get_embeddings method) 138 | * Implementation of the new API for all models 139 | * TorusE implementation fixed 140 | * TransD reimplementation to avoid matmul usage (costly in 141 | back-propagation) 142 | * Added feature to negative samplers to generate several negative 143 | samples from each fact. Those can be fed directly to the models. 144 | * Added some wrappers for training to utils module. 145 | * Progress bars now make the most of tqdm's possibilities 146 | * Code reformatting 147 | * Docstrings update 148 | 149 | 0.15.5 (2020-04-23) 150 | ------------------- 151 | * Defined a new homemade and simpler DataLoader class. 152 | 153 | 0.15.4 (2020-04-22) 154 | ------------------- 155 | * Removed the use of torch DataLoader object. 156 | 157 | 0.15.3 (2020-04-02) 158 | ------------------- 159 | * Added a method to print results in link prediction evaluator 160 | 161 | 0.15.2 (2020-04-01) 162 | ------------------- 163 | * Fixed a misfit test 164 | 165 | 0.15.1 (2020-04-01) 166 | ------------------- 167 | * Cleared the definition of rank in link prediction 168 | 169 | 0.15.0 (2020-04-01) 170 | ------------------- 171 | * Improved use of tqdm progress bars 172 | 173 | 0.14.0 (2020-04-01) 174 | ------------------- 175 | * Change in the API of loss functions (margin and logistic loss) 176 | * Documentation update 177 | 178 | 0.13.0 (2020-02-10) 179 | ------------------- 180 | * Added ConvKB model 181 | 182 | 0.12.1 (2020-01-10) 183 | ------------------- 184 | * Minor patch in interfaces 185 | * Comment additions 186 | 187 | 0.12.0 (2019-12-05) 188 | ------------------- 189 | * Various bug fixes 190 | * New KG splitting method enforcing all entities and relations to appear at least once in the training set. 191 | 192 | 0.11.3 (2019-11-15) 193 | ------------------- 194 | * Minor bug fixes 195 | 196 | 0.11.2 (2019-11-11) 197 | ------------------- 198 | * Minor bug fixes 199 | 200 | 0.11.1 (2019-10-21) 201 | ------------------- 202 | * Fixed requirements conflicts 203 | 204 | 0.11.0 (2019-10-21) 205 | ------------------- 206 | * Added TorusE model 207 | * Added dataloaders 208 | * Fixed some bugs 209 | 210 | 0.10.4 (2019-10-07) 211 | ------------------- 212 | * Fixed error in bilinear models. 213 | 214 | 0.10.3 (2019-07-23) 215 | ------------------- 216 | * Added intermediate function for hit@k metric in link prediction. 217 | 218 | 0.10.2 (2019-07-22) 219 | ------------------- 220 | * Fixed assertion error in Analogy model 221 | 222 | 0.10.0 (2019-07-19) 223 | ------------------- 224 | * Implemented Triplet Classification evaluation method 225 | * Added Negative Sampler objects to standardize negative sampling methods. 226 | 227 | 228 | 0.9.0 (2019-07-17) 229 | ------------------ 230 | * Implemented HolE model (Nickel et al.) 231 | * Implemented ComplEx model (Trouillon et al.) 232 | * Implemented ANALOGY model (Liu et al.) 233 | * Added knowledge graph splitting into train, validation and test instead of just train and test. 234 | 235 | 0.8.0 (2019-07-09) 236 | ------------------ 237 | * Implemented Bernoulli negative sampling as in Wang et al. paper on TransH (2014). 238 | 239 | 0.7.0 (2019-07-01) 240 | ------------------ 241 | * Implemented Mean Reciprocal Rank measure of performance. 242 | * Implemented Logistic Loss. 243 | * Changed implementation of margin loss to use torch methods. 244 | 245 | 0.6.0 (2019-06-25) 246 | ------------------ 247 | * Implemented DistMult 248 | 249 | 0.5.0 (2019-06-24) 250 | ------------------ 251 | * Changed implementation of LinkPrediction ranks by moving functions to model methods. 252 | * Implemented RESCAL. 253 | 254 | 0.4.0 (2019-05-15) 255 | ------------------ 256 | * Fixed a major bug/problem in the Evaluation protocol of LinkPrediction. 257 | 258 | 0.3.1 (2019-05-10) 259 | ------------------ 260 | * Minor bug fixes in the various normalization functions. 261 | 262 | 0.3.0 (2019-05-09) 263 | ------------------ 264 | * Fixed CUDA support. 265 | 266 | 0.2.0 (2019-05-07) 267 | ------------------ 268 | * Added support for filtered performance measures. 269 | 270 | 0.1.7 (2019-04-03) 271 | ------------------ 272 | * First real release on PyPI. 273 | 274 | 0.1.0 (2019-04-01) 275 | ------------------ 276 | * First release on PyPI. 277 | -------------------------------------------------------------------------------- /torchkge/docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to TorchKGE' s documentation! 2 | ====================================== 3 | 4 | .. include:: readme.rst 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: Tutorials: 9 | 10 | tutorials/training 11 | tutorials/evaluation 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :caption: Reference: 16 | 17 | reference/models 18 | reference/evaluation 19 | reference/inference 20 | reference/sampling 21 | reference/data 22 | reference/utils 23 | 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :caption: Installation: 28 | 29 | installation 30 | 31 | .. toctree:: 32 | :maxdepth: 1 33 | :caption: About: 34 | 35 | contributing 36 | authors 37 | history 38 | -------------------------------------------------------------------------------- /torchkge/docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | Stable release 4 | -------------- 5 | 6 | To install TorchKGE, run this command in your terminal: 7 | 8 | .. code-block:: console 9 | 10 | $ pip install torchkge 11 | 12 | This is the preferred method to install TorchKGE, as it will always install the most recent stable release. 13 | 14 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 15 | you through the process. 16 | 17 | .. _pip: https://pip.pypa.io 18 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 19 | 20 | 21 | From sources 22 | ------------ 23 | 24 | The sources for TorchKGE can be downloaded from the `Github repo`_. 25 | 26 | You can either clone the public repository: 27 | 28 | .. code-block:: console 29 | 30 | $ git clone git://github.com/torchkge/torchkge 31 | 32 | Or download the `tarball`_: 33 | 34 | .. code-block:: console 35 | 36 | $ curl -OL https://github.com/torchkge/torchkge/tarball/master 37 | 38 | Once you have a copy of the source, you can install it with: 39 | 40 | .. code-block:: console 41 | 42 | $ python setup.py install 43 | 44 | 45 | .. _Github repo: https://github.com/torchkge/torchkge 46 | .. _tarball: https://github.com/torchkge/torchkge/tarball/master 47 | -------------------------------------------------------------------------------- /torchkge/docs/logo_torchKGE_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunzeyeah/item-alignment/e1fd35edeb41081fe0f733dcc54df6679492f572/torchkge/docs/logo_torchKGE_small.png -------------------------------------------------------------------------------- /torchkge/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=torchkge 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /torchkge/docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /torchkge/docs/reference/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | 4 | Data Structure 5 | ******************* 6 | 7 | .. currentmodule:: torchkge.data_structures 8 | 9 | Knowledge Graph 10 | --------------- 11 | .. autoclass:: torchkge.data_structures.KnowledgeGraph 12 | :members: 13 | 14 | Small KG 15 | -------- 16 | .. autoclass:: torchkge.data_structures.SmallKG 17 | :members: 18 | -------------------------------------------------------------------------------- /torchkge/docs/reference/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _evaluation: 2 | 3 | 4 | Evaluation 5 | ********** 6 | 7 | Link Prediction 8 | --------------- 9 | To assess the performance of the link prediction evaluation module of TorchKGE, it was compared with the ones of 10 | `AmpliGraph `_ (v1.3.1) and `OpenKE `_ (version of 11 | April, 9). The computation times (in seconds) reported in the following table are averaged over 5 independent evaluation 12 | processes. Experiments were done using PyTorch 1.5, TensorFlow 1.15 and a Tesla K80 GPU. Missing values for AmpliGraph 13 | are due to missing models in the library. 14 | 15 | .. tabularcolumns:: p{2cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 16 | 17 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 18 | | Model | TransE | TransD | RESCAL | ComplEx | 19 | +===========+===========+===========+===========+===========+===========+===========+===========+===========+ 20 | | Dataset |FB15k | WN18 | FB15k | WN18 | FB15k | WN18 | FB15k | WN18 | 21 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 22 | |AmpliGraph | 354.8 | 39.8 | | | | | 537.2 | 94.9 | 23 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 24 | |OpenKE | 235.6 | 42.2 | 258.5 | 43.7 | 789.1 | 178.4 | 354.7 | 63.9 | 25 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 26 | |TorchKGE | 76.1 | 13.8 | 60.8 | 11.1 | 46.9 | 7.1 | 96.4 | 18.6 | 27 | +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+ 28 | 29 | .. autoclass:: torchkge.evaluation.LinkPredictionEvaluator 30 | :members: 31 | 32 | Relation Prediction 33 | ------------------- 34 | .. autoclass:: torchkge.evaluation.RelationPredictionEvaluator 35 | 36 | Triplet Classification 37 | ---------------------- 38 | .. autoclass:: torchkge.evaluation.TripletClassificationEvaluator 39 | :members: 40 | -------------------------------------------------------------------------------- /torchkge/docs/reference/inference.rst: -------------------------------------------------------------------------------- 1 | .. _inference: 2 | 3 | 4 | Inference 5 | ********* 6 | 7 | Entity Inference 8 | ---------------- 9 | 10 | .. autoclass:: torchkge.inference.EntityInference 11 | :members: 12 | 13 | Relation Inference 14 | ------------------ 15 | 16 | .. autoclass:: torchkge.inference.RelationInference 17 | :members: 18 | -------------------------------------------------------------------------------- /torchkge/docs/reference/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Models 4 | ****** 5 | 6 | Interfaces 7 | ========== 8 | 9 | Model 10 | ----- 11 | .. autoclass:: torchkge.models.interfaces.Model 12 | :members: 13 | 14 | TranslationalModels 15 | ------------------- 16 | .. autoclass:: torchkge.models.interfaces.TranslationModel 17 | :members: 18 | 19 | Translational Models 20 | ==================== 21 | 22 | Parameters used to train models available in pre-trained version : 23 | 24 | .. tabularcolumns:: p{2cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 25 | 26 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 27 | | | Dataset | Dimension | Optimizer | Learning Rate | Batch Size | Loss | Margin | L2 penalization | 28 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 29 | |TransE | FB15k | 100 | Adam | 2.1e-5 | 32768 | Margin | .651 | 1e-5 | 30 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 31 | |TransE | FB15k237 | 100 | Adam | 2.1e-5 | 32768 | Margin | .651 | 1e-5 | 32 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 33 | |TransE | FB15k237 | 150 | Adam | 2.7e-5 | 32768 | Margin | .648 | 1e-5 | 34 | +-------+-----------+-----------+-----------+---------------+------------+--------+--------+-----------------+ 35 | 36 | TransE 37 | ------ 38 | .. autoclass:: torchkge.models.translation.TransEModel 39 | :members: 40 | 41 | TransH 42 | ------ 43 | .. autoclass:: torchkge.models.translation.TransHModel 44 | :members: 45 | 46 | TransR 47 | ------ 48 | .. autoclass:: torchkge.models.translation.TransRModel 49 | :members: 50 | 51 | TransD 52 | ------ 53 | .. autoclass:: torchkge.models.translation.TransDModel 54 | :members: 55 | 56 | TorusE 57 | ------ 58 | .. autoclass:: torchkge.models.translation.TorusEModel 59 | :members: 60 | 61 | Bilinear Models 62 | =============== 63 | 64 | RESCAL 65 | ------ 66 | .. autoclass:: torchkge.models.bilinear.RESCALModel 67 | :members: 68 | 69 | DistMult 70 | -------- 71 | .. autoclass:: torchkge.models.bilinear.DistMultModel 72 | :members: 73 | 74 | HolE 75 | ---- 76 | .. autoclass:: torchkge.models.bilinear.HolEModel 77 | :members: 78 | 79 | ComplEx 80 | ------- 81 | .. autoclass:: torchkge.models.bilinear.ComplExModel 82 | :members: 83 | 84 | ANALOGY 85 | ------- 86 | .. autoclass:: torchkge.models.bilinear.AnalogyModel 87 | :members: 88 | 89 | Deep Models 90 | =========== 91 | 92 | ConvKB 93 | ------ 94 | .. autoclass:: torchkge.models.deep.ConvKBModel 95 | :members: 96 | -------------------------------------------------------------------------------- /torchkge/docs/reference/sampling.rst: -------------------------------------------------------------------------------- 1 | .. _sampling: 2 | 3 | .. currentmodule:: torchkge.sampling 4 | 5 | Negative Sampling 6 | ***************** 7 | 8 | Uniform negative sampler 9 | ------------------------ 10 | .. autoclass:: torchkge.sampling.UniformNegativeSampler 11 | :members: 12 | 13 | Bernoulli negative sampler 14 | -------------------------- 15 | .. autoclass:: torchkge.sampling.BernoulliNegativeSampler 16 | :members: 17 | 18 | Positional negative sampler 19 | --------------------------- 20 | .. autoclass:: torchkge.sampling.PositionalNegativeSampler 21 | :members: 22 | -------------------------------------------------------------------------------- /torchkge/docs/reference/utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | 4 | Utils 5 | ***** 6 | 7 | .. currentmodule:: torchkge.utils 8 | 9 | Datasets loaders 10 | ---------------- 11 | 12 | .. autofunction:: torchkge.utils.datasets.load_fb13 13 | .. autofunction:: torchkge.utils.datasets.load_fb15k 14 | .. autofunction:: torchkge.utils.datasets.load_fb15k237 15 | .. autofunction:: torchkge.utils.datasets.load_wn18 16 | .. autofunction:: torchkge.utils.datasets.load_wn18rr 17 | .. autofunction:: torchkge.utils.datasets.load_yago3_10 18 | .. autofunction:: torchkge.utils.datasets.load_wikidatasets 19 | .. autofunction:: torchkge.utils.datasets.load_wikidata_vitals 20 | 21 | 22 | Pre-trained models 23 | ------------------ 24 | 25 | TransE model 26 | ============ 27 | .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm}p{3cm} 28 | 29 | +-----------+-----------+-----------+----------+--------------------+ 30 | | Model | Dataset | Dimension | Test MRR | Filtered Test MRR | 31 | +===========+===========+===========+==========+====================+ 32 | | TransE | FB15k | 100 | 0.250 | 0.420 | 33 | +-----------+-----------+-----------+----------+--------------------+ 34 | | TransE | FB15k237 | 150 | 0.187 | 0.287 | 35 | +-----------+-----------+-----------+----------+--------------------+ 36 | | TransE | WDV5 | 150 | 0.258 | 0.305 | 37 | +-----------+-----------+-----------+----------+--------------------+ 38 | | TransE | WN18RR | 100 | 0.201 | 0.236 | 39 | +-----------+-----------+-----------+----------+--------------------+ 40 | | TransE | Yago3-10 | 200 | 0.143 | 0.261 | 41 | +-----------+-----------+-----------+----------+--------------------+ 42 | 43 | .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_transe 44 | 45 | ComplEx Model 46 | ============= 47 | .. tabularcolumns:: p{3cm}p{3cm}p{3cm}p{3cm} 48 | 49 | +-----------+-----------+-----------+----------+--------------------+ 50 | | Model | Dataset | Dimension | Test MRR | Filtered Test MRR | 51 | +===========+===========+===========+==========+====================+ 52 | | ComplEx | FB15k237 | 200 | 0.180 | 0.308 | 53 | +-----------+-----------+-----------+----------+--------------------+ 54 | | ComplEx | WN18RR | 200 | 0.290 | 0.455 | 55 | +-----------+-----------+-----------+----------+--------------------+ 56 | | ComplEx | WDV5 | 200 | 0.283 | 0.371 | 57 | +-----------+-----------+-----------+----------+--------------------+ 58 | 59 | .. autofunction:: torchkge.utils.pretrained_models.load_pretrained_complex 60 | 61 | Data redundancy 62 | --------------- 63 | .. autofunction:: torchkge.utils.data_redundancy.duplicates 64 | .. autofunction:: torchkge.utils.data_redundancy.count_triplets 65 | .. autofunction:: torchkge.utils.data_redundancy.cartesian_product_relations 66 | 67 | Dissimilarities 68 | --------------- 69 | .. autofunction:: torchkge.utils.dissimilarities.l1_dissimilarity 70 | .. autofunction:: torchkge.utils.dissimilarities.l2_dissimilarity 71 | .. autofunction:: torchkge.utils.dissimilarities.l1_torus_dissimilarity 72 | .. autofunction:: torchkge.utils.dissimilarities.l2_torus_dissimilarity 73 | .. autofunction:: torchkge.utils.dissimilarities.el2_torus_dissimilarity 74 | 75 | Losses 76 | ------ 77 | .. autoclass:: torchkge.utils.losses.MarginLoss 78 | :members: 79 | .. autoclass:: torchkge.utils.losses.LogisticLoss 80 | :members: 81 | .. autoclass:: torchkge.utils.losses.BinaryCrossEntropyLoss 82 | :members: 83 | 84 | Training wrappers 85 | ----------------- 86 | .. autoclass:: torchkge.utils.training.TrainDataLoader 87 | :members: 88 | .. autoclass:: torchkge.utils.training.Trainer 89 | :members: 90 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/evaluation.rst: -------------------------------------------------------------------------------- 1 | Model Evaluation 2 | **************** 3 | 4 | .. include:: linkprediction.rst 5 | 6 | .. include:: tripletclassification.rst 7 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/linkprediction.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Link Prediction 3 | =============== 4 | 5 | To evaluate a model on link prediction:: 6 | 7 | from torch import cuda 8 | from torchkge.utils.pretrained_models import load_pretrained_transe 9 | from torchkge.utils.datasets import load_fb15k 10 | from torchkge.evaluation import LinkPredictionEvaluator 11 | 12 | _, _, kg_test = load_fb15k() 13 | 14 | model = load_pretrained_transe('fb15k', 100) 15 | if cuda.is_available(): 16 | model.cuda() 17 | 18 | # Link prediction evaluation on test set. 19 | evaluator = LinkPredictionEvaluator(model, kg_test) 20 | evaluator.evaluate(b_size=32) 21 | evaluator.print_results() 22 | 23 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/training.rst: -------------------------------------------------------------------------------- 1 | Model Training 2 | ************** 3 | 4 | Here are two examples of models being trained on FB15k. 5 | 6 | .. include:: transe.rst 7 | 8 | .. include:: transe_wrappers.rst 9 | 10 | .. include:: transe_early_stopping.rst 11 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/transe.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Simplest training 3 | ================= 4 | 5 | This is the python code to train TransE without any wrapper. This script shows how all parts of TorchKGE should be used 6 | together:: 7 | 8 | from torch import cuda 9 | from torch.optim import Adam 10 | 11 | from torchkge.models import TransEModel 12 | from torchkge.sampling import BernoulliNegativeSampler 13 | from torchkge.utils import MarginLoss, DataLoader 14 | from torchkge.utils.datasets import load_fb15k 15 | 16 | from tqdm.autonotebook import tqdm 17 | 18 | # Load dataset 19 | kg_train, _, _ = load_fb15k() 20 | 21 | # Define some hyper-parameters for training 22 | emb_dim = 100 23 | lr = 0.0004 24 | n_epochs = 1000 25 | b_size = 32768 26 | margin = 0.5 27 | 28 | # Define the model and criterion 29 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2') 30 | criterion = MarginLoss(margin) 31 | 32 | # Move everything to CUDA if available 33 | if cuda.is_available(): 34 | cuda.empty_cache() 35 | model.cuda() 36 | criterion.cuda() 37 | 38 | # Define the torch optimizer to be used 39 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 40 | 41 | sampler = BernoulliNegativeSampler(kg_train) 42 | dataloader = DataLoader(kg_train, batch_size=b_size, use_cuda='all') 43 | 44 | iterator = tqdm(range(n_epochs), unit='epoch') 45 | for epoch in iterator: 46 | running_loss = 0.0 47 | for i, batch in enumerate(dataloader): 48 | h, t, r = batch[0], batch[1], batch[2] 49 | n_h, n_t = sampler.corrupt_batch(h, t, r) 50 | 51 | optimizer.zero_grad() 52 | 53 | # forward + backward + optimize 54 | pos, neg = model(h, t, r, n_h, n_t) 55 | loss = criterion(pos, neg) 56 | loss.backward() 57 | optimizer.step() 58 | 59 | running_loss += loss.item() 60 | iterator.set_description( 61 | 'Epoch {} | mean loss: {:.5f}'.format(epoch + 1, 62 | running_loss / len(dataloader))) 63 | 64 | model.normalize_parameters() 65 | 66 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/transe_early_stopping.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Training with Ignite 3 | ==================== 4 | 5 | TorchKGE can be used along with the `PyTorch ignite `_ library. It makes it easy to include 6 | early stopping in the training process. Here is an example script of training a TransE model on FB15k on GPU with early 7 | stopping on evaluation MRR:: 8 | 9 | import torch 10 | from ignite.engine import Engine, Events 11 | from ignite.handlers import EarlyStopping 12 | from ignite.metrics import RunningAverage 13 | from torch.optim import Adam 14 | 15 | from torchkge.evaluation import LinkPredictionEvaluator 16 | from torchkge.models import TransEModel 17 | from torchkge.sampling import BernoulliNegativeSampler 18 | from torchkge.utils import MarginLoss, DataLoader 19 | from torchkge.utils.datasets import load_fb15k 20 | 21 | 22 | def process_batch(engine, batch): 23 | h, t, r = batch[0], batch[1], batch[2] 24 | n_h, n_t = sampler.corrupt_batch(h, t, r) 25 | 26 | optimizer.zero_grad() 27 | 28 | pos, neg = model(h, t, r, n_h, n_t) 29 | loss = criterion(pos, neg) 30 | loss.backward() 31 | optimizer.step() 32 | 33 | return loss.item() 34 | 35 | 36 | def linkprediction_evaluation(engine): 37 | model.normalize_parameters() 38 | 39 | loss = engine.state.output 40 | 41 | # validation MRR measure 42 | if engine.state.epoch % eval_epoch == 0: 43 | evaluator = LinkPredictionEvaluator(model, kg_val) 44 | evaluator.evaluate(b_size=256, verbose=False) 45 | val_mrr = evaluator.mrr()[1] 46 | else: 47 | val_mrr = 0 48 | 49 | print('Epoch {} | Train loss: {}, Validation MRR: {}'.format( 50 | engine.state.epoch, loss, val_mrr)) 51 | 52 | try: 53 | if engine.state.best_mrr < val_mrr: 54 | engine.state.best_mrr = val_mrr 55 | return val_mrr 56 | 57 | except AttributeError as e: 58 | if engine.state.epoch == 1: 59 | engine.state.best_mrr = val_mrr 60 | return val_mrr 61 | else: 62 | raise e 63 | 64 | device = torch.device('cuda') 65 | 66 | eval_epoch = 20 # do link prediction evaluation each 20 epochs 67 | max_epochs = 1000 68 | patience = 40 69 | batch_size = 32768 70 | emb_dim = 100 71 | lr = 0.0004 72 | margin = 0.5 73 | 74 | kg_train, kg_val, kg_test = load_fb15k() 75 | 76 | # Define the model, optimizer and criterion 77 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, 78 | dissimilarity_type='L2') 79 | model.to(device) 80 | 81 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 82 | criterion = MarginLoss(margin) 83 | sampler = BernoulliNegativeSampler(kg_train, kg_val=kg_val, kg_test=kg_test) 84 | 85 | # Define the engine 86 | trainer = Engine(process_batch) 87 | 88 | # Define the moving average 89 | RunningAverage(output_transform=lambda x: x).attach(trainer, 'margin') 90 | 91 | # Add early stopping 92 | handler = EarlyStopping(patience=patience, 93 | score_function=linkprediction_evaluation, 94 | trainer=trainer) 95 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) 96 | 97 | # Training 98 | train_iterator = DataLoader(kg_train, batch_size, use_cuda='all') 99 | trainer.run(train_iterator, 100 | epoch_length=len(train_iterator), 101 | max_epochs=max_epochs) 102 | 103 | print('Best score {:.3f} at epoch {}'.format(handler.best_score, 104 | trainer.state.epoch - handler.patience)) 105 | 106 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/transe_wrappers.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Shortest training 3 | ================= 4 | 5 | TorchKGE also provides simple utility wrappers for model training. Here is an example on how to use them:: 6 | 7 | from torch.optim import Adam 8 | 9 | from torchkge.evaluation import LinkPredictionEvaluator 10 | from torchkge.models import TransEModel 11 | from torchkge.utils.datasets import load_fb15k 12 | from torchkge.utils import Trainer, MarginLoss 13 | 14 | 15 | def main(): 16 | # Define some hyper-parameters for training 17 | emb_dim = 100 18 | lr = 0.0004 19 | margin = 0.5 20 | n_epochs = 1000 21 | batch_size = 32768 22 | 23 | # Load dataset 24 | kg_train, kg_val, kg_test = load_fb15k() 25 | 26 | # Define the model and criterion 27 | model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, 28 | dissimilarity_type='L2') 29 | criterion = MarginLoss(margin) 30 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5) 31 | 32 | trainer = Trainer(model, criterion, kg_train, n_epochs, batch_size, 33 | optimizer=optimizer, sampling_type='bern', use_cuda='all',) 34 | 35 | trainer.run() 36 | 37 | evaluator = LinkPredictionEvaluator(model, kg_test) 38 | evaluator.evaluate(200) 39 | evaluator.print_results() 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | 45 | -------------------------------------------------------------------------------- /torchkge/docs/tutorials/tripletclassification.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | Triplet Classification 3 | ====================== 4 | 5 | To evaluate a model on triplet classification:: 6 | 7 | from torch import cuda 8 | from torchkge.evaluation import TripletClassificationEvaluator 9 | from torchkge.utils.pretrained_models import load_pretrained_transe 10 | from torchkge.utils.datasets import load_fb15k 11 | 12 | _, kg_val, kg_test = load_fb15k() 13 | 14 | model = load_pretrained_transe('fb15k', 100): 15 | if cuda.is_available(): 16 | model.cuda() 17 | 18 | # Triplet classification evaluation on test set by learning thresholds on validation set 19 | evaluator = TripletClassificationEvaluator(model, kg_val, kg_test) 20 | evaluator.evaluate(b_size=128) 21 | 22 | print('Accuracy on test set: {}'.format(evaluator.accuracy(b_size=128))) 23 | 24 | -------------------------------------------------------------------------------- /torchkge/requirements_dev.txt: -------------------------------------------------------------------------------- 1 | # Documentation 2 | sphinx>=3.1 3 | sphinx_rtd_theme==1.0 4 | numpydoc==1.2.1 5 | 6 | # Tests 7 | flake8>=3.8.3 8 | tox==3.25.0 9 | pytest-runner>=6.0.0 10 | pytest==7.1.1 11 | 12 | # Deployement 13 | pip==22.0.4 14 | bumpversion==0.6 15 | wheel==0.37.1 16 | -------------------------------------------------------------------------------- /torchkge/torchkge/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for TorchKGE.""" 4 | 5 | __author__ = """Armand Boschin""" 6 | __email__ = 'aboschin@enst.fr' 7 | __version__ = '0.17.3' 8 | 9 | from .exceptions import NotYetEvaluatedError 10 | from .utils import MarginLoss, LogisticLoss 11 | from .utils import l1_dissimilarity, l2_dissimilarity 12 | from .data_structures import KnowledgeGraph 13 | from .evaluation import LinkPredictionEvaluator 14 | from .evaluation import TripletClassificationEvaluator 15 | from .models import ConvKBModel 16 | from .models import RESCALModel, DistMultModel 17 | from .models import TransEModel, TransHModel, TransRModel, TransDModel 18 | -------------------------------------------------------------------------------- /torchkge/torchkge/exceptions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | 8 | class NotYetEvaluatedError(Exception): 9 | def __init__(self, message): 10 | super().__init__(message) 11 | 12 | 13 | class SizeMismatchError(Exception): 14 | def __init__(self, message): 15 | super().__init__(message) 16 | 17 | 18 | class WrongDimensionError(Exception): 19 | def __init__(self, message): 20 | super().__init__(message) 21 | 22 | 23 | class NotYetImplementedError(Exception): 24 | def __init__(self, message): 25 | super().__init__(message) 26 | 27 | 28 | class WrongArgumentsError(Exception): 29 | def __init__(self, message): 30 | super().__init__(message) 31 | 32 | 33 | class SanityError(Exception): 34 | def __init__(self, message): 35 | super().__init__(message) 36 | 37 | 38 | class SplitabilityError(Exception): 39 | def __init__(self, message): 40 | super().__init__(message) 41 | 42 | 43 | class NoPreTrainedVersionError(Exception): 44 | def __init__(self, message): 45 | super().__init__(message) 46 | -------------------------------------------------------------------------------- /torchkge/torchkge/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | from torch import empty, tensor 7 | from tqdm.autonotebook import tqdm 8 | 9 | from .exceptions import WrongArgumentsError 10 | from .utils import filter_scores 11 | from .utils.data import get_n_batches 12 | 13 | 14 | class DataLoader_: 15 | """This class is inspired from :class:`torch.utils.dataloader.DataLoader`. 16 | It is however way simpler. 17 | 18 | """ 19 | def __init__(self, a, b, batch_size, use_cuda=None): 20 | """ 21 | 22 | Parameters 23 | ---------- 24 | batch_size: int 25 | Size of the required batches. 26 | use_cuda: str (opt, default = None) 27 | Can be either None (no use of cuda at all), 'all' to move all the 28 | dataset to cuda and then split in batches or 'batch' to simply move 29 | the batches to cuda before they are returned. 30 | """ 31 | self.a = a 32 | self.b = b 33 | 34 | self.use_cuda = use_cuda 35 | self.batch_size = batch_size 36 | 37 | if use_cuda is not None and use_cuda == 'all': 38 | self.a = self.a.cuda() 39 | self.b = self.b.cuda() 40 | 41 | def __len__(self): 42 | return get_n_batches(len(self.a), self.batch_size) 43 | 44 | def __iter__(self): 45 | return _DataLoaderIter(self) 46 | 47 | 48 | class _DataLoaderIter: 49 | def __init__(self, loader): 50 | self.a = loader.a 51 | self.b = loader.b 52 | 53 | self.use_cuda = loader.use_cuda 54 | self.batch_size = loader.batch_size 55 | 56 | self.n_batches = get_n_batches(len(self.a), self.batch_size) 57 | self.current_batch = 0 58 | 59 | def __next__(self): 60 | if self.current_batch == self.n_batches: 61 | raise StopIteration 62 | else: 63 | i = self.current_batch 64 | self.current_batch += 1 65 | 66 | tmp_a = self.a[i * self.batch_size: (i + 1) * self.batch_size] 67 | tmp_b = self.b[i * self.batch_size: (i + 1) * self.batch_size] 68 | 69 | if self.use_cuda is not None and self.use_cuda == 'batch': 70 | return tmp_a.cuda(), tmp_b.cuda() 71 | else: 72 | return tmp_a, tmp_b 73 | 74 | def __iter__(self): 75 | return self 76 | 77 | 78 | class RelationInference(object): 79 | """Use trained embedding model to infer missing relations in triples. 80 | 81 | Parameters 82 | ---------- 83 | model: torchkge.models.interfaces.Model 84 | Embedding model inheriting from the right interface. 85 | entities1: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 86 | List of the indices of known entities 1. 87 | entities2: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 88 | List of the indices of known entities 2. 89 | top_k: int 90 | Indicates the number of top predictions to return. 91 | dictionary: dict, optional (default=None) 92 | Dictionary of possible relations. It is used to filter predictions 93 | that are known to be True in the training set in order to return 94 | only new facts. 95 | 96 | Attributes 97 | ---------- 98 | model: torchkge.models.interfaces.Model 99 | Embedding model inheriting from the right interface. 100 | entities1: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 101 | List of the indices of known entities 1. 102 | entities2: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 103 | List of the indices of known entities 2. 104 | top_k: int 105 | Indicates the number of top predictions to return. 106 | dictionary: dict, optional (default=None) 107 | Dictionary of possible relations. It is used to filter predictions 108 | that are known to be True in the training set in order to return 109 | only new facts. 110 | predictions: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.long` 111 | List of the indices of predicted relations for each test fact. 112 | scores: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.float` 113 | List of the scores of resulting triples for each test fact. 114 | """ 115 | # TODO: add the possibility to infer link orientation as well. 116 | 117 | def __init__(self, model, entities1, entities2, top_k=1, dictionary=None): 118 | 119 | self.model = model 120 | self.entities1 = entities1 121 | self.entities2 = entities2 122 | self.topk = top_k 123 | self.dictionary = dictionary 124 | 125 | self.predictions = empty(size=(len(entities1), top_k)).long() 126 | self.scores = empty(size=(len(entities2), top_k)) 127 | 128 | def evaluate(self, b_size, verbose=True): 129 | use_cuda = next(self.model.parameters()).is_cuda 130 | 131 | if use_cuda: 132 | dataloader = DataLoader_(self.entities1, self.entities2, batch_size=b_size, use_cuda='batch') 133 | self.predictions = self.predictions.cuda() 134 | else: 135 | dataloader = DataLoader_(self.entities1, self.entities2, batch_size=b_size) 136 | 137 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 138 | unit='batch', disable=(not verbose), 139 | desc='Inference'): 140 | ents1, ents2 = batch[0], batch[1] 141 | h_emb, t_emb, _, candidates = self.model.inference_prepare_candidates(ents1, ents2, tensor([]).long(), 142 | entities=False) 143 | scores = self.model.inference_scoring_function(h_emb, t_emb, candidates) 144 | 145 | if self.dictionary is not None: 146 | scores = filter_scores(scores, self.dictionary, ents1, ents2, None) 147 | 148 | scores, indices = scores.sort(descending=True) 149 | 150 | self.predictions[i * b_size: (i + 1) * b_size] = indices[:, :self.topk] 151 | self.scores[i * b_size, (i + 1) * b_size] = scores[:, :self.topk] 152 | 153 | if use_cuda: 154 | self.predictions = self.predictions.cpu() 155 | self.scores = self.scores.cpu() 156 | 157 | 158 | class EntityInference(object): 159 | """Use trained embedding model to infer missing entities in triples. 160 | 161 | Parameters 162 | ---------- 163 | model: torchkge.models.interfaces.Model 164 | Embedding model inheriting from the right interface. 165 | known_entities: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 166 | List of the indices of known entities. 167 | known_relations: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 168 | List of the indices of known relations. 169 | top_k: int 170 | Indicates the number of top predictions to return. 171 | missing: str 172 | String indicating if the missing entities are the heads or the tails. 173 | dictionary: dict, optional (default=None) 174 | Dictionary of possible heads or tails (depending on the value of `missing`). 175 | It is used to filter predictions that are known to be True in the training set 176 | in order to return only new facts. 177 | 178 | Attributes 179 | ---------- 180 | model: torchkge.models.interfaces.Model 181 | Embedding model inheriting from the right interface. 182 | known_entities: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 183 | List of the indices of known entities. 184 | known_relations: `torch.Tensor`, shape: (n_facts), dtype: `torch.long` 185 | List of the indices of known relations. 186 | top_k: int 187 | Indicates the number of top predictions to return. 188 | missing: str 189 | String indicating if the missing entities are the heads or the tails. 190 | dictionary: dict, optional (default=None) 191 | Dictionary of possible heads or tails (depending on the value of `missing`). 192 | It is used to filter predictions that are known to be True in the training set 193 | in order to return only new facts. 194 | predictions: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.long` 195 | List of the indices of predicted entities for each test fact. 196 | scores: `torch.Tensor`, shape: (n_facts, self.top_k), dtype: `torch.float` 197 | List of the scores of resulting triples for each test fact. 198 | 199 | """ 200 | def __init__(self, model, known_entities, known_relations, top_k=1, missing='tails', dictionary=None): 201 | try: 202 | assert missing in ['heads', 'tails'] 203 | self.missing = missing 204 | except AssertionError: 205 | raise WrongArgumentsError("missing entity should either be 'heads' or 'tails'") 206 | self.model = model 207 | self.known_entities = known_entities 208 | self.known_relations = known_relations 209 | self.missing = missing 210 | self.top_k = top_k 211 | self.dictionary = dictionary 212 | 213 | self.predictions = empty(size=(len(known_entities), top_k)).long() 214 | self.scores = empty(size=(len(known_entities), top_k)) 215 | 216 | def evaluate(self, b_size, verbose=True): 217 | use_cuda = next(self.model.parameters()).is_cuda 218 | 219 | if use_cuda: 220 | dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size, use_cuda='batch') 221 | self.predictions = self.predictions.cuda() 222 | else: 223 | dataloader = DataLoader_(self.known_entities, self.known_relations, batch_size=b_size) 224 | 225 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), 226 | unit='batch', disable=(not verbose), 227 | desc='Inference'): 228 | known_ents, known_rels = batch[0], batch[1] 229 | if self.missing == 'heads': 230 | _, t_emb, rel_emb, candidates = self.model.inference_prepare_candidates(tensor([]).long(), known_ents, 231 | known_rels, 232 | entities=True) 233 | scores = self.model.inference_scoring_function(candidates, t_emb, rel_emb) 234 | else: 235 | h_emb, _, rel_emb, candidates = self.model.inference_prepare_candidates(known_ents, tensor([]).long(), 236 | known_rels, 237 | entities=True) 238 | scores = self.model.inference_scoring_function(h_emb, candidates, rel_emb) 239 | 240 | if self.dictionary is not None: 241 | scores = filter_scores(scores, self.dictionary, known_ents, known_rels, None) 242 | 243 | scores, indices = scores.sort(descending=True) 244 | 245 | self.predictions[i * b_size: (i+1)*b_size] = indices[:, :self.top_k] 246 | self.scores[i*b_size, (i+1)*b_size] = scores[:, :self.top_k] 247 | 248 | if use_cuda: 249 | self.predictions = self.predictions.cpu() 250 | self.scores = self.scores.cpu() 251 | -------------------------------------------------------------------------------- /torchkge/torchkge/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .interfaces import Model, TranslationModel, BilinearModel 2 | 3 | from .translation import TransEModel 4 | from .translation import TransHModel 5 | from .translation import TransRModel 6 | from .translation import TransDModel 7 | from .translation import TorusEModel 8 | from .translation import PKGMModel 9 | 10 | from .bilinear import RESCALModel 11 | from .bilinear import DistMultModel 12 | from .bilinear import HolEModel 13 | from .bilinear import ComplExModel 14 | from .bilinear import AnalogyModel 15 | 16 | from .deep import ConvKBModel 17 | -------------------------------------------------------------------------------- /torchkge/torchkge/models/deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import nn, cat 8 | 9 | from ..models.interfaces import Model 10 | from ..utils import init_embedding 11 | 12 | 13 | class ConvKBModel(Model): 14 | """Implementation of ConvKB model detailed in 2018 paper by Nguyen et al.. 15 | This class inherits from the :class:`torchkge.models.interfaces.Model` 16 | interface. It then has its attributes as well. 17 | 18 | 19 | References 20 | ---------- 21 | * Nguyen, D. Q., Nguyen, T. D., Nguyen, D. Q., and Phung, D. 22 | `A Novel Embed- ding Model for Knowledge Base Completion Based on 23 | Convolutional Neural Network. 24 | `_ 25 | In Proceedings of the 2018 Conference of the North American Chapter of 26 | the Association for Computational Linguistics: Human Language 27 | Technologies (2018), vol. 2, pp. 327–333. 28 | 29 | Parameters 30 | ---------- 31 | emb_dim: int 32 | Dimension of embedding space. 33 | n_filters: int 34 | Number of filters used for convolution. 35 | n_entities: int 36 | Number of entities in the current data set. 37 | n_relations: int 38 | Number of relations in the current data set. 39 | 40 | Attributes 41 | ---------- 42 | ent_emb: torch.nn.Embedding, shape: (n_ent, emb_dim) 43 | Embeddings of the entities, initialized with Xavier uniform 44 | distribution and then normalized. 45 | rel_emb: torch.nn.Embedding, shape: (n_rel, emb_dim) 46 | Embeddings of the relations, initialized with Xavier uniform 47 | distribution. 48 | 49 | """ 50 | 51 | def __init__(self, emb_dim, n_filters, n_entities, n_relations): 52 | super().__init__(n_entities, n_relations) 53 | self.emb_dim = emb_dim 54 | 55 | self.ent_emb = init_embedding(self.n_ent, self.emb_dim) 56 | self.rel_emb = init_embedding(self.n_rel, self.emb_dim) 57 | 58 | self.convlayer = nn.Sequential(nn.Conv1d(3, n_filters, 1, stride=1), 59 | nn.ReLU()) 60 | self.output = nn.Sequential(nn.Linear(emb_dim * n_filters, 2), 61 | nn.Softmax(dim=1)) 62 | 63 | def scoring_function(self, h_idx, t_idx, r_idx): 64 | """Compute the scoring function for the triplets given as argument: 65 | by applying convolutions to the concatenation of the embeddings. See 66 | referenced paper for more details on the score. See 67 | torchkge.models.interfaces.Models for more details on the API. 68 | 69 | """ 70 | b_size = h_idx.shape[0] 71 | 72 | h = self.ent_emb(h_idx).view(b_size, 1, -1) 73 | t = self.ent_emb(t_idx).view(b_size, 1, -1) 74 | r = self.rel_emb(r_idx).view(b_size, 1, -1) 75 | concat = cat((h, r, t), dim=1) 76 | 77 | return self.output(self.convlayer(concat).reshape(b_size, -1))[:, 1] 78 | 79 | def normalize_parameters(self): 80 | """Normalize the entity embeddings, as explained in original paper. 81 | This methods should be called at the end of each training epoch and at 82 | the end of training as well. 83 | 84 | """ 85 | pass 86 | 87 | def get_embeddings(self): 88 | """Return the embeddings of entities and relations. 89 | 90 | Returns 91 | ------- 92 | ent_emb: torch.Tensor, shape: (n_ent, emb_dim), dtype: torch.float 93 | Embeddings of entities. 94 | rel_emb: torch.Tensor, shape: (n_rel, emb_dim), dtype: torch.float 95 | Embeddings of relations. 96 | 97 | """ 98 | self.normalize_parameters() 99 | return self.ent_emb.weight.data, self.rel_emb.weight.data 100 | 101 | def inference_scoring_function(self, h, t, r): 102 | """Link prediction evaluation helper function. See 103 | torchkge.models.interfaces.Models for more details on the API. 104 | 105 | """ 106 | b_size = h.shape[0] 107 | 108 | if (len(h.shape) == 2) & (len(t.shape) == 4) & (len(r.shape) == 2): 109 | concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 110 | r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 111 | t), dim=2) 112 | concat = concat.reshape(-1, 3, self.emb_dim) 113 | 114 | elif (len(h.shape) == 4) & (len(t.shape) == 2) & (len(r.shape) == 2): 115 | concat = cat((h, 116 | r.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim), 117 | t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_ent, 1, self.emb_dim)), dim=2) 118 | concat = concat.reshape(-1, 3, self.emb_dim) 119 | 120 | else: 121 | assert (len(h.shape) == 2) & (len(t.shape) == 2) & (len(r.shape) == 4) 122 | concat = cat((h.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim), 123 | r, 124 | t.view(b_size, 1, 1, self.emb_dim).expand(b_size, self.n_rel, 1, self.emb_dim)), dim=2) 125 | concat = concat.reshape(-1, 3, self.emb_dim) 126 | 127 | scores = self.output(self.convlayer(concat).reshape(concat.shape[0], -1)) 128 | scores = scores.reshape(b_size, -1, 2) 129 | 130 | return scores[:, :, 1] 131 | 132 | def inference_prepare_candidates(self, h_idx, t_idx, r_idx, entities=True): 133 | """Link prediction evaluation helper function. Get entities embeddings 134 | and relations embeddings. The output will be fed to the 135 | `inference_scoring_function` method. See torchkge.models.interfaces.Models for 136 | more details on the API. 137 | 138 | """ 139 | b_size = h_idx.shape[0] 140 | 141 | h = self.ent_emb(h_idx) 142 | t = self.ent_emb(t_idx) 143 | r = self.rel_emb(r_idx) 144 | 145 | if entities: 146 | candidates = self.ent_emb.weight.data.view(1, self.n_ent, self.emb_dim) 147 | candidates = candidates.expand(b_size, self.n_ent, self.emb_dim) 148 | candidates = candidates.view(b_size, self.n_ent, 1, self.emb_dim) 149 | else: 150 | candidates = self.rel_emb.weight.data.view(1, self.n_rel, self.emb_dim) 151 | candidates = candidates.expand(b_size, self.n_rel, self.emb_dim) 152 | candidates = candidates.view(b_size, self.n_rel, 1, self.emb_dim) 153 | 154 | return h, t, r, candidates 155 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import DataLoader, get_data_home, clear_data_home 2 | 3 | from .data_redundancy import count_triplets, duplicates 4 | from .data_redundancy import cartesian_product_relations 5 | 6 | from .datasets import load_fb15k, load_fb13, load_fb15k237, load_ccks 7 | from .datasets import load_wn18, load_wn18rr 8 | from .datasets import load_yago3_10, load_wikidatasets, load_wikidata_vitals 9 | 10 | from .dissimilarities import l1_dissimilarity, l2_dissimilarity 11 | from .dissimilarities import l1_torus_dissimilarity, l2_torus_dissimilarity, \ 12 | el2_torus_dissimilarity 13 | 14 | from .losses import MarginLoss, LogisticLoss, BinaryCrossEntropyLoss 15 | from .modeling import init_embedding, init_linear_projection, get_true_targets, load_embeddings, filter_scores 16 | from .operations import get_rank, get_mask, get_bernoulli_probs 17 | from .training import Trainer, TrainDataLoader 18 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | import shutil 8 | 9 | from os import environ, makedirs 10 | from os.path import exists, expanduser, join 11 | 12 | 13 | def get_data_home(data_home=None): 14 | """Returns the path to the data directory. The path is created if 15 | it does not exist. 16 | 17 | If data_home is none, the data is downloaded into the home directory of 18 | of the user. 19 | 20 | Parameters 21 | ---------- 22 | data_home: string 23 | The path to the data set. 24 | """ 25 | if data_home is None: 26 | data_home = environ.get('TORCHKGE_DATA', 27 | join('~', 'torchkge_data')) 28 | data_home = expanduser(data_home) 29 | if not exists(data_home): 30 | makedirs(data_home) 31 | return data_home 32 | 33 | 34 | def clear_data_home(data_home=None): 35 | """Deletes the directory data_home 36 | 37 | Parameters 38 | ---------- 39 | data_home: string 40 | The path to the directory that should be removed. 41 | """ 42 | data_home = get_data_home(data_home) 43 | shutil.rmtree(data_home) 44 | 45 | 46 | def get_n_batches(n, b_size): 47 | """Returns the number of bachtes. Let n be the number of samples in the data set, 48 | let batch_size be the number of samples per batch, then the number of batches is given by 49 | n 50 | n_batches = --------- 51 | batch_size 52 | 53 | Parameters 54 | ---------- 55 | n: int 56 | Size of the data set. 57 | b_size: int 58 | Number of samples per batch. 59 | """ 60 | n_batch = n // b_size 61 | if n % b_size > 0: 62 | n_batch += 1 63 | return n_batch 64 | 65 | 66 | class DataLoader: 67 | """This class is inspired from :class:`torch.utils.dataloader.DataLoader`. 68 | It is however way simpler. 69 | 70 | """ 71 | def __init__(self, kg, batch_size, use_cuda=None): 72 | """ 73 | 74 | Parameters 75 | ---------- 76 | kg: torchkge.data_structures.KnowledgeGraph or torchkge.data_structures.SmallKG 77 | Knowledge graph in the form of an object implemented in 78 | torchkge.data_structures. 79 | batch_size: int 80 | Size of the required batches. 81 | use_cuda: str (opt, default = None) 82 | Can be either None (no use of cuda at all), 'all' to move all the 83 | dataset to cuda and then split in batches or 'batch' to simply move 84 | the batches to cuda before they are returned. 85 | """ 86 | self.h = kg.head_idx 87 | self.t = kg.tail_idx 88 | self.r = kg.relations 89 | 90 | self.use_cuda = use_cuda 91 | self.batch_size = batch_size 92 | 93 | if use_cuda is not None and use_cuda == 'all': 94 | self.h = self.h.cuda() 95 | self.t = self.t.cuda() 96 | self.r = self.r.cuda() 97 | 98 | def __len__(self): 99 | return get_n_batches(len(self.h), self.batch_size) 100 | 101 | def __iter__(self): 102 | return _DataLoaderIter(self) 103 | 104 | 105 | class _DataLoaderIter: 106 | def __init__(self, loader): 107 | self.h = loader.h 108 | self.t = loader.t 109 | self.r = loader.r 110 | 111 | self.use_cuda = loader.use_cuda 112 | self.batch_size = loader.batch_size 113 | 114 | self.n_batches = get_n_batches(len(self.h), self.batch_size) 115 | self.current_batch = 0 116 | 117 | def __next__(self): 118 | if self.current_batch == self.n_batches: 119 | raise StopIteration 120 | else: 121 | i = self.current_batch 122 | self.current_batch += 1 123 | 124 | tmp_h = self.h[i * self.batch_size: (i + 1) * self.batch_size] 125 | tmp_t = self.t[i * self.batch_size: (i + 1) * self.batch_size] 126 | tmp_r = self.r[i * self.batch_size: (i + 1) * self.batch_size] 127 | 128 | if self.use_cuda is not None and self.use_cuda == 'batch': 129 | return tmp_h.cuda(), tmp_t.cuda(), tmp_r.cuda() 130 | else: 131 | return tmp_h, tmp_t, tmp_r 132 | 133 | def __iter__(self): 134 | return self 135 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/data_redundancy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | 6 | This module contains functions implementing methods explained in `this 7 | paper`__ by Akrami et al. 8 | """ 9 | from itertools import combinations 10 | from torch import cat 11 | from tqdm.autonotebook import tqdm 12 | 13 | 14 | def concat_kgs(kg_tr, kg_val, kg_te): 15 | h = cat((kg_tr.head_idx, kg_val.head_idx, kg_te.head_idx)) 16 | t = cat((kg_tr.tail_idx, kg_val.tail_idx, kg_te.tail_idx)) 17 | r = cat((kg_tr.relations, kg_val.relations, kg_te.relations)) 18 | return h, t, r 19 | 20 | 21 | def get_pairs(kg, r, type='ht'): 22 | mask = (kg.relations == r) 23 | 24 | if type == 'ht': 25 | return set((i.item(), j.item()) for i, j in cat( 26 | (kg.head_idx[mask].view(-1, 1), 27 | kg.tail_idx[mask].view(-1, 1)), dim=1)) 28 | else: 29 | assert type == 'th' 30 | return set((j.item(), i.item()) for i, j in cat( 31 | (kg.head_idx[mask].view(-1, 1), 32 | kg.tail_idx[mask].view(-1, 1)), dim=1)) 33 | 34 | 35 | def count_triplets(kg1, kg2, duplicates, rev_duplicates): 36 | """ 37 | Parameters 38 | ---------- 39 | kg1: torchkge.data_structures.KnowledgeGraph 40 | kg2: torchkge.data_structures.KnowledgeGraph 41 | duplicates: list 42 | List returned by torchkge.utils.data_redundancy.duplicates. 43 | rev_duplicates: list 44 | List returned by torchkge.utils.data_redundancy.duplicates. 45 | 46 | Returns 47 | ------- 48 | n_duplicates: int 49 | Number of triplets in kg2 that have their duplicate triplet 50 | in kg1 51 | n_rev_duplicates: int 52 | Number of triplets in kg2 that have their reverse duplicate 53 | triplet in kg1. 54 | """ 55 | n_duplicates = 0 56 | for r1, r2 in duplicates: 57 | ht_tr = get_pairs(kg1, r2, type='ht') 58 | ht_te = get_pairs(kg2, r1, type='ht') 59 | 60 | n_duplicates += len(ht_te.intersection(ht_tr)) 61 | 62 | ht_tr = get_pairs(kg1, r1, type='ht') 63 | ht_te = get_pairs(kg2, r2, type='ht') 64 | 65 | n_duplicates += len(ht_te.intersection(ht_tr)) 66 | 67 | n_rev_duplicates = 0 68 | for r1, r2 in rev_duplicates: 69 | th_tr = get_pairs(kg1, r2, type='th') 70 | ht_te = get_pairs(kg2, r1, type='ht') 71 | 72 | n_rev_duplicates += len(ht_te.intersection(th_tr)) 73 | 74 | th_tr = get_pairs(kg1, r1, type='th') 75 | ht_te = get_pairs(kg2, r2, type='ht') 76 | 77 | n_rev_duplicates += len(ht_te.intersection(th_tr)) 78 | 79 | return n_duplicates, n_rev_duplicates 80 | 81 | 82 | def duplicates(kg_tr, kg_val, kg_te, theta1=0.8, theta2=0.8, 83 | verbose=False, counts=False, reverses=None): 84 | """Return the duplicate and reverse duplicate relations as explained 85 | in paper by Akrami et al. 86 | 87 | References 88 | ---------- 89 | * Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang. 90 | `Realistic Re-evaluation of Knowledge Graph Completion Methods: 91 | An Experimental Study. `_ 92 | SIGMOD’20, June 14–19, 2020, Portland, OR, USA 93 | 94 | Parameters 95 | ---------- 96 | kg_tr: torchkge.data_structures.KnowledgeGraph 97 | Train set 98 | kg_val: torchkge.data_structures.KnowledgeGraph 99 | Validation set 100 | kg_te: torchkge.data_structures.KnowledgeGraph 101 | Test set 102 | theta1: float 103 | First threshold (see paper). 104 | theta2: float 105 | Second threshold (see paper). 106 | verbose: bool 107 | counts: bool 108 | Should the triplets involving (reverse) duplicate relations be 109 | counted in all sets. 110 | reverses: list 111 | List of known reverse relations. 112 | 113 | Returns 114 | ------- 115 | duplicates: list 116 | List of pairs giving duplicate relations. 117 | rev_duplicates: list 118 | List of pairs giving reverse duplicate relations. 119 | """ 120 | if verbose: 121 | print('Computing Ts') 122 | 123 | if reverses is None: 124 | reverses = [] 125 | 126 | T = dict() 127 | T_inv = dict() 128 | lengths = dict() 129 | 130 | h, t, r = concat_kgs(kg_tr, kg_val, kg_te) 131 | 132 | for r_ in tqdm(range(kg_tr.n_rel)): 133 | mask = (r == r_) 134 | lengths[r_] = mask.sum().item() 135 | 136 | pairs = cat((h[mask].view(-1, 1), t[mask].view(-1, 1)), dim=1) 137 | 138 | T[r_] = set([(h_.item(), t_.item()) for h_, t_ in pairs]) 139 | T_inv[r_] = set([(t_.item(), h_.item()) for h_, t_ in pairs]) 140 | 141 | if verbose: 142 | print('Finding duplicate relations') 143 | 144 | duplicates = [] 145 | rev_duplicates = [] 146 | 147 | iter_ = list(combinations(range(1345), 2)) 148 | 149 | for r1, r2 in tqdm(iter_): 150 | a = len(T[r1].intersection(T[r2])) / lengths[r1] 151 | b = len(T[r1].intersection(T[r2])) / lengths[r2] 152 | 153 | if a > theta1 and b > theta2: 154 | duplicates.append((r1, r2)) 155 | 156 | if (r1, r2) not in reverses: 157 | a = len(T[r1].intersection(T_inv[r2])) / lengths[r1] 158 | b = len(T[r1].intersection(T_inv[r2])) / lengths[r2] 159 | 160 | if a > theta1 and b > theta2: 161 | rev_duplicates.append((r1, r2)) 162 | 163 | if verbose: 164 | print('Duplicate relations: {}'.format(len(duplicates))) 165 | print('Reverse duplicate relations: ' 166 | '{}\n'.format(len(rev_duplicates))) 167 | 168 | if counts: 169 | dupl, rev = count_triplets(kg_tr, kg_tr, duplicates, rev_duplicates) 170 | print('{} train triplets have duplicate in train set ' 171 | '({}%)'.format(dupl, int(dupl / len(kg_tr)))) 172 | print('{} train triplets have reverse duplicate in train set ' 173 | '({}%)\n'.format(rev, int(rev / len(kg_tr) * 100))) 174 | 175 | dupl, rev = count_triplets(kg_tr, kg_te, duplicates, rev_duplicates) 176 | print('{} test triplets have duplicate in train set ' 177 | '({}%)'.format(dupl, int(dupl / len(kg_te)))) 178 | print('{} test triplets have reverse duplicate in train set ' 179 | '({}%)\n'.format(rev, int(rev / len(kg_te) * 100))) 180 | 181 | dupl, rev = count_triplets(kg_te, kg_te, duplicates, rev_duplicates) 182 | print('{} test triplets have duplicate in test set ' 183 | '({}%)'.format(dupl, int(dupl / len(kg_te)))) 184 | print('{} test triplets have reverse duplicate in test set ' 185 | '({}%)\n'.format(rev, int(rev / len(kg_te) * 100))) 186 | 187 | return duplicates, rev_duplicates 188 | 189 | 190 | def cartesian_product_relations(kg_tr, kg_val, kg_te, theta=0.8): 191 | """Return the cartesian product relations as explained in paper by 192 | Akrami et al. 193 | 194 | References 195 | ---------- 196 | * Farahnaz Akrami, Mohammed Samiul Saeef, Quingheng Zhang. 197 | `Realistic Re-evaluation of Knowledge Graph Completion Methods: An 198 | Experimental Study. `_ 199 | SIGMOD’20, June 14–19, 2020, Portland, OR, USA 200 | 201 | Parameters 202 | ---------- 203 | kg_tr: torchkge.data_structures.KnowledgeGraph 204 | Train set 205 | kg_val: torchkge.data_structures.KnowledgeGraph 206 | Validation set 207 | kg_te: torchkge.data_structures.KnowledgeGraph 208 | Test set 209 | theta: float 210 | Threshold used to compute the cartesian product relations. 211 | 212 | Returns 213 | ------- 214 | selected_relations: list 215 | List of relations index that are cartesian product relations 216 | (see paper for details). 217 | 218 | """ 219 | selected_relations = [] 220 | 221 | h, t, r = concat_kgs(kg_tr, kg_val, kg_te) 222 | 223 | S = dict() 224 | O = dict() 225 | lengths = dict() 226 | 227 | for r_ in tqdm(range(kg_tr.n_rel)): 228 | mask = (r == r_) 229 | lengths[r_] = mask.sum().item() 230 | 231 | S[r_] = set(h_.item() for h_ in h[mask]) 232 | O[r_] = set(t_.item() for t_ in t[mask]) 233 | 234 | if lengths[r_] / (len(S[r_]) * len(O[r_])) > theta: 235 | selected_relations.append(r_) 236 | 237 | return selected_relations 238 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/dissimilarities.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from math import pi 8 | from torch import abs, cos, min 9 | 10 | 11 | def l1_dissimilarity(a, b): 12 | """Compute dissimilarity between rows of `a` and `b` as :math:`||a-b||_1`. 13 | 14 | """ 15 | assert len(a.shape) == len(b.shape) 16 | return (a-b).norm(p=1, dim=-1) 17 | 18 | 19 | def l2_dissimilarity(a, b): 20 | """Compute dissimilarity between rows of `a` and `b` as 21 | :math:`||a-b||_2^2`. 22 | 23 | """ 24 | assert len(a.shape) == len(b.shape) 25 | return (a-b).norm(p=2, dim=-1)**2 26 | 27 | 28 | def l1_torus_dissimilarity(a, b): 29 | """See `paper by Ebisu et al. `_ 30 | for details about the definition of this dissimilarity function. 31 | 32 | """ 33 | assert len(a.shape) == len(b.shape) 34 | return 2 * min(abs(a - b), 1 - abs(a - b)).sum(dim=-1) 35 | 36 | 37 | def l2_torus_dissimilarity(a, b): 38 | """See `paper by Ebisu et al. `_ 39 | for details about the definition of this dissimilarity function. 40 | 41 | """ 42 | assert len(a.shape) == len(b.shape) 43 | return 4 * min((a - b) ** 2, 1 - (a - b) ** 2).sum(dim=-1) 44 | 45 | 46 | def el2_torus_dissimilarity(a, b): 47 | """See `paper by Ebisu et al. `_ 48 | for details about the definition of this dissimilarity function. 49 | 50 | """ 51 | assert len(a.shape) == len(b.shape) 52 | tmp = min(a - b, 1 - (a - b)) 53 | tmp = 2 * (1 - cos(2 * pi * tmp)) 54 | return tmp.sum(dim=-1) / 4 55 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import ones_like, zeros_like 8 | from torch.nn import Module, Sigmoid 9 | from torch.nn import MarginRankingLoss, SoftMarginLoss, BCELoss 10 | 11 | 12 | class MarginLoss(Module): 13 | """Margin loss as it was defined in `TransE paper 14 | `_ 15 | by Bordes et al. in 2013. This class implements :class:`torch.nn.Module` 16 | interface. 17 | 18 | """ 19 | def __init__(self, margin): 20 | super().__init__() 21 | self.loss = MarginRankingLoss(margin=margin, reduction='sum') 22 | 23 | def forward(self, positive_triplets, negative_triplets): 24 | """ 25 | Parameters 26 | ---------- 27 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 28 | Scores of the true triplets as returned by the `forward` methods of 29 | the models. 30 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 31 | Scores of the negative triplets as returned by the `forward` 32 | methods of the models. 33 | 34 | Returns 35 | ------- 36 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 37 | Loss of the form 38 | :math:`\\max\\{0, \\gamma - f(h,r,t) + f(h',r',t')\\}` where 39 | :math:`\\gamma` is the margin (defined at initialization), 40 | :math:`f(h,r,t)` is the score of a true fact and 41 | :math:`f(h',r',t')` is the score of the associated negative fact. 42 | """ 43 | return self.loss(positive_triplets, negative_triplets, 44 | target=ones_like(positive_triplets)) 45 | 46 | 47 | class LogisticLoss(Module): 48 | """Logistic loss as it was defined in `TransE paper 49 | `_ 50 | by Bordes et al. in 2013. This class implements :class:`torch.nn.Module` 51 | interface. 52 | 53 | """ 54 | def __init__(self): 55 | super().__init__() 56 | self.loss = SoftMarginLoss(reduction='sum') 57 | 58 | def forward(self, positive_triplets, negative_triplets): 59 | """ 60 | Parameters 61 | ---------- 62 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 63 | Scores of the true triplets as returned by the `forward` methods 64 | of the models. 65 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 66 | Scores of the negative triplets as returned by the `forward` 67 | methods of the models. 68 | Returns 69 | ------- 70 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 71 | Loss of the form :math:`\\log(1+ \\exp(\\eta \\times f(h,r,t))` 72 | where :math:`f(h,r,t)` is the score of the fact and :math:`\\eta` 73 | is either 1 or -1 if the fact is true or false. 74 | """ 75 | targets = ones_like(positive_triplets) 76 | return self.loss(positive_triplets, targets) + \ 77 | self.loss(negative_triplets, -targets) 78 | 79 | 80 | class BinaryCrossEntropyLoss(Module): 81 | """This class implements :class:`torch.nn.Module` interface. 82 | 83 | """ 84 | 85 | def __init__(self): 86 | super().__init__() 87 | self.sig = Sigmoid() 88 | self.loss = BCELoss(reduction='sum') 89 | 90 | def forward(self, positive_triplets, negative_triplets): 91 | """ 92 | 93 | Parameters 94 | ---------- 95 | positive_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 96 | Scores of the true triplets as returned by the `forward` methods 97 | of the models. 98 | negative_triplets: torch.Tensor, dtype: torch.float, shape: (b_size) 99 | Scores of the negative triplets as returned by the `forward` 100 | methods of the models. 101 | Returns 102 | ------- 103 | loss: torch.Tensor, shape: (n_facts, dim), dtype: torch.float 104 | Loss of the form :math:`-\\eta \\cdot \\log(f(h,r,t)) + 105 | (1-\\eta) \\cdot \\log(1 - f(h,r,t))` where :math:`f(h,r,t)` 106 | is the score of the fact and :math:`\\eta` is either 1 or 107 | 0 if the fact is true or false. 108 | """ 109 | return self.loss(self.sig(positive_triplets), 110 | ones_like(positive_triplets)) + \ 111 | self.loss(self.sig(negative_triplets), 112 | zeros_like(negative_triplets)) 113 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/modeling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from torch import tensor 8 | from torch.nn import Embedding, Linear 9 | from torch.nn.init import xavier_uniform_ 10 | 11 | import pickle 12 | import tarfile 13 | 14 | from .data import get_data_home 15 | 16 | from os import makedirs, remove 17 | from os.path import exists 18 | from urllib.request import urlretrieve 19 | 20 | 21 | def init_embedding(n_vectors, dim): 22 | """Create a torch.nn.Embedding object with `n_vectors` samples and `dim` 23 | dimensions. It is then initialized with Xavier uniform distribution. 24 | """ 25 | entity_embeddings = Embedding(n_vectors, dim) 26 | xavier_uniform_(entity_embeddings.weight.data) 27 | 28 | return entity_embeddings 29 | 30 | 31 | def init_linear_projection(n_vectors, dim, bias=False): 32 | """Create a torch.nn.Linear object with `n_vectors` samples and `dim` 33 | dimensions. It is then initialized with Xavier uniform distribution. 34 | """ 35 | linear_projection = Linear(n_vectors, dim, bias=bias) 36 | xavier_uniform_(linear_projection.weight.data) 37 | 38 | return linear_projection 39 | 40 | 41 | def load_embeddings(model, dim, dataset, data_home=None): 42 | 43 | if data_home is None: 44 | data_home = get_data_home() 45 | data_path = data_home + '/models/' 46 | targz_file = data_path + '{}_{}_{}.tar.gz'.format(model, dataset, dim) 47 | pkl_file = data_path + '{}_{}_{}.pkl'.format(model, dataset, dim) 48 | if not exists(pkl_file): 49 | if not exists(data_path): 50 | makedirs(data_path, exist_ok=True) 51 | urlretrieve("https://graphs.telecom-paris.fr/data/torchkge/models/{}_{}_{}.tar.gz".format(model, dataset, dim), 52 | targz_file) 53 | with tarfile.open(targz_file, 'r') as tf: 54 | 55 | import os 56 | 57 | def is_within_directory(directory, target): 58 | 59 | abs_directory = os.path.abspath(directory) 60 | abs_target = os.path.abspath(target) 61 | 62 | prefix = os.path.commonprefix([abs_directory, abs_target]) 63 | 64 | return prefix == abs_directory 65 | 66 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 67 | 68 | for member in tar.getmembers(): 69 | member_path = os.path.join(path, member.name) 70 | if not is_within_directory(path, member_path): 71 | raise Exception("Attempted Path Traversal in Tar File") 72 | 73 | tar.extractall(path, members, numeric_owner=numeric_owner) 74 | 75 | 76 | safe_extract(tf, data_path) 77 | remove(targz_file) 78 | 79 | with open(pkl_file, 'rb') as f: 80 | state_dict = pickle.load(f) 81 | 82 | return state_dict 83 | 84 | 85 | def get_true_targets(dictionary, key1, key2, true_idx, i): 86 | """For a current index `i` of the batch, returns a tensor containing the 87 | indices of entities for which the triplet is an existing one (i.e. a true 88 | one under CWA). 89 | 90 | Parameters 91 | ---------- 92 | dictionary: default dict 93 | Dictionary of keys (int, int) and values list of ints giving all 94 | possible entities for the (entity, relation) pair. 95 | key1: torch.Tensor, shape: (batch_size), dtype: torch.long 96 | key2: torch.Tensor, shape: (batch_size), dtype: torch.long 97 | true_idx: torch.Tensor, shape: (batch_size), dtype: torch.long 98 | Tensor containing the true entity for each sample. 99 | i: int 100 | Indicates which index of the batch is currently treated. 101 | 102 | Returns 103 | ------- 104 | true_targets: torch.Tensor, shape: (batch_size), dtype: torch.long 105 | Tensor containing the indices of entities such that 106 | (e_idx[i], r_idx[i], true_target[any]) is a true fact. 107 | 108 | """ 109 | try: 110 | true_targets = dictionary[key1[i].item(), key2[i].item()].copy() 111 | if true_idx is not None: 112 | true_targets.remove(true_idx[i].item()) 113 | if len(true_targets) > 0: 114 | return tensor(list(true_targets)).long() 115 | else: 116 | return None 117 | else: 118 | return tensor(list(true_targets)).long() 119 | except KeyError: 120 | return None 121 | 122 | 123 | def filter_scores(scores, dictionary, key1, key2, true_idx): 124 | # filter out the true negative samples by assigning - inf score. 125 | b_size = scores.shape[0] 126 | filt_scores = scores.clone() 127 | 128 | for i in range(b_size): 129 | true_targets = get_true_targets(dictionary, key1, key2, true_idx, i) 130 | if true_targets is None: 131 | continue 132 | filt_scores[i][true_targets] = - float('Inf') 133 | 134 | return filt_scores 135 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/operations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from collections import defaultdict 8 | from pandas import DataFrame 9 | from torch import zeros, cat 10 | 11 | 12 | def get_mask(length, start, end): 13 | """Create a mask of length `length` filled with 0s except between indices 14 | `start` (included) and `end` (excluded). 15 | 16 | Parameters 17 | ---------- 18 | length: int 19 | Length of the mask to be created. 20 | start: int 21 | First index (included) where the mask will be filled with 0s. 22 | end: int 23 | Last index (excluded) where the mask will be filled with 0s. 24 | 25 | Returns 26 | ------- 27 | mask: `torch.Tensor`, shape: (length), dtype: `torch.bool` 28 | Mask of length `length` filled with 0s except between indices `start` 29 | (included) and `end` (excluded). 30 | """ 31 | mask = zeros(length) 32 | mask[[i for i in range(start, end)]] = 1 33 | return mask.bool() 34 | 35 | 36 | def get_rank(data, true, low_values=False): 37 | """Computes the rank of entity at index true[i]. If the rank is k then 38 | there are k-1 entities with better (higher or lower) value in data. 39 | 40 | Parameters 41 | ---------- 42 | data: `torch.Tensor`, dtype: `torch.float`, shape: (n_facts, dimensions) 43 | Scores for each entity. 44 | true: `torch.Tensor`, dtype: `torch.int`, shape: (n_facts) 45 | true[i] is the index of the true entity for test i of the batch. 46 | low_values: bool, optional (default=False) 47 | if True, best rank is the lowest score else it is the highest. 48 | 49 | Returns 50 | ------- 51 | ranks: `torch.Tensor`, dtype: `torch.int`, shape: (n_facts) 52 | ranks[i] - 1 is the number of entities which have better (or same) 53 | scores in data than the one and index true[i] 54 | """ 55 | true_data = data.gather(1, true.long().view(-1, 1)) 56 | 57 | if low_values: 58 | return (data <= true_data).sum(dim=1) 59 | else: 60 | return (data >= true_data).sum(dim=1) 61 | 62 | 63 | def get_dictionaries(df, ent=True): 64 | """Build entities or relations dictionaries. 65 | 66 | Parameters 67 | ---------- 68 | df: `pandas.DataFrame` 69 | Data frame containing three columns [from, to, rel]. 70 | ent: bool 71 | if True then ent2ix is returned, if False then rel2ix is returned. 72 | 73 | Returns 74 | ------- 75 | dict: dictionary 76 | Either ent2ix or rel2ix. 77 | 78 | """ 79 | if ent: 80 | tmp = list(set(df['from'].unique()).union(set(df['to'].unique()))) 81 | return {ent: i for i, ent in enumerate(sorted(tmp))} 82 | else: 83 | tmp = list(df['rel'].unique()) 84 | return {rel: i for i, rel in enumerate(sorted(tmp))} 85 | 86 | 87 | def get_tph(t): 88 | """Get the average number of tail per heads for each relation. 89 | 90 | Parameters 91 | ---------- 92 | t: `torch.Tensor`, dtype: `torch.long`, shape: (b_size, 3) 93 | First column contains head indices, second tails and third relations. 94 | Returns 95 | ------- 96 | d: dict 97 | keys: relation indices, values: average number of tail per heads. 98 | """ 99 | df = DataFrame(t.numpy(), columns=['from', 'to', 'rel']) 100 | df = df.groupby(['from', 'rel']).count().groupby('rel').mean() 101 | df.reset_index(inplace=True) 102 | return {df.loc[i].values[0]: df.loc[i].values[1] for i in df.index} 103 | 104 | 105 | def get_hpt(t): 106 | """Get the average number of head per tails for each relation. 107 | 108 | Parameters 109 | ---------- 110 | t: `torch.Tensor`, dtype: `torch.long`, shape: (b_size, 3) 111 | First column contains head indices, second tails and third relations. 112 | Returns 113 | ------- 114 | d: dict 115 | keys: relation indices, values: average number of head per tails. 116 | """ 117 | df = DataFrame(t.numpy(), columns=['from', 'to', 'rel']) 118 | df = df.groupby(['rel', 'to']).count().groupby('rel').mean() 119 | df.reset_index(inplace=True) 120 | return {df.loc[i].values[0]: df.loc[i].values[1] for i in df.index} 121 | 122 | 123 | def get_bernoulli_probs(kg): 124 | """Evaluate the Bernoulli probabilities for negative sampling as in the 125 | TransH original paper by Wang et al. (2014). 126 | 127 | Parameters 128 | ---------- 129 | kg: `torchkge.data_structures.KnowledgeGraph` 130 | 131 | Returns 132 | ------- 133 | tph: dict 134 | keys: relations , values: sampling probabilities as described by 135 | Wang et al. in their paper. 136 | 137 | """ 138 | t = cat((kg.head_idx.view(-1, 1), 139 | kg.tail_idx.view(-1, 1), 140 | kg.relations.view(-1, 1)), dim=1) 141 | 142 | hpt = get_hpt(t) 143 | tph = get_tph(t) 144 | 145 | assert hpt.keys() == tph.keys() 146 | 147 | for k in tph.keys(): 148 | tph[k] = tph[k] / (tph[k] + hpt[k]) 149 | 150 | return tph 151 | 152 | 153 | def get_fitlering_dictionaries(kg, kg_te=None): 154 | dict_of_heads = defaultdict(set) 155 | dict_of_tails = defaultdict(set) 156 | dict_of_rels = defaultdict(set) 157 | for i in range(kg.n_facts): 158 | dict_of_heads[(kg.tail_idx[i].item(), kg.relations[i].item())].add(kg.head_idx[i].item()) 159 | dict_of_tails[(kg.head_idx[i].item(), kg.relations[i].item())].add(kg.tail_idx[i].item()) 160 | dict_of_rels[(kg.head_idx[i].item(), kg.tail_idx[i].item())].add(kg.relations[i].item()) 161 | if kg_te is not None: 162 | for i in range(kg_te.n_facts): 163 | dict_of_rels[(kg_te.tail_idx[i].item(), kg_te.relations[i].item())].add(kg_te.head_idx[i].item()) 164 | dict_of_rels[(kg_te.head_idx[i].item(), kg_te.relations[i].item())].add(kg_te.tail_idx[i].item()) 165 | dict_of_rels[(kg_te.head_idx[i].item(), kg_te.tail_idx[i].item())].add(kg_te.relations[i].item()) 166 | return dict_of_heads, dict_of_tails, dict_of_rels -------------------------------------------------------------------------------- /torchkge/torchkge/utils/pretrained_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | 7 | from ..exceptions import NoPreTrainedVersionError 8 | from ..models import TransEModel, ComplExModel 9 | from ..utils import load_embeddings 10 | 11 | 12 | def load_pretrained_transe(dataset, emb_dim, data_home=None): 13 | """Load a pretrained version of TransE model. 14 | 15 | Parameters 16 | ---------- 17 | dataset: str 18 | emb_dim: int 19 | Embedding dimension 20 | data_home: str (opt, default None) 21 | Path to the `torchkge_data` directory (containing data folders). Useful 22 | for pre-trained model loading. 23 | 24 | Returns 25 | ------- 26 | model: `TorchKGE.model.translation.TransEModel` 27 | Pretrained version of TransE model. 28 | """ 29 | try: 30 | assert (dataset in {'fb15k', 'wn18rr'} and emb_dim == 100) \ 31 | or (dataset == 'fb15k237' and emb_dim == 150) \ 32 | or (dataset == 'wdv5' and emb_dim == 150) \ 33 | or (dataset == 'yago310' and emb_dim == 200) 34 | 35 | except AssertionError: 36 | raise NoPreTrainedVersionError('No pre-trained version of TransE for ' 37 | '{} in dimension {}'.format(dataset, 38 | emb_dim)) 39 | 40 | state_dict = load_embeddings('transe', emb_dim, dataset, data_home) 41 | model = TransEModel(emb_dim, 42 | n_entities=state_dict['ent_emb.weight'].shape[0], 43 | n_relations=state_dict['rel_emb.weight'].shape[0], 44 | dissimilarity_type='L2') 45 | model.load_state_dict(state_dict) 46 | 47 | return model 48 | 49 | 50 | def load_pretrained_complex(dataset, emb_dim, data_home=None): 51 | """Load a pretrained version of ComplEx model. 52 | 53 | Parameters 54 | ---------- 55 | dataset: str 56 | emb_dim: int 57 | Embedding dimension 58 | data_home: str (opt, default None) 59 | Path to the `torchkge_data` directory (containing data folders). Useful 60 | for pre-trained model loading. 61 | 62 | Returns 63 | ------- 64 | model: `TorchKGE.model.translation.ComplExModel` 65 | Pretrained version of ComplEx model. 66 | """ 67 | try: 68 | assert (dataset == 'wn18rr' and emb_dim == 200) \ 69 | or (dataset == 'fb15k237' and emb_dim == 200) \ 70 | or (dataset == 'wdv5' and emb_dim == 200) 71 | 72 | except AssertionError: 73 | raise NoPreTrainedVersionError('No pre-trained version of ComplEx for ' 74 | '{} in dimension {}'.format(dataset, 75 | emb_dim)) 76 | 77 | state_dict = load_embeddings('complex', emb_dim, dataset, data_home) 78 | model = ComplExModel(emb_dim, 79 | n_entities=state_dict['re_ent_emb.weight'].shape[0], 80 | n_relations=state_dict['re_rel_emb.weight'].shape[0]) 81 | model.load_state_dict(state_dict) 82 | 83 | return model 84 | -------------------------------------------------------------------------------- /torchkge/torchkge/utils/training.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Copyright TorchKGE developers 4 | @author: Armand Boschin 5 | """ 6 | import torch 7 | import logging 8 | from tqdm.autonotebook import tqdm 9 | from ..sampling import BernoulliNegativeSampler, UniformNegativeSampler 10 | from ..utils.data import get_n_batches 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)s] %(message)s", 14 | datefmt="%Y/%m/%d %H:%M:%S", 15 | level=logging.INFO 16 | ) 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class TrainDataLoader: 21 | """Dataloader providing the training process with batches of true and 22 | negatively sampled facts. 23 | 24 | Parameters 25 | ---------- 26 | kg: torchkge.data_structures.KnowledgeGraph 27 | Dataset to be divided in batches. 28 | batch_size: int 29 | Size of the batches. 30 | sampling_type: str 31 | Either 'unif' (uniform negative sampling) or 'bern' (Bernoulli negative 32 | sampling). 33 | use_cuda: str (opt, default = None) 34 | Can be either None (no use of cuda at all), 'all' to move all the 35 | dataset to cuda and then split in batches or 'batch' to simply move 36 | the batches to cuda before they are returned. 37 | 38 | """ 39 | 40 | def __init__(self, kg, batch_size, sampling_type, n_neg, use_cuda=None): 41 | self.h = kg.head_idx 42 | self.t = kg.tail_idx 43 | self.r = kg.relations 44 | 45 | self.use_cuda = use_cuda 46 | self.b_size = batch_size 47 | 48 | if sampling_type == 'unif': 49 | self.sampler = UniformNegativeSampler(kg, n_neg=n_neg) 50 | elif sampling_type == 'bern': 51 | self.sampler = BernoulliNegativeSampler(kg, n_neg=n_neg) 52 | 53 | self.tmp_cuda = use_cuda in ['batch', 'all'] 54 | 55 | if use_cuda is not None and use_cuda == 'all': 56 | self.h = self.h.cuda() 57 | self.t = self.t.cuda() 58 | self.r = self.r.cuda() 59 | 60 | def __len__(self): 61 | return get_n_batches(len(self.h), self.b_size) 62 | 63 | def __iter__(self): 64 | return TrainDataLoaderIter(self) 65 | 66 | 67 | class TrainDataLoaderIter: 68 | def __init__(self, loader): 69 | self.h = loader.h 70 | self.t = loader.t 71 | self.r = loader.r 72 | 73 | self.nh, self.nt = loader.sampler.corrupt_kg(loader.b_size, 74 | loader.tmp_cuda) 75 | if loader.use_cuda: 76 | self.nh = self.nh.cuda() 77 | self.nt = self.nt.cuda() 78 | 79 | self.use_cuda = loader.use_cuda 80 | self.b_size = loader.b_size 81 | 82 | self.n_batches = get_n_batches(len(self.h), self.b_size) 83 | self.current_batch = 0 84 | 85 | def __next__(self): 86 | if self.current_batch == self.n_batches: 87 | raise StopIteration 88 | else: 89 | i = self.current_batch 90 | self.current_batch += 1 91 | 92 | batch = dict() 93 | batch['h'] = self.h[i * self.b_size: (i + 1) * self.b_size] 94 | batch['t'] = self.t[i * self.b_size: (i + 1) * self.b_size] 95 | batch['r'] = self.r[i * self.b_size: (i + 1) * self.b_size] 96 | batch['nh'] = self.nh[i * self.b_size: (i + 1) * self.b_size] 97 | batch['nt'] = self.nt[i * self.b_size: (i + 1) * self.b_size] 98 | 99 | if self.use_cuda == 'batch': 100 | batch['h'] = batch['h'].cuda() 101 | batch['t'] = batch['t'].cuda() 102 | batch['r'] = batch['r'].cuda() 103 | batch['nh'] = batch['nh'].cuda() 104 | batch['nt'] = batch['nt'].cuda() 105 | 106 | return batch 107 | 108 | def __iter__(self): 109 | return self 110 | 111 | 112 | class Trainer: 113 | """This class simply wraps a simple training procedure. 114 | 115 | Parameters 116 | ---------- 117 | model: torchkge.models.interfaces.Model 118 | Model to be trained. 119 | criterion: 120 | Criteria which should differentiate positive and negative scores. Can 121 | be an elements of torchkge.utils.losses 122 | kg_train: torchkge.data_structures.KnowledgeGraph 123 | KG used for training. 124 | n_epochs: int 125 | Number of epochs in the training procedure. 126 | n_batches: int 127 | Number of batches to use. 128 | sampling_type: str 129 | Either 'unif' (uniform negative sampling) or 'bern' (Bernoulli negative 130 | sampling). 131 | use_cuda: str (opt, default = None) 132 | Can be either None (no use of cuda at all), 'all' to move all the 133 | dataset to cuda and then split in batches or 'batch' to simply move 134 | the batches to cuda before they are returned. 135 | 136 | 137 | Attributes 138 | ---------- 139 | 140 | """ 141 | def __init__(self, model, criterion, kg_train, n_epochs, batch_size, 142 | optimizer, scheduler, model_save_path, sampling_type='bern', n_neg=1, 143 | use_cuda=None, fp16=False, scaler=None, log_steps=100, 144 | start_epoch=0, save_epochs=None, gradient_accumulation_steps=1): 145 | 146 | self.model = model 147 | self.criterion = criterion 148 | self.kg_train = kg_train 149 | self.use_cuda = use_cuda 150 | self.n_epochs = n_epochs 151 | self.optimizer = optimizer 152 | self.scheduler = scheduler 153 | self.model_save_path = model_save_path 154 | self.sampling_type = sampling_type 155 | self.n_neg = n_neg 156 | self.batch_size = batch_size 157 | self.n_triples = len(kg_train) 158 | self.fp16 = fp16 159 | self.scaler = scaler 160 | self.log_steps = log_steps 161 | self.start_epoch = start_epoch 162 | self.save_epoch = save_epochs 163 | self.gradient_accumulation_steps=gradient_accumulation_steps 164 | 165 | def process_batch(self, current_batch): 166 | self.optimizer.zero_grad() 167 | 168 | h, t, r = current_batch['h'], current_batch['t'], current_batch['r'] 169 | nh, nt = current_batch['nh'], current_batch['nt'] 170 | 171 | if self.fp16: 172 | with torch.cuda.amp.autocast(): 173 | p, n = self.model(h, t, r, nh, nt) 174 | loss = self.criterion(p, n) 175 | self.scaler.scale(loss).backward() 176 | else: 177 | p, n = self.model(h, t, r, nh, nt) 178 | loss = self.criterion(p, n) 179 | loss.backward() 180 | 181 | return loss.detach().item() 182 | 183 | def run(self): 184 | if self.use_cuda in ['all', 'batch']: 185 | self.model.cuda() 186 | self.criterion.cuda() 187 | 188 | iterator = tqdm(range(self.start_epoch, self.n_epochs), unit='epoch') 189 | data_loader = TrainDataLoader(self.kg_train, 190 | batch_size=self.batch_size, 191 | sampling_type=self.sampling_type, 192 | n_neg=self.n_neg, 193 | use_cuda=self.use_cuda) 194 | for epoch in iterator: 195 | sum_ = 0 196 | for i, batch in enumerate(data_loader): 197 | loss = self.process_batch(batch) 198 | # 更新optimizer 199 | if (i + 1) % self.gradient_accumulation_steps == 0: 200 | if self.fp16: 201 | self.scaler.step(self.optimizer) 202 | self.scaler.update() 203 | else: 204 | self.optimizer.step() 205 | # 更新学习率 206 | self.scheduler.step() 207 | sum_ += loss 208 | if self.log_steps is not None and i % self.log_steps == 0: 209 | logger.info(f"[Epoch-{epoch + 1}] step: {i}, loss: {loss}") 210 | 211 | iterator.set_description( 212 | 'Epoch {} | mean loss: {:.5f}'.format(epoch + 1, sum_ / len(data_loader))) 213 | self.model.normalize_parameters() 214 | 215 | if self.save_epoch is not None and (epoch+1) % self.save_epoch == 0: 216 | torch.save(self.model.state_dict(), self.model_save_path.format(epoch+1)) 217 | 218 | torch.save(self.model.state_dict(), self.model_save_path.format(epoch+1)) 219 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | ROOT_DIR="${HOME}/Data" 5 | DATA_DIR=${ROOT_DIR} 6 | OUTPUT_DIR="${ROOT_DIR}/output" 7 | PRETRAINED_MODEL_PATH="${ROOT_DIR}/bert/roberta_large" 8 | 9 | # Roberta_large-v3.4 10 | python finetune_text.py \ 11 | --data_dir $DATA_DIR \ 12 | --output_dir $OUTPUT_DIR \ 13 | --model_name "roberta_large" \ 14 | --data_version "v3.4" \ 15 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 16 | --config_file "src/config/roberta_large.json" \ 17 | --do_train \ 18 | --interaction_type "one_tower" \ 19 | --classification_method "cls" \ 20 | --similarity_measure "NA" \ 21 | --loss_type "ce" \ 22 | --max_seq_len 50 \ 23 | --max_seq_len_pv 205 \ 24 | --train_batch_size 40 \ 25 | --learning_rate 5e-5 \ 26 | --fp16 27 | 28 | # Roberta_large-v3.4-cls_cat_1,2,3,4 29 | python finetune_text.py \ 30 | --data_dir $DATA_DIR \ 31 | --output_dir $OUTPUT_DIR \ 32 | --model_name "roberta_large" \ 33 | --data_version "v3.4" \ 34 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 35 | --config_file "src/config/roberta_large.json" \ 36 | --do_train \ 37 | --interaction_type "one_tower" \ 38 | --classification_method "cls_1,2,3,4_cat" \ 39 | --similarity_measure "NA" \ 40 | --loss_type "ce" \ 41 | --cls_layers "1,2,3,4" \ 42 | --cls_pool "cat" \ 43 | --max_seq_len 50 \ 44 | --max_seq_len_pv 205 \ 45 | --train_batch_size 40 \ 46 | --learning_rate 5e-5 \ 47 | --fp16 48 | 49 | # Roberta_large-v4 50 | python finetune_text.py \ 51 | --data_dir $DATA_DIR \ 52 | --output_dir $OUTPUT_DIR \ 53 | --model_name "roberta_large" \ 54 | --data_version "v4" \ 55 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 56 | --config_file "src/config/roberta_large.json" \ 57 | --do_train \ 58 | --interaction_type "one_tower" \ 59 | --classification_method "cls" \ 60 | --similarity_measure "NA" \ 61 | --loss_type "ce" \ 62 | --max_seq_len 50 \ 63 | --max_seq_len_pv 205 \ 64 | --train_batch_size 40 \ 65 | --learning_rate 5e-5 \ 66 | --fp16 67 | 68 | # pkgm_large-v3.4 69 | python finetune_text.py \ 70 | --data_dir $DATA_DIR \ 71 | --output_dir $OUTPUT_DIR \ 72 | --model_name "pkgm_large" \ 73 | --data_version "v3.4" \ 74 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 75 | --config_file "src/config/pkgm_large.json" \ 76 | --do_train \ 77 | --interaction_type "one_tower" \ 78 | --classification_method "cls" \ 79 | --similarity_measure "NA" \ 80 | --loss_type "ce" \ 81 | --max_seq_len 64 \ 82 | --max_pvs 30 \ 83 | --train_batch_size 256 \ 84 | --learning_rate 5e-5 \ 85 | --fp16 86 | 87 | # textcnn-v3.4 88 | python finetune_text.py \ 89 | --data_dir $DATA_DIR \ 90 | --output_dir $OUTPUT_DIR \ 91 | --model_name "textcnn" \ 92 | --data_version "v3.4" \ 93 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 94 | --config_file "src/config/roberta_large.json" \ 95 | --do_train \ 96 | --interaction_type "two_tower" \ 97 | --classification_method "cls" \ 98 | --similarity_measure "NA" \ 99 | --loss_type "ce" \ 100 | --max_seq_len 50 \ 101 | --max_seq_len_pv 205 \ 102 | --train_batch_size 256 \ 103 | --learning_rate 5e-5 \ 104 | --fp16 105 | 106 | # bert_base 107 | python finetune_bert.py 108 | 109 | # roberta_image_large-v5 110 | python finetune_multimodal.py \ 111 | --data_dir $DATA_DIR \ 112 | --output_dir $OUTPUT_DIR \ 113 | --model_name "roberta_image_large" \ 114 | --data_version "v5" \ 115 | --config_file "src/config/roberta_image_large.json" \ 116 | --pretrained_model_path $PRETRAINED_MODEL_PATH \ 117 | --do_train \ 118 | --interaction_type "one_tower" \ 119 | --classification_method "cls" \ 120 | --ensemble "begin" \ 121 | --loss_type "ce" \ 122 | --max_seq_len 50 \ 123 | --max_seq_len_pv 205 \ 124 | --train_batch_size 40 \ 125 | --learning_rate 5e-5 \ 126 | --fp16 127 | 128 | # eca_nfnet_l0-v6 129 | python finetune_image.py \ 130 | --data_dir $DATA_DIR/raw \ 131 | --output_dir $OUTPUT_DIR \ 132 | --model_name "eca_nfnet_l0" \ 133 | --data_version "v6" \ 134 | --config_file "src/config/eca_nfnet_l0.json" \ 135 | --do_train \ 136 | --image_size 1000 \ 137 | --train_batch_size 64 \ 138 | --learning_rate 5e-5 \ 139 | --fp16 140 | 141 | --------------------------------------------------------------------------------