├── models ├── __init__.py ├── neuronet │ ├── __init__.py │ ├── resnet1d.py │ └── model.py ├── loss.py └── utils.py ├── dataset ├── __init__.py ├── utils.py └── selected_shhs1_files.txt ├── downstream ├── __init__.py ├── linear_prob.py └── fine_tuning.py ├── pretrained ├── __init__.py ├── data_loader.py └── train.py ├── figures ├── overview.jpg ├── hypnogram.jpg └── model_structure.jpg ├── requirements.txt ├── README.md ├── .gitignore └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /downstream/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/neuronet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/overview.jpg -------------------------------------------------------------------------------- /figures/hypnogram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/hypnogram.jpg -------------------------------------------------------------------------------- /figures/model_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/model_structure.jpg -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import numpy as np 4 | from sklearn.model_selection import KFold 5 | from scipy.signal import butter, lfilter 6 | 7 | 8 | def butter_bandpass_filter(signal, low_cut, high_cut, fs, order=5): 9 | if low_cut == 0: 10 | low_cut = 0.5 11 | nyq = 0.5 * fs 12 | low = low_cut / nyq 13 | high = high_cut / nyq 14 | b, a = butter(order, [low, high], btype='band') 15 | y = lfilter(b, a, signal, axis=-1) 16 | return y 17 | 18 | 19 | def split_train_test_val_files(base_path, n_splits=5): 20 | # Subject Variability 21 | files = os.listdir(base_path) 22 | files = np.array(files) 23 | 24 | size = len(files) 25 | print('File Path => ' + base_path) 26 | print('Total Subject Size => {}'.format(size)) 27 | kf = KFold(n_splits=n_splits) 28 | 29 | temp = {} 30 | for fold, (train_idx, test_idx) in enumerate(kf.split(files)): 31 | train_size = len(train_idx) 32 | val_point = int(train_size * 0.85) 33 | train_idx, val_idx = train_idx[:val_point], train_idx[val_point:] 34 | 35 | temp[fold] = { 36 | 'train_paths': list([os.path.join(base_path, f_name) for f_name in files[train_idx]]), 37 | 'ft_paths': list([os.path.join(base_path, f_name) for f_name in files[val_idx]]), 38 | 'eval_paths': list([os.path.join(base_path, f_name) for f_name in files[test_idx]]) 39 | } 40 | return temp 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2024.2.2 2 | charset-normalizer==3.3.2 3 | contourpy==1.2.1 4 | cycler==0.12.1 5 | decorator==5.1.1 6 | einops==0.8.0 7 | filelock==3.14.0 8 | fonttools==4.51.0 9 | fsspec==2024.3.1 10 | huggingface-hub==0.23.0 11 | idna==3.7 12 | importlib_resources==6.4.0 13 | Jinja2==3.1.4 14 | joblib==1.4.2 15 | kiwisolver==1.4.5 16 | lazy_loader==0.4 17 | mamba-ssm==1.2.0.post1 18 | MarkupSafe==2.1.5 19 | matplotlib==3.8.4 20 | mne==1.7.0 21 | mpmath==1.3.0 22 | networkx==3.2.1 23 | ninja==1.11.1.1 24 | numpy==1.26.4 25 | nvidia-cublas-cu12==12.1.3.1 26 | nvidia-cuda-cupti-cu12==12.1.105 27 | nvidia-cuda-nvrtc-cu12==12.1.105 28 | nvidia-cuda-runtime-cu12==12.1.105 29 | nvidia-cudnn-cu12==8.9.2.26 30 | nvidia-cufft-cu12==11.0.2.54 31 | nvidia-curand-cu12==10.3.2.106 32 | nvidia-cusolver-cu12==11.4.5.107 33 | nvidia-cusparse-cu12==12.1.0.106 34 | nvidia-nccl-cu12==2.19.3 35 | nvidia-nvjitlink-cu12==12.4.127 36 | nvidia-nvtx-cu12==12.1.105 37 | packaging==24.0 38 | pillow==10.3.0 39 | platformdirs==4.2.1 40 | pooch==1.8.1 41 | pyparsing==3.1.2 42 | python-dateutil==2.9.0.post0 43 | PyYAML==6.0.1 44 | regex==2024.5.10 45 | requests==2.31.0 46 | safetensors==0.4.3 47 | scikit-learn==1.4.2 48 | scipy==1.13.0 49 | six==1.16.0 50 | sympy==1.12 51 | threadpoolctl==3.5.0 52 | timm==0.9.16 53 | tokenizers==0.19.1 54 | torch==2.2.2 55 | torchvision==0.17.2 56 | tqdm==4.66.4 57 | transformers==4.40.2 58 | triton==2.2.0 59 | typing_extensions==4.11.0 60 | urllib3==2.2.1 61 | zipp==3.18.1 62 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as f 6 | 7 | 8 | class NTXentLoss(nn.Module): 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.criterion = nn.CrossEntropyLoss(reduction='sum') 12 | self.similarity_f = nn.CosineSimilarity(dim=-1) 13 | self.temperature = temperature 14 | 15 | @staticmethod 16 | def mask_correlated_samples(batch_size): 17 | n = 2 * batch_size 18 | mask = torch.ones((n, n), dtype=bool) 19 | mask = mask.fill_diagonal_(0) 20 | 21 | for i in range(batch_size): 22 | mask[i, batch_size + i] = 0 23 | mask[batch_size + i, i] = 0 24 | return mask 25 | 26 | def forward(self, z_i, z_j): 27 | batch_size = z_j.shape[0] 28 | n = 2 * batch_size 29 | z = torch.cat((z_i, z_j), dim=0) 30 | z = f.normalize(z, dim=-1) 31 | 32 | mask = self.mask_correlated_samples(batch_size) 33 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) 34 | 35 | sim_i_j = torch.diag(sim, batch_size) 36 | sim_j_i = torch.diag(sim, -batch_size) 37 | 38 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(n, 1) 39 | negative_samples = sim[mask].reshape(n, -1) 40 | 41 | labels = torch.from_numpy(np.array([0] * n)).reshape(-1).to(positive_samples.device).long() # .float() 42 | logits = torch.cat((positive_samples, negative_samples), dim=1) / self.temperature 43 | 44 | loss = self.criterion(logits, labels) 45 | loss /= n 46 | return loss, (labels, logits) 47 | -------------------------------------------------------------------------------- /pretrained/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import mne 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 10 | 11 | random_seed = 777 12 | np.random.seed(random_seed) 13 | torch.manual_seed(random_seed) 14 | random.seed(random_seed) 15 | 16 | 17 | class TorchDataset(Dataset): 18 | def __init__(self, paths, sfreq, rfreq, scaler: bool = False): 19 | super().__init__() 20 | self.x, self.y = self.get_data(paths, sfreq, rfreq, scaler) 21 | self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long) 22 | 23 | @staticmethod 24 | def get_data(paths, sfreq, rfreq, scaler_flag): 25 | info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1']) 26 | scaler = mne.decoding.Scaler(info=info, scalings='median') 27 | total_x, total_y = [], [] 28 | for path in paths: 29 | data = np.load(path) 30 | x, y = data['x'], data['y'] 31 | x = np.expand_dims(x, axis=1) 32 | if scaler_flag: 33 | x = scaler.fit_transform(x) 34 | x = mne.EpochsArray(x, info=info) 35 | x = x.resample(rfreq) 36 | x = x.get_data().squeeze() 37 | total_x.append(x) 38 | total_y.append(y) 39 | total_x, total_y = np.concatenate(total_x), np.concatenate(total_y) 40 | return total_x, total_y 41 | 42 | def __len__(self): 43 | return len(self.y) 44 | 45 | def __getitem__(self, item): 46 | x = torch.tensor(self.x[item]) 47 | y = torch.tensor(self.y[item]) 48 | return x, y 49 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 7 | grid = np.arange(grid_size, dtype=float) 8 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 9 | if cls_token: 10 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 11 | return pos_embed 12 | 13 | 14 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 15 | assert embed_dim % 2 == 0 16 | omega = np.arange(embed_dim // 2, dtype=float) 17 | omega /= embed_dim / 2. 18 | omega = 1. / 10000 ** omega # (D/2,) 19 | 20 | pos = pos.reshape(-1) # (M,) 21 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 22 | 23 | emb_sin = np.sin(out) # (M, D/2) 24 | emb_cos = np.cos(out) # (M, D/2) 25 | 26 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 27 | return emb 28 | 29 | 30 | def get_2d_sincos_pos_embed(embed_dim, grid_sizes, cls_token=False): 31 | """ 32 | grid_size: int of the grid height and width 33 | return: 34 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 35 | """ 36 | grid_h = np.arange(grid_sizes[0], dtype=np.float32) 37 | grid_w = np.arange(grid_sizes[1], dtype=np.float32) 38 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 39 | grid = np.stack(grid, axis=0) 40 | 41 | grid = grid.reshape([2, 1, grid_sizes[0], grid_sizes[1]]) 42 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 43 | if cls_token: 44 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 45 | return pos_embed 46 | 47 | 48 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 49 | assert embed_dim % 2 == 0 50 | 51 | # use half of dimensions to encode grid_h 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 54 | 55 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 56 | return emb 57 | 58 | 59 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): 60 | """ 61 | grid_size: int of the grid height and width 62 | return: 63 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 64 | """ 65 | grid_h = np.arange(grid_size[0], dtype=np.float32) 66 | grid_w = np.arange(grid_size[1], dtype=np.float32) 67 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 68 | grid = np.stack(grid, axis=0) 69 | 70 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 71 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 72 | if cls_token: 73 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 74 | return pos_embed 75 | 76 | 77 | def model_size(model): 78 | size_model = 0 79 | for param in model.parameters(): 80 | if param.data.is_floating_point(): 81 | size_model += param.numel() * torch.finfo(param.data.dtype).bits 82 | else: 83 | size_model += param.numel() * torch.iinfo(param.data.dtype).bits 84 | mb_size = size_model / 8e6 85 | return mb_size 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuroNet: A Novel Hybrid Self-Supervised Learning Framework for Sleep Stage Classification Using Single-Channel EEG 2 | 3 | >**Choel-Hui Lee, Hakseung Kim, Hyun-jee Han, Min-Kyung Jung, Byung C. Yoon and Dong-Joo Kim** 4 | 5 | [[`Paper`](https://arxiv.org/abs/2404.17585)] [[`Paper with Code`](https://paperswithcode.com/paper/neuronet-a-novel-hybrid-self-supervised)] [[`BibTeX`](#license-and-citation)] 6 | 7 | **Full code coming soon^^** 8 | 9 | ![neuronet structure](https://github.com/dlcjfgmlnasa/NeuroNet/blob/main/figures/model_structure.jpg) 10 | 11 | ## Introduction 🔥 12 | The classification of sleep stages is a pivotal aspect of diagnosing sleep disorders and evaluating sleep quality. However, the conventional manual scoring process, conducted by clinicians, is time-consuming and prone to human bias. Recent advancements in deep learning have substantially propelled the automation of sleep stage classification. Nevertheless, challenges persist, including the need for large datasets with labels and the inherent biases in human-generated annotations. This paper introduces NeuroNet, a self-supervised learning (SSL) framework designed to effectively harness unlabeled single-channel sleep electroencephalogram (EEG) signals by integrating contrastive learning tasks and masked prediction tasks. NeuroNet demonstrates superior performance over existing SSL methodologies through extensive experimentation conducted across three polysomnography (PSG) datasets. Additionally, this study proposes a Mamba-based temporal context module to capture the relationships among diverse EEG epochs. Combining NeuroNet with the Mamba-based temporal context module has demonstrated the capability to achieve, or even surpass, the performance of the latest supervised learning methodologies, even with a limited amount of labeled data. This study is expected to establish a new benchmark in sleep stage classification, promising to guide future research and applications in the field of sleep analysis. 13 | 14 | ## Main Result 🥇 15 | 16 |
17 | Performance of Sleep-EDFX across various self-supervised learning and supervised learning 18 |

