├── CRT.py ├── README.md ├── base_models.py ├── data_processing.py ├── dataset ├── har_test_all.npy ├── har_test_label.npy ├── har_train_all.npy ├── har_train_label.npy ├── har_valid_all.npy └── har_valid_label.npy ├── main.py └── requirements.txt /CRT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | from base_models import MLP, resnet1d18 9 | 10 | 11 | class cnn_extractor(nn.Module): 12 | def __init__(self, dim, input_plane): 13 | super(cnn_extractor, self).__init__() 14 | self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane) 15 | 16 | def forward(self, x): 17 | x = self.cnn(x) 18 | return x 19 | 20 | 21 | class PreNorm(nn.Module): 22 | def __init__(self, dim, fn): 23 | super().__init__() 24 | self.norm = nn.LayerNorm(dim) 25 | self.fn = fn 26 | 27 | def forward(self, x, **kwargs): 28 | return self.fn(self.norm(x), **kwargs) 29 | 30 | 31 | class FeedForward(nn.Module): 32 | def __init__(self, dim, hidden_dim, dropout=0.): 33 | super().__init__() 34 | self.net = nn.Sequential( 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Dropout(dropout), 38 | nn.Linear(hidden_dim, dim), 39 | nn.Dropout(dropout) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.net(x) 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 48 | super().__init__() 49 | inner_dim = dim_head * heads 50 | project_out = not (heads == 1 and dim_head == dim) 51 | 52 | self.heads = heads 53 | self.scale = dim_head ** -0.5 54 | 55 | self.attend = nn.Softmax(dim=-1) 56 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 57 | 58 | self.to_out = nn.Sequential( 59 | nn.Linear(inner_dim, dim), 60 | nn.Dropout(dropout) 61 | ) if project_out else nn.Identity() 62 | 63 | def forward(self, x): 64 | qkv = self.to_qkv(x).chunk(3, dim=-1) 65 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 66 | 67 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 68 | 69 | attn = self.attend(dots) 70 | 71 | out = torch.matmul(attn, v) 72 | out = rearrange(out, 'b h n d -> b n (h d)') 73 | return self.to_out(out) 74 | 75 | 76 | class Transformer(nn.Module): 77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 78 | super().__init__() 79 | self.layers = nn.ModuleList([]) 80 | for _ in range(depth): 81 | self.layers.append(nn.ModuleList([ 82 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 83 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 84 | ])) 85 | 86 | def forward(self, x): 87 | for attn, ff in self.layers: 88 | x = attn(x) + x 89 | x = ff(x) + x 90 | return x 91 | 92 | 93 | class TFR(nn.Module): 94 | def __init__(self, seq_len, patch_len, num_classes, dim, depth, heads, mlp_dim, channels=12, 95 | dim_head=64, dropout=0., emb_dropout=0.): 96 | ''' 97 | The encoder of CRT 98 | ''' 99 | super().__init__() 100 | 101 | assert seq_len % (4 * patch_len) == 0, \ 102 | 'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.' 103 | 104 | num_patches = seq_len // patch_len 105 | patch_dim = channels * patch_len 106 | 107 | self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len), 108 | Rearrange('b n c p1 -> (b n) c p1')) 109 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim)) 110 | self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim)) 111 | self.cls_token = nn.Parameter(torch.randn(1, 3, dim)) 112 | self.dropout = nn.Dropout(emb_dropout) 113 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 114 | 115 | self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data 116 | self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data 117 | self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data 118 | 119 | def forward(self, x): 120 | batch, _, time_steps = x.shape 121 | # t, m, p refers to temporal features, magnitude featuers, phase features respectively 122 | # Assuming that the length of temporal data is L, then the magnitude and phase data are set as L // 2 here. The length can be adjusted by users. 123 | t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:] 124 | t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p) 125 | patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1), 126 | Rearrange('(b n) c 1 -> b n c', b=batch)) 127 | 128 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch) 129 | 130 | x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)), 131 | cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)), 132 | cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1) 133 | 134 | b, t, c = x.shape # t = time_steps + 3 135 | time_steps = t - 3 136 | t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2 137 | x[:m_token_idx] += self.modal_embedding[:1] 138 | x[m_token_idx: p_token_idx] += self.modal_embedding[1:2] 139 | x[p_token_idx: ] += self.modal_embedding[2:] 140 | x += self.pos_embedding[:, : t] 141 | x = self.dropout(x) 142 | x = self.transformer(x) 143 | t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx] 144 | avg = (t_token + m_token + p_token) / 3 145 | return avg 146 | 147 | 148 | def TFR_Encoder(seq_len, patch_len, dim, num_class, in_dim): 149 | vit = TFR(seq_len=seq_len, 150 | patch_len=patch_len, 151 | num_classes=num_class, 152 | dim=dim, 153 | depth=6, 154 | heads=8, 155 | mlp_dim=dim, 156 | dropout=0.2, 157 | emb_dropout=0.1, 158 | channels=in_dim) 159 | return vit 160 | 161 | class CRT(nn.Module): 162 | def __init__( 163 | self, 164 | encoder, 165 | decoder_dim, 166 | decoder_depth=2, 167 | decoder_heads=8, 168 | decoder_dim_head=64, 169 | patch_len = 20, 170 | in_dim=12 171 | ): 172 | super().__init__() 173 | self.encoder = encoder 174 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 175 | self.to_patch = encoder.to_patch 176 | pixel_values_per_patch = in_dim * patch_len 177 | 178 | # decoder parameters 179 | self.modal_embedding = self.encoder.modal_embedding 180 | self.mask_token = nn.Parameter(torch.randn(3, decoder_dim)) 181 | self.decoder = Transformer(dim=decoder_dim, 182 | depth=decoder_depth, 183 | heads=decoder_heads, 184 | dim_head=decoder_dim_head, 185 | mlp_dim=decoder_dim) 186 | self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) 187 | self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)]) 188 | self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)]) 189 | 190 | def IDC_loss(self, tokens, encoded_tokens): 191 | ''' 192 | :param tokens: tokens before Transformer 193 | :param encoded_tokens: tokens after Transformer 194 | :return: 195 | ''' 196 | B, T, D = tokens.shape 197 | tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1) 198 | encoded_tokens = encoded_tokens.transpose(2, 1) 199 | cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens)) 200 | mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device) 201 | cross_mul = cross_mul * mask 202 | return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1) 203 | 204 | 205 | def forward(self, x, mask_ratio=0.75, beta = 1e-4): 206 | device = x.device 207 | patches = self.to_patch[0](x) 208 | batch, num_patches, c, length = patches.shape 209 | 210 | num_masked = int(mask_ratio * num_patches) 211 | 212 | # masked_indices1: masked index of temporal features 213 | # masked_indices2: masked index of spectral features 214 | rand_indices1 = torch.randperm(num_patches // 2, device=device) 215 | masked_indices1 = rand_indices1[: num_masked // 2].sort()[0] 216 | unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0] 217 | rand_indices2 = torch.randperm(num_patches // 4, device=device) 218 | masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0] 219 | rand_indices = torch.cat((masked_indices1, unmasked_indices1, 220 | masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2, 221 | masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3)) 222 | 223 | masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0] 224 | unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0] 225 | 226 | # t, m, p refer to temporal, magnitude, phase 227 | tpatches = patches[:, : num_patches // 2, :, :] 228 | mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :] 229 | 230 | # 1. Generate tokens from patches via CNNs. 231 | unmasked_tpatches = tpatches[:, unmasked_indices1, :, :] 232 | unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :] 233 | t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches) 234 | t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens) 235 | Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1), 236 | Rearrange('(b n) c 1 -> b n c', b=batch)) 237 | t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens) 238 | ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone() 239 | 240 | # 2. Add three cls_tokens before temporal, magnitude and phase tokens. 241 | cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch) 242 | tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens, 243 | cls_tokens[:, 1:2, :], m_tokens, 244 | cls_tokens[:, 2:3, :], p_tokens), dim=1) 245 | 246 | # 3. Generate Positional Embeddings. 247 | t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1 248 | pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :], 249 | self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3], 250 | self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :], 251 | self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4], 252 | self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1) 253 | 254 | # 4. Generate Domain-type Embedding 255 | modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1), 256 | repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1), 257 | repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1) 258 | 259 | tokens = tokens + pos_embedding + modal_embedding 260 | 261 | encoded_tokens = self.encoder.transformer(tokens) 262 | 263 | t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1 264 | 265 | idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1))) 266 | 267 | decoder_tokens = encoded_tokens 268 | 269 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above 270 | mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t) 271 | mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2) 272 | mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2) 273 | mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1) 274 | 275 | # mask_tokens = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_f+masked_num_t) 276 | decoder_pos_emb = self.decoder_pos_emb(torch.cat( 277 | (masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4))) 278 | 279 | mask_tokens = mask_tokens + decoder_pos_emb 280 | # concat the masked tokens to the decoder tokens and attend with decoder 281 | decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1) 282 | decoded_tokens = self.decoder(decoder_tokens) 283 | 284 | mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:] 285 | 286 | pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)) 287 | pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1)) 288 | pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1)) 289 | pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1) 290 | 291 | recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)')) 292 | 293 | # print(float(recon_loss), '....', float(info_loss)) 294 | return recon_loss + beta * idc_loss 295 | 296 | class Model(nn.Module): 297 | def __init__(self, seq_len, patch_len, dim, num_class, in_dim): 298 | super(Model, self).__init__() 299 | self.encoder = TFR_Encoder(seq_len=seq_len, 300 | patch_len=patch_len, 301 | dim=dim, 302 | num_class=num_class, 303 | in_dim=in_dim) 304 | self.crt = CRT(encoder=self.encoder, 305 | decoder_dim=dim, 306 | in_dim=in_dim, 307 | patch_len=patch_len) 308 | self.classifier = MLP(dim, dim//2, num_class) 309 | 310 | def forward(self, x, ssl = False, ratio = 0.5): 311 | if ssl == False: 312 | features = self.encoder(x) 313 | return self.classifier(features) 314 | return self.crt(x, mask_ratio=ratio) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer 2 | The official implementation for our TNNLS paper [Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer](https://arxiv.org/abs/2205.09928). 3 | 4 | ## Overview of Cross Reconstruction Transformer 5 | image 6 | 7 | image 8 | 9 | ## Getting Started 10 | ### Installation 11 | Git clone our repository, and install the required packages with the following command 12 | ``` 13 | git clone https://github.com/BobZwr/Cross-Reconstruction-Transformer.git 14 | cd Cross-Reconstruction-Transformer 15 | pip install -r requirements.txt 16 | ``` 17 | We use torch=1.13.0. 18 | 19 | ### Processing Data (Optional) 20 | We provide `data_processing.py` to generate phase and magnitude information based on the time-domain data. You can modify this file to adapt it to your own datasets. 21 | 22 | ## Training and Evaluating 23 | We provide the sample script for training and evaluating our CRT 24 | ``` 25 | # For Training: 26 | python main.py --ssl True --sl True --load True --seq_len 256 --patch_len 8 --in_dim 9 --n_classes 6 27 | ``` 28 | 29 | ``` 30 | # For Testing: 31 | python main.py --ssl False --sl False --load False --seq_len 256 --patch_len 8 --in_dim 9 --n_classes 6 32 | ``` 33 | We also provide a subset of HAR dataset for training and testing. 34 | 35 | If you found the codes and datasets are useful, please cite our paper 36 | ``` 37 | @article{zhang2023self, 38 | title={Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer}, 39 | author={Zhang, Wenrui and Yang, Ling and Geng, Shijia and Hong, Shenda}, 40 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 41 | year={2023}, 42 | publisher={IEEE} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /base_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class SSLDataSet(Dataset): 11 | def __init__(self, data): 12 | super(SSLDataSet, self).__init__() 13 | self.data = data 14 | 15 | def __getitem__(self, idx): 16 | return torch.tensor(self.data[idx], dtype=torch.float) 17 | 18 | def __len__(self): 19 | return self.data.shape[0] 20 | 21 | class FTDataSet(Dataset): 22 | def __init__(self, data, label, multi_label=False): 23 | super(FTDataSet, self).__init__() 24 | self.data = data 25 | self.label = label 26 | self.multi_label = multi_label 27 | 28 | def __getitem__(self, index): 29 | if self.multi_label: 30 | return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float)) 31 | else: 32 | return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long)) 33 | 34 | def __len__(self): 35 | return self.data.shape[0] 36 | 37 | # Resnet 1d 38 | 39 | def conv(in_planes, out_planes, stride=1, kernel_size=3): 40 | "convolution with padding 自动使用zeros进行padding" 41 | return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 42 | padding=(kernel_size - 1) // 2, bias=False) 43 | 44 | class ZeroPad1d(nn.Module): 45 | def __init__(self, pad_left, pad_right): 46 | super().__init__() 47 | self.pad_left = pad_left 48 | self.pad_right = pad_right 49 | 50 | def forward(self, x): 51 | return F.pad(x, (self.pad_left, self.pad_right)) 52 | 53 | 54 | class BasicBlock1d(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super().__init__() 59 | 60 | self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=3) 61 | self.bn1 = nn.BatchNorm1d(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | self.conv2 = conv(planes, planes, kernel_size=3) 65 | self.bn2 = nn.BatchNorm1d(planes) 66 | 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck1d(nn.Module): 90 | """Bottleneck for ResNet52 ...""" 91 | expansion = 4 92 | 93 | def __init__(self, inplanes, planes, stride=1, downsample=None): 94 | super().__init__() 95 | kernel_size = 3 96 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) 97 | self.bn1 = nn.BatchNorm1d(planes) 98 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride, 99 | padding=(kernel_size - 1) // 2, bias=False) 100 | self.bn2 = nn.BatchNorm1d(planes) 101 | self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False) 102 | self.bn3 = nn.BatchNorm1d(planes * 4) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.downsample = downsample 105 | self.stride = stride 106 | 107 | def forward(self, x): 108 | residual = x 109 | 110 | out = self.conv1(x) 111 | out = self.bn1(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv2(out) 115 | out = self.bn2(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv3(out) 119 | out = self.bn3(out) 120 | 121 | if self.downsample is not None: 122 | residual = self.downsample(x) 123 | 124 | out += residual 125 | out = self.relu(out) 126 | 127 | return out 128 | 129 | class ResNet1d(nn.Module): 130 | '''1d adaptation of the torchvision resnet''' 131 | 132 | def __init__(self, block, layers, kernel_size=3, input_channels=12, inplanes=64, 133 | fix_feature_dim=False, kernel_size_stem=None, stride_stem=2, pooling_stem=True, 134 | stride=2): 135 | super(ResNet1d, self).__init__() 136 | 137 | self.inplanes = inplanes 138 | layers_tmp = [] 139 | if kernel_size_stem is None: 140 | kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size 141 | 142 | # conv-bn-relu (basic feature extraction) 143 | layers_tmp.append(nn.Conv1d(input_channels, inplanes, 144 | kernel_size=kernel_size_stem, 145 | stride=stride_stem, 146 | padding=(kernel_size_stem - 1) // 2, bias=False)) 147 | layers_tmp.append(nn.BatchNorm1d(inplanes)) 148 | layers_tmp.append(nn.ReLU(inplace=True)) 149 | 150 | if pooling_stem is True: 151 | layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) 152 | 153 | for i, l in enumerate(layers): 154 | if i == 0: 155 | layers_tmp.append(self._make_block(block, inplanes, layers[0])) 156 | else: 157 | layers_tmp.append( 158 | self._make_block(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i], 159 | stride=stride)) 160 | 161 | self.feature_extractor = nn.Sequential(*layers_tmp) 162 | 163 | def _make_block(self, block, planes, blocks, stride=1, kernel_size=3): 164 | down_sample = None 165 | 166 | # 注定会进行下采样 167 | if stride != 1 or self.inplanes != planes * block.expansion: 168 | down_sample = nn.Sequential( 169 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 170 | nn.BatchNorm1d(planes * block.expansion), 171 | ) 172 | 173 | layers = [] 174 | layers.append(block(self.inplanes, planes, stride, down_sample)) 175 | self.inplanes = planes * block.expansion 176 | 177 | for i in range(1, blocks): 178 | layers.append(block(self.inplanes, planes)) 179 | 180 | return nn.Sequential(*layers) 181 | 182 | def forward(self, x): 183 | return self.feature_extractor(x) 184 | 185 | def resnet1d14(inplanes, input_channels): 186 | return ResNet1d(BasicBlock1d, [2,2,2], inplanes=inplanes, input_channels=input_channels) 187 | 188 | def resnet1d18(**kwargs): 189 | return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs) 190 | 191 | # MLP 192 | class MLP(nn.Module): 193 | def __init__(self, in_channels, hidden_channels, n_classes, bn = True): 194 | super(MLP, self).__init__() 195 | self.in_channels = in_channels 196 | self.n_classes = n_classes 197 | self.hidden_channels = hidden_channels 198 | self.fc1 = nn.Linear(self.in_channels, self.hidden_channels) 199 | self.fc2 = nn.Linear(self.hidden_channels, self.n_classes) 200 | self.ac = nn.ReLU() 201 | self.bn = nn.BatchNorm1d(hidden_channels) 202 | 203 | def forward(self, x): 204 | hidden = self.fc1(x) 205 | hidden = self.ac(hidden) 206 | hidden = self.bn(hidden) 207 | out = self.fc2(hidden) 208 | 209 | return out 210 | 211 | # Time-steps features -> aggregated features 212 | class Flatten(nn.Module): 213 | def __init__(self): 214 | super(Flatten, self).__init__() 215 | 216 | def forward(self, tensor): 217 | b = tensor.size(0) 218 | return tensor.reshape(b, -1) 219 | 220 | class AdaptiveConcatPool1d(nn.Module): 221 | "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`." 222 | 223 | def __init__(self, sz=None): 224 | "Output will be 2*sz or 2 if sz is None" 225 | super().__init__() 226 | sz = sz or 1 227 | self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) 228 | 229 | def forward(self, x): 230 | """x is shaped of B, C, T""" 231 | return torch.cat([self.mp(x), self.ap(x), x[..., -1:]], 1) 232 | 233 | def bn_drop_lin(n_in, n_out, bn, p, actn): 234 | "`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`" 235 | layers = list() 236 | 237 | if bn: 238 | layers.append(nn.BatchNorm1d(n_in)) 239 | 240 | if p > 0.: 241 | layers.append(nn.Dropout(p=p)) 242 | 243 | layers.append(nn.Linear(n_in, n_out)) 244 | 245 | if actn is not None: 246 | layers.append(actn) 247 | 248 | return layers 249 | 250 | def create_head1d(nf: int, nc: int, lin_ftrs=[512, ], dropout=0.5, bn: bool = True, act="relu"): 251 | lin_ftrs = [3 * nf] + lin_ftrs + [nc] 252 | 253 | activations = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None] 254 | layers = [AdaptiveConcatPool1d(), Flatten()] 255 | 256 | for ni, no, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], activations): 257 | layers += bn_drop_lin(ni, no, bn, dropout, actn) 258 | 259 | layers += [nn.Sigmoid()] 260 | 261 | return nn.Sequential(*layers) 262 | -------------------------------------------------------------------------------- /data_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | data_path = '...' 5 | 6 | time = np.load(data_path) 7 | 8 | freq = np.fft.fft(time)[:,:,:time.shape[-1] // 2] 9 | 10 | a = freq.real 11 | b = freq.imag 12 | 13 | magnitude = np.abs(freq) 14 | 15 | phase = np.zeros_like(a) 16 | phase[a > 0] = np.arctan(b / a)[a > 0] 17 | phase[a < 0] = (np.arctan(b / a) + np.sign(b) * np.pi)[a < 0] 18 | phase[a == 0] = (np.sign(b) * np.pi / 2)[a == 0] 19 | 20 | data = np.concatenate([time, magnitude, phase], -1) 21 | 22 | np.save(data, os.path.join(data_path, 'time_freq_data.npy')) 23 | -------------------------------------------------------------------------------- /dataset/har_test_all.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_test_all.npy -------------------------------------------------------------------------------- /dataset/har_test_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_test_label.npy -------------------------------------------------------------------------------- /dataset/har_train_all.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_train_all.npy -------------------------------------------------------------------------------- /dataset/har_train_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_train_label.npy -------------------------------------------------------------------------------- /dataset/har_valid_all.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_valid_all.npy -------------------------------------------------------------------------------- /dataset/har_valid_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_valid_label.npy -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import random 4 | 5 | from tqdm import tqdm, trange 6 | from sklearn.metrics import roc_auc_score 7 | 8 | from CRT import CRT, TFR_Encoder, Model 9 | from base_models import SSLDataSet, FTDataSet 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.utils.data import DataLoader 16 | 17 | 18 | 19 | def self_supervised_learning(model, X, n_epoch, lr, batch_size, device, min_ratio=0.3, max_ratio=0.8): 20 | optimizer = optim.Adam(model.parameters(), lr) 21 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20) 22 | 23 | model.to(device) 24 | model.train() 25 | 26 | dataset = SSLDataSet(X) 27 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 28 | 29 | losses = [] 30 | 31 | pbar = trange(n_epoch) 32 | for _ in pbar: 33 | for batch in dataloader: 34 | x = batch.to(device) 35 | loss = model(x, ssl=True, ratio=max(min_ratio, min(max_ratio, _ / n_epoch))) 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | losses.append(float(loss)) 40 | scheduler.step(_) 41 | pbar.set_description(str(sum(losses) / len(losses))) 42 | torch.save(model.to('cpu'), 'Pretrained_Model.pkl') 43 | 44 | def finetuning(model, train_set, valid_set, n_epoch, lr, batch_size, device, multi_label=True): 45 | # multi_label: whether the classification task is a multi-label task. 46 | model.train() 47 | model.to(device) 48 | 49 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 50 | 51 | loss_func = nn.BCEWithLogitsLoss() if multi_label else nn.CrossEntropyLoss() 52 | 53 | for stage in range(2): 54 | # stage0: finetuning only classifier; stage1: finetuning whole model 55 | best_auc = 0 56 | step = 0 57 | if stage == 0: 58 | min_lr = 1e-6 59 | optimizer = optim.Adam(model.classifier.parameters(), lr=lr) 60 | else: 61 | min_lr = 1e-8 62 | optimizer = optim.Adam(model.parameters(), lr=lr/2) 63 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, mode = 'max', factor=0.8, min_lr=min_lr) 64 | pbar = trange(n_epoch) 65 | for _ in pbar: 66 | for batch_idx, batch in enumerate(train_loader): 67 | step += 1 68 | x, y = tuple(t.to(device) for t in batch) 69 | pred = model(x) 70 | loss = loss_func(pred, y) 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | if step % 10 == 0: 75 | valid_auc = test(model, valid_set, batch_size, multi_label) 76 | pbar.set_description('Best Validation AUC: {:.4f} --------- AUC on this step: {:.4f}'.format(best_auc, valid_auc)) 77 | if valid_auc > best_auc: 78 | best_auc = valid_auc 79 | torch.save(model, 'Finetuned_Model.pkl') 80 | scheduler.step(best_auc) 81 | 82 | def test(model, dataset, batch_size, multi_label): 83 | model.eval() 84 | testloader = DataLoader(dataset, batch_size=batch_size) 85 | 86 | pred_prob = [] 87 | with torch.no_grad(): 88 | for batch in testloader: 89 | x, y = tuple(t.to(device) for t in batch) 90 | pred = model(x) 91 | pred = torch.sigmoid(pred) if multi_label else F.softmax(pred, dim=1) 92 | pred_prob.extend([i.cpu().detach().numpy().tolist() for i in pred]) 93 | auc = roc_auc_score(dataset.label, pred_prob, multi_class='ovr') 94 | 95 | print('AUC is {:.2f}'.format(auc * 100)) 96 | print('More metrics can be added in the test function.') 97 | model.train() 98 | return auc 99 | 100 | def str2bool(v): 101 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 102 | return True 103 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 104 | return False 105 | else: 106 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 107 | 108 | def set_seed(seed): 109 | torch.manual_seed(seed) 110 | torch.cuda.manual_seed(seed) 111 | torch.cuda.manual_seed_all(seed) 112 | np.random.seed(seed) # Numpy module. 113 | random.seed(seed) # Python random module. 114 | torch.backends.cudnn.benchmark = False 115 | torch.backends.cudnn.deterministic = True 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--ssl", type=str2bool, default=False) 120 | parser.add_argument("--sl", type=str2bool, default=True) 121 | parser.add_argument("--load", type=str2bool, default=True) 122 | parser.add_argument("--test", type=str2bool, default=True) 123 | # all default values of parameters are for PTB-XL 124 | parser.add_argument("--seq_len", type=int, default=10000) 125 | parser.add_argument("--patch_len", type=int, default=20) 126 | parser.add_argument("--dim", type=int, default=128) 127 | parser.add_argument("--in_dim", type=int, default=12) 128 | parser.add_argument("--n_classes", type=int, default=5) 129 | opt = parser.parse_args() 130 | 131 | set_seed(0) 132 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 133 | 134 | seq_len = opt.seq_len 135 | patch_len = opt.patch_len 136 | dim = opt.dim 137 | in_dim = opt.in_dim 138 | n_classes = opt.n_classes 139 | 140 | if opt.ssl: 141 | model = Model(seq_len, patch_len, dim, n_classes, in_dim).to(device) 142 | X = np.load('./dataset/har_train_all.npy') 143 | self_supervised_learning(model, X, 100, 1e-3, 128, device) 144 | if opt.load: 145 | model = torch.load('Pretrained_Model.pkl', map_location=device) 146 | else: 147 | model = Model(seq_len, patch_len, dim, n_classes, in_dim).to(device) 148 | if opt.sl: 149 | train_X, train_y = np.load('./dataset/har_train_all.npy'), np.load('./dataset/har_train_label.npy') 150 | valid_X, valid_y = np.load('./dataset/har_valid_all.npy'), np.load('./dataset/har_valid_label.npy') 151 | TrainSet = FTDataSet(train_X, train_y, multi_label=False) # multi_label = True when the dataset is PTBXL 152 | ValidSet = FTDataSet(valid_X, valid_y, multi_label=False) 153 | finetuning(model, TrainSet, ValidSet, 100, 1e-3, 128, device, multi_label=False) 154 | if opt.test: 155 | test_X, test_y = np.load('./dataset/har_test_all.npy'), np.load('./dataset/har_test_label.npy') 156 | TestSet = FTDataSet(test_X, test_y, multi_label=False) 157 | model = torch.load('Finetuned_Model.pkl', map_location=device) 158 | test(model, TestSet, 100, multi_label=False) 159 | 160 | 161 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.5.0 2 | numpy==1.23.4 3 | scikit-learn==1.1.2 4 | tqdm==4.64.0 5 | --------------------------------------------------------------------------------