├── README.md ├── readme_复赛.txt └── run.py /README.md: -------------------------------------------------------------------------------- 1 | # megemini 2 | 瑞金医院知识图谱大赛总决赛第四名比赛攻略_megemini队 3 | 4 | 复赛代码: 5 | 6 | run.html (jupyter notebook导出的html文件) 7 | 8 | run.py (run.html将主要代码合并后的可执行文件) 9 | 10 | readme_复赛.txt (复赛提交的readme文件) 11 | 12 | -------------------------------------------------------------------------------- /readme_复赛.txt: -------------------------------------------------------------------------------- 1 | 队伍:megemini 2 | 成绩:0.770 (2018-12-05) 3 | 文件:1544004574135_s2_01_final_v4_b_raw.zip 4 | 5 | (1.简要介绍每个文件的作用及自己实现的代码主要在哪些文件) 6 | 1. 目录结构: 7 | |-run.ipynb # 代码,只有一个文件,直接运行程序即可。程序会先训练模型,然后预测并生成结果。 8 | |-readme.txt # 本说明文档 9 | |-Demo 10 | |-DataSets # 数据文件,内部结构保持不变 11 | |-Models 12 | |-s2_01 # 存放模型文件,以及训练历史文件。复现时请新建此文件夹!!! 13 | |-result 14 | |-s2_01_final_v4_b_raw # 存放结果文件。复现时请新建此文件夹!!! 15 | 16 | (2.注明哪些部分使用到开源工具) 17 | 2. 程序主要用到开源工具 18 | (1)jieba、gensim,预训练词向量 19 | (2)keras,建模 20 | 21 | (3.复现时需要在一个新的dsw环境下额外安装哪些包) 22 | 3. 文件后面使用pip freeze导出了所有安装包,复现时相应安装即可。 23 | 24 | (4.运行时长,分别注明训练时长和预测时长) 25 | 4. 时长: 26 | (1)数据准备:~5min 27 | (2)模型训练:~868(每epoch)×25(epoch数)×2(正反两次)sec = 723.3min = 12hour 28 | (3)模型预测:~36min 29 | 30 | (5.生成的提交结果文件所在位置) 31 | 5. 目录:./Demo/result/s2_01_final_v4_b_raw 32 | 33 | (没有6) 34 | 35 | (7.队伍名称 队长简介及联系方式) 36 | 7. 队伍 37 | 名称:megemini 38 | 邮箱:megemini@outlook.com 39 | 40 | (8.其他需特殊说明的部分,没有则不写。) 41 | 8. 其他事项 42 | (1)由于使用jupyter notebook进行编辑,notebook中会记录每一步运行的结果,如果出入较大则考虑代码运行出错。 43 | (2)由于之前run.ipynb文件是在Demo文件夹内,而且会在Demo中的Models文件夹中产生模型文件,复现如果出现路径错误请联系队长修改。 44 | (3)打包ann目录的时候,请务必过滤掉ipynb的隐藏文件,以防结果出错!!! 45 | 46 | 9. 解题思路 47 | 基本思路为:(1)FastText生成embedding向量 -> (2)Stack-Bi-LSTM(GRU)模型 -> (3)Bi-Sample Training 48 | (1)FastText生成embedding向量 49 | 复赛数据为arg_a(arg_1)和arg_b(arg_2)的匹配问题,本程序针对arg_b进行滑窗采样以生成训练数据。 50 | 这就可能会产生一个问题,由于arg_b在text中的分布不均匀导致网络词向量不能很好得到训练,而且随着滑窗的缩小,问题可能越来越明显。所以,这里先使用train text训练FastText生成词向量,再导入Stack-Bi-LSTM模型进行训练。 51 | (2)Stack-Bi-LSTM(GRU)模型 52 | Stack是指三个输入:text、arg_a、arg_b分别通过embedding和Bi-LSTM后连接到一起,再一同送入Bi-LSTM、Bi-GRU模型这样的一个模型结构。 53 | 其中text为滑窗范围内(此处为320)的文本编码;arg_a为滑窗范围内所有arg_a的索引值;arg_b为此滑窗内需要匹配的某一个arg_b索引值。 54 | 模型的输出为,此滑窗内针对此arg_b的arg_a的0/1真值序列。 55 | (3)Bi-Sample Training 56 | Bi-LSTM和Bi-GRU虽然为双向结构,但文本本身会存在左右顺序,所以这里使用正向、反向两份数据进行训练。也就相当于数据增强了一倍。但是,如果将两份数据混在一起训练,可能会彼此产生噪音,所以,在训练时间相同的情况下,正反向分别训练一个模型然后融合预测。 57 | 58 | 10. pip 环境 59 | 60 | absl-py==0.6.1 61 | asn1crypto==0.23.0 62 | astor==0.7.1 63 | backcall==0.1.0 64 | bleach==3.0.2 65 | boto==2.49.0 66 | boto3==1.9.57 67 | botocore==1.12.57 68 | bz2file==0.98 69 | certifi==2018.10.15 70 | cffi==1.11.2 71 | chardet==3.0.4 72 | conda==4.3.31 73 | cryptography==2.1.4 74 | cycler==0.10.0 75 | decorator==4.3.0 76 | defusedxml==0.5.0 77 | docutils==0.14 78 | entrypoints==0.2.3 79 | future==0.16.0 80 | gast==0.2.0 81 | gensim==3.6.0 82 | grpcio==1.16.0 83 | h5py==2.8.0 84 | idna==2.7 85 | ipykernel==5.1.0 86 | ipython==7.1.1 87 | ipython-genutils==0.2.0 88 | jedi==0.13.1 89 | jieba==0.39 90 | Jinja2==2.10 91 | jmespath==0.9.3 92 | jsonschema==2.6.0 93 | jupyter-client==5.2.3 94 | jupyter-core==4.4.0 95 | jupyterlab==0.34.12 96 | jupyterlab-launcher==0.13.1 97 | jupyterlab-prometheus==0.1 98 | Keras==2.2.4 99 | Keras-Applications==1.0.6 100 | Keras-Preprocessing==1.0.5 101 | keras-self-attention==0.31.0 102 | kiwisolver==1.0.1 103 | lightgbm==2.2.2 104 | Markdown==3.0.1 105 | MarkupSafe==1.1.0 106 | matplotlib==3.0.0 107 | mistune==0.8.4 108 | nbconvert==5.4.0 109 | nbformat==4.4.0 110 | notebook==5.7.0 111 | np-utils==0.5.5.2 112 | numpy==1.15.2 113 | pandas==0.23.4 114 | pandocfilters==1.4.2 115 | parso==0.3.1 116 | pexpect==4.6.0 117 | pickleshare==0.7.5 118 | Pillow==5.3.0 119 | prometheus-client==0.4.2 120 | prompt-toolkit==2.0.7 121 | protobuf==3.6.1 122 | ptyprocess==0.6.0 123 | pycosat==0.6.3 124 | pycparser==2.18 125 | Pygments==2.2.0 126 | pyOpenSSL==17.5.0 127 | pyparsing==2.2.2 128 | PySocks==1.6.7 129 | python-dateutil==2.7.5 130 | pytz==2018.5 131 | PyYAML==3.13 132 | pyzmq==17.1.2 133 | requests==2.19.1 134 | ruamel-yaml==0.11.14 135 | s3transfer==0.1.13 136 | scikit-learn==0.20.0 137 | scipy==1.1.0 138 | Send2Trash==1.5.0 139 | simplegeneric==0.8.1 140 | simplejson==3.16.0 141 | six==1.11.0 142 | smart-open==1.7.1 143 | tensorboard==1.11.0 144 | tensorflow-gpu==1.11.0 145 | termcolor==1.1.0 146 | terminado==0.8.1 147 | testpath==0.4.2 148 | torch==0.4.1 149 | torchtext==0.3.1 150 | torchvision==0.2.1 151 | tornado==5.1.1 152 | tqdm==4.28.1 153 | traitlets==4.3.2 154 | urllib3==1.23 155 | wcwidth==0.1.7 156 | webencodings==0.5.1 157 | Werkzeug==0.14.1 158 | xgboost==0.81 159 | 160 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import warnings 7 | import re 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from sklearn.model_selection import KFold 13 | from sklearn.metrics import accuracy_score 14 | from sklearn.linear_model import Ridge, RidgeCV 15 | from sklearn.linear_model import ElasticNet, Lasso, BayesianRidge, LassoLarsIC 16 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor, GradientBoostingClassifier 17 | from sklearn.ensemble import AdaBoostRegressor, BaggingRegressor, BaggingClassifier 18 | from sklearn.tree import ExtraTreeRegressor, ExtraTreeClassifier 19 | from sklearn.kernel_ridge import KernelRidge 20 | from sklearn.pipeline import make_pipeline 21 | from sklearn.preprocessing import RobustScaler 22 | from sklearn.base import BaseEstimator, TransformerMixin, RegressorMixin, clone 23 | from sklearn.model_selection import KFold, cross_val_score, train_test_split 24 | from sklearn.metrics import mean_squared_error 25 | from sklearn.feature_selection import SelectFromModel 26 | from sklearn.model_selection import StratifiedKFold 27 | from sklearn.metrics import f1_score, roc_auc_score 28 | 29 | from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, HashingVectorizer, CountVectorizer 30 | from sklearn.decomposition import TruncatedSVD 31 | from sklearn import preprocessing, pipeline 32 | 33 | 34 | import xgboost as xgb 35 | import lightgbm as lgb 36 | from keras.models import Model, load_model, Sequential 37 | from keras.layers import Input, Add, Dense, Flatten, BatchNormalization, Activation, UpSampling2D, Embedding 38 | from keras.layers.core import Lambda 39 | from keras.layers.convolutional import Conv2D, Conv2DTranspose 40 | from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling2D 41 | from keras.layers.merge import concatenate 42 | from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau 43 | from keras import backend as K 44 | from keras import losses, metrics, optimizers 45 | from keras.applications import MobileNet 46 | # from keras.applications.mobilenet import relu6, DepthwiseConv2D 47 | 48 | import tensorflow as tf 49 | import keras 50 | 51 | from glob import glob 52 | 53 | import simplejson as json 54 | 55 | import pickle 56 | import gensim 57 | import jieba 58 | import itertools 59 | 60 | from keras.models import Sequential 61 | from keras.layers import Dense, Dropout, Concatenate, Add, SpatialDropout1D 62 | from keras.layers import Embedding 63 | # from keras.layers import LSTM, TimeDistributed, Bidirectional, CuDNNLSTM, CuDNNGRU 64 | from keras.layers import LSTM, TimeDistributed, Bidirectional 65 | 66 | import keras.backend as K 67 | 68 | 69 | LABELS = ['Normal', 70 | 'Disease', 71 | 'Reason', 72 | 'Symptom', 73 | 'Test', 74 | 'Test_Value', 75 | 'Drug', 76 | 'Frequency', 77 | 'Amount', 78 | 'Method', 79 | 'Treatment', 80 | 'Operation', 81 | 'Anatomy', 82 | 'Level', 83 | 'Duration', 84 | 'SideEff', 85 | ] 86 | 87 | 88 | LABELS_0 = ['Normal', 'Test', 'Symptom', 'Treatment', 'Drug', 'Anatomy', ] 89 | LABELS_1 = ['Normal', 'Frequency', 'Duration', 'Amount', 'Method', 'SideEff', ] 90 | 91 | LABELS_DICT = dict(zip(LABELS, range(len(LABELS)))) 92 | 93 | LABELS_DICT_IDX = {str(v):k for k, v in LABELS_DICT.items()} 94 | 95 | ARG_0a_SET = set(['Test', 'Symptom', 'Treatment', 'Drug', 'Anatomy']) 96 | 97 | ARG_1a_SET = set(['Frequency', 'Duration', 'Amount', 'Method', 'SideEff']) 98 | 99 | ARG_0b = 'Disease' 100 | ARG_1b = 'Drug' 101 | 102 | 103 | TRAN_TAB = str.maketrans(u'!?【】『』〔〕()%#@&1234567890“”’‘', u'!?[][][]()%#@&1234567890""\'\'') 104 | 105 | TOTAL_SIZE = None 106 | 107 | TRAIN_LENGTH_HALF = 160 108 | TRAIN_LENGTH = (TRAIN_LENGTH_HALF*2) 109 | 110 | LEN_LABELS = len(LABELS_DICT) 111 | 112 | 113 | NOISE_RATIO = 0.02 114 | NOISE_ARGA = 0.01 115 | 116 | MARGIN_LIST = [-140, -112, -84, -56, -28, 0, 28, 56, 84, 112, 140] 117 | 118 | W2V_LEN_TEXT = 64 119 | W2V_LEN_ARGA = 32 120 | W2V_LEN_ARGB = 8 121 | 122 | TOTAL_ARGA = LEN_LABELS + 1 123 | TOTAL_ARGB = 3 124 | 125 | FLAG_DROP = True 126 | 127 | KFOLD = 10 128 | 129 | EPOCHS = 25 130 | 131 | 132 | def remove_semicolon(t): 133 | 134 | if ';' in (t): 135 | 136 | temp = t.split(';') 137 | temp_0 = temp[0].split(' ')[:-1] 138 | temp_1 = temp[-1].split(' ')[-1:] 139 | 140 | return temp_0 + temp_1 141 | 142 | else: 143 | return t.split(' ') 144 | 145 | def split_arg(t): 146 | 147 | t_list = t.split(' ') 148 | 149 | if '_' in t_list[0]: 150 | arg_a, arg_b = t_list[0].split('_') 151 | elif '-' in t_list[0]: 152 | arg_a, arg_b = t_list[0].split('-') 153 | 154 | arg1 = t_list[1].split(':')[1] 155 | arg2 = t_list[2].split(':')[1] 156 | 157 | return [arg_a, arg_b, arg1, arg2] 158 | 159 | def get_df(path, arg=False, filted=False): 160 | 161 | if arg: 162 | anno = pd.read_csv(path, sep='\t', header=None) 163 | 164 | anno = anno[anno[0].str.startswith('R')].reset_index(drop=True) 165 | 166 | anno_1 = np.vstack((anno[1].apply(split_arg))) 167 | 168 | anno_df = pd.DataFrame() 169 | anno_df['id'] = anno[0] 170 | anno_df['arg_a'] = anno_1[:, 0] 171 | anno_df['arg_b'] = anno_1[:, 1] 172 | anno_df['arg_1'] = anno_1[:, 2] 173 | anno_df['arg_2'] = anno_1[:, 3] 174 | 175 | return anno_df, anno 176 | 177 | else: 178 | anno = pd.read_csv(path, sep='\t', header=None) 179 | 180 | anno = anno[anno[0].str.startswith('T')].reset_index(drop=True) 181 | 182 | anno_1 = np.vstack((anno[1].apply(remove_semicolon))) 183 | 184 | anno_df = pd.DataFrame() 185 | anno_df['id'] = anno[0] 186 | anno_df['label'] = anno_1[:, 0] 187 | anno_df['idx_0'] = anno_1[:, 1].astype(int) 188 | anno_df['idx_1'] = anno_1[:, 2].astype(int) 189 | anno_df['text'] = anno[2] 190 | 191 | if filted: 192 | anno_df = anno_df[anno_df['label'].isin(ARG_ALL)].reset_index(drop=True) 193 | 194 | return anno_df, anno 195 | 196 | def text_filter(text): 197 | 198 | _l = len(text) 199 | _text = text.translate(TRAN_TAB) 200 | assert _l == len(_text) 201 | 202 | return _text 203 | 204 | def _add_anno_to_arg(df_anno, df_arg): 205 | 206 | anno_a = df_anno # train_anno_dict.get('0') 207 | anno_b = df_arg # train_arg_dict.get('0') 208 | 209 | arg1_df = (pd.merge(anno_b[['id', 'arg_1', 'arg_a']], 210 | anno_a[['id', 'idx_0', 'idx_1', 'label', 'text']], 211 | left_on=['arg_1'], 212 | right_on=['id']).drop(columns=['arg_1', 'id_y'])).rename(columns={'id_x':'id', 213 | 'idx_0':'arg_1_idx_0', 214 | 'idx_1':'arg_1_idx_1', 215 | 'label':'arg_a_label', 216 | 'text':'arg_a_text'}) 217 | arg2_df = (pd.merge(anno_b[['id', 'arg_2', 'arg_b']], 218 | anno_a[['id', 'idx_0', 'idx_1', 'label', 'text']], 219 | left_on=['arg_2'], 220 | right_on=['id']).drop(columns=['arg_2', 'id_y'])).rename(columns={'id_x':'id', 221 | 'idx_0':'arg_2_idx_0', 222 | 'idx_1':'arg_2_idx_1', 223 | 'label':'arg_b_label', 224 | 'text':'arg_b_text'}) 225 | return pd.merge(pd.merge(anno_b, arg1_df), arg2_df) 226 | 227 | def add_anno_to_arg(df_anno_dict, df_arg_dict): 228 | 229 | arg_dict = {} 230 | 231 | for k in df_anno_dict: 232 | anno = _add_anno_to_arg(df_anno_dict.get(k), df_arg_dict.get(k)) 233 | idx_all = np.sort(anno[['arg_1_idx_0', 'arg_1_idx_1', 'arg_2_idx_0', 'arg_2_idx_1']].values, axis=1) 234 | idx_len = idx_all[:, -1] - idx_all[:, 0] 235 | anno['idx_len_max'] = idx_len 236 | anno['idx_len_mid'] = (anno['arg_2_idx_1']+anno['arg_2_idx_0'])/2 - (anno['arg_1_idx_1']+anno['arg_1_idx_0'])/2 237 | anno['arg_len'] = anno['arg_2'].str[1:].astype(int) - anno['arg_1'].str[1:].astype(int) 238 | anno['len_abs_min'] = anno.apply(lambda x: np.min((np.abs(x['idx_len_max']), np.abs(x['idx_len_mid']), np.abs(x['arg_len']))), axis=1) 239 | 240 | # anno = anno[~(anno['arg_a'] != anno['arg_a_label']) | (anno['arg_b'] != anno['arg_b_label'])].reset_index(drop=True) 241 | anno = anno[(anno['arg_a'] == anno['arg_a_label']) & (anno['arg_b'] == anno['arg_b_label'])].reset_index(drop=True) 242 | 243 | # thres = np.mean(anno['idx_len_max']) + np.std(anno['idx_len_max']) * 2 244 | # anno = anno[(anno['idx_len_max'] < thres) | (anno['idx_len_max'] < 350)] 245 | # anno = anno[(anno['idx_len_max'] < thres)] 246 | 247 | arg_dict[k] = anno 248 | 249 | return arg_dict 250 | 251 | 252 | def get_a_labels_dict(anno_dict, text_dict, labels_dict): 253 | train_ab_lables_dict = {} 254 | for k, v in anno_dict.items(): 255 | 256 | _labels = np.ones((len(text_dict.get(k)), 1), dtype=np.uint8) * labels_dict.get('Normal') 257 | 258 | for idx, row in enumerate(v.itertuples()): 259 | 260 | _label = row[2] 261 | 262 | if _label in labels_dict: 263 | 264 | _idx_0 = int(row[3]) 265 | _idx_1 = int(row[4]) 266 | 267 | if _idx_0 >= _idx_1: 268 | print('Bad idx', k, row) 269 | break 270 | 271 | _labels[_idx_0:_idx_1] = labels_dict.get(_label) 272 | 273 | train_ab_lables_dict[k] = _labels 274 | 275 | return train_ab_lables_dict 276 | 277 | def get_aa_labels_dict(anno_dict, text_dict, labels_set): 278 | train_ab_lables_dict = {} 279 | for k, v in anno_dict.items(): 280 | 281 | _labels = np.zeros((len(text_dict.get(k)), 1), dtype=np.uint8) 282 | 283 | for idx, row in enumerate(v.itertuples()): 284 | 285 | _label = row[2] 286 | 287 | if _label in labels_set: 288 | 289 | _idx_0 = int(row[3]) 290 | _idx_1 = int(row[4]) 291 | 292 | if _idx_0 >= _idx_1: 293 | print('Bad idx', k, row) 294 | break 295 | 296 | # _labels[_idx_0:_idx_1] = labels_dict.get(_label) 297 | _labels[_idx_0:_idx_1] = 1 298 | 299 | train_ab_lables_dict[k] = _labels 300 | 301 | return train_ab_lables_dict 302 | 303 | def _get_set_b_list_dict(b_dict, set_b, arg_df, len_text): 304 | 305 | for _b in set_b: 306 | _df = arg_df[arg_df['arg_2'] == _b] 307 | _x = np.zeros((len_text), dtype=np.uint8) 308 | 309 | for idx_b, row_b in enumerate(_df.itertuples()): 310 | 311 | assert row_b[6] >= 0 312 | assert row_b[7] >= 0 313 | assert row_b[6] <= len_text 314 | assert row_b[7] <= len_text 315 | 316 | _x[row_b[6]:row_b[7]] = 1 317 | 318 | b_dict[_b] = _x 319 | 320 | return b_dict 321 | 322 | def _get_meta_data(data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list, 323 | text_encoded, b_dict, list_a, df_b, x_idx, len_text, set_b, list_aa): 324 | 325 | np.random.seed(522) 326 | 327 | for idx_b, row_b in enumerate(df_b.itertuples()): 328 | 329 | _, id_b, label_b, idx_b_0, idx_b_1, _ = row_b 330 | 331 | _idx_b_mid = int((idx_b_0 + idx_b_1)/2) 332 | 333 | for margin in MARGIN_LIST: 334 | 335 | idx_b_mid = _idx_b_mid + margin 336 | 337 | _left = idx_b_mid - TRAIN_LENGTH_HALF 338 | _right = idx_b_mid + TRAIN_LENGTH_HALF 339 | 340 | idx_left = max(0, _left) 341 | idx_right = min(len_text, _right) 342 | 343 | idx_pad_left = np.abs(min(0, _left)) 344 | idx_pad_right = np.abs(min(0, len_text - _right)) 345 | 346 | # x arg b 347 | _b = np.zeros((TRAIN_LENGTH, 1), dtype=np.uint8) 348 | b_start = idx_pad_left + (idx_b_0 - idx_left) 349 | b_end = idx_pad_left + (idx_b_1 - idx_left) 350 | 351 | if (b_start < 0) or (b_end < 0) or (b_start > TRAIN_LENGTH) or (b_end > TRAIN_LENGTH): 352 | continue 353 | 354 | assert b_start >= 0 355 | assert b_end >= 0 356 | assert b_start <= TRAIN_LENGTH 357 | assert b_end <= TRAIN_LENGTH 358 | 359 | _b[b_start:b_end] = x_idx+1 360 | 361 | # x arg a 362 | _a = np.pad(list_a[idx_left:idx_right], ((idx_pad_left, idx_pad_right), (0, 0)), mode='constant', constant_values=0).copy() 363 | 364 | # x arg aa 365 | _aa = np.pad(list_aa[idx_left:idx_right], ((idx_pad_left, idx_pad_right), (0, 0)), mode='constant', constant_values=0).copy() 366 | 367 | # x text 368 | _x = np.pad(text_encoded[idx_left:idx_right], (idx_pad_left, idx_pad_right), mode='constant', constant_values=0).copy() 369 | 370 | mask_t = None 371 | mask_a = None 372 | if NOISE_RATIO != 0: 373 | mask = np.argwhere(np.squeeze(_a + _b) == 0).reshape(-1)[idx_pad_left:TRAIN_LENGTH-idx_pad_right] 374 | # mask = np.argwhere(np.squeeze(_b) == 0).reshape(-1)[idx_pad_left:TRAIN_LENGTH-idx_pad_right] 375 | if len(mask)>0: 376 | mask_t = np.random.choice(mask, int(len(mask)*NOISE_RATIO)) 377 | _x[mask_t] = 0 378 | 379 | if NOISE_ARGA != 0: 380 | mask = np.argwhere(np.squeeze(_a) != 0).reshape(-1) #[idx_pad_left:TRAIN_LENGTH-idx_pad_right] 381 | if len(mask)>0: 382 | # mask_a = np.random.choice(mask, int(np.ceil(len(mask)*NOISE_ARGA))) 383 | mask_a = np.random.choice(mask, int(len(mask)*NOISE_ARGA)) 384 | _x[mask_a] = 0 385 | # _a[mask_a] = 0 386 | # _aa[mask_a] = 0 387 | 388 | data_x_arga_list.append(_a) 389 | data_x_argaa_list.append(_aa) 390 | data_x_argb_list.append(_b) 391 | data_x_text_list.append(_x) 392 | 393 | # y 394 | _y = None 395 | if id_b in set_b: 396 | _y = np.pad( 397 | b_dict.get(id_b)[idx_left:idx_right], 398 | (idx_pad_left, idx_pad_right), 399 | mode='constant', constant_values=0)[..., np.newaxis].copy() 400 | # if mask_a is not None: 401 | # _y[mask_a] = 0 402 | 403 | else: 404 | _y = np.zeros((TRAIN_LENGTH, 1), dtype=np.uint8) 405 | 406 | data_y_list.append(_y) 407 | 408 | 409 | return (data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list) 410 | 411 | def get_meta_data(data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list, 412 | text_encoded, anno_df, list_a, arg_df, list_0a, list_1a): 413 | 414 | len_text = len(text_encoded) 415 | 416 | df_0b = anno_df[anno_df['label'] == ARG_0b].sort_values('idx_0').reset_index(drop=True) 417 | df_1b = anno_df[anno_df['label'] == ARG_1b].sort_values('idx_0').reset_index(drop=True) 418 | 419 | set_0b = None 420 | set_1b = None 421 | b_dict = None 422 | if arg_df is not None: 423 | set_0b = arg_df[arg_df['arg_b'] == ARG_0b]['arg_2'].unique() 424 | set_1b = arg_df[arg_df['arg_b'] == ARG_1b]['arg_2'].unique() 425 | 426 | b_dict = {} 427 | b_dict = _get_set_b_list_dict(b_dict, set_0b, arg_df, len_text) 428 | b_dict = _get_set_b_list_dict(b_dict, set_1b, arg_df, len_text) 429 | 430 | # print(df_0b.shape, df_1b.shape, list_0a.shape, list_1a.shape, len(set_0b), len(set_1b)) 431 | 432 | # list_a = list_0a 433 | list_aa = list_0a 434 | df_b = df_0b 435 | set_b = set_0b 436 | x_idx = 0 437 | 438 | (data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list) = \ 439 | _get_meta_data(data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list, 440 | text_encoded, b_dict, list_a, df_b, x_idx, len_text, set_b, list_aa) 441 | 442 | # list_a = list_1a 443 | list_aa = list_1a 444 | df_b = df_1b 445 | set_b = set_1b 446 | x_idx = 1 447 | 448 | (data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list) = \ 449 | _get_meta_data(data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list, 450 | text_encoded, b_dict, list_a, df_b, x_idx, len_text, set_b, list_aa) 451 | 452 | return (data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list) 453 | 454 | from keras.layers import LSTM, GRU 455 | 456 | 457 | emb_weights = None 458 | 459 | def get_model(emb_weights=emb_weights): 460 | 461 | inputs_text = Input((TRAIN_LENGTH, )) 462 | emb_text = Embedding(input_dim=TOTAL_SIZE+1,output_dim=W2V_LEN_TEXT, weights=[emb_weights], trainable=False)(inputs_text) 463 | emb_text = SpatialDropout1D(0.3)(emb_text) 464 | # x_text = Bidirectional(CuDNNLSTM(128, return_sequences=True))(emb_text) 465 | x_text = Bidirectional(LSTM(128, return_sequences=True))(emb_text) 466 | x_text = Dropout(0.5)(x_text) 467 | 468 | inputs_arg_a = Input((TRAIN_LENGTH, )) 469 | emb_arga = Embedding(input_dim=TOTAL_ARGA,output_dim=W2V_LEN_ARGA)(inputs_arg_a) 470 | emb_arga = SpatialDropout1D(0.3)(emb_arga) 471 | # x_arg_a = Bidirectional(CuDNNLSTM(128, return_sequences=True))(emb_arga) 472 | x_arg_a = Bidirectional(LSTM(128, return_sequences=True))(emb_arga) 473 | x_arg_a = Dropout(0.5)(x_arg_a) 474 | 475 | inputs_arg_b = Input((TRAIN_LENGTH, )) 476 | emb_argb = Embedding(input_dim=TOTAL_ARGB,output_dim=W2V_LEN_ARGB)(inputs_arg_b) 477 | # x_arg_b = Bidirectional(CuDNNLSTM(128, return_sequences=True))(emb_argb) 478 | x_arg_b = Bidirectional(LSTM(128, return_sequences=True))(emb_argb) 479 | 480 | inputs_all = Concatenate()([x_text, x_arg_a, x_arg_b]) 481 | 482 | # x0 = Bidirectional(CuDNNLSTM(128, return_sequences=True))(inputs_all) 483 | x0 = Bidirectional(LSTM(128, return_sequences=True))(inputs_all) 484 | # x1 = Bidirectional(CuDNNGRU(128, return_sequences=True))(inputs_all) 485 | x1 = Bidirectional(GRU(128, return_sequences=True))(inputs_all) 486 | 487 | x = Concatenate()([x0, x1]) 488 | 489 | # x = Bidirectional(CuDNNLSTM(128, return_sequences=True))(x) 490 | x = Bidirectional(LSTM(128, return_sequences=True))(x) 491 | x = Dropout(0.5)(x) 492 | x = TimeDistributed(Dense(64, activation='relu'))(x) 493 | x = Dropout(0.5)(x) 494 | x = TimeDistributed(Dense(1))(x) 495 | x = Activation('sigmoid')(x) 496 | 497 | model = Model(inputs=[inputs_text, inputs_arg_a, inputs_arg_b], outputs=x) 498 | model.compile(loss=losses.binary_crossentropy, 499 | optimizer=optimizers.RMSprop(), 500 | metrics=[metrics.binary_accuracy, metrics.binary_crossentropy]) 501 | 502 | return model 503 | 504 | 505 | def _get_test_meta_data(b_list, idx_pad_left_list, idx_left_list, idx_right_list, 506 | x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list, text_encoded, list_a, df_b, x_idx, len_text, list_aa): 507 | 508 | 509 | for idx_b, row_b in enumerate(df_b.itertuples()): 510 | 511 | _, id_b, label_b, idx_b_0, idx_b_1, _ = row_b 512 | 513 | _idx_b_mid = int((idx_b_0 + idx_b_1)/2) 514 | 515 | for margin in MARGIN_LIST: 516 | 517 | idx_b_mid = _idx_b_mid + margin 518 | 519 | _left = idx_b_mid - TRAIN_LENGTH_HALF 520 | _right = idx_b_mid + TRAIN_LENGTH_HALF 521 | 522 | idx_left = max(0, _left) 523 | idx_right = min(len_text, _right) 524 | 525 | idx_pad_left = np.abs(min(0, _left)) 526 | idx_pad_right = np.abs(min(0, len_text - _right)) 527 | 528 | # x arg b 529 | _b = np.zeros((TRAIN_LENGTH, 1), dtype=np.uint8) 530 | b_start = idx_pad_left + (idx_b_0 - idx_left) 531 | b_end = idx_pad_left + (idx_b_1 - idx_left) 532 | 533 | if (b_start < 0) or (b_end < 0) or (b_start > TRAIN_LENGTH) or (b_end > TRAIN_LENGTH): 534 | continue 535 | 536 | assert b_start >= 0 537 | assert b_end >= 0 538 | assert b_start <= TRAIN_LENGTH 539 | assert b_end <= TRAIN_LENGTH 540 | 541 | 542 | _b[b_start:b_end] = x_idx+1 543 | 544 | x_argb_list.append(_b) 545 | 546 | 547 | # b_list 548 | b_list.append(id_b) 549 | 550 | # idx_pad_left 551 | idx_pad_left_list.append(idx_pad_left) 552 | 553 | # idx_left_list 554 | idx_left_list.append(idx_left) 555 | 556 | # idx_right_list 557 | idx_right_list.append(idx_right) 558 | 559 | # x arg a 560 | _a = np.pad(list_a[idx_left:idx_right], ((idx_pad_left, idx_pad_right), (0, 0)), mode='constant', constant_values=0) 561 | 562 | x_arga_list.append(_a) 563 | 564 | # x arg aa 565 | _aa = np.pad(list_aa[idx_left:idx_right], ((idx_pad_left, idx_pad_right), (0, 0)), mode='constant', constant_values=0) 566 | 567 | x_argaa_list.append(_aa) 568 | 569 | # x text 570 | _x = np.pad(text_encoded[idx_left:idx_right], (idx_pad_left, idx_pad_right), mode='constant', constant_values=0).copy() 571 | 572 | x_text_list.append(_x) 573 | 574 | return b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list 575 | 576 | def get_test_meta_data(x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list, text_encoded, anno_df, list_a, list_0a, list_1a): 577 | 578 | len_text = len(text_encoded) 579 | 580 | df_0b = anno_df[anno_df['label'] == ARG_0b].sort_values('idx_0').reset_index(drop=True) 581 | df_1b = anno_df[anno_df['label'] == ARG_1b].sort_values('idx_0').reset_index(drop=True) 582 | 583 | b_list = [] 584 | idx_pad_left_list = [] 585 | idx_left_list = [] 586 | idx_right_list = [] 587 | 588 | # list_a = list_0a 589 | list_aa = list_0a 590 | df_b = df_0b 591 | x_idx = 0 592 | 593 | b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list = \ 594 | _get_test_meta_data(b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list, text_encoded, list_a, df_b, x_idx, len_text, list_aa) 595 | 596 | # list_a = list_1a 597 | list_aa = list_1a 598 | df_b = df_1b 599 | x_idx = 1 600 | 601 | b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list = \ 602 | _get_test_meta_data(b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list, text_encoded, list_a, df_b, x_idx, len_text, list_aa) 603 | 604 | return (b_list, idx_pad_left_list, idx_left_list, idx_right_list, x_text_list, x_arga_list, x_argaa_list, x_argb_list, y_list) 605 | 606 | 607 | if __name__ == "__main__": 608 | 609 | # Part1 train 610 | train_ann_file_list = glob('./Demo/DataSets/ruijin_round2_train/ruijin_round2_train/*.ann') 611 | train_txt_file_list = glob('./Demo/DataSets/ruijin_round2_train/ruijin_round2_train/*.txt') 612 | 613 | test_txt_file_list = glob('./Demo/DataSets/ruijin_round2_test_b/ruijin_round2_test_b/*.txt') 614 | test_ann_file_list = glob('./Demo/DataSets/ruijin_round2_test_b/ruijin_round2_test_b/*.ann') 615 | 616 | test_ann_file_list_dict = {str(path.split('/')[-1].split('.')[0]):path for path in test_ann_file_list} 617 | 618 | train_text_dict = {} 619 | for path in train_txt_file_list: 620 | with open(path) as f: 621 | train_text_dict[str(path.split('/')[-1].split('.')[0])] = list(text_filter(str(f.read()))) 622 | 623 | 624 | train_anno_dict = {str(path.split('/')[-1].split('.')[0]):get_df(path, filted=False)[0] for path in train_ann_file_list} 625 | 626 | train_arg_dict = {str(path.split('/')[-1].split('.')[0]):get_df(path, arg=True)[0] for path in train_ann_file_list} 627 | 628 | train_arg_dict = add_anno_to_arg(train_anno_dict, train_arg_dict) 629 | 630 | 631 | train_text_set = set() 632 | for k, v in train_text_dict.items(): 633 | train_text_set = train_text_set | set(v) 634 | 635 | 636 | train_encode_dict = dict(zip(sorted(list(train_text_set)), range(len(train_text_set)))) 637 | 638 | TOTAL_SIZE = len(train_text_set)+1 639 | 640 | train_text_encoded_dict = {} 641 | for k, v in train_text_dict.items(): 642 | train_text_encoded_dict[k] = list(map(lambda x: train_encode_dict.get(x) 643 | if x in train_encode_dict else 0, v)) 644 | 645 | 646 | train_encode_dict_idx = {str(v):k for k, v in train_encode_dict.items()} 647 | 648 | train_0a_lables_dict = get_aa_labels_dict(train_anno_dict, train_text_dict, LABELS_0) 649 | train_1a_lables_dict = get_aa_labels_dict(train_anno_dict, train_text_dict, LABELS_1) 650 | 651 | train_a_lables_dict = get_a_labels_dict(train_anno_dict, train_text_dict, LABELS_DICT) 652 | 653 | 654 | data_x_text_list = [] 655 | data_x_arga_list = [] 656 | data_x_argaa_list = [] 657 | data_x_argb_list = [] 658 | data_y_list = [] 659 | 660 | 661 | for k in train_text_encoded_dict: 662 | 663 | t0 = train_text_encoded_dict.get(k) 664 | anno_0 = train_anno_dict.get(k) 665 | arg_0 = train_arg_dict.get(k) 666 | list_a = train_a_lables_dict.get(k) 667 | list_0a = train_0a_lables_dict.get(k) 668 | list_1a = train_1a_lables_dict.get(k) 669 | 670 | (data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list) = \ 671 | get_meta_data(data_x_text_list, data_x_arga_list, data_x_argaa_list, data_x_argb_list, data_y_list, 672 | text_encoded=t0, anno_df=anno_0, list_a=list_a, arg_df=arg_0, list_0a=list_0a, list_1a=list_1a) 673 | 674 | 675 | model_ft = gensim.models.FastText([jieba.lcut(''.join(v)) for v in train_text_dict.values()], size=W2V_LEN_TEXT, word_ngrams=9, window=24, min_count=1, workers=8, seed=522) 676 | 677 | 678 | emb_weights = np.zeros((TOTAL_SIZE+1, W2V_LEN_TEXT)) 679 | 680 | count_emb_0 = 0 681 | 682 | for k, v in train_encode_dict.items(): 683 | 684 | _k = str(k) 685 | if _k in model_ft.wv: 686 | emb_weights[v] = model_ft.wv[_k] 687 | else: 688 | count_emb_0 += 0 689 | 690 | 691 | # data_x_text_list = np.squeeze(data_x_text_list).astype(np.uint16) 692 | # data_x_arga_list = np.squeeze(data_x_arga_list).astype(np.uint8) 693 | # data_x_argaa_list = np.asarray(data_x_argaa_list, dtype=np.uint8) 694 | # data_x_argb_list = np.squeeze(data_x_argb_list).astype(np.uint8) 695 | # data_y_list = np.asarray(data_y_list, dtype=np.uint8) 696 | 697 | data_x_text_list = np.squeeze(data_x_text_list).astype(np.uint16)[:128] 698 | data_x_arga_list = np.squeeze(data_x_arga_list).astype(np.uint8)[:128] 699 | data_x_argaa_list = np.asarray(data_x_argaa_list, dtype=np.uint8)[:128] 700 | data_x_argb_list = np.squeeze(data_x_argb_list).astype(np.uint8)[:128] 701 | data_y_list = np.asarray(data_y_list, dtype=np.uint8)[:128] 702 | 703 | 704 | kf = KFold(n_splits=KFOLD, shuffle=True, random_state=522).split(data_x_text_list) 705 | train_idx, valid_idx = next(kf) 706 | 707 | 708 | train_x_text_list = data_x_text_list[train_idx] 709 | train_x_arga_list = data_x_arga_list[train_idx] 710 | train_x_argaa_list = data_x_argaa_list[train_idx] 711 | train_x_argb_list = data_x_argb_list[train_idx] 712 | train_y_list = data_y_list[train_idx] 713 | 714 | valid_x_text_list = data_x_text_list[valid_idx] 715 | valid_x_arga_list = data_x_arga_list[valid_idx] 716 | valid_x_argaa_list = data_x_argaa_list[valid_idx] 717 | valid_x_argb_list = data_x_argb_list[valid_idx] 718 | valid_y_list = data_y_list[valid_idx] 719 | 720 | MODEL_BASE = 's2_01' 721 | MODEL_NAME = MODEL_BASE + '_final_v4_0' 722 | 723 | 724 | model = get_model(emb_weights) 725 | 726 | 727 | from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau 728 | 729 | checkpoint = ModelCheckpoint('./Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_{epoch:02d}.hd5', monitor='val_loss', verbose=1, 730 | save_best_only=False, save_weights_only=True, mode='auto', period=1) 731 | # lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.6, min_delta=1e-5, patience=2, verbose=1, min_lr = 0.00001) 732 | 733 | 734 | hist_0 = model.fit(x=[train_x_text_list, train_x_arga_list, train_x_argb_list], y=train_y_list, 735 | # batch_size=512, 736 | batch_size=64, 737 | epochs=EPOCHS, 738 | shuffle=True, 739 | validation_data=([valid_x_text_list, valid_x_arga_list, valid_x_argb_list], valid_y_list), 740 | # callbacks=[checkpoint, lr_reduce], 741 | callbacks=[checkpoint], 742 | verbose=2) 743 | 744 | with open('./Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_hist.pkl', 'wb') as f: 745 | pickle.dump(hist_0.history, f) 746 | 747 | 748 | MODEL_BASE = 's2_01' 749 | MODEL_NAME = MODEL_BASE + '_final_v4_1' 750 | 751 | 752 | model = get_model(emb_weights) 753 | 754 | 755 | from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau 756 | 757 | checkpoint = ModelCheckpoint('./Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_{epoch:02d}.hd5', monitor='val_loss', verbose=1, 758 | save_best_only=False, save_weights_only=True, mode='auto', period=1) 759 | # lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.6, min_delta=1e-5, patience=2, verbose=1, min_lr = 0.00001) 760 | 761 | 762 | 763 | hist_1 = model.fit(x=[train_x_text_list[:, ::-1], train_x_arga_list[:, ::-1], train_x_argb_list[:, ::-1]], y=train_y_list[:, ::-1, :], 764 | # batch_size=512, 765 | batch_size=64, 766 | epochs=EPOCHS, 767 | shuffle=True, 768 | validation_data=([valid_x_text_list[:, ::-1], valid_x_arga_list[:, ::-1], valid_x_argb_list[:, ::-1]], valid_y_list[:, ::-1, :]), 769 | # callbacks=[checkpoint, lr_reduce], 770 | callbacks=[checkpoint], 771 | verbose=2) 772 | 773 | with open('./Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_hist.pkl', 'wb') as f: 774 | pickle.dump(hist_1.history, f) 775 | 776 | 777 | # Part2 prediction 778 | test_text_dict = {} 779 | for path in test_txt_file_list: 780 | with open(path) as f: 781 | test_text_dict[str(path.split('/')[-1].split('.')[0])] = list(text_filter(str(f.read()))) 782 | 783 | 784 | 785 | test_text_encoded_dict = {} 786 | for k, v in test_text_dict.items(): 787 | test_text_encoded_dict[k] = list(map(lambda x: train_encode_dict.get(x) 788 | if x in train_encode_dict else 0, v)) 789 | 790 | 791 | test_anno_dict = {str(path.split('/')[-1].split('.')[0]):get_df(path, filted=False)[0] for path in test_ann_file_list} 792 | 793 | 794 | test_0a_lables_dict = get_aa_labels_dict(test_anno_dict, test_text_dict, LABELS_0) 795 | test_1a_lables_dict = get_aa_labels_dict(test_anno_dict, test_text_dict, LABELS_1) 796 | 797 | test_a_lables_dict = get_a_labels_dict(test_anno_dict, test_text_dict, LABELS_DICT) 798 | 799 | 800 | model = get_model(emb_weights) 801 | 802 | # PRE_THRES = 0.6 803 | PRE_THRES = 0.1 804 | 805 | pre_final_dict = {} 806 | 807 | for file_k in test_text_encoded_dict: 808 | 809 | print('Start ', file_k, '---------------------------------') 810 | 811 | # 1. get model data inputs 812 | print('Part 1...') 813 | test_x_text_list = [] 814 | test_x_arga_list = [] 815 | test_x_argaa_list = [] 816 | test_x_argb_list = [] 817 | test_y_list = [] 818 | 819 | test_b_list = [] 820 | test_idx_pad_left_list = [] 821 | test_idx_left_list = [] 822 | test_idx_right_list = [] 823 | 824 | t0 = test_text_encoded_dict.get(file_k) 825 | anno_0 = test_anno_dict.get(file_k) 826 | # arg_0 = train_arg_dict.get(file_k) 827 | arg_0 = None 828 | list_a = test_a_lables_dict.get(file_k) 829 | list_0a = test_0a_lables_dict.get(file_k) 830 | list_1a = test_1a_lables_dict.get(file_k) 831 | 832 | test_b_list, test_idx_pad_left_list, test_idx_left_list, test_idx_right_list, test_x_text_list, test_x_arga_list, test_x_argaa_list, test_x_argb_list, test_y_list = \ 833 | get_test_meta_data(test_x_text_list, test_x_arga_list, test_x_argaa_list, test_x_argb_list, test_y_list, 834 | text_encoded=t0, anno_df=anno_0, list_a=list_a, list_0a=list_0a, list_1a=list_1a) 835 | 836 | print(np.shape(test_b_list), np.shape(test_idx_left_list), np.shape(test_x_text_list), np.shape(test_x_arga_list), np.shape(test_x_argaa_list), 837 | np.shape(test_x_argb_list), np.shape(test_y_list)) 838 | 839 | # 2. get predictions 840 | print('Part 2...') 841 | test_x_text_list = np.squeeze(test_x_text_list).astype(int) 842 | test_x_arga_list = np.squeeze(test_x_arga_list).astype(int) 843 | test_x_argaa_list = np.asarray(test_x_argaa_list, dtype=int) 844 | test_x_argb_list = np.squeeze(test_x_argb_list).astype(int) 845 | 846 | pre_list = [] 847 | 848 | MODEL_BASE = 's2_01' 849 | MODEL_NAME = MODEL_BASE + '_final_v4_0' 850 | # for i in [17, 19, 21, 23, 25]: 851 | for i in [17-1, 19-1, 21-1, 23-1, 25-1]: 852 | w_p = './Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_%02d.hd5' % (i) 853 | model.load_weights(w_p) 854 | # pre_list.append(model.predict([test_x_text_list, test_x_arga_list, test_x_argb_list])) 855 | pre_list.append(model.predict([test_x_text_list, test_x_arga_list, test_x_argb_list], batch_size=32)) 856 | 857 | MODEL_NAME = MODEL_BASE + '_final_v4_1' 858 | # for i in [17, 19, 21, 23, 25]: 859 | for i in [17-1, 19-1, 21-1, 23-1, 25-1]: 860 | w_p = './Demo/Models/' + MODEL_BASE + '/' + MODEL_NAME + '_%02d.hd5' % (i) 861 | model.load_weights(w_p) 862 | # pre_list.append(model.predict([test_x_text_list[:, ::-1], test_x_arga_list[:, ::-1], test_x_argb_list[:, ::-1]])[:, ::-1, :]) 863 | pre_list.append(model.predict([test_x_text_list[:, ::-1], test_x_arga_list[:, ::-1], test_x_argb_list[:, ::-1]], batch_size=32)[:, ::-1, :]) 864 | 865 | pre = np.mean(pre_list, axis=0) 866 | print(pre.shape) 867 | 868 | # 3. get pre dict for each output 869 | print('Part 3...') 870 | pre_test_dict = {} 871 | for idx, p in enumerate(pre): 872 | 873 | # if not len(np.argwhere(p>0.5)) > 0: continue 874 | 875 | _test_b = test_b_list[idx] 876 | _test_idx_pad_left = test_idx_pad_left_list[idx] 877 | _test_idx_left = test_idx_left_list[idx] 878 | _test_idx_right = test_idx_right_list[idx] 879 | 880 | _test_b_label = anno_0[anno_0['id'] == _test_b].iloc[0]['label'] 881 | 882 | if _test_b_label == ARG_0b: 883 | # _anno_0 = anno_0[(anno_0['label'].isin(LABELS_0_DICT.keys())) & (anno_0['idx_0'] >= _test_idx_left) & (anno_0['idx_1'] <= _test_idx_right)] 884 | _anno_0 = anno_0[(anno_0['label'].isin(ARG_0a_SET)) & (anno_0['idx_0'] >= _test_idx_left) & (anno_0['idx_1'] <= _test_idx_right)] 885 | 886 | if _test_b_label == ARG_1b: 887 | # _anno_0 = anno_0[(anno_0['label'].isin(LABELS_1_DICT.keys())) & (anno_0['idx_0'] >= _test_idx_left) & (anno_0['idx_1'] <= _test_idx_right)] 888 | _anno_0 = anno_0[(anno_0['label'].isin(ARG_1a_SET)) & (anno_0['idx_0'] >= _test_idx_left) & (anno_0['idx_1'] <= _test_idx_right)] 889 | 890 | if len(_anno_0) > 0: 891 | 892 | for _idx_a, _row_a in enumerate(_anno_0.itertuples()): 893 | _, _id_a, _label_a, _idx_a_0, _idx_a_1, _ = _row_a 894 | 895 | _k = _label_a + _test_b_label + _id_a + _test_b 896 | 897 | if _k not in pre_test_dict: 898 | pre_test_dict[_k] = {} 899 | pre_test_dict[_k]['arg_a'] = _label_a 900 | pre_test_dict[_k]['arg_b'] = _test_b_label 901 | pre_test_dict[_k]['arg_1'] = _id_a 902 | pre_test_dict[_k]['arg_2'] = _test_b 903 | pre_test_dict[_k]['values'] = 0 904 | pre_test_dict[_k]['counts'] = 0 905 | 906 | _v_l = p[_test_idx_pad_left + _idx_a_0 - _test_idx_left: _test_idx_pad_left + _idx_a_1 - _test_idx_left] 907 | assert len(_v_l) > 0 908 | # if len(_v_l) == 0: 909 | # print(_row_a) 910 | # break 911 | _value = np.mean(_v_l) 912 | pre_test_dict[_k]['values'] = pre_test_dict[_k]['values'] + _value 913 | pre_test_dict[_k]['counts'] = pre_test_dict[_k]['counts'] + 1 914 | 915 | print(len(pre_test_dict)) 916 | 917 | # 4. filter pre with threshold 918 | print('Part 4...') 919 | len_max = 0 920 | pre_test_df = [] 921 | r_id_idx = 1 922 | for k, v in pre_test_dict.items(): 923 | _mean = v['values'] / v['counts'] 924 | 925 | len_max = len_max if len_max > v['counts'] else v['counts'] 926 | assert len_max <= len(MARGIN_LIST) 927 | 928 | if _mean > PRE_THRES: 929 | _row = ['R'+str(r_id_idx), v['arg_a']+'_'+v['arg_b']+' '+'Arg1:'+v['arg_1']+' '+'Arg2:'+v['arg_2']] 930 | pre_test_df.append(_row) 931 | r_id_idx += 1 932 | 933 | print(len(pre_test_df)) 934 | 935 | pre_final_dict[file_k] = pre_test_df 936 | 937 | print('End ', file_k, '---------------------------------') 938 | 939 | 940 | MODEL_NAME = 's2_01_final_v4_b' 941 | 942 | 943 | for file_k in pre_final_dict: 944 | _ann_result = pd.read_csv(test_ann_file_list_dict.get(file_k), sep='\t', header=None, names=['id', 'p1', 'p2']) 945 | 946 | _ann_pre_final = np.asarray(pre_final_dict.get(file_k)) 947 | _ann_pre_final_df = pd.DataFrame(columns=['id', 'p1', 'p2']) 948 | _ann_pre_final_df['id'] = _ann_pre_final[:, 0] 949 | _ann_pre_final_df['p1'] = _ann_pre_final[:, 1] 950 | 951 | _final_df = _ann_result.append(_ann_pre_final_df, ignore_index=True) 952 | 953 | 954 | 955 | for file_k in pre_final_dict: 956 | _ann_result = pd.read_csv(test_ann_file_list_dict.get(file_k), sep='\t', header=None, names=['id', 'p1', 'p2']) 957 | 958 | _ann_pre_final = np.asarray(pre_final_dict.get(file_k)) 959 | _ann_pre_final_df = pd.DataFrame(columns=['id', 'p1', 'p2']) 960 | _ann_pre_final_df['id'] = _ann_pre_final[:, 0] 961 | _ann_pre_final_df['p1'] = _ann_pre_final[:, 1] 962 | _ann_pre_final_df['p1'] = _ann_pre_final_df['p1'].str.replace('SideEff_Drug', 'SideEff-Drug') 963 | 964 | _final_df = _ann_result.append(_ann_pre_final_df, ignore_index=True) 965 | 966 | _final_df.to_csv('./Demo/result/' + MODEL_NAME + '_raw/' + str(file_k) + '.ann', index=False, header=None, sep='\t') 967 | 968 | 969 | 970 | 971 | 972 | 973 | 974 | 975 | --------------------------------------------------------------------------------