├── LICENSE ├── README.md ├── configs └── config.ini ├── data_index └── PLACEHOLDER ├── proc_data └── PLACEHOLDER ├── requirements.txt ├── src ├── build_data_index.py ├── infer.py ├── loss.py ├── model.py ├── proc_shhs.py ├── recorder.py ├── train.py └── utlis.py └── weight └── checkpoint.pt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Songchi Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Annotation of sleep depth index (SDI) by deep learning 3 | The implementation for the paper: ["Continuous Sleep Depth Index Annotation with Deep Learning Yields Novel Digital Biomarkers for Sleep Health"](https://www.nature.com/articles/s41746-025-01607-0). 4 | A web app for annotating the Sleep Depth Index is available at [here](http://183.162.233.24:10024/PSG_Sleep_depth) (with support for EDF-format input). 5 | 6 | # Requirements 7 | - Install the dependencies by: 8 | 9 | ```bash 10 | conda create -n sdi python=3.11 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | # Usage 15 | 16 | You can modify the training configs in `src/configs`, and run model training by 17 | ```bash 18 | python train.py --config ../configs/config.ini 19 | ``` 20 | 21 | After training, you can try the inference by running 22 | ```bash 23 | python infer.py --data_file YOUR_DATA(EDF) --output_file NAMED_FILE.csv 24 | ``` 25 | The resulting CSV file represents data where each row corresponds to a 30-second interval. The first column contains the Sleep Depth Index, while the second column indicates the classification of REM sleep. 26 | 27 | # Citation 28 | 29 | If you find the idea useful or use this code in your work, please cite our paper 30 | ```bibtex 31 | @article{zhou2025continuous, 32 | title={Continuous sleep depth index annotation with deep learning yields novel digital biomarkers for sleep health}, 33 | author={Zhou, Songchi and Song, Ge and Sun, Haoqi and Zhang, Deyun and Leng, Yue and Westover, M Brandon and Hong, Shenda}, 34 | journal={npj Digital Medicine}, 35 | volume={8}, 36 | number={1}, 37 | pages={203}, 38 | year={2025}, 39 | publisher={Nature Publishing Group UK London} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /configs/config.ini: -------------------------------------------------------------------------------- 1 | [TRAINING] 2 | seed = 2024 3 | device = cuda:3 4 | n_epochs = 10 5 | lr = 3e-4 6 | weight_decay = 1e-3 7 | batch_size = 512 8 | checkup_steps = 1200 9 | warmup_steps = 3000 10 | num_train_steps = 99830 11 | 12 | [DATA] 13 | data_folder = ../proc_data/ 14 | save_path = ../weight/ 15 | cohorts = mesa,cfs,mros 16 | 17 | [LOGGING] 18 | log_file = SDI_mesa_cfs_mros.log 19 | -------------------------------------------------------------------------------- /data_index/PLACEHOLDER: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczzz3/SDI/9b1a9d7015fe261dd287154098177550ed33c892/data_index/PLACEHOLDER -------------------------------------------------------------------------------- /proc_data/PLACEHOLDER: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczzz3/SDI/9b1a9d7015fe261dd287154098177550ed33c892/proc_data/PLACEHOLDER -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | mne==1.6.1 3 | numpy==1.26.4 4 | pandas==2.2.1 5 | torch==2.2.1 6 | transformers==4.39.2 7 | xmltodict==0.13.0 -------------------------------------------------------------------------------- /src/build_data_index.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import pickle 5 | 6 | data_path = "../proc_data" 7 | cohort = ['mesa', 'cfs', 'mros'] 8 | for c in cohort: 9 | 10 | c_path = os.path.join(data_path, c) 11 | data_files = os.listdir(c_path) 12 | 13 | random.seed(2024) 14 | random.shuffle(data_files) 15 | split_index = int(len(data_files) * 0.7) 16 | 17 | train_files = data_files[:split_index] 18 | test_files = data_files[split_index:] 19 | 20 | with open(os.path.join("../data_index/", c+'_train.pkl'), 'wb') as f: 21 | pickle.dump(train_files, f) 22 | with open(os.path.join("../data_index/", c+'_test.pkl'), 'wb') as f: 23 | pickle.dump(test_files, f) 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | import mne 6 | import torch 7 | from model import Net 8 | 9 | 10 | def match_channel_names(channel_names, target_keywords): 11 | matched_channels = {} 12 | for key, keywords in target_keywords.items(): 13 | matched = [] 14 | for keyword in keywords: 15 | if key == 'EEG': 16 | matched = [ch for ch in channel_names if keyword in ch] 17 | if not matched: 18 | matched = [ch for ch in channel_names if 'EEG' in ch] 19 | if not matched: 20 | matched = [ch for ch in channel_names if 'EEG' in ch or 'EEG' in ch] 21 | elif key == 'ECG': 22 | matched = [ch for ch in channel_names if ('ECG' in ch or 'EKG' in ch) and keyword in ch] 23 | if not matched: 24 | matched = [ch for ch in channel_names if 'ECG' in ch or 'EKG' in ch] 25 | elif key == 'EOG': 26 | matched = [ch for ch in channel_names if ('EOG' in ch or 'ROC' in ch or 'LOC' in ch) and keyword in ch] 27 | if not matched: 28 | matched = [ch for ch in channel_names if 'EOG' in ch or 'ROC' in ch or 'LOC' in ch] 29 | elif key == 'EMG': 30 | matched = [ch for ch in channel_names if ('EMG' in ch or 'chin' in ch.lower()) and keyword in ch] 31 | if not matched: 32 | matched = [ch for ch in channel_names if 'EMG' in ch or 'chin' in ch.lower()] 33 | if matched: 34 | matched_channels[key] = matched[0] 35 | break 36 | if not matched: 37 | matched_channels[key] = None 38 | return matched_channels 39 | 40 | def read(data_file): 41 | 42 | raw_data = mne.io.read_raw_edf(data_file, preload=True, verbose=False) 43 | raw_data.resample(sfreq=100) 44 | 45 | channel_names = raw_data.ch_names 46 | 47 | # Define the target keywords and conditions 48 | targets = { 49 | 'EEG': ['C4'], 50 | 'EOG': ['R'], 51 | 'ECG': ['R', 'EKG'], 52 | 'EMG': ['R'] 53 | } 54 | 55 | matched_channels = match_channel_names(channel_names, targets) 56 | 57 | # # Print the matched channels 58 | # for key, channel in matched_channels.items(): 59 | # print(f"{key} channel: {channel}") 60 | 61 | 62 | selected_channels = [ch for ch in matched_channels.values() if ch is not None] 63 | raw_selected = raw_data.copy().pick(selected_channels, verbose=False) 64 | 65 | epoch_duration = 30 66 | epochs = mne.make_fixed_length_epochs(raw_selected, duration=epoch_duration, preload=True, verbose=False) 67 | 68 | C4 = matched_channels['EEG'] 69 | chin = matched_channels['EMG'] 70 | EOGR = matched_channels['EOG'] 71 | ECG = matched_channels['ECG'] 72 | uV_signal = epochs.get_data(picks=[C4, chin, EOGR], units='uV') 73 | mV_signal = epochs.get_data(picks=[ECG], units='mV') 74 | 75 | combined_signal = [] 76 | for epoch_idx in range(len(uV_signal)): 77 | combined_epoch = np.concatenate((uV_signal[epoch_idx], mV_signal[epoch_idx]), axis=0) 78 | combined_signal.append(combined_epoch) 79 | 80 | combined_signal = np.array(combined_signal) 81 | 82 | # Print the shape of the resulting data 83 | # print(f"Shape of the combined signal: {combined_signal.shape}") 84 | return combined_signal 85 | 86 | def save_to_csv(data1, data2, output_file): 87 | 88 | data1_np = data1.numpy() 89 | data2_np = data2.numpy() 90 | df = pd.DataFrame({ 91 | 'Column1': data1_np, 92 | 'Column2': data2_np 93 | }) 94 | df.to_csv(output_file, index=False) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description='Process an EDF file and extract EEG, ECG, EOG, and EMG signals.') 99 | parser.add_argument('data_file', type=str, help='Path to the EDF file') 100 | parser.add_argument('output_file', type=str, help='Path to the output CSV file.') 101 | args = parser.parse_args() 102 | 103 | psg = read(args.data_file) 104 | 105 | device = torch.device("cuda") 106 | 107 | model = Net().to(device) 108 | model.load_state_dict(torch.load('../weight/checkpoint.pt')) 109 | model.eval() 110 | 111 | with torch.no_grad(): 112 | psg = torch.tensor(psg).float().to(device) 113 | logits, pred_nerm = model(psg) 114 | logits = torch.sigmoid(logits) 115 | depth = logits.squeeze(dim=-1).cpu() 116 | pred_nerm = torch.argmax(pred_nerm, dim=-1).cpu() 117 | 118 | save_to_csv(depth, pred_nerm, args.output_file) -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PairMarginRankLoss(nn.Module): 8 | def __init__(self, pair_margins=None): 9 | """ 10 | Initialize the loss function with a mapping of pair types to margins. 11 | 12 | :param pair_margins: A dictionary or any mapping where each key is a tuple representing a pair 13 | (or pair type) and the value is the margin for that pair. 14 | """ 15 | super(PairMarginRankLoss, self).__init__() 16 | # Store the pair margins 17 | pair_margins = { 18 | (0, 1): 1, (1, 0): 1, 19 | (0, 2): 1.5, (2, 0): 1.5, 20 | (0, 3): 3, (3, 0): 3, 21 | (0, 4): 1.2, (4, 0): 1.2, 22 | 23 | (1, 2): 0.5, (2, 1): 0.5, 24 | (1, 3): 2, (3, 1): 2, 25 | 26 | (2, 3): 1.5, (3, 2): 1.5, 27 | 28 | } 29 | self.pair_margins = pair_margins 30 | 31 | def forward(self, pred_depth, label): 32 | 33 | pred_depth = torch.combinations(pred_depth.squeeze(dim=1)) 34 | label = torch.combinations(label) 35 | 36 | margins = torch.tensor([self.pair_margins.get(tuple(pair), 1.0) for pair in label.tolist()]) 37 | margins = margins.to(pred_depth.device, non_blocking=True) 38 | 39 | # Exclude the relations between REM and the other sleep stagess 40 | uncertain_relationships = [[1, 4], [4, 1], [2, 4], [4, 2], [3, 4], [4, 3]] 41 | uncertain_relationships_tensor = torch.tensor(uncertain_relationships).to(pred_depth.device, non_blocking=True) 42 | label_expanded = label.unsqueeze(1).expand(-1, uncertain_relationships_tensor.size(0), -1) 43 | is_uncertain = torch.any(torch.all(label_expanded == uncertain_relationships_tensor, dim=2), dim=1) 44 | 45 | penalties = torch.relu(margins + (pred_depth[:, 1] - pred_depth[:, 0])) * (label[:, 0] > label[:, 1]) 46 | penalties += torch.relu(margins + (pred_depth[:, 0] - pred_depth[:, 1])) * (label[:, 0] < label[:, 1]) 47 | 48 | loss = torch.mean(penalties * (~is_uncertain)) 49 | return loss 50 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat, pack, unpack 7 | 8 | 9 | # https://github.com/insitro/ChannelViT 10 | class PatchEmbedPerChannel(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | sig_size: int=3000, 15 | patch_size: int=100, 16 | in_chans: int=4, 17 | embed_dim: int=512, 18 | ): 19 | super().__init__() 20 | num_patches = (sig_size // patch_size) * in_chans 21 | self.sig_size = sig_size 22 | self.patch_size = patch_size 23 | self.num_patches = num_patches 24 | 25 | self.proj = nn.Conv2d( 26 | 1, 27 | embed_dim, 28 | kernel_size=(1, patch_size), 29 | stride=(1, patch_size), 30 | ) 31 | 32 | self.channel_embed = nn.parameter.Parameter( 33 | torch.zeros(1, embed_dim, in_chans, 1) 34 | ) 35 | # trunc_normal_(self.channel_embed, std=0.02) 36 | 37 | def forward(self, x): 38 | 39 | B, Cin, S = x.shape 40 | 41 | # shared projection layer across channels 42 | x = self.proj(x.unsqueeze(1)) # B Cout Cin S 43 | # channel specific offsets 44 | x += self.channel_embed[:, :, :, :] # B Cout Cin S 45 | 46 | # preparing the output sequence 47 | x = x.flatten(2) # B Cout CinS 48 | x = x.transpose(1, 2) # B CinS Cout 49 | 50 | return x 51 | 52 | 53 | class FeedForward(nn.Module): 54 | def __init__(self, dim, num_patches, hidden_dim, dropout = 0.): 55 | super().__init__() 56 | self.net = nn.Sequential( 57 | nn.LayerNorm(dim), 58 | nn.Linear(dim, hidden_dim), 59 | nn.GELU(), 60 | nn.Dropout(dropout), 61 | nn.Linear(hidden_dim, dim), 62 | nn.Dropout(dropout) 63 | ) 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_patches, heads = 8, dim_head = 64, dropout = 0.): 70 | super().__init__() 71 | inner_dim = dim_head * heads 72 | project_out = not (heads == 1 and dim_head == dim) 73 | 74 | self.heads = heads 75 | self.scale = dim_head ** -0.5 76 | 77 | self.norm = nn.LayerNorm(dim) 78 | self.attend = nn.Softmax(dim = -1) 79 | self.dropout = nn.Dropout(dropout) 80 | 81 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 82 | 83 | self.to_out = nn.Sequential( 84 | nn.Linear(inner_dim, dim), 85 | nn.Dropout(dropout) 86 | ) if project_out else nn.Identity() 87 | 88 | def forward(self, x): 89 | 90 | x = self.norm(x) 91 | qkv = self.to_qkv(x).chunk(3, dim = -1) 92 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 93 | 94 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 95 | 96 | attn = self.attend(dots) 97 | attn = self.dropout(attn) 98 | 99 | out = torch.matmul(attn, v) 100 | out = rearrange(out, 'b h n d -> b n (h d)') 101 | return self.to_out(out) 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__(self, dim, num_patches, depth, heads, dim_head, mlp_dim, dropout = 0.): 106 | super().__init__() 107 | self.layers = nn.ModuleList([]) 108 | for _ in range(depth): 109 | self.layers.append(nn.ModuleList([ 110 | Attention(dim, num_patches, heads = heads, dim_head = dim_head, dropout = dropout), 111 | FeedForward(dim, num_patches, mlp_dim, dropout = dropout) 112 | ])) 113 | def forward(self, x): 114 | for attn, ff in self.layers: 115 | x = attn(x) + x 116 | x = ff(x) + x 117 | return x 118 | 119 | 120 | class Backbone(nn.Module): 121 | def __init__(self, *, seq_len, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 122 | super().__init__() 123 | assert (seq_len % patch_size) == 0 124 | 125 | num_patches = (seq_len // patch_size) * channels 126 | self.to_patch_embedding = PatchEmbedPerChannel( 127 | sig_size=seq_len, 128 | patch_size=patch_size, 129 | in_chans=channels, 130 | embed_dim=dim 131 | ) 132 | 133 | self.cls_token = nn.Parameter(torch.randn(dim)) 134 | 135 | # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim)) 136 | self.num_extra_tokens = 1 # cls token 137 | # self.num_extra_tokens = 0 # avg pooling 138 | self.pos_embedding = nn.Parameter( 139 | # torch.zeros(1, num_patches // channels + self.num_extra_tokens, dim) 140 | torch.randn(1, num_patches // channels + self.num_extra_tokens, dim) 141 | ) 142 | # trunc_normal_(self.pos_embedding, std=0.02) 143 | 144 | self.dropout = nn.Dropout(emb_dropout) 145 | 146 | self.transformer = Transformer(dim, num_patches, depth, heads, dim_head, mlp_dim, dropout) 147 | 148 | def interpolate_pos_encoding(self, x, L, c): 149 | 150 | # number of auxiliary dimensions before the patches 151 | if not hasattr(self, "num_extra_tokens"): 152 | # backward compatibility 153 | num_extra_tokens = 1 154 | else: 155 | num_extra_tokens = self.num_extra_tokens 156 | 157 | npatch = x.shape[1] - num_extra_tokens 158 | N = self.pos_embedding.shape[1] - num_extra_tokens 159 | 160 | if npatch == N: 161 | return self.pos_embedding 162 | 163 | class_pos_embed = self.pos_embedding[:, :num_extra_tokens] 164 | patch_pos_embed = self.pos_embedding[:, num_extra_tokens:] 165 | 166 | dim = x.shape[-1] 167 | L0 = L // self.to_patch_embedding.patch_size 168 | 169 | # see discussion at https://github.com/facebookresearch/dino/issues/8 170 | L0 += 0.1 171 | patch_pos_embed = F.interpolate( 172 | patch_pos_embed.reshape(1, N, dim).permute(0, 2, 1), 173 | size=int(L0), 174 | mode="linear", 175 | ) # torch.Size([1, 512, 30]) 176 | 177 | assert int(L0) == patch_pos_embed.shape[-1] 178 | 179 | patch_pos_embed = patch_pos_embed.permute(0, 2, 1).view(1, 1, -1, dim) 180 | # create copies of the positional embeddings for each channel 181 | patch_pos_embed = patch_pos_embed.expand(-1, c, -1, dim).reshape(1, -1, dim) 182 | 183 | return torch.cat((class_pos_embed, patch_pos_embed), dim=1) 184 | 185 | 186 | def forward(self, x): 187 | 188 | b, c, L = x.shape 189 | 190 | x = self.to_patch_embedding(x) 191 | _, n, _ = x.shape 192 | 193 | cls_tokens = repeat(self.cls_token, 'd -> b d', b = b) 194 | x, ps = pack([cls_tokens, x], 'b * d') 195 | 196 | # x += self.pos_embedding[:, :(n+1)] 197 | x = x + self.interpolate_pos_encoding(x, L, c) # torch.Size([512, 121, 512]) + torch.Size([1, 121, 512]) 198 | 199 | x = self.dropout(x) 200 | 201 | x = self.transformer(x) 202 | 203 | cls_tokens, _ = unpack(x, ps, 'b * d') 204 | return cls_tokens 205 | # return torch.mean(x, dim=1) 206 | 207 | 208 | class Net(nn.Module): 209 | 210 | def __init__(self): 211 | super(Net, self).__init__() 212 | 213 | self.net1 = Backbone( 214 | seq_len=3000, 215 | channels=4, 216 | patch_size=100, 217 | dim=512, 218 | depth=6, 219 | heads=8, 220 | mlp_dim=2048, 221 | dropout=0.1, 222 | emb_dropout=0.1, 223 | ) 224 | 225 | self.nerm_ff = nn.Sequential( 226 | nn.LayerNorm(512), 227 | nn.Linear(512, 2048), 228 | nn.GELU(), 229 | nn.Linear(2048, 512), 230 | ) 231 | self.nerm_head = nn.Sequential( 232 | nn.LayerNorm(512), 233 | nn.Linear(512, 2) 234 | ) 235 | self.depth_ff = nn.Sequential( 236 | nn.LayerNorm(512), 237 | nn.Linear(512, 2048), 238 | nn.GELU(), 239 | nn.Linear(2048, 512), 240 | ) 241 | self.depth_head = nn.Sequential( 242 | nn.LayerNorm(512), 243 | nn.Linear(512, 1), 244 | ) 245 | 246 | def forward(self, psg): 247 | psg_feat = self.net1(psg) 248 | 249 | nerm_feat = self.nerm_ff(psg_feat) + psg_feat 250 | depth_feat = self.depth_ff(psg_feat) + psg_feat 251 | 252 | return self.depth_head(depth_feat), self.nerm_head(nerm_feat) -------------------------------------------------------------------------------- /src/proc_shhs.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import mne 7 | 8 | import xmltodict 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | channel_alias = { 14 | # 'C3': ['EEG(sec)', 'EEG 2', 'EEG2', 'EEG sec', 'EEG(SEC)'], 15 | 'C4': ['EEG'], 16 | 'ECG': ['ECG'], 17 | # 'EOGL': ['EOG(L)'], 18 | 'EOGR': ['EOG(R)'], 19 | 'EMG': ['EMG'], 20 | } 21 | 22 | def extract(data_file, anno_file): 23 | 24 | visits = data_file.split('/')[-2].split('shhs')[1] 25 | record_id = data_file.split('/')[-1].split('.edf')[0].split('shhs'+visits+'-')[1] 26 | nsrr_dataset = pd.read_csv("../shhs/datasets/shhs-harmonized-dataset-0.20.0.csv") 27 | record_row = nsrr_dataset[(nsrr_dataset.loc[:, 'nsrrid']==int(record_id)) & (nsrr_dataset.loc[:, 'visitnumber']==int(visits))] 28 | NSRR_AHI = float(record_row.loc[:, ['nsrr_ahi_hp4u_aasm15']].values[0][0]) 29 | 30 | raw_data = mne.io.read_raw_edf(data_file, preload=True, verbose=False) 31 | pick_channels = {} 32 | try: 33 | for chan in channel_alias.keys(): 34 | pick_channels[chan] = list(set(channel_alias[chan]).intersection(set(raw_data.info.ch_names)))[0] 35 | except: 36 | with open("error_shhs.txt", 'a+') as f: 37 | f.write(data_file + '\t' + str(raw_data.info.ch_names) + '\n') 38 | return [], [], [], None 39 | 40 | raw_data = raw_data.pick(list(pick_channels.values())) 41 | raw_data.resample(sfreq=100) 42 | # sf = raw_data.info["sfreq"] 43 | 44 | # XML-NSRR type 45 | with open (anno_file, 'r') as f: 46 | xml_text = f.read() 47 | json_text = xmltodict.parse(xml_text) 48 | events_list = dict(json_text)['PSGAnnotation']["ScoredEvents"]['ScoredEvent'] 49 | onset = [] 50 | duration = [] 51 | desc = [] 52 | 53 | for event in events_list: 54 | desc.append(event["EventConcept"]) 55 | onset.append(float(event["Start"])) 56 | duration.append(float(event["Duration"])) 57 | 58 | anno_data = mne.Annotations(onset, duration, desc) 59 | raw_data.set_annotations(anno_data, emit_warning=False, verbose=False) 60 | 61 | annotation_desc_2_event_id = { 62 | 'Wake|0': 0, 63 | 'Stage 1 sleep|1': 1, 64 | 'Stage 2 sleep|2': 2, 65 | 'Stage 3 sleep|3': 3, 66 | 'Stage 4 sleep|4': 3, 67 | 'REM sleep|5': 4, 68 | } 69 | 70 | events_data, event_id_mapping = mne.events_from_annotations( 71 | raw_data, event_id=annotation_desc_2_event_id, 72 | chunk_duration=30., 73 | verbose=False,) 74 | 75 | tmax = 30. - 1. / raw_data.info['sfreq'] 76 | 77 | event_id = {} 78 | if np.any(np.unique(events_data[:, 2] == 0)): 79 | event_id['Wake'] = 0 80 | if np.any(np.unique(events_data[:, 2] == 1)): 81 | event_id['Stage 1'] = 1 82 | if np.any(np.unique(events_data[:, 2] == 2)): 83 | event_id['Stage 2'] = 2 84 | if np.any(np.unique(events_data[:, 2] == 3)): 85 | event_id['Stage 3 / Stage 4'] = 3 86 | if np.any(np.unique(events_data[:, 2] == 4)): 87 | event_id['REM sleep'] = 4 88 | 89 | epochs_data = mne.Epochs(raw=raw_data, 90 | events=events_data, 91 | event_id=event_id, 92 | event_repeated='merge', 93 | tmin=0., tmax=tmax, 94 | baseline=None, preload=True, verbose=False) 95 | 96 | ### 97 | # C3 = list(set(channel_alias['Fz-Cz']).intersection(set(raw_data.info.ch_names)))[0] 98 | C4 = list(set(channel_alias['C4']).intersection(set(raw_data.info.ch_names)))[0] 99 | chin = list(set(channel_alias['EMG']).intersection(set(raw_data.info.ch_names)))[0] 100 | ECG = list(set(channel_alias['ECG']).intersection(set(raw_data.info.ch_names)))[0] 101 | # EOGL = list(set(channel_alias['EOGL']).intersection(set(raw_data.info.ch_names)))[0] 102 | EOGR = list(set(channel_alias['EOGR']).intersection(set(raw_data.info.ch_names)))[0] 103 | 104 | uV_signal = epochs_data.get_data(picks=[C4, chin, EOGR], copy=True, units='uV') 105 | mV_signal = epochs_data.get_data(picks=[ECG], copy=True, units='mV') 106 | psg_signal = np.concatenate((uV_signal, mV_signal), axis=1) 107 | 108 | ######################################################################## 109 | sleep_events = {'Wake|0': 0, 'Stage 1 sleep|1': 1, 'Stage 2 sleep|2': 2, 'Stage 3 sleep|3': 3, 'Stage 4 sleep|4': 3, 'REM sleep|5': 4} 110 | apnea_events = {'Central apnea|Central Apnea': 1, 'Obstructive apnea|Obstructive Apnea': 1} 111 | hypopnea_events = {'Hypopnea|Hypopnea': 1} 112 | arousal_events = {'Arousal|Arousal ()': 1} 113 | 114 | 115 | EPOCH_LENGTH = 30 116 | epsilon = 1e-6 117 | 118 | total_epochs = int(raw_data.times[-1] // EPOCH_LENGTH) + 1 119 | excluded_epochs = set() 120 | epoch_sleep_stages = [None] * total_epochs # Sleep stages per epoch 121 | epoch_apnea_events = [0] * total_epochs 122 | epoch_hypopnea_events = [0] * total_epochs 123 | epoch_arousal_events = [0] * total_epochs 124 | 125 | for event in events_list: 126 | event_type = event["EventType"] 127 | event_concept = event["EventConcept"] 128 | start_time = float(event["Start"]) 129 | duration = float(event["Duration"]) 130 | end_time = start_time + duration 131 | start_epoch = int(start_time // EPOCH_LENGTH) 132 | end_epoch = int((start_time + duration - epsilon) // EPOCH_LENGTH) 133 | 134 | if event_type == "Stages|Stages": 135 | if event_concept not in sleep_events: 136 | for epoch in range(start_epoch, min(end_epoch + 1, total_epochs)): 137 | excluded_epochs.add(epoch) 138 | else: 139 | for epoch in range(start_epoch, min(end_epoch + 1, total_epochs)): 140 | epoch_sleep_stages[epoch] = sleep_events[event_concept] 141 | 142 | for epoch in range(start_epoch, min(end_epoch + 1, total_epochs)): 143 | epoch_start_time = epoch * EPOCH_LENGTH 144 | epoch_end_time = (epoch + 1) * EPOCH_LENGTH 145 | 146 | overlap_start = max(start_time, epoch_start_time) 147 | overlap_end = min(end_time, epoch_end_time) 148 | overlap_duration = max(0, overlap_end - overlap_start) 149 | 150 | overlap_fraction = overlap_duration / EPOCH_LENGTH 151 | 152 | if event_type == "Respiratory|Respiratory" and event_concept in apnea_events: 153 | epoch_apnea_events[epoch] = max(epoch_apnea_events[epoch], overlap_fraction) 154 | 155 | elif event_type == "Respiratory|Respiratory" and event_concept in hypopnea_events: 156 | epoch_hypopnea_events[epoch] = max(epoch_hypopnea_events[epoch], overlap_fraction) 157 | 158 | elif event_type == "Arousals|Arousals" and event_concept in arousal_events: 159 | epoch_arousal_events[epoch] = max(epoch_arousal_events[epoch], overlap_fraction) 160 | 161 | 162 | sleep_stages = [stage for epoch, stage in enumerate(epoch_sleep_stages) if epoch not in excluded_epochs] 163 | apnea_events = [event for epoch, event in enumerate(epoch_apnea_events) if epoch not in excluded_epochs] 164 | hypopnea_events = [event for epoch, event in enumerate(epoch_hypopnea_events) if epoch not in excluded_epochs] 165 | arousal_events = [event for epoch, event in enumerate(epoch_arousal_events) if epoch not in excluded_epochs] 166 | 167 | epoch_sleep_stages_np = np.array(sleep_stages, dtype=int) 168 | epoch_apnea_events_np = np.array(apnea_events, dtype=float) 169 | epoch_hypopnea_events_np = np.array(hypopnea_events, dtype=float) 170 | epoch_arousal_events_np = np.array(arousal_events, dtype=float) 171 | epochs_labels = np.stack((epoch_sleep_stages_np, epoch_apnea_events_np, epoch_hypopnea_events_np, epoch_arousal_events_np), axis=-1) 172 | 173 | return psg_signal, epochs_labels, NSRR_AHI 174 | 175 | 176 | if __name__ == '__main__': 177 | 178 | # cohort_list = ['shhs1', 'shhs2'] 179 | cohort_list = ['shhs1'] 180 | data_path = '../shhs/polysomnography/edfs/' 181 | label_path = '../shhs/polysomnography/annotations-events-nsrr/' 182 | save_path = '../proc_data/shhs/' 183 | os.makedirs(save_path, exist_ok=True) 184 | 185 | for cohort in cohort_list: 186 | 187 | cur_data_path = os.path.join(data_path, cohort) 188 | cur_label_path = os.path.join(label_path, cohort) 189 | 190 | records = os.listdir(cur_data_path) 191 | annos = [x.split('.')[0]+'-nsrr.xml' for x in records] 192 | 193 | for i in tqdm(range(len(records)), desc=cohort): 194 | 195 | if os.path.exists(os.path.join(save_path, records[i].split('.')[0]+'.npz')): 196 | continue 197 | 198 | if os.path.exists(os.path.join(cur_label_path, annos[i])): 199 | psg, oximetry, y_data, ahi = extract(os.path.join(cur_data_path, records[i]), os.path.join(cur_label_path, annos[i])) 200 | if len(psg) > 0: 201 | np.savez(os.path.join(save_path, records[i].split('.')[0]), psg=psg, oxi=oximetry, y=y_data, ahi=ahi) 202 | 203 | -------------------------------------------------------------------------------- /src/recorder.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import wraps 3 | import torch 4 | from torch import nn 5 | from .model import Attention 6 | 7 | def find_modules(nn_module, type): 8 | return [module for module in nn_module.modules() if isinstance(module, type)] 9 | 10 | class Recorder(nn.Module): 11 | def __init__(self, vit, device = None): 12 | super().__init__() 13 | self.vit = vit 14 | 15 | self.data = None 16 | self.recordings = [] 17 | self.hooks = [] 18 | self.hook_registered = False 19 | self.ejected = False 20 | self.device = device 21 | 22 | def _hook(self, _, input, output): 23 | self.recordings.append(output.clone().detach()) 24 | 25 | def _register_hook(self): 26 | modules = find_modules(self.vit.transformer, Attention) 27 | for module in modules: 28 | handle = module.attend.register_forward_hook(self._hook) 29 | self.hooks.append(handle) 30 | self.hook_registered = True 31 | 32 | def eject(self): 33 | self.ejected = True 34 | for hook in self.hooks: 35 | hook.remove() 36 | self.hooks.clear() 37 | return self.vit 38 | 39 | def clear(self): 40 | self.recordings.clear() 41 | 42 | def record(self, attn): 43 | recording = attn.clone().detach() 44 | self.recordings.append(recording) 45 | 46 | def forward(self, img): 47 | assert not self.ejected, 'recorder has been ejected, cannot be used anymore' 48 | self.clear() 49 | if not self.hook_registered: 50 | self._register_hook() 51 | 52 | pred = self.vit(img) 53 | 54 | # move all recordings to one device before stacking 55 | target_device = self.device if self.device is not None else img.device 56 | recordings = tuple(map(lambda t: t.to(target_device), self.recordings)) 57 | 58 | attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None 59 | return pred, attns 60 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import random 5 | import logging 6 | import pickle 7 | import configparser 8 | import argparse 9 | 10 | import torch 11 | from torch.utils.data import IterableDataset, DataLoader 12 | from transformers import get_cosine_schedule_with_warmup 13 | 14 | from loss import PairMarginRankLoss 15 | from model import Net 16 | 17 | 18 | def set_seed(seed): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | class PSGDataSet(IterableDataset): 29 | def __init__(self, data_folder, cohorts): 30 | super().__init__() 31 | self.folder = data_folder 32 | self.data_files = [] 33 | for c in cohorts: 34 | with open(f"../data_index/{c}_train.pkl", 'rb') as f: 35 | data = pickle.load(f) 36 | self.data_files += [os.path.join(self.folder, c, x) for x in data] 37 | random.shuffle(self.data_files) 38 | 39 | def __iter__(self): 40 | worker_info = torch.utils.data.get_worker_info() 41 | if worker_info is None: 42 | iter_data_files = self.data_files 43 | else: 44 | per_worker = int(np.ceil(len(self.data_files) / float(worker_info.num_workers))) 45 | worker_id = worker_info.id 46 | iter_start = worker_id * per_worker 47 | iter_end = min(iter_start + per_worker, len(self.data_files)) 48 | iter_data_files = self.data_files[iter_start:iter_end] 49 | 50 | for file_name in iter_data_files: 51 | data = np.load(file_name, mmap_mode='r') 52 | psg = torch.from_numpy(data['psg']).float() 53 | stage_label = torch.tensor(data['y']).long()[:, 0] 54 | nerm_label = torch.zeros_like(stage_label) 55 | nerm_label[stage_label <= 3] = 0 56 | nerm_label[stage_label == 4] = 1 57 | 58 | perm = torch.randperm(psg.size(0)) 59 | permuted_psg = psg.index_select(0, perm) 60 | permuted_labels = stage_label[perm] 61 | permuted_nerm = nerm_label[perm] 62 | for c in range(permuted_psg.shape[0]): 63 | yield permuted_psg[c], permuted_labels[c], permuted_nerm[c] 64 | 65 | def run(model, data_folder, cohorts, rank_loss, n_epoch, lr, decay, batch_size, device, checkup_steps, warmup_steps, num_train_steps, save_path): 66 | 67 | optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr, 'weight_decay': decay}]) 68 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_steps) 69 | rem_loss = torch.nn.CrossEntropyLoss() 70 | 71 | model.train() 72 | train_loss = 0 73 | train_steps = 0 74 | logging.info('==========================================') 75 | logging.info(' Start Training ') 76 | logging.info('==========================================') 77 | 78 | for epoch in range(n_epoch): 79 | dataset = PSGDataSet(data_folder, cohorts) 80 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=64, pin_memory=True) 81 | 82 | model.train() 83 | for step, batch in enumerate(dataloader): 84 | train_steps += 1 85 | psg, label, nerm = batch[0].to(device, non_blocking=True), batch[1].to(device, non_blocking=True), batch[2].to(device, non_blocking=True) 86 | 87 | pred_depth, pred_nerm = model(psg) 88 | if pred_depth.shape[0] == 1: 89 | continue 90 | 91 | loss = rank_loss(pred_depth, label) 92 | loss += rem_loss(pred_nerm, nerm) 93 | 94 | optimizer.zero_grad() 95 | loss.backward() 96 | train_loss += loss.item() 97 | optimizer.step() 98 | scheduler.step() 99 | 100 | if train_steps % checkup_steps == 0: 101 | logging.info('[%d, %5d] loss: %.10f' % (epoch, train_steps, train_loss / train_steps)) 102 | 103 | torch.save(model.to('cpu').state_dict(), os.path.join(save_path, 'epoch'+str(epoch)+'_ckpt.pt')) 104 | model.to(device) 105 | logging.info('===================================') 106 | 107 | 108 | def main(config_path): 109 | 110 | config = configparser.ConfigParser() 111 | config.read(config_path) 112 | 113 | set_seed(int(config['TRAINING']['seed'])) 114 | device = torch.device(config['TRAINING']['device']) 115 | 116 | data_folder = config['DATA']['data_folder'] 117 | save_path = config['DATA']['save_path'] 118 | cohorts = config['DATA']['cohorts'].split(',') 119 | 120 | n_epochs = int(config['TRAINING']['n_epochs']) 121 | lr = float(config['TRAINING']['lr']) 122 | weight_decay = float(config['TRAINING']['weight_decay']) 123 | batch_size = int(config['TRAINING']['batch_size']) 124 | checkup_steps = int(config['TRAINING']['checkup_steps']) 125 | warmup_steps = int(config['TRAINING']['warmup_steps']) 126 | num_train_steps = int(config['TRAINING']['num_train_steps']) 127 | 128 | model = Net().to(device) 129 | rank_loss = PairMarginRankLoss().to(device) 130 | 131 | logging.basicConfig( 132 | filename=config['LOGGING']['log_file'], 133 | level=logging.INFO, 134 | filemode='w', 135 | format='%(name)s - %(levelname)s - %(message)s' 136 | ) 137 | 138 | run(model, data_folder, cohorts, rank_loss, n_epochs, lr, weight_decay, batch_size, device, checkup_steps, warmup_steps, num_train_steps, save_path) 139 | 140 | 141 | if __name__ == '__main__': 142 | 143 | parser = argparse.ArgumentParser(description="Train a model using a config file.") 144 | parser.add_argument('--config', type=str, required=True, help="Path to the config file.") 145 | 146 | args = parser.parse_args() 147 | main(args.config) 148 | -------------------------------------------------------------------------------- /src/utlis.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import warnings 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | def norm_cdf(x): 9 | # Computes standard normal cumulative distribution function 10 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 11 | 12 | if (mean < a - 2 * std) or (mean > b + 2 * std): 13 | warnings.warn( 14 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2, 17 | ) 18 | 19 | with torch.no_grad(): 20 | l = norm_cdf((a - mean) / std) 21 | u = norm_cdf((b - mean) / std) 22 | 23 | tensor.uniform_(2 * l - 1, 2 * u - 1) 24 | tensor.erfinv_() 25 | tensor.mul_(std * math.sqrt(2.0)) 26 | tensor.add_(mean) 27 | tensor.clamp_(min=a, max=b) 28 | return tensor 29 | 30 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 31 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 32 | 33 | -------------------------------------------------------------------------------- /weight/checkpoint.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczzz3/SDI/9b1a9d7015fe261dd287154098177550ed33c892/weight/checkpoint.pt --------------------------------------------------------------------------------