├── CNN.ipynb └── README.md /CNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using TensorFlow backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import gc\n", 18 | "import json\n", 19 | "import pandas as pd\n", 20 | "import numpy as np\n", 21 | "import gensim\n", 22 | "from gensim.models import Word2Vec\n", 23 | "from gensim.models.word2vec import LineSentence\n", 24 | "import tensorflow as tf\n", 25 | "import keras\n", 26 | "from keras.layers import *\n", 27 | "from keras.models import *\n", 28 | "from keras.optimizers import *\n", 29 | "from keras.callbacks import *\n", 30 | "from keras.preprocessing import text, sequence\n", 31 | "from keras.utils import to_categorical\n", 32 | "from sklearn.preprocessing import MultiLabelBinarizer\n", 33 | "from sklearn.model_selection import KFold\n", 34 | "import dask.dataframe as dd\n", 35 | "import warnings\n", 36 | "warnings.filterwarnings(\"ignore\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "train=pd.read_table(\"../smalldata/train_word.txt\",sep=\"\\n\",names=[\"fact\"])\n", 46 | "val=pd.read_table(\"../smalldata/val_word.txt\",sep=\"\\n\",names=[\"fact\"])\n", 47 | "test=pd.read_table(\"../smalldata/test_word.txt\",sep=\"\\n\",names=[\"fact\"])\n", 48 | "\n", 49 | "\n", 50 | "train_label=pd.read_table(\"../smalldata/train_label\",header=None,sep=\"\\n\")\n", 51 | "val_label=pd.read_table(\"../smalldata/val_label\",header=None,sep=\"\\n\")\n", 52 | "test_label=pd.read_table(\"../smalldata/test_label\",header=None,sep=\"\\n\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "(154592, 1) (17131, 1) (32508, 1) (154592, 1) (17131, 1) (32508, 1)\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "print(train.shape,val.shape,test.shape,train_label.shape,val_label.shape,test_label.shape)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "step 1:transform the text into idx\n", 82 | "step 1 run finish\n", 83 | "step 2: train the word2vec\n", 84 | "step 2 run finish\n", 85 | "step 3: transform into embedding matrix\n", 86 | "step 3 run finish\n", 87 | "(579810, 300) (154592, 400) 0.1985374519239061\n", 88 | "word to idx finish\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "def w2v_pad(train,val, test, maxlen_,victor_size):\n", 94 | " max_features = 50000\n", 95 | " count = 0\n", 96 | " #第一步:将字或者词转化为id\n", 97 | " print(\"step 1:transform the text into idx\")\n", 98 | " \n", 99 | " tokenizer = text.Tokenizer(num_words=max_features, lower=True)\n", 100 | " tokenizer.fit_on_texts(pd.concat([train,val,test])['fact'].tolist())\n", 101 | " train_ = sequence.pad_sequences(tokenizer.texts_to_sequences(train['fact'].tolist()), maxlen=maxlen_)\n", 102 | " val_ = sequence.pad_sequences(tokenizer.texts_to_sequences(val['fact'].tolist()), maxlen=maxlen_)\n", 103 | " test_ = sequence.pad_sequences(tokenizer.texts_to_sequences(test['fact'].tolist()), maxlen=maxlen_) \n", 104 | " word_index = tokenizer.word_index\n", 105 | "\n", 106 | " print(\"step 1 run finish\")\n", 107 | " \n", 108 | " #第二步:训练词向量\n", 109 | " print(\"step 2: train the word2vec\")\n", 110 | " file_name = '../cache/' + 'Word2Vec_' + str(victor_size) +'.model'\n", 111 | " if not os.path.exists(file_name):\n", 112 | " print(\"train word2vec\")\n", 113 | " model = Word2Vec([line.split(\" \") for line in (pd.concat([train,val,test])['fact'].tolist())],\n", 114 | " size=victor_size, window=5, iter=15, workers=11, seed=2018, min_count=5)\n", 115 | " model.save(file_name)\n", 116 | " else:\n", 117 | " model = Word2Vec.load(file_name)\n", 118 | " print(\"step 2 run finish\")\n", 119 | " \n", 120 | " \n", 121 | " #第三步:将其转化为matrix\n", 122 | " print(\"step 3: transform into embedding matrix\")\n", 123 | "\n", 124 | " embedding_matrix = np.zeros((len(word_index) + 1, victor_size))\n", 125 | " for word, i in word_index.items():\n", 126 | " embedding_vector = model[word] if word in model else None\n", 127 | " if embedding_vector is not None:\n", 128 | " count += 1\n", 129 | " embedding_matrix[i] = embedding_vector\n", 130 | " else:\n", 131 | " unk_vec = np.random.random(victor_size) * 0.5\n", 132 | " unk_vec = unk_vec - unk_vec.mean()\n", 133 | " embedding_matrix[i] = unk_vec\n", 134 | "\n", 135 | " print(\"step 3 run finish\")\n", 136 | " \n", 137 | " print(embedding_matrix.shape, train_.shape, count * 1.0 / embedding_matrix.shape[0]) \n", 138 | " return train_,val_,test_, word_index, embedding_matrix\n", 139 | "\n", 140 | "word_seq_len=400\n", 141 | "victor_size=300\n", 142 | "\n", 143 | "train_,val_,test_, word2idx, word_embedding = w2v_pad(train,val,test, word_seq_len,victor_size)\n", 144 | "print(\"word to idx finish\")\n", 145 | "\n" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "(154592, 202) (32508, 202)\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "train_label[0]=train_label[0].map(lambda x:x.split(\" \"))\n", 170 | "val_label[0]=val_label[0].map(lambda x:x.split(\" \"))\n", 171 | "test_label[0]=test_label[0].map(lambda x:x.split(\" \"))\n", 172 | "\n", 173 | "mlb = MultiLabelBinarizer()\n", 174 | "mlb.fit(train_label[0].tolist()+test_label[0].tolist()+val_label[0].tolist())\n", 175 | "\n", 176 | "train_label=mlb.transform(train_label[0].tolist())\n", 177 | "val_label=mlb.transform(val_label[0].tolist())\n", 178 | "test_label=mlb.transform(test_label[0].tolist())\n", 179 | "label_name = mlb.classes_\n", 180 | "\n", 181 | "print(train_label.shape,test_label.shape)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "model 部分" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "def TextCNN(sent_length,embeddings_weight):\n", 205 | " content = Input(shape=(sent_length,), dtype='int32')\n", 206 | " embedding = Embedding(\n", 207 | " name=\"word_embedding\",\n", 208 | " input_dim=embeddings_weight.shape[0],\n", 209 | " weights=[embeddings_weight],\n", 210 | " output_dim=embeddings_weight.shape[1],\n", 211 | " trainable=False)\n", 212 | " x=embedding(content)\n", 213 | " conv1 = Conv1D(filters=64, kernel_size=1, padding='same')(x)\n", 214 | " conv1 = MaxPool1D(pool_size=32)(conv1)\n", 215 | " \n", 216 | " \n", 217 | " conv2 = Conv1D(filters=64, kernel_size=2, padding='same')(x)\n", 218 | " conv2 = MaxPool1D(pool_size=32)(conv2)\n", 219 | " \n", 220 | " conv3 = Conv1D(filters=64, kernel_size=3, padding='same')(x)\n", 221 | " conv3 = MaxPool1D(pool_size=32)(conv3)\n", 222 | " \n", 223 | " conv4 = Conv1D(filters=64, kernel_size=4, padding='same')(x)\n", 224 | " conv4 = MaxPool1D(pool_size=32)(conv4)\n", 225 | " \n", 226 | " cnn = concatenate([conv1, conv2, conv3, conv4], axis=-1)\n", 227 | " fc = Flatten()(cnn)\n", 228 | "\n", 229 | " #fc layer\n", 230 | " fc=Dense(512)(fc)\n", 231 | "# fc=BatchNormalization()(fc)\n", 232 | " fc=Activation(activation=\"relu\")(fc)\n", 233 | "# fc = Dropout(0.2)(fc)\n", 234 | " \n", 235 | " fc=Dense(256)(fc)\n", 236 | " #fc=BatchNormalization()(fc)\n", 237 | " fc = Activation(activation=\"relu\")(fc)\n", 238 | " output = Dense(202, activation=\"softmax\")(fc)\n", 239 | " \n", 240 | " \n", 241 | " model = Model(inputs=content, outputs=output)\n", 242 | " model.compile(loss= \"categorical_crossentropy\",\n", 243 | " optimizer='adam',\n", 244 | " metrics=['accuracy'])\n", 245 | " model.summary()\n", 246 | " return model" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 12, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "__________________________________________________________________________________________________\n", 259 | "Layer (type) Output Shape Param # Connected to \n", 260 | "==================================================================================================\n", 261 | "input_1 (InputLayer) (None, 400) 0 \n", 262 | "__________________________________________________________________________________________________\n", 263 | "word_embedding (Embedding) (None, 400, 300) 173943000 input_1[0][0] \n", 264 | "__________________________________________________________________________________________________\n", 265 | "conv1d_1 (Conv1D) (None, 400, 64) 19264 word_embedding[0][0] \n", 266 | "__________________________________________________________________________________________________\n", 267 | "conv1d_2 (Conv1D) (None, 400, 64) 38464 word_embedding[0][0] \n", 268 | "__________________________________________________________________________________________________\n", 269 | "conv1d_3 (Conv1D) (None, 400, 64) 57664 word_embedding[0][0] \n", 270 | "__________________________________________________________________________________________________\n", 271 | "conv1d_4 (Conv1D) (None, 400, 64) 76864 word_embedding[0][0] \n", 272 | "__________________________________________________________________________________________________\n", 273 | "max_pooling1d_1 (MaxPooling1D) (None, 12, 64) 0 conv1d_1[0][0] \n", 274 | "__________________________________________________________________________________________________\n", 275 | "max_pooling1d_2 (MaxPooling1D) (None, 12, 64) 0 conv1d_2[0][0] \n", 276 | "__________________________________________________________________________________________________\n", 277 | "max_pooling1d_3 (MaxPooling1D) (None, 12, 64) 0 conv1d_3[0][0] \n", 278 | "__________________________________________________________________________________________________\n", 279 | "max_pooling1d_4 (MaxPooling1D) (None, 12, 64) 0 conv1d_4[0][0] \n", 280 | "__________________________________________________________________________________________________\n", 281 | "concatenate_1 (Concatenate) (None, 12, 256) 0 max_pooling1d_1[0][0] \n", 282 | " max_pooling1d_2[0][0] \n", 283 | " max_pooling1d_3[0][0] \n", 284 | " max_pooling1d_4[0][0] \n", 285 | "__________________________________________________________________________________________________\n", 286 | "flatten_1 (Flatten) (None, 3072) 0 concatenate_1[0][0] \n", 287 | "__________________________________________________________________________________________________\n", 288 | "dense_1 (Dense) (None, 512) 1573376 flatten_1[0][0] \n", 289 | "__________________________________________________________________________________________________\n", 290 | "activation_1 (Activation) (None, 512) 0 dense_1[0][0] \n", 291 | "__________________________________________________________________________________________________\n", 292 | "dense_2 (Dense) (None, 256) 131328 activation_1[0][0] \n", 293 | "__________________________________________________________________________________________________\n", 294 | "activation_2 (Activation) (None, 256) 0 dense_2[0][0] \n", 295 | "__________________________________________________________________________________________________\n", 296 | "dense_3 (Dense) (None, 202) 51914 activation_2[0][0] \n", 297 | "==================================================================================================\n", 298 | "Total params: 175,891,874\n", 299 | "Trainable params: 1,948,874\n", 300 | "Non-trainable params: 173,943,000\n", 301 | "__________________________________________________________________________________________________\n", 302 | "Train on 154592 samples, validate on 17131 samples\n", 303 | "Epoch 1/20\n", 304 | " - 32s - loss: 2.5675 - acc: 0.5186 - val_loss: 1.1968 - val_acc: 0.6976\n", 305 | "Epoch 2/20\n", 306 | " - 31s - loss: 1.5582 - acc: 0.6940 - val_loss: 1.0510 - val_acc: 0.7337\n", 307 | "Epoch 3/20\n", 308 | " - 31s - loss: 1.3956 - acc: 0.7251 - val_loss: 0.9766 - val_acc: 0.7561\n", 309 | "Epoch 4/20\n", 310 | " - 31s - loss: 1.3000 - acc: 0.7421 - val_loss: 0.9668 - val_acc: 0.7554\n", 311 | "Epoch 5/20\n", 312 | " - 31s - loss: 1.2293 - acc: 0.7549 - val_loss: 0.9545 - val_acc: 0.7656\n", 313 | "Epoch 6/20\n", 314 | " - 31s - loss: 1.1659 - acc: 0.7651 - val_loss: 0.9811 - val_acc: 0.7601\n", 315 | "Epoch 7/20\n", 316 | " - 31s - loss: 1.1206 - acc: 0.7736 - val_loss: 0.9912 - val_acc: 0.7568\n", 317 | "Epoch 8/20\n", 318 | " - 31s - loss: 1.0721 - acc: 0.7834 - val_loss: 1.0040 - val_acc: 0.7584\n", 319 | "Epoch 9/20\n", 320 | " - 31s - loss: 1.0430 - acc: 0.7866 - val_loss: 1.0245 - val_acc: 0.7546\n", 321 | "Epoch 10/20\n", 322 | " - 31s - loss: 1.0085 - acc: 0.7922 - val_loss: 1.0287 - val_acc: 0.7644\n", 323 | "Epoch 11/20\n", 324 | " - 31s - loss: 0.9755 - acc: 0.7997 - val_loss: 1.0434 - val_acc: 0.7667\n", 325 | "Epoch 12/20\n", 326 | " - 31s - loss: 0.9536 - acc: 0.8028 - val_loss: 1.0779 - val_acc: 0.7625\n", 327 | "Epoch 13/20\n", 328 | " - 31s - loss: 0.9274 - acc: 0.8077 - val_loss: 1.0662 - val_acc: 0.7593\n", 329 | "Epoch 14/20\n", 330 | " - 31s - loss: 0.9044 - acc: 0.8119 - val_loss: 1.1029 - val_acc: 0.7705\n", 331 | "Epoch 15/20\n", 332 | " - 31s - loss: 0.9006 - acc: 0.8129 - val_loss: 1.1115 - val_acc: 0.7613\n", 333 | "Epoch 16/20\n", 334 | " - 31s - loss: 0.8792 - acc: 0.8163 - val_loss: 1.1514 - val_acc: 0.7576\n", 335 | "Epoch 17/20\n", 336 | " - 31s - loss: 0.8638 - acc: 0.8182 - val_loss: 1.1479 - val_acc: 0.7636\n", 337 | "Epoch 18/20\n", 338 | " - 31s - loss: 0.8555 - acc: 0.8203 - val_loss: 1.1370 - val_acc: 0.7696\n", 339 | "Epoch 19/20\n", 340 | " - 31s - loss: 0.8434 - acc: 0.8220 - val_loss: 1.1845 - val_acc: 0.7582\n", 341 | "Epoch 20/20\n", 342 | " - 31s - loss: 0.8328 - acc: 0.8242 - val_loss: 1.1786 - val_acc: 0.7503\n" 343 | ] 344 | }, 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "" 349 | ] 350 | }, 351 | "execution_count": 12, 352 | "metadata": {}, 353 | "output_type": "execute_result" 354 | } 355 | ], 356 | "source": [ 357 | "from sklearn.model_selection import train_test_split\n", 358 | "\n", 359 | "\n", 360 | "file_path = \"../model/TextCNN.hdf\"\n", 361 | "\n", 362 | "model = TextCNN(word_seq_len, word_embedding)\n", 363 | "\n", 364 | "early_stopping = EarlyStopping(monitor='val_acc', patience=6)\n", 365 | "plateau = ReduceLROnPlateau(monitor=\"val_acc\", verbose=1, mode='max', factor=0.5, patience=3)\n", 366 | "checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max',save_weights_only=False)\n", 367 | "#if not os.path.exists(file_path):\n", 368 | "model.fit(train_, train_label,\n", 369 | " epochs=20,\n", 370 | " batch_size=128,\n", 371 | " validation_data=(val_, val_label),\n", 372 | " #callbacks=[early_stopping, plateau, checkpoint],\n", 373 | " verbose=2)\n", 374 | "# else:\n", 375 | "# model.load_weights(file_path)\n" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 16, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "test_pred=model.predict(test_,batch_size=128)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 14, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "sklearn Macro-F1-Score: 0.5554373854255992\n", 397 | "sklearn Macro-precision-Score: 0.6289731656984494\n", 398 | "sklearn Macro-recall-Score: 0.5299730109888571\n", 399 | "sklearn hamming_loss: 0.0026159592703456393\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "#仅预测1个label\n", 405 | "from sklearn.preprocessing import label_binarize\n", 406 | "from sklearn.metrics import *\n", 407 | "test_pred_onelabel=label_binarize(np.argmax(test_pred,axis=1),classes=list(range(0,202)))\n", 408 | "print('sklearn Macro-F1-Score:', f1_score(test_label, test_pred_onelabel, average='macro'))\n", 409 | "print('sklearn Macro-precision-Score:', precision_score(test_label, test_pred_onelabel, average='macro'))\n", 410 | "print('sklearn Macro-recall-Score:', recall_score(test_label, test_pred_onelabel, average='macro'))\n", 411 | "print('sklearn hamming_loss:', hamming_loss(test_label, test_pred_onelabel))" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 15, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "0.0\n", 424 | "sklearn Macro-F1-Score: 0.010407896254786374\n", 425 | "sklearn Macro-precision-Score: 0.005287350440470405\n", 426 | "sklearn Macro-recall-Score: 1.0\n", 427 | "sklearn hamming_loss: 0.9947126495595296\n", 428 | "0.05\n", 429 | "sklearn Macro-F1-Score: 0.5144346725699229\n", 430 | "sklearn Macro-precision-Score: 0.43670731440793775\n", 431 | "sklearn Macro-recall-Score: 0.6900723564290342\n", 432 | "sklearn hamming_loss: 0.005386640546668178\n", 433 | "0.1\n", 434 | "sklearn Macro-F1-Score: 0.5554054258904327\n", 435 | "sklearn Macro-precision-Score: 0.5098931264996042\n", 436 | "sklearn Macro-recall-Score: 0.6542859792928142\n", 437 | "sklearn hamming_loss: 0.0038505373239428042\n", 438 | "0.15000000000000002\n", 439 | "sklearn Macro-F1-Score: 0.571610897014695\n", 440 | "sklearn Macro-precision-Score: 0.5532853513895253\n", 441 | "sklearn Macro-recall-Score: 0.6283328974819034\n", 442 | "sklearn hamming_loss: 0.0032223598882590364\n", 443 | "0.2\n", 444 | "sklearn Macro-F1-Score: 0.5761391952984019\n", 445 | "sklearn Macro-precision-Score: 0.5809014556283769\n", 446 | "sklearn Macro-recall-Score: 0.6054050710455529\n", 447 | "sklearn hamming_loss: 0.002887331922561027\n", 448 | "0.25\n", 449 | "sklearn Macro-F1-Score: 0.5797097039336729\n", 450 | "sklearn Macro-precision-Score: 0.6074361284616707\n", 451 | "sklearn Macro-recall-Score: 0.5886154469071126\n", 452 | "sklearn hamming_loss: 0.00268418314699687\n", 453 | "0.30000000000000004\n", 454 | "sklearn Macro-F1-Score: 0.5771798432156464\n", 455 | "sklearn Macro-precision-Score: 0.6215385591987598\n", 456 | "sklearn Macro-recall-Score: 0.5702052420226028\n", 457 | "sklearn hamming_loss: 0.0025672279298804743\n", 458 | "0.35000000000000003\n", 459 | "sklearn Macro-F1-Score: 0.568575409192357\n", 460 | "sklearn Macro-precision-Score: 0.6317458650705322\n", 461 | "sklearn Macro-recall-Score: 0.5485184665540663\n", 462 | "sklearn hamming_loss: 0.002498699482351336\n", 463 | "0.4\n", 464 | "sklearn Macro-F1-Score: 0.5671974019786552\n", 465 | "sklearn Macro-precision-Score: 0.6515112128267819\n", 466 | "sklearn Macro-recall-Score: 0.5344357993022667\n", 467 | "sklearn hamming_loss: 0.0024718972450954954\n", 468 | "0.45\n", 469 | "sklearn Macro-F1-Score: 0.5577397673345604\n", 470 | "sklearn Macro-precision-Score: 0.6682589241301129\n", 471 | "sklearn Macro-recall-Score: 0.5159222150452262\n", 472 | "sklearn hamming_loss: 0.0024891054996972564\n", 473 | "0.5\n", 474 | "sklearn Macro-F1-Score: 0.5453895338340395\n", 475 | "sklearn Macro-precision-Score: 0.6669357700383469\n", 476 | "sklearn Macro-recall-Score: 0.4969160832737277\n", 477 | "sklearn hamming_loss: 0.0025255017196071767\n", 478 | "0.55\n", 479 | "sklearn Macro-F1-Score: 0.53501335279995\n", 480 | "sklearn Macro-precision-Score: 0.6780075136233795\n", 481 | "sklearn Macro-recall-Score: 0.4796162984596181\n", 482 | "sklearn hamming_loss: 0.0025896138894066596\n", 483 | "0.6000000000000001\n", 484 | "sklearn Macro-F1-Score: 0.5238740995431711\n", 485 | "sklearn Macro-precision-Score: 0.6797945559663356\n", 486 | "sklearn Macro-recall-Score: 0.46388740405137774\n", 487 | "sklearn hamming_loss: 0.0026616449020317313\n", 488 | "0.65\n", 489 | "sklearn Macro-F1-Score: 0.5112305582530124\n", 490 | "sklearn Macro-precision-Score: 0.6829867344170846\n", 491 | "sklearn Macro-recall-Score: 0.44727371051374276\n", 492 | "sklearn hamming_loss: 0.002722559077613188\n", 493 | "0.7000000000000001\n", 494 | "sklearn Macro-F1-Score: 0.4971380770062703\n", 495 | "sklearn Macro-precision-Score: 0.6845154393536712\n", 496 | "sklearn Macro-recall-Score: 0.42919270412661004\n", 497 | "sklearn hamming_loss: 0.0027772295501975446\n", 498 | "0.75\n", 499 | "sklearn Macro-F1-Score: 0.4878045023129802\n", 500 | "sklearn Macro-precision-Score: 0.6860507591502716\n", 501 | "sklearn Macro-recall-Score: 0.41363271490133896\n", 502 | "sklearn hamming_loss: 0.0028518494152848287\n", 503 | "0.8\n", 504 | "sklearn Macro-F1-Score: 0.47491985901886397\n", 505 | "sklearn Macro-precision-Score: 0.687245032405263\n", 506 | "sklearn Macro-recall-Score: 0.3963256610055329\n", 507 | "sklearn hamming_loss: 0.0029395658281221254\n", 508 | "0.8500000000000001\n", 509 | "sklearn Macro-F1-Score: 0.46010551338811884\n", 510 | "sklearn Macro-precision-Score: 0.6986568247096664\n", 511 | "sklearn Macro-recall-Score: 0.3766995483465471\n", 512 | "sklearn hamming_loss: 0.0030510387694361904\n", 513 | "0.9\n", 514 | "sklearn Macro-F1-Score: 0.43737922585964834\n", 515 | "sklearn Macro-precision-Score: 0.7009159195823822\n", 516 | "sklearn Macro-recall-Score: 0.35154448670039545\n", 517 | "sklearn hamming_loss: 0.003190227660639818\n", 518 | "0.9500000000000001\n", 519 | "sklearn Macro-F1-Score: 0.3972693505915991\n", 520 | "sklearn Macro-precision-Score: 0.6815773014480535\n", 521 | "sklearn Macro-recall-Score: 0.3099583875753521\n", 522 | "sklearn hamming_loss: 0.003418198962753418\n" 523 | ] 524 | } 525 | ], 526 | "source": [ 527 | "from sklearn.metrics import *\n", 528 | "\n", 529 | "\n", 530 | "for i in np.arange(0,1,0.05):\n", 531 | " print(i)\n", 532 | " temp=test_pred.copy()\n", 533 | " temp[temp=i]=1\n", 535 | " print('sklearn Macro-F1-Score:', f1_score(test_label, temp, average='macro'))\n", 536 | " print('sklearn Macro-precision-Score:', precision_score(test_label, temp, average='macro'))\n", 537 | " print('sklearn Macro-recall-Score:', recall_score(test_label, temp, average='macro'))\n", 538 | " print('sklearn hamming_loss:', hamming_loss(test_label, temp))\n", 539 | " " 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "metadata": {}, 553 | "outputs": [], 554 | "source": [] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [] 562 | } 563 | ], 564 | "metadata": { 565 | "kernelspec": { 566 | "display_name": "Python 3", 567 | "language": "python", 568 | "name": "python3" 569 | }, 570 | "language_info": { 571 | "codemirror_mode": { 572 | "name": "ipython", 573 | "version": 3 574 | }, 575 | "file_extension": ".py", 576 | "mimetype": "text/x-python", 577 | "name": "python", 578 | "nbconvert_exporter": "python", 579 | "pygments_lexer": "ipython3", 580 | "version": "3.6.8" 581 | } 582 | }, 583 | "nbformat": 4, 584 | "nbformat_minor": 2 585 | } 586 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Legal-Charge-Prediction 2 | 法律智能 ,NLP技术赋能法律判案 3 | 4 | # 背景介绍 5 | 6 | 近年来,随着深度学习、自然语言处理 等人工智能技术取得了巨大的突破,开始在法律领域崭露头角,受到学术界和产业界的广泛的关注。AI赋能法律旨在利用机器阅读理解法律文本与定量分析的能力,完成罪名预测、法律条款推荐、刑期预测等具有实际应用需求的任务,从而辅助法官、律师等相关认识更加高效地进行法律判决。 7 | 8 | # 数据集 9 | 数据集下载网站: https://cail.oss-cn-qingdao.aliyuncs.com/CAIL2018_ALL_DATA.zip 10 | 11 | 数据集主要分为两部分: 练习数据集 和 正赛数据集 12 | 13 | --------------------------------------------------------------------------------