├── models
├── __init__.py
├── neuronet
│ ├── __init__.py
│ ├── resnet1d.py
│ └── model.py
├── loss.py
└── utils.py
├── dataset
├── __init__.py
├── utils.py
└── selected_shhs1_files.txt
├── downstream
├── __init__.py
├── linear_prob.py
└── fine_tuning.py
├── pretrained
├── __init__.py
├── data_loader.py
└── train.py
├── figures
├── overview.jpg
├── hypnogram.jpg
└── model_structure.jpg
├── requirements.txt
├── README.md
├── .gitignore
└── LICENSE
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/downstream/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pretrained/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/neuronet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/overview.jpg
--------------------------------------------------------------------------------
/figures/hypnogram.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/hypnogram.jpg
--------------------------------------------------------------------------------
/figures/model_structure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dlcjfgmlnasa/NeuroNet/HEAD/figures/model_structure.jpg
--------------------------------------------------------------------------------
/dataset/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import numpy as np
4 | from sklearn.model_selection import KFold
5 | from scipy.signal import butter, lfilter
6 |
7 |
8 | def butter_bandpass_filter(signal, low_cut, high_cut, fs, order=5):
9 | if low_cut == 0:
10 | low_cut = 0.5
11 | nyq = 0.5 * fs
12 | low = low_cut / nyq
13 | high = high_cut / nyq
14 | b, a = butter(order, [low, high], btype='band')
15 | y = lfilter(b, a, signal, axis=-1)
16 | return y
17 |
18 |
19 | def split_train_test_val_files(base_path, n_splits=5):
20 | # Subject Variability
21 | files = os.listdir(base_path)
22 | files = np.array(files)
23 |
24 | size = len(files)
25 | print('File Path => ' + base_path)
26 | print('Total Subject Size => {}'.format(size))
27 | kf = KFold(n_splits=n_splits)
28 |
29 | temp = {}
30 | for fold, (train_idx, test_idx) in enumerate(kf.split(files)):
31 | train_size = len(train_idx)
32 | val_point = int(train_size * 0.85)
33 | train_idx, val_idx = train_idx[:val_point], train_idx[val_point:]
34 |
35 | temp[fold] = {
36 | 'train_paths': list([os.path.join(base_path, f_name) for f_name in files[train_idx]]),
37 | 'ft_paths': list([os.path.join(base_path, f_name) for f_name in files[val_idx]]),
38 | 'eval_paths': list([os.path.join(base_path, f_name) for f_name in files[test_idx]])
39 | }
40 | return temp
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2024.2.2
2 | charset-normalizer==3.3.2
3 | contourpy==1.2.1
4 | cycler==0.12.1
5 | decorator==5.1.1
6 | einops==0.8.0
7 | filelock==3.14.0
8 | fonttools==4.51.0
9 | fsspec==2024.3.1
10 | huggingface-hub==0.23.0
11 | idna==3.7
12 | importlib_resources==6.4.0
13 | Jinja2==3.1.4
14 | joblib==1.4.2
15 | kiwisolver==1.4.5
16 | lazy_loader==0.4
17 | mamba-ssm==1.2.0.post1
18 | MarkupSafe==2.1.5
19 | matplotlib==3.8.4
20 | mne==1.7.0
21 | mpmath==1.3.0
22 | networkx==3.2.1
23 | ninja==1.11.1.1
24 | numpy==1.26.4
25 | nvidia-cublas-cu12==12.1.3.1
26 | nvidia-cuda-cupti-cu12==12.1.105
27 | nvidia-cuda-nvrtc-cu12==12.1.105
28 | nvidia-cuda-runtime-cu12==12.1.105
29 | nvidia-cudnn-cu12==8.9.2.26
30 | nvidia-cufft-cu12==11.0.2.54
31 | nvidia-curand-cu12==10.3.2.106
32 | nvidia-cusolver-cu12==11.4.5.107
33 | nvidia-cusparse-cu12==12.1.0.106
34 | nvidia-nccl-cu12==2.19.3
35 | nvidia-nvjitlink-cu12==12.4.127
36 | nvidia-nvtx-cu12==12.1.105
37 | packaging==24.0
38 | pillow==10.3.0
39 | platformdirs==4.2.1
40 | pooch==1.8.1
41 | pyparsing==3.1.2
42 | python-dateutil==2.9.0.post0
43 | PyYAML==6.0.1
44 | regex==2024.5.10
45 | requests==2.31.0
46 | safetensors==0.4.3
47 | scikit-learn==1.4.2
48 | scipy==1.13.0
49 | six==1.16.0
50 | sympy==1.12
51 | threadpoolctl==3.5.0
52 | timm==0.9.16
53 | tokenizers==0.19.1
54 | torch==2.2.2
55 | torchvision==0.17.2
56 | tqdm==4.66.4
57 | transformers==4.40.2
58 | triton==2.2.0
59 | typing_extensions==4.11.0
60 | urllib3==2.2.1
61 | zipp==3.18.1
62 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as f
6 |
7 |
8 | class NTXentLoss(nn.Module):
9 | def __init__(self, temperature):
10 | super().__init__()
11 | self.criterion = nn.CrossEntropyLoss(reduction='sum')
12 | self.similarity_f = nn.CosineSimilarity(dim=-1)
13 | self.temperature = temperature
14 |
15 | @staticmethod
16 | def mask_correlated_samples(batch_size):
17 | n = 2 * batch_size
18 | mask = torch.ones((n, n), dtype=bool)
19 | mask = mask.fill_diagonal_(0)
20 |
21 | for i in range(batch_size):
22 | mask[i, batch_size + i] = 0
23 | mask[batch_size + i, i] = 0
24 | return mask
25 |
26 | def forward(self, z_i, z_j):
27 | batch_size = z_j.shape[0]
28 | n = 2 * batch_size
29 | z = torch.cat((z_i, z_j), dim=0)
30 | z = f.normalize(z, dim=-1)
31 |
32 | mask = self.mask_correlated_samples(batch_size)
33 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0))
34 |
35 | sim_i_j = torch.diag(sim, batch_size)
36 | sim_j_i = torch.diag(sim, -batch_size)
37 |
38 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(n, 1)
39 | negative_samples = sim[mask].reshape(n, -1)
40 |
41 | labels = torch.from_numpy(np.array([0] * n)).reshape(-1).to(positive_samples.device).long() # .float()
42 | logits = torch.cat((positive_samples, negative_samples), dim=1) / self.temperature
43 |
44 | loss = self.criterion(logits, labels)
45 | loss /= n
46 | return loss, (labels, logits)
47 |
--------------------------------------------------------------------------------
/pretrained/data_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import mne
3 | import torch
4 | import random
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 | import warnings
8 |
9 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
10 |
11 | random_seed = 777
12 | np.random.seed(random_seed)
13 | torch.manual_seed(random_seed)
14 | random.seed(random_seed)
15 |
16 |
17 | class TorchDataset(Dataset):
18 | def __init__(self, paths, sfreq, rfreq, scaler: bool = False):
19 | super().__init__()
20 | self.x, self.y = self.get_data(paths, sfreq, rfreq, scaler)
21 | self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long)
22 |
23 | @staticmethod
24 | def get_data(paths, sfreq, rfreq, scaler_flag):
25 | info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1'])
26 | scaler = mne.decoding.Scaler(info=info, scalings='median')
27 | total_x, total_y = [], []
28 | for path in paths:
29 | data = np.load(path)
30 | x, y = data['x'], data['y']
31 | x = np.expand_dims(x, axis=1)
32 | if scaler_flag:
33 | x = scaler.fit_transform(x)
34 | x = mne.EpochsArray(x, info=info)
35 | x = x.resample(rfreq)
36 | x = x.get_data().squeeze()
37 | total_x.append(x)
38 | total_y.append(y)
39 | total_x, total_y = np.concatenate(total_x), np.concatenate(total_y)
40 | return total_x, total_y
41 |
42 | def __len__(self):
43 | return len(self.y)
44 |
45 | def __getitem__(self, item):
46 | x = torch.tensor(self.x[item])
47 | y = torch.tensor(self.y[item])
48 | return x, y
49 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
7 | grid = np.arange(grid_size, dtype=float)
8 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
9 | if cls_token:
10 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
11 | return pos_embed
12 |
13 |
14 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
15 | assert embed_dim % 2 == 0
16 | omega = np.arange(embed_dim // 2, dtype=float)
17 | omega /= embed_dim / 2.
18 | omega = 1. / 10000 ** omega # (D/2,)
19 |
20 | pos = pos.reshape(-1) # (M,)
21 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
22 |
23 | emb_sin = np.sin(out) # (M, D/2)
24 | emb_cos = np.cos(out) # (M, D/2)
25 |
26 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
27 | return emb
28 |
29 |
30 | def get_2d_sincos_pos_embed(embed_dim, grid_sizes, cls_token=False):
31 | """
32 | grid_size: int of the grid height and width
33 | return:
34 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
35 | """
36 | grid_h = np.arange(grid_sizes[0], dtype=np.float32)
37 | grid_w = np.arange(grid_sizes[1], dtype=np.float32)
38 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
39 | grid = np.stack(grid, axis=0)
40 |
41 | grid = grid.reshape([2, 1, grid_sizes[0], grid_sizes[1]])
42 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
43 | if cls_token:
44 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
45 | return pos_embed
46 |
47 |
48 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
49 | assert embed_dim % 2 == 0
50 |
51 | # use half of dimensions to encode grid_h
52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
54 |
55 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
56 | return emb
57 |
58 |
59 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
60 | """
61 | grid_size: int of the grid height and width
62 | return:
63 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
64 | """
65 | grid_h = np.arange(grid_size[0], dtype=np.float32)
66 | grid_w = np.arange(grid_size[1], dtype=np.float32)
67 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
68 | grid = np.stack(grid, axis=0)
69 |
70 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
71 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
72 | if cls_token:
73 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
74 | return pos_embed
75 |
76 |
77 | def model_size(model):
78 | size_model = 0
79 | for param in model.parameters():
80 | if param.data.is_floating_point():
81 | size_model += param.numel() * torch.finfo(param.data.dtype).bits
82 | else:
83 | size_model += param.numel() * torch.iinfo(param.data.dtype).bits
84 | mb_size = size_model / 8e6
85 | return mb_size
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeuroNet: A Novel Hybrid Self-Supervised Learning Framework for Sleep Stage Classification Using Single-Channel EEG
2 |
3 | >**Choel-Hui Lee, Hakseung Kim, Hyun-jee Han, Min-Kyung Jung, Byung C. Yoon and Dong-Joo Kim**
4 |
5 | [[`Paper`](https://arxiv.org/abs/2404.17585)] [[`Paper with Code`](https://paperswithcode.com/paper/neuronet-a-novel-hybrid-self-supervised)] [[`BibTeX`](#license-and-citation)]
6 |
7 | **Full code coming soon^^**
8 |
9 | 
10 |
11 | ## Introduction 🔥
12 | The classification of sleep stages is a pivotal aspect of diagnosing sleep disorders and evaluating sleep quality. However, the conventional manual scoring process, conducted by clinicians, is time-consuming and prone to human bias. Recent advancements in deep learning have substantially propelled the automation of sleep stage classification. Nevertheless, challenges persist, including the need for large datasets with labels and the inherent biases in human-generated annotations. This paper introduces NeuroNet, a self-supervised learning (SSL) framework designed to effectively harness unlabeled single-channel sleep electroencephalogram (EEG) signals by integrating contrastive learning tasks and masked prediction tasks. NeuroNet demonstrates superior performance over existing SSL methodologies through extensive experimentation conducted across three polysomnography (PSG) datasets. Additionally, this study proposes a Mamba-based temporal context module to capture the relationships among diverse EEG epochs. Combining NeuroNet with the Mamba-based temporal context module has demonstrated the capability to achieve, or even surpass, the performance of the latest supervised learning methodologies, even with a limited amount of labeled data. This study is expected to establish a new benchmark in sleep stage classification, promising to guide future research and applications in the field of sleep analysis.
13 |
14 | ## Main Result 🥇
15 |
16 |
17 | Performance of Sleep-EDFX across various self-supervised learning and supervised learning
18 |
19 |
20 |
21 |
22 | The output hypnograms across five sleep stages.
23 |
24 | The first, second, and third columns correspond to #sc4031e0, #shhs1-204928, and #subject-53 within Sleep-EDFX, SHHS, and ISRUC, respectively. (A) is manually scored by a sleep expert. (B) and (C) respectively represent NeuroNet-B and NeuroNet-T. The first row for both (B) and (C) displays the results for NeuroNet+TCM, while the second row shows the results for NeuroNet. The errors are marked by the red dots.
25 |
26 |
27 |
28 | ## License and Citation 📰
29 | The software is licensed under the Apache License 2.0. Please cite the following paper if you have used this code:
30 | ```
31 | @misc{lee2024neuronet,
32 | title={NeuroNet: A Novel Hybrid Self-Supervised Learning Framework for Sleep Stage Classification Using Single-Channel EEG},
33 | author={Cheol-Hui Lee and Hakseung Kim and Hyun-jee Han and Min-Kyung Jung and Byung C. Yoon and Dong-Joo Kim},
34 | year={2024},
35 | eprint={2404.17585},
36 | archivePrefix={arXiv},
37 | primaryClass={cs.HC}
38 | }
39 | ```
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | .idea/
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
--------------------------------------------------------------------------------
/dataset/selected_shhs1_files.txt:
--------------------------------------------------------------------------------
1 | shhs1-200010
2 | shhs1-200017
3 | shhs1-200039
4 | shhs1-200052
5 | shhs1-200097
6 | shhs1-200122
7 | shhs1-200139
8 | shhs1-200152
9 | shhs1-200166
10 | shhs1-200174
11 | shhs1-200178
12 | shhs1-200214
13 | shhs1-200215
14 | shhs1-200258
15 | shhs1-200319
16 | shhs1-200336
17 | shhs1-200348
18 | shhs1-200349
19 | shhs1-200383
20 | shhs1-200390
21 | shhs1-200408
22 | shhs1-200409
23 | shhs1-200437
24 | shhs1-200466
25 | shhs1-200467
26 | shhs1-200477
27 | shhs1-200496
28 | shhs1-200516
29 | shhs1-200536
30 | shhs1-200558
31 | shhs1-200569
32 | shhs1-200606
33 | shhs1-200631
34 | shhs1-200645
35 | shhs1-200665
36 | shhs1-200691
37 | shhs1-200701
38 | shhs1-200702
39 | shhs1-200725
40 | shhs1-200737
41 | shhs1-200741
42 | shhs1-200771
43 | shhs1-200780
44 | shhs1-200783
45 | shhs1-200785
46 | shhs1-200793
47 | shhs1-200805
48 | shhs1-200823
49 | shhs1-200834
50 | shhs1-200843
51 | shhs1-200853
52 | shhs1-200869
53 | shhs1-200870
54 | shhs1-200879
55 | shhs1-200884
56 | shhs1-200886
57 | shhs1-200895
58 | shhs1-200897
59 | shhs1-200902
60 | shhs1-200914
61 | shhs1-200920
62 | shhs1-200922
63 | shhs1-200924
64 | shhs1-200953
65 | shhs1-200964
66 | shhs1-200975
67 | shhs1-200983
68 | shhs1-201021
69 | shhs1-201023
70 | shhs1-201081
71 | shhs1-201087
72 | shhs1-201090
73 | shhs1-201102
74 | shhs1-201117
75 | shhs1-201118
76 | shhs1-201125
77 | shhs1-201153
78 | shhs1-201171
79 | shhs1-201180
80 | shhs1-201237
81 | shhs1-201243
82 | shhs1-201254
83 | shhs1-201270
84 | shhs1-201316
85 | shhs1-201329
86 | shhs1-201359
87 | shhs1-201371
88 | shhs1-201515
89 | shhs1-201552
90 | shhs1-201560
91 | shhs1-201566
92 | shhs1-201581
93 | shhs1-201586
94 | shhs1-201637
95 | shhs1-201706
96 | shhs1-201725
97 | shhs1-201748
98 | shhs1-201792
99 | shhs1-201839
100 | shhs1-201917
101 | shhs1-201946
102 | shhs1-201986
103 | shhs1-202003
104 | shhs1-202039
105 | shhs1-202054
106 | shhs1-202097
107 | shhs1-202108
108 | shhs1-202139
109 | shhs1-202160
110 | shhs1-202222
111 | shhs1-202240
112 | shhs1-202253
113 | shhs1-202257
114 | shhs1-202261
115 | shhs1-202267
116 | shhs1-202282
117 | shhs1-202293
118 | shhs1-202392
119 | shhs1-202409
120 | shhs1-202424
121 | shhs1-202516
122 | shhs1-202552
123 | shhs1-202589
124 | shhs1-202636
125 | shhs1-202650
126 | shhs1-202675
127 | shhs1-202694
128 | shhs1-202715
129 | shhs1-202723
130 | shhs1-202735
131 | shhs1-202738
132 | shhs1-202865
133 | shhs1-202907
134 | shhs1-202911
135 | shhs1-202940
136 | shhs1-202956
137 | shhs1-202983
138 | shhs1-202987
139 | shhs1-202996
140 | shhs1-203027
141 | shhs1-203040
142 | shhs1-203106
143 | shhs1-203137
144 | shhs1-203168
145 | shhs1-203171
146 | shhs1-203173
147 | shhs1-203180
148 | shhs1-203185
149 | shhs1-203203
150 | shhs1-203218
151 | shhs1-203233
152 | shhs1-203250
153 | shhs1-203254
154 | shhs1-203286
155 | shhs1-203312
156 | shhs1-203324
157 | shhs1-203350
158 | shhs1-203358
159 | shhs1-203380
160 | shhs1-203381
161 | shhs1-203389
162 | shhs1-203409
163 | shhs1-203410
164 | shhs1-203421
165 | shhs1-203433
166 | shhs1-203460
167 | shhs1-203478
168 | shhs1-203479
169 | shhs1-203514
170 | shhs1-203523
171 | shhs1-203533
172 | shhs1-203546
173 | shhs1-203547
174 | shhs1-203580
175 | shhs1-203614
176 | shhs1-203622
177 | shhs1-203651
178 | shhs1-203655
179 | shhs1-203687
180 | shhs1-203691
181 | shhs1-203699
182 | shhs1-203726
183 | shhs1-203734
184 | shhs1-203746
185 | shhs1-203791
186 | shhs1-203824
187 | shhs1-203845
188 | shhs1-203847
189 | shhs1-203850
190 | shhs1-203862
191 | shhs1-203870
192 | shhs1-203884
193 | shhs1-203889
194 | shhs1-203897
195 | shhs1-203901
196 | shhs1-203941
197 | shhs1-203965
198 | shhs1-203971
199 | shhs1-203997
200 | shhs1-204009
201 | shhs1-204027
202 | shhs1-204028
203 | shhs1-204041
204 | shhs1-204042
205 | shhs1-204047
206 | shhs1-204051
207 | shhs1-204062
208 | shhs1-204074
209 | shhs1-204083
210 | shhs1-204094
211 | shhs1-204103
212 | shhs1-204106
213 | shhs1-204107
214 | shhs1-204111
215 | shhs1-204115
216 | shhs1-204132
217 | shhs1-204134
218 | shhs1-204145
219 | shhs1-204148
220 | shhs1-204149
221 | shhs1-204165
222 | shhs1-204187
223 | shhs1-204189
224 | shhs1-204219
225 | shhs1-204230
226 | shhs1-204234
227 | shhs1-204282
228 | shhs1-204283
229 | shhs1-204294
230 | shhs1-204298
231 | shhs1-204299
232 | shhs1-204317
233 | shhs1-204318
234 | shhs1-204350
235 | shhs1-204351
236 | shhs1-204355
237 | shhs1-204401
238 | shhs1-204405
239 | shhs1-204431
240 | shhs1-204434
241 | shhs1-204453
242 | shhs1-204473
243 | shhs1-204486
244 | shhs1-204494
245 | shhs1-204500
246 | shhs1-204509
247 | shhs1-204519
248 | shhs1-204529
249 | shhs1-204543
250 | shhs1-204550
251 | shhs1-204553
252 | shhs1-204554
253 | shhs1-204576
254 | shhs1-204611
255 | shhs1-204641
256 | shhs1-204666
257 | shhs1-204684
258 | shhs1-204688
259 | shhs1-204706
260 | shhs1-204710
261 | shhs1-204714
262 | shhs1-204724
263 | shhs1-204731
264 | shhs1-204737
265 | shhs1-204741
266 | shhs1-204743
267 | shhs1-204745
268 | shhs1-204748
269 | shhs1-204751
270 | shhs1-204754
271 | shhs1-204775
272 | shhs1-204777
273 | shhs1-204781
274 | shhs1-204783
275 | shhs1-204785
276 | shhs1-204787
277 | shhs1-204796
278 | shhs1-204802
279 | shhs1-204811
280 | shhs1-204817
281 | shhs1-204818
282 | shhs1-204826
283 | shhs1-204830
284 | shhs1-204831
285 | shhs1-204836
286 | shhs1-204844
287 | shhs1-204846
288 | shhs1-204851
289 | shhs1-204865
290 | shhs1-204873
291 | shhs1-204879
292 | shhs1-204887
293 | shhs1-204890
294 | shhs1-204894
295 | shhs1-204905
296 | shhs1-204908
297 | shhs1-204911
298 | shhs1-204916
299 | shhs1-204927
300 | shhs1-204928
301 | shhs1-204934
302 | shhs1-204980
303 | shhs1-204985
304 | shhs1-205045
305 | shhs1-205049
306 | shhs1-205071
307 | shhs1-205083
308 | shhs1-205146
309 | shhs1-205148
310 | shhs1-205178
311 | shhs1-205207
312 | shhs1-205213
313 | shhs1-205238
314 | shhs1-205244
315 | shhs1-205245
316 | shhs1-205257
317 | shhs1-205380
318 | shhs1-205429
319 | shhs1-205451
320 | shhs1-205477
321 | shhs1-205502
322 | shhs1-205516
323 | shhs1-205582
324 | shhs1-205610
325 | shhs1-205676
326 | shhs1-205700
327 | shhs1-205741
328 | shhs1-205780
329 | shhs1-205789
--------------------------------------------------------------------------------
/downstream/linear_prob.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import mne
4 | import torch
5 | import random
6 | import argparse
7 | import warnings
8 | import numpy as np
9 | import torch.nn as nn
10 | from typing import List
11 | import torch.optim as opt
12 | from torch.utils.data import Dataset, DataLoader
13 | from sklearn.metrics import accuracy_score, f1_score
14 | from models.neuronet.model import NeuroNet, NeuroNetEncoderWrapper
15 |
16 |
17 | warnings.filterwarnings(action='ignore')
18 |
19 |
20 | random_seed = 777
21 | np.random.seed(random_seed)
22 | torch.manual_seed(random_seed)
23 | random.seed(random_seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 |
29 |
30 | def get_args():
31 | file_name = 'mini'
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--n_fold', default=0, choices=[0, 1, 2, 3, 4])
34 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt',
35 | 'ISRUC-Sleep', 'cm_eeg', file_name), type=str)
36 | parser.add_argument('--epochs', default=300, type=int)
37 | parser.add_argument('--batch_size', default=512, type=int)
38 | parser.add_argument('--lr', default=0.00005, type=float)
39 | return parser.parse_args()
40 |
41 |
42 | class Classifier(nn.Module):
43 | def __init__(self, backbone, backbone_final_length):
44 | super().__init__()
45 | self.backbone = self.freeze_backbone(backbone)
46 | self.backbone_final_length = backbone_final_length
47 | self.feature_num = self.backbone_final_length * 2
48 | self.dropout_p = 0.5
49 | self.fc = nn.Sequential(
50 | nn.Linear(backbone_final_length, self.feature_num),
51 | nn.BatchNorm1d(self.feature_num),
52 | nn.ELU(),
53 | nn.Dropout(self.dropout_p),
54 | nn.Linear(self.feature_num, 5)
55 | )
56 |
57 | def forward(self, x):
58 | x = self.backbone(x)
59 | x = self.fc(x)
60 | return x
61 |
62 | @staticmethod
63 | def freeze_backbone(backbone: nn.Module):
64 | for name, module in backbone.named_modules():
65 | for param in module.parameters():
66 | param.requires_grad = False
67 | return backbone
68 |
69 |
70 | class Trainer(object):
71 | def __init__(self, args):
72 | super().__init__()
73 | self.args = args
74 | self.ckpt_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'model', 'best_model.pth')
75 | self.ckpt = torch.load(self.ckpt_path, map_location='cpu')
76 | self.sfreq, self.rfreq = self.ckpt['hyperparameter']['sfreq'], self.ckpt['hyperparameter']['rfreq']
77 | self.ft_paths, self.eval_paths = self.ckpt['paths']['ft_paths'], self.ckpt['paths']['eval_paths']
78 | self.model = self.get_pretrained_model().to(device)
79 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.args.lr)
80 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.epochs)
81 | self.criterion = nn.CrossEntropyLoss()
82 |
83 | def train(self):
84 | print('Checkpoint File Path : {}'.format(self.ckpt_path))
85 | train_dataset = TorchDataset(paths=self.ft_paths, sfreq=self.sfreq, rfreq=self.rfreq)
86 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.args.batch_size,
87 | shuffle=True, drop_last=True)
88 | eval_dataset = TorchDataset(paths=self.eval_paths, sfreq=self.sfreq, rfreq=self.rfreq)
89 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=self.args.batch_size, drop_last=False)
90 |
91 | best_model_state, best_mf1 = None, 0.0
92 | best_pred, best_real = [], []
93 |
94 | for epoch in range(self.args.epochs):
95 | self.model.train()
96 | epoch_train_loss = []
97 | for data in train_dataloader:
98 | self.optimizer.zero_grad()
99 | x, y = data
100 | x, y = x.to(device), y.to(device)
101 |
102 | pred = self.model(x)
103 | loss = self.criterion(pred, y)
104 |
105 | epoch_train_loss.append(float(loss.detach().cpu().item()))
106 | loss.backward()
107 | self.optimizer.step()
108 |
109 | self.model.eval()
110 | epoch_test_loss = []
111 | epoch_real, epoch_pred = [], []
112 | for data in eval_dataloader:
113 | with torch.no_grad():
114 | x, y = data
115 | x, y = x.to(device), y.to(device)
116 | pred = self.model(x)
117 | loss = self.criterion(pred, y)
118 | pred = pred.argmax(dim=-1)
119 | real = y
120 |
121 | epoch_real.extend(list(real.detach().cpu().numpy()))
122 | epoch_pred.extend(list(pred.detach().cpu().numpy()))
123 | epoch_test_loss.append(float(loss.detach().cpu().item()))
124 |
125 | epoch_train_loss, epoch_test_loss = np.mean(epoch_train_loss), np.mean(epoch_test_loss)
126 | eval_acc, eval_mf1 = accuracy_score(y_true=epoch_real, y_pred=epoch_pred), \
127 | f1_score(y_true=epoch_real, y_pred=epoch_pred, average='macro')
128 |
129 | print('[Epoch] : {0:03d} \t '
130 | '[Train Loss] => {1:.4f} \t '
131 | '[Evaluation Loss] => {2:.4f} \t '
132 | '[Evaluation Accuracy] => {3:.4f} \t'
133 | '[Evaluation Macro-F1] => {4:.4f}'.format(epoch + 1, epoch_train_loss, epoch_test_loss,
134 | eval_acc, eval_mf1))
135 |
136 | if best_mf1 < eval_mf1:
137 | best_mf1 = eval_mf1
138 | best_model_state = self.model.state_dict()
139 | best_pred, best_real = epoch_pred, epoch_real
140 |
141 | self.scheduler.step()
142 |
143 | self.save_ckpt(best_model_state, best_pred, best_real)
144 |
145 | def save_ckpt(self, model_state, pred, real):
146 | if not os.path.exists(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob')):
147 | os.makedirs(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob'))
148 |
149 | save_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'linear_prob', 'best_model.pth')
150 | torch.save({
151 | 'backbone_name': 'NeuroNet_LinearProb',
152 | 'model_state': model_state,
153 | 'hyperparameter': self.args.__dict__,
154 | 'result': {'real': real, 'pred': pred},
155 | 'paths': {'train_paths': self.ft_paths, 'eval_paths': self.eval_paths}
156 | }, save_path)
157 |
158 | def get_pretrained_model(self):
159 | # 1. Prepared Pretrained Model
160 | model_parameter = self.ckpt['model_parameter']
161 | pretrained_model = NeuroNet(**model_parameter)
162 | pretrained_model.load_state_dict(self.ckpt['model_state'])
163 |
164 | # 2. Encoder Wrapper
165 | backbone = NeuroNetEncoderWrapper(
166 | fs=model_parameter['fs'], second=model_parameter['second'],
167 | time_window=model_parameter['time_window'], time_step=model_parameter['time_step'],
168 | frame_backbone=pretrained_model.frame_backbone,
169 | patch_embed=pretrained_model.autoencoder.patch_embed,
170 | encoder_block=pretrained_model.autoencoder.encoder_block,
171 | encoder_norm=pretrained_model.autoencoder.encoder_norm,
172 | cls_token=pretrained_model.autoencoder.cls_token,
173 | pos_embed=pretrained_model.autoencoder.pos_embed,
174 | final_length=pretrained_model.autoencoder.embed_dim
175 | )
176 |
177 | # 3. Generator Classifier
178 | model = Classifier(backbone=backbone,
179 | backbone_final_length=pretrained_model.autoencoder.embed_dim)
180 | return model
181 |
182 |
183 | class TorchDataset(Dataset):
184 | def __init__(self, paths: List, sfreq: int, rfreq: int):
185 | self.paths = paths
186 | self.info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1'])
187 | self.xs, self.ys = self.get_data(rfreq)
188 |
189 | def __len__(self):
190 | return self.xs.shape[0]
191 |
192 | def get_data(self, rfreq):
193 | xs, ys = [], []
194 | for path in self.paths:
195 | data = np.load(path)
196 | x, y = data['x'], data['y']
197 | x = np.expand_dims(x, axis=1)
198 | x = mne.EpochsArray(x, info=self.info)
199 | x = x.resample(rfreq)
200 | x = x.get_data().squeeze()
201 | xs.append(x)
202 | ys.append(y)
203 | xs = np.concatenate(xs, axis=0)
204 | ys = np.concatenate(ys, axis=0)
205 | return xs, ys
206 |
207 | def __getitem__(self, idx):
208 | x = torch.tensor(self.xs[idx], dtype=torch.float)
209 | y = torch.tensor(self.ys[idx], dtype=torch.long)
210 | return x, y
211 |
212 |
213 | if __name__ == '__main__':
214 | augments = get_args()
215 | for n_fold in range(10):
216 | augments.n_fold = n_fold
217 | trainer = Trainer(augments)
218 | trainer.train()
219 |
--------------------------------------------------------------------------------
/models/neuronet/resnet1d.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class FrameBackBone(nn.Module):
7 | def __init__(self, fs: int, window: int):
8 | super().__init__()
9 | self.model = BackBone(input_size=fs * window, input_channel=1, layers=[1, 1, 1, 1])
10 | self.feature_num = self.model.get_final_length() // 2
11 | self.feature_layer = nn.Sequential(
12 | nn.Linear(self.model.get_final_length(), self.feature_num),
13 | nn.ELU(),
14 | nn.Linear(self.feature_num, self.feature_num)
15 | )
16 |
17 | def forward(self, x):
18 | latent_seq = []
19 | for i in range(x.shape[1]):
20 | sample = torch.unsqueeze(x[:, i, :], dim=1)
21 | latent = self.model(sample)
22 | latent_seq.append(latent)
23 | latent_seq = torch.stack(latent_seq, dim=1)
24 | latent_seq = self.feature_layer(latent_seq)
25 | return latent_seq
26 |
27 |
28 | class BackBone(nn.Module):
29 | def __init__(self, input_size, input_channel, layers):
30 | super().__init__()
31 | self.inplanes3 = 32
32 | self.inplanes5 = 32
33 | self.inplanes7 = 32
34 |
35 | self.input_size = input_size
36 | self.conv1 = nn.Conv1d(input_channel, 32, kernel_size=7, stride=2, padding=3, bias=False)
37 | self.bn1 = nn.BatchNorm1d(32)
38 | self.relu = nn.ELU(inplace=True)
39 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
40 |
41 | self.layer3x3_1 = self._make_layer3(BasicBlock3x3, 32, layers[0], stride=1)
42 | self.layer3x3_2 = self._make_layer3(BasicBlock3x3, 32, layers[1], stride=1)
43 | self.layer3x3_3 = self._make_layer3(BasicBlock3x3, 48, layers[2], stride=2)
44 | self.layer3x3_4 = self._make_layer3(BasicBlock3x3, 64, layers[3], stride=2)
45 | self.maxpool3 = nn.AvgPool1d(kernel_size=16, stride=1, padding=0)
46 |
47 | self.layer5x5_1 = self._make_layer5(BasicBlock5x5, 32, layers[0], stride=1)
48 | self.layer5x5_2 = self._make_layer5(BasicBlock5x5, 32, layers[1], stride=1)
49 | self.layer5x5_3 = self._make_layer5(BasicBlock5x5, 48, layers[2], stride=2)
50 | self.layer5x5_4 = self._make_layer5(BasicBlock5x5, 64, layers[3], stride=2)
51 | self.maxpool5 = nn.AvgPool1d(kernel_size=11, stride=1, padding=0)
52 |
53 | self.layer7x7_1 = self._make_layer7(BasicBlock7x7, 32, layers[0], stride=1)
54 | self.layer7x7_2 = self._make_layer7(BasicBlock7x7, 32, layers[1], stride=1)
55 | self.layer7x7_3 = self._make_layer7(BasicBlock7x7, 48, layers[2], stride=2)
56 | self.layer7x7_4 = self._make_layer7(BasicBlock7x7, 64, layers[3], stride=2)
57 | self.maxpool7 = nn.AvgPool1d(kernel_size=6, stride=1, padding=0)
58 |
59 | def forward(self, x0):
60 | b = x0.shape[0]
61 | x0 = self.conv1(x0)
62 | x0 = self.bn1(x0)
63 | x0 = self.relu(x0)
64 | x0 = self.maxpool(x0)
65 |
66 | x1 = self.layer3x3_1(x0)
67 | x1 = self.layer3x3_2(x1)
68 | x1 = self.layer3x3_3(x1)
69 | x1 = self.layer3x3_4(x1)
70 | x1 = self.maxpool3(x1)
71 |
72 | x2 = self.layer5x5_1(x0)
73 | x2 = self.layer5x5_2(x2)
74 | x2 = self.layer5x5_3(x2)
75 | x2 = self.layer5x5_4(x2)
76 | x2 = self.maxpool5(x2)
77 |
78 | x3 = self.layer7x7_1(x0)
79 | x3 = self.layer7x7_2(x3)
80 | x3 = self.layer7x7_3(x3)
81 | x3 = self.layer7x7_4(x3)
82 | x3 = self.maxpool7(x3)
83 |
84 | out = torch.cat([x1, x2, x3], dim=-1)
85 | out = torch.reshape(out, [b, -1])
86 | return out
87 |
88 | def _make_layer3(self, block, planes, blocks, stride=2):
89 | downsample = None
90 | if stride != 1 or self.inplanes3 != planes * block.expansion:
91 | downsample = nn.Sequential(
92 | nn.Conv1d(self.inplanes3, planes * block.expansion,
93 | kernel_size=1, stride=stride, bias=False),
94 | nn.BatchNorm1d(planes * block.expansion),
95 | )
96 |
97 | layers = list()
98 | layers.append(block(self.inplanes3, planes, stride, downsample))
99 | self.inplanes3 = planes * block.expansion
100 | for i in range(1, blocks):
101 | layers.append(block(self.inplanes3, planes))
102 |
103 | return nn.Sequential(*layers)
104 |
105 | def _make_layer5(self, block, planes, blocks, stride=2):
106 | downsample = None
107 | if stride != 1 or self.inplanes5 != planes * block.expansion:
108 | downsample = nn.Sequential(
109 | nn.Conv1d(self.inplanes5, planes * block.expansion,
110 | kernel_size=1, stride=stride, bias=False),
111 | nn.BatchNorm1d(planes * block.expansion),
112 | )
113 |
114 | layers = list()
115 | layers.append(block(self.inplanes5, planes, stride, downsample))
116 | self.inplanes5 = planes * block.expansion
117 | for i in range(1, blocks):
118 | layers.append(block(self.inplanes5, planes))
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def _make_layer7(self, block, planes, blocks, stride=2):
123 | downsample = None
124 | if stride != 1 or self.inplanes7 != planes * block.expansion:
125 | downsample = nn.Sequential(
126 | nn.Conv1d(self.inplanes7, planes * block.expansion,
127 | kernel_size=1, stride=stride, bias=False),
128 | nn.BatchNorm1d(planes * block.expansion),
129 | )
130 |
131 | layers = list()
132 | layers.append(block(self.inplanes7, planes, stride, downsample))
133 | self.inplanes7 = planes * block.expansion
134 | for i in range(1, blocks):
135 | layers.append(block(self.inplanes7, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def get_final_length(self):
140 | x = torch.randn(1, 1, self.input_size)
141 | x = self.forward(x)
142 | return x.shape[-1]
143 |
144 |
145 | class BasicBlock3x3(nn.Module):
146 | expansion = 1
147 |
148 | def __init__(self, inplanes3, planes, stride=1, downsample=None):
149 | super(BasicBlock3x3, self).__init__()
150 | self.conv1 = conv3x3(inplanes3, planes, stride)
151 | self.bn1 = nn.BatchNorm1d(planes)
152 | self.relu = nn.ELU(inplace=True)
153 | self.conv2 = conv3x3(planes, planes)
154 | self.bn2 = nn.BatchNorm1d(planes)
155 | self.downsample = downsample
156 | self.stride = stride
157 |
158 | def forward(self, x):
159 | residual = x
160 |
161 | out = self.conv1(x)
162 | out = self.bn1(out)
163 | out = self.relu(out)
164 |
165 | out = self.conv2(out)
166 | out = self.bn2(out)
167 |
168 | if self.downsample is not None:
169 | residual = self.downsample(x)
170 |
171 | out += residual
172 | out = self.relu(out)
173 |
174 | return out
175 |
176 |
177 | class BasicBlock5x5(nn.Module):
178 | expansion = 1
179 |
180 | def __init__(self, inplanes5, planes, stride=1, downsample=None):
181 | super(BasicBlock5x5, self).__init__()
182 | self.conv1 = conv5x5(inplanes5, planes, stride)
183 | self.bn1 = nn.BatchNorm1d(planes)
184 | self.relu = nn.ELU(inplace=True)
185 | self.conv2 = conv5x5(planes, planes)
186 | self.bn2 = nn.BatchNorm1d(planes)
187 | self.downsample = downsample
188 | self.stride = stride
189 |
190 | def forward(self, x):
191 | residual = x
192 |
193 | out = self.conv1(x)
194 | out = self.bn1(out)
195 | out = self.relu(out)
196 |
197 | out = self.conv2(out)
198 | out = self.bn2(out)
199 |
200 | if self.downsample is not None:
201 | residual = self.downsample(x)
202 |
203 | d = residual.shape[2] - out.shape[2]
204 | out1 = residual[:, :, 0:-d] + out
205 | out1 = self.relu(out1)
206 | return out1
207 |
208 |
209 | class BasicBlock7x7(nn.Module):
210 | expansion = 1
211 |
212 | def __init__(self, inplanes7, planes, stride=1, downsample=None):
213 | super(BasicBlock7x7, self).__init__()
214 | self.conv1 = conv7x7(inplanes7, planes, stride)
215 | self.bn1 = nn.BatchNorm1d(planes)
216 | self.relu = nn.ELU(inplace=True)
217 | self.conv2 = conv7x7(planes, planes)
218 | self.bn2 = nn.BatchNorm1d(planes)
219 | self.downsample = downsample
220 | self.stride = stride
221 |
222 | def forward(self, x):
223 | residual = x
224 |
225 | out = self.conv1(x)
226 | out = self.bn1(out)
227 | out = self.relu(out)
228 |
229 | out = self.conv2(out)
230 | out = self.bn2(out)
231 |
232 | if self.downsample is not None:
233 | residual = self.downsample(x)
234 |
235 | d = residual.shape[2] - out.shape[2]
236 | out1 = residual[:, :, 0:-d] + out
237 | out1 = self.relu(out1)
238 | return out1
239 |
240 |
241 | def conv3x3(in_planes, out_planes, stride=1):
242 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
243 | padding=1, bias=False)
244 |
245 |
246 | def conv5x5(in_planes, out_planes, stride=1):
247 | return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride,
248 | padding=1, bias=False)
249 |
250 |
251 | def conv7x7(in_planes, out_planes, stride=1):
252 | return nn.Conv1d(in_planes, out_planes, kernel_size=7, stride=stride,
253 | padding=1, bias=False)
254 |
255 |
256 | if __name__ == '__main__':
257 | # st = ST_BackBone(input_size=500)
258 | # ss = st(
259 | # torch.randn(50, 1, 500)
260 | # )
261 | fb = FrameBackBone(fs=100, window=3)
262 | ss = fb(
263 | torch.randn(50, 1, 300)
264 | )
265 | print(ss.shape)
266 |
267 |
268 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/pretrained/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import sys
4 | sys.path.extend([os.path.abspath('.'), os.path.abspath('..')])
5 |
6 | import mne
7 | import torch
8 | import random
9 | import shutil
10 | import argparse
11 | import warnings
12 | import numpy as np
13 | import torch.optim as opt
14 | from models.utils import model_size
15 | from sklearn.decomposition import PCA
16 | from torch.utils.tensorboard import SummaryWriter
17 | from sklearn.neighbors import KNeighborsClassifier
18 | from sklearn.metrics import accuracy_score, f1_score
19 | from torch.utils.data import DataLoader
20 | from dataset.utils import split_train_test_val_files
21 | from pretrained.data_loader import TorchDataset
22 | from models.neuronet.model import NeuroNet
23 |
24 |
25 | warnings.filterwarnings(action='ignore')
26 |
27 |
28 | random_seed = 777
29 | np.random.seed(random_seed)
30 | torch.manual_seed(random_seed)
31 | random.seed(random_seed)
32 | torch.backends.cudnn.deterministic = True
33 | torch.backends.cudnn.benchmark = False
34 |
35 | mne.set_log_level(False)
36 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
37 |
38 |
39 | def get_args():
40 | parser = argparse.ArgumentParser()
41 | # Dataset
42 | parser.add_argument('--base_path', default=os.path.join('..', '..', '..', 'data', 'stage', 'Sleep-EDFX-2018'))
43 | parser.add_argument('--k_splits', default=5)
44 | parser.add_argument('--n_fold', default=0, choices=[0, 1, 2, 3, 4])
45 |
46 | # Dataset Hyperparameter
47 | parser.add_argument('--sfreq', default=100, type=int)
48 | parser.add_argument('--rfreq', default=100, type=int)
49 | parser.add_argument('--data_scaler', default=False, type=bool)
50 |
51 | # Train Hyperparameter
52 | parser.add_argument('--train_epochs', default=30, type=int)
53 | parser.add_argument('--train_warmup_epoch', type=int, default=100)
54 | parser.add_argument('--train_base_learning_rate', default=1e-5, type=float)
55 | parser.add_argument('--train_batch_size', default=256, type=int)
56 | parser.add_argument('--train_batch_accumulation', default=1, type=int)
57 |
58 | # Model Hyperparameter
59 | parser.add_argument('--second', default=30, type=int)
60 | parser.add_argument('--time_window', default=3, type=int)
61 | parser.add_argument('--time_step', default=0.375, type=int)
62 |
63 | # >> 1. NeuroNet-M Hyperparameter
64 | # parser.add_argument('--encoder_dim', default=512, type=int)
65 | # parser.add_argument('--encoder_heads', default=8, type=int)
66 | # parser.add_argument('--encoder_depths', default=4, type=int)
67 | # parser.add_argument('--decoder_embed_dim', default=192, type=int)
68 | # parser.add_argument('--decoder_heads', default=8, type=int)
69 | # parser.add_argument('--decoder_depths', default=1, type=int)
70 |
71 | # >> 2. NeuroNet-B Hyperparameter
72 | parser.add_argument('--encoder_embed_dim', default=768, type=int)
73 | parser.add_argument('--encoder_heads', default=8, type=int)
74 | parser.add_argument('--encoder_depths', default=4, type=int)
75 | parser.add_argument('--decoder_embed_dim', default=256, type=int)
76 | parser.add_argument('--decoder_heads', default=8, type=int)
77 | parser.add_argument('--decoder_depths', default=3, type=int)
78 | parser.add_argument('--alpha', default=1.0, type=float)
79 |
80 | parser.add_argument('--projection_hidden', default=[1024, 512], type=list)
81 | parser.add_argument('--temperature', default=0.05, type=float)
82 | parser.add_argument('--mask_ratio', default=0.8, type=float)
83 | parser.add_argument('--print_point', default=20, type=int)
84 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt', 'Sleep-EDFX'), type=str)
85 | parser.add_argument('--model_name', default='mini')
86 | return parser.parse_args()
87 |
88 |
89 | class Trainer(object):
90 | def __init__(self, args):
91 | self.args = args
92 | self.model = NeuroNet(
93 | fs=args.rfreq, second=args.second, time_window=args.time_window, time_step=args.time_step,
94 | encoder_embed_dim=args.encoder_embed_dim, encoder_heads=args.encoder_heads, encoder_depths=args.encoder_depths,
95 | decoder_embed_dim=args.decoder_embed_dim, decoder_heads=args.decoder_heads,
96 | decoder_depths=args.decoder_depths, projection_hidden=args.projection_hidden, temperature=args.temperature
97 | ).to(device)
98 | print('Model Size : {0:.2f}MB'.format(model_size(self.model)))
99 |
100 | self.eff_batch_size = self.args.train_batch_size * self.args.train_batch_accumulation
101 | self.lr = self.args.train_base_learning_rate * self.eff_batch_size / 256
102 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.lr)
103 | self.train_paths, self.val_paths, self.eval_paths = self.data_paths()
104 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.train_epochs)
105 | self.tensorboard_path = os.path.join(self.args.ckpt_path, self.args.model_name,
106 | str(self.args.n_fold), 'tensorboard')
107 |
108 | # remote tensorboard files
109 | if os.path.exists(self.tensorboard_path):
110 | shutil.rmtree(self.tensorboard_path)
111 |
112 | self.tensorboard_writer = SummaryWriter(log_dir=self.tensorboard_path)
113 |
114 | print('Frame Size : {}'.format(self.model.num_patches))
115 | print('Leaning Rate : {0}'.format(self.lr))
116 | print('Validation Paths : {0}'.format(len(self.val_paths)))
117 | print('Evaluation Paths : {0}'.format(len(self.eval_paths)))
118 |
119 | def train(self):
120 | print('K-Fold : {}/{}'.format(self.args.n_fold + 1, self.args.k_splits))
121 | train_dataset = TorchDataset(paths=self.train_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq,
122 | scaler=self.args.data_scaler)
123 | train_dataloader = DataLoader(train_dataset, batch_size=self.args.train_batch_size, shuffle=True)
124 | val_dataset = TorchDataset(paths=self.val_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq,
125 | scaler=self.args.data_scaler)
126 | val_dataloader = DataLoader(val_dataset, batch_size=self.args.train_batch_size, drop_last=True)
127 | eval_dataset = TorchDataset(paths=self.eval_paths, sfreq=self.args.sfreq, rfreq=self.args.rfreq,
128 | scaler=self.args.data_scaler)
129 | eval_dataloader = DataLoader(eval_dataset, batch_size=self.args.train_batch_size, drop_last=True)
130 |
131 | total_step = 0
132 | best_model_state, best_score = self.model.state_dict(), 0
133 |
134 | for epoch in range(self.args.train_epochs):
135 | step = 0
136 | self.model.train()
137 | self.optimizer.zero_grad()
138 |
139 | for x, _ in train_dataloader:
140 | x = x.to(device)
141 | out = self.model(x, mask_ratio=self.args.mask_ratio)
142 | recon_loss, contrastive_loss, (cl_labels, cl_logits) = out
143 |
144 | loss = recon_loss + self.args.alpha * contrastive_loss
145 | loss.backward()
146 |
147 | if (step + 1) % self.args.train_batch_accumulation == 0:
148 | self.optimizer.step()
149 | self.optimizer.zero_grad()
150 |
151 | if (total_step + 1) % self.args.print_point == 0:
152 | print('[Epoch] : {0:03d} [Step] : {1:06d} '
153 | '[Reconstruction Loss] : {2:02.4f} [Contrastive Loss] : {3:02.4f} '
154 | '[Total Loss] : {4:02.4f} [Contrastive Acc] : {5:02.4f}'.format(
155 | epoch, total_step + 1, recon_loss, contrastive_loss, loss,
156 | self.compute_metrics(cl_logits, cl_labels)))
157 |
158 | self.tensorboard_writer.add_scalar('Reconstruction Loss', recon_loss, total_step)
159 | self.tensorboard_writer.add_scalar('Contrastive loss', contrastive_loss, total_step)
160 | self.tensorboard_writer.add_scalar('Total loss', loss, total_step)
161 |
162 | step += 1
163 | total_step += 1
164 |
165 | val_acc, val_mf1 = self.linear_probing(val_dataloader, eval_dataloader)
166 |
167 | if val_mf1 > best_score:
168 | best_model_state = self.model.state_dict()
169 | best_score = val_mf1
170 |
171 | print('[Epoch] : {0:03d} \t [Accuracy] : {1:2.4f} \t [Macro-F1] : {2:2.4f} \n'.format(
172 | epoch, val_acc * 100, val_mf1 * 100))
173 | self.tensorboard_writer.add_scalar('Validation Accuracy', val_acc, total_step)
174 | self.tensorboard_writer.add_scalar('Validation Macro-F1', val_mf1, total_step)
175 |
176 | self.optimizer.step()
177 | self.scheduler.step()
178 |
179 | self.save_ckpt(model_state=best_model_state)
180 |
181 | def linear_probing(self, val_dataloader, eval_dataloader):
182 | self.model.eval()
183 | (train_x, train_y), (test_x, test_y) = self.get_latent_vector(val_dataloader), \
184 | self.get_latent_vector(eval_dataloader)
185 | pca = PCA(n_components=50)
186 | train_x = pca.fit_transform(train_x)
187 | test_x = pca.transform(test_x)
188 |
189 | model = KNeighborsClassifier()
190 | model.fit(train_x, train_y)
191 |
192 | out = model.predict(test_x)
193 | acc, mf1 = accuracy_score(test_y, out), f1_score(test_y, out, average='macro')
194 | self.model.train()
195 | return acc, mf1
196 |
197 | def get_latent_vector(self, dataloader):
198 | total_x, total_y = [], []
199 | with torch.no_grad():
200 | for data in dataloader:
201 | x, y = data
202 | x, y = x.to(device), y.to(device)
203 | latent = self.model.forward_latent(x)
204 | total_x.append(latent.detach().cpu().numpy())
205 | total_y.append(y.detach().cpu().numpy())
206 | total_x, total_y = np.concatenate(total_x, axis=0), np.concatenate(total_y, axis=0)
207 | return total_x, total_y
208 |
209 | def save_ckpt(self, model_state):
210 | ckpt_path = os.path.join(self.args.ckpt_path, self.args.model_name, str(self.args.n_fold), 'model')
211 | if not os.path.exists(ckpt_path):
212 | os.makedirs(ckpt_path)
213 |
214 | torch.save({
215 | 'model_name': 'NeuroNet',
216 | 'model_state': model_state,
217 | 'model_parameter': {
218 | 'fs': self.args.rfreq, 'second': self.args.second,
219 | 'time_window': self.args.time_window, 'time_step': self.args.time_step,
220 | 'encoder_embed_dim': self.args.encoder_embed_dim, 'encoder_heads': self.args.encoder_heads,
221 | 'encoder_depths': self.args.encoder_depths,
222 | 'decoder_embed_dim': self.args.decoder_embed_dim, 'decoder_heads': self.args.decoder_heads,
223 | 'decoder_depths': self.args.decoder_depths,
224 | 'projection_hidden': self.args.projection_hidden, 'temperature': self.args.temperature
225 | },
226 | 'hyperparameter': self.args.__dict__,
227 | 'paths': {'train_paths': self.train_paths, 'ft_paths': self.val_paths, 'eval_paths': self.eval_paths}
228 | }, os.path.join(ckpt_path, 'best_model.pth'))
229 |
230 | def data_paths(self):
231 | kf = split_train_test_val_files(base_path=self.args.base_path, n_splits=self.args.k_splits)
232 |
233 | paths = kf[self.args.n_fold]
234 | train_paths, ft_paths, eval_paths = paths['train_paths'], paths['ft_paths'], paths['eval_paths']
235 | return train_paths, ft_paths, eval_paths
236 |
237 | @staticmethod
238 | def compute_metrics(output, target):
239 | output = output.argmax(dim=-1)
240 | accuracy = torch.mean(torch.eq(target, output).to(torch.float32))
241 | return accuracy
242 |
243 |
244 | if __name__ == '__main__':
245 | augments = get_args()
246 | for n_fold in range(augments.k_splits):
247 | augments.n_fold = n_fold
248 | trainer = Trainer(augments)
249 | trainer.train()
250 |
--------------------------------------------------------------------------------
/models/neuronet/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | from typing import List
6 | from models.neuronet.resnet1d import FrameBackBone
7 | from timm.models.vision_transformer import Block
8 | from models.utils import get_2d_sincos_pos_embed_flexible
9 | from models.loss import NTXentLoss
10 | from functools import partial
11 |
12 |
13 | class NeuroNet(nn.Module):
14 | def __init__(self, fs: int, second: int, time_window: int, time_step: float,
15 | encoder_embed_dim, encoder_heads: int, encoder_depths: int,
16 | decoder_embed_dim: int, decoder_heads: int, decoder_depths: int,
17 | projection_hidden: List, temperature=0.01):
18 | super().__init__()
19 | self.fs, self.second = fs, second
20 | self.time_window = time_window
21 | self.time_step = time_step
22 |
23 | self.num_patches, _ = frame_size(fs=fs, second=second, time_window=time_window, time_step=time_step)
24 | self.frame_backbone = FrameBackBone(fs=self.fs, window=self.time_window)
25 | self.autoencoder = MaskedAutoEncoderViT(input_size=self.frame_backbone.feature_num,
26 | encoder_embed_dim=encoder_embed_dim, num_patches=self.num_patches,
27 | encoder_heads=encoder_heads, encoder_depths=encoder_depths,
28 | decoder_embed_dim=decoder_embed_dim, decoder_heads=decoder_heads,
29 | decoder_depths=decoder_depths)
30 | self.contrastive_loss = NTXentLoss(temperature=temperature)
31 |
32 | projection_hidden = [encoder_embed_dim] + projection_hidden
33 | projectors = []
34 | for i, (h1, h2) in enumerate(zip(projection_hidden[:-1], projection_hidden[1:])):
35 | if i != len(projection_hidden) - 2:
36 | projectors.append(nn.Linear(h1, h2))
37 | projectors.append(nn.BatchNorm1d(h2))
38 | projectors.append(nn.ELU())
39 | else:
40 | projectors.append(nn.Linear(h1, h2))
41 | self.projectors = nn.Sequential(*projectors)
42 | self.projectors_bn = nn.BatchNorm1d(projection_hidden[-1], affine=False)
43 | self.norm_pix_loss = False
44 |
45 | def forward(self, x: torch.Tensor, mask_ratio: float = 0.5) -> (torch.Tensor, torch.Tensor):
46 | x = self.make_frame(x)
47 | x = self.frame_backbone(x)
48 |
49 | # Masked Prediction
50 | latent1, pred1, mask1 = self.autoencoder(x, mask_ratio)
51 | latent2, pred2, mask2 = self.autoencoder(x, mask_ratio)
52 | o1, o2 = latent1[:, :1, :].squeeze(), latent2[:, :1, :].squeeze()
53 | recon_loss1 = self.forward_mae_loss(x, pred1, mask1)
54 | recon_loss2 = self.forward_mae_loss(x, pred2, mask2)
55 | recon_loss = recon_loss1 + recon_loss2
56 |
57 | # Contrastive Learning
58 | o1, o2 = self.projectors(o1), self.projectors(o2)
59 | contrastive_loss, (labels, logits) = self.contrastive_loss(o1, o2)
60 | print(contrastive_loss)
61 | return recon_loss, contrastive_loss, (labels, logits)
62 |
63 | def forward_latent(self, x: torch.Tensor):
64 | x = self.make_frame(x)
65 | x = self.frame_backbone(x)
66 | latent = self.autoencoder.forward_encoder(x, mask_ratio=0)[0]
67 | latent_o = latent[:, :1, :].squeeze()
68 | return latent_o
69 |
70 | def forward_mae_loss(self,
71 | real: torch.Tensor,
72 | pred: torch.Tensor,
73 | mask: torch.Tensor):
74 |
75 | if self.norm_pix_loss:
76 | mean = real.mean(dim=-1, keepdim=True)
77 | var = real.var(dim=-1, keepdim=True)
78 | real = (real - mean) / (var + 1.e-6) ** .5
79 |
80 | loss = (pred - real) ** 2
81 | loss = loss.mean(dim=-1)
82 | loss = (loss * mask).sum() / mask.sum()
83 | return loss
84 |
85 | def make_frame(self, x):
86 | size = self.fs * self.second
87 | step = int(self.time_step * self.fs)
88 | window = int(self.time_window * self.fs)
89 | frame = []
90 | for i in range(0, size, step):
91 | start_idx, end_idx = i, i+window
92 | sample = x[..., start_idx: end_idx]
93 | if sample.shape[-1] == window:
94 | frame.append(sample)
95 | frame = torch.stack(frame, dim=1)
96 | return frame
97 |
98 |
99 | class MaskedAutoEncoderViT(nn.Module):
100 | def __init__(self, input_size: int, num_patches: int,
101 | encoder_embed_dim: int, encoder_heads: int, encoder_depths: int,
102 | decoder_embed_dim: int, decoder_heads: int, decoder_depths: int):
103 | super().__init__()
104 | self.patch_embed = nn.Linear(input_size, encoder_embed_dim)
105 | self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
106 | self.embed_dim = encoder_embed_dim
107 | self.encoder_depths = encoder_depths
108 | self.mlp_ratio = 4.
109 |
110 | self.input_size = (num_patches, encoder_embed_dim)
111 | self.patch_size = (1, encoder_embed_dim)
112 | self.grid_h = int(self.input_size[0] // self.patch_size[0])
113 | self.grid_w = int(self.input_size[1] // self.patch_size[1])
114 | self.num_patches = self.grid_h * self.grid_w
115 |
116 | # MAE Encoder
117 | self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, encoder_embed_dim), requires_grad=False)
118 | self.encoder_block = nn.ModuleList([
119 | Block(encoder_embed_dim, encoder_heads, self.mlp_ratio, qkv_bias=True,
120 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
121 | for _ in range(encoder_depths)
122 | ])
123 | self.encoder_norm = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
124 |
125 | # MAE Decoder
126 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
127 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
128 | self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.num_patches, decoder_embed_dim), requires_grad=False)
129 | self.decoder_block = nn.ModuleList([
130 | Block(decoder_embed_dim, decoder_heads, self.mlp_ratio, qkv_bias=True,
131 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
132 | for _ in range(decoder_depths)
133 | ])
134 | self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
135 | self.decoder_pred = nn.Linear(decoder_embed_dim, input_size, bias=True)
136 | self.initialize_weights()
137 |
138 | def forward(self, x, mask_ratio=0.8):
139 | latent, mask, ids_restore = self.forward_encoder(x, mask_ratio)
140 | pred = self.forward_decoder(latent, ids_restore)
141 | return latent, pred, mask
142 |
143 | def forward_encoder(self, x: torch.Tensor, mask_ratio: float = 0.5):
144 | # embed patches
145 | x = self.patch_embed(x)
146 |
147 | # add pos embed w/o cls token
148 | x = x + self.pos_embed[:, 1:, :]
149 |
150 | # masking: length -> length * mask_ratio
151 | x, mask, ids_restore = self.random_masking(x, mask_ratio)
152 |
153 | # append cls token
154 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
155 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
156 | x = torch.cat((cls_tokens, x), dim=1)
157 |
158 | # apply Transformer blocks
159 | for block in self.encoder_block:
160 | x = block(x)
161 |
162 | x = self.encoder_norm(x)
163 | return x, mask, ids_restore
164 |
165 | def forward_decoder(self, x, ids_restore: torch.Tensor):
166 | # embed tokens
167 | x = self.decoder_embed(x[:, 1:, :])
168 |
169 | # append mask tokens to sequence
170 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
171 | x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
172 | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
173 |
174 | # add pos embed
175 | x = x + self.decoder_pos_embed
176 |
177 | # apply Transformer blocks
178 | for block in self.decoder_block:
179 | x = block(x)
180 |
181 | x = self.decoder_norm(x)
182 |
183 | # predictor projection
184 | x = self.decoder_pred(x)
185 | return x
186 |
187 | @staticmethod
188 | def random_masking(x, mask_ratio):
189 | n, l, d = x.shape # batch, length, dim
190 | len_keep = int(l * (1 - mask_ratio))
191 |
192 | noise = torch.rand(n, l, device=x.device) # noise in [0, 1]
193 |
194 | # sort noise for each sample
195 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
196 | ids_restore = torch.argsort(ids_shuffle, dim=1)
197 |
198 | # keep the first subset
199 | ids_keep = ids_shuffle[:, :len_keep]
200 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, d))
201 |
202 | # generate the binary mask: 0 is keep, 1 is remove
203 | mask = torch.ones([n, l], device=x.device)
204 | mask[:, :len_keep] = 0
205 |
206 | mask = torch.gather(mask, dim=1, index=ids_restore)
207 | return x_masked, mask, ids_restore
208 |
209 | def initialize_weights(self):
210 | # initialization
211 | # initialize (and freeze) pos_embed by sin-cos embedding
212 | pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1],
213 | (self.grid_h, self.grid_w),
214 | cls_token=True)
215 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
216 | decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1],
217 | (self.grid_h, self.grid_w),
218 | cls_token=False)
219 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
220 |
221 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
222 | torch.nn.init.normal_(self.cls_token, std=.02)
223 | torch.nn.init.normal_(self.mask_token, std=.02)
224 |
225 | # initialize nn.Linear and nn.LayerNorm
226 | self.apply(self._init_weights)
227 |
228 | @staticmethod
229 | def _init_weights(m):
230 | if isinstance(m, nn.Linear):
231 | # we use xavier_uniform following official JAX ViT:
232 | torch.nn.init.xavier_uniform_(m.weight)
233 | if isinstance(m, nn.Linear) and m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 | elif isinstance(m, nn.LayerNorm):
236 | nn.init.constant_(m.bias, 0)
237 | nn.init.constant_(m.weight, 1.0)
238 |
239 |
240 | def frame_size(fs, second, time_window, time_step):
241 | x = np.random.randn(1, fs * second)
242 | size = fs * second
243 | step = int(time_step * fs)
244 | window = int(time_window * fs)
245 | frame = []
246 | for i in range(0, size, step):
247 | start_idx, end_idx = i, i + window
248 | sample = x[..., start_idx: end_idx]
249 | if sample.shape[-1] == window:
250 | frame.append(sample)
251 | frame = np.stack(frame, axis=1)
252 | return frame.shape[1], frame.shape[2]
253 |
254 |
255 | class NeuroNetEncoderWrapper(nn.Module):
256 | def __init__(self, fs: int, second: int, time_window: int, time_step: float,
257 | frame_backbone, patch_embed, encoder_block, encoder_norm, cls_token, pos_embed,
258 | final_length):
259 |
260 | super().__init__()
261 | self.fs, self.second = fs, second
262 | self.time_window = time_window
263 | self.time_step = time_step
264 |
265 | self.patch_embed = patch_embed
266 | self.frame_backbone = frame_backbone
267 | self.encoder_block = encoder_block
268 | self.encoder_norm = encoder_norm
269 | self.cls_token = cls_token
270 | self.pos_embed = pos_embed
271 |
272 | self.final_length = final_length
273 |
274 | def forward(self, x, semantic_token=True):
275 | # frame backbone
276 | x = self.make_frame(x)
277 | x = self.frame_backbone(x)
278 |
279 | # embed patches
280 | x = self.patch_embed(x)
281 |
282 | # add pos embed w/o cls token
283 | x = x + self.pos_embed[:, 1:, :]
284 |
285 | # append cls token
286 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
287 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
288 | x = torch.cat((cls_tokens, x), dim=1)
289 |
290 | # apply Transformer blocks
291 | for block in self.encoder_block:
292 | x = block(x)
293 |
294 | x = self.encoder_norm(x)
295 |
296 | # get semantic information
297 | x = torch.mean(x[:, 1:, :], dim=1)
298 | return x
299 |
300 | def make_frame(self, x):
301 | size = self.fs * self.second
302 | step = int(self.time_step * self.fs)
303 | window = int(self.time_window * self.fs)
304 | frame = []
305 | for i in range(0, size, step):
306 | start_idx, end_idx = i, i+window
307 | sample = x[..., start_idx: end_idx]
308 | if sample.shape[-1] == window:
309 | frame.append(sample)
310 | frame = torch.stack(frame, dim=1)
311 | return frame
312 |
313 |
314 | if __name__ == '__main__':
315 | x0 = torch.randn((50, 3000))
316 | m0 = NeuroNet(fs=100, second=30, time_window=5, time_step=0.5,
317 | encoder_embed_dim=256, encoder_depths=6, encoder_heads=8,
318 | decoder_embed_dim=128, decoder_heads=4, decoder_depths=8,
319 | projection_hidden=[1024, 512])
320 | m0.forward(x0, mask_ratio=0.5)
321 |
--------------------------------------------------------------------------------
/downstream/fine_tuning.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import mne
4 | import torch
5 | import random
6 | import argparse
7 | import warnings
8 | import numpy as np
9 | import torch.nn as nn
10 | import torch.optim as opt
11 | from mamba_ssm import Mamba
12 | from models.utils import model_size
13 | from torch.utils.data import Dataset, DataLoader
14 | from models.neuronet.model import NeuroNet, NeuroNetEncoderWrapper
15 | from sklearn.metrics import accuracy_score, f1_score
16 |
17 |
18 | warnings.filterwarnings(action='ignore')
19 |
20 |
21 | random_seed = 777
22 | np.random.seed(random_seed)
23 | torch.manual_seed(random_seed)
24 | random.seed(random_seed)
25 |
26 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
27 | torch.backends.cudnn.deterministic = True
28 | torch.backends.cudnn.benchmark = False
29 | torch.backends.cuda.enable_mem_efficient_sdp(False)
30 | torch.backends.cuda.enable_flash_sdp(False)
31 | torch.backends.cuda.enable_math_sdp(True)
32 |
33 |
34 | def get_args():
35 | file_name = 'mini'
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--n_fold', default=1, choices=[0, 1, 2, 3, 4])
38 | parser.add_argument('--ckpt_path', default=os.path.join('..', '..', '..', 'ckpt',
39 | 'SHHS', 'cm_eeg', file_name), type=str)
40 | parser.add_argument('--temporal_context_length', default=20)
41 | parser.add_argument('--window_size', default=10)
42 | parser.add_argument('--epochs', default=150, type=int)
43 | parser.add_argument('--batch_size', default=64, type=int)
44 | parser.add_argument('--lr', default=0.0005, type=float)
45 |
46 | parser.add_argument('--embed_dim', default=256)
47 | parser.add_argument('--temporal_context_modules', choices=['lstm', 'mha', 'lstm_mha', 'mamba'], default='mamba')
48 | parser.add_argument('--save_path', default=os.path.join('..', '..', '..',
49 | 'ckpt', 'SHHS', 'cm_eeg',
50 | file_name), type=str)
51 | return parser.parse_args()
52 |
53 |
54 | class TemporalContextModule(nn.Module):
55 | def __init__(self, backbone, backbone_final_length, embed_dim):
56 | super().__init__()
57 | self.backbone = self.freeze_backbone(backbone)
58 | self.backbone_final_length = backbone_final_length
59 | self.embed_dim = embed_dim
60 | self.embed_layer = nn.Sequential(
61 | nn.Linear(backbone_final_length, embed_dim),
62 | nn.BatchNorm1d(embed_dim),
63 | nn.ELU(),
64 | nn.Linear(embed_dim, embed_dim)
65 | )
66 |
67 | def apply_backbone(self, x):
68 | out = []
69 | for x_ in torch.split(x, dim=1, split_size_or_sections=1):
70 | o = self.backbone(x_.squeeze())
71 | o = self.embed_layer(o)
72 | out.append(o)
73 | out = torch.stack(out, dim=1)
74 | return out
75 |
76 | @staticmethod
77 | def freeze_backbone(backbone: nn.Module):
78 | for name, module in backbone.named_modules():
79 | if name in ['encoder_block.3.ls1', 'encoder_block.3.drop_path1', 'encoder_block.3.norm2',
80 | 'encoder_block.3.mlp', 'encoder_block.3.mlp.fc1', 'encoder_block.3.mlp.act',
81 | 'encoder_block.3.mlp.drop1', 'encoder_block.3.mlp.norm', 'encoder_block.3.mlp.fc2',
82 | 'encoder_block.3.mlp.drop2', 'encoder_block.3.ls2', 'encoder_block.3.drop_path2',
83 | 'encoder_norm']:
84 | for param in module.parameters():
85 | param.requires_grad = True
86 | else:
87 | for param in module.parameters():
88 | param.requires_grad = False
89 | return backbone
90 |
91 |
92 | class LSTM_TCM(TemporalContextModule):
93 | def __init__(self, backbone, backbone_final_length, embed_dim):
94 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim)
95 | self.rnn_layer = 2
96 | self.lstm = nn.LSTM(input_size=self.embed_dim, hidden_size=self.embed_dim, num_layers=self.rnn_layer)
97 | self.fc = nn.Linear(self.embed_dim, 5)
98 |
99 | def forward(self, x):
100 | x = self.apply_backbone(x)
101 | x, _ = self.lstm(x)
102 | x = self.fc(x)
103 | return x
104 |
105 |
106 | class MHA_TCM(TemporalContextModule):
107 | def __init__(self, backbone, backbone_final_length, embed_dim):
108 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim)
109 | self.mha_heads = 8
110 | self.mha_layer = 2
111 | self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.embed_dim, self.mha_heads),
112 | num_layers=self.mha_layer)
113 | self.fc = nn.Linear(self.embed_dim, 5)
114 |
115 | def forward(self, x):
116 | x = self.apply_backbone(x)
117 | x = self.transformer(x)
118 | x = self.fc(x)
119 | return x
120 |
121 |
122 | class LSTM_MHA_TCM(TemporalContextModule):
123 | def __init__(self, backbone, backbone_final_length, embed_dim):
124 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim)
125 | self.mha_heads = 8
126 | self.mha_layer = 2
127 | self.rnn_layer = 1
128 | self.lstm = nn.LSTM(input_size=self.embed_dim, hidden_size=self.embed_dim, num_layers=self.rnn_layer)
129 | self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(self.embed_dim, self.mha_heads),
130 | num_layers=self.mha_layer)
131 | self.fc = nn.Linear(self.embed_dim, 5)
132 |
133 | def forward(self, x):
134 | x = self.apply_backbone(x)
135 | x, _ = self.lstm(x)
136 | x = self.transformer(x)
137 | x = self.fc(x)
138 | return x
139 |
140 |
141 | class MAMBA_TCM(TemporalContextModule):
142 | def __init__(self, backbone, backbone_final_length, embed_dim):
143 | super().__init__(backbone=backbone, backbone_final_length=backbone_final_length, embed_dim=embed_dim)
144 | self.mamba_heads = 8
145 | self.mamba_layer = 1
146 | self.mamba = nn.Sequential(*[
147 | Mamba(d_model=self.embed_dim,
148 | d_state=16,
149 | d_conv=4,
150 | expand=2)
151 | for _ in range(self.mamba_layer)
152 | ])
153 | self.fc = nn.Linear(self.embed_dim, 5)
154 |
155 | def forward(self, x):
156 | x = self.apply_backbone(x)
157 | x = self.mamba(x)
158 | x = self.fc(x)
159 | return x
160 |
161 |
162 | class Trainer(object):
163 | def __init__(self, args):
164 | super().__init__()
165 | self.args = args
166 | self.ckpt_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'model', 'best_model.pth')
167 | self.ckpt = torch.load(self.ckpt_path, map_location='cpu')
168 |
169 | self.sfreq, self.rfreq = self.ckpt['hyperparameter']['sfreq'], self.ckpt['hyperparameter']['rfreq']
170 | self.ft_paths, self.eval_paths = self.ckpt['paths']['ft_paths'], self.ckpt['paths']['eval_paths']
171 | self.model = self.get_pretrained_model().to(device)
172 |
173 | self.optimizer = opt.AdamW(self.model.parameters(), lr=self.args.lr)
174 | self.scheduler = opt.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.args.epochs)
175 | self.criterion = nn.CrossEntropyLoss()
176 |
177 | def train(self):
178 | print('Checkpoint File Path : {}'.format(self.ckpt_path))
179 | train_dataset, eval_dataset = TorchDataset(paths=self.ft_paths,
180 | temporal_context_length=self.args.temporal_context_length,
181 | window_size=self.args.window_size,
182 | sfreq=self.sfreq, rfreq=self.rfreq), \
183 | TorchDataset(paths=self.eval_paths,
184 | temporal_context_length=self.args.temporal_context_length,
185 | window_size=self.args.temporal_context_length,
186 | sfreq=self.sfreq, rfreq=self.rfreq)
187 |
188 | train_dataloader, eval_dataloader = DataLoader(dataset=train_dataset,
189 | batch_size=self.args.batch_size,
190 | shuffle=True), \
191 | DataLoader(dataset=eval_dataset,
192 | batch_size=self.args.batch_size,
193 | shuffle=False)
194 |
195 | best_model_state, best_mf1 = None, 0.0
196 | best_pred, best_real = [], []
197 |
198 | for epoch in range(self.args.epochs):
199 | self.model.train()
200 | epoch_train_loss = []
201 | for batch in train_dataloader:
202 | self.optimizer.zero_grad()
203 | x, y = batch
204 | x, y = x.to(device), y.to(device)
205 |
206 | out = self.model(x)
207 | loss, pred, real = self.get_loss(out, y)
208 |
209 | epoch_train_loss.append(float(loss.detach().cpu().item()))
210 | loss.backward()
211 | self.optimizer.step()
212 |
213 | self.model.eval()
214 | epoch_test_loss = []
215 | epoch_real, epoch_pred = [], []
216 | for batch in eval_dataloader:
217 | x, y = batch
218 | x, y = x.to(device), y.to(device)
219 | try:
220 | out = self.model(x)
221 | except IndexError:
222 | continue
223 | loss, pred, real = self.get_loss(out, y)
224 | pred = torch.argmax(pred, dim=-1)
225 | epoch_real.extend(list(real.detach().cpu().numpy()))
226 | epoch_pred.extend(list(pred.detach().cpu().numpy()))
227 | epoch_test_loss.append(float(loss.detach().cpu().item()))
228 |
229 | epoch_train_loss, epoch_test_loss = np.mean(epoch_train_loss), np.mean(epoch_test_loss)
230 | eval_acc, eval_mf1 = accuracy_score(y_true=epoch_real, y_pred=epoch_pred), \
231 | f1_score(y_true=epoch_real, y_pred=epoch_pred, average='macro')
232 |
233 | print('[Epoch] : {0:03d} \t '
234 | '[Train Loss] => {1:.4f} \t '
235 | '[Evaluation Loss] => {2:.4f} \t '
236 | '[Evaluation Accuracy] => {3:.4f} \t'
237 | '[Evaluation Macro-F1] => {4:.4f}'.format(epoch + 1, epoch_train_loss, epoch_test_loss,
238 | eval_acc, eval_mf1))
239 |
240 | if best_mf1 < eval_mf1:
241 | best_mf1 = eval_mf1
242 | best_model_state = self.model.state_dict()
243 | best_pred, best_real = epoch_pred, epoch_real
244 |
245 | self.scheduler.step()
246 |
247 | self.save_ckpt(best_model_state, best_pred, best_real)
248 |
249 | def save_ckpt(self, model_state, pred, real):
250 | if not os.path.exists(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning')):
251 | os.makedirs(os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning'))
252 |
253 | save_path = os.path.join(self.args.ckpt_path, str(self.args.n_fold), 'fine_tuning', 'best_model.pth')
254 | torch.save({
255 | 'backbone_name': 'NeuroNet_FineTuning',
256 | 'model_state': model_state,
257 | 'hyperparameter': self.args.__dict__,
258 | 'result': {'real': real, 'pred': pred},
259 | 'paths': {'train_paths': self.ft_paths, 'eval_paths': self.eval_paths}
260 | }, save_path)
261 |
262 | def get_pretrained_model(self):
263 | # 1. Prepared Pretrained Model
264 | model_parameter = self.ckpt['model_parameter']
265 | pretrained_model = NeuroNet(**model_parameter)
266 | pretrained_model.load_state_dict(self.ckpt['model_state'])
267 |
268 | # 2. Encoder Wrapper
269 | backbone = NeuroNetEncoderWrapper(
270 | fs=model_parameter['fs'], second=model_parameter['second'],
271 | time_window=model_parameter['time_window'], time_step=model_parameter['time_step'],
272 | frame_backbone=pretrained_model.frame_backbone,
273 | patch_embed=pretrained_model.autoencoder.patch_embed,
274 | encoder_block=pretrained_model.autoencoder.encoder_block,
275 | encoder_norm=pretrained_model.autoencoder.encoder_norm,
276 | cls_token=pretrained_model.autoencoder.cls_token,
277 | pos_embed=pretrained_model.autoencoder.pos_embed,
278 | final_length=pretrained_model.autoencoder.embed_dim,
279 | )
280 |
281 | # 3. Temporal Context Module
282 | tcm = self.get_temporal_context_module()
283 | model = tcm(backbone=backbone,
284 | backbone_final_length=pretrained_model.autoencoder.embed_dim,
285 | embed_dim=self.args.embed_dim)
286 | return model
287 |
288 | def get_temporal_context_module(self):
289 | if self.args.temporal_context_modules == 'lstm':
290 | return LSTM_TCM
291 | if self.args.temporal_context_modules == 'mha':
292 | return MHA_TCM
293 | if self.args.temporal_context_modules == 'lstm_mha':
294 | return LSTM_MHA_TCM
295 | if self.args.temporal_context_modules == 'mamba':
296 | return MAMBA_TCM
297 |
298 | def get_loss(self, pred, real):
299 | if pred.dim() == 3:
300 | pred = pred.view(-1, pred.size(2))
301 | real = real.view(-1)
302 | loss = self.criterion(pred, real)
303 | return loss, pred, real
304 |
305 |
306 | class TorchDataset(Dataset):
307 | def __init__(self, paths, temporal_context_length, window_size,
308 | sfreq: int, rfreq: int):
309 | super().__init__()
310 | self.sfreq, self.rfreq = sfreq, rfreq
311 | self.info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1'])
312 | self.x, self.y = self.get_data(paths,
313 | temporal_context_length=temporal_context_length,
314 | window_size=window_size)
315 | self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long)
316 |
317 | def get_data(self, paths, temporal_context_length, window_size):
318 | total_x, total_y = [], []
319 | for path in paths:
320 | data = np.load(path)
321 | x, y = data['x'], data['y']
322 | x = np.expand_dims(x, axis=1)
323 | x = mne.EpochsArray(x, info=self.info)
324 | x = x.resample(self.rfreq)
325 | x = x.get_data().squeeze()
326 | x = self.many_to_many(x, temporal_context_length, window_size)
327 | y = self.many_to_many(y, temporal_context_length, window_size)
328 | total_x.append(x)
329 | total_y.append(y)
330 | total_x, total_y = np.concatenate(total_x), np.concatenate(total_y)
331 | return total_x, total_y
332 |
333 | @staticmethod
334 | def many_to_many(elements, temporal_context_length, window_size):
335 | size = len(elements)
336 | total = []
337 | if size <= temporal_context_length:
338 | return elements
339 | for i in range(0, size-temporal_context_length+1, window_size):
340 | temp = np.array(elements[i:i+temporal_context_length])
341 | total.append(temp)
342 | total.append(elements[size-temporal_context_length:size])
343 | total = np.array(total)
344 | return total
345 |
346 | def __len__(self):
347 | return len(self.y)
348 |
349 | def __getitem__(self, item):
350 | x = torch.tensor(self.x[item])
351 | y = torch.tensor(self.y[item])
352 | return x, y
353 |
354 |
355 | if __name__ == '__main__':
356 | augments = get_args()
357 | for n_fold in range(10):
358 | augments.n_fold = n_fold
359 | trainer = Trainer(augments)
360 | trainer.train()
361 |
--------------------------------------------------------------------------------