├── CRT.py
├── README.md
├── base_models.py
├── data_processing.py
├── dataset
├── har_test_all.npy
├── har_test_label.npy
├── har_train_all.npy
├── har_train_label.npy
├── har_valid_all.npy
└── har_valid_label.npy
├── main.py
└── requirements.txt
/CRT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 | from base_models import MLP, resnet1d18
9 |
10 |
11 | class cnn_extractor(nn.Module):
12 | def __init__(self, dim, input_plane):
13 | super(cnn_extractor, self).__init__()
14 | self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane)
15 |
16 | def forward(self, x):
17 | x = self.cnn(x)
18 | return x
19 |
20 |
21 | class PreNorm(nn.Module):
22 | def __init__(self, dim, fn):
23 | super().__init__()
24 | self.norm = nn.LayerNorm(dim)
25 | self.fn = fn
26 |
27 | def forward(self, x, **kwargs):
28 | return self.fn(self.norm(x), **kwargs)
29 |
30 |
31 | class FeedForward(nn.Module):
32 | def __init__(self, dim, hidden_dim, dropout=0.):
33 | super().__init__()
34 | self.net = nn.Sequential(
35 | nn.Linear(dim, hidden_dim),
36 | nn.GELU(),
37 | nn.Dropout(dropout),
38 | nn.Linear(hidden_dim, dim),
39 | nn.Dropout(dropout)
40 | )
41 |
42 | def forward(self, x):
43 | return self.net(x)
44 |
45 |
46 | class Attention(nn.Module):
47 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
48 | super().__init__()
49 | inner_dim = dim_head * heads
50 | project_out = not (heads == 1 and dim_head == dim)
51 |
52 | self.heads = heads
53 | self.scale = dim_head ** -0.5
54 |
55 | self.attend = nn.Softmax(dim=-1)
56 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
57 |
58 | self.to_out = nn.Sequential(
59 | nn.Linear(inner_dim, dim),
60 | nn.Dropout(dropout)
61 | ) if project_out else nn.Identity()
62 |
63 | def forward(self, x):
64 | qkv = self.to_qkv(x).chunk(3, dim=-1)
65 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
66 |
67 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
68 |
69 | attn = self.attend(dots)
70 |
71 | out = torch.matmul(attn, v)
72 | out = rearrange(out, 'b h n d -> b n (h d)')
73 | return self.to_out(out)
74 |
75 |
76 | class Transformer(nn.Module):
77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
78 | super().__init__()
79 | self.layers = nn.ModuleList([])
80 | for _ in range(depth):
81 | self.layers.append(nn.ModuleList([
82 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
83 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
84 | ]))
85 |
86 | def forward(self, x):
87 | for attn, ff in self.layers:
88 | x = attn(x) + x
89 | x = ff(x) + x
90 | return x
91 |
92 |
93 | class TFR(nn.Module):
94 | def __init__(self, seq_len, patch_len, num_classes, dim, depth, heads, mlp_dim, channels=12,
95 | dim_head=64, dropout=0., emb_dropout=0.):
96 | '''
97 | The encoder of CRT
98 | '''
99 | super().__init__()
100 |
101 | assert seq_len % (4 * patch_len) == 0, \
102 | 'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.'
103 |
104 | num_patches = seq_len // patch_len
105 | patch_dim = channels * patch_len
106 |
107 | self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len),
108 | Rearrange('b n c p1 -> (b n) c p1'))
109 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim))
110 | self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim))
111 | self.cls_token = nn.Parameter(torch.randn(1, 3, dim))
112 | self.dropout = nn.Dropout(emb_dropout)
113 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
114 |
115 | self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data
116 | self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data
117 | self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data
118 |
119 | def forward(self, x):
120 | batch, _, time_steps = x.shape
121 | # t, m, p refers to temporal features, magnitude featuers, phase features respectively
122 | # Assuming that the length of temporal data is L, then the magnitude and phase data are set as L // 2 here. The length can be adjusted by users.
123 | t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:]
124 | t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p)
125 | patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1),
126 | Rearrange('(b n) c 1 -> b n c', b=batch))
127 |
128 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch)
129 |
130 | x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)),
131 | cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)),
132 | cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1)
133 |
134 | b, t, c = x.shape # t = time_steps + 3
135 | time_steps = t - 3
136 | t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2
137 | x[:m_token_idx] += self.modal_embedding[:1]
138 | x[m_token_idx: p_token_idx] += self.modal_embedding[1:2]
139 | x[p_token_idx: ] += self.modal_embedding[2:]
140 | x += self.pos_embedding[:, : t]
141 | x = self.dropout(x)
142 | x = self.transformer(x)
143 | t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx]
144 | avg = (t_token + m_token + p_token) / 3
145 | return avg
146 |
147 |
148 | def TFR_Encoder(seq_len, patch_len, dim, num_class, in_dim):
149 | vit = TFR(seq_len=seq_len,
150 | patch_len=patch_len,
151 | num_classes=num_class,
152 | dim=dim,
153 | depth=6,
154 | heads=8,
155 | mlp_dim=dim,
156 | dropout=0.2,
157 | emb_dropout=0.1,
158 | channels=in_dim)
159 | return vit
160 |
161 | class CRT(nn.Module):
162 | def __init__(
163 | self,
164 | encoder,
165 | decoder_dim,
166 | decoder_depth=2,
167 | decoder_heads=8,
168 | decoder_dim_head=64,
169 | patch_len = 20,
170 | in_dim=12
171 | ):
172 | super().__init__()
173 | self.encoder = encoder
174 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
175 | self.to_patch = encoder.to_patch
176 | pixel_values_per_patch = in_dim * patch_len
177 |
178 | # decoder parameters
179 | self.modal_embedding = self.encoder.modal_embedding
180 | self.mask_token = nn.Parameter(torch.randn(3, decoder_dim))
181 | self.decoder = Transformer(dim=decoder_dim,
182 | depth=decoder_depth,
183 | heads=decoder_heads,
184 | dim_head=decoder_dim_head,
185 | mlp_dim=decoder_dim)
186 | self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
187 | self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)])
188 | self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)])
189 |
190 | def IDC_loss(self, tokens, encoded_tokens):
191 | '''
192 | :param tokens: tokens before Transformer
193 | :param encoded_tokens: tokens after Transformer
194 | :return:
195 | '''
196 | B, T, D = tokens.shape
197 | tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1)
198 | encoded_tokens = encoded_tokens.transpose(2, 1)
199 | cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens))
200 | mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device)
201 | cross_mul = cross_mul * mask
202 | return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1)
203 |
204 |
205 | def forward(self, x, mask_ratio=0.75, beta = 1e-4):
206 | device = x.device
207 | patches = self.to_patch[0](x)
208 | batch, num_patches, c, length = patches.shape
209 |
210 | num_masked = int(mask_ratio * num_patches)
211 |
212 | # masked_indices1: masked index of temporal features
213 | # masked_indices2: masked index of spectral features
214 | rand_indices1 = torch.randperm(num_patches // 2, device=device)
215 | masked_indices1 = rand_indices1[: num_masked // 2].sort()[0]
216 | unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0]
217 | rand_indices2 = torch.randperm(num_patches // 4, device=device)
218 | masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0]
219 | rand_indices = torch.cat((masked_indices1, unmasked_indices1,
220 | masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2,
221 | masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3))
222 |
223 | masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0]
224 | unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0]
225 |
226 | # t, m, p refer to temporal, magnitude, phase
227 | tpatches = patches[:, : num_patches // 2, :, :]
228 | mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :]
229 |
230 | # 1. Generate tokens from patches via CNNs.
231 | unmasked_tpatches = tpatches[:, unmasked_indices1, :, :]
232 | unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :]
233 | t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches)
234 | t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens)
235 | Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1),
236 | Rearrange('(b n) c 1 -> b n c', b=batch))
237 | t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens)
238 | ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone()
239 |
240 | # 2. Add three cls_tokens before temporal, magnitude and phase tokens.
241 | cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch)
242 | tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens,
243 | cls_tokens[:, 1:2, :], m_tokens,
244 | cls_tokens[:, 2:3, :], p_tokens), dim=1)
245 |
246 | # 3. Generate Positional Embeddings.
247 | t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1
248 | pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :],
249 | self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3],
250 | self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :],
251 | self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4],
252 | self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1)
253 |
254 | # 4. Generate Domain-type Embedding
255 | modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1),
256 | repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1),
257 | repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1)
258 |
259 | tokens = tokens + pos_embedding + modal_embedding
260 |
261 | encoded_tokens = self.encoder.transformer(tokens)
262 |
263 | t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1
264 |
265 | idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1)))
266 |
267 | decoder_tokens = encoded_tokens
268 |
269 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above
270 | mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t)
271 | mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2)
272 | mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2)
273 | mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1)
274 |
275 | # mask_tokens = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_f+masked_num_t)
276 | decoder_pos_emb = self.decoder_pos_emb(torch.cat(
277 | (masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4)))
278 |
279 | mask_tokens = mask_tokens + decoder_pos_emb
280 | # concat the masked tokens to the decoder tokens and attend with decoder
281 | decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1)
282 | decoded_tokens = self.decoder(decoder_tokens)
283 |
284 | mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:]
285 |
286 | pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1))
287 | pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1))
288 | pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1))
289 | pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1)
290 |
291 | recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)'))
292 |
293 | # print(float(recon_loss), '....', float(info_loss))
294 | return recon_loss + beta * idc_loss
295 |
296 | class Model(nn.Module):
297 | def __init__(self, seq_len, patch_len, dim, num_class, in_dim):
298 | super(Model, self).__init__()
299 | self.encoder = TFR_Encoder(seq_len=seq_len,
300 | patch_len=patch_len,
301 | dim=dim,
302 | num_class=num_class,
303 | in_dim=in_dim)
304 | self.crt = CRT(encoder=self.encoder,
305 | decoder_dim=dim,
306 | in_dim=in_dim,
307 | patch_len=patch_len)
308 | self.classifier = MLP(dim, dim//2, num_class)
309 |
310 | def forward(self, x, ssl = False, ratio = 0.5):
311 | if ssl == False:
312 | features = self.encoder(x)
313 | return self.classifier(features)
314 | return self.crt(x, mask_ratio=ratio)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer
2 | The official implementation for our TNNLS paper [Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer](https://arxiv.org/abs/2205.09928).
3 |
4 | ## Overview of Cross Reconstruction Transformer
5 |
6 |
7 |
8 |
9 | ## Getting Started
10 | ### Installation
11 | Git clone our repository, and install the required packages with the following command
12 | ```
13 | git clone https://github.com/BobZwr/Cross-Reconstruction-Transformer.git
14 | cd Cross-Reconstruction-Transformer
15 | pip install -r requirements.txt
16 | ```
17 | We use torch=1.13.0.
18 |
19 | ### Processing Data (Optional)
20 | We provide `data_processing.py` to generate phase and magnitude information based on the time-domain data. You can modify this file to adapt it to your own datasets.
21 |
22 | ## Training and Evaluating
23 | We provide the sample script for training and evaluating our CRT
24 | ```
25 | # For Training:
26 | python main.py --ssl True --sl True --load True --seq_len 256 --patch_len 8 --in_dim 9 --n_classes 6
27 | ```
28 |
29 | ```
30 | # For Testing:
31 | python main.py --ssl False --sl False --load False --seq_len 256 --patch_len 8 --in_dim 9 --n_classes 6
32 | ```
33 | We also provide a subset of HAR dataset for training and testing.
34 |
35 | If you found the codes and datasets are useful, please cite our paper
36 | ```
37 | @article{zhang2023self,
38 | title={Self-Supervised Time Series Representation Learning via Cross Reconstruction Transformer},
39 | author={Zhang, Wenrui and Yang, Ling and Geng, Shijia and Hong, Shenda},
40 | journal={IEEE Transactions on Neural Networks and Learning Systems},
41 | year={2023},
42 | publisher={IEEE}
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/base_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import functools
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class SSLDataSet(Dataset):
11 | def __init__(self, data):
12 | super(SSLDataSet, self).__init__()
13 | self.data = data
14 |
15 | def __getitem__(self, idx):
16 | return torch.tensor(self.data[idx], dtype=torch.float)
17 |
18 | def __len__(self):
19 | return self.data.shape[0]
20 |
21 | class FTDataSet(Dataset):
22 | def __init__(self, data, label, multi_label=False):
23 | super(FTDataSet, self).__init__()
24 | self.data = data
25 | self.label = label
26 | self.multi_label = multi_label
27 |
28 | def __getitem__(self, index):
29 | if self.multi_label:
30 | return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float))
31 | else:
32 | return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long))
33 |
34 | def __len__(self):
35 | return self.data.shape[0]
36 |
37 | # Resnet 1d
38 |
39 | def conv(in_planes, out_planes, stride=1, kernel_size=3):
40 | "convolution with padding 自动使用zeros进行padding"
41 | return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
42 | padding=(kernel_size - 1) // 2, bias=False)
43 |
44 | class ZeroPad1d(nn.Module):
45 | def __init__(self, pad_left, pad_right):
46 | super().__init__()
47 | self.pad_left = pad_left
48 | self.pad_right = pad_right
49 |
50 | def forward(self, x):
51 | return F.pad(x, (self.pad_left, self.pad_right))
52 |
53 |
54 | class BasicBlock1d(nn.Module):
55 | expansion = 1
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super().__init__()
59 |
60 | self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=3)
61 | self.bn1 = nn.BatchNorm1d(planes)
62 | self.relu = nn.ReLU(inplace=True)
63 |
64 | self.conv2 = conv(planes, planes, kernel_size=3)
65 | self.bn2 = nn.BatchNorm1d(planes)
66 |
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 |
80 | if self.downsample is not None:
81 | residual = self.downsample(x)
82 |
83 | out += residual
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class Bottleneck1d(nn.Module):
90 | """Bottleneck for ResNet52 ..."""
91 | expansion = 4
92 |
93 | def __init__(self, inplanes, planes, stride=1, downsample=None):
94 | super().__init__()
95 | kernel_size = 3
96 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
97 | self.bn1 = nn.BatchNorm1d(planes)
98 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
99 | padding=(kernel_size - 1) // 2, bias=False)
100 | self.bn2 = nn.BatchNorm1d(planes)
101 | self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
102 | self.bn3 = nn.BatchNorm1d(planes * 4)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.downsample = downsample
105 | self.stride = stride
106 |
107 | def forward(self, x):
108 | residual = x
109 |
110 | out = self.conv1(x)
111 | out = self.bn1(out)
112 | out = self.relu(out)
113 |
114 | out = self.conv2(out)
115 | out = self.bn2(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv3(out)
119 | out = self.bn3(out)
120 |
121 | if self.downsample is not None:
122 | residual = self.downsample(x)
123 |
124 | out += residual
125 | out = self.relu(out)
126 |
127 | return out
128 |
129 | class ResNet1d(nn.Module):
130 | '''1d adaptation of the torchvision resnet'''
131 |
132 | def __init__(self, block, layers, kernel_size=3, input_channels=12, inplanes=64,
133 | fix_feature_dim=False, kernel_size_stem=None, stride_stem=2, pooling_stem=True,
134 | stride=2):
135 | super(ResNet1d, self).__init__()
136 |
137 | self.inplanes = inplanes
138 | layers_tmp = []
139 | if kernel_size_stem is None:
140 | kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
141 |
142 | # conv-bn-relu (basic feature extraction)
143 | layers_tmp.append(nn.Conv1d(input_channels, inplanes,
144 | kernel_size=kernel_size_stem,
145 | stride=stride_stem,
146 | padding=(kernel_size_stem - 1) // 2, bias=False))
147 | layers_tmp.append(nn.BatchNorm1d(inplanes))
148 | layers_tmp.append(nn.ReLU(inplace=True))
149 |
150 | if pooling_stem is True:
151 | layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
152 |
153 | for i, l in enumerate(layers):
154 | if i == 0:
155 | layers_tmp.append(self._make_block(block, inplanes, layers[0]))
156 | else:
157 | layers_tmp.append(
158 | self._make_block(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
159 | stride=stride))
160 |
161 | self.feature_extractor = nn.Sequential(*layers_tmp)
162 |
163 | def _make_block(self, block, planes, blocks, stride=1, kernel_size=3):
164 | down_sample = None
165 |
166 | # 注定会进行下采样
167 | if stride != 1 or self.inplanes != planes * block.expansion:
168 | down_sample = nn.Sequential(
169 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
170 | nn.BatchNorm1d(planes * block.expansion),
171 | )
172 |
173 | layers = []
174 | layers.append(block(self.inplanes, planes, stride, down_sample))
175 | self.inplanes = planes * block.expansion
176 |
177 | for i in range(1, blocks):
178 | layers.append(block(self.inplanes, planes))
179 |
180 | return nn.Sequential(*layers)
181 |
182 | def forward(self, x):
183 | return self.feature_extractor(x)
184 |
185 | def resnet1d14(inplanes, input_channels):
186 | return ResNet1d(BasicBlock1d, [2,2,2], inplanes=inplanes, input_channels=input_channels)
187 |
188 | def resnet1d18(**kwargs):
189 | return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
190 |
191 | # MLP
192 | class MLP(nn.Module):
193 | def __init__(self, in_channels, hidden_channels, n_classes, bn = True):
194 | super(MLP, self).__init__()
195 | self.in_channels = in_channels
196 | self.n_classes = n_classes
197 | self.hidden_channels = hidden_channels
198 | self.fc1 = nn.Linear(self.in_channels, self.hidden_channels)
199 | self.fc2 = nn.Linear(self.hidden_channels, self.n_classes)
200 | self.ac = nn.ReLU()
201 | self.bn = nn.BatchNorm1d(hidden_channels)
202 |
203 | def forward(self, x):
204 | hidden = self.fc1(x)
205 | hidden = self.ac(hidden)
206 | hidden = self.bn(hidden)
207 | out = self.fc2(hidden)
208 |
209 | return out
210 |
211 | # Time-steps features -> aggregated features
212 | class Flatten(nn.Module):
213 | def __init__(self):
214 | super(Flatten, self).__init__()
215 |
216 | def forward(self, tensor):
217 | b = tensor.size(0)
218 | return tensor.reshape(b, -1)
219 |
220 | class AdaptiveConcatPool1d(nn.Module):
221 | "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
222 |
223 | def __init__(self, sz=None):
224 | "Output will be 2*sz or 2 if sz is None"
225 | super().__init__()
226 | sz = sz or 1
227 | self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
228 |
229 | def forward(self, x):
230 | """x is shaped of B, C, T"""
231 | return torch.cat([self.mp(x), self.ap(x), x[..., -1:]], 1)
232 |
233 | def bn_drop_lin(n_in, n_out, bn, p, actn):
234 | "`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`"
235 | layers = list()
236 |
237 | if bn:
238 | layers.append(nn.BatchNorm1d(n_in))
239 |
240 | if p > 0.:
241 | layers.append(nn.Dropout(p=p))
242 |
243 | layers.append(nn.Linear(n_in, n_out))
244 |
245 | if actn is not None:
246 | layers.append(actn)
247 |
248 | return layers
249 |
250 | def create_head1d(nf: int, nc: int, lin_ftrs=[512, ], dropout=0.5, bn: bool = True, act="relu"):
251 | lin_ftrs = [3 * nf] + lin_ftrs + [nc]
252 |
253 | activations = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
254 | layers = [AdaptiveConcatPool1d(), Flatten()]
255 |
256 | for ni, no, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], activations):
257 | layers += bn_drop_lin(ni, no, bn, dropout, actn)
258 |
259 | layers += [nn.Sigmoid()]
260 |
261 | return nn.Sequential(*layers)
262 |
--------------------------------------------------------------------------------
/data_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | data_path = '...'
5 |
6 | time = np.load(data_path)
7 |
8 | freq = np.fft.fft(time)[:,:,:time.shape[-1] // 2]
9 |
10 | a = freq.real
11 | b = freq.imag
12 |
13 | magnitude = np.abs(freq)
14 |
15 | phase = np.zeros_like(a)
16 | phase[a > 0] = np.arctan(b / a)[a > 0]
17 | phase[a < 0] = (np.arctan(b / a) + np.sign(b) * np.pi)[a < 0]
18 | phase[a == 0] = (np.sign(b) * np.pi / 2)[a == 0]
19 |
20 | data = np.concatenate([time, magnitude, phase], -1)
21 |
22 | np.save(data, os.path.join(data_path, 'time_freq_data.npy'))
23 |
--------------------------------------------------------------------------------
/dataset/har_test_all.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_test_all.npy
--------------------------------------------------------------------------------
/dataset/har_test_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_test_label.npy
--------------------------------------------------------------------------------
/dataset/har_train_all.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_train_all.npy
--------------------------------------------------------------------------------
/dataset/har_train_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_train_label.npy
--------------------------------------------------------------------------------
/dataset/har_valid_all.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_valid_all.npy
--------------------------------------------------------------------------------
/dataset/har_valid_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BobZwr/Cross-Reconstruction-Transformer/9656810f4a1d1e91e42f3921282157e3f151754c/dataset/har_valid_label.npy
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import random
4 |
5 | from tqdm import tqdm, trange
6 | from sklearn.metrics import roc_auc_score
7 |
8 | from CRT import CRT, TFR_Encoder, Model
9 | from base_models import SSLDataSet, FTDataSet
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | from torch.utils.data import DataLoader
16 |
17 |
18 |
19 | def self_supervised_learning(model, X, n_epoch, lr, batch_size, device, min_ratio=0.3, max_ratio=0.8):
20 | optimizer = optim.Adam(model.parameters(), lr)
21 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
22 |
23 | model.to(device)
24 | model.train()
25 |
26 | dataset = SSLDataSet(X)
27 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
28 |
29 | losses = []
30 |
31 | pbar = trange(n_epoch)
32 | for _ in pbar:
33 | for batch in dataloader:
34 | x = batch.to(device)
35 | loss = model(x, ssl=True, ratio=max(min_ratio, min(max_ratio, _ / n_epoch)))
36 | optimizer.zero_grad()
37 | loss.backward()
38 | optimizer.step()
39 | losses.append(float(loss))
40 | scheduler.step(_)
41 | pbar.set_description(str(sum(losses) / len(losses)))
42 | torch.save(model.to('cpu'), 'Pretrained_Model.pkl')
43 |
44 | def finetuning(model, train_set, valid_set, n_epoch, lr, batch_size, device, multi_label=True):
45 | # multi_label: whether the classification task is a multi-label task.
46 | model.train()
47 | model.to(device)
48 |
49 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
50 |
51 | loss_func = nn.BCEWithLogitsLoss() if multi_label else nn.CrossEntropyLoss()
52 |
53 | for stage in range(2):
54 | # stage0: finetuning only classifier; stage1: finetuning whole model
55 | best_auc = 0
56 | step = 0
57 | if stage == 0:
58 | min_lr = 1e-6
59 | optimizer = optim.Adam(model.classifier.parameters(), lr=lr)
60 | else:
61 | min_lr = 1e-8
62 | optimizer = optim.Adam(model.parameters(), lr=lr/2)
63 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, mode = 'max', factor=0.8, min_lr=min_lr)
64 | pbar = trange(n_epoch)
65 | for _ in pbar:
66 | for batch_idx, batch in enumerate(train_loader):
67 | step += 1
68 | x, y = tuple(t.to(device) for t in batch)
69 | pred = model(x)
70 | loss = loss_func(pred, y)
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 | if step % 10 == 0:
75 | valid_auc = test(model, valid_set, batch_size, multi_label)
76 | pbar.set_description('Best Validation AUC: {:.4f} --------- AUC on this step: {:.4f}'.format(best_auc, valid_auc))
77 | if valid_auc > best_auc:
78 | best_auc = valid_auc
79 | torch.save(model, 'Finetuned_Model.pkl')
80 | scheduler.step(best_auc)
81 |
82 | def test(model, dataset, batch_size, multi_label):
83 | model.eval()
84 | testloader = DataLoader(dataset, batch_size=batch_size)
85 |
86 | pred_prob = []
87 | with torch.no_grad():
88 | for batch in testloader:
89 | x, y = tuple(t.to(device) for t in batch)
90 | pred = model(x)
91 | pred = torch.sigmoid(pred) if multi_label else F.softmax(pred, dim=1)
92 | pred_prob.extend([i.cpu().detach().numpy().tolist() for i in pred])
93 | auc = roc_auc_score(dataset.label, pred_prob, multi_class='ovr')
94 |
95 | print('AUC is {:.2f}'.format(auc * 100))
96 | print('More metrics can be added in the test function.')
97 | model.train()
98 | return auc
99 |
100 | def str2bool(v):
101 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
102 | return True
103 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
104 | return False
105 | else:
106 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
107 |
108 | def set_seed(seed):
109 | torch.manual_seed(seed)
110 | torch.cuda.manual_seed(seed)
111 | torch.cuda.manual_seed_all(seed)
112 | np.random.seed(seed) # Numpy module.
113 | random.seed(seed) # Python random module.
114 | torch.backends.cudnn.benchmark = False
115 | torch.backends.cudnn.deterministic = True
116 |
117 | if __name__ == '__main__':
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument("--ssl", type=str2bool, default=False)
120 | parser.add_argument("--sl", type=str2bool, default=True)
121 | parser.add_argument("--load", type=str2bool, default=True)
122 | parser.add_argument("--test", type=str2bool, default=True)
123 | # all default values of parameters are for PTB-XL
124 | parser.add_argument("--seq_len", type=int, default=10000)
125 | parser.add_argument("--patch_len", type=int, default=20)
126 | parser.add_argument("--dim", type=int, default=128)
127 | parser.add_argument("--in_dim", type=int, default=12)
128 | parser.add_argument("--n_classes", type=int, default=5)
129 | opt = parser.parse_args()
130 |
131 | set_seed(0)
132 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
133 |
134 | seq_len = opt.seq_len
135 | patch_len = opt.patch_len
136 | dim = opt.dim
137 | in_dim = opt.in_dim
138 | n_classes = opt.n_classes
139 |
140 | if opt.ssl:
141 | model = Model(seq_len, patch_len, dim, n_classes, in_dim).to(device)
142 | X = np.load('./dataset/har_train_all.npy')
143 | self_supervised_learning(model, X, 100, 1e-3, 128, device)
144 | if opt.load:
145 | model = torch.load('Pretrained_Model.pkl', map_location=device)
146 | else:
147 | model = Model(seq_len, patch_len, dim, n_classes, in_dim).to(device)
148 | if opt.sl:
149 | train_X, train_y = np.load('./dataset/har_train_all.npy'), np.load('./dataset/har_train_label.npy')
150 | valid_X, valid_y = np.load('./dataset/har_valid_all.npy'), np.load('./dataset/har_valid_label.npy')
151 | TrainSet = FTDataSet(train_X, train_y, multi_label=False) # multi_label = True when the dataset is PTBXL
152 | ValidSet = FTDataSet(valid_X, valid_y, multi_label=False)
153 | finetuning(model, TrainSet, ValidSet, 100, 1e-3, 128, device, multi_label=False)
154 | if opt.test:
155 | test_X, test_y = np.load('./dataset/har_test_all.npy'), np.load('./dataset/har_test_label.npy')
156 | TestSet = FTDataSet(test_X, test_y, multi_label=False)
157 | model = torch.load('Finetuned_Model.pkl', map_location=device)
158 | test(model, TestSet, 100, multi_label=False)
159 |
160 |
161 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.5.0
2 | numpy==1.23.4
3 | scikit-learn==1.1.2
4 | tqdm==4.64.0
5 |
--------------------------------------------------------------------------------