├── .gitignore ├── README.md ├── svm_model.py ├── token.py ├── ace_data.py ├── feature_extract.py └── ace_filereader.py /.gitignore: -------------------------------------------------------------------------------- 1 | /*.pyc 2 | /*.pkl 3 | /Chinese 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | anaconda 3 | jieba 4 | scikit-learn 5 | 6 | # Run 7 | step1 move the `\Chinese` in [ace2005](https://catalog.ldc.upenn.edu/LDC2006T06) to this folder 8 | 9 | step2 extract the feature `feature.pkl` 10 | ```bash 11 | python feature_extract.py 12 | ``` 13 | step3 train the svm_model and test it 14 | ```bash 15 | python svm_model.py 16 | ``` 17 | -------------------------------------------------------------------------------- /svm_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 03 16:32:36 2017 4 | 5 | @author: DaCapo 6 | """ 7 | 8 | import pickle 9 | import numpy as np 10 | import matplotlib as plt 11 | 12 | 13 | def softmax(x): 14 | x -= np.max(x,axis=-1,keepdims=True) 15 | x = np.exp(x)/np.sum(np.exp(x),axis=-1,keepdims=True) 16 | return x 17 | 18 | def computsx(sx, alpha): 19 | prex = np.argmax(sx,axis=1) 20 | sxx = sx > alpha 21 | scc = np.sum(sxx,axis=1) 22 | return metrics.f1_score(y_test[scc>0], prex[scc>0], average="micro") 23 | # print(metrics.classification_report(y_test[scc>0], prex[scc>0])) 24 | 25 | with open('feature.pkl', 'rb') as f: 26 | data = pickle.load(f) 27 | 28 | from sklearn import svm 29 | from sklearn import preprocessing 30 | from sklearn import metrics 31 | 32 | 33 | X = np.array([x[0] for x in data["features"].values()]) 34 | y = np.array([x[1] for x in data["features"].values()]) 35 | 36 | 37 | X_scaled = preprocessing.scale(X) 38 | 39 | #np.random.choice(5, 3) 40 | X_train = X_scaled[:8000] 41 | y_train = y[:8000] 42 | X_test = X_scaled[8000:] 43 | y_test = y[8000:] 44 | print "初始化" 45 | clf = svm.SVC(probability=True, decision_function_shape="ovr") 46 | print "训练中" 47 | clf.fit(X_train, y_train) 48 | predicted = clf.predict(X_test) 49 | #clf.predict_proba(X_test) 50 | #print float(np.sum(clf.predict(X_test)==y_test))/float(len(y_test)) 51 | print(metrics.classification_report(y_test, predicted)) 52 | print(metrics.confusion_matrix(y_test, predicted)) 53 | sx = softmax(clf.decision_function(X_test)) 54 | #sx /= np.max(sx,axis=-1,keepdims=True) 55 | 56 | f1_s = [] 57 | for i in range(1,100): 58 | alpha = i/100.0 59 | f1_s.append(computsx(sx, alpha)) 60 | 61 | -------------------------------------------------------------------------------- /token.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 30 17:04:58 2017 4 | 5 | @author: DaCapo 6 | """ 7 | 8 | #categories = ['alt.atheism', 'soc.religion.christian', 9 | # 'comp.graphics', 'sci.med'] 10 | 11 | #from sklearn.datasets import fetch_20newsgroups 12 | #twenty_train = fetch_20newsgroups(subset='train', 13 | # categories=categories, shuffle=True, random_state=42) 14 | 15 | 16 | import ace_data 17 | from sklearn.feature_extraction.text import CountVectorizer,TfidfTransformer 18 | import jieba 19 | import numpy as np 20 | import re 21 | 22 | def has_chinese(s): 23 | """判断一个unicode是否是汉字""" 24 | return re.match(ur"[\u4e00-\u9fa5]",s) 25 | 26 | def has_number(s): 27 | """判断一个unicode是否是数字""" 28 | return re.match(ur"[\u0030-\u0039]",s) 29 | 30 | def has_alphabet(s): 31 | """判断一个unicode是否是英文字母""" 32 | return re.match(ur"[\u0041-\u005a]",s) or re.match(ur"[\u0061-\u007a]",s) 33 | 34 | def checklist(l, callback): 35 | return [ x for x in l if callback(x[4])] 36 | 37 | 38 | #建立词袋 39 | def get_tokens(): 40 | vectorizer = CountVectorizer() 41 | # global data, cut_docs 42 | data = ace_data.load() 43 | docs = data["docs"] 44 | cut_docs = [] 45 | BOW_data = [] 46 | #nes = erdata.conbime_list_dic(data["nes"].values()) 47 | # xxx =[ x for x in nes.values() if has_number(x[4])] 48 | nes = data["nes"] 49 | 50 | print "载入字典" 51 | #for ne in nes: 52 | # jieba.add_word(ne[4]) 53 | # BOW_data += ne[4]+" " 54 | 55 | print "开始分词" 56 | #分词 57 | for k in docs: 58 | result = ' '.join(jieba.cut(docs[k])) 59 | cut_docs.append(result) 60 | this_nes = [x[1][4] for x in nes[k].items()] 61 | BOW_data.append(result+" "+" ".join(this_nes)) 62 | 63 | print "创建词袋" 64 | X_train_counts = vectorizer.fit_transform(cut_docs) 65 | print X_train_counts.shape 66 | 67 | tfidf_transformer = TfidfTransformer() 68 | X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts) 69 | X_train_tfidf.shape 70 | tf = np.sum(X_train_counts.A, axis=0) 71 | tokens = vectorizer.get_feature_names() 72 | tf_dict = {} 73 | for i in range(len(tokens)): 74 | if not has_number(tokens[i]): 75 | tf_dict[tokens[i]] = tf[i] 76 | 77 | dict_sorted = sorted(tf_dict.iteritems(), key=lambda d:d[1], reverse=True) 78 | tf_tokens = [key for key,value in dict_sorted][:20000] 79 | 80 | return tf_tokens 81 | 82 | if __name__ == "__main__": 83 | tokens = get_tokens() 84 | print get_token(u"66", tokens) 85 | 86 | -------------------------------------------------------------------------------- /ace_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 30 19:35:26 2017 4 | 5 | @author: DaCapo 6 | """ 7 | 8 | import os 9 | from xml.dom.minidom import parse 10 | import xml.etree.ElementTree as ET 11 | import pickle 12 | import ace_filereader as fr 13 | 14 | #获取文档(弃用) 15 | def getText(): 16 | filename = [] 17 | data=[] 18 | for root, dirs, files in os.walk("Chinese"): 19 | for fn in files: 20 | root_arr = root.split("\\") 21 | X = "" 22 | if(fn.find(".sgm")>0 and root.find("\\adj")>0): 23 | doc = parse(root+"\\"+fn) 24 | if(root_arr[1]=="wl"): 25 | for turn in doc.getElementsByTagName('POST'): 26 | X += turn.childNodes[-1].data.replace("\n","").replace(" ","") 27 | data.append(X) 28 | filename.append(fn) 29 | elif(root_arr[1]=="nw"): 30 | for turn in doc.getElementsByTagName('TEXT'): 31 | X += turn.childNodes[-1].data.replace("\n","").replace(" ","") 32 | data.append(X) 33 | filename.append(fn) 34 | 35 | else: 36 | for turn in doc.getElementsByTagName('TURN'): 37 | X += turn.childNodes[-1].data.replace("\n","").replace(" ","") 38 | data.append(X) 39 | filename.append(fn) 40 | assert(X!="") 41 | return dict(zip(filename, data)) 42 | 43 | #遍历文件获取所有的实体 关系 文档 44 | def get_ERDs(): 45 | E = {} 46 | R = {} 47 | D = {} 48 | for root, dirs, files in os.walk("Chinese"): 49 | for fn in files: 50 | # root_arr = root.split("\\") 51 | # fn_arr = fn.split(".") 52 | if(fn.find(".sgm")>0 and root.find("\\adj")>0): 53 | f_no = fn[0:-4] 54 | named_entities, rels, doc = fr.get_ERD(root+"\\"+f_no) 55 | E[f_no]=named_entities 56 | R[f_no]=rels 57 | D[f_no]=doc 58 | return E, R , D 59 | 60 | #导入ace的数据 61 | def load(): 62 | #判断是否需要创建文件 63 | if not os.path.exists(r'nes_res.pkl'): 64 | create() 65 | #读取文件 66 | with open('nes_res.pkl', 'rb') as f: 67 | data = pickle.load(f) 68 | 69 | return data 70 | 71 | #创建 实体 关系 文档的数据文件 72 | def create(): 73 | nes, res, docs = get_ERDs() 74 | # texts = getText() 75 | 76 | with open('nes_res.pkl', 'wb') as output: 77 | pickle.dump({"nes":nes,"res":res, "docs":docs}, output) 78 | 79 | def conbime_list_dic(ld): 80 | return dict(pair for d in ld for pair in d.items()) 81 | 82 | 83 | if __name__ == "__main__": 84 | datatest = load() 85 | 86 | # for x in rl: 87 | # if x =="METONYMY": 88 | # print x -------------------------------------------------------------------------------- /feature_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 03 10:30:53 2017 4 | 5 | @author: DaCapo 6 | """ 7 | 8 | import ace_data 9 | import jieba 10 | import numpy as np 11 | import re 12 | from nltk.tokenize.stanford_segmenter import StanfordSegmenter 13 | import jieba.posseg as pseg 14 | import token 15 | 16 | 17 | 18 | #词性对应表list 19 | #POS = u"eng Mg Rg a ad ag al an b begin bl c cc d dg dl e end f gb gc gg gi gm gp i j k l m mq n nba nbc nf ng nhd nhm nis nit nmc nnd nnt nr nr1 nr2 nrf nrj ns nsf nt ntc ntcb ntcf ntch nth nto nts ntu nz o p pba pbei q qt qv r rr ry rys ryt ryv rz rzs rzt rzv s t tg u ude1 ude2 ude3 udeng udh uguo ule ulian uls usuo uyy uzhe uzhi v vd vf vg vi vl vn vshi vx vyou w x y z" 20 | POS = u"eng a ad ag an b c d df dg e f g h i j k l m mg mq n ng nr nrfg nrt ns nt nz o p q r rg rr rz s t tg u ud ug uj ul uv uz v vd vg vi vn vq x y z zg" 21 | POS = POS.split(" ") 22 | 23 | #获取一个词的词袋对应数字 24 | def get_token(s, tokens): 25 | return tokens.index(s)+1 if s in tokens else 0 26 | 27 | #对文档s分词 28 | def get_seg(docs): 29 | cut_result={} 30 | for f_no in docs: 31 | doc = docs[f_no] 32 | st = [] 33 | cutw = [] 34 | offset=0 35 | words = pseg.cut(doc) 36 | for word, flag in words: 37 | if not word == " ": 38 | st.append(offset) 39 | cutw.append((word,POS.index(flag))) 40 | offset += len(word) 41 | cut_result[f_no] = [st,cutw] 42 | return cut_result 43 | 44 | #导入数据 45 | data = ace_data.load() 46 | docs = data["docs"] 47 | nes = data["nes"] 48 | res = data["res"] 49 | 50 | #导入词袋 51 | tokens = token.get_tokens() 52 | 53 | #文本分词 54 | seg_docs = get_seg(docs) 55 | 56 | #获取各种的类型对应表list 57 | el = sorted(list(set([x[1] for f in nes.values() for x in f.values()]))) 58 | esl= sorted(list(set([x[-1] for f in nes.values() for x in f.values()]))) 59 | rl= sorted(list(set([x[1] for f in res.values() for x in f.values()]))) 60 | rsl= sorted(list(set([x[2] for f in res.values() for x in f.values()]))) 61 | 62 | #提取特征过程 63 | w = 2 64 | features = {} 65 | lables = {} 66 | for f_no in res: 67 | sts = seg_docs[f_no][0] 68 | words = seg_docs[f_no][1] 69 | for r_no in res[f_no]: 70 | e1 = nes[f_no][res[f_no][r_no][4]] 71 | e2 = nes[f_no][res[f_no][r_no][5]] 72 | st1 = e1[2] 73 | ed1 = e1[3] 74 | st2 = e2[2] 75 | ed2 = e2[3] 76 | sidx1 = sts.index(st1) 77 | eidx1 = sidx1 78 | while sts[eidx1+1] < ed1: 79 | eidx1 += 1 80 | sidx2 = sts.index(st2) 81 | eidx2 = sidx2 82 | while sts[eidx2+1] < ed2: 83 | eidx2 += 1 84 | 85 | #判断包含关系 86 | if res[f_no][r_no][6][1] <= res[f_no][r_no][7][0]: 87 | order = 0 88 | elif res[f_no][r_no][7][1] <= res[f_no][r_no][6][0]: 89 | order = 1 90 | else: 91 | order = 2 92 | 93 | #特征组合 94 | feature= [el.index(e1[1]), el.index(e2[1]), esl.index(e1[-1]), esl.index(e2[-1]), order] 95 | e1_w, e1_t, e2_w, e2_t = [],[],[],[] 96 | for i in range(w): 97 | e1_w.insert(0,get_token(words[sidx1-i][0],tokens)) 98 | e1_w.append(get_token(words[sidx1+i][0],tokens)) 99 | e1_t.insert(0,words[sidx1-i][1]) 100 | e1_t.append(words[eidx1+i][1]) 101 | e2_w.insert(0,get_token(words[sidx2-i][0],tokens)) 102 | e2_w.append(get_token(words[sidx2+i][0],tokens)) 103 | e2_t.insert(0,words[sidx2-i][1]) 104 | e2_t.append(words[eidx2+i][1]) 105 | 106 | feature.extend(e1_w) 107 | feature.extend(e1_t) 108 | feature.extend(e2_w) 109 | feature.extend(e2_t) 110 | 111 | features[r_no] = (feature, rl.index(res[f_no][r_no][1]), rsl.index(res[f_no][r_no][2])) 112 | 113 | #特征结果保存 114 | with open('feature.pkl', 'wb') as output: 115 | out_data = {"features":features, 116 | "POS":POS,"el":el,"esl":esl,"rl":rl,"rsl":rsl, 117 | "tokens":tokens } 118 | pickle.dump(out_data, output) 119 | 120 | -------------------------------------------------------------------------------- /ace_filereader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 03 10:30:53 2017 4 | 5 | @author: DaCapo 6 | """ 7 | 8 | import xml.etree.ElementTree as ET 9 | import sys, re 10 | import codecs 11 | 12 | import HTMLParser 13 | parser = HTMLParser.HTMLParser() 14 | 15 | #从文件中读取实体、关系和文档 16 | def get_ERD(path): 17 | apf_tree = ET.parse(path+".apf.xml") 18 | apf_root = apf_tree.getroot() 19 | 20 | # 读取实体 21 | named_entities = {} 22 | check_nes = {} 23 | ne_starts={} 24 | ne_ends={} 25 | ne_map = {} 26 | for entity in apf_root.iter('entity'): 27 | ne_type = entity.attrib["TYPE"] 28 | ne_subtype = entity.attrib["SUBTYPE"] if entity.attrib.has_key('SUBTYPE') else "" 29 | for mention in entity.iter('entity_mention'): 30 | ne_id = mention.attrib["ID"] 31 | for child in mention: 32 | if child.tag == 'head': 33 | for charseq in child: 34 | start = int(charseq.attrib["START"]) 35 | end = int(charseq.attrib["END"])+1 36 | text = re.sub(r"\n", r"", charseq.text) 37 | ne_tuple = (ne_type, start, end, text) 38 | if ne_tuple in check_nes: 39 | sys.stderr.write("duplicated entity %s\n" % (ne_id)) 40 | ne_map[ne_id] = check_nes[ne_tuple] 41 | continue 42 | check_nes[ne_tuple] = ne_id 43 | named_entities[ne_id] = [ne_id, ne_type, start, end, text, ne_subtype] 44 | if not start in ne_starts: 45 | ne_starts[start] = [] 46 | ne_starts[start].append(ne_id) 47 | if not end in ne_ends: 48 | ne_ends[end] = [] 49 | ne_ends[end].append(ne_id) 50 | 51 | # 关系 52 | rels = {} 53 | check_rels = [] 54 | for relation in apf_root.iter('relation'): 55 | rel_type = relation.attrib["TYPE"] 56 | rel_subtype = relation.attrib["SUBTYPE"] if relation.attrib.has_key('SUBTYPE') else "" 57 | for mention in relation.iter('relation_mention'): 58 | rel_id = mention.attrib["ID"] 59 | for child in mention: 60 | if child.tag == 'extent': 61 | for charseq in child: 62 | text = re.sub(r"\n", r"", charseq.text) 63 | rel = [rel_id, rel_type, rel_subtype, text] 64 | if rel_type =="METONYMY": 65 | print rel_type 66 | ignore = False 67 | for arg in mention.iter('relation_mention_argument'): 68 | arg_id = arg.attrib["REFID"] 69 | if arg.attrib["ROLE"] != "Arg-1" and arg.attrib["ROLE"] != "Arg-2": 70 | continue 71 | if arg_id in ne_map: 72 | arg_id = ne_map[arg_id] 73 | for child in arg: 74 | for charseq in child: 75 | start = int(charseq.attrib["START"]) 76 | end = int(charseq.attrib["END"])+1 77 | text = re.sub(r"\n", r"", charseq.text) 78 | rel.append(arg_id) 79 | rel.append([start,end,text]) 80 | if not arg_id in named_entities: 81 | ignore = True 82 | # ignored duplicated entity 83 | 84 | if ignore: 85 | sys.stderr.write("ignored relation %s\n" % (rel_id)) 86 | continue 87 | if rel[1:] in check_rels: 88 | sys.stderr.write("duplicated relation %s\n" % (rel_id)) 89 | continue 90 | check_rels.append(rel[1:]) 91 | rel[5],rel[6] = rel[6],rel[5] 92 | rels[rel_id] = rel 93 | 94 | 95 | # 文档 96 | with codecs.open(path+".sgm", 'r', 'utf-8') as f: 97 | doc = "".join(f.readlines()) 98 | # doc = re.sub(r"&", "&", doc) 99 | doc = re.sub(r"<[^>]+>", "", doc) 100 | doc = re.sub(r"(\S+)\n(\S[^:])", r"\1 \2", doc) 101 | 102 | offset = 0 103 | size = len(doc) 104 | current = 0 105 | regions = [] 106 | for i in range(size): 107 | if i in ne_starts or i in ne_ends : 108 | inc = 0 109 | if (doc[i-1] != " " and doc[i-1] != "\n") and (doc[i] != " " and doc[i] != "\n"): 110 | regions.append(doc[current:i]) 111 | inc = 1 112 | current = i 113 | if i in ne_starts: 114 | for ent in ne_starts[i]: 115 | named_entities[ent][2] += offset + inc 116 | if i in ne_ends: 117 | for ent in ne_ends[i]: 118 | named_entities[ent][3] += offset 119 | offset+=inc 120 | regions.append(doc[current:]) 121 | doc = " ".join(regions) 122 | 123 | for ne in named_entities.values(): 124 | if "\n" in doc[int(ne[2]):int(ne[3])]: 125 | l = [] 126 | l.append(doc[0:int(ne[2])]) 127 | l.append(doc[int(ne[2]):int(ne[3])].replace("\n", " ")) 128 | l.append(doc[int(ne[3]):]) 129 | doc = "".join(l) 130 | 131 | for rel in rels.values(): 132 | for ne in [rel[6],rel[7]]: 133 | if "\n" in doc[int(ne[0]):int(ne[1])]: 134 | l = [] 135 | l.append(doc[0:int(ne[0])]) 136 | l.append(doc[int(ne[0]):int(ne[1])].replace("\n", " ")) 137 | l.append(doc[int(ne[1]):]) 138 | doc = "".join(l) 139 | 140 | for ne in named_entities.values(): 141 | # print parser.unescape(doc[int(ne[2]):int(ne[3])]), ne[4], ne[0] 142 | assert parser.unescape(doc[int(ne[2]):int(ne[3])]).replace("&", "&").replace("&", "&").replace(" ", "") == ne[4].replace(" ",""), "%s <=> %s" % (doc[int(ne[2]):int(ne[3])], ne[4]) 143 | 144 | # for rel in rels.values(): 145 | # for ne in [rel[6],rel[7]]: 146 | # print parser.unescape(doc[int(ne[0]):int(ne[1])]).replace("&", "&").replace("&", "&").replace(" ", ""), ne[2].replace(" ","") 147 | # print parser.unescape(doc[int(ne[0]):int(ne[1])]).replace("&", "&").replace("&", "&").replace(" ", "")==ne[2].replace(" ","") 148 | # assert parser.unescape(doc[int(ne[0]):int(ne[1])]).replace("&", "&").replace("&", "&").replace(" ", "") == ne[2].replace(" ",""), "%s <=> %s" % (doc[int(ne[2]):int(ne[0])], ne[1]) 149 | # 150 | 151 | return named_entities, rels, doc 152 | 153 | 154 | if __name__ == "__main__": 155 | nesssssss,relsss,doccccc = get_ERD("Chinese\\wl\\adj\\DAVYZW_20050110.1403") 156 | checks = [(ne[4],doccccc[int(ne[2]):int(ne[3])]) for ne in nesssssss.values()] --------------------------------------------------------------------------------