├── data ├── attrs.json ├── data_preprocess.py ├── entities.json ├── relationships.json ├── schema.json ├── yanbao │ └── yanbao000.txt └── yanbao_txt │ ├── yanbao000.txt │ ├── yanbao001.txt │ ├── yanbao002.txt │ ├── yanbao003.txt │ ├── yanbao004.txt │ └── yanbao005.txt ├── expirement_attr ├── __init__.py └── etlspan │ ├── __init__.py │ ├── eval.py │ ├── models │ ├── __init__.py │ ├── layers.py │ ├── model.py │ └── submodel.py │ ├── train.py │ └── utils │ ├── __init__.py │ ├── constant.py │ ├── helper.py │ ├── loader.py │ ├── result.py │ ├── score.py │ ├── torch_utils.py │ └── vocab.py ├── expirement_er ├── __init__.py ├── bertMRC │ ├── config │ │ ├── en_bert_base_cased.json │ │ ├── en_bert_base_uncased.json │ │ └── zh_bert.json │ ├── data_loader │ │ ├── __init__.py │ │ ├── bert_tokenizer.py │ │ ├── entities_type.py │ │ ├── model_config.py │ │ ├── mrc_data_loader.py │ │ ├── mrc_data_processor.py │ │ └── mrc_utils.py │ ├── layer │ │ ├── __init__.py │ │ ├── bert_basic_model.py │ │ ├── bert_layernorm.py │ │ ├── bert_utils.py │ │ ├── classifier.py │ │ └── optim.py │ ├── metric │ │ ├── __init__.py │ │ ├── flat_span_f1.py │ │ ├── mrc_ner_evaluate.py │ │ └── nest_span_f1.py │ ├── model │ │ ├── __init__.py │ │ └── bert_mrc.py │ └── train_bert_mrc.py ├── bertcrf │ ├── __init__.py │ ├── eval.py │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ ├── mycrf.py │ │ └── submodel.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── constant.py │ │ ├── helper.py │ │ ├── loader.py │ │ ├── result.py │ │ ├── score.py │ │ ├── torch_utils.py │ │ └── vocab.py ├── etlspan │ ├── __init__.py │ ├── eval.py │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ └── submodel.py │ ├── saved_models │ │ └── baiduRE │ │ │ └── 5_17_bert_2 │ │ │ └── 00 │ │ │ ├── best_dev_results.json │ │ │ ├── config.json │ │ │ └── logs.txt │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── constant.py │ │ ├── helper.py │ │ ├── loader.py │ │ ├── result.py │ │ ├── score.py │ │ ├── torch_utils.py │ │ └── vocab.py └── spert │ ├── __init__.py │ ├── args.py │ ├── config_reader.py │ ├── configs │ ├── example_eval.conf │ └── example_train.conf │ ├── spert.py │ └── spert │ ├── __init__.py │ ├── entities.py │ ├── evaluator.py │ ├── input_reader.py │ ├── loss.py │ ├── models.py │ ├── opt.py │ ├── sampling.py │ ├── spert_trainer.py │ ├── trainer.py │ └── util.py ├── expirement_re ├── __init__.py ├── bert_att_mis │ ├── __init__.py │ ├── checkpoints │ │ └── pcnnone_DEF.pth │ ├── config.py │ ├── dataloader.py │ ├── main_att.py │ ├── models │ │ ├── BasicModule.py │ │ ├── __init__.py │ │ └── bert_att_mis.py │ └── utils.py ├── bert_one_mis │ ├── __init__.py │ ├── checkpoints │ │ └── pcnnone_DEF.pth │ ├── config.py │ ├── dataloader.py │ ├── main_mil.py │ ├── models │ │ ├── BasicModule.py │ │ ├── __init__.py │ │ └── bert_one_mis.py │ └── utils.py └── data │ ├── baidu19 │ ├── data_transfer.py │ ├── train.txt │ ├── train_mulit.txt │ ├── valid.txt │ └── valid_mulit.txt │ ├── train_mulit.txt │ └── valid_mulit.txt ├── extract_attrs.py ├── extract_entities.py ├── extract_relations.py ├── images ├── d1.PNG └── d2.PNG ├── main.py ├── model_attribute └── __init__.py ├── model_entity ├── EtlModel.py ├── HanlpNER.py └── __init__.py ├── model_relation └── __init__.py ├── output └── process.py ├── parameters.py ├── readme.md ├── regulation.py ├── result_process.py ├── test.py └── utils ├── EvaluateScores.py ├── MyDataset.py ├── QuestionText.py ├── __init__.py ├── bert_dict_enhance.py ├── functions.py ├── preprocess_data.py ├── preprocessing.py ├── schemas.py └── torch_utils.py /data/attrs.json: -------------------------------------------------------------------------------- 1 | { 2 | "attrs": [ 3 | [ 4 | "高品质服务+跨越式成长,长期价值可期", 5 | "评级", 6 | "优于大市" 7 | ], 8 | [ 9 | "高品质服务+跨越式成长,长期价值可期", 10 | "发布时间", 11 | "2020/2/19" 12 | ], 13 | [ 14 | "结构性问题改善遥遥无期", 15 | "评级", 16 | "增持" 17 | ], 18 | [ 19 | "结构性问题改善遥遥无期", 20 | "上次评级", 21 | "增持" 22 | ], 23 | [ 24 | "结构性问题改善遥遥无期", 25 | "发布时间", 26 | "2020/02/13" 27 | ], 28 | [ 29 | "Q4业绩延续强势,发行可转债优化财务结构", 30 | "评级", 31 | "买入" 32 | ], 33 | [ 34 | "Q4业绩延续强势,发行可转债优化财务结构", 35 | "上次评级", 36 | "买入" 37 | ], 38 | [ 39 | "Q4业绩延续强势,发行可转债优化财务结构", 40 | "发布时间", 41 | "2020/1/19" 42 | ], 43 | [ 44 | "新能源金属与材料:重申,迎来新周期又三年", 45 | "发布时间", 46 | "2020/2/6" 47 | ], 48 | [ 49 | "崛起中的地产新秀", 50 | "评级", 51 | "买入" 52 | ], 53 | [ 54 | "崛起中的地产新秀", 55 | "发布时间", 56 | "2020/2/18" 57 | ], 58 | [ 59 | "交运角度看复工:进度、优缺点与交叉验证", 60 | "评级", 61 | "买入" 62 | ], 63 | [ 64 | "交运角度看复工:进度、优缺点与交叉验证", 65 | "上次评级", 66 | "买入" 67 | ], 68 | [ 69 | "交运角度看复工:进度、优缺点与交叉验证", 70 | "发布时间", 71 | "2020/2/16" 72 | ], 73 | [ 74 | "携手“鲲鹏”,展翅入“云”", 75 | "评级", 76 | "买入" 77 | ], 78 | [ 79 | "携手“鲲鹏”,展翅入“云”", 80 | "发布时间", 81 | "2020/2/14" 82 | ], 83 | [ 84 | "布局弹性,早周期领跑", 85 | "发布时间", 86 | "2020/2/20" 87 | ], 88 | [ 89 | "布局弹性,早周期领跑", 90 | "评级", 91 | "增持" 92 | ], 93 | [ 94 | "布局弹性,早周期领跑", 95 | "上次评级", 96 | "增持" 97 | ], 98 | [ 99 | "疫情利于调味品企业集中度提升", 100 | "评级", 101 | "优于大势" 102 | ], 103 | [ 104 | "疫情利于调味品企业集中度提升", 105 | "上次评级", 106 | "优于大势" 107 | ], 108 | [ 109 | "疫情利于调味品企业集中度提升", 110 | "发布时间", 111 | "2020/2/25" 112 | ], 113 | [ 114 | "国内海洋油气龙头,成本低、油价弹性大", 115 | "发布时间", 116 | "2020/2/7" 117 | ], 118 | [ 119 | "国内海洋油气龙头,成本低、油价弹性大", 120 | "评级", 121 | "优于大市" 122 | ], 123 | [ 124 | "探索5G手机能量的聚与散:快充、无线充电与热管理", 125 | "评级", 126 | "推荐" 127 | ], 128 | [ 129 | "探索5G手机能量的聚与散:快充、无线充电与热管理", 130 | "上次评级", 131 | "推荐" 132 | ], 133 | [ 134 | "探索5G手机能量的聚与散:快充、无线充电与热管理", 135 | "发布时间", 136 | "2020/2/5" 137 | ], 138 | [ 139 | "羊奶粉龙头,业绩增长稳健", 140 | "评级", 141 | "买入" 142 | ], 143 | [ 144 | "羊奶粉龙头,业绩增长稳健", 145 | "发布时间", 146 | "2020/2/27" 147 | ], 148 | [ 149 | "新中民物业", 150 | "全称", 151 | "新中民物业集团" 152 | ], 153 | [ 154 | "特斯拉", 155 | "英文名", 156 | "Tesla" 157 | ], 158 | [ 159 | "绿地物业", 160 | "全称", 161 | "上海绿地物业服务有限公司" 162 | ], 163 | [ 164 | "美菜", 165 | "全称", 166 | "美菜网" 167 | ], 168 | [ 169 | "世界卫生组织", 170 | "英文名", 171 | "WHO" 172 | ], 173 | [ 174 | "北海康城", 175 | "英文名", 176 | "Canbridge Pharmaceuticals Inc" 177 | ], 178 | [ 179 | "美国疾病预防控制中心", 180 | "英文名", 181 | "CDC" 182 | ], 183 | [ 184 | "《乘用车企业平均燃料消耗量与新能源汽车积分并行管理办法》修正案(征求意见稿)", 185 | "发布时间", 186 | "2019/7/9" 187 | ], 188 | [ 189 | "关于修改《乘用车企业平均燃料消耗量与新能源汽车积分并行管理办法》的决定(征求意见稿)", 190 | "发布时间", 191 | "9/11" 192 | ], 193 | [ 194 | "国家公共卫生监测信息体系建设规划", 195 | "发布时间", 196 | "2004/3/5" 197 | ], 198 | [ 199 | "突发公共卫生事件医疗救治体系建设规划", 200 | "发布时间", 201 | "2004/3/5" 202 | ] 203 | ] 204 | } -------------------------------------------------------------------------------- /data/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def label(entities2label): 6 | filePath = 'yanbao' 7 | labelPath = 'label-7-12' 8 | filename = os.listdir(filePath) 9 | for i in filename: 10 | yanbao_label(entities2label, os.path.join(filePath, i), os.path.join(labelPath, i)) 11 | 12 | 13 | def yanbao_label(entities2label, yanbao_filename, label_filename): 14 | f = open(yanbao_filename, encoding='utf-8') 15 | f_w = open(label_filename, 'w+', encoding='utf-8') 16 | while 1: 17 | line = f.readline() 18 | if not line: 19 | break 20 | f_w.writelines('sentence:' + line) 21 | # 遍历句子将句中的所有实体找出来 22 | sent_entity = [] 23 | for i in entities2label: 24 | if i in line: 25 | sent_entity.append(entities2label[i]) 26 | entities = ' '.join(sent_entity) 27 | if entities: 28 | ent_str = 'entities: ' + entities + '\n' 29 | f_w.writelines(ent_str) 30 | f.close() 31 | f_w.close() 32 | 33 | 34 | def readentitys(): 35 | # 读取文件 36 | # 取得种子知识图谱中的实体 37 | with open('entities.json', encoding='utf-8') as f: 38 | entities_dict = json.load(f) 39 | # 取得新抽取出的实体 40 | with open('answers-7-12.json', encoding='utf-8') as f: 41 | data_dict = json.load(f) 42 | new_entities_dict = data_dict['entities'] 43 | # 合并新旧文件中的所有实体 44 | for ent_type, ents in new_entities_dict.items(): 45 | entities_dict[ent_type] = entities_dict[ent_type] + list(set(new_entities_dict[ent_type] + ents)) 46 | # 将所有类型的实体加进一个list中 47 | entities2label = {} 48 | for ent_type, ents in entities_dict.items(): 49 | for entity in entities_dict[ent_type]: 50 | entities2label[entity] = entity + '-' + ent_type 51 | return entities2label 52 | 53 | 54 | if __name__ == "__main__": 55 | entities2label = readentitys() 56 | label(entities2label) -------------------------------------------------------------------------------- /data/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "entity_type": [ 3 | "人物", 4 | "行业", 5 | "业务", 6 | "产品", 7 | "研报", 8 | "机构", 9 | "风险", 10 | "文章", 11 | "指标", 12 | "品牌" 13 | ], 14 | "attrs": { 15 | "研报": { 16 | "发布时间": "date", 17 | "评级": "string", 18 | "上次评级": "string" 19 | }, 20 | "机构": { 21 | "全称": "string", 22 | "英文名": "string" 23 | }, 24 | "文章": { 25 | "发布时间": "date" 26 | } 27 | }, 28 | "relationships": [ 29 | [ 30 | "行业", 31 | "隶属", 32 | "行业" 33 | ], 34 | [ 35 | "机构", 36 | "属于", 37 | "行业" 38 | ], 39 | [ 40 | "研报", 41 | "涉及", 42 | "行业" 43 | ], 44 | [ 45 | "行业", 46 | "拥有", 47 | "风险" 48 | ], 49 | [ 50 | "机构", 51 | "拥有", 52 | "风险" 53 | ], 54 | [ 55 | "机构", 56 | "隶属", 57 | "机构" 58 | ], 59 | [ 60 | "机构", 61 | "投资", 62 | "机构" 63 | ], 64 | [ 65 | "机构", 66 | "并购", 67 | "机构" 68 | ], 69 | [ 70 | "机构", 71 | "客户", 72 | "机构" 73 | ], 74 | [ 75 | "人物", 76 | "任职于", 77 | "机构" 78 | ], 79 | [ 80 | "人物", 81 | "投资", 82 | "机构" 83 | ], 84 | [ 85 | "研报", 86 | "涉及", 87 | "机构" 88 | ], 89 | [ 90 | "机构", 91 | "生产销售", 92 | "产品" 93 | ], 94 | [ 95 | "机构", 96 | "采购买入", 97 | "产品" 98 | ], 99 | [ 100 | "机构", 101 | "开展", 102 | "业务" 103 | ], 104 | [ 105 | "机构", 106 | "拥有", 107 | "品牌" 108 | ], 109 | [ 110 | "产品", 111 | "属于", 112 | "品牌" 113 | ], 114 | [ 115 | "机构", 116 | "发布", 117 | "文章" 118 | ], 119 | [ 120 | "研报", 121 | "采用", 122 | "指标" 123 | ] 124 | ] 125 | } -------------------------------------------------------------------------------- /data/yanbao/yanbao000.txt: -------------------------------------------------------------------------------- 1 | 赋能全球生物药研发的 CRO/CDMO 龙头 2 | 3 | 药明生物:高速发展的全球生物药 CRO/CDMO 龙头 4 | 5 | 全球唯一提供全方位生物制剂研发服务的企业 6 | 7 | 药明生物为全球目前唯一能提供全方位生物制剂研发服务的企业,研发管线涵盖各类生物药。截至2019年中期,药明生物综合项目数224个,其中单抗项目132个,融合蛋白项目35个,抗体偶联药物25个,双特异性抗体19个,全球新(First-in-Class)项目58个,占比26%。最近一年抗体偶联药物和双特异性抗体项目增长较快,符合生物药行业发展方向和创新趋势,彰显全球领先的技术平台实力。 8 | 药明生物成立于2010年,为CRO龙头公司药明康德的生物药研发外包服务部门,后从药明康德独立出来,并于2017年在香港交易所上市。成立以来,药明生物高速发展,收入从2014年的3.32亿元增长至2018年的25.35亿元(CAGR为66.2%),归母净利润从0.42亿元增长至6.31亿元(CAGR约96.9%)。欧洲、北美和中国地区持续为公司高成长的三驾马车。 9 | 10 | 优秀的商业模式:“Follow the Molecule”一体化解决方案 11 | 12 | 药明生物采用“Follow the Molecule”的业务发展策略,提供一体化解决方案。客户对公司提供服务的需求随着生物制剂开发过程的推进并最终商业化生产而不断增加,并使得公司来自每个项目的综合收益随着项目在生物制剂开发周期中的推进而呈几何级数增加。生物药研发项目难度高、粘性高,早期项目的锁定将使得公司有更多项目进入商业化生产阶段并持续贡献里程碑收入和销售分成收入。 13 | 14 | 综合项目数持续增加,平均阶段收益不断提高 15 | 16 | 公司的综合项目数量增长强劲,从2016年中期的75个增长至2019年中期的224个。每个项目所带来的平均收益也不断提高,从2016年的960万元提高至2018年的1236万元。 17 | 18 | 未完成服务订单及潜在里程碑收入充足。公司的未完成订单总额增长强劲,由2016的2.65亿美元增长至2019年中期的46.3亿美元。其中未完成服务订单总额由2016年的2.41亿美元增长至2019年中期的17.36亿美元;未完成潜在里程碑收入由2016年的0.24亿美元增长至2019中期的28.94亿美元。 19 | 2017年,未完成潜在里程碑收入大幅增加至10亿美元,主要由于公司将PD-1抗体(GLS-010)国际权益授权给Arcus Biosicence,于2017年下半年获得1850万美元的前期许可费,合同总金额高达8.16亿美元及未来预计高达10%的销售分成收入。同时公司获得与Arcus共同研发生物制剂的三年独家合作伙伴关系,并成为GLS-010的独家生产商。类似于此类的交易成为未完成订单的主要增长动力之一。原液及制剂可在药明生物欧洲、中国及美国选择任何两个生产基地生产。 20 | 2018年下半年和2019年上半年,新签商业化生产合同带动未完成服务订单的大幅提升;WuXiBody双抗项目的签订带动未完成潜在里程碑收入的大幅提升。 21 | 22 | 全球客户日益增长,引领创新潮流吸引独家客户 23 | 24 | 药明目前拥有220家全球合作伙伴,其中包括全球前20大药企中的13家以及中国前50大药企中的22家。 25 | 26 | 三驾马车持续推动收入快速增长 27 | 28 | 全球业务快速拓展,欧洲增长十分靓丽。药明生物在北美、中国和欧洲都保持强劲的增长,美国和中国增速均超过50%。全球中欧洲市场表现最为靓丽,过去5年的CAGR高达185%,WuXiBody在欧洲被广泛关注。公司与瑞士10家客户签订了17个项目的合作协议,瑞士也成为公司于欧洲的最大市场。 29 | 技术实力全球领先,全球新(First-in-Class)项目不断增加。2018年,公司的205个项目中已经有51个全球新项目,占比约25%。2019年上半年,公司的224个项目中有58个全球新项目,占比26%。公司的ADC和双抗项目同比增长强劲,可看出客户对药明生物尖端新技术的认可。预计当前项目中有超过20个可以实现商业化生产,销售收入可达20亿美元。预计有超过40个项目未来可能产生超过40亿美元里程碑及每年超过2亿美元的销售分成。 30 | 31 | 多项核心优势,打造全球龙头 32 | 33 | 得益于独家核心尖端技术和对于客户需求的快速响应,药明生物于全球和中国的市场份额不断扩大。2015-2018年药明生物于全球的市场份额从1%提高至3.2%,随着中国工程师红利的进一步释放和海外客户对药明生物尖端科技的进一步认可,全球的CRO/CDMO产业链有望继续向中国转移,药明生物于全球市场的份额有望持续提升。 -------------------------------------------------------------------------------- /data/yanbao_txt/yanbao000.txt: -------------------------------------------------------------------------------- 1 | 赋能全球生物药研发的 CRO/CDMO 龙头 2 | 3 | 药明生物:高速发展的全球生物药 CRO/CDMO 龙头 4 | 5 | 全球唯一提供全方位生物制剂研发服务的企业 6 | 7 | 药明生物为全球目前唯一能提供全方位生物制剂研发服务的企业,研发管线涵盖各类生物药。截至2019年中期,药明生物综合项目数224个,其中单抗项目132个,融合蛋白项目35个,抗体偶联药物25个,双特异性抗体19个,全球新(First-in-Class)项目58个,占比26%。最近一年抗体偶联药物和双特异性抗体项目增长较快,符合生物药行业发展方向和创新趋势,彰显全球领先的技术平台实力。 8 | 药明生物成立于2010年,为CRO龙头公司药明康德的生物药研发外包服务部门,后从药明康德独立出来,并于2017年在香港交易所上市。成立以来,药明生物高速发展,收入从2014年的3.32亿元增长至2018年的25.35亿元(CAGR为66.2%),归母净利润从0.42亿元增长至6.31亿元(CAGR约96.9%)。欧洲、北美和中国地区持续为公司高成长的三驾马车。 9 | 10 | 优秀的商业模式:“Follow the Molecule”一体化解决方案 11 | 12 | 药明生物采用“Follow the Molecule”的业务发展策略,提供一体化解决方案。客户对公司提供服务的需求随着生物制剂开发过程的推进并最终商业化生产而不断增加,并使得公司来自每个项目的综合收益随着项目在生物制剂开发周期中的推进而呈几何级数增加。生物药研发项目难度高、粘性高,早期项目的锁定将使得公司有更多项目进入商业化生产阶段并持续贡献里程碑收入和销售分成收入。 13 | 14 | 综合项目数持续增加,平均阶段收益不断提高 15 | 16 | 公司的综合项目数量增长强劲,从2016年中期的75个增长至2019年中期的224个。每个项目所带来的平均收益也不断提高,从2016年的960万元提高至2018年的1236万元。 17 | 18 | 未完成服务订单及潜在里程碑收入充足。公司的未完成订单总额增长强劲,由2016的2.65亿美元增长至2019年中期的46.3亿美元。其中未完成服务订单总额由2016年的2.41亿美元增长至2019年中期的17.36亿美元;未完成潜在里程碑收入由2016年的0.24亿美元增长至2019中期的28.94亿美元。 19 | 2017年,未完成潜在里程碑收入大幅增加至10亿美元,主要由于公司将PD-1抗体(GLS-010)国际权益授权给Arcus Biosicence,于2017年下半年获得1850万美元的前期许可费,合同总金额高达8.16亿美元及未来预计高达10%的销售分成收入。同时公司获得与Arcus共同研发生物制剂的三年独家合作伙伴关系,并成为GLS-010的独家生产商。类似于此类的交易成为未完成订单的主要增长动力之一。原液及制剂可在药明生物欧洲、中国及美国选择任何两个生产基地生产。 20 | 2018年下半年和2019年上半年,新签商业化生产合同带动未完成服务订单的大幅提升;WuXiBody双抗项目的签订带动未完成潜在里程碑收入的大幅提升。 21 | 22 | 全球客户日益增长,引领创新潮流吸引独家客户 23 | 24 | 药明目前拥有220家全球合作伙伴,其中包括全球前20大药企中的13家以及中国前50大药企中的22家。 25 | 26 | 三驾马车持续推动收入快速增长 27 | 28 | 全球业务快速拓展,欧洲增长十分靓丽。药明生物在北美、中国和欧洲都保持强劲的增长,美国和中国增速均超过50%。全球中欧洲市场表现最为靓丽,过去5年的CAGR高达185%,WuXiBody在欧洲被广泛关注。公司与瑞士10家客户签订了17个项目的合作协议,瑞士也成为公司于欧洲的最大市场。 29 | 技术实力全球领先,全球新(First-in-Class)项目不断增加。2018年,公司的205个项目中已经有51个全球新项目,占比约25%。2019年上半年,公司的224个项目中有58个全球新项目,占比26%。公司的ADC和双抗项目同比增长强劲,可看出客户对药明生物尖端新技术的认可。预计当前项目中有超过20个可以实现商业化生产,销售收入可达20亿美元。预计有超过40个项目未来可能产生超过40亿美元里程碑及每年超过2亿美元的销售分成。 30 | 31 | 多项核心优势,打造全球龙头 32 | 33 | 得益于独家核心尖端技术和对于客户需求的快速响应,药明生物于全球和中国的市场份额不断扩大。2015-2018年药明生物于全球的市场份额从1%提高至3.2%,随着中国工程师红利的进一步释放和海外客户对药明生物尖端科技的进一步认可,全球的CRO/CDMO产业链有望继续向中国转移,药明生物于全球市场的份额有望持续提升。 -------------------------------------------------------------------------------- /data/yanbao_txt/yanbao002.txt: -------------------------------------------------------------------------------- 1 | 赋能全球生物药研发的 CRO/CDMO 龙头 2 | 3 | 多项核心优势,打造全球龙头 4 | 5 | 得益于独家核心尖端技术和对于客户需求的快速响应,药明生物于全球和中国的市场份额不断扩大。2015-2018年药明生物于全球的市场份额从1%提高至3.2%,随着中国工程师红利的进一步释放和海外客户对药明生物尖端科技的进一步认可,全球的CRO/CDMO产业链有望继续向中国转移,药明生物于全球市场的份额有望持续提升。 6 | 7 | 核心优势 1:尖端科技,引领创新潮流 8 | 9 | (1)WuXiBodyTM双抗技术平台上市后客户迅速拓展 10 | 11 | WuXiBodyTM技术平台2018年下半年上市,得到广泛的关注。WuXiBodyTM平台为药明生物尖端科技的代表:1)灵活性高:几乎所有的单抗序列都可以用来构建双抗,基于生物学特性的二/三/四价;2)速度快:减少CMC工艺开发的挑战:无表达、多聚体或错配导致的纯化困难,节约6-18个月的研发时间;3)质量高:低免疫原性,无需复杂工程的天然序列;与单抗接近的体内半衰期,比典型的双抗更长。 12 | 13 | (2)WuXia细胞系平台,专有技术可换取许可费和销售分成 14 | 15 | 公司拥有自主知识产权的细胞系与自主算法相结合,更具成本优势、效率更高,且通过向客户许可细胞系构建和开发过程中产生的专有技术以换取许可费和未来的销售分成费。目前WuXia平台已开发出220多个细胞系,每年可开发超过60个项目。 16 | 17 | (3)WuXiUP连续生产平台,表达量显著提高 18 | 19 | 高表达量,高纯化率。WuXiUP平台使用2000L一次性生物反应器生产产品具有与20000L传统不锈钢罐生物反应器相当的批次产量,显著降低生产成本,并且在实现高产量的同时保证了媲美传统纯化工艺的高纯化收率。目前该技术正在放大至GMP生产阶段,并将应用于药明生物的全球生产基地。截至2019年年中,公司已经开展15个WuXiUP项目,WuXiUP连续细胞培养的表达量已经高达30-50g/L,是传统技术的十倍。 20 | 21 | 一次性生物反应器生物药生产的全球领导者 22 | 23 | 药明生物是一次性反应器最大的使用商,500+批次生产的成功率高达98%,采用横向拓展(Scale-out)策略的生产成本可与传统万升以上不锈钢大罐的水平相当。建厂投资较低、建厂更快、成本更有优势。 24 | 25 | 核心优势 2:“药明速度”,显著加快研发进度 26 | 27 | 凭借强大的技术平台、一体化的服务模式以及对客户需求的快速响应,药明生物可以在最短时间内为客户实现从DNA到IND相关CMC工作。目前行业平均时间为18-24个月,药明生物平均为15个月,最快纪录已经从9个月缩短到7个月。研发时间的缩短将有效提升药企的研发效率和研发回报率,“药明速度”较竞争对手优势显著,对国内外药企均具备极高的吸引力。 28 | 29 | 核心优势 3:过往记录优秀,人才团队庞大 30 | 31 | 从过往记录来看,药明生物的项目交付可以达到100%,没有客户流失,客户满意度高、认可度高。优秀的过往记录增加了药明生物与国际巨头竞争中的优势。 32 | 人才队伍快速增长,研发人员占比高,员工保留率高。经过8年的发展,药明生物员工总数已经迅速上升至4512人,预计2019年底员工总数将达到5600人。其中,研发人员达到1682人,占比37%,是业内最大的生物制药研发团队之一。2019年上半年员工保留率95%,核心员工保留率达到98%,保留率高于biotech行业平均。 33 | 34 | 核心优势 4:卓越的质量体系和高质量的产能扩张 35 | 36 | 高质量产能扩张顺利,支持各类项目于4周内快速启动。药明生物的MFG1是中国首个及当前唯一同时获美国FDA及欧盟EMA认证的生物药生产设施。目前药明生物于无锡和上海的多个生产设施已经顺利完成GMP认证。美国和爱尔兰的生产设施也在快速推进中。药明生物的“按需扩产计划”和“全球双厂生产”策略将满足部分客户多区域、稳定供应的要求,更好的加强客户后期阶段项目的粘性。 37 | 38 | 核心优势 5:新业务的孵化能力强:疫苗业务值得期待 39 | 40 | 疫苗业务有望成为新增长点,助力公司长期稳健成长。疫苗为全球生物药中的重要细分领域,具备高研发和生产壁垒,由几家国际巨头所垄断。由于某些具备高临床需求的疫苗产品上市后容易面临供不应求的情况,疫苗行业的CDMO机会不容忽视。2018年7月,药明生物通过全资子公司与海利生物订立协议,成立合资企业从事人用疫苗(包括癌症疫苗)CDMO业务,并提供人用疫苗从概念到商业化生产全过程的发现、开发及生产端到端服务及解决方案平台。该合资公司由药明生物和海利生物分别拥有70%和30%股权。 41 | 2019年5月,药明生物与一家全球疫苗巨头签订意向书,通过与海利生物共同设立的公司(药明海德),建设疫苗原液(DS)及制剂(DP)生产、质量控制(QC)实验室于一体的综合疫苗生产基地,并为疫苗合作伙伴生产若干疫苗。预期生产合约期限长达20年,总合约价值超过30亿美元。与全球疫苗巨头合作,并为全球市场生产疫苗充分展示了药明生物的技术优势及全球高标准的质量标准。开展该项目后,疫苗业务将对药明生物的业务增长作出重大贡献。 42 | 43 | -------------------------------------------------------------------------------- /data/yanbao_txt/yanbao003.txt: -------------------------------------------------------------------------------- 1 | 赋能全球生物药研发的 CRO/CDMO 龙头 2 | 3 | 财务分析:里程碑收入有望持续提高利润率 4 | 5 | 2019年上半年,药明生物里程碑收入约2.14亿元,占收入比例约13.3%。考虑随着未来BLA申报和双抗IND申报带来的较高里程碑收入,公司的利润率水平有望随着里程碑收入的增长而持续提高。 6 | 资产周转率的下降影响了公司ROE的提高,未来随着利润体量的快速增长、里程碑和销售分成带来的利润率的提高,公司ROE具备提升空间。 7 | 8 | 金融资产增加,关注孵化项目 9 | 10 | 根据公司的投资策略按公允价值基准管理及评估已购买未上市优先股的投资表现。 11 | 2018年5月和2019年1月,公司分别以300万美元和1200万美元的现金代价购买Inhibrx, Inc. 的429,799股及1,719,197股系列夹层2优先股。Inhibrx总部位于美国加州,致力于开发创新生物药物,聚焦于癌症、传染病及罕见病领域。 12 | 2018年9月和2019年1月,公司分别以500万美元和500万美元的现金代价购买Canbridge Pharmaceuticals Inc. (北海康城)的481,454股系列C-1优先股及481,454股系列C-3优先股。Canbridge为根据开曼群岛法律注册成立的获豁免有限公司,致力于开发、销售用于治疗或预防肿瘤或罕见疾的药物业务。 13 | 2019年3月,公司以190万美元现金代价购买Virtuoso Therapeutics, Inc.的2,856,055股系列A优先股。Virtuoso为根据开曼群岛法律正式注册成立并获有效存续的豁免公司,致力于肿瘤治疗的抗体药物研发业务。 14 | 15 | 收入拆分与盈利预测:有望持续高成长 16 | 17 | 近年来药明生物持续高速发展,收入从2014年的3.32亿元增长至2018年的25.35亿元(CAGR为66.2%),净利润从0.42亿元增长至6.31亿元(CAGR约96.9%)。我们预计依托公司强大的尖端技术、优秀的科学人才、庞大的技术团队和高质量的生产体系,对于全球客户的吸引力将持续上升,每年新增全球新项目数量可观,且考虑生物药研发生产粘性高、进入后期之后项目收费将逐步提高,公司的收入有望持续高速增长。随着更多的全球新项目申报IND和更多的产品申报BLA,公司的里程碑收入也有望大幅提升,随之提高公司的利润率水平。 18 | 19 | 费用率情况: 20 | 21 | 2019年上半年,公司毛利率41.8%,同比提高2.5pct,主要受益于里程碑收入的提高,2020年随着更多的项目申报IND/BLA,公司的里程碑收入仍有望持续提高,并带动利润率的稳步提升。 22 | 2019年公司销售费用率1.6%,同比下降0.3pct;管理费用率9.3%,同比提高1.1pct。考虑公司的经营管理较为稳定,未来期间费用率有望维持现有水平。 23 | 综上,预计公司2019-2021年营业收入37.72/54.05/76.98亿元,同比增长48.8%/43.3%/42.3% ;归母净利润9.61/14.05/20.06亿元,同比增长52.3%/46.3%/42.8%,EPS为0.74/1.09/1.55元/股,当前股价(98.10港元)对应PE119/81/57x。如果获取里程碑进度超预期或者获得较大的订单,公司的收入和利润有望进一步提速。 24 | 25 | 估值与投资建议 26 | 27 | 我们采用DCF方法对公司进行估值:假设WACC=9.10%,永续增长率=3.0%。得到公司合理价值为116.16港元/股,对应2020年PE96x,2021年PE 67x,首次覆盖给予“买入”评级。 28 | 29 | 风险提示 30 | 31 | (1)研发进度不达预期:药企的药品研发进度受到审评政策、临床方案设计等多方面因素影响,药企研发进度低于预期可能影响公司的收入增速和里程碑获取的时间。 32 | (2)新增项目数量低于预期:随着政策推动下制药产业出清速度加快,长期来看药企赢利能力的下降可能引起全行业研发投入的减少; 33 | (3)中止项目数量上升:全球新(First-in-Class)药品的研发难度大、成功率低,此类项目占比的提升可能造成公司整体成功率的下降; 34 | (4)行业政策风险:公司于全球开展业务,可能受到各地区法规变化的影响。 -------------------------------------------------------------------------------- /data/yanbao_txt/yanbao004.txt: -------------------------------------------------------------------------------- 1 | 2020年02月13日 2 | 3 | 华虹半导体(01347.HK) 增持(维持评级) 4 | 5 | 结构性问题改善遥遥无期 6 | 7 | 双季低于预期,结构性问题改善遥遥无期 8 | 9 | ■华虹公布四季度营收符合预期,但其中美国(-10%q/q)的150-180纳米(一15%q/q)模拟芯片及电源管理(-17%q/q)通讯客户(一25%q/q)的环比衰退似乎是拖累华虹的主要因素。加上管理费用(占28%的营收)跳升,华虹四季度首次步入营业亏损,四季度EPS为US$0.02,比市场一致预期低35%。又因为公司担心客户年后复工延迟,华虹预期一季度营收环比衰退18%,同比衰退9%(低于彭博分析师预期环比增长2%,同比增长13%)。公司并预期毛利率大幅下滑到21-23%,低于市场预期的26%。 10 | 11 | ■折旧费用攀升影响毛利率可期:我们担心的是等到华虹12“全能量产后,即使有政府的补贴,季度折旧摊销费用将超过5,000万美元,占营业成本比例会从过去的19%,拉到未来的25-30%,这样20-25%的毛利率及10%以下的营业利润率可能在2020-2021年将成为常态,这是比市场预期的10-15%的营业利润率差很多。 12 | 13 | 投资建议 14 | 15 | ■调降获利,结构性问题改善遥遥无期:虽然我们认为华虹从二季度开始,营收环比及同比增速将逐步改善,以5-10%的ROE而言,1.5x以下P/B估值都还算合理,但我们仍然看到公司几项影响未来股价结构性的问题,所以我们还是决定调降华虹2020E EPS近23%到US$0.12,及下调2021E EPS近16%到US$0.16,维持目标价在HK$16。 16 | 17 | 风险提示 18 | 19 | ■12“晶圓代工扩产计划所拉高的折旧费用可能破坏毛利结构,增长动力的MCU和电力功率半导体下游景气度的下滑,高毛利分立器件晶圓代工的竞争加速,而与市场预估值的差异是短期所面临的风险。 20 | 21 | 双季低于预期,12“厂量产逆风可期 22 | 23 | ■四季度/一季度都低于预期:华虹公布四季度营收(2.43亿美金)环比增长2%,同比去年同期小幅衰退2%及27%的毛利率都符合彭博21位分析师的预期。其中美国(-10%q/q)的150-180纳米(-15%q/q)模拟芯片及电源管理(-17%q/q)通讯客户(-25%q/q)的环比衰退似乎是拖累华虹的主要因素。加上管理费用(占28%的营收)跳升(vs.去年同期14.8%的营收及三季度营收的16%),华虹四季度首次步入营业亏损,四季度US$0.02EPS是35%低于市场预期。又因为公司担心客户年后复工延迟,华虹预期一季度营收环比衰退18%,同比衰退9%(低于彭博分析师预期环比增长2%,同比增长13%)。公司并预期毛利率大幅下滑到21-23%,也低于市场预期的26%毛利率。 24 | 25 | ■调降获利,结构性问题改善遥遥无期:虽然我们认为华虹从二季度开始,营收环比及同比增速将逐步改善,以5-10%的ROE而言,1.5x以下P/B估值都还算合理,但我们仍然看到公司几项影响未来股价结构性的问题,所以我们还是决定调降华虹2020E EPS近23%到US$0.12,及下调2021E EPS近16%到US$0.16,维持目标价在HK$16。 26 | 27 | ■折旧费用攀升影响毛利率可期:我们担心的是等到华虹12“全能量产后,即使有政府的补贴,季度折旧摊销费用将超过5,000万美元,占营业成本比会从过去的19%,拉到未来的25-30%,这样20-25%的毛利率及10%以下的营业利润率可能在2020-2021年将成为常态,这是比市场预期的10-15%的营业利润率来得差很多。 28 | 29 | 风险提示 30 | 31 | 12“晶圆代工扩产计划所拉高的折旧费用可能持续破坏整体毛利结构,增长动力的MCU和电力功率半导体下游景气度的下滑,高毛利分立器件晶圆代工的竞争加速,而与市场预估值的差异是短期所面临的风险。 -------------------------------------------------------------------------------- /data/yanbao_txt/yanbao005.txt: -------------------------------------------------------------------------------- 1 | 公告点评|耐用消费品与服装 2 | 3 | 【广发海外】安踏体育(02020.HK) 4 | 5 | Q4业绩延续强势,发行可转债优化财务结构 6 | 7 | 公司评级 买入 8 | 前次评级 买入 9 | 报告日期 2020-01-19 10 | 11 | 核心观点: 12 | 公司公布19年全年及19Q4运营数据。19Q4安踏主品牌、FILA零售流水分别实现高双位数、50%-55%增长,全年分别实现中双位数、55%-60%增长;其他品牌(不包括Amer Sports)19Q4同比增长25%-30%,19年全年同比增长30%-35%。 13 |  14 | 主品牌提速,FILA维持高增。分品牌看,1)安踏:主品牌Q4较前6季度有所提速,18Q2-19Q4安踏流水增速分别为低双位数、中双位数、中双位数、低双位数、中双位数、中双位数;2)FILA:Q4延续强势,19年Q1-Q3流水增速分别为65%-70%、55%-60%、50%-55%。 15 | 16 | 公司发行10亿欧元5年期可转债,优化财务结构。公司此次发行的可转债主要用于偿还此前收购Amer产生的8.5亿欧元银行贷款中的大部分。且本次可转债转股价格为105.28港元/股,较公告前一日收盘价溢价40%,充分体现投资者对公司未来发展的信心。 17 | 18 | Amer Sports中国区有望稳健增长。公司的“大品牌、大渠道、大国家”战略清晰,未来品牌端有望打造始祖鸟、Salomon、Wilson三大10亿级品牌;渠道端有望全面实现零售思维转型,直营规模达10亿级;向“世界级体育用品集团”的愿景不断靠近。 19 | 20 | 19-21年业绩分别为2.03元/股、2.86元/股、3.52元/股。预计19-21年公司收入分别为325.6/406/485.9亿元,YOY+35.1%/24.7%/19.7%;19-21年归母净利润分别为54.9/77.3/95亿元,YOY+33.8%/40.7%/22.9%;19-21年EPS分别为2.03/2.86/3.52元/股。据wind一致预期,20年李宁、滔搏平均PE为27x,给予公司20年27xPE,CNY/HKD汇率取0.89,对应合理价值为88.3港元/股,维持“买入”评级。 21 | 22 | 风险提示。行业景气度下行;FILA、Amer Sports增速低于预期。 -------------------------------------------------------------------------------- /expirement_attr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_attr/__init__.py -------------------------------------------------------------------------------- /expirement_attr/etlspan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_attr/etlspan/__init__.py -------------------------------------------------------------------------------- /expirement_attr/etlspan/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run evaluation with saved models. 3 | """ 4 | 5 | import os 6 | import random 7 | import argparse 8 | import pickle 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import json 14 | from models.model import RelationModel 15 | from utils import torch_utils, helper, result 16 | from utils.vocab import Vocab 17 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 18 | 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('model_dir', type=str, help='Directory of the model.') 23 | parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.') 24 | parser.add_argument('--data_dir', type=str, default='dataset/NYT-multi/data') 25 | parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.") 26 | 27 | parser.add_argument('--seed', type=int, default=42) 28 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) 29 | parser.add_argument('--cpu', action='store_true') 30 | args = parser.parse_args() 31 | 32 | torch.manual_seed(args.seed) 33 | random.seed(args.seed) 34 | if args.cpu: 35 | args.cuda = False 36 | elif args.cuda: 37 | torch.cuda.manual_seed(args.seed) 38 | 39 | # load opt 40 | model_file = args.model_dir + '/' + args.model 41 | print("Loading model from {}".format(model_file)) 42 | opt = torch_utils.load_config(model_file) 43 | model = RelationModel(opt) 44 | model.load(model_file) 45 | 46 | 47 | 48 | # load data 49 | data_file = args.data_dir + '/{}.json'.format(args.dataset) 50 | print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size'])) 51 | data = json.load(open(data_file)) 52 | 53 | id2predicate, predicate2id, id2subj_type, subj_type2id, id2obj_type, obj_type2id = json.load(open(opt['data_dir'] + '/schemas.json')) 54 | id2predicate = {int(i):j for i,j in id2predicate.items()} 55 | 56 | # 加载bert词表 57 | UNCASED = args.bert_model # your path for model and vocab 58 | VOCAB = 'bert-base-chinese-vocab.txt' 59 | bert_tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED,VOCAB)) 60 | 61 | helper.print_config(opt) 62 | 63 | results = result.evaluate(bert_tokenizer, data, id2predicate, model) 64 | results_save_dir = opt['model_save_dir'] + '/results.json' 65 | print("Dumping the best test results to {}".format(results_save_dir)) 66 | 67 | with open(results_save_dir, 'w') as fw: 68 | json.dump(results, fw, indent=4, ensure_ascii=False) 69 | 70 | print("Evaluation ended.") 71 | 72 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import math 7 | 8 | from utils import torch_utils 9 | 10 | 11 | class CharEncoder(nn.Module): 12 | def __init__(self, opt): 13 | super().__init__() 14 | self.char_encoder = nn.Conv1d(in_channels=opt['char_emb_dim'], out_channels=opt['char_hidden_dim'], kernel_size=3, 15 | padding=1) 16 | nn.init.xavier_uniform_(self.char_encoder.weight) 17 | nn.init.uniform_(self.char_encoder.bias) 18 | 19 | def forward(self, x, mask): 20 | x = self.char_encoder(x.transpose(1,2)).transpose(1,2) 21 | x = F.relu(x) 22 | hidden_masked = x - ((1 - mask) * 1e10).unsqueeze(2).repeat(1,1,x.shape[2]).float() 23 | sentence_rep = F.max_pool1d(torch.transpose(hidden_masked, 1, 2), hidden_masked.size(1)).squeeze(2) 24 | return sentence_rep 25 | 26 | 27 | def seq_and_vec(seq_len, vec): 28 | return vec.unsqueeze(1).repeat(1,seq_len,1) 29 | 30 | 31 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/models/submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from utils import torch_utils 7 | 8 | from models.layers import * 9 | 10 | 11 | 12 | class SubjTypeModel(nn.Module): 13 | 14 | 15 | def __init__(self, opt, filter=3): 16 | super(SubjTypeModel, self).__init__() 17 | 18 | self.dropout = nn.Dropout(opt['dropout']) 19 | self.hidden_dim = opt['word_emb_dim'] 20 | 21 | self.position_embedding = nn.Embedding(500, opt['position_emb_dim']) 22 | 23 | self.linear_subj_start = nn.Linear(self.hidden_dim, opt['num_subj_type']+1) 24 | self.linear_subj_end = nn.Linear(self.hidden_dim, opt['num_subj_type']+1) 25 | self.init_weights() 26 | 27 | 28 | 29 | def init_weights(self): 30 | 31 | self.position_embedding.weight.data.uniform_(-1.0, 1.0) 32 | self.linear_subj_start.bias.data.fill_(0) 33 | init.xavier_uniform_(self.linear_subj_start.weight, gain=1) # initialize linear layer 34 | 35 | self.linear_subj_end.bias.data.fill_(0) 36 | init.xavier_uniform_(self.linear_subj_end.weight, gain=1) # initialize linear layer 37 | 38 | def forward(self, hidden): 39 | 40 | subj_start_inputs = self.dropout(hidden) 41 | subj_start_logits = self.linear_subj_start(subj_start_inputs) 42 | 43 | subj_end_inputs = self.dropout(hidden) 44 | subj_end_logits = self.linear_subj_end(subj_end_inputs) 45 | 46 | return subj_start_logits.squeeze(-1), subj_end_logits.squeeze(-1) 47 | 48 | 49 | def predict_subj_start(self, hidden): 50 | 51 | subj_start_logits = self.linear_subj_start(hidden) 52 | 53 | return subj_start_logits.squeeze(-1)[0].data.cpu().numpy() 54 | 55 | def predict_subj_end(self, hidden): 56 | 57 | subj_end_logits = self.linear_subj_end(hidden) 58 | 59 | return subj_end_logits.squeeze(-1)[0].data.cpu().numpy() 60 | 61 | 62 | class ObjBaseModel(nn.Module): 63 | 64 | def __init__(self, opt, filter=3): 65 | super(ObjBaseModel, self).__init__() 66 | 67 | self.dropout = self.drop = nn.Dropout(opt['dropout']) 68 | self.distance_to_subj_embedding = nn.Embedding(400, opt['position_emb_dim']) 69 | self.distance_to_obj_start_embedding = nn.Embedding(500, opt['position_emb_dim']) 70 | self.input_dim = opt['obj_input_dim'] 71 | 72 | self.linear_obj_start = nn.Linear(self.input_dim, opt['num_class']+1) 73 | self.linear_obj_end = nn.Linear(self.input_dim, opt['num_class']+1) 74 | self.init_weights() 75 | 76 | def init_weights(self): 77 | self.linear_obj_start.bias.data.fill_(0) 78 | init.xavier_uniform_(self.linear_obj_start.weight, gain=1) # initialize linear layer 79 | self.linear_obj_end.bias.data.fill_(0) 80 | init.xavier_uniform_(self.linear_obj_end.weight, gain=1) # initialize linear layer 81 | self.distance_to_subj_embedding.weight.data.uniform_(-1.0, 1.0) 82 | self.distance_to_obj_start_embedding.weight.data.uniform_(-1.0, 1.0) 83 | 84 | def forward(self, hidden, subj_start_position, subj_end_position, distance_to_subj): 85 | 86 | 87 | batch_size, seq_len, input_size = hidden.shape 88 | 89 | subj_start_hidden = torch.gather(hidden, dim=1, index=subj_start_position.unsqueeze(2).repeat(1,1,input_size)).squeeze(1) 90 | subj_end_hidden = torch.gather(hidden, dim=1, index=subj_end_position.unsqueeze(2).repeat(1,1,input_size)).squeeze(1) 91 | distance_to_subj_emb = self.distance_to_subj_embedding(distance_to_subj+200) # To avoid negative indices 92 | 93 | subj_related_info = torch.cat([seq_and_vec(seq_len,subj_start_hidden), seq_and_vec(seq_len,subj_end_hidden), distance_to_subj_emb], dim=2) 94 | obj_inputs = torch.cat([hidden, subj_related_info], dim=2) 95 | obj_inputs = self.dropout(obj_inputs) 96 | 97 | obj_start_outputs = self.dropout(obj_inputs) 98 | obj_start_logits = self.linear_obj_start(obj_start_outputs) 99 | 100 | obj_end_outputs = self.dropout(obj_inputs) 101 | obj_end_logits = self.linear_obj_end(obj_end_outputs) 102 | 103 | return obj_start_logits, obj_end_logits 104 | 105 | def predict_obj_start(self, hidden, subj_start_position, subj_end_position, distance_to_subj): 106 | 107 | batch_size, seq_len, input_size = hidden.size() 108 | 109 | subj_start_hidden = torch.gather(hidden, dim=1, index=subj_start_position.unsqueeze(2).repeat(1,1,input_size)).squeeze(1) 110 | subj_end_hidden = torch.gather(hidden, dim=1, index=subj_end_position.unsqueeze(2).repeat(1,1,input_size)).squeeze(1) 111 | distance_to_subj_emb = self.distance_to_subj_embedding(distance_to_subj+200) 112 | subj_related_info = torch.cat([seq_and_vec(seq_len,subj_start_hidden), seq_and_vec(seq_len,subj_end_hidden), distance_to_subj_emb], dim=2) 113 | obj_inputs = torch.cat([hidden, subj_related_info], dim=2) 114 | obj_inputs = self.dropout(obj_inputs) 115 | 116 | obj_start_logits = self.linear_obj_start(obj_inputs) 117 | 118 | return obj_start_logits.squeeze(-1)[0].data.cpu().numpy() 119 | 120 | 121 | 122 | def predict_obj_end(self, hidden, subj_start_position, subj_end_position, distance_to_subj): 123 | 124 | batch_size, seq_len, input_size = hidden.size() 125 | 126 | subj_start_hidden = torch.gather(hidden, dim=1, index=subj_start_position.unsqueeze(2).repeat(1, 1, input_size)).squeeze(1) 127 | subj_end_hidden = torch.gather(hidden, dim=1, index=subj_end_position.unsqueeze(2).repeat(1, 1, input_size)).squeeze(1) 128 | distance_to_subj_emb = self.distance_to_subj_embedding(distance_to_subj + 200) 129 | subj_related_info = torch.cat([seq_and_vec(seq_len, subj_start_hidden), seq_and_vec(seq_len, subj_end_hidden), distance_to_subj_emb],dim=2) 130 | obj_inputs = torch.cat([hidden, subj_related_info], dim=2) 131 | obj_inputs = self.dropout(obj_inputs) 132 | 133 | obj_end_logits = self.linear_obj_end(obj_inputs) 134 | 135 | return obj_end_logits.squeeze(-1)[0].data.cpu().numpy() 136 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/utils/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define common constants. 3 | """ 4 | TRAIN_JSON = 'train.json' 5 | DEV_JSON = 'dev.json' 6 | TEST_JSON = 'test.json' 7 | 8 | GLOVE_DIR = 'dataset/glove' 9 | 10 | EMB_INIT_RANGE = 1.0 11 | MAX_LEN = 100 12 | 13 | # vocab 14 | PAD_TOKEN = '' 15 | PAD_ID = 0 16 | UNK_TOKEN = '' 17 | UNK_ID = 1 18 | 19 | VOCAB_PREFIX = [PAD_TOKEN, UNK_TOKEN] 20 | 21 | INFINITY_NUMBER = 1e12 22 | 23 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/utils/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions. 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | 9 | ### IO 10 | def check_dir(d): 11 | if not os.path.exists(d): 12 | print("Directory {} does not exist. Exit.".format(d)) 13 | exit(1) 14 | 15 | def check_files(files): 16 | for f in files: 17 | if f is not None and not os.path.exists(f): 18 | print("File {} does not exist. Exit.".format(f)) 19 | exit(1) 20 | 21 | def ensure_dir(d, verbose=True): 22 | if not os.path.exists(d): 23 | if verbose: 24 | print("Directory {} do not exist; creating...".format(d)) 25 | os.makedirs(d) 26 | 27 | def save_config(config, path, verbose=True): 28 | with open(path, 'w') as outfile: 29 | json.dump(config, outfile, indent=2) 30 | if verbose: 31 | print("Config saved to file {}".format(path)) 32 | return config 33 | 34 | def load_config(path, verbose=True): 35 | with open(path) as f: 36 | config = json.load(f) 37 | if verbose: 38 | print("Config loaded from file {}".format(path)) 39 | return config 40 | 41 | def print_config(config): 42 | info = "Running with the following configs:\n" 43 | for k,v in config.items(): 44 | info += "\t{} : {}\n".format(k, str(v)) 45 | print("\n" + info + "\n") 46 | return 47 | 48 | class FileLogger(object): 49 | """ 50 | A file logger that opens the file periodically and write to it. 51 | """ 52 | def __init__(self, filename, header=None): 53 | self.filename = filename 54 | if os.path.exists(filename): 55 | # remove the old file 56 | os.remove(filename) 57 | if header is not None: 58 | with open(filename, 'w') as out: 59 | print(header, file=out) 60 | 61 | def log(self, message): 62 | with open(self.filename, 'a') as out: 63 | print(message, file=out) 64 | 65 | 66 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/utils/score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from utils import loader 4 | import json 5 | 6 | def extract_items(bert_tokenizer, tokens_in, id2predicate, model): 7 | R = [] 8 | # word2vec编码 9 | # _t = [word2id.get(w, 1) for w in tokens_in] 10 | _t = bert_tokenizer.convert_tokens_to_ids(tokens_in) 11 | 12 | # _c = [[char2id.get(c, 1) for c in token] for token in tokens_in] 13 | # _pos = loader.get_pos_tags(tokens_in, pos_tags, pos2id) 14 | 15 | _t = np.array(loader.seq_padding([_t])) 16 | # _c = np.array(loader.char_padding([_c ])) 17 | # _pos = np.array(loader.seq_padding([_pos])) 18 | _s1, _s2, hidden = model.predict_subj_per_instance(_t) 19 | _s1, _s2 = np.argmax(_s1, 1), np.argmax(_s2, 1) 20 | for i,_ss1 in enumerate(_s1): 21 | if _ss1 > 0: 22 | _subject = '' 23 | for j,_ss2 in enumerate(_s2[i:]): 24 | if _ss2 == _ss1: 25 | _subject = ''.join(tokens_in[i: i+j+1]) 26 | break 27 | if _subject: 28 | _k1, _k2 = np.array([[i]]), np.array([[i+j]]) 29 | if len(tokens_in) > 150: 30 | continue 31 | distance_to_subj = np.array([loader.get_positions(i, i+j, len(tokens_in))]) 32 | _o1, _o2 = model.predict_obj_per_instance([_t,_k1, _k2, distance_to_subj], hidden) 33 | _o1, _o2 = np.argmax(_o1, 1), np.argmax(_o2, 1) 34 | for i,_oo1 in enumerate(_o1): 35 | if _oo1 > 0: 36 | for j,_oo2 in enumerate(_o2[i:]): 37 | if _oo2 == _oo1: 38 | _object = ''.join(tokens_in[i: i+j+1]) 39 | _predicate = id2predicate[_oo1] 40 | 41 | #处理关系类型并解析数据 42 | 43 | R.append((_subject, _predicate, _object)) 44 | break 45 | return set(R) 46 | 47 | 48 | def evaluate(bert_tokenizer, data, id2predicate, model): 49 | official_A, official_B, official_C = 1e-10, 1e-10, 1e-10 50 | manual_A, manual_B, manual_C = 1e-10, 1e-10, 1e-10 51 | 52 | results = [] 53 | for d in tqdm(iter(data)): 54 | R = extract_items(bert_tokenizer, d['tokens'], id2predicate, model) 55 | official_T = set([tuple(i) for i in d['spo_list']]) 56 | results.append({'text':' '.join(d['tokens']), 'predict':list(R), 'truth':list(official_T)}) 57 | official_A += len(R & official_T) 58 | official_B += len(R) 59 | official_C += len(official_T) 60 | return 2 * official_A / (official_B + official_C), official_A / official_B, official_A / official_C, results 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /expirement_attr/etlspan/utils/vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | A class for basic vocab operations. 3 | """ 4 | 5 | from __future__ import print_function 6 | import os 7 | import random 8 | import numpy as np 9 | import pickle 10 | 11 | from utils import constant 12 | 13 | 14 | 15 | random.seed(1234) 16 | np.random.seed(1234) 17 | 18 | def build_embedding(wv_file, vocab, wv_dim): 19 | vocab_size = len(vocab) 20 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim)) 21 | emb[constant.PAD_ID] = 0 # should be all 0 (using broadcast) 22 | 23 | w2id = {w: i for i, w in enumerate(vocab)} 24 | with open(wv_file, encoding="utf8") as f: 25 | for line in f: 26 | elems = line.split() 27 | token = ''.join(elems[0:-wv_dim]) 28 | if token in w2id: 29 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 30 | return emb 31 | 32 | def load_glove_vocab(file, wv_dim): 33 | """ 34 | Load all words from glove. 35 | """ 36 | vocab = set() 37 | with open(file, encoding='utf8') as f: 38 | try: 39 | for line_num, line in enumerate(f): 40 | elems = line.split() 41 | token = ''.join(elems[0:-wv_dim]) 42 | vocab.add(token) 43 | except Exception as e: 44 | print(line) 45 | return vocab 46 | 47 | def normalize_glove(token): 48 | mapping = {'-LRB-': '(', 49 | '-RRB-': ')', 50 | '-LSB-': '[', 51 | '-RSB-': ']', 52 | '-LCB-': '{', 53 | '-RCB-': '}'} 54 | if token in mapping: 55 | token = mapping[token] 56 | return token 57 | 58 | class Vocab(object): 59 | def __init__(self, filename, load=False, word_counter=None, threshold=0): 60 | if load: 61 | assert os.path.exists(filename), "Vocab file does not exist at " + filename 62 | # load from file and ignore all other params 63 | self.id2word, self.word2id = self.load(filename) 64 | self.size = len(self.id2word) 65 | print("Vocab size {} loaded from file".format(self.size)) 66 | else: 67 | print("Creating vocab from scratch...") 68 | assert word_counter is not None, "word_counter is not provided for vocab creation." 69 | self.word_counter = word_counter 70 | if threshold > 1: 71 | # remove words that occur less than thres 72 | self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold]) 73 | self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True) 74 | # add special tokens to the beginning 75 | self.id2word = ['**PAD**', '**UNK**'] + self.id2word 76 | self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))]) 77 | self.size = len(self.id2word) 78 | self.save(filename) 79 | print("Vocab size {} saved to file {}".format(self.size, filename)) 80 | 81 | def load(self, filename): 82 | with open(filename, 'rb') as infile: 83 | id2word = pickle.load(infile) 84 | word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))]) 85 | return id2word, word2id 86 | 87 | def save(self, filename): 88 | #assert not os.path.exists(filename), "Cannot save vocab: file exists at " + filename 89 | if os.path.exists(filename): 90 | print("Overwriting old vocab file at " + filename) 91 | os.remove(filename) 92 | with open(filename, 'wb') as outfile: 93 | pickle.dump(self.id2word, outfile) 94 | return 95 | 96 | def map(self, token_list): 97 | """ 98 | Map a list of tokens to their ids. 99 | """ 100 | return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list] 101 | 102 | def unmap(self, idx_list): 103 | """ 104 | Unmap ids back to tokens. 105 | """ 106 | return [self.id2word[idx] for idx in idx_list] 107 | 108 | def get_embeddings(self, word_vectors=None, dim=100): 109 | #self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE 110 | self.embeddings = np.zeros((self.size, dim)) 111 | if word_vectors is not None: 112 | assert len(list(word_vectors.values())[0]) == dim, \ 113 | "Word vectors does not have required dimension {}.".format(dim) 114 | for w, idx in self.word2id.items(): 115 | if w in word_vectors: 116 | self.embeddings[idx] = np.asarray(word_vectors[w]) 117 | return self.embeddings 118 | 119 | 120 | -------------------------------------------------------------------------------- /expirement_er/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/config/en_bert_base_cased.json: -------------------------------------------------------------------------------- 1 | { 2 | "bert_frozen": "false", 3 | "hidden_size": 768, 4 | "hidden_dropout_prob": 0.2, 5 | "classifier_sign": "multi_nonlinear", 6 | "clip_grad": 1, 7 | "bert_config": { 8 | "attention_probs_dropout_prob": 0.1, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.1, 11 | "hidden_size": 768, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "max_position_embeddings": 512, 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "type_vocab_size": 2, 18 | "vocab_size": 28996 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/config/en_bert_base_uncased.json: -------------------------------------------------------------------------------- 1 | { 2 | "bert_frozen": "false", 3 | "hidden_size": 768, 4 | "hidden_dropout_prob": 0.2, 5 | "classifier_sign": "multi_nonlinear", 6 | "clip_grad": 1, 7 | "bert_config": { 8 | "attention_probs_dropout_prob": 0.1, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.1, 11 | "hidden_size": 768, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "max_position_embeddings": 512, 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522 19 | } 20 | } -------------------------------------------------------------------------------- /expirement_er/bertMRC/config/zh_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "bert_frozen": "false", 3 | "hidden_size": 768, 4 | "hidden_dropout_prob": 0.2, 5 | "classifier_sign": "multi_nonlinear", 6 | "clip_grad": 1, 7 | "bert_config": { 8 | "attention_probs_dropout_prob": 0.1, 9 | "directionality": "bidi", 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 768, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "max_position_embeddings": 512, 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pooler_fc_size": 768, 19 | "pooler_num_attention_heads": 12, 20 | "pooler_num_fc_layers": 3, 21 | "pooler_size_per_head": 128, 22 | "pooler_type": "first_token_transform", 23 | "type_vocab_size": 2, 24 | "vocab_size": 21128 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/bertMRC/data_loader/__init__.py -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from pytorch_pretrained_bert.tokenization import BertTokenizer 6 | 7 | 8 | def whitespace_tokenize(text): 9 | """ 10 | Desc: 11 | runs basic whitespace cleaning and splitting on a piece of text 12 | """ 13 | text = text.strip() 14 | if not text: 15 | return [] 16 | tokens = text.split() 17 | return tokens 18 | 19 | 20 | class BertTokenizer4Tagger(BertTokenizer): 21 | """ 22 | Desc: 23 | slove the problem of tagging span can not fit after run word_piece tokenizing 24 | """ 25 | def __init__(self, vocab_file, do_lower_case=False, max_len=None, 26 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 27 | 28 | super(BertTokenizer4Tagger, self).__init__(vocab_file, do_lower_case=do_lower_case, 29 | max_len=max_len, never_split=never_split) 30 | 31 | def tokenize(self, text, label_lst=None): 32 | """ 33 | Desc: 34 | text: 35 | label_lst: ["B", "M", "E", "S", "O"] 36 | """ 37 | 38 | split_tokens = [] 39 | split_labels = [] 40 | 41 | if label_lst is None: 42 | for token in self.basic_tokenizer.tokenize(text): 43 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 44 | split_tokens.append(sub_token) 45 | return split_tokens 46 | 47 | for token, label in zip(self.basic_tokenizer.tokenize(text), label_lst): 48 | # cureent token should be 1 single word 49 | sub_tokens = self.wordpiece_tokenizer.tokenize(token) 50 | if len(sub_tokens) > 1: 51 | for tmp_idx, tmp_sub_token in enumerate(sub_tokens): 52 | if tmp_idx == 0: 53 | split_tokens.append(tmp_sub_token) 54 | split_labels.append(label) 55 | else: 56 | split_tokens.append(tmp_sub_token) 57 | split_labels.append("X") 58 | else: 59 | split_tokens.append(sub_token) 60 | split_labels.append(label) 61 | 62 | return split_tokens, split_labels 63 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/entities_type.py: -------------------------------------------------------------------------------- 1 | baidu10_type2id = { 2 | "图书作品": 1, 3 | "学科专业": 2, 4 | "景点": 3, 5 | "历史人物": 4, 6 | "生物": 5, 7 | "网络小说": 6, 8 | "电视综艺": 7, 9 | "歌曲": 8, 10 | "机构": 9, 11 | "行政区": 10, 12 | "企业": 11, 13 | "影视作品": 12, 14 | "国家": 13, 15 | "书籍": 14, 16 | "人物": 15, 17 | "地点": 16, 18 | "音乐专辑": 17, 19 | "城市": 18, 20 | "Text": 19, 21 | "气候": 20, 22 | "Date": 21, 23 | "语言": 22, 24 | "Number": 23, 25 | "出版社": 24, 26 | "网站": 25, 27 | "目": 26, 28 | "学校": 27, 29 | "作品": 28, 30 | "O": 29 31 | } -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/model_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import json 6 | import copy 7 | 8 | 9 | 10 | class Config(object): 11 | 12 | @classmethod 13 | def from_dict(cls, json_object): 14 | config_instance = Config() 15 | for key, value in json_object.items(): 16 | try: 17 | tmp_value = Config.from_json_dict(value) 18 | config_instance.__dict__[key] = tmp_value 19 | except: 20 | config_instance.__dict__[key] = value 21 | return config_instance 22 | 23 | @classmethod 24 | def from_json_file(cls, json_file): 25 | with open(json_file, "r") as f: 26 | text = f.read() 27 | return Config.from_dict(json.loads(text)) 28 | 29 | 30 | @classmethod 31 | def from_json_dict(cls, json_str): 32 | return Config.from_dict(json_str) 33 | 34 | 35 | @classmethod 36 | def from_json_str(cls, json_str): 37 | return Config.from_dict(json.loads(json_str)) 38 | 39 | 40 | def to_dict(self): 41 | output = copy.deepcopy(self.__dict__) 42 | output = {k: v.to_dict() if isinstance(v, Config) else v for k, v in output.items()} 43 | return output 44 | 45 | 46 | def print_config(self): 47 | model_config = self.to_dict() 48 | json_config = json.dumps(model_config, indent=2) 49 | print(json_config) 50 | return json_config 51 | 52 | 53 | def to_json_string(self): 54 | return json.dumps(self.to_dict(), indent = 2) + "\n" 55 | 56 | 57 | def update_args(self, args_namespace): 58 | args_dict = args_namespace.__dict__ 59 | print("Please notice that merge the args_dict and json_config ... ...") 60 | for args_key, args_value in args_dict.items(): 61 | if args_key not in self.__dict__.keys(): 62 | self.__dict__[args_key] = args_value 63 | else: 64 | print("update the config from args input ... ...") 65 | self.__dict__[args_key] = args_values 66 | 67 | 68 | 69 | if __name__ == "__main__": 70 | str_instance = "{'name': 'lixiaoya'}" 71 | json_config = json.loads(str_instance) 72 | print("check the content of json_files") 73 | print(json_config.keys()) 74 | 75 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/mrc_data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import os 6 | import torch 7 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 8 | from expirement_er.bertMRC.data_loader.mrc_utils import convert_examples_to_features 9 | 10 | 11 | class MRCNERDataLoader(object): 12 | def __init__(self, config, data_processor, label_list, tokenizer, mode="train", allow_impossible=True): 13 | 14 | self.data_dir = config.data_dir 15 | self.max_seq_length = config.max_seq_length 16 | 17 | if mode == "train": 18 | self.train_batch_size = config.train_batch_size 19 | self.dev_batch_size = config.dev_batch_size 20 | self.test_batch_size = config.test_batch_size 21 | self.num_train_epochs = config.num_train_epochs 22 | elif mode == "test": 23 | self.test_batch_size = config.test_batch_size 24 | 25 | self.data_processor = data_processor 26 | self.label_list = label_list 27 | self.allow_impossible = allow_impossible 28 | self.tokenizer = tokenizer 29 | self.max_seq_len = config.max_seq_length 30 | self.data_cache = config.data_cache 31 | 32 | self.num_train_instances = 0 33 | self.num_dev_instances = 0 34 | self.num_test_instances = 0 35 | 36 | def convert_examples_to_features(self, data_sign="train",): 37 | 38 | print("=*="*10) 39 | print("loading {} data ... ...".format(data_sign)) 40 | 41 | if data_sign == "train": 42 | examples = self.data_processor.get_train_examples(self.data_dir) 43 | self.num_train_instances = len(examples) 44 | elif data_sign == "dev": 45 | examples = self.data_processor.get_dev_examples(self.data_dir) 46 | self.num_dev_instances = len(examples) 47 | elif data_sign == "test": 48 | examples = self.data_processor.get_test_examples(self.data_dir) 49 | self.num_test_instances = len(examples) 50 | else: 51 | raise ValueError("please notice that the data_sign can only be train/dev/test !!") 52 | 53 | cache_path = os.path.join(self.data_dir, "mrc-ner.{}.cache.{}".format(data_sign, str(self.max_seq_len))) 54 | if os.path.exists(cache_path) and self.data_cache: 55 | features = torch.load(cache_path) 56 | else: 57 | features = convert_examples_to_features(examples, self.tokenizer, self.label_list, self.max_seq_length, allow_impossible=self.allow_impossible) 58 | if self.data_cache: 59 | torch.save(features, cache_path) 60 | return features 61 | 62 | 63 | def get_dataloader(self, data_sign="train"): 64 | 65 | features = self.convert_examples_to_features(data_sign=data_sign) 66 | 67 | print(f"{len(features)} {data_sign} data loaded") 68 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 69 | input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 70 | output_mask = torch.tensor([f.output_mask for f in features], dtype=torch.long) 71 | segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 72 | start_pos = torch.tensor([f.start_position for f in features], dtype=torch.long) 73 | end_pos = torch.tensor([f.end_position for f in features], dtype=torch.long) 74 | span_pos = torch.tensor([f.span_position for f in features], dtype=torch.long) 75 | ner_cate = torch.tensor([f.ner_cate for f in features], dtype=torch.long) 76 | dataset = TensorDataset(input_ids, input_mask, output_mask, segment_ids, start_pos, end_pos, span_pos, ner_cate) 77 | 78 | if data_sign == "train": 79 | datasampler = SequentialSampler(dataset) # RandomSampler(dataset) 80 | dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.train_batch_size) 81 | elif data_sign == "dev": 82 | datasampler = SequentialSampler(dataset) 83 | dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.dev_batch_size) 84 | elif data_sign == "test": 85 | datasampler = SequentialSampler(dataset) 86 | dataloader = DataLoader(dataset, sampler=datasampler, batch_size=self.test_batch_size) 87 | 88 | return dataloader 89 | 90 | 91 | def get_num_train_epochs(self, ): 92 | return int((self.num_train_instances / self.train_batch_size) * self.num_train_epochs) 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/data_loader/mrc_data_processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | # Author: Xiaoy LI 6 | # Description: 7 | # mrc_ner_data_processor.py 8 | 9 | 10 | import os 11 | from expirement_er.bertMRC.data_loader.mrc_utils import read_mrc_ner_examples 12 | from expirement_er.bertMRC.data_loader.entities_type import * 13 | 14 | 15 | class QueryNERProcessor(object): 16 | # processor for the query-based ner dataset 17 | def get_train_examples(self, data_dir): 18 | data = read_mrc_ner_examples(os.path.join(data_dir, "train_.json")) 19 | return data 20 | 21 | def get_dev_examples(self, data_dir): 22 | return read_mrc_ner_examples(os.path.join(data_dir, "dev_.json")) 23 | 24 | def get_test_examples(self, data_dir): 25 | return read_mrc_ner_examples(os.path.join(data_dir, "dev_.json")) 26 | 27 | 28 | class Baidu19Processor(QueryNERProcessor): 29 | def get_labels(self, ): 30 | return baidu10_type2id.keys() 31 | 32 | 33 | class YanbaoProcessor(QueryNERProcessor): 34 | def get_labels(self, ): 35 | return ["NS", "NR", "NT", "O"] 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/bertMRC/layer/__init__.py -------------------------------------------------------------------------------- /expirement_er/bertMRC/layer/bert_layernorm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: Xiaoy LI 7 | # Description: 8 | # bert_layernorm.py 9 | 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | 16 | class BertLayerNorm(nn.Module): 17 | def __init__(self, hidden_size, eps=1e-12): 18 | # construct a layernorm module in the TF style 19 | # epsilon inside the square are not 20 | super(BertLayerNorm, self).__init__() 21 | self.weight = nn.Parameter(torch.ones(hidden_size)) 22 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 23 | self.variance_epsilon = eps 24 | 25 | 26 | def forward(self, x): 27 | u = x.mean(-1, keepdim=True) 28 | s = (x - u).pow(2).mean(-1, keepdim=True) 29 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 30 | return self.weight * x + self.bias -------------------------------------------------------------------------------- /expirement_er/bertMRC/layer/classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: Xiaoy LI 7 | # Description: 8 | # 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | 16 | class SingleLinearClassifier(nn.Module): 17 | def __init__(self, hidden_size, num_label): 18 | super(SingleLinearClassifier, self).__init__() 19 | self.num_label = num_label 20 | self.classifier = nn.Linear(hidden_size, num_label) 21 | 22 | def forward(self, input_features): 23 | features_output = self.classifier(input_features) 24 | 25 | return features_output 26 | 27 | 28 | class MultiNonLinearClassifier(nn.Module): 29 | def __init__(self, hidden_size, num_label, dropout_rate): 30 | super(MultiNonLinearClassifier, self).__init__() 31 | self.num_label = num_label 32 | self.classifier1 = nn.Linear(hidden_size, int(hidden_size / 2)) 33 | self.classifier2 = nn.Linear(int(hidden_size/2), num_label) 34 | self.dropout = nn.Dropout(dropout_rate) 35 | 36 | def forward(self, input_features): 37 | features_output1 = self.classifier1(input_features) 38 | features_output1 = nn.ReLU()(features_output1) 39 | features_output1 = self.dropout(features_output1) 40 | features_output2 = self.classifier2(features_output1) 41 | return features_output2 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/layer/optim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: hugging face -- transformer 7 | # description: 8 | # 9 | 10 | 11 | import math 12 | import torch 13 | from torch.optim import Optimizer 14 | from torch.optim.lr_scheduler import LambdaLR 15 | 16 | 17 | def lr_linear_decay(optimizer, decay_rate=0.95): 18 | for param_group in optimizer.param_groups: 19 | param_group["lr"] = param_group["lr"]*decay_rate 20 | print("current learning rate", param_group["lr"]) 21 | 22 | 23 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 24 | def lr_lambda(current_step): 25 | if current_step < num_warmup_steps: 26 | return float(current_step) / float(max(1, num_warmup_steps)) 27 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 28 | 29 | return LambdaLR(optimizer, lr_lambda, last_epoch) 30 | 31 | 32 | 33 | class AdamW(Optimizer): 34 | """ Implements Adam algorithm with weight decay fix. 35 | Parameters: 36 | lr (float): learning rate. Default 1e-3. 37 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 38 | eps (float): Adams epsilon. Default: 1e-6 39 | weight_decay (float): Weight decay. Default: 0.0 40 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 41 | """ 42 | 43 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 44 | if lr < 0.0: 45 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 46 | if not 0.0 <= betas[0] < 1.0: 47 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 48 | if not 0.0 <= betas[1] < 1.0: 49 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 50 | if not 0.0 <= eps: 51 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 52 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 53 | super().__init__(params, defaults) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group["params"]: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 72 | 73 | state = self.state[p] 74 | 75 | # State initialization 76 | if len(state) == 0: 77 | state["step"] = 0 78 | # Exponential moving average of gradient values 79 | state["exp_avg"] = torch.zeros_like(p.data) 80 | # Exponential moving average of squared gradient values 81 | state["exp_avg_sq"] = torch.zeros_like(p.data) 82 | 83 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 84 | beta1, beta2 = group["betas"] 85 | 86 | state["step"] += 1 87 | 88 | # Decay the first and second moment running average coefficient 89 | # In-place operations to update the averages at the same time 90 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 91 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 92 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 93 | 94 | step_size = group["lr"] 95 | if group["correct_bias"]: # No bias correction for Bert 96 | bias_correction1 = 1.0 - beta1 ** state["step"] 97 | bias_correction2 = 1.0 - beta2 ** state["step"] 98 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 99 | 100 | p.data.addcdiv_(-step_size, exp_avg, denom) 101 | 102 | # Just adding the square of the weights to the loss function is *not* 103 | # the correct way of using L2 regularization/weight decay with Adam, 104 | # since that will interact with the m and v parameters in strange ways. 105 | # 106 | # Instead we want to decay the weights in a manner that doesn't interact 107 | # with the m/v parameters. This is equivalent to adding the square 108 | # of the weights to the loss with plain (non-momentum) SGD. 109 | # Add weight decay at the end (fixed version) 110 | if group["weight_decay"] > 0.0: 111 | p.data.add_(-group["lr"] * group["weight_decay"], p.data) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/bertMRC/metric/__init__.py -------------------------------------------------------------------------------- /expirement_er/bertMRC/metric/flat_span_f1.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Yuxian Meng 4 | @contact: yuxian_meng@shannonai.com 5 | 6 | @version: 1.0 7 | @file: span_f1 8 | """ 9 | 10 | import json 11 | from typing import List, Set, Tuple 12 | 13 | 14 | def mask_span_f1(batch_preds, batch_labels, batch_masks=None, label_list: List[str] = None, 15 | output_path = None): 16 | """ 17 | compute span-based F1 18 | Args: 19 | batch_preds: predication . [batch, length] 20 | batch_labels: ground truth. [batch, length] 21 | label_list: label_list[idx] = label_idx. one label for every position 22 | batch_masks: [batch, length] 23 | 24 | Returns: 25 | span-based f1 26 | 27 | Examples: 28 | >>> label_list = ["B-W", "M-W", "E-W", "S-W", "O"] 29 | >>> batch_golden = [[0, 1, 2, 3, 4], [0, 2, 4]] 30 | >>> batch_preds = [[0, 1, 2, 3, 4], [4, 4, 4]] 31 | >>> metric_dic = mask_span_f1(batch_preds=batch_preds, batch_labels=batch_golden, label_list=label_list) 32 | """ 33 | fake_term = "一" 34 | true_positives = 0 35 | false_positives = 0 36 | false_negatives = 0 37 | 38 | if batch_masks is None: 39 | batch_masks = [None] * len(batch_preds) 40 | 41 | outputs = [] 42 | 43 | for preds, labels, masks in zip(batch_preds, batch_labels, batch_masks): 44 | if masks is not None: 45 | preds = trunc_by_mask(preds, masks) 46 | labels = trunc_by_mask(labels, masks) 47 | 48 | preds = [label_list[idx] if idx < len(label_list) else "O" for idx in preds] 49 | labels = [label_list[idx] for idx in labels] 50 | 51 | pred_tags: List[Tag] = bmes_decode(char_label_list=[(fake_term, pred) for pred in preds])[1] 52 | golden_tags: List[Tag] = bmes_decode(char_label_list=[(fake_term, label) for label in labels])[1] 53 | 54 | pred_set: Set[Tuple] = set((tag.begin, tag.end, tag.tag) for tag in pred_tags) 55 | golden_set: Set[Tuple] = set((tag.begin, tag.end, tag.tag) for tag in golden_tags) 56 | pred_tags = sorted([list(s) for s in pred_set], key=lambda x: x[0]) 57 | golden_tags = sorted([list(s) for s in golden_set], key=lambda x: x[0]) 58 | outputs.append( 59 | { 60 | "preds": " ".join(preds), 61 | "golden": " ".join(labels), 62 | "pred_tags:": "|".join(" ".join(str(s) for s in tag) for tag in pred_tags), 63 | "gold_tags:": "|".join(" ".join(str(s) for s in tag) for tag in golden_tags) 64 | } 65 | ) 66 | 67 | for pred in pred_set: 68 | if pred in golden_set: 69 | true_positives += 1 70 | else: 71 | false_positives += 1 72 | 73 | for pred in golden_set: 74 | if pred not in pred_set: 75 | false_negatives += 1 76 | 77 | precision = true_positives / (true_positives + false_positives + 1e-10) 78 | recall = true_positives / (true_positives + false_negatives + 1e-10) 79 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 80 | 81 | if output_path: 82 | json.dump(outputs, open(output_path, "w"), indent=4, sort_keys=True, ensure_ascii=False) 83 | print(f"Wrote visualization to {output_path}") 84 | 85 | return { 86 | "span-precision": precision, 87 | "span-recall": recall, 88 | "span-f1": f1 89 | } 90 | 91 | 92 | def trunc_by_mask(lst: List, masks: List) -> List: 93 | """mask according to truncate lst""" 94 | out = [] 95 | for item, mask in zip(lst, masks): 96 | if mask: 97 | out.append(item) 98 | return out 99 | 100 | 101 | 102 | class Tag(object): 103 | def __init__(self, term, tag, begin, end): 104 | self.term = term 105 | self.tag = tag 106 | self.begin = begin 107 | self.end = end 108 | 109 | def to_tuple(self): 110 | return tuple([self.term, self.begin, self.end]) 111 | 112 | def __str__(self): 113 | return str({key: value for key, value in self.__dict__.items()}) 114 | 115 | def __repr__(self): 116 | return str({key: value for key, value in self.__dict__.items()}) 117 | 118 | 119 | def bmes_decode(char_label_list: List[Tuple[str, str]]) -> Tuple[str, List[Tag]]: 120 | idx = 0 121 | length = len(char_label_list) 122 | tags = [] 123 | while idx < length: 124 | term, label = char_label_list[idx] 125 | current_label = label[0] 126 | 127 | # correct labels 128 | if idx + 1 == length and current_label == "B": 129 | current_label = "S" 130 | 131 | # merge chars 132 | if current_label == "O": 133 | idx += 1 134 | continue 135 | if current_label == "S": 136 | tags.append(Tag(term, label[2:], idx, idx + 1)) 137 | idx += 1 138 | continue 139 | if current_label == "B": 140 | end = idx + 1 141 | while end + 1 < length and char_label_list[end][1][0] == "M": 142 | end += 1 143 | if char_label_list[end][1][0] == "E": # end with E 144 | entity = "".join(char_label_list[i][0] for i in range(idx, end + 1)) 145 | tags.append(Tag(entity, label[2:], idx, end + 1)) 146 | idx = end + 1 147 | else: # end with M/B 148 | entity = "".join(char_label_list[i][0] for i in range(idx, end)) 149 | tags.append(Tag(entity, label[2:], idx, end)) 150 | idx = end 151 | continue 152 | else: 153 | idx += 1 154 | continue 155 | 156 | sentence = "".join(term for term, _ in char_label_list) 157 | return sentence, tags 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/metric/nest_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: xiaoyli 7 | # description: 8 | # 9 | 10 | 11 | 12 | class Tag(object): 13 | def __init__(self, tag, begin, end): 14 | self.tag = tag 15 | self.begin = begin 16 | self.end = end 17 | 18 | def to_tuple(self): 19 | return tuple([self.tag, self.begin, self.end]) 20 | 21 | def __str__(self): 22 | return str({key: value for key, value in self.__dict__.items()}) 23 | 24 | def __repr__(self): 25 | return str({key: value for key, value in self.__dict__.items()}) 26 | 27 | 28 | def nested_calculate_f1(pred_span_tag_lst, gold_span_tag_lst, dims=2): 29 | if dims == 2: 30 | true_positives = 0 31 | false_positives = 0 32 | false_negatives = 0 33 | 34 | for pred_span_tags, gold_span_tags in zip(pred_span_tag_lst, gold_span_tag_lst): 35 | pred_set = set((tag.begin, tag.end, tag.tag) for tag in pred_span_tags) 36 | gold_set = set((tag.begin, tag.end, tag.tag) for tag in gold_span_tags) 37 | 38 | for pred in pred_set: 39 | if pred in gold_set: 40 | true_positives += 1 41 | else: 42 | false_positives += 1 43 | 44 | for pred in gold_set: 45 | if pred not in pred_set: 46 | false_negatives += 1 47 | 48 | 49 | precision = true_positives / (true_positives + false_positives + 1e-10) 50 | recall = true_positives / (true_positives + false_negatives + 1e-10) 51 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 52 | 53 | return precision, recall, f1 54 | 55 | else: 56 | raise ValueError("Can not be other number except 2 !") 57 | 58 | 59 | -------------------------------------------------------------------------------- /expirement_er/bertMRC/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/bertMRC/model/__init__.py -------------------------------------------------------------------------------- /expirement_er/bertMRC/model/bert_mrc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | 6 | # Author: Xiaoy LI 7 | # Description: 8 | # Bert Model for MRC-Based NER Task 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import CrossEntropyLoss 14 | 15 | 16 | from expirement_er.bertMRC.layer.classifier import MultiNonLinearClassifier 17 | from expirement_er.bertMRC.layer.bert_basic_model import BertModel, BertConfig 18 | 19 | 20 | class BertQueryNER(nn.Module): 21 | def __init__(self, config): 22 | super(BertQueryNER, self).__init__() 23 | bert_config = BertConfig.from_dict(config.bert_config.to_dict()) 24 | self.bert = BertModel(bert_config) 25 | 26 | self.start_outputs = nn.Linear(config.hidden_size, 2) 27 | self.end_outputs = nn.Linear(config.hidden_size, 2) 28 | 29 | self.dropout = nn.Dropout(config.dropout) 30 | 31 | self.span_embedding = MultiNonLinearClassifier(config.hidden_size*2, 1, config.dropout) 32 | self.hidden_size = config.hidden_size 33 | self.bert = self.bert.from_pretrained(config.bert_model) 34 | self.loss_wb = config.weight_start 35 | self.loss_we = config.weight_end 36 | self.loss_ws = config.weight_span 37 | 38 | 39 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, 40 | start_positions=None, end_positions=None, span_positions=None, output_mask=None): 41 | """ 42 | Args: 43 | start_positions: (batch x max_len x 1) 44 | [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] 45 | end_positions: (batch x max_len x 1) 46 | [[0, 1, 0, 0, 1, 0, 1, 0, 0, ], [0, 1, 0, 0, 1, 0, 1, 0, 0, ]] 47 | span_positions: (batch x max_len x max_len) 48 | span_positions[k][i][j] is one of [0, 1], 49 | span_positions[k][i][j] represents whether or not from start_pos{i} to end_pos{j} of the K-th sentence in the batch is an entity. 50 | :param output_mask: 51 | """ 52 | 53 | sequence_output, pooled_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 54 | 55 | sequence_heatmap = sequence_output # batch x seq_len x hidden 56 | batch_size, seq_len, hid_size = sequence_heatmap.size() 57 | 58 | start_logits = self.start_outputs(sequence_heatmap) # batch x seq_len x 2 59 | end_logits = self.end_outputs(sequence_heatmap) # batch x seq_len x 2 60 | 61 | # start_logits = self.dropout(start_logits) 62 | # end_logits = self.dropout(end_logits) 63 | 64 | # for every position $i$ in sequence, should concate $j$ to 65 | # predict if $i$ and $j$ are start_pos and end_pos for an entity. 66 | start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) 67 | end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) 68 | # the shape of start_end_concat[0] is : batch x 1 x seq_len x 2*hidden 69 | 70 | span_matrix = torch.cat([start_extend, end_extend], 3) # batch x seq_len x seq_len x 2*hidden 71 | 72 | span_logits = self.span_embedding(span_matrix) # batch x seq_len x seq_len x 1 73 | span_logits = torch.squeeze(span_logits) # batch x seq_len x seq_len 74 | 75 | if start_positions is not None and end_positions is not None: 76 | loss_fct = CrossEntropyLoss() 77 | start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1)) 78 | end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1)) 79 | 80 | span_loss_fct = nn.BCEWithLogitsLoss() 81 | span_loss = span_loss_fct(span_logits.view(batch_size, -1), span_positions.view(batch_size, -1).float()) 82 | total_loss = self.loss_wb * start_loss + self.loss_we * end_loss + self.loss_ws * span_loss 83 | return total_loss 84 | else: 85 | span_logits = torch.sigmoid(span_logits) # batch x seq_len x seq_len 86 | start_logits = torch.argmax(start_logits, dim=-1) 87 | end_logits = torch.argmax(end_logits, dim=-1) 88 | 89 | return start_logits, end_logits, span_logits 90 | 91 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run evaluation with saved models. 3 | """ 4 | 5 | import os 6 | import random 7 | import argparse 8 | import pickle 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import json 14 | from expirement_er.bertcrf.models.model import RelationModel 15 | from expirement_er.bertcrf.utils import torch_utils, helper, result 16 | from expirement_er.bertcrf.utils.vocab import Vocab 17 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 18 | 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('model_dir', type=str, help='Directory of the model.') 23 | parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.') 24 | parser.add_argument('--data_dir', type=str, default='dataset/NYT-multi/data') 25 | parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.") 26 | 27 | parser.add_argument('--seed', type=int, default=42) 28 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) 29 | parser.add_argument('--cpu', action='store_true') 30 | args = parser.parse_args() 31 | 32 | torch.manual_seed(args.seed) 33 | random.seed(args.seed) 34 | if args.cpu: 35 | args.cuda = False 36 | elif args.cuda: 37 | torch.cuda.manual_seed(args.seed) 38 | 39 | # load opt 40 | model_file = args.model_dir + '/' + args.model 41 | print("Loading model from {}".format(model_file)) 42 | opt = torch_utils.load_config(model_file) 43 | model = RelationModel(opt) 44 | model.load(model_file) 45 | 46 | 47 | 48 | # load data 49 | data_file = args.data_dir + '/{}.json'.format(args.dataset) 50 | print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size'])) 51 | data = json.load(open(data_file)) 52 | 53 | id2predicate, predicate2id, id2subj_type, subj_type2id, id2obj_type, obj_type2id = json.load(open(opt['data_dir'] + '/schemas.json')) 54 | id2predicate = {int(i):j for i,j in id2predicate.items()} 55 | 56 | # 加载bert词表 57 | UNCASED = args.bert_model # your path for model and vocab 58 | VOCAB = 'bert-base-chinese-vocab.txt' 59 | bert_tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED,VOCAB)) 60 | 61 | helper.print_config(opt) 62 | 63 | results = result.evaluate(bert_tokenizer, data, id2predicate, model) 64 | results_save_dir = opt['model_save_dir'] + '/results.json' 65 | print("Dumping the best test results to {}".format(results_save_dir)) 66 | 67 | with open(results_save_dir, 'w') as fw: 68 | json.dump(results, fw, indent=4, ensure_ascii=False) 69 | 70 | print("Evaluation ended.") 71 | 72 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/models/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | A joint model for relation extraction, written in pytorch. 3 | """ 4 | import math 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import init 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import pdb 12 | from expirement_er.bertcrf.utils import torch_utils, loader 13 | from expirement_er.bertcrf.models import submodel 14 | 15 | 16 | class RelationModel(object): 17 | """ A wrapper class for the training and evaluation of models. """ 18 | def __init__(self, opt, bert_model): 19 | self.opt = opt 20 | self.bert_model = bert_model 21 | self.model = BiLSTMCNN(opt, bert_model) 22 | self.parameters = [p for p in self.model.parameters() if p.requires_grad] 23 | if opt['cuda']: 24 | self.model.cuda() 25 | self.optimizer = torch_utils.get_optimizer(opt['optim'], self.parameters, opt['lr'], opt['weight_decay']) 26 | 27 | 28 | def update(self, batch): 29 | """ Run a step of forward and backward model update. """ 30 | if self.opt['cuda']: 31 | inputs = Variable(torch.LongTensor(batch[0]).cuda()) 32 | tags = Variable(torch.LongTensor(batch[1]).cuda()) 33 | 34 | mask = (inputs.data>0).float() 35 | self.model.train() 36 | self.optimizer.zero_grad() 37 | 38 | loss = self.model(inputs, tags, mask) 39 | 40 | loss.backward() 41 | # torch.nn.utils.clip_grad_norm(self.model.parameters(), self.opt['max_grad_norm']) 42 | self.optimizer.step() 43 | loss_val = loss.data.item() 44 | return loss_val 45 | 46 | 47 | def predict_subj_per_instance(self, words): 48 | 49 | if self.opt['cuda']: 50 | words = Variable(torch.LongTensor(words).cuda()) 51 | mask = (words.data > 0).float() 52 | 53 | # forward 54 | self.model.eval() 55 | hidden = self.model.based_encoder(words) 56 | best_tags_list = self.model.subj_sublayer.predict_subj_start(hidden, mask) 57 | 58 | return best_tags_list 59 | 60 | 61 | def update_lr(self, new_lr): 62 | torch_utils.change_lr(self.optimizer, new_lr) 63 | 64 | def save(self, filename, epoch): 65 | params = { 66 | 'model': self.model.state_dict(), 67 | 'config': self.opt, 68 | 'epoch': epoch 69 | } 70 | try: 71 | torch.save(params, filename) 72 | print("model saved to {}".format(filename)) 73 | except BaseException: 74 | print("[Warning: Saving failed... continuing anyway.]") 75 | 76 | def load(self, filename): 77 | try: 78 | checkpoint = torch.load(filename) 79 | except BaseException: 80 | print("Cannot load model from {}".format(filename)) 81 | exit() 82 | self.model.load_state_dict(checkpoint['model']) 83 | self.opt = checkpoint['config'] 84 | 85 | class BiLSTMCNN(nn.Module): 86 | """ A sequence model for relation extraction. """ 87 | 88 | def __init__(self, opt, bert_model): 89 | super(BiLSTMCNN, self).__init__() 90 | self.drop = nn.Dropout(opt['dropout']) 91 | self.input_size = opt['word_emb_dim'] 92 | self.subj_sublayer = submodel.SubjTypeModel(opt) 93 | self.opt = opt 94 | self.bert_model = bert_model 95 | self.topn = self.opt.get('topn', 1e10) 96 | self.use_cuda = opt['cuda'] 97 | 98 | def based_encoder(self, words): 99 | hidden, _ = self.bert_model(words, output_all_encoded_layers=False) 100 | return hidden 101 | 102 | def forward(self, inputs, tags, mask): 103 | 104 | hidden = self.based_encoder(inputs) 105 | loss = self.subj_sublayer(hidden, tags, mask) 106 | return loss 107 | 108 | 109 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/models/mycrf.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | PAD = "" # padding 5 | SOS = "" # start of sequence 6 | EOS = "" # end of sequence 7 | UNK = "" # unknown token 8 | 9 | PAD_IDX = 59 10 | SOS_IDX = 0 11 | EOS_IDX = 1 12 | UNK_IDX = 60 13 | 14 | CUDA = torch.cuda.is_available() 15 | torch.manual_seed(0) # for reproducibility 16 | # torch.cuda.set_device(0) 17 | 18 | Tensor = torch.cuda.FloatTensor if CUDA else torch.FloatTensor 19 | LongTensor = torch.cuda.LongTensor if CUDA else torch.LongTensor 20 | randn = lambda *x: torch.randn(*x).cuda() if CUDA else torch.randn 21 | zeros = lambda *x: torch.zeros(*x).cuda() if CUDA else torch.zeros 22 | 23 | def log_sum_exp(x): 24 | m = torch.max(x, -1)[0] 25 | return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(-1)), -1)) 26 | 27 | 28 | class mycrf(nn.Module): 29 | def __init__(self, num_tags): 30 | super(mycrf, self).__init__() 31 | self.batch_size = 1 32 | self.num_tags = num_tags 33 | 34 | # matrix of transition scores from j to i 35 | self.trans = nn.Parameter(randn(num_tags, num_tags)) 36 | self.trans.data[SOS_IDX, :] = -10000 # no transition to SOS 37 | self.trans.data[:, EOS_IDX] = -10000 # no transition from EOS except to PAD 38 | self.trans.data[:, PAD_IDX] = -10000 # no transition from PAD except to PAD 39 | self.trans.data[PAD_IDX, :] = -10000 # no transition to PAD except from EOS 40 | self.trans.data[PAD_IDX, EOS_IDX] = 0 41 | self.trans.data[PAD_IDX, PAD_IDX] = 0 42 | 43 | def forward(self, h, mask): # forward algorithm 44 | # initialize forward variables in log space 45 | score = Tensor(self.batch_size, self.num_tags).fill_(-10000) # [B, C] 46 | score[:, SOS_IDX] = 0. 47 | trans = self.trans.unsqueeze(0) # [1, C, C] 48 | for t in range(h.size(1)): # recursion through the sequence 49 | mask_t = mask[:, t].unsqueeze(1) 50 | emit_t = h[:, t].unsqueeze(2) # [B, C, 1] 51 | score_t = score.unsqueeze(1) + emit_t + trans # [B, 1, C] -> [B, C, C] 52 | score_t = log_sum_exp(score_t) # [B, C, C] -> [B, C] 53 | score = score_t * mask_t + score * (1 - mask_t) 54 | score = log_sum_exp(score + self.trans[EOS_IDX]) 55 | return score # partition function 56 | 57 | def score(self, h, y0, mask): # calculate the score of a given sequence 58 | score = Tensor(self.batch_size).fill_(0.) 59 | h = h.unsqueeze(3) 60 | trans = self.trans.unsqueeze(2) 61 | for t in range(h.size(1)-1): # recursion through the sequence 62 | mask_t = mask[:, t] 63 | emit_t = torch.cat([h[t, y0[t + 1]] for h, y0 in zip(h, y0)]) 64 | trans_t = torch.cat([trans[y0[t + 1], y0[t]] for y0 in y0]) 65 | score += (emit_t + trans_t) * mask_t 66 | # last_tag = y0.gather(1, mask.sum(1).long().unsqueeze(1)) 67 | # last_tag = last_tag.squeeze(1) 68 | # score += self.trans[EOS_IDX, last_tag] 69 | return score 70 | 71 | def decode(self, h, mask): # Viterbi decoding 72 | # initialize backpointers and viterbi variables in log space 73 | bptr = LongTensor() 74 | score = Tensor(self.batch_size, self.num_tags).fill_(-10000) 75 | score[:, SOS_IDX] = 0. 76 | 77 | for t in range(h.size(1)-1): # recursion through the sequence 78 | mask_t = mask[:, t].unsqueeze(1) 79 | score_t = score.unsqueeze(1) + self.trans # [B, 1, C] -> [B, C, C] 80 | score_t, bptr_t = score_t.max(2) # best previous scores and tags 81 | score_t += h[:, t] # plus emission scores 82 | bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1) 83 | score = score_t * mask_t + score * (1 - mask_t) 84 | score += self.trans[EOS_IDX] 85 | best_score, best_tag = torch.max(score, 1) 86 | 87 | # back-tracking 88 | bptr = bptr.tolist() 89 | best_path = [[i] for i in best_tag.tolist()] 90 | for b in range(self.batch_size): 91 | i = best_tag[b] # best tag 92 | j = int(mask[b].sum().item()) 93 | for bptr_t in reversed(bptr[b][:j]): 94 | i = bptr_t[i] 95 | best_path[b].append(i) 96 | best_path[b].pop() 97 | best_path[b].reverse() 98 | 99 | return best_path 100 | 101 | 102 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/models/submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from expirement_er.bertcrf.utils import torch_utils 7 | from pytorchcrf import CRF 8 | from expirement_er.bertcrf.models.mycrf import mycrf 9 | import numpy as np 10 | 11 | class SubjTypeModel(nn.Module): 12 | 13 | def __init__(self, opt, filter=3): 14 | super(SubjTypeModel, self).__init__() 15 | self.opt = opt 16 | self.num_tags = opt['num_tags'] 17 | 18 | self.dropout = nn.Dropout(opt['dropout']) 19 | self.hidden_dim = opt['word_emb_dim'] 20 | 21 | self.position_embedding = nn.Embedding(500, opt['position_emb_dim']) 22 | 23 | self.linear_subj = nn.Linear(self.hidden_dim, opt['num_tags']) 24 | self.init_weights() 25 | 26 | # 1.使用CRF工具包 27 | # self.crf_module = CRF(opt['num_tags'], False) 28 | 29 | # 2.使用mycrf 30 | if opt['decode_function'] == 'mycrf': 31 | self.crf_module = mycrf(opt['num_tags']) 32 | 33 | # 3.使用softmax 34 | if self.opt['decode_function'] == 'softmax': 35 | self.criterion = nn.CrossEntropyLoss(reduction='none') 36 | self.criterion.cuda() 37 | 38 | 39 | def init_weights(self): 40 | 41 | self.position_embedding.weight.data.uniform_(-1.0, 1.0) 42 | self.linear_subj.bias.data.fill_(0) 43 | init.xavier_uniform_(self.linear_subj.weight, gain=1) # initialize linear layer 44 | 45 | def forward(self, hidden, tags, mask): 46 | 47 | subj_start_inputs = self.dropout(hidden) 48 | subj_start_logits = self.linear_subj(subj_start_inputs) 49 | 50 | # 1.使用CRF工具包 51 | # score = self.crf_module.forward(subj_start_logits, tags=tags, mask=mask) 52 | # return score 53 | 54 | # 2.使用mycrf 55 | if self.opt['decode_function'] == 'mycrf': 56 | Z = self.crf_module.forward(subj_start_logits, mask) 57 | score = self.crf_module.score(subj_start_logits, tags, mask) 58 | return torch.mean(Z - score) # NLL loss 59 | 60 | # 3.使用softmax 61 | if self.opt['decode_function'] == 'softmax': 62 | # subj_start_logits = torch.softmax(subj_start_logits, dim=2) 63 | _s1 = subj_start_logits.view(-1, self.num_tags) 64 | tags = tags.view(-1).squeeze() 65 | loss = self.criterion(_s1, tags) 66 | loss = loss.view_as(mask) 67 | loss = torch.sum(loss.mul(mask.float())) / torch.sum(mask.float()) 68 | return loss 69 | 70 | def predict_subj_start(self, hidden, mask): 71 | 72 | subj_start_logits = self.linear_subj(hidden) 73 | # 1.使用CRF工具包 74 | # best_tags_list = self.crf_module.decode(subj_start_logits) 75 | # return best_tags_list 76 | 77 | # 2.使用mycrf 78 | if self.opt['decode_function'] == 'mycrf': 79 | return self.crf_module.decode(subj_start_logits, mask) 80 | 81 | # 3.使用softmax 82 | if self.opt['decode_function'] == 'softmax': 83 | # subj_start_logits = torch.softmax(subj_start_logits, dim=2) 84 | best_tags_list = subj_start_logits.cpu().detach().numpy() 85 | best_tags_list = np.argmax(best_tags_list, 2).tolist() 86 | return best_tags_list 87 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define common constants. 3 | """ 4 | TRAIN_JSON = 'train.json' 5 | DEV_JSON = 'dev.json' 6 | TEST_JSON = 'test.json' 7 | 8 | GLOVE_DIR = 'dataset/glove' 9 | 10 | EMB_INIT_RANGE = 1.0 11 | MAX_LEN = 100 12 | 13 | # vocab 14 | PAD_TOKEN = '' 15 | PAD_ID = 0 16 | UNK_TOKEN = '' 17 | UNK_ID = 1 18 | 19 | VOCAB_PREFIX = [PAD_TOKEN, UNK_TOKEN] 20 | 21 | INFINITY_NUMBER = 1e12 22 | 23 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions. 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | 9 | ### IO 10 | def check_dir(d): 11 | if not os.path.exists(d): 12 | print("Directory {} does not exist. Exit.".format(d)) 13 | exit(1) 14 | 15 | def check_files(files): 16 | for f in files: 17 | if f is not None and not os.path.exists(f): 18 | print("File {} does not exist. Exit.".format(f)) 19 | exit(1) 20 | 21 | def ensure_dir(d, verbose=True): 22 | if not os.path.exists(d): 23 | if verbose: 24 | print("Directory {} do not exist; creating...".format(d)) 25 | os.makedirs(d) 26 | 27 | def save_config(config, path, verbose=True): 28 | with open(path, 'w') as outfile: 29 | json.dump(config, outfile, indent=2) 30 | if verbose: 31 | print("Config saved to file {}".format(path)) 32 | return config 33 | 34 | def load_config(path, verbose=True): 35 | with open(path) as f: 36 | config = json.load(f) 37 | if verbose: 38 | print("Config loaded from file {}".format(path)) 39 | return config 40 | 41 | def print_config(config): 42 | info = "Running with the following configs:\n" 43 | for k,v in config.items(): 44 | info += "\t{} : {}\n".format(k, str(v)) 45 | print("\n" + info + "\n") 46 | return 47 | 48 | class FileLogger(object): 49 | """ 50 | A file logger that opens the file periodically and write to it. 51 | """ 52 | def __init__(self, filename, header=None): 53 | self.filename = filename 54 | if os.path.exists(filename): 55 | # remove the old file 56 | os.remove(filename) 57 | if header is not None: 58 | with open(filename, 'w') as out: 59 | print(header, file=out) 60 | 61 | def log(self, message): 62 | with open(self.filename, 'a') as out: 63 | print(message, file=out) 64 | 65 | 66 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from expirement_er.bertcrf.utils import loader 4 | import json 5 | from torch import nn 6 | import torch 7 | 8 | 9 | def extract_items(opt, tokens, tokens_ids, model, idx2tag): 10 | R = [] 11 | 12 | _t = np.array(loader.seq_padding([tokens_ids])) 13 | 14 | best_tags_list_ids = model.predict_subj_per_instance(_t) 15 | # best_tags_list_ids = best_tags_list_ids.squeeze(-1).cpu().detach().numpy().tolist()[0] 16 | best_tags_list_ids = best_tags_list_ids[0] 17 | # best_tags_list = [[idx2tag[idx] for idx in idxs] for idxs in best_tags_list_ids] 18 | 19 | record = False 20 | for token_index, tag_id in enumerate(best_tags_list_ids): 21 | tag = idx2tag[tag_id] 22 | if tag.startswith('B'): 23 | start_token_index = token_index 24 | record = True 25 | elif record and tag == 'O': 26 | end_token_index = token_index 27 | str_start_index = start_token_index 28 | str_end_index = end_token_index 29 | # 使用crf时多了一个起始标签 30 | if opt['decode_function'] == 'mycrf': 31 | entity_name = tokens[str_start_index + 1: str_end_index + 1] 32 | elif opt['decode_function'] == 'softmax': 33 | entity_name = tokens[str_start_index: str_end_index] 34 | entity_name = ''.join(entity_name) 35 | R.append(entity_name) 36 | record = False 37 | # if R == []: 38 | # R.append(" ") 39 | return set(tuple(R)) 40 | 41 | 42 | def evaluate(opt, data, model, idx2tag): 43 | official_A, official_B, official_C = 1e-10, 1e-10, 1e-10 44 | manual_A, manual_B, manual_C = 1e-10, 1e-10, 1e-10 45 | 46 | # { 47 | # "sent": "半导体行情的风险是什么", 48 | # "sent_tokens": ["半", "导", "体", "行", "情", "的", "风", "险", "是", "什", "么"], 49 | # "sent_token_ids": [1288, 2193, 860, 6121, 2658, 4638, 7599, 7372, 3221, 784, 720], 50 | # "entity_labels": [{"entity_type": "研报", "start_token_id": 0, "end_token_id": 10, "start_index": 0, "end_index": 10, 51 | # "entity_tokens": ["半", "导", "体", "行", "情", "的", "风", "险", "是", "什", "么"], 52 | # "entity_name": "半导体行情的风险是什么"}], 53 | # "tags": ["B-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report"], 54 | # "tag_ids": [11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]}, 55 | 56 | results = [] 57 | index = 0 58 | for d in tqdm(iter(data)): 59 | # index += 1 60 | # if index > 200: 61 | # break 62 | if len(d['sent_token_ids']) > 180: 63 | continue 64 | R = extract_items(opt, d['sent_tokens'], d['sent_token_ids'], model, idx2tag) 65 | # official_T = set([tuple(i['entity_name']) for i in d['spo_details']]) 66 | official_T_list = [] 67 | for i in d['entity_labels']: 68 | official_T_list.append(i['entity_name']) 69 | official_T = set(tuple(official_T_list)) 70 | results.append({'text': ''.join(d['sent_tokens']), 'predict': list(R), 'truth': list(official_T)}) 71 | official_A += len(R & official_T) 72 | official_B += len(R) 73 | official_C += len(official_T) 74 | return 2 * official_A / (official_B + official_C), official_A / official_B, official_A / official_C, results 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for torch. 3 | """ 4 | 5 | import torch 6 | from torch import nn, optim 7 | from torch.optim import Optimizer 8 | 9 | ### class 10 | class MyAdagrad(Optimizer): 11 | """My modification of the Adagrad optimizer that allows to specify an initial 12 | accumulater value. This mimics the behavior of the default Adagrad implementation 13 | in Tensorflow. The default PyTorch Adagrad uses 0 for initial acculmulator value. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-2) 19 | lr_decay (float, optional): learning rate decay (default: 0) 20 | init_accu_value (float, optional): initial accumulater value. 21 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 22 | """ 23 | 24 | def __init__(self, params, lr=1e-2, lr_decay=0, init_accu_value=0.1, weight_decay=0): 25 | defaults = dict(lr=lr, lr_decay=lr_decay, init_accu_value=init_accu_value, \ 26 | weight_decay=weight_decay) 27 | super(MyAdagrad, self).__init__(params, defaults) 28 | 29 | for group in self.param_groups: 30 | for p in group['params']: 31 | state = self.state[p] 32 | state['step'] = 0 33 | state['sum'] = torch.ones(p.data.size()).type_as(p.data) *\ 34 | init_accu_value 35 | 36 | def share_memory(self): 37 | for group in self.param_groups: 38 | for p in group['params']: 39 | state = self.state[p] 40 | state['sum'].share_memory_() 41 | 42 | def step(self, closure=None): 43 | """Performs a single optimization step. 44 | 45 | Arguments: 46 | closure (callable, optional): A closure that reevaluates the model 47 | and returns the loss. 48 | """ 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | 53 | for group in self.param_groups: 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | 58 | grad = p.grad.data 59 | state = self.state[p] 60 | 61 | state['step'] += 1 62 | 63 | if group['weight_decay'] != 0: 64 | if p.grad.data.is_sparse: 65 | raise RuntimeError("weight_decay option is not compatible with sparse gradients ") 66 | grad = grad.add(group['weight_decay'], p.data) 67 | 68 | clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay']) 69 | 70 | if p.grad.data.is_sparse: 71 | grad = grad.coalesce() # the update is non-linear so indices must be unique 72 | grad_indices = grad._indices() 73 | grad_values = grad._values() 74 | size = torch.Size([x for x in grad.size()]) 75 | 76 | def make_sparse(values): 77 | constructor = type(p.grad.data) 78 | if grad_indices.dim() == 0 or values.dim() == 0: 79 | return constructor() 80 | return constructor(grad_indices, values, size) 81 | state['sum'].add_(make_sparse(grad_values.pow(2))) 82 | std = state['sum']._sparse_mask(grad) 83 | std_values = std._values().sqrt_().add_(1e-10) 84 | p.data.add_(-clr, make_sparse(grad_values / std_values)) 85 | else: 86 | state['sum'].addcmul_(1, grad, grad) 87 | std = state['sum'].sqrt().add_(1e-10) 88 | p.data.addcdiv_(-clr, grad, std) 89 | 90 | return loss 91 | 92 | 93 | ### torch specific functions 94 | def get_optimizer(name, parameters, lr, weight_decay): 95 | if name == 'sgd': 96 | return torch.optim.SGD(parameters, lr=lr) 97 | elif name in ['adagrad', 'myadagrad']: 98 | # use my own adagrad to allow for init accumulator value 99 | return MyAdagrad(parameters, lr=lr, init_accu_value=0.1) 100 | elif name == 'adam': 101 | return torch.optim.Adam(parameters, lr=lr, betas=(0.9, 0.99), weight_decay=weight_decay) # use default lr 102 | elif name == 'adamax': 103 | return torch.optim.Adamax(parameters, lr=lr) # use default lr 104 | else: 105 | raise Exception("Unsupported optimizer: {}".format(name)) 106 | 107 | def change_lr(optimizer, new_lr): 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = new_lr 110 | 111 | def flatten_indices(seq_lens, width): 112 | flat = [] 113 | for i, l in enumerate(seq_lens): 114 | for j in range(l): 115 | flat.append(i * width + j) 116 | return flat 117 | 118 | def set_cuda(var, cuda): 119 | if cuda: 120 | return var.cuda() 121 | return var 122 | 123 | def keep_partial_grad(grad, topk): 124 | """ 125 | Keep only the topk rows of grads. 126 | """ 127 | assert topk < grad.size(0) 128 | grad.data[topk:].zero_() 129 | return grad 130 | 131 | ### model IO 132 | def save(model, optimizer, opt, filename): 133 | params = { 134 | 'model': model.state_dict(), 135 | 'optimizer': optimizer.state_dict(), 136 | 'config': opt 137 | } 138 | try: 139 | torch.save(params, filename) 140 | except BaseException: 141 | print("[ Warning: model saving failed. ]") 142 | 143 | def load(model, optimizer, filename): 144 | try: 145 | dump = torch.load(filename) 146 | except BaseException: 147 | print("[ Fail: model loading failed. ]") 148 | if model is not None: 149 | model.load_state_dict(dump['model']) 150 | if optimizer is not None: 151 | optimizer.load_state_dict(dump['optimizer']) 152 | opt = dump['config'] 153 | return model, optimizer, opt 154 | 155 | def load_config(filename): 156 | try: 157 | dump = torch.load(filename) 158 | except BaseException: 159 | print("[ Fail: model loading failed. ]") 160 | return dump['config'] 161 | 162 | 163 | -------------------------------------------------------------------------------- /expirement_er/bertcrf/utils/vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | A class for basic vocab operations. 3 | """ 4 | 5 | from __future__ import print_function 6 | import os 7 | import random 8 | import numpy as np 9 | import pickle 10 | 11 | from expirement_er.etlspan.utils import constant 12 | 13 | 14 | 15 | random.seed(1234) 16 | np.random.seed(1234) 17 | 18 | def build_embedding(wv_file, vocab, wv_dim): 19 | vocab_size = len(vocab) 20 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim)) 21 | emb[constant.PAD_ID] = 0 # should be all 0 (using broadcast) 22 | 23 | w2id = {w: i for i, w in enumerate(vocab)} 24 | with open(wv_file, encoding="utf8") as f: 25 | for line in f: 26 | elems = line.split() 27 | token = ''.join(elems[0:-wv_dim]) 28 | if token in w2id: 29 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 30 | return emb 31 | 32 | def load_glove_vocab(file, wv_dim): 33 | """ 34 | Load all words from glove. 35 | """ 36 | vocab = set() 37 | with open(file, encoding='utf8') as f: 38 | try: 39 | for line_num, line in enumerate(f): 40 | elems = line.split() 41 | token = ''.join(elems[0:-wv_dim]) 42 | vocab.add(token) 43 | except Exception as e: 44 | print(line) 45 | return vocab 46 | 47 | def normalize_glove(token): 48 | mapping = {'-LRB-': '(', 49 | '-RRB-': ')', 50 | '-LSB-': '[', 51 | '-RSB-': ']', 52 | '-LCB-': '{', 53 | '-RCB-': '}'} 54 | if token in mapping: 55 | token = mapping[token] 56 | return token 57 | 58 | class Vocab(object): 59 | def __init__(self, filename, load=False, word_counter=None, threshold=0): 60 | if load: 61 | assert os.path.exists(filename), "Vocab file does not exist at " + filename 62 | # load from file and ignore all other params 63 | self.id2word, self.word2id = self.load(filename) 64 | self.size = len(self.id2word) 65 | print("Vocab size {} loaded from file".format(self.size)) 66 | else: 67 | print("Creating vocab from scratch...") 68 | assert word_counter is not None, "word_counter is not provided for vocab creation." 69 | self.word_counter = word_counter 70 | if threshold > 1: 71 | # remove words that occur less than thres 72 | self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold]) 73 | self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True) 74 | # add special tokens to the beginning 75 | self.id2word = ['**PAD**', '**UNK**'] + self.id2word 76 | self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))]) 77 | self.size = len(self.id2word) 78 | self.save(filename) 79 | print("Vocab size {} saved to file {}".format(self.size, filename)) 80 | 81 | def load(self, filename): 82 | with open(filename, 'rb') as infile: 83 | id2word = pickle.load(infile) 84 | word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))]) 85 | return id2word, word2id 86 | 87 | def save(self, filename): 88 | #assert not os.path.exists(filename), "Cannot save vocab: file exists at " + filename 89 | if os.path.exists(filename): 90 | print("Overwriting old vocab file at " + filename) 91 | os.remove(filename) 92 | with open(filename, 'wb') as outfile: 93 | pickle.dump(self.id2word, outfile) 94 | return 95 | 96 | def map(self, token_list): 97 | """ 98 | Map a list of tokens to their ids. 99 | """ 100 | return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list] 101 | 102 | def unmap(self, idx_list): 103 | """ 104 | Unmap ids back to tokens. 105 | """ 106 | return [self.id2word[idx] for idx in idx_list] 107 | 108 | def get_embeddings(self, word_vectors=None, dim=100): 109 | #self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE 110 | self.embeddings = np.zeros((self.size, dim)) 111 | if word_vectors is not None: 112 | assert len(list(word_vectors.values())[0]) == dim, \ 113 | "Word vectors does not have required dimension {}.".format(dim) 114 | for w, idx in self.word2id.items(): 115 | if w in word_vectors: 116 | self.embeddings[idx] = np.asarray(word_vectors[w]) 117 | return self.embeddings 118 | 119 | 120 | -------------------------------------------------------------------------------- /expirement_er/etlspan/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/etlspan/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run evaluation with saved models. 3 | """ 4 | 5 | import os 6 | import random 7 | import argparse 8 | import pickle 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | import json 14 | from expirement_er.etlspan.models.model import RelationModel 15 | from expirement_er.etlspan.utils import torch_utils, helper, result 16 | from expirement_er.etlspan.utils.vocab import Vocab 17 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('model_dir', type=str, help='Directory of the model.') 22 | parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.') 23 | parser.add_argument('--data_dir', type=str, default='dataset/NYT-multi/data') 24 | parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.") 25 | 26 | parser.add_argument('--seed', type=int, default=42) 27 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) 28 | parser.add_argument('--cpu', action='store_true') 29 | args = parser.parse_args() 30 | 31 | torch.manual_seed(args.seed) 32 | random.seed(args.seed) 33 | if args.cpu: 34 | args.cuda = False 35 | elif args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | # load opt 39 | model_file = args.model_dir + '/' + args.model 40 | print("Loading model from {}".format(model_file)) 41 | opt = torch_utils.load_config(model_file) 42 | model = RelationModel(opt) 43 | model.load(model_file) 44 | 45 | 46 | 47 | # load data 48 | data_file = args.data_dir + '/{}.json'.format(args.dataset) 49 | print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size'])) 50 | data = json.load(open(data_file)) 51 | 52 | id2predicate, predicate2id, id2subj_type, subj_type2id, id2obj_type, obj_type2id = json.load(open(opt['data_dir'] + '/schemas.json')) 53 | id2predicate = {int(i):j for i,j in id2predicate.items()} 54 | 55 | # 加载bert词表 56 | UNCASED = args.bert_model # your path for model and vocab 57 | VOCAB = 'bert-base-chinese-vocab.txt' 58 | bert_tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED,VOCAB)) 59 | 60 | helper.print_config(opt) 61 | 62 | results = result.evaluate(bert_tokenizer, data, id2predicate, model) 63 | results_save_dir = opt['model_save_dir'] + '/results.json' 64 | print("Dumping the best test results to {}".format(results_save_dir)) 65 | 66 | with open(results_save_dir, 'w') as fw: 67 | json.dump(results, fw, indent=4, ensure_ascii=False) 68 | 69 | print("Evaluation ended.") 70 | 71 | -------------------------------------------------------------------------------- /expirement_er/etlspan/models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/etlspan/models/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | A joint model for relation extraction, written in pytorch. 3 | """ 4 | import math 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import init 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import pdb 12 | from expirement_er.etlspan.utils import torch_utils, loader 13 | from expirement_er.etlspan.models import submodel 14 | 15 | 16 | class RelationModel(object): 17 | """ A wrapper class for the training and evaluation of models. """ 18 | def __init__(self, opt, bert_model): 19 | self.opt = opt 20 | self.bert_model = bert_model 21 | self.model = BiLSTMCNN(opt, bert_model) 22 | self.subj_criterion = nn.BCELoss(reduction='none') 23 | self.obj_criterion = nn.CrossEntropyLoss(reduction='none') 24 | self.parameters = [p for p in self.model.parameters() if p.requires_grad] 25 | if opt['cuda']: 26 | self.model.cuda() 27 | self.subj_criterion.cuda() 28 | self.obj_criterion.cuda() 29 | self.optimizer = torch_utils.get_optimizer(opt['optim'], self.parameters, opt['lr'], opt['weight_decay']) 30 | 31 | 32 | def update(self, batch): 33 | """ Run a step of forward and backward model update. """ 34 | if self.opt['cuda']: 35 | inputs = Variable(torch.LongTensor(batch[0]).cuda()) 36 | subj_start_type = Variable(torch.LongTensor(batch[3]).cuda()) 37 | subj_end_type = Variable(torch.LongTensor(batch[4]).cuda()) 38 | 39 | mask = (inputs.data>0).float() 40 | self.model.train() 41 | self.optimizer.zero_grad() 42 | 43 | subj_start_logits, subj_end_logits = self.model(inputs) 44 | subj_start_logits = subj_start_logits.view(-1, self.opt['num_subj_type']+1) 45 | subj_start_type = subj_start_type.view(-1).squeeze() 46 | subj_start_loss = self.obj_criterion(subj_start_logits, subj_start_type).view_as(mask) 47 | subj_start_loss = torch.sum(subj_start_loss.mul(mask.float())) / torch.sum(mask.float()) 48 | 49 | subj_end_loss = self.obj_criterion(subj_end_logits.view(-1, self.opt['num_subj_type']+1), subj_end_type.view(-1).squeeze()).view_as(mask) 50 | subj_end_loss = torch.sum(subj_end_loss.mul(mask.float())) / torch.sum(mask.float()) 51 | 52 | loss = subj_start_loss + subj_end_loss 53 | 54 | # backward 55 | loss.backward() 56 | # torch.nn.utils.clip_grad_norm(self.model.parameters(), self.opt['max_grad_norm']) 57 | self.optimizer.step() 58 | loss_val = loss.data.item() 59 | return loss_val 60 | 61 | 62 | def predict_subj_per_instance(self, words): 63 | 64 | if self.opt['cuda']: 65 | words = Variable(torch.LongTensor(words).cuda()) 66 | # chars = Variable(torch.LongTensor(chars).cuda()) 67 | # pos_tags = Variable(torch.LongTensor(pos_tags).cuda()) 68 | else: 69 | words = Variable(torch.LongTensor(words)) 70 | # features = Variable(torch.LongTensor(features)) 71 | 72 | batch_size, seq_len = words.size() 73 | mask = (words.data>0).float() 74 | # forward 75 | self.model.eval() 76 | hidden = self.model.based_encoder(words) 77 | 78 | subj_start_logits = self.model.subj_sublayer.predict_subj_start(hidden) 79 | subj_end_logits = self.model.subj_sublayer.predict_subj_end(hidden) 80 | 81 | return subj_start_logits, subj_end_logits, hidden 82 | 83 | 84 | def update_lr(self, new_lr): 85 | torch_utils.change_lr(self.optimizer, new_lr) 86 | 87 | def save(self, filename, epoch): 88 | params = { 89 | 'model': self.model.state_dict(), 90 | 'config': self.opt, 91 | 'epoch': epoch 92 | } 93 | try: 94 | torch.save(params, filename) 95 | print("model saved to {}".format(filename)) 96 | except BaseException: 97 | print("[Warning: Saving failed... continuing anyway.]") 98 | 99 | def load(self, filename): 100 | try: 101 | checkpoint = torch.load(filename) 102 | except BaseException: 103 | print("Cannot load model from {}".format(filename)) 104 | exit() 105 | self.model.load_state_dict(checkpoint['model']) 106 | self.opt = checkpoint['config'] 107 | 108 | class BiLSTMCNN(nn.Module): 109 | """ A sequence model for relation extraction. """ 110 | 111 | def __init__(self, opt, bert_model): 112 | super(BiLSTMCNN, self).__init__() 113 | self.drop = nn.Dropout(opt['dropout']) 114 | self.input_size = opt['word_emb_dim'] 115 | self.subj_sublayer = submodel.SubjTypeModel(opt) 116 | self.opt = opt 117 | self.bert_model = bert_model 118 | self.topn = self.opt.get('topn', 1e10) 119 | self.use_cuda = opt['cuda'] 120 | 121 | def based_encoder(self, words): 122 | hidden, _ = self.bert_model(words, output_all_encoded_layers=False) 123 | return hidden 124 | 125 | def forward(self, inputs): 126 | 127 | hidden = self.based_encoder(inputs) 128 | subj_start_logits, subj_end_logits = self.subj_sublayer(hidden) 129 | 130 | return subj_start_logits, subj_end_logits 131 | 132 | 133 | -------------------------------------------------------------------------------- /expirement_er/etlspan/models/submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from expirement_er.etlspan.utils import torch_utils 7 | 8 | 9 | class SubjTypeModel(nn.Module): 10 | 11 | 12 | def __init__(self, opt, filter=3): 13 | super(SubjTypeModel, self).__init__() 14 | 15 | self.dropout = nn.Dropout(opt['dropout']) 16 | self.hidden_dim = opt['word_emb_dim'] 17 | 18 | self.position_embedding = nn.Embedding(500, opt['position_emb_dim']) 19 | 20 | self.linear_subj_start = nn.Linear(self.hidden_dim, opt['num_subj_type']+1) 21 | self.linear_subj_end = nn.Linear(self.hidden_dim, opt['num_subj_type']+1) 22 | self.init_weights() 23 | 24 | 25 | 26 | def init_weights(self): 27 | 28 | self.position_embedding.weight.data.uniform_(-1.0, 1.0) 29 | self.linear_subj_start.bias.data.fill_(0) 30 | init.xavier_uniform_(self.linear_subj_start.weight, gain=1) # initialize linear layer 31 | 32 | self.linear_subj_end.bias.data.fill_(0) 33 | init.xavier_uniform_(self.linear_subj_end.weight, gain=1) # initialize linear layer 34 | 35 | def forward(self, hidden): 36 | 37 | subj_start_inputs = self.dropout(hidden) 38 | subj_start_logits = self.linear_subj_start(subj_start_inputs) 39 | 40 | subj_end_inputs = self.dropout(hidden) 41 | subj_end_logits = self.linear_subj_end(subj_end_inputs) 42 | 43 | return subj_start_logits.squeeze(-1), subj_end_logits.squeeze(-1) 44 | 45 | 46 | def predict_subj_start(self, hidden): 47 | 48 | subj_start_logits = self.linear_subj_start(hidden) 49 | subj_start_logits = torch.argmax(subj_start_logits, 2) 50 | 51 | return subj_start_logits.squeeze(-1)[0].data.cpu().numpy() 52 | 53 | def predict_subj_end(self, hidden): 54 | 55 | subj_end_logits = self.linear_subj_end(hidden) 56 | subj_end_logits = torch.argmax(subj_end_logits, 2) 57 | 58 | return subj_end_logits.squeeze(-1)[0].data.cpu().numpy() 59 | 60 | 61 | -------------------------------------------------------------------------------- /expirement_er/etlspan/saved_models/baiduRE/5_17_bert_2/00/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "dataset/yanbao", 3 | "vocab_dir": "dataset/baiduRE/vocab", 4 | "word_emb_dim": 768, 5 | "char_emb_dim": 100, 6 | "pos_emb_dim": 20, 7 | "position_emb_dim": 20, 8 | "obj_input_dim": 2324, 9 | "char_hidden_dim": 100, 10 | "num_layers": 1, 11 | "dropout": 0.4, 12 | "word_dropout": 0.04, 13 | "topn": 10000000000.0, 14 | "subj_loss_weight": 1, 15 | "type_loss_weight": 1, 16 | "lr": 2e-05, 17 | "lr_decay": 0, 18 | "weight_decay": 0, 19 | "optim": "adam", 20 | "num_epoch": 50, 21 | "load_saved": "", 22 | "batch_size": 10, 23 | "max_grad_norm": 5.0, 24 | "log_step": 50, 25 | "log": "logs.txt", 26 | "save_epoch": 10, 27 | "save_dir": "./saved_models/baiduRE/5_17_bert_2", 28 | "id": "00", 29 | "info": "", 30 | "bert_model": "../../bert-base-chinese", 31 | "seed": 35, 32 | "cuda": true, 33 | "cpu": false, 34 | "num_subj_type": 10, 35 | "model_save_dir": "./saved_models/baiduRE/5_17_bert_2/00" 36 | } -------------------------------------------------------------------------------- /expirement_er/etlspan/saved_models/baiduRE/5_17_bert_2/00/logs.txt: -------------------------------------------------------------------------------- 1 | # epoch train_loss\dev_p dev_r dev_f1 2 | -------------------------------------------------------------------------------- /expirement_er/etlspan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | -------------------------------------------------------------------------------- /expirement_er/etlspan/utils/constant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define common constants. 3 | """ 4 | TRAIN_JSON = 'train.json' 5 | DEV_JSON = 'dev.json' 6 | TEST_JSON = 'test.json' 7 | 8 | GLOVE_DIR = 'dataset/glove' 9 | 10 | EMB_INIT_RANGE = 1.0 11 | MAX_LEN = 100 12 | 13 | # vocab 14 | PAD_TOKEN = '' 15 | PAD_ID = 0 16 | UNK_TOKEN = '' 17 | UNK_ID = 1 18 | 19 | VOCAB_PREFIX = [PAD_TOKEN, UNK_TOKEN] 20 | 21 | INFINITY_NUMBER = 1e12 22 | 23 | -------------------------------------------------------------------------------- /expirement_er/etlspan/utils/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions. 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | 9 | ### IO 10 | def check_dir(d): 11 | if not os.path.exists(d): 12 | print("Directory {} does not exist. Exit.".format(d)) 13 | exit(1) 14 | 15 | def check_files(files): 16 | for f in files: 17 | if f is not None and not os.path.exists(f): 18 | print("File {} does not exist. Exit.".format(f)) 19 | exit(1) 20 | 21 | def ensure_dir(d, verbose=True): 22 | if not os.path.exists(d): 23 | if verbose: 24 | print("Directory {} do not exist; creating...".format(d)) 25 | os.makedirs(d) 26 | 27 | def save_config(config, path, verbose=True): 28 | with open(path, 'w') as outfile: 29 | json.dump(config, outfile, indent=2) 30 | if verbose: 31 | print("Config saved to file {}".format(path)) 32 | return config 33 | 34 | def load_config(path, verbose=True): 35 | with open(path) as f: 36 | config = json.load(f) 37 | if verbose: 38 | print("Config loaded from file {}".format(path)) 39 | return config 40 | 41 | def print_config(config): 42 | info = "Running with the following configs:\n" 43 | for k,v in config.items(): 44 | info += "\t{} : {}\n".format(k, str(v)) 45 | print("\n" + info + "\n") 46 | return 47 | 48 | class FileLogger(object): 49 | """ 50 | A file logger that opens the file periodically and write to it. 51 | """ 52 | def __init__(self, filename, header=None): 53 | self.filename = filename 54 | if os.path.exists(filename): 55 | # remove the old file 56 | os.remove(filename) 57 | if header is not None: 58 | with open(filename, 'w') as out: 59 | print(header, file=out) 60 | 61 | def log(self, message): 62 | with open(self.filename, 'a') as out: 63 | print(message, file=out) 64 | 65 | 66 | -------------------------------------------------------------------------------- /expirement_er/etlspan/utils/score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from expirement_er.etlspan.utils import loader 4 | import json 5 | from torch import nn 6 | import torch 7 | 8 | 9 | def extract_items(tokens, tokens_ids, model): 10 | R = [] 11 | if len(tokens_ids) > 180: 12 | return set(R) 13 | 14 | _t = np.array(loader.seq_padding([tokens_ids])) 15 | 16 | _s1, _s2, hidden = model.predict_subj_per_instance(_t) 17 | 18 | _s1 = _s1.tolist() 19 | _s2 = _s2.tolist() 20 | 21 | for i, _ss1 in enumerate(_s1): 22 | if _ss1 > 0: 23 | for j, _ss2 in enumerate(_s2[i:]): 24 | if _ss2 == _ss1: 25 | entity = ''.join(tokens[i: i + j + 1]) 26 | R.append(entity) 27 | break 28 | return set(R) 29 | 30 | 31 | def evaluate(data, model): 32 | official_A, official_B, official_C = 1e-10, 1e-10, 1e-10 33 | manual_A, manual_B, manual_C = 1e-10, 1e-10, 1e-10 34 | 35 | # { 36 | # "sent": "半导体行情的风险是什么", 37 | # "sent_tokens": ["半", "导", "体", "行", "情", "的", "风", "险", "是", "什", "么"], 38 | # "sent_token_ids": [1288, 2193, 860, 6121, 2658, 4638, 7599, 7372, 3221, 784, 720], 39 | # "entity_labels": [{"entity_type": "研报", "start_token_id": 0, "end_token_id": 10, "start_index": 0, "end_index": 10, 40 | # "entity_tokens": ["半", "导", "体", "行", "情", "的", "风", "险", "是", "什", "么"], 41 | # "entity_name": "半导体行情的风险是什么"}], 42 | # "tags": ["B-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report", "I-Report"], 43 | # "tag_ids": [11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]}, 44 | 45 | results = [] 46 | for d in tqdm(iter(data)): 47 | R = extract_items(d['sent_tokens'], d['sent_token_ids'], model) 48 | # official_T = set([tuple(i['entity_name']) for i in d['spo_details']]) 49 | official_T_list = [] 50 | for i in d['entity_labels']: 51 | official_T_list.append(i['entity_name']) 52 | official_T = set(official_T_list) 53 | results.append({'text': ''.join(d['sent_tokens']), 'predict': list(R), 'truth': list(official_T)}) 54 | official_A += len(R & official_T) 55 | official_B += len(R) 56 | official_C += len(official_T) 57 | return 2 * official_A / (official_B + official_C), official_A / official_B, official_A / official_C, results 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /expirement_er/etlspan/utils/vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | A class for basic vocab operations. 3 | """ 4 | 5 | from __future__ import print_function 6 | import os 7 | import random 8 | import numpy as np 9 | import pickle 10 | 11 | from expirement_er.etlspan.utils import constant 12 | 13 | 14 | 15 | random.seed(1234) 16 | np.random.seed(1234) 17 | 18 | def build_embedding(wv_file, vocab, wv_dim): 19 | vocab_size = len(vocab) 20 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim)) 21 | emb[constant.PAD_ID] = 0 # should be all 0 (using broadcast) 22 | 23 | w2id = {w: i for i, w in enumerate(vocab)} 24 | with open(wv_file, encoding="utf8") as f: 25 | for line in f: 26 | elems = line.split() 27 | token = ''.join(elems[0:-wv_dim]) 28 | if token in w2id: 29 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 30 | return emb 31 | 32 | def load_glove_vocab(file, wv_dim): 33 | """ 34 | Load all words from glove. 35 | """ 36 | vocab = set() 37 | with open(file, encoding='utf8') as f: 38 | try: 39 | for line_num, line in enumerate(f): 40 | elems = line.split() 41 | token = ''.join(elems[0:-wv_dim]) 42 | vocab.add(token) 43 | except Exception as e: 44 | print(line) 45 | return vocab 46 | 47 | def normalize_glove(token): 48 | mapping = {'-LRB-': '(', 49 | '-RRB-': ')', 50 | '-LSB-': '[', 51 | '-RSB-': ']', 52 | '-LCB-': '{', 53 | '-RCB-': '}'} 54 | if token in mapping: 55 | token = mapping[token] 56 | return token 57 | 58 | class Vocab(object): 59 | def __init__(self, filename, load=False, word_counter=None, threshold=0): 60 | if load: 61 | assert os.path.exists(filename), "Vocab file does not exist at " + filename 62 | # load from file and ignore all other params 63 | self.id2word, self.word2id = self.load(filename) 64 | self.size = len(self.id2word) 65 | print("Vocab size {} loaded from file".format(self.size)) 66 | else: 67 | print("Creating vocab from scratch...") 68 | assert word_counter is not None, "word_counter is not provided for vocab creation." 69 | self.word_counter = word_counter 70 | if threshold > 1: 71 | # remove words that occur less than thres 72 | self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold]) 73 | self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True) 74 | # add special tokens to the beginning 75 | self.id2word = ['**PAD**', '**UNK**'] + self.id2word 76 | self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))]) 77 | self.size = len(self.id2word) 78 | self.save(filename) 79 | print("Vocab size {} saved to file {}".format(self.size, filename)) 80 | 81 | def load(self, filename): 82 | with open(filename, 'rb') as infile: 83 | id2word = pickle.load(infile) 84 | word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))]) 85 | return id2word, word2id 86 | 87 | def save(self, filename): 88 | #assert not os.path.exists(filename), "Cannot save vocab: file exists at " + filename 89 | if os.path.exists(filename): 90 | print("Overwriting old vocab file at " + filename) 91 | os.remove(filename) 92 | with open(filename, 'wb') as outfile: 93 | pickle.dump(self.id2word, outfile) 94 | return 95 | 96 | def map(self, token_list): 97 | """ 98 | Map a list of tokens to their ids. 99 | """ 100 | return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list] 101 | 102 | def unmap(self, idx_list): 103 | """ 104 | Unmap ids back to tokens. 105 | """ 106 | return [self.id2word[idx] for idx in idx_list] 107 | 108 | def get_embeddings(self, word_vectors=None, dim=100): 109 | #self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE 110 | self.embeddings = np.zeros((self.size, dim)) 111 | if word_vectors is not None: 112 | assert len(list(word_vectors.values())[0]) == dim, \ 113 | "Word vectors does not have required dimension {}.".format(dim) 114 | for w, idx in self.word2id.items(): 115 | if w in word_vectors: 116 | self.embeddings[idx] = np.asarray(word_vectors[w]) 117 | return self.embeddings 118 | 119 | 120 | -------------------------------------------------------------------------------- /expirement_er/spert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/spert/__init__.py -------------------------------------------------------------------------------- /expirement_er/spert/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def _add_common_args(arg_parser): 5 | arg_parser.add_argument('--config', type=str) 6 | 7 | # Input 8 | arg_parser.add_argument('--types_path', type=str, help="Path to type specifications") 9 | 10 | # Preprocessing 11 | arg_parser.add_argument('--tokenizer_path', type=str, help="Path to tokenizer") 12 | arg_parser.add_argument('--max_span_size', type=int, default=10, help="Maximum size of spans") 13 | arg_parser.add_argument('--lowercase', action='store_true', default=False, 14 | help="If true, input is lowercased during preprocessing") 15 | arg_parser.add_argument('--sampling_processes', type=int, default=4, 16 | help="Number of sampling processes. 0 = no multiprocessing for sampling") 17 | arg_parser.add_argument('--sampling_limit', type=int, default=100, help="Maximum number of sample batches in queue") 18 | 19 | # Logging 20 | arg_parser.add_argument('--label', type=str, help="Label of run. Used as the directory name of logs/models") 21 | arg_parser.add_argument('--log_path', type=str, help="Path do directory where training/evaluation logs are stored") 22 | arg_parser.add_argument('--store_predictions', action='store_true', default=False, 23 | help="If true, store predictions on disc (in log directory)") 24 | arg_parser.add_argument('--store_examples', action='store_true', default=False, 25 | help="If true, store evaluation examples on disc (in log directory)") 26 | arg_parser.add_argument('--example_count', type=int, default=None, 27 | help="Count of evaluation example to store (if store_examples == True)") 28 | arg_parser.add_argument('--debug', action='store_true', default=False, help="Debugging mode on/off") 29 | 30 | # Model / Training / Evaluation 31 | arg_parser.add_argument('--model_path', type=str, help="Path to directory that contains model checkpoints") 32 | arg_parser.add_argument('--model_type', type=str, default="spert", help="Type of model") 33 | arg_parser.add_argument('--cpu', action='store_true', default=False, 34 | help="If true, train/evaluate on CPU even if a CUDA device is available") 35 | arg_parser.add_argument('--eval_batch_size', type=int, default=1, help="Evaluation batch size") 36 | arg_parser.add_argument('--max_pairs', type=int, default=1000, 37 | help="Maximum entity pairs to process during training/evaluation") 38 | arg_parser.add_argument('--rel_filter_threshold', type=float, default=0.4, help="Filter threshold for relations") 39 | arg_parser.add_argument('--size_embedding', type=int, default=25, help="Dimensionality of size embedding") 40 | arg_parser.add_argument('--prop_drop', type=float, default=0.1, help="Probability of dropout used in SpERT") 41 | arg_parser.add_argument('--freeze_transformer', action='store_true', default=False, help="Freeze BERT weights") 42 | arg_parser.add_argument('--no_overlapping', action='store_true', default=False, 43 | help="If true, do not evaluate on overlapping entities " 44 | "and relations with overlapping entities") 45 | 46 | # Misc 47 | arg_parser.add_argument('--seed', type=int, default=None, help="Seed") 48 | arg_parser.add_argument('--cache_path', type=str, default=None, 49 | help="Path to cache transformer models (for HuggingFace transformers library)") 50 | 51 | 52 | def train_argparser(): 53 | arg_parser = argparse.ArgumentParser() 54 | 55 | # Input 56 | arg_parser.add_argument('--train_path', type=str, help="Path to train dataset") 57 | arg_parser.add_argument('--valid_path', type=str, help="Path to validation dataset") 58 | 59 | # Logging 60 | arg_parser.add_argument('--save_path', type=str, help="Path to directory where model checkpoints are stored") 61 | arg_parser.add_argument('--init_eval', action='store_true', default=False, 62 | help="If true, evaluate validation set before training") 63 | arg_parser.add_argument('--save_optimizer', action='store_true', default=False, 64 | help="Save optimizer alongside model") 65 | arg_parser.add_argument('--train_log_iter', type=int, default=1, help="Log training process every x iterations") 66 | arg_parser.add_argument('--final_eval', action='store_true', default=False, 67 | help="Evaluate the model only after training, not at every epoch") 68 | 69 | # Model / Training 70 | arg_parser.add_argument('--train_batch_size', type=int, default=1, help="Training batch size") 71 | arg_parser.add_argument('--epochs', type=int, default=20, help="Number of epochs") 72 | arg_parser.add_argument('--neg_entity_count', type=int, default=100, 73 | help="Number of negative entity samples per document (sentence)") 74 | arg_parser.add_argument('--neg_relation_count', type=int, default=100, 75 | help="Number of negative relation samples per document (sentence)") 76 | arg_parser.add_argument('--lr', type=float, default=5e-5, help="Learning rate") 77 | arg_parser.add_argument('--lr_warmup', type=float, default=0.1, 78 | help="Proportion of total train iterations to warmup in linear increase/decrease schedule") 79 | arg_parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay to apply") 80 | arg_parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm") 81 | 82 | _add_common_args(arg_parser) 83 | 84 | return arg_parser 85 | 86 | 87 | def eval_argparser(): 88 | arg_parser = argparse.ArgumentParser() 89 | 90 | # Input 91 | arg_parser.add_argument('--dataset_path', type=str, help="Path to dataset") 92 | 93 | _add_common_args(arg_parser) 94 | 95 | return arg_parser 96 | -------------------------------------------------------------------------------- /expirement_er/spert/config_reader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import multiprocessing as mp 3 | 4 | 5 | def process_configs(target, arg_parser): 6 | args, _ = arg_parser.parse_known_args() 7 | ctx = mp.get_context('fork') 8 | 9 | for run_args, _run_config, _run_repeat in _yield_configs(arg_parser, args): 10 | p = ctx.Process(target=target, args=(run_args,)) 11 | p.start() 12 | p.join() 13 | 14 | 15 | def _read_config(path): 16 | lines = open(path).readlines() 17 | 18 | runs = [] 19 | run = [1, dict()] 20 | for line in lines: 21 | stripped_line = line.strip() 22 | 23 | # continue in case of comment 24 | if stripped_line.startswith('#'): 25 | continue 26 | 27 | if not stripped_line: 28 | if run[1]: 29 | runs.append(run) 30 | 31 | run = [1, dict()] 32 | continue 33 | 34 | if stripped_line.startswith('[') and stripped_line.endswith(']'): 35 | repeat = int(stripped_line[1:-1]) 36 | run[0] = repeat 37 | else: 38 | key, value = stripped_line.split('=') 39 | key, value = (key.strip(), value.strip()) 40 | run[1][key] = value 41 | 42 | if run[1]: 43 | runs.append(run) 44 | 45 | return runs 46 | 47 | 48 | def _convert_config(config): 49 | config_list = [] 50 | for k, v in config.items(): 51 | if v.lower() == 'true': 52 | config_list.append('--' + k) 53 | elif v.lower() != 'false': 54 | config_list.extend(['--' + k] + v.split(' ')) 55 | 56 | return config_list 57 | 58 | 59 | def _yield_configs(arg_parser, args, verbose=True): 60 | _print = (lambda x: print(x)) if verbose else lambda x: x 61 | 62 | if args.config: 63 | config = _read_config(args.config) 64 | 65 | for run_repeat, run_config in config: 66 | print("-" * 50) 67 | print("Config:") 68 | print(run_config) 69 | 70 | args_copy = copy.deepcopy(args) 71 | config_list = _convert_config(run_config) 72 | run_args = arg_parser.parse_args(config_list, namespace=args_copy) 73 | run_args_dict = vars(run_args) 74 | 75 | # set boolean values 76 | for k, v in run_config.items(): 77 | if v.lower() == 'false': 78 | run_args_dict[k] = False 79 | 80 | print("Repeat %s times" % run_repeat) 81 | print("-" * 50) 82 | 83 | for iteration in range(run_repeat): 84 | _print("Iteration %s" % iteration) 85 | _print("-" * 50) 86 | 87 | yield run_args, run_config, run_repeat 88 | 89 | else: 90 | yield args, None, None 91 | -------------------------------------------------------------------------------- /expirement_er/spert/configs/example_eval.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = conll04_eval 3 | model_type = spert 4 | model_path = data/models/conll04 5 | tokenizer_path = data/models/conll04 6 | dataset_path = data/datasets/conll04/conll04_test.json 7 | types_path = data/datasets/conll04/conll04_types.json 8 | eval_batch_size = 1 9 | rel_filter_threshold = 0.4 10 | size_embedding = 25 11 | prop_drop = 0.1 12 | max_span_size = 10 13 | store_predictions = true 14 | store_examples = true 15 | sampling_processes = 4 16 | sampling_limit = 100 17 | max_pairs = 1000 18 | log_path = data/log/ -------------------------------------------------------------------------------- /expirement_er/spert/configs/example_train.conf: -------------------------------------------------------------------------------- 1 | [1] 2 | label = conll04_train 3 | model_type = spert 4 | model_path = bert-base-cased 5 | tokenizer_path = bert-base-cased 6 | train_path = data/datasets/conll04/conll04_train.json 7 | valid_path = data/datasets/conll04/conll04_dev.json 8 | types_path = data/datasets/conll04/conll04_types.json 9 | train_batch_size = 2 10 | eval_batch_size = 1 11 | neg_entity_count = 100 12 | neg_relation_count = 100 13 | epochs = 20 14 | lr = 5e-5 15 | lr_warmup = 0.1 16 | weight_decay = 0.01 17 | max_grad_norm = 1.0 18 | rel_filter_threshold = 0.4 19 | size_embedding = 25 20 | prop_drop = 0.1 21 | max_span_size = 10 22 | store_predictions = true 23 | store_examples = true 24 | sampling_processes = 4 25 | sampling_limit = 100 26 | max_pairs = 1000 27 | final_eval = true 28 | log_path = data/log/ 29 | save_path = data/save/ -------------------------------------------------------------------------------- /expirement_er/spert/spert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from expirement_er.spert.args import train_argparser, eval_argparser 4 | from expirement_er.spert.config_reader import process_configs 5 | from expirement_er.spert.spert import input_reader 6 | from expirement_er.spert.spert.spert_trainer import SpERTTrainer 7 | 8 | 9 | def __train(run_args): 10 | trainer = SpERTTrainer(run_args) 11 | trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path, 12 | types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader) 13 | 14 | 15 | def _train(): 16 | arg_parser = train_argparser() 17 | process_configs(target=__train, arg_parser=arg_parser) 18 | 19 | 20 | def __eval(run_args): 21 | trainer = SpERTTrainer(run_args) 22 | trainer.eval(dataset_path=run_args.dataset_path, types_path=run_args.types_path, 23 | input_reader_cls=input_reader.JsonInputReader) 24 | 25 | 26 | def _eval(): 27 | arg_parser = eval_argparser() 28 | process_configs(target=__eval, arg_parser=arg_parser) 29 | 30 | 31 | if __name__ == '__main__': 32 | arg_parser = argparse.ArgumentParser(add_help=False) 33 | arg_parser.add_argument('mode', type=str, help="Mode: 'train' or 'eval'") 34 | args, _ = arg_parser.parse_known_args() 35 | 36 | if args.mode == 'train': 37 | _train() 38 | elif args.mode == 'eval': 39 | _eval() 40 | else: 41 | raise Exception("Mode not in ['train', 'eval'], e.g. 'python spert.py train ...'") 42 | -------------------------------------------------------------------------------- /expirement_er/spert/spert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_er/spert/spert/__init__.py -------------------------------------------------------------------------------- /expirement_er/spert/spert/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | 5 | 6 | class Loss(ABC): 7 | def compute(self, *args, **kwargs): 8 | pass 9 | 10 | 11 | class SpERTLoss(Loss): 12 | def __init__(self, rel_criterion, entity_criterion, model, optimizer, scheduler, max_grad_norm): 13 | self._rel_criterion = rel_criterion 14 | self._entity_criterion = entity_criterion 15 | self._model = model 16 | self._optimizer = optimizer 17 | self._scheduler = scheduler 18 | self._max_grad_norm = max_grad_norm 19 | 20 | def compute(self, entity_logits, rel_logits, entity_types, rel_types, entity_sample_masks, rel_sample_masks): 21 | # entity loss 22 | entity_logits = entity_logits.view(-1, entity_logits.shape[-1]) 23 | entity_types = entity_types.view(-1) 24 | entity_sample_masks = entity_sample_masks.view(-1).float() 25 | 26 | entity_loss = self._entity_criterion(entity_logits, entity_types) 27 | entity_loss = (entity_loss * entity_sample_masks).sum() / entity_sample_masks.sum() 28 | 29 | # relation loss 30 | rel_sample_masks = rel_sample_masks.view(-1).float() 31 | rel_count = rel_sample_masks.sum() 32 | 33 | if rel_count.item() != 0: 34 | rel_logits = rel_logits.view(-1, rel_logits.shape[-1]) 35 | rel_types = rel_types.view(-1, rel_types.shape[-1]) 36 | 37 | rel_loss = self._rel_criterion(rel_logits, rel_types) 38 | rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1] 39 | rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count 40 | 41 | # joint loss 42 | train_loss = entity_loss + rel_loss 43 | else: 44 | # corner case: no positive/negative relation samples 45 | train_loss = entity_loss 46 | 47 | train_loss.backward() 48 | torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._max_grad_norm) 49 | self._optimizer.step() 50 | self._scheduler.step() 51 | self._model.zero_grad() 52 | return train_loss.item() 53 | -------------------------------------------------------------------------------- /expirement_er/spert/spert/opt.py: -------------------------------------------------------------------------------- 1 | # optional packages 2 | 3 | try: 4 | import tensorboardX 5 | except ImportError: 6 | tensorboardX = None 7 | 8 | 9 | try: 10 | import jinja2 11 | except ImportError: 12 | jinja2 = None 13 | -------------------------------------------------------------------------------- /expirement_er/spert/spert/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | import sys 6 | from typing import List, Dict, Tuple 7 | 8 | import torch 9 | from torch.nn import DataParallel 10 | from torch.optim import Optimizer 11 | from transformers import PreTrainedModel 12 | from transformers import PreTrainedTokenizer 13 | 14 | from expirement_er.spert.spert import util 15 | from expirement_er.spert.spert.opt import tensorboardX 16 | 17 | SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | 20 | class BaseTrainer: 21 | """ Trainer base class with common methods """ 22 | 23 | def __init__(self, args: argparse.Namespace): 24 | self.args = args 25 | self._debug = self.args.debug 26 | 27 | # logging 28 | name = str(datetime.datetime.now()).replace(' ', '_') 29 | self._log_path = os.path.join(self.args.log_path, self.args.label, name) 30 | util.create_directories_dir(self._log_path) 31 | 32 | if hasattr(args, 'save_path'): 33 | self._save_path = os.path.join(self.args.save_path, self.args.label, name) 34 | util.create_directories_dir(self._save_path) 35 | 36 | self._log_paths = dict() 37 | 38 | # file + console logging 39 | log_formatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") 40 | self._logger = logging.getLogger() 41 | util.reset_logger(self._logger) 42 | 43 | file_handler = logging.FileHandler(os.path.join(self._log_path, 'all.log')) 44 | file_handler.setFormatter(log_formatter) 45 | self._logger.addHandler(file_handler) 46 | 47 | console_handler = logging.StreamHandler(sys.stdout) 48 | console_handler.setFormatter(log_formatter) 49 | self._logger.addHandler(console_handler) 50 | 51 | if self._debug: 52 | self._logger.setLevel(logging.DEBUG) 53 | else: 54 | self._logger.setLevel(logging.INFO) 55 | 56 | # tensorboard summary 57 | self._summary_writer = tensorboardX.SummaryWriter(self._log_path) if tensorboardX is not None else None 58 | 59 | self._best_results = dict() 60 | self._log_arguments() 61 | 62 | # CUDA devices 63 | self._device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") 64 | self._gpu_count = torch.cuda.device_count() 65 | 66 | # set seed 67 | if args.seed is not None: 68 | util.set_seed(args.seed) 69 | 70 | def _add_dataset_logging(self, *labels, data: Dict[str, List[str]]): 71 | for label in labels: 72 | dic = dict() 73 | 74 | for key, columns in data.items(): 75 | path = os.path.join(self._log_path, '%s_%s.csv' % (key, label)) 76 | util.create_csv(path, *columns) 77 | dic[key] = path 78 | 79 | self._log_paths[label] = dic 80 | self._best_results[label] = 0 81 | 82 | def _log_arguments(self): 83 | util.save_dict(self._log_path, self.args, 'args') 84 | if self._summary_writer is not None: 85 | util.summarize_dict(self._summary_writer, self.args, 'args') 86 | 87 | def _log_tensorboard(self, dataset_label: str, data_label: str, data: object, iteration: int): 88 | if self._summary_writer is not None: 89 | self._summary_writer.add_scalar('data/%s/%s' % (dataset_label, data_label), data, iteration) 90 | 91 | def _log_csv(self, dataset_label: str, data_label: str, *data: Tuple[object]): 92 | logs = self._log_paths[dataset_label] 93 | util.append_csv(logs[data_label], *data) 94 | 95 | def _save_best(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, optimizer: Optimizer, 96 | accuracy: float, iteration: int, label: str, extra=None): 97 | if accuracy > self._best_results[label]: 98 | self._logger.info("[%s] Best model in iteration %s: %s%% accuracy" % (label, iteration, accuracy)) 99 | self._save_model(self._save_path, model, tokenizer, iteration, 100 | optimizer=optimizer if self.args.save_optimizer else None, 101 | save_as_best=True, name='model_%s' % label, extra=extra) 102 | self._best_results[label] = accuracy 103 | 104 | def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, 105 | iteration: int, optimizer: Optimizer = None, save_as_best: bool = False, 106 | extra: dict = None, include_iteration: int = True, name: str = 'model'): 107 | extra_state = dict(iteration=iteration) 108 | 109 | if optimizer: 110 | extra_state['optimizer'] = optimizer.state_dict() 111 | 112 | if extra: 113 | extra_state.update(extra) 114 | 115 | if save_as_best: 116 | dir_path = os.path.join(save_path, '%s_best' % name) 117 | else: 118 | dir_name = '%s_%s' % (name, iteration) if include_iteration else name 119 | dir_path = os.path.join(save_path, dir_name) 120 | 121 | util.create_directories_dir(dir_path) 122 | 123 | # save model 124 | if isinstance(model, DataParallel): 125 | model.module.save_pretrained(dir_path) 126 | else: 127 | model.save_pretrained(dir_path) 128 | 129 | # save vocabulary 130 | tokenizer.save_pretrained(dir_path) 131 | 132 | # save extra 133 | state_path = os.path.join(dir_path, 'extra.state') 134 | torch.save(extra_state, state_path) 135 | 136 | def _get_lr(self, optimizer): 137 | lrs = [] 138 | for group in optimizer.param_groups: 139 | lr_scheduled = group['lr'] 140 | lrs.append(lr_scheduled) 141 | return lrs 142 | 143 | def _close_summary_writer(self): 144 | if self._summary_writer is not None: 145 | self._summary_writer.close() 146 | -------------------------------------------------------------------------------- /expirement_re/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/__init__.py -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_att_mis/__init__.py -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/checkpoints/pcnnone_DEF.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_att_mis/checkpoints/pcnnone_DEF.pth -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | label2id_dict = { 3 | "妻子": 1, 4 | "成立日期": 2, 5 | "号": 3, 6 | "出生地": 4, 7 | "注册资本": 5, 8 | "作曲": 6, 9 | "歌手": 7, 10 | "出品公司": 8, 11 | "人口数量": 9, 12 | "连载网站": 10, 13 | "创始人": 11, 14 | "首都": 12, 15 | "民族": 13, 16 | "目": 14, 17 | "邮政编码": 15, 18 | "毕业院校": 16, 19 | "作者": 17, 20 | "母亲": 18, 21 | "所在城市": 19, 22 | "制片人": 20, 23 | "出生日期": 21, 24 | "作词": 22, 25 | "占地面积": 23, 26 | "主演": 24, 27 | "面积": 25, 28 | "嘉宾": 26, 29 | "总部地点": 27, 30 | "修业年限": 28, 31 | "编剧": 29, 32 | "导演": 30, 33 | "主角": 31, 34 | "上映时间": 32, 35 | "出版社": 33, 36 | "祖籍": 34, 37 | "董事长": 35, 38 | "朝代": 36, 39 | "海拔": 37, 40 | "父亲": 38, 41 | "身高": 39, 42 | "主持人": 40, 43 | "改编自": 41, 44 | "简称": 42, 45 | "国籍": 43, 46 | "所属专辑": 44, 47 | "丈夫": 45, 48 | "气候": 46, 49 | "官方语言": 47, 50 | "字": 48, 51 | "无关系": 49 52 | } 53 | id2label_dict = {v: k for k, v in label2id_dict.items()} 54 | 55 | 56 | class DefaultConfig(object): 57 | label2id = label2id_dict 58 | id2label = id2label_dict 59 | 60 | model = 'bert_att_mis' # the name of used model, in 61 | 62 | root_path = '../data' 63 | result_dir = './out' 64 | 65 | bert_tokenizer_path = '../../bert-base-chinese/vocab.txt' 66 | bert_path = '../../bert-base-chinese' 67 | 68 | load_model_path = 'checkpoints/model.pth' # the trained model 69 | 70 | seed = 3435 71 | batch_size = 10 # batch size 72 | use_gpu = True # user GPU or not 73 | gpu_id = 0 74 | num_workers = 0 # how many workers for loading data 75 | 76 | hidden_dim = 768 77 | rel_num = len(label2id_dict) 78 | rel_dim = 768 79 | 80 | num_epochs = 16 # the number of epochs for training 81 | drop_out = 0.5 82 | lr = 2e-5 # initial learning rate 83 | lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay 84 | weight_decay = 0.0001 # optimizer parameter 85 | 86 | print_opt = 'DEF' 87 | 88 | 89 | def parse(self, kwargs): 90 | ''' 91 | user can update the default hyperparamter 92 | ''' 93 | for k, v in kwargs.items(): 94 | if not hasattr(self, k): 95 | raise Exception('opt has No key: {}'.format(k)) 96 | setattr(self, k, v) 97 | 98 | print('*************************************************') 99 | print('user config:') 100 | for k, v in self.__class__.__dict__.items(): 101 | if not k.startswith('__'): 102 | print("{} => {}".format(k, getattr(self, k))) 103 | 104 | print('*************************************************') 105 | 106 | 107 | DefaultConfig.parse = parse 108 | opt = DefaultConfig() 109 | -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from torch.utils.data import Dataset 4 | import os 5 | import numpy as np 6 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 7 | 8 | 9 | class PCNNData(Dataset): 10 | 11 | def __init__(self, opt, train=True): 12 | self.opt = opt 13 | self.bert_tokenizer = BertTokenizer.from_pretrained(opt.bert_tokenizer_path) 14 | if train: 15 | path = os.path.join(opt.root_path, 'train_mulit.txt') 16 | print('loading train data') 17 | else: 18 | path = os.path.join(opt.root_path, 'valid_mulit.txt') 19 | print('loading valid data') 20 | 21 | self.x = self.parse_sen(path) 22 | print('loading finish') 23 | 24 | def __getitem__(self, idx): 25 | assert idx < len(self.x) 26 | return self.x[idx] 27 | 28 | def __len__(self): 29 | return len(self.x) 30 | 31 | def parse_sen(self, path): 32 | all_bags = [] 33 | 34 | str1 = list('实体') 35 | str2 = list('和实体') 36 | str3 = list('之间的关系是什么?') 37 | 38 | f = open(path, encoding='utf-8') 39 | while 1: 40 | 41 | question = [] 42 | 43 | # if len(all_bags) > 50: 44 | # break 45 | # 第一行读取实体和关系 46 | line = f.readline() 47 | if not line: 48 | break 49 | entities = list(map(str, line.split('-*-'))) 50 | entitity = entities[:2] 51 | pre = entities[2][0:-1] 52 | label = self.opt.label2id[pre] 53 | # 第二行读取句子数量 54 | line = f.readline() 55 | sent_num = int(line) 56 | 57 | ent1_str = self.bert_tokenizer.tokenize(entitity[0].replace(' ', '')) 58 | ent1 = self.bert_tokenizer.convert_tokens_to_ids(ent1_str) 59 | ent2_str = self.bert_tokenizer.tokenize(entitity[1].replace(' ', '')) 60 | ent2 = self.bert_tokenizer.convert_tokens_to_ids(ent2_str) 61 | 62 | question = ['[CLS]'] + str1 + ent1_str + str2 + ent2_str + str3 + ['[SEP]'] 63 | 64 | # 通过MRC问答的方式加入实体 65 | # 实体1 和 实体2 的关系是什么? 66 | sentences = [] 67 | for i in range(0, sent_num): 68 | sent = f.readline() 69 | sent = sent.replace(' ', '') 70 | sent = self.bert_tokenizer.tokenize(sent) 71 | sent = question + sent + ['[SEP]'] 72 | word = self.bert_tokenizer.convert_tokens_to_ids(sent[:]) 73 | sentences.append(word) 74 | 75 | sentences = self.seq_padding(sentences) 76 | 77 | bag = [ent1, ent2, sent_num, label, sentences] 78 | all_bags += [bag] 79 | return all_bags 80 | 81 | def seq_padding(self, X): 82 | L = [len(x) for x in X] 83 | # ML = config.MAX_LEN #max(L) 84 | ML = max(L) 85 | return [x + [0] * (ML - len(x)) for x in X] 86 | 87 | 88 | -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/main_att.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from expirement_re.bert_att_mis.config import opt 4 | from expirement_re.bert_att_mis.models.bert_att_mis import * 5 | import torch 6 | from expirement_re.bert_att_mis.dataloader import * 7 | import torch.optim as optim 8 | from expirement_re.bert_att_mis.utils import save_pr, now, eval_metric 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | def collate_fn(batch): 13 | return batch 14 | 15 | 16 | def test(**kwargs): 17 | pass 18 | 19 | 20 | def train(**kwargs): 21 | 22 | kwargs.update({'model': 'bert_att_mis'}) 23 | opt.parse(kwargs) 24 | 25 | if opt.use_gpu: 26 | torch.cuda.set_device(opt.gpu_id) 27 | 28 | model = bert_att_mis(opt) 29 | if opt.use_gpu: 30 | model.cuda() 31 | 32 | # loading data 33 | train_data = PCNNData(opt, train=True) 34 | train_data_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, collate_fn=collate_fn) 35 | 36 | test_data = PCNNData(opt, train=False) 37 | test_data_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn) 38 | print('{} train data: {}; test data: {}'.format(now(), len(train_data), len(test_data))) 39 | 40 | # criterion and optimizer 41 | # criterion = nn.CrossEntropyLoss() 42 | optimizer = optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6) 43 | 44 | # train 45 | # max_pre = -1.0 46 | # max_rec = -1.0 47 | for epoch in range(opt.num_epochs): 48 | total_loss = 0 49 | for idx, data in enumerate(train_data_loader): 50 | 51 | label = [l[3] for l in data] 52 | 53 | optimizer.zero_grad() 54 | model.batch_size = opt.batch_size 55 | loss = model(data, label) 56 | if opt.use_gpu: 57 | label = torch.LongTensor(label).cuda() 58 | else: 59 | label = torch.LongTensor(label) 60 | loss.backward() 61 | optimizer.step() 62 | total_loss += loss.item() 63 | 64 | # true_y, pred_y, pred_p= predict(model, test_data_loader) 65 | # all_pre, all_rec = eval_metric(true_y, pred_y, pred_p) 66 | correct, positive_num = predict_var(model, test_data_loader) 67 | # last_pre, last_rec = eval_metric_var(pred_res, p_num) 68 | 69 | # if pred_res > 0.8 and p_num > 0.8: 70 | # save_pr(opt.result_dir, model.model_name, epoch, last_pre, last_rec, opt=opt.print_opt) 71 | # print('{} Epoch {} save pr'.format(now(), epoch + 1)) 72 | 73 | print('{} Epoch {}/{}: train loss: {}; test correct: {}, test num {}'.format(now(), epoch + 1, opt.num_epochs, total_loss, correct, positive_num)) 74 | 75 | 76 | def predict_var(model, test_data_loader): 77 | 78 | model.eval() 79 | 80 | res = [] 81 | true_y = [] 82 | 83 | correct = 0 84 | for idx, data in enumerate(test_data_loader): 85 | labels = [l[3] for l in data] 86 | out = model(data) 87 | true_y.extend(labels) 88 | if opt.use_gpu: 89 | # out = map(lambda o: o.data.cpu().numpy().tolist(), out) 90 | out = out.data.cpu().numpy().tolist() 91 | else: 92 | # out = map(lambda o: o.data.numpy().tolist(), out) 93 | out = out.data.numpy().tolist() 94 | 95 | for i, j in zip(out, labels): 96 | if i == j: 97 | correct += 1 98 | 99 | model.train() 100 | positive_num = len(test_data_loader.dataset.x) 101 | return correct, positive_num 102 | 103 | 104 | def eval_metric_var(pred_res, p_num): 105 | correct = 0.0 106 | p_nums = 0 107 | 108 | for i in pred_res: 109 | true_y = i[0] 110 | pred_y = i[1] 111 | values = i[2] 112 | if values > 0.5: 113 | p_nums += 1 114 | if true_y == pred_y: 115 | correct += 1 116 | 117 | if p_nums == 0: 118 | precision = 0 119 | else: 120 | precision = correct / p_nums 121 | recall = correct / p_num 122 | 123 | print("positive_num: {}; correct: {}".format(p_num, correct)) 124 | return precision, recall 125 | 126 | 127 | def predict(model, test_data_loader): 128 | 129 | model.eval() 130 | 131 | pred_y = [] 132 | true_y = [] 133 | pred_p = [] 134 | for idx, data in enumerate(test_data_loader): 135 | labels = [l[3] for l in data] 136 | true_y.extend(labels) 137 | out = model(data) 138 | res = torch.max(out, 1) 139 | if model.opt.use_gpu: 140 | pred_y.extend(res[1].data.cpu().numpy().tolist()) 141 | pred_p.extend(res[0].data.cpu().numpy().tolist()) 142 | else: 143 | pred_y.extend(res[1].data.numpy().tolist()) 144 | pred_p.extend(res[0].data.numpy().tolist()) 145 | # if idx % 100 == 99: 146 | # print('{} Eval: iter {}'.format(now(), idx)) 147 | 148 | size = len(test_data_loader.dataset) 149 | assert len(pred_y) == size and len(true_y) == size 150 | assert len(pred_y) == len(pred_p) 151 | model.train() 152 | return true_y, pred_y, pred_p 153 | 154 | 155 | if __name__ == "__main__": 156 | train() 157 | -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/models/BasicModule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import time 5 | 6 | 7 | class BasicModule(torch.nn.Module): 8 | ''' 9 | 封装了nn.Module,主要是提供了save和load两个方法 10 | ''' 11 | 12 | def __init__(self): 13 | super(BasicModule, self).__init__() 14 | self.model_name=str(type(self)) # model name 15 | 16 | def load(self, path): 17 | ''' 18 | 可加载指定路径的模型 19 | ''' 20 | self.load_state_dict(torch.load(path)) 21 | 22 | def save(self, name=None): 23 | ''' 24 | 保存模型,默认使用“模型名字+时间”作为文件名 25 | ''' 26 | prefix = 'checkpoints/' 27 | if name is None: 28 | name = prefix + self.model_name + '_' 29 | name = time.strftime(name + '%m%d_%H:%M:%S.pth') 30 | else: 31 | name = prefix + self.model_name + '_' + str(name)+ '.pth' 32 | torch.save(self.state_dict(), name) 33 | return name 34 | -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_att_mis/models/__init__.py -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/models/bert_att_mis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .BasicModule import BasicModule 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 10 | 11 | class bert_att_mis(BasicModule): 12 | 13 | def __init__(self, opt): 14 | super(bert_att_mis, self).__init__() 15 | 16 | self.opt = opt 17 | self.model_name = 'bert_att_mis' 18 | self.test_scale_p = 0.5 19 | 20 | self.bert_model = BertModel.from_pretrained(opt.bert_path) 21 | self.bert_model.cuda() 22 | 23 | self.bags_feature = [] 24 | 25 | rel_dim = opt.rel_dim 26 | 27 | self.rel_embs = nn.Parameter(torch.randn(self.opt.rel_num, rel_dim)) 28 | self.rel_bias = nn.Parameter(torch.randn(self.opt.rel_num)) 29 | 30 | self.dropout = nn.Dropout(self.opt.drop_out) 31 | 32 | self.init_model_weight() 33 | 34 | def init_model_weight(self): 35 | ''' 36 | use xavier to init 37 | ''' 38 | nn.init.xavier_uniform(self.rel_embs) 39 | nn.init.uniform(self.rel_bias) 40 | 41 | def forward(self, x, label=None): 42 | # get all sentences embedding in all bags of one batch 43 | self.bags_feature = self.get_bags_feature(x) 44 | 45 | if label is None: 46 | return self.test(x) 47 | else: 48 | return self.fit(label) 49 | 50 | def fit(self, label): 51 | ''' 52 | train process 53 | ''' 54 | x = self.get_batch_feature(label) # batch_size * sentence_feature_num 55 | x = self.dropout(x) 56 | out = x.mm(self.rel_embs.t()) + self.rel_bias # o = Ms + d (formual 10 in paper) 57 | 58 | if self.opt.use_gpu: 59 | v_label = torch.LongTensor(label).cuda() 60 | else: 61 | v_label = torch.LongTensor(label) 62 | ce_loss = F.cross_entropy(out, Variable(v_label)) 63 | return ce_loss 64 | 65 | def test(self, x): 66 | ''' 67 | test process 68 | ''' 69 | pre_y = [] 70 | for label in range(0, self.opt.rel_num): 71 | labels = [label for _ in range(len(x))] # generate the batch labels 72 | bags_feature = self.get_batch_feature(labels) 73 | out = self.test_scale_p * bags_feature.mm(self.rel_embs.t()) + self.rel_bias 74 | # out = F.softmax(out, 1) 75 | # pre_y.append(out[:, label]) 76 | pre_y.append(out.unsqueeze(1)) 77 | 78 | # return pre_y 79 | res = torch.cat(pre_y, 1).max(1)[0] 80 | return torch.argmax(F.softmax(res, 1), 1) 81 | 82 | def get_batch_feature(self, labels): 83 | ''' 84 | Using Attention to get all bags embedding in a batch 85 | ''' 86 | batch_feature = [] 87 | 88 | for bag_embs, label in zip(self.bags_feature, labels): 89 | # calculate the weight: xAr or xr 90 | alpha = bag_embs.mm(self.rel_embs[label].view(-1, 1)) 91 | # alpha = bag_embs.mm(self.att_w[label]).mm(self.rel_embs[label].view(-1, 1)) 92 | bag_embs = bag_embs * F.softmax(alpha, 0) 93 | bag_vec = torch.sum(bag_embs, 0) 94 | batch_feature.append(bag_vec.unsqueeze(0)) 95 | 96 | return torch.cat(batch_feature, 0) 97 | 98 | def get_bags_feature(self, bags): 99 | ''' 100 | get all bags embedding in one batch before Attention 101 | ''' 102 | bags_feature = [] 103 | for bag in bags: 104 | if self.opt.use_gpu: 105 | data = map(lambda x: Variable(torch.LongTensor(x).cuda()), bag) 106 | else: 107 | data = map(lambda x: Variable(torch.LongTensor(x)), bag) 108 | 109 | ent1, ent2, bags_num, lab, sen = data 110 | hidden, out = self.bert_model(sen, output_all_encoded_layers=False) 111 | bags_feature.append(out) 112 | 113 | return bags_feature 114 | 115 | 116 | -------------------------------------------------------------------------------- /expirement_re/bert_att_mis/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import time 5 | 6 | 7 | def now(): 8 | return str(time.strftime('%Y-%m-%d %H:%M:%S')) 9 | 10 | def seq_padding(X): 11 | L = [len(x) for x in X] 12 | # ML = config.MAX_LEN #max(L) 13 | ML = max(L) 14 | return [x + [0] * (ML - len(x)) for x in X] 15 | 16 | def save_pr(out_dir, name, epoch, pre, rec, fp_res=None, opt=None): 17 | if opt is None: 18 | out = open('{}/{}_{}_PR.txt'.format(out_dir, name, epoch + 1), 'w') 19 | else: 20 | out = open('{}/{}_{}_{}_PR.txt'.format(out_dir, name, opt, epoch + 1), 'w') 21 | 22 | if fp_res is not None: 23 | fp_out = open('{}/{}_{}_FP.txt'.format(out_dir, name, epoch + 1), 'w') 24 | for idx, r, p in fp_res: 25 | fp_out.write('{} {} {}\n'.format(idx, r, p)) 26 | fp_out.close() 27 | 28 | for p, r in zip(pre, rec): 29 | out.write('{} {}\n'.format(p, r)) 30 | 31 | out.close() 32 | 33 | 34 | def eval_metric(true_y, pred_y, pred_p): 35 | ''' 36 | calculate the precision and recall for p-r curve 37 | reglect the NA relation 38 | ''' 39 | assert len(true_y) == len(pred_y) 40 | positive_num = len([i for i in true_y if i[0] > 0]) 41 | index = np.argsort(pred_p)[::-1] 42 | 43 | tp = 0 44 | fp = 0 45 | fn = 0 46 | all_pre = [0] 47 | all_rec = [0] 48 | fp_res = [] 49 | 50 | for idx in range(len(true_y)): 51 | i = true_y[index[idx]] 52 | j = pred_y[index[idx]] 53 | 54 | if i[0] == 0: # NA relation 55 | if j > 0: 56 | fp_res.append((index[idx], j, pred_p[index[idx]])) 57 | fp += 1 58 | else: 59 | if j == 0: 60 | fn += 1 61 | else: 62 | for k in i: 63 | if k == -1: 64 | break 65 | if k == j: 66 | tp += 1 67 | break 68 | 69 | if fp + tp == 0: 70 | precision = 1.0 71 | else: 72 | precision = tp * 1.0 / (tp + fp) 73 | recall = tp * 1.0 / positive_num 74 | if precision != all_pre[-1] or recall != all_rec[-1]: 75 | all_pre.append(precision) 76 | all_rec.append(recall) 77 | 78 | print("tp={}; fp={}; fn={}; positive_num={}".format(tp, fp, fn, positive_num)) 79 | return all_pre[1:], all_rec[1:], fp_res 80 | -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_one_mis/__init__.py -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/checkpoints/pcnnone_DEF.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_one_mis/checkpoints/pcnnone_DEF.pth -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | label2id_dict = { 3 | "妻子": 1, 4 | "成立日期": 2, 5 | "号": 3, 6 | "出生地": 4, 7 | "注册资本": 5, 8 | "作曲": 6, 9 | "歌手": 7, 10 | "出品公司": 8, 11 | "人口数量": 9, 12 | "连载网站": 10, 13 | "创始人": 11, 14 | "首都": 12, 15 | "民族": 13, 16 | "目": 14, 17 | "邮政编码": 15, 18 | "毕业院校": 16, 19 | "作者": 17, 20 | "母亲": 18, 21 | "所在城市": 19, 22 | "制片人": 20, 23 | "出生日期": 21, 24 | "作词": 22, 25 | "占地面积": 23, 26 | "主演": 24, 27 | "面积": 25, 28 | "嘉宾": 26, 29 | "总部地点": 27, 30 | "修业年限": 28, 31 | "编剧": 29, 32 | "导演": 30, 33 | "主角": 31, 34 | "上映时间": 32, 35 | "出版社": 33, 36 | "祖籍": 34, 37 | "董事长": 35, 38 | "朝代": 36, 39 | "海拔": 37, 40 | "父亲": 38, 41 | "身高": 39, 42 | "主持人": 40, 43 | "改编自": 41, 44 | "简称": 42, 45 | "国籍": 43, 46 | "所属专辑": 44, 47 | "丈夫": 45, 48 | "气候": 46, 49 | "官方语言": 47, 50 | "字": 48, 51 | "无关系": 49 52 | } 53 | id2label_dict = {v: k for k, v in label2id_dict.items()} 54 | 55 | 56 | class DefaultConfig(object): 57 | label2id = label2id_dict 58 | id2label = id2label_dict 59 | 60 | model = 'bert_one_mis' # the name of used model, in 61 | 62 | root_path = '../data' 63 | result_dir = './out' 64 | 65 | bert_tokenizer_path = '../../bert-base-chinese/vocab.txt' 66 | bert_path = '../../bert-base-chinese' 67 | 68 | load_model_path = 'checkpoints/model.pth' # the trained model 69 | 70 | seed = 3435 71 | batch_size = 10 # batch size 72 | use_gpu = True # user GPU or not 73 | gpu_id = 0 74 | num_workers = 0 # how many workers for loading data 75 | 76 | hidden_dim = 768 77 | rel_num = len(label2id_dict) 78 | 79 | num_epochs = 16 # the number of epochs for training 80 | drop_out = 0.5 81 | lr = 2e-5 # initial learning rate 82 | lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay 83 | weight_decay = 0.0001 # optimizer parameter 84 | 85 | print_opt = 'DEF' 86 | 87 | 88 | def parse(self, kwargs): 89 | ''' 90 | user can update the default hyperparamter 91 | ''' 92 | for k, v in kwargs.items(): 93 | if not hasattr(self, k): 94 | raise Exception('opt has No key: {}'.format(k)) 95 | setattr(self, k, v) 96 | 97 | print('*************************************************') 98 | print('user config:') 99 | for k, v in self.__class__.__dict__.items(): 100 | if not k.startswith('__'): 101 | print("{} => {}".format(k, getattr(self, k))) 102 | 103 | print('*************************************************') 104 | 105 | 106 | DefaultConfig.parse = parse 107 | opt = DefaultConfig() 108 | -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from torch.utils.data import Dataset 4 | import os 5 | import numpy as np 6 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 7 | 8 | 9 | class PCNNData(Dataset): 10 | 11 | def __init__(self, opt, train=True): 12 | self.opt = opt 13 | self.bert_tokenizer = BertTokenizer.from_pretrained(opt.bert_tokenizer_path) 14 | if train: 15 | path = os.path.join(opt.root_path, 'train_mulit.txt') 16 | print('loading train data') 17 | else: 18 | path = os.path.join(opt.root_path, 'valid_mulit.txt') 19 | print('loading valid data') 20 | 21 | self.x = self.parse_sen(path) 22 | print('loading finish') 23 | 24 | def __getitem__(self, idx): 25 | assert idx < len(self.x) 26 | return self.x[idx] 27 | 28 | def __len__(self): 29 | return len(self.x) 30 | 31 | def parse_sen(self, path): 32 | all_bags = [] 33 | 34 | str1 = list('实体') 35 | str2 = list('和实体') 36 | str3 = list('之间的关系是什么?') 37 | 38 | f = open(path, encoding='utf-8') 39 | while 1: 40 | 41 | question = [] 42 | 43 | # if len(all_bags) > 200: 44 | # break 45 | # 第一行读取实体和关系 46 | line = f.readline() 47 | if not line: 48 | break 49 | entities = list(map(str, line.split('-*-'))) 50 | entitity = entities[:2] 51 | pre = entities[2][0:-1] 52 | label = self.opt.label2id[pre] 53 | # 第二行读取句子数量 54 | line = f.readline() 55 | sent_num = int(line) 56 | 57 | ent1_str = self.bert_tokenizer.tokenize(entitity[0].replace(' ', '')) 58 | ent1 = self.bert_tokenizer.convert_tokens_to_ids(ent1_str) 59 | ent2_str = self.bert_tokenizer.tokenize(entitity[1].replace(' ', '')) 60 | ent2 = self.bert_tokenizer.convert_tokens_to_ids(ent2_str) 61 | 62 | question = ['[CLS]'] + str1 + ent1_str + str2 + ent2_str + str3 + ['[SEP]'] 63 | 64 | # 通过MRC问答的方式加入实体 65 | # 实体1 和 实体2 的关系是什么? 66 | sentences = [] 67 | for i in range(0, sent_num): 68 | sent = f.readline() 69 | sent = sent.replace(' ', '') 70 | sent = self.bert_tokenizer.tokenize(sent) 71 | sent = question + sent + ['[SEP]'] 72 | word = self.bert_tokenizer.convert_tokens_to_ids(sent[:]) 73 | sentences.append(word) 74 | 75 | bag = [ent1, ent2, sent_num, label, sentences] 76 | all_bags += [bag] 77 | return all_bags 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/models/BasicModule.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import time 5 | 6 | 7 | class BasicModule(torch.nn.Module): 8 | ''' 9 | 封装了nn.Module,主要是提供了save和load两个方法 10 | ''' 11 | 12 | def __init__(self): 13 | super(BasicModule, self).__init__() 14 | self.model_name=str(type(self)) # model name 15 | 16 | def load(self, path): 17 | ''' 18 | 可加载指定路径的模型 19 | ''' 20 | self.load_state_dict(torch.load(path)) 21 | 22 | def save(self, name=None): 23 | ''' 24 | 保存模型,默认使用“模型名字+时间”作为文件名 25 | ''' 26 | prefix = 'checkpoints/' 27 | if name is None: 28 | name = prefix + self.model_name + '_' 29 | name = time.strftime(name + '%m%d_%H:%M:%S.pth') 30 | else: 31 | name = prefix + self.model_name + '_' + str(name)+ '.pth' 32 | torch.save(self.state_dict(), name) 33 | return name 34 | -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/expirement_re/bert_one_mis/models/__init__.py -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/models/bert_one_mis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .BasicModule import BasicModule 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 9 | from expirement_re.bert_one_mis import utils 10 | 11 | 12 | class bert_one_mis(BasicModule): 13 | 14 | def __init__(self, opt): 15 | super(bert_one_mis, self).__init__() 16 | 17 | self.opt = opt 18 | 19 | self.model_name = 'bert_one_mis' 20 | 21 | self.bert_model = BertModel.from_pretrained(opt.bert_path) 22 | self.bert_model.cuda() 23 | 24 | hidden_dim = self.opt.hidden_dim 25 | rel_num = self.opt.rel_num 26 | 27 | self.linear = nn.Linear(hidden_dim, rel_num) 28 | self.dropout = nn.Dropout(self.opt.drop_out) 29 | 30 | 31 | def forward(self, x, train=False): 32 | 33 | ent1, ent2, select_num, select_lab, select_sen = x 34 | 35 | hidden, _ = self.bert_model(select_sen, output_all_encoded_layers=False) 36 | out = self.dropout(_) 37 | out = self.linear(out) 38 | # out = torch.argmax(out, dim=1) 39 | 40 | return out 41 | -------------------------------------------------------------------------------- /expirement_re/bert_one_mis/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import time 5 | 6 | 7 | def now(): 8 | return str(time.strftime('%Y-%m-%d %H:%M:%S')) 9 | 10 | def seq_padding(X): 11 | L = [len(x) for x in X] 12 | # ML = config.MAX_LEN #max(L) 13 | ML = max(L) 14 | return [x + [0] * (ML - len(x)) for x in X] 15 | 16 | def save_pr(out_dir, name, epoch, pre, rec, fp_res=None, opt=None): 17 | if opt is None: 18 | out = open('{}/{}_{}_PR.txt'.format(out_dir, name, epoch + 1), 'w') 19 | else: 20 | out = open('{}/{}_{}_{}_PR.txt'.format(out_dir, name, opt, epoch + 1), 'w') 21 | 22 | if fp_res is not None: 23 | fp_out = open('{}/{}_{}_FP.txt'.format(out_dir, name, epoch + 1), 'w') 24 | for idx, r, p in fp_res: 25 | fp_out.write('{} {} {}\n'.format(idx, r, p)) 26 | fp_out.close() 27 | 28 | for p, r in zip(pre, rec): 29 | out.write('{} {}\n'.format(p, r)) 30 | 31 | out.close() 32 | 33 | 34 | def eval_metric(true_y, pred_y, pred_p): 35 | ''' 36 | calculate the precision and recall for p-r curve 37 | reglect the NA relation 38 | ''' 39 | assert len(true_y) == len(pred_y) 40 | positive_num = len([i for i in true_y if i[0] > 0]) 41 | index = np.argsort(pred_p)[::-1] 42 | 43 | tp = 0 44 | fp = 0 45 | fn = 0 46 | all_pre = [0] 47 | all_rec = [0] 48 | fp_res = [] 49 | 50 | for idx in range(len(true_y)): 51 | i = true_y[index[idx]] 52 | j = pred_y[index[idx]] 53 | 54 | if i[0] == 0: # NA relation 55 | if j > 0: 56 | fp_res.append((index[idx], j, pred_p[index[idx]])) 57 | fp += 1 58 | else: 59 | if j == 0: 60 | fn += 1 61 | else: 62 | for k in i: 63 | if k == -1: 64 | break 65 | if k == j: 66 | tp += 1 67 | break 68 | 69 | if fp + tp == 0: 70 | precision = 1.0 71 | else: 72 | precision = tp * 1.0 / (tp + fp) 73 | recall = tp * 1.0 / positive_num 74 | if precision != all_pre[-1] or recall != all_rec[-1]: 75 | all_pre.append(precision) 76 | all_rec.append(recall) 77 | 78 | print("tp={}; fp={}; fn={}; positive_num={}".format(tp, fp, fn, positive_num)) 79 | return all_pre[1:], all_rec[1:], fp_res 80 | -------------------------------------------------------------------------------- /expirement_re/data/baidu19/data_transfer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import torch 4 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 5 | import os 6 | 7 | # { 8 | # "tokens": 9 | # ["Massachusetts", "ASTON", "MAGNA", "Great", "Barrington", ";", "also", "at", 10 | # "Bard", "College", ",", "Annandale-on-Hudson", ",", "N.Y.", ",", "July", "1-Aug", "."], 11 | # "spo_list": 12 | # [["Annandale-on-Hudson", "/location/location/contains", "Bard College"]], 13 | # "spo_details": 14 | # [[11, 12, "LOCATION", "/location/location/contains", 8, 10, "ORGANIZATION"]], 15 | # "pos_tags": 16 | # ["NNP", "NNP", "NNP", "NNP", "NNP", ":", "RB", "IN", "NNP", "NNP", ",", "NNP", ",", "NNP", ",", "NNP", "NNP", "."] 17 | # } 18 | 19 | # {"postag": [{"word": "《", "pos": "w"}, {"word": "忘记我还是忘记他", "pos": "nw"}, {"word": "》", "pos": "w"}, {"word": "是", "pos": "v"}, {"word": "迪克牛仔", "pos": "nr"}, {"word": "于", "pos": "p"}, {"word": "2002年", "pos": "t"}, {"word": "发行", "pos": "v"}, {"word": "的", "pos": "u"}, {"word": "专辑", "pos": "n"}], 20 | # "text": "《忘记我还是忘记他》是迪克牛仔于2002年发行的专辑", 21 | # "spo_list": [{"predicate": "歌手", "object_type": "人物", "subject_type": "歌曲", "object": "迪克牛仔", "subject": "忘记我还是忘记他"}]} 22 | 23 | # {"text": "《邪少兵王》是冰火未央写的网络小说连载于旗峰天下", 24 | # "spo_list": [{"predicate": "作者", 25 | # "object_type": {"@value": "人物"}, 26 | # "subject_type": "图书作品", 27 | # "object": {"@value": "冰火未央"}, 28 | # "subject": "邪少兵王"}]} 29 | 30 | # { 31 | # S_TYPE: 娱乐人物, 32 | # P: 饰演, 33 | # O_TYPE: { 34 | # @value: 角色 35 | # inWork: 影视作品 36 | # } 37 | # } 38 | fr = open('train.txt', 'r', encoding='utf-8') 39 | 40 | sentdic_dict = {} 41 | 42 | for line in fr: 43 | ins = json.loads(line) 44 | sent = ins['text'] 45 | for spo in ins['spo_list']: 46 | ent1 = spo['subject'] 47 | ent2 = spo['object'] 48 | label = spo['predicate'] 49 | entity_key = ent1 + '-*-' + ent2 + '-*-' + label 50 | ent1_t = spo['subject_type'] 51 | ent2_t = spo['object_type'] 52 | enttype = ent1_t + '-*-' + ent2_t 53 | 54 | # 一个bag是一个词典 55 | # sentdic ['sents': , 'label': , 'entitytype': ] 56 | sentdic = {} 57 | if entity_key in sentdic_dict.keys(): 58 | sentdic = sentdic_dict[entity_key] 59 | sentdic['sents'] = sentdic['sents'] + sent + '-*-' 60 | sentdic['entitytype'] = sentdic['entitytype'] + enttype + '-*-' 61 | sentdic_dict[entity_key] = sentdic 62 | else: 63 | sentdic['sents'] = sent + '-*-' 64 | sentdic['entitytype'] = enttype + '-*-' 65 | sentdic_dict[entity_key] = sentdic 66 | 67 | 68 | with open('train_mulit.txt', 'w', encoding='utf-8') as file_obj: 69 | # json.dump(sentdic_dict, file_obj, ensure_ascii=False) 70 | for key, value in sentdic_dict.items(): 71 | file_obj.write(key + '\n') 72 | sents = [] 73 | sents_ = value['sents'] 74 | sents = sents_.split('-*-') 75 | file_obj.write(str(len(sents) - 1) + '\n') 76 | for i in sents: 77 | if i != '': 78 | file_obj.write(i + '\n') 79 | 80 | print('保存成功') 81 | 82 | -------------------------------------------------------------------------------- /extract_attrs.py: -------------------------------------------------------------------------------- 1 | import jieba.posseg as pseg 2 | 3 | from utils.functions import * 4 | 5 | # # 属性抽取 6 | # 7 | # 通过规则抽取属性 8 | # 9 | # - 研报时间 10 | # - 研报评级 11 | # - 文章时间 12 | 13 | # In[ ]: 14 | 15 | 16 | def find_article_time(yanbao_txt, entity): 17 | str_start_index = yanbao_txt.index(entity) 18 | str_end_index = str_start_index + len(entity) 19 | para_start_index = yanbao_txt.rindex('\n', 0, str_start_index) 20 | para_end_index = yanbao_txt.index('\n', str_end_index) 21 | 22 | para = yanbao_txt[para_start_index + 1: para_end_index].strip() 23 | if len(entity) > 5: 24 | ret = re.findall(r'(\d{4})\s*[年-]\s*(\d{1,2})\s*[月-]\s*(\d{1,2})\s*日?', para) 25 | if ret: 26 | year, month, day = ret[0] 27 | time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip()) 28 | return time 29 | 30 | start_index = 0 31 | time = None 32 | min_gap = float('inf') 33 | for word, poseg in pseg.cut(para): 34 | if poseg in ['t', 'TIME'] and str_start_index <= start_index < str_end_index: 35 | gap = abs(start_index - (str_start_index + str_end_index) // 2) 36 | if gap < min_gap: 37 | min_gap = gap 38 | time = word 39 | start_index += len(word) 40 | return time 41 | 42 | 43 | def find_yanbao_time(yanbao_txt, entity): 44 | paras = [para.strip() for para in yanbao_txt.split('\n') if para.strip()][:5] 45 | for para in paras: 46 | ret = re.findall(r'(\d{4})\s*[\./年-]\s*(\d{1,2})\s*[\./月-]\s*(\d{1,2})\s*日?', para) 47 | if ret: 48 | year, month, day = ret[0] 49 | time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip()) 50 | return time 51 | return None 52 | 53 | 54 | # In[ ]: 55 | 56 | 57 | def extract_attrs(entities_json): 58 | train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs'] 59 | 60 | seen_pingjis = [] 61 | for attr in train_attrs: 62 | if attr[1] == '评级': 63 | seen_pingjis.append(attr[2]) 64 | article_entities = entities_json.get('文章', []) 65 | yanbao_entities = entities_json.get('研报', []) 66 | 67 | attrs_json = [] 68 | for file_path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))): 69 | yanbao_txt = '\n' + Path(file_path).open(encoding='UTF-8').read() + '\n' 70 | for entity in article_entities: 71 | if entity not in yanbao_txt: 72 | continue 73 | time = find_article_time(yanbao_txt, entity) 74 | if time: 75 | attrs_json.append([entity, '发布时间', time]) 76 | 77 | yanbao_txt = '\n'.join( 78 | [para.strip() for para in yanbao_txt.split('\n') if 79 | len(para.strip()) != 0]) 80 | for entity in yanbao_entities: 81 | if entity not in yanbao_txt: 82 | continue 83 | 84 | paras = yanbao_txt.split('\n') 85 | for para_id, para in enumerate(paras): 86 | if entity in para: 87 | break 88 | 89 | paras = paras[: para_id + 5] 90 | for para in paras: 91 | for pingji in seen_pingjis: 92 | if pingji in para: 93 | if '上次' in para: 94 | attrs_json.append([entity, '上次评级', pingji]) 95 | continue 96 | elif '维持' in para: 97 | attrs_json.append([entity, '上次评级', pingji]) 98 | attrs_json.append([entity, '评级', pingji]) 99 | 100 | time = find_yanbao_time(yanbao_txt, entity) 101 | if time: 102 | attrs_json.append([entity, '发布时间', time]) 103 | attrs_json = list(set(tuple(_) for _ in attrs_json) - set(tuple(_) for _ in train_attrs)) 104 | 105 | return attrs_json 106 | 107 | # In[ ]: 108 | 109 | -------------------------------------------------------------------------------- /extract_relations.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from utils.functions import * 4 | 5 | # # 关系抽取 6 | # 7 | # - 对于研报实体,整个文档抽取特定类型(行业,机构,指标)的关系实体 8 | # - 其他的实体仅考虑与其出现在同一句话中的其他实体组织成特定关系 9 | 10 | # In[ ]: 11 | 12 | 13 | def extract_relations(schema, entities_json): 14 | relation_by_rules = [] 15 | relation_schema = schema['relationships'] 16 | unique_s_o_types = [] 17 | so_type_cnt = defaultdict(int) 18 | for s_type, p, o_type in schema['relationships']: 19 | so_type_cnt[(s_type, o_type)] += 1 20 | for (s_type, o_type), cnt in so_type_cnt.items(): 21 | if cnt == 1 and s_type != o_type: 22 | unique_s_o_types.append((s_type, o_type)) 23 | 24 | for path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))): 25 | with open(path, encoding='UTF-8') as f: 26 | entity_dict_in_file = defaultdict(lambda: defaultdict(list)) 27 | main_org = None 28 | for line_idx, line in enumerate(f.readlines()): 29 | for sent_idx, sent in enumerate(split_to_sents(line)): 30 | for ent_type, ents in entities_json.items(): 31 | for ent in ents: 32 | if ent in sent: 33 | if ent_type == '机构' and len(line) - len(ent) < 3 or \ 34 | re.findall('[\((]\d+\.*[A-Z]*[\))]', line): 35 | main_org = ent 36 | else: 37 | if main_org and '客户' in sent: 38 | relation_by_rules.append([ent, '客户', main_org]) 39 | entity_dict_in_file[ent_type][ 40 | ('test', ent)].append( 41 | [line_idx, sent_idx, sent, 42 | sent.find(ent)] 43 | ) 44 | 45 | for s_type, p, o_type in relation_schema: 46 | s_ents = entity_dict_in_file[s_type] 47 | o_ents = entity_dict_in_file[o_type] 48 | if o_type == '业务' and not '业务' in line: 49 | continue 50 | if o_type == '行业' and not '行业' in line: 51 | continue 52 | if o_type == '文章' and not ('《' in line or not '》' in line): 53 | continue 54 | if s_ents and o_ents: 55 | for (s_ent_src, s_ent), (o_ent_src, o_ent) in product(s_ents, o_ents): 56 | if s_ent != o_ent: 57 | s_occs = [tuple(_[:2]) for _ in 58 | s_ents[(s_ent_src, s_ent)]] 59 | o_occs = [tuple(_[:2]) for _ in 60 | o_ents[(o_ent_src, o_ent)]] 61 | intersection = set(s_occs) & set(o_occs) 62 | if s_type == '研报' and s_ent_src == 'test': 63 | relation_by_rules.append([s_ent, p, o_ent]) 64 | continue 65 | if not intersection: 66 | continue 67 | if (s_type, o_type) in unique_s_o_types and s_ent_src == 'test': 68 | relation_by_rules.append([s_ent, p, o_ent]) 69 | 70 | train_relations = read_json(Path(DATA_DIR, 'relationships.json'))['relationships'] 71 | result_relations_set = list(set(tuple(_) for _ in relation_by_rules) - set(tuple(_) for _ in train_relations)) 72 | 73 | submit_relations = [] 74 | for i in result_relations_set: 75 | if len(list(i[2])) > 3: 76 | submit_relations.append(i) 77 | 78 | return submit_relations 79 | 80 | # In[ ]: -------------------------------------------------------------------------------- /images/d1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/images/d1.PNG -------------------------------------------------------------------------------- /images/d2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/images/d2.PNG -------------------------------------------------------------------------------- /model_attribute/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/model_attribute/__init__.py -------------------------------------------------------------------------------- /model_entity/EtlModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | A joint model for relation extraction, written in pytorch. 3 | """ 4 | import torch 5 | from torch import nn 6 | from torch.nn import init 7 | from utils import torch_utils 8 | 9 | 10 | class EntityModel(object): 11 | """ A wrapper class for the training and evaluation of models. """ 12 | 13 | def __init__(self, opt, bert_model): 14 | self.opt = opt 15 | self.obj_criterion = nn.CrossEntropyLoss(reduction='none') 16 | self.model = BiLSTMCNN(opt, bert_model) 17 | self.parameters = [p for p in self.model.parameters() if p.requires_grad] 18 | if opt['cuda']: 19 | self.model.cuda() 20 | self.obj_criterion.cuda() 21 | self.optimizer = torch_utils.get_optimizer(opt['optim'], self.parameters, opt['lr'], opt['weight_decay']) 22 | 23 | def update(self, T, S1, S2, mask): 24 | 25 | inputs = T 26 | subj_start_type = S1 27 | subj_end_type = S2 28 | 29 | self.model.train() 30 | self.optimizer.zero_grad() 31 | 32 | subj_start_logits, subj_end_logits = self.model(inputs) 33 | subj_start_logits = subj_start_logits.view(-1, self.opt['num_subj_type'] + 1) 34 | subj_start_type = subj_start_type.view(-1).squeeze() 35 | subj_start_loss = self.obj_criterion(subj_start_logits, subj_start_type).view_as(mask) 36 | subj_start_loss = torch.sum(subj_start_loss.mul(mask.float())) / torch.sum(mask.float()) 37 | 38 | subj_end_loss = self.obj_criterion(subj_end_logits.view(-1, self.opt['num_subj_type'] + 1), 39 | subj_end_type.view(-1).squeeze()).view_as(mask) 40 | subj_end_loss = torch.sum(subj_end_loss.mul(mask.float())) / torch.sum(mask.float()) 41 | 42 | loss = subj_start_loss + subj_end_loss 43 | 44 | # backward 45 | loss.backward() 46 | # torch.nn.utils.clip_grad_norm(self.model.parameters(), self.opt['max_grad_norm']) 47 | self.optimizer.step() 48 | loss_val = loss.data.item() 49 | return loss_val 50 | 51 | def predict_subj_per_instance(self, words, mask): 52 | 53 | self.model.eval() 54 | hidden = self.model.based_encoder(words) 55 | 56 | subj_start_logits = self.model.subj_sublayer.predict_subj_start(hidden, mask) 57 | subj_end_logits = self.model.subj_sublayer.predict_subj_end(hidden, mask) 58 | 59 | return subj_start_logits, subj_end_logits 60 | 61 | def update_lr(self, new_lr): 62 | torch_utils.change_lr(self.optimizer, new_lr) 63 | 64 | def save(self, filename, epoch): 65 | params = { 66 | 'model': self.model.state_dict(), 67 | 'config': self.opt, 68 | 'epoch': epoch 69 | } 70 | try: 71 | torch.save(params, filename) 72 | print("model saved to {}".format(filename)) 73 | except BaseException: 74 | print("[Warning: Saving failed... continuing anyway.]") 75 | 76 | def load(self, filename): 77 | try: 78 | checkpoint = torch.load(filename) 79 | except BaseException: 80 | print("Cannot load model from {}".format(filename)) 81 | exit() 82 | self.model.load_state_dict(checkpoint['model']) 83 | self.opt = checkpoint['config'] 84 | 85 | 86 | class SubjTypeModel(nn.Module): 87 | 88 | def __init__(self, opt): 89 | super(SubjTypeModel, self).__init__() 90 | 91 | self.dropout = nn.Dropout(opt['dropout']) 92 | self.hidden_dim = opt['word_emb_dim'] 93 | 94 | self.linear_subj_start = nn.Linear(self.hidden_dim, opt['num_subj_type'] + 1) 95 | self.linear_subj_end = nn.Linear(self.hidden_dim, opt['num_subj_type'] + 1) 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | self.linear_subj_start.bias.data.fill_(0) 100 | init.xavier_uniform_(self.linear_subj_start.weight, gain=1) # initialize linear layer 101 | 102 | self.linear_subj_end.bias.data.fill_(0) 103 | init.xavier_uniform_(self.linear_subj_end.weight, gain=1) # initialize linear layer 104 | 105 | def forward(self, hidden): 106 | subj_start_inputs = self.dropout(hidden) 107 | subj_start_logits = self.linear_subj_start(subj_start_inputs) 108 | 109 | subj_end_inputs = self.dropout(hidden) 110 | subj_end_logits = self.linear_subj_end(subj_end_inputs) 111 | 112 | return subj_start_logits.squeeze(-1), subj_end_logits.squeeze(-1) 113 | 114 | def predict_subj_start(self, hidden, mask): 115 | subj_start_logits = self.linear_subj_start(hidden) 116 | subj_start_logits = torch.argmax(subj_start_logits, 2) 117 | subj_start_logits = subj_start_logits.mul(mask.float()) 118 | 119 | return subj_start_logits.squeeze(-1).data.cpu().numpy().tolist() 120 | 121 | def predict_subj_end(self, hidden, mask): 122 | subj_end_logits = self.linear_subj_end(hidden) 123 | subj_end_logits = torch.argmax(subj_end_logits, 2) 124 | subj_end_logits = subj_end_logits.mul(mask.float()) 125 | 126 | return subj_end_logits.squeeze(-1).data.cpu().numpy().tolist() 127 | 128 | 129 | class BiLSTMCNN(nn.Module): 130 | """ A sequence model for relation extraction. """ 131 | 132 | def __init__(self, opt, bert_model): 133 | super(BiLSTMCNN, self).__init__() 134 | self.drop = nn.Dropout(opt['dropout']) 135 | self.input_size = opt['word_emb_dim'] 136 | self.subj_sublayer = SubjTypeModel(opt) 137 | self.opt = opt 138 | self.bert_model = bert_model 139 | self.topn = self.opt.get('topn', 1e10) 140 | self.use_cuda = opt['cuda'] 141 | 142 | def based_encoder(self, words): 143 | hidden, _ = self.bert_model(words, output_all_encoded_layers=False) 144 | return hidden 145 | 146 | def forward(self, inputs): 147 | hidden = self.based_encoder(inputs) 148 | subj_start_logits, subj_end_logits = self.subj_sublayer(hidden) 149 | 150 | return subj_start_logits, subj_end_logits -------------------------------------------------------------------------------- /model_entity/HanlpNER.py: -------------------------------------------------------------------------------- 1 | # # 用hanlp进行实体识别 2 | # 3 | # hanlp支持对人物、机构的实体识别,可以使用它来对其中的两个实体类型进行识别:人物、机构。 4 | # 5 | # hanlp见[https://github.com/hankcs/HanLP](https://github.com/hankcs/HanLP) 6 | 7 | # In[ ]: 8 | 9 | import hanlp 10 | import re 11 | 12 | 13 | ## NER by third party tool 14 | class HanlpNER: 15 | def __init__(self): 16 | self.recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH) 17 | self.max_sent_len = 126 18 | self.ent_type_map = { 19 | 'NR': '人物', 20 | 'NT': '机构' 21 | } 22 | self.black_list = {'公司'} 23 | 24 | def recognize(self, sent): 25 | entities_dict = {} 26 | for result in self.recognizer.predict([list(sent)]): 27 | for entity, hanlp_ent_type, _, _ in result: 28 | if not re.findall(r'^[\.\s\da-zA-Z]{1,2}$', entity) and \ 29 | len(entity) > 1 and entity not in self.black_list \ 30 | and hanlp_ent_type in self.ent_type_map: 31 | entities_dict.setdefault(self.ent_type_map[hanlp_ent_type], []).append(entity) 32 | return entities_dict 33 | -------------------------------------------------------------------------------- /model_entity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/model_entity/__init__.py -------------------------------------------------------------------------------- /model_relation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/model_relation/__init__.py -------------------------------------------------------------------------------- /output/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | fr = open('answers11.json', 'r', encoding='utf-8') 4 | 5 | ins = json.load(fr) 6 | 7 | sentdic_list = {} 8 | 9 | sentdic_list['attrs'] = ins['attrs'] 10 | sentdic_list['entities'] = ins['entities'] 11 | list_ = [] 12 | for i in ins['relationships']: 13 | if len(list(i[2])) > 3: 14 | list_.append(i) 15 | 16 | sentdic_list['relationships'] = list_ 17 | 18 | with open('test.json', 'w', encoding = 'utf-8') as file_obj: 19 | json.dump(sentdic_list, file_obj,ensure_ascii=False, indent=4) 20 | 21 | print('保存成功') 22 | 23 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.schemas import * 3 | import torch 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--lr_decay', type=float, default=0) 8 | parser.add_argument('--topn', type=int, default=1e10, help='Only finetune top N embeddings.') 9 | parser.add_argument('--word_emb_dim', type=int, default=768, help='Word embedding dimension.') 10 | parser.add_argument('--dropout', type=float, default=0.4, help='Input and RNN dropout rate.') 11 | parser.add_argument('--weight_decay', type=float, default=0, help='Applies to SGD and Adagrad.') 12 | parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.') 13 | parser.add_argument('--bert_model', type=str, default='./bert-base-chinese', help='bert模型位置') 14 | parser.add_argument('--data_dir', type=str, default='./data/ccks2020-stage2-open', help='输入数据文件位置') 15 | parser.add_argument('--out_dir', type=str, default='./output/6.17', help='输出模型文件位置') 16 | parser.add_argument('--out_data_dir', type=str, default='./output/data', help='输出数据文件位置') 17 | parser.add_argument('--batch_size', type=int, default=6, help='batch_size') 18 | parser.add_argument('--total_epoch_nums', type=int, default=5, help='epoch') 19 | parser.add_argument('--nums_round', type=int, default=10, help='nums_round') 20 | parser.add_argument('--lr', type=float, default=2e-5) 21 | parser.add_argument('--seed', type=int, default=35) 22 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) 23 | parser.add_argument('--cpu', action='store_true', help='Ignore CUDA.') 24 | args = parser.parse_args() 25 | opt = vars(args) 26 | 27 | opt['num_subj_type'] = len(entity_type2id) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 方法说明 2 | 3 | 该代码方法用到了开源工具Hanlp,和官方的预训练模型bert-base-chinese。 4 | 5 | 项目目录结构如下: 6 | 7 | ![](images/d1.PNG) 8 | 9 | 其中expirement_attr、expirement_er和expirement_re三个文件夹下分别是做评测过程中进行的一些相关实验,data文件夹下存放的评测数据。 10 | 11 | ## 1.实体抽取方法 12 | 13 | 通过Hanlp实体识别工具,抽取“人物”和“机构”两种类型的实体。 14 | 15 | 通过规则,抽取“研报“,“文章“,“风险“,“ 机构“四种类型的实体。 16 | 17 | 除了规则匹配外,还可以采用远程监督的方法,主要用于抽取研报中的实体,具体流程如下图所示: 18 | 19 | ![](images/d2.PNG) 20 | 21 | 1.使用规则和外部工具抽取一部分实体 22 | 23 | 2.将原始数据平均分成两半,一半用于训练,一半用于测试,对用于训练的一半数据使用远程监督进行标注 24 | 25 | 3.采用将远程监督方法标注的数据按4:1划分,分别作为训练和验证集,训练模型 26 | 27 | 4.使用上一步训练出的模型在测试集上进行预测,抽取出一部分实体 28 | 29 | 5.查看是否达到中止循环的条件,达到条件后中止 30 | 31 | 6.通过规则匹配的方法筛选掉一些实体,剩下的实体加入种子知识图谱,然后从第2步开始,重复上一次训练,迭代进行实体抽取 32 | 33 | ## 2.属性抽取方法 34 | 35 | 使用规则匹配的抽取方法 36 | 37 | ## 3.关系抽取方法 38 | 39 | 使用规则匹配的抽取方法 40 | 41 | # 程序运行说明 42 | 43 | 需要先安装python3.7和pytorch1.3 44 | 45 | 然后需要使用以下命令安装相关依赖库: 46 | 47 | ``` 48 | pip install jieba 49 | pip install hanlp 50 | pip install pytorch_pretrained_bert 51 | ``` 52 | 53 | 54 | 使用如下命令启动程序: 55 | 56 | ``` 57 | python main.py 58 | ``` 59 | 60 | 最终结果存放在 61 | 62 | output文件夹下,名称为answers.json 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /regulation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from utils.preprocessing import * 4 | 5 | 6 | DATA_DIR = './data' # 输入数据文件夹 7 | 8 | 9 | # ## 通过规则抽取实体 10 | # 11 | # - 机构 12 | # - 研报 13 | # - 文章 14 | # - 风险 15 | 16 | # In[ ]: 17 | 18 | 19 | def aug_entities_by_rules(yanbao_dir): 20 | entities_by_rule = defaultdict(list) 21 | for file in list(yanbao_dir.glob('*.txt'))[:]: 22 | with open(file, encoding='utf-8') as f: 23 | found_yanbao = False 24 | found_fengxian = False 25 | for lidx, line in enumerate(f): 26 | # 公司的标题 27 | ret = re.findall('^[\((]*[\d一二三四五六七八九十①②③④⑤]*[\))\.\s]*(.*有限公司)$', line) 28 | if ret: 29 | entities_by_rule['机构'].append(ret[0]) 30 | 31 | # 研报 32 | if not found_yanbao and lidx <= 5 and len(line) > 10: 33 | may_be_yanbao = line.strip() 34 | if not re.findall(r'\d{4}\s*[年-]\s*\d{1,2}\s*[月-]\s*\d{1,2}\s*日?', may_be_yanbao) \ 35 | and not re.findall('^[\d一二三四五六七八九十]+\s*[\.、]\s*.*$', may_be_yanbao) \ 36 | and not re.findall('[\((]\d+\.*[A-Z]*[\))]', may_be_yanbao) \ 37 | and len(may_be_yanbao) > 5 \ 38 | and len(may_be_yanbao) < 100: 39 | entities_by_rule['研报'].append(may_be_yanbao) 40 | found_yanbao = True 41 | 42 | # 文章 43 | for sent in split_to_sents(line): 44 | results = re.findall('《(.*?)》', sent) 45 | for result in results: 46 | entities_by_rule['文章'].append(result) 47 | 48 | # 风险 49 | for sent in split_to_sents(line): 50 | if found_fengxian: 51 | sent = sent.split(':')[0] 52 | fengxian_entities = re.split('以及|、|,|;|。', sent) 53 | fengxian_entities = [re.sub('^[■]+[\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities] 54 | fengxian_entities = [re.sub('^[\((]*[\d一二三四五六七八九十①②③④⑤]+[\))\.\s]+', '', ent) for ent in 55 | fengxian_entities] 56 | fengxian_entities = [_ for _ in fengxian_entities if len(_) >= 4] 57 | entities_by_rule['风险'] += fengxian_entities 58 | found_fengxian = False 59 | if not found_fengxian and re.findall('^\s*[\d一二三四五六七八九十]*\s*[\.、]*\s*风险提示[::]*$', sent): 60 | found_fengxian = True 61 | 62 | results = re.findall('^\s*[\d一二三四五六七八九十]*\s*[\.、]*\s*风险提示[::]*(.{5,})$', sent) 63 | if results: 64 | fengxian_entities = re.split('以及|、|,|;|。', results[0]) 65 | fengxian_entities = [re.sub('^[■]+[\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities] 66 | fengxian_entities = [re.sub('^[\((]*[\d一二三四五六七八九十①②③④⑤]+[\))\.\s]+', '', ent) for ent in 67 | fengxian_entities] 68 | fengxian_entities = [_ for _ in fengxian_entities if len(_) >= 4] 69 | entities_by_rule['风险'] += fengxian_entities 70 | 71 | for ent_type, ents in entities_by_rule.items(): 72 | entities_by_rule[ent_type] = list(set(ents)) 73 | return entities_by_rule -------------------------------------------------------------------------------- /result_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | from regulation import * 3 | from extract_attrs import * 4 | from extract_relations import * 5 | from extract_entities import * 6 | 7 | 8 | fr = open('answers11.json', 'r', encoding='utf-8') 9 | 10 | ins = json.load(fr) 11 | dict__ = { 12 | "品牌": ["魅族"], 13 | "机构": ["日本厚生劳动省"], 14 | } 15 | 16 | dict_ = { 17 | '行业': [], 18 | '机构': ["官网", "三环", "学习中心", "营运中心", "维达团队", "发改", "保险", "云商", "解放军海", "沪深", "日本厚生劳动", "人保", "招商", "新加", "研究院", "美国总部", "中百", "市场", "移为", "太阳", "云计算", "田忌", "腾讯系", "工程学院", "平台", "上证报", "中国太平", "字节", "亚太", "两颗", "2颗鸡", "件公司", "我厨", "2018", "201",], 19 | '研报': [], 20 | '指标': ["– "], 21 | '人物': [], 22 | '业务': [], 23 | '风险': [], 24 | '文章': [], 25 | '品牌': [], 26 | '产品': [] 27 | } 28 | 29 | submit_entities = {} 30 | 31 | for key, value in ins['entities'].items(): 32 | submit_entities[key] = list(set(ins['entities'][key]).difference(set(dict_[key]))) 33 | 34 | 35 | # for key, value in submit_entities.items(): 36 | # for key__, value__ in dict__.items(): 37 | # if key__ == key: 38 | # submit_entities[key] = list(set(submit_entities[key]).union(set(dict__[key]))) 39 | # else: 40 | # submit_entities[key] = submit_entities[key] 41 | 42 | # submit_entities = ins['entities'] 43 | # 使用规则匹配得到实体属性 44 | train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs'] 45 | submit_attrs = extract_attrs(submit_entities) 46 | 47 | # 使用规则匹配得到实体关系三元组 48 | schema = read_json(Path(DATA_DIR, 'schema.json')) 49 | submit_relations = extract_relations(schema, submit_entities) 50 | 51 | 52 | final_answer = {'attrs': submit_attrs, 53 | 'entities': submit_entities, 54 | 'relationships': submit_relations, 55 | } 56 | 57 | with open('output/answers22.json', mode='w', encoding='UTF-8') as fw: 58 | json.dump(final_answer, fw, ensure_ascii=False, indent=4) 59 | 60 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | yanbao_str = "我们看好国内的半导体上游的芯片设计产业,上游芯片设计公司越多,对下游的代工需求越增加旺盛,有利于国内的半导体代工厂,国内两大代工巨头都在港股,我们在港股范围内推荐中芯国际和华虹半导体。" 2 | 3 | yanbao_list = list(yanbao_str) 4 | 5 | print(len(yanbao_list)) -------------------------------------------------------------------------------- /utils/EvaluateScores.py: -------------------------------------------------------------------------------- 1 | # # 定义训练时评价指标 2 | # 3 | # 仅供训练时参考, 包含实体的precision,recall以及f1。 4 | # 5 | # 只有和标注的数据完全相同才算是1,否则为0 6 | 7 | # In[ ]: 8 | 9 | 10 | # 训练时指标 11 | class EvaluateScores: 12 | def __init__(self, entities_json, predict_entities_json): 13 | self.entities_json = entities_json 14 | self.predict_entities_json = predict_entities_json 15 | 16 | def compute_entities_score(self): 17 | return evaluate_entities(self.entities_json, self.predict_entities_json, list(set(self.entities_json.keys()))) 18 | 19 | 20 | def evaluate_entities(true_entities, pred_entities, entity_types): 21 | scores = [] 22 | 23 | ps2 = [] 24 | rs2 = [] 25 | fs2 = [] 26 | 27 | for ent_type in entity_types: 28 | true_entities_list = true_entities.get(ent_type, []) 29 | pred_entities_list = pred_entities.get(ent_type, []) 30 | s = _compute_metrics(true_entities_list, pred_entities_list) 31 | scores.append(s) 32 | ps = [i['p'] for i in scores] 33 | rs = [i['r'] for i in scores] 34 | fs = [i['f'] for i in scores] 35 | s = { 36 | 'p': sum(ps) / len(ps), 37 | 'r': sum(rs) / len(rs), 38 | 'f': sum(fs) / len(fs), 39 | } 40 | return s 41 | 42 | 43 | def _compute_metrics(ytrue, ypred): 44 | ytrue = set(ytrue) 45 | ypred = set(ypred) 46 | tr = len(ytrue) 47 | pr = len(ypred) 48 | hit = len(ypred.intersection(ytrue)) 49 | p = hit / pr if pr!=0 else 0 50 | r = hit / tr if tr!=0 else 0 51 | f1 = 2 * p * r / (p + r) if (p+r)!=0 else 0 52 | return { 53 | 'p': p, 54 | 'r': r, 55 | 'f': f1, 56 | } -------------------------------------------------------------------------------- /utils/MyDataset.py: -------------------------------------------------------------------------------- 1 | # # 定义dataset 以及 dataloader 2 | 3 | # In[ ]: 4 | 5 | import torch 6 | from pytorch_pretrained_bert import BertModel, BertTokenizer 7 | import numpy as np 8 | from utils.schemas import * 9 | from torch.autograd import Variable 10 | 11 | 12 | class MyDataset(torch.utils.data.Dataset): 13 | def __init__(self, preprocessed_datas, tokenizer: BertTokenizer, max_length=256): 14 | self.preprocessed_datas = preprocessed_datas 15 | self.tokenizer = tokenizer 16 | self.max_length = max_length 17 | 18 | def pad_sent_ids(self, sent_ids, ts1, ts2, max_length, padded_token_id): 19 | mask = [1] * (min(len(sent_ids), max_length)) + [0] * (max_length - len(sent_ids)) 20 | sent_ids = sent_ids[:max_length] + [padded_token_id] * (max_length - len(sent_ids)) 21 | ts1 = ts1[:max_length] + [padded_token_id] * (max_length - len(ts1)) 22 | ts2 = ts2[:max_length] + [padded_token_id] * (max_length - len(ts2)) 23 | return sent_ids, ts1, ts2, mask 24 | 25 | def process_one_preprocessed_data(self, preprocessed_data): 26 | import copy 27 | preprocessed_data = copy.deepcopy(preprocessed_data) 28 | 29 | sent_token_ids = preprocessed_data['sent_token_ids'] 30 | ts1 = preprocessed_data['ts1'] 31 | ts2 = preprocessed_data['ts2'] 32 | 33 | sent_token_ids, ts1, ts2, mask = self.pad_sent_ids(sent_token_ids, ts1, ts2, max_length = self.max_length, padded_token_id=0) 34 | 35 | T = np.array(sent_token_ids) 36 | S1 = np.array(ts1) 37 | S2 = np.array(ts2) 38 | mask = np.array(mask) 39 | 40 | preprocessed_data['T'] = Variable(torch.LongTensor(T).cuda()) 41 | preprocessed_data['S1'] = Variable(torch.LongTensor(S1).cuda()) 42 | preprocessed_data['S2'] = Variable(torch.LongTensor(S2).cuda()) 43 | preprocessed_data['mask'] = Variable(torch.LongTensor(mask).cuda()) 44 | 45 | return preprocessed_data 46 | 47 | def __getitem__(self, item): 48 | return self.process_one_preprocessed_data( 49 | self.preprocessed_datas[item] 50 | ) 51 | 52 | def __len__(self): 53 | return len(self.preprocessed_datas) 54 | -------------------------------------------------------------------------------- /utils/QuestionText.py: -------------------------------------------------------------------------------- 1 | # type2query = { 2 | # "人物": '社会中的以人为身份的个体,企业董事长,员工等都是人物;人物有姓名、年龄等属性;例如:钟南山、冯欣;', 3 | # "行业": '指一组提供同类相互密切替代商品或服务的公司;例如:交通运输、旅游业;', 4 | # "业务": '各行业中需要处理的事务,通常偏向指销售的事务;例如:住宅物业服务、维修;', 5 | # "产品": '做为商品提供给市场,被人们使用和消费,并能满足人们某种需求的任何东西;例如:奶粉、手机;', 6 | # "研报": '企业或者机构发布的公司或者行业的发展调查报告;有发布时间、评级、上次评级等属性;例如:证券研究报告/行业深度报告;', 7 | # "机构": '指政府、机关、团体等的内部组织;有全称、简称、英文名等属性;例如:美国国会、广汽集团;', 8 | # "风险": '指生产目的与劳动成果之间的不确定性;例如:行业政策风险、研发进度不达预期', 9 | # "文章": '用文字写成的,用于描述、记录或表达思想的文本,如各种手册、行业报告等等;文章有发布时间属性;例如:边缘计算专题报告、电商法', 10 | # "指标": '衡量目标的参数,预期中打算达到的指数、规格、标准,一般用数据表示;例如:市场占有率、利润率', 11 | # "品牌": '商业公司的产品或者业务的名称,也可以是公司名称;例如:小米集团、Note系列', 12 | # } 13 | type2query = { 14 | "指标": '衡量目标的参数,预期中打算达到的指数、规格、标准,一般用数据表示;例如:市场占有率、利润率', 15 | "品牌": '商业公司的产品或者业务的名称,也可以是公司名称;例如:小米集团、Note系列', 16 | "行业": '指一组提供同类相互密切替代商品或服务的公司;例如:交通运输、旅游业;', 17 | "业务": '各行业中需要处理的事务,通常偏向指销售的事务;例如:住宅物业服务、维修;', 18 | "产品": '做为商品提供给市场,被人们使用和消费,并能满足人们某种需求的任何东西;例如:奶粉、手机;' 19 | } 20 | query2type = {v: k for k, v in type2query.items()} -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaStudenttwo/ccks_kg/a5404669de86a7f7b87c07c15a5f24c95497ab86/utils/__init__.py -------------------------------------------------------------------------------- /utils/bert_dict_enhance.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import tqdm 4 | from utils.preprocess_data import Article 5 | from pytorch_pretrained_bert import BertModel, BertTokenizer 6 | 7 | 8 | DATA_DIR = '../data' # 输入数据文件夹 9 | YANBAO_DIR_PATH = str(Path(DATA_DIR, 'yanbao_txt__')) 10 | PRETRAINED_BERT_MODEL_DIR = '../bert-base-chinese/' 11 | 12 | 13 | tokenizer = BertTokenizer.from_pretrained( 14 | os.path.join(PRETRAINED_BERT_MODEL_DIR, 'vocab.txt') 15 | ) 16 | 17 | yanbao_texts = [] 18 | for yanbao_file_path in Path(YANBAO_DIR_PATH).glob('*.txt'): 19 | with open(yanbao_file_path, encoding='utf-8') as f: 20 | yanbao_texts.append(f.read()) 21 | 22 | 23 | article_tokens = [] 24 | for article in tqdm.tqdm([Article(t) for t in yanbao_texts]): 25 | for para_text in article.para_texts: 26 | for sent in article.split_into_sentence(para_text): 27 | sent_tokens = list(sent) 28 | for i in sent_tokens: 29 | article_tokens.append(i) 30 | 31 | dict_list = [] 32 | dict = tokenizer.vocab 33 | for i in dict: 34 | dict_list.append(i) 35 | 36 | notin_tokens = [] 37 | for i in article_tokens: 38 | if i not in dict_list and i not in notin_tokens: 39 | notin_tokens.append(i) 40 | 41 | num = 0 42 | for i in notin_tokens: 43 | num = num + 1 44 | print(i) 45 | print(num) -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from collections import defaultdict 3 | import torch 4 | import json 5 | from pytorch_pretrained_bert import BertModel, BertTokenizer 6 | import jieba 7 | from jieba.analyse.tfidf import TFIDF 8 | from jieba.posseg import POSTokenizer 9 | 10 | from model_entity.HanlpNER import HanlpNER 11 | from utils.preprocessing import * 12 | from pathlib import Path 13 | from utils.MyDataset import MyDataset 14 | from parameters import * 15 | 16 | # In[ ]: 17 | 18 | 19 | # # 预训练模型配置 20 | # 21 | # 参考 https://github.com/huggingface/pytorch-transformers 下载预训练模型,并配置下面参数为相关路径 22 | # 23 | # ```python 24 | # PRETRAINED_BERT_MODEL_DIR = '/you/path/to/bert-base-chinese/' 25 | # ``` 26 | 27 | # In[ ]: 28 | 29 | 30 | # # 一些参数 31 | 32 | # In[ ]: 33 | 34 | 35 | DATA_DIR = opt['data_dir'] # 输入数据文件夹 36 | OUT_DIR = opt['out_dir'] # 输出文件夹 37 | 38 | Path(OUT_DIR).mkdir(exist_ok=True) 39 | 40 | BATCH_SIZE = opt['batch_size'] 41 | TOTAL_EPOCH_NUMS = opt['total_epoch_nums'] 42 | 43 | if torch.cuda.is_available(): 44 | DEVICE = 'cuda:0' 45 | else: 46 | DEVICE = 'cpu' 47 | YANBAO_DIR_PATH = str(Path(DATA_DIR, 'yanbao_txt')) 48 | SAVE_MODEL_DIR = str(OUT_DIR) 49 | 50 | 51 | def aug_entities_by_third_party_tool(): 52 | hanlpner = HanlpNER() 53 | entities_by_third_party_tool = defaultdict(list) 54 | for file in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))[:]): 55 | with open(file, encoding='utf-8') as f: 56 | sents = [[]] 57 | cur_sent_len = 0 58 | for line in f: 59 | for sent in split_to_subsents(line): 60 | sent = sent[:hanlpner.max_sent_len] 61 | if cur_sent_len + len(sent) > hanlpner.max_sent_len: 62 | sents.append([sent]) 63 | cur_sent_len = len(sent) 64 | else: 65 | sents[-1].append(sent) 66 | cur_sent_len += len(sent) 67 | sents = [''.join(_) for _ in sents] 68 | sents = [_ for _ in sents if _] 69 | for sent in sents: 70 | entities_dict = hanlpner.recognize(sent) 71 | for ent_type, ents in entities_dict.items(): 72 | entities_by_third_party_tool[ent_type] += ents 73 | 74 | for ent_type, ents in entities_by_third_party_tool.items(): 75 | entities_by_third_party_tool[ent_type] = list([ent for ent in set(ents) if len(ent) > 1]) 76 | return entities_by_third_party_tool 77 | 78 | 79 | def custom_collate_fn(data): 80 | # copy from torch official,无需深究 81 | from torch._six import container_abcs, string_classes 82 | 83 | r"""Converts each NumPy array data field into a tensor""" 84 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 85 | elem_type = type(data) 86 | if isinstance(data, torch.Tensor): 87 | return data 88 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and elem_type.__name__ != 'string_': 89 | # array of string classes and object 90 | if elem_type.__name__ == 'ndarray' and np_str_obj_array_pattern.search(data.dtype.str) is not None: 91 | return data 92 | return torch.as_tensor(data) 93 | elif isinstance(data, container_abcs.Mapping): 94 | tmp_dict = {} 95 | for key in data: 96 | if key in ['sent_token_ids', 'tag_ids', 'mask']: 97 | tmp_dict[key] = custom_collate_fn(data[key]) 98 | if key == 'mask': 99 | tmp_dict[key] = tmp_dict[key].byte() 100 | else: 101 | tmp_dict[key] = data[key] 102 | return tmp_dict 103 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple 104 | return elem_type(*(custom_collate_fn(d) for d in data)) 105 | elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes): 106 | return [custom_collate_fn(d) for d in data] 107 | else: 108 | return data 109 | 110 | 111 | def read_json(file_path): 112 | with open(file_path, mode='r', encoding='utf8') as f: 113 | return json.load(f) 114 | 115 | 116 | def build_dataloader(preprocessed_datas, tokenizer: BertTokenizer, batch_size=32, shuffle=True): 117 | dataset = MyDataset(preprocessed_datas, tokenizer) 118 | import torch.utils.data 119 | dataloader = torch.utils.data.DataLoader( 120 | dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle) 121 | return dataloader 122 | 123 | 124 | # ### 模型预测结果后处理函数 125 | # 126 | # - `review_model_predict_entities`函数将模型预测结果后处理,从而生成提交文件格式 127 | 128 | 129 | def review_model_predict_entities(model_predict_entities): 130 | word_tag_map = POSTokenizer().word_tag_tab 131 | idf_freq = TFIDF().idf_freq 132 | reviewed_entities = defaultdict(list) 133 | for ent_type, ent_and_sent_list in model_predict_entities.items(): 134 | for ent, sent in ent_and_sent_list: 135 | start = sent.lower().find(ent) 136 | if start == -1: 137 | continue 138 | start += 1 139 | end = start + len(ent) - 1 140 | tokens = jieba.lcut(sent) 141 | offset = 0 142 | selected_tokens = [] 143 | for token in tokens: 144 | offset += len(token) 145 | if offset >= start: 146 | selected_tokens.append(token) 147 | if offset >= end: 148 | break 149 | 150 | fixed_entity = ''.join(selected_tokens) 151 | fixed_entity = re.sub(r'\d*\.?\d+%$', '', fixed_entity) 152 | if ent_type == '人物': 153 | if len(fixed_entity) >= 10: 154 | continue 155 | if len(fixed_entity) <= 1: 156 | continue 157 | if re.findall(r'^\d+$', fixed_entity): 158 | continue 159 | if word_tag_map.get(fixed_entity, '') == 'v' and idf_freq[fixed_entity] < 7: 160 | continue 161 | reviewed_entities[ent_type].append(fixed_entity) 162 | return reviewed_entities -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | # # 预处理函数 5 | # 6 | # 对文章进行预处理,切分句子和子句等 7 | 8 | # In[ ]: 9 | 10 | 11 | def split_to_sents(content, filter_length=(2, 1000)): 12 | content = re.sub(r"\s*", "", content) 13 | content = re.sub("([。!…??!;;])", "\\1\1", content) 14 | sents = content.split("\1") 15 | sents = [_[: filter_length[1]] for _ in sents] 16 | return [_ for _ in sents 17 | if filter_length[0] <= len(_) <= filter_length[1]] 18 | 19 | 20 | def split_to_subsents(content, filter_length=(2, 1000)): 21 | content = re.sub(r"\s*", "", content) 22 | content = re.sub("([。!…??!;;,,])", "\\1\1", content) 23 | sents = content.split("\1") 24 | sents = [_[: filter_length[1]] for _ in sents] 25 | return [_ for _ in sents 26 | if filter_length[0] <= len(_) <= filter_length[1]] -------------------------------------------------------------------------------- /utils/schemas.py: -------------------------------------------------------------------------------- 1 | # entity_type2id = { 2 | # '人物': 1, 3 | # '行业': 2, 4 | # '业务': 3, 5 | # '研报': 4, 6 | # '机构': 5, 7 | # '风险': 6, 8 | # '文章': 7, 9 | # '指标': 8, 10 | # '品牌': 9, 11 | # '产品': 10 12 | # } 13 | entity_type2id = { 14 | '指标': 1, 15 | '品牌': 2, 16 | '行业': 3, 17 | '业务': 4, 18 | '产品': 5 19 | } 20 | entity_id2type = {v: k for k, v in entity_type2id.items()} 21 | --------------------------------------------------------------------------------