image

19 |
20 | 21 |
22 | The output hypnograms across five sleep stages. 23 |

image

24 | The first, second, and third columns correspond to #sc4031e0, #shhs1-204928, and #subject-53 within Sleep-EDFX, SHHS, and ISRUC, respectively. (A) is manually scored by a sleep expert. (B) and (C) respectively represent NeuroNet-B and NeuroNet-T. The first row for both (B) and (C) displays the results for NeuroNet+TCM, while the second row shows the results for NeuroNet. The errors are marked by the red dots. 25 |
26 | 27 | 28 | ## License and Citation 📰 29 | The software is licensed under the Apache License 2.0. Please cite the following paper if you have used this code: 30 | ``` 31 | @misc{lee2024neuronet, 32 | title={NeuroNet: A Novel Hybrid Self-Supervised Learning Framework for Sleep Stage Classification Using Single-Channel EEG}, 33 | author={Cheol-Hui Lee and Hakseung Kim and Hyun-jee Han and Min-Kyung Jung and Byung C. Yoon and Dong-Joo Kim}, 34 | year={2024}, 35 | eprint={2404.17585}, 36 | archivePrefix={arXiv}, 37 | primaryClass={cs.HC} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .idea/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /dataset/selected_shhs1_files.txt: -------------------------------------------------------------------------------- 1 | shhs1-200010 2 | shhs1-200017 3 | shhs1-200039 4 | shhs1-200052 5 | shhs1-200097 6 | shhs1-200122 7 | shhs1-200139 8 | shhs1-200152 9 | shhs1-200166 10 | shhs1-200174 11 | shhs1-200178 12 | shhs1-200214 13 | shhs1-200215 14 | shhs1-200258 15 | shhs1-200319 16 | shhs1-200336 17 | shhs1-200348 18 | shhs1-200349 19 | shhs1-200383 20 | shhs1-200390 21 | shhs1-200408 22 | shhs1-200409 23 | shhs1-200437 24 | shhs1-200466 25 | shhs1-200467 26 | shhs1-200477 27 | shhs1-200496 28 | shhs1-200516 29 | shhs1-200536 30 | shhs1-200558 31 | shhs1-200569 32 | shhs1-200606 33 | shhs1-200631 34 | shhs1-200645 35 | shhs1-200665 36 | shhs1-200691 37 | shhs1-200701 38 | shhs1-200702 39 | shhs1-200725 40 | shhs1-200737 41 | shhs1-200741 42 | shhs1-200771 43 | shhs1-200780 44 | shhs1-200783 45 | shhs1-200785 46 | shhs1-200793 47 | shhs1-200805 48 | shhs1-200823 49 | shhs1-200834 50 | shhs1-200843 51 | shhs1-200853 52 | shhs1-200869 53 | shhs1-200870 54 | shhs1-200879 55 | shhs1-200884 56 | shhs1-200886 57 | shhs1-200895 58 | shhs1-200897 59 | shhs1-200902 60 | shhs1-200914 61 | shhs1-200920 62 | shhs1-200922 63 | shhs1-200924 64 | shhs1-200953 65 | shhs1-200964 66 | shhs1-200975 67 | shhs1-200983 68 | shhs1-201021 69 | shhs1-201023 70 | shhs1-201081 71 | shhs1-201087 72 | shhs1-201090 73 | shhs1-201102 74 | shhs1-201117 75 | shhs1-201118 76 | shhs1-201125 77 | shhs1-201153 78 | shhs1-201171 79 | shhs1-201180 80 | shhs1-201237 81 | shhs1-201243 82 | shhs1-201254 83 | shhs1-201270 84 | shhs1-201316 85 | shhs1-201329 86 | shhs1-201359 87 | shhs1-201371 88 | shhs1-201515 89 | shhs1-201552 90 | shhs1-201560 91 | shhs1-201566 92 | shhs1-201581 93 | shhs1-201586 94 | shhs1-201637 95 | shhs1-201706 96 | shhs1-201725 97 | shhs1-201748 98 | shhs1-201792 99 | shhs1-201839 100 | shhs1-201917 101 | shhs1-201946 102 | shhs1-201986 103 | shhs1-202003 104 | shhs1-202039 105 | shhs1-202054 106 | shhs1-202097 107 | shhs1-202108 108 | shhs1-202139 109 | shhs1-202160 110 | shhs1-202222 111 | shhs1-202240 112 | shhs1-202253 113 | shhs1-202257 114 | shhs1-202261 115 | shhs1-202267 116 | shhs1-202282 117 | shhs1-202293 118 | shhs1-202392 119 | shhs1-202409 120 | shhs1-202424 121 | shhs1-202516 122 | shhs1-202552 123 | shhs1-202589 124 | shhs1-202636 125 | shhs1-202650 126 | shhs1-202675 127 | shhs1-202694 128 | shhs1-202715 129 | shhs1-202723 130 | shhs1-202735 131 | shhs1-202738 132 | shhs1-202865 133 | shhs1-202907 134 | shhs1-202911 135 | shhs1-202940 136 | shhs1-202956 137 | shhs1-202983 138 | shhs1-202987 139 | shhs1-202996 140 | shhs1-203027 141 | shhs1-203040 142 | shhs1-203106 143 | shhs1-203137 144 | shhs1-203168 145 | shhs1-203171 146 | shhs1-203173 147 | shhs1-203180 148 | shhs1-203185 149 | shhs1-203203 150 | shhs1-203218 151 | shhs1-203233 152 | shhs1-203250 153 | shhs1-203254 154 | shhs1-203286 155 | shhs1-203312 156 | shhs1-203324 157 | shhs1-203350 158 | shhs1-203358 159 | shhs1-203380 160 | shhs1-203381 161 | shhs1-203389 162 | shhs1-203409 163 | shhs1-203410 164 | shhs1-203421 165 | shhs1-203433 166 | shhs1-203460 167 | shhs1-203478 168 | shhs1-203479 169 | shhs1-203514 170 | shhs1-203523 171 | shhs1-203533 172 | shhs1-203546 173 | shhs1-203547 174 | shhs1-203580 175 | shhs1-203614 176 | shhs1-203622 177 | shhs1-203651 178 | shhs1-203655 179 | shhs1-203687 180 | shhs1-203691 181 | shhs1-203699 182 | shhs1-203726 183 | shhs1-203734 184 | shhs1-203746 185 | shhs1-203791 186 | shhs1-203824 187 | shhs1-203845 188 | shhs1-203847 189 | shhs1-203850 190 | shhs1-203862 191 | shhs1-203870 192 | shhs1-203884 193 | shhs1-203889 194 | shhs1-203897 195 | shhs1-203901 196 | shhs1-203941 197 | shhs1-203965 198 | shhs1-203971 199 | shhs1-203997 200 | shhs1-204009 201 | shhs1-204027 202 | shhs1-204028 203 | shhs1-204041 204 | shhs1-204042 205 | shhs1-204047 206 | shhs1-204051 207 | shhs1-204062 208 | shhs1-204074 209 | shhs1-204083 210 | shhs1-204094 211 | shhs1-204103 212 | shhs1-204106 213 | shhs1-204107 214 | shhs1-204111 215 | shhs1-204115 216 | shhs1-204132 217 | shhs1-204134 218 | shhs1-204145 219 | shhs1-204148 220 | shhs1-204149 221 | shhs1-204165 222 | shhs1-204187 223 | shhs1-204189 224 | shhs1-204219 225 | shhs1-204230 226 | shhs1-204234 227 | shhs1-204282 228 | shhs1-204283 229 | shhs1-204294 230 | shhs1-204298 231 | shhs1-204299 232 | shhs1-204317 233 | shhs1-204318 234 | shhs1-204350 235 | shhs1-204351 236 | shhs1-204355 237 | shhs1-204401 238 | shhs1-204405 239 | shhs1-204431 240 | shhs1-204434 241 | shhs1-204453 242 | shhs1-204473 243 | shhs1-204486 244 | shhs1-204494 245 | shhs1-204500 246 | shhs1-204509 247 | shhs1-204519 248 | shhs1-204529 249 | shhs1-204543 250 | shhs1-204550 251 | shhs1-204553 252 | shhs1-204554 253 | shhs1-204576 254 | shhs1-204611 255 | shhs1-204641 256 | shhs1-204666 257 | shhs1-204684 258 | shhs1-204688 259 | shhs1-204706 260 | shhs1-204710 261 | shhs1-204714 262 | shhs1-204724 263 | shhs1-204731 264 | shhs1-204737 265 | shhs1-204741 266 | shhs1-204743 267 | shhs1-204745 268 | shhs1-204748 269 | shhs1-204751 270 | shhs1-204754 271 | shhs1-204775 272 | shhs1-204777 273 | shhs1-204781 274 | shhs1-204783 275 | shhs1-204785 276 | shhs1-204787 277 | shhs1-204796 278 | shhs1-204802 279 | shhs1-204811 280 | shhs1-204817 281 | shhs1-204818 282 | shhs1-204826 283 | shhs1-204830 284 | shhs1-204831 285 | shhs1-204836 286 | shhs1-204844 287 | shhs1-204846 288 | shhs1-204851 289 | shhs1-204865 290 | shhs1-204873 291 | shhs1-204879 292 | shhs1-204887 293 | shhs1-204890 294 | shhs1-204894 295 | shhs1-204905 296 | shhs1-204908 297 | shhs1-204911 298 | shhs1-204916 299 | shhs1-204927 300 | shhs1-204928 301 | shhs1-204934 302 | shhs1-204980 303 | shhs1-204985 304 | shhs1-205045 305 | shhs1-205049 306 | shhs1-205071 307 | shhs1-205083 308 | shhs1-205146 309 | shhs1-205148 310 | shhs1-205178 311 | shhs1-205207 312 | shhs1-205213 313 | shhs1-205238 314 | shhs1-205244 315 | shhs1-205245 316 | shhs1-205257 317 | shhs1-205380 318 | shhs1-205429 319 | shhs1-205451 320 | shhs1-205477 321 | shhs1-205502 322 | shhs1-205516 323 | shhs1-205582 324 | shhs1-205610 325 | shhs1-205676 326 | shhs1-205700 327 | shhs1-205741 328 | shhs1-205780 329 | shhs1-205789 -------------------------------------------------------------------------------- /downstream/linear_prob.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import mne 4 | import torch 5 | import random 6 | import argparse 7 | import warnings 8 | import numpy as np 9 | import torch.nn as nn 10 | from typing import List 11 | import torch.optim as opt 12 | from torch.utils.data import Dataset, DataLoader 13 | from sklearn.metrics import accuracy_score, f1_score 14 | from models.neuronet.model import NeuroNet, NeuroNetEncoderWrapper 15 | 16 | 17 | warnings.filterwarnings(action='ignore') 18 | 19 | 20 | random_seed = 777 21 | np.random.seed(random_seed) 22 | torch.manual_seed(random_seed) 23 | random.seed(random_seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | 29 | 30 | def get_args(): 31 | file_name = 'mini' 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--n_fold', default=0, choices=[0, 1, 2, 3, 4]) 34 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt', 35 | 'ISRUC-Sleep', 'cm_eeg', file_name), type=str) 36 | parser.add_argument('--epochs', default=300, type=int) 37 | parser.add_argument('--batch_size', default=512, type=int) 38 | parser.add_argument('--lr', default=0.00005, type=float) 39 | return parser.parse_args() 40 | 41 | 42 | class Classifier(nn.Module): 43 | def __init__(self, backbone, backbone_final_length): 44 | super().__init__() 45 | self.backbone = self.freeze_backbone(backbone) 46 | self.backbone_final_length = backbone_final_length 47 | self.feature_num = self.backbone_final_length * 2 48 | self.dropout_p = 0.5 49 | self.fc = nn.Sequential( 50 | nn.Linear(backbone_final_length, self.feature_num), 51 | nn.BatchNorm1d(self.feature_num), 52 | nn.ELU(), 53 | nn.Dropout(self.dropout_p), 54 | nn.Linear(self.feature_num, 5) 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.backbone(x) 59 | x = self.fc(x) 60 | return x 61 | 62 | @staticmethod 63 | def freeze_backbone(backbone: nn.Module): 64 | for name, module in backbone.named_modules(): 65 | for param in module.parameters(): 66 | param.requires_grad = False 67 | return backbone 68 | 69 | 70 | class Trainer(object): 71 | def __init__(self, args): 72 | super().__init__() 73 | self.args = args 74 | self.ckpt_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'model', 'best_model.pth') 75 | self.ckpt = torch.load(self.ckpt_path, map_location='cpu') 76 | self.sfreq, self.rfreq = self.ckpt['hyperparameter']['sfreq'], self.ckpt['hyperparameter']['rfreq'] 77 | self.ft_paths, self.eval_paths = self.ckpt['paths']['ft_paths'], self.ckpt['paths']['eval_paths'] 78 | self.model = self.get_pretrained_model().to(device) 79 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.args.lr) 80 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.epochs) 81 | self.criterion = nn.CrossEntropyLoss() 82 | 83 | def train(self): 84 | print('Checkpoint File Path : {}'.format(self.ckpt_path)) 85 | train_dataset = TorchDataset(paths=self.ft_paths, sfreq=self.sfreq, rfreq=self.rfreq) 86 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.args.batch_size, 87 | shuffle=True, drop_last=True) 88 | eval_dataset = TorchDataset(paths=self.eval_paths, sfreq=self.sfreq, rfreq=self.rfreq) 89 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=self.args.batch_size, drop_last=False) 90 | 91 | best_model_state, best_mf1 = None, 0.0 92 | best_pred, best_real = [], [] 93 | 94 | for epoch in range(self.args.epochs): 95 | self.model.train() 96 | epoch_train_loss = [] 97 | for data in train_dataloader: 98 | self.optimizer.zero_grad() 99 | x, y = data 100 | x, y = x.to(device), y.to(device) 101 | 102 | pred = self.model(x) 103 | loss = self.criterion(pred, y) 104 | 105 | epoch_train_loss.append(float(loss.detach().cpu().item())) 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | self.model.eval() 110 | epoch_test_loss = [] 111 | epoch_real, epoch_pred = [], [] 112 | for data in eval_dataloader: 113 | with torch.no_grad(): 114 | x, y = data 115 | x, y = x.to(device), y.to(device) 116 | pred = self.model(x) 117 | loss = self.criterion(pred, y) 118 | pred = pred.argmax(dim=-1) 119 | real = y 120 | 121 | epoch_real.extend(list(real.detach().cpu().numpy())) 122 | epoch_pred.extend(list(pred.detach().cpu().numpy())) 123 | epoch_test_loss.append(float(loss.detach().cpu().item())) 124 | 125 | epoch_train_loss, epoch_test_loss = np.mean(epoch_train_loss), np.mean(epoch_test_loss) 126 | eval_acc, eval_mf1 = accuracy_score(y_true=epoch_real, y_pred=epoch_pred), \ 127 | f1_score(y_true=epoch_real, y_pred=epoch_pred, average='macro') 128 | 129 | print('[Epoch] : {0:03d} \t ' 130 | '[Train Loss] => {1:.4f} \t ' 131 | '[Evaluation Loss] => {2:.4f} \t ' 132 | '[Evaluation Accuracy] => {3:.4f} \t' 133 | '[Evaluation Macro-F1] => {4:.4f}'.format(epoch + 1, epoch_train_loss, epoch_test_loss, 134 | eval_acc, eval_mf1)) 135 | 136 | if best_mf1 < eval_mf1: 137 | best_mf1 = eval_mf1 138 | best_model_state = self.model.state_dict() 139 | best_pred, best_real = epoch_pred, epoch_real 140 | 141 | self.scheduler.step() 142 | 143 | self.save_ckpt(best_model_state, best_pred, best_real) 144 | 145 | def save_ckpt(self, model_state, pred, real): 146 | if not os.path.exists(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob')): 147 | os.makedirs(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob')) 148 | 149 | save_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob', 'best_model.pth') 150 | torch.save({ 151 | 'backbone_name': 'NeuroNet_LinearProb', 152 | 'model_state': model_state, 153 | 'hyperparameter': self.args.__dict__, 154 | 'result': {'real': real, 'pred': pred}, 155 | 'paths': {'train_paths': self.ft_paths, 'eval_paths': self.eval_paths} 156 | }, save_path) 157 | 158 | def get_pretrained_model(self): 159 | # 1. Prepared Pretrained Model 160 | model_parameter = self.ckpt['model_parameter'] 161 | pretrained_model = NeuroNet(**model_parameter) 162 | pretrained_model.load_state_dict(self.ckpt['model_state']) 163 | 164 | # 2. Encoder Wrapper 165 | backbone = NeuroNetEncoderWrapper( 166 | fs=model_parameter['fs'], second=model_parameter['second'], 167 | time_window=model_parameter['time_window'], time_step=model_parameter['time_step'], 168 | frame_backbone=pretrained_model.frame_backbone, 169 | patch_embed=pretrained_model.autoencoder.patch_embed, 170 | encoder_block=pretrained_model.autoencoder.encoder_block, 171 | encoder_norm=pretrained_model.autoencoder.encoder_norm, 172 | cls_token=pretrained_model.autoencoder.cls_token, 173 | pos_embed=pretrained_model.autoencoder.pos_embed, 174 | final_length=pretrained_model.autoencoder.embed_dim 175 | ) 176 | 177 | # 3. Generator Classifier 178 | model = Classifier(backbone=backbone, 179 | backbone_final_length=pretrained_model.autoencoder.embed_dim) 180 | return model 181 | 182 | 183 | class TorchDataset(Dataset): 184 | def __init__(self, paths: List, sfreq: int, rfreq: int): 185 | self.paths = paths 186 | self.info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1']) 187 | self.xs, self.ys = self.get_data(rfreq) 188 | 189 | def __len__(self): 190 | return self.xs.shape[0] 191 | 192 | def get_data(self, rfreq): 193 | xs, ys = [], [] 194 | for path in self.paths: 195 | data = np.load(path) 196 | x, y = data['x'], data['y'] 197 | x = np.expand_dims(x, axis=1) 198 | x = mne.EpochsArray(x, info=self.info) 199 | x = x.resample(rfreq) 200 | x = x.get_data().squeeze() 201 | xs.append(x) 202 | ys.append(y) 203 | xs = np.concatenate(xs, axis=0) 204 | ys = np.concatenate(ys, axis=0) 205 | return xs, ys 206 | 207 | def __getitem__(self, idx): 208 | x = torch.tensor(self.xs[idx], dtype=torch.float) 209 | y = torch.tensor(self.ys[idx], dtype=torch.long) 210 | return x, y 211 | 212 | 213 | if __name__ == '__main__': 214 | augments = get_args() 215 | for n_fold in range(10): 216 | augments.n_fold = n_fold 217 | trainer = Trainer(augments) 218 | trainer.train() 219 | -------------------------------------------------------------------------------- /models/neuronet/resnet1d.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FrameBackBone(nn.Module): 7 | def __init__(self, fs: int, window: int): 8 | super().__init__() 9 | self.model = BackBone(input_size=fs * window, input_channel=1, layers=[1, 1, 1, 1]) 10 | self.feature_num = self.model.get_final_length() // 2 11 | self.feature_layer = nn.Sequential( 12 | nn.Linear(self.model.get_final_length(), self.feature_num), 13 | nn.ELU(), 14 | nn.Linear(self.feature_num, self.feature_num) 15 | ) 16 | 17 | def forward(self, x): 18 | latent_seq = [] 19 | for i in range(x.shape[1]): 20 | sample = torch.unsqueeze(x[:, i, :], dim=1) 21 | latent = self.model(sample) 22 | latent_seq.append(latent) 23 | latent_seq = torch.stack(latent_seq, dim=1) 24 | latent_seq = self.feature_layer(latent_seq) 25 | return latent_seq 26 | 27 | 28 | class BackBone(nn.Module): 29 | def __init__(self, input_size, input_channel, layers): 30 | super().__init__() 31 | self.inplanes3 = 32 32 | self.inplanes5 = 32 33 | self.inplanes7 = 32 34 | 35 | self.input_size = input_size 36 | self.conv1 = nn.Conv1d(input_channel, 32, kernel_size=7, stride=2, padding=3, bias=False) 37 | self.bn1 = nn.BatchNorm1d(32) 38 | self.relu = nn.ELU(inplace=True) 39 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 40 | 41 | self.layer3x3_1 = self._make_layer3(BasicBlock3x3, 32, layers[0], stride=1) 42 | self.layer3x3_2 = self._make_layer3(BasicBlock3x3, 32, layers[1], stride=1) 43 | self.layer3x3_3 = self._make_layer3(BasicBlock3x3, 48, layers[2], stride=2) 44 | self.layer3x3_4 = self._make_layer3(BasicBlock3x3, 64, layers[3], stride=2) 45 | self.maxpool3 = nn.AvgPool1d(kernel_size=16, stride=1, padding=0) 46 | 47 | self.layer5x5_1 = self._make_layer5(BasicBlock5x5, 32, layers[0], stride=1) 48 | self.layer5x5_2 = self._make_layer5(BasicBlock5x5, 32, layers[1], stride=1) 49 | self.layer5x5_3 = self._make_layer5(BasicBlock5x5, 48, layers[2], stride=2) 50 | self.layer5x5_4 = self._make_layer5(BasicBlock5x5, 64, layers[3], stride=2) 51 | self.maxpool5 = nn.AvgPool1d(kernel_size=11, stride=1, padding=0) 52 | 53 | self.layer7x7_1 = self._make_layer7(BasicBlock7x7, 32, layers[0], stride=1) 54 | self.layer7x7_2 = self._make_layer7(BasicBlock7x7, 32, layers[1], stride=1) 55 | self.layer7x7_3 = self._make_layer7(BasicBlock7x7, 48, layers[2], stride=2) 56 | self.layer7x7_4 = self._make_layer7(BasicBlock7x7, 64, layers[3], stride=2) 57 | self.maxpool7 = nn.AvgPool1d(kernel_size=6, stride=1, padding=0) 58 | 59 | def forward(self, x0): 60 | b = x0.shape[0] 61 | x0 = self.conv1(x0) 62 | x0 = self.bn1(x0) 63 | x0 = self.relu(x0) 64 | x0 = self.maxpool(x0) 65 | 66 | x1 = self.layer3x3_1(x0) 67 | x1 = self.layer3x3_2(x1) 68 | x1 = self.layer3x3_3(x1) 69 | x1 = self.layer3x3_4(x1) 70 | x1 = self.maxpool3(x1) 71 | 72 | x2 = self.layer5x5_1(x0) 73 | x2 = self.layer5x5_2(x2) 74 | x2 = self.layer5x5_3(x2) 75 | x2 = self.layer5x5_4(x2) 76 | x2 = self.maxpool5(x2) 77 | 78 | x3 = self.layer7x7_1(x0) 79 | x3 = self.layer7x7_2(x3) 80 | x3 = self.layer7x7_3(x3) 81 | x3 = self.layer7x7_4(x3) 82 | x3 = self.maxpool7(x3) 83 | 84 | out = torch.cat([x1, x2, x3], dim=-1) 85 | out = torch.reshape(out, [b, -1]) 86 | return out 87 | 88 | def _make_layer3(self, block, planes, blocks, stride=2): 89 | downsample = None 90 | if stride != 1 or self.inplanes3 != planes * block.expansion: 91 | downsample = nn.Sequential( 92 | nn.Conv1d(self.inplanes3, planes * block.expansion, 93 | kernel_size=1, stride=stride, bias=False), 94 | nn.BatchNorm1d(planes * block.expansion), 95 | ) 96 | 97 | layers = list() 98 | layers.append(block(self.inplanes3, planes, stride, downsample)) 99 | self.inplanes3 = planes * block.expansion 100 | for i in range(1, blocks): 101 | layers.append(block(self.inplanes3, planes)) 102 | 103 | return nn.Sequential(*layers) 104 | 105 | def _make_layer5(self, block, planes, blocks, stride=2): 106 | downsample = None 107 | if stride != 1 or self.inplanes5 != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv1d(self.inplanes5, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm1d(planes * block.expansion), 112 | ) 113 | 114 | layers = list() 115 | layers.append(block(self.inplanes5, planes, stride, downsample)) 116 | self.inplanes5 = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes5, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def _make_layer7(self, block, planes, blocks, stride=2): 123 | downsample = None 124 | if stride != 1 or self.inplanes7 != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv1d(self.inplanes7, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm1d(planes * block.expansion), 129 | ) 130 | 131 | layers = list() 132 | layers.append(block(self.inplanes7, planes, stride, downsample)) 133 | self.inplanes7 = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes7, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def get_final_length(self): 140 | x = torch.randn(1, 1, self.input_size) 141 | x = self.forward(x) 142 | return x.shape[-1] 143 | 144 | 145 | class BasicBlock3x3(nn.Module): 146 | expansion = 1 147 | 148 | def __init__(self, inplanes3, planes, stride=1, downsample=None): 149 | super(BasicBlock3x3, self).__init__() 150 | self.conv1 = conv3x3(inplanes3, planes, stride) 151 | self.bn1 = nn.BatchNorm1d(planes) 152 | self.relu = nn.ELU(inplace=True) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = nn.BatchNorm1d(planes) 155 | self.downsample = downsample 156 | self.stride = stride 157 | 158 | def forward(self, x): 159 | residual = x 160 | 161 | out = self.conv1(x) 162 | out = self.bn1(out) 163 | out = self.relu(out) 164 | 165 | out = self.conv2(out) 166 | out = self.bn2(out) 167 | 168 | if self.downsample is not None: 169 | residual = self.downsample(x) 170 | 171 | out += residual 172 | out = self.relu(out) 173 | 174 | return out 175 | 176 | 177 | class BasicBlock5x5(nn.Module): 178 | expansion = 1 179 | 180 | def __init__(self, inplanes5, planes, stride=1, downsample=None): 181 | super(BasicBlock5x5, self).__init__() 182 | self.conv1 = conv5x5(inplanes5, planes, stride) 183 | self.bn1 = nn.BatchNorm1d(planes) 184 | self.relu = nn.ELU(inplace=True) 185 | self.conv2 = conv5x5(planes, planes) 186 | self.bn2 = nn.BatchNorm1d(planes) 187 | self.downsample = downsample 188 | self.stride = stride 189 | 190 | def forward(self, x): 191 | residual = x 192 | 193 | out = self.conv1(x) 194 | out = self.bn1(out) 195 | out = self.relu(out) 196 | 197 | out = self.conv2(out) 198 | out = self.bn2(out) 199 | 200 | if self.downsample is not None: 201 | residual = self.downsample(x) 202 | 203 | d = residual.shape[2] - out.shape[2] 204 | out1 = residual[:, :, 0:-d] + out 205 | out1 = self.relu(out1) 206 | return out1 207 | 208 | 209 | class BasicBlock7x7(nn.Module): 210 | expansion = 1 211 | 212 | def __init__(self, inplanes7, planes, stride=1, downsample=None): 213 | super(BasicBlock7x7, self).__init__() 214 | self.conv1 = conv7x7(inplanes7, planes, stride) 215 | self.bn1 = nn.BatchNorm1d(planes) 216 | self.relu = nn.ELU(inplace=True) 217 | self.conv2 = conv7x7(planes, planes) 218 | self.bn2 = nn.BatchNorm1d(planes) 219 | self.downsample = downsample 220 | self.stride = stride 221 | 222 | def forward(self, x): 223 | residual = x 224 | 225 | out = self.conv1(x) 226 | out = self.bn1(out) 227 | out = self.relu(out) 228 | 229 | out = self.conv2(out) 230 | out = self.bn2(out) 231 | 232 | if self.downsample is not None: 233 | residual = self.downsample(x) 234 | 235 | d = residual.shape[2] - out.shape[2] 236 | out1 = residual[:, :, 0:-d] + out 237 | out1 = self.relu(out1) 238 | return out1 239 | 240 | 241 | def conv3x3(in_planes, out_planes, stride=1): 242 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, 243 | padding=1, bias=False) 244 | 245 | 246 | def conv5x5(in_planes, out_planes, stride=1): 247 | return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride, 248 | padding=1, bias=False) 249 | 250 | 251 | def conv7x7(in_planes, out_planes, stride=1): 252 | return nn.Conv1d(in_planes, out_planes, kernel_size=7, stride=stride, 253 | padding=1, bias=False) 254 | 255 | 256 | if __name__ == '__main__': 257 | # st = ST_BackBone(input_size=500) 258 | # ss = st( 259 | # torch.randn(50, 1, 500) 260 | # ) 261 | fb = FrameBackBone(fs=100, window=3) 262 | ss = fb( 263 | torch.randn(50, 1, 300) 264 | ) 265 | print(ss.shape) 266 | 267 | 268 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /pretrained/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import sys 4 | sys.path.extend([os.path.abspath('.'), os.path.abspath('..')]) 5 | 6 | import mne 7 | import torch 8 | import random 9 | import shutil 10 | import argparse 11 | import warnings 12 | import numpy as np 13 | import torch.optim as opt 14 | from models.utils import model_size 15 | from sklearn.decomposition import PCA 16 | from torch.utils.tensorboard import SummaryWriter 17 | from sklearn.neighbors import KNeighborsClassifier 18 | from sklearn.metrics import accuracy_score, f1_score 19 | from torch.utils.data import DataLoader 20 | from dataset.utils import split_train_test_val_files 21 | from pretrained.data_loader import TorchDataset 22 | from models.neuronet.model import NeuroNet 23 | 24 | 25 | warnings.filterwarnings(action='ignore') 26 | 27 | 28 | random_seed = 777 29 | np.random.seed(random_seed) 30 | torch.manual_seed(random_seed) 31 | random.seed(random_seed) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | mne.set_log_level(False) 36 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 37 | 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser() 41 | # Dataset 42 | parser.add_argument('--base_path', default=os.path.join('..', '..', '..', 'data', 'stage', 'Sleep-EDFX-2018')) 43 | parser.add_argument('--k_splits', default=5) 44 | parser.add_argument('--n_fold', default=0, choices=[0, 1, 2, 3, 4]) 45 | 46 | # Dataset Hyperparameter 47 | parser.add_argument('--sfreq', default=100, type=int) 48 | parser.add_argument('--rfreq', default=100, type=int) 49 | parser.add_argument('--data_scaler', default=False, type=bool) 50 | 51 | # Train Hyperparameter 52 | parser.add_argument('--train_epochs', default=30, type=int) 53 | parser.add_argument('--train_warmup_epoch', type=int, default=100) 54 | parser.add_argument('--train_base_learning_rate', default=1e-5, type=float) 55 | parser.add_argument('--train_batch_size', default=256, type=int) 56 | parser.add_argument('--train_batch_accumulation', default=1, type=int) 57 | 58 | # Model Hyperparameter 59 | parser.add_argument('--second', default=30, type=int) 60 | parser.add_argument('--time_window', default=3, type=int) 61 | parser.add_argument('--time_step', default=0.375, type=int) 62 | 63 | # >> 1. NeuroNet-M Hyperparameter 64 | # parser.add_argument('--encoder_dim', default=512, type=int) 65 | # parser.add_argument('--encoder_heads', default=8, type=int) 66 | # parser.add_argument('--encoder_depths', default=4, type=int) 67 | # parser.add_argument('--decoder_embed_dim', default=192, type=int) 68 | # parser.add_argument('--decoder_heads', default=8, type=int) 69 | # parser.add_argument('--decoder_depths', default=1, type=int) 70 | 71 | # >> 2. NeuroNet-B Hyperparameter 72 | parser.add_argument('--encoder_embed_dim', default=768, type=int) 73 | parser.add_argument('--encoder_heads', default=8, type=int) 74 | parser.add_argument('--encoder_depths', default=4, type=int) 75 | parser.add_argument('--decoder_embed_dim', default=256, type=int) 76 | parser.add_argument('--decoder_heads', default=8, type=int) 77 | parser.add_argument('--decoder_depths', default=3, type=int) 78 | parser.add_argument('--alpha', default=1.0, type=float) 79 | 80 | parser.add_argument('--projection_hidden', default=[1024, 512], type=list) 81 | parser.add_argument('--temperature', default=0.05, type=float) 82 | parser.add_argument('--mask_ratio', default=0.8, type=float) 83 | parser.add_argument('--print_point', default=20, type=int) 84 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt', 'Sleep-EDFX'), type=str) 85 | parser.add_argument('--model_name', default='mini') 86 | return parser.parse_args() 87 | 88 | 89 | class Trainer(object): 90 | def __init__(self, args): 91 | self.args = args 92 | self.model = NeuroNet( 93 | fs=args.rfreq, second=args.second, time_window=args.time_window, time_step=args.time_step, 94 | encoder_embed_dim=args.encoder_embed_dim, encoder_heads=args.encoder_heads, encoder_depths=args.encoder_depths, 95 | decoder_embed_dim=args.decoder_embed_dim, decoder_heads=args.decoder_heads, 96 | decoder_depths=args.decoder_depths, projection_hidden=args.projection_hidden, temperature=args.temperature 97 | ).to(device) 98 | print('Model Size : {0:.2f}MB'.format(model_size(self.model))) 99 | 100 | self.eff_batch_size = self.args.train_batch_size * self.args.train_batch_accumulation 101 | self.lr = self.args.train_base_learning_rate * self.eff_batch_size / 256 102 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.lr) 103 | self.train_paths, self.val_paths, self.eval_paths = self.data_paths() 104 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.train_epochs) 105 | self.tensorboard_path = os.path.join(self.args.ckpt_path, self.args.model_name, 106 | str(self.args.n_fold), 'tensorboard') 107 | 108 | # remote tensorboard files 109 | if os.path.exists(self.tensorboard_path): 110 | shutil.rmtree(self.tensorboard_path) 111 | 112 | self.tensorboard_writer = SummaryWriter(log_dir=self.tensorboard_path) 113 | 114 | print('Frame Size : {}'.format(self.model.num_patches)) 115 | print('Leaning Rate : {0}'.format(self.lr)) 116 | print('Validation Paths : {0}'.format(len(self.val_paths))) 117 | print('Evaluation Paths : {0}'.format(len(self.eval_paths))) 118 | 119 | def train(self): 120 | print('K-Fold : {}/{}'.format(self.args.n_fold + 1, self.args.k_splits)) 121 | train_dataset = TorchDataset(paths=self.train_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq, 122 | scaler=self.args.data_scaler) 123 | train_dataloader = DataLoader(train_dataset, batch_size=self.args.train_batch_size, shuffle=True) 124 | val_dataset = TorchDataset(paths=self.val_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq, 125 | scaler=self.args.data_scaler) 126 | val_dataloader = DataLoader(val_dataset, batch_size=self.args.train_batch_size, drop_last=True) 127 | eval_dataset = TorchDataset(paths=self.eval_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq, 128 | scaler=self.args.data_scaler) 129 | eval_dataloader = DataLoader(eval_dataset, batch_size=self.args.train_batch_size, drop_last=True) 130 | 131 | total_step = 0 132 | best_model_state, best_score = self.model.state_dict(), 0 133 | 134 | for epoch in range(self.args.train_epochs): 135 | step = 0 136 | self.model.train() 137 | self.optimizer.zero_grad() 138 | 139 | for x, _ in train_dataloader: 140 | x = x.to(device) 141 | out = self.model(x, mask_ratio=self.args.mask_ratio) 142 | recon_loss, contrastive_loss, (cl_labels, cl_logits) = out 143 | 144 | loss = recon_loss + self.args.alpha * contrastive_loss 145 | loss.backward() 146 | 147 | if (step + 1) % self.args.train_batch_accumulation == 0: 148 | self.optimizer.step() 149 | self.optimizer.zero_grad() 150 | 151 | if (total_step + 1) % self.args.print_point == 0: 152 | print('[Epoch] : {0:03d} [Step] : {1:06d} ' 153 | '[Reconstruction Loss] : {2:02.4f} [Contrastive Loss] : {3:02.4f} ' 154 | '[Total Loss] : {4:02.4f} [Contrastive Acc] : {5:02.4f}'.format( 155 | epoch, total_step + 1, recon_loss, contrastive_loss, loss, 156 | self.compute_metrics(cl_logits, cl_labels))) 157 | 158 | self.tensorboard_writer.add_scalar('Reconstruction Loss', recon_loss, total_step) 159 | self.tensorboard_writer.add_scalar('Contrastive loss', contrastive_loss, total_step) 160 | self.tensorboard_writer.add_scalar('Total loss', loss, total_step) 161 | 162 | step += 1 163 | total_step += 1 164 | 165 | val_acc, val_mf1 = self.linear_probing(val_dataloader, eval_dataloader) 166 | 167 | if val_mf1 > best_score: 168 | best_model_state = self.model.state_dict() 169 | best_score = val_mf1 170 | 171 | print('[Epoch] : {0:03d} \t [Accuracy] : {1:2.4f} \t [Macro-F1] : {2:2.4f} \n'.format( 172 | epoch, val_acc * 100, val_mf1 * 100)) 173 | self.tensorboard_writer.add_scalar('Validation Accuracy', val_acc, total_step) 174 | self.tensorboard_writer.add_scalar('Validation Macro-F1', val_mf1, total_step) 175 | 176 | self.optimizer.step() 177 | self.scheduler.step() 178 | 179 | self.save_ckpt(model_state=best_model_state) 180 | 181 | def linear_probing(self, val_dataloader, eval_dataloader): 182 | self.model.eval() 183 | (train_x, train_y), (test_x, test_y) = self.get_latent_vector(val_dataloader), \ 184 | self.get_latent_vector(eval_dataloader) 185 | pca = PCA(n_components=50) 186 | train_x = pca.fit_transform(train_x) 187 | test_x = pca.transform(test_x) 188 | 189 | model = KNeighborsClassifier() 190 | model.fit(train_x, train_y) 191 | 192 | out = model.predict(test_x) 193 | acc, mf1 = accuracy_score(test_y, out), f1_score(test_y, out, average='macro') 194 | self.model.train() 195 | return acc, mf1 196 | 197 | def get_latent_vector(self, dataloader): 198 | total_x, total_y = [], [] 199 | with torch.no_grad(): 200 | for data in dataloader: 201 | x, y = data 202 | x, y = x.to(device), y.to(device) 203 | latent = self.model.forward_latent(x) 204 | total_x.append(latent.detach().cpu().numpy()) 205 | total_y.append(y.detach().cpu().numpy()) 206 | total_x, total_y = np.concatenate(total_x, axis=0), np.concatenate(total_y, axis=0) 207 | return total_x, total_y 208 | 209 | def save_ckpt(self, model_state): 210 | ckpt_path = os.path.join(self.args.ckpt_path, self.args.model_name, str(self.args.n_fold), 'model') 211 | if not os.path.exists(ckpt_path): 212 | os.makedirs(ckpt_path) 213 | 214 | torch.save({ 215 | 'model_name': 'NeuroNet', 216 | 'model_state': model_state, 217 | 'model_parameter': { 218 | 'fs': self.args.rfreq, 'second': self.args.second, 219 | 'time_window': self.args.time_window, 'time_step': self.args.time_step, 220 | 'encoder_embed_dim': self.args.encoder_embed_dim, 'encoder_heads': self.args.encoder_heads, 221 | 'encoder_depths': self.args.encoder_depths, 222 | 'decoder_embed_dim': self.args.decoder_embed_dim, 'decoder_heads': self.args.decoder_heads, 223 | 'decoder_depths': self.args.decoder_depths, 224 | 'projection_hidden': self.args.projection_hidden, 'temperature': self.args.temperature 225 | }, 226 | 'hyperparameter': self.args.__dict__, 227 | 'paths': {'train_paths': self.train_paths, 'ft_paths': self.val_paths, 'eval_paths': self.eval_paths} 228 | }, os.path.join(ckpt_path, 'best_model.pth')) 229 | 230 | def data_paths(self): 231 | kf = split_train_test_val_files(base_path=self.args.base_path, n_splits=self.args.k_splits) 232 | 233 | paths = kf[self.args.n_fold] 234 | train_paths, ft_paths, eval_paths = paths['train_paths'], paths['ft_paths'], paths['eval_paths'] 235 | return train_paths, ft_paths, eval_paths 236 | 237 | @staticmethod 238 | def compute_metrics(output, target): 239 | output = output.argmax(dim=-1) 240 | accuracy = torch.mean(torch.eq(target, output).to(torch.float32)) 241 | return accuracy 242 | 243 | 244 | if __name__ == '__main__': 245 | augments = get_args() 246 | for n_fold in range(augments.k_splits): 247 | augments.n_fold = n_fold 248 | trainer = Trainer(augments) 249 | trainer.train() 250 | -------------------------------------------------------------------------------- /models/neuronet/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from typing import List 6 | from models.neuronet.resnet1d import FrameBackBone 7 | from timm.models.vision_transformer import Block 8 | from models.utils import get_2d_sincos_pos_embed_flexible 9 | from models.loss import NTXentLoss 10 | from functools import partial 11 | 12 | 13 | class NeuroNet(nn.Module): 14 | def __init__(self, fs: int, second: int, time_window: int, time_step: float, 15 | encoder_embed_dim, encoder_heads: int, encoder_depths: int, 16 | decoder_embed_dim: int, decoder_heads: int, decoder_depths: int, 17 | projection_hidden: List, temperature=0.01): 18 | super().__init__() 19 | self.fs, self.second = fs, second 20 | self.time_window = time_window 21 | self.time_step = time_step 22 | 23 | self.num_patches, _ = frame_size(fs=fs, second=second, time_window=time_window, time_step=time_step) 24 | self.frame_backbone = FrameBackBone(fs=self.fs, window=self.time_window) 25 | self.autoencoder = MaskedAutoEncoderViT(input_size=self.frame_backbone.feature_num, 26 | encoder_embed_dim=encoder_embed_dim, num_patches=self.num_patches, 27 | encoder_heads=encoder_heads, encoder_depths=encoder_depths, 28 | decoder_embed_dim=decoder_embed_dim, decoder_heads=decoder_heads, 29 | decoder_depths=decoder_depths) 30 | self.contrastive_loss = NTXentLoss(temperature=temperature) 31 | 32 | projection_hidden = [encoder_embed_dim] + projection_hidden 33 | projectors = [] 34 | for i, (h1, h2) in enumerate(zip(projection_hidden[:-1], projection_hidden[1:])): 35 | if i != len(projection_hidden) - 2: 36 | projectors.append(nn.Linear(h1, h2)) 37 | projectors.append(nn.BatchNorm1d(h2)) 38 | projectors.append(nn.ELU()) 39 | else: 40 | projectors.append(nn.Linear(h1, h2)) 41 | self.projectors = nn.Sequential(*projectors) 42 | self.projectors_bn = nn.BatchNorm1d(projection_hidden[-1], affine=False) 43 | self.norm_pix_loss = False 44 | 45 | def forward(self, x: torch.Tensor, mask_ratio: float = 0.5) -> (torch.Tensor, torch.Tensor): 46 | x = self.make_frame(x) 47 | x = self.frame_backbone(x) 48 | 49 | # Masked Prediction 50 | latent1, pred1, mask1 = self.autoencoder(x, mask_ratio) 51 | latent2, pred2, mask2 = self.autoencoder(x, mask_ratio) 52 | o1, o2 = latent1[:, :1, :].squeeze(), latent2[:, :1, :].squeeze() 53 | recon_loss1 = self.forward_mae_loss(x, pred1, mask1) 54 | recon_loss2 = self.forward_mae_loss(x, pred2, mask2) 55 | recon_loss = recon_loss1 + recon_loss2 56 | 57 | # Contrastive Learning 58 | o1, o2 = self.projectors(o1), self.projectors(o2) 59 | contrastive_loss, (labels, logits) = self.contrastive_loss(o1, o2) 60 | print(contrastive_loss) 61 | return recon_loss, contrastive_loss, (labels, logits) 62 | 63 | def forward_latent(self, x: torch.Tensor): 64 | x = self.make_frame(x) 65 | x = self.frame_backbone(x) 66 | latent = self.autoencoder.forward_encoder(x, mask_ratio=0)[0] 67 | latent_o = latent[:, :1, :].squeeze() 68 | return latent_o 69 | 70 | def forward_mae_loss(self, 71 | real: torch.Tensor, 72 | pred: torch.Tensor, 73 | mask: torch.Tensor): 74 | 75 | if self.norm_pix_loss: 76 | mean = real.mean(dim=-1, keepdim=True) 77 | var = real.var(dim=-1, keepdim=True) 78 | real = (real - mean) / (var + 1.e-6) ** .5 79 | 80 | loss = (pred - real) ** 2 81 | loss = loss.mean(dim=-1) 82 | loss = (loss * mask).sum() / mask.sum() 83 | return loss 84 | 85 | def make_frame(self, x): 86 | size = self.fs * self.second 87 | step = int(self.time_step * self.fs) 88 | window = int(self.time_window * self.fs) 89 | frame = [] 90 | for i in range(0, size, step): 91 | start_idx, end_idx = i, i+window 92 | sample = x[..., start_idx: end_idx] 93 | if sample.shape[-1] == window: 94 | frame.append(sample) 95 | frame = torch.stack(frame, dim=1) 96 | return frame 97 | 98 | 99 | class MaskedAutoEncoderViT(nn.Module): 100 | def __init__(self, input_size: int, num_patches: int, 101 | encoder_embed_dim: int, encoder_heads: int, encoder_depths: int, 102 | decoder_embed_dim: int, decoder_heads: int, decoder_depths: int): 103 | super().__init__() 104 | self.patch_embed = nn.Linear(input_size, encoder_embed_dim) 105 | self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim)) 106 | self.embed_dim = encoder_embed_dim 107 | self.encoder_depths = encoder_depths 108 | self.mlp_ratio = 4. 109 | 110 | self.input_size = (num_patches, encoder_embed_dim) 111 | self.patch_size = (1, encoder_embed_dim) 112 | self.grid_h = int(self.input_size[0] // self.patch_size[0]) 113 | self.grid_w = int(self.input_size[1] // self.patch_size[1]) 114 | self.num_patches = self.grid_h * self.grid_w 115 | 116 | # MAE Encoder 117 | self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, encoder_embed_dim), requires_grad=False) 118 | self.encoder_block = nn.ModuleList([ 119 | Block(encoder_embed_dim, encoder_heads, self.mlp_ratio, qkv_bias=True, 120 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 121 | for _ in range(encoder_depths) 122 | ]) 123 | self.encoder_norm = nn.LayerNorm(encoder_embed_dim, eps=1e-6) 124 | 125 | # MAE Decoder 126 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) 127 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 128 | self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.num_patches, decoder_embed_dim), requires_grad=False) 129 | self.decoder_block = nn.ModuleList([ 130 | Block(decoder_embed_dim, decoder_heads, self.mlp_ratio, qkv_bias=True, 131 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 132 | for _ in range(decoder_depths) 133 | ]) 134 | self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6) 135 | self.decoder_pred = nn.Linear(decoder_embed_dim, input_size, bias=True) 136 | self.initialize_weights() 137 | 138 | def forward(self, x, mask_ratio=0.8): 139 | latent, mask, ids_restore = self.forward_encoder(x, mask_ratio) 140 | pred = self.forward_decoder(latent, ids_restore) 141 | return latent, pred, mask 142 | 143 | def forward_encoder(self, x: torch.Tensor, mask_ratio: float = 0.5): 144 | # embed patches 145 | x = self.patch_embed(x) 146 | 147 | # add pos embed w/o cls token 148 | x = x + self.pos_embed[:, 1:, :] 149 | 150 | # masking: length -> length * mask_ratio 151 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 152 | 153 | # append cls token 154 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 155 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 156 | x = torch.cat((cls_tokens, x), dim=1) 157 | 158 | # apply Transformer blocks 159 | for block in self.encoder_block: 160 | x = block(x) 161 | 162 | x = self.encoder_norm(x) 163 | return x, mask, ids_restore 164 | 165 | def forward_decoder(self, x, ids_restore: torch.Tensor): 166 | # embed tokens 167 | x = self.decoder_embed(x[:, 1:, :]) 168 | 169 | # append mask tokens to sequence 170 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) 171 | x_ = torch.cat([x, mask_tokens], dim=1) # no cls token 172 | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 173 | 174 | # add pos embed 175 | x = x + self.decoder_pos_embed 176 | 177 | # apply Transformer blocks 178 | for block in self.decoder_block: 179 | x = block(x) 180 | 181 | x = self.decoder_norm(x) 182 | 183 | # predictor projection 184 | x = self.decoder_pred(x) 185 | return x 186 | 187 | @staticmethod 188 | def random_masking(x, mask_ratio): 189 | n, l, d = x.shape # batch, length, dim 190 | len_keep = int(l * (1 - mask_ratio)) 191 | 192 | noise = torch.rand(n, l, device=x.device) # noise in [0, 1] 193 | 194 | # sort noise for each sample 195 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 196 | ids_restore = torch.argsort(ids_shuffle, dim=1) 197 | 198 | # keep the first subset 199 | ids_keep = ids_shuffle[:, :len_keep] 200 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, d)) 201 | 202 | # generate the binary mask: 0 is keep, 1 is remove 203 | mask = torch.ones([n, l], device=x.device) 204 | mask[:, :len_keep] = 0 205 | 206 | mask = torch.gather(mask, dim=1, index=ids_restore) 207 | return x_masked, mask, ids_restore 208 | 209 | def initialize_weights(self): 210 | # initialization 211 | # initialize (and freeze) pos_embed by sin-cos embedding 212 | pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], 213 | (self.grid_h, self.grid_w), 214 | cls_token=True) 215 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 216 | decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], 217 | (self.grid_h, self.grid_w), 218 | cls_token=False) 219 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 220 | 221 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 222 | torch.nn.init.normal_(self.cls_token, std=.02) 223 | torch.nn.init.normal_(self.mask_token, std=.02) 224 | 225 | # initialize nn.Linear and nn.LayerNorm 226 | self.apply(self._init_weights) 227 | 228 | @staticmethod 229 | def _init_weights(m): 230 | if isinstance(m, nn.Linear): 231 | # we use xavier_uniform following official JAX ViT: 232 | torch.nn.init.xavier_uniform_(m.weight) 233 | if isinstance(m, nn.Linear) and m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | elif isinstance(m, nn.LayerNorm): 236 | nn.init.constant_(m.bias, 0) 237 | nn.init.constant_(m.weight, 1.0) 238 | 239 | 240 | def frame_size(fs, second, time_window, time_step): 241 | x = np.random.randn(1, fs * second) 242 | size = fs * second 243 | step = int(time_step * fs) 244 | window = int(time_window * fs) 245 | frame = [] 246 | for i in range(0, size, step): 247 | start_idx, end_idx = i, i + window 248 | sample = x[..., start_idx: end_idx] 249 | if sample.shape[-1] == window: 250 | frame.append(sample) 251 | frame = np.stack(frame, axis=1) 252 | return frame.shape[1], frame.shape[2] 253 | 254 | 255 | class NeuroNetEncoderWrapper(nn.Module): 256 | def __init__(self, fs: int, second: int, time_window: int, time_step: float, 257 | frame_backbone, patch_embed, encoder_block, encoder_norm, cls_token, pos_embed, 258 | final_length): 259 | 260 | super().__init__() 261 | self.fs, self.second = fs, second 262 | self.time_window = time_window 263 | self.time_step = time_step 264 | 265 | self.patch_embed = patch_embed 266 | self.frame_backbone = frame_backbone 267 | self.encoder_block = encoder_block 268 | self.encoder_norm = encoder_norm 269 | self.cls_token = cls_token 270 | self.pos_embed = pos_embed 271 | 272 | self.final_length = final_length 273 | 274 | def forward(self, x, semantic_token=True): 275 | # frame backbone 276 | x = self.make_frame(x) 277 | x = self.frame_backbone(x) 278 | 279 | # embed patches 280 | x = self.patch_embed(x) 281 | 282 | # add pos embed w/o cls token 283 | x = x + self.pos_embed[:, 1:, :] 284 | 285 | # append cls token 286 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 287 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 288 | x = torch.cat((cls_tokens, x), dim=1) 289 | 290 | # apply Transformer blocks 291 | for block in self.encoder_block: 292 | x = block(x) 293 | 294 | x = self.encoder_norm(x) 295 | 296 | # get semantic information 297 | x = torch.mean(x[:, 1:, :], dim=1) 298 | return x 299 | 300 | def make_frame(self, x): 301 | size = self.fs * self.second 302 | step = int(self.time_step * self.fs) 303 | window = int(self.time_window * self.fs) 304 | frame = [] 305 | for i in range(0, size, step): 306 | start_idx, end_idx = i, i+window 307 | sample = x[..., start_idx: end_idx] 308 | if sample.shape[-1] == window: 309 | frame.append(sample) 310 | frame = torch.stack(frame, dim=1) 311 | return frame 312 | 313 | 314 | if __name__ == '__main__': 315 | x0 = torch.randn((50, 3000)) 316 | m0 = NeuroNet(fs=100, second=30, time_window=5, time_step=0.5, 317 | encoder_embed_dim=256, encoder_depths=6, encoder_heads=8, 318 | decoder_embed_dim=128, decoder_heads=4, decoder_depths=8, 319 | projection_hidden=[1024, 512]) 320 | m0.forward(x0, mask_ratio=0.5) 321 | -------------------------------------------------------------------------------- /downstream/fine_tuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import mne 4 | import torch 5 | import random 6 | import argparse 7 | import warnings 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.optim as opt 11 | from mamba_ssm import Mamba 12 | from models.utils import model_size 13 | from torch.utils.data import Dataset, DataLoader 14 | from models.neuronet.model import NeuroNet, NeuroNetEncoderWrapper 15 | from sklearn.metrics import accuracy_score, f1_score 16 | 17 | 18 | warnings.filterwarnings(action='ignore') 19 | 20 | 21 | random_seed = 777 22 | np.random.seed(random_seed) 23 | torch.manual_seed(random_seed) 24 | random.seed(random_seed) 25 | 26 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cuda.enable_mem_efficient_sdp(False) 30 | torch.backends.cuda.enable_flash_sdp(False) 31 | torch.backends.cuda.enable_math_sdp(True) 32 | 33 | 34 | def get_args(): 35 | file_name = 'mini' 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--n_fold', default=1, choices=[0, 1, 2, 3, 4]) 38 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt', 39 | 'SHHS', 'cm_eeg', file_name), type=str) 40 | parser.add_argument('--temporal_context_length', default=20) 41 | parser.add_argument('--window_size', default=10) 42 | parser.add_argument('--epochs', default=150, type=int) 43 | parser.add_argument('--batch_size', default=64, type=int) 44 | parser.add_argument('--lr', default=0.0005, type=float) 45 | 46 | parser.add_argument('--embed_dim', default=256) 47 | parser.add_argument('--temporal_context_modules', choices=['lstm', 'mha', 'lstm_mha', 'mamba'], default='mamba') 48 | parser.add_argument('--save_path', default=os.path.join('..', '..', '..', 49 | 'ckpt', 'SHHS', 'cm_eeg', 50 | file_name), type=str) 51 | return parser.parse_args() 52 | 53 | 54 | class TemporalContextModule(nn.Module): 55 | def __init__(self, backbone, backbone_final_length, embed_dim): 56 | super().__init__() 57 | self.backbone = self.freeze_backbone(backbone) 58 | self.backbone_final_length = backbone_final_length 59 | self.embed_dim = embed_dim 60 | self.embed_layer = nn.Sequential( 61 | nn.Linear(backbone_final_length, embed_dim), 62 | nn.BatchNorm1d(embed_dim), 63 | nn.ELU(), 64 | nn.Linear(embed_dim, embed_dim) 65 | ) 66 | 67 | def apply_backbone(self, x): 68 | out = [] 69 | for x_ in torch.split(x, dim=1, split_size_or_sections=1): 70 | o = self.backbone(x_.squeeze()) 71 | o = self.embed_layer(o) 72 | out.append(o) 73 | out = torch.stack(out, dim=1) 74 | return out 75 | 76 | @staticmethod 77 | def freeze_backbone(backbone: nn.Module): 78 | for name, module in backbone.named_modules(): 79 | if name in ['encoder_block.3.ls1', 'encoder_block.3.drop_path1', 'encoder_block.3.norm2', 80 | 'encoder_block.3.mlp', 'encoder_block.3.mlp.fc1', 'encoder_block.3.mlp.act', 81 | 'encoder_block.3.mlp.drop1', 'encoder_block.3.mlp.norm', 'encoder_block.3.mlp.fc2', 82 | 'encoder_block.3.mlp.drop2', 'encoder_block.3.ls2', 'encoder_block.3.drop_path2', 83 | 'encoder_norm']: 84 | for param in module.parameters(): 85 | param.requires_grad = True 86 | else: 87 | for param in module.parameters(): 88 | param.requires_grad = False 89 | return backbone 90 | 91 | 92 | class LSTM_TCM(TemporalContextModule): 93 | def __init__(self, backbone, backbone_final_length, embed_dim): 94 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim) 95 | self.rnn_layer = 2 96 | self.lstm = nn.LSTM(input_size=self.embed_dim, hidden_size=self.embed_dim, num_layers=self.rnn_layer) 97 | self.fc = nn.Linear(self.embed_dim, 5) 98 | 99 | def forward(self, x): 100 | x = self.apply_backbone(x) 101 | x, _ = self.lstm(x) 102 | x = self.fc(x) 103 | return x 104 | 105 | 106 | class MHA_TCM(TemporalContextModule): 107 | def __init__(self, backbone, backbone_final_length, embed_dim): 108 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim) 109 | self.mha_heads = 8 110 | self.mha_layer = 2 111 | self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.embed_dim, self.mha_heads), 112 | num_layers=self.mha_layer) 113 | self.fc = nn.Linear(self.embed_dim, 5) 114 | 115 | def forward(self, x): 116 | x = self.apply_backbone(x) 117 | x = self.transformer(x) 118 | x = self.fc(x) 119 | return x 120 | 121 | 122 | class LSTM_MHA_TCM(TemporalContextModule): 123 | def __init__(self, backbone, backbone_final_length, embed_dim): 124 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim) 125 | self.mha_heads = 8 126 | self.mha_layer = 2 127 | self.rnn_layer = 1 128 | self.lstm = nn.LSTM(input_size=self.embed_dim, hidden_size=self.embed_dim, num_layers=self.rnn_layer) 129 | self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.embed_dim, self.mha_heads), 130 | num_layers=self.mha_layer) 131 | self.fc = nn.Linear(self.embed_dim, 5) 132 | 133 | def forward(self, x): 134 | x = self.apply_backbone(x) 135 | x, _ = self.lstm(x) 136 | x = self.transformer(x) 137 | x = self.fc(x) 138 | return x 139 | 140 | 141 | class MAMBA_TCM(TemporalContextModule): 142 | def __init__(self, backbone, backbone_final_length, embed_dim): 143 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim) 144 | self.mamba_heads = 8 145 | self.mamba_layer = 1 146 | self.mamba = nn.Sequential(*[ 147 | Mamba(d_model=self.embed_dim, 148 | d_state=16, 149 | d_conv=4, 150 | expand=2) 151 | for _ in range(self.mamba_layer) 152 | ]) 153 | self.fc = nn.Linear(self.embed_dim, 5) 154 | 155 | def forward(self, x): 156 | x = self.apply_backbone(x) 157 | x = self.mamba(x) 158 | x = self.fc(x) 159 | return x 160 | 161 | 162 | class Trainer(object): 163 | def __init__(self, args): 164 | super().__init__() 165 | self.args = args 166 | self.ckpt_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'model', 'best_model.pth') 167 | self.ckpt = torch.load(self.ckpt_path, map_location='cpu') 168 | 169 | self.sfreq, self.rfreq = self.ckpt['hyperparameter']['sfreq'], self.ckpt['hyperparameter']['rfreq'] 170 | self.ft_paths, self.eval_paths = self.ckpt['paths']['ft_paths'], self.ckpt['paths']['eval_paths'] 171 | self.model = self.get_pretrained_model().to(device) 172 | 173 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.args.lr) 174 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.epochs) 175 | self.criterion = nn.CrossEntropyLoss() 176 | 177 | def train(self): 178 | print('Checkpoint File Path : {}'.format(self.ckpt_path)) 179 | train_dataset, eval_dataset = TorchDataset(paths=self.ft_paths, 180 | temporal_context_length=self.args.temporal_context_length, 181 | window_size=self.args.window_size, 182 | sfreq=self.sfreq, rfreq=self.rfreq), \ 183 | TorchDataset(paths=self.eval_paths, 184 | temporal_context_length=self.args.temporal_context_length, 185 | window_size=self.args.temporal_context_length, 186 | sfreq=self.sfreq, rfreq=self.rfreq) 187 | 188 | train_dataloader, eval_dataloader = DataLoader(dataset=train_dataset, 189 | batch_size=self.args.batch_size, 190 | shuffle=True), \ 191 | DataLoader(dataset=eval_dataset, 192 | batch_size=self.args.batch_size, 193 | shuffle=False) 194 | 195 | best_model_state, best_mf1 = None, 0.0 196 | best_pred, best_real = [], [] 197 | 198 | for epoch in range(self.args.epochs): 199 | self.model.train() 200 | epoch_train_loss = [] 201 | for batch in train_dataloader: 202 | self.optimizer.zero_grad() 203 | x, y = batch 204 | x, y = x.to(device), y.to(device) 205 | 206 | out = self.model(x) 207 | loss, pred, real = self.get_loss(out, y) 208 | 209 | epoch_train_loss.append(float(loss.detach().cpu().item())) 210 | loss.backward() 211 | self.optimizer.step() 212 | 213 | self.model.eval() 214 | epoch_test_loss = [] 215 | epoch_real, epoch_pred = [], [] 216 | for batch in eval_dataloader: 217 | x, y = batch 218 | x, y = x.to(device), y.to(device) 219 | try: 220 | out = self.model(x) 221 | except IndexError: 222 | continue 223 | loss, pred, real = self.get_loss(out, y) 224 | pred = torch.argmax(pred, dim=-1) 225 | epoch_real.extend(list(real.detach().cpu().numpy())) 226 | epoch_pred.extend(list(pred.detach().cpu().numpy())) 227 | epoch_test_loss.append(float(loss.detach().cpu().item())) 228 | 229 | epoch_train_loss, epoch_test_loss = np.mean(epoch_train_loss), np.mean(epoch_test_loss) 230 | eval_acc, eval_mf1 = accuracy_score(y_true=epoch_real, y_pred=epoch_pred), \ 231 | f1_score(y_true=epoch_real, y_pred=epoch_pred, average='macro') 232 | 233 | print('[Epoch] : {0:03d} \t ' 234 | '[Train Loss] => {1:.4f} \t ' 235 | '[Evaluation Loss] => {2:.4f} \t ' 236 | '[Evaluation Accuracy] => {3:.4f} \t' 237 | '[Evaluation Macro-F1] => {4:.4f}'.format(epoch + 1, epoch_train_loss, epoch_test_loss, 238 | eval_acc, eval_mf1)) 239 | 240 | if best_mf1 < eval_mf1: 241 | best_mf1 = eval_mf1 242 | best_model_state = self.model.state_dict() 243 | best_pred, best_real = epoch_pred, epoch_real 244 | 245 | self.scheduler.step() 246 | 247 | self.save_ckpt(best_model_state, best_pred, best_real) 248 | 249 | def save_ckpt(self, model_state, pred, real): 250 | if not os.path.exists(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning')): 251 | os.makedirs(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning')) 252 | 253 | save_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning', 'best_model.pth') 254 | torch.save({ 255 | 'backbone_name': 'NeuroNet_FineTuning', 256 | 'model_state': model_state, 257 | 'hyperparameter': self.args.__dict__, 258 | 'result': {'real': real, 'pred': pred}, 259 | 'paths': {'train_paths': self.ft_paths, 'eval_paths': self.eval_paths} 260 | }, save_path) 261 | 262 | def get_pretrained_model(self): 263 | # 1. Prepared Pretrained Model 264 | model_parameter = self.ckpt['model_parameter'] 265 | pretrained_model = NeuroNet(**model_parameter) 266 | pretrained_model.load_state_dict(self.ckpt['model_state']) 267 | 268 | # 2. Encoder Wrapper 269 | backbone = NeuroNetEncoderWrapper( 270 | fs=model_parameter['fs'], second=model_parameter['second'], 271 | time_window=model_parameter['time_window'], time_step=model_parameter['time_step'], 272 | frame_backbone=pretrained_model.frame_backbone, 273 | patch_embed=pretrained_model.autoencoder.patch_embed, 274 | encoder_block=pretrained_model.autoencoder.encoder_block, 275 | encoder_norm=pretrained_model.autoencoder.encoder_norm, 276 | cls_token=pretrained_model.autoencoder.cls_token, 277 | pos_embed=pretrained_model.autoencoder.pos_embed, 278 | final_length=pretrained_model.autoencoder.embed_dim, 279 | ) 280 | 281 | # 3. Temporal Context Module 282 | tcm = self.get_temporal_context_module() 283 | model = tcm(backbone=backbone, 284 | backbone_final_length=pretrained_model.autoencoder.embed_dim, 285 | embed_dim=self.args.embed_dim) 286 | return model 287 | 288 | def get_temporal_context_module(self): 289 | if self.args.temporal_context_modules == 'lstm': 290 | return LSTM_TCM 291 | if self.args.temporal_context_modules == 'mha': 292 | return MHA_TCM 293 | if self.args.temporal_context_modules == 'lstm_mha': 294 | return LSTM_MHA_TCM 295 | if self.args.temporal_context_modules == 'mamba': 296 | return MAMBA_TCM 297 | 298 | def get_loss(self, pred, real): 299 | if pred.dim() == 3: 300 | pred = pred.view(-1, pred.size(2)) 301 | real = real.view(-1) 302 | loss = self.criterion(pred, real) 303 | return loss, pred, real 304 | 305 | 306 | class TorchDataset(Dataset): 307 | def __init__(self, paths, temporal_context_length, window_size, 308 | sfreq: int, rfreq: int): 309 | super().__init__() 310 | self.sfreq, self.rfreq = sfreq, rfreq 311 | self.info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1']) 312 | self.x, self.y = self.get_data(paths, 313 | temporal_context_length=temporal_context_length, 314 | window_size=window_size) 315 | self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long) 316 | 317 | def get_data(self, paths, temporal_context_length, window_size): 318 | total_x, total_y = [], [] 319 | for path in paths: 320 | data = np.load(path) 321 | x, y = data['x'], data['y'] 322 | x = np.expand_dims(x, axis=1) 323 | x = mne.EpochsArray(x, info=self.info) 324 | x = x.resample(self.rfreq) 325 | x = x.get_data().squeeze() 326 | x = self.many_to_many(x, temporal_context_length, window_size) 327 | y = self.many_to_many(y, temporal_context_length, window_size) 328 | total_x.append(x) 329 | total_y.append(y) 330 | total_x, total_y = np.concatenate(total_x), np.concatenate(total_y) 331 | return total_x, total_y 332 | 333 | @staticmethod 334 | def many_to_many(elements, temporal_context_length, window_size): 335 | size = len(elements) 336 | total = [] 337 | if size <= temporal_context_length: 338 | return elements 339 | for i in range(0, size-temporal_context_length+1, window_size): 340 | temp = np.array(elements[i:i+temporal_context_length]) 341 | total.append(temp) 342 | total.append(elements[size-temporal_context_length:size]) 343 | total = np.array(total) 344 | return total 345 | 346 | def __len__(self): 347 | return len(self.y) 348 | 349 | def __getitem__(self, item): 350 | x = torch.tensor(self.x[item]) 351 | y = torch.tensor(self.y[item]) 352 | return x, y 353 | 354 | 355 | if __name__ == '__main__': 356 | augments = get_args() 357 | for n_fold in range(10): 358 | augments.n_fold = n_fold 359 | trainer = Trainer(augments) 360 | trainer.train() 361 | --------------------------------------------------------------------------------