├── .ipynb_checkpoints └── edf_dataset_to_numpy-checkpoint.ipynb ├── README.md ├── jupyter_example ├── .ipynb_checkpoints │ └── check_dataset_info-checkpoint.ipynb ├── SleepEEGNet_CNN.ipynb ├── check_dataset_info.ipynb ├── edf_dataset_to_numpy.ipynb └── makeDataset_each.ipynb ├── main.py ├── models └── cnn │ ├── DeepSleepNet_cnn.py │ ├── ResNet.py │ ├── __init__.py │ ├── __pycache__ │ ├── DeepSleepNet_cnn.cpython-38.pyc │ ├── ResNet.cpython-38.pyc │ └── __init__.cpython-38.pyc │ └── modules │ ├── ResNet_module.py │ ├── __init__.py │ └── __pycache__ │ ├── ResNet_module.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── train ├── representation_learning │ └── single_epoch │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── train_resnet_representationlearning.cpython-38.pyc │ │ └── train_resnet_simCLR.cpython-38.pyc │ │ ├── train_resnet_representationlearning.py │ │ └── train_resnet_simCLR.py └── single_epoch │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── train_deepsleepnet.cpython-38.pyc │ └── train_resnet.cpython-38.pyc │ ├── train_deepsleepnet.py │ └── train_resnet.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── function.cpython-38.pyc ├── loss_fn.cpython-38.pyc └── scheduler.cpython-38.pyc ├── dataloader ├── Transform.py ├── __init__.py ├── __pycache__ │ ├── Transform.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── sleep_edf.cpython-38.pyc └── sleep_edf.py ├── dataset └── Sleep_edf │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── edf_to_numpy.cpython-38.pyc │ ├── function.cpython-38.pyc │ └── makeDataset_each.cpython-38.pyc │ ├── edf_to_numpy.py │ ├── function.py │ └── makeDataset_each.py ├── function.py ├── loss_fn.py └── scheduler.py /README.md: -------------------------------------------------------------------------------- 1 | # DeepSleepNet_pytorch 2 | (Revising...) 3 | 4 | ### If you want to check simple results about DeepSleepNet without RNN module, check the jupyter notebook file(.ipynb). 5 | 6 | ### To make more various utilization, currently, I have be fixing my git repository... 7 | -------------------------------------------------------------------------------- /jupyter_example/edf_dataset_to_numpy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Pytorch 1.3.1\n", 8 | "### cuda 10.1\n", 9 | "### Dataset = Sleep-edf-2013\n", 10 | "### pip install pyEDFlib : python에서 edf 파일을 열기 위한 라이브러리\n", 11 | "### pip install matplotlib " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "from pyedflib import highlevel\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import os\n", 24 | "import pandas as pd\n", 25 | "import random\n", 26 | "import shutil" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "### 함수 search_annotations_edf \n", 34 | "#### 해당 함수는 Hypnogram.edf이라는 이름을 가진 파일만 추려내기 위한 함수이며 해당 파일들은 annotations ( sleep stage ) 정보를 저장하고 있는 파일이다.\n", 35 | "\n", 36 | "### 함수 search_signals_edf \n", 37 | "#### 해당 함수는 PSG.edf이라는 이름을 가진 파일만 추려내기 위한 함수이며 해당 파일들은 signals 정보를 저장하고 있는 파일이다." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def search_annotations_edf(dirname):\n", 47 | " filenames = os.listdir(dirname)\n", 48 | " filenames = [file for file in filenames if file.endswith(\"Hypnogram.edf\")]\n", 49 | " return filenames\n", 50 | "\n", 51 | "def search_signals_edf(dirname):\n", 52 | " filenames = os.listdir(dirname)\n", 53 | " filenames = [file for file in filenames if file.endswith(\"PSG.edf\")]\n", 54 | " return filenames\n", 55 | "\n", 56 | "def search_correct_annotations(dirname,filename):\n", 57 | " search_filename = filename.split('-')[0][:-2]\n", 58 | " file_list = os.listdir(dirname)\n", 59 | " filename = [file for file in file_list if search_filename in file if file.endswith(\"Hypnogram.edf\")]\n", 60 | " \n", 61 | " return filename\n", 62 | "\n", 63 | "def search_signals_npy(dirname):\n", 64 | " filenames = os.listdir(dirname)\n", 65 | " filenames = [file for file in filenames if file.endswith(\".npy\")]\n", 66 | " return filenames\n", 67 | "\n", 68 | "def search_correct_annotations_npy(dirname,filename):\n", 69 | " search_filename = filename.split('-')[0][:-2]\n", 70 | " file_list = os.listdir(dirname)\n", 71 | " filename = [file for file in file_list if search_filename in file if file.endswith(\"npy\")]\n", 72 | " \n", 73 | " return filename\n", 74 | "\n", 75 | "def search_correct_signals_npy(dirname,filename):\n", 76 | " search_filename = filename.split('-')[0][:-2]\n", 77 | " file_list = os.listdir(dirname)\n", 78 | " filename = [file for file in file_list if search_filename in file if file.endswith(\"npy\")]\n", 79 | " \n", 80 | " return filename" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "path = 'D:/dataset/data_2013/' # sleep-edf 2013 데이터를 가지고 있는 폴더 명\n", 90 | "annotations_edf_list = search_annotations_edf(path)\n", 91 | "signals_edf_list = search_signals_edf(path)\n", 92 | "\n", 93 | "print('signals edf file list')\n", 94 | "print(signals_edf_list)\n", 95 | "\n", 96 | "print('annotations edf file list')\n", 97 | "print(annotations_edf_list)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "#### signals_edf_list[0].split('-')[0][:-2]\n", 105 | "#### signals 와 같은 annotations 파일을 찾기 위함\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# for filename in signals_edf_list:\n", 115 | "# print('signals file name : %s , annotations file name : %s'%(filename,search_correct_annotations(signal_path,filename)[0]))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 8, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "epoch_size = 30\n", 125 | "sample_rate = 100\n", 126 | "save_signals_path = path + 'origin_npy/'\n", 127 | "save_annotations_path = save_signals_path+'annotations/'\n", 128 | "\n", 129 | "os.makedirs(save_annotations_path,exist_ok=True)\n", 130 | "os.makedirs(save_signals_path,exist_ok=True)\n", 131 | "\n", 132 | "for filename in signals_edf_list:\n", 133 | " signals_filename = filename\n", 134 | " annotations_filename = search_correct_annotations(signal_path,filename)[0]\n", 135 | " \n", 136 | " signals_filename = path + signals_filename\n", 137 | " annotations_filename = path + annotations_filename\n", 138 | " \n", 139 | " _, _, annotations_header = highlevel.read_edf(annotations_filename)\n", 140 | " \n", 141 | " label = []\n", 142 | " for ann in annotations_header['annotations']:\n", 143 | " start = ann[0]\n", 144 | "\n", 145 | " length = ann[1]\n", 146 | " length = int(str(length)[2:-1]) // epoch_size # label은 30초 간격으로 사용할 것이기 때문에 30으로 나눈 값이 해당 sleep stage가 반복된 횟수이다.\n", 147 | " \n", 148 | " if ann[2] == 'Sleep stage W':\n", 149 | " for time in range(length):\n", 150 | " label.append(0)\n", 151 | " elif ann[2] == 'Sleep stage 1':\n", 152 | " for time in range(length):\n", 153 | " label.append(1)\n", 154 | " elif ann[2] == 'Sleep stage 2':\n", 155 | " for time in range(length):\n", 156 | " label.append(2)\n", 157 | " elif ann[2] == 'Sleep stage 3':\n", 158 | " for time in range(length):\n", 159 | " label.append(3)\n", 160 | " elif ann[2] == 'Sleep stage 4':\n", 161 | " for time in range(length):\n", 162 | " label.append(3)\n", 163 | " elif ann[2] == 'Sleep stage R':\n", 164 | " for time in range(length):\n", 165 | " label.append(4)\n", 166 | " else:\n", 167 | " for time in range(length):\n", 168 | " label.append(5)\n", 169 | " label = np.array(label)\n", 170 | " signals, _, signals_header = highlevel.read_edf(signals_filename)\n", 171 | " \n", 172 | " \n", 173 | " signals_len = len(signals[0]) // sample_rate // epoch_size\n", 174 | " annotations_len = len(label)\n", 175 | " if signals_header['startdate'] == annotations_header['startdate']:\n", 176 | " print(\"%s file's signal & annotations start time is same\"%signals_filename.split('/')[-1])\n", 177 | " \n", 178 | " if signals_len > annotations_len :\n", 179 | " signals = signals[:3][:annotations_len]\n", 180 | " elif signals_len < annotations_len :\n", 181 | " signals = signals[:3]\n", 182 | " label = label[:signals_len]\n", 183 | " else:\n", 184 | " signals = signals[:3]\n", 185 | " signals = np.array(signals)\n", 186 | " \n", 187 | " np.save(save_signals_path + signals_filename.split('/')[-1].split('.')[0],signals)\n", 188 | " np.save(save_annotations_path + annotations_filename.split('/')[-1].split('.')[0],label)\n", 189 | " \n", 190 | " if (len(signals[0])//sample_rate//epoch_size != len(label)):\n", 191 | " print('signals len : %d / annotations len : %d'%(len(signals[0])//sample_rate//epoch_size,len(label)))\n", 192 | " \n", 193 | " else:\n", 194 | " print(\"%s file''s signal & annotations start time is different\"%signals_filename.split('/')[-1])\n", 195 | " " 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### Signals\n", 203 | "#### channel 0 : EEG Fpz-Cz , sample rate = 100 , dimension = $\\mu$V prefilter : HP = 0.5Hz , LP = 100Hz\n", 204 | "#### channel 1 : EEG Pz-Oz , sample rate = 100 , dimension = $\\mu$V prefilter : HP = 0.5Hz , LP = 100Hz\n", 205 | "#### channel 2 : EOG horizontal , sample rate = 100 , dimension = $\\mu$V prefilter : HP = 0.5Hz , LP = 100Hz\n", 206 | "#### channel 3 : Resp oro-nasal \n", 207 | "#### channel 4 : EMG submental\n", 208 | "#### channel 5 : Temp rectal\n", 209 | "#### channel 6 : Event marker\n", 210 | "\n", 211 | "#### 실제로 학습에 사용하는 채널은 0,1 채널이고, Benchmark에서 높은 성능을 가지는 채널은 0번 채널 ( Fpz-Cz )이다.\n", 212 | "\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 9, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "epoch_size = 30\n", 222 | "sample_rate = 100\n", 223 | "\n", 224 | "path = 'D:/dataset/data_2013/origin_npy/'\n", 225 | "\n", 226 | "signals_npy_list = search_signals_npy(path)\n", 227 | "\n", 228 | "print(signals_npy_list)\n", 229 | "\n", 230 | "channel_name_list = ['Fpz-Cz/','Pz-Oz/','EOG/']\n", 231 | "for channel_index,channel_name in enumerate(channel_name_list):\n", 232 | " save_path = path + channel_name\n", 233 | " os.makedirs(save_path,exist_ok=True)\n", 234 | "\n", 235 | " for filename in signals_npy_list:\n", 236 | " signals_filename = filename\n", 237 | "\n", 238 | " signals_filename = path + signals_filename\n", 239 | " \n", 240 | " signals = np.load(signals_filename)\n", 241 | " \n", 242 | " signals = signals[channel_index].reshape(1,-1)\n", 243 | " print(signals.shape)\n", 244 | " \n", 245 | " np.save(save_path + filename,signals)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "#### Channel 별 npy 파일 분리 작업" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 11, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "epoch_size = 30\n", 262 | "sample_rate = 100\n", 263 | "\n", 264 | "path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/'\n", 265 | "annotations_path = 'D:/dataset/data_2013/origin_npy/annotations/'\n", 266 | "signals_npy_list = search_signals_npy(path)\n", 267 | "\n", 268 | "print(signals_npy_list)\n", 269 | "\n", 270 | "\n", 271 | "for filename in signals_npy_list:\n", 272 | " signals_filename = path + filename\n", 273 | " annotations_filename = annotations_path+search_correct_annotations_npy(annotations_path,filename)[0]\n", 274 | " signals = np.load(signals_filename)\n", 275 | " label = np.load(annotations_filename)\n", 276 | " if len(signals[0])//sample_rate//epoch_size != len(label):\n", 277 | " print('%s is fault'%filename)\n", 278 | " " 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 13, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "fs = 100 # Sampling rate (512 Hz)\n", 288 | "epoch_size = 30\n", 289 | "#data = np.random.uniform(0, 100, 1024) # 2 sec of data b/w 0.0-100.0\n", 290 | "\n", 291 | "path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/'\n", 292 | "\n", 293 | "signals_npy_list = search_signals_npy(path)\n", 294 | "\n", 295 | "for filename in signals_npy_list:\n", 296 | " signals = np.load(path+filename)\n", 297 | " length = len(signals[0])//fs//epoch_size\n", 298 | " print(signals.shape)\n", 299 | " for index in range(length):\n", 300 | " data = signals[0,int(index*fs*30) : int((index+1)*fs*30)]\n", 301 | " # Get real amplitudes of FFT (only in postive frequencies)\n", 302 | " fft_vals = np.absolute(np.fft.rfft(data)) / (fs*epoch_size)# real fft 계산 \n", 303 | "\n", 304 | " \n", 305 | " # fft_vals[:1*30+1] = 0\n", 306 | " # fft_vals[35*30:] = 0\n", 307 | " \n", 308 | " # Get frequencies for amplitudes in Hz\n", 309 | " fft_freq = np.fft.rfftfreq(len(data), 1.0/fs)\n", 310 | "\n", 311 | " # Define EEG bands\n", 312 | " eeg_bands = {'Delta-0-0.5': (0, 0.5),\n", 313 | " 'Delta-0.5-1': (0.5, 1),\n", 314 | " 'Delta-1-2': (1, 2),\n", 315 | " 'Delta-2-3': (2, 3),\n", 316 | " 'Delta-3-4': (3, 4),\n", 317 | " 'Theta-4-5': (4, 5),\n", 318 | " 'Theta-5-6': (5, 6),\n", 319 | " 'Theta-6-7': (6, 7),\n", 320 | " 'Theta-7-8': (7, 8),\n", 321 | " 'Alpha-8-9': (8, 9),\n", 322 | " 'Alpha-9-10': (9, 10),\n", 323 | " 'Alpha-10-11': (10, 11),\n", 324 | " 'Alpha-11-12': (11, 12),\n", 325 | " 'Beta': (12, 30),\n", 326 | " 'Gamma': (30, 45)}\n", 327 | "\n", 328 | " # Take the mean of the fft amplitude for each EEG band\n", 329 | " eeg_band_fft = []\n", 330 | " for band in eeg_bands: \n", 331 | " #print('band : ',band)\n", 332 | " freq_ix = np.where((fft_freq >= eeg_bands[band][0]) & \n", 333 | " (fft_freq <= eeg_bands[band][1]))[0]\n", 334 | " \n", 335 | " eeg_band_fft.append(np.mean(fft_vals[freq_ix]))\n", 336 | " eeg_band_fft = np.array(eeg_band_fft)\n", 337 | " # Plot the data (using pandas here cause it's easy)\n", 338 | "\n", 339 | " print(eeg_bands.keys())\n", 340 | " print(eeg_band_fft)\n", 341 | " plt.bar(eeg_bands.keys(),eeg_band_fft)\n", 342 | " plt.xticks(rotation=90)\n", 343 | " plt.show()\n", 344 | "\n" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "#### Frequency 대역 정보 확인" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 14, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "fs = 100 # Sampling rate (512 Hz)\n", 361 | "epoch_size = 30\n", 362 | "#data = np.random.uniform(0, 100, 1024) # 2 sec of data b/w 0.0-100.0\n", 363 | "\n", 364 | "path = 'D:/dataset/data_2013/origin_npy/annotations/'\n", 365 | "\n", 366 | "annotations_npy_list = search_signals_npy(path)\n", 367 | "\n", 368 | "check_index_size = 10\n", 369 | "\n", 370 | "for filename in annotations_npy_list:\n", 371 | " label_info = np.zeros([5],dtype=int)\n", 372 | " label = np.load(path + filename)\n", 373 | " \n", 374 | " \n", 375 | "\n", 376 | " plt.plot(label)\n", 377 | " plt.show()\n", 378 | " print('='*20)\n", 379 | "\n" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "#### label 구성 확인" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 15, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "fs = 100 # Sampling rate (512 Hz)\n", 396 | "epoch_size = 30\n", 397 | "#data = np.random.uniform(0, 100, 1024) # 2 sec of data b/w 0.0-100.0\n", 398 | "\n", 399 | "path = 'D:/dataset/data_2013/origin_npy/annotations/'\n", 400 | "signals_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/'\n", 401 | "\n", 402 | "save_annotations_path = path + 'remove_wake/'\n", 403 | "save_signals_path = signals_path + 'remove_wake/'\n", 404 | "\n", 405 | "os.makedirs(save_annotations_path,exist_ok=True)\n", 406 | "os.makedirs(save_signals_path,exist_ok=True)\n", 407 | "annotations_npy_list = search_signals_npy(path)\n", 408 | "\n", 409 | "check_index_size = 10\n", 410 | "\n", 411 | "total_label = np.zeros([6],dtype=int)\n", 412 | "\n", 413 | "for filename in annotations_npy_list:\n", 414 | " label = np.load(path + filename)\n", 415 | " signals_filename = search_correct_signals_npy(signals_path,filename)[0]\n", 416 | " \n", 417 | " signals = np.load(signals_path+signals_filename)\n", 418 | " \n", 419 | " for remove_start_index in range(0,len(label),1):\n", 420 | " #print(np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[0])\n", 421 | " if(np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[0] != check_index_size):\n", 422 | " break\n", 423 | " \n", 424 | " for remove_end_index in range(len(label),-1,-1,):\n", 425 | " #print(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0])\n", 426 | " if(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0] != check_index_size and np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[5] == 0 ):\n", 427 | " break\n", 428 | " \n", 429 | " #print('remove start index : %d / remove end index : %d'%(remove_start_index,remove_end_index))\n", 430 | " label = label[remove_start_index:remove_end_index+1]\n", 431 | " signals = signals[0,remove_start_index*fs*epoch_size:(remove_end_index+1)*fs*epoch_size].reshape(1,-1)\n", 432 | " #print(np.bincount(label,minlength=6))\n", 433 | " if len(label) ==len(signals[0])//30//fs:\n", 434 | " np.save(save_annotations_path+filename.split('.')[0],label)\n", 435 | " np.save(save_signals_path+signals_filename.split('.')[0],signals)\n", 436 | " for i in range(6):\n", 437 | " total_label[i] += np.bincount(label,minlength=6)[i]\n", 438 | " \n", 439 | " \n", 440 | "print(total_label)\n" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "#### Wake 수 줄이는 작업" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 16, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "name": "stdout", 457 | "output_type": "stream", 458 | "text": [ 459 | "[ 4258 2762 17340 5575 7522 59]\n" 460 | ] 461 | }, 462 | { 463 | "data": { 464 | "text/plain": [ 465 | "" 466 | ] 467 | }, 468 | "execution_count": 16, 469 | "metadata": {}, 470 | "output_type": "execute_result" 471 | }, 472 | { 473 | "data": { 474 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASgklEQVR4nO3df6zddX3H8edrRZxTCTgurGth7Uw1QbJVuUESonFDsICxuOjWZpPOsVUNZDOabGX7A6cjYZvOjcRhqjaWzNGxIaORKnbMjZiA9FYrP0TWC3ZybUMrdRPjgim+98f53O3YntvennN6z237fCQn5/t9fz/f731/Q+jrfj/f7zk3VYUk6eT2U6NuQJI0eoaBJMkwkCQZBpIkDANJEnDKqBvo15lnnllLliwZdRuSdFzZvn37d6tq7OD6cRsGS5YsYWJiYtRtSNJxJcl/9qo7TSRJMgwkSYaBJAnDQJLELMIgyYYke5M80lX7hyQ72mtXkh2tviTJ/3Rt+3jXPhckeTjJZJKbk6TVX5Zka5Kd7f2MY3GikqSZzebK4NPAiu5CVf1GVS2vquXAHcBnuzY/Mb2tqt7dVb8FWAssa6/pY64D7q2qZcC9bV2SNIeOGAZVdR+wv9e29tv9rwO3He4YSRYCp1XV/dX5mtRbgava5pXAxra8sasuSZojg94zeB3wdFXt7KotTfK1JP+e5HWttgiY6hoz1WoAZ1fVHoD2ftZMPyzJ2iQTSSb27ds3YOuSpGmDhsFqfvKqYA9wblW9Gngf8PdJTgPSY9+j/kMKVbW+qsaranxs7JAP0EmS+tT3J5CTnAL8GnDBdK2qngOea8vbkzwBvILOlcDirt0XA7vb8tNJFlbVnjadtLffnnTiWbLu7lG3cIhdN1056hakoRvkyuCNwDer6v+mf5KMJVnQln+Rzo3iJ9v0z7NJLmr3Ga4G7mq7bQbWtOU1XXVJ0hyZzaOltwH3A69MMpXkmrZpFYfeOH498FCSrwP/BLy7qqZvPr8H+CQwCTwBfL7VbwIuTbITuLStS5Lm0BGniapq9Qz13+5Ru4POo6a9xk8A5/eoPwNccqQ+JEnHjp9AliQZBpIkw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJDGLMEiyIcneJI901T6Q5DtJdrTXFV3brk8ymeTxJG/qqq9otckk67rqS5N8JcnOJP+Q5NRhnqAk6chmc2XwaWBFj/pHq2p5e20BSHIesAp4Vdvnb5MsSLIA+BhwOXAesLqNBfjzdqxlwPeAawY5IUnS0TtiGFTVfcD+WR5vJbCpqp6rqm8Bk8CF7TVZVU9W1Y+ATcDKJAF+Ffintv9G4KqjPAdJ0oAGuWdwXZKH2jTSGa22CHiqa8xUq81U/1ngv6rqwEH1npKsTTKRZGLfvn0DtC5J6tZvGNwCvBxYDuwBPtLq6TG2+qj3VFXrq2q8qsbHxsaOrmNJ0oxO6Wenqnp6ejnJJ4DPtdUp4JyuoYuB3W25V/27wOlJTmlXB93jJUlzpK8rgyQLu1bfCkw/abQZWJXkhUmWAsuAB4FtwLL25NCpdG4yb66qAr4EvK3tvwa4q5+eJEn9O+KVQZLbgDcAZyaZAm4A3pBkOZ0pnV3AuwCq6tEktwPfAA4A11bV8+041wH3AAuADVX1aPsRfwRsSvJnwNeATw3t7CRJs3LEMKiq1T3KM/6DXVU3Ajf2qG8BtvSoP0nnaSNJ0oj4CWRJkmEgSTIMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiRmEQZJNiTZm+SRrtpfJvlmkoeS3Jnk9FZfkuR/kuxor4937XNBkoeTTCa5OUla/WVJtibZ2d7POBYnKkma2WyuDD4NrDiothU4v6p+CfgP4PqubU9U1fL2endX/RZgLbCsvaaPuQ64t6qWAfe2dUnSHDpiGFTVfcD+g2pfrKoDbfUBYPHhjpFkIXBaVd1fVQXcClzVNq8ENrbljV11SdIcGcY9g98BPt+1vjTJ15L8e5LXtdoiYKprzFSrAZxdVXsA2vtZQ+hJknQUThlk5yR/AhwAPtNKe4Bzq+qZJBcA/5zkVUB67F59/Ly1dKaaOPfcc/trWpJ0iL6vDJKsAd4M/Gab+qGqnquqZ9ryduAJ4BV0rgS6p5IWA7vb8tNtGml6OmnvTD+zqtZX1XhVjY+NjfXbuiTpIH2FQZIVwB8Bb6mqH3bVx5IsaMu/SOdG8ZNt+ufZJBe1p4iuBu5qu20G1rTlNV11SdIcOeI0UZLbgDcAZyaZAm6g8/TQC4Gt7QnRB9qTQ68HPpjkAPA88O6qmr75/B46Tya9iM49hun7DDcBtye5Bvg28PahnJkkadaOGAZVtbpH+VMzjL0DuGOGbRPA+T3qzwCXHKkPSdKx4yeQJUmGgSTJMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkScwyDJJsSLI3ySNdtZcl2ZpkZ3s/o9WT5OYkk0keSvKarn3WtPE7k6zpql+Q5OG2z81JMsyTlCQd3myvDD4NrDiotg64t6qWAfe2dYDLgWXttRa4BTrhAdwAvBa4ELhhOkDamLVd+x38syRJx9CswqCq7gP2H1ReCWxsyxuBq7rqt1bHA8DpSRYCbwK2VtX+qvoesBVY0badVlX3V1UBt3YdS5I0Bwa5Z3B2Ve0BaO9ntfoi4KmucVOtdrj6VI/6IZKsTTKRZGLfvn0DtC5J6nYsbiD3mu+vPuqHFqvWV9V4VY2PjY0N0KIkqdsgYfB0m+Khve9t9SngnK5xi4HdR6gv7lGXJM2RQcJgMzD9RNAa4K6u+tXtqaKLgP9u00j3AJclOaPdOL4MuKdtezbJRe0poqu7jiVJmgOnzGZQktuANwBnJpmi81TQTcDtSa4Bvg28vQ3fAlwBTAI/BN4JUFX7k3wI2NbGfbCqpm9Kv4fOE0svAj7fXpKkOTKrMKiq1TNsuqTH2AKuneE4G4ANPeoTwPmz6UWSNHx+AlmSZBhIkgwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJLEAGGQ5JVJdnS9vp/kvUk+kOQ7XfUruva5PslkkseTvKmrvqLVJpOsG/SkJElH55R+d6yqx4HlAEkWAN8B7gTeCXy0qj7cPT7JecAq4FXAzwP/kuQVbfPHgEuBKWBbks1V9Y1+e5MkHZ2+w+AglwBPVNV/JplpzEpgU1U9B3wrySRwYds2WVVPAiTZ1MYaBpI0R4YVBquA27rWr0tyNTABvL+qvgcsAh7oGjPVagBPHVR/ba8fkmQtsBbg3HPPHU7nkoZiybq7R93CIXbddOWoWzhuDHwDOcmpwFuAf2ylW4CX05lC2gN8ZHpoj93rMPVDi1Xrq2q8qsbHxsYG6luS9P+GcWVwOfDVqnoaYPodIMkngM+11SngnK79FgO72/JMdUnSHBjGo6Wr6ZoiSrKwa9tbgUfa8mZgVZIXJlkKLAMeBLYBy5IsbVcZq9pYSdIcGejKIMnP0HkK6F1d5b9IspzOVM+u6W1V9WiS2+ncGD4AXFtVz7fjXAfcAywANlTVo4P0JUk6OgOFQVX9EPjZg2rvOMz4G4Ebe9S3AFsG6UWS1D8/gSxJMgwkSYaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQG/Etnko7eknV3j7qFnnbddOWoW9AIeWUgSTIMJEmGgSSJIYRBkl1JHk6yI8lEq70sydYkO9v7Ga2eJDcnmUzyUJLXdB1nTRu/M8maQfuSJM3esK4MfqWqllfVeFtfB9xbVcuAe9s6wOXAsvZaC9wCnfAAbgBeC1wI3DAdIJKkY+9YTROtBDa25Y3AVV31W6vjAeD0JAuBNwFbq2p/VX0P2AqsOEa9SZIOMowwKOCLSbYnWdtqZ1fVHoD2flarLwKe6tp3qtVmqv+EJGuTTCSZ2Ldv3xBalyTBcD5ncHFV7U5yFrA1yTcPMzY9anWY+k8WqtYD6wHGx8cP2S5J6s/AVwZVtbu97wXupDPn/3Sb/qG9723Dp4BzunZfDOw+TF2SNAcGCoMkL07y0ull4DLgEWAzMP1E0Brgrra8Gbi6PVV0EfDfbRrpHuCyJGe0G8eXtZokaQ4MOk10NnBnkulj/X1VfSHJNuD2JNcA3wbe3sZvAa4AJoEfAu8EqKr9ST4EbGvjPlhV+wfsTZI0SwOFQVU9Cfxyj/ozwCU96gVcO8OxNgAbBulHktQfP4EsSTIMJEmGgSSJk/TvGczH75P3u+QljZJXBpIkw0CSZBhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgSWKAv3SW5BzgVuDngB8D66vqb5J8APg9YF8b+sdVtaXtcz1wDfA88PtVdU+rrwD+BlgAfLKqbuq3r5PVfPzrbeBfcJOOF4P82csDwPur6qtJXgpsT7K1bftoVX24e3CS84BVwKuAnwf+Jckr2uaPAZcCU8C2JJur6hsD9CZJOgp9h0FV7QH2tOVnkzwGLDrMLiuBTVX1HPCtJJPAhW3bZFU9CZBkUxtrGEjSHBnKPYMkS4BXA19ppeuSPJRkQ5IzWm0R8FTXblOtNlO9189Zm2QiycS+fft6DZEk9WHgMEjyEuAO4L1V9X3gFuDlwHI6Vw4fmR7aY/c6TP3QYtX6qhqvqvGxsbFBW5ckNYPcMyDJC+gEwWeq6rMAVfV01/ZPAJ9rq1PAOV27LwZ2t+WZ6pKkOdD3lUGSAJ8CHquqv+qqL+wa9lbgkba8GViV5IVJlgLLgAeBbcCyJEuTnErnJvPmfvuSJB29Qa4MLgbeATycZEer/TGwOslyOlM9u4B3AVTVo0lup3Nj+ABwbVU9D5DkOuAeOo+WbqiqRwfoS5J0lAZ5mujL9J7v33KYfW4EbuxR33K4/SRJx5afQJYkGQaSJMNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCQxj8IgyYokjyeZTLJu1P1I0slkXoRBkgXAx4DLgfOA1UnOG21XknTyOGXUDTQXApNV9SRAkk3ASuAbI+1K0kltybq7R93CIXbddOUxOW6q6pgc+KiaSN4GrKiq323r7wBeW1XXHTRuLbC2rb4SeHxOG+3tTOC7o25iyE7Ec4IT87w8p+PHfDmvX6iqsYOL8+XKID1qh6RUVa0H1h/7dmYvyURVjY+6j2E6Ec8JTszz8pyOH/P9vObFPQNgCjina30xsHtEvUjSSWe+hME2YFmSpUlOBVYBm0fckySdNObFNFFVHUhyHXAPsADYUFWPjrit2ZpX01ZDciKeE5yY5+U5HT/m9XnNixvIkqTRmi/TRJKkETIMJEmGQb9OxK/PSLIhyd4kj4y6l2FJck6SLyV5LMmjSf5g1D0NQ5KfTvJgkq+38/rTUfc0LEkWJPlaks+NupdhSLIrycNJdiSZGHU/M/GeQR/a12f8B3ApncditwGrq+q4/sR0ktcDPwBurarzR93PMCRZCCysqq8meSmwHbjqBPhvFeDFVfWDJC8Avgz8QVU9MOLWBpbkfcA4cFpVvXnU/QwqyS5gvKrmwwfOZuSVQX/+7+szqupHwPTXZxzXquo+YP+o+ximqtpTVV9ty88CjwGLRtvV4KrjB231Be113P9ml2QxcCXwyVH3crIxDPqzCHiqa32KE+AfmBNdkiXAq4GvjLaT4WjTKTuAvcDWqjoRzuuvgT8EfjzqRoaogC8m2d6+UmdeMgz6M6uvz9D8keQlwB3Ae6vq+6PuZxiq6vmqWk7nE/sXJjmup/aSvBnYW1XbR93LkF1cVa+h863M17bp2HnHMOiPX59xHGlz6ncAn6mqz466n2Grqv8C/g1YMeJWBnUx8JY2x74J+NUkfzfalgZXVbvb+17gTjrTzPOOYdAfvz7jONFutH4KeKyq/mrU/QxLkrEkp7flFwFvBL452q4GU1XXV9XiqlpC5/+pf62q3xpxWwNJ8uL24AJJXgxcBszLp/UMgz5U1QFg+uszHgNuP46+PmNGSW4D7gdemWQqyTWj7mkILgbeQee3zB3tdcWomxqChcCXkjxE55eTrVV1QjyKeYI5G/hykq8DDwJ3V9UXRtxTTz5aKknyykCSZBhIkjAMJEkYBpIkDANJEoaBJAnDQJIE/C8nHoLX0zV5vgAAAABJRU5ErkJggg==\n", 475 | "text/plain": [ 476 | "
" 477 | ] 478 | }, 479 | "metadata": { 480 | "needs_background": "light" 481 | }, 482 | "output_type": "display_data" 483 | } 484 | ], 485 | "source": [ 486 | "fs = 100 # Sampling rate (512 Hz)\n", 487 | "epoch_size = 30\n", 488 | "#data = np.random.uniform(0, 100, 1024) # 2 sec of data b/w 0.0-100.0\n", 489 | "\n", 490 | "path = 'D:/dataset/data_2013/origin_npy/annotations/remove_wake/'\n", 491 | "signals_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/'\n", 492 | "\n", 493 | "annotations_npy_list = search_signals_npy(path)\n", 494 | "\n", 495 | "total_label = np.zeros([6],dtype=int)\n", 496 | "\n", 497 | "for filename in annotations_npy_list:\n", 498 | " label = np.load(path + filename)\n", 499 | " signals_filename = search_correct_signals_npy(signals_path,filename)[0]\n", 500 | " \n", 501 | " signals = np.load(signals_path+signals_filename)\n", 502 | " \n", 503 | " \n", 504 | " #print('remove start index : %d / remove end index : %d'%(remove_start_index,remove_end_index))\n", 505 | " #print(np.bincount(label,minlength=6))\n", 506 | " if len(label) !=len(signals[0])//30//fs:\n", 507 | " print('file is fault!!!')\n", 508 | " for i in range(6):\n", 509 | " total_label[i] += np.bincount(label,minlength=6)[i]\n", 510 | " \n", 511 | "print(total_label)\n", 512 | "\n", 513 | "x = np.arange(len(total_label))\n", 514 | "\n", 515 | "plt.bar(x,total_label,width=0.7)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 17, 521 | "metadata": {}, 522 | "outputs": [ 523 | { 524 | "name": "stdout", 525 | "output_type": "stream", 526 | "text": [ 527 | "38\n" 528 | ] 529 | } 530 | ], 531 | "source": [ 532 | "print(len(annotations_npy_list))" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "#### 최종 데이터셋 형태" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 19, 545 | "metadata": {}, 546 | "outputs": [ 547 | { 548 | "name": "stdout", 549 | "output_type": "stream", 550 | "text": [ 551 | "['SC4001EC-Hypnogram.npy', 'SC4002EC-Hypnogram.npy', 'SC4011EH-Hypnogram.npy', 'SC4012EC-Hypnogram.npy', 'SC4021EH-Hypnogram.npy', 'SC4022EJ-Hypnogram.npy', 'SC4031EC-Hypnogram.npy', 'SC4032EP-Hypnogram.npy', 'SC4041EC-Hypnogram.npy', 'SC4042EC-Hypnogram.npy', 'SC4051EC-Hypnogram.npy', 'SC4052EC-Hypnogram.npy', 'SC4061EC-Hypnogram.npy', 'SC4062EC-Hypnogram.npy', 'SC4071EC-Hypnogram.npy', 'SC4072EH-Hypnogram.npy', 'SC4081EC-Hypnogram.npy', 'SC4082EP-Hypnogram.npy', 'SC4091EC-Hypnogram.npy', 'SC4092EC-Hypnogram.npy', 'SC4101EC-Hypnogram.npy', 'SC4102EC-Hypnogram.npy', 'SC4111EC-Hypnogram.npy', 'SC4112EC-Hypnogram.npy', 'SC4121EC-Hypnogram.npy', 'SC4122EV-Hypnogram.npy', 'SC4131EC-Hypnogram.npy', 'SC4141EU-Hypnogram.npy', 'SC4142EU-Hypnogram.npy', 'SC4151EC-Hypnogram.npy', 'SC4152EC-Hypnogram.npy', 'SC4161EC-Hypnogram.npy', 'SC4171EU-Hypnogram.npy', 'SC4172EC-Hypnogram.npy', 'SC4181EC-Hypnogram.npy', 'SC4182EC-Hypnogram.npy', 'SC4191EP-Hypnogram.npy', 'SC4192EV-Hypnogram.npy']\n", 552 | "['SC4092EC-Hypnogram.npy', 'SC4171EU-Hypnogram.npy', 'SC4031EC-Hypnogram.npy', 'SC4072EH-Hypnogram.npy', 'SC4101EC-Hypnogram.npy', 'SC4022EJ-Hypnogram.npy', 'SC4161EC-Hypnogram.npy', 'SC4071EC-Hypnogram.npy', 'SC4102EC-Hypnogram.npy', 'SC4152EC-Hypnogram.npy', 'SC4061EC-Hypnogram.npy', 'SC4001EC-Hypnogram.npy', 'SC4111EC-Hypnogram.npy', 'SC4141EU-Hypnogram.npy', 'SC4131EC-Hypnogram.npy', 'SC4012EC-Hypnogram.npy', 'SC4122EV-Hypnogram.npy', 'SC4062EC-Hypnogram.npy', 'SC4192EV-Hypnogram.npy', 'SC4052EC-Hypnogram.npy', 'SC4042EC-Hypnogram.npy', 'SC4151EC-Hypnogram.npy', 'SC4021EH-Hypnogram.npy', 'SC4002EC-Hypnogram.npy', 'SC4181EC-Hypnogram.npy', 'SC4121EC-Hypnogram.npy', 'SC4081EC-Hypnogram.npy', 'SC4191EP-Hypnogram.npy', 'SC4011EH-Hypnogram.npy', 'SC4051EC-Hypnogram.npy', 'SC4142EU-Hypnogram.npy', 'SC4041EC-Hypnogram.npy', 'SC4112EC-Hypnogram.npy', 'SC4172EC-Hypnogram.npy', 'SC4182EC-Hypnogram.npy', 'SC4082EP-Hypnogram.npy', 'SC4091EC-Hypnogram.npy', 'SC4032EP-Hypnogram.npy']\n", 553 | "[11.15174913 7.62978767 46.29686071 14.49083951 20.28514342 0.14561956]\n", 554 | "[12.08213347 6.37285589 45.93714787 16.22636785 19.1811694 0.20032553]\n" 555 | ] 556 | } 557 | ], 558 | "source": [ 559 | "fs = 100 # Sampling rate (512 Hz)\n", 560 | "epoch_size = 30\n", 561 | "#data = np.random.uniform(0, 100, 1024) # 2 sec of data b/w 0.0-100.0\n", 562 | "\n", 563 | "path = 'D:/dataset/data_2013/origin_npy/annotations/remove_wake/'\n", 564 | "signals_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/'\n", 565 | "\n", 566 | "annotations_npy_list = search_signals_npy(path)\n", 567 | "\n", 568 | "print(annotations_npy_list)\n", 569 | "\n", 570 | "random.shuffle(annotations_npy_list)\n", 571 | "\n", 572 | "print(annotations_npy_list)\n", 573 | "\n", 574 | "trainDataset_count = 30\n", 575 | "testDataset_count = len(annotations_npy_list)-trainDataset_count\n", 576 | "\n", 577 | "train_label = np.zeros([6],dtype=int)\n", 578 | "test_label = np.zeros([6],dtype=int)\n", 579 | "\n", 580 | "for filename in annotations_npy_list[:trainDataset_count]:\n", 581 | " label = np.load(path + filename)\n", 582 | " \n", 583 | " for i in range(6):\n", 584 | " train_label[i] += np.bincount(label,minlength=6)[i]\n", 585 | "\n", 586 | " \n", 587 | "for filename in annotations_npy_list[trainDataset_count:]:\n", 588 | " label = np.load(path + filename)\n", 589 | " \n", 590 | " for i in range(6):\n", 591 | " test_label[i] += np.bincount(label,minlength=6)[i]\n", 592 | " \n", 593 | "train_label = train_label / np.sum(train_label) * 100\n", 594 | "test_label = test_label / np.sum(test_label) * 100\n", 595 | "print(train_label)\n", 596 | "print(test_label)" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 21, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "path = 'D:/dataset/data_2013/origin_npy/annotations/remove_wake/'\n", 606 | "signals_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/'\n", 607 | "\n", 608 | "save_train_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/train/'\n", 609 | "save_test_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/test/'\n", 610 | "\n", 611 | "os.makedirs(save_train_path,exist_ok=True)\n", 612 | "os.makedirs(save_test_path,exist_ok=True)\n", 613 | "\n", 614 | "for filename in annotations_npy_list[:trainDataset_count]:\n", 615 | " signals_filename = search_correct_signals_npy(signals_path,filename)[0]\n", 616 | " shutil.copy(signals_path+signals_filename,save_train_path+filename)\n", 617 | " \n", 618 | "\n", 619 | " \n", 620 | "for filename in annotations_npy_list[trainDataset_count:]:\n", 621 | " signals_filename = search_correct_signals_npy(signals_path,filename)[0]\n", 622 | " shutil.copy(signals_path+signals_filename,save_test_path+filename)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 255, 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "name": "stdout", 632 | "output_type": "stream", 633 | "text": [ 634 | "['SC4001EC-Hypnogram.npy', 'SC4011EH-Hypnogram.npy', 'SC4012EC-Hypnogram.npy', 'SC4021EH-Hypnogram.npy', 'SC4032EP-Hypnogram.npy', 'SC4041EC-Hypnogram.npy', 'SC4051EC-Hypnogram.npy', 'SC4052EC-Hypnogram.npy', 'SC4061EC-Hypnogram.npy', 'SC4062EC-Hypnogram.npy', 'SC4071EC-Hypnogram.npy', 'SC4081EC-Hypnogram.npy', 'SC4082EP-Hypnogram.npy', 'SC4091EC-Hypnogram.npy', 'SC4092EC-Hypnogram.npy', 'SC4102EC-Hypnogram.npy', 'SC4111EC-Hypnogram.npy', 'SC4112EC-Hypnogram.npy', 'SC4121EC-Hypnogram.npy', 'SC4122EV-Hypnogram.npy', 'SC4131EC-Hypnogram.npy', 'SC4142EU-Hypnogram.npy', 'SC4152EC-Hypnogram.npy', 'SC4161EC-Hypnogram.npy', 'SC4171EU-Hypnogram.npy', 'SC4172EC-Hypnogram.npy', 'SC4181EC-Hypnogram.npy', 'SC4191EP-Hypnogram.npy']\n", 635 | "['SC4002EC-Hypnogram.npy', 'SC4022EJ-Hypnogram.npy', 'SC4031EC-Hypnogram.npy', 'SC4042EC-Hypnogram.npy', 'SC4072EH-Hypnogram.npy', 'SC4101EC-Hypnogram.npy', 'SC4141EU-Hypnogram.npy', 'SC4151EC-Hypnogram.npy', 'SC4182EC-Hypnogram.npy', 'SC4192EV-Hypnogram.npy']\n", 636 | "[11.52193537 6.78461816 47.01570681 15.07853403 19.44394295 0.15526268]\n", 637 | "[10.86447409 8.99093779 43.97719173 14.24498524 21.75949496 0.1629162 ]\n" 638 | ] 639 | } 640 | ], 641 | "source": [ 642 | "\n", 643 | "train_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/train/'\n", 644 | "test_path = 'D:/dataset/data_2013/origin_npy/Fpz-Cz/remove_wake/test/'\n", 645 | "annotations_path = 'D:/dataset/data_2013/origin_npy/annotations/remove_wake/'\n", 646 | "\n", 647 | "train_list = search_signals_npy(train_path)\n", 648 | "test_list = search_signals_npy(test_path)\n", 649 | "\n", 650 | "print(train_list)\n", 651 | "print(test_list)\n", 652 | "\n", 653 | "train_label = np.zeros([6],dtype=int)\n", 654 | "test_label = np.zeros([6],dtype=int)\n", 655 | "\n", 656 | "for filename in train_list:\n", 657 | " filename = search_correct_annotations_npy(annotations_path,filename)[0]\n", 658 | " label = np.load(annotations_path + filename)\n", 659 | " \n", 660 | " for i in range(6):\n", 661 | " train_label[i] += np.bincount(label,minlength=6)[i]\n", 662 | "\n", 663 | " \n", 664 | "for filename in test_list:\n", 665 | " filename = search_correct_annotations_npy(annotations_path,filename)[0]\n", 666 | " label = np.load(annotations_path + filename)\n", 667 | " \n", 668 | " for i in range(6):\n", 669 | " test_label[i] += np.bincount(label,minlength=6)[i]\n", 670 | " \n", 671 | "train_label = train_label / np.sum(train_label) * 100\n", 672 | "test_label = test_label / np.sum(test_label) * 100\n", 673 | "print(train_label)\n", 674 | "print(test_label)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [] 683 | } 684 | ], 685 | "metadata": { 686 | "kernelspec": { 687 | "display_name": "pytorch-1.3.1-cuda10.1", 688 | "language": "python", 689 | "name": "pytorch-cuda10.1" 690 | }, 691 | "language_info": { 692 | "codemirror_mode": { 693 | "name": "ipython", 694 | "version": 3 695 | }, 696 | "file_extension": ".py", 697 | "mimetype": "text/x-python", 698 | "name": "python", 699 | "nbconvert_exporter": "python", 700 | "pygments_lexer": "ipython3", 701 | "version": "3.7.5" 702 | } 703 | }, 704 | "nbformat": 4, 705 | "nbformat_minor": 2 706 | } 707 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from models.cnn.DeepSleepNet_cnn import * 2 | from utils.dataset.Sleep_edf.edf_to_numpy import * 3 | from utils.dataset.Sleep_edf.makeDataset_each import * 4 | from train.single_epoch.train_deepsleepnet import * 5 | from train.single_epoch.train_resnet import * 6 | 7 | from train.representation_learning.single_epoch.train_resnet_representationlearning import * 8 | from train.representation_learning.single_epoch.train_resnet_simCLR import * 9 | 10 | from utils.function import * 11 | 12 | from torchsummary import summary 13 | 14 | def make_datasets(): 15 | # check_label() 16 | files = check_edf_dataset(path='/home/eslab/dataset/sleep-edf-database-expanded-1.0.0/sleep-cassette/',type='edf') 17 | signals_edf_list = files['signals_file_list'] 18 | annotation_list = files['annotation_file_list'] 19 | 20 | print(len(signals_edf_list),len(annotation_list)) 21 | # make_dataset() 22 | # check_wellmade() 23 | # remove_unnessersary_wake() 24 | # makeDataset_for_loader() 25 | 26 | def check_distribution_correct(): 27 | data_path = '/home/eslab/dataset/sleep_edf/origin_npy/remove_wake_version0/each/' 28 | patient_list = os.listdir(data_path) 29 | print(len(patient_list)) 30 | 31 | patient_list = [data_path + filename for filename in patient_list] 32 | 33 | labels, labels_percent = check_label_info_withPath(patient_list) 34 | print(labels) 35 | print(labels_percent) 36 | 37 | original_path = '/home/eslab/dataset/sleep_edf/annotations/remove_wake_version0/' 38 | 39 | annotation_list = search_signals_npy(original_path) 40 | # print(len(annotation_list)) 41 | labels = np.zeros(6) 42 | for filename in annotation_list: 43 | label = np.load(original_path + filename) 44 | labels += np.bincount(label,minlength=6) 45 | print(labels) 46 | print(labels/np.sum(labels)) 47 | 48 | def train_deepsleepnet_singleEpoch(): 49 | use_channel_list = [[0],[2],[0,2],[0,1,2]] 50 | # use_channel_list = [[0],[0,1,2]] 51 | aug_p_list = [0.] 52 | 53 | aug_method_list=[[]] 54 | # aug_method_list=[['h_flip'],['v_flip'],['h_flip','v_flip']] 55 | entropy_list = [0.,0.1,.2,.3,.4,.5,.6,.7,.8,.9,1,1.5,2] 56 | for aug_p in aug_p_list: 57 | for aug_method in aug_method_list: 58 | for entropy_hyperparam in entropy_list: 59 | for use_channel in use_channel_list: 60 | training_deepsleepnet_dataloader(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,random_seed=2,use_channel=use_channel, 61 | entropy_hyperparam=entropy_hyperparam, 62 | aug_p=aug_p,aug_method=aug_method, 63 | classification_mode='5class',gpu_num=[0,1,2,3]) 64 | 65 | 66 | def train_resnet_singleEpoch(): 67 | use_channel_list = [[0]] 68 | # use_channel_list = [[0],[0,1,2]] 69 | aug_p_list = [0.] 70 | use_model_list = ['resnet18'] 71 | aug_method_list=[[]] 72 | first_conv_list = [[49, 4, 24]] 73 | entropy_list = [0.] 74 | block_kernel_size_list = [7,3,5] 75 | layer_filters_list = [[64,64,64,128],[64,64,128,128],[64,128,128,128],[64,64,128,256],[64,64,64,256],[64,128,128,256],[64,128,256,256],[64,128,256,512],[32,64,128,256],[16,32,64,128]] 76 | for layer_filters in layer_filters_list: 77 | for first_conv in first_conv_list: 78 | for block_kernel_size in block_kernel_size_list: 79 | for use_model in use_model_list: 80 | for aug_p in aug_p_list: 81 | for aug_method in aug_method_list: 82 | for entropy_hyperparam in entropy_list: 83 | for use_channel in use_channel_list: 84 | training_resnet_dataloader(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,random_seed=2,use_channel=use_channel, 85 | entropy_hyperparam=entropy_hyperparam, 86 | aug_p=aug_p,aug_method=aug_method,use_model = use_model, 87 | first_conv=first_conv,maxpool=[7,3,3], layer_filters=layer_filters,block_kernel_size=block_kernel_size,block_stride_size=2, 88 | classification_mode='5class',gpu_num=[0,1,2,3]) 89 | 90 | def train_resnet_singleEpoch_representationLearning(): 91 | batch_size_list = [256,512,1024,2048,4096] 92 | use_channel = [0,2] 93 | optim_list = ['SGD','LARS'] 94 | for optim in optim_list: 95 | for batch_size in batch_size_list: 96 | learning_rate = (0.3*batch_size / 256) 97 | training_resnet_dataloader_representationlearning(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,use_model = 'resnet18', 98 | random_seed=2,use_channel=use_channel,entropy_hyperparam=0.,classification_mode='5class',aug_p=0.,aug_method=[],learning_rate=learning_rate,batch_size=batch_size,optim=optim, 99 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 64, 64, 128],block_kernel_size=3,block_stride_size=2, 100 | gpu_num=[0,1,2,3]) 101 | 102 | def train_resnet_singleEpoch_representationLearning_simCLR(): 103 | batch_size_list = [256,512,1024,2048,4096] 104 | use_channel = [0,2] 105 | optim_list = ['SGD','LARS'] 106 | for optim in optim_list: 107 | for batch_size in batch_size_list: 108 | learning_rate = (0.3*batch_size / 256) 109 | training_resnet_dataloader_representationlearning_simCLR(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,use_model = 'resnet18', 110 | random_seed=2,use_channel=use_channel,entropy_hyperparam=0.,classification_mode='5class',aug_p=1.,aug_method=['permute','crop'],learning_rate=learning_rate,batch_size=batch_size,optim=optim, 111 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 64, 64, 128],block_kernel_size=3,block_stride_size=2, 112 | gpu_num=[0,1,2,3]) 113 | if __name__ == '__main__': 114 | train_resnet_singleEpoch_representationLearning_simCLR() 115 | train_resnet_singleEpoch_representationLearning() 116 | # train_resnet_singleEpoch() 117 | # check_distribution_correct() 118 | 119 | -------------------------------------------------------------------------------- /models/cnn/DeepSleepNet_cnn.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | 4 | class conv_bn_activate_layer(nn.Module): 5 | def __init__(self,in_channel,out_channel,filter_size,stride,padding,activation='relu',bias=False): 6 | super().__init__() 7 | self.conv = nn.Conv1d(in_channel,out_channel,kernel_size=filter_size,stride=stride,padding=padding,bias=bias) 8 | self.bn = nn.BatchNorm1d(out_channel) 9 | if activation =='relu': 10 | self.activation = nn.ReLU(inplace=True) 11 | elif activation == 'prelu': 12 | self.activation == nn.PReLU() 13 | elif activation == 'leakyrelu': 14 | self.activation == nn.LeakyReLU(negative_slope=0.01,inplace=True) 15 | def forward(self,x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | x = self.activation(x) 19 | return x 20 | 21 | class feature_extractor(nn.Module): 22 | def __init__(self,version='big',in_channel=1,layer=[64,128,128,128],activation='relu',sample_rate = 100,dropout_p= 0.5): 23 | super().__init__() 24 | if version == 'big': 25 | first_conv_filter_size = int(sample_rate * 4) 26 | first_conv_stride = int(sample_rate // 2) 27 | conv_filter_size = 6 28 | mp1_size = 4 29 | mp2_size = 2 30 | elif version == 'small': 31 | first_conv_filter_size = int(sample_rate // 2) 32 | first_conv_stride = int(sample_rate // 16) 33 | conv_filter_size = 8 34 | mp1_size = 8 35 | mp2_size = 4 36 | 37 | first_conv_padding = int((first_conv_filter_size-1)//2) 38 | conv_padding_size = int((conv_filter_size-1)//2) 39 | ''' 40 | In this paper, authors doesn't mention about padding size in convolution layer 41 | And, they use even number in convoltuion layer filter, so, symmety padding size can't use here. 42 | if you want to use padding, you can changed 'bias' parameter. 43 | ''' 44 | 45 | self.conv1 = conv_bn_activate_layer(in_channel=in_channel,out_channel=layer[0],filter_size=first_conv_filter_size,stride=first_conv_stride,padding=first_conv_padding,activation=activation,bias=False) 46 | self.maxpool1 = nn.MaxPool1d(kernel_size=mp1_size, stride=mp1_size, padding=0) 47 | if dropout_p > 0.: 48 | self.dropout = nn.Dropout(p=dropout_p) 49 | else: 50 | self.dropout = nn.Identitiy() 51 | 52 | self.conv2_1 = conv_bn_activate_layer(in_channel=layer[0],out_channel=layer[1],filter_size=conv_filter_size,stride=1,padding=conv_padding_size,activation=activation,bias=False) 53 | self.conv2_2 = conv_bn_activate_layer(in_channel=layer[1],out_channel=layer[2],filter_size=conv_filter_size,stride=1,padding=conv_padding_size,activation=activation,bias=False) 54 | self.conv2_3 = conv_bn_activate_layer(in_channel=layer[2],out_channel=layer[3],filter_size=conv_filter_size,stride=1,padding=conv_padding_size,activation=activation,bias=False) 55 | 56 | self.mxpool2 = nn.MaxPool1d(kernel_size=mp2_size,stride=mp2_size,padding=0) 57 | 58 | def forward(self,x): 59 | x = self.conv1(x) 60 | x = self.maxpool1(x) 61 | x = self.dropout(x) 62 | 63 | x = self.conv2_1(x) 64 | x = self.conv2_2(x) 65 | x = self.conv2_3(x) 66 | 67 | return x 68 | 69 | 70 | 71 | 72 | 73 | 74 | class DeepSleepNet_featureExtractor(nn.Module): 75 | def __init__(self,in_channel=1,layer=[64,128,128,128],activation='relu',sample_rate = 100,dropout_p=0.5): 76 | super().__init__() 77 | 78 | self.big_extractor = feature_extractor(version='big',in_channel=in_channel,layer=layer,activation=activation,sample_rate=sample_rate,dropout_p=dropout_p) 79 | self.small_extractor = feature_extractor(version='small',in_channel=in_channel,layer=layer,activation=activation,sample_rate=sample_rate,dropout_p=dropout_p) 80 | 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv1d): 83 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=activation) 84 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 85 | nn.init.constant_(m.weight, 1) 86 | nn.init.constant_(m.bias, 0) 87 | 88 | def forward(self, x): 89 | big_feature = self.big_extractor(x) 90 | small_feature = self.small_extractor(x) 91 | 92 | big_feature = torch.flatten(big_feature,1) 93 | small_feature = torch.flatten(small_feature,1) 94 | # print('big : ',big_feature.shape) 95 | # print('small : ',small_feature.shape) 96 | output = torch.cat((big_feature,small_feature),dim=1) 97 | return output 98 | 99 | class DeepSleepNet_CNN(nn.Module): # input channel = 8channel / output = 5 100 | def __init__(self,in_channel=1,out_channel=5,layer=[64,128,128,128],activation='relu',sample_rate = 100,dropout_p=0.5): 101 | super(DeepSleepNet_CNN, self).__init__() 102 | self.featureExtractor = DeepSleepNet_featureExtractor(in_channel=in_channel,layer=layer,activation=activation,sample_rate=sample_rate,dropout_p=dropout_p) 103 | if dropout_p > 0.: 104 | self.dropout = nn.Dropout(p=dropout_p) 105 | else: 106 | self.dropout = nn.Identity() 107 | self.fc = nn.Linear(9088, out_channel) # big and small conv concat 108 | 109 | 110 | def forward(self, x): 111 | x = self.featureExtractor(x) 112 | x = self.dropout(x) 113 | x = self.fc(x) 114 | 115 | return x -------------------------------------------------------------------------------- /models/cnn/ResNet.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from .modules.ResNet_module import * 3 | 4 | class ResNet_featureExtractor(nn.Module): 5 | def __init__(self, block=BasicBlock, layers=[2,2,2,2], first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256], in_channel=1, 6 | block_kernel_size=3,block_stride_size=1,use_batchnorm=False, zero_init_residual=False, 7 | groups=1, width_per_group=64, replace_stride_with_dilation=None,dilation=1, 8 | norm_layer=None,dropout_p=0.): 9 | super().__init__() 10 | if norm_layer is None: 11 | norm_layer = nn.BatchNorm1d 12 | self._norm_layer = norm_layer 13 | self.use_batchnorm = use_batchnorm 14 | self.inplanes = layer_filters[0] 15 | self.dilation = dilation 16 | if replace_stride_with_dilation is None: 17 | # each element in the tuple indicates if we should replace 18 | # the 2x2 stride with a dilated convolution instead 19 | replace_stride_with_dilation = [False, False, False] 20 | if len(replace_stride_with_dilation) != 3: 21 | raise ValueError("replace_stride_with_dilation should be None " 22 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 23 | self.groups = groups 24 | self.base_width = width_per_group 25 | self.conv1_1d = nn.Conv1d(in_channel, self.inplanes, kernel_size=first_conv[0], stride=first_conv[1], 26 | padding=first_conv[2], 27 | bias=False) 28 | # self.conv1_1d = nn.Conv1d(in_channel, self.inplanes, kernel_size=200, stride=40, padding=100, 29 | # bias=False) 30 | 31 | self.dropout = nn.Identity() 32 | self.dropout_p = dropout_p 33 | if self.dropout_p != 0.: 34 | self.dropout = nn.Dropout(p=self.dropout_p) 35 | 36 | self.bn1 = nn.Identity() 37 | if self.use_batchnorm: 38 | self.bn1 = norm_layer(self.inplanes) 39 | 40 | 41 | self.relu = nn.ReLU(inplace=True) 42 | self.maxpool = nn.MaxPool1d(kernel_size=maxpool[0], stride=maxpool[1], padding=maxpool[2]) 43 | 44 | self.block_kernel_size = block_kernel_size 45 | self.block_stride_size = block_stride_size 46 | 47 | self.padding = self.block_kernel_size // 2 48 | 49 | self.layer1 = self._make_layer(block, layer_filters[0], layers[0]) 50 | self.layer2 = self._make_layer(block, layer_filters[1], layers[1], stride=self.block_stride_size, 51 | dilate=replace_stride_with_dilation[0]) 52 | self.layer3 = self._make_layer(block, layer_filters[2], layers[2], stride=self.block_stride_size, 53 | dilate=replace_stride_with_dilation[1]) 54 | self.layer4 = self._make_layer(block, layer_filters[3], layers[3], stride=self.block_stride_size, 55 | dilate=replace_stride_with_dilation[2]) 56 | self.avgpool = nn.AdaptiveAvgPool1d(1) # 57 | 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv1d): 61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 62 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 63 | nn.init.constant_(m.weight, 1) 64 | nn.init.constant_(m.bias, 0) 65 | 66 | # Zero-initialize the last BN in each residual branch, 67 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 68 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 69 | if zero_init_residual: 70 | for m in self.modules(): 71 | ''' 72 | if isinstance(m, Bottleneck): 73 | nn.init.constant_(m.bn3.weight, 0) 74 | ''' 75 | 76 | if isinstance(m, BasicBlock): 77 | nn.init.constant_(m.bn2.weight, 0) 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 80 | norm_layer = self._norm_layer 81 | downsample = None 82 | print(f'self.dilation = {self.dilation}') 83 | if stride != 1 or self.inplanes != planes * block.expansion: 84 | downsample = nn.Sequential( 85 | conv1x1_1d(self.inplanes, planes * block.expansion, stride), 86 | norm_layer(planes * block.expansion), 87 | ) 88 | 89 | layers = [] 90 | # print('drop out : ',self.dropout_p) 91 | layers.append( 92 | block(self.inplanes, planes, block_kernel_size=self.block_kernel_size, stride=stride, downsample=downsample, 93 | groups=self.groups, 94 | base_width=self.base_width, padding=self.padding,dilation=self.dilation, 95 | norm_layer=norm_layer,dropout_p=self.dropout_p,use_batchnorm=self.use_batchnorm)) 96 | self.inplanes = planes * block.expansion 97 | for _ in range(1, blocks): 98 | layers.append(block(self.inplanes, planes, block_kernel_size=self.block_kernel_size, groups=self.groups, 99 | base_width=self.base_width, padding=self.padding,dilation=self.dilation, 100 | norm_layer=norm_layer,dropout_p=self.dropout_p,use_batchnorm=self.use_batchnorm)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def _forward_impl(self, x): 105 | x = self.conv1_1d(x) 106 | x = self.bn1(x) 107 | 108 | x = self.relu(x) 109 | x = self.dropout(x) 110 | x = self.maxpool(x) 111 | x = self.layer1(x) 112 | x = self.layer2(x) 113 | x = self.layer3(x) 114 | x = self.layer4(x) 115 | 116 | x = self.avgpool(x) # -3 117 | x = torch.flatten(x, 1) # -2 118 | return x 119 | 120 | def forward(self, x): 121 | return self._forward_impl(x) 122 | 123 | class ResNet_classification(nn.Module): 124 | def __init__(self, block=BasicBlock, layers=[2,2,2,2], first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256], in_channel=1, 125 | block_kernel_size=3,block_stride_size=1, num_classes=5, use_batchnorm=True, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None,dilation=1, 127 | norm_layer=None,dropout_p=0.): 128 | super().__init__() 129 | self.featureExtractor = ResNet_featureExtractor(block=block,layers=layers,layer_filters=layer_filters, 130 | first_conv=first_conv,maxpool=maxpool,block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, 131 | in_channel=in_channel,use_batchnorm=use_batchnorm,zero_init_residual=zero_init_residual, 132 | groups = groups, width_per_group=width_per_group, replace_stride_with_dilation=replace_stride_with_dilation,dilation=dilation, 133 | norm_layer=norm_layer,dropout_p=dropout_p) 134 | 135 | if block == BasicBlock: 136 | self.classifier = nn.Linear(layer_filters[-1],num_classes) 137 | else: 138 | self.classifier = nn.Linear(layer_filters[3]*4, num_classes) 139 | 140 | def forward(self, x): 141 | x = self.featureExtractor(x) 142 | x = self.classifier(x) 143 | 144 | return x 145 | 146 | 147 | class ResNet_contrastiveLearning(nn.Module): 148 | def __init__(self, block=BasicBlock, layers=[2,2,2,2], first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 256, 512], in_channel=1, 149 | block_kernel_size=3,block_stride_size=1,embedding=256,feature_dim=128,use_batchnorm=False, zero_init_residual=False, 150 | groups=1, width_per_group=64, replace_stride_with_dilation=None,dilation=1, 151 | norm_layer=None,dropout_p=0.): 152 | super().__init__() 153 | self.featureExtracxt = ResNet_featureExtractor(block=block,layers=layers,first_conv=first_conv,maxpool=maxpool,layer_filters=layer_filters, 154 | in_channel=in_channel,block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, 155 | use_batchnorm=use_batchnorm,zero_init_residual=zero_init_residual,groups=groups, 156 | width_per_group=width_per_group,replace_stride_with_dilation=replace_stride_with_dilation,dilation=dilation, 157 | norm_layer=norm_layer,dropout_p=dropout_p) 158 | 159 | if block == BasicBlock: 160 | self.g = nn.Sequential(nn.Linear(layer_filters[-1],embedding,bias=False),nn.BatchNorm1d(embedding),nn.ReLU(inplace=True),nn.Linear(embedding,feature_dim,bias=True)) 161 | else: 162 | self.g = nn.Sequential(nn.Linear(layer_filters[-1]*4,embedding,bias=False),nn.BatchNorm1d(embedding),nn.ReLU(inplace=True),nn.Linear(embedding,feature_dim,bias=True)) 163 | 164 | 165 | 166 | 167 | def _forward_impl(self, x): 168 | x = self.featureExtracxt(x) 169 | # === original ResNet Feature Extract Layers === 170 | x = self.g(x) 171 | return F.normalize(x,dim=-1) 172 | 173 | def forward(self, x): 174 | return self._forward_impl(x) -------------------------------------------------------------------------------- /models/cnn/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F -------------------------------------------------------------------------------- /models/cnn/__pycache__/DeepSleepNet_cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/models/cnn/__pycache__/DeepSleepNet_cnn.cpython-38.pyc -------------------------------------------------------------------------------- /models/cnn/__pycache__/ResNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/models/cnn/__pycache__/ResNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/cnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/models/cnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/cnn/modules/ResNet_module.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | def conv_1d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, padding=1, dilation=1): 4 | return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 5 | padding=padding, groups=groups, bias=False, dilation=dilation) 6 | 7 | def conv1x1_1d(in_planes, out_planes, stride=1): # we use this function when we have to downsampling 8 | """1x1 convolution""" 9 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, 10 | bias=False) 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | __constants__ = ['downsample'] 15 | 16 | def __init__(self, inplanes, planes, stride=1, block_kernel_size=3, padding=1,downsample=None, groups=1, 17 | base_width=64, dilation=1, norm_layer=None,dropout_p=0.,use_batchnorm=True): 18 | super(BasicBlock, self).__init__() 19 | if norm_layer is None: 20 | norm_layer = nn.BatchNorm1d 21 | # if groups != 1 or base_width != 64: 22 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 23 | # if dilation > 1: 24 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 25 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 26 | self.dilation = dilation 27 | self.use_batchnorm = use_batchnorm 28 | print(f'block_kernel_size == {block_kernel_size}') 29 | self.conv1 = conv_1d(in_planes=inplanes, out_planes=planes, kernel_size=block_kernel_size, stride=stride, 30 | groups=groups,padding=padding,dilation=dilation) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv_1d(in_planes=planes, out_planes=planes, kernel_size=block_kernel_size, groups=groups, 33 | padding=padding,dilation=dilation) 34 | 35 | 36 | self.bn1 = nn.Identity() 37 | self.bn2 = nn.Identity() 38 | if self.use_batchnorm: 39 | self.bn1 = norm_layer(planes) 40 | self.bn2 = norm_layer(planes) 41 | 42 | self.dropout_p = dropout_p 43 | 44 | self.dropout1 = nn.Identity() 45 | self.dropout2 = nn.Identity() 46 | 47 | if self.dropout_p > 0: 48 | self.dropout1 = nn.Dropout(p=self.dropout_p) 49 | self.dropout2 = nn.Dropout(p=self.dropout_p) 50 | 51 | self.block_kernel_size = block_kernel_size 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | # print('kernel size : ',self.block_kernel_size) 57 | # print('dilation : ', self.dilation) 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | # print('out1.shape : ',out.shape ) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.dropout1(out) 66 | 67 | out = self.conv2(out) 68 | # print('out2.shape : ',out.shape) 69 | out = self.bn2(out) 70 | 71 | if self.downsample is not None: 72 | identity = self.downsample(x) 73 | # print('out : ',out.shape) 74 | # print('id : ',identity.shape) 75 | out += identity 76 | out = self.relu(out) 77 | 78 | out = self.dropout2(out) 79 | 80 | return out 81 | 82 | class Bottleneck(nn.Module): 83 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 84 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 85 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 86 | # This variant is also known as ResNet V1.5 and improves accuracy according to 87 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 88 | 89 | expansion: int = 4 90 | 91 | def __init__( 92 | self, 93 | inplanes: int, 94 | planes: int, 95 | stride: int = 1, 96 | block_kernel_size: int =3, 97 | padding:int =1, 98 | downsample = None, 99 | groups: int = 1, 100 | base_width: int = 64, 101 | dilation: int = 1, 102 | norm_layer = None,dropout_p:float=0.,use_batchnorm:bool=True): 103 | super(Bottleneck, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm1d 106 | width = int(planes * (base_width / 64.)) * groups 107 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 108 | print(f'block_kernel_size == {block_kernel_size}') 109 | print(f'padding == {padding}') 110 | print(f'stride = {stride}') 111 | print(f'dilation = {dilation}') 112 | self.conv1 = conv1x1_1d(in_planes=inplanes, out_planes=width, stride=1) 113 | 114 | self.conv2 = conv_1d(in_planes=width, out_planes=width,kernel_size=block_kernel_size, stride=stride,padding=padding, groups=groups, dilation=dilation) 115 | 116 | self.conv3 = conv1x1_1d(in_planes=width, out_planes=planes * self.expansion, stride=1) 117 | 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | 121 | self.use_batchnorm = use_batchnorm 122 | if self.use_batchnorm: 123 | self.bn1 = norm_layer(width) 124 | self.bn2 = norm_layer(width) 125 | self.bn3 = norm_layer(planes * self.expansion) 126 | 127 | self.dropout_p = dropout_p 128 | 129 | self.dropout1 = nn.Identity() 130 | self.dropout2 = nn.Identity() 131 | self.dropout3 = nn.Identity() 132 | if self.dropout_p > 0: 133 | self.dropout1 = nn.Dropout(p=self.dropout_p) 134 | self.dropout2 = nn.Dropout(p=self.dropout_p) 135 | self.dropout3 = nn.Dropout(p=self.dropout_p) 136 | 137 | def forward(self, x): 138 | identity = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | # print(out.shape) 144 | out = self.dropout1(out) 145 | 146 | out = self.conv2(out) 147 | out = self.bn2(out) 148 | out = self.relu(out) 149 | # print(out.shape) 150 | 151 | out = self.dropout2(out) 152 | 153 | 154 | out = self.conv3(out) 155 | out = self.bn3(out) 156 | 157 | if self.downsample is not None: 158 | identity = self.downsample(x) 159 | # print(f'out shape = {out.shape} // identity shape = {identity.shape}') 160 | out += identity 161 | out = self.relu(out) 162 | 163 | out = self.dropout3(out) 164 | 165 | return out -------------------------------------------------------------------------------- /models/cnn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F -------------------------------------------------------------------------------- /models/cnn/modules/__pycache__/ResNet_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/models/cnn/modules/__pycache__/ResNet_module.cpython-38.pyc -------------------------------------------------------------------------------- /models/cnn/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/models/cnn/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.function import * 2 | from utils.loss_fn import * 3 | from utils.scheduler import * 4 | from utils.dataloader.sleep_edf import * 5 | 6 | from models.cnn.DeepSleepNet_cnn import * 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.autograd import Variable 13 | from torch.autograd import Function 14 | from torch.utils.data import DataLoader 15 | 16 | from torchvision import transforms, utils 17 | from torchvision.datasets import ImageFolder 18 | 19 | from torch import einsum 20 | from einops import rearrange, repeat 21 | 22 | from torchlars import LARS 23 | import torchlars 24 | 25 | # pip install torchsummary 26 | from torchsummary import summary 27 | # pip install tqdm 28 | from tqdm import tnrange, tqdm 29 | 30 | # multiprocessing 31 | import multiprocessing 32 | from multiprocessing import Process, Manager, Pool, Lock 33 | 34 | import os 35 | import random 36 | import math 37 | import time 38 | import sys 39 | import warnings 40 | import datetime 41 | import shutil 42 | 43 | # import argparse 44 | 45 | import itertools 46 | import numpy as np 47 | import pandas as pd 48 | 49 | -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/representation_learning/single_epoch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/__pycache__/train_resnet_representationlearning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/representation_learning/single_epoch/__pycache__/train_resnet_representationlearning.cpython-38.pyc -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/__pycache__/train_resnet_simCLR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/representation_learning/single_epoch/__pycache__/train_resnet_simCLR.cpython-38.pyc -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/train_resnet_representationlearning.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from models.cnn.ResNet import * 3 | 4 | 5 | 6 | def train_resnet_dataloader_representationlearning(save_filename,logging_filename,train_dataset_list,val_dataset_list,test_dataset_list,batch_size = 512,entropy_hyperparam=0., 7 | epochs=100,optim='Adam',loss_function='CE',use_model='resnet18', 8 | learning_rate=0.001,scheduler=None,warmup_iter=20,cosine_decay_iter=40,stop_iter=10, 9 | use_channel=[0,1],class_num=6,classification_mode='6class',aug_p=0.,aug_method=['h_flip','v_flip'], 10 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2, 11 | gpu_num=0,sample_rate= 100,epoch_size = 30): 12 | # cpu processor num 13 | cpu_num = multiprocessing.cpu_count() 14 | 15 | 16 | #dataload Training Dataset 17 | train_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=train_dataset_list,class_num=class_num, 18 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode,aug_p = aug_p,aug_method = aug_method,) 19 | train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True, num_workers=(cpu_num//4)) 20 | 21 | # calculate weight from training dataset (for "Class Balanced Weight") 22 | weights,count = make_weights_for_balanced_classes(train_dataset.signals_files_path,nclasses=class_num) 23 | weights = torch.DoubleTensor(weights) 24 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights,len(weights)) 25 | 26 | 27 | #dataload Validation Dataset 28 | val_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=val_dataset_list,class_num=class_num, 29 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 30 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 31 | 32 | 33 | test_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=test_dataset_list,class_num=class_num, 34 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 35 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 36 | 37 | 38 | print(train_dataset.length,val_dataset.length,test_dataset.length) 39 | 40 | # Adam optimizer paramQ 41 | b1 = 0.9 42 | b2 = 0.999 43 | 44 | # for Regularization 45 | beta = 0.001 46 | norm_square = 2 47 | 48 | check_file = open(logging_filename, 'w') # logging file 49 | 50 | best_accuracy = 0. 51 | best_epoch = 0 52 | if use_model == 'resnet18': 53 | model = ResNet_contrastiveLearning(block=BasicBlock, layers=[2,2,2,2], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 54 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 55 | elif use_model == 'resnet34': 56 | model = ResNet_contrastiveLearning(block=BasicBlock, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 57 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 58 | elif use_model == 'resnet50': 59 | model = ResNet_contrastiveLearning(block=Bottleneck, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 60 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 61 | 62 | cuda = torch.cuda.is_available() 63 | device = torch.device(f"cuda:{gpu_num[0]}" if torch.cuda.is_available() else "cpu") 64 | torch.cuda.set_device(device) 65 | 66 | if cuda: 67 | print('can use CUDA!!!') 68 | model = model.cuda() 69 | summary(model,(len(use_channel),sample_rate*epoch_size)) 70 | # exit(1) 71 | print('torch.cuda.device_count() : ', torch.cuda.device_count()) 72 | 73 | if torch.cuda.device_count() > 1: 74 | print('Multi GPU Activation !!!', torch.cuda.device_count()) 75 | model = nn.DataParallel(model) 76 | 77 | # summary(model, (3, 6000)) 78 | model.apply(weights_init) # weight init 79 | print('loss function : %s' % loss_function) 80 | 81 | loss_fn = SupConLoss(temperature=0.07,contrast_mode='one').to(device) 82 | 83 | # optimizer ADAM (SGD의 경우에는 정상적으로 학습이 진행되지 않았음) 84 | if optim == 'Adam': 85 | print('Optimizer : Adam') 86 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(b1, b2)) 87 | elif optim == 'RMS': 88 | print('Optimizer : RMSprop') 89 | optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) 90 | elif optim == 'SGD': 91 | print('Optimizer : SGD') 92 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 93 | elif optim == 'AdamW': 94 | print('Optimizer AdamW') 95 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(b1, b2)) 96 | elif optim == 'LARS': 97 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 98 | optimizer = torchlars.LARS(optimizer=optimizer,eps=1e-8,trust_coef=0.001) 99 | 100 | gamma = 0.8 101 | 102 | lr = learning_rate 103 | epochs = epochs 104 | if scheduler == 'WarmUp_restart_gamma': 105 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter} / cosine_decay_iter : {cosine_decay_iter} / gamma : {gamma}') 106 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 107 | scheduler = LearningRateWarmUP_restart_changeMax(optimizer=optimizer, 108 | warmup_iteration=warmup_iter, 109 | cosine_decay_iter=cosine_decay_iter, 110 | target_lr=lr, 111 | after_scheduler=scheduler_cosine,gamma=gamma) 112 | elif scheduler == 'WarmUp_restart': 113 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 114 | scheduler = LearningRateWarmUP_restart(optimizer=optimizer, 115 | warmup_iteration=warmup_iter, 116 | cosine_decay_iter=cosine_decay_iter, 117 | target_lr=lr, 118 | after_scheduler=scheduler_cosine) 119 | elif scheduler == 'WarmUp': 120 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter}') 121 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs-warmup_iter+1) 122 | scheduler = LearningRateWarmUP(optimizer=optimizer, 123 | warmup_iteration=warmup_iter, 124 | target_lr=lr, 125 | after_scheduler=scheduler_cosine) 126 | elif scheduler == 'StepLR': # 특정 epoch 도착하면 비율만큼 감소 127 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [11, 21], gamma=0.1) 128 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 129 | elif scheduler == 'Reduce': # factor 비율만큼 줄여주기 130 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=.5, patience=10, 131 | min_lr=1e-6) 132 | elif scheduler == 'Cosine': 133 | print('Cosine Scheduler') 134 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epochs) 135 | 136 | 137 | # scheduler 138 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=.5) 139 | # loss의 값이 최소가 되도록 하며, 50번 동안 loss의 값이 감소가 되지 않을 경우 factor값 만큼 140 | # learning_rate의 값을 줄이고, 최저 1e-6까지 줄어들 수 있게 설정 141 | 142 | best_loss = 0. 143 | stop_count = 0 144 | check_loss = False 145 | for epoch in range(epochs): 146 | if scheduler != 'None': 147 | scheduler.step(epoch) 148 | train_total_loss = 0.0 149 | train_total_count = 0 150 | train_total_data = 0 151 | 152 | val_total_loss = 0.0 153 | val_total_count = 0 154 | val_total_data = 0 155 | 156 | start_time = time.time() 157 | model.train() 158 | 159 | output_str = 'current epoch : %d/%d / current_lr : %f \n' % (epoch+1,epochs,optimizer.state_dict()['param_groups'][0]['lr']) 160 | sys.stdout.write(output_str) 161 | check_file.write(output_str) 162 | with tqdm(train_dataloader,desc='Train',unit='batch') as tepoch: 163 | for index,(batch_signal, batch_label) in enumerate(tepoch): 164 | batch_signal = batch_signal.to(device) 165 | batch_label = batch_label.long().to(device) 166 | optimizer.zero_grad() 167 | 168 | pred = model(batch_signal) 169 | pred = pred.unsqueeze(1) 170 | loss = loss_fn(pred, batch_label) # + beta * norm 171 | 172 | train_total_loss += loss.item() 173 | 174 | loss.backward() 175 | optimizer.step() 176 | 177 | 178 | tepoch.set_postfix(loss=train_total_loss/(index+1)) 179 | 180 | train_total_loss /= index 181 | 182 | output_str = 'train dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f \n' \ 183 | % (epoch + 1, epochs, time.time() - start_time, train_total_loss) 184 | # sys.stdout.write(output_str) 185 | check_file.write(output_str) 186 | 187 | # check validation dataset 188 | start_time = time.time() 189 | model.eval() 190 | 191 | with tqdm(val_dataloader,desc='Validation',unit='batch') as tepoch: 192 | for index,(batch_signal, batch_label) in enumerate(tepoch): 193 | batch_signal = batch_signal.to(device) 194 | batch_label = batch_label.long().to(device) 195 | 196 | with torch.no_grad(): 197 | pred = model(batch_signal) 198 | pred = pred.unsqueeze(1) # [batch , num of views(augmentation) , embedding_size] 199 | loss = loss_fn(pred, batch_label) 200 | # print(f'val loss : {loss.item()}') 201 | val_total_loss += loss.item() 202 | tepoch.set_postfix(loss=val_total_loss/(index+1)) 203 | 204 | val_total_loss /= (index + 1) 205 | 206 | output_str = 'val dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f\n' \ 207 | % (epoch + 1, epochs, time.time() - start_time, val_total_loss) 208 | # sys.stdout.write(output_str) 209 | check_file.write(output_str) 210 | 211 | # scheduler.step(float(val_total_loss)) 212 | # scheduler.step(epoch) 213 | if epoch == 0: 214 | best_loss = val_total_loss 215 | best_epoch = epoch 216 | save_file = save_filename 217 | if torch.cuda.device_count() > 1: 218 | torch.save(model.module.state_dict(), save_file) 219 | else: 220 | torch.save(model.state_dict(), save_file) 221 | stop_count = 0 222 | # sys.stdout.write(output_str) 223 | check_file.write(output_str) 224 | else: 225 | if best_loss > val_total_loss: 226 | best_loss = val_total_loss 227 | best_epoch = epoch 228 | save_file = save_filename 229 | if torch.cuda.device_count() > 1: 230 | torch.save(model.module.state_dict(), save_file) 231 | else: 232 | torch.save(model.state_dict(), save_file) 233 | stop_count = 0 234 | else: 235 | stop_count += 1 236 | 237 | output_str = 'best epoch : %d/%d / best loss : %f%%\n' \ 238 | % (best_epoch + 1, epochs, best_loss) 239 | sys.stdout.write(output_str) 240 | print('=' * 30) 241 | 242 | output_str = 'best epoch : %d/%d / best loss : %f%%\n' \ 243 | % (best_epoch + 1, epochs, best_loss) 244 | sys.stdout.write(output_str) 245 | check_file.write(output_str) 246 | print('=' * 30) 247 | 248 | check_file.close() 249 | 250 | 251 | 252 | 253 | 254 | def training_resnet_dataloader_representationlearning(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,use_model = 'resnet18', 255 | random_seed=2,use_channel=[0,1],entropy_hyperparam=0.,classification_mode='5class',aug_p=0.,aug_method=['h_flip','v_flip'], 256 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2,learning_rate=0.1,batch_size=512,optim='SGD', 257 | gpu_num=[0]): 258 | 259 | if use_dataset == 'sleep_edf': 260 | signals_path = '/home/eslab/dataset/sleep_edf_final/origin_npy/remove_wake_version1/each/' 261 | 262 | random.seed(random_seed) # seed 263 | np.random.seed(random_seed) 264 | torch.manual_seed(random_seed) 265 | 266 | # signals_path = '/home/eslab/dataset/seoulDataset/1channel_prefilter_butter_minmax_-1_1/signals_dataloader/' 267 | 268 | dataset_list = os.listdir(signals_path) 269 | dataset_list = [signals_path + filename + '/' for filename in dataset_list] 270 | dataset_list.sort() 271 | random.shuffle(dataset_list) 272 | 273 | 274 | 275 | training_fold_list = [] 276 | validation_fold_list = [] 277 | test_fold_list = [] 278 | 279 | 280 | 281 | val_length = int(len(dataset_list) * val_percent) 282 | test_length = int(len(dataset_list) * test_percent) 283 | train_length = int(len(dataset_list) - val_length - test_length) 284 | 285 | 286 | for i in range(0,val_length): 287 | validation_fold_list.append(dataset_list[i]) 288 | for i in range(val_length,val_length + test_length): 289 | test_fold_list.append(dataset_list[i]) 290 | for i in range(val_length + test_length,len(dataset_list)): 291 | training_fold_list.append(dataset_list[i]) 292 | 293 | 294 | 295 | # print(dataset_list[:10]) 296 | print('='*20) 297 | print(len(training_fold_list)) 298 | print(len(validation_fold_list)) 299 | print(len(test_fold_list)) 300 | print('='*20) 301 | 302 | train_label,train_label_percent = check_label_info_withPath(file_list = training_fold_list) 303 | val_label,val_label_percent = check_label_info_withPath(file_list = validation_fold_list) 304 | test_label,test_label_percent = check_label_info_withPath(file_list = test_fold_list) 305 | 306 | print(train_label) 307 | print(np.round(train_label_percent,3)) 308 | print(val_label) 309 | print(np.round(val_label_percent,3)) 310 | print(test_label) 311 | print(np.round(test_label_percent,3)) 312 | 313 | 314 | # exit(1) 315 | 316 | # number of classes 317 | if classification_mode == '6class': 318 | class_num = 6 319 | elif classification_mode =='5class': 320 | class_num = 5 321 | else: 322 | class_num=3 323 | 324 | # hyperparameters 325 | epochs = 100 326 | # batch_size = 2048 327 | warmup_iter=10 328 | cosine_decay_iter=10 329 | 330 | stop_iter = 10 331 | loss_function = 'CE' # CEs 332 | scheduler = 'WarmUp' # 'WarmUp_restart' 333 | 334 | print(f'class num = {class_num}') 335 | model_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/saved_model/representation_learning/{use_dataset}/{classification_mode}/'\ 336 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 337 | f'optim_{optim}_random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 338 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 339 | logging_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/log/representation_learning/{use_dataset}/{classification_mode}/'\ 340 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 341 | f'optim_{optim}_random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 342 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 343 | # model_save_path = '/home/eslab/kdy/git/Sleep_pytorch/saved_model/seoulDataset/single_epoch_models/' 344 | # logging_save_path = '/home/eslab/kdy/git/Sleep_pytorch/log/seoulDataset/single_epoch_models/' 345 | 346 | os.makedirs(model_save_path,exist_ok=True) 347 | os.makedirs(logging_save_path,exist_ok=True) 348 | 349 | save_filename = model_save_path + f'{use_model}_%.5f_{use_channel}_{batch_size}_entropy_{entropy_hyperparam}.pth'%(learning_rate) 350 | 351 | logging_filename = logging_save_path + f'{use_model}_%.5f_{use_channel}_{batch_size}_entropy_{entropy_hyperparam}.txt'%(learning_rate) 352 | print('logging filename : ',logging_filename) 353 | print('save filename : ',save_filename) 354 | 355 | # exit(1) 356 | train_resnet_dataloader_representationlearning(save_filename=save_filename,logging_filename=logging_filename,train_dataset_list=training_fold_list,val_dataset_list=validation_fold_list,test_dataset_list=test_fold_list, 357 | batch_size = batch_size,entropy_hyperparam=entropy_hyperparam, 358 | epochs=epochs,optim=optim,loss_function=loss_function, 359 | learning_rate=learning_rate,scheduler=scheduler,warmup_iter=warmup_iter,cosine_decay_iter=cosine_decay_iter,stop_iter=stop_iter, 360 | use_channel=use_channel,class_num=class_num,classification_mode=classification_mode,aug_p=aug_p,aug_method=aug_method, 361 | first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters,block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, 362 | gpu_num=gpu_num,sample_rate= 100,epoch_size = 30) -------------------------------------------------------------------------------- /train/representation_learning/single_epoch/train_resnet_simCLR.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from models.cnn.ResNet import * 3 | 4 | 5 | 6 | def train_resnet_dataloader_representationlearning_simCLR(save_filename,logging_filename,train_dataset_list,val_dataset_list,test_dataset_list,batch_size = 512,entropy_hyperparam=0., 7 | epochs=100,optim='Adam',loss_function='CE',use_model='resnet18', 8 | learning_rate=0.001,scheduler=None,warmup_iter=20,cosine_decay_iter=40,stop_iter=10, 9 | use_channel=[0,1],class_num=6,classification_mode='6class',aug_p=0.,aug_method=['h_flip','v_flip'], 10 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2, 11 | gpu_num=0,sample_rate= 100,epoch_size = 30): 12 | # cpu processor num 13 | cpu_num = multiprocessing.cpu_count() 14 | 15 | 16 | #dataload Training Dataset 17 | train_dataset = Sleep_Dataset_withPath_sleepEDF_simCLR(dataset_list=train_dataset_list,class_num=class_num, 18 | use_channel=use_channel,use_cuda = True,preprocessing = True,sample_rate=sample_rate, 19 | preprocessing_method = ['permute','crop'], 20 | permute_size=sample_rate, 21 | crop_size=2500, 22 | cutout_size=1000, 23 | classification_mode='5class') 24 | train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True, num_workers=(cpu_num//4)) 25 | 26 | # calculate weight from training dataset (for "Class Balanced Weight") 27 | weights,count = make_weights_for_balanced_classes(train_dataset.signals_files_path,nclasses=class_num) 28 | weights = torch.DoubleTensor(weights) 29 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights,len(weights)) 30 | 31 | 32 | #dataload Validation Dataset 33 | val_dataset = Sleep_Dataset_withPath_sleepEDF_simCLR(dataset_list=val_dataset_list,class_num=class_num, 34 | use_channel=use_channel,use_cuda = True,preprocessing = True,sample_rate=sample_rate, 35 | preprocessing_method = ['permute','crop'], 36 | permute_size=sample_rate, 37 | crop_size=2500, 38 | cutout_size=1000, 39 | classification_mode='5class') 40 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 41 | 42 | 43 | 44 | 45 | print(train_dataset.length,val_dataset.length) 46 | 47 | # Adam optimizer paramQ 48 | b1 = 0.9 49 | b2 = 0.999 50 | 51 | # for Regularization 52 | beta = 0.001 53 | norm_square = 2 54 | 55 | check_file = open(logging_filename, 'w') # logging file 56 | 57 | best_accuracy = 0. 58 | best_epoch = 0 59 | if use_model == 'resnet18': 60 | model = ResNet_contrastiveLearning(block=BasicBlock, layers=[2,2,2,2], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 61 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 62 | elif use_model == 'resnet34': 63 | model = ResNet_contrastiveLearning(block=BasicBlock, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 64 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 65 | elif use_model == 'resnet50': 66 | model = ResNet_contrastiveLearning(block=Bottleneck, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 67 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, embedding=256,feature_dim=128, use_batchnorm=True, zero_init_residual=False) 68 | 69 | cuda = torch.cuda.is_available() 70 | device = torch.device(f"cuda:{gpu_num[0]}" if torch.cuda.is_available() else "cpu") 71 | torch.cuda.set_device(device) 72 | 73 | if cuda: 74 | print('can use CUDA!!!') 75 | model = model.cuda() 76 | summary(model,(len(use_channel),sample_rate*epoch_size)) 77 | # exit(1) 78 | print('torch.cuda.device_count() : ', torch.cuda.device_count()) 79 | 80 | if torch.cuda.device_count() > 1: 81 | print('Multi GPU Activation !!!', torch.cuda.device_count()) 82 | model = nn.DataParallel(model) 83 | 84 | # summary(model, (3, 6000)) 85 | model.apply(weights_init) # weight init 86 | print('loss function : %s' % loss_function) 87 | 88 | loss_fn = SupConLoss(temperature=0.07,contrast_mode='all').to(device) 89 | 90 | # optimizer ADAM (SGD의 경우에는 정상적으로 학습이 진행되지 않았음) 91 | if optim == 'Adam': 92 | print('Optimizer : Adam') 93 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(b1, b2)) 94 | elif optim == 'RMS': 95 | print('Optimizer : RMSprop') 96 | optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) 97 | elif optim == 'SGD': 98 | print('Optimizer : SGD') 99 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 100 | elif optim == 'AdamW': 101 | print('Optimizer AdamW') 102 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(b1, b2)) 103 | elif optim == 'LARS': 104 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 105 | optimizer = torchlars.LARS(optimizer=optimizer,eps=1e-8,trust_coef=0.001) 106 | 107 | gamma = 0.8 108 | 109 | lr = learning_rate 110 | epochs = epochs 111 | if scheduler == 'WarmUp_restart_gamma': 112 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter} / cosine_decay_iter : {cosine_decay_iter} / gamma : {gamma}') 113 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 114 | scheduler = LearningRateWarmUP_restart_changeMax(optimizer=optimizer, 115 | warmup_iteration=warmup_iter, 116 | cosine_decay_iter=cosine_decay_iter, 117 | target_lr=lr, 118 | after_scheduler=scheduler_cosine,gamma=gamma) 119 | elif scheduler == 'WarmUp_restart': 120 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 121 | scheduler = LearningRateWarmUP_restart(optimizer=optimizer, 122 | warmup_iteration=warmup_iter, 123 | cosine_decay_iter=cosine_decay_iter, 124 | target_lr=lr, 125 | after_scheduler=scheduler_cosine) 126 | elif scheduler == 'WarmUp': 127 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter}') 128 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs-warmup_iter+1) 129 | scheduler = LearningRateWarmUP(optimizer=optimizer, 130 | warmup_iteration=warmup_iter, 131 | target_lr=lr, 132 | after_scheduler=scheduler_cosine) 133 | elif scheduler == 'StepLR': # 특정 epoch 도착하면 비율만큼 감소 134 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [11, 21], gamma=0.1) 135 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 136 | elif scheduler == 'Reduce': # factor 비율만큼 줄여주기 137 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=.5, patience=10, 138 | min_lr=1e-6) 139 | elif scheduler == 'Cosine': 140 | print('Cosine Scheduler') 141 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epochs) 142 | 143 | 144 | # scheduler 145 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=.5) 146 | # loss의 값이 최소가 되도록 하며, 50번 동안 loss의 값이 감소가 되지 않을 경우 factor값 만큼 147 | # learning_rate의 값을 줄이고, 최저 1e-6까지 줄어들 수 있게 설정 148 | 149 | best_loss = 0. 150 | stop_count = 0 151 | check_loss = False 152 | for epoch in range(epochs): 153 | if scheduler != 'None': 154 | scheduler.step(epoch) 155 | train_total_loss = 0.0 156 | train_total_count = 0 157 | train_total_data = 0 158 | 159 | val_total_loss = 0.0 160 | val_total_count = 0 161 | val_total_data = 0 162 | 163 | start_time = time.time() 164 | model.train() 165 | 166 | output_str = 'current epoch : %d/%d / current_lr : %f \n' % (epoch+1,epochs,optimizer.state_dict()['param_groups'][0]['lr']) 167 | sys.stdout.write(output_str) 168 | check_file.write(output_str) 169 | with tqdm(train_dataloader,desc='Train',unit='batch') as tepoch: 170 | for index,(batch_signal1,batch_signal2, batch_label) in enumerate(tepoch): 171 | 172 | bsz = batch_signal1.shape[0] 173 | 174 | batch_signal1 = batch_signal1.to(device) 175 | batch_signal2 = batch_signal2.to(device) 176 | batch_label = batch_label.long().to(device) 177 | batch_signal = torch.cat([batch_signal1,batch_signal2],dim=0) 178 | 179 | optimizer.zero_grad() 180 | 181 | pred = model(batch_signal) 182 | 183 | 184 | # print(features) 185 | f1, f2 = torch.split(pred, [bsz, bsz], dim=0) 186 | pred = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 187 | 188 | 189 | 190 | loss = loss_fn(pred) # + beta * norm 191 | 192 | train_total_loss += loss.item() 193 | 194 | loss.backward() 195 | optimizer.step() 196 | 197 | 198 | tepoch.set_postfix(loss=train_total_loss/(index+1)) 199 | 200 | train_total_loss /= index 201 | 202 | output_str = 'train dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f \n' \ 203 | % (epoch + 1, epochs, time.time() - start_time, train_total_loss) 204 | # sys.stdout.write(output_str) 205 | check_file.write(output_str) 206 | 207 | # check validation dataset 208 | start_time = time.time() 209 | model.eval() 210 | 211 | with tqdm(val_dataloader,desc='Validation',unit='batch') as tepoch: 212 | for index,(batch_signal1,batch_signal2, batch_label) in enumerate(tepoch): 213 | bsz = batch_signal1.shape[0] 214 | 215 | batch_signal1 = batch_signal1.to(device) 216 | batch_signal2 = batch_signal2.to(device) 217 | batch_label = batch_label.long().to(device) 218 | 219 | batch_signal = torch.cat([batch_signal1,batch_signal2],dim=0) 220 | 221 | with torch.no_grad(): 222 | pred = model(batch_signal) 223 | # print(features) 224 | f1, f2 = torch.split(pred, [bsz, bsz], dim=0) 225 | pred = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 226 | 227 | loss = loss_fn(pred) # + beta * norm 228 | # print(f'val loss : {loss.item()}') 229 | val_total_loss += loss.item() 230 | tepoch.set_postfix(loss=val_total_loss/(index+1)) 231 | 232 | val_total_loss /= (index + 1) 233 | 234 | output_str = 'val dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f\n' \ 235 | % (epoch + 1, epochs, time.time() - start_time, val_total_loss) 236 | # sys.stdout.write(output_str) 237 | check_file.write(output_str) 238 | 239 | # scheduler.step(float(val_total_loss)) 240 | # scheduler.step(epoch) 241 | if epoch == 0: 242 | best_loss = val_total_loss 243 | best_epoch = epoch 244 | save_file = save_filename 245 | if torch.cuda.device_count() > 1: 246 | torch.save(model.module.state_dict(), save_file) 247 | else: 248 | torch.save(model.state_dict(), save_file) 249 | stop_count = 0 250 | # sys.stdout.write(output_str) 251 | check_file.write(output_str) 252 | else: 253 | if best_loss > val_total_loss: 254 | best_loss = val_total_loss 255 | best_epoch = epoch 256 | save_file = save_filename 257 | if torch.cuda.device_count() > 1: 258 | torch.save(model.module.state_dict(), save_file) 259 | else: 260 | torch.save(model.state_dict(), save_file) 261 | stop_count = 0 262 | else: 263 | stop_count += 1 264 | 265 | output_str = 'best epoch : %d/%d / best loss : %f%%\n' \ 266 | % (best_epoch + 1, epochs, best_loss) 267 | sys.stdout.write(output_str) 268 | print('=' * 30) 269 | 270 | output_str = 'best epoch : %d/%d / best loss : %f%%\n' \ 271 | % (best_epoch + 1, epochs, best_loss) 272 | sys.stdout.write(output_str) 273 | check_file.write(output_str) 274 | print('=' * 30) 275 | 276 | check_file.close() 277 | 278 | 279 | 280 | 281 | 282 | def training_resnet_dataloader_representationlearning_simCLR(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,use_model = 'resnet18', 283 | random_seed=2,use_channel=[0,1],entropy_hyperparam=0.,classification_mode='5class',aug_p=0.,aug_method=['h_flip','v_flip'], 284 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2,learning_rate=0.1,batch_size=512,optim='SGD', 285 | gpu_num=[0]): 286 | 287 | if use_dataset == 'sleep_edf': 288 | signals_path = '/home/eslab/dataset/sleep_edf_final/origin_npy/remove_wake_version1/each/' 289 | 290 | random.seed(random_seed) # seed 291 | np.random.seed(random_seed) 292 | torch.manual_seed(random_seed) 293 | 294 | # signals_path = '/home/eslab/dataset/seoulDataset/1channel_prefilter_butter_minmax_-1_1/signals_dataloader/' 295 | 296 | dataset_list = os.listdir(signals_path) 297 | dataset_list = [signals_path + filename + '/' for filename in dataset_list] 298 | dataset_list.sort() 299 | random.shuffle(dataset_list) 300 | 301 | 302 | 303 | training_fold_list = [] 304 | validation_fold_list = [] 305 | test_fold_list = [] 306 | 307 | 308 | 309 | val_length = int(len(dataset_list) * val_percent) 310 | test_length = int(len(dataset_list) * test_percent) 311 | train_length = int(len(dataset_list) - val_length - test_length) 312 | 313 | 314 | for i in range(0,val_length): 315 | validation_fold_list.append(dataset_list[i]) 316 | for i in range(val_length,val_length + test_length): 317 | test_fold_list.append(dataset_list[i]) 318 | for i in range(val_length + test_length,len(dataset_list)): 319 | training_fold_list.append(dataset_list[i]) 320 | 321 | 322 | 323 | # print(dataset_list[:10]) 324 | print('='*20) 325 | print(len(training_fold_list)) 326 | print(len(validation_fold_list)) 327 | print(len(test_fold_list)) 328 | print('='*20) 329 | 330 | train_label,train_label_percent = check_label_info_withPath(file_list = training_fold_list) 331 | val_label,val_label_percent = check_label_info_withPath(file_list = validation_fold_list) 332 | test_label,test_label_percent = check_label_info_withPath(file_list = test_fold_list) 333 | 334 | print(train_label) 335 | print(np.round(train_label_percent,3)) 336 | print(val_label) 337 | print(np.round(val_label_percent,3)) 338 | print(test_label) 339 | print(np.round(test_label_percent,3)) 340 | 341 | 342 | # exit(1) 343 | 344 | # number of classes 345 | if classification_mode == '6class': 346 | class_num = 6 347 | elif classification_mode =='5class': 348 | class_num = 5 349 | else: 350 | class_num=3 351 | 352 | # hyperparameters 353 | epochs = 100 354 | warmup_iter=10 355 | cosine_decay_iter=10 356 | 357 | stop_iter = 10 358 | loss_function = 'CE' # CEs 359 | scheduler = 'WarmUp' # 'WarmUp_restart' 360 | 361 | print(f'class num = {class_num}') 362 | model_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/saved_model/representation_learning_simCLR/{use_dataset}/{classification_mode}/'\ 363 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 364 | f'optim_{optim}_random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 365 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 366 | logging_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/log/representation_learning_simCLR/{use_dataset}/{classification_mode}/'\ 367 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 368 | f'optim_{optim}_random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 369 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 370 | # model_save_path = '/home/eslab/kdy/git/Sleep_pytorch/saved_model/seoulDataset/single_epoch_models/' 371 | # logging_save_path = '/home/eslab/kdy/git/Sleep_pytorch/log/seoulDataset/single_epoch_models/' 372 | 373 | os.makedirs(model_save_path,exist_ok=True) 374 | os.makedirs(logging_save_path,exist_ok=True) 375 | 376 | save_filename = model_save_path + f'{use_model}_%.5f_{use_channel}_{batch_size}_entropy_{entropy_hyperparam}.pth'%(learning_rate) 377 | 378 | logging_filename = logging_save_path + f'{use_model}_%.5f_{use_channel}_{batch_size}_entropy_{entropy_hyperparam}.txt'%(learning_rate) 379 | print('logging filename : ',logging_filename) 380 | print('save filename : ',save_filename) 381 | 382 | # exit(1) 383 | train_resnet_dataloader_representationlearning_simCLR(save_filename=save_filename,logging_filename=logging_filename,train_dataset_list=training_fold_list,val_dataset_list=validation_fold_list,test_dataset_list=test_fold_list, 384 | batch_size = batch_size,entropy_hyperparam=entropy_hyperparam, 385 | epochs=epochs,optim=optim,loss_function=loss_function, 386 | learning_rate=learning_rate,scheduler=scheduler,warmup_iter=warmup_iter,cosine_decay_iter=cosine_decay_iter,stop_iter=stop_iter, 387 | use_channel=use_channel,class_num=class_num,classification_mode=classification_mode,aug_p=aug_p,aug_method=aug_method, 388 | first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters,block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, 389 | gpu_num=gpu_num,sample_rate= 100,epoch_size = 30) -------------------------------------------------------------------------------- /train/single_epoch/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.function import * 2 | from utils.loss_fn import * 3 | from utils.scheduler import * 4 | from utils.dataloader.sleep_edf import * 5 | 6 | from models.cnn.DeepSleepNet_cnn import * 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from torch.autograd import Variable 13 | from torch.autograd import Function 14 | from torch.utils.data import DataLoader 15 | 16 | from torchvision import transforms, utils 17 | from torchvision.datasets import ImageFolder 18 | 19 | from torch import einsum 20 | from einops import rearrange, repeat 21 | 22 | # pip install torchsummary 23 | from torchsummary import summary 24 | # pip install tqdm 25 | from tqdm import tnrange, tqdm 26 | 27 | # multiprocessing 28 | import multiprocessing 29 | from multiprocessing import Process, Manager, Pool, Lock 30 | 31 | import os 32 | import random 33 | import math 34 | import time 35 | import sys 36 | import warnings 37 | import datetime 38 | import shutil 39 | 40 | # import argparse 41 | 42 | import itertools 43 | import numpy as np 44 | import pandas as pd 45 | -------------------------------------------------------------------------------- /train/single_epoch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/single_epoch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train/single_epoch/__pycache__/train_deepsleepnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/single_epoch/__pycache__/train_deepsleepnet.cpython-38.pyc -------------------------------------------------------------------------------- /train/single_epoch/__pycache__/train_resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/train/single_epoch/__pycache__/train_resnet.cpython-38.pyc -------------------------------------------------------------------------------- /train/single_epoch/train_deepsleepnet.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | 4 | 5 | 6 | def train_deepsleepnet_dataloader(save_filename,logging_filename,train_dataset_list,val_dataset_list,test_dataset_list,batch_size = 512,entropy_hyperparam=0., 7 | epochs=100,optim='Adam',loss_function='CE', 8 | learning_rate=0.001,scheduler=None,warmup_iter=20,cosine_decay_iter=40,stop_iter=10, 9 | use_channel=[0,1],class_num=6,classification_mode='6class',aug_p=0.,aug_method=['h_flip','v_flip'], 10 | gpu_num=0,sample_rate= 100,epoch_size = 30): 11 | # cpu processor num 12 | cpu_num = multiprocessing.cpu_count() 13 | 14 | 15 | #dataload Training Dataset 16 | train_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=train_dataset_list,class_num=class_num, 17 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode,aug_p = aug_p,aug_method = ['h_flip'],) 18 | train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True, num_workers=(cpu_num//4)) 19 | 20 | # calculate weight from training dataset (for "Class Balanced Weight") 21 | weights,count = make_weights_for_balanced_classes(train_dataset.signals_files_path,nclasses=class_num) 22 | weights = torch.DoubleTensor(weights) 23 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights,len(weights)) 24 | 25 | 26 | #dataload Validation Dataset 27 | val_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=val_dataset_list,class_num=class_num, 28 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 29 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 30 | 31 | 32 | test_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=test_dataset_list,class_num=class_num, 33 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 34 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 35 | 36 | 37 | print(train_dataset.length,val_dataset.length,test_dataset.length) 38 | 39 | # Adam optimizer paramQ 40 | b1 = 0.9 41 | b2 = 0.999 42 | 43 | # for Regularization 44 | beta = 0.001 45 | norm_square = 2 46 | 47 | check_file = open(logging_filename, 'w') # logging file 48 | 49 | best_accuracy = 0. 50 | best_epoch = 0 51 | model = DeepSleepNet_CNN(in_channel=len(use_channel),out_channel=class_num,layer=[64,128,128,128],activation='relu',sample_rate = sample_rate,dropout_p=0.5) 52 | 53 | cuda = torch.cuda.is_available() 54 | device = torch.device(f"cuda:{gpu_num[0]}" if torch.cuda.is_available() else "cpu") 55 | torch.cuda.set_device(device) 56 | 57 | if cuda: 58 | print('can use CUDA!!!') 59 | model = model.cuda() 60 | summary(model,(len(use_channel),sample_rate*epoch_size)) 61 | # exit(1) 62 | print('torch.cuda.device_count() : ', torch.cuda.device_count()) 63 | 64 | if torch.cuda.device_count() > 1: 65 | print('Multi GPU Activation !!!', torch.cuda.device_count()) 66 | model = nn.DataParallel(model) 67 | 68 | # summary(model, (3, 6000)) 69 | model.apply(weights_init) # weight init 70 | print('loss function : %s' % loss_function) 71 | if loss_function == 'CE': 72 | loss_fn = nn.CrossEntropyLoss().to(device) 73 | elif loss_function == 'CEW': 74 | samples_per_cls = count / np.sum(count) 75 | no_of_classes = class_num 76 | effective_num = 1.0 - np.power(beta, samples_per_cls) 77 | weights = (1.0 - beta) / np.array(effective_num) 78 | weights = weights / np.sum(weights) * no_of_classes 79 | weights = torch.tensor(weights).float() 80 | weights = weights.to(device) 81 | loss_fn = nn.CrossEntropyLoss(weight=weights).to(device) 82 | elif loss_function == 'FL': 83 | loss_fn = FocalLoss(gamma=2).to(device) 84 | elif loss_function == 'CBL': 85 | samples_per_cls = count / np.sum(count) 86 | loss_fn = CB_loss(samples_per_cls=samples_per_cls, no_of_classes=class_num, loss_type='focal', beta=0.9999, 87 | gamma=2.0) 88 | # loss_fn = FocalLoss(gamma=2).to(device) 89 | if entropy_hyperparam > 0.: 90 | loss_fn2 = Entropy() 91 | # optimizer ADAM (SGD의 경우에는 정상적으로 학습이 진행되지 않았음) 92 | if optim == 'Adam': 93 | print('Optimizer : Adam') 94 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(b1, b2)) 95 | elif optim == 'RMS': 96 | print('Optimizer : RMSprop') 97 | optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) 98 | elif optim == 'SGD': 99 | print('Optimizer : SGD') 100 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 101 | elif optim == 'AdamW': 102 | print('Optimizer AdamW') 103 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(b1, b2)) 104 | 105 | 106 | gamma = 0.8 107 | 108 | lr = learning_rate 109 | epochs = epochs 110 | if scheduler == 'WarmUp_restart_gamma': 111 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter} / cosine_decay_iter : {cosine_decay_iter} / gamma : {gamma}') 112 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 113 | scheduler = LearningRateWarmUP_restart_changeMax(optimizer=optimizer, 114 | warmup_iteration=warmup_iter, 115 | cosine_decay_iter=cosine_decay_iter, 116 | target_lr=lr, 117 | after_scheduler=scheduler_cosine,gamma=gamma) 118 | elif scheduler == 'WarmUp_restart': 119 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 120 | scheduler = LearningRateWarmUP_restart(optimizer=optimizer, 121 | warmup_iteration=warmup_iter, 122 | cosine_decay_iter=cosine_decay_iter, 123 | target_lr=lr, 124 | after_scheduler=scheduler_cosine) 125 | elif scheduler == 'WarmUp': # 126 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter}') 127 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs-warmup_iter+1) 128 | scheduler = LearningRateWarmUP(optimizer=optimizer, 129 | warmup_iteration=warmup_iter, 130 | target_lr=lr, 131 | after_scheduler=scheduler_cosine) 132 | elif scheduler == 'StepLR': # 특정 epoch 도착하면 비율만큼 감소 133 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [11, 21], gamma=0.1) 134 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 135 | elif scheduler == 'Reduce': # factor 비율만큼 줄여주기 136 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=.5, patience=10, 137 | min_lr=1e-6) 138 | elif scheduler == 'Cosine': 139 | print('Cosine Scheduler') 140 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epochs) 141 | 142 | 143 | # scheduler 144 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=.5) 145 | # loss의 값이 최소가 되도록 하며, 50번 동안 loss의 값이 감소가 되지 않을 경우 factor값 만큼 146 | # learning_rate의 값을 줄이고, 최저 1e-6까지 줄어들 수 있게 설정 147 | 148 | best_accuracy = 0. 149 | stop_count = 0 150 | best_test_accuracy = 0. 151 | check_loss = False 152 | for epoch in range(epochs): 153 | if scheduler != 'None': 154 | scheduler.step(epoch) 155 | train_total_loss = 0.0 156 | train_total_count = 0 157 | train_total_data = 0 158 | 159 | val_total_loss = 0.0 160 | val_total_count = 0 161 | val_total_data = 0 162 | 163 | start_time = time.time() 164 | model.train() 165 | 166 | output_str = 'current epoch : %d/%d / current_lr : %f \n' % (epoch+1,epochs,optimizer.state_dict()['param_groups'][0]['lr']) 167 | sys.stdout.write(output_str) 168 | check_file.write(output_str) 169 | with tqdm(train_dataloader,desc='Train',unit='batch') as tepoch: 170 | for index,(batch_signal, batch_label) in enumerate(tepoch): 171 | batch_signal = batch_signal.to(device) 172 | batch_label = batch_label.long().to(device) 173 | optimizer.zero_grad() 174 | 175 | pred = model(batch_signal) 176 | 177 | # norm = 0 178 | # for parameter in model.parameters(): 179 | # norm += torch.norm(parameter, p=norm_square) 180 | 181 | loss = loss_fn(pred, batch_label) # + beta * norm 182 | if entropy_hyperparam > 0.: 183 | if check_loss == False: # Only once access! 184 | print('Using Entropy loss for training!') 185 | check_loss = True 186 | loss2 = loss_fn2(pred) 187 | loss = loss + entropy_hyperparam * loss2 188 | 189 | _, predict = torch.max(pred, 1) 190 | 191 | check_count = (predict == batch_label).sum().item() 192 | 193 | train_total_loss += loss.item() 194 | 195 | train_total_count += check_count 196 | train_total_data += len(batch_signal) 197 | loss.backward() 198 | optimizer.step() 199 | 200 | accuracy = train_total_count / train_total_data 201 | tepoch.set_postfix(loss=train_total_loss/(index+1),accuracy=100.*accuracy) 202 | 203 | train_total_loss /= index 204 | train_accuracy = train_total_count / train_total_data * 100 205 | 206 | output_str = 'train dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f correct : %d/%d -> %.4f%%\n' \ 207 | % (epoch + 1, epochs, time.time() - start_time, train_total_loss, 208 | train_total_count, train_total_data, train_accuracy) 209 | # sys.stdout.write(output_str) 210 | check_file.write(output_str) 211 | 212 | # check validation dataset 213 | start_time = time.time() 214 | model.eval() 215 | 216 | with tqdm(val_dataloader,desc='Validation',unit='batch') as tepoch: 217 | for index,(batch_signal, batch_label) in enumerate(tepoch): 218 | batch_signal = batch_signal.to(device) 219 | batch_label = batch_label.long().to(device) 220 | 221 | with torch.no_grad(): 222 | pred = model(batch_signal) 223 | 224 | loss = loss_fn(pred, batch_label) 225 | 226 | # acc 227 | _, predict = torch.max(pred, 1) 228 | check_count = (predict == batch_label).sum().item() 229 | 230 | val_total_loss += loss.item() 231 | val_total_count += check_count 232 | val_total_data += len(batch_signal) 233 | accuracy = val_total_count / val_total_data 234 | tepoch.set_postfix(loss=val_total_loss/(index+1),accuracy=100.*accuracy) 235 | 236 | val_total_loss /= index 237 | val_accuracy = val_total_count / val_total_data * 100 238 | 239 | output_str = 'val dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f correct : %d/%d -> %.4f%%\n' \ 240 | % (epoch + 1, epochs, time.time() - start_time, val_total_loss, 241 | val_total_count, val_total_data, val_accuracy) 242 | # sys.stdout.write(output_str) 243 | check_file.write(output_str) 244 | 245 | # scheduler.step(float(val_total_loss)) 246 | # scheduler.step(epoch) 247 | if epoch == 0: 248 | best_accuracy = val_accuracy 249 | best_epoch = epoch 250 | save_file = save_filename 251 | if torch.cuda.device_count() > 1: 252 | torch.save(model.module.state_dict(), save_file) 253 | else: 254 | torch.save(model.state_dict(), save_file) 255 | stop_count = 0 256 | test_total_count = 0 257 | test_total_data = 0 258 | # check validation dataset 259 | start_time = time.time() 260 | model.eval() 261 | 262 | with tqdm(test_dataloader,desc='Test',unit='batch') as tepoch: 263 | for index,(batch_signal, batch_label) in enumerate(tepoch): 264 | batch_signal = batch_signal.to(device) 265 | batch_label = batch_label.long().to(device) 266 | 267 | with torch.no_grad(): 268 | pred = model(batch_signal) 269 | 270 | loss = loss_fn(pred, batch_label) 271 | 272 | # acc 273 | _, predict = torch.max(pred, 1) 274 | check_count = (predict == batch_label).sum().item() 275 | 276 | test_total_count += check_count 277 | test_total_data += len(batch_signal) 278 | accuracy = test_total_count / test_total_data 279 | tepoch.set_postfix(accuracy=100.*accuracy) 280 | 281 | 282 | test_accuracy = test_total_count / test_total_data * 100 283 | best_test_accuracy = test_accuracy 284 | output_str = 'test dataset : %d/%d epochs spend time : %.4f sec / correct : %d/%d -> %.4f%%\n' \ 285 | % (epoch + 1, epochs, time.time() - start_time, 286 | test_total_count, test_total_data, test_accuracy) 287 | # sys.stdout.write(output_str) 288 | check_file.write(output_str) 289 | else: 290 | if best_accuracy < val_accuracy: 291 | best_accuracy = val_accuracy 292 | best_epoch = epoch 293 | save_file = save_filename 294 | if torch.cuda.device_count() > 1: 295 | torch.save(model.module.state_dict(), save_file) 296 | else: 297 | torch.save(model.state_dict(), save_file) 298 | stop_count = 0 299 | test_total_count = 0 300 | test_total_data = 0 301 | # check validation dataset 302 | start_time = time.time() 303 | model.eval() 304 | 305 | with tqdm(test_dataloader,desc='Test',unit='batch') as tepoch: 306 | for index,(batch_signal, batch_label) in enumerate(tepoch): 307 | batch_signal = batch_signal.to(device) 308 | batch_label = batch_label.long().to(device) 309 | 310 | with torch.no_grad(): 311 | pred = model(batch_signal) 312 | 313 | loss = loss_fn(pred, batch_label) 314 | 315 | # acc 316 | _, predict = torch.max(pred, 1) 317 | check_count = (predict == batch_label).sum().item() 318 | 319 | test_total_count += check_count 320 | test_total_data += len(batch_signal) 321 | accuracy = test_total_count / test_total_data 322 | tepoch.set_postfix(accuracy=100.*accuracy) 323 | 324 | 325 | test_accuracy = test_total_count / test_total_data * 100 326 | best_test_accuracy = test_accuracy 327 | output_str = 'test dataset : %d/%d epochs spend time : %.4f sec / correct : %d/%d -> %.4f%%\n' \ 328 | % (epoch + 1, epochs, time.time() - start_time, 329 | test_total_count, test_total_data, test_accuracy) 330 | # sys.stdout.write(output_str) 331 | check_file.write(output_str) 332 | else: 333 | stop_count += 1 334 | if stop_count > stop_iter: 335 | print('Early Stopping') 336 | break 337 | 338 | output_str = 'best epoch : %d/%d / test accuracy : %f%%\n' \ 339 | % (best_epoch + 1, epochs, best_test_accuracy) 340 | sys.stdout.write(output_str) 341 | print('=' * 30) 342 | 343 | output_str = 'best epoch : %d/%d / test accuracy : %f%%\n' \ 344 | % (best_epoch + 1, epochs, best_test_accuracy) 345 | sys.stdout.write(output_str) 346 | check_file.write(output_str) 347 | print('=' * 30) 348 | 349 | check_file.close() 350 | 351 | 352 | 353 | 354 | 355 | def training_deepsleepnet_dataloader(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1, 356 | random_seed=2,use_channel=[0,1],entropy_hyperparam=0.,classification_mode='6class',aug_p=0.,aug_method=['h_flip','v_flip'],gpu_num=[0]): 357 | 358 | if use_dataset == 'sleep_edf': 359 | signals_path = '/home/eslab/dataset/sleep_edf_final/origin_npy/remove_wake_version1/each/' 360 | 361 | random.seed(random_seed) # seed 362 | np.random.seed(random_seed) 363 | torch.manual_seed(random_seed) 364 | 365 | # signals_path = '/home/eslab/dataset/seoulDataset/1channel_prefilter_butter_minmax_-1_1/signals_dataloader/' 366 | 367 | dataset_list = os.listdir(signals_path) 368 | dataset_list = [signals_path + filename + '/' for filename in dataset_list] 369 | dataset_list.sort() 370 | random.shuffle(dataset_list) 371 | 372 | 373 | 374 | training_fold_list = [] 375 | validation_fold_list = [] 376 | test_fold_list = [] 377 | 378 | 379 | 380 | val_length = int(len(dataset_list) * val_percent) 381 | test_length = int(len(dataset_list) * test_percent) 382 | train_length = int(len(dataset_list) - val_length - test_length) 383 | 384 | 385 | for i in range(0,val_length): 386 | validation_fold_list.append(dataset_list[i]) 387 | for i in range(val_length,val_length + test_length): 388 | test_fold_list.append(dataset_list[i]) 389 | for i in range(val_length + test_length,len(dataset_list)): 390 | training_fold_list.append(dataset_list[i]) 391 | 392 | 393 | 394 | # print(dataset_list[:10]) 395 | print('='*20) 396 | print(len(training_fold_list)) 397 | print(len(validation_fold_list)) 398 | print(len(test_fold_list)) 399 | print('='*20) 400 | 401 | train_label,train_label_percent = check_label_info_withPath(file_list = training_fold_list) 402 | val_label,val_label_percent = check_label_info_withPath(file_list = validation_fold_list) 403 | test_label,test_label_percent = check_label_info_withPath(file_list = test_fold_list) 404 | 405 | print(train_label) 406 | print(np.round(train_label_percent,3)) 407 | print(val_label) 408 | print(np.round(val_label_percent,3)) 409 | print(test_label) 410 | print(np.round(test_label_percent,3)) 411 | 412 | 413 | # exit(1) 414 | 415 | # number of classes 416 | if classification_mode == '6class': 417 | class_num = 6 418 | elif classification_mode =='5class': 419 | class_num = 5 420 | else: 421 | class_num=3 422 | 423 | # hyperparameters 424 | epochs = 100 425 | batch_size = 512 426 | warmup_iter=10 427 | cosine_decay_iter=10 428 | learning_rate = 10**-4 429 | stop_iter = 10 430 | loss_function = 'CE' # CEs 431 | optim= 'Adam' 432 | scheduler = 'Cosine' # 'WarmUp_restart' 433 | 434 | print(f'class num = {class_num}') 435 | model_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/saved_model/{use_dataset}/{classification_mode}/single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/' 436 | logging_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/log/{use_dataset}/{classification_mode}/single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/' 437 | # model_save_path = '/home/eslab/kdy/git/Sleep_pytorch/saved_model/seoulDataset/single_epoch_models/' 438 | # logging_save_path = '/home/eslab/kdy/git/Sleep_pytorch/log/seoulDataset/single_epoch_models/' 439 | 440 | os.makedirs(model_save_path,exist_ok=True) 441 | os.makedirs(logging_save_path,exist_ok=True) 442 | 443 | save_filename = model_save_path + f'DeepSleepNet_%.5f_{use_channel}_entropy_{entropy_hyperparam}.pth'%(learning_rate) 444 | 445 | logging_filename = logging_save_path + f'DeepSleepNet_%.5f_{use_channel}_entropy_{entropy_hyperparam}.txt'%(learning_rate) 446 | print('logging filename : ',logging_filename) 447 | print('save filename : ',save_filename) 448 | 449 | # exit(1) 450 | train_deepsleepnet_dataloader(save_filename=save_filename,logging_filename=logging_filename,train_dataset_list=training_fold_list,val_dataset_list=validation_fold_list,test_dataset_list=test_fold_list, 451 | batch_size = batch_size,entropy_hyperparam=entropy_hyperparam, 452 | epochs=epochs,optim=optim,loss_function=loss_function, 453 | learning_rate=learning_rate,scheduler=scheduler,warmup_iter=warmup_iter,cosine_decay_iter=cosine_decay_iter,stop_iter=stop_iter, 454 | use_channel=use_channel,class_num=class_num,classification_mode=classification_mode,aug_p=aug_p,aug_method=aug_method, 455 | gpu_num=gpu_num,sample_rate= 100,epoch_size = 30) -------------------------------------------------------------------------------- /train/single_epoch/train_resnet.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from models.cnn.ResNet import * 3 | 4 | 5 | 6 | def train_resnet_dataloader(save_filename,logging_filename,train_dataset_list,val_dataset_list,test_dataset_list,batch_size = 512,entropy_hyperparam=0., 7 | epochs=100,optim='Adam',loss_function='CE',use_model='resnet18', 8 | learning_rate=0.001,scheduler=None,warmup_iter=20,cosine_decay_iter=40,stop_iter=10, 9 | use_channel=[0,1],class_num=6,classification_mode='6class',aug_p=0.,aug_method=['h_flip','v_flip'], 10 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2, 11 | gpu_num=0,sample_rate= 100,epoch_size = 30): 12 | # cpu processor num 13 | cpu_num = multiprocessing.cpu_count() 14 | 15 | 16 | #dataload Training Dataset 17 | train_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=train_dataset_list,class_num=class_num, 18 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode,aug_p = aug_p,aug_method = aug_method,) 19 | train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True, num_workers=(cpu_num//4)) 20 | 21 | # calculate weight from training dataset (for "Class Balanced Weight") 22 | weights,count = make_weights_for_balanced_classes(train_dataset.signals_files_path,nclasses=class_num) 23 | weights = torch.DoubleTensor(weights) 24 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights,len(weights)) 25 | 26 | 27 | #dataload Validation Dataset 28 | val_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=val_dataset_list,class_num=class_num, 29 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 30 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 31 | 32 | 33 | test_dataset = Sleep_Dataset_withPath_sleepEDF(dataset_list=test_dataset_list,class_num=class_num, 34 | use_channel=use_channel,use_cuda = True,classification_mode=classification_mode) 35 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=(cpu_num//4)) 36 | 37 | 38 | print(train_dataset.length,val_dataset.length,test_dataset.length) 39 | 40 | # Adam optimizer paramQ 41 | b1 = 0.9 42 | b2 = 0.999 43 | 44 | # for Regularization 45 | beta = 0.001 46 | norm_square = 2 47 | 48 | check_file = open(logging_filename, 'w') # logging file 49 | 50 | best_accuracy = 0. 51 | best_epoch = 0 52 | if use_model == 'resnet18': 53 | model = ResNet_classification(block=BasicBlock, layers=[2,2,2,2], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 54 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, num_classes=class_num, use_batchnorm=True, zero_init_residual=False) 55 | elif use_model == 'resnet34': 56 | model = ResNet_classification(block=BasicBlock, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 57 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, num_classes=class_num, use_batchnorm=True, zero_init_residual=False) 58 | elif use_model == 'resnet50': 59 | model = ResNet_classification(block=Bottleneck, layers=[3,4,6,3], first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters, in_channel=len(use_channel), 60 | block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, num_classes=class_num, use_batchnorm=True, zero_init_residual=False) 61 | 62 | cuda = torch.cuda.is_available() 63 | device = torch.device(f"cuda:{gpu_num[0]}" if torch.cuda.is_available() else "cpu") 64 | torch.cuda.set_device(device) 65 | 66 | if cuda: 67 | print('can use CUDA!!!') 68 | model = model.cuda() 69 | summary(model,(len(use_channel),sample_rate*epoch_size)) 70 | # exit(1) 71 | print('torch.cuda.device_count() : ', torch.cuda.device_count()) 72 | 73 | if torch.cuda.device_count() > 1: 74 | print('Multi GPU Activation !!!', torch.cuda.device_count()) 75 | model = nn.DataParallel(model) 76 | 77 | # summary(model, (3, 6000)) 78 | model.apply(weights_init) # weight init 79 | print('loss function : %s' % loss_function) 80 | if loss_function == 'CE': 81 | loss_fn = nn.CrossEntropyLoss().to(device) 82 | elif loss_function == 'CEW': 83 | samples_per_cls = count / np.sum(count) 84 | no_of_classes = class_num 85 | effective_num = 1.0 - np.power(beta, samples_per_cls) 86 | weights = (1.0 - beta) / np.array(effective_num) 87 | weights = weights / np.sum(weights) * no_of_classes 88 | weights = torch.tensor(weights).float() 89 | weights = weights.to(device) 90 | loss_fn = nn.CrossEntropyLoss(weight=weights).to(device) 91 | elif loss_function == 'FL': 92 | loss_fn = FocalLoss(gamma=2).to(device) 93 | elif loss_function == 'CBL': 94 | samples_per_cls = count / np.sum(count) 95 | loss_fn = CB_loss(samples_per_cls=samples_per_cls, no_of_classes=class_num, loss_type='focal', beta=0.9999, 96 | gamma=2.0) 97 | # loss_fn = FocalLoss(gamma=2).to(device) 98 | if entropy_hyperparam > 0.: 99 | loss_fn2 = Entropy() 100 | # optimizer ADAM (SGD의 경우에는 정상적으로 학습이 진행되지 않았음) 101 | if optim == 'Adam': 102 | print('Optimizer : Adam') 103 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(b1, b2)) 104 | elif optim == 'RMS': 105 | print('Optimizer : RMSprop') 106 | optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) 107 | elif optim == 'SGD': 108 | print('Optimizer : SGD') 109 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5,nesterov=False) 110 | elif optim == 'AdamW': 111 | print('Optimizer AdamW') 112 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(b1, b2)) 113 | 114 | 115 | gamma = 0.8 116 | 117 | lr = learning_rate 118 | epochs = epochs 119 | if scheduler == 'WarmUp_restart_gamma': 120 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter} / cosine_decay_iter : {cosine_decay_iter} / gamma : {gamma}') 121 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 122 | scheduler = LearningRateWarmUP_restart_changeMax(optimizer=optimizer, 123 | warmup_iteration=warmup_iter, 124 | cosine_decay_iter=cosine_decay_iter, 125 | target_lr=lr, 126 | after_scheduler=scheduler_cosine,gamma=gamma) 127 | elif scheduler == 'WarmUp_restart': 128 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cosine_decay_iter+1) 129 | scheduler = LearningRateWarmUP_restart(optimizer=optimizer, 130 | warmup_iteration=warmup_iter, 131 | cosine_decay_iter=cosine_decay_iter, 132 | target_lr=lr, 133 | after_scheduler=scheduler_cosine) 134 | elif scheduler == 'WarmUp': # 135 | print(f'target lr : {learning_rate} / warmup_iter : {warmup_iter}') 136 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs-warmup_iter+1) 137 | scheduler = LearningRateWarmUP(optimizer=optimizer, 138 | warmup_iteration=warmup_iter, 139 | target_lr=lr, 140 | after_scheduler=scheduler_cosine) 141 | elif scheduler == 'StepLR': # 특정 epoch 도착하면 비율만큼 감소 142 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [11, 21], gamma=0.1) 143 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 144 | elif scheduler == 'Reduce': # factor 비율만큼 줄여주기 145 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=.5, patience=10, 146 | min_lr=1e-6) 147 | elif scheduler == 'Cosine': 148 | print('Cosine Scheduler') 149 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epochs) 150 | 151 | 152 | # scheduler 153 | # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=.5) 154 | # loss의 값이 최소가 되도록 하며, 50번 동안 loss의 값이 감소가 되지 않을 경우 factor값 만큼 155 | # learning_rate의 값을 줄이고, 최저 1e-6까지 줄어들 수 있게 설정 156 | 157 | best_accuracy = 0. 158 | stop_count = 0 159 | best_test_accuracy = 0. 160 | check_loss = False 161 | for epoch in range(epochs): 162 | if scheduler != 'None': 163 | scheduler.step(epoch) 164 | train_total_loss = 0.0 165 | train_total_count = 0 166 | train_total_data = 0 167 | 168 | val_total_loss = 0.0 169 | val_total_count = 0 170 | val_total_data = 0 171 | 172 | start_time = time.time() 173 | model.train() 174 | 175 | output_str = 'current epoch : %d/%d / current_lr : %f \n' % (epoch+1,epochs,optimizer.state_dict()['param_groups'][0]['lr']) 176 | sys.stdout.write(output_str) 177 | check_file.write(output_str) 178 | with tqdm(train_dataloader,desc='Train',unit='batch') as tepoch: 179 | for index,(batch_signal, batch_label) in enumerate(tepoch): 180 | batch_signal = batch_signal.to(device) 181 | batch_label = batch_label.long().to(device) 182 | optimizer.zero_grad() 183 | 184 | pred = model(batch_signal) 185 | 186 | # norm = 0 187 | # for parameter in model.parameters(): 188 | # norm += torch.norm(parameter, p=norm_square) 189 | 190 | loss = loss_fn(pred, batch_label) # + beta * norm 191 | if entropy_hyperparam > 0.: 192 | if check_loss == False: # Only once access! 193 | print('Using Entropy loss for training!') 194 | check_loss = True 195 | loss2 = loss_fn2(pred) 196 | loss = loss + entropy_hyperparam * loss2 197 | 198 | _, predict = torch.max(pred, 1) 199 | 200 | check_count = (predict == batch_label).sum().item() 201 | 202 | train_total_loss += loss.item() 203 | 204 | train_total_count += check_count 205 | train_total_data += len(batch_signal) 206 | loss.backward() 207 | optimizer.step() 208 | 209 | accuracy = train_total_count / train_total_data 210 | tepoch.set_postfix(loss=train_total_loss/(index+1),accuracy=100.*accuracy) 211 | 212 | train_total_loss /= index 213 | train_accuracy = train_total_count / train_total_data * 100 214 | 215 | output_str = 'train dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f correct : %d/%d -> %.4f%%\n' \ 216 | % (epoch + 1, epochs, time.time() - start_time, train_total_loss, 217 | train_total_count, train_total_data, train_accuracy) 218 | # sys.stdout.write(output_str) 219 | check_file.write(output_str) 220 | 221 | # check validation dataset 222 | start_time = time.time() 223 | model.eval() 224 | 225 | with tqdm(val_dataloader,desc='Validation',unit='batch') as tepoch: 226 | for index,(batch_signal, batch_label) in enumerate(tepoch): 227 | batch_signal = batch_signal.to(device) 228 | batch_label = batch_label.long().to(device) 229 | 230 | with torch.no_grad(): 231 | pred = model(batch_signal) 232 | 233 | loss = loss_fn(pred, batch_label) 234 | 235 | # acc 236 | _, predict = torch.max(pred, 1) 237 | check_count = (predict == batch_label).sum().item() 238 | 239 | val_total_loss += loss.item() 240 | val_total_count += check_count 241 | val_total_data += len(batch_signal) 242 | accuracy = val_total_count / val_total_data 243 | tepoch.set_postfix(loss=val_total_loss/(index+1),accuracy=100.*accuracy) 244 | 245 | val_total_loss /= index 246 | val_accuracy = val_total_count / val_total_data * 100 247 | 248 | output_str = 'val dataset : %d/%d epochs spend time : %.4f sec / total_loss : %.4f correct : %d/%d -> %.4f%%\n' \ 249 | % (epoch + 1, epochs, time.time() - start_time, val_total_loss, 250 | val_total_count, val_total_data, val_accuracy) 251 | # sys.stdout.write(output_str) 252 | check_file.write(output_str) 253 | 254 | # scheduler.step(float(val_total_loss)) 255 | # scheduler.step(epoch) 256 | if epoch == 0: 257 | best_accuracy = val_accuracy 258 | best_epoch = epoch 259 | save_file = save_filename 260 | if torch.cuda.device_count() > 1: 261 | torch.save(model.module.state_dict(), save_file) 262 | else: 263 | torch.save(model.state_dict(), save_file) 264 | stop_count = 0 265 | test_total_count = 0 266 | test_total_data = 0 267 | # check validation dataset 268 | start_time = time.time() 269 | model.eval() 270 | 271 | with tqdm(test_dataloader,desc='Test',unit='batch') as tepoch: 272 | for index,(batch_signal, batch_label) in enumerate(tepoch): 273 | batch_signal = batch_signal.to(device) 274 | batch_label = batch_label.long().to(device) 275 | 276 | with torch.no_grad(): 277 | pred = model(batch_signal) 278 | 279 | loss = loss_fn(pred, batch_label) 280 | 281 | # acc 282 | _, predict = torch.max(pred, 1) 283 | check_count = (predict == batch_label).sum().item() 284 | 285 | test_total_count += check_count 286 | test_total_data += len(batch_signal) 287 | accuracy = test_total_count / test_total_data 288 | tepoch.set_postfix(accuracy=100.*accuracy) 289 | 290 | 291 | test_accuracy = test_total_count / test_total_data * 100 292 | best_test_accuracy = test_accuracy 293 | output_str = 'test dataset : %d/%d epochs spend time : %.4f sec / correct : %d/%d -> %.4f%%\n' \ 294 | % (epoch + 1, epochs, time.time() - start_time, 295 | test_total_count, test_total_data, test_accuracy) 296 | # sys.stdout.write(output_str) 297 | check_file.write(output_str) 298 | else: 299 | if best_accuracy < val_accuracy: 300 | best_accuracy = val_accuracy 301 | best_epoch = epoch 302 | save_file = save_filename 303 | if torch.cuda.device_count() > 1: 304 | torch.save(model.module.state_dict(), save_file) 305 | else: 306 | torch.save(model.state_dict(), save_file) 307 | stop_count = 0 308 | test_total_count = 0 309 | test_total_data = 0 310 | # check validation dataset 311 | start_time = time.time() 312 | model.eval() 313 | 314 | with tqdm(test_dataloader,desc='Test',unit='batch') as tepoch: 315 | for index,(batch_signal, batch_label) in enumerate(tepoch): 316 | batch_signal = batch_signal.to(device) 317 | batch_label = batch_label.long().to(device) 318 | 319 | with torch.no_grad(): 320 | pred = model(batch_signal) 321 | 322 | loss = loss_fn(pred, batch_label) 323 | 324 | # acc 325 | _, predict = torch.max(pred, 1) 326 | check_count = (predict == batch_label).sum().item() 327 | 328 | test_total_count += check_count 329 | test_total_data += len(batch_signal) 330 | accuracy = test_total_count / test_total_data 331 | tepoch.set_postfix(accuracy=100.*accuracy) 332 | 333 | 334 | test_accuracy = test_total_count / test_total_data * 100 335 | best_test_accuracy = test_accuracy 336 | output_str = 'test dataset : %d/%d epochs spend time : %.4f sec / correct : %d/%d -> %.4f%%\n' \ 337 | % (epoch + 1, epochs, time.time() - start_time, 338 | test_total_count, test_total_data, test_accuracy) 339 | # sys.stdout.write(output_str) 340 | check_file.write(output_str) 341 | else: 342 | stop_count += 1 343 | if stop_count > stop_iter: 344 | print('Early Stopping') 345 | break 346 | 347 | output_str = 'best epoch : %d/%d / test accuracy : %f%%\n' \ 348 | % (best_epoch + 1, epochs, best_test_accuracy) 349 | sys.stdout.write(output_str) 350 | print('=' * 30) 351 | 352 | output_str = 'best epoch : %d/%d / test accuracy : %f%%\n' \ 353 | % (best_epoch + 1, epochs, best_test_accuracy) 354 | sys.stdout.write(output_str) 355 | check_file.write(output_str) 356 | print('=' * 30) 357 | 358 | check_file.close() 359 | 360 | 361 | 362 | 363 | 364 | def training_resnet_dataloader(use_dataset='sleep_edf',total_train_percent = 1.,train_percent=0.8,val_percent=0.1,test_percent=0.1,use_model = 'resnet18', 365 | random_seed=2,use_channel=[0,1],entropy_hyperparam=0.,classification_mode='5class',aug_p=0.,aug_method=['h_flip','v_flip'], 366 | first_conv=[49, 4, 24],maxpool=[7,3,3], layer_filters=[64, 128, 128, 256],block_kernel_size=5,block_stride_size=2, 367 | gpu_num=[0]): 368 | 369 | if use_dataset == 'sleep_edf': 370 | signals_path = '/home/eslab/dataset/sleep_edf_final/origin_npy/remove_wake_version1/each/' 371 | 372 | random.seed(random_seed) # seed 373 | np.random.seed(random_seed) 374 | torch.manual_seed(random_seed) 375 | 376 | # signals_path = '/home/eslab/dataset/seoulDataset/1channel_prefilter_butter_minmax_-1_1/signals_dataloader/' 377 | 378 | dataset_list = os.listdir(signals_path) 379 | dataset_list = [signals_path + filename + '/' for filename in dataset_list] 380 | dataset_list.sort() 381 | random.shuffle(dataset_list) 382 | 383 | 384 | 385 | training_fold_list = [] 386 | validation_fold_list = [] 387 | test_fold_list = [] 388 | 389 | 390 | 391 | val_length = int(len(dataset_list) * val_percent) 392 | test_length = int(len(dataset_list) * test_percent) 393 | train_length = int(len(dataset_list) - val_length - test_length) 394 | 395 | 396 | for i in range(0,val_length): 397 | validation_fold_list.append(dataset_list[i]) 398 | for i in range(val_length,val_length + test_length): 399 | test_fold_list.append(dataset_list[i]) 400 | for i in range(val_length + test_length,len(dataset_list)): 401 | training_fold_list.append(dataset_list[i]) 402 | 403 | 404 | 405 | # print(dataset_list[:10]) 406 | print('='*20) 407 | print(len(training_fold_list)) 408 | print(len(validation_fold_list)) 409 | print(len(test_fold_list)) 410 | print('='*20) 411 | 412 | train_label,train_label_percent = check_label_info_withPath(file_list = training_fold_list) 413 | val_label,val_label_percent = check_label_info_withPath(file_list = validation_fold_list) 414 | test_label,test_label_percent = check_label_info_withPath(file_list = test_fold_list) 415 | 416 | print(train_label) 417 | print(np.round(train_label_percent,3)) 418 | print(val_label) 419 | print(np.round(val_label_percent,3)) 420 | print(test_label) 421 | print(np.round(test_label_percent,3)) 422 | 423 | 424 | # exit(1) 425 | 426 | # number of classes 427 | if classification_mode == '6class': 428 | class_num = 6 429 | elif classification_mode =='5class': 430 | class_num = 5 431 | else: 432 | class_num=3 433 | 434 | # hyperparameters 435 | epochs = 100 436 | batch_size = 512 437 | warmup_iter=10 438 | cosine_decay_iter=10 439 | learning_rate = 10**-4 440 | stop_iter = 10 441 | loss_function = 'CE' # CEs 442 | optim= 'Adam' 443 | scheduler = 'Cosine' # 'WarmUp_restart' 444 | 445 | print(f'class num = {class_num}') 446 | model_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/saved_model/{use_dataset}/{classification_mode}/'\ 447 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 448 | f'random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 449 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 450 | logging_save_path = f'/data/hdd3/git/DeepSleepNet_pytorch/log/{use_dataset}/{classification_mode}/'\ 451 | f'single_epoch_models_{round(train_percent,2)}_{round(val_percent,2)}_{round(test_percent,2)}/'\ 452 | f'random_seed_{random_seed}_scheduler_{scheduler}_withoutRegularization_aug_p_{aug_p}_aug_method_{aug_method}/'\ 453 | f'firstconv_{first_conv}_maxpool_{maxpool}_layerfilters_{layer_filters}_blockkernelsize_{block_kernel_size}_blockstridesize_{block_stride_size}/' 454 | # model_save_path = '/home/eslab/kdy/git/Sleep_pytorch/saved_model/seoulDataset/single_epoch_models/' 455 | # logging_save_path = '/home/eslab/kdy/git/Sleep_pytorch/log/seoulDataset/single_epoch_models/' 456 | 457 | os.makedirs(model_save_path,exist_ok=True) 458 | os.makedirs(logging_save_path,exist_ok=True) 459 | 460 | save_filename = model_save_path + f'{use_model}_%.5f_{use_channel}_entropy_{entropy_hyperparam}.pth'%(learning_rate) 461 | 462 | logging_filename = logging_save_path + f'{use_model}_%.5f_{use_channel}_entropy_{entropy_hyperparam}.txt'%(learning_rate) 463 | print('logging filename : ',logging_filename) 464 | print('save filename : ',save_filename) 465 | 466 | # exit(1) 467 | train_resnet_dataloader(save_filename=save_filename,logging_filename=logging_filename,train_dataset_list=training_fold_list,val_dataset_list=validation_fold_list,test_dataset_list=test_fold_list, 468 | batch_size = batch_size,entropy_hyperparam=entropy_hyperparam, 469 | epochs=epochs,optim=optim,loss_function=loss_function, 470 | learning_rate=learning_rate,scheduler=scheduler,warmup_iter=warmup_iter,cosine_decay_iter=cosine_decay_iter,stop_iter=stop_iter, 471 | use_channel=use_channel,class_num=class_num,classification_mode=classification_mode,aug_p=aug_p,aug_method=aug_method, 472 | first_conv=first_conv,maxpool=maxpool, layer_filters=layer_filters,block_kernel_size=block_kernel_size,block_stride_size=block_stride_size, 473 | gpu_num=gpu_num,sample_rate= 100,epoch_size = 30) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/function.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/__pycache__/function.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/__pycache__/loss_fn.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader/Transform.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | 4 | 5 | def interp_1d(arr,short_sample=750,long_sample=6000): 6 | return np.interp( 7 | np.arange(0,long_sample), 8 | np.linspace(0,long_sample,num=short_sample), 9 | arr) 10 | 11 | def interp_1d_multiChannel(arr,short_sample=750,long_sample=6000): 12 | signals = [] 13 | # print(arr.shape) 14 | if len(arr) == 1: 15 | return np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr.reshape(-1)).reshape(1,-1) 16 | for i in range(np.shape(arr)[0]): 17 | signals.append(np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr[i].reshape(-1))) 18 | 19 | signals = np.array(signals) 20 | 21 | return signals 22 | 23 | def interp_1d_multiChannel_tensor(arr,short_sample=750,long_sample=6000): 24 | signals = [] 25 | # print(arr.shape) 26 | if len(arr) == 1: 27 | return torch.nn.functional.interpolate(input=arr.reshape(-1),size=long_sample,mode='linear') 28 | # return np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr.reshape(-1)) 29 | for i in range(np.shape(arr)[0]): 30 | signals.append(torch.nn.functional.interpolate(input=arr[i],size=long_sample,mode='linear')) 31 | # signals.append(np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr[i].reshape(-1))) 32 | 33 | signals = np.array(signals) 34 | 35 | return signals 36 | 37 | class Transform: 38 | def __init__(self): 39 | pass 40 | def add_jittering(self, signal, mu=0,std=1,channel_dependent=False): 41 | if channel_dependent == False: # Channel Independent Noise 42 | noise = np.random.normal(loc=mu,scale=std,size=np.shape(signal)) 43 | else: # Channel Dependent Noise 44 | noise = np.random.normal(loc=mu,scale=std,size=np.shape(signal)[-1]) 45 | noise = np.reshape((-1,1)) 46 | noise = np.repeat(noise,repeats=np.shape(signal)[0],axis=1) 47 | noise = noise.T 48 | signal = signal + noise 49 | return signal 50 | 51 | def horizon_flip(self,signal): 52 | signal = np.flip(signal,axis=-1) 53 | return signal 54 | 55 | def permute(self,signal,pieces_size=200): 56 | assert np.shape(signal)[-1] % pieces_size == 0, "Fault Pieces Size!!!" 57 | permute_length = np.shape(signal)[-1]//pieces_size 58 | random_index = np.arange(permute_length) 59 | 60 | np.random.shuffle(random_index) 61 | permute_signal = signal.reshape(np.shape(signal)[0],permute_length,pieces_size) 62 | 63 | permute_signal = permute_signal[:,random_index,:] 64 | 65 | permute_signal = permute_signal.reshape(np.shape(signal)[0],-1) 66 | 67 | return permute_signal 68 | 69 | def cutout_resize(self, signal, length): 70 | cutout_length = int(length) 71 | while(1): 72 | random_num = np.random.rand() 73 | # print(f'np.random.rand() = {np.random.rand()}') 74 | # print(np.shape(signal)[-1]) 75 | # print(f'random_num = {random_num} // int(np.ceil(np.shape(signal)[-1] // 100 * np.random.rand())) = {int(np.ceil(np.shape(signal)[-1] // 100 * random_num))}') 76 | start_num = int(np.ceil(np.shape(signal)[-1] * random_num)) 77 | if start_num + cutout_length <= np.shape(signal)[-1]: 78 | break 79 | # print(f'start_num = {start_num}') 80 | if start_num + cutout_length == np.shape(signal)[-1]: 81 | cutout_signal = signal[:,:start_num] 82 | else: 83 | cutout_signal = np.concatenate((signal[:,:start_num],signal[:,start_num+cutout_length:]),axis=1) 84 | 85 | cutout_signal = interp_1d_multiChannel(arr=cutout_signal,short_sample=6000-length,long_sample=6000) 86 | 87 | return cutout_signal 88 | 89 | def crop_resize(self, signal, length,long_sample=3000): 90 | crop_length = int(length) 91 | while(1): 92 | start_num = int(np.ceil(np.shape(signal)[-1] * np.random.rand())) 93 | if start_num + crop_length <= np.shape(signal)[-1]: 94 | break 95 | crop_signal = signal[:,start_num:start_num+crop_length] 96 | # print(f'crop signal = {crop_signal.shape}') 97 | crop_signal = interp_1d_multiChannel(arr=crop_signal,short_sample=crop_length,long_sample=long_sample) 98 | # print(f'crop signal interpolation = {crop_signal.shape}') 99 | return crop_signal -------------------------------------------------------------------------------- /utils/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch -------------------------------------------------------------------------------- /utils/dataloader/__pycache__/Transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataloader/__pycache__/Transform.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataloader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader/__pycache__/sleep_edf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataloader/__pycache__/sleep_edf.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader/sleep_edf.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from .Transform import * 3 | 4 | 5 | def make_weights_for_balanced_classes(data_list, nclasses=6,check_file='.npy'): 6 | count = [0] * nclasses 7 | 8 | for data in data_list: 9 | count[int(data.split(check_file)[0].split('_')[-1])] += 1 10 | 11 | weight_per_class = [0.] * nclasses 12 | N = float(sum(count)) 13 | for i in range(nclasses): 14 | weight_per_class[i] = N/float(count[i]) 15 | weight = [0] * len(data_list) 16 | for idx, val in enumerate(data_list): 17 | weight[idx] = weight_per_class[int(val.split(check_file)[0].split('_')[-1])] 18 | return weight , count 19 | 20 | 21 | class Sleep_Dataset_withPath_sleepEDF(object): 22 | def read_dataset(self): 23 | all_signals_files = [] 24 | all_labels = [] 25 | 26 | for dataset_folder in self.dataset_list: 27 | signals_path = dataset_folder 28 | signals_list = os.listdir(signals_path) 29 | signals_list.sort() 30 | for signals_filename in signals_list: 31 | if self.classification_mode == '5class': 32 | if int(signals_filename.split('.npy')[0].split('_')[-1]) != 5: # pass 'None' class 33 | signals_file = signals_path+signals_filename 34 | all_signals_files.append(signals_file) 35 | all_labels.append(int(signals_filename.split('.npy')[0].split('_')[-1])) 36 | else: 37 | signals_file = signals_path+signals_filename 38 | all_signals_files.append(signals_file) 39 | all_labels.append(int(signals_filename.split('.npy')[0].split('_')[-1])) 40 | 41 | return all_signals_files, all_labels, len(all_signals_files) 42 | 43 | def __init__(self, 44 | dataset_list, 45 | class_num=5, 46 | use_cuda = True, 47 | use_channel = [0], 48 | window_size=500, 49 | stride=250, 50 | sample_rate=125, 51 | epoch_size=30, 52 | aug_p = 0., 53 | aug_method = ['h_flip','v_flip'], 54 | classification_mode='5class' 55 | ): 56 | self.class_num = class_num 57 | self.dataset_list = dataset_list 58 | self.classification_mode = classification_mode 59 | self.signals_files_path, self.labels, self.length = self.read_dataset() 60 | self.use_channel = use_channel 61 | self.use_cuda = use_cuda 62 | self.seq_size = ((sample_rate*epoch_size)-window_size)//stride + 1 63 | self.window_size = window_size 64 | self.stride = stride 65 | self.aug_p = aug_p 66 | self.aug_method = aug_method 67 | 68 | print('classification_mode : ',classification_mode) 69 | print(f'window size = {window_size} / stride = {stride}') 70 | 71 | def __getitem__(self, index): 72 | 73 | # current file index 74 | 75 | labels = int(self.labels[index]) 76 | 77 | if self.classification_mode == 'REM-NoneREM': 78 | if labels == 0: # Wake 79 | labels = 0 80 | elif labels == 4: #REM 81 | labels = 2 82 | else: # None-REM 83 | labels = 1 84 | 85 | elif self.classification_mode == 'LS-DS': 86 | if labels == 0: 87 | labels = 0 88 | elif labels == 1 or labels == 2: 89 | labels = 1 90 | else: 91 | labels = 2 92 | 93 | signals = np.load(self.signals_files_path[index]) 94 | signals = signals[self.use_channel,:] 95 | 96 | 97 | # for i in range(self.seq_size): 98 | # print(np.array_equal(signals[:,i*self.stride:(i*self.stride)+self.window_size],input_signals[i])) 99 | 100 | signals = np.array(signals) 101 | if self.aug_p > 0.: 102 | if np.random.rand() > self.aug_p: #using aug 103 | if 'h_flip' in self.aug_method: # horizontal flip 104 | signals = -1 * signals 105 | elif 'v_flip' in self.aug_method: # vertical flip 106 | signals = signals[:,::-1] 107 | if self.use_cuda: 108 | signals = torch.from_numpy(signals).float() 109 | 110 | # print(signals.shape) 111 | return signals,labels 112 | 113 | def __len__(self): 114 | return self.length 115 | 116 | 117 | 118 | class Sleep_Dataset_withPath_sleepEDF_simCLR(object): 119 | def read_dataset(self): 120 | all_signals_files = [] 121 | all_labels = [] 122 | 123 | for dataset_folder in self.dataset_list: 124 | signals_path = dataset_folder 125 | signals_list = os.listdir(signals_path) 126 | signals_list.sort() 127 | for signals_filename in signals_list: 128 | if self.classification_mode == '5class': 129 | if int(signals_filename.split('.npy')[0].split('_')[-1]) != 5: # pass 'None' class 130 | signals_file = signals_path+signals_filename 131 | all_signals_files.append(signals_file) 132 | all_labels.append(int(signals_filename.split('.npy')[0].split('_')[-1])) 133 | else: 134 | signals_file = signals_path+signals_filename 135 | all_signals_files.append(signals_file) 136 | all_labels.append(int(signals_filename.split('.npy')[0].split('_')[-1])) 137 | 138 | return all_signals_files, all_labels, len(all_signals_files) 139 | 140 | def __init__(self, 141 | dataset_list, 142 | class_num=5, 143 | use_cuda = True, 144 | use_channel = [0], 145 | window_size=500, 146 | stride=250, 147 | sample_rate=125, 148 | epoch_size=30, 149 | preprocessing = True, 150 | preprocessing_method = ['permute','crop'], 151 | permute_size=200, 152 | crop_size=1000, 153 | cutout_size=1000, 154 | classification_mode='5class' 155 | ): 156 | self.class_num = class_num 157 | self.dataset_list = dataset_list 158 | self.classification_mode = classification_mode 159 | self.signals_files_path, self.labels, self.length = self.read_dataset() 160 | self.use_channel = use_channel 161 | self.use_cuda = use_cuda 162 | self.seq_size = ((sample_rate*epoch_size)-window_size)//stride + 1 163 | self.window_size = window_size 164 | self.stride = stride 165 | self.long_length =sample_rate * epoch_size 166 | self.preprocessing = preprocessing 167 | self.preprocessing_method = preprocessing_method 168 | self.Transform = Transform() 169 | self.permute_size = permute_size 170 | self.crop_size = crop_size 171 | self.cutout_size = cutout_size 172 | 173 | print('classification_mode : ',classification_mode) 174 | print(f'window size = {window_size} / stride = {stride}') 175 | 176 | def __getitem__(self, index): 177 | 178 | # current file index 179 | 180 | labels = int(self.labels[index]) 181 | 182 | if self.classification_mode == 'REM-NoneREM': 183 | if labels == 0: # Wake 184 | labels = 0 185 | elif labels == 4: #REM 186 | labels = 2 187 | else: # None-REM 188 | labels = 1 189 | 190 | elif self.classification_mode == 'LS-DS': 191 | if labels == 0: 192 | labels = 0 193 | elif labels == 1 or labels == 2: 194 | labels = 1 195 | else: 196 | labels = 2 197 | 198 | signals = np.load(self.signals_files_path[index]) 199 | signals = signals[self.use_channel,:] 200 | 201 | 202 | # for i in range(self.seq_size): 203 | # print(np.array_equal(signals[:,i*self.stride:(i*self.stride)+self.window_size],input_signals[i])) 204 | 205 | # signals = np.array(signals) 206 | if self.preprocessing: 207 | for signal_index, current_method in enumerate(self.preprocessing_method): 208 | if current_method == 'permute': 209 | if signal_index == 0: 210 | signals1 = self.Transform.permute(signal=signals,pieces_size=self.permute_size) 211 | else: 212 | signals2 = self.Transform.permute(signal=signals,pieces_size=self.permute_size) 213 | elif current_method == 'crop': 214 | if signal_index == 0: 215 | signals1 = self.Transform.crop_resize(signal=signals,length=self.crop_size,long_sample=self.long_length) 216 | else: 217 | signals2 = self.Transform.crop_resize(signal=signals,length=self.crop_size,long_sample=self.long_length) 218 | elif current_method =='cutout': 219 | if signal_index == 0: 220 | signals1 = self.Transform.cutout_resize(signal=signals,length=self.cutout_size) 221 | else: 222 | signals2 = self.Transform.cutout_resize(signal=signals,length=self.cutout_size) 223 | 224 | if self.use_cuda: 225 | signals1 = torch.from_numpy(signals1).float() 226 | signals2 = torch.from_numpy(signals2).float() 227 | 228 | # print(signals.shape) 229 | return signals1,signals2,labels 230 | 231 | def __len__(self): 232 | return self.length 233 | 234 | -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # pip install pyedflib ( to read edf file in python ) 4 | from pyedflib import highlevel 5 | 6 | import os 7 | import pandas as pd 8 | import random 9 | import shutil 10 | import itertools 11 | 12 | import multiprocessing 13 | from multiprocessing import Process, Manager, Pool, Lock -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataset/Sleep_edf/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/__pycache__/edf_to_numpy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataset/Sleep_edf/__pycache__/edf_to_numpy.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/__pycache__/function.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataset/Sleep_edf/__pycache__/function.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/__pycache__/makeDataset_each.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyyyyy/DeepSleepNet_pytorch/7091443480694dd8a539a9b1e1e6ebb55da80ae2/utils/dataset/Sleep_edf/__pycache__/makeDataset_each.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/edf_to_numpy.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from .function import * 3 | # dowload Sleep-edf dataset(public dataset) 4 | # !pip install pyedflib (essential) 5 | ## wget -r -N -c -np https://physionet.org/files/sleep-edfx/1.0.0/ 6 | ## if using wget method to download sleep-edf public dataset, some files will be dropped by some error. 7 | ## Therefore, if you can use an internet browser, try to download the sleep-edf dataset using internet! 8 | ## https://www.physionet.org/content/sleep-edfx/1.0.0/ 9 | def check_edf_dataset(path,type='edf'): # read signal and anntoation file list 10 | 11 | if type == 'edf': 12 | annotations_list = search_annotations_edf(path) 13 | signals_list = search_signals_edf(path) 14 | elif type == 'npy': 15 | # print(path.split('/')[:-1]) 16 | annotations_path = '/'.join(path.split('/')[:-2])+'/annotations/' 17 | # print(f'path = {annotations_path}') 18 | annotations_list = search_signals_npy(annotations_path) 19 | signals_list = search_signals_npy(path) 20 | 21 | return {'signals_file_list' : signals_list, 'annotation_file_list' : annotations_list} 22 | 23 | 24 | def make_dataset(dataset='sleepedf',path='/home/eslab/dataset/sleep-edf-database-expanded-1.0.0/sleep-cassette/'): # make edf file to npy file format! 25 | 26 | files = check_edf_dataset(path=path) 27 | signals_edf_list = files['signals_file_list'] 28 | annotation_list = files['annotation_file_list'] 29 | 30 | # print(signals_edf_list) 31 | # print(annotation_list) 32 | print(f'signals file length : {len(signals_edf_list)} // annotation file length : {len(annotation_list)}') 33 | 34 | epoch_size = 30 35 | if dataset =='sleepedf': 36 | sample_rate = 100 37 | 38 | save_signals_path = '/home/eslab/dataset/sleep_edf_final/origin_npy/' 39 | save_annotations_path = '/home/eslab/dataset/sleep_edf_final/annotations/' 40 | 41 | os.makedirs(save_annotations_path,exist_ok=True) 42 | os.makedirs(save_signals_path,exist_ok=True) 43 | 44 | for filename in signals_edf_list: 45 | signals_filename = filename 46 | annotations_filename = search_correct_annotations(path,filename)[0] 47 | 48 | signals_filename = path + signals_filename 49 | annotations_filename = path + annotations_filename 50 | 51 | 52 | 53 | _, _, annotations_header = highlevel.read_edf(annotations_filename) 54 | 55 | label = [] 56 | for ann in annotations_header['annotations']: 57 | start = ann[0] 58 | length = ann[1] 59 | 60 | length = int((length) // epoch_size) # label은 30초 간격으로 사용할 것이기 때문에 30으로 나눈 값이 해당 sleep stage가 반복된 횟수이다. 61 | 62 | if ann[2] == 'Sleep stage W': 63 | for time in range(length): 64 | label.append(0) 65 | elif ann[2] == 'Sleep stage 1': 66 | for time in range(length): 67 | label.append(1) 68 | elif ann[2] == 'Sleep stage 2': 69 | for time in range(length): 70 | label.append(2) 71 | elif ann[2] == 'Sleep stage 3': 72 | for time in range(length): 73 | label.append(3) 74 | elif ann[2] == 'Sleep stage 4': 75 | for time in range(length): 76 | label.append(3) 77 | elif ann[2] == 'Sleep stage R': 78 | for time in range(length): 79 | label.append(4) 80 | else: 81 | for time in range(length): 82 | label.append(5) 83 | label = np.array(label) 84 | 85 | signals, _, signals_header = highlevel.read_edf(signals_filename) 86 | 87 | 88 | signals_len = len(signals[0]) // sample_rate // epoch_size 89 | annotations_len = len(label) 90 | if signals_header['startdate'] == annotations_header['startdate']: 91 | print("%s file's signal & annotations start time is same"%signals_filename.split('/')[-1]) 92 | 93 | if signals_len > annotations_len : 94 | signals = signals[:3][:annotations_len] 95 | elif signals_len < annotations_len : 96 | signals = signals[:3] 97 | label = label[:signals_len] 98 | else: 99 | signals = signals[:3] 100 | signals = np.array(signals) 101 | 102 | np.save(save_signals_path + signals_filename.split('/')[-1].split('.')[0],signals) 103 | np.save(save_annotations_path + annotations_filename.split('/')[-1].split('.')[0],label) 104 | 105 | if (len(signals[0])//sample_rate//epoch_size != len(label)): 106 | print('signals len : %d / annotations len : %d'%(len(signals[0])//sample_rate//epoch_size,len(label))) 107 | print(signals_filename,'\n',annotations_filename) 108 | # else: 109 | # print('signals file and annotations file length is same!!(No problem)') 110 | else: 111 | print("%s file''s signal & annotations start time is different"%signals_filename.split('/')[-1]) 112 | 113 | def check_wellmade(path='/home/eslab/dataset/sleep-edf-database-expanded-1.0.0/sleep-cassette/', 114 | path1='/home/eslab/dataset/sleep_edf_final/origin_npy/'): # Check if the created npy is normal or unnormal. 115 | files = check_edf_dataset(path=path) 116 | signals_edf_list = files['signals_file_list'] 117 | annotation_edf_list = files['annotation_file_list'] 118 | 119 | # print(signals_edf_list) 120 | # print(annotation_list) 121 | print(f'signals file length : {len(signals_edf_list)} // annotation file length : {len(annotation_edf_list)}') 122 | 123 | files = check_edf_dataset(path=path1,type='npy') 124 | signals_npy_list = files['signals_file_list'] 125 | annotation_npy_list = files['annotation_file_list'] 126 | print(f'signals file length : {len(signals_npy_list)} // annotation file length : {len(annotation_npy_list)}') 127 | 128 | for signals_filename in signals_npy_list: 129 | annotations_path = '/'.join(path1.split('/')[:-2])+'/annotations/' 130 | annotaion_filename = search_correct_npy(annotations_path,signals_filename)[0] 131 | if signals_filename.split('-')[:-2] != annotaion_filename.split('-')[:-2]: 132 | print(f'signals file({signals_filename}) and annotation file({annotation_filename}) is not matched!!!') 133 | 134 | 135 | def remove_unnessersary_wake(path='/home/eslab/dataset/sleep_edf_final/origin_npy/'): 136 | files = check_edf_dataset(path=path,type='npy') 137 | signals_npy_list = files['signals_file_list'] 138 | annotation_path = '/'.join(path.split('/')[:-2])+'/annotations/' 139 | 140 | save_signals_path = path + 'remove_wake_version1/' 141 | save_annotations_path = annotation_path + 'remove_wake_version1/' 142 | 143 | 144 | print(save_annotations_path,save_signals_path) 145 | 146 | os.makedirs(save_annotations_path,exist_ok=True) 147 | os.makedirs(save_signals_path,exist_ok=True) 148 | fs = 100 149 | epoch_size = 30 150 | 151 | check_index_size = 20 152 | 153 | for signal_filename in signals_npy_list: 154 | total_label = np.zeros([6],dtype=int) 155 | current_label = np.zeros([6],dtype=int) 156 | annotation_filename = search_correct_npy(dirname=annotation_path,filename=signal_filename)[0] 157 | label = np.load(annotation_path+annotation_filename) 158 | signal = np.load(path+signal_filename) 159 | 160 | current_label = np.bincount(label,minlength=6) 161 | 162 | if len(label) != signal.shape[1]//(fs*epoch_size): 163 | print(f'{signal_filename} file is fault!!!') 164 | 165 | for remove_start_index in range(0,len(label),1): 166 | if(np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[0] != check_index_size): 167 | break 168 | 169 | for remove_end_index in range(len(label),-1,-1,): 170 | #print(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0]) 171 | if(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0] != check_index_size and np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[5] == 0 ): 172 | break 173 | 174 | for remove_start_index in range(0,len(label),1): 175 | if(np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[0] + np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[-1] != check_index_size): 176 | if np.sum(np.bincount(label[remove_start_index:(remove_start_index+check_index_size)],minlength=6)[1:6]) > check_index_size // 2: 177 | break 178 | 179 | for remove_end_index in range(len(label),-1,-1): 180 | #print(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0]) 181 | if(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[0] + np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[-1] != check_index_size ): 182 | if np.sum(np.bincount(label[remove_end_index-check_index_size:(remove_end_index)],minlength=6)[1:6]) > check_index_size // 2: 183 | break 184 | label = label[remove_start_index:remove_end_index+1] 185 | signal = signal[:,remove_start_index*fs*epoch_size:(remove_end_index+1)*fs*epoch_size] 186 | #print(np.bincount(label,minlength=6)) 187 | # print(signal.shape) 188 | # print(label.shape) 189 | if len(label) == len(signal[0])//30//fs: 190 | np.save(save_signals_path+signal_filename,signal) 191 | np.save(save_annotations_path+annotation_filename,label) 192 | 193 | total_label = np.bincount(label,minlength=6) 194 | print(f'{annotation_filename} // original label distribution : {current_label} // new label distribution : {total_label}') 195 | 196 | 197 | def check_label(path='/home/eslab/dataset/sleep_edf_final/annotations/remove_wake/',path1='/home/eslab/dataset/sleep_edf_final/annotations/remove_wake_version0/'): 198 | list1 = search_signals_npy(path) 199 | list2 = search_signals_npy(path1) 200 | 201 | # print(len(list1),len(list2)) 202 | total_label1 = np.zeros(6) 203 | total_label2 = np.zeros(6) 204 | for index in range(len(list1)): 205 | label1 = np.load(path+list1[index]) 206 | label2 = np.load(path1+list1[index]) 207 | total_label1 += np.bincount(label1,minlength=6) 208 | total_label2 += np.bincount(label2,minlength=6) 209 | if np.bincount(label2,minlength=6)[-1] > 10: # you can decide to remove this files in your dataset. (A lot of 'Non' class...) 210 | print(list1[index],np.bincount(label2,minlength=6)[-1]) 211 | ''' 212 | === file list === 213 | SC4091EC-Hypnogram.npy 11 214 | SC4761EP-Hypnogram.npy 20 215 | SC4762EG-Hypnogram.npy 148 216 | SC4092EC-Hypnogram.npy 12 217 | ''' 218 | # print(np.bincount(label1,minlength=6), np.bincount(label2,minlength=6)) 219 | for i in label1: 220 | if i not in label2: 221 | print('='*30) 222 | print(label1) 223 | print(label2) 224 | print('='*30) 225 | break 226 | 227 | print(total_label1/np.sum(total_label1)) 228 | print(total_label2/np.sum(total_label2)) 229 | -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/function.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | # edf file check 4 | def search_annotations_edf(dirname): # if file name includes Hypnogram that file is annotation file 5 | filenames = os.listdir(dirname) 6 | filenames = [file for file in filenames if file.endswith("Hypnogram.edf")] 7 | return filenames 8 | 9 | def search_signals_edf(dirname): # if file name includes 'PSG' that file is signal file 10 | filenames = os.listdir(dirname) 11 | filenames = [file for file in filenames if file.endswith("PSG.edf")] 12 | return filenames 13 | 14 | def search_correct_annotations(dirname,filename): 15 | search_filename = filename.split('-')[0][:-2] # end character will be different between PSG and Hypnogram file. 16 | file_list = os.listdir(dirname) 17 | filename = [file for file in file_list if search_filename in file if file.endswith("Hypnogram.edf")] 18 | 19 | return filename 20 | 21 | 22 | # npy file check 23 | def search_signals_npy(dirname): 24 | filenames = os.listdir(dirname) 25 | filenames = [file for file in filenames if file.endswith(".npy")] 26 | return filenames 27 | 28 | def search_correct_npy(dirname,filename): 29 | search_filename = filename.split('-')[0][:-2] 30 | file_list = os.listdir(dirname) 31 | filename = [file for file in file_list if search_filename in file if file.endswith("npy")] 32 | 33 | return filename 34 | -------------------------------------------------------------------------------- /utils/dataset/Sleep_edf/makeDataset_each.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | from .function import * 3 | 4 | def makeDataset_for_loader(path='/home/eslab/dataset/sleep_edf_final/origin_npy/remove_wake_version1/'): 5 | train_list = search_signals_npy(dirname=path) 6 | train_save_path = path + 'each/' 7 | os.makedirs(train_save_path,exist_ok=True) 8 | 9 | annotation_path = path.split('/') 10 | annotation_path[-3] = 'annotations' 11 | 12 | annotation_path = '/'.join(annotation_path) 13 | sample_rate = 100 14 | epoch_size = 30 15 | 16 | 17 | for folder_name in train_list: 18 | annotation_filename = search_correct_npy(dirname=annotation_path,filename=folder_name)[0] 19 | 20 | create_folder_path = train_save_path + (folder_name.split('-')[0])[:-2] + '/' 21 | os.makedirs(create_folder_path,exist_ok=True) 22 | 23 | signals = np.load(path+folder_name) 24 | label = np.load(annotation_path + annotation_filename) 25 | 26 | # print(signals.shape[1]//(sample_rate*epoch_size)) 27 | # print(label.shape) 28 | if signals.shape[1]//(sample_rate*epoch_size) == len(label): 29 | for index in range(len(label)): 30 | save_signals = signals[:,index*sample_rate*epoch_size:(index+1)*sample_rate*epoch_size] 31 | if index < 10: 32 | file_name = f'000{index}' 33 | elif index < 100: 34 | file_name = f'00{index}' 35 | elif index < 1000: 36 | file_name = f'0{index}' 37 | else: 38 | file_name = f'{index}' 39 | save_file = create_folder_path+f'{file_name}_{label[index]}' 40 | np.save(save_file,save_signals) 41 | 42 | 43 | # label = 44 | 45 | print(len(train_list)) -------------------------------------------------------------------------------- /utils/function.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | def weights_init(m): 4 | classname = m.__class__.__name__ 5 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: # Conv weight init 6 | torch.nn.init.xavier_uniform_(m.weight.data) 7 | 8 | def check_label_info_withPath( file_list,class_mode='6class',check='All',check_file='.npy'): 9 | if class_mode =='5class': 10 | labels = np.zeros(5,dtype=np.intc) 11 | elif class_mode =='6class': 12 | labels = np.zeros(6,dtype=np.intc) 13 | else: 14 | labels = np.zeros(3,dtype=np.intc) 15 | # print(labels) 16 | for signals_paths in file_list: 17 | # print(signals_paths) 18 | 19 | signals_list = os.listdir(signals_paths) 20 | signals_list.sort() 21 | if len(signals_list) != 0: 22 | for signals_filename in signals_list: 23 | if class_mode == '5class': 24 | labels[int(signals_filename.split(check_file)[0].split('_')[-1])] += 1 25 | elif class_mode =='6class': 26 | labels[int(signals_filename.split(check_file)[0].split('_')[-1])] += 1 27 | else: 28 | current_label = int(signals_filename.split(check_file)[0].split('_')[-1]) 29 | if current_label == 4: 30 | labels[2] += 1 31 | elif current_label == 0: 32 | labels[0] += 1 33 | else: 34 | labels[1] += 1 35 | labels_sum = labels.sum() 36 | print(labels_sum) 37 | labels_percent = labels/labels_sum 38 | return labels, labels_percent 39 | 40 | def check_label_info_W_NR_R(signals_path, file_list): 41 | labels = np.zeros(3) 42 | for dataset_folder in file_list: 43 | signals_paths = signals_path + dataset_folder+'/' 44 | signals_list = os.listdir(signals_paths) 45 | signals_list.sort() 46 | for signals_filename in signals_list: 47 | if int(signals_filename.split('.npy')[0].split('_')[-1]) == 0: 48 | labels[0] += 1 49 | elif int(signals_filename.split('.npy')[0].split('_')[-1]) == 4: 50 | labels[2] += 1 51 | else: 52 | labels[1] += 1 53 | labels_sum = labels.sum() 54 | print(labels_sum) 55 | labels_percent = labels/labels_sum 56 | return labels, labels_percent 57 | 58 | def check_label_change_W_NR_R(signals_path, file_list): 59 | 60 | total_change = [0 for _ in range(6)] # W -> NR / W -> R / NR -> W / NR -> R / R -> W / R -> NR 61 | total_count = [0 for _ in range(6)] 62 | total_num = [[] for _ in range(6)] 63 | 64 | for dataset_folder in file_list: 65 | current_label = 0 66 | count = 0 67 | signals_paths = signals_path + dataset_folder+'/' 68 | signals_list = os.listdir(signals_paths) 69 | signals_list.sort() 70 | 71 | for index,signals_filename in enumerate(signals_list): 72 | if index == 0: 73 | if int(signals_filename.split('.npy')[0].split('_')[-1]) == 0: 74 | current_label = 0 75 | elif int(signals_filename.split('.npy')[0].split('_')[-1]) == 4: 76 | current_label = 2 77 | else: 78 | current_label = 1 79 | count = 1 80 | else: # W -> NR / W -> R / NR -> W / NR -> R / R -> W / R -> NR 81 | if int(signals_filename.split('.npy')[0].split('_')[-1]) == 0: # Wake 82 | if current_label == 0: # Wake 그대로 지속 83 | count += 1 84 | else: # NR 또는 R에서 Wake로 온 경우 85 | if current_label == 1: # NR -> W 86 | total_change[2] += 1 87 | total_count[2] += count 88 | total_num[2].append(count) 89 | else: 90 | total_change[4] += 1 91 | total_count[4] += count 92 | total_num[4].append(count) 93 | current_label = 0 # label change 94 | count = 1 # count init 95 | 96 | elif int(signals_filename.split('.npy')[0].split('_')[-1]) == 4: 97 | if current_label == 2: 98 | count += 1 99 | else: 100 | if current_label == 0: # W -> R 101 | total_change[1] += 1 102 | total_count[1] += count 103 | total_num[1].append(count) 104 | else: # NR -> R 105 | total_change[3] += 1 106 | total_count[3] += count 107 | total_num[3].append(count) 108 | 109 | current_label = 2 # label change 110 | count = 1 # count init 111 | else: 112 | if current_label == 1: 113 | count += 1 114 | else: 115 | if current_label == 0: # W -> NR 116 | total_change[0] += 1 117 | total_count[0] += count 118 | total_num[0].append(count) 119 | else: # R -> NR 120 | total_change[5] += 1 121 | total_count[5] += count 122 | total_num[5].append(count) 123 | 124 | current_label = 1 # label change 125 | count = 1 # count init 126 | print(total_change) 127 | print(total_count) 128 | total_change = np.array(total_change) 129 | total_count= np.array(total_count) 130 | total_num = [np.array(i) for i in total_num] 131 | 132 | print(total_count/total_change) 133 | print(np.sum(total_count)/np.sum(total_change)) 134 | 135 | # total_mean = total_num.mean(1) 136 | for index,i in enumerate(total_num): 137 | print(i[:100]) 138 | print('mean : ',i.mean()) 139 | print('std : ',i.std()) 140 | plt.hist(i, bins=50) 141 | plt.savefig('/home/eslab/%d_plot.png'%index) 142 | plt.cla() 143 | 144 | 145 | # exit(1) 146 | 147 | def interp_1d(arr,short_sample=750,long_sample=6000): 148 | return np.interp( 149 | np.arange(0,long_sample), 150 | np.linspace(0,long_sample,num=short_sample), 151 | arr) 152 | 153 | def interp_1d_multiChannel(arr,short_sample=750,long_sample=6000): 154 | signals = [] 155 | # print(arr.shape) 156 | if len(arr) == 1: 157 | return np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr.reshape(-1)).reshape(1,-1) 158 | for i in range(np.shape(arr)[0]): 159 | signals.append(np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr[i].reshape(-1))) 160 | 161 | signals = np.array(signals) 162 | 163 | return signals 164 | 165 | def interp_1d_multiChannel_tensor(arr,short_sample=750,long_sample=6000): 166 | signals = [] 167 | # print(arr.shape) 168 | if len(arr) == 1: 169 | return torch.nn.functional.interpolate(input=arr.reshape(-1),size=long_sample,mode='linear') 170 | # return np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr.reshape(-1)) 171 | for i in range(np.shape(arr)[0]): 172 | signals.append(torch.nn.functional.interpolate(input=arr[i],size=long_sample,mode='linear')) 173 | # signals.append(np.interp(np.arange(0,long_sample),np.linspace(0,long_sample,num=short_sample),arr[i].reshape(-1))) 174 | 175 | signals = np.array(signals) 176 | 177 | return signals -------------------------------------------------------------------------------- /utils/loss_fn.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | 4 | class Entropy(nn.Module): 5 | def __init__(self): 6 | super(Entropy, self).__init__() 7 | 8 | def forward(self, x): 9 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 10 | b = -1.0 * b.sum(dim=1) 11 | 12 | return b.mean() 13 | 14 | class CrossEntropy(nn.Module): 15 | def __init__(self): 16 | super(CrossEntropy, self).__init__() 17 | 18 | def forward(self, x,label,class_num): 19 | # lalel = # 20 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 21 | b = -1.0 * b.sum(dim=1) 22 | 23 | return b.mean() 24 | 25 | class Entropy_each(nn.Module): 26 | def __init__(self): 27 | super(Entropy_each, self).__init__() 28 | 29 | def forward(self, x): 30 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 31 | b = -1.0 * b.sum(dim=1) 32 | 33 | return b 34 | 35 | # def make_weights_for_balanced_classes(annotations_path, file_list,nclasses=5): 36 | # count = np.zeros((nclasses)) 37 | 38 | # for filename in file_list: 39 | # annotations = np.load(annotations_path+filename) 40 | # count += np.bincount(annotations,minlength=nclasses) 41 | 42 | # weight_per_class = [0.] * nclasses 43 | # N = float(sum(count)) 44 | # for i in range(nclasses): 45 | # weight_per_class[i] = N/float(count[i]) 46 | 47 | # return weight_per_class , count 48 | 49 | class FocalLoss(nn.Module): 50 | def __init__(self, gamma=2, size_average=True,weights=None,no_of_classes=5): 51 | super(FocalLoss, self).__init__() 52 | self.gamma = gamma 53 | self.no_of_classes =no_of_classes 54 | self.weights = weights 55 | def forward(self, input, target): 56 | if input.dim() > 2: 57 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 58 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 59 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 60 | if self.weights != None: 61 | labels_one_hot = F.one_hot(target, self.no_of_classes).float() # one-hot Encoding 62 | weights = self.weights.unsqueeze(0)# (5) -> (1,5) [1,1,1,1,1]->[[1,1,1,1,]] 63 | weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot # label에 해당하는 위치의 weight값 64 | weights = weights.sum(1) 65 | weights = weights.unsqueeze(1) 66 | weights = weights.repeat(1,self.no_of_classes) # 정답 class의 weight로 해당 batch의 모든 class weight 일치시킴 67 | #print(weights) 68 | else: 69 | weights = 1 70 | 71 | target = target.view(-1, 1) 72 | if self.weights != None: 73 | weights = weights.gather(1, target) 74 | 75 | logpt = F.log_softmax(input) 76 | 77 | logpt = logpt.gather(1, target) 78 | logpt = logpt.view(-1) 79 | pt = Variable(logpt.data.exp()) 80 | loss = weights * -1 * (1 - pt) ** self.gamma * logpt 81 | return loss.mean() 82 | 83 | 84 | 85 | class MeanFalseError_loss(nn.Module): 86 | def __init__(self): 87 | super(MeanFalseError_loss,self).__init__() 88 | def forward(self, input, target): 89 | loss = (input - target) ** 2 90 | return loss.mean(axis=1) 91 | 92 | class MeanFalseError(nn.Module): 93 | def __init__(self,class_num=5): 94 | super(MeanFalseError,self).__init__() 95 | self.MeanFalseError_loss = MeanFalseError_loss() 96 | 97 | def forward(self, input, target): 98 | input = F.softmax(input) 99 | target = np.eye(self.class_num)[target] 100 | loss = self.MeanFalseError_loss(input,target) 101 | 102 | return loss.sum() 103 | 104 | 105 | class MeanFalseError_squared(nn.Module): 106 | def __init__(self, class_num=5): 107 | super(MeanFalseError_squared, self).__init__() 108 | self.MeanFalseError_loss = MeanFalseError_loss() 109 | 110 | def forward(self, input, target): 111 | input = F.softmax(input) 112 | target = np.eye(self.class_num)[target] 113 | loss = self.MeanFalseError_loss(input, target) ** 2 114 | 115 | return loss.sum() 116 | 117 | class dice_loss(nn.Module): 118 | def __init__(self,smooth = 1.): 119 | super(dice_loss, self).__init__() 120 | self.smooth = smooth 121 | def forward(self, input, target): 122 | num = input.size(0) 123 | input = F.log_softmax(input) 124 | m1 = input.view(num,-1).float() 125 | # pred 값에 대해서 softmax를 해야되는지 파악! 126 | #m1 = F.softmax(m1) 127 | m2 = target.view(num,-1).float() 128 | intersection = (m1 * m2).sum().float() 129 | loss = (1 - ((2. * intersection + self.smooth)/(m1.sum() + m2.sum() + self.smooth)))**2 130 | return loss.mean() 131 | 132 | class dice_cross_loss(nn.Module): 133 | def __init__(self,smooth=1.): 134 | super(dice_cross_loss,self).__init__() 135 | self.loss_dice = dice_loss(smooth=smooth) 136 | 137 | def forward(self,input,target): 138 | dl = self.loss_dice(input,target) # dice loss 139 | cel = F.cross_entropy(input,target) # cross Entropy loss 140 | return 0.9*cel + 0.1*dl 141 | 142 | 143 | def CB_focal_loss(labels, logits, alpha, gamma): 144 | 145 | BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none") 146 | 147 | if gamma == 0.0: 148 | modulator = 1.0 149 | else: 150 | modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + 151 | torch.exp(-1.0 * logits))) 152 | 153 | loss = modulator * BCLoss 154 | 155 | weighted_loss = alpha * loss 156 | focal_loss = torch.sum(weighted_loss) 157 | 158 | focal_loss /= torch.sum(labels) 159 | return focal_loss 160 | 161 | 162 | # saples_per_cls = 샘플의 비율 163 | # no_of_classes = 클래스의 수 164 | # 165 | class CB_loss(nn.Module): 166 | def __init__(self,samples_per_cls, no_of_classes, loss_type, beta=0.9999, gamma=2.0): 167 | super(CB_loss,self).__init__() 168 | self.samples_per_cls = samples_per_cls 169 | self.no_of_classes = no_of_classes 170 | self.loss_type = loss_type 171 | self.beta = beta 172 | self.gamma = gamma 173 | 174 | no_of_classes = 6 175 | if no_of_classes == 6: 176 | self.effective_num = 1.0 - np.power(beta, samples_per_cls[:5]) 177 | else: 178 | self.effective_num = 1.0 - np.power(self.beta, self.samples_per_cls) 179 | print(self.effective_num) 180 | self.weights = (1.0 - self.beta) / np.array(self.effective_num) 181 | 182 | self.weights = np.append(self.weights,0.) 183 | print(self.weights) 184 | self.weights = self.weights / np.sum(self.weights) * self.no_of_classes 185 | 186 | self.weights = torch.tensor(self.weights).float() 187 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 188 | 189 | self.weights = self.weights.to(device) 190 | 191 | def forward(self,logits,labels): 192 | labels_one_hot = F.one_hot(labels, self.no_of_classes).float() # one-hot Encoding 193 | weights = self.weights.unsqueeze(0)# (5) -> (1,5) [1,1,1,1,1]->[[1,1,1,1,]] 194 | weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot # label에 해당하는 위치의 weight값 195 | weights = weights.sum(1) 196 | weights = weights.unsqueeze(1) 197 | weights = weights.repeat(1,self.no_of_classes) # 정답 class의 weight로 해당 batch의 모든 class weight 일치시킴 198 | 199 | if self.loss_type == "focal": 200 | cb_loss = CB_focal_loss(labels_one_hot, logits, weights, self.gamma) 201 | elif self.loss_type == "sigmoid": 202 | cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weight = weights) 203 | elif self.loss_type == "softmax": 204 | pred = logits.softmax(dim = 1) 205 | cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights) 206 | return cb_loss 207 | 208 | 209 | def reduce_loss(loss, reduction='mean'): 210 | return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss 211 | 212 | def linear_combination(x, y, epsilon): 213 | return epsilon * x + (1 - epsilon) * y 214 | 215 | class LabelSmoothingCrossEntropy(nn.Module): 216 | def __init__(self, epsilon:float=0.1, reduction='mean'): 217 | super().__init__() 218 | self.epsilon = epsilon 219 | self.reduction = reduction 220 | 221 | def forward(self, preds, target): 222 | n = preds.size()[-1] 223 | log_preds = F.log_softmax(preds, dim=-1) 224 | loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction) 225 | nll = F.nll_loss(log_preds, target, reduction=self.reduction) 226 | return linear_combination(loss/n, nll, self.epsilon) 227 | 228 | class JSD(nn.Module): 229 | 230 | def __init__(self): 231 | super(JSD, self).__init__() 232 | 233 | def forward(self, net_1_logits, net_2_logits): 234 | net_1_probs = F.softmax(net_1_logits, dim=1) 235 | net_2_probs= F.softmax(net_2_logits, dim=1) 236 | 237 | m = 0.5 * (net_1_probs + net_1_probs) 238 | loss = 0.0 239 | loss += F.kl_div(F.log_softmax(net_1_logits, dim=1), m, reduction="batchmean") 240 | loss += F.kl_div(F.log_softmax(net_2_logits, dim=1), m, reduction="batchmean") 241 | 242 | return (0.5 * loss) 243 | 244 | class JSD_temperal(nn.Module): 245 | 246 | def __init__(self,T=3): 247 | super().__init__() 248 | self.T = T 249 | def forward(self, net_1_logits, net_2_logits): 250 | net_1_probs = F.softmax(net_1_logits/self.T, dim=1) 251 | net_2_probs= F.softmax(net_2_logits/self.T, dim=1) 252 | m = (net_1_probs + net_2_probs) / 2 253 | loss = 0.0 254 | loss += F.kl_div(F.log_softmax(net_1_logits/self.T,dim=1), m, reduction="batchmean") 255 | loss += F.kl_div(F.log_softmax(net_2_logits/self.T,dim=1), m, reduction="batchmean") 256 | # print(f'loss = {0.5*loss}') 257 | return (0.5 * loss) 258 | 259 | class KL_divergence(nn.Module): 260 | 261 | def __init__(self): 262 | super(KL_divergence, self).__init__() 263 | 264 | def forward(self, net_1_logits, net_2_logits): 265 | net_1_probs = F.softmax(net_1_logits, dim=1) 266 | net_2_probs= F.softmax(net_2_logits, dim=1) 267 | 268 | loss = 0.0 269 | loss += F.kl_div(F.log_softmax(net_1_logits, dim=1), net_2_probs, reduction="batchmean") 270 | 271 | return (0.5 * loss) 272 | 273 | class KL_divergence_temperal(nn.Module): 274 | 275 | def __init__(self,T=3): 276 | super(KL_divergence_temperal, self).__init__() 277 | self.T = T 278 | def forward(self, net_1_logits, net_2_logits): 279 | # net_1_probs = F.softmax(net_1_logits/self.T, dim=1) 280 | net_2_probs= F.softmax(net_2_logits/self.T, dim=1) 281 | 282 | loss = 0.0 283 | loss += F.kl_div(F.log_softmax(net_1_logits/self.T, dim=1), net_2_probs, reduction="batchmean") 284 | 285 | return (0.5 * loss) 286 | 287 | class SupConLoss(nn.Module): 288 | def __init__(self, temperature=0.07, contrast_mode='one', 289 | base_temperature=0.07,epsilon=1e-6): 290 | super(SupConLoss, self).__init__() 291 | self.temperature = temperature 292 | self.contrast_mode = contrast_mode 293 | self.base_temperature = base_temperature 294 | self.epsilon = epsilon 295 | def forward(self, features, labels=None, mask=None): 296 | device = (torch.device('cuda') 297 | if features.is_cuda 298 | else torch.device('cpu')) 299 | 300 | if len(features.shape) < 3: 301 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 302 | 'at least 3 dimensions are required') 303 | if len(features.shape) > 3: 304 | features = features.view(features.shape[0], features.shape[1], -1) 305 | # feature shape = [batch, 2, feature_size] 306 | # [:,0,:] => data augmentation 1 307 | # [:,1,:] => data augmentation 2 308 | 309 | batch_size = features.shape[0] 310 | if labels is not None and mask is not None: 311 | raise ValueError('Cannot define both `labels` and `mask`') 312 | elif labels is None and mask is None: 313 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 314 | elif labels is not None: 315 | labels = labels.contiguous().view(-1, 1) 316 | if labels.shape[0] != batch_size: 317 | raise ValueError('Num of labels does not match num of features') 318 | mask = torch.eq(labels, labels.T).float().to(device) 319 | else: 320 | mask = mask.float().to(device) 321 | 322 | contrast_count = features.shape[1] 323 | # [batch, 2, feature] => [batch*2, feature] 324 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 325 | 326 | if self.contrast_mode == 'one': # data augmentation using one 327 | anchor_feature = features[:, 0] 328 | anchor_count = 1 329 | elif self.contrast_mode == 'all': # data augmentation using more than two 330 | anchor_feature = contrast_feature # same 331 | anchor_count = contrast_count 332 | else: 333 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 334 | 335 | # compute logits 336 | # ([batch * anchor count, feature] matmul [feature, batch * anchor count] / tau) 337 | # anchor_feature = anchor feature 338 | # contrastive_feature = positive and negative features 339 | anchor_dot_contrast = torch.div( 340 | torch.matmul(anchor_feature, contrast_feature.T), 341 | self.temperature) 342 | # 모든 sample에 대해서 matmul연산 343 | 344 | # for numerical stability 345 | # print(anchor_dot_contrast) 346 | 347 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 348 | # print(logits_max) 349 | # exit(1) 350 | # 성능 안정화를 위한 norm 작업 351 | logits = anchor_dot_contrast - logits_max.detach() 352 | # print('====mask====',mask.shape) 353 | # print(mask) 354 | # tile mask 355 | mask = mask.repeat(anchor_count, contrast_count) 356 | # print('====mask repeat====',mask.shape) 357 | # print(mask) 358 | # mask-out self-contrast cases 359 | logits_mask = torch.scatter( 360 | torch.ones_like(mask), 361 | 1, 362 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 363 | 0 364 | ) 365 | # print('====logit_mask====',logits_mask.shape) 366 | # print(logits_mask) 367 | 368 | # mask => self remove 369 | mask = mask * logits_mask 370 | # print(np.log(exp(1))) 371 | # exit(1) 372 | # compute log_prob 373 | # 분모 value 374 | exp_logits = torch.exp(logits) * logits_mask # remove self about matmul at all samples 375 | 376 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # Contrastive Loss 377 | 378 | # compute mean of log-likelihood over positive 379 | mean_log_prob_pos = ((mask * log_prob).sum(1) + self.epsilon) / (mask.sum(1) + self.epsilon) 380 | 381 | # loss 382 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 383 | loss = loss.view(anchor_count, batch_size).mean() 384 | 385 | return loss -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | 4 | class LearningRateWarmUP_restart_changeMax(object): 5 | def __init__(self, optimizer, warmup_iteration, cosine_decay_iter,target_lr=0.1, gamma=0.8,after_scheduler=None,two_param=0): 6 | self.optimizer = optimizer 7 | # print(self.optimizer) 8 | self.warmup_iteration = warmup_iteration 9 | self.target_lr = target_lr 10 | self.after_scheduler = after_scheduler 11 | self.cur_iteration_decay = 0 12 | self.cosine_decay_iter = cosine_decay_iter 13 | self.gamma = gamma 14 | self.two_param = two_param 15 | def warmup_learning_rate(self, cur_iteration): 16 | warmup_lr = self.target_lr*float(cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter)))/float(self.warmup_iteration) 17 | for index,param_group in enumerate(self.optimizer.param_groups): 18 | if index == self.two_param: 19 | param_group['lr'] = warmup_lr 20 | def step(self, cur_iteration): 21 | cur_iteration += 1 22 | # print((cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter)))) 23 | if (cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter))) == 1: 24 | for index,param_group in enumerate(self.optimizer.param_groups): 25 | if index == self.two_param: 26 | param_group['initial_lr'] = self.target_lr 27 | # self.optimizer.param_groups[self.two_param]['initial_lr'] = self.target_lr 28 | self.after_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.cosine_decay_iter) 29 | if (cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter))) < self.warmup_iteration: 30 | self.warmup_learning_rate(cur_iteration) 31 | elif (cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter))) < self.warmup_iteration + self.cosine_decay_iter: 32 | # print('cosine : ',cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter))) 33 | self.after_scheduler.step(cur_iteration-(self.cur_iteration_decay*(self.warmup_iteration+self.cosine_decay_iter))-self.warmup_iteration) 34 | else: 35 | # self.after_scheduler.step(cur_iteration - self.warmup_iteration) 36 | self.cur_iteration_decay += 1 37 | self.target_lr = self.target_lr * self.gamma 38 | # self.after_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.warmup_iteration) 39 | 40 | class LearningRateWarmUP_restart(object): 41 | def __init__(self, optimizer, warmup_iteration,cosine_decay_iter, target_lr=0.1,after_scheduler=None): 42 | self.optimizer = optimizer 43 | # print(self.optimizer) 44 | self.warmup_iteration = warmup_iteration 45 | self.target_lr = target_lr 46 | self.after_scheduler = after_scheduler 47 | self.cur_iteration_decay = 0 48 | self.cosine_decay_iter = cosine_decay_iter 49 | def warmup_learning_rate(self, cur_iteration): 50 | warmup_lr = self.target_lr * float( 51 | cur_iteration - (self.cur_iteration_decay * (self.warmup_iteration + self.cosine_decay_iter))) / float( 52 | self.warmup_iteration) 53 | for param_group in self.optimizer.param_groups: 54 | param_group['lr'] = warmup_lr 55 | def step(self, cur_iteration): 56 | cur_iteration += 1 57 | # print(cur_iteration,self.cur_iteration_decay,(cur_iteration-self.cur_iteration_decay) / (self.warmup_iteration*2)) 58 | if (cur_iteration - (self.cur_iteration_decay * (self.warmup_iteration + self.cosine_decay_iter))) < self.warmup_iteration: 59 | self.warmup_learning_rate(cur_iteration) 60 | elif (cur_iteration - (self.cur_iteration_decay * ( 61 | self.warmup_iteration + self.cosine_decay_iter))) <= self.warmup_iteration + self.cosine_decay_iter: 62 | self.after_scheduler.step(cur_iteration - (self.cur_iteration_decay * ( 63 | self.warmup_iteration + self.cosine_decay_iter)) - self.warmup_iteration) 64 | else: 65 | # self.after_scheduler.step(cur_iteration - self.warmup_iteration) 66 | self.cur_iteration_decay += 1 67 | 68 | class LearningRateWarmUP(object): 69 | def __init__(self, optimizer, warmup_iteration, target_lr, after_scheduler=None): 70 | self.optimizer = optimizer 71 | self.warmup_iteration = warmup_iteration 72 | self.target_lr = target_lr 73 | self.after_scheduler = after_scheduler 74 | 75 | def warmup_learning_rate(self, cur_iteration): 76 | warmup_lr = self.target_lr*float(cur_iteration)/float(self.warmup_iteration) 77 | for param_group in self.optimizer.param_groups: 78 | param_group['lr'] = warmup_lr 79 | 80 | def step(self, cur_iteration): 81 | cur_iteration += 1 82 | if cur_iteration < self.warmup_iteration: 83 | self.warmup_learning_rate(cur_iteration) 84 | else: 85 | self.after_scheduler.step(cur_iteration-self.warmup_iteration) 86 | 87 | 88 | def check_warmup(): 89 | v = torch.zeros(10) 90 | lr = 0.1 91 | total_iter = 200 92 | warmup_iter = 20 93 | cosine_decay_iter = 80 94 | 95 | optim = torch.optim.SGD([v], lr=lr) 96 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optim, cosine_decay_iter+1) 97 | # scheduler = LearningRateWarmUP_restart_changeMax(optimizer=optim, 98 | # warmup_iteration=warmup_iter, 99 | # cosine_decay_iter=cosine_decay_iter, 100 | # gamma=0.8, 101 | # target_lr=lr, 102 | # after_scheduler=scheduler_cosine) 103 | 104 | scheduler = LearningRateWarmUP_restart(optimizer=optim, 105 | warmup_iteration=warmup_iter, 106 | cosine_decay_iter=cosine_decay_iter, 107 | target_lr=lr, 108 | after_scheduler=scheduler_cosine) 109 | # scheduler = LearningRateWarmUP(optimizer=optim, 110 | # warmup_iteration=warmup_iter, 111 | # target_lr=lr, 112 | # after_scheduler=scheduler_cosine) 113 | x_iter = [0] 114 | y_lr = [0.] 115 | 116 | for iter in range(0, total_iter): 117 | scheduler.step(iter) 118 | print("iter: ", iter, " ,lr: ", optim.param_groups[0]['lr']) 119 | 120 | optim.zero_grad() 121 | optim.step() 122 | 123 | x_iter.append(iter) 124 | y_lr.append(optim.param_groups[0]['lr']) 125 | 126 | plt.plot(x_iter, y_lr, 'b') 127 | plt.legend(['learning rate']) 128 | plt.xlabel('iteration') 129 | plt.ylabel('learning rate') 130 | plt.savefig('/home/eslab/A.png') 131 | 132 | 133 | 134 | 135 | # check_warmup() --------------------------------------------------------------------------------