├── AudioPreprocess.ipynb ├── README.md ├── classfication ├── .DS_Store ├── BiLSTM.ipynb ├── cnn_audio.ipynb └── fusion_net.ipynb ├── cnn_audio_reg_avid.py ├── regression ├── BiLSTM.ipynb ├── cnn_audio_reg.py └── fusion_net_reg.ipynb └── requirements.txt /AudioPreprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 57, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "FileNotFoundError", 10 | "evalue": "[Errno 2] No such file or directory: '/Users/apple/Downloads/Audio/Development/Freeform/206_1_Freeform_audio.mp4'", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 14 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 15 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtrim_ms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0msound\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAudioSegment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/Users/apple/Downloads/Audio/Development/Freeform/206_1_Freeform_audio.mp4\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"mp4\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msound\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 16 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/pydub/audio_segment.py\u001b[0m in \u001b[0;36mfrom_file\u001b[0;34m(cls, file, format, codec, parameters, **kwargs)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_fd_or_path_or_tempfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtempfile\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 17 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/pydub/utils.py\u001b[0m in \u001b[0;36m_fd_or_path_or_tempfile\u001b[0;34m(fd, mode, tempfile)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbasestring\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mfd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 18 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/Users/apple/Downloads/Audio/Development/Freeform/206_1_Freeform_audio.mp4'" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "from pydub import AudioSegment\n", 24 | "\n", 25 | "def detect_leading_silence(sound, silence_threshold=-30.0, chunk_size=10):\n", 26 | " '''\n", 27 | " sound is a pydub.AudioSegment\n", 28 | " silence_threshold in dB\n", 29 | " chunk_size in ms\n", 30 | "\n", 31 | " iterate over chunks until you find the first one with sound\n", 32 | " '''\n", 33 | " trim_ms = 0 # ms\n", 34 | "\n", 35 | " assert chunk_size > 0 # to avoid infinite loop\n", 36 | " while sound[trim_ms:trim_ms+chunk_size].dBFS < silence_threshold and trim_ms < len(sound):\n", 37 | " trim_ms += chunk_size\n", 38 | "\n", 39 | " return trim_ms\n", 40 | "\n", 41 | "sound = AudioSegment.from_file(\"/Users/apple/Downloads/Audio/Development/Freeform/206_1_Freeform_audio.mp4\", format=\"mp4\")\n", 42 | "print(type(sound))\n", 43 | "\n", 44 | "start_trim = detect_leading_silence(sound)\n", 45 | "end_trim = detect_leading_silence(sound.reverse())\n", 46 | "\n", 47 | "duration = len(sound) \n", 48 | "trimmed_sound = sound[start_trim:duration-end_trim]\n", 49 | "trimmed_sound.export('test.mp4', format=\"mp4\")\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 58, 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "import os\n", 61 | "\n", 62 | "def del_sil_in_batch(path, save_path):\n", 63 | " file_list = os.listdir(path)\n", 64 | " file_list.sort()\n", 65 | " if '.DS_Store' in file_list:\n", 66 | " file_list.remove('.DS_Store')\n", 67 | " if '.mp4' in file_list:\n", 68 | " file_list.remove('.mp4')\n", 69 | " for file in file_list:\n", 70 | " if os.path.isfile(os.path.join(path, file)):\n", 71 | " sound = AudioSegment.from_file(os.path.join(path, file), format=\"wav\")\n", 72 | " start_trim = detect_leading_silence(sound)\n", 73 | " end_trim = detect_leading_silence(sound.reverse())\n", 74 | " duration = len(sound) \n", 75 | " trimmed_sound = sound[start_trim:duration-end_trim]\n", 76 | " trimmed_sound.export(os.path.join(save_path, file[:5])+'.wav', format=\"wav\")\n", 77 | "# sound.export(os.path.join(save_path, file[:5])+'.wav', format=\"wav\")\n", 78 | " \n", 79 | "paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Freeform', '/Users/apple/Downloads/depression/AViD/Audio/Training/Freeform', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Freeform']\n", 80 | "save_paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim']\n", 81 | "\n", 82 | "for i in range(len(paths)):\n", 83 | " del_sil_in_batch(paths[i], save_paths[i])\n", 84 | " " 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 43, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/310_2_Depression.csv\n", 97 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/315_3_Depression.csv\n", 98 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/344_2_Depression.csv\n", 99 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/365_1_Depression.csv\n", 100 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/246_1_Depression.csv\n", 101 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/335_2_Depression.csv\n", 102 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/325_2_Depression.csv\n", 103 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/245_3_Depression.csv\n", 104 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/206_2_Depression.csv\n", 105 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/357_1_Depression.csv\n", 106 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/245_1_Depression.csv\n", 107 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/220_1_Depression.csv\n", 108 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/334_1_Depression.csv\n", 109 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/368_1_Depression.csv\n", 110 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/319_2_Depression.csv\n", 111 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/246_2_Depression.csv\n", 112 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/250_1_Depression.csv\n", 113 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/359_1_Depression.csv\n", 114 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/211_2_Depression.csv\n", 115 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/341_1_Depression.csv\n", 116 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/240_3_Depression.csv\n", 117 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/342_3_Depression.csv\n", 118 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/338_1_Depression.csv\n", 119 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/328_1_Depression.csv\n", 120 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/244_2_Depression.csv\n", 121 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/317_4_Depression.csv\n", 122 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/364_1_Depression.csv\n", 123 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/212_1_Depression.csv\n", 124 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/341_2_Depression.csv\n", 125 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/347_2_Depression.csv\n", 126 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/345_3_Depression.csv\n", 127 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/226_2_Depression.csv\n", 128 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/234_2_Depression.csv\n", 129 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/242_1_Depression.csv\n", 130 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/357_2_Depression.csv\n", 131 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/203_2_Depression.csv\n", 132 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/346_1_Depression.csv\n", 133 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/244_3_Depression.csv\n", 134 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/343_1_Depression.csv\n", 135 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/237_1_Depression.csv\n", 136 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/249_1_Depression.csv\n", 137 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/218_3_Depression.csv\n", 138 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/220_3_Depression.csv\n", 139 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/214_3_Depression.csv\n", 140 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/247_3_Depression.csv\n", 141 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/315_2_Depression.csv\n", 142 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/210_2_Depression.csv\n", 143 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/224_1_Depression.csv\n", 144 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/247_1_Depression.csv\n", 145 | "True /Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels/236_3_Depression.csv\n", 146 | "[['203_1', '205_2', '207_2', '208_2', '209_1', '213_1', '214_1', '215_2', '215_3', '217_2', '217_3', '219_1', '219_3', '223_1', '223_2', '225_2', '226_1', '227_2', '228_1', '229_2', '230_1', '232_1', '233_1', '234_3', '236_1', '237_3', '238_2', '239_1', '240_1', '240_2', '241_2', '242_3', '243_1', '306_3', '308_3', '310_4', '312_2', '317_1', '317_3', '318_2', '318_3', '320_1', '320_2', '321_2', '322_1', '324_1', '329_1', '331_1', '332_2', '332_4'], ['3', '3', '10', '4', '6', '33', '11', '9', '5', '30', '32', '33', '19', '0', '0', '15', '35', '13', '1', '17', '20', '10', '4', '23', '23', '37', '37', '22', '14', '16', '44', '9', '41', '19', '21', '19', '0', '24', '17', '12', '7', '3', '8', '11', '0', '0', '24', '5', '2', '0']]\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "import numpy as np\n", 152 | "import pickle\n", 153 | "\n", 154 | "def generate_labels(path):\n", 155 | " file_list = os.listdir(path)\n", 156 | " file_list.sort()\n", 157 | " if '.DS_Store' in file_list:\n", 158 | " file_list.remove('.DS_Store')\n", 159 | " nums, labels = [], []\n", 160 | " for file in file_list:\n", 161 | " if os.path.isfile(os.path.join(path, file)):\n", 162 | " nums.append(file[:5])\n", 163 | " with open(os.path.join(path, file)) as f:\n", 164 | " labels.append(f.readlines()[0].split('\\n')[0])\n", 165 | " return [nums, labels]\n", 166 | "\n", 167 | "paths = ['/Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Development_DepressionLabels',\n", 168 | " '/Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Testing_DepressionLabels',\n", 169 | " '/Users/apple/Downloads/depression/AViD/AVEC2014_DepressionLabels/Training_DepressionLabels']\n", 170 | "\n", 171 | "info = {}\n", 172 | "info['dev'] = generate_labels(paths[0])\n", 173 | "info['train'] = generate_labels(paths[2])\n", 174 | "info['test'] = generate_labels(paths[1])\n", 175 | "test_index = list(set(info['test'][0])-set(info['dev'][0])-set(info['train'][0]))\n", 176 | "new_test_labels = []\n", 177 | "path = paths[1]\n", 178 | "for file in test_index:\n", 179 | " print(os.path.isfile(os.path.join(path, file)+'_Depression.csv'), os.path.join(path, file)+'_Depression.csv')\n", 180 | " if os.path.isfile(os.path.join(path, file)+'_Depression.csv'):\n", 181 | " with open(os.path.join(path, file)+'_Depression.csv') as f:\n", 182 | " new_test_labels.append(f.readlines()[0].split('\\n')[0])\n", 183 | "info['test'][0] = test_index\n", 184 | "info['test'][1] = new_test_labels\n", 185 | "\n", 186 | "print(info['train'])\n", 187 | "\n", 188 | "# np.save('avid_info.npy', info)\n", 189 | "with open('avid_info.pkl', 'wb') as f:\n", 190 | " pickle.dump(info, f)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 59, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "/Users/apple/Downloads/depression/AViD/Audio/Development/Trim/241_1.wav\n", 203 | "/Users/apple/Downloads/depression/AViD/Audio/Development/Trim/241_3.wav\n", 204 | "/Users/apple/Downloads/depression/AViD/Audio/Training/Trim/234_3.wav\n", 205 | "/Users/apple/Downloads/depression/AViD/Audio/Training/Trim/238_2.wav\n", 206 | "/Users/apple/Downloads/depression/AViD/Audio/Training/Trim/241_2.wav\n", 207 | "/Users/apple/Downloads/depression/AViD/Audio/Training/Trim/308_3.wav\n", 208 | "/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim/310_2.wav\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "import os\n", 214 | "\n", 215 | "def del_sil_in_batch(path, save_path):\n", 216 | " file_list = os.listdir(path)\n", 217 | " file_list.sort()\n", 218 | " if '.DS_Store' in file_list:\n", 219 | " file_list.remove('.DS_Store')\n", 220 | " if '.mp4' in file_list:\n", 221 | " file_list.remove('.mp4')\n", 222 | " for file in file_list:\n", 223 | " if os.path.isfile(os.path.join(path, file)):\n", 224 | " try:\n", 225 | " sound = AudioSegment.from_file(os.path.join(path, file), format=\"wav\")\n", 226 | " except:\n", 227 | " print(os.path.join(path, file))\n", 228 | " continue\n", 229 | " sound.export(os.path.join(save_path, file), format=\"wav\", bitrate='441k')\n", 230 | " \n", 231 | "paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim']\n", 232 | "save_paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim']\n", 233 | "\n", 234 | "for i in range(len(paths)):\n", 235 | " del_sil_in_batch(paths[i], save_paths[i])\n", 236 | " " 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 61, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "import sox\n", 246 | "\n", 247 | "def upsample_wav(path, save_path):\n", 248 | " file_list = os.listdir(path)\n", 249 | " file_list.sort()\n", 250 | " if '.DS_Store' in file_list:\n", 251 | " file_list.remove('.DS_Store')\n", 252 | " if '.mp4' in file_list:\n", 253 | " file_list.remove('.mp4')\n", 254 | " for file in file_list:\n", 255 | " if os.path.isfile(os.path.join(path, file)):\n", 256 | " tfm = sox.Transformer()\n", 257 | " tfm.rate(44100)\n", 258 | " try:\n", 259 | " tfm.build(os.path.join(path, file), os.path.join(save_path, file))\n", 260 | " except:\n", 261 | " print(os.path.join(path, file))\n", 262 | " continue\n", 263 | " \n", 264 | "paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim']\n", 265 | "save_paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/tmp', '/Users/apple/Downloads/depression/AViD/Audio/Training/tmp', '/Users/apple/Downloads/depression/AViD/Audio/Testing/tmp']\n", 266 | "\n", 267 | "for i in range(len(paths)):\n", 268 | " upsample_wav(paths[i], save_paths[i])\n", 269 | " " 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "def del_copies(path, save_path):\n", 279 | " file_list = os.listdir(path)\n", 280 | " file_list.sort()\n", 281 | " if '.DS_Store' in file_list:\n", 282 | " file_list.remove('.DS_Store')\n", 283 | " if '.mp4' in file_list:\n", 284 | " file_list.remove('.mp4')\n", 285 | " for file in file_list:\n", 286 | " if os.path.isfile(os.path.join(path, file)):\n", 287 | " tfm = sox.Transformer()\n", 288 | " tfm.rate(44100)\n", 289 | " try:\n", 290 | " tfm.build(os.path.join(path, file), os.path.join(save_path, file))\n", 291 | " except:\n", 292 | " print(os.path.join(path, file))\n", 293 | " continue\n", 294 | " \n", 295 | "paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/Trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/Trim']\n", 296 | "save_paths = ['/Users/apple/Downloads/depression/AViD/Audio/Development/new_trim', '/Users/apple/Downloads/depression/AViD/Audio/Training/new_trim', '/Users/apple/Downloads/depression/AViD/Audio/Testing/new_trim']\n", 297 | "\n", 298 | "for i in range(len(paths)):\n", 299 | " upsample_wav(paths[i], save_paths[i])\n", 300 | " " 301 | ] 302 | } 303 | ], 304 | "metadata": { 305 | "kernelspec": { 306 | "display_name": "Python 3", 307 | "language": "python", 308 | "name": "python3" 309 | }, 310 | "language_info": { 311 | "codemirror_mode": { 312 | "name": "ipython", 313 | "version": 3 314 | }, 315 | "file_extension": ".py", 316 | "mimetype": "text/x-python", 317 | "name": "python", 318 | "nbconvert_exporter": "python", 319 | "pygments_lexer": "ipython3", 320 | "version": "3.7.3" 321 | } 322 | }, 323 | "nbformat": 4, 324 | "nbformat_minor": 2 325 | } 326 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DepressionDectection 2 | 3 | This depression detection work is run on **Jupyter notebook**. 4 | 5 | ### Dataset 6 | 7 | 1. DAIC: https://dcapswoz.ict.usc.edu/ 8 | 2. AVID: http://avec2013-db.sspnet.eu/ 9 | 3. To obtain both DAIC and AVID datasets, agreement forms should be signed and return to corresponding email address 10 | 11 | ### How to run 12 | 13 | ##### DAIC 14 | 15 | Classification model training code and regression model training code are provided in **classfication folder** and **regression folder**. To run the code, **prefix** in *BiLSTM.ipynb, cnn_audio.py and fusion_net.ipynb* should be set to the path where DAIC dataset placed. To run *fusion_net.ipynb*, path of trained **lstm_model** and **cnn_model** should be set corresponding path. 16 | 17 | ##### AVID 18 | 19 | Preprocessing code of audio recordings in AVID dataset is offered in **AudioPreprocess.ipynb**. You should change the paths to your AVID dataset path before runing. Regression model training code is provided in **cnn_audio_reg_avid.py**. The input required by the training code is preprocessed audio clips, which is saved in **avid_info.pkl**. Therefore, **prefix** varaible should be set to the path where **avid_info.pkl** stored during the preprocessing step. 20 | 21 | -------------------------------------------------------------------------------- /classfication/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linlemn/DepressionDectection/38e04e1bc5b54a63bb1685e06ff0542784322256/classfication/.DS_Store -------------------------------------------------------------------------------- /classfication/BiLSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 220, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "(36, 1024)\n", 15 | "(49, 1024)\n", 16 | "(46, 1024)\n", 17 | "(47, 1024)\n", 18 | "(42, 1024)\n", 19 | "(43, 1024)\n", 20 | "(48, 1024)\n", 21 | "(55, 1024)\n", 22 | "(50, 1024)\n", 23 | "(38, 1024)\n", 24 | "(44, 1024)\n", 25 | "(63, 1024)\n", 26 | "(59, 1024)\n", 27 | "(54, 1024)\n", 28 | "(44, 1024)\n", 29 | "(38, 1024)\n", 30 | "(56, 1024)\n", 31 | "(55, 1024)\n", 32 | "(41, 1024)\n", 33 | "(48, 1024)\n", 34 | "(46, 1024)\n", 35 | "(56, 1024)\n", 36 | "(41, 1024)\n", 37 | "(49, 1024)\n", 38 | "(48, 1024)\n", 39 | "(47, 1024)\n", 40 | "(63, 1024)\n", 41 | "(55, 1024)\n", 42 | "(37, 1024)\n", 43 | "(50, 1024)\n", 44 | "(56, 1024)\n", 45 | "(44, 1024)\n", 46 | "(51, 1024)\n", 47 | "(32, 1024)\n", 48 | "(43, 1024)\n", 49 | "(38, 1024)\n", 50 | "(37, 1024)\n", 51 | "(30, 1024)\n", 52 | "(41, 1024)\n", 53 | "(33, 1024)\n", 54 | "(42, 1024)\n", 55 | "(33, 1024)\n", 56 | "(42, 1024)\n", 57 | "(34, 1024)\n", 58 | "(28, 1024)\n", 59 | "(31, 1024)\n", 60 | "(25, 1024)\n", 61 | "(54, 1024)\n", 62 | "(43, 1024)\n", 63 | "(39, 1024)\n", 64 | "(46, 1024)\n", 65 | "(40, 1024)\n", 66 | "(35, 1024)\n", 67 | "(38, 1024)\n", 68 | "(34, 1024)\n", 69 | "(48, 1024)\n", 70 | "(42, 1024)\n", 71 | "(47, 1024)\n", 72 | "(43, 1024)\n", 73 | "(39, 1024)\n", 74 | "(54, 1024)\n", 75 | "(53, 1024)\n", 76 | "(43, 1024)\n", 77 | "(41, 1024)\n", 78 | "(29, 1024)\n", 79 | "(53, 1024)\n", 80 | "(35, 1024)\n", 81 | "(44, 1024)\n", 82 | "(51, 1024)\n", 83 | "(37, 1024)\n", 84 | "(51, 1024)\n", 85 | "(40, 1024)\n", 86 | "(48, 1024)\n", 87 | "(47, 1024)\n", 88 | "(48, 1024)\n", 89 | "(51, 1024)\n", 90 | "(46, 1024)\n", 91 | "(47, 1024)\n", 92 | "(37, 1024)\n", 93 | "(44, 1024)\n", 94 | "(68, 1024)\n", 95 | "(52, 1024)\n", 96 | "(37, 1024)\n", 97 | "(43, 1024)\n", 98 | "(34, 1024)\n", 99 | "(38, 1024)\n", 100 | "(32, 1024)\n", 101 | "(33, 1024)\n", 102 | "(64, 1024)\n", 103 | "(34, 1024)\n", 104 | "(48, 1024)\n", 105 | "(47, 1024)\n", 106 | "(59, 1024)\n", 107 | "(48, 1024)\n", 108 | "(40, 1024)\n", 109 | "(35, 1024)\n", 110 | "(42, 1024)\n", 111 | "(54, 1024)\n", 112 | "(45, 1024)\n", 113 | "(37, 1024)\n", 114 | "(49, 1024)\n", 115 | "(54, 1024)\n", 116 | "(48, 1024)\n", 117 | "(43, 1024)\n", 118 | "(35, 1024)\n", 119 | "(35, 1024)\n", 120 | "(50, 1024)\n", 121 | "(38, 1024)\n", 122 | "(35, 1024)\n", 123 | "(42, 1024)\n", 124 | "(46, 1024)\n", 125 | "(42, 1024)\n", 126 | "(31, 1024)\n", 127 | "(43, 1024)\n", 128 | "(45, 1024)\n", 129 | "(47, 1024)\n", 130 | "(45, 1024)\n", 131 | "(62, 1024)\n", 132 | "(48, 1024)\n", 133 | "(46, 1024)\n", 134 | "(37, 1024)\n", 135 | "(43, 1024)\n", 136 | "(39, 1024)\n", 137 | "(41, 1024)\n", 138 | "(38, 1024)\n", 139 | "(42, 1024)\n", 140 | "(55, 1024)\n", 141 | "(33, 1024)\n", 142 | "(42, 1024)\n", 143 | "(33, 1024)\n", 144 | "(51, 1024)\n", 145 | "(30, 1024)\n", 146 | "(40, 1024)\n", 147 | "(43, 1024)\n", 148 | "(44, 1024)\n", 149 | "(39, 1024)\n", 150 | "(38, 1024)\n", 151 | "(43, 1024)\n", 152 | "(39, 1024)\n", 153 | "(49, 1024)\n", 154 | "(49, 1024)\n", 155 | "(40, 1024)\n", 156 | "(107,) (35,) (107,) (35,)\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "import numpy as np\n", 162 | "import pandas as pd\n", 163 | "import wave\n", 164 | "import librosa\n", 165 | "import re\n", 166 | "from allennlp.commands.elmo import ElmoEmbedder\n", 167 | "# from bert_serving.client import BertClient\n", 168 | "\n", 169 | "prefix = '/Users/apple/Downloads/depression/'\n", 170 | "\n", 171 | "elmo = ElmoEmbedder()\n", 172 | "# bc = BertClient(ip='100.66.165.12')\n", 173 | "\n", 174 | "train_split_df = pd.read_csv(prefix+'train_split_Depression_AVEC2017 (1).csv')\n", 175 | "test_split_df = pd.read_csv(prefix+'dev_split_Depression_AVEC2017.csv')\n", 176 | "train_split_num = train_split_df[['Participant_ID']]['Participant_ID'].tolist()\n", 177 | "test_split_num = test_split_df[['Participant_ID']]['Participant_ID'].tolist()\n", 178 | "train_split_clabel = train_split_df[['PHQ8_Binary']]['PHQ8_Binary'].tolist()\n", 179 | "test_split_clabel = test_split_df[['PHQ8_Binary']]['PHQ8_Binary'].tolist()\n", 180 | "\n", 181 | "topics = []\n", 182 | "with open('/Users/apple/Downloads/depression/queries.txt', 'r') as f:\n", 183 | " for line in f.readlines():\n", 184 | " topics.append(line.strip('\\n').strip())\n", 185 | " \n", 186 | "\n", 187 | "def identify_topics(sentence):\n", 188 | " if sentence in topics:\n", 189 | " return True\n", 190 | " return False\n", 191 | "\n", 192 | "def extract_features(number, text_features, target, mode, text_targets):\n", 193 | " \n", 194 | " transcript = pd.read_csv(prefix+'{0}_P/{0}_TRANSCRIPT.csv'.format(number), sep='\\t').fillna('')\n", 195 | " \n", 196 | " \n", 197 | " time_range = []\n", 198 | " responses = []\n", 199 | " response = ''\n", 200 | " response_flag = False\n", 201 | " start_time = 0\n", 202 | " stop_time = 0\n", 203 | "\n", 204 | " signal = []\n", 205 | " \n", 206 | " global counter1, counter2\n", 207 | "\n", 208 | " for t in transcript.itertuples():\n", 209 | " # participant一句话结束\n", 210 | " if getattr(t,'speaker') == 'Ellie':\n", 211 | "# if '(' in getattr(t,'value'):\n", 212 | "# content = re.findall(re.compile(r'[(](.*?)[)]', re.S), getattr(t,'value'))[0]\n", 213 | "# print(content)\n", 214 | "# else:\n", 215 | "# content = getattr(t,'value').strip()\n", 216 | " content = getattr(t,'value').strip()\n", 217 | " if identify_topics(content):\n", 218 | " response_flag = True\n", 219 | " if len(response) != 0:\n", 220 | " responses.append(response.strip())\n", 221 | " response = ''\n", 222 | " elif response_flag and len(content.split()) > 4:\n", 223 | " response_flag = False\n", 224 | " if len(response) != 0:\n", 225 | " responses.append(response)\n", 226 | " response = ''\n", 227 | " elif getattr(t,'speaker') == 'Participant':\n", 228 | " if 'scrubbed_entry' in getattr(t,'value'):\n", 229 | " continue\n", 230 | " elif response_flag:\n", 231 | " content = getattr(t,'value').split('\\n')[0].strip()\n", 232 | "# if '<' in getattr(t,'value'):\n", 233 | "# content = re.sub(u\"\\\\<.*?\\\\>\", \"\", content)\n", 234 | " response+=' '+content\n", 235 | " \n", 236 | " text_feature = elmo.embed_sentence(responses).mean(0)\n", 237 | "# text_feature = bc.encode(responses)\n", 238 | "# while text_feature.shape[0] < 30:\n", 239 | "# print(number)\n", 240 | "# text_feature = np.vstack((text_feature, np.zeros(text_feature.shape[1])))\n", 241 | " print(text_feature.shape)\n", 242 | "# text_features.append(text_feature[:30])\n", 243 | " text_features.append(text_feature)\n", 244 | " text_targets.append(target)\n", 245 | " \n", 246 | "def extract_features1(number, text_features, target, mode, text_targets):\n", 247 | " \n", 248 | " transcript = pd.read_csv(prefix+'{0}_P/{0}_TRANSCRIPT.csv'.format(number), sep='\\t').fillna('')\n", 249 | " \n", 250 | " \n", 251 | " time_range = []\n", 252 | " responses = []\n", 253 | " response = ''\n", 254 | " response_flag = False\n", 255 | " start_time = 0\n", 256 | " stop_time = 0\n", 257 | "\n", 258 | " signal = []\n", 259 | " \n", 260 | " global counter1, counter2\n", 261 | "\n", 262 | " for t in transcript.itertuples():\n", 263 | " # participant一句话结束\n", 264 | " if getattr(t,'speaker') == 'Ellie':\n", 265 | " if '(' in getattr(t,'value'):\n", 266 | " content = re.findall(re.compile(r'[(](.*?)[)]', re.S), getattr(t,'value'))[0]\n", 267 | " else:\n", 268 | " content = getattr(t,'value').strip()\n", 269 | " content = getattr(t,'value').strip()\n", 270 | " if identify_topics(content):\n", 271 | " response_flag = True\n", 272 | " if len(response) != 0:\n", 273 | " responses.append(response.strip())\n", 274 | " response = ''\n", 275 | " elif response_flag and len(content.split()) > 4:\n", 276 | " response_flag = False\n", 277 | " if len(response) != 0:\n", 278 | " responses.append(response)\n", 279 | " response = ''\n", 280 | " elif getattr(t,'speaker') == 'Participant':\n", 281 | " if 'scrubbed_entry' in getattr(t,'value'):\n", 282 | " continue\n", 283 | " elif response_flag:\n", 284 | " response+=' '+getattr(t,'value').split('\\n')[0].strip()\n", 285 | " \n", 286 | " text_feature = elmo.embed_sentence(responses).mean(0)\n", 287 | " print(text_feature.shape)\n", 288 | " text_features.append(text_feature)\n", 289 | " if target == 1:\n", 290 | " counter1 += len(text_feature)\n", 291 | " else:\n", 292 | " counter2 += len(text_feature)\n", 293 | " text_targets.append(target)\n", 294 | "\n", 295 | "counter1 = 0\n", 296 | "counter2 = 0\n", 297 | " \n", 298 | "# training set\n", 299 | "text_features_train = []\n", 300 | "text_ctargets_train = []\n", 301 | "\n", 302 | "# test set\n", 303 | "text_features_test = []\n", 304 | "text_ctargets_test = []\n", 305 | "\n", 306 | "# ======================= classification =======================\n", 307 | "\n", 308 | "training set\n", 309 | "for index in range(len(train_split_num)):\n", 310 | " extract_features(train_split_num[index], text_features_train, train_split_clabel[index], 'train', text_ctargets_train)\n", 311 | " \n", 312 | "# test set\n", 313 | "for index in range(len(test_split_num)):\n", 314 | " extract_features(test_split_num[index], text_features_test, test_split_clabel[index], 'test', text_ctargets_test)\n", 315 | "\n", 316 | "# ======================= classification =======================\n" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# print(\"Saving npz file locally...\")\n", 326 | "\n", 327 | "# np.savez(prefix+'data/text/train_samples.npz', text_features_train)\n", 328 | "# np.savez(prefix+'data/text/train_labels.npz', text_features_test)\n", 329 | "# np.savez(prefix+'data/text/test_samples.npz', text_ctargets_train)\n", 330 | "# np.savez(prefix+'data/text/test_labels.npz', text_ctargets_test)\n", 331 | "\n", 332 | "prefix = '/Users/apple/Downloads/depression/'\n", 333 | "\n", 334 | "text_features_train = np.load(prefix+'data/text/train_samples.npz')['arr_0']\n", 335 | "text_features_test = np.load(prefix+'data/text/train_labels.npz')['arr_0']\n", 336 | "text_ctargets_train = np.load(prefix+'data/text/test_samples.npz')['arr_0']\n", 337 | "text_ctargets_test = np.load(prefix+'data/text/test_labels.npz')['arr_0']" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 221, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "import torch\n", 347 | "import torch.nn as nn\n", 348 | "from torch.autograd import Variable\n", 349 | "from torch.nn import functional as F\n", 350 | "import torch.optim as optim\n", 351 | "from sklearn.metrics import confusion_matrix" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 222, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "class BiLSTM(nn.Module):\n", 361 | " \n", 362 | " def __init__(self, config):\n", 363 | " super(BiLSTM, self).__init__()\n", 364 | " self.num_classes = config['num_classes']\n", 365 | " self.learning_rate = config['learning_rate']\n", 366 | " self.dropout = config['dropout']\n", 367 | " self.hidden_dims = config['hidden_dims']\n", 368 | " self.rnn_layers = config['rnn_layers']\n", 369 | " self.embedding_size = config['embedding_size']\n", 370 | " self.bidirectional = config['bidirectional']\n", 371 | "\n", 372 | " self.build_model()\n", 373 | " self.init_weight()\n", 374 | " \n", 375 | " def init_weight(net):\n", 376 | " for name, param in net.named_parameters():\n", 377 | " if 'bias' in name:\n", 378 | " nn.init.constant_(param, 0.0)\n", 379 | " elif 'weight' in name:\n", 380 | " nn.init.xavier_uniform_(param)\n", 381 | "\n", 382 | " def build_model(self):\n", 383 | " # attention layer\n", 384 | " self.attention_layer = nn.Sequential(\n", 385 | " nn.Linear(self.hidden_dims, self.hidden_dims),\n", 386 | " nn.ReLU(inplace=True)\n", 387 | " )\n", 388 | " # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)\n", 389 | "\n", 390 | " # 双层lstm\n", 391 | " self.lstm_net = nn.LSTM(self.embedding_size, self.hidden_dims,\n", 392 | " num_layers=self.rnn_layers, dropout=self.dropout,\n", 393 | " bidirectional=self.bidirectional)\n", 394 | " \n", 395 | "# self.init_weight()\n", 396 | " \n", 397 | " # FC层\n", 398 | "# self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)\n", 399 | " self.fc_out = nn.Sequential(\n", 400 | " nn.Dropout(self.dropout),\n", 401 | " nn.Linear(self.hidden_dims, self.hidden_dims),\n", 402 | " nn.ReLU(inplace=True),\n", 403 | " nn.Dropout(self.dropout),\n", 404 | " nn.Linear(self.hidden_dims, self.num_classes),\n", 405 | " nn.ReLU(),\n", 406 | " )\n", 407 | "\n", 408 | " def attention_net_with_w(self, lstm_out, lstm_hidden):\n", 409 | " '''\n", 410 | " :param lstm_out: [batch_size, len_seq, n_hidden * 2]\n", 411 | " :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]\n", 412 | " :return: [batch_size, n_hidden]\n", 413 | " '''\n", 414 | " lstm_tmp_out = torch.chunk(lstm_out, 2, -1)\n", 415 | " # h [batch_size, time_step, hidden_dims]\n", 416 | " h = lstm_tmp_out[0] + lstm_tmp_out[1]\n", 417 | "# h = lstm_out\n", 418 | " # [batch_size, num_layers * num_directions, n_hidden]\n", 419 | " lstm_hidden = torch.sum(lstm_hidden, dim=1)\n", 420 | " # [batch_size, 1, n_hidden]\n", 421 | " lstm_hidden = lstm_hidden.unsqueeze(1)\n", 422 | " # atten_w [batch_size, 1, hidden_dims]\n", 423 | " atten_w = self.attention_layer(lstm_hidden)\n", 424 | " # m [batch_size, time_step, hidden_dims]\n", 425 | " m = nn.Tanh()(h)\n", 426 | " # atten_context [batch_size, 1, time_step]\n", 427 | " atten_context = torch.bmm(atten_w, m.transpose(1, 2))\n", 428 | " # softmax_w [batch_size, 1, time_step]\n", 429 | " softmax_w = F.softmax(atten_context, dim=-1)\n", 430 | " # context [batch_size, 1, hidden_dims]\n", 431 | " context = torch.bmm(softmax_w, h)\n", 432 | " result = context.squeeze(1)\n", 433 | " return result\n", 434 | "\n", 435 | " def forward(self, x):\n", 436 | " \n", 437 | " # x : [len_seq, batch_size, embedding_dim]\n", 438 | " x = x.permute(1, 0, 2)\n", 439 | " output, (final_hidden_state, final_cell_state) = self.lstm_net(x)\n", 440 | " # output : [batch_size, len_seq, n_hidden * 2]\n", 441 | " output = output.permute(1, 0, 2)\n", 442 | " # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]\n", 443 | " final_hidden_state = final_hidden_state.permute(1, 0, 2)\n", 444 | " # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)\n", 445 | " # atten_out = self.attention_net(output, final_hidden_state)\n", 446 | " atten_out = self.attention_net_with_w(output, final_hidden_state)\n", 447 | " return self.fc_out(atten_out)\n", 448 | " \n", 449 | " " 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 284, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "# data imbalance\n", 459 | "X_train = []\n", 460 | "Y_train = []\n", 461 | "X_test = []\n", 462 | "Y_test = []\n", 463 | "\n", 464 | "counter = 0\n", 465 | "\n", 466 | "cut = 10\n", 467 | "debt = 0\n", 468 | "\n", 469 | "for i in range(len(text_features_train)):\n", 470 | "# if text_ctargets_train[i] == 1:\n", 471 | " if text_ctargets_train[i] >= 10:\n", 472 | " times = 3+debt if counter < 46 else 2+debt\n", 473 | "# print(times, text_features_train[i].shape, debt)\n", 474 | " for j in range(times):\n", 475 | " if (j+1)*cut > len(text_features_train[i]):\n", 476 | " debt+=1\n", 477 | " continue\n", 478 | " X_train.append(text_features_train[i][j*cut:(j+1)*cut])\n", 479 | " Y_train.append(text_ctargets_train[i])\n", 480 | " if debt > 0:\n", 481 | " debt -= 1\n", 482 | " counter+=1\n", 483 | " else:\n", 484 | " X_train.append(text_features_train[i][:cut])\n", 485 | " Y_train.append(text_ctargets_train[i])\n", 486 | " \n", 487 | " \n", 488 | "for i in range(len(text_features_test)):\n", 489 | " X_test.append(text_features_test[i][:cut])\n", 490 | " Y_test.append(text_ctargets_test[i])\n", 491 | "\n", 492 | "# for i in range(len(text_features_train)):\n", 493 | "# if text_ctargets_train[i] == 1:\n", 494 | "# times = int(len(text_features_train[i]) / 10)\n", 495 | "# for j in range(times):\n", 496 | "# X_train.append(text_features_train[i][j*10:(j+1)*10])\n", 497 | "# Y_train.append(text_ctargets_train[i])\n", 498 | "# counter+=1\n", 499 | "# else:\n", 500 | "# times = \n", 501 | "# X_train.append(text_features_train[i][:10])\n", 502 | "# Y_train.append(text_ctargets_train[i])\n", 503 | " \n", 504 | " \n", 505 | "# for i in range(len(text_features_test)):\n", 506 | "# X_test.append(text_features_test[i][:10])\n", 507 | "# Y_test.append(text_ctargets_test[i])\n", 508 | " \n", 509 | "X_train = np.array(X_train)\n", 510 | "Y_train = np.array(Y_train)\n", 511 | "X_test = np.array(X_test)\n", 512 | "Y_test = np.array(Y_test)" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 285, 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "data": { 522 | "text/plain": [ 523 | "(154, 10, 1024)" 524 | ] 525 | }, 526 | "execution_count": 285, 527 | "metadata": {}, 528 | "output_type": "execute_result" 529 | } 530 | ], 531 | "source": [ 532 | "X_train.shape" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 340, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "\n", 542 | "config = {\n", 543 | " 'num_classes': 2,\n", 544 | " 'dropout': 0.5,\n", 545 | " 'rnn_layers': 2,\n", 546 | " 'embedding_size': 1024,\n", 547 | " 'batch_size': 8,\n", 548 | " 'epochs': 200,\n", 549 | " 'learning_rate': 5e-4,\n", 550 | " 'hidden_dims': 128,\n", 551 | " 'bidirectional': True\n", 552 | "}\n", 553 | "\n", 554 | "model = BiLSTM(config)\n", 555 | "\n", 556 | "# if args.cuda:\n", 557 | "# model = model.cuda()\n", 558 | "# X_train = X_train.cuda()\n", 559 | "# Y_train = Y_train.cuda()\n", 560 | "# X_test = X_test.cuda()\n", 561 | "# Y_test = Y_test.cuda()\n", 562 | "\n", 563 | "optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])\n", 564 | "# optimizer = optim.Adam(model.parameters())\n", 565 | "# criterion = nn.CrossEntropyLoss()\n", 566 | "criterion = nn.SmoothL1Loss()\n", 567 | "max_f1 = -1\n", 568 | "max_acc = -1\n", 569 | "train_acc = -1\n", 570 | "min_mae = 100\n", 571 | "\n", 572 | "def save(model, filename):\n", 573 | " save_filename = '{}.pt'.format(filename)\n", 574 | " torch.save(model, save_filename)\n", 575 | " print('Saved as %s' % save_filename)\n", 576 | " \n", 577 | "def standard_confusion_matrix(y_test, y_test_pred):\n", 578 | " \"\"\"\n", 579 | " Make confusion matrix with format:\n", 580 | " -----------\n", 581 | " | TP | FP |\n", 582 | " -----------\n", 583 | " | FN | TN |\n", 584 | " -----------\n", 585 | " Parameters\n", 586 | " ----------\n", 587 | " y_true : ndarray - 1D\n", 588 | " y_pred : ndarray - 1D\n", 589 | "\n", 590 | " Returns\n", 591 | " -------\n", 592 | " ndarray - 2D\n", 593 | " \"\"\"\n", 594 | " [[tn, fp], [fn, tp]] = confusion_matrix(y_test, y_test_pred)\n", 595 | " return np.array([[tp, fp], [fn, tn]])\n", 596 | "\n", 597 | "def model_performance(y_test, y_test_pred_proba):\n", 598 | " \"\"\"\n", 599 | " Evaluation metrics for network performance.\n", 600 | " \"\"\"\n", 601 | " y_test_pred = y_test_pred_proba.data.max(1, keepdim=True)[1]\n", 602 | "\n", 603 | " # Computing confusion matrix for test dataset\n", 604 | " conf_matrix = standard_confusion_matrix(y_test, y_test_pred)\n", 605 | " print(\"Confusion Matrix:\")\n", 606 | " print(conf_matrix)\n", 607 | "\n", 608 | " return y_test_pred, conf_matrix\n", 609 | "\n", 610 | "def plot_roc_curve(y_test, y_score):\n", 611 | " \"\"\"\n", 612 | " Plots ROC curve for final trained model. Code taken from:\n", 613 | " https://vkolachalama.blogspot.com/2016/05/keras-implementation-of-mlp-neural.html\n", 614 | " \"\"\"\n", 615 | " fpr, tpr, _ = roc_curve(y_test, y_score)\n", 616 | " roc_auc = auc(fpr, tpr)\n", 617 | " plt.figure()\n", 618 | " plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)\n", 619 | " plt.plot([0, 1], [0, 1], 'k--')\n", 620 | " plt.xlim([0.0, 1.05])\n", 621 | " plt.ylim([0.0, 1.05])\n", 622 | " plt.xlabel('False Positive Rate')\n", 623 | " plt.ylabel('True Positive Rate')\n", 624 | " plt.title('Receiver operating characteristic curve')\n", 625 | " plt.legend(loc=\"lower right\")\n", 626 | " plt.savefig(prefix+'images/BiLSTM_roc.png')\n", 627 | " plt.close()\n", 628 | "\n", 629 | "\n", 630 | "def train(epoch):\n", 631 | " global lr, train_acc\n", 632 | " model.train()\n", 633 | " batch_idx = 1\n", 634 | " total_loss = 0\n", 635 | " correct = 0\n", 636 | " for i in range(0, X_train.shape[0], config['batch_size']):\n", 637 | " if i + config['batch_size'] > X_train.shape[0]:\n", 638 | " x, y = X_train[i:], Y_train[i:]\n", 639 | " else:\n", 640 | " x, y = X_train[i:(i+config['batch_size'])], Y_train[i:(i+config['batch_size'])]\n", 641 | " if False:\n", 642 | " x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda()\n", 643 | " else:\n", 644 | "# x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y))\n", 645 | " x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y)).type(torch.FloatTensor)\n", 646 | " # 将模型的参数梯度设置为0\n", 647 | " optimizer.zero_grad()\n", 648 | " output = model(x)\n", 649 | "# pred = output.data.max(1, keepdim=True)[1]\n", 650 | "# correct += pred.eq(y.data.view_as(pred)).cpu().sum()\n", 651 | " loss = criterion(output.flatten(), y)\n", 652 | " # 后向传播调整参数\n", 653 | " loss.backward()\n", 654 | " # 根据梯度更新网络参数\n", 655 | " optimizer.step()\n", 656 | " batch_idx += 1\n", 657 | " # loss.item()能够得到张量中的元素值\n", 658 | " total_loss += loss.item()\n", 659 | " \n", 660 | "# train_acc = correct\n", 661 | " train_acc = total_loss/batch_idx\n", 662 | " cur_loss = total_loss\n", 663 | "# print('Train Epoch: {:2d}\\t Learning rate: {:.4f}\\tLoss: {:.6f}\\t Accuracy: {}/{} ({:.0f}%)\\n '.format(\n", 664 | "# epoch+1, config['learning_rate'], cur_loss, correct, len(X_train),\n", 665 | "# 100. * correct / len(X_train)))\n", 666 | " print('Train Epoch: {:2d}\\t Learning rate: {:.4f}\\t Loss: {:.6f}\\t '.format(\n", 667 | " epoch+1, config['learning_rate'], total_loss/batch_idx))\n", 668 | "\n", 669 | "\n", 670 | "def evaluate(model):\n", 671 | " model.eval()\n", 672 | " batch_idx = 1\n", 673 | " total_loss = 0\n", 674 | " global max_f1, max_acc, min_mae\n", 675 | " with torch.no_grad():\n", 676 | " x, y = Variable(torch.from_numpy(X_test).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(Y_test)).type(torch.FloatTensor)\n", 677 | " optimizer.zero_grad()\n", 678 | " output = model(x)\n", 679 | " loss = criterion(output.flatten(), y)\n", 680 | " total_loss += loss.item()\n", 681 | "# print(y, output)\n", 682 | "# y_test_pred, conf_matrix = model_performance(y, output)\n", 683 | " print('\\nTest set: Average loss: {:.4f} \\t MAE: {:.4f}\\n'.format(total_loss, F.l1_loss(output.flatten(), y)))\n", 684 | " \n", 685 | " # custom evaluation metrics\n", 686 | " print('Calculating additional test metrics...')\n", 687 | " accuracy = float(conf_matrix[0][0] + conf_matrix[1][1]) / np.sum(conf_matrix)\n", 688 | " precision = float(conf_matrix[0][0]) / (conf_matrix[0][0] + conf_matrix[0][1])\n", 689 | " recall = float(conf_matrix[0][0]) / (conf_matrix[0][0] + conf_matrix[1][0])\n", 690 | " f1_score = 2 * (precision * recall) / (precision + recall)\n", 691 | " print(\"Accuracy: {}\".format(accuracy))\n", 692 | " print(\"Precision: {}\".format(precision))\n", 693 | " print(\"Recall: {}\".format(recall))\n", 694 | " print(\"F1-Score: {}\\n\".format(f1_score))\n", 695 | " print('='*89)\n", 696 | " \n", 697 | " if max_f1 <= f1_score and train_acc > 151:\n", 698 | " max_f1 = f1_score\n", 699 | " max_acc = accuracy\n", 700 | " save(model, 'BiLSTM_elmo_{}_{:.2f}'.format(config['hidden_dims'], max_f1)) \n", 701 | " print('*'*64)\n", 702 | " print('model saved: f1: {}\\tacc: {}'.format(max_f1, max_acc))\n", 703 | " print('*'*64)\n", 704 | " # if min_mae >= F.l1_loss(output.flatten(), y) and train_acc < 2.0:\n", 705 | " # min_mae = F.l1_loss(output.flatten(), y)\n", 706 | " # save(model, 'BiLSTM_reg_{}_{:.2f}'.format(config['hidden_dims'], min_mae)) \n", 707 | " # print('*'*64)\n", 708 | " # print('model saved: f1: {}\\tacc: {}'.format(min_mae, train_acc))\n", 709 | " # print('*'*64)\n", 710 | "\n", 711 | " return total_loss\n" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 341, 717 | "metadata": { 718 | "scrolled": true 719 | }, 720 | "outputs": [ 721 | { 722 | "name": "stdout", 723 | "output_type": "stream", 724 | "text": [ 725 | "Train Epoch: 2\t Learning rate: 0.0005\t Loss: 6.222528\t \n", 726 | "\n", 727 | "Test set: Average loss: 5.2454 \t MAE: 5.7352\n", 728 | "\n", 729 | "Train Epoch: 3\t Learning rate: 0.0005\t Loss: 4.801173\t \n", 730 | "\n", 731 | "Test set: Average loss: 5.5409 \t MAE: 6.0273\n", 732 | "\n", 733 | "Train Epoch: 4\t Learning rate: 0.0005\t Loss: 4.431237\t \n", 734 | "\n", 735 | "Test set: Average loss: 5.2539 \t MAE: 5.7438\n", 736 | "\n", 737 | "Train Epoch: 5\t Learning rate: 0.0005\t Loss: 4.382258\t \n", 738 | "\n", 739 | "Test set: Average loss: 5.3497 \t MAE: 5.8280\n", 740 | "\n", 741 | "Train Epoch: 6\t Learning rate: 0.0005\t Loss: 4.518691\t \n", 742 | "\n", 743 | "Test set: Average loss: 5.4297 \t MAE: 5.9101\n", 744 | "\n", 745 | "Train Epoch: 7\t Learning rate: 0.0005\t Loss: 4.425840\t \n", 746 | "\n", 747 | "Test set: Average loss: 5.3102 \t MAE: 5.7929\n", 748 | "\n", 749 | "Train Epoch: 8\t Learning rate: 0.0005\t Loss: 4.562018\t \n", 750 | "\n", 751 | "Test set: Average loss: 5.0805 \t MAE: 5.5750\n", 752 | "\n", 753 | "Train Epoch: 9\t Learning rate: 0.0005\t Loss: 4.654327\t \n", 754 | "\n", 755 | "Test set: Average loss: 5.2846 \t MAE: 5.7715\n", 756 | "\n", 757 | "Train Epoch: 10\t Learning rate: 0.0005\t Loss: 4.565650\t \n", 758 | "\n", 759 | "Test set: Average loss: 5.2604 \t MAE: 5.7470\n", 760 | "\n", 761 | "Train Epoch: 11\t Learning rate: 0.0005\t Loss: 4.451685\t \n", 762 | "\n", 763 | "Test set: Average loss: 5.2530 \t MAE: 5.7409\n", 764 | "\n", 765 | "Train Epoch: 12\t Learning rate: 0.0005\t Loss: 4.512582\t \n", 766 | "\n", 767 | "Test set: Average loss: 5.1506 \t MAE: 5.6465\n", 768 | "\n", 769 | "Train Epoch: 13\t Learning rate: 0.0005\t Loss: 4.443742\t \n", 770 | "\n", 771 | "Test set: Average loss: 5.1558 \t MAE: 5.6496\n", 772 | "\n", 773 | "Train Epoch: 14\t Learning rate: 0.0005\t Loss: 4.298081\t \n", 774 | "\n", 775 | "Test set: Average loss: 4.8663 \t MAE: 5.3347\n", 776 | "\n", 777 | "Train Epoch: 15\t Learning rate: 0.0005\t Loss: 3.788248\t \n", 778 | "\n", 779 | "Test set: Average loss: 4.2885 \t MAE: 4.7878\n", 780 | "\n", 781 | "Train Epoch: 16\t Learning rate: 0.0005\t Loss: 3.246390\t \n", 782 | "\n", 783 | "Test set: Average loss: 4.1861 \t MAE: 4.6585\n", 784 | "\n", 785 | "Train Epoch: 17\t Learning rate: 0.0005\t Loss: 4.021637\t \n", 786 | "\n", 787 | "Test set: Average loss: 3.8039 \t MAE: 4.2900\n", 788 | "\n", 789 | "Train Epoch: 18\t Learning rate: 0.0005\t Loss: 3.298809\t \n", 790 | "\n", 791 | "Test set: Average loss: 3.8797 \t MAE: 4.3641\n", 792 | "\n", 793 | "Train Epoch: 19\t Learning rate: 0.0005\t Loss: 2.909718\t \n", 794 | "\n", 795 | "Test set: Average loss: 3.6439 \t MAE: 4.1346\n", 796 | "\n", 797 | "Train Epoch: 20\t Learning rate: 0.0005\t Loss: 2.991214\t \n", 798 | "\n", 799 | "Test set: Average loss: 3.5464 \t MAE: 4.0234\n", 800 | "\n", 801 | "Train Epoch: 21\t Learning rate: 0.0005\t Loss: 2.892707\t \n", 802 | "\n", 803 | "Test set: Average loss: 3.1266 \t MAE: 3.5861\n", 804 | "\n", 805 | "Train Epoch: 22\t Learning rate: 0.0005\t Loss: 2.806894\t \n", 806 | "\n", 807 | "Test set: Average loss: 3.3687 \t MAE: 3.8398\n", 808 | "\n", 809 | "Train Epoch: 23\t Learning rate: 0.0005\t Loss: 2.462016\t \n", 810 | "\n", 811 | "Test set: Average loss: 3.2983 \t MAE: 3.7604\n", 812 | "\n", 813 | "Train Epoch: 24\t Learning rate: 0.0005\t Loss: 2.452944\t \n", 814 | "\n", 815 | "Test set: Average loss: 3.7188 \t MAE: 4.2055\n", 816 | "\n", 817 | "Train Epoch: 25\t Learning rate: 0.0005\t Loss: 2.329042\t \n", 818 | "\n", 819 | "Test set: Average loss: 3.4484 \t MAE: 3.8869\n", 820 | "\n", 821 | "Train Epoch: 26\t Learning rate: 0.0005\t Loss: 2.495918\t \n", 822 | "\n", 823 | "Test set: Average loss: 3.8797 \t MAE: 4.3500\n", 824 | "\n", 825 | "Train Epoch: 27\t Learning rate: 0.0005\t Loss: 2.272594\t \n", 826 | "\n", 827 | "Test set: Average loss: 3.8090 \t MAE: 4.2980\n", 828 | "\n", 829 | "Train Epoch: 28\t Learning rate: 0.0005\t Loss: 2.619707\t \n", 830 | "\n", 831 | "Test set: Average loss: 3.6732 \t MAE: 4.1515\n", 832 | "\n", 833 | "Train Epoch: 29\t Learning rate: 0.0005\t Loss: 2.312229\t \n", 834 | "\n", 835 | "Test set: Average loss: 3.6419 \t MAE: 4.1090\n", 836 | "\n", 837 | "Train Epoch: 30\t Learning rate: 0.0005\t Loss: 2.056890\t \n", 838 | "\n", 839 | "Test set: Average loss: 4.1940 \t MAE: 4.6646\n", 840 | "\n", 841 | "Train Epoch: 31\t Learning rate: 0.0005\t Loss: 2.042540\t \n", 842 | "\n", 843 | "Test set: Average loss: 3.8574 \t MAE: 4.3044\n", 844 | "\n", 845 | "Train Epoch: 32\t Learning rate: 0.0005\t Loss: 2.110537\t \n", 846 | "\n", 847 | "Test set: Average loss: 3.8590 \t MAE: 4.3267\n", 848 | "\n", 849 | "Train Epoch: 33\t Learning rate: 0.0005\t Loss: 2.181078\t \n", 850 | "\n", 851 | "Test set: Average loss: 4.0800 \t MAE: 4.5380\n", 852 | "\n", 853 | "Train Epoch: 34\t Learning rate: 0.0005\t Loss: 1.838920\t \n", 854 | "\n", 855 | "Test set: Average loss: 3.6816 \t MAE: 4.1475\n", 856 | "\n", 857 | "Saved as BiLSTM_reg_128_4.15.pt\n", 858 | "****************************************************************\n", 859 | "model saved: f1: 4.147465705871582\tacc: 1.8389204747620083\n", 860 | "****************************************************************\n", 861 | "Train Epoch: 35\t Learning rate: 0.0005\t Loss: 1.695412\t \n", 862 | "\n", 863 | "Test set: Average loss: 4.4402 \t MAE: 4.9363\n", 864 | "\n", 865 | "Train Epoch: 36\t Learning rate: 0.0005\t Loss: 2.076556\t \n", 866 | "\n", 867 | "Test set: Average loss: 4.0873 \t MAE: 4.5635\n", 868 | "\n", 869 | "Train Epoch: 37\t Learning rate: 0.0005\t Loss: 1.879723\t \n", 870 | "\n", 871 | "Test set: Average loss: 4.5997 \t MAE: 5.0551\n", 872 | "\n", 873 | "Train Epoch: 38\t Learning rate: 0.0005\t Loss: 1.973436\t \n", 874 | "\n", 875 | "Test set: Average loss: 4.1427 \t MAE: 4.6181\n", 876 | "\n", 877 | "Train Epoch: 39\t Learning rate: 0.0005\t Loss: 1.913318\t \n", 878 | "\n", 879 | "Test set: Average loss: 4.1102 \t MAE: 4.5848\n", 880 | "\n", 881 | "Train Epoch: 40\t Learning rate: 0.0005\t Loss: 1.856847\t \n", 882 | "\n", 883 | "Test set: Average loss: 4.1139 \t MAE: 4.5657\n", 884 | "\n", 885 | "Train Epoch: 41\t Learning rate: 0.0005\t Loss: 1.771701\t \n", 886 | "\n", 887 | "Test set: Average loss: 3.9007 \t MAE: 4.3717\n", 888 | "\n", 889 | "Train Epoch: 42\t Learning rate: 0.0005\t Loss: 1.764812\t \n", 890 | "\n", 891 | "Test set: Average loss: 3.7831 \t MAE: 4.2577\n", 892 | "\n", 893 | "Train Epoch: 43\t Learning rate: 0.0005\t Loss: 1.636393\t \n", 894 | "\n", 895 | "Test set: Average loss: 3.9092 \t MAE: 4.3770\n", 896 | "\n", 897 | "Train Epoch: 44\t Learning rate: 0.0005\t Loss: 1.874798\t \n", 898 | "\n", 899 | "Test set: Average loss: 4.0795 \t MAE: 4.5496\n", 900 | "\n", 901 | "Train Epoch: 45\t Learning rate: 0.0005\t Loss: 1.580698\t \n", 902 | "\n", 903 | "Test set: Average loss: 4.1065 \t MAE: 4.5840\n", 904 | "\n", 905 | "Train Epoch: 46\t Learning rate: 0.0005\t Loss: 1.548155\t \n", 906 | "\n", 907 | "Test set: Average loss: 4.0280 \t MAE: 4.5011\n", 908 | "\n", 909 | "Train Epoch: 47\t Learning rate: 0.0005\t Loss: 1.400702\t \n", 910 | "\n", 911 | "Test set: Average loss: 4.2299 \t MAE: 4.6754\n", 912 | "\n", 913 | "Train Epoch: 48\t Learning rate: 0.0005\t Loss: 1.549356\t \n", 914 | "\n", 915 | "Test set: Average loss: 4.5074 \t MAE: 4.9787\n", 916 | "\n", 917 | "Train Epoch: 49\t Learning rate: 0.0005\t Loss: 1.574735\t \n", 918 | "\n", 919 | "Test set: Average loss: 4.5332 \t MAE: 4.9923\n", 920 | "\n", 921 | "Train Epoch: 50\t Learning rate: 0.0005\t Loss: 1.354524\t \n", 922 | "\n", 923 | "Test set: Average loss: 4.7064 \t MAE: 5.1539\n", 924 | "\n", 925 | "Train Epoch: 51\t Learning rate: 0.0005\t Loss: 2.049156\t \n", 926 | "\n", 927 | "Test set: Average loss: 4.1546 \t MAE: 4.6241\n", 928 | "\n", 929 | "Train Epoch: 52\t Learning rate: 0.0005\t Loss: 1.527975\t \n", 930 | "\n", 931 | "Test set: Average loss: 4.0918 \t MAE: 4.5543\n", 932 | "\n", 933 | "Train Epoch: 53\t Learning rate: 0.0005\t Loss: 1.346940\t \n", 934 | "\n", 935 | "Test set: Average loss: 4.0010 \t MAE: 4.4820\n", 936 | "\n", 937 | "Train Epoch: 54\t Learning rate: 0.0005\t Loss: 1.353701\t \n", 938 | "\n", 939 | "Test set: Average loss: 4.3962 \t MAE: 4.8609\n", 940 | "\n", 941 | "Train Epoch: 55\t Learning rate: 0.0005\t Loss: 1.306012\t \n", 942 | "\n", 943 | "Test set: Average loss: 4.2128 \t MAE: 4.6894\n", 944 | "\n", 945 | "Train Epoch: 56\t Learning rate: 0.0005\t Loss: 1.371004\t \n", 946 | "\n", 947 | "Test set: Average loss: 4.1537 \t MAE: 4.6346\n", 948 | "\n", 949 | "Train Epoch: 57\t Learning rate: 0.0005\t Loss: 1.368044\t \n", 950 | "\n", 951 | "Test set: Average loss: 4.2551 \t MAE: 4.7531\n", 952 | "\n", 953 | "Train Epoch: 58\t Learning rate: 0.0005\t Loss: 1.447491\t \n", 954 | "\n", 955 | "Test set: Average loss: 4.0097 \t MAE: 4.4771\n", 956 | "\n", 957 | "Train Epoch: 59\t Learning rate: 0.0005\t Loss: 1.247104\t \n", 958 | "\n", 959 | "Test set: Average loss: 3.9311 \t MAE: 4.4053\n", 960 | "\n", 961 | "Train Epoch: 60\t Learning rate: 0.0005\t Loss: 1.315407\t \n", 962 | "\n", 963 | "Test set: Average loss: 4.3192 \t MAE: 4.7980\n", 964 | "\n", 965 | "Train Epoch: 61\t Learning rate: 0.0005\t Loss: 1.247827\t \n", 966 | "\n", 967 | "Test set: Average loss: 4.2105 \t MAE: 4.6818\n", 968 | "\n", 969 | "Train Epoch: 62\t Learning rate: 0.0005\t Loss: 1.415605\t \n", 970 | "\n", 971 | "Test set: Average loss: 4.3634 \t MAE: 4.8293\n", 972 | "\n", 973 | "Train Epoch: 63\t Learning rate: 0.0005\t Loss: 1.222708\t \n", 974 | "\n", 975 | "Test set: Average loss: 4.5952 \t MAE: 5.0836\n", 976 | "\n", 977 | "Train Epoch: 64\t Learning rate: 0.0005\t Loss: 1.311415\t \n", 978 | "\n", 979 | "Test set: Average loss: 4.2124 \t MAE: 4.6954\n", 980 | "\n", 981 | "Train Epoch: 65\t Learning rate: 0.0005\t Loss: 1.487544\t \n", 982 | "\n", 983 | "Test set: Average loss: 4.0394 \t MAE: 4.4944\n", 984 | "\n", 985 | "Train Epoch: 66\t Learning rate: 0.0005\t Loss: 1.071867\t \n", 986 | "\n", 987 | "Test set: Average loss: 3.9212 \t MAE: 4.3847\n", 988 | "\n", 989 | "Train Epoch: 67\t Learning rate: 0.0005\t Loss: 1.151054\t \n", 990 | "\n", 991 | "Test set: Average loss: 4.4762 \t MAE: 4.9519\n", 992 | "\n", 993 | "Train Epoch: 68\t Learning rate: 0.0005\t Loss: 1.309198\t \n", 994 | "\n", 995 | "Test set: Average loss: 4.4172 \t MAE: 4.8880\n", 996 | "\n", 997 | "Train Epoch: 69\t Learning rate: 0.0005\t Loss: 1.121209\t \n", 998 | "\n", 999 | "Test set: Average loss: 4.4038 \t MAE: 4.8850\n", 1000 | "\n", 1001 | "Train Epoch: 70\t Learning rate: 0.0005\t Loss: 1.122766\t \n", 1002 | "\n", 1003 | "Test set: Average loss: 3.9661 \t MAE: 4.4481\n", 1004 | "\n", 1005 | "Train Epoch: 71\t Learning rate: 0.0005\t Loss: 1.306642\t \n", 1006 | "\n", 1007 | "Test set: Average loss: 4.0182 \t MAE: 4.4837\n", 1008 | "\n", 1009 | "Train Epoch: 72\t Learning rate: 0.0005\t Loss: 1.271639\t \n", 1010 | "\n", 1011 | "Test set: Average loss: 4.3392 \t MAE: 4.7997\n", 1012 | "\n", 1013 | "Train Epoch: 73\t Learning rate: 0.0005\t Loss: 1.111040\t \n", 1014 | "\n", 1015 | "Test set: Average loss: 4.1592 \t MAE: 4.6452\n", 1016 | "\n", 1017 | "Train Epoch: 74\t Learning rate: 0.0005\t Loss: 1.092647\t \n", 1018 | "\n", 1019 | "Test set: Average loss: 4.5354 \t MAE: 5.0171\n", 1020 | "\n", 1021 | "Train Epoch: 75\t Learning rate: 0.0005\t Loss: 1.071747\t \n", 1022 | "\n", 1023 | "Test set: Average loss: 4.2062 \t MAE: 4.6748\n", 1024 | "\n", 1025 | "Train Epoch: 76\t Learning rate: 0.0005\t Loss: 1.139450\t \n", 1026 | "\n", 1027 | "Test set: Average loss: 3.7759 \t MAE: 4.2505\n", 1028 | "\n", 1029 | "Train Epoch: 77\t Learning rate: 0.0005\t Loss: 1.193882\t \n", 1030 | "\n", 1031 | "Test set: Average loss: 3.6220 \t MAE: 4.0822\n", 1032 | "\n", 1033 | "Saved as BiLSTM_reg_128_4.08.pt\n", 1034 | "****************************************************************\n", 1035 | "model saved: f1: 4.082216739654541\tacc: 1.1938817671367101\n", 1036 | "****************************************************************\n" 1037 | ] 1038 | }, 1039 | { 1040 | "name": "stdout", 1041 | "output_type": "stream", 1042 | "text": [ 1043 | "Train Epoch: 78\t Learning rate: 0.0005\t Loss: 1.062055\t \n", 1044 | "\n", 1045 | "Test set: Average loss: 3.9293 \t MAE: 4.4046\n", 1046 | "\n", 1047 | "Train Epoch: 79\t Learning rate: 0.0005\t Loss: 1.161488\t \n", 1048 | "\n", 1049 | "Test set: Average loss: 4.0311 \t MAE: 4.4819\n", 1050 | "\n", 1051 | "Train Epoch: 80\t Learning rate: 0.0005\t Loss: 1.155397\t \n", 1052 | "\n", 1053 | "Test set: Average loss: 4.4762 \t MAE: 4.9510\n", 1054 | "\n", 1055 | "Train Epoch: 81\t Learning rate: 0.0005\t Loss: 1.124911\t \n", 1056 | "\n", 1057 | "Test set: Average loss: 3.9555 \t MAE: 4.4218\n", 1058 | "\n", 1059 | "Train Epoch: 82\t Learning rate: 0.0005\t Loss: 1.087112\t \n", 1060 | "\n", 1061 | "Test set: Average loss: 3.9758 \t MAE: 4.4648\n", 1062 | "\n", 1063 | "Train Epoch: 83\t Learning rate: 0.0005\t Loss: 1.143451\t \n", 1064 | "\n", 1065 | "Test set: Average loss: 4.1492 \t MAE: 4.6386\n", 1066 | "\n", 1067 | "Train Epoch: 84\t Learning rate: 0.0005\t Loss: 1.021506\t \n", 1068 | "\n", 1069 | "Test set: Average loss: 4.2033 \t MAE: 4.6706\n", 1070 | "\n", 1071 | "Train Epoch: 85\t Learning rate: 0.0005\t Loss: 1.397525\t \n", 1072 | "\n", 1073 | "Test set: Average loss: 4.3981 \t MAE: 4.8657\n", 1074 | "\n", 1075 | "Train Epoch: 86\t Learning rate: 0.0005\t Loss: 0.968711\t \n", 1076 | "\n", 1077 | "Test set: Average loss: 3.7188 \t MAE: 4.2074\n", 1078 | "\n", 1079 | "Train Epoch: 87\t Learning rate: 0.0005\t Loss: 1.095073\t \n", 1080 | "\n", 1081 | "Test set: Average loss: 3.9837 \t MAE: 4.4526\n", 1082 | "\n", 1083 | "Train Epoch: 88\t Learning rate: 0.0005\t Loss: 1.061223\t \n", 1084 | "\n", 1085 | "Test set: Average loss: 4.3055 \t MAE: 4.7557\n", 1086 | "\n", 1087 | "Train Epoch: 89\t Learning rate: 0.0005\t Loss: 0.967244\t \n", 1088 | "\n", 1089 | "Test set: Average loss: 3.9507 \t MAE: 4.4074\n", 1090 | "\n", 1091 | "Train Epoch: 90\t Learning rate: 0.0005\t Loss: 1.116576\t \n", 1092 | "\n", 1093 | "Test set: Average loss: 4.2355 \t MAE: 4.7020\n", 1094 | "\n", 1095 | "Train Epoch: 91\t Learning rate: 0.0005\t Loss: 1.129589\t \n", 1096 | "\n", 1097 | "Test set: Average loss: 4.2573 \t MAE: 4.7175\n", 1098 | "\n", 1099 | "Train Epoch: 92\t Learning rate: 0.0005\t Loss: 0.998689\t \n", 1100 | "\n", 1101 | "Test set: Average loss: 4.2436 \t MAE: 4.7121\n", 1102 | "\n", 1103 | "Train Epoch: 93\t Learning rate: 0.0005\t Loss: 1.028936\t \n", 1104 | "\n", 1105 | "Test set: Average loss: 4.2015 \t MAE: 4.6581\n", 1106 | "\n", 1107 | "Train Epoch: 94\t Learning rate: 0.0005\t Loss: 1.025468\t \n", 1108 | "\n", 1109 | "Test set: Average loss: 4.2462 \t MAE: 4.6956\n", 1110 | "\n", 1111 | "Train Epoch: 95\t Learning rate: 0.0005\t Loss: 0.973014\t \n", 1112 | "\n", 1113 | "Test set: Average loss: 3.9069 \t MAE: 4.3618\n", 1114 | "\n", 1115 | "Train Epoch: 96\t Learning rate: 0.0005\t Loss: 0.917344\t \n", 1116 | "\n", 1117 | "Test set: Average loss: 4.2787 \t MAE: 4.7302\n", 1118 | "\n", 1119 | "Train Epoch: 97\t Learning rate: 0.0005\t Loss: 1.120929\t \n", 1120 | "\n", 1121 | "Test set: Average loss: 4.3165 \t MAE: 4.7986\n", 1122 | "\n", 1123 | "Train Epoch: 98\t Learning rate: 0.0005\t Loss: 0.962194\t \n", 1124 | "\n", 1125 | "Test set: Average loss: 3.8568 \t MAE: 4.3219\n", 1126 | "\n", 1127 | "Train Epoch: 99\t Learning rate: 0.0005\t Loss: 1.249937\t \n", 1128 | "\n", 1129 | "Test set: Average loss: 4.0346 \t MAE: 4.4936\n", 1130 | "\n", 1131 | "Train Epoch: 100\t Learning rate: 0.0005\t Loss: 1.070985\t \n", 1132 | "\n", 1133 | "Test set: Average loss: 4.2920 \t MAE: 4.7649\n", 1134 | "\n", 1135 | "Train Epoch: 101\t Learning rate: 0.0005\t Loss: 0.921444\t \n", 1136 | "\n", 1137 | "Test set: Average loss: 4.0885 \t MAE: 4.5694\n", 1138 | "\n", 1139 | "Train Epoch: 102\t Learning rate: 0.0005\t Loss: 1.260023\t \n", 1140 | "\n", 1141 | "Test set: Average loss: 4.1256 \t MAE: 4.5917\n", 1142 | "\n", 1143 | "Train Epoch: 103\t Learning rate: 0.0005\t Loss: 1.159951\t \n", 1144 | "\n", 1145 | "Test set: Average loss: 4.2488 \t MAE: 4.7348\n", 1146 | "\n", 1147 | "Train Epoch: 104\t Learning rate: 0.0005\t Loss: 1.322786\t \n", 1148 | "\n", 1149 | "Test set: Average loss: 4.3939 \t MAE: 4.8783\n", 1150 | "\n", 1151 | "Train Epoch: 105\t Learning rate: 0.0005\t Loss: 0.962339\t \n", 1152 | "\n", 1153 | "Test set: Average loss: 4.4167 \t MAE: 4.8781\n", 1154 | "\n", 1155 | "Train Epoch: 106\t Learning rate: 0.0005\t Loss: 0.944769\t \n", 1156 | "\n", 1157 | "Test set: Average loss: 4.1260 \t MAE: 4.6052\n", 1158 | "\n", 1159 | "Train Epoch: 107\t Learning rate: 0.0005\t Loss: 1.055879\t \n", 1160 | "\n", 1161 | "Test set: Average loss: 4.4353 \t MAE: 4.8979\n", 1162 | "\n", 1163 | "Train Epoch: 108\t Learning rate: 0.0005\t Loss: 0.929050\t \n", 1164 | "\n", 1165 | "Test set: Average loss: 4.3397 \t MAE: 4.7854\n", 1166 | "\n", 1167 | "Train Epoch: 109\t Learning rate: 0.0005\t Loss: 0.961483\t \n", 1168 | "\n", 1169 | "Test set: Average loss: 3.9624 \t MAE: 4.4224\n", 1170 | "\n", 1171 | "Train Epoch: 110\t Learning rate: 0.0005\t Loss: 0.946940\t \n", 1172 | "\n", 1173 | "Test set: Average loss: 4.2200 \t MAE: 4.6670\n", 1174 | "\n", 1175 | "Train Epoch: 111\t Learning rate: 0.0005\t Loss: 0.918805\t \n", 1176 | "\n", 1177 | "Test set: Average loss: 3.9667 \t MAE: 4.4206\n", 1178 | "\n", 1179 | "Train Epoch: 112\t Learning rate: 0.0005\t Loss: 1.147821\t \n", 1180 | "\n", 1181 | "Test set: Average loss: 4.0068 \t MAE: 4.4806\n", 1182 | "\n", 1183 | "Train Epoch: 113\t Learning rate: 0.0005\t Loss: 1.246980\t \n", 1184 | "\n", 1185 | "Test set: Average loss: 4.1982 \t MAE: 4.6641\n", 1186 | "\n", 1187 | "Train Epoch: 114\t Learning rate: 0.0005\t Loss: 1.113731\t \n", 1188 | "\n", 1189 | "Test set: Average loss: 4.4716 \t MAE: 4.9515\n", 1190 | "\n", 1191 | "Train Epoch: 115\t Learning rate: 0.0005\t Loss: 0.942133\t \n", 1192 | "\n", 1193 | "Test set: Average loss: 3.8472 \t MAE: 4.3116\n", 1194 | "\n", 1195 | "Train Epoch: 116\t Learning rate: 0.0005\t Loss: 0.836111\t \n", 1196 | "\n", 1197 | "Test set: Average loss: 3.7778 \t MAE: 4.2443\n", 1198 | "\n", 1199 | "Train Epoch: 117\t Learning rate: 0.0005\t Loss: 1.042380\t \n", 1200 | "\n", 1201 | "Test set: Average loss: 3.8928 \t MAE: 4.3669\n", 1202 | "\n", 1203 | "Train Epoch: 118\t Learning rate: 0.0005\t Loss: 0.888997\t \n", 1204 | "\n", 1205 | "Test set: Average loss: 4.3176 \t MAE: 4.7722\n", 1206 | "\n", 1207 | "Train Epoch: 119\t Learning rate: 0.0005\t Loss: 0.863362\t \n", 1208 | "\n", 1209 | "Test set: Average loss: 3.7324 \t MAE: 4.2109\n", 1210 | "\n", 1211 | "Train Epoch: 120\t Learning rate: 0.0005\t Loss: 1.225582\t \n", 1212 | "\n", 1213 | "Test set: Average loss: 4.1852 \t MAE: 4.6686\n", 1214 | "\n", 1215 | "Train Epoch: 121\t Learning rate: 0.0005\t Loss: 1.197303\t \n", 1216 | "\n", 1217 | "Test set: Average loss: 4.4508 \t MAE: 4.9219\n", 1218 | "\n", 1219 | "Train Epoch: 122\t Learning rate: 0.0005\t Loss: 1.091303\t \n", 1220 | "\n", 1221 | "Test set: Average loss: 4.3728 \t MAE: 4.8488\n", 1222 | "\n", 1223 | "Train Epoch: 123\t Learning rate: 0.0005\t Loss: 1.260954\t \n", 1224 | "\n", 1225 | "Test set: Average loss: 4.1683 \t MAE: 4.6418\n", 1226 | "\n", 1227 | "Train Epoch: 124\t Learning rate: 0.0005\t Loss: 0.879488\t \n", 1228 | "\n", 1229 | "Test set: Average loss: 3.9878 \t MAE: 4.4624\n", 1230 | "\n", 1231 | "Train Epoch: 125\t Learning rate: 0.0005\t Loss: 1.115321\t \n", 1232 | "\n", 1233 | "Test set: Average loss: 3.9576 \t MAE: 4.4411\n", 1234 | "\n", 1235 | "Train Epoch: 126\t Learning rate: 0.0005\t Loss: 0.900170\t \n", 1236 | "\n", 1237 | "Test set: Average loss: 4.1885 \t MAE: 4.6625\n", 1238 | "\n", 1239 | "Train Epoch: 127\t Learning rate: 0.0005\t Loss: 0.914275\t \n", 1240 | "\n", 1241 | "Test set: Average loss: 4.1754 \t MAE: 4.6519\n", 1242 | "\n", 1243 | "Train Epoch: 128\t Learning rate: 0.0005\t Loss: 0.845138\t \n", 1244 | "\n", 1245 | "Test set: Average loss: 4.3401 \t MAE: 4.8130\n", 1246 | "\n", 1247 | "Train Epoch: 129\t Learning rate: 0.0005\t Loss: 1.078165\t \n", 1248 | "\n", 1249 | "Test set: Average loss: 4.2290 \t MAE: 4.7133\n", 1250 | "\n", 1251 | "Train Epoch: 130\t Learning rate: 0.0005\t Loss: 0.893762\t \n", 1252 | "\n", 1253 | "Test set: Average loss: 4.2804 \t MAE: 4.7473\n", 1254 | "\n", 1255 | "Train Epoch: 131\t Learning rate: 0.0005\t Loss: 0.898790\t \n", 1256 | "\n", 1257 | "Test set: Average loss: 4.3800 \t MAE: 4.8504\n", 1258 | "\n", 1259 | "Train Epoch: 132\t Learning rate: 0.0005\t Loss: 0.943924\t \n", 1260 | "\n", 1261 | "Test set: Average loss: 4.0918 \t MAE: 4.5740\n", 1262 | "\n", 1263 | "Train Epoch: 133\t Learning rate: 0.0005\t Loss: 0.892788\t \n", 1264 | "\n", 1265 | "Test set: Average loss: 4.2037 \t MAE: 4.6698\n", 1266 | "\n", 1267 | "Train Epoch: 134\t Learning rate: 0.0005\t Loss: 0.993899\t \n", 1268 | "\n", 1269 | "Test set: Average loss: 4.5271 \t MAE: 4.9923\n", 1270 | "\n", 1271 | "Train Epoch: 135\t Learning rate: 0.0005\t Loss: 1.000830\t \n", 1272 | "\n", 1273 | "Test set: Average loss: 3.9788 \t MAE: 4.4592\n", 1274 | "\n", 1275 | "Train Epoch: 136\t Learning rate: 0.0005\t Loss: 0.865986\t \n", 1276 | "\n", 1277 | "Test set: Average loss: 3.9782 \t MAE: 4.4541\n", 1278 | "\n", 1279 | "Train Epoch: 137\t Learning rate: 0.0005\t Loss: 0.778416\t \n", 1280 | "\n", 1281 | "Test set: Average loss: 4.0501 \t MAE: 4.5125\n", 1282 | "\n", 1283 | "Train Epoch: 138\t Learning rate: 0.0005\t Loss: 0.899644\t \n", 1284 | "\n", 1285 | "Test set: Average loss: 3.9759 \t MAE: 4.4407\n", 1286 | "\n", 1287 | "Train Epoch: 139\t Learning rate: 0.0005\t Loss: 0.858805\t \n", 1288 | "\n", 1289 | "Test set: Average loss: 4.1885 \t MAE: 4.6453\n", 1290 | "\n", 1291 | "Train Epoch: 140\t Learning rate: 0.0005\t Loss: 0.980515\t \n", 1292 | "\n", 1293 | "Test set: Average loss: 4.2229 \t MAE: 4.6874\n", 1294 | "\n", 1295 | "Train Epoch: 141\t Learning rate: 0.0005\t Loss: 0.990140\t \n", 1296 | "\n", 1297 | "Test set: Average loss: 4.2213 \t MAE: 4.6812\n", 1298 | "\n", 1299 | "Train Epoch: 142\t Learning rate: 0.0005\t Loss: 1.142093\t \n", 1300 | "\n", 1301 | "Test set: Average loss: 4.4396 \t MAE: 4.9052\n", 1302 | "\n", 1303 | "Train Epoch: 143\t Learning rate: 0.0005\t Loss: 0.918046\t \n", 1304 | "\n", 1305 | "Test set: Average loss: 4.5148 \t MAE: 4.9913\n", 1306 | "\n", 1307 | "Train Epoch: 144\t Learning rate: 0.0005\t Loss: 0.942595\t \n", 1308 | "\n", 1309 | "Test set: Average loss: 4.2595 \t MAE: 4.7389\n", 1310 | "\n", 1311 | "Train Epoch: 145\t Learning rate: 0.0005\t Loss: 0.875068\t \n", 1312 | "\n", 1313 | "Test set: Average loss: 3.9844 \t MAE: 4.4673\n", 1314 | "\n", 1315 | "Train Epoch: 146\t Learning rate: 0.0005\t Loss: 0.932485\t \n", 1316 | "\n", 1317 | "Test set: Average loss: 4.4675 \t MAE: 4.9310\n", 1318 | "\n", 1319 | "Train Epoch: 147\t Learning rate: 0.0005\t Loss: 1.047656\t \n", 1320 | "\n", 1321 | "Test set: Average loss: 4.2100 \t MAE: 4.6831\n", 1322 | "\n", 1323 | "Train Epoch: 148\t Learning rate: 0.0005\t Loss: 0.835819\t \n", 1324 | "\n", 1325 | "Test set: Average loss: 4.0317 \t MAE: 4.5118\n", 1326 | "\n", 1327 | "Train Epoch: 149\t Learning rate: 0.0005\t Loss: 0.980606\t \n", 1328 | "\n", 1329 | "Test set: Average loss: 4.3075 \t MAE: 4.7642\n", 1330 | "\n", 1331 | "Train Epoch: 150\t Learning rate: 0.0005\t Loss: 0.830669\t \n", 1332 | "\n", 1333 | "Test set: Average loss: 4.5313 \t MAE: 4.9935\n", 1334 | "\n", 1335 | "Train Epoch: 151\t Learning rate: 0.0005\t Loss: 0.893053\t \n", 1336 | "\n", 1337 | "Test set: Average loss: 4.4451 \t MAE: 4.9095\n", 1338 | "\n", 1339 | "Train Epoch: 152\t Learning rate: 0.0005\t Loss: 0.872342\t \n", 1340 | "\n", 1341 | "Test set: Average loss: 4.2925 \t MAE: 4.7543\n", 1342 | "\n", 1343 | "Train Epoch: 153\t Learning rate: 0.0005\t Loss: 0.983376\t \n", 1344 | "\n", 1345 | "Test set: Average loss: 4.1181 \t MAE: 4.5934\n", 1346 | "\n", 1347 | "Train Epoch: 154\t Learning rate: 0.0005\t Loss: 1.020965\t \n", 1348 | "\n", 1349 | "Test set: Average loss: 4.1233 \t MAE: 4.5878\n", 1350 | "\n", 1351 | "Train Epoch: 155\t Learning rate: 0.0005\t Loss: 0.784421\t \n", 1352 | "\n", 1353 | "Test set: Average loss: 4.4377 \t MAE: 4.9053\n", 1354 | "\n", 1355 | "Train Epoch: 156\t Learning rate: 0.0005\t Loss: 1.126911\t \n", 1356 | "\n", 1357 | "Test set: Average loss: 3.9217 \t MAE: 4.4043\n", 1358 | "\n" 1359 | ] 1360 | }, 1361 | { 1362 | "name": "stdout", 1363 | "output_type": "stream", 1364 | "text": [ 1365 | "Train Epoch: 157\t Learning rate: 0.0005\t Loss: 0.992670\t \n", 1366 | "\n", 1367 | "Test set: Average loss: 4.0878 \t MAE: 4.5467\n", 1368 | "\n", 1369 | "Train Epoch: 158\t Learning rate: 0.0005\t Loss: 0.865013\t \n", 1370 | "\n", 1371 | "Test set: Average loss: 4.3303 \t MAE: 4.7961\n", 1372 | "\n", 1373 | "Train Epoch: 159\t Learning rate: 0.0005\t Loss: 0.814098\t \n", 1374 | "\n", 1375 | "Test set: Average loss: 4.0651 \t MAE: 4.5495\n", 1376 | "\n", 1377 | "Train Epoch: 160\t Learning rate: 0.0005\t Loss: 0.843722\t \n", 1378 | "\n", 1379 | "Test set: Average loss: 3.9392 \t MAE: 4.4237\n", 1380 | "\n", 1381 | "Train Epoch: 161\t Learning rate: 0.0005\t Loss: 0.778094\t \n", 1382 | "\n", 1383 | "Test set: Average loss: 4.1093 \t MAE: 4.5677\n", 1384 | "\n", 1385 | "Train Epoch: 162\t Learning rate: 0.0005\t Loss: 0.760861\t \n", 1386 | "\n", 1387 | "Test set: Average loss: 4.0424 \t MAE: 4.4901\n", 1388 | "\n", 1389 | "Train Epoch: 163\t Learning rate: 0.0005\t Loss: 0.976117\t \n", 1390 | "\n", 1391 | "Test set: Average loss: 3.9120 \t MAE: 4.3788\n", 1392 | "\n", 1393 | "Train Epoch: 164\t Learning rate: 0.0005\t Loss: 0.747449\t \n", 1394 | "\n", 1395 | "Test set: Average loss: 4.1145 \t MAE: 4.5898\n", 1396 | "\n", 1397 | "Train Epoch: 165\t Learning rate: 0.0005\t Loss: 0.960359\t \n", 1398 | "\n", 1399 | "Test set: Average loss: 4.3442 \t MAE: 4.8045\n", 1400 | "\n", 1401 | "Train Epoch: 166\t Learning rate: 0.0005\t Loss: 0.837411\t \n", 1402 | "\n", 1403 | "Test set: Average loss: 4.5189 \t MAE: 4.9757\n", 1404 | "\n", 1405 | "Train Epoch: 167\t Learning rate: 0.0005\t Loss: 0.917638\t \n", 1406 | "\n", 1407 | "Test set: Average loss: 4.2090 \t MAE: 4.6905\n", 1408 | "\n", 1409 | "Train Epoch: 168\t Learning rate: 0.0005\t Loss: 0.850689\t \n", 1410 | "\n", 1411 | "Test set: Average loss: 4.0824 \t MAE: 4.5608\n", 1412 | "\n", 1413 | "Train Epoch: 169\t Learning rate: 0.0005\t Loss: 0.974172\t \n", 1414 | "\n", 1415 | "Test set: Average loss: 4.0506 \t MAE: 4.5225\n", 1416 | "\n", 1417 | "Train Epoch: 170\t Learning rate: 0.0005\t Loss: 0.863295\t \n", 1418 | "\n", 1419 | "Test set: Average loss: 4.3983 \t MAE: 4.8468\n", 1420 | "\n", 1421 | "Train Epoch: 171\t Learning rate: 0.0005\t Loss: 0.931016\t \n", 1422 | "\n", 1423 | "Test set: Average loss: 4.1384 \t MAE: 4.6132\n", 1424 | "\n", 1425 | "Train Epoch: 172\t Learning rate: 0.0005\t Loss: 0.794198\t \n", 1426 | "\n", 1427 | "Test set: Average loss: 4.3596 \t MAE: 4.8161\n", 1428 | "\n", 1429 | "Train Epoch: 173\t Learning rate: 0.0005\t Loss: 0.820409\t \n", 1430 | "\n", 1431 | "Test set: Average loss: 4.1466 \t MAE: 4.6183\n", 1432 | "\n", 1433 | "Train Epoch: 174\t Learning rate: 0.0005\t Loss: 0.967401\t \n", 1434 | "\n", 1435 | "Test set: Average loss: 4.2012 \t MAE: 4.6818\n", 1436 | "\n", 1437 | "Train Epoch: 175\t Learning rate: 0.0005\t Loss: 0.931624\t \n", 1438 | "\n", 1439 | "Test set: Average loss: 4.0032 \t MAE: 4.4773\n", 1440 | "\n", 1441 | "Train Epoch: 176\t Learning rate: 0.0005\t Loss: 0.914719\t \n", 1442 | "\n", 1443 | "Test set: Average loss: 4.2812 \t MAE: 4.7415\n", 1444 | "\n", 1445 | "Train Epoch: 177\t Learning rate: 0.0005\t Loss: 0.853825\t \n", 1446 | "\n", 1447 | "Test set: Average loss: 4.1724 \t MAE: 4.6384\n", 1448 | "\n", 1449 | "Train Epoch: 178\t Learning rate: 0.0005\t Loss: 0.782554\t \n", 1450 | "\n", 1451 | "Test set: Average loss: 4.1872 \t MAE: 4.6614\n", 1452 | "\n", 1453 | "Train Epoch: 179\t Learning rate: 0.0005\t Loss: 0.790341\t \n", 1454 | "\n", 1455 | "Test set: Average loss: 4.3984 \t MAE: 4.8560\n", 1456 | "\n", 1457 | "Train Epoch: 180\t Learning rate: 0.0005\t Loss: 0.691040\t \n", 1458 | "\n", 1459 | "Test set: Average loss: 4.4033 \t MAE: 4.8658\n", 1460 | "\n", 1461 | "Train Epoch: 181\t Learning rate: 0.0005\t Loss: 0.850953\t \n", 1462 | "\n", 1463 | "Test set: Average loss: 3.8660 \t MAE: 4.3404\n", 1464 | "\n", 1465 | "Train Epoch: 182\t Learning rate: 0.0005\t Loss: 0.896042\t \n", 1466 | "\n", 1467 | "Test set: Average loss: 4.2568 \t MAE: 4.7351\n", 1468 | "\n", 1469 | "Train Epoch: 183\t Learning rate: 0.0005\t Loss: 0.960200\t \n", 1470 | "\n", 1471 | "Test set: Average loss: 4.3396 \t MAE: 4.7906\n", 1472 | "\n", 1473 | "Train Epoch: 184\t Learning rate: 0.0005\t Loss: 0.959877\t \n", 1474 | "\n", 1475 | "Test set: Average loss: 3.8674 \t MAE: 4.3530\n", 1476 | "\n", 1477 | "Train Epoch: 185\t Learning rate: 0.0005\t Loss: 0.725641\t \n", 1478 | "\n", 1479 | "Test set: Average loss: 3.9436 \t MAE: 4.4092\n", 1480 | "\n", 1481 | "Train Epoch: 186\t Learning rate: 0.0005\t Loss: 0.767895\t \n", 1482 | "\n", 1483 | "Test set: Average loss: 4.2262 \t MAE: 4.6659\n", 1484 | "\n", 1485 | "Train Epoch: 187\t Learning rate: 0.0005\t Loss: 0.843481\t \n", 1486 | "\n", 1487 | "Test set: Average loss: 4.2961 \t MAE: 4.7634\n", 1488 | "\n", 1489 | "Train Epoch: 188\t Learning rate: 0.0005\t Loss: 0.874929\t \n", 1490 | "\n", 1491 | "Test set: Average loss: 4.0638 \t MAE: 4.5159\n", 1492 | "\n", 1493 | "Train Epoch: 189\t Learning rate: 0.0005\t Loss: 1.043584\t \n", 1494 | "\n", 1495 | "Test set: Average loss: 4.0349 \t MAE: 4.5109\n", 1496 | "\n", 1497 | "Train Epoch: 190\t Learning rate: 0.0005\t Loss: 0.935190\t \n", 1498 | "\n", 1499 | "Test set: Average loss: 4.0865 \t MAE: 4.5353\n", 1500 | "\n", 1501 | "Train Epoch: 191\t Learning rate: 0.0005\t Loss: 0.694541\t \n", 1502 | "\n", 1503 | "Test set: Average loss: 4.3420 \t MAE: 4.8012\n", 1504 | "\n", 1505 | "Train Epoch: 192\t Learning rate: 0.0005\t Loss: 0.865203\t \n", 1506 | "\n", 1507 | "Test set: Average loss: 3.9585 \t MAE: 4.4290\n", 1508 | "\n", 1509 | "Train Epoch: 193\t Learning rate: 0.0005\t Loss: 0.858346\t \n", 1510 | "\n", 1511 | "Test set: Average loss: 4.1437 \t MAE: 4.6031\n", 1512 | "\n", 1513 | "Train Epoch: 194\t Learning rate: 0.0005\t Loss: 0.906553\t \n", 1514 | "\n", 1515 | "Test set: Average loss: 4.4437 \t MAE: 4.9088\n", 1516 | "\n", 1517 | "Train Epoch: 195\t Learning rate: 0.0005\t Loss: 0.836493\t \n", 1518 | "\n", 1519 | "Test set: Average loss: 4.1875 \t MAE: 4.6559\n", 1520 | "\n", 1521 | "Train Epoch: 196\t Learning rate: 0.0005\t Loss: 0.852588\t \n", 1522 | "\n", 1523 | "Test set: Average loss: 4.4612 \t MAE: 4.9229\n", 1524 | "\n", 1525 | "Train Epoch: 197\t Learning rate: 0.0005\t Loss: 0.905962\t \n", 1526 | "\n", 1527 | "Test set: Average loss: 4.4976 \t MAE: 4.9639\n", 1528 | "\n", 1529 | "Train Epoch: 198\t Learning rate: 0.0005\t Loss: 0.915486\t \n", 1530 | "\n", 1531 | "Test set: Average loss: 4.2601 \t MAE: 4.7137\n", 1532 | "\n", 1533 | "Train Epoch: 199\t Learning rate: 0.0005\t Loss: 0.840349\t \n", 1534 | "\n", 1535 | "Test set: Average loss: 4.0902 \t MAE: 4.5491\n", 1536 | "\n", 1537 | "Train Epoch: 200\t Learning rate: 0.0005\t Loss: 0.837486\t \n", 1538 | "\n", 1539 | "Test set: Average loss: 4.2122 \t MAE: 4.6674\n", 1540 | "\n" 1541 | ] 1542 | } 1543 | ], 1544 | "source": [ 1545 | "for ep in range(1, config['epochs']):\n", 1546 | " train(ep)\n", 1547 | " tloss = evaluate(model)" 1548 | ] 1549 | }, 1550 | { 1551 | "cell_type": "code", 1552 | "execution_count": 9, 1553 | "metadata": {}, 1554 | "outputs": [ 1555 | { 1556 | "name": "stdout", 1557 | "output_type": "stream", 1558 | "text": [ 1559 | "Confusion Matrix:\n", 1560 | "[[10 2]\n", 1561 | " [ 2 21]]\n", 1562 | "\n", 1563 | "Test set: Average loss: 1.1024\n", 1564 | "Calculating additional test metrics...\n", 1565 | "Accuracy: 0.8857142857142857\n", 1566 | "Precision: 0.8333333333333334\n", 1567 | "Recall: 0.8333333333333334\n", 1568 | "F1-Score: 0.8333333333333334\n", 1569 | "\n", 1570 | "=========================================================================================\n" 1571 | ] 1572 | }, 1573 | { 1574 | "data": { 1575 | "text/plain": [ 1576 | "1.102392554283142" 1577 | ] 1578 | }, 1579 | "execution_count": 9, 1580 | "metadata": {}, 1581 | "output_type": "execute_result" 1582 | } 1583 | ], 1584 | "source": [ 1585 | "lstm_model = torch.load('/Users/apple/Downloads/depression/BiLSTM_elmo_128_0.83.pt')\n", 1586 | "model = BiLSTM(config)\n", 1587 | "model.load_state_dict(lstm_model.state_dict())\n", 1588 | "evaluate(model)" 1589 | ] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "execution_count": 342, 1594 | "metadata": {}, 1595 | "outputs": [ 1596 | { 1597 | "name": "stdout", 1598 | "output_type": "stream", 1599 | "text": [ 1600 | "\n", 1601 | "Test set: Average loss: 3.4328 \t MAE: 3.8846\n", 1602 | "\n", 1603 | "Saved as BiLSTM_reg_128_3.88.pt\n", 1604 | "****************************************************************\n", 1605 | "model saved: f1: 3.8845977783203125\tacc: 0.8374860261877378\n", 1606 | "****************************************************************\n" 1607 | ] 1608 | }, 1609 | { 1610 | "data": { 1611 | "text/plain": [ 1612 | "3.432814359664917" 1613 | ] 1614 | }, 1615 | "execution_count": 342, 1616 | "metadata": {}, 1617 | "output_type": "execute_result" 1618 | } 1619 | ], 1620 | "source": [ 1621 | "lstm_model = torch.load('/Users/apple/Downloads/depression/BiLSTM_reg_128_3.88.pt')\n", 1622 | "model = BiLSTM(config)\n", 1623 | "model.load_state_dict(lstm_model.state_dict())\n", 1624 | "evaluate(model)" 1625 | ] 1626 | }, 1627 | { 1628 | "cell_type": "code", 1629 | "execution_count": null, 1630 | "metadata": {}, 1631 | "outputs": [], 1632 | "source": [] 1633 | } 1634 | ], 1635 | "metadata": { 1636 | "kernelspec": { 1637 | "display_name": "Python 3", 1638 | "language": "python", 1639 | "name": "python3" 1640 | }, 1641 | "language_info": { 1642 | "codemirror_mode": { 1643 | "name": "ipython", 1644 | "version": 3 1645 | }, 1646 | "file_extension": ".py", 1647 | "mimetype": "text/x-python", 1648 | "name": "python", 1649 | "nbconvert_exporter": "python", 1650 | "pygments_lexer": "ipython3", 1651 | "version": "3.7.3" 1652 | } 1653 | }, 1654 | "nbformat": 4, 1655 | "nbformat_minor": 2 1656 | } -------------------------------------------------------------------------------- /cnn_audio_reg_avid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import wave 4 | import librosa 5 | from python_speech_features import * 6 | import pickle 7 | 8 | import torch 9 | from torch.utils import data # 获取迭代数据 10 | from torch.autograd import Variable # 获取变量 11 | import torchvision 12 | import matplotlib.pyplot as plt 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from sklearn.metrics import confusion_matrix, mean_absolute_error, mean_squared_error 16 | from torch.nn import functional as F 17 | 18 | from keras.models import Sequential 19 | from keras.layers import Dense, Dropout, Activation, Flatten 20 | from keras.layers import Conv2D, MaxPooling2D 21 | 22 | prefix = '/Users/apple/Downloads/depression/' 23 | with open(prefix+'avid_info.pkl', 'rb') as f: 24 | split_info = pickle.load(f) 25 | 26 | train_split_num = split_info['train'][0] 27 | train_split_label = [int(x) for x in split_info['train'][1]] 28 | test_split_num = split_info['test'][0] 29 | test_split_label = [int(x) for x in split_info['test'][1]] 30 | # train_split_label = np.hstack((train_split_label, dev_split_label)) 31 | # train_split_num = np.hstack((train_split_num, dev_split_num)) 32 | 33 | def extract_features(number, audio_features, mode): 34 | wavefile = wave.open(prefix+'AViD/Audio/{1}/Trim/{0}.wav'.format(number, mode)) 35 | sr = wavefile.getframerate() 36 | nframes = wavefile.getnframes() 37 | wave_data = np.frombuffer(wavefile.readframes(nframes), dtype=np.short).astype(np.float) 38 | 39 | if len(wave_data) < sr*15: 40 | wave_data = np.hstack((wave_data, [1e-2]*(sr*15-len(wave_data)))) 41 | 42 | # 1分钟 43 | clip = sr*1*15 44 | melspec = librosa.feature.melspectrogram(wave_data[:clip], n_mels=80, sr=sr) 45 | audio_features.append(melspec) 46 | if sr == 32000: 47 | print(prefix+'AViD/Audio/{1}/Trim/{0}.wav'.format(number, mode)) 48 | 49 | # training set 50 | audio_features_train = [] 51 | # test set 52 | audio_features_test = [] 53 | 54 | # # training set 55 | # for index in range(len(train_split_num)): 56 | # extract_features(train_split_num[index], audio_features_train, 'Training') 57 | 58 | # # test set 59 | # for index in range(len(test_split_num)): 60 | # extract_features(test_split_num[index], audio_features_test, 'Testing') 61 | 62 | # print(np.shape(audio_features_train), np.shape(audio_features_test)) 63 | 64 | # print("Saving npz file locally...") 65 | # np.savez(prefix+'data/audio/train_samples_reg_avid.npz', audio_features_train) 66 | # np.savez(prefix+'data/audio/train_labels_reg_avid.npz', train_split_label) 67 | # np.savez(prefix+'data/audio/test_samples_reg_avid.npz', audio_features_test) 68 | # np.savez(prefix+'data/audio/test_labels_reg_avid.npz', test_split_label) 69 | 70 | audio_features_train = np.load(prefix+'data/audio/train_samples_reg_avid.npz', allow_pickle=True)['arr_0'] 71 | audio_ctargets_train = np.load(prefix+'data/audio/train_labels_reg_avid.npz', allow_pickle=True)['arr_0'] 72 | audio_features_test = np.load(prefix+'data/audio/test_samples_reg_avid.npz', allow_pickle=True)['arr_0'] 73 | audio_ctargets_test = np.load(prefix+'data/audio/test_labels_reg_avid.npz', allow_pickle=True)['arr_0'] 74 | 75 | config = { 76 | 'num_classes': 1, 77 | 'dropout': 0.5, 78 | 'rnn_layers': 2, 79 | 'embedding_size': 80, 80 | 'batch_size': 2, 81 | 'epochs': 30, 82 | 'learning_rate': 1e-3, 83 | 'cuda': False, 84 | } 85 | 86 | X_train = np.array(audio_features_train) 87 | Y_train = np.array(audio_ctargets_train) 88 | X_test = np.array(audio_features_test) 89 | Y_test = np.array(audio_ctargets_test) 90 | 91 | X_train = X_train.astype('float32') 92 | X_test = X_test.astype('float32') 93 | 94 | X_train = np.array([(X - X.min()) / (X.max() - X.min()) if X.max() != X.min() else X for X in X_train ]) 95 | X_test = np.array([(X - X.min()) / (X.max() - X.min()) if X.max() != X.min() else X for X in X_test ]) 96 | 97 | class CNN(nn.Module): 98 | def __init__(self): 99 | super(CNN, self).__init__() 100 | self.conv2d_1 = nn.Conv2d(1, 32, (1,7), 1) 101 | self.conv2d_2 = nn.Conv2d(32, 32, (1,7), 2) 102 | self.dense_1 = nn.Linear(87360, 128) 103 | self.dense_2 = nn.Linear(128, 128) 104 | self.dense_3 = nn.Linear(128, 1) 105 | self.dropout = nn.Dropout(0.5) 106 | 107 | def forward(self, x): 108 | x = F.relu(self.conv2d_1(x)) 109 | x = F.max_pool2d(x, (4, 3), (1, 3)) 110 | x = F.relu(self.conv2d_2(x)) 111 | x = F.max_pool2d(x, (1, 3), (1, 3)) 112 | # flatten in keras 113 | x = x.permute((0, 2, 3, 1)) 114 | x = x.contiguous().view(-1, 87360) 115 | x = F.relu(self.dense_1(x)) 116 | x = F.relu(self.dense_2(x)) 117 | x = self.dropout(x) 118 | output = self.dense_3(x) 119 | # output = torch.sigmoid(self.dense_3(x)) 120 | # output = torch.relu(self.dense_3(x)) 121 | return output 122 | 123 | model = CNN() 124 | # optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) 125 | optimizer = optim.Adam(model.parameters()) 126 | criterion = nn.SmoothL1Loss() 127 | 128 | def save(model, filename): 129 | save_filename = '{}.pt'.format(filename) 130 | torch.save(model, save_filename) 131 | print('Saved as %s' % save_filename) 132 | 133 | def train(epoch): 134 | global lr 135 | model.train() 136 | batch_idx = 1 137 | total_loss = 0 138 | correct = 0 139 | for i in range(0, X_train.shape[0], config['batch_size']): 140 | if i + config['batch_size'] > X_train.shape[0]: 141 | x, y = X_train[i:], Y_train[i:] 142 | else: 143 | x, y = X_train[i:(i+config['batch_size'])], Y_train[i:(i+config['batch_size'])] 144 | if config['cuda']: 145 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda() 146 | else: 147 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y).type(torch.FloatTensor)) 148 | # 将模型的参数梯度设置为0 149 | optimizer.zero_grad() 150 | output = model(x.unsqueeze(1)) 151 | loss = criterion(output, y.view_as(output)) 152 | # 后向传播调整参数 153 | loss.backward() 154 | # 根据梯度更新网络参数 155 | optimizer.step() 156 | batch_idx += 1 157 | # loss.item()能够得到张量中的元素值 158 | total_loss += loss.item() 159 | 160 | cur_loss = total_loss 161 | print('Train Epoch: {:2d}\t Learning rate: {:.4f}\t Loss: {:.6f} \n '.format( 162 | epoch, config['learning_rate'], cur_loss/batch_idx)) 163 | 164 | def evaluate(model): 165 | model.eval() 166 | batch_idx = 1 167 | total_loss = 0 168 | pred = np.array([]) 169 | batch_size = config['batch_size'] 170 | global min_mae 171 | for i in range(0, X_test.shape[0], batch_size): 172 | if i + batch_size > X_test.shape[0]: 173 | x, y = X_test[i:], Y_test[i:] 174 | else: 175 | x, y = X_test[i:(i+batch_size)], Y_test[i:(i+batch_size)] 176 | if False: 177 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda() 178 | else: 179 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y).type(torch.FloatTensor)) 180 | with torch.no_grad(): 181 | output = model(x.unsqueeze(1)) 182 | loss = criterion(output, torch.tensor(y).view_as(output)) 183 | pred = np.hstack((pred, output.flatten().numpy())) 184 | total_loss += loss.item() 185 | 186 | # print(Y_test, pred) 187 | mae = mean_absolute_error(Y_test, pred) 188 | rmse = np.sqrt(mean_squared_error(Y_test, pred)) 189 | print('MAE: {}\t RMSE: {}\n'.format(mae, rmse)) 190 | print('='*89) 191 | 192 | 193 | if mae < min_mae and mae < 9.41: 194 | min_mae = mae 195 | save(model, 'cnn_reg_avid_{:.2f}.pt'.format(mae)) 196 | return total_loss 197 | 198 | min_mae = 100 199 | 200 | # for ep in range(1, config['epochs']): 201 | # train(ep) 202 | # tloss = evaluate(model) 203 | 204 | model = torch.load('/Users/apple/Downloads/depression/cnn_reg_avid_9.30.pt.pt') 205 | # model = BiLSTM(config) 206 | # model.load_state_dict(lstm_model.state_dict()) 207 | evaluate(model) 208 | 209 | -------------------------------------------------------------------------------- /regression/BiLSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 220, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "(36, 1024)\n", 15 | "(49, 1024)\n", 16 | "(46, 1024)\n", 17 | "(47, 1024)\n", 18 | "(42, 1024)\n", 19 | "(43, 1024)\n", 20 | "(48, 1024)\n", 21 | "(55, 1024)\n", 22 | "(50, 1024)\n", 23 | "(38, 1024)\n", 24 | "(44, 1024)\n", 25 | "(63, 1024)\n", 26 | "(59, 1024)\n", 27 | "(54, 1024)\n", 28 | "(44, 1024)\n", 29 | "(38, 1024)\n", 30 | "(56, 1024)\n", 31 | "(55, 1024)\n", 32 | "(41, 1024)\n", 33 | "(48, 1024)\n", 34 | "(46, 1024)\n", 35 | "(56, 1024)\n", 36 | "(41, 1024)\n", 37 | "(49, 1024)\n", 38 | "(48, 1024)\n", 39 | "(47, 1024)\n", 40 | "(63, 1024)\n", 41 | "(55, 1024)\n", 42 | "(37, 1024)\n", 43 | "(50, 1024)\n", 44 | "(56, 1024)\n", 45 | "(44, 1024)\n", 46 | "(51, 1024)\n", 47 | "(32, 1024)\n", 48 | "(43, 1024)\n", 49 | "(38, 1024)\n", 50 | "(37, 1024)\n", 51 | "(30, 1024)\n", 52 | "(41, 1024)\n", 53 | "(33, 1024)\n", 54 | "(42, 1024)\n", 55 | "(33, 1024)\n", 56 | "(42, 1024)\n", 57 | "(34, 1024)\n", 58 | "(28, 1024)\n", 59 | "(31, 1024)\n", 60 | "(25, 1024)\n", 61 | "(54, 1024)\n", 62 | "(43, 1024)\n", 63 | "(39, 1024)\n", 64 | "(46, 1024)\n", 65 | "(40, 1024)\n", 66 | "(35, 1024)\n", 67 | "(38, 1024)\n", 68 | "(34, 1024)\n", 69 | "(48, 1024)\n", 70 | "(42, 1024)\n", 71 | "(47, 1024)\n", 72 | "(43, 1024)\n", 73 | "(39, 1024)\n", 74 | "(54, 1024)\n", 75 | "(53, 1024)\n", 76 | "(43, 1024)\n", 77 | "(41, 1024)\n", 78 | "(29, 1024)\n", 79 | "(53, 1024)\n", 80 | "(35, 1024)\n", 81 | "(44, 1024)\n", 82 | "(51, 1024)\n", 83 | "(37, 1024)\n", 84 | "(51, 1024)\n", 85 | "(40, 1024)\n", 86 | "(48, 1024)\n", 87 | "(47, 1024)\n", 88 | "(48, 1024)\n", 89 | "(51, 1024)\n", 90 | "(46, 1024)\n", 91 | "(47, 1024)\n", 92 | "(37, 1024)\n", 93 | "(44, 1024)\n", 94 | "(68, 1024)\n", 95 | "(52, 1024)\n", 96 | "(37, 1024)\n", 97 | "(43, 1024)\n", 98 | "(34, 1024)\n", 99 | "(38, 1024)\n", 100 | "(32, 1024)\n", 101 | "(33, 1024)\n", 102 | "(64, 1024)\n", 103 | "(34, 1024)\n", 104 | "(48, 1024)\n", 105 | "(47, 1024)\n", 106 | "(59, 1024)\n", 107 | "(48, 1024)\n", 108 | "(40, 1024)\n", 109 | "(35, 1024)\n", 110 | "(42, 1024)\n", 111 | "(54, 1024)\n", 112 | "(45, 1024)\n", 113 | "(37, 1024)\n", 114 | "(49, 1024)\n", 115 | "(54, 1024)\n", 116 | "(48, 1024)\n", 117 | "(43, 1024)\n", 118 | "(35, 1024)\n", 119 | "(35, 1024)\n", 120 | "(50, 1024)\n", 121 | "(38, 1024)\n", 122 | "(35, 1024)\n", 123 | "(42, 1024)\n", 124 | "(46, 1024)\n", 125 | "(42, 1024)\n", 126 | "(31, 1024)\n", 127 | "(43, 1024)\n", 128 | "(45, 1024)\n", 129 | "(47, 1024)\n", 130 | "(45, 1024)\n", 131 | "(62, 1024)\n", 132 | "(48, 1024)\n", 133 | "(46, 1024)\n", 134 | "(37, 1024)\n", 135 | "(43, 1024)\n", 136 | "(39, 1024)\n", 137 | "(41, 1024)\n", 138 | "(38, 1024)\n", 139 | "(42, 1024)\n", 140 | "(55, 1024)\n", 141 | "(33, 1024)\n", 142 | "(42, 1024)\n", 143 | "(33, 1024)\n", 144 | "(51, 1024)\n", 145 | "(30, 1024)\n", 146 | "(40, 1024)\n", 147 | "(43, 1024)\n", 148 | "(44, 1024)\n", 149 | "(39, 1024)\n", 150 | "(38, 1024)\n", 151 | "(43, 1024)\n", 152 | "(39, 1024)\n", 153 | "(49, 1024)\n", 154 | "(49, 1024)\n", 155 | "(40, 1024)\n", 156 | "(107,) (35,) (107,) (35,)\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "import numpy as np\n", 162 | "import pandas as pd\n", 163 | "import wave\n", 164 | "import librosa\n", 165 | "import re\n", 166 | "from allennlp.commands.elmo import ElmoEmbedder\n", 167 | "# from bert_serving.client import BertClient\n", 168 | "\n", 169 | "prefix = '/Users/apple/Downloads/depression/'\n", 170 | "\n", 171 | "elmo = ElmoEmbedder()\n", 172 | "# bc = BertClient(ip='100.66.165.12')\n", 173 | "\n", 174 | "train_split_df = pd.read_csv(prefix+'train_split_Depression_AVEC2017 (1).csv')\n", 175 | "test_split_df = pd.read_csv(prefix+'dev_split_Depression_AVEC2017.csv')\n", 176 | "train_split_num = train_split_df[['Participant_ID']]['Participant_ID'].tolist()\n", 177 | "test_split_num = test_split_df[['Participant_ID']]['Participant_ID'].tolist()\n", 178 | "# train_split_clabel = train_split_df[['PHQ8_Binary']]['PHQ8_Binary'].tolist()\n", 179 | "# test_split_clabel = test_split_df[['PHQ8_Binary']]['PHQ8_Binary'].tolist()\n", 180 | "train_split_rlabel = train_split_df[['PHQ8_Score']]['PHQ8_Score'].tolist()\n", 181 | "test_split_rlabel = test_split_df[['PHQ8_Score']]['PHQ8_Score'].tolist()\n", 182 | "\n", 183 | "topics = []\n", 184 | "with open('/Users/apple/Downloads/depression/queries.txt', 'r') as f:\n", 185 | " for line in f.readlines():\n", 186 | " topics.append(line.strip('\\n').strip())\n", 187 | " \n", 188 | "\n", 189 | "def identify_topics(sentence):\n", 190 | " if sentence in topics:\n", 191 | " return True\n", 192 | " return False\n", 193 | "\n", 194 | "def extract_features(number, text_features, target, mode, text_targets):\n", 195 | " \n", 196 | " transcript = pd.read_csv(prefix+'{0}_P/{0}_TRANSCRIPT.csv'.format(number), sep='\\t').fillna('')\n", 197 | " \n", 198 | " \n", 199 | " time_range = []\n", 200 | " responses = []\n", 201 | " response = ''\n", 202 | " response_flag = False\n", 203 | " start_time = 0\n", 204 | " stop_time = 0\n", 205 | "\n", 206 | " signal = []\n", 207 | " \n", 208 | " global counter1, counter2\n", 209 | "\n", 210 | " for t in transcript.itertuples():\n", 211 | " # participant一句话结束\n", 212 | " if getattr(t,'speaker') == 'Ellie':\n", 213 | "# if '(' in getattr(t,'value'):\n", 214 | "# content = re.findall(re.compile(r'[(](.*?)[)]', re.S), getattr(t,'value'))[0]\n", 215 | "# print(content)\n", 216 | "# else:\n", 217 | "# content = getattr(t,'value').strip()\n", 218 | " content = getattr(t,'value').strip()\n", 219 | " if identify_topics(content):\n", 220 | " response_flag = True\n", 221 | " if len(response) != 0:\n", 222 | " responses.append(response.strip())\n", 223 | " response = ''\n", 224 | " elif response_flag and len(content.split()) > 4:\n", 225 | " response_flag = False\n", 226 | " if len(response) != 0:\n", 227 | " responses.append(response)\n", 228 | " response = ''\n", 229 | " elif getattr(t,'speaker') == 'Participant':\n", 230 | " if 'scrubbed_entry' in getattr(t,'value'):\n", 231 | " continue\n", 232 | " elif response_flag:\n", 233 | " content = getattr(t,'value').split('\\n')[0].strip()\n", 234 | "# if '<' in getattr(t,'value'):\n", 235 | "# content = re.sub(u\"\\\\<.*?\\\\>\", \"\", content)\n", 236 | " response+=' '+content\n", 237 | " \n", 238 | " text_feature = elmo.embed_sentence(responses).mean(0)\n", 239 | "# text_feature = bc.encode(responses)\n", 240 | "# while text_feature.shape[0] < 30:\n", 241 | "# print(number)\n", 242 | "# text_feature = np.vstack((text_feature, np.zeros(text_feature.shape[1])))\n", 243 | " print(text_feature.shape)\n", 244 | "# text_features.append(text_feature[:30])\n", 245 | " text_features.append(text_feature)\n", 246 | " text_targets.append(target)\n", 247 | " \n", 248 | "def extract_features1(number, text_features, target, mode, text_targets):\n", 249 | " \n", 250 | " transcript = pd.read_csv(prefix+'{0}_P/{0}_TRANSCRIPT.csv'.format(number), sep='\\t').fillna('')\n", 251 | " \n", 252 | " \n", 253 | " time_range = []\n", 254 | " responses = []\n", 255 | " response = ''\n", 256 | " response_flag = False\n", 257 | " start_time = 0\n", 258 | " stop_time = 0\n", 259 | "\n", 260 | " signal = []\n", 261 | " \n", 262 | " global counter1, counter2\n", 263 | "\n", 264 | " for t in transcript.itertuples():\n", 265 | " # participant一句话结束\n", 266 | " if getattr(t,'speaker') == 'Ellie':\n", 267 | " if '(' in getattr(t,'value'):\n", 268 | " content = re.findall(re.compile(r'[(](.*?)[)]', re.S), getattr(t,'value'))[0]\n", 269 | " else:\n", 270 | " content = getattr(t,'value').strip()\n", 271 | " content = getattr(t,'value').strip()\n", 272 | " if identify_topics(content):\n", 273 | " response_flag = True\n", 274 | " if len(response) != 0:\n", 275 | " responses.append(response.strip())\n", 276 | " response = ''\n", 277 | " elif response_flag and len(content.split()) > 4:\n", 278 | " response_flag = False\n", 279 | " if len(response) != 0:\n", 280 | " responses.append(response)\n", 281 | " response = ''\n", 282 | " elif getattr(t,'speaker') == 'Participant':\n", 283 | " if 'scrubbed_entry' in getattr(t,'value'):\n", 284 | " continue\n", 285 | " elif response_flag:\n", 286 | " response+=' '+getattr(t,'value').split('\\n')[0].strip()\n", 287 | " \n", 288 | " text_feature = elmo.embed_sentence(responses).mean(0)\n", 289 | " print(text_feature.shape)\n", 290 | " text_features.append(text_feature)\n", 291 | " if target == 1:\n", 292 | " counter1 += len(text_feature)\n", 293 | " else:\n", 294 | " counter2 += len(text_feature)\n", 295 | " text_targets.append(target)\n", 296 | "\n", 297 | "counter1 = 0\n", 298 | "counter2 = 0\n", 299 | " \n", 300 | "# training set\n", 301 | "text_features_train = []\n", 302 | "text_ctargets_train = []\n", 303 | "\n", 304 | "# test set\n", 305 | "text_features_test = []\n", 306 | "text_ctargets_test = []\n", 307 | "\n", 308 | "# ======================= classification =======================\n", 309 | "\n", 310 | "# training set\n", 311 | "# for index in range(len(train_split_num)):\n", 312 | "# extract_features(train_split_num[index], text_features_train, train_split_clabel[index], 'train', text_ctargets_train)\n", 313 | " \n", 314 | "# # test set\n", 315 | "# for index in range(len(test_split_num)):\n", 316 | "# extract_features(test_split_num[index], text_features_test, test_split_clabel[index], 'test', text_ctargets_test)\n", 317 | "\n", 318 | "# ======================= classification =======================\n", 319 | "\n", 320 | "# training set\n", 321 | "for index in range(len(train_split_num)):\n", 322 | " extract_features(train_split_num[index], text_features_train, train_split_rlabel[index], 'train', text_ctargets_train)\n", 323 | " \n", 324 | "# test set\n", 325 | "for index in range(len(test_split_num)):\n", 326 | " extract_features(test_split_num[index], text_features_test, test_split_rlabel[index], 'test', text_ctargets_test)\n", 327 | "print(np.shape(text_features_train), np.shape(text_features_test), np.shape(text_ctargets_train), np.shape(text_ctargets_test))\n" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "# print(\"Saving npz file locally...\")\n", 337 | "\n", 338 | "# np.savez(prefix+'data/text/train_samples.npz', text_features_train)\n", 339 | "# np.savez(prefix+'data/text/train_labels.npz', text_features_test)\n", 340 | "# np.savez(prefix+'data/text/test_samples.npz', text_ctargets_train)\n", 341 | "# np.savez(prefix+'data/text/test_labels.npz', text_ctargets_test)\n", 342 | "\n", 343 | "prefix = '/Users/apple/Downloads/depression/'\n", 344 | "\n", 345 | "text_features_train = np.load(prefix+'data/text/train_samples.npz')['arr_0']\n", 346 | "text_features_test = np.load(prefix+'data/text/train_labels.npz')['arr_0']\n", 347 | "text_ctargets_train = np.load(prefix+'data/text/test_samples.npz')['arr_0']\n", 348 | "text_ctargets_test = np.load(prefix+'data/text/test_labels.npz')['arr_0']" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 221, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "import torch\n", 358 | "import torch.nn as nn\n", 359 | "from torch.autograd import Variable\n", 360 | "from torch.nn import functional as F\n", 361 | "import torch.optim as optim\n", 362 | "from sklearn.metrics import confusion_matrix" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 222, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "class BiLSTM(nn.Module):\n", 372 | " \n", 373 | " def __init__(self, config):\n", 374 | " super(BiLSTM, self).__init__()\n", 375 | " self.num_classes = config['num_classes']\n", 376 | " self.learning_rate = config['learning_rate']\n", 377 | " self.dropout = config['dropout']\n", 378 | " self.hidden_dims = config['hidden_dims']\n", 379 | " self.rnn_layers = config['rnn_layers']\n", 380 | " self.embedding_size = config['embedding_size']\n", 381 | " self.bidirectional = config['bidirectional']\n", 382 | "\n", 383 | " self.build_model()\n", 384 | " self.init_weight()\n", 385 | " \n", 386 | " def init_weight(net):\n", 387 | " for name, param in net.named_parameters():\n", 388 | " if 'bias' in name:\n", 389 | " nn.init.constant_(param, 0.0)\n", 390 | " elif 'weight' in name:\n", 391 | " nn.init.xavier_uniform_(param)\n", 392 | "\n", 393 | " def build_model(self):\n", 394 | " # attention layer\n", 395 | " self.attention_layer = nn.Sequential(\n", 396 | " nn.Linear(self.hidden_dims, self.hidden_dims),\n", 397 | " nn.ReLU(inplace=True)\n", 398 | " )\n", 399 | " # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1)\n", 400 | "\n", 401 | " # 双层lstm\n", 402 | " self.lstm_net = nn.LSTM(self.embedding_size, self.hidden_dims,\n", 403 | " num_layers=self.rnn_layers, dropout=self.dropout,\n", 404 | " bidirectional=self.bidirectional)\n", 405 | " \n", 406 | "# self.init_weight()\n", 407 | " \n", 408 | " # FC层\n", 409 | "# self.fc_out = nn.Linear(self.hidden_dims, self.num_classes)\n", 410 | " self.fc_out = nn.Sequential(\n", 411 | " nn.Dropout(self.dropout),\n", 412 | " nn.Linear(self.hidden_dims, self.hidden_dims),\n", 413 | " nn.ReLU(inplace=True),\n", 414 | " nn.Dropout(self.dropout),\n", 415 | " nn.Linear(self.hidden_dims, self.num_classes),\n", 416 | " nn.ReLU(),\n", 417 | " )\n", 418 | "\n", 419 | " def attention_net_with_w(self, lstm_out, lstm_hidden):\n", 420 | " '''\n", 421 | " :param lstm_out: [batch_size, len_seq, n_hidden * 2]\n", 422 | " :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden]\n", 423 | " :return: [batch_size, n_hidden]\n", 424 | " '''\n", 425 | " lstm_tmp_out = torch.chunk(lstm_out, 2, -1)\n", 426 | " # h [batch_size, time_step, hidden_dims]\n", 427 | " h = lstm_tmp_out[0] + lstm_tmp_out[1]\n", 428 | "# h = lstm_out\n", 429 | " # [batch_size, num_layers * num_directions, n_hidden]\n", 430 | " lstm_hidden = torch.sum(lstm_hidden, dim=1)\n", 431 | " # [batch_size, 1, n_hidden]\n", 432 | " lstm_hidden = lstm_hidden.unsqueeze(1)\n", 433 | " # atten_w [batch_size, 1, hidden_dims]\n", 434 | " atten_w = self.attention_layer(lstm_hidden)\n", 435 | " # m [batch_size, time_step, hidden_dims]\n", 436 | " m = nn.Tanh()(h)\n", 437 | " # atten_context [batch_size, 1, time_step]\n", 438 | " atten_context = torch.bmm(atten_w, m.transpose(1, 2))\n", 439 | " # softmax_w [batch_size, 1, time_step]\n", 440 | " softmax_w = F.softmax(atten_context, dim=-1)\n", 441 | " # context [batch_size, 1, hidden_dims]\n", 442 | " context = torch.bmm(softmax_w, h)\n", 443 | " result = context.squeeze(1)\n", 444 | " return result\n", 445 | "\n", 446 | " def forward(self, x):\n", 447 | " \n", 448 | " # x : [len_seq, batch_size, embedding_dim]\n", 449 | " x = x.permute(1, 0, 2)\n", 450 | " output, (final_hidden_state, final_cell_state) = self.lstm_net(x)\n", 451 | " # output : [batch_size, len_seq, n_hidden * 2]\n", 452 | " output = output.permute(1, 0, 2)\n", 453 | " # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden]\n", 454 | " final_hidden_state = final_hidden_state.permute(1, 0, 2)\n", 455 | " # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True)\n", 456 | " # atten_out = self.attention_net(output, final_hidden_state)\n", 457 | " atten_out = self.attention_net_with_w(output, final_hidden_state)\n", 458 | " return self.fc_out(atten_out)\n", 459 | " \n", 460 | " " 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 284, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "# data imbalance\n", 470 | "X_train = []\n", 471 | "Y_train = []\n", 472 | "X_test = []\n", 473 | "Y_test = []\n", 474 | "\n", 475 | "counter = 0\n", 476 | "\n", 477 | "cut = 10\n", 478 | "debt = 0\n", 479 | "\n", 480 | "for i in range(len(text_features_train)):\n", 481 | "# if text_ctargets_train[i] == 1:\n", 482 | " if text_ctargets_train[i] >= 10:\n", 483 | " times = 3+debt if counter < 46 else 2+debt\n", 484 | "# print(times, text_features_train[i].shape, debt)\n", 485 | " for j in range(times):\n", 486 | " if (j+1)*cut > len(text_features_train[i]):\n", 487 | " debt+=1\n", 488 | " continue\n", 489 | " X_train.append(text_features_train[i][j*cut:(j+1)*cut])\n", 490 | " Y_train.append(text_ctargets_train[i])\n", 491 | " if debt > 0:\n", 492 | " debt -= 1\n", 493 | " counter+=1\n", 494 | " else:\n", 495 | " X_train.append(text_features_train[i][:cut])\n", 496 | " Y_train.append(text_ctargets_train[i])\n", 497 | " \n", 498 | " \n", 499 | "for i in range(len(text_features_test)):\n", 500 | " X_test.append(text_features_test[i][:cut])\n", 501 | " Y_test.append(text_ctargets_test[i])\n", 502 | "\n", 503 | "# for i in range(len(text_features_train)):\n", 504 | "# if text_ctargets_train[i] == 1:\n", 505 | "# times = int(len(text_features_train[i]) / 10)\n", 506 | "# for j in range(times):\n", 507 | "# X_train.append(text_features_train[i][j*10:(j+1)*10])\n", 508 | "# Y_train.append(text_ctargets_train[i])\n", 509 | "# counter+=1\n", 510 | "# else:\n", 511 | "# times = \n", 512 | "# X_train.append(text_features_train[i][:10])\n", 513 | "# Y_train.append(text_ctargets_train[i])\n", 514 | " \n", 515 | " \n", 516 | "# for i in range(len(text_features_test)):\n", 517 | "# X_test.append(text_features_test[i][:10])\n", 518 | "# Y_test.append(text_ctargets_test[i])\n", 519 | " \n", 520 | "X_train = np.array(X_train)\n", 521 | "Y_train = np.array(Y_train)\n", 522 | "X_test = np.array(X_test)\n", 523 | "Y_test = np.array(Y_test)" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 285, 529 | "metadata": {}, 530 | "outputs": [ 531 | { 532 | "data": { 533 | "text/plain": [ 534 | "(154, 10, 1024)" 535 | ] 536 | }, 537 | "execution_count": 285, 538 | "metadata": {}, 539 | "output_type": "execute_result" 540 | } 541 | ], 542 | "source": [ 543 | "X_train.shape" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 340, 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "\n", 553 | "config = {\n", 554 | " 'num_classes': 1,\n", 555 | " 'dropout': 0.5,\n", 556 | " 'rnn_layers': 2,\n", 557 | " 'embedding_size': 1024,\n", 558 | " 'batch_size': 8,\n", 559 | " 'epochs': 200,\n", 560 | " 'learning_rate': 5e-4,\n", 561 | " 'hidden_dims': 128,\n", 562 | " 'bidirectional': True\n", 563 | "}\n", 564 | "\n", 565 | "model = BiLSTM(config)\n", 566 | "\n", 567 | "# if args.cuda:\n", 568 | "# model = model.cuda()\n", 569 | "# X_train = X_train.cuda()\n", 570 | "# Y_train = Y_train.cuda()\n", 571 | "# X_test = X_test.cuda()\n", 572 | "# Y_test = Y_test.cuda()\n", 573 | "\n", 574 | "optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])\n", 575 | "# optimizer = optim.Adam(model.parameters())\n", 576 | "# criterion = nn.CrossEntropyLoss()\n", 577 | "criterion = nn.SmoothL1Loss()\n", 578 | "max_f1 = -1\n", 579 | "max_acc = -1\n", 580 | "train_acc = -1\n", 581 | "min_mae = 100\n", 582 | "\n", 583 | "def save(model, filename):\n", 584 | " save_filename = '{}.pt'.format(filename)\n", 585 | " torch.save(model, save_filename)\n", 586 | " print('Saved as %s' % save_filename)\n", 587 | " \n", 588 | "def standard_confusion_matrix(y_test, y_test_pred):\n", 589 | " \"\"\"\n", 590 | " Make confusion matrix with format:\n", 591 | " -----------\n", 592 | " | TP | FP |\n", 593 | " -----------\n", 594 | " | FN | TN |\n", 595 | " -----------\n", 596 | " Parameters\n", 597 | " ----------\n", 598 | " y_true : ndarray - 1D\n", 599 | " y_pred : ndarray - 1D\n", 600 | "\n", 601 | " Returns\n", 602 | " -------\n", 603 | " ndarray - 2D\n", 604 | " \"\"\"\n", 605 | " [[tn, fp], [fn, tp]] = confusion_matrix(y_test, y_test_pred)\n", 606 | " return np.array([[tp, fp], [fn, tn]])\n", 607 | "\n", 608 | "def model_performance(y_test, y_test_pred_proba):\n", 609 | " \"\"\"\n", 610 | " Evaluation metrics for network performance.\n", 611 | " \"\"\"\n", 612 | " y_test_pred = y_test_pred_proba.data.max(1, keepdim=True)[1]\n", 613 | "\n", 614 | " # Computing confusion matrix for test dataset\n", 615 | " conf_matrix = standard_confusion_matrix(y_test, y_test_pred)\n", 616 | " print(\"Confusion Matrix:\")\n", 617 | " print(conf_matrix)\n", 618 | "\n", 619 | " return y_test_pred, conf_matrix\n", 620 | "\n", 621 | "def plot_roc_curve(y_test, y_score):\n", 622 | " \"\"\"\n", 623 | " Plots ROC curve for final trained model. Code taken from:\n", 624 | " https://vkolachalama.blogspot.com/2016/05/keras-implementation-of-mlp-neural.html\n", 625 | " \"\"\"\n", 626 | " fpr, tpr, _ = roc_curve(y_test, y_score)\n", 627 | " roc_auc = auc(fpr, tpr)\n", 628 | " plt.figure()\n", 629 | " plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)\n", 630 | " plt.plot([0, 1], [0, 1], 'k--')\n", 631 | " plt.xlim([0.0, 1.05])\n", 632 | " plt.ylim([0.0, 1.05])\n", 633 | " plt.xlabel('False Positive Rate')\n", 634 | " plt.ylabel('True Positive Rate')\n", 635 | " plt.title('Receiver operating characteristic curve')\n", 636 | " plt.legend(loc=\"lower right\")\n", 637 | " plt.savefig(prefix+'images/BiLSTM_roc.png')\n", 638 | " plt.close()\n", 639 | "\n", 640 | "\n", 641 | "def train(epoch):\n", 642 | " global lr, train_acc\n", 643 | " model.train()\n", 644 | " batch_idx = 1\n", 645 | " total_loss = 0\n", 646 | " correct = 0\n", 647 | " for i in range(0, X_train.shape[0], config['batch_size']):\n", 648 | " if i + config['batch_size'] > X_train.shape[0]:\n", 649 | " x, y = X_train[i:], Y_train[i:]\n", 650 | " else:\n", 651 | " x, y = X_train[i:(i+config['batch_size'])], Y_train[i:(i+config['batch_size'])]\n", 652 | " if False:\n", 653 | " x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda()\n", 654 | " else:\n", 655 | "# x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y))\n", 656 | " x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y)).type(torch.FloatTensor)\n", 657 | " # 将模型的参数梯度设置为0\n", 658 | " optimizer.zero_grad()\n", 659 | " output = model(x)\n", 660 | "# pred = output.data.max(1, keepdim=True)[1]\n", 661 | "# correct += pred.eq(y.data.view_as(pred)).cpu().sum()\n", 662 | " loss = criterion(output.flatten(), y)\n", 663 | " # 后向传播调整参数\n", 664 | " loss.backward()\n", 665 | " # 根据梯度更新网络参数\n", 666 | " optimizer.step()\n", 667 | " batch_idx += 1\n", 668 | " # loss.item()能够得到张量中的元素值\n", 669 | " total_loss += loss.item()\n", 670 | " \n", 671 | "# train_acc = correct\n", 672 | " train_acc = total_loss/batch_idx\n", 673 | " cur_loss = total_loss\n", 674 | "# print('Train Epoch: {:2d}\\t Learning rate: {:.4f}\\tLoss: {:.6f}\\t Accuracy: {}/{} ({:.0f}%)\\n '.format(\n", 675 | "# epoch+1, config['learning_rate'], cur_loss, correct, len(X_train),\n", 676 | "# 100. * correct / len(X_train)))\n", 677 | " print('Train Epoch: {:2d}\\t Learning rate: {:.4f}\\t Loss: {:.6f}\\t '.format(\n", 678 | " epoch+1, config['learning_rate'], total_loss/batch_idx))\n", 679 | "\n", 680 | "\n", 681 | "def evaluate(model):\n", 682 | " model.eval()\n", 683 | " batch_idx = 1\n", 684 | " total_loss = 0\n", 685 | " global max_f1, max_acc, min_mae\n", 686 | " with torch.no_grad():\n", 687 | " x, y = Variable(torch.from_numpy(X_test).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(Y_test)).type(torch.FloatTensor)\n", 688 | " optimizer.zero_grad()\n", 689 | " output = model(x)\n", 690 | " loss = criterion(output.flatten(), y)\n", 691 | " total_loss += loss.item()\n", 692 | "# print(y, output)\n", 693 | "# y_test_pred, conf_matrix = model_performance(y, output)\n", 694 | " print('\\nTest set: Average loss: {:.4f} \\t MAE: {:.4f}\\n'.format(total_loss, F.l1_loss(output.flatten(), y)))\n", 695 | " \n", 696 | "# # custom evaluation metrics\n", 697 | "# print('Calculating additional test metrics...')\n", 698 | "# accuracy = float(conf_matrix[0][0] + conf_matrix[1][1]) / np.sum(conf_matrix)\n", 699 | "# precision = float(conf_matrix[0][0]) / (conf_matrix[0][0] + conf_matrix[0][1])\n", 700 | "# recall = float(conf_matrix[0][0]) / (conf_matrix[0][0] + conf_matrix[1][0])\n", 701 | "# f1_score = 2 * (precision * recall) / (precision + recall)\n", 702 | "# print(\"Accuracy: {}\".format(accuracy))\n", 703 | "# print(\"Precision: {}\".format(precision))\n", 704 | "# print(\"Recall: {}\".format(recall))\n", 705 | "# print(\"F1-Score: {}\\n\".format(f1_score))\n", 706 | "# print('='*89)\n", 707 | " \n", 708 | "# if max_f1 <= f1_score and train_acc > 151:\n", 709 | "# max_f1 = f1_score\n", 710 | "# max_acc = accuracy\n", 711 | "# save(model, 'BiLSTM_elmo_{}_{:.2f}'.format(config['hidden_dims'], max_f1)) \n", 712 | "# print('*'*64)\n", 713 | "# print('model saved: f1: {}\\tacc: {}'.format(max_f1, max_acc))\n", 714 | "# print('*'*64)\n", 715 | " if min_mae >= F.l1_loss(output.flatten(), y) and train_acc < 2.0:\n", 716 | " min_mae = F.l1_loss(output.flatten(), y)\n", 717 | " save(model, 'BiLSTM_reg_{}_{:.2f}'.format(config['hidden_dims'], min_mae)) \n", 718 | " print('*'*64)\n", 719 | " print('model saved: f1: {}\\tacc: {}'.format(min_mae, train_acc))\n", 720 | " print('*'*64)\n", 721 | "\n", 722 | " return total_loss\n" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 341, 728 | "metadata": { 729 | "scrolled": true 730 | }, 731 | "outputs": [ 732 | { 733 | "name": "stdout", 734 | "output_type": "stream", 735 | "text": [ 736 | "Train Epoch: 2\t Learning rate: 0.0005\t Loss: 6.222528\t \n", 737 | "\n", 738 | "Test set: Average loss: 5.2454 \t MAE: 5.7352\n", 739 | "\n", 740 | "Train Epoch: 3\t Learning rate: 0.0005\t Loss: 4.801173\t \n", 741 | "\n", 742 | "Test set: Average loss: 5.5409 \t MAE: 6.0273\n", 743 | "\n", 744 | "Train Epoch: 4\t Learning rate: 0.0005\t Loss: 4.431237\t \n", 745 | "\n", 746 | "Test set: Average loss: 5.2539 \t MAE: 5.7438\n", 747 | "\n", 748 | "Train Epoch: 5\t Learning rate: 0.0005\t Loss: 4.382258\t \n", 749 | "\n", 750 | "Test set: Average loss: 5.3497 \t MAE: 5.8280\n", 751 | "\n", 752 | "Train Epoch: 6\t Learning rate: 0.0005\t Loss: 4.518691\t \n", 753 | "\n", 754 | "Test set: Average loss: 5.4297 \t MAE: 5.9101\n", 755 | "\n", 756 | "Train Epoch: 7\t Learning rate: 0.0005\t Loss: 4.425840\t \n", 757 | "\n", 758 | "Test set: Average loss: 5.3102 \t MAE: 5.7929\n", 759 | "\n", 760 | "Train Epoch: 8\t Learning rate: 0.0005\t Loss: 4.562018\t \n", 761 | "\n", 762 | "Test set: Average loss: 5.0805 \t MAE: 5.5750\n", 763 | "\n", 764 | "Train Epoch: 9\t Learning rate: 0.0005\t Loss: 4.654327\t \n", 765 | "\n", 766 | "Test set: Average loss: 5.2846 \t MAE: 5.7715\n", 767 | "\n", 768 | "Train Epoch: 10\t Learning rate: 0.0005\t Loss: 4.565650\t \n", 769 | "\n", 770 | "Test set: Average loss: 5.2604 \t MAE: 5.7470\n", 771 | "\n", 772 | "Train Epoch: 11\t Learning rate: 0.0005\t Loss: 4.451685\t \n", 773 | "\n", 774 | "Test set: Average loss: 5.2530 \t MAE: 5.7409\n", 775 | "\n", 776 | "Train Epoch: 12\t Learning rate: 0.0005\t Loss: 4.512582\t \n", 777 | "\n", 778 | "Test set: Average loss: 5.1506 \t MAE: 5.6465\n", 779 | "\n", 780 | "Train Epoch: 13\t Learning rate: 0.0005\t Loss: 4.443742\t \n", 781 | "\n", 782 | "Test set: Average loss: 5.1558 \t MAE: 5.6496\n", 783 | "\n", 784 | "Train Epoch: 14\t Learning rate: 0.0005\t Loss: 4.298081\t \n", 785 | "\n", 786 | "Test set: Average loss: 4.8663 \t MAE: 5.3347\n", 787 | "\n", 788 | "Train Epoch: 15\t Learning rate: 0.0005\t Loss: 3.788248\t \n", 789 | "\n", 790 | "Test set: Average loss: 4.2885 \t MAE: 4.7878\n", 791 | "\n", 792 | "Train Epoch: 16\t Learning rate: 0.0005\t Loss: 3.246390\t \n", 793 | "\n", 794 | "Test set: Average loss: 4.1861 \t MAE: 4.6585\n", 795 | "\n", 796 | "Train Epoch: 17\t Learning rate: 0.0005\t Loss: 4.021637\t \n", 797 | "\n", 798 | "Test set: Average loss: 3.8039 \t MAE: 4.2900\n", 799 | "\n", 800 | "Train Epoch: 18\t Learning rate: 0.0005\t Loss: 3.298809\t \n", 801 | "\n", 802 | "Test set: Average loss: 3.8797 \t MAE: 4.3641\n", 803 | "\n", 804 | "Train Epoch: 19\t Learning rate: 0.0005\t Loss: 2.909718\t \n", 805 | "\n", 806 | "Test set: Average loss: 3.6439 \t MAE: 4.1346\n", 807 | "\n", 808 | "Train Epoch: 20\t Learning rate: 0.0005\t Loss: 2.991214\t \n", 809 | "\n", 810 | "Test set: Average loss: 3.5464 \t MAE: 4.0234\n", 811 | "\n", 812 | "Train Epoch: 21\t Learning rate: 0.0005\t Loss: 2.892707\t \n", 813 | "\n", 814 | "Test set: Average loss: 3.1266 \t MAE: 3.5861\n", 815 | "\n", 816 | "Train Epoch: 22\t Learning rate: 0.0005\t Loss: 2.806894\t \n", 817 | "\n", 818 | "Test set: Average loss: 3.3687 \t MAE: 3.8398\n", 819 | "\n", 820 | "Train Epoch: 23\t Learning rate: 0.0005\t Loss: 2.462016\t \n", 821 | "\n", 822 | "Test set: Average loss: 3.2983 \t MAE: 3.7604\n", 823 | "\n", 824 | "Train Epoch: 24\t Learning rate: 0.0005\t Loss: 2.452944\t \n", 825 | "\n", 826 | "Test set: Average loss: 3.7188 \t MAE: 4.2055\n", 827 | "\n", 828 | "Train Epoch: 25\t Learning rate: 0.0005\t Loss: 2.329042\t \n", 829 | "\n", 830 | "Test set: Average loss: 3.4484 \t MAE: 3.8869\n", 831 | "\n", 832 | "Train Epoch: 26\t Learning rate: 0.0005\t Loss: 2.495918\t \n", 833 | "\n", 834 | "Test set: Average loss: 3.8797 \t MAE: 4.3500\n", 835 | "\n", 836 | "Train Epoch: 27\t Learning rate: 0.0005\t Loss: 2.272594\t \n", 837 | "\n", 838 | "Test set: Average loss: 3.8090 \t MAE: 4.2980\n", 839 | "\n", 840 | "Train Epoch: 28\t Learning rate: 0.0005\t Loss: 2.619707\t \n", 841 | "\n", 842 | "Test set: Average loss: 3.6732 \t MAE: 4.1515\n", 843 | "\n", 844 | "Train Epoch: 29\t Learning rate: 0.0005\t Loss: 2.312229\t \n", 845 | "\n", 846 | "Test set: Average loss: 3.6419 \t MAE: 4.1090\n", 847 | "\n", 848 | "Train Epoch: 30\t Learning rate: 0.0005\t Loss: 2.056890\t \n", 849 | "\n", 850 | "Test set: Average loss: 4.1940 \t MAE: 4.6646\n", 851 | "\n", 852 | "Train Epoch: 31\t Learning rate: 0.0005\t Loss: 2.042540\t \n", 853 | "\n", 854 | "Test set: Average loss: 3.8574 \t MAE: 4.3044\n", 855 | "\n", 856 | "Train Epoch: 32\t Learning rate: 0.0005\t Loss: 2.110537\t \n", 857 | "\n", 858 | "Test set: Average loss: 3.8590 \t MAE: 4.3267\n", 859 | "\n", 860 | "Train Epoch: 33\t Learning rate: 0.0005\t Loss: 2.181078\t \n", 861 | "\n", 862 | "Test set: Average loss: 4.0800 \t MAE: 4.5380\n", 863 | "\n", 864 | "Train Epoch: 34\t Learning rate: 0.0005\t Loss: 1.838920\t \n", 865 | "\n", 866 | "Test set: Average loss: 3.6816 \t MAE: 4.1475\n", 867 | "\n", 868 | "Saved as BiLSTM_reg_128_4.15.pt\n", 869 | "****************************************************************\n", 870 | "model saved: f1: 4.147465705871582\tacc: 1.8389204747620083\n", 871 | "****************************************************************\n", 872 | "Train Epoch: 35\t Learning rate: 0.0005\t Loss: 1.695412\t \n", 873 | "\n", 874 | "Test set: Average loss: 4.4402 \t MAE: 4.9363\n", 875 | "\n", 876 | "Train Epoch: 36\t Learning rate: 0.0005\t Loss: 2.076556\t \n", 877 | "\n", 878 | "Test set: Average loss: 4.0873 \t MAE: 4.5635\n", 879 | "\n", 880 | "Train Epoch: 37\t Learning rate: 0.0005\t Loss: 1.879723\t \n", 881 | "\n", 882 | "Test set: Average loss: 4.5997 \t MAE: 5.0551\n", 883 | "\n", 884 | "Train Epoch: 38\t Learning rate: 0.0005\t Loss: 1.973436\t \n", 885 | "\n", 886 | "Test set: Average loss: 4.1427 \t MAE: 4.6181\n", 887 | "\n", 888 | "Train Epoch: 39\t Learning rate: 0.0005\t Loss: 1.913318\t \n", 889 | "\n", 890 | "Test set: Average loss: 4.1102 \t MAE: 4.5848\n", 891 | "\n", 892 | "Train Epoch: 40\t Learning rate: 0.0005\t Loss: 1.856847\t \n", 893 | "\n", 894 | "Test set: Average loss: 4.1139 \t MAE: 4.5657\n", 895 | "\n", 896 | "Train Epoch: 41\t Learning rate: 0.0005\t Loss: 1.771701\t \n", 897 | "\n", 898 | "Test set: Average loss: 3.9007 \t MAE: 4.3717\n", 899 | "\n", 900 | "Train Epoch: 42\t Learning rate: 0.0005\t Loss: 1.764812\t \n", 901 | "\n", 902 | "Test set: Average loss: 3.7831 \t MAE: 4.2577\n", 903 | "\n", 904 | "Train Epoch: 43\t Learning rate: 0.0005\t Loss: 1.636393\t \n", 905 | "\n", 906 | "Test set: Average loss: 3.9092 \t MAE: 4.3770\n", 907 | "\n", 908 | "Train Epoch: 44\t Learning rate: 0.0005\t Loss: 1.874798\t \n", 909 | "\n", 910 | "Test set: Average loss: 4.0795 \t MAE: 4.5496\n", 911 | "\n", 912 | "Train Epoch: 45\t Learning rate: 0.0005\t Loss: 1.580698\t \n", 913 | "\n", 914 | "Test set: Average loss: 4.1065 \t MAE: 4.5840\n", 915 | "\n", 916 | "Train Epoch: 46\t Learning rate: 0.0005\t Loss: 1.548155\t \n", 917 | "\n", 918 | "Test set: Average loss: 4.0280 \t MAE: 4.5011\n", 919 | "\n", 920 | "Train Epoch: 47\t Learning rate: 0.0005\t Loss: 1.400702\t \n", 921 | "\n", 922 | "Test set: Average loss: 4.2299 \t MAE: 4.6754\n", 923 | "\n", 924 | "Train Epoch: 48\t Learning rate: 0.0005\t Loss: 1.549356\t \n", 925 | "\n", 926 | "Test set: Average loss: 4.5074 \t MAE: 4.9787\n", 927 | "\n", 928 | "Train Epoch: 49\t Learning rate: 0.0005\t Loss: 1.574735\t \n", 929 | "\n", 930 | "Test set: Average loss: 4.5332 \t MAE: 4.9923\n", 931 | "\n", 932 | "Train Epoch: 50\t Learning rate: 0.0005\t Loss: 1.354524\t \n", 933 | "\n", 934 | "Test set: Average loss: 4.7064 \t MAE: 5.1539\n", 935 | "\n", 936 | "Train Epoch: 51\t Learning rate: 0.0005\t Loss: 2.049156\t \n", 937 | "\n", 938 | "Test set: Average loss: 4.1546 \t MAE: 4.6241\n", 939 | "\n", 940 | "Train Epoch: 52\t Learning rate: 0.0005\t Loss: 1.527975\t \n", 941 | "\n", 942 | "Test set: Average loss: 4.0918 \t MAE: 4.5543\n", 943 | "\n", 944 | "Train Epoch: 53\t Learning rate: 0.0005\t Loss: 1.346940\t \n", 945 | "\n", 946 | "Test set: Average loss: 4.0010 \t MAE: 4.4820\n", 947 | "\n", 948 | "Train Epoch: 54\t Learning rate: 0.0005\t Loss: 1.353701\t \n", 949 | "\n", 950 | "Test set: Average loss: 4.3962 \t MAE: 4.8609\n", 951 | "\n", 952 | "Train Epoch: 55\t Learning rate: 0.0005\t Loss: 1.306012\t \n", 953 | "\n", 954 | "Test set: Average loss: 4.2128 \t MAE: 4.6894\n", 955 | "\n", 956 | "Train Epoch: 56\t Learning rate: 0.0005\t Loss: 1.371004\t \n", 957 | "\n", 958 | "Test set: Average loss: 4.1537 \t MAE: 4.6346\n", 959 | "\n", 960 | "Train Epoch: 57\t Learning rate: 0.0005\t Loss: 1.368044\t \n", 961 | "\n", 962 | "Test set: Average loss: 4.2551 \t MAE: 4.7531\n", 963 | "\n", 964 | "Train Epoch: 58\t Learning rate: 0.0005\t Loss: 1.447491\t \n", 965 | "\n", 966 | "Test set: Average loss: 4.0097 \t MAE: 4.4771\n", 967 | "\n", 968 | "Train Epoch: 59\t Learning rate: 0.0005\t Loss: 1.247104\t \n", 969 | "\n", 970 | "Test set: Average loss: 3.9311 \t MAE: 4.4053\n", 971 | "\n", 972 | "Train Epoch: 60\t Learning rate: 0.0005\t Loss: 1.315407\t \n", 973 | "\n", 974 | "Test set: Average loss: 4.3192 \t MAE: 4.7980\n", 975 | "\n", 976 | "Train Epoch: 61\t Learning rate: 0.0005\t Loss: 1.247827\t \n", 977 | "\n", 978 | "Test set: Average loss: 4.2105 \t MAE: 4.6818\n", 979 | "\n", 980 | "Train Epoch: 62\t Learning rate: 0.0005\t Loss: 1.415605\t \n", 981 | "\n", 982 | "Test set: Average loss: 4.3634 \t MAE: 4.8293\n", 983 | "\n", 984 | "Train Epoch: 63\t Learning rate: 0.0005\t Loss: 1.222708\t \n", 985 | "\n", 986 | "Test set: Average loss: 4.5952 \t MAE: 5.0836\n", 987 | "\n", 988 | "Train Epoch: 64\t Learning rate: 0.0005\t Loss: 1.311415\t \n", 989 | "\n", 990 | "Test set: Average loss: 4.2124 \t MAE: 4.6954\n", 991 | "\n", 992 | "Train Epoch: 65\t Learning rate: 0.0005\t Loss: 1.487544\t \n", 993 | "\n", 994 | "Test set: Average loss: 4.0394 \t MAE: 4.4944\n", 995 | "\n", 996 | "Train Epoch: 66\t Learning rate: 0.0005\t Loss: 1.071867\t \n", 997 | "\n", 998 | "Test set: Average loss: 3.9212 \t MAE: 4.3847\n", 999 | "\n", 1000 | "Train Epoch: 67\t Learning rate: 0.0005\t Loss: 1.151054\t \n", 1001 | "\n", 1002 | "Test set: Average loss: 4.4762 \t MAE: 4.9519\n", 1003 | "\n", 1004 | "Train Epoch: 68\t Learning rate: 0.0005\t Loss: 1.309198\t \n", 1005 | "\n", 1006 | "Test set: Average loss: 4.4172 \t MAE: 4.8880\n", 1007 | "\n", 1008 | "Train Epoch: 69\t Learning rate: 0.0005\t Loss: 1.121209\t \n", 1009 | "\n", 1010 | "Test set: Average loss: 4.4038 \t MAE: 4.8850\n", 1011 | "\n", 1012 | "Train Epoch: 70\t Learning rate: 0.0005\t Loss: 1.122766\t \n", 1013 | "\n", 1014 | "Test set: Average loss: 3.9661 \t MAE: 4.4481\n", 1015 | "\n", 1016 | "Train Epoch: 71\t Learning rate: 0.0005\t Loss: 1.306642\t \n", 1017 | "\n", 1018 | "Test set: Average loss: 4.0182 \t MAE: 4.4837\n", 1019 | "\n", 1020 | "Train Epoch: 72\t Learning rate: 0.0005\t Loss: 1.271639\t \n", 1021 | "\n", 1022 | "Test set: Average loss: 4.3392 \t MAE: 4.7997\n", 1023 | "\n", 1024 | "Train Epoch: 73\t Learning rate: 0.0005\t Loss: 1.111040\t \n", 1025 | "\n", 1026 | "Test set: Average loss: 4.1592 \t MAE: 4.6452\n", 1027 | "\n", 1028 | "Train Epoch: 74\t Learning rate: 0.0005\t Loss: 1.092647\t \n", 1029 | "\n", 1030 | "Test set: Average loss: 4.5354 \t MAE: 5.0171\n", 1031 | "\n", 1032 | "Train Epoch: 75\t Learning rate: 0.0005\t Loss: 1.071747\t \n", 1033 | "\n", 1034 | "Test set: Average loss: 4.2062 \t MAE: 4.6748\n", 1035 | "\n", 1036 | "Train Epoch: 76\t Learning rate: 0.0005\t Loss: 1.139450\t \n", 1037 | "\n", 1038 | "Test set: Average loss: 3.7759 \t MAE: 4.2505\n", 1039 | "\n", 1040 | "Train Epoch: 77\t Learning rate: 0.0005\t Loss: 1.193882\t \n", 1041 | "\n", 1042 | "Test set: Average loss: 3.6220 \t MAE: 4.0822\n", 1043 | "\n", 1044 | "Saved as BiLSTM_reg_128_4.08.pt\n", 1045 | "****************************************************************\n", 1046 | "model saved: f1: 4.082216739654541\tacc: 1.1938817671367101\n", 1047 | "****************************************************************\n" 1048 | ] 1049 | }, 1050 | { 1051 | "name": "stdout", 1052 | "output_type": "stream", 1053 | "text": [ 1054 | "Train Epoch: 78\t Learning rate: 0.0005\t Loss: 1.062055\t \n", 1055 | "\n", 1056 | "Test set: Average loss: 3.9293 \t MAE: 4.4046\n", 1057 | "\n", 1058 | "Train Epoch: 79\t Learning rate: 0.0005\t Loss: 1.161488\t \n", 1059 | "\n", 1060 | "Test set: Average loss: 4.0311 \t MAE: 4.4819\n", 1061 | "\n", 1062 | "Train Epoch: 80\t Learning rate: 0.0005\t Loss: 1.155397\t \n", 1063 | "\n", 1064 | "Test set: Average loss: 4.4762 \t MAE: 4.9510\n", 1065 | "\n", 1066 | "Train Epoch: 81\t Learning rate: 0.0005\t Loss: 1.124911\t \n", 1067 | "\n", 1068 | "Test set: Average loss: 3.9555 \t MAE: 4.4218\n", 1069 | "\n", 1070 | "Train Epoch: 82\t Learning rate: 0.0005\t Loss: 1.087112\t \n", 1071 | "\n", 1072 | "Test set: Average loss: 3.9758 \t MAE: 4.4648\n", 1073 | "\n", 1074 | "Train Epoch: 83\t Learning rate: 0.0005\t Loss: 1.143451\t \n", 1075 | "\n", 1076 | "Test set: Average loss: 4.1492 \t MAE: 4.6386\n", 1077 | "\n", 1078 | "Train Epoch: 84\t Learning rate: 0.0005\t Loss: 1.021506\t \n", 1079 | "\n", 1080 | "Test set: Average loss: 4.2033 \t MAE: 4.6706\n", 1081 | "\n", 1082 | "Train Epoch: 85\t Learning rate: 0.0005\t Loss: 1.397525\t \n", 1083 | "\n", 1084 | "Test set: Average loss: 4.3981 \t MAE: 4.8657\n", 1085 | "\n", 1086 | "Train Epoch: 86\t Learning rate: 0.0005\t Loss: 0.968711\t \n", 1087 | "\n", 1088 | "Test set: Average loss: 3.7188 \t MAE: 4.2074\n", 1089 | "\n", 1090 | "Train Epoch: 87\t Learning rate: 0.0005\t Loss: 1.095073\t \n", 1091 | "\n", 1092 | "Test set: Average loss: 3.9837 \t MAE: 4.4526\n", 1093 | "\n", 1094 | "Train Epoch: 88\t Learning rate: 0.0005\t Loss: 1.061223\t \n", 1095 | "\n", 1096 | "Test set: Average loss: 4.3055 \t MAE: 4.7557\n", 1097 | "\n", 1098 | "Train Epoch: 89\t Learning rate: 0.0005\t Loss: 0.967244\t \n", 1099 | "\n", 1100 | "Test set: Average loss: 3.9507 \t MAE: 4.4074\n", 1101 | "\n", 1102 | "Train Epoch: 90\t Learning rate: 0.0005\t Loss: 1.116576\t \n", 1103 | "\n", 1104 | "Test set: Average loss: 4.2355 \t MAE: 4.7020\n", 1105 | "\n", 1106 | "Train Epoch: 91\t Learning rate: 0.0005\t Loss: 1.129589\t \n", 1107 | "\n", 1108 | "Test set: Average loss: 4.2573 \t MAE: 4.7175\n", 1109 | "\n", 1110 | "Train Epoch: 92\t Learning rate: 0.0005\t Loss: 0.998689\t \n", 1111 | "\n", 1112 | "Test set: Average loss: 4.2436 \t MAE: 4.7121\n", 1113 | "\n", 1114 | "Train Epoch: 93\t Learning rate: 0.0005\t Loss: 1.028936\t \n", 1115 | "\n", 1116 | "Test set: Average loss: 4.2015 \t MAE: 4.6581\n", 1117 | "\n", 1118 | "Train Epoch: 94\t Learning rate: 0.0005\t Loss: 1.025468\t \n", 1119 | "\n", 1120 | "Test set: Average loss: 4.2462 \t MAE: 4.6956\n", 1121 | "\n", 1122 | "Train Epoch: 95\t Learning rate: 0.0005\t Loss: 0.973014\t \n", 1123 | "\n", 1124 | "Test set: Average loss: 3.9069 \t MAE: 4.3618\n", 1125 | "\n", 1126 | "Train Epoch: 96\t Learning rate: 0.0005\t Loss: 0.917344\t \n", 1127 | "\n", 1128 | "Test set: Average loss: 4.2787 \t MAE: 4.7302\n", 1129 | "\n", 1130 | "Train Epoch: 97\t Learning rate: 0.0005\t Loss: 1.120929\t \n", 1131 | "\n", 1132 | "Test set: Average loss: 4.3165 \t MAE: 4.7986\n", 1133 | "\n", 1134 | "Train Epoch: 98\t Learning rate: 0.0005\t Loss: 0.962194\t \n", 1135 | "\n", 1136 | "Test set: Average loss: 3.8568 \t MAE: 4.3219\n", 1137 | "\n", 1138 | "Train Epoch: 99\t Learning rate: 0.0005\t Loss: 1.249937\t \n", 1139 | "\n", 1140 | "Test set: Average loss: 4.0346 \t MAE: 4.4936\n", 1141 | "\n", 1142 | "Train Epoch: 100\t Learning rate: 0.0005\t Loss: 1.070985\t \n", 1143 | "\n", 1144 | "Test set: Average loss: 4.2920 \t MAE: 4.7649\n", 1145 | "\n", 1146 | "Train Epoch: 101\t Learning rate: 0.0005\t Loss: 0.921444\t \n", 1147 | "\n", 1148 | "Test set: Average loss: 4.0885 \t MAE: 4.5694\n", 1149 | "\n", 1150 | "Train Epoch: 102\t Learning rate: 0.0005\t Loss: 1.260023\t \n", 1151 | "\n", 1152 | "Test set: Average loss: 4.1256 \t MAE: 4.5917\n", 1153 | "\n", 1154 | "Train Epoch: 103\t Learning rate: 0.0005\t Loss: 1.159951\t \n", 1155 | "\n", 1156 | "Test set: Average loss: 4.2488 \t MAE: 4.7348\n", 1157 | "\n", 1158 | "Train Epoch: 104\t Learning rate: 0.0005\t Loss: 1.322786\t \n", 1159 | "\n", 1160 | "Test set: Average loss: 4.3939 \t MAE: 4.8783\n", 1161 | "\n", 1162 | "Train Epoch: 105\t Learning rate: 0.0005\t Loss: 0.962339\t \n", 1163 | "\n", 1164 | "Test set: Average loss: 4.4167 \t MAE: 4.8781\n", 1165 | "\n", 1166 | "Train Epoch: 106\t Learning rate: 0.0005\t Loss: 0.944769\t \n", 1167 | "\n", 1168 | "Test set: Average loss: 4.1260 \t MAE: 4.6052\n", 1169 | "\n", 1170 | "Train Epoch: 107\t Learning rate: 0.0005\t Loss: 1.055879\t \n", 1171 | "\n", 1172 | "Test set: Average loss: 4.4353 \t MAE: 4.8979\n", 1173 | "\n", 1174 | "Train Epoch: 108\t Learning rate: 0.0005\t Loss: 0.929050\t \n", 1175 | "\n", 1176 | "Test set: Average loss: 4.3397 \t MAE: 4.7854\n", 1177 | "\n", 1178 | "Train Epoch: 109\t Learning rate: 0.0005\t Loss: 0.961483\t \n", 1179 | "\n", 1180 | "Test set: Average loss: 3.9624 \t MAE: 4.4224\n", 1181 | "\n", 1182 | "Train Epoch: 110\t Learning rate: 0.0005\t Loss: 0.946940\t \n", 1183 | "\n", 1184 | "Test set: Average loss: 4.2200 \t MAE: 4.6670\n", 1185 | "\n", 1186 | "Train Epoch: 111\t Learning rate: 0.0005\t Loss: 0.918805\t \n", 1187 | "\n", 1188 | "Test set: Average loss: 3.9667 \t MAE: 4.4206\n", 1189 | "\n", 1190 | "Train Epoch: 112\t Learning rate: 0.0005\t Loss: 1.147821\t \n", 1191 | "\n", 1192 | "Test set: Average loss: 4.0068 \t MAE: 4.4806\n", 1193 | "\n", 1194 | "Train Epoch: 113\t Learning rate: 0.0005\t Loss: 1.246980\t \n", 1195 | "\n", 1196 | "Test set: Average loss: 4.1982 \t MAE: 4.6641\n", 1197 | "\n", 1198 | "Train Epoch: 114\t Learning rate: 0.0005\t Loss: 1.113731\t \n", 1199 | "\n", 1200 | "Test set: Average loss: 4.4716 \t MAE: 4.9515\n", 1201 | "\n", 1202 | "Train Epoch: 115\t Learning rate: 0.0005\t Loss: 0.942133\t \n", 1203 | "\n", 1204 | "Test set: Average loss: 3.8472 \t MAE: 4.3116\n", 1205 | "\n", 1206 | "Train Epoch: 116\t Learning rate: 0.0005\t Loss: 0.836111\t \n", 1207 | "\n", 1208 | "Test set: Average loss: 3.7778 \t MAE: 4.2443\n", 1209 | "\n", 1210 | "Train Epoch: 117\t Learning rate: 0.0005\t Loss: 1.042380\t \n", 1211 | "\n", 1212 | "Test set: Average loss: 3.8928 \t MAE: 4.3669\n", 1213 | "\n", 1214 | "Train Epoch: 118\t Learning rate: 0.0005\t Loss: 0.888997\t \n", 1215 | "\n", 1216 | "Test set: Average loss: 4.3176 \t MAE: 4.7722\n", 1217 | "\n", 1218 | "Train Epoch: 119\t Learning rate: 0.0005\t Loss: 0.863362\t \n", 1219 | "\n", 1220 | "Test set: Average loss: 3.7324 \t MAE: 4.2109\n", 1221 | "\n", 1222 | "Train Epoch: 120\t Learning rate: 0.0005\t Loss: 1.225582\t \n", 1223 | "\n", 1224 | "Test set: Average loss: 4.1852 \t MAE: 4.6686\n", 1225 | "\n", 1226 | "Train Epoch: 121\t Learning rate: 0.0005\t Loss: 1.197303\t \n", 1227 | "\n", 1228 | "Test set: Average loss: 4.4508 \t MAE: 4.9219\n", 1229 | "\n", 1230 | "Train Epoch: 122\t Learning rate: 0.0005\t Loss: 1.091303\t \n", 1231 | "\n", 1232 | "Test set: Average loss: 4.3728 \t MAE: 4.8488\n", 1233 | "\n", 1234 | "Train Epoch: 123\t Learning rate: 0.0005\t Loss: 1.260954\t \n", 1235 | "\n", 1236 | "Test set: Average loss: 4.1683 \t MAE: 4.6418\n", 1237 | "\n", 1238 | "Train Epoch: 124\t Learning rate: 0.0005\t Loss: 0.879488\t \n", 1239 | "\n", 1240 | "Test set: Average loss: 3.9878 \t MAE: 4.4624\n", 1241 | "\n", 1242 | "Train Epoch: 125\t Learning rate: 0.0005\t Loss: 1.115321\t \n", 1243 | "\n", 1244 | "Test set: Average loss: 3.9576 \t MAE: 4.4411\n", 1245 | "\n", 1246 | "Train Epoch: 126\t Learning rate: 0.0005\t Loss: 0.900170\t \n", 1247 | "\n", 1248 | "Test set: Average loss: 4.1885 \t MAE: 4.6625\n", 1249 | "\n", 1250 | "Train Epoch: 127\t Learning rate: 0.0005\t Loss: 0.914275\t \n", 1251 | "\n", 1252 | "Test set: Average loss: 4.1754 \t MAE: 4.6519\n", 1253 | "\n", 1254 | "Train Epoch: 128\t Learning rate: 0.0005\t Loss: 0.845138\t \n", 1255 | "\n", 1256 | "Test set: Average loss: 4.3401 \t MAE: 4.8130\n", 1257 | "\n", 1258 | "Train Epoch: 129\t Learning rate: 0.0005\t Loss: 1.078165\t \n", 1259 | "\n", 1260 | "Test set: Average loss: 4.2290 \t MAE: 4.7133\n", 1261 | "\n", 1262 | "Train Epoch: 130\t Learning rate: 0.0005\t Loss: 0.893762\t \n", 1263 | "\n", 1264 | "Test set: Average loss: 4.2804 \t MAE: 4.7473\n", 1265 | "\n", 1266 | "Train Epoch: 131\t Learning rate: 0.0005\t Loss: 0.898790\t \n", 1267 | "\n", 1268 | "Test set: Average loss: 4.3800 \t MAE: 4.8504\n", 1269 | "\n", 1270 | "Train Epoch: 132\t Learning rate: 0.0005\t Loss: 0.943924\t \n", 1271 | "\n", 1272 | "Test set: Average loss: 4.0918 \t MAE: 4.5740\n", 1273 | "\n", 1274 | "Train Epoch: 133\t Learning rate: 0.0005\t Loss: 0.892788\t \n", 1275 | "\n", 1276 | "Test set: Average loss: 4.2037 \t MAE: 4.6698\n", 1277 | "\n", 1278 | "Train Epoch: 134\t Learning rate: 0.0005\t Loss: 0.993899\t \n", 1279 | "\n", 1280 | "Test set: Average loss: 4.5271 \t MAE: 4.9923\n", 1281 | "\n", 1282 | "Train Epoch: 135\t Learning rate: 0.0005\t Loss: 1.000830\t \n", 1283 | "\n", 1284 | "Test set: Average loss: 3.9788 \t MAE: 4.4592\n", 1285 | "\n", 1286 | "Train Epoch: 136\t Learning rate: 0.0005\t Loss: 0.865986\t \n", 1287 | "\n", 1288 | "Test set: Average loss: 3.9782 \t MAE: 4.4541\n", 1289 | "\n", 1290 | "Train Epoch: 137\t Learning rate: 0.0005\t Loss: 0.778416\t \n", 1291 | "\n", 1292 | "Test set: Average loss: 4.0501 \t MAE: 4.5125\n", 1293 | "\n", 1294 | "Train Epoch: 138\t Learning rate: 0.0005\t Loss: 0.899644\t \n", 1295 | "\n", 1296 | "Test set: Average loss: 3.9759 \t MAE: 4.4407\n", 1297 | "\n", 1298 | "Train Epoch: 139\t Learning rate: 0.0005\t Loss: 0.858805\t \n", 1299 | "\n", 1300 | "Test set: Average loss: 4.1885 \t MAE: 4.6453\n", 1301 | "\n", 1302 | "Train Epoch: 140\t Learning rate: 0.0005\t Loss: 0.980515\t \n", 1303 | "\n", 1304 | "Test set: Average loss: 4.2229 \t MAE: 4.6874\n", 1305 | "\n", 1306 | "Train Epoch: 141\t Learning rate: 0.0005\t Loss: 0.990140\t \n", 1307 | "\n", 1308 | "Test set: Average loss: 4.2213 \t MAE: 4.6812\n", 1309 | "\n", 1310 | "Train Epoch: 142\t Learning rate: 0.0005\t Loss: 1.142093\t \n", 1311 | "\n", 1312 | "Test set: Average loss: 4.4396 \t MAE: 4.9052\n", 1313 | "\n", 1314 | "Train Epoch: 143\t Learning rate: 0.0005\t Loss: 0.918046\t \n", 1315 | "\n", 1316 | "Test set: Average loss: 4.5148 \t MAE: 4.9913\n", 1317 | "\n", 1318 | "Train Epoch: 144\t Learning rate: 0.0005\t Loss: 0.942595\t \n", 1319 | "\n", 1320 | "Test set: Average loss: 4.2595 \t MAE: 4.7389\n", 1321 | "\n", 1322 | "Train Epoch: 145\t Learning rate: 0.0005\t Loss: 0.875068\t \n", 1323 | "\n", 1324 | "Test set: Average loss: 3.9844 \t MAE: 4.4673\n", 1325 | "\n", 1326 | "Train Epoch: 146\t Learning rate: 0.0005\t Loss: 0.932485\t \n", 1327 | "\n", 1328 | "Test set: Average loss: 4.4675 \t MAE: 4.9310\n", 1329 | "\n", 1330 | "Train Epoch: 147\t Learning rate: 0.0005\t Loss: 1.047656\t \n", 1331 | "\n", 1332 | "Test set: Average loss: 4.2100 \t MAE: 4.6831\n", 1333 | "\n", 1334 | "Train Epoch: 148\t Learning rate: 0.0005\t Loss: 0.835819\t \n", 1335 | "\n", 1336 | "Test set: Average loss: 4.0317 \t MAE: 4.5118\n", 1337 | "\n", 1338 | "Train Epoch: 149\t Learning rate: 0.0005\t Loss: 0.980606\t \n", 1339 | "\n", 1340 | "Test set: Average loss: 4.3075 \t MAE: 4.7642\n", 1341 | "\n", 1342 | "Train Epoch: 150\t Learning rate: 0.0005\t Loss: 0.830669\t \n", 1343 | "\n", 1344 | "Test set: Average loss: 4.5313 \t MAE: 4.9935\n", 1345 | "\n", 1346 | "Train Epoch: 151\t Learning rate: 0.0005\t Loss: 0.893053\t \n", 1347 | "\n", 1348 | "Test set: Average loss: 4.4451 \t MAE: 4.9095\n", 1349 | "\n", 1350 | "Train Epoch: 152\t Learning rate: 0.0005\t Loss: 0.872342\t \n", 1351 | "\n", 1352 | "Test set: Average loss: 4.2925 \t MAE: 4.7543\n", 1353 | "\n", 1354 | "Train Epoch: 153\t Learning rate: 0.0005\t Loss: 0.983376\t \n", 1355 | "\n", 1356 | "Test set: Average loss: 4.1181 \t MAE: 4.5934\n", 1357 | "\n", 1358 | "Train Epoch: 154\t Learning rate: 0.0005\t Loss: 1.020965\t \n", 1359 | "\n", 1360 | "Test set: Average loss: 4.1233 \t MAE: 4.5878\n", 1361 | "\n", 1362 | "Train Epoch: 155\t Learning rate: 0.0005\t Loss: 0.784421\t \n", 1363 | "\n", 1364 | "Test set: Average loss: 4.4377 \t MAE: 4.9053\n", 1365 | "\n", 1366 | "Train Epoch: 156\t Learning rate: 0.0005\t Loss: 1.126911\t \n", 1367 | "\n", 1368 | "Test set: Average loss: 3.9217 \t MAE: 4.4043\n", 1369 | "\n" 1370 | ] 1371 | }, 1372 | { 1373 | "name": "stdout", 1374 | "output_type": "stream", 1375 | "text": [ 1376 | "Train Epoch: 157\t Learning rate: 0.0005\t Loss: 0.992670\t \n", 1377 | "\n", 1378 | "Test set: Average loss: 4.0878 \t MAE: 4.5467\n", 1379 | "\n", 1380 | "Train Epoch: 158\t Learning rate: 0.0005\t Loss: 0.865013\t \n", 1381 | "\n", 1382 | "Test set: Average loss: 4.3303 \t MAE: 4.7961\n", 1383 | "\n", 1384 | "Train Epoch: 159\t Learning rate: 0.0005\t Loss: 0.814098\t \n", 1385 | "\n", 1386 | "Test set: Average loss: 4.0651 \t MAE: 4.5495\n", 1387 | "\n", 1388 | "Train Epoch: 160\t Learning rate: 0.0005\t Loss: 0.843722\t \n", 1389 | "\n", 1390 | "Test set: Average loss: 3.9392 \t MAE: 4.4237\n", 1391 | "\n", 1392 | "Train Epoch: 161\t Learning rate: 0.0005\t Loss: 0.778094\t \n", 1393 | "\n", 1394 | "Test set: Average loss: 4.1093 \t MAE: 4.5677\n", 1395 | "\n", 1396 | "Train Epoch: 162\t Learning rate: 0.0005\t Loss: 0.760861\t \n", 1397 | "\n", 1398 | "Test set: Average loss: 4.0424 \t MAE: 4.4901\n", 1399 | "\n", 1400 | "Train Epoch: 163\t Learning rate: 0.0005\t Loss: 0.976117\t \n", 1401 | "\n", 1402 | "Test set: Average loss: 3.9120 \t MAE: 4.3788\n", 1403 | "\n", 1404 | "Train Epoch: 164\t Learning rate: 0.0005\t Loss: 0.747449\t \n", 1405 | "\n", 1406 | "Test set: Average loss: 4.1145 \t MAE: 4.5898\n", 1407 | "\n", 1408 | "Train Epoch: 165\t Learning rate: 0.0005\t Loss: 0.960359\t \n", 1409 | "\n", 1410 | "Test set: Average loss: 4.3442 \t MAE: 4.8045\n", 1411 | "\n", 1412 | "Train Epoch: 166\t Learning rate: 0.0005\t Loss: 0.837411\t \n", 1413 | "\n", 1414 | "Test set: Average loss: 4.5189 \t MAE: 4.9757\n", 1415 | "\n", 1416 | "Train Epoch: 167\t Learning rate: 0.0005\t Loss: 0.917638\t \n", 1417 | "\n", 1418 | "Test set: Average loss: 4.2090 \t MAE: 4.6905\n", 1419 | "\n", 1420 | "Train Epoch: 168\t Learning rate: 0.0005\t Loss: 0.850689\t \n", 1421 | "\n", 1422 | "Test set: Average loss: 4.0824 \t MAE: 4.5608\n", 1423 | "\n", 1424 | "Train Epoch: 169\t Learning rate: 0.0005\t Loss: 0.974172\t \n", 1425 | "\n", 1426 | "Test set: Average loss: 4.0506 \t MAE: 4.5225\n", 1427 | "\n", 1428 | "Train Epoch: 170\t Learning rate: 0.0005\t Loss: 0.863295\t \n", 1429 | "\n", 1430 | "Test set: Average loss: 4.3983 \t MAE: 4.8468\n", 1431 | "\n", 1432 | "Train Epoch: 171\t Learning rate: 0.0005\t Loss: 0.931016\t \n", 1433 | "\n", 1434 | "Test set: Average loss: 4.1384 \t MAE: 4.6132\n", 1435 | "\n", 1436 | "Train Epoch: 172\t Learning rate: 0.0005\t Loss: 0.794198\t \n", 1437 | "\n", 1438 | "Test set: Average loss: 4.3596 \t MAE: 4.8161\n", 1439 | "\n", 1440 | "Train Epoch: 173\t Learning rate: 0.0005\t Loss: 0.820409\t \n", 1441 | "\n", 1442 | "Test set: Average loss: 4.1466 \t MAE: 4.6183\n", 1443 | "\n", 1444 | "Train Epoch: 174\t Learning rate: 0.0005\t Loss: 0.967401\t \n", 1445 | "\n", 1446 | "Test set: Average loss: 4.2012 \t MAE: 4.6818\n", 1447 | "\n", 1448 | "Train Epoch: 175\t Learning rate: 0.0005\t Loss: 0.931624\t \n", 1449 | "\n", 1450 | "Test set: Average loss: 4.0032 \t MAE: 4.4773\n", 1451 | "\n", 1452 | "Train Epoch: 176\t Learning rate: 0.0005\t Loss: 0.914719\t \n", 1453 | "\n", 1454 | "Test set: Average loss: 4.2812 \t MAE: 4.7415\n", 1455 | "\n", 1456 | "Train Epoch: 177\t Learning rate: 0.0005\t Loss: 0.853825\t \n", 1457 | "\n", 1458 | "Test set: Average loss: 4.1724 \t MAE: 4.6384\n", 1459 | "\n", 1460 | "Train Epoch: 178\t Learning rate: 0.0005\t Loss: 0.782554\t \n", 1461 | "\n", 1462 | "Test set: Average loss: 4.1872 \t MAE: 4.6614\n", 1463 | "\n", 1464 | "Train Epoch: 179\t Learning rate: 0.0005\t Loss: 0.790341\t \n", 1465 | "\n", 1466 | "Test set: Average loss: 4.3984 \t MAE: 4.8560\n", 1467 | "\n", 1468 | "Train Epoch: 180\t Learning rate: 0.0005\t Loss: 0.691040\t \n", 1469 | "\n", 1470 | "Test set: Average loss: 4.4033 \t MAE: 4.8658\n", 1471 | "\n", 1472 | "Train Epoch: 181\t Learning rate: 0.0005\t Loss: 0.850953\t \n", 1473 | "\n", 1474 | "Test set: Average loss: 3.8660 \t MAE: 4.3404\n", 1475 | "\n", 1476 | "Train Epoch: 182\t Learning rate: 0.0005\t Loss: 0.896042\t \n", 1477 | "\n", 1478 | "Test set: Average loss: 4.2568 \t MAE: 4.7351\n", 1479 | "\n", 1480 | "Train Epoch: 183\t Learning rate: 0.0005\t Loss: 0.960200\t \n", 1481 | "\n", 1482 | "Test set: Average loss: 4.3396 \t MAE: 4.7906\n", 1483 | "\n", 1484 | "Train Epoch: 184\t Learning rate: 0.0005\t Loss: 0.959877\t \n", 1485 | "\n", 1486 | "Test set: Average loss: 3.8674 \t MAE: 4.3530\n", 1487 | "\n", 1488 | "Train Epoch: 185\t Learning rate: 0.0005\t Loss: 0.725641\t \n", 1489 | "\n", 1490 | "Test set: Average loss: 3.9436 \t MAE: 4.4092\n", 1491 | "\n", 1492 | "Train Epoch: 186\t Learning rate: 0.0005\t Loss: 0.767895\t \n", 1493 | "\n", 1494 | "Test set: Average loss: 4.2262 \t MAE: 4.6659\n", 1495 | "\n", 1496 | "Train Epoch: 187\t Learning rate: 0.0005\t Loss: 0.843481\t \n", 1497 | "\n", 1498 | "Test set: Average loss: 4.2961 \t MAE: 4.7634\n", 1499 | "\n", 1500 | "Train Epoch: 188\t Learning rate: 0.0005\t Loss: 0.874929\t \n", 1501 | "\n", 1502 | "Test set: Average loss: 4.0638 \t MAE: 4.5159\n", 1503 | "\n", 1504 | "Train Epoch: 189\t Learning rate: 0.0005\t Loss: 1.043584\t \n", 1505 | "\n", 1506 | "Test set: Average loss: 4.0349 \t MAE: 4.5109\n", 1507 | "\n", 1508 | "Train Epoch: 190\t Learning rate: 0.0005\t Loss: 0.935190\t \n", 1509 | "\n", 1510 | "Test set: Average loss: 4.0865 \t MAE: 4.5353\n", 1511 | "\n", 1512 | "Train Epoch: 191\t Learning rate: 0.0005\t Loss: 0.694541\t \n", 1513 | "\n", 1514 | "Test set: Average loss: 4.3420 \t MAE: 4.8012\n", 1515 | "\n", 1516 | "Train Epoch: 192\t Learning rate: 0.0005\t Loss: 0.865203\t \n", 1517 | "\n", 1518 | "Test set: Average loss: 3.9585 \t MAE: 4.4290\n", 1519 | "\n", 1520 | "Train Epoch: 193\t Learning rate: 0.0005\t Loss: 0.858346\t \n", 1521 | "\n", 1522 | "Test set: Average loss: 4.1437 \t MAE: 4.6031\n", 1523 | "\n", 1524 | "Train Epoch: 194\t Learning rate: 0.0005\t Loss: 0.906553\t \n", 1525 | "\n", 1526 | "Test set: Average loss: 4.4437 \t MAE: 4.9088\n", 1527 | "\n", 1528 | "Train Epoch: 195\t Learning rate: 0.0005\t Loss: 0.836493\t \n", 1529 | "\n", 1530 | "Test set: Average loss: 4.1875 \t MAE: 4.6559\n", 1531 | "\n", 1532 | "Train Epoch: 196\t Learning rate: 0.0005\t Loss: 0.852588\t \n", 1533 | "\n", 1534 | "Test set: Average loss: 4.4612 \t MAE: 4.9229\n", 1535 | "\n", 1536 | "Train Epoch: 197\t Learning rate: 0.0005\t Loss: 0.905962\t \n", 1537 | "\n", 1538 | "Test set: Average loss: 4.4976 \t MAE: 4.9639\n", 1539 | "\n", 1540 | "Train Epoch: 198\t Learning rate: 0.0005\t Loss: 0.915486\t \n", 1541 | "\n", 1542 | "Test set: Average loss: 4.2601 \t MAE: 4.7137\n", 1543 | "\n", 1544 | "Train Epoch: 199\t Learning rate: 0.0005\t Loss: 0.840349\t \n", 1545 | "\n", 1546 | "Test set: Average loss: 4.0902 \t MAE: 4.5491\n", 1547 | "\n", 1548 | "Train Epoch: 200\t Learning rate: 0.0005\t Loss: 0.837486\t \n", 1549 | "\n", 1550 | "Test set: Average loss: 4.2122 \t MAE: 4.6674\n", 1551 | "\n" 1552 | ] 1553 | } 1554 | ], 1555 | "source": [ 1556 | "for ep in range(1, config['epochs']):\n", 1557 | " train(ep)\n", 1558 | " tloss = evaluate(model)" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "code", 1563 | "execution_count": 9, 1564 | "metadata": {}, 1565 | "outputs": [ 1566 | { 1567 | "name": "stdout", 1568 | "output_type": "stream", 1569 | "text": [ 1570 | "Confusion Matrix:\n", 1571 | "[[10 2]\n", 1572 | " [ 2 21]]\n", 1573 | "\n", 1574 | "Test set: Average loss: 1.1024\n", 1575 | "Calculating additional test metrics...\n", 1576 | "Accuracy: 0.8857142857142857\n", 1577 | "Precision: 0.8333333333333334\n", 1578 | "Recall: 0.8333333333333334\n", 1579 | "F1-Score: 0.8333333333333334\n", 1580 | "\n", 1581 | "=========================================================================================\n" 1582 | ] 1583 | }, 1584 | { 1585 | "data": { 1586 | "text/plain": [ 1587 | "1.102392554283142" 1588 | ] 1589 | }, 1590 | "execution_count": 9, 1591 | "metadata": {}, 1592 | "output_type": "execute_result" 1593 | } 1594 | ], 1595 | "source": [ 1596 | "lstm_model = torch.load('/Users/apple/Downloads/depression/BiLSTM_elmo_128_0.83.pt')\n", 1597 | "model = BiLSTM(config)\n", 1598 | "model.load_state_dict(lstm_model.state_dict())\n", 1599 | "evaluate(model)" 1600 | ] 1601 | }, 1602 | { 1603 | "cell_type": "code", 1604 | "execution_count": 342, 1605 | "metadata": {}, 1606 | "outputs": [ 1607 | { 1608 | "name": "stdout", 1609 | "output_type": "stream", 1610 | "text": [ 1611 | "\n", 1612 | "Test set: Average loss: 3.4328 \t MAE: 3.8846\n", 1613 | "\n", 1614 | "Saved as BiLSTM_reg_128_3.88.pt\n", 1615 | "****************************************************************\n", 1616 | "model saved: f1: 3.8845977783203125\tacc: 0.8374860261877378\n", 1617 | "****************************************************************\n" 1618 | ] 1619 | }, 1620 | { 1621 | "data": { 1622 | "text/plain": [ 1623 | "3.432814359664917" 1624 | ] 1625 | }, 1626 | "execution_count": 342, 1627 | "metadata": {}, 1628 | "output_type": "execute_result" 1629 | } 1630 | ], 1631 | "source": [ 1632 | "lstm_model = torch.load('/Users/apple/Downloads/depression/BiLSTM_reg_128_3.88.pt')\n", 1633 | "model = BiLSTM(config)\n", 1634 | "model.load_state_dict(lstm_model.state_dict())\n", 1635 | "evaluate(model)" 1636 | ] 1637 | }, 1638 | { 1639 | "cell_type": "code", 1640 | "execution_count": null, 1641 | "metadata": {}, 1642 | "outputs": [], 1643 | "source": [] 1644 | } 1645 | ], 1646 | "metadata": { 1647 | "kernelspec": { 1648 | "display_name": "Python 3", 1649 | "language": "python", 1650 | "name": "python3" 1651 | }, 1652 | "language_info": { 1653 | "codemirror_mode": { 1654 | "name": "ipython", 1655 | "version": 3 1656 | }, 1657 | "file_extension": ".py", 1658 | "mimetype": "text/x-python", 1659 | "name": "python", 1660 | "nbconvert_exporter": "python", 1661 | "pygments_lexer": "ipython3", 1662 | "version": "3.7.3" 1663 | } 1664 | }, 1665 | "nbformat": 4, 1666 | "nbformat_minor": 2 1667 | } 1668 | -------------------------------------------------------------------------------- /regression/cnn_audio_reg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import wave 4 | import librosa 5 | from python_speech_features import * 6 | 7 | import torch 8 | from torch.utils import data # 获取迭代数据 9 | from torch.autograd import Variable # 获取变量 10 | import torchvision 11 | import matplotlib.pyplot as plt 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from sklearn.metrics import confusion_matrix 15 | from torch.nn import functional as F 16 | 17 | from keras.models import Sequential 18 | from keras.layers import Dense, Dropout, Activation, Flatten 19 | from keras.layers import Conv2D, MaxPooling2D 20 | 21 | prefix = '/Users/apple/Downloads/depression/' 22 | 23 | train_split_df = pd.read_csv(prefix+'train_split_Depression_AVEC2017 (1).csv') 24 | test_split_df = pd.read_csv(prefix+'dev_split_Depression_AVEC2017.csv') 25 | train_split_num = train_split_df[['Participant_ID']]['Participant_ID'].tolist() 26 | test_split_num = test_split_df[['Participant_ID']]['Participant_ID'].tolist() 27 | train_split_clabel = train_split_df[['PHQ8_Score']]['PHQ8_Score'].tolist() 28 | test_split_clabel = test_split_df[['PHQ8_Score']]['PHQ8_Score'].tolist() 29 | 30 | def extract_features(number, audio_features, target, audio_targets, mode): 31 | transcript = pd.read_csv(prefix+'{0}_P/{0}_TRANSCRIPT.csv'.format(number), sep='\t').fillna('') 32 | 33 | wavefile = wave.open(prefix+'{0}_P/{0}_AUDIO.wav'.format(number, 'r')) 34 | sr = wavefile.getframerate() 35 | nframes = wavefile.getnframes() 36 | wave_data = np.frombuffer(wavefile.readframes(nframes), dtype=np.short) 37 | 38 | time_range = [] 39 | response = '' 40 | response_flag = False 41 | time_collect_flag = False 42 | start_time = 0 43 | stop_time = 0 44 | 45 | signal = [] 46 | 47 | global counter_train 48 | 49 | for t in transcript.itertuples(): 50 | # participant一句话结束 51 | if getattr(t,'speaker') == 'Ellie': 52 | continue 53 | elif getattr(t,'speaker') == 'Participant': 54 | if 'scrubbed_entry' in getattr(t,'value'): 55 | continue 56 | start_time = int(getattr(t,'start_time')*sr) 57 | stop_time = int(getattr(t,'stop_time')*sr) 58 | signal = np.hstack((signal, wave_data[start_time:stop_time].astype(np.float))) 59 | 60 | # 1分钟 61 | clip = sr*1*15 62 | if target >= 10 and mode == 'train': 63 | times = 3 if counter_train < 48 else 2 64 | for i in range(times): 65 | if clip*(i+1) > len(signal): 66 | continue 67 | melspec = librosa.feature.melspectrogram(signal[clip*i:clip*(i+1)], n_mels=80,sr=sr) 68 | # melspec = base.logfbank(signal[clip*i:clip*(i+1)], samplerate=sr, winlen=0.064, winstep=0.032, nfilt=80, nfft=1024, lowfreq=130, highfreq=6854) 69 | logspec = melspec 70 | audio_features.append(logspec) 71 | audio_targets.append(target) 72 | counter_train+=1 73 | else: 74 | melspec = librosa.feature.melspectrogram(signal[:clip], n_mels=80, sr=sr) 75 | # melspec = base.logfbank(signal[:clip], samplerate=sr, winlen=0.064, winstep=0.032, nfilt=80, nfft=1024, lowfreq=130, highfreq=6854) 76 | logspec = melspec 77 | audio_features.append(logspec) 78 | audio_targets.append(target) 79 | # print(melspec.shape) 80 | print('{}_P feature done'.format(number)) 81 | 82 | # training set 83 | audio_features_train = [] 84 | audio_ctargets_train = [] 85 | 86 | # test set 87 | audio_features_test = [] 88 | audio_ctargets_test = [] 89 | mark_test = [] 90 | 91 | counter_train = 0 92 | counter_test = 0 93 | 94 | # training set 95 | for index in range(len(train_split_num)): 96 | extract_features(train_split_num[index], audio_features_train, train_split_clabel[index], audio_ctargets_train, 'train') 97 | 98 | # test set 99 | for index in range(len(test_split_num)): 100 | extract_features(test_split_num[index], audio_features_test, test_split_clabel[index], audio_ctargets_test, 'test') 101 | 102 | print(np.shape(audio_ctargets_train), np.shape(audio_ctargets_test)) 103 | print(counter_train, counter_test) 104 | 105 | print("Saving npz file locally...") 106 | np.savez(prefix+'data/audio/train_samples_reg.npz', audio_features_train) 107 | np.savez(prefix+'data/audio/train_labels_reg.npz', audio_ctargets_train) 108 | np.savez(prefix+'data/audio/test_samples_reg.npz', audio_features_test) 109 | np.savez(prefix+'data/audio/test_labels_reg.npz', audio_ctargets_test) 110 | 111 | config = { 112 | 'num_classes': 1, 113 | 'dropout': 0.5, 114 | 'rnn_layers': 2, 115 | 'embedding_size': 80, 116 | 'batch_size': 2, 117 | 'epochs': 30, 118 | 'learning_rate': 1e-4, 119 | 'cuda': True, 120 | } 121 | 122 | X_train = np.array(audio_features_train) 123 | Y_train = np.array(audio_ctargets_train) 124 | X_test = np.array(audio_features_test) 125 | Y_test = np.array(audio_ctargets_test) 126 | 127 | optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4) 128 | 129 | def save(model, filename): 130 | save_filename = '{}.pt'.format(filename) 131 | torch.save(model, save_filename) 132 | print('Saved as %s' % save_filename) 133 | 134 | class CNN(nn.Module): 135 | def __init__(self): 136 | super(CNN, self).__init__() 137 | self.conv2d_1 = nn.Conv2d(1, 32, (1,7), 1) 138 | self.conv2d_2 = nn.Conv2d(32, 32, (1,7), 1) 139 | self.dense_1 = nn.Linear(120736, 128) 140 | self.dense_2 = nn.Linear(128, 128) 141 | self.dense_3 = nn.Linear(128, 1) 142 | self.dropout = nn.Dropout(0.5) 143 | 144 | def forward(self, x): 145 | x = F.relu(self.conv2d_1(x)) 146 | x = F.max_pool2d(x, (4, 3), (1, 3)) 147 | x = F.relu(self.conv2d_2(x)) 148 | x = F.max_pool2d(x, (1, 3), (1, 3)) 149 | # flatten in keras 150 | x = x.permute((0, 2, 3, 1)) 151 | x = x.contiguous().view(-1, 120736) 152 | x = F.relu(self.dense_1(x)) 153 | x = F.relu(self.dense_2(x)) 154 | x = self.dropout(x) 155 | # output = torch.sigmoid(self.dense_3(x)) 156 | output = torch.relu(self.dense_3(x)) 157 | return output 158 | 159 | def train(epoch): 160 | global lr 161 | model.train() 162 | batch_idx = 1 163 | total_loss = 0 164 | correct = 0 165 | for i in range(0, X_train.shape[0], config['batch_size']): 166 | if i + config['batch_size'] > X_train.shape[0]: 167 | x, y = X_train[i:], Y_train[i:] 168 | else: 169 | x, y = X_train[i:(i+config['batch_size'])], Y_train[i:(i+config['batch_size'])] 170 | if config['cuda']: 171 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda() 172 | else: 173 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y)) 174 | # 将模型的参数梯度设置为0 175 | optimizer.zero_grad() 176 | output = model(x.unsqueeze(1)) 177 | loss = criterion(output, y) 178 | # 后向传播调整参数 179 | loss.backward() 180 | # 根据梯度更新网络参数 181 | optimizer.step() 182 | batch_idx += 1 183 | # loss.item()能够得到张量中的元素值 184 | total_loss += loss.item() 185 | 186 | cur_loss = total_loss 187 | print('Train Epoch: {:2d}\t Learning rate: {:.4f}\t Loss: {:.6f} \n '.format( 188 | epoch, config['learning_rate'], cur_loss/batch_idx)) 189 | 190 | def evaluate(model): 191 | model.eval() 192 | batch_idx = 1 193 | total_loss = 0 194 | pred = np.array([]) 195 | for i in range(0, X_test.shape[0], batch_size): 196 | if i + batch_size > X_test.shape[0]: 197 | x, y = X_test[i:], Y_test[i:] 198 | else: 199 | x, y = X_test[i:(i+batch_size)], Y_test[i:(i+batch_size)] 200 | if False: 201 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True).cuda(), Variable(torch.from_numpy(y)).cuda() 202 | else: 203 | x, y = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True), Variable(torch.from_numpy(y).type(torch.FloatTensor)) 204 | with torch.no_grad(): 205 | output = model(x.unsqueeze(1)) 206 | loss = criterion(output, torch.tensor(y)) 207 | pred = np.hstack((pred, output.flatten().numpy())) 208 | total_loss += loss.item() 209 | 210 | print(Y_test, pred) 211 | print('MAE: {}'.format(mean_absolute_error(Y_test, pred))) 212 | print('='*89) 213 | 214 | return total_loss 215 | 216 | model = CNN() 217 | criterion = nn.CrossEntropyLoss() 218 | 219 | for ep in range(1, config['epochs']): 220 | train(ep) 221 | tloss = evaluate() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.5 2 | matplotlib==3.2.2 3 | pandas==1.0.5 4 | fuzzywuzzy==0.18.0 5 | keras==2.4.3 6 | librosa==0.8.0 7 | pyenchant==3.1.1 8 | python_speech_features==0.6 9 | scikit_learn==0.23.1 10 | torch==1.6.0 11 | torchvision==0.7.0 12 | --------------------------------------------------------------------------------