├── README.md ├── config.py ├── dataset ├── __init__.py ├── extract.cpp ├── filternyt.py └── nyt.py ├── main_att.py ├── main_mil.py ├── models ├── BasicModule.py ├── PCNN_ATT.py ├── PCNN_ONE.py └── __init__.py ├── plot.ipynb └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | **2019.03.09更新**: 3 | 4 | - 更新至Python3.X 5 | - 更新至Pytorch 0.4+(移除Variable等) 6 | - 使用mask作 Piece Pooling 7 | - 相比FilterNYT,建议使用大版本数据集NYT 8 | 9 | 10 | 2019.03.05: 11 | 12 | 修复`mask piece wise`的bug. 13 | 14 | - 更新至pytorch 0.4+, 0.3版本不兼容 15 | 16 | 17 | 2018.11.3: 18 | 19 | **基于mask的**`use_pcnn=True`目前有一些问题,正在修改, 建议: 20 | 21 | - 直接使用 `use_pcnn=False` 测试,性能差不太多 22 | - 使用mask修改之前的版本: https://github.com/ShomyLiu/pytorch-relation-extraction/tree/7e3ef1720d43690fc0da0d81e54bdc0fc0cf822a 23 | 24 | 25 | 2018.10.14更新: 26 | 27 | 全监督的关系抽取PCNN(Zeng 2014)的代码地址: [PCNN](https://github.com/ShomyLiu/pytorch-pcnn) 28 | 29 | 30 | 2018.9.10 更新: 31 | - 参考OpenNRE使用mask可以快速计算piece wise pooling. 32 | - 修改NYT 53类数据处理 (完成) 33 | - 修改NYT 27类数据处理 (未完成) 34 | 35 | 数据处理已经修改 36 | 37 | 使用Pytorch 复现 PCNN+MIL (Zeng 2015) 与 PCNN+ATT (Lin 2016), 以及两个模型在两个大小版本的数据集上(27类关系/53类关系)的表现对比。 38 | 39 | 40 | 41 | 相关博客: 42 | 43 | - [关系抽取论文笔记](http://shomy.top/2018/02/28/relation-extraction/) 44 | 45 | - [复现结果说明](http://shomy.top/2018/07/05/pytorch-relation-extraction/) 46 | 47 | 48 | 49 | 在代码的组织,结构设计上, 主要参考 [陈云Pytorch实战指南](https://zhuanlan.zhihu.com/p/29024978) (个人推荐)。因此一些实现细节就不再赘述了,可以参考陈云的实战指南。 50 | 51 | 52 | 53 | ## 实现总览 54 | 55 | 56 | 环境: 57 | 58 | - Python 2.X 59 | - Pytorch 0.3.1 60 | - fire 61 | 62 | 简单介绍主要目录: 63 | 64 | ``` 65 | ├── checkpoints # 保存预加载模型 66 | ├── config.py # 参数 67 | ├── dataset # 数据目录 68 | │ ├── FilterNYT # SMALL 数据 69 | │ ├── NYT # LARGE 数据 70 | │ ├── filternyt.py 71 | │ ├── __init__.py 72 | │ ├── nyt.py 73 | ├── main_mil.py # PCNN+ONE 主文件 74 | ├── main_att.py # PCNN+ATT 主文件 75 | ├── models # 模型目录 76 | │ ├── BasicModule.py 77 | │ ├── __init__.py 78 | │ ├── PCNN_ATT.py 79 | │ ├── PCNN_ONE.py 80 | ├── plot.ipynb 81 | ├── README.md 82 | ├── utils.py # 工具函数 83 | ``` 84 | 85 | 86 | 87 | 这份代码基本上是按照陈云的指南模仿来写的。 数据模型分开,参数/配置单独文件, 并且使用fire 库来管理命令行参数,更加方便修改参数。 88 | 89 | 因为PCNN+ONE和PCNN+ATT的训练,测试方法不太一样,因此为了简单起见, 分别写了主文件: `main_mil.py`与`main_att.py`。 90 | 91 | 训练方式一样,如使用PCNN+ONE 训练大数据集, 后面可以直接修改参数, 默认使用`config.py`的参数: 92 | 93 | ``` 94 | 95 | python main_mil.py train --data="NYT" --batch_size=128 96 | 97 | ``` 98 | 99 | 注:需要提前按照下一节处理下数据(主要是生成npy格式的数据,方便直接被模型导入). 100 | 101 | 102 | 103 | ## 数据预处理 104 | 105 | 为了节省空间, 上传了LARGE和SMALL两份的原生数据,因此需要用数据预处理下,从而生成npy格式数据。 106 | 107 | 首先下载两份原始数据,地址: 108 | 109 | [百度网盘](https://pan.baidu.com/s/1Mu46NOtrrJhqN68s9WfLKg) [谷歌云盘](https://drive.google.com/drive/folders/1kqHG0KszGhkyLA4AZSLZ2XZm9sxD8b58?usp=sharing) 110 | 111 | 数据格式简单说明: 112 | - 第一行: 两个实体ID: ent1id ent2id 113 | - 第二行: bag标签和bag内句子个数,其中由于少数bag有多个label(不会超过4个),因此句子label用4个整数表示,-1表示为空,如: 2 4 -1 -1 3 表示该bag的标签为2和4,然后包含3个句子 114 | - 后续几行表示该bag内的句子 115 | 116 | 117 | 将两个zip放到`dataset`目录下,解压,这样会形成两个目录 ,一个NYT, 一个FilterNYT, 其中LARGE数据集在NYT目录,SMALL数据在FilterNYT内,这里的原始数据分别是从Zeng 2015 以及 Lin2016 的开源代码中获得。 118 | 119 | 120 | 121 | 对于LARGE数据: 122 | 123 | 124 | 125 | - 切换到NYT目录下, 126 | 127 | - 编译执行extract_cpp目录的extract.cpp: `g++ extract.cpp -o extract`, 之后执行:`./extract`, 得到`bag_train.txt, bag_test.txt, vector.txt` (在NYT目录内),该cpp是Lin2016预处理的代码 128 | 129 | - 切换回主目录:执行数据预处理: `python dataset/nyt.py` 这样就会在NYT目录下生成一系列的npy文件。 130 | 131 | 132 | 133 | 对于SMALL数据 134 | 135 | - 直接执行 `python dataset/filternyt.py` 即可在FilterNYT的目录下生成npy文件。 136 | 137 | 138 | 139 | 生成的NPY文件,均使用Pytorch的Dataset来直接导入,具体代码见 `nyt.py` 与`filternyt.py` 的 `*Data`类. 140 | 141 | 数据预处理完毕之后,即可按照上述的命令来训练/测试。 142 | 143 | 144 | 145 | ## 调参优化 146 | 147 | 在复现的过程了花了不少功夫,踩了不少坑,简单记一下: 148 | 149 | - 优化函数使用`Adadelta`而不是`Adam`, 用`SGD` 也可以,不过不如`Adadelta` 效果好。 150 | 151 | - Zeng 2015的theano代码中,关于select instance 和predict的地方,有些错误(并没有取概率最大的instance) 152 | 153 | - BatchSize相对大一些效果要好(128) 154 | 155 | 156 | 157 | 关于结果的说明可以在博客查看。 158 | 159 | 160 | 161 | ## 参考 162 | 163 | - [PCNN+ONE Zeng 2015](https://github.com/smilelhh/ds_pcnns) 164 | - [PCNN+ATT Lin 2016](https://github.com/thunlp/OpenNRE) 165 | - [RE-DS-Word-Attention-Models](https://github.com/SharmisthaJat/RE-DS-Word-Attention-Models) 166 | - [GloRE](https://github.com/ppuliu/GloRE) 167 | 168 | ## 附 169 | 使用此代码可以自愿选择引用: 170 | ``` 171 | @inproceedings{liu2019reet, 172 | title={REET: Joint Relation Extraction and Entity Typing via Multi-task Learning}, 173 | author={Liu, Hongtao and Wang, Peiyi and Wu, Fangzhao and Jiao, Pengfei and Wang, Wenjun and Xie, Xing and Sun, Yueheng}, 174 | booktitle={CCF International Conference on Natural Language Processing and Chinese Computing}, 175 | pages={327--339}, 176 | year={2019}, 177 | organization={Springer} 178 | } 179 | ``` 180 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | data_dic ={ 4 | 'NYT': { 5 | 'data_root': './dataset/NYT/', 6 | 'w2v_path': './dataset/NYT/w2v.npy', 7 | 'p1_2v_path': './dataset/NYT/p1_2v.npy', 8 | 'p2_2v_path': './dataset/NYT/p2_2v.npy', 9 | 'vocab_size': 114043, 10 | 'rel_num': 53 11 | }, 12 | 'FilterNYT': { 13 | 'data_root': './dataset/FilterNYT/', 14 | 'w2v_path': './dataset/FilterNYT/w2v.npy', 15 | 'p1_2v_path': './dataset/FilterNYT/p1_2v.npy', 16 | 'p2_2v_path': './dataset/FilterNYT/p2_2v.npy', 17 | 'vocab_size': 160695 + 2, 18 | 'rel_num': 27 19 | } 20 | } 21 | 22 | 23 | class DefaultConfig(object): 24 | 25 | model = 'PCNN_ONE' # the name of used model, in 26 | data = 'NYT' # SEM NYT FilterNYT 27 | 28 | result_dir = './out' 29 | data_root = data_dic[data]['data_root'] # the data dir 30 | w2v_path = data_dic[data]['w2v_path'] 31 | p1_2v_path = data_dic[data]['p1_2v_path'] 32 | p2_2v_path = data_dic[data]['p2_2v_path'] 33 | load_model_path = 'checkpoints/model.pth' # the trained model 34 | 35 | seed = 3435 36 | batch_size = 128 # batch size 37 | use_gpu = True # user GPU or not 38 | gpu_id = 1 39 | num_workers = 0 # how many workers for loading data 40 | 41 | max_len = 80 + 2 # max_len for each sentence + two padding 42 | limit = 50 # the position range <-limit, limit> 43 | 44 | vocab_size = data_dic[data]['vocab_size'] # vocab + UNK + BLANK 45 | rel_num = data_dic[data]['rel_num'] 46 | word_dim = 50 47 | pos_dim = 5 48 | pos_size = limit * 2 + 2 49 | 50 | norm_emb=True 51 | 52 | num_epochs = 16 # the number of epochs for training 53 | drop_out = 0.5 54 | lr = 0.0003 # initial learning rate 55 | lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay 56 | weight_decay = 0.0001 # optimizer parameter 57 | 58 | # Conv 59 | filters = [3] 60 | filters_num = 230 61 | sen_feature_dim = filters_num 62 | 63 | rel_dim = filters_num * len(filters) 64 | rel_filters_num = 100 65 | 66 | print_opt = 'DEF' 67 | use_pcnn=True 68 | 69 | 70 | def parse(self, kwargs): 71 | ''' 72 | user can update the default hyperparamter 73 | ''' 74 | for k, v in kwargs.items(): 75 | if not hasattr(self, k): 76 | raise Exception('opt has No key: {}'.format(k)) 77 | setattr(self, k, v) 78 | data_list = ['data_root', 'w2v_path', 'rel_num', 'vocab_size', 'p1_2v_path', 'p2_2v_path'] 79 | for r in data_list: 80 | setattr(self, r, data_dic[self.data][r]) 81 | 82 | print('*************************************************') 83 | print('user config:') 84 | for k, v in self.__class__.__dict__.items(): 85 | if not k.startswith('__'): 86 | print("{} => {}".format(k, getattr(self, k))) 87 | 88 | print('*************************************************') 89 | 90 | 91 | DefaultConfig.parse = parse 92 | opt = DefaultConfig() 93 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .nyt import NYTData 4 | from .filternyt import FilterNYTData 5 | -------------------------------------------------------------------------------- /dataset/extract.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | using namespace std; 17 | int output_model = 0; 18 | 19 | string version = ""; 20 | 21 | float wt = 0.5; 22 | 23 | int num_threads = 12; 24 | int trainTimes = 15; 25 | float alpha = 0.03; 26 | float reduce = 0.98; 27 | int tt,tt1; 28 | int dimensionC = 230;//1000; 29 | int dimensionWPE = 5;//25; 30 | int window = 3; 31 | int limit = 50; 32 | //int limit = 100; 33 | float marginPositive = 2.5; 34 | float marginNegative = 0.5; 35 | float margin = 2; 36 | float Belt = 0.001; 37 | float *matrixB1, *matrixRelation, *matrixW1, *matrixRelationDao, *matrixRelationPr, *matrixRelationPrDao; 38 | float *matrixB1_egs, *matrixRelation_egs, *matrixW1_egs, *matrixRelationPr_egs; 39 | float *matrixB1_exs, *matrixRelation_exs, *matrixW1_exs, *matrixRelationPr_exs; 40 | float *wordVecDao,*wordVec_egs,*wordVec_exs; 41 | float *positionVecE1, *positionVecE2, *matrixW1PositionE1, *matrixW1PositionE2; 42 | float *positionVecE1_egs, *positionVecE2_egs, *matrixW1PositionE1_egs, *matrixW1PositionE2_egs, *positionVecE1_exs, *positionVecE2_exs, *matrixW1PositionE1_exs, *matrixW1PositionE2_exs; 43 | float *matrixW1PositionE1Dao; 44 | float *matrixW1PositionE2Dao; 45 | float *positionVecDaoE1; 46 | float *positionVecDaoE2; 47 | float *matrixW1Dao; 48 | float *matrixB1Dao; 49 | double mx = 0; 50 | int batch = 16; 51 | int npoch; 52 | int len; 53 | float rate = 1; 54 | FILE *logg; 55 | 56 | float *wordVec; 57 | long long wordTotal, dimension, relationTotal; 58 | long long PositionMinE1, PositionMaxE1, PositionTotalE1,PositionMinE2, PositionMaxE2, PositionTotalE2; 59 | map wordMapping; 60 | vector wordList; 61 | map relationMapping; 62 | vector trainLists, trainPositionE1, trainPositionE2, trainmaskLists; 63 | vector trainLength; 64 | vector headList, tailList, relationList, ldist, rdist,testldist, testrdist; 65 | vector testtrainLists, testPositionE1, testPositionE2, testmaskLists; 66 | vector testtrainLength; 67 | vector testheadList, testtailList, testrelationList; 68 | vector testheadLists, testtailLists, testrelationLists; 69 | vector trainheadLists, traintailLists, trainrelationLists; 70 | vector nam; 71 | 72 | map > bags_train, bags_test; 73 | 74 | void init() { 75 | 76 | 77 | FILE *f = fopen("vec.bin", "rb"); 78 | fscanf(f, "%lld", &wordTotal); 79 | fscanf(f, "%lld", &dimension); 80 | cout<<"wordTotal=\t"< tmpp; 150 | while (fscanf(f,"%s", buffer)==1) { 151 | std::string con = buffer; 152 | if (con=="###END###") break; 153 | int gg = wordMapping[con]; 154 | if (con == head_s) lefnum = len; 155 | if (con == tail_s) rignum = len; 156 | len++; 157 | tmpp.push_back(gg); 158 | } 159 | int first_num = lefnum, second_num = rignum; 160 | if (lefnum > rignum){ 161 | swap(first_num, second_num); 162 | } 163 | headList.push_back(head); 164 | tailList.push_back(tail); 165 | trainheadLists.push_back(head_s); 166 | traintailLists.push_back(tail_s); 167 | 168 | relationList.push_back(num); 169 | trainLength.push_back(len); 170 | ldist.push_back(lefnum); 171 | rdist.push_back(rignum); 172 | int *con=(int *)calloc(len,sizeof(int)); 173 | int *conl=(int *)calloc(len,sizeof(int)); 174 | int *conr=(int *)calloc(len,sizeof(int)); 175 | int *trainmask = (int *)calloc(len, sizeof(int)); 176 | 177 | for (int i = 0; i < len; i++) { 178 | if( i - first_num <= 0) trainmask[i] = 1; 179 | else if (i - second_num <= 0) trainmask[i] = 2; 180 | else trainmask[i] = 3; 181 | con[i] = tmpp[i]; 182 | conl[i] = lefnum - i; 183 | conr[i] = rignum - i; 184 | if (conl[i] >= limit) conl[i] = limit; 185 | if (conr[i] >= limit) conr[i] = limit; 186 | if (conl[i] <= -limit) conl[i] = -limit; 187 | if (conr[i] <= -limit) conr[i] = -limit; 188 | if (conl[i] > PositionMaxE1) PositionMaxE1 = conl[i]; 189 | if (conr[i] > PositionMaxE2) PositionMaxE2 = conr[i]; 190 | if (conl[i] < PositionMinE1) PositionMinE1 = conl[i]; 191 | if (conr[i] < PositionMinE2) PositionMinE2 = conr[i]; 192 | } 193 | trainLists.push_back(con); 194 | trainmaskLists.push_back(trainmask); 195 | trainPositionE1.push_back(conl); 196 | trainPositionE2.push_back(conr); 197 | } 198 | fclose(f); 199 | 200 | f = fopen("test.txt", "r"); 201 | while (fscanf(f,"%s",buffer)==1) { 202 | string e1 = buffer; 203 | fscanf(f,"%s",buffer); 204 | string e2 = buffer; 205 | bags_test[e1+"\t"+e2].push_back(testheadList.size()); 206 | fscanf(f,"%s",buffer); 207 | string head_s = (string)(buffer); 208 | int head = wordMapping[(string)(buffer)]; 209 | fscanf(f,"%s",buffer); 210 | string tail_s = (string)(buffer); 211 | int tail = wordMapping[(string)(buffer)]; 212 | fscanf(f,"%s",buffer); 213 | int num = relationMapping[(string)(buffer)]; 214 | int len = 0 , lefnum = 0, rignum = 0; 215 | std::vector tmpp; 216 | while (fscanf(f,"%s", buffer)==1) { 217 | std::string con = buffer; 218 | if (con=="###END###") break; 219 | int gg = wordMapping[con]; 220 | if (head_s == con) lefnum = len; 221 | if (tail_s == con) rignum = len; 222 | len++; 223 | tmpp.push_back(gg); 224 | } 225 | int first_num = lefnum, second_num = rignum; 226 | if (lefnum > rignum){ 227 | swap(first_num, second_num); 228 | } 229 | testheadList.push_back(head); 230 | testtailList.push_back(tail); 231 | testheadLists.push_back(head_s); 232 | testtailLists.push_back(tail_s); 233 | testrelationList.push_back(num); 234 | testtrainLength.push_back(len); 235 | testldist.push_back(lefnum); 236 | testrdist.push_back(rignum); 237 | int *con=(int *)calloc(len,sizeof(int)); 238 | int *conl=(int *)calloc(len,sizeof(int)); 239 | int *conr=(int *)calloc(len,sizeof(int)); 240 | int *testmask = (int *)calloc(len,sizeof(int)); 241 | for (int i = 0; i < len; i++) { 242 | con[i] = tmpp[i]; 243 | if( i - first_num <= 0) testmask[i] = 1; 244 | else if (i - second_num <= 0) testmask[i] = 2; 245 | else testmask[i] = 3; 246 | 247 | conl[i] = lefnum - i; 248 | conr[i] = rignum - i; 249 | if (conl[i] >= limit) conl[i] = limit; 250 | if (conr[i] >= limit) conr[i] = limit; 251 | if (conl[i] <= -limit) conl[i] = -limit; 252 | if (conr[i] <= -limit) conr[i] = -limit; 253 | if (conl[i] > PositionMaxE1) PositionMaxE1 = conl[i]; 254 | if (conr[i] > PositionMaxE2) PositionMaxE2 = conr[i]; 255 | if (conl[i] < PositionMinE1) PositionMinE1 = conl[i]; 256 | if (conr[i] < PositionMinE2) PositionMinE2 = conr[i]; 257 | } 258 | testtrainLists.push_back(con); 259 | testmaskLists.push_back(testmask); 260 | testPositionE1.push_back(conl); 261 | testPositionE2.push_back(conr); 262 | } 263 | fclose(f); 264 | cout< >::iterator it; 290 | fout.open(("../bags_train.txt"),ios::out); 291 | cout << "Number of bags "<< bags_train.size()<<'\t'<first; 295 | fout << bagname<<"\t"; 296 | vector indices = it->second; 297 | for(int i=0;ifirst; 356 | fout << bagname<<"\t"; 357 | vector indices = it->second; 358 | for(int i=0;i 0 ~ limit * 2+2 188 | : -51 => 1 189 | : -50 => 1 190 | : 50 => 101 191 | : >50: 102 192 | ''' 193 | 194 | def padding(x): 195 | if x < 1: 196 | return 1 197 | if x > self.limit * 2 + 1: 198 | return self.limit * 2 + 1 199 | return x 200 | 201 | if sen_len < self.max_len: 202 | index = np.arange(sen_len) 203 | else: 204 | index = np.arange(self.max_len) 205 | 206 | pf1 = [] 207 | pf2 = [] 208 | pf1 += list(map(padding, index - ent_pos[0] + 2 + self.limit)) 209 | pf2 += list(map(padding, index - ent_pos[1] + 2 + self.limit)) 210 | 211 | if len(pf1) < self.max_len + 2 * self.pad: 212 | pf1 += [0] * (self.max_len + 2 * self.pad - len(pf1)) 213 | pf2 += [0] * (self.max_len + 2 * self.pad - len(pf2)) 214 | mask += [0] * (self.max_len + 2 * self.pad - len(mask)) 215 | return [pf1, pf2], mask 216 | 217 | 218 | if __name__ == "__main__": 219 | data = FilterNYTLoad('./dataset/FilterNYT/') 220 | 221 | -------------------------------------------------------------------------------- /dataset/nyt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from torch.utils.data import Dataset 4 | import os 5 | import numpy as np 6 | 7 | 8 | class NYTData(Dataset): 9 | 10 | def __init__(self, root_path, train=True): 11 | if train: 12 | path = os.path.join(root_path, 'train/') 13 | print('loading train data') 14 | else: 15 | path = os.path.join(root_path, 'test/') 16 | print('loading test data') 17 | 18 | self.labels = np.load(path + 'labels.npy') 19 | self.x = np.load(path + 'bags_feature.npy') 20 | self.x = list(zip(self.x, self.labels)) 21 | 22 | print('loading finish') 23 | 24 | def __getitem__(self, idx): 25 | assert idx < len(self.x) 26 | return self.x[idx] 27 | 28 | def __len__(self): 29 | return len(self.x) 30 | 31 | 32 | class NYTLoad(object): 33 | ''' 34 | load and preprocess data 35 | ''' 36 | def __init__(self, root_path, max_len=80, limit=50, pos_dim=5, pad=1): 37 | 38 | self.max_len = max_len 39 | self.limit = limit 40 | self.root_path = root_path 41 | self.pos_dim = pos_dim 42 | self.pad = pad 43 | 44 | self.w2v_path = os.path.join(root_path, 'vector.txt') 45 | self.train_path = os.path.join(root_path, 'bags_train.txt') 46 | self.test_path = os.path.join(root_path, 'bags_test.txt') 47 | 48 | print('loading start....') 49 | self.w2v, self.word2id, self.id2word = self.load_w2v() 50 | self.p1_2v, self.p2_2v = self.load_p2v() 51 | 52 | np.save(os.path.join(self.root_path, 'w2v.npy'), self.w2v) 53 | np.save(os.path.join(self.root_path, 'p1_2v.npy'), self.p1_2v) 54 | np.save(os.path.join(self.root_path, 'p2_2v.npy'), self.p2_2v) 55 | 56 | print("parsing train text...") 57 | self.bags_feature, self.labels = self.parse_sen(self.train_path, 'train') 58 | np.save(os.path.join(self.root_path, 'train', 'bags_feature.npy'), self.bags_feature) 59 | np.save(os.path.join(self.root_path, 'train', 'labels.npy'), self.labels) 60 | 61 | print("parsing test text...") 62 | self.bags_feature, self.labels = self.parse_sen(self.test_path, 'test') 63 | np.save(os.path.join(self.root_path, 'test', 'bags_feature.npy'), self.bags_feature) 64 | np.save(os.path.join(self.root_path, 'test', 'labels.npy'), self.labels) 65 | print('save finish!') 66 | 67 | def load_p2v(self): 68 | pos1_vec = [np.zeros(self.pos_dim)] 69 | pos1_vec.extend([np.random.uniform(low=-1.0, high=1.0, size=self.pos_dim) for _ in range(self.limit * 2 + 1)]) 70 | pos2_vec = [np.zeros(self.pos_dim)] 71 | pos2_vec.extend([np.random.uniform(low=-1.0, high=1.0, size=self.pos_dim) for _ in range(self.limit * 2 + 1)]) 72 | return np.array(pos1_vec, dtype=np.float32), np.array(pos2_vec, dtype=np.float32) 73 | 74 | def load_w2v(self): 75 | ''' 76 | reading from vec.bin 77 | add two extra tokens: 78 | : UNK for unkown tokens 79 | ''' 80 | wordlist = [] 81 | 82 | f = open(self.w2v_path) 83 | # dim = int(f.readline().split()[1]) 84 | # f = f.readlines() 85 | 86 | vecs = [] 87 | for line in f: 88 | line = line.strip('\n').split() 89 | vec = list(map(float, line[1].split(',')[:-1])) 90 | vecs.append(vec) 91 | wordlist.append(line[0]) 92 | 93 | # wordlist.append('UNK') 94 | # vecs.append(np.random.uniform(low=-0.5, high=0.5, size=dim)) 95 | word2id = {j: i for i, j in enumerate(wordlist)} 96 | id2word = {i: j for i, j in enumerate(wordlist)} 97 | 98 | return np.array(vecs, dtype=np.float32), word2id, id2word 99 | 100 | def parse_sen(self, path, flag): 101 | ''' 102 | parse the records in data 103 | ''' 104 | all_sens =[] 105 | all_labels =[] 106 | f = open(path) 107 | while 1: 108 | line = f.readline() 109 | if not line: 110 | break 111 | if flag == 'train': 112 | line = line.split('\t') 113 | num = line[3].strip().split(',') 114 | num = len(num) 115 | else: 116 | line = line.split('\t') 117 | num = line[2].strip().split(',') 118 | num = len(num) 119 | 120 | ldists = [] 121 | rdists = [] 122 | sentences = [] 123 | entitiesPos = [] 124 | pos = [] 125 | masks = [] 126 | rels = [] 127 | 128 | for i in range(num): 129 | ent_pair_line = f.readline().strip().split(',') 130 | # entities = ent_pair_line[:2] 131 | # ignore the entities index in vocab 132 | entities = [0, 0] 133 | epos = list(map(lambda x: int(x) + 1, ent_pair_line[2:4])) 134 | pos.append(epos) 135 | epos.sort() 136 | entitiesPos.append(epos) 137 | 138 | rel = int(ent_pair_line[4]) 139 | rels.append(rel) 140 | sent = f.readline().strip().split(',') 141 | sentences.append(list(map(lambda x: int(x), sent))) 142 | ldist = f.readline().strip().split(',') 143 | rdist = f.readline().strip().split(',') 144 | mask = f.readline().strip().split(",") 145 | ldists.append(list(map(int, ldist))) 146 | rdists.append(list(map(int, rdist))) 147 | masks.append(list(map(int, mask))) 148 | 149 | rels = list(set(rels)) 150 | if len(rels) < 4: 151 | rels.extend([-1] * (4 - len(rels))) 152 | else: 153 | rels = rels[:4] 154 | bag = [entities, num, sentences, ldists, rdists, pos, entitiesPos, masks] 155 | 156 | all_labels.append(rels) 157 | all_sens += [bag] 158 | 159 | f.close() 160 | bags_feature = self.get_sentence_feature(all_sens) 161 | 162 | return bags_feature, all_labels 163 | 164 | def get_sentence_feature(self, bags): 165 | ''' 166 | : word embedding 167 | : postion embedding 168 | return: 169 | sen list 170 | pos_left 171 | pos_right 172 | ''' 173 | update_bags = [] 174 | 175 | for bag in bags: 176 | es, num, sens, ldists, rdists, pos, enPos, masks = bag 177 | new_sen = [] 178 | new_pos = [] 179 | new_entPos = [] 180 | new_masks= [] 181 | 182 | for idx, sen in enumerate(sens): 183 | sen, pf1, pf2, pos, mask = self.get_pad_sen_pos(sen, ldists[idx], rdists[idx], enPos[idx], masks[idx]) 184 | new_sen.append(sen) 185 | new_pos.append([pf1, pf2]) 186 | new_entPos.append(pos) 187 | new_masks.append(mask) 188 | update_bags.append([es, num, new_sen, new_pos, new_entPos, new_masks]) 189 | 190 | return update_bags 191 | 192 | def get_pad_sen_pos(self, sen, ldist, rdist, pos, mask): 193 | ''' 194 | refer: github.com/SharmisthaJat/RE-DS-Word-Attention-Models 195 | ''' 196 | x = [] 197 | pf1 = [] 198 | pf2 = [] 199 | masks = [] 200 | 201 | # shorter than max_len 202 | if len(sen) <= self.max_len: 203 | for i, ind in enumerate(sen): 204 | x.append(ind) 205 | pf1.append(ldist[i] + 1) 206 | pf2.append(rdist[i] + 1) 207 | masks.append(mask[i]) 208 | # longer than max_len, expand between two entities 209 | else: 210 | idx = [i for i in range(pos[0], pos[1] + 1)] 211 | if len(idx) > self.max_len: 212 | idx = idx[:self.max_len] 213 | for i in idx: 214 | x.append(sen[i]) 215 | pf1.append(ldist[i] + 1) 216 | pf2.append(rdist[i] + 1) 217 | masks.append(mask[i]) 218 | pos[0] = 1 219 | pos[1] = len(idx) - 1 220 | else: 221 | for i in idx: 222 | x.append(sen[i]) 223 | pf1.append(ldist[i] + 1) 224 | pf2.append(rdist[i] + 1) 225 | masks.append(mask[i]) 226 | 227 | before = pos[0] - 1 228 | after = pos[1] + 1 229 | pos[0] = 1 230 | pos[1] = len(idx) - 1 231 | numAdded = 0 232 | while True: 233 | added = 0 234 | if before >= 0 and len(x) + 1 <= self.max_len + self.pad: 235 | x.append(sen[before]) 236 | pf1.append(ldist[before] + 1) 237 | pf2.append(rdist[before] + 1) 238 | masks.append(mask[before]) 239 | added = 1 240 | numAdded += 1 241 | 242 | if after < len(sen) and len(x) + 1 <= self.max_len + self.pad: 243 | x.append(sen[after]) 244 | pf1.append(ldist[after] + 1) 245 | pf2.append(rdist[after] + 1) 246 | masks.append(mask[after]) 247 | added = 1 248 | 249 | if added == 0: 250 | break 251 | 252 | before -= 1 253 | after += 1 254 | 255 | pos[0] = pos[0] + numAdded 256 | pos[1] = pos[1] + numAdded 257 | 258 | while len(x) < self.max_len + 2 * self.pad: 259 | x.append(0) 260 | pf1.append(0) 261 | pf2.append(0) 262 | masks.append(0) 263 | 264 | if pos[0] == pos[1]: 265 | if pos[1] + 1 < len(sen): 266 | pos[1] += 1 267 | else: 268 | if pos[0] - 1 >= 1: 269 | pos[0] = pos[0] - 1 270 | else: 271 | raise Exception('pos= {},{}'.format(pos[0], pos[1])) 272 | 273 | return [x, pf1, pf2, pos, masks] 274 | 275 | def get_pad_sen(self, sen): 276 | ''' 277 | padding the sentences 278 | ''' 279 | sen.insert(0, self.word2id['BLANK']) 280 | if len(sen) < self.max_len + 2 * self.pad: 281 | sen += [self.word2id['BLANK']] * (self.max_len +2 * self.pad - len(sen)) 282 | else: 283 | sen = sen[: self.max_len + 2 * self.pad] 284 | 285 | return sen 286 | 287 | def get_pos_feature(self, sen_len, ent_pos): 288 | ''' 289 | clip the postion range: 290 | : -limit ~ limit => 0 ~ limit * 2+2 291 | : -51 => 1 292 | : -50 => 1 293 | : 50 => 101 294 | : >50: 102 295 | ''' 296 | 297 | def padding(x): 298 | if x < 1: 299 | return 1 300 | if x > self.limit * 2 + 1: 301 | return self.limit * 2 + 1 302 | return x 303 | 304 | if sen_len < self.max_len: 305 | index = np.arange(sen_len) 306 | else: 307 | index = np.arange(self.max_len) 308 | 309 | pf1 = [0] 310 | pf2 = [0] 311 | pf1 += list(map(padding, index - ent_pos[0] + 2 + self.limit)) 312 | pf2 += list(map(padding, index - ent_pos[1] + 2 + self.limit)) 313 | 314 | if len(pf1) < self.max_len + 2 * self.pad: 315 | pf1 += [0] * (self.max_len + 2 * self.pad - len(pf1)) 316 | pf2 += [0] * (self.max_len + 2 * self.pad - len(pf2)) 317 | return [pf1, pf2] 318 | 319 | 320 | if __name__ == "__main__": 321 | data = NYTLoad('./dataset/NYT/') 322 | -------------------------------------------------------------------------------- /main_att.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from config import opt 4 | import models 5 | import dataset 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from utils import save_pr, now, eval_metric 10 | 11 | 12 | def collate_fn(batch): 13 | ''' 14 | custom for DataLoader 15 | ''' 16 | data, label = zip(*batch) 17 | return data, label 18 | 19 | 20 | def test(**kwargs): 21 | pass 22 | 23 | 24 | def train(**kwargs): 25 | 26 | kwargs.update({'model': 'PCNN_ATT'}) 27 | opt.parse(kwargs) 28 | 29 | if opt.use_gpu: 30 | torch.cuda.set_device(opt.gpu_id) 31 | 32 | model = getattr(models, 'PCNN_ATT')(opt) 33 | if opt.use_gpu: 34 | model.cuda() 35 | 36 | # loading data 37 | DataModel = getattr(dataset, opt.data + 'Data') 38 | train_data = DataModel(opt.data_root, train=True) 39 | train_data_loader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers, collate_fn=collate_fn) 40 | 41 | test_data = DataModel(opt.data_root, train=False) 42 | test_data_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn) 43 | print('{} train data: {}; test data: {}'.format(now(), len(train_data), len(test_data))) 44 | 45 | # criterion and optimizer 46 | # criterion = nn.CrossEntropyLoss() 47 | optimizer = optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6) 48 | 49 | # train 50 | # max_pre = -1.0 51 | # max_rec = -1.0 52 | for epoch in range(opt.num_epochs): 53 | total_loss = 0 54 | for idx, (data, label_set) in enumerate(train_data_loader): 55 | 56 | label = [l[0] for l in label_set] 57 | 58 | optimizer.zero_grad() 59 | model.batch_size = opt.batch_size 60 | loss = model(data, label) 61 | if opt.use_gpu: 62 | label = torch.LongTensor(label).cuda() 63 | else: 64 | label = torch.LongTensor(label) 65 | loss.backward() 66 | optimizer.step() 67 | total_loss += loss.item() 68 | 69 | # if idx % 100 == 99: 70 | # print('{}: Train iter: {} finish'.format(now(), idx)) 71 | 72 | if epoch > 2: 73 | # true_y, pred_y, pred_p= predict(model, test_data_loader) 74 | # all_pre, all_rec = eval_metric(true_y, pred_y, pred_p) 75 | pred_res, p_num = predict_var(model, test_data_loader) 76 | all_pre, all_rec = eval_metric_var(pred_res, p_num) 77 | 78 | last_pre, last_rec = all_pre[-1], all_rec[-1] 79 | if last_pre > 0.24 and last_rec > 0.24: 80 | save_pr(opt.result_dir, model.model_name, epoch, all_pre, all_rec, opt=opt.print_opt) 81 | print('{} Epoch {} save pr'.format(now(), epoch + 1)) 82 | 83 | print('{} Epoch {}/{}: train loss: {}; test precision: {}, test recall {}'.format(now(), epoch + 1, opt.num_epochs, total_loss, last_pre, last_rec)) 84 | else: 85 | print('{} Epoch {}/{}: train loss: {};'.format(now(), epoch + 1, opt.num_epochs, total_loss)) 86 | 87 | 88 | def predict_var(model, test_data_loader): 89 | ''' 90 | Apply the prediction method in Lin 2016 91 | ''' 92 | model.eval() 93 | 94 | res = [] 95 | true_y = [] 96 | for idx, (data, labels) in enumerate(test_data_loader): 97 | out = model(data) 98 | true_y.extend(labels) 99 | if opt.use_gpu: 100 | # out = map(lambda o: o.data.cpu().numpy().tolist(), out) 101 | out = out.data.cpu().numpy().tolist() 102 | else: 103 | # out = map(lambda o: o.data.numpy().tolist(), out) 104 | out = out.data.numpy().tolist() 105 | 106 | for r in range(1, opt.rel_num): 107 | for j in range(len(out[0])): 108 | res.append([labels[j], r, out[r][j]]) 109 | 110 | # if idx % 100 == 99: 111 | # print('{} Eval: iter {}'.format(now(), idx)) 112 | 113 | model.train() 114 | positive_num = len([i for i in true_y if i[0] > 0]) 115 | return res, positive_num 116 | 117 | 118 | def eval_metric_var(pred_res, p_num): 119 | ''' 120 | Apply the evalation method in Lin 2016 121 | ''' 122 | 123 | pred_res_sort = sorted(pred_res, key=lambda x: -x[2]) 124 | correct = 0.0 125 | all_pre = [] 126 | all_rec = [] 127 | 128 | for i in range(2000): 129 | true_y = pred_res_sort[i][0] 130 | pred_y = pred_res_sort[i][1] 131 | for j in true_y: 132 | if pred_y == j: 133 | correct += 1 134 | break 135 | precision = correct / (i + 1) 136 | recall = correct / p_num 137 | all_pre.append(precision) 138 | all_rec.append(recall) 139 | 140 | print("positive_num: {}; correct: {}".format(p_num, correct)) 141 | return all_pre, all_rec 142 | 143 | 144 | def predict(model, test_data_loader): 145 | ''' 146 | Apply the prediction method in Zeng 2015 147 | ''' 148 | 149 | model.eval() 150 | 151 | pred_y = [] 152 | true_y = [] 153 | pred_p = [] 154 | for idx, (data, labels) in enumerate(test_data_loader): 155 | true_y.extend(labels) 156 | out = model(data) 157 | res = torch.max(out, 1) 158 | if model.opt.use_gpu: 159 | pred_y.extend(res[1].data.cpu().numpy().tolist()) 160 | pred_p.extend(res[0].data.cpu().numpy().tolist()) 161 | else: 162 | pred_y.extend(res[1].data.numpy().tolist()) 163 | pred_p.extend(res[0].data.numpy().tolist()) 164 | # if idx % 100 == 99: 165 | # print('{} Eval: iter {}'.format(now(), idx)) 166 | 167 | size = len(test_data_loader.dataset) 168 | assert len(pred_y) == size and len(true_y) == size 169 | assert len(pred_y) == len(pred_p) 170 | model.train() 171 | return true_y, pred_y, pred_p 172 | 173 | 174 | if __name__ == "__main__": 175 | import fire 176 | fire.Fire() 177 | -------------------------------------------------------------------------------- /main_mil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from config import opt 4 | import models 5 | import dataset 6 | import torch 7 | import numpy as np 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils import save_pr, now, eval_metric 13 | 14 | 15 | def collate_fn(batch): 16 | data, label = zip(*batch) 17 | return data, label 18 | 19 | 20 | def test(**kwargs): 21 | pass 22 | 23 | 24 | def setup_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | def train(**kwargs): 32 | 33 | setup_seed(opt.seed) 34 | 35 | kwargs.update({'model': 'PCNN_ONE'}) 36 | opt.parse(kwargs) 37 | 38 | if opt.use_gpu: 39 | torch.cuda.set_device(opt.gpu_id) 40 | 41 | # torch.manual_seed(opt.seed) 42 | model = getattr(models, 'PCNN_ONE')(opt) 43 | if opt.use_gpu: 44 | # torch.cuda.manual_seed_all(opt.seed) 45 | model.cuda() 46 | # parallel 47 | # model = nn.DataParallel(model) 48 | 49 | # loading data 50 | DataModel = getattr(dataset, opt.data + 'Data') 51 | train_data = DataModel(opt.data_root, train=True) 52 | train_data_loader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers, collate_fn=collate_fn) 53 | 54 | test_data = DataModel(opt.data_root, train=False) 55 | test_data_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn) 56 | print('train data: {}; test data: {}'.format(len(train_data), len(test_data))) 57 | 58 | criterion = nn.CrossEntropyLoss() 59 | optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), rho=1.0, eps=1e-6, weight_decay=opt.weight_decay) 60 | # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay) 61 | # optimizer = optim.Adadelta(model.parameters(), rho=1.0, eps=1e-6, weight_decay=opt.weight_decay) 62 | # train 63 | print("start training...") 64 | max_pre = -1.0 65 | max_rec = -1.0 66 | for epoch in range(opt.num_epochs): 67 | 68 | total_loss = 0 69 | for idx, (data, label_set) in enumerate(train_data_loader): 70 | label = [l[0] for l in label_set] 71 | 72 | if opt.use_gpu: 73 | label = torch.LongTensor(label).cuda() 74 | else: 75 | label = torch.LongTensor(label) 76 | 77 | data = select_instance(model, data, label) 78 | model.batch_size = opt.batch_size 79 | 80 | optimizer.zero_grad() 81 | 82 | out = model(data, train=True) 83 | loss = criterion(out, label) 84 | loss.backward() 85 | optimizer.step() 86 | 87 | total_loss += loss.item() 88 | 89 | if epoch < -1: 90 | continue 91 | true_y, pred_y, pred_p = predict(model, test_data_loader) 92 | all_pre, all_rec, fp_res = eval_metric(true_y, pred_y, pred_p) 93 | 94 | last_pre, last_rec = all_pre[-1], all_rec[-1] 95 | if last_pre > 0.24 and last_rec > 0.24: 96 | save_pr(opt.result_dir, model.model_name, epoch, all_pre, all_rec, fp_res, opt=opt.print_opt) 97 | print('{} Epoch {} save pr'.format(now(), epoch + 1)) 98 | if last_pre > max_pre and last_rec > max_rec: 99 | print("save model") 100 | max_pre = last_pre 101 | max_rec = last_rec 102 | model.save(opt.print_opt) 103 | 104 | print('{} Epoch {}/{}: train loss: {}; test precision: {}, test recall {}'.format(now(), epoch + 1, opt.num_epochs, total_loss, last_pre, last_rec)) 105 | 106 | 107 | def select_instance(model, batch_data, labels): 108 | 109 | model.eval() 110 | select_ent = [] 111 | select_num = [] 112 | select_sen = [] 113 | select_pf = [] 114 | select_pool = [] 115 | select_mask = [] 116 | for idx, bag in enumerate(batch_data): 117 | insNum = bag[1] 118 | label = labels[idx] 119 | max_ins_id = 0 120 | if insNum > 1: 121 | model.batch_size = insNum 122 | if opt.use_gpu: 123 | data = map(lambda x: torch.LongTensor(x).cuda(), bag) 124 | else: 125 | data = map(lambda x: torch.LongTensor(x), bag) 126 | 127 | out = model(data) 128 | 129 | # max_ins_id = torch.max(torch.max(out, 1)[0], 0)[1] 130 | max_ins_id = torch.max(out[:, label], 0)[1] 131 | 132 | if opt.use_gpu: 133 | # max_ins_id = max_ins_id.data.cpu().numpy()[0] 134 | max_ins_id = max_ins_id.item() 135 | else: 136 | max_ins_id = max_ins_id.data.numpy()[0] 137 | 138 | max_sen = bag[2][max_ins_id] 139 | max_pf = bag[3][max_ins_id] 140 | max_pool = bag[4][max_ins_id] 141 | max_mask = bag[5][max_ins_id] 142 | 143 | select_ent.append(bag[0]) 144 | select_num.append(bag[1]) 145 | select_sen.append(max_sen) 146 | select_pf.append(max_pf) 147 | select_pool.append(max_pool) 148 | select_mask.append(max_mask) 149 | 150 | if opt.use_gpu: 151 | data = map(lambda x: torch.LongTensor(x).cuda(), [select_ent, select_num, select_sen, select_pf, select_pool, select_mask]) 152 | else: 153 | data = map(lambda x: torch.LongTensor(x), [select_ent, select_num, select_sen, select_pf, select_pool, select_mask]) 154 | 155 | model.train() 156 | return data 157 | 158 | 159 | def predict(model, test_data_loader): 160 | 161 | model.eval() 162 | 163 | pred_y = [] 164 | true_y = [] 165 | pred_p = [] 166 | for idx, (data, labels) in enumerate(test_data_loader): 167 | true_y.extend(labels) 168 | for bag in data: 169 | insNum = bag[1] 170 | model.batch_size = insNum 171 | if opt.use_gpu: 172 | data = map(lambda x: torch.LongTensor(x).cuda(), bag) 173 | else: 174 | data = map(lambda x: torch.LongTensor(x), bag) 175 | 176 | out = model(data) 177 | out = F.softmax(out, 1) 178 | max_ins_prob, max_ins_label = map(lambda x: x.data.cpu().numpy(), torch.max(out, 1)) 179 | tmp_prob = -1.0 180 | tmp_NA_prob = -1.0 181 | pred_label = 0 182 | pos_flag = False 183 | 184 | for i in range(insNum): 185 | if pos_flag and max_ins_label[i] < 1: 186 | continue 187 | else: 188 | if max_ins_label[i] > 0: 189 | pos_flag = True 190 | if max_ins_prob[i] > tmp_prob: 191 | pred_label = max_ins_label[i] 192 | tmp_prob = max_ins_prob[i] 193 | else: 194 | if max_ins_prob[i] > tmp_NA_prob: 195 | tmp_NA_prob = max_ins_prob[i] 196 | 197 | if pos_flag: 198 | pred_p.append(tmp_prob) 199 | else: 200 | pred_p.append(tmp_NA_prob) 201 | 202 | pred_y.append(pred_label) 203 | 204 | size = len(test_data_loader.dataset) 205 | assert len(pred_y) == size and len(true_y) == size 206 | 207 | model.train() 208 | return true_y, pred_y, pred_p 209 | 210 | 211 | if __name__ == "__main__": 212 | import fire 213 | fire.Fire() 214 | -------------------------------------------------------------------------------- /models/BasicModule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import time 5 | 6 | 7 | class BasicModule(torch.nn.Module): 8 | ''' 9 | 封装了nn.Module,主要是提供了save和load两个方法 10 | ''' 11 | 12 | def __init__(self): 13 | super(BasicModule, self).__init__() 14 | self.model_name=str(type(self)) # model name 15 | 16 | def load(self, path): 17 | ''' 18 | 可加载指定路径的模型 19 | ''' 20 | self.load_state_dict(torch.load(path)) 21 | 22 | def save(self, name=None): 23 | ''' 24 | 保存模型,默认使用“模型名字+时间”作为文件名 25 | ''' 26 | prefix = 'checkpoints/' 27 | if name is None: 28 | name = prefix + self.model_name + '_' 29 | name = time.strftime(name + '%m%d_%H:%M:%S.pth') 30 | else: 31 | name = prefix + self.model_name + '_' + str(name)+ '.pth' 32 | torch.save(self.state_dict(), name) 33 | return name 34 | -------------------------------------------------------------------------------- /models/PCNN_ATT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .BasicModule import BasicModule 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | 11 | class PCNN_ATT(BasicModule): 12 | ''' 13 | Lin 2016 Att PCNN 14 | ''' 15 | def __init__(self, opt): 16 | super(PCNN_ATT, self).__init__() 17 | 18 | self.opt = opt 19 | self.model_name = 'PCNN_ATT' 20 | self.test_scale_p = 0.5 21 | 22 | self.word_embs = nn.Embedding(self.opt.vocab_size, self.opt.word_dim) 23 | self.pos1_embs = nn.Embedding(self.opt.pos_size, self.opt.pos_dim) 24 | self.pos2_embs = nn.Embedding(self.opt.pos_size, self.opt.pos_dim) 25 | 26 | all_filter_num = self.opt.filters_num * len(self.opt.filters) 27 | 28 | rel_dim = all_filter_num 29 | 30 | if self.opt.use_pcnn: 31 | rel_dim = all_filter_num * 3 32 | masks = torch.LongTensor(([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])) 33 | if self.opt.use_gpu: 34 | masks = masks.cuda() 35 | self.mask_embedding = nn.Embedding(4, 3) 36 | self.mask_embedding.weight.data.copy_(masks) 37 | self.mask_embedding.weight.requires_grad = False 38 | 39 | self.rel_embs = nn.Parameter(torch.randn(self.opt.rel_num, rel_dim)) 40 | self.rel_bias = nn.Parameter(torch.randn(self.opt.rel_num)) 41 | 42 | # the Relation-Specific Attention Diagonal Matrix 43 | # self.att_w = nn.ParameterList([nn.Parameter(torch.eye(rel_dim)) for _ in range(self.opt.rel_num)]) 44 | 45 | # Conv filter width 46 | feature_dim = self.opt.word_dim + self.opt.pos_dim * 2 47 | 48 | # option for multi size filter 49 | # here is only a kind of filter with height = 3 50 | self.convs = nn.ModuleList([nn.Conv2d(1, self.opt.filters_num, (k, feature_dim), padding=(int(k / 2), 0)) for k in self.opt.filters]) 51 | self.dropout = nn.Dropout(self.opt.drop_out) 52 | 53 | self.init_model_weight() 54 | self.init_word_emb() 55 | 56 | def init_model_weight(self): 57 | ''' 58 | use xavier to init 59 | ''' 60 | nn.init.xavier_uniform(self.rel_embs) 61 | nn.init.uniform(self.rel_bias) 62 | for conv in self.convs: 63 | nn.init.xavier_uniform(conv.weight) 64 | nn.init.uniform(conv.bias) 65 | 66 | def init_word_emb(self): 67 | 68 | def p_2norm(path): 69 | v = torch.from_numpy(np.load(path)) 70 | if self.opt.norm_emb: 71 | v = torch.div(v, v.norm(2, 1).unsqueeze(1)) 72 | v[v != v] = 0.0 73 | return v 74 | 75 | w2v = p_2norm(self.opt.w2v_path) 76 | p1_2v = p_2norm(self.opt.p1_2v_path) 77 | p2_2v = p_2norm(self.opt.p2_2v_path) 78 | 79 | if self.opt.use_gpu: 80 | self.word_embs.weight.data.copy_(w2v.cuda()) 81 | self.pos1_embs.weight.data.copy_(p1_2v.cuda()) 82 | self.pos2_embs.weight.data.copy_(p2_2v.cuda()) 83 | else: 84 | self.pos1_embs.weight.data.copy_(p1_2v) 85 | self.pos2_embs.weight.data.copy_(p2_2v) 86 | self.word_embs.weight.data.copy_(w2v) 87 | 88 | def init_int_constant(self, num): 89 | ''' 90 | a util function for generating a LongTensor Variable 91 | ''' 92 | if self.opt.use_gpu: 93 | return Variable(torch.LongTensor([num]).cuda()) 94 | else: 95 | return Variable(torch.LongTensor([num])) 96 | 97 | def mask_piece_pooling(self, x, mask): 98 | ''' 99 | refer: https://github.com/thunlp/OpenNRE 100 | A fast piecewise pooling using mask 101 | ''' 102 | x = x.unsqueeze(-1).permute(0, 2, 1, 3) 103 | masks = self.mask_embedding(mask).unsqueeze(-2) * 100 104 | x = masks + x 105 | x = torch.max(x, 1)[0] - 100 106 | return x.view(-1, x.size(1) * x.size(2)) 107 | 108 | def piece_max_pooling(self, x, insPool): 109 | ''' 110 | piecewise pool into 3 segements 111 | x: the batch data 112 | insPool: the batch Pool 113 | ''' 114 | split_batch_x = torch.split(x, 1, 0) 115 | split_pool = torch.split(insPool, 1, 0) 116 | batch_res = [] 117 | 118 | for i in range(len(split_pool)): 119 | ins = split_batch_x[i].squeeze() # all_filter_num * max_len 120 | pool = split_pool[i].squeeze().data # 2 121 | seg_1 = ins[:, :pool[0]].max(1)[0].unsqueeze(1) # all_filter_num * 1 122 | seg_2 = ins[:, pool[0]: pool[1]].max(1)[0].unsqueeze(1) # all_filter_num * 1 123 | seg_3 = ins[:, pool[1]:].max(1)[0].unsqueeze(1) 124 | piece_max_pool = torch.cat([seg_1, seg_2, seg_3], 1).view(1, -1) # 1 * 3all_filter_num 125 | batch_res.append(piece_max_pool) 126 | 127 | out = torch.cat(batch_res, 0) 128 | assert out.size(1) == 3 * self.opt.filters_num 129 | return out 130 | 131 | def forward(self, x, label=None): 132 | 133 | # get all sentences embedding in all bags of one batch 134 | self.bags_feature = self.get_bags_feature(x) 135 | 136 | if label is None: 137 | # for test 138 | assert self.training is False 139 | return self.test(x) 140 | else: 141 | # for train 142 | assert self.training is True 143 | return self.fit(x, label) 144 | 145 | def fit(self, x, label): 146 | ''' 147 | train process 148 | ''' 149 | x = self.get_batch_feature(label) # batch_size * sentence_feature_num 150 | x = self.dropout(x) 151 | out = x.mm(self.rel_embs.t()) + self.rel_bias # o = Ms + d (formual 10 in paper) 152 | 153 | if self.opt.use_gpu: 154 | v_label = torch.LongTensor(label).cuda() 155 | else: 156 | v_label = torch.LongTensor(label) 157 | ce_loss = F.cross_entropy(out, Variable(v_label)) 158 | return ce_loss 159 | 160 | def test(self, x): 161 | ''' 162 | test process 163 | ''' 164 | pre_y = [] 165 | for label in range(0, self.opt.rel_num): 166 | labels = [label for _ in range(len(x))] # generate the batch labels 167 | bags_feature = self.get_batch_feature(labels) 168 | out = self.test_scale_p * bags_feature.mm(self.rel_embs.t()) + self.rel_bias 169 | # out = F.softmax(out, 1) 170 | # pre_y.append(out[:, label]) 171 | pre_y.append(out.unsqueeze(1)) 172 | 173 | # return pre_y 174 | res = torch.cat(pre_y, 1).max(1)[0] 175 | return F.softmax(res, 1).t() 176 | 177 | def get_batch_feature(self, labels): 178 | ''' 179 | Using Attention to get all bags embedding in a batch 180 | ''' 181 | batch_feature = [] 182 | 183 | for bag_embs, label in zip(self.bags_feature, labels): 184 | # calculate the weight: xAr or xr 185 | alpha = bag_embs.mm(self.rel_embs[label].view(-1, 1)) 186 | # alpha = bag_embs.mm(self.att_w[label]).mm(self.rel_embs[label].view(-1, 1)) 187 | bag_embs = bag_embs * F.softmax(alpha, 0) 188 | bag_vec = torch.sum(bag_embs, 0) 189 | batch_feature.append(bag_vec.unsqueeze(0)) 190 | 191 | return torch.cat(batch_feature, 0) 192 | 193 | def get_bags_feature(self, bags): 194 | ''' 195 | get all bags embedding in one batch before Attention 196 | ''' 197 | bags_feature = [] 198 | for bag in bags: 199 | if self.opt.use_gpu: 200 | data = map(lambda x: Variable(torch.LongTensor(x).cuda()), bag) 201 | else: 202 | data = map(lambda x: Variable(torch.LongTensor(x)), bag) 203 | 204 | bag_embs = self.get_ins_emb(data) # get all instances embedding in one bag 205 | bags_feature.append(bag_embs) 206 | 207 | return bags_feature 208 | 209 | def get_ins_emb(self, x): 210 | ''' 211 | x: all instance in a Bag 212 | ''' 213 | insEnt, _, insX, insPFs, insPool, mask = x 214 | insPF1, insPF2 = [i.squeeze(1) for i in torch.split(insPFs, 1, 1)] 215 | 216 | word_emb = self.word_embs(insX) 217 | pf1_emb = self.pos1_embs(insPF1) 218 | pf2_emb = self.pos2_embs(insPF2) 219 | 220 | x = torch.cat([word_emb, pf1_emb, pf2_emb], 2) # insNum * 1 * maxLen * (word_dim + 2pos_dim) 221 | x = x.unsqueeze(1) # insNum * 1 * maxLen * (word_dim + 2pos_dim) 222 | x = [conv(x).squeeze(3) for conv in self.convs] 223 | x = [self.mask_piece_pooling(i, mask) for i in x] 224 | # x = [self.piece_max_pooling(i, insPool) for i in x] 225 | x = torch.cat(x, 1).tanh() 226 | return x 227 | -------------------------------------------------------------------------------- /models/PCNN_ONE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .BasicModule import BasicModule 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class PCNN_ONE(BasicModule): 11 | ''' 12 | Zeng 2015 DS PCNN 13 | ''' 14 | def __init__(self, opt): 15 | super(PCNN_ONE, self).__init__() 16 | 17 | self.opt = opt 18 | 19 | self.model_name = 'PCNN_ONE' 20 | 21 | self.word_embs = nn.Embedding(self.opt.vocab_size, self.opt.word_dim) 22 | self.pos1_embs = nn.Embedding(self.opt.pos_size, self.opt.pos_dim) 23 | self.pos2_embs = nn.Embedding(self.opt.pos_size, self.opt.pos_dim) 24 | 25 | feature_dim = self.opt.word_dim + self.opt.pos_dim * 2 26 | 27 | # for more filter size 28 | self.convs = nn.ModuleList([nn.Conv2d(1, self.opt.filters_num, (k, feature_dim), padding=(int(k / 2), 0)) for k in self.opt.filters]) 29 | 30 | all_filter_num = self.opt.filters_num * len(self.opt.filters) 31 | 32 | if self.opt.use_pcnn: 33 | all_filter_num = all_filter_num * 3 34 | masks = torch.FloatTensor(([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])) 35 | if self.opt.use_gpu: 36 | masks = masks.cuda() 37 | self.mask_embedding = nn.Embedding(4, 3) 38 | self.mask_embedding.weight.data.copy_(masks) 39 | self.mask_embedding.weight.requires_grad = False 40 | 41 | self.linear = nn.Linear(all_filter_num, self.opt.rel_num) 42 | self.dropout = nn.Dropout(self.opt.drop_out) 43 | 44 | self.init_model_weight() 45 | self.init_word_emb() 46 | 47 | def init_model_weight(self): 48 | ''' 49 | use xavier to init 50 | ''' 51 | for conv in self.convs: 52 | nn.init.xavier_uniform_(conv.weight) 53 | nn.init.constant_(conv.bias, 0.0) 54 | 55 | nn.init.xavier_uniform_(self.linear.weight) 56 | nn.init.constant_(self.linear.bias, 0.0) 57 | 58 | def init_word_emb(self): 59 | 60 | def p_2norm(path): 61 | v = torch.from_numpy(np.load(path)) 62 | if self.opt.norm_emb: 63 | v = torch.div(v, v.norm(2, 1).unsqueeze(1)) 64 | v[v != v] = 0.0 65 | return v 66 | 67 | w2v = p_2norm(self.opt.w2v_path) 68 | p1_2v = p_2norm(self.opt.p1_2v_path) 69 | p2_2v = p_2norm(self.opt.p2_2v_path) 70 | 71 | if self.opt.use_gpu: 72 | self.word_embs.weight.data.copy_(w2v.cuda()) 73 | self.pos1_embs.weight.data.copy_(p1_2v.cuda()) 74 | self.pos2_embs.weight.data.copy_(p2_2v.cuda()) 75 | else: 76 | self.pos1_embs.weight.data.copy_(p1_2v) 77 | self.pos2_embs.weight.data.copy_(p2_2v) 78 | self.word_embs.weight.data.copy_(w2v) 79 | 80 | def mask_piece_pooling(self, x, mask): 81 | ''' 82 | refer: https://github.com/thunlp/OpenNRE 83 | A fast piecewise pooling using mask 84 | ''' 85 | x = x.unsqueeze(-1).permute(0, 2, 1, -1) 86 | masks = self.mask_embedding(mask).unsqueeze(-2) * 100 87 | x = masks.float() + x 88 | x = torch.max(x, 1)[0] - torch.FloatTensor([100]).cuda() 89 | x = x.view(-1, x.size(1) * x.size(2)) 90 | return x 91 | 92 | def piece_max_pooling(self, x, insPool): 93 | ''' 94 | old version piecewise 95 | ''' 96 | split_batch_x = torch.split(x, 1, 0) 97 | split_pool = torch.split(insPool, 1, 0) 98 | batch_res = [] 99 | for i in range(len(split_pool)): 100 | ins = split_batch_x[i].squeeze() # all_filter_num * max_len 101 | pool = split_pool[i].squeeze().data # 2 102 | seg_1 = ins[:, :pool[0]].max(1)[0].unsqueeze(1) # all_filter_num * 1 103 | seg_2 = ins[:, pool[0]: pool[1]].max(1)[0].unsqueeze(1) # all_filter_num * 1 104 | seg_3 = ins[:, pool[1]:].max(1)[0].unsqueeze(1) 105 | piece_max_pool = torch.cat([seg_1, seg_2, seg_3], 1).view(1, -1) # 1 * 3all_filter_num 106 | batch_res.append(piece_max_pool) 107 | 108 | out = torch.cat(batch_res, 0) 109 | assert out.size(1) == 3 * self.opt.filters_num 110 | return out 111 | 112 | def forward(self, x, train=False): 113 | 114 | insEnt, _, insX, insPFs, insPool, insMasks = x 115 | insPF1, insPF2 = [i.squeeze(1) for i in torch.split(insPFs, 1, 1)] 116 | 117 | word_emb = self.word_embs(insX) 118 | pf1_emb = self.pos1_embs(insPF1) 119 | pf2_emb = self.pos2_embs(insPF2) 120 | 121 | x = torch.cat([word_emb, pf1_emb, pf2_emb], 2) 122 | x = x.unsqueeze(1) 123 | x = self.dropout(x) 124 | 125 | x = [conv(x).squeeze(3) for conv in self.convs] 126 | if self.opt.use_pcnn: 127 | x = [self.mask_piece_pooling(i, insMasks) for i in x] 128 | # x = [self.piece_max_pooling(i, insPool) for i in x] 129 | else: 130 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 131 | x = torch.cat(x, 1).tanh() 132 | x = self.dropout(x) 133 | x = self.linear(x) 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .PCNN_ONE import PCNN_ONE 4 | from .PCNN_ATT import PCNN_ATT 5 | -------------------------------------------------------------------------------- /plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 27, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import os\n", 13 | "import matplotlib as mpl\n", 14 | "# mpl.rcParams['figure.figsize'] = (15,15)\n", 15 | "\n", 16 | "mpl.rcParams['figure.figsize'] = (9,6)\n", 17 | "plt.ioff()\n", 18 | "\n", 19 | "color = [ 'grey', 'r', 'b', 'black','teal','cornflowerblue', 'g', 'gray', 'c', 'r','m', 'y', 'k']\n", 20 | "label_font_size = 18\n", 21 | "marker = ['>', 'v', '^', 'o', 's']\n", 22 | "\n", 23 | "d = 'out/'\n", 24 | "def getXY(file_name, n = 3000):\n", 25 | " x, y = [], []\n", 26 | " with open(file_name) as f:\n", 27 | " for i, line in enumerate(f):\n", 28 | " entries = line.split()\n", 29 | " y.append(float(entries[0]))\n", 30 | " x.append(float(entries[1]))\n", 31 | " if i >= n:\n", 32 | " break\n", 33 | " return x, y\n", 34 | "\n", 35 | "def plot_mul(files):\n", 36 | " '''\n", 37 | " 对比不同模型的PR曲线\n", 38 | " 传入列表,元素为文件完整名\n", 39 | " '''\n", 40 | " print d\n", 41 | " for i, f in enumerate(files):\n", 42 | " path = './{}/{}'.format(d,f)\n", 43 | " if not os.path.exists(path):\n", 44 | " print('{} is not exists'.format(f))\n", 45 | " continue\n", 46 | "\n", 47 | " x, y = getXY(path)\n", 48 | " plt.plot(x,y, marker = marker[i%len(marker)], markevery = 100, markersize = 5, color = color[i%len(color)])\n", 49 | " legend = ['_'.join(i.split('_')[:-2]) for i in files]\n", 50 | " plt.legend(legend, prop={'size':12})\n", 51 | " plt.ylim([0.3, 1])\n", 52 | " plt.xlim([0.0, 0.5])\n", 53 | " plt.xlabel('Recall', fontsize=label_font_size)\n", 54 | " plt.ylabel('Precision', fontsize=label_font_size)\n", 55 | " plt.gca().tick_params(labelsize=16)\n", 56 | " plt.grid(linestyle='dashdot')\n", 57 | " plt.show() \n", 58 | " \n", 59 | "\n", 60 | "def plot_one(prefix, flag=True):\n", 61 | " '''\n", 62 | " 绘制同一个模型不同epoch的PR曲线\n", 63 | " 传入模型前缀即可(如: PCNN_ATT_DEF)\n", 64 | " '''\n", 65 | " fid = []\n", 66 | " print d\n", 67 | " for i in range(1, 18):\n", 68 | " if flag:\n", 69 | " path = './{}/{}_{}_PR.txt'.format(d, prefix, i)\n", 70 | " else:\n", 71 | " path = './{}/{}_{}.txt'.format(d, prefix, i)\n", 72 | " if not os.path.exists(path):\n", 73 | " #print path\n", 74 | " continue\n", 75 | " fid.append(i)\n", 76 | " x, y = getXY(path)\n", 77 | " plt.plot(x,y, marker = '>', markevery = 100, markersize = 5, color = color[(i-1)%len(color)])\n", 78 | " \n", 79 | " plt.legend([prefix + str(i) for i in fid],prop={'size':10})\n", 80 | " plt.ylim([0.2, 1])\n", 81 | " plt.xlim([0.0, 0.5])\n", 82 | " plt.xlabel('Recall', fontsize=label_font_size,)\n", 83 | " plt.ylabel('Precision', fontsize=label_font_size)\n", 84 | " plt.gca().tick_params(labelsize=16)\n", 85 | " plt.grid(linestyle='dashdot')\n", 86 | " plt.show()" 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 2", 93 | "language": "python", 94 | "name": "python2" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 2 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython2", 106 | "version": "2.7.12" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 2 111 | } 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import time 5 | 6 | 7 | def now(): 8 | return str(time.strftime('%Y-%m-%d %H:%M:%S')) 9 | 10 | 11 | def save_pr(out_dir, name, epoch, pre, rec, fp_res=None, opt=None): 12 | if opt is None: 13 | out = open('{}/{}_{}_PR.txt'.format(out_dir, name, epoch + 1), 'w') 14 | else: 15 | out = open('{}/{}_{}_{}_PR.txt'.format(out_dir, name, opt, epoch + 1), 'w') 16 | 17 | if fp_res is not None: 18 | fp_out = open('{}/{}_{}_FP.txt'.format(out_dir, name, epoch + 1), 'w') 19 | for idx, r, p in fp_res: 20 | fp_out.write('{} {} {}\n'.format(idx, r, p)) 21 | fp_out.close() 22 | 23 | for p, r in zip(pre, rec): 24 | out.write('{} {}\n'.format(p, r)) 25 | 26 | out.close() 27 | 28 | 29 | def eval_metric(true_y, pred_y, pred_p): 30 | ''' 31 | calculate the precision and recall for p-r curve 32 | reglect the NA relation 33 | ''' 34 | assert len(true_y) == len(pred_y) 35 | positive_num = len([i for i in true_y if i[0] > 0]) 36 | index = np.argsort(pred_p)[::-1] 37 | 38 | tp = 0 39 | fp = 0 40 | fn = 0 41 | all_pre = [0] 42 | all_rec = [0] 43 | fp_res = [] 44 | 45 | for idx in range(len(true_y)): 46 | i = true_y[index[idx]] 47 | j = pred_y[index[idx]] 48 | 49 | if i[0] == 0: # NA relation 50 | if j > 0: 51 | fp_res.append((index[idx], j, pred_p[index[idx]])) 52 | fp += 1 53 | else: 54 | if j == 0: 55 | fn += 1 56 | else: 57 | for k in i: 58 | if k == -1: 59 | break 60 | if k == j: 61 | tp += 1 62 | break 63 | 64 | if fp + tp == 0: 65 | precision = 1.0 66 | else: 67 | precision = tp * 1.0 / (tp + fp) 68 | recall = tp * 1.0 / positive_num 69 | if precision != all_pre[-1] or recall != all_rec[-1]: 70 | all_pre.append(precision) 71 | all_rec.append(recall) 72 | 73 | print("tp={}; fp={}; fn={}; positive_num={}".format(tp, fp, fn, positive_num)) 74 | return all_pre[1:], all_rec[1:], fp_res 75 | --------------------------------------------------------------------------------