├── CameraReady_DiffShape_AAAI_2024.pdf ├── README.md ├── Supplementary Material_DiffShape_AAAI2024.pdf └── diffshape_ssc ├── __init__.py ├── diffusion ├── __init__.py ├── blocks.py ├── diff_utils.py └── diffusion_model.py ├── main_diffshape.py ├── parameter_shapelets.py ├── semi_backbone.py └── semi_utils.py /CameraReady_DiffShape_AAAI_2024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qianlima-lab/DiffShape/8e9e83cbac654078302be9a0b506a8bcef85ad9b/CameraReady_DiffShape_AAAI_2024.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Language-Shapelets for Semi-supervised Time-Series Classification 2 | 3 | This is the training code for our paper "Diffusion Language-Shapelets for Semi-supervised Time-Series Classification" (AAAI-24). 4 | 5 | ## Abstract 6 | 7 | Semi-supervised time-series classification could effectively alleviate the issue of lacking labeled data. However, 8 | existing approaches usually ignore model interpretability, making it difficult for humans to understand the principles 9 | behind the predictions of a model. Shapelets are a set of discriminative subsequences that show high interpretability in 10 | time series classification tasks. Shapelet learning-based methods have demonstrated promising classification 11 | performance. Unfortunately, without enough labeled data, the shapelets learned by existing methods are often poorly 12 | discriminative, and even dissimilar to any subsequence of the original time series. To address this issue, we propose 13 | the Diffusion Language-Shapelets model (DiffShape) for semi-supervised time series classification. In DiffShape, a 14 | self-supervised diffusion learning mechanism is designed, which uses real subsequences as a condition. This helps to 15 | increase the similarity between the learned shapelets and real subsequences by using a large amount of unlabeled data. 16 | Furthermore, we introduce a contrastive language-shapelets learning strategy that improves the discriminability of the 17 | learned shapelets by incorporating the natural language descriptions of the time series. 18 | Experiments have been conducted on the UCR time series archive, and the results reveal that the proposed DiffShape 19 | method achieves state-of-the-art performance and exhibits superior interpretability over baselines. 20 | 21 | ## Datasets 22 | 23 | ### UCR archive time series datasets 24 | 25 | * [UCR time series archive](https://www.cs.ucr.edu/~eamonn/time_series_data_2018/UCRArchive_2018.zip) 26 | 27 | In accordance with the recommendation provided by the creators of the UCR archive and TS-TFC, we implemented a 28 | restriction to maintain an average 29 | of at least 30 samples per class within each dataset. This measure was taken to enhance the stability of the 30 | classification test results. As a result, 31 | we employed a total of 106 datasets from the initial pool of 128 UCR datasets for our experimental analysis. 32 | 33 | Please refer to **page 13** of the [PDF](https://www.cs.ucr.edu/~eamonn/time_series_data_2018/BriefingDocument2018.pdf) document for the password to access the zipped file of the UCR archive. 34 | 35 | ## Usage (Our Model) 36 | 37 | To train a DiffShape model on a dataset for semi-supervised time series classification, run 38 | 39 | ```bash 40 | python diffshape_ssc/main_diffshape.py --dataset [name of the dataset you want to train] ... 41 | ``` 42 | 43 | For detailed options and examples, please refer to parser setup in ```diffshape_ssc/main_diffshape.py``` 44 | 45 | ## Citation 46 | If you use this code for your research, please cite our paper: 47 | ``` 48 | @inproceedings{liu2024diffusion, 49 | title={Diffusion language-shapelets for semi-supervised time-series classification}, 50 | author={Liu, Zhen and Pei, Wenbin and Lan, Disen and Ma, Qianli}, 51 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 52 | volume={38}, 53 | number={13}, 54 | pages={14079--14087}, 55 | year={2024} 56 | } 57 | ``` 58 | 59 | 60 | -------------------------------------------------------------------------------- /Supplementary Material_DiffShape_AAAI2024.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qianlima-lab/DiffShape/8e9e83cbac654078302be9a0b506a8bcef85ad9b/Supplementary Material_DiffShape_AAAI2024.pdf -------------------------------------------------------------------------------- /diffshape_ssc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qianlima-lab/DiffShape/8e9e83cbac654078302be9a0b506a8bcef85ad9b/diffshape_ssc/__init__.py -------------------------------------------------------------------------------- /diffshape_ssc/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qianlima-lab/DiffShape/8e9e83cbac654078302be9a0b506a8bcef85ad9b/diffshape_ssc/diffusion/__init__.py -------------------------------------------------------------------------------- /diffshape_ssc/diffusion/blocks.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import pack, rearrange, reduce, repeat, unpack 7 | from torch import Tensor, einsum, nn 8 | from typing_extensions import TypeGuard 9 | 10 | V = TypeVar("V") 11 | 12 | """ 13 | Helper functions 14 | """ 15 | 16 | 17 | class T: 18 | """Where the magic happens, builds a type template for a given type""" 19 | 20 | def __init__(self, t: Callable, override: bool = True): 21 | self.t = t 22 | self.override = override 23 | 24 | def __call__(self, *a, **ka): 25 | t, override = self.t, self.override 26 | 27 | class Inner: 28 | def __init__(self): 29 | self.args = a 30 | self.__dict__.update(**ka) 31 | 32 | def __call__(self, *b, **kb): 33 | if override: 34 | return t(*(*a, *b), **{**ka, **kb}) 35 | else: 36 | return t(*(*b, *a), **{**kb, **ka}) 37 | 38 | return Inner() 39 | 40 | 41 | def Ts(t: Callable[..., V]) -> Callable[..., Callable[..., V]]: 42 | """Builds a type template for a given type that accepts a list of instances""" 43 | return lambda *types: lambda: t(*[tp() for tp in types]) 44 | 45 | 46 | def exists(val: Optional[V]) -> TypeGuard[V]: 47 | return val is not None 48 | 49 | 50 | def default(val: Optional[V], d: V) -> V: 51 | return val if exists(val) else d 52 | 53 | 54 | def Module(modules: Sequence[nn.Module], forward_fn: Callable): 55 | """Functional module helper""" 56 | 57 | class Module(nn.Module): 58 | def __init__(self): 59 | super().__init__() 60 | self.blocks = nn.ModuleList(modules) 61 | 62 | def forward(self, *args, **kwargs): 63 | return forward_fn(*args, **kwargs) 64 | 65 | return Module() 66 | 67 | 68 | class Sequential(nn.Module): 69 | """Custom Sequential that includes all args""" 70 | 71 | def __init__(self, *blocks): 72 | super().__init__() 73 | self.blocks = nn.ModuleList(blocks) 74 | 75 | def forward(self, x: Tensor, *args) -> Tensor: 76 | for block in self.blocks: 77 | x = block(x, *args) 78 | return x 79 | 80 | 81 | def Select(args_fn: Callable) -> Callable[..., Type[nn.Module]]: 82 | """Selects (swap, remove, repeat) forward arguments given a (lambda) function""" 83 | 84 | def fn(block_t: Type[nn.Module]) -> Type[nn.Module]: 85 | class Select(nn.Module): 86 | def __init__(self, *args, **kwargs): 87 | super().__init__() 88 | self.block = block_t(*args, **kwargs) 89 | self.args_fn = args_fn 90 | 91 | def forward(self, *args, **kwargs): 92 | return self.block(*args_fn(*args), **kwargs) 93 | 94 | return Select 95 | 96 | return fn 97 | 98 | 99 | class Packed(Sequential): 100 | """Packs, and transposes non-channel dims, useful for attention-like view""" 101 | 102 | def forward(self, x: Tensor, *args) -> Tensor: 103 | x, ps = pack([x], "b d *") 104 | x = rearrange(x, "b d n -> b n d") 105 | x = super().forward(x, *args) 106 | x = rearrange(x, "b n d -> b d n") 107 | x = unpack(x, ps, "b d *")[0] 108 | return x 109 | 110 | 111 | def Repeat(m: Union[nn.Module, Type[nn.Module]], times: int) -> Any: 112 | ms = (m,) * times 113 | return Sequential(*ms) if isinstance(m, nn.Module) else Ts(Sequential)(*ms) 114 | 115 | 116 | def Skip(merge_fn: Callable[[Tensor, Tensor], Tensor] = torch.add) -> Type[Sequential]: 117 | class Skip(Sequential): 118 | """Adds skip connection around modules""" 119 | 120 | def forward(self, x: Tensor, *args) -> Tensor: 121 | return merge_fn(x, super().forward(x, *args)) 122 | 123 | return Skip 124 | 125 | 126 | """ 127 | Modules 128 | """ 129 | 130 | 131 | def Conv(dim: int, *args, **kwargs) -> nn.Module: 132 | return [nn.Conv1d, nn.Conv2d, nn.Conv3d][dim - 1](*args, **kwargs) 133 | 134 | 135 | def ConvTranspose(dim: int, *args, **kwargs) -> nn.Module: 136 | return [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d][dim - 1]( 137 | *args, **kwargs 138 | ) 139 | 140 | 141 | def Downsample( 142 | dim: int, factor: int = 2, width: int = 1, conv_t=Conv, **kwargs 143 | ) -> nn.Module: 144 | width = width if factor > 1 else 1 145 | return conv_t( 146 | dim=dim, 147 | kernel_size=factor * width, 148 | stride=factor, 149 | padding=(factor * width - factor) // 2, 150 | **kwargs, 151 | ) 152 | 153 | 154 | def Upsample( 155 | dim: int, 156 | factor: int = 2, 157 | width: int = 1, 158 | conv_t=Conv, 159 | conv_tranpose_t=ConvTranspose, 160 | **kwargs, 161 | ) -> nn.Module: 162 | width = width if factor > 1 else 1 163 | return conv_tranpose_t( 164 | dim=dim, 165 | kernel_size=factor * width, 166 | stride=factor, 167 | padding=(factor * width - factor) // 2, 168 | **kwargs, 169 | ) 170 | 171 | 172 | def UpsampleInterpolate( 173 | dim: int, 174 | factor: int = 2, 175 | kernel_size: int = 3, 176 | mode: str = "nearest", 177 | conv_t=Conv, 178 | **kwargs, 179 | ) -> nn.Module: 180 | assert kernel_size % 2 == 1, "upsample kernel size must be odd" 181 | return nn.Sequential( 182 | nn.Upsample(scale_factor=factor, mode=mode), 183 | conv_t( 184 | dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs 185 | ), 186 | ) 187 | 188 | 189 | def ConvBlock( 190 | dim: int, 191 | in_channels: int, 192 | activation_t=nn.SiLU, 193 | norm_t=T(nn.GroupNorm)(num_groups=1), 194 | conv_t=Conv, 195 | **kwargs, 196 | ) -> nn.Module: 197 | return nn.Sequential( 198 | norm_t(num_channels=in_channels), 199 | activation_t(), 200 | conv_t(dim=dim, in_channels=in_channels, **kwargs), 201 | ) 202 | 203 | 204 | def ResnetBlock( 205 | dim: int, 206 | in_channels: int, 207 | out_channels: int, 208 | kernel_size: int = 3, 209 | conv_block_t=ConvBlock, 210 | conv_t=Conv, 211 | **kwargs, 212 | ) -> nn.Module: 213 | ConvBlock = T(conv_block_t)( 214 | dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs 215 | ) 216 | Conv = T(conv_t)(dim=dim, kernel_size=1) 217 | 218 | conv_block = Sequential( 219 | ConvBlock(in_channels=in_channels, out_channels=out_channels), 220 | ConvBlock(in_channels=out_channels, out_channels=out_channels), 221 | ) 222 | conv = nn.Identity() 223 | if in_channels != out_channels: 224 | conv = Conv(in_channels=in_channels, out_channels=out_channels) 225 | 226 | return Module([conv_block, conv], lambda x: conv_block(x) + conv(x)) 227 | 228 | 229 | class GRN(nn.Module): 230 | """GRN (Global Response Normalization) layer from ConvNextV2 generic to any dim""" 231 | 232 | def __init__(self, dim: int, channels: int): 233 | super().__init__() 234 | ones = (1,) * dim 235 | self.gamma = nn.Parameter(torch.zeros(1, channels, *ones)) 236 | self.beta = nn.Parameter(torch.zeros(1, channels, *ones)) 237 | self.norm_dims = [d + 2 for d in range(dim)] 238 | 239 | def forward(self, x: Tensor) -> Tensor: 240 | Gx = torch.norm(x, p=2, dim=self.norm_dims, keepdim=True) 241 | Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) 242 | return self.gamma * (x * Nx) + self.beta + x 243 | 244 | 245 | def ConvNextV2Block(dim: int, channels: int) -> nn.Module: 246 | block = nn.Sequential( 247 | # Depthwise and LayerNorm 248 | Conv( 249 | dim=dim, 250 | in_channels=channels, 251 | out_channels=channels, 252 | kernel_size=7, 253 | padding=3, 254 | groups=channels, 255 | ), 256 | nn.GroupNorm(num_groups=1, num_channels=channels), 257 | # Pointwise expand 258 | Conv(dim=dim, in_channels=channels, out_channels=channels * 4, kernel_size=1), 259 | # Activation and GRN 260 | nn.GELU(), 261 | GRN(dim=dim, channels=channels * 4), 262 | # Pointwise contract 263 | Conv( 264 | dim=dim, 265 | in_channels=channels * 4, 266 | out_channels=channels, 267 | kernel_size=1, 268 | ), 269 | ) 270 | 271 | return Module([block], lambda x: x + block(x)) 272 | 273 | 274 | def AttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module: 275 | scale = head_features ** -0.5 276 | mid_features = head_features * num_heads 277 | to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False) 278 | 279 | def forward( 280 | q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 281 | ) -> Tensor: 282 | h = num_heads 283 | # Split heads 284 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 285 | # Compute similarity matrix and add eventual mask 286 | sim = einsum("... n d, ... m d -> ... n m", q, k) * scale 287 | # Get attention matrix with softmax 288 | attn = sim.softmax(dim=-1) 289 | # Compute values 290 | out = einsum("... n m, ... m d -> ... n d", attn, v) 291 | out = rearrange(out, "b h n d -> b n (h d)") 292 | return to_out(out) 293 | 294 | return Module([to_out], forward) 295 | 296 | 297 | def LinearAttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module: 298 | scale = head_features ** -0.5 299 | mid_features = head_features * num_heads 300 | to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False) 301 | 302 | def forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor: 303 | h = num_heads 304 | # Split heads 305 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 306 | # Softmax rows and cols 307 | q = q.softmax(dim=-1) * scale 308 | k = k.softmax(dim=-2) 309 | # Attend on channel dim 310 | attn = einsum("... n d, ... n c -> ... d c", k, v) 311 | out = einsum("... n d, ... d c -> ... n c", q, attn) 312 | out = rearrange(out, "b h n d -> b n (h d)") 313 | return to_out(out) 314 | 315 | return Module([to_out], forward) 316 | 317 | 318 | def FixedEmbedding(max_length: int, features: int): 319 | embedding = nn.Embedding(max_length, features) 320 | 321 | def forward(x: Tensor) -> Tensor: 322 | batch_size, length, device = *x.shape[0:2], x.device 323 | assert_message = "Input sequence length must be <= max_length" 324 | assert length <= max_length, assert_message 325 | position = torch.arange(length, device=device) 326 | fixed_embedding = embedding(position) 327 | fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) 328 | return fixed_embedding 329 | 330 | return Module([embedding], forward) 331 | 332 | 333 | class Attention(nn.Module): 334 | def __init__( 335 | self, 336 | features: int, 337 | *, 338 | head_features: int, 339 | num_heads: int, 340 | context_features: Optional[int] = None, 341 | max_length: Optional[int] = None, 342 | attention_base_t=AttentionBase, 343 | positional_embedding_t=None, 344 | ): 345 | super().__init__() 346 | self.context_features = context_features 347 | self.use_positional_embedding = exists(positional_embedding_t) 348 | self.use_context = exists(context_features) 349 | mid_features = head_features * num_heads 350 | context_features = default(context_features, features) 351 | 352 | self.max_length = max_length 353 | if self.use_positional_embedding: 354 | assert exists(max_length) 355 | self.positional_embedding = positional_embedding_t( 356 | max_length=max_length, features=features 357 | ) 358 | 359 | self.norm = nn.LayerNorm(features) 360 | self.norm_context = nn.LayerNorm(context_features) 361 | self.to_q = nn.Linear( 362 | in_features=features, out_features=mid_features, bias=False 363 | ) 364 | self.to_kv = nn.Linear( 365 | in_features=context_features, out_features=mid_features * 2, bias=False 366 | ) 367 | self.attention = attention_base_t( 368 | features, num_heads=num_heads, head_features=head_features 369 | ) 370 | 371 | def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: 372 | assert_message = "You must provide a context when using context_features" 373 | assert not self.context_features or exists(context), assert_message 374 | skip = x 375 | if self.use_positional_embedding: 376 | x = x + self.positional_embedding(x) 377 | # Use context if provided 378 | context = context if exists(context) and self.use_context else x 379 | # Normalize then compute q from input and k,v from context 380 | x, context = self.norm(x), self.norm_context(context) 381 | q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) 382 | # Compute and return attention 383 | return skip + self.attention(q, k, v) 384 | 385 | 386 | def CrossAttention(context_features: int, **kwargs): 387 | return Attention(context_features=context_features, **kwargs) 388 | 389 | 390 | def FeedForward(features: int, multiplier: int) -> nn.Module: 391 | mid_features = features * multiplier 392 | return Skip(torch.add)( 393 | nn.Linear(in_features=features, out_features=mid_features), 394 | nn.GELU(), 395 | nn.Linear(in_features=mid_features, out_features=features), 396 | ) 397 | 398 | 399 | def Modulation(in_features: int, num_features: int) -> nn.Module: 400 | to_scale_shift = nn.Sequential( 401 | nn.SiLU(), 402 | nn.Linear(in_features=num_features, out_features=in_features * 2, bias=True), 403 | ) 404 | norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6) 405 | 406 | def forward(x: Tensor, features: Tensor) -> Tensor: 407 | scale_shift = to_scale_shift(features) 408 | scale, shift = rearrange(scale_shift, "b d -> b 1 d").chunk(2, dim=-1) 409 | return norm(x) * (1 + scale) + shift 410 | 411 | return Module([to_scale_shift, norm], forward) 412 | 413 | 414 | def MergeAdd(): 415 | return Module([], lambda x, y, *_: x + y) 416 | 417 | 418 | def MergeCat(dim: int, channels: int, scale: float = 2 ** -0.5) -> nn.Module: 419 | conv = Conv(dim=dim, in_channels=channels * 2, out_channels=channels, kernel_size=1) 420 | return Module([conv], lambda x, y, *_: conv(torch.cat([x * scale, y], dim=1))) 421 | 422 | 423 | def MergeModulate(dim: int, channels: int, modulation_features: int): 424 | to_scale = nn.Sequential( 425 | nn.SiLU(), 426 | nn.Linear(in_features=modulation_features, out_features=channels, bias=True), 427 | ) 428 | 429 | def forward(x: Tensor, y: Tensor, features: Tensor, *args) -> Tensor: 430 | scale = rearrange(to_scale(features), f'b c -> b c {"1 " * dim}') 431 | return x + scale * y 432 | 433 | return Module([to_scale], forward) 434 | 435 | 436 | """ 437 | Embedders 438 | """ 439 | 440 | 441 | class NumberEmbedder(nn.Module): 442 | def __init__(self, features: int, dim: int = 256): 443 | super().__init__() 444 | assert dim % 2 == 0, f"dim must be divisible by 2, found {dim}" 445 | self.features = features 446 | self.weights = nn.Parameter(torch.randn(dim // 2)) 447 | self.to_out = nn.Linear(in_features=dim + 1, out_features=features) 448 | 449 | def to_embedding(self, x: Tensor) -> Tensor: 450 | x = rearrange(x, "b -> b 1") 451 | freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi 452 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) 453 | fouriered = torch.cat((x, fouriered), dim=-1) 454 | return self.to_out(fouriered) 455 | 456 | def forward(self, x: Union[Sequence[float], Tensor]) -> Tensor: 457 | if not torch.is_tensor(x): 458 | x = torch.tensor(x, device=self.weights.device) 459 | assert isinstance(x, Tensor) 460 | shape = x.shape 461 | x = rearrange(x, "... -> (...)") 462 | return self.to_embedding(x).view(*shape, self.features) # type: ignore 463 | 464 | 465 | class T5Embedder(nn.Module): 466 | def __init__(self, model: str = "t5-base", max_length: int = 64): 467 | super().__init__() 468 | from transformers import AutoTokenizer, T5EncoderModel 469 | 470 | self.tokenizer = AutoTokenizer.from_pretrained(model) 471 | self.transformer = T5EncoderModel.from_pretrained(model) 472 | self.max_length = max_length 473 | self.embedding_features = self.transformer.config.d_model 474 | 475 | @torch.no_grad() 476 | def forward(self, texts: Sequence[str]) -> Tensor: 477 | encoded = self.tokenizer( 478 | texts, 479 | truncation=True, 480 | max_length=self.max_length, 481 | padding="max_length", 482 | return_tensors="pt", 483 | ) 484 | 485 | device = next(self.transformer.parameters()).device 486 | input_ids = encoded["input_ids"].to(device) 487 | attention_mask = encoded["attention_mask"].to(device) 488 | 489 | self.transformer.eval() 490 | 491 | embedding = self.transformer( 492 | input_ids=input_ids, attention_mask=attention_mask 493 | )["last_hidden_state"] 494 | 495 | return embedding 496 | 497 | 498 | """ 499 | Plugins 500 | """ 501 | 502 | 503 | def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: 504 | if proba == 1: 505 | return torch.ones(shape, device=device, dtype=torch.bool) 506 | elif proba == 0: 507 | return torch.zeros(shape, device=device, dtype=torch.bool) 508 | else: 509 | return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) 510 | 511 | 512 | def LanguageShapeletGuidancePlugin( 513 | net_t: Type[nn.Module], 514 | embedding_max_length: int, 515 | ) -> Callable[..., nn.Module]: 516 | """Classifier-Free Guidance -> CFG(UNet, embedding_max_length=512)(...)""" 517 | 518 | def Net(embedding_features: int, **kwargs) -> nn.Module: 519 | fixed_embedding = FixedEmbedding( 520 | max_length=embedding_max_length, 521 | features=embedding_features, 522 | ) 523 | net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore 524 | 525 | def forward( 526 | x: Tensor, 527 | embedding: Optional[Tensor] = None, 528 | embedding_scale: float = 1.0, 529 | embedding_mask_proba: float = 0.0, 530 | **kwargs, 531 | ): 532 | msg = "ClassiferFreeGuidancePlugin requires embedding" 533 | assert exists(embedding), msg 534 | b, device = embedding.shape[0], embedding.device 535 | embedding_mask = fixed_embedding(embedding) 536 | 537 | if embedding_mask_proba > 0.0: 538 | # Randomly mask embedding 539 | batch_mask = rand_bool( 540 | shape=(b, 1, 1), proba=embedding_mask_proba, device=device 541 | ) 542 | embedding = torch.where(batch_mask, embedding_mask, embedding) 543 | 544 | if embedding_scale != 1.0: 545 | # Compute both normal and fixed embedding outputs 546 | out = net(x, embedding=embedding, **kwargs) 547 | out_masked = net(x, embedding=embedding_mask, **kwargs) 548 | # Scale conditional output using classifier-free guidance 549 | return out_masked + (out - out_masked) * embedding_scale 550 | else: 551 | return net(x, embedding=embedding, **kwargs) 552 | 553 | return Module([fixed_embedding, net], forward) 554 | 555 | return Net 556 | 557 | 558 | def TimeConditioningPlugin( 559 | net_t: Type[nn.Module], 560 | num_layers: int = 2, 561 | ) -> Callable[..., nn.Module]: 562 | """Adds time conditioning (e.g. for diffusion)""" 563 | 564 | def Net(modulation_features: Optional[int] = None, **kwargs) -> nn.Module: 565 | msg = "TimeConditioningPlugin requires modulation_features" 566 | assert exists(modulation_features), msg 567 | 568 | embedder = NumberEmbedder(features=modulation_features) 569 | mlp = Repeat( 570 | nn.Sequential( 571 | nn.Linear(modulation_features, modulation_features), nn.GELU() 572 | ), 573 | times=num_layers, 574 | ) 575 | net = net_t(modulation_features=modulation_features, **kwargs) # type: ignore 576 | 577 | def forward( 578 | x: Tensor, 579 | time: Optional[Tensor] = None, 580 | features: Optional[Tensor] = None, 581 | **kwargs, 582 | ): 583 | msg = "TimeConditioningPlugin requires time in forward" 584 | assert exists(time), msg 585 | # Process time to time_features 586 | time_features = F.gelu(embedder(time)) 587 | time_features = mlp(time_features) 588 | # Overlap features if more than one per batch 589 | if time_features.ndim == 3: 590 | time_features = reduce(time_features, "b n d -> b d", "sum") 591 | # Merge time features with features if provided 592 | features = features + time_features if exists(features) else time_features 593 | return net(x, features=features, **kwargs) 594 | 595 | return Module([embedder, mlp, net], forward) 596 | 597 | return Net 598 | 599 | 600 | def TextConditioningPlugin( 601 | net_t: Type[nn.Module], embedder: Optional[nn.Module] = None 602 | ) -> Callable[..., nn.Module]: 603 | """Adds text conditioning""" 604 | embedder = embedder if exists(embedder) else T5Embedder() 605 | msg = "TextConditioningPlugin embedder requires embedding_features attribute" 606 | assert hasattr(embedder, "embedding_features"), msg 607 | features: int = embedder.embedding_features # type: ignore 608 | 609 | def Net(embedding_features: int = features, **kwargs) -> nn.Module: 610 | msg = f"TextConditioningPlugin requires embedding_features={features}" 611 | assert embedding_features == features, msg 612 | net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore 613 | 614 | def forward( 615 | x: Tensor, text: Sequence[str], embedding: Optional[Tensor] = None, **kwargs 616 | ): 617 | text_embedding = embedder(text) # type: ignore 618 | if exists(embedding): 619 | text_embedding = torch.cat([text_embedding, embedding], dim=1) 620 | return net(x, embedding=text_embedding, **kwargs) 621 | 622 | return Module([embedder, net], forward) # type: ignore 623 | 624 | return Net 625 | -------------------------------------------------------------------------------- /diffshape_ssc/diffusion/diff_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | curPath = os.path.abspath(os.path.dirname(__file__)) 5 | rootPath = os.path.split(curPath)[0] 6 | sys.path.append(rootPath) 7 | 8 | from inspect import isfunction 9 | from typing import Dict, Tuple, TypeVar, Union 10 | 11 | from typing_extensions import TypeGuard 12 | 13 | from typing import Callable, List, Optional, Sequence, no_type_check 14 | 15 | import torch 16 | from torch import Tensor, nn 17 | 18 | from diffusion.blocks import ( 19 | Attention, 20 | Conv, 21 | ConvBlock, 22 | ConvNextV2Block, 23 | CrossAttention, 24 | Downsample, 25 | FeedForward, 26 | LinearAttentionBase, 27 | MergeAdd, 28 | MergeCat, 29 | MergeModulate, 30 | Modulation, 31 | Module, 32 | Packed, 33 | ResnetBlock, 34 | Select, 35 | Sequential, 36 | T, 37 | Upsample, 38 | UpsampleInterpolate, 39 | default, 40 | exists, 41 | ) 42 | 43 | T = TypeVar("T") 44 | 45 | 46 | def exists(val: Optional[T]) -> TypeGuard[T]: 47 | return val is not None 48 | 49 | 50 | def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: 51 | return_dicts: Tuple[Dict, Dict] = ({}, {}) 52 | for key in d.keys(): 53 | no_prefix = int(not key.startswith(prefix)) 54 | return_dicts[no_prefix][key] = d[key] 55 | return return_dicts 56 | 57 | 58 | def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: 59 | kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) 60 | if keep_prefix: 61 | return kwargs_with_prefix, kwargs 62 | kwargs_no_prefix = {k[len(prefix):]: v for k, v in kwargs_with_prefix.items()} 63 | return kwargs_no_prefix, kwargs 64 | 65 | 66 | def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: 67 | if exists(val): 68 | return val 69 | return d() if isfunction(d) else d 70 | 71 | 72 | """ 73 | Items 74 | """ 75 | 76 | # Selections for item forward parameters 77 | SelectX = Select(lambda x, *_: (x,)) 78 | SelectXF = Select(lambda x, f, *_: (x, f)) 79 | SelectXE = Select(lambda x, f, e, *_: (x, e)) 80 | SelectXC = Select(lambda x, f, e, c, *_: (x, c)) 81 | 82 | """ Downsample / Upsample """ 83 | 84 | 85 | def DownsampleItem( 86 | dim: Optional[int] = None, 87 | factor: Optional[int] = None, 88 | in_channels: Optional[int] = None, 89 | channels: Optional[int] = None, 90 | downsample_width: int = 1, 91 | **kwargs, 92 | ) -> nn.Module: 93 | msg = "DownsampleItem requires dim, factor, in_channels, channels" 94 | assert ( 95 | exists(dim) and exists(factor) and exists(in_channels) and exists(channels) 96 | ), msg 97 | Item = SelectX(Downsample) 98 | return Item( # type: ignore 99 | dim=dim, 100 | factor=factor, 101 | width=downsample_width, 102 | in_channels=in_channels, 103 | out_channels=channels, 104 | ) 105 | 106 | 107 | def UpsampleItem( 108 | dim: Optional[int] = None, 109 | factor: Optional[int] = None, 110 | channels: Optional[int] = None, 111 | out_channels: Optional[int] = None, 112 | upsample_mode: str = "nearest", 113 | upsample_kernel_size: int = 3, # Used with upsample_mode != "transpose" 114 | upsample_width: int = 1, # Used with upsample_mode == "transpose" 115 | **kwargs, 116 | ) -> nn.Module: 117 | msg = "UpsampleItem requires dim, factor, channels, out_channels" 118 | assert ( 119 | exists(dim) and exists(factor) and exists(channels) and exists(out_channels) 120 | ), msg 121 | if upsample_mode == "transpose": 122 | Item = SelectX(Upsample) 123 | return Item( # type: ignore 124 | dim=dim, 125 | factor=factor, 126 | width=upsample_width, 127 | in_channels=channels, 128 | out_channels=out_channels, 129 | ) 130 | else: 131 | Item = SelectX(UpsampleInterpolate) 132 | return Item( # type: ignore 133 | dim=dim, 134 | factor=factor, 135 | mode=upsample_mode, 136 | kernel_size=upsample_kernel_size, 137 | in_channels=channels, 138 | out_channels=out_channels, 139 | ) 140 | 141 | 142 | """ Main """ 143 | 144 | 145 | def ResnetItem( 146 | dim: Optional[int] = None, 147 | channels: Optional[int] = None, 148 | resnet_groups: Optional[int] = None, 149 | resnet_kernel_size: int = 3, 150 | **kwargs, 151 | ) -> nn.Module: 152 | msg = "ResnetItem requires dim, channels, and resnet_groups" 153 | assert exists(dim) and exists(channels) and exists(resnet_groups), msg 154 | Item = SelectX(ResnetBlock) 155 | conv_block_t = T(ConvBlock)(norm_t=T(nn.GroupNorm)(num_groups=resnet_groups)) 156 | return Item( # type: ignore 157 | dim=dim, 158 | in_channels=channels, 159 | out_channels=channels, 160 | kernel_size=resnet_kernel_size, 161 | conv_block_t=conv_block_t, 162 | ) 163 | 164 | 165 | def ConvNextV2Item( 166 | dim: Optional[int] = None, 167 | channels: Optional[int] = None, 168 | **kwargs, 169 | ) -> nn.Module: 170 | msg = "ResnetItem requires dim and channels" 171 | assert exists(dim) and exists(channels), msg 172 | Item = SelectX(ConvNextV2Block) 173 | return Item(dim=dim, channels=channels) # type: ignore 174 | 175 | 176 | def AttentionItem( 177 | channels: Optional[int] = None, 178 | attention_features: Optional[int] = None, 179 | attention_heads: Optional[int] = None, 180 | **kwargs, 181 | ) -> nn.Module: 182 | msg = "AttentionItem requires channels, attention_features, attention_heads" 183 | assert ( 184 | exists(channels) and exists(attention_features) and exists(attention_heads) 185 | ), msg 186 | Item = SelectX(Attention) 187 | return Packed( 188 | Item( # type: ignore 189 | features=channels, 190 | head_features=attention_features, 191 | num_heads=attention_heads, 192 | ) 193 | ) 194 | 195 | 196 | def CrossAttentionItem( 197 | channels: Optional[int] = None, 198 | attention_features: Optional[int] = None, 199 | attention_heads: Optional[int] = None, 200 | embedding_features: Optional[int] = None, 201 | **kwargs, 202 | ) -> nn.Module: 203 | msg = "CrossAttentionItem requires channels, embedding_features, attention_*" 204 | assert ( 205 | exists(channels) 206 | and exists(embedding_features) 207 | and exists(attention_features) 208 | and exists(attention_heads) 209 | ), msg 210 | Item = SelectXE(CrossAttention) 211 | return Packed( 212 | Item( # type: ignore 213 | features=channels, 214 | head_features=attention_features, 215 | num_heads=attention_heads, 216 | context_features=embedding_features, 217 | ) 218 | ) 219 | 220 | 221 | def ModulationItem( 222 | channels: Optional[int] = None, modulation_features: Optional[int] = None, **kwargs 223 | ) -> nn.Module: 224 | msg = "ModulationItem requires channels, modulation_features" 225 | assert exists(channels) and exists(modulation_features), msg 226 | Item = SelectXF(Modulation) 227 | return Packed( 228 | Item(in_features=channels, num_features=modulation_features) # type: ignore 229 | ) 230 | 231 | 232 | def LinearAttentionItem( 233 | channels: Optional[int] = None, 234 | attention_features: Optional[int] = None, 235 | attention_heads: Optional[int] = None, 236 | **kwargs, 237 | ) -> nn.Module: 238 | msg = "LinearAttentionItem requires attention_features and attention_heads" 239 | assert ( 240 | exists(channels) and exists(attention_features) and exists(attention_heads) 241 | ), msg 242 | Item = SelectX(T(Attention)(attention_base_t=LinearAttentionBase)) 243 | return Packed( 244 | Item( # type: ignore 245 | features=channels, 246 | head_features=attention_features, 247 | num_heads=attention_heads, 248 | ) 249 | ) 250 | 251 | 252 | def LinearCrossAttentionItem( 253 | channels: Optional[int] = None, 254 | attention_features: Optional[int] = None, 255 | attention_heads: Optional[int] = None, 256 | embedding_features: Optional[int] = None, 257 | **kwargs, 258 | ) -> nn.Module: 259 | msg = "LinearCrossAttentionItem requires channels, embedding_features, attention_*" 260 | assert ( 261 | exists(channels) 262 | and exists(embedding_features) 263 | and exists(attention_features) 264 | and exists(attention_heads) 265 | ), msg 266 | Item = SelectXE(T(CrossAttention)(attention_base_t=LinearAttentionBase)) 267 | return Packed( 268 | Item( # type: ignore 269 | features=channels, 270 | head_features=attention_features, 271 | num_heads=attention_heads, 272 | context_features=embedding_features, 273 | ) 274 | ) 275 | 276 | 277 | def FeedForwardItem( 278 | channels: Optional[int] = None, attention_multiplier: Optional[int] = None, **kwargs 279 | ) -> nn.Module: 280 | msg = "FeedForwardItem requires channels, attention_multiplier" 281 | assert exists(channels) and exists(attention_multiplier), msg 282 | Item = SelectX(FeedForward) 283 | return Packed( 284 | Item(features=channels, multiplier=attention_multiplier) # type: ignore 285 | ) 286 | 287 | 288 | def InjectChannelsItem( 289 | dim: Optional[int] = None, 290 | channels: Optional[int] = None, 291 | depth: Optional[int] = None, 292 | context_channels: Optional[int] = None, 293 | **kwargs, 294 | ) -> nn.Module: 295 | msg = "InjectChannelsItem requires dim, depth, channels, context_channels" 296 | assert ( 297 | exists(dim) and exists(depth) and exists(channels) and exists(context_channels) 298 | ), msg 299 | msg = "InjectChannelsItem requires context_channels > 0" 300 | assert context_channels > 0, msg 301 | 302 | conv = Conv( 303 | dim=dim, 304 | in_channels=channels + context_channels, 305 | out_channels=channels, 306 | kernel_size=1, 307 | ) 308 | 309 | @no_type_check 310 | def forward(x: Tensor, channels: Sequence[Optional[Tensor]]) -> Tensor: 311 | msg_ = f"context `channels` at depth {depth} in forward" 312 | assert depth < len(channels), f"Required {msg_}" 313 | context = channels[depth] 314 | shape = torch.Size([x.shape[0], context_channels, *x.shape[2:]]) 315 | msg = f"Required {msg_} to be tensor of shape {shape}, found {context.shape}" 316 | assert torch.is_tensor(context) and context.shape == shape, msg 317 | return conv(torch.cat([x, context], dim=1)) + x 318 | 319 | return SelectXC(Module)([conv], forward) # type: ignore 320 | 321 | 322 | """ Skip Adapters """ 323 | 324 | 325 | def SkipAdapter( 326 | dim: Optional[int] = None, 327 | in_channels: Optional[int] = None, 328 | out_channels: Optional[int] = None, 329 | **kwargs, 330 | ): 331 | msg = "SkipAdapter requires dim, in_channels, out_channels" 332 | assert exists(dim) and exists(in_channels) and exists(out_channels), msg 333 | Item = SelectX(Conv) 334 | return ( 335 | Item( # type: ignore 336 | dim=dim, 337 | in_channels=in_channels, 338 | out_channels=out_channels, 339 | kernel_size=1, 340 | ) 341 | if in_channels != out_channels 342 | else SelectX(nn.Identity)() 343 | ) 344 | 345 | 346 | """ Skip Connections """ 347 | 348 | 349 | def SkipAdd(**kwargs) -> nn.Module: 350 | return MergeAdd() 351 | 352 | 353 | def SkipCat( 354 | dim: Optional[int] = None, 355 | out_channels: Optional[int] = None, 356 | skip_scale: float = 2 ** -0.5, 357 | **kwargs, 358 | ) -> nn.Module: 359 | msg = "SkipCat requires dim, out_channels" 360 | assert exists(dim) and exists(out_channels), msg 361 | return MergeCat(dim=dim, channels=out_channels, scale=skip_scale) 362 | 363 | 364 | def SkipModulate( 365 | dim: Optional[int] = None, 366 | out_channels: Optional[int] = None, 367 | modulation_features: Optional[int] = None, 368 | **kwargs, 369 | ) -> nn.Module: 370 | msg = "SkipModulate requires dim, out_channels, modulation_features" 371 | assert exists(dim) and exists(out_channels) and exists(modulation_features), msg 372 | return MergeModulate( 373 | dim=dim, channels=out_channels, modulation_features=modulation_features 374 | ) 375 | 376 | 377 | """ Block """ 378 | 379 | 380 | class Block(nn.Module): 381 | def __init__( 382 | self, 383 | in_channels: int, 384 | downsample_t: Callable = DownsampleItem, 385 | upsample_t: Callable = UpsampleItem, 386 | skip_t: Callable = SkipAdd, 387 | skip_adapter_t: Callable = SkipAdapter, 388 | items: Sequence[Callable] = [], 389 | items_up: Optional[Sequence[Callable]] = None, 390 | out_channels: Optional[int] = None, 391 | inner_block: Optional[nn.Module] = None, 392 | **kwargs, 393 | ): 394 | super().__init__() 395 | out_channels = default(out_channels, in_channels) 396 | 397 | items_up = default(items_up, items) # type: ignore 398 | items_down = [downsample_t] + list(items) 399 | items_up = list(items_up) + [upsample_t] 400 | items_kwargs = dict( 401 | in_channels=in_channels, out_channels=out_channels, **kwargs 402 | ) 403 | 404 | # Build items stack: items down -> inner block -> items up 405 | items_all: List[nn.Module] = [] 406 | items_all += [item_t(**items_kwargs) for item_t in items_down] 407 | items_all += [inner_block] if exists(inner_block) else [] 408 | items_all += [item_t(**items_kwargs) for item_t in items_up] 409 | 410 | self.skip_adapter = skip_adapter_t(**items_kwargs) 411 | self.block = Sequential(*items_all) 412 | self.skip = skip_t(**items_kwargs) 413 | 414 | def forward( 415 | self, 416 | x: Tensor, 417 | features: Optional[Tensor] = None, 418 | embedding: Optional[Tensor] = None, 419 | channels: Optional[Sequence[Tensor]] = None, 420 | ) -> Tensor: 421 | skip = self.skip_adapter(x) 422 | x = self.block(x, features, embedding, channels) 423 | x = self.skip(skip, x, features) 424 | return x 425 | 426 | 427 | # Block type, to be provided in UNet 428 | # XBlock = T(Block, override=False) 429 | 430 | """ UNet """ 431 | 432 | 433 | class XUNet(nn.Module): 434 | def __init__( 435 | self, 436 | in_channels: int, 437 | blocks: Sequence, 438 | out_channels: Optional[int] = None, 439 | **kwargs, 440 | ): 441 | super().__init__() 442 | num_layers = len(blocks) 443 | out_channels = default(out_channels, in_channels) 444 | 445 | def Net(i: int) -> Optional[nn.Module]: 446 | if i == num_layers: 447 | return None # noqa 448 | block_t = blocks[i] 449 | in_ch = in_channels if i == 0 else blocks[i - 1].channels 450 | out_ch = out_channels if i == 0 else in_ch 451 | 452 | return block_t( 453 | in_channels=in_ch, 454 | out_channels=out_ch, 455 | depth=i, 456 | inner_block=Net(i + 1), 457 | **kwargs, 458 | ) 459 | 460 | self.net = Net(0) 461 | 462 | def forward( 463 | self, 464 | x: Tensor, 465 | *, 466 | features: Optional[Tensor] = None, 467 | embedding: Optional[Tensor] = None, 468 | channels: Optional[Sequence[Tensor]] = None, 469 | ) -> Tensor: 470 | return self.net(x, features, embedding, channels) # type: ignore 471 | -------------------------------------------------------------------------------- /diffshape_ssc/diffusion/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | curPath = os.path.abspath(os.path.dirname(__file__)) 5 | rootPath = os.path.split(curPath)[0] 6 | sys.path.append(rootPath) 7 | 8 | import torch.nn.functional as F 9 | from typing import Callable, Sequence, Type 10 | from math import pi 11 | from typing import Any, Optional, Tuple 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from einops import rearrange, repeat 16 | from torch import Tensor 17 | from tqdm import tqdm 18 | from diffusion.diff_utils import exists, groupby, default 19 | from a_unet.apex import ( 20 | AttentionItem, 21 | CrossAttentionItem, 22 | InjectChannelsItem, 23 | ModulationItem, 24 | ResnetItem, 25 | SkipCat, 26 | SkipModulate, 27 | XBlock, 28 | XUNet, 29 | ) 30 | 31 | from a_unet import ( 32 | TimeConditioningPlugin, 33 | ClassifierFreeGuidancePlugin, 34 | Module, 35 | T5Embedder, 36 | ) 37 | 38 | """ Distributions """ 39 | 40 | 41 | class Distribution: 42 | """Interface used by different distributions""" 43 | 44 | def __call__(self, num_samples: int, device: torch.device): 45 | raise NotImplementedError() 46 | 47 | 48 | class UniformDistribution(Distribution): 49 | def __init__(self, vmin: float = 0.0, vmax: float = 1.0): 50 | super().__init__() 51 | self.vmin, self.vmax = vmin, vmax 52 | 53 | def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): 54 | vmax, vmin = self.vmax, self.vmin 55 | return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin 56 | 57 | 58 | """ Diffusion Methods """ 59 | 60 | 61 | def pad_dims(x: Tensor, ndim: int) -> Tensor: 62 | # Pads additional ndims to the right of the tensor 63 | return x.view(*x.shape, *((1,) * ndim)) 64 | 65 | 66 | def clip(x: Tensor, dynamic_threshold: float = 0.0): 67 | if dynamic_threshold == 0.0: 68 | return x.clamp(-1.0, 1.0) 69 | else: 70 | # Dynamic thresholding 71 | # Find dynamic threshold quantile for each batch 72 | x_flat = rearrange(x, "b ... -> b (...)") 73 | scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) 74 | # Clamp to a min of 1.0 75 | scale.clamp_(min=1.0) 76 | # Clamp all values and scale 77 | scale = pad_dims(scale, ndim=x.ndim - scale.ndim) 78 | x = x.clamp(-scale, scale) / scale 79 | return x 80 | 81 | 82 | def extend_dim(x: Tensor, dim: int): 83 | # e.g. if dim = 4: shape [b] => [b, 1, 1, 1], 84 | return x.view(*x.shape + (1,) * (dim - x.ndim)) 85 | 86 | 87 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 88 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 89 | nn.init.kaiming_normal_(layer.weight) 90 | return layer 91 | 92 | 93 | class Diffusion(nn.Module): 94 | """Interface used by different diffusion methods""" 95 | 96 | pass 97 | 98 | 99 | class VDiffusion(Diffusion): 100 | def __init__( 101 | self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution(), loss_fn: Any = F.mse_loss 102 | ): 103 | super().__init__() 104 | self.net = net 105 | self.sigma_distribution = sigma_distribution 106 | self.loss_fn = loss_fn 107 | self.input_projection = Conv1d_with_init(2, 1, 1) 108 | 109 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 110 | angle = sigmas * pi / 2 111 | alpha, beta = torch.cos(angle), torch.sin(angle) 112 | return alpha, beta 113 | 114 | def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore 115 | batch_size, device = x.shape[0], x.device 116 | 117 | sigmas = self.sigma_distribution(num_samples=batch_size, device=device) 118 | sigmas_batch = extend_dim(sigmas, dim=x.ndim) 119 | # Get noise 120 | noise = torch.randn_like(x) 121 | # Combine input and noise weighted by half-circle 122 | alphas, betas = self.get_alpha_beta(sigmas_batch) 123 | x_noisy = alphas * x + betas * noise 124 | 125 | v_target = alphas * noise - betas * x 126 | # Predict velocity and return loss 127 | v_pred = self.net(x_noisy, sigmas, **kwargs) 128 | return self.loss_fn(v_pred, v_target) 129 | 130 | 131 | """ Schedules """ 132 | 133 | 134 | class Schedule(nn.Module): 135 | """Interface used by different sampling schedules""" 136 | 137 | def forward(self, num_steps: int, device: torch.device) -> Tensor: 138 | raise NotImplementedError() 139 | 140 | 141 | class LinearSchedule(Schedule): 142 | def __init__(self, start: float = 1.0, end: float = 0.0): 143 | super().__init__() 144 | self.start, self.end = start, end 145 | 146 | def forward(self, num_steps: int, device: Any) -> Tensor: 147 | return torch.linspace(self.start, self.end, num_steps, device=device) 148 | 149 | 150 | """ Samplers """ 151 | 152 | 153 | class Sampler(nn.Module): 154 | pass 155 | 156 | 157 | class VSampler(Sampler): 158 | diffusion_types = [VDiffusion] 159 | 160 | def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): 161 | super().__init__() 162 | self.net = net 163 | self.schedule = schedule 164 | self.input_projection = Conv1d_with_init(2, 1, 1) 165 | 166 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 167 | angle = sigmas * pi / 2 168 | alpha, beta = torch.cos(angle), torch.sin(angle) 169 | return alpha, beta 170 | 171 | def forward( # type: ignore 172 | self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs 173 | ) -> Tensor: 174 | b = x_noisy.shape[0] 175 | sigmas = self.schedule(num_steps + 1, device=x_noisy.device) 176 | sigmas = repeat(sigmas, "i -> i b", b=b) 177 | sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) 178 | alphas, betas = self.get_alpha_beta(sigmas_batch) 179 | progress_bar = tqdm(range(num_steps), disable=not show_progress) 180 | 181 | for i in progress_bar: 182 | v_pred = self.net(x_noisy, sigmas[i], **kwargs) 183 | x_pred = alphas[i] * x_noisy - betas[i] * v_pred 184 | noise_pred = betas[i] * x_noisy + alphas[i] * v_pred 185 | x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred 186 | progress_bar.set_description(f"Sampling (noise={sigmas[i + 1, 0]:.2f})") 187 | 188 | return x_noisy 189 | 190 | 191 | class DiffusionModel(nn.Module): 192 | def __init__( 193 | self, 194 | net_t: Callable, 195 | diffusion_t: Callable = VDiffusion, 196 | sampler_t: Callable = VSampler, 197 | loss_fn: Callable = torch.nn.functional.mse_loss, 198 | dim: int = 1, 199 | **kwargs, 200 | ): 201 | super().__init__() 202 | diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) 203 | sampler_kwargs, kwargs = groupby("sampler_", kwargs) 204 | 205 | self.net = net_t(dim=dim, **kwargs) 206 | self.diffusion = diffusion_t(net=self.net, loss_fn=loss_fn, **diffusion_kwargs) 207 | self.sampler = sampler_t(net=self.net, **sampler_kwargs) 208 | 209 | def forward(self, *args, **kwargs) -> Tensor: 210 | return self.diffusion(*args, **kwargs) 211 | 212 | def sample(self, *args, **kwargs) -> Tensor: 213 | return self.sampler(*args, **kwargs) 214 | 215 | 216 | def TextConditioningPlugin11( 217 | net_t: Type[nn.Module], embedder: Optional[nn.Module] = None 218 | ) -> Callable[..., nn.Module]: 219 | """Adds text conditioning""" 220 | embedder = embedder if exists(embedder) else T5Embedder() 221 | # msg = "TextConditioningPlugin embedder requires embedding_features attribute" 222 | # assert hasattr(embedder, "embedding_features"), msg 223 | features: int = embedder.embedding_features # type: ignore 224 | 225 | def Net(embedding_features: int = features, **kwargs) -> nn.Module: 226 | # msg = f"TextConditioningPlugin requires embedding_features={features}" 227 | # assert embedding_features == features, msg 228 | net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore 229 | 230 | def forward( 231 | x: Tensor, text: Sequence[str], embedding: Optional[Tensor] = None, **kwargs 232 | ): 233 | text_embedding = embedding 234 | # if exists(embedding): 235 | # text_embedding = torch.cat([text_embedding, embedding], dim=1) 236 | return net(x, embedding=text_embedding, **kwargs) 237 | 238 | return Module([embedder, net], forward) # type: ignore 239 | 240 | return Net 241 | 242 | 243 | def UNetV0( 244 | dim: int, 245 | in_channels: int, 246 | channels: Sequence[int], 247 | factors: Sequence[int], 248 | items: Sequence[int], 249 | attentions: Optional[Sequence[int]] = None, 250 | cross_attentions: Optional[Sequence[int]] = None, 251 | context_channels: Optional[Sequence[int]] = None, 252 | attention_features: Optional[int] = None, 253 | attention_heads: Optional[int] = None, 254 | embedding_features: Optional[int] = None, 255 | resnet_groups: int = 8, 256 | use_modulation: bool = True, 257 | modulation_features: int = 1024, 258 | embedding_max_length: Optional[int] = None, 259 | use_time_conditioning: bool = True, 260 | use_embedding_cfg: bool = False, 261 | use_text_conditioning: bool = False, 262 | out_channels: Optional[int] = None, 263 | ): 264 | # Set defaults and check lengths 265 | num_layers = len(channels) 266 | attentions = default(attentions, [0] * num_layers) 267 | cross_attentions = default(cross_attentions, [0] * num_layers) 268 | context_channels = default(context_channels, [0] * num_layers) 269 | xs = (channels, factors, items, attentions, cross_attentions, context_channels) 270 | assert all(len(x) == num_layers for x in xs) # type: ignore 271 | 272 | # Define UNet type 273 | UNetV0 = XUNet 274 | 275 | if use_embedding_cfg: 276 | msg = "use_embedding_cfg requires embedding_max_length" 277 | # assert exists(embedding_max_length), msg 278 | # UNetV0 = LanguageShapeletGuidancePlugin(UNetV0, embedding_max_length) 279 | assert exists(embedding_max_length), msg 280 | UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) 281 | 282 | if use_text_conditioning: 283 | UNetV0 = TextConditioningPlugin11(UNetV0) 284 | 285 | if use_time_conditioning: 286 | assert use_modulation, "use_time_conditioning requires use_modulation=True" 287 | UNetV0 = TimeConditioningPlugin(UNetV0) 288 | 289 | # Build 290 | return UNetV0( 291 | dim=dim, 292 | in_channels=in_channels, 293 | out_channels=out_channels, 294 | blocks=[ 295 | XBlock( 296 | channels=channels, 297 | factor=factor, 298 | context_channels=ctx_channels, 299 | items=( 300 | [ResnetItem] 301 | + [ModulationItem] * use_modulation 302 | + [InjectChannelsItem] * (ctx_channels > 0) 303 | + [AttentionItem] * att 304 | + [CrossAttentionItem] * cross 305 | ) 306 | * items, 307 | ) 308 | for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa 309 | ], 310 | skip_t=SkipModulate if use_modulation else SkipCat, 311 | attention_features=attention_features, 312 | attention_heads=attention_heads, 313 | embedding_features=embedding_features, 314 | modulation_features=modulation_features, 315 | resnet_groups=resnet_groups, 316 | ) 317 | -------------------------------------------------------------------------------- /diffshape_ssc/main_diffshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import numpy as np 6 | from diffusion.diffusion_model import VDiffusion, VSampler, DiffusionModel, UNetV0 7 | from torch.utils.data import DataLoader 8 | from sklearn.model_selection import train_test_split 9 | 10 | from semi_utils import set_seed, build_loss, evaluate_model_acc, lan_shapelet_contrastive_loss, get_all_text_labels, \ 11 | get_each_sample_distance_shapelet, get_similarity_shapelet, build_model, build_dataset, \ 12 | get_all_datasets, fill_nan_value, normalize_per_series, shuffler, UCRDataset, \ 13 | get_pesudo_via_high_confidence_softlabels 14 | from semi_backbone import ProjectionHead 15 | from parameter_shapelets import * 16 | from torch.cuda.amp import GradScaler, autocast 17 | 18 | ucr_datasets_dict = { 19 | 'AllGestureWiimoteX': {'0': 'poteg - pick-up', '1': 'shake - shake', '2': 'desno - one move to the right', 20 | '3': 'levo - one move to the left', '4': 'gor - one move to up', 21 | '5': 'dol - one move to down', '6': 'kroglevo - one left circle', 22 | '7': 'krogdesno - one right circle', '8': 'suneknot - one move toward the screen', 23 | '9': 'sunekven - one move away from the screen'}, 24 | 'AllGestureWiimoteY': {'0': 'poteg - pick-up', '1': 'shake - shake', '2': 'desno - one move to the right', 25 | '3': 'levo - one move to the left', '4': 'gor - one move to up', 26 | '5': 'dol - one move to down', '6': 'kroglevo - one left circle', 27 | '7': 'krogdesno - one right circle', '8': 'suneknot - one move toward the screen', 28 | '9': 'sunekven - one move away from the screen'}, 29 | 'AllGestureWiimoteZ': {'0': 'poteg - pick-up', '1': 'shake - shake', '2': 'desno - one move to the right', 30 | '3': 'levo - one move to the left', '4': 'gor - one move to up', 31 | '5': 'dol - one move to down', '6': 'kroglevo - one left circle', 32 | '7': 'krogdesno - one right circle', '8': 'suneknot - one move toward the screen', 33 | '9': 'sunekven - one move away from the screen'}, 34 | 'ArrowHead': {'0': 'Avonlea', '1': 'Clovis', '2': 'Mix'}, 35 | 'BME': {'0': 'Begin', '1': 'Middle', '2': 'End'}, 36 | 'Car': {'0': 'Sedan', '1': 'Pickup', '2': 'Minivan', '3': 'SUV'}, 37 | 'CBF': {'0': 'Cylinder', '1': 'Bell', '2': 'Funnel'}, 38 | 'Chinatown': {'0': 'Weekend', '1': 'Weekday'}, 39 | 'CinCECGTorso': {'0': 'People 1', '1': 'People 2', '2': 'People 3', '3': 'People 4'}, 40 | 'ChlorineConcentration': {'0': 'low', '1': 'middle', '2': 'high'}, 41 | 'Computers': {'0': 'Laptop', '1': 'Desktop'}, 42 | 'CricketX': {'0': 'Cancel Call', '1': 'Dead Ball', '2': 'Four', '3': 'Last Hour', '4': 'Leg Bye', '5': 'No Ball', 43 | '6': 'One Short', '7': 'Out', '8': 'Penalty Runs', '9': 'Six', '10': 'TV Replay', '11': 'Wide'}, 44 | 'CricketY': {'0': 'Cancel Call', '1': 'Dead Ball', '2': 'Four', '3': 'Last Hour', '4': 'Leg Bye', '5': 'No Ball', 45 | '6': 'One Short', '7': 'Out', '8': 'Penalty Runs', '9': 'Six', '10': 'TV Replay', '11': 'Wide'}, 46 | 'CricketZ': {'0': 'Cancel Call', '1': 'Dead Ball', '2': 'Four', '3': 'Last Hour', '4': 'Leg Bye', '5': 'No Ball', 47 | '6': 'One Short', '7': 'Out', '8': 'Penalty Runs', '9': 'Six', '10': 'TV Replay', '11': 'Wide'}, 48 | 'Crop': {'0': 'corn', '1': 'wheat', '2': 'dense building', '3': 'built indu', '4': 'diffuse building', 49 | '5': 'temporary meadow', '6': 'hardwood', '7': 'wasteland', '8': 'jachere', '9': 'soy', '10': 'water', 50 | '11': 'pre', '12': 'softwood', '13': 'sunflower', '14': 'sorghum', '15': 'eucalyptus', '16': 'rapeseed', 51 | '17': 'but drilling', '18': 'barley', '19': 'peas', '20': 'poplars', '21': 'mineral surface', 52 | '22': 'gravel', '23': 'lake'}, 53 | 'DiatomSizeReduction': {'0': 'Gomphonema augur', '1': 'Fragilariforma bicapitata', '2': 'Stauroneis smithii', 54 | '3': 'Eunotia tenella'}, 55 | 'DistalPhalanxOutlineAgeGroup': {'0': '0-6 years old', '1': '7-12 years old', '2': '13-19 years old'}, 56 | 'DistalPhalanxOutlineCorrect': {'0': 'Correct', '1': 'Incorrect'}, 57 | 'DistalPhalanxTW': {'0': 'belong to 0-6 years old', '1': 'not belong to 0-6 years old', 58 | '2': 'belong to 7-12 years old', '3': 'not belong to 7-12 years old', 59 | '4': 'belong to 13-19 years old', '5': 'not belong to 13-19 years old'}, 60 | 'DodgerLoopGame': {'0': 'Sunday', '1': 'Monday', '2': 'Tuesday', '3': 'Wednesday', '4': 'Thursday', '5': 'Friday', 61 | '6': 'Saturday'}, 62 | 'DodgerLoopWeekend': {'0': 'Sunday', '1': 'Monday', '2': 'Tuesday', '3': 'Wednesday', '4': 'Thursday', 63 | '5': 'Friday', '6': 'Saturday'}, 64 | 'Earthquakes': {'0': 'negative', '1': 'positive'}, 65 | 'ECG200': {'0': 'Ischemia', '1': 'Normal'}, 66 | 'ECG5000': {'0': 'Normal', '1': 'R-on-T premature ventricular contraction', 67 | '2': 'Supraventricular premature or ectopic beat', '3': 'Premature ventricular contraction', 68 | '4': 'Unclassifiable beat'}, 69 | 'ECGFiveDays': {'0': '12/11/1990', '1': '17/11/1990'}, 70 | 'ElectricDevices': {'0': 'screenGroup', '1': 'dishwasher', '2': 'coldGroup', '3': 'immersionHeater', '4': 'kettle', 71 | '5': 'ovenCooker', '6': 'washingMachine'}, 72 | 'EOGHorizontalSignal': {'0': 'Upper left', '1': 'Up', '2': 'Upper right', '3': 'Left', '4': 'Center', '5': 'Right', 73 | '6': 'Lower left', '7': 'Down', '8': 'Lower right', '9': 'Sil', 74 | '10': 'Left, up, right, down, left', '11': 'Blinking'}, 75 | 'EOGVerticalSignal': {'0': 'Upper left', '1': 'Up', '2': 'Upper right', '3': 'Left', '4': 'Center', '5': 'Right', 76 | '6': 'Lower left', '7': 'Down', '8': 'Lower right', '9': 'Sil', 77 | '10': 'Left, up, right, down, left', '11': 'Blinking'}, 78 | 'EthanolLevel': {'0': 'Thirty-five percent ethanol', '1': 'Thirty-eight percent ethanol', 79 | '2': 'Forty percent ethanol', '3': 'Forty-five percent ethanol'}, 80 | 'FaceAll': {'0': 'Student1', '1': 'Student2', '2': 'Student3', '3': 'Student4', '4': 'Student5', '5': 'Student6', 81 | '6': 'Student7', '7': 'Student8', '8': 'Student9', '9': 'Student10', '10': 'Student11', 82 | '11': 'Student12', '12': 'Student13', '13': 'Student14'}, 83 | 'FacesUCR': {'0': 'Student1', '1': 'Student2', '2': 'Student3', '3': 'Student4', '4': 'Student5', '5': 'Student6', 84 | '6': 'Student7', '7': 'Student8', '8': 'Student9', '9': 'Student10', '10': 'Student11', 85 | '11': 'Student12', '12': 'Student13', '13': 'Student14'}, 86 | 'Fish': {'0': 'Chinook salmon', '1': 'Winter coho', '2': 'brown trout', '3': 'Bonneville cutthroat', 87 | '4': 'Colorado River cutthroat trout', '5': 'Yellowstone cutthroat', '6': 'Mountain whitefish'}, 88 | 'FordA': {'0': 'not exists a certain symptom', '1': 'exists a certain symptom'}, 89 | 'FordB': {'0': 'not exists a certain symptom', '1': 'exists a certain symptom'}, 90 | 'FreezerRegularTrain': {'0': 'in the kitchen', '1': 'in the garage'}, 91 | 'FreezerSmallTrain': {'0': 'in the kitchen', '1': 'in the garage'}, 92 | 'GesturePebbleZ1': {'0': 'hh', '1': 'hu', '2': 'ud', '3': 'hud', '4': 'hh2', '5': 'hu2'}, 93 | 'GesturePebbleZ2': {'0': 'hh', '1': 'hu', '2': 'ud', '3': 'hud', '4': 'hh2', '5': 'hu2'}, 94 | 'GunPoint': {'0': 'Gun-Draw', '1': 'No gun pointing'}, 95 | 'GunPointAgeSpan': {'0': 'Gun (FG03, MG03, FG18, MG18)', '1': 'Point (FP03, MP03, FP18, MP18)'}, 96 | 'GunPointMaleVersusFemale': {'0': 'Female (FG03, FP03, FG18, FP18)', '1': 'Male (MG03, MP03, MG18, MP18)'}, 97 | 'GunPointOldVersusYoung': {'0': 'Young (FG03, MG03, FP03, MP03)', '1': 'Old (FG18, MG18, FP18, MP18)'}, 98 | 'Ham': {'0': 'Spanish', '1': 'French'}, 99 | 'HandOutlines': {'0': 'Male', '1': 'Female'}, 100 | 'Haptics': {'0': 'Person 1', '1': 'Person 2', '2': 'Person 3', '3': 'Person 4', '4': 'Person 5'}, 101 | 'Herring': {'0': 'North sea', '1': 'Thames'}, 102 | 'HouseTwenty': {'0': 'household aggregate usage of electricity', 103 | '1': 'aggregate electricity load of Tumble Dryer and Washing Machine'}, 104 | 'InlineSkate': {'0': 'Individual1', '1': 'Individual2', '2': 'Individual3', '3': 'Individual4', '4': 'Individual5', 105 | '5': 'Individual6', '6': 'Individual7'}, 106 | 'InsectEPGRegularTrain': {'0': 'Phloem Salivation', '1': 'Phloem Ingestion', '2': 'Xylem Ingestion'}, 107 | 'InsectEPGSmallTrain': {'0': 'Stylet passage through plant cells1', '1': 'Contact with Phloem Tissue', 108 | '2': 'Contact with Phloem Tissue'}, 109 | 'InsectWingbeatSound': {'0': 'Male Ae. aegypti', '1': 'Female Ae. aegypti', '2': 'Male Cx. tarsalis', 110 | '3': 'Female Cx. tarsalis', '4': 'Male Cx. quinquefasciants', 111 | '5': 'Female Cx. quinquefasciants', '6': 'Male Cx. stigmatosoma', 112 | '7': 'Female Cx. stigmatosoma', '8': 'Musca domestica', '9': 'Drosophila simulans', 113 | '10': 'Other insects'}, 114 | 'ItalyPowerDemand': {'0': 'Days from Oct to March', '1': 'Days from April to September'}, 115 | 'LargeKitchenAppliances': {'0': 'Washing Machine', '1': 'Tumble Dryer', '2': 'Dishwasher'}, 116 | 'Lightning2': {'0': 'class 1', '1': 'class 2'}, 117 | 'Mallat': {'0': 'case 1', '1': 'case 2', '2': 'case 3', '3': 'case 4', '4': 'case 5', '5': 'case 6', '6': 'case 7', 118 | '7': 'Original'}, 119 | 'Meat': {'0': 'chicken', '1': 'pork', '2': 'turkey'}, 120 | 'MedicalImages': {'0': 'brain', '1': 'spine', '2': 'heart', '3': 'liver', '4': 'adiposity', '5': 'breast', 121 | '6': 'muscle', '7': 'bone', '8': 'lung', '9': 'other'}, 122 | 'MelbournePedestrian': {'0': 'Bourke Street Mall (North)', '1': 'Southern Cross Station', '2': 'New Quay', 123 | '3': 'Flinders St Station Underpass', '4': 'QV Market-Elizabeth (West)', 124 | '5': 'Convention/Exhibition Centre', '6': 'Chinatown-Swanston St (North)', 125 | '7': 'Webb Bridge', '8': 'Tin Alley-Swanston St (West)', '9': 'Southbank'}, 126 | 'MiddlePhalanxOutlineAgeGroup': {'0': '0-6 years old', '1': '7-12 years old', '2': '13-19 years old'}, 127 | 'MiddlePhalanxOutlineCorrect': {'0': 'Correct', '1': 'Incorrect'}, 128 | 'MiddlePhalanxTW': {'0': '0-6 years old is correct', '1': '0-6 years old is incorrect', 129 | '2': '7-12 years old is correct', '3': '7-12 years old is incorrect', 130 | '4': '13-19 years old is correct', '5': '13-19 years old is incorrect'}, 131 | 'MixedShapesRegularTrain': {'0': 'Arrowhead', '1': 'Butterfly', '2': 'Fish', '3': 'Seashell', '4': 'Shield'}, 132 | 'MixedShapesSmallTrain': {'0': 'Arrowhead', '1': 'Butterfly', '2': 'Fish', '3': 'Seashell', '4': 'Shield'}, 133 | 'MoteStrain': {'0': 'q8calibHumid', '1': 'q8calibHumTemp'}, 134 | 'NonInvasiveFetalECGThorax1': {'0': 'ECG signal 1', '1': 'ECG signal 2', '2': 'ECG signal 3', '3': 'ECG signal 4', 135 | '4': 'ECG signal 5', '5': 'ECG signal 6', '6': 'ECG signal 7', '7': 'ECG signal 8' 136 | , '8': 'ECG signal 9', '9': 'ECG signal 10', '10': 'ECG signal 11', '11': 'ECG signal 12', 137 | '12': 'ECG signal 13', '13': 'ECG signal 14', '14': 'ECG signal 15' 138 | , '15': 'ECG signal 16', '16': 'ECG signal 17', '17': 'ECG signal 18', '18': 'ECG signal 19', 139 | '19': 'ECG signal 20', '20': 'ECG signal 21', '21': 'ECG signal 22' 140 | , '22': 'ECG signal 23', '23': 'ECG signal 24', '24': 'ECG signal 25', '25': 'ECG signal 26', 141 | '26': 'ECG signal 27', '27': 'ECG signal 28', '28': 'ECG signal 29' 142 | , '29': 'ECG signal 30', '30': 'ECG signal 31', '31': 'ECG signal 32', '32': 'ECG signal 33', 143 | '33': 'ECG signal 34', '34': 'ECG signal 35', '35': 'ECG signal 36' 144 | , '36': 'ECG signal 37', '37': 'ECG signal 38', '38': 'ECG signal 39', '39': 'ECG signal 40', 145 | '40': 'ECG signal 41', '41': 'ECG signal 42'}, 146 | 'NonInvasiveFetalECGThorax2': {'0': 'ECG signal 1', '1': 'ECG signal 2', '2': 'ECG signal 3', '3': 'ECG signal 4', 147 | '4': 'ECG signal 5', '5': 'ECG signal 6', '6': 'ECG signal 7', '7': 'ECG signal 8' 148 | , '8': 'ECG signal 9', '9': 'ECG signal 10', '10': 'ECG signal 11', '11': 'ECG signal 12', 149 | '12': 'ECG signal 13', '13': 'ECG signal 14', '14': 'ECG signal 15' 150 | , '15': 'ECG signal 16', '16': 'ECG signal 17', '17': 'ECG signal 18', '18': 'ECG signal 19', 151 | '19': 'ECG signal 20', '20': 'ECG signal 21', '21': 'ECG signal 22' 152 | , '22': 'ECG signal 23', '23': 'ECG signal 24', '24': 'ECG signal 25', '25': 'ECG signal 26', 153 | '26': 'ECG signal 27', '27': 'ECG signal 28', '28': 'ECG signal 29' 154 | , '29': 'ECG signal 30', '30': 'ECG signal 31', '31': 'ECG signal 32', '32': 'ECG signal 33', 155 | '33': 'ECG signal 34', '34': 'ECG signal 35', '35': 'ECG signal 36' 156 | , '36': 'ECG signal 37', '37': 'ECG signal 38', '38': 'ECG signal 39', '39': 'ECG signal 40', 157 | '40': 'ECG signal 41', '41': 'ECG signal 42'}, 158 | 'OSULeaf': {'0': 'Acer Circinatum', '1': 'Acer Glabrum', '2': 'Acer Macrophyllum', '3': 'Acer Negundo', 159 | '4': 'Quercus Garryana', '5': 'Quercus Kelloggii'}, 160 | 'PhalangesOutlinesCorrect': {'0': 'Correct', '1': 'Incorrect'}, 161 | 'Phoneme': {'0': 'HH', '1': 'DH', '2': 'F', '3': 'S', '4': 'SH', '5': 'TH', '6': 'V', '7': 'Z', '8': 'ZH', 162 | '9': 'CH', '10': 'JH', '11': 'B', '12': 'D', '13': 'G', '14': 'K', '15': 'P', '16': 'T', '17': 'M', 163 | '18': 'N', '19': 'NG', '20': 'AA', '21': 'AE', '22': 'AH', '23': 'UW', '24': 'AO', '25': 'AW', 164 | '26': 'AY', '27': 'UH', '28': 'EH', '29': 'ER', '30': 'EY', '31': 'OY', '32': 'IH', '33': 'IY', 165 | '34': 'OW', '35': 'W', '36': 'Y', '37': 'L', '38': 'R'}, 166 | 'PLAID': {'0': 'air conditioner', '1': 'compact flourescent lamp', '2': 'fan', '3': 'fridge', '4': 'hairdryer', 167 | '5': 'heater', '6': 'incandescent light bulb', '7': 'laptop', '8': 'microwave', '9': 'vacuum', 168 | '10': 'wahing machine'}, 169 | 'Plane': {'0': 'Mirage', '1': 'Eurofighter', '2': 'F-14 wings closed', '3': 'F-14 wings opened', '4': 'Harrier', 170 | '5': 'F-22', '6': 'F-15'}, 171 | 'PowerCons': {'0': 'Warm season', '1': 'Cold season'}, 172 | 'ProximalPhalanxOutlineAgeGroup': {'0': '0-6 years old', '1': '7-12 years old', '2': '13-19 years old'}, 173 | 'ProximalPhalanxOutlineCorrect': {'0': 'Correct', '1': 'Incorrect'}, 174 | 'ProximalPhalanxTW': {'0': 'that 0-6 years old is correct', '1': 'that 0-6 years old is incorrect', 175 | '2': 'that 7-12 years old is correct', '3': 'that 7-12 years old is incorrect', 176 | '4': 'that 13-19 years old is correct', '5': 'that 13-19 years old is incorrect'}, 177 | 'RefrigerationDevices': {'0': 'Fridge/Freezer', '1': 'Refrigerator', '2': 'Upright Freezer'}, 178 | 'ScreenType': {'0': 'CRT TV', '1': 'LCD TV', '2': 'Computer Monitor'}, 179 | 'SemgHandGenderCh2': {'0': 'Female', '1': 'Male'}, 180 | 'SemgHandMovementCh2': {'0': 'Cylindrical', '1': 'Hook', '2': 'Tip', '3': 'Palmar', '4': 'Spherical', 181 | '5': 'Lateral'}, 182 | 'SemgHandSubjectCh2': {'0': 'Female 1', '1': 'Female 2', '2': 'Female 3', '3': 'Male 1', '4': 'Male 2'}, 183 | 'ShapeletSim': {'0': 'shape 1', '1': 'shape 2'}, 184 | 'SmallKitchenAppliances': {'0': 'Kettle', '1': 'Microwave', '2': 'Toaster'}, 185 | 'SmoothSubspace': {'0': 'from time stamp 1-5 ', '1': 'from time stamp 6-10', '2': 'from time stamp 11-15'}, 186 | 'SonyAIBORobotSurface1': {'0': 'walk on carpet', '1': 'walk on cement'}, 187 | 'SonyAIBORobotSurface2': {'0': 'walk on carpet', '1': 'walk on cement'}, 188 | 'StarLightCurves': {'0': 'Cepheid', '1': 'Eclipsing Binary', '2': 'RR Lyrae'}, 189 | 'Strawberry': {'0': 'strawberry', '1': 'non-strawberry'}, 190 | 'SwedishLeaf': {'0': 'Ulmus carpinifolia', '1': 'Acer', '2': 'Salix aurita', '3': 'Quercus', '4': 'Alnus incana', 191 | '5': 'Betula pubescens', '6': 'Salix alba Sericea', '7': 'Populus tremula', '8': 'Ulmus glabra', 192 | '9': 'Sorbus aucuparia', '10': 'Salix sinerea', '11': 'Populus', '12': 'Tilia', 193 | '13': 'Sorbus intermedia', '14': 'Fagus silvatica'}, 194 | 'Symbols': {'0': 'symbol 1', '1': 'symbol 2', '2': 'symbol 3', '3': 'symbol 4', '4': 'symbol 5', '5': 'symbol 6'}, 195 | 'SyntheticControl': {'0': 'Normal', '1': 'Cyclic', '2': 'Increasing trend', '3': 'Decreasing trend', 196 | '4': 'Upward shift', '5': 'Downward shift'}, 197 | 'ToeSegmentation1': {'0': 'normal walk', '1': 'abnormal walk'}, 198 | 'ToeSegmentation2': {'0': 'normal walk', '1': 'abnormal walk'}, 199 | 'Trace': {'0': 'The second feature of class two', '1': 'The second feature of class six', 200 | '2': 'The third feature of class three', '3': 'The third feature of class seven'}, 201 | 'TwoLeadECG': {'0': 'signal 0', '1': 'signal 1'}, 202 | 'TwoPatterns': {'0': 'down-down', '1': 'up-down', '2': 'down-up', '3': 'up-up'}, 203 | 'UMD': {'0': 'Up', '1': 'Middle', '2': 'Down'}, 204 | 'UWaveGestureLibraryAll': {'0': 'fold line', '1': 'clockwise square', '2': 'right arrow', '3': 'left arrow', 205 | '4': 'up arrow', '5': 'down arrow', '6': 'clockwise circler', 206 | '7': 'anticlockwise circle'}, 207 | 'UWaveGestureLibraryX': {'0': 'fold line', '1': 'clockwise square', '2': 'right arrow', '3': 'left arrow', 208 | '4': 'up arrow', '5': 'down arrow', '6': 'clockwise circler', '7': 'anticlockwise circle'}, 209 | 'UWaveGestureLibraryY': {'0': 'fold line', '1': 'clockwise square', '2': 'right arrow', '3': 'left arrow', 210 | '4': 'up arrow', '5': 'down arrow', '6': 'clockwise circler', '7': 'anticlockwise circle'}, 211 | 'UWaveGestureLibraryZ': {'0': 'fold line', '1': 'clockwise square', '2': 'right arrow', '3': 'left arrow', 212 | '4': 'up arrow', '5': 'down arrow', '6': 'clockwise circler', '7': 'anticlockwise circle'}, 213 | 'Wafer': {'0': 'normal', '1': 'abnormal'}, 214 | 'Wine': {'0': 'strawberry', '1': 'non-strawberry'}, 215 | 'WordSynonyms': {'0': 'one', '1': 'two', '2': 'three', '3': 'four', '4': 'five', '5': 'six', '6': 'seven', 216 | '7': 'eight', '8': 'nine', '9': 'ten', '10': 'eleven', '11': 'twelve', '12': 'thirteen', 217 | '13': 'fourteen', '14': 'fifteen', '15': 'sixteen', '16': 'seventeen', '17': 'eighteen', 218 | '18': 'nineteen', '19': 'twenty', '20': 'twenty-one', '21': 'twenty-two', '22': 'twenty-three', 219 | '23': 'twenty-four', '24': 'twenty-five'}, 220 | 'Worms': {'0': 'wild-type', '1': 'goa-1', '2': 'unc-1', '3': 'unc-38', '4': 'unc-63'}, 221 | 'WormsTwoClass': {'0': 'wild-type', '1': 'mutant'}, 222 | 'Yoga': {'0': 'male', '1': 'female'}, 223 | } 224 | 225 | if __name__ == '__main__': 226 | 227 | parser = argparse.ArgumentParser() 228 | 229 | # Base setup 230 | parser.add_argument('--backbone', type=str, default='fcn', help='encoder backbone, fcn') 231 | parser.add_argument('--random_seed', type=int, default=42, help='shuffle seed') 232 | parser.add_argument('--total_dim', type=int, default=100, help='total numbers of slicing shapelets') 233 | parser.add_argument('--target_dim', type=int, default=5, help='2, 5, 10') 234 | parser.add_argument('--len_shapelet_ratio', type=float, default=0.2, help='0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8') 235 | 236 | # Dataset setup 237 | parser.add_argument('--dataset', type=str, default='Trace', help='') 238 | parser.add_argument('--dataroot', type=str, default='.../UCRArchive_2018', help='path of UCR folder') 239 | parser.add_argument('--num_classes', type=int, default=2, help='number of class') 240 | parser.add_argument('--input_size', type=int, default=2, help='input_size') 241 | 242 | # training setup 243 | parser.add_argument('--labeled_ratio', type=float, default=0.1, help='0.1, 0.2, 0.4') 244 | parser.add_argument('--warmup_epochs', type=int, default=300, 245 | help='warmup epochs using only labeled data for ssl') 246 | parser.add_argument('--temperature', type=float, default=50, help='20 or 50') 247 | parser.add_argument('--sup_con_mu', type=float, default=0.001, help='0.001 or 0.005') ## prompt_toolkit_series_i 248 | parser.add_argument('--sup_df', type=float, default=0.01, 249 | help='0.001 or 0.01') 250 | parser.add_argument('--prompt_toolkit_series_i', type=int, default=0, help='') 251 | 252 | parser.add_argument('--loss', type=str, default='cross_entropy', help='loss function') 253 | parser.add_argument('--optimizer', type=str, default='adam', help='optimizer') 254 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 255 | parser.add_argument('--batch_size', type=int, default=128, help='') 256 | parser.add_argument('--epoch', type=int, default=1000, help='training epoch') 257 | parser.add_argument('--cuda', type=str, default='cuda:0') 258 | 259 | # classifier setup 260 | parser.add_argument('--classifier', type=str, default='linear', help='') 261 | parser.add_argument('--classifier_input', type=int, default=128, help='input dim of the classifiers') 262 | 263 | args = parser.parse_args() 264 | device = torch.device(args.cuda if torch.cuda.is_available() else "cpu") 265 | set_seed(args) 266 | 267 | if args.labeled_ratio == 0.1: 268 | args.total_dim = ucr_hyp_dict_10[args.dataset]["target_dim"] 269 | args.len_shapelet_ratio = ucr_hyp_dict_10[args.dataset]["len_shapelet_ratio"] 270 | args.prompt_toolkit_series_i = ucr_hyp_prompt_10[args.dataset]["prompt_toolkit_series_i"] - 1 271 | 272 | if args.labeled_ratio == 0.2: 273 | args.total_dim = ucr_hyp_dict_20[args.dataset]["target_dim"] 274 | args.len_shapelet_ratio = ucr_hyp_dict_20[args.dataset]["len_shapelet_ratio"] 275 | args.prompt_toolkit_series_i = ucr_hyp_prompt_20[args.dataset]["prompt_toolkit_series_i"] - 1 276 | 277 | if args.labeled_ratio == 0.4: 278 | args.total_dim = ucr_hyp_dict_40[args.dataset]["target_dim"] 279 | args.len_shapelet_ratio = ucr_hyp_dict_40[args.dataset]["len_shapelet_ratio"] 280 | args.prompt_toolkit_series_i = ucr_hyp_prompt_40[args.dataset]["prompt_toolkit_series_i"] - 1 281 | 282 | sum_dataset, sum_target, num_classes = build_dataset(args) 283 | args.num_classes = num_classes 284 | 285 | train_datasets, train_targets, val_datasets, val_targets, test_datasets, test_targets = get_all_datasets( 286 | sum_dataset, sum_target) 287 | 288 | len_series = train_datasets[0].shape[-1] 289 | 290 | window_shape_s = np.int(len_series * args.len_shapelet_ratio) 291 | temp_train_dataset = np.lib.stride_tricks.sliding_window_view( 292 | x=train_datasets[0], 293 | window_shape=window_shape_s, 294 | axis=1 295 | ) 296 | args.input_size = args.target_dim 297 | args.total_dim = temp_train_dataset.shape[1] 298 | 299 | while args.batch_size * 2 > train_datasets[0].shape[0]: 300 | args.batch_size = args.batch_size // 2 301 | 302 | text_embedding_labels = get_all_text_labels(ucr_datasets_dict=ucr_datasets_dict, 303 | dataset_name=args.dataset, num_labels=num_classes, device=device, 304 | prompt_toolkit_series_i=args.prompt_toolkit_series_i) 305 | 306 | fcn_model, fcn_classifier = build_model(args) 307 | fcn_model = fcn_model.to(device) 308 | fcn_classifier = fcn_classifier.to(device) 309 | 310 | mlp_text_head = ProjectionHead(input_dim=768, output_dim=128).to(device) 311 | 312 | conv_model = nn.Conv2d(in_channels=args.total_dim, out_channels=args.target_dim, kernel_size=3, padding='same') 313 | conv_model = conv_model.to(device) 314 | 315 | loss_fcn = build_loss(args).to(device) 316 | 317 | df_model = DiffusionModel( 318 | net_t=UNetV0, 319 | in_channels=args.target_dim, 320 | channels=[8, 32, 64, 64], 321 | factors=[1, 1, 1, 1], 322 | items=[1, 2, 2, 2], 323 | attentions=[0, 0, 0, 1], 324 | cross_attentions=[0, 0, 0, 1], 325 | modulation_features=64, 326 | attention_heads=8, 327 | attention_features=64, 328 | diffusion_t=VDiffusion, 329 | sampler_t=VSampler, 330 | use_embedding_cfg=True, 331 | embedding_features=window_shape_s, 332 | embedding_max_length=args.target_dim, 333 | ) 334 | df_model = df_model.to(device) 335 | 336 | conv_model_init_state = conv_model.state_dict() 337 | fcn_model_init_state = fcn_model.state_dict() 338 | fcn_classifier_init_state = fcn_classifier.state_dict() 339 | mlp_text_head_init_state = mlp_text_head.state_dict() 340 | df_model_init_state = df_model.state_dict() 341 | 342 | optimizer = torch.optim.Adam([{'params': conv_model.parameters()}, {'params': mlp_text_head.parameters()}, 343 | {'params': df_model.parameters()}, {'params': fcn_model.parameters()}, 344 | {'params': fcn_classifier.parameters()}], 345 | lr=args.lr) 346 | 347 | print('Start training on {}'.format(args.dataset)) 348 | 349 | losses = [] 350 | test_accuracies = [] 351 | train_time = 0.0 352 | end_val_epochs = [] 353 | 354 | for i, train_dataset in enumerate(train_datasets): 355 | t = time.time() 356 | 357 | conv_model.load_state_dict(conv_model_init_state) 358 | fcn_model.load_state_dict(fcn_model_init_state) 359 | fcn_classifier.load_state_dict(fcn_classifier_init_state) 360 | df_model.load_state_dict(df_model_init_state) 361 | mlp_text_head.load_state_dict(mlp_text_head_init_state) 362 | 363 | print('{} fold start training and evaluate'.format(i)) 364 | 365 | scaler = GradScaler() 366 | 367 | train_target = train_targets[i] 368 | val_dataset = val_datasets[i] 369 | val_target = val_targets[i] 370 | 371 | test_dataset = test_datasets[i] 372 | test_target = test_targets[i] 373 | 374 | train_dataset, val_dataset, test_dataset = fill_nan_value(train_dataset, val_dataset, test_dataset) 375 | 376 | # TODO normalize per series 377 | train_dataset = normalize_per_series(train_dataset) 378 | val_dataset = normalize_per_series(val_dataset) 379 | test_dataset = normalize_per_series(test_dataset) 380 | 381 | train_dataset = np.lib.stride_tricks.sliding_window_view( 382 | x=train_dataset, 383 | window_shape=window_shape_s, 384 | axis=1 385 | ) 386 | 387 | val_dataset = np.lib.stride_tricks.sliding_window_view( 388 | x=val_dataset, 389 | window_shape=window_shape_s, 390 | axis=1 391 | ) 392 | 393 | test_dataset = np.lib.stride_tricks.sliding_window_view( 394 | x=test_dataset, 395 | window_shape=window_shape_s, 396 | axis=1 397 | ) 398 | 399 | train_labeled, train_unlabeled, y_labeled, y_unlabeled = train_test_split(train_dataset, train_target, 400 | test_size=( 401 | 1 - args.labeled_ratio), 402 | random_state=args.random_seed) 403 | mask_labeled = np.zeros(len(y_labeled)) 404 | mask_unlabeled = np.ones(len(y_unlabeled)) 405 | mask_train = np.concatenate([mask_labeled, mask_unlabeled]) 406 | train_all_split = np.concatenate([train_labeled, train_unlabeled]) 407 | y_label_split = np.concatenate([y_labeled, y_unlabeled]) 408 | 409 | x_train_all, y_train_all = shuffler(train_all_split, y_label_split) 410 | mask_train, _ = shuffler(mask_train, mask_train) 411 | y_train_all[mask_train == 1] = -1 ## Generate unlabeled data 412 | 413 | x_train_all = torch.from_numpy(x_train_all).to(device) 414 | y_train_all = torch.from_numpy(y_train_all).to(device).to(torch.int64) 415 | 416 | x_train_labeled_all = x_train_all[mask_train == 0] 417 | y_train_labeled_all = y_train_all[mask_train == 0] 418 | 419 | train_set_labled = UCRDataset(x_train_labeled_all, y_train_labeled_all) 420 | train_set = UCRDataset(x_train_all, y_train_all) 421 | val_set = UCRDataset(torch.from_numpy(val_dataset).to(device), 422 | torch.from_numpy(val_target).to(device).to(torch.int64)) 423 | test_set = UCRDataset(torch.from_numpy(test_dataset).to(device), 424 | torch.from_numpy(test_target).to(device).to(torch.int64)) 425 | 426 | train_labeled_loader = DataLoader(train_set_labled, batch_size=args.batch_size, num_workers=0, 427 | drop_last=False) 428 | train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=0, drop_last=False) 429 | val_loader = DataLoader(val_set, batch_size=args.batch_size, num_workers=0) 430 | test_loader = DataLoader(test_set, batch_size=args.batch_size, num_workers=0) 431 | 432 | last_val_accu = 0 433 | stop_count = 0 434 | increase_count = 0 435 | 436 | num_steps = train_set.__len__() // args.batch_size 437 | 438 | max_val_accu = 0 439 | test_accuracy = 0 440 | end_val_epoch = 0 441 | 442 | for epoch in range(args.epoch): 443 | 444 | if stop_count == 80 or increase_count == 80: 445 | print('model convergent at epoch {}, early stopping'.format(epoch)) 446 | break 447 | 448 | num_iterations = 0 449 | epoch_train_loss = 0 450 | 451 | conv_model.train() 452 | df_model.train() 453 | fcn_model.train() 454 | fcn_classifier.train() 455 | mlp_text_head.train() 456 | 457 | if epoch < args.warmup_epochs: 458 | for x, y in train_labeled_loader: 459 | 460 | if x.shape[0] < 2: 461 | continue 462 | 463 | optimizer.zero_grad() 464 | 465 | with autocast(): 466 | predicted = conv_model(torch.unsqueeze(x, 2)) 467 | 468 | raw_similar_shapelets_list = None 469 | transformation_similart_loss = None 470 | for _i in range(predicted.shape[0]): 471 | _, _raw_similar_shapelets = get_each_sample_distance_shapelet( 472 | generator_shapelet=predicted[_i], 473 | raw_shapelet=x[_i], 474 | topk=1) 475 | _raw_similar_shapelets = torch.unsqueeze(_raw_similar_shapelets, 0) 476 | if raw_similar_shapelets_list is None: 477 | raw_similar_shapelets_list = _raw_similar_shapelets 478 | else: 479 | raw_similar_shapelets_list = torch.cat( 480 | (raw_similar_shapelets_list, _raw_similar_shapelets), 481 | 0) 482 | 483 | _i_sim_loss = get_similarity_shapelet(generator_shapelet=predicted[_i]) 484 | if transformation_similart_loss == None: 485 | transformation_similart_loss = 0.01 * _i_sim_loss 486 | else: 487 | transformation_similart_loss = transformation_similart_loss + 0.01 * _i_sim_loss 488 | 489 | loss_df = df_model(torch.squeeze(predicted, 2), embedding=raw_similar_shapelets_list) 490 | 491 | fcn_cls_emb = fcn_model(torch.squeeze(predicted, 2)) 492 | 493 | text_embedding = mlp_text_head(text_embedding_labels) 494 | text_embd_batch = None 495 | _i = 0 496 | for _y in y: 497 | temp_text_embd = torch.unsqueeze(text_embedding[_y], 0) 498 | if text_embd_batch is None: 499 | text_embd_batch = temp_text_embd 500 | else: 501 | text_embd_batch = torch.cat((text_embd_batch, temp_text_embd), 0) 502 | _i = _i + 1 503 | 504 | batch_sup_contrastive_loss = lan_shapelet_contrastive_loss( 505 | embd_batch=torch.nn.functional.normalize(fcn_cls_emb), 506 | text_embd_batch=torch.nn.functional.normalize(text_embd_batch), 507 | labels=y, 508 | device=device, 509 | temperature=args.temperature, 510 | base_temperature=args.temperature) 511 | 512 | fcn_cls_prd = fcn_classifier(fcn_cls_emb) 513 | step_loss_fcn = loss_fcn(fcn_cls_prd, y) 514 | 515 | sum_loss = transformation_similart_loss + args.sup_df * loss_df + step_loss_fcn + batch_sup_contrastive_loss * args.sup_con_mu ## 516 | 517 | scaler.scale(sum_loss).backward() 518 | scaler.step(optimizer) 519 | scaler.update() 520 | 521 | epoch_train_loss += sum_loss.item() 522 | else: 523 | for x, y in train_loader: 524 | 525 | if (num_iterations + 1) * args.batch_size < train_set.__len__(): 526 | mask_train_batch = mask_train[ 527 | num_iterations * args.batch_size: (num_iterations + 1) * args.batch_size] 528 | else: 529 | mask_train_batch = mask_train[num_iterations * args.batch_size:] 530 | 531 | mask_labeled = [True if mask_train_batch[m] == 0 else False for m in range(len(mask_train_batch))] 532 | 533 | optimizer.zero_grad() 534 | 535 | with autocast(): 536 | 537 | predicted = conv_model(torch.unsqueeze(x, 2)) 538 | 539 | raw_similar_shapelets_list = None 540 | transformation_similart_loss = None 541 | for _i in range(predicted.shape[0]): 542 | _, _raw_similar_shapelets = get_each_sample_distance_shapelet( 543 | generator_shapelet=predicted[_i], 544 | raw_shapelet=x[_i], 545 | topk=1) 546 | _raw_similar_shapelets = torch.unsqueeze(_raw_similar_shapelets, 0) 547 | if raw_similar_shapelets_list is None: 548 | raw_similar_shapelets_list = _raw_similar_shapelets 549 | else: 550 | raw_similar_shapelets_list = torch.cat( 551 | (raw_similar_shapelets_list, _raw_similar_shapelets), 552 | 0) 553 | 554 | _i_sim_loss = get_similarity_shapelet(generator_shapelet=predicted[_i]) 555 | if transformation_similart_loss == None: 556 | transformation_similart_loss = _i_sim_loss * 0.01 557 | else: 558 | transformation_similart_loss = transformation_similart_loss + _i_sim_loss * 0.01 559 | 560 | loss_df = df_model(torch.squeeze(predicted, 2), embedding=raw_similar_shapelets_list) 561 | 562 | fcn_cls_emb = fcn_model(torch.squeeze(predicted, 2)) 563 | 564 | new_mask_labeled = None 565 | end_all_label = None 566 | if len(y[mask_labeled]) >= 1: 567 | 568 | fcn_cls_prd = fcn_classifier(fcn_cls_emb) 569 | 570 | if epoch > args.warmup_epochs: 571 | 572 | new_mask_labeled, end_all_label = get_pesudo_via_high_confidence_softlabels(y_label=y, 573 | pseudo_label_soft=fcn_cls_prd, 574 | mask_label=mask_labeled, 575 | num_real_class=args.num_classes, 576 | device=device, 577 | p_cutoff=0.99) 578 | 579 | step_loss_fcn = loss_fcn(fcn_cls_prd[new_mask_labeled], end_all_label[new_mask_labeled]) 580 | else: 581 | step_loss_fcn = loss_fcn(fcn_cls_prd[mask_labeled], y[mask_labeled]) 582 | 583 | else: 584 | step_loss_fcn = 0 585 | 586 | if len(y[mask_labeled]) >= 1: 587 | if new_mask_labeled is not None: 588 | _mask_labeled = new_mask_labeled 589 | y = end_all_label 590 | else: 591 | _mask_labeled = mask_labeled 592 | 593 | text_embedding = mlp_text_head(text_embedding_labels) 594 | text_embd_batch = None 595 | _i = 0 596 | for _y in y[_mask_labeled]: 597 | temp_text_embd = torch.unsqueeze(text_embedding[_y], 0) 598 | if text_embd_batch is None: 599 | text_embd_batch = temp_text_embd 600 | else: 601 | text_embd_batch = torch.cat((text_embd_batch, temp_text_embd), 0) 602 | _i = _i + 1 603 | 604 | batch_sup_contrastive_loss = lan_shapelet_contrastive_loss( 605 | embd_batch=torch.nn.functional.normalize(fcn_cls_emb[_mask_labeled]), 606 | text_embd_batch=torch.nn.functional.normalize(text_embd_batch), 607 | labels=y[_mask_labeled], 608 | device=device, 609 | temperature=args.temperature, 610 | base_temperature=args.temperature) 611 | else: 612 | batch_sup_contrastive_loss = 0 613 | 614 | step_loss2 = 0.0 615 | if len(y[mask_labeled]) >= 1: 616 | _a, _b, _c = torch.squeeze(predicted[mask_labeled], 2).shape 617 | 618 | noise = torch.randn(_a, _b, _c).to(device) # [batch_size, in_channels, length] 619 | df_sample = df_model.sample(noise, num_steps=10, 620 | embedding=raw_similar_shapelets_list[ 621 | mask_labeled]) # Suggested num_steps 5, 10, 20 622 | df_sample = torch.unsqueeze(df_sample, 2) 623 | fcn_cls_emb1 = fcn_model(torch.squeeze(df_sample, 2)) 624 | fcn_cls_prd1 = fcn_classifier(fcn_cls_emb1) 625 | step_loss_fcn2 = loss_fcn(fcn_cls_prd1, y[mask_labeled]) 626 | step_loss2 = step_loss_fcn2 627 | 628 | sum_loss = transformation_similart_loss + args.sup_df * ( 629 | loss_df) + step_loss_fcn + batch_sup_contrastive_loss * args.sup_con_mu + step_loss2 ## 630 | 631 | scaler.scale(sum_loss).backward() 632 | scaler.step(optimizer) 633 | scaler.update() 634 | 635 | epoch_train_loss += sum_loss.item() 636 | 637 | num_iterations += 1 638 | 639 | conv_model.eval() 640 | df_model.eval() 641 | fcn_model.eval() 642 | fcn_classifier.eval() 643 | mlp_text_head.eval() 644 | 645 | val_accu = evaluate_model_acc(val_loader, conv_model, fcn_model, fcn_classifier) 646 | 647 | if max_val_accu < val_accu: 648 | max_val_accu = val_accu 649 | end_val_epoch = epoch 650 | test_accuracy = evaluate_model_acc(test_loader, conv_model, fcn_model, fcn_classifier) 651 | 652 | if (epoch > args.warmup_epochs * 2) and (last_val_accu >= val_accu): 653 | stop_count += 1 654 | else: 655 | stop_count = 0 656 | 657 | if (epoch > args.warmup_epochs * 2) and (end_val_epoch + 80 < epoch): 658 | increase_count += 1 659 | else: 660 | increase_count = 0 661 | 662 | last_val_accu = val_accu 663 | 664 | if epoch % 100 == 0: 665 | print("epoch : {}, train loss: {}, val_accu: {}, stop_count: {}, test_accuracy: {}".format(epoch, 666 | epoch_train_loss, 667 | val_accu, 668 | stop_count, 669 | test_accuracy)) 670 | 671 | test_accuracies.append(test_accuracy) 672 | train_time = time.time() - t 673 | 674 | test_acc_list = test_accuracies 675 | train_time_end = train_time 676 | print("The average test acc = ", np.mean(test_acc_list), ", all fold test acc = ", test_acc_list, 677 | ", training time = ", train_time_end) 678 | print('Done!') 679 | -------------------------------------------------------------------------------- /diffshape_ssc/parameter_shapelets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Like [1], we use a cross-validation grid search method to select k (target_dim, number of shapelets), 3 | η (len_shapelet_ratio, the ratio of the time series length), and hard prompt template. 4 | 5 | [1] Grabocka, Josif, et al. "Learning time-series shapelets." Proceedings of the 20th ACM SIGKDD international conference 6 | on Knowledge discovery and data mining. 2014. 7 | ''' 8 | 9 | ucr_hyp_dict_10 = {'AllGestureWiimoteX': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 10 | 'AllGestureWiimoteY': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 11 | 'AllGestureWiimoteZ': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 12 | 'ArrowHead': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 13 | 'BME': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 14 | 'Car': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 15 | 'CBF': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 16 | 'Chinatown': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 17 | 'CinCECGTorso': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 18 | 'ChlorineConcentration': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 19 | 'Computers': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 20 | 'CricketX': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 21 | 'CricketY': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 22 | 'CricketZ': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 23 | 'Crop': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 24 | 'DiatomSizeReduction': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 25 | 'DistalPhalanxOutlineAgeGroup': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 26 | 'DistalPhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 27 | 'DistalPhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 28 | 'DodgerLoopGame': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 29 | 'DodgerLoopWeekend': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 30 | 'Earthquakes': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 31 | 'ECG200': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 32 | 'ECG5000': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 33 | 'ECGFiveDays': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 34 | 'ElectricDevices': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 35 | 'EOGHorizontalSignal': {'target_dim': 5, 'len_shapelet_ratio': 0.4}, 36 | 'EOGVerticalSignal': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 37 | 'EthanolLevel': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 38 | 'FaceAll': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 39 | 'FacesUCR': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 40 | 'Fish': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 41 | 'FordA': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 42 | 'FordB': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 43 | 'FreezerRegularTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 44 | 'FreezerSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 45 | 'GesturePebbleZ1': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 46 | 'GesturePebbleZ2': {'target_dim': 5, 'len_shapelet_ratio': 0.5}, 47 | 'GunPoint': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 48 | 'GunPointAgeSpan': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 49 | 'GunPointMaleVersusFemale': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 50 | 'GunPointOldVersusYoung': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 51 | 'Ham': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 52 | 'HandOutlines': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 53 | 'Haptics': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 54 | 'Herring': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 55 | 'HouseTwenty': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 56 | 'InlineSkate': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 57 | 'InsectEPGRegularTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 58 | 'InsectEPGSmallTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.4}, 59 | 'InsectWingbeatSound': {'target_dim': 5, 'len_shapelet_ratio': 0.4}, 60 | 'ItalyPowerDemand': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 61 | 'LargeKitchenAppliances': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 62 | 'Lightning2': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 63 | 'Mallat': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 64 | 'Meat': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 65 | 'MedicalImages': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 66 | 'MelbournePedestrian': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 67 | 'MiddlePhalanxOutlineAgeGroup': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 68 | 'MiddlePhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 69 | 'MiddlePhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 70 | 'MixedShapesRegularTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 71 | 'MixedShapesSmallTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 72 | 'MoteStrain': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 73 | 'NonInvasiveFetalECGThorax1': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 74 | 'NonInvasiveFetalECGThorax2': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 75 | 'OSULeaf': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 76 | 'PhalangesOutlinesCorrect': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 77 | 'Phoneme': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 78 | 'PLAID': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 79 | 'Plane': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 80 | 'PowerCons': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 81 | 'ProximalPhalanxOutlineAgeGroup': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 82 | 'ProximalPhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 83 | 'ProximalPhalanxTW': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 84 | 'RefrigerationDevices': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 85 | 'ScreenType': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 86 | 'SemgHandGenderCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 87 | 'SemgHandMovementCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 88 | 'SemgHandSubjectCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 89 | 'ShapeletSim': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 90 | 'SmallKitchenAppliances': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 91 | 'SmoothSubspace': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 92 | 'SonyAIBORobotSurface1': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 93 | 'SonyAIBORobotSurface2': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 94 | 'StarLightCurves': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 95 | 'Strawberry': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 96 | 'SwedishLeaf': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 97 | 'Symbols': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 98 | 'SyntheticControl': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 99 | 'ToeSegmentation1': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 100 | 'ToeSegmentation2': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 101 | 'Trace': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 102 | 'TwoLeadECG': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 103 | 'TwoPatterns': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 104 | 'UMD': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 105 | 'UWaveGestureLibraryAll': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 106 | 'UWaveGestureLibraryX': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 107 | 'UWaveGestureLibraryY': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 108 | 'UWaveGestureLibraryZ': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 109 | 'Wafer': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 110 | 'Wine': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 111 | 'WordSynonyms': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 112 | 'Worms': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 113 | 'WormsTwoClass': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 114 | 'Yoga': {'target_dim': 5, 'len_shapelet_ratio': 0.6}} 115 | 116 | ucr_hyp_dict_20 = {'AllGestureWiimoteX': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 117 | 'AllGestureWiimoteY': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 118 | 'AllGestureWiimoteZ': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 119 | 'ArrowHead': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 120 | 'BME': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 121 | 'Car': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 122 | 'CBF': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 123 | 'Chinatown': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 124 | 'CinCECGTorso': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 125 | 'ChlorineConcentration': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 126 | 'Computers': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 127 | 'CricketX': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 128 | 'CricketY': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 129 | 'CricketZ': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 130 | 'Crop': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 131 | 'DiatomSizeReduction': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 132 | 'DistalPhalanxOutlineAgeGroup': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 133 | 'DistalPhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 134 | 'DistalPhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.4}, 135 | 'DodgerLoopGame': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 136 | 'DodgerLoopWeekend': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 137 | 'Earthquakes': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 138 | 'ECG200': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 139 | 'ECG5000': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 140 | 'ECGFiveDays': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 141 | 'ElectricDevices': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 142 | 'EOGHorizontalSignal': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 143 | 'EOGVerticalSignal': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 144 | 'EthanolLevel': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 145 | 'FaceAll': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 146 | 'FacesUCR': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 147 | 'Fish': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 148 | 'FordA': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 149 | 'FordB': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 150 | 'FreezerRegularTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 151 | 'FreezerSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 152 | 'GesturePebbleZ1': {'target_dim': 10, 'len_shapelet_ratio': 0.5}, 153 | 'GesturePebbleZ2': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 154 | 'GunPoint': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 155 | 'GunPointAgeSpan': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 156 | 'GunPointMaleVersusFemale': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 157 | 'GunPointOldVersusYoung': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 158 | 'Ham': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 159 | 'HandOutlines': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 160 | 'Haptics': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 161 | 'Herring': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 162 | 'HouseTwenty': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 163 | 'InlineSkate': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 164 | 'InsectEPGRegularTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 165 | 'InsectEPGSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 166 | 'InsectWingbeatSound': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 167 | 'ItalyPowerDemand': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 168 | 'LargeKitchenAppliances': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 169 | 'Lightning2': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 170 | 'Mallat': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 171 | 'Meat': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 172 | 'MedicalImages': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 173 | 'MelbournePedestrian': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 174 | 'MiddlePhalanxOutlineAgeGroup': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 175 | 'MiddlePhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 176 | 'MiddlePhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 177 | 'MixedShapesRegularTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 178 | 'MixedShapesSmallTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 179 | 'MoteStrain': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 180 | 'NonInvasiveFetalECGThorax1': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 181 | 'NonInvasiveFetalECGThorax2': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 182 | 'OSULeaf': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 183 | 'PhalangesOutlinesCorrect': {'target_dim': 5, 'len_shapelet_ratio': 0.5}, 184 | 'Phoneme': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 185 | 'PLAID': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 186 | 'Plane': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 187 | 'PowerCons': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 188 | 'ProximalPhalanxOutlineAgeGroup': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 189 | 'ProximalPhalanxOutlineCorrect': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 190 | 'ProximalPhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 191 | 'RefrigerationDevices': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 192 | 'ScreenType': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 193 | 'SemgHandGenderCh2': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 194 | 'SemgHandMovementCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 195 | 'SemgHandSubjectCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.2}, 196 | 'ShapeletSim': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 197 | 'SmallKitchenAppliances': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 198 | 'SmoothSubspace': {'target_dim': 10, 'len_shapelet_ratio': 0.5}, 199 | 'SonyAIBORobotSurface1': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 200 | 'SonyAIBORobotSurface2': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 201 | 'StarLightCurves': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 202 | 'Strawberry': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 203 | 'SwedishLeaf': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 204 | 'Symbols': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 205 | 'SyntheticControl': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 206 | 'ToeSegmentation1': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 207 | 'ToeSegmentation2': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 208 | 'Trace': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 209 | 'TwoLeadECG': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 210 | 'TwoPatterns': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 211 | 'UMD': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 212 | 'UWaveGestureLibraryAll': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 213 | 'UWaveGestureLibraryX': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 214 | 'UWaveGestureLibraryY': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 215 | 'UWaveGestureLibraryZ': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 216 | 'Wafer': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 217 | 'Wine': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 218 | 'WordSynonyms': {'target_dim': 10, 'len_shapelet_ratio': 0.5}, 219 | 'Worms': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 220 | 'WormsTwoClass': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 221 | 'Yoga': {'target_dim': 10, 'len_shapelet_ratio': 0.7}} 222 | 223 | ucr_hyp_dict_40 = {'AllGestureWiimoteX': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 224 | 'AllGestureWiimoteY': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 225 | 'AllGestureWiimoteZ': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 226 | 'ArrowHead': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 227 | 'BME': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 228 | 'Car': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 229 | 'CBF': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 230 | 'Chinatown': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 231 | 'CinCECGTorso': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 232 | 'ChlorineConcentration': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 233 | 'Computers': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 234 | 'CricketX': {'target_dim': 2, 'len_shapelet_ratio': 0.6}, 235 | 'CricketY': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 236 | 'CricketZ': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 237 | 'Crop': {'target_dim': 5, 'len_shapelet_ratio': 0.7}, 238 | 'DiatomSizeReduction': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 239 | 'DistalPhalanxOutlineAgeGroup': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 240 | 'DistalPhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 241 | 'DistalPhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 242 | 'DodgerLoopGame': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 243 | 'DodgerLoopWeekend': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 244 | 'Earthquakes': {'target_dim': 10, 'len_shapelet_ratio': 0.4}, 245 | 'ECG200': {'target_dim': 2, 'len_shapelet_ratio': 0.4}, 246 | 'ECG5000': {'target_dim': 10, 'len_shapelet_ratio': 0.5}, 247 | 'ECGFiveDays': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 248 | 'ElectricDevices': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 249 | 'EOGHorizontalSignal': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 250 | 'EOGVerticalSignal': {'target_dim': 5, 'len_shapelet_ratio': 0.4}, 251 | 'EthanolLevel': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 252 | 'FaceAll': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 253 | 'FacesUCR': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 254 | 'Fish': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 255 | 'FordA': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 256 | 'FordB': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 257 | 'FreezerRegularTrain': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 258 | 'FreezerSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 259 | 'GesturePebbleZ1': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 260 | 'GesturePebbleZ2': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 261 | 'GunPoint': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 262 | 'GunPointAgeSpan': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 263 | 'GunPointMaleVersusFemale': {'target_dim': 10, 'len_shapelet_ratio': 0.5}, 264 | 'GunPointOldVersusYoung': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 265 | 'Ham': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 266 | 'HandOutlines': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 267 | 'Haptics': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 268 | 'Herring': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 269 | 'HouseTwenty': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 270 | 'InlineSkate': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 271 | 'InsectEPGRegularTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 272 | 'InsectEPGSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 273 | 'InsectWingbeatSound': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 274 | 'ItalyPowerDemand': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 275 | 'LargeKitchenAppliances': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 276 | 'Lightning2': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 277 | 'Mallat': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 278 | 'Meat': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 279 | 'MedicalImages': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 280 | 'MelbournePedestrian': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 281 | 'MiddlePhalanxOutlineAgeGroup': {'target_dim': 2, 'len_shapelet_ratio': 0.1}, 282 | 'MiddlePhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 283 | 'MiddlePhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 284 | 'MixedShapesRegularTrain': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 285 | 'MixedShapesSmallTrain': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 286 | 'MoteStrain': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 287 | 'NonInvasiveFetalECGThorax1': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 288 | 'NonInvasiveFetalECGThorax2': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 289 | 'OSULeaf': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 290 | 'PhalangesOutlinesCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 291 | 'Phoneme': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 292 | 'PLAID': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 293 | 'Plane': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 294 | 'PowerCons': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 295 | 'ProximalPhalanxOutlineAgeGroup': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 296 | 'ProximalPhalanxOutlineCorrect': {'target_dim': 2, 'len_shapelet_ratio': 0.3}, 297 | 'ProximalPhalanxTW': {'target_dim': 5, 'len_shapelet_ratio': 0.1}, 298 | 'RefrigerationDevices': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 299 | 'ScreenType': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 300 | 'SemgHandGenderCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.1}, 301 | 'SemgHandMovementCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.3}, 302 | 'SemgHandSubjectCh2': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 303 | 'ShapeletSim': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 304 | 'SmallKitchenAppliances': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 305 | 'SmoothSubspace': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 306 | 'SonyAIBORobotSurface1': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 307 | 'SonyAIBORobotSurface2': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 308 | 'StarLightCurves': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 309 | 'Strawberry': {'target_dim': 5, 'len_shapelet_ratio': 0.2}, 310 | 'SwedishLeaf': {'target_dim': 2, 'len_shapelet_ratio': 0.5}, 311 | 'Symbols': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 312 | 'SyntheticControl': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 313 | 'ToeSegmentation1': {'target_dim': 2, 'len_shapelet_ratio': 0.7}, 314 | 'ToeSegmentation2': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 315 | 'Trace': {'target_dim': 2, 'len_shapelet_ratio': 0.8}, 316 | 'TwoLeadECG': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 317 | 'TwoPatterns': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 318 | 'UMD': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 319 | 'UWaveGestureLibraryAll': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 320 | 'UWaveGestureLibraryX': {'target_dim': 10, 'len_shapelet_ratio': 0.7}, 321 | 'UWaveGestureLibraryY': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 322 | 'UWaveGestureLibraryZ': {'target_dim': 10, 'len_shapelet_ratio': 0.8}, 323 | 'Wafer': {'target_dim': 5, 'len_shapelet_ratio': 0.3}, 324 | 'Wine': {'target_dim': 2, 'len_shapelet_ratio': 0.2}, 325 | 'WordSynonyms': {'target_dim': 10, 'len_shapelet_ratio': 0.6}, 326 | 'Worms': {'target_dim': 5, 'len_shapelet_ratio': 0.6}, 327 | 'WormsTwoClass': {'target_dim': 5, 'len_shapelet_ratio': 0.8}, 328 | 'Yoga': {'target_dim': 10, 'len_shapelet_ratio': 0.7}} 329 | 330 | ucr_hyp_prompt_10 = {'AllGestureWiimoteX': {'prompt_toolkit_series_i': 8}, 331 | 'AllGestureWiimoteY': {'prompt_toolkit_series_i': 6}, 332 | 'AllGestureWiimoteZ': {'prompt_toolkit_series_i': 5}, 'ArrowHead': {'prompt_toolkit_series_i': 6}, 333 | 'BME': {'prompt_toolkit_series_i': 4}, 'Car': {'prompt_toolkit_series_i': 2}, 334 | 'CBF': {'prompt_toolkit_series_i': 1}, 'Chinatown': {'prompt_toolkit_series_i': 6}, 335 | 'CinCECGTorso': {'prompt_toolkit_series_i': 2}, 336 | 'ChlorineConcentration': {'prompt_toolkit_series_i': 2}, 337 | 'Computers': {'prompt_toolkit_series_i': 8}, 'CricketX': {'prompt_toolkit_series_i': 7}, 338 | 'CricketY': {'prompt_toolkit_series_i': 7}, 'CricketZ': {'prompt_toolkit_series_i': 4}, 339 | 'Crop': {'prompt_toolkit_series_i': 2}, 'DiatomSizeReduction': {'prompt_toolkit_series_i': 6}, 340 | 'DistalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 5}, 341 | 'DistalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 2}, 342 | 'DistalPhalanxTW': {'prompt_toolkit_series_i': 1}, 343 | 'DodgerLoopGame': {'prompt_toolkit_series_i': 6}, 344 | 'DodgerLoopWeekend': {'prompt_toolkit_series_i': 8}, 'Earthquakes': {'prompt_toolkit_series_i': 3}, 345 | 'ECG200': {'prompt_toolkit_series_i': 1}, 'ECG5000': {'prompt_toolkit_series_i': 5}, 346 | 'ECGFiveDays': {'prompt_toolkit_series_i': 3}, 'ElectricDevices': {'prompt_toolkit_series_i': 4}, 347 | 'EOGHorizontalSignal': {'prompt_toolkit_series_i': 3}, 348 | 'EOGVerticalSignal': {'prompt_toolkit_series_i': 7}, 349 | 'EthanolLevel': {'prompt_toolkit_series_i': 6}, 'FaceAll': {'prompt_toolkit_series_i': 4}, 350 | 'FacesUCR': {'prompt_toolkit_series_i': 4}, 'Fish': {'prompt_toolkit_series_i': 5}, 351 | 'FordA': {'prompt_toolkit_series_i': 2}, 'FordB': {'prompt_toolkit_series_i': 5}, 352 | 'FreezerRegularTrain': {'prompt_toolkit_series_i': 3}, 353 | 'FreezerSmallTrain': {'prompt_toolkit_series_i': 5}, 354 | 'GesturePebbleZ1': {'prompt_toolkit_series_i': 3}, 355 | 'GesturePebbleZ2': {'prompt_toolkit_series_i': 4}, 'GunPoint': {'prompt_toolkit_series_i': 4}, 356 | 'GunPointAgeSpan': {'prompt_toolkit_series_i': 3}, 357 | 'GunPointMaleVersusFemale': {'prompt_toolkit_series_i': 1}, 358 | 'GunPointOldVersusYoung': {'prompt_toolkit_series_i': 3}, 'Ham': {'prompt_toolkit_series_i': 6}, 359 | 'HandOutlines': {'prompt_toolkit_series_i': 3}, 'Haptics': {'prompt_toolkit_series_i': 4}, 360 | 'Herring': {'prompt_toolkit_series_i': 7}, 'HouseTwenty': {'prompt_toolkit_series_i': 8}, 361 | 'InlineSkate': {'prompt_toolkit_series_i': 6}, 362 | 'InsectEPGRegularTrain': {'prompt_toolkit_series_i': 1}, 363 | 'InsectEPGSmallTrain': {'prompt_toolkit_series_i': 1}, 364 | 'InsectWingbeatSound': {'prompt_toolkit_series_i': 1}, 365 | 'ItalyPowerDemand': {'prompt_toolkit_series_i': 2}, 366 | 'LargeKitchenAppliances': {'prompt_toolkit_series_i': 8}, 367 | 'Lightning2': {'prompt_toolkit_series_i': 5}, 'Mallat': {'prompt_toolkit_series_i': 7}, 368 | 'Meat': {'prompt_toolkit_series_i': 7}, 'MedicalImages': {'prompt_toolkit_series_i': 4}, 369 | 'MelbournePedestrian': {'prompt_toolkit_series_i': 6}, 370 | 'MiddlePhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 7}, 371 | 'MiddlePhalanxOutlineCorrect': {'prompt_toolkit_series_i': 5}, 372 | 'MiddlePhalanxTW': {'prompt_toolkit_series_i': 6}, 373 | 'MixedShapesRegularTrain': {'prompt_toolkit_series_i': 1}, 374 | 'MixedShapesSmallTrain': {'prompt_toolkit_series_i': 8}, 375 | 'MoteStrain': {'prompt_toolkit_series_i': 1}, 376 | 'NonInvasiveFetalECGThorax1': {'prompt_toolkit_series_i': 1}, 377 | 'NonInvasiveFetalECGThorax2': {'prompt_toolkit_series_i': 3}, 378 | 'OSULeaf': {'prompt_toolkit_series_i': 2}, 379 | 'PhalangesOutlinesCorrect': {'prompt_toolkit_series_i': 5}, 380 | 'Phoneme': {'prompt_toolkit_series_i': 8}, 'PLAID': {'prompt_toolkit_series_i': 3}, 381 | 'Plane': {'prompt_toolkit_series_i': 3}, 'PowerCons': {'prompt_toolkit_series_i': 3}, 382 | 'ProximalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 7}, 383 | 'ProximalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 6}, 384 | 'ProximalPhalanxTW': {'prompt_toolkit_series_i': 2}, 385 | 'RefrigerationDevices': {'prompt_toolkit_series_i': 5}, 386 | 'ScreenType': {'prompt_toolkit_series_i': 4}, 'SemgHandGenderCh2': {'prompt_toolkit_series_i': 3}, 387 | 'SemgHandMovementCh2': {'prompt_toolkit_series_i': 6}, 388 | 'SemgHandSubjectCh2': {'prompt_toolkit_series_i': 2}, 389 | 'ShapeletSim': {'prompt_toolkit_series_i': 6}, 390 | 'SmallKitchenAppliances': {'prompt_toolkit_series_i': 2}, 391 | 'SmoothSubspace': {'prompt_toolkit_series_i': 1}, 392 | 'SonyAIBORobotSurface1': {'prompt_toolkit_series_i': 1}, 393 | 'SonyAIBORobotSurface2': {'prompt_toolkit_series_i': 1}, 394 | 'StarLightCurves': {'prompt_toolkit_series_i': 8}, 'Strawberry': {'prompt_toolkit_series_i': 1}, 395 | 'SwedishLeaf': {'prompt_toolkit_series_i': 5}, 'Symbols': {'prompt_toolkit_series_i': 4}, 396 | 'SyntheticControl': {'prompt_toolkit_series_i': 3}, 397 | 'ToeSegmentation1': {'prompt_toolkit_series_i': 5}, 398 | 'ToeSegmentation2': {'prompt_toolkit_series_i': 3}, 'Trace': {'prompt_toolkit_series_i': 6}, 399 | 'TwoLeadECG': {'prompt_toolkit_series_i': 8}, 'TwoPatterns': {'prompt_toolkit_series_i': 2}, 400 | 'UMD': {'prompt_toolkit_series_i': 1}, 'UWaveGestureLibraryAll': {'prompt_toolkit_series_i': 7}, 401 | 'UWaveGestureLibraryX': {'prompt_toolkit_series_i': 3}, 402 | 'UWaveGestureLibraryY': {'prompt_toolkit_series_i': 1}, 403 | 'UWaveGestureLibraryZ': {'prompt_toolkit_series_i': 6}, 'Wafer': {'prompt_toolkit_series_i': 4}, 404 | 'Wine': {'prompt_toolkit_series_i': 3}, 'WordSynonyms': {'prompt_toolkit_series_i': 5}, 405 | 'Worms': {'prompt_toolkit_series_i': 7}, 'WormsTwoClass': {'prompt_toolkit_series_i': 5}, 406 | 'Yoga': {'prompt_toolkit_series_i': 6}} 407 | 408 | ucr_hyp_prompt_20 = {'AllGestureWiimoteX': {'prompt_toolkit_series_i': 4}, 409 | 'AllGestureWiimoteY': {'prompt_toolkit_series_i': 6}, 410 | 'AllGestureWiimoteZ': {'prompt_toolkit_series_i': 7}, 'ArrowHead': {'prompt_toolkit_series_i': 8}, 411 | 'BME': {'prompt_toolkit_series_i': 1}, 'Car': {'prompt_toolkit_series_i': 1}, 412 | 'CBF': {'prompt_toolkit_series_i': 1}, 'Chinatown': {'prompt_toolkit_series_i': 5}, 413 | 'CinCECGTorso': {'prompt_toolkit_series_i': 1}, 414 | 'ChlorineConcentration': {'prompt_toolkit_series_i': 8}, 415 | 'Computers': {'prompt_toolkit_series_i': 3}, 'CricketX': {'prompt_toolkit_series_i': 6}, 416 | 'CricketY': {'prompt_toolkit_series_i': 3}, 'CricketZ': {'prompt_toolkit_series_i': 3}, 417 | 'Crop': {'prompt_toolkit_series_i': 4}, 'DiatomSizeReduction': {'prompt_toolkit_series_i': 6}, 418 | 'DistalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 7}, 419 | 'DistalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 7}, 420 | 'DistalPhalanxTW': {'prompt_toolkit_series_i': 3}, 421 | 'DodgerLoopGame': {'prompt_toolkit_series_i': 7}, 422 | 'DodgerLoopWeekend': {'prompt_toolkit_series_i': 5}, 'Earthquakes': {'prompt_toolkit_series_i': 4}, 423 | 'ECG200': {'prompt_toolkit_series_i': 7}, 'ECG5000': {'prompt_toolkit_series_i': 5}, 424 | 'ECGFiveDays': {'prompt_toolkit_series_i': 1}, 'ElectricDevices': {'prompt_toolkit_series_i': 3}, 425 | 'EOGHorizontalSignal': {'prompt_toolkit_series_i': 7}, 426 | 'EOGVerticalSignal': {'prompt_toolkit_series_i': 8}, 427 | 'EthanolLevel': {'prompt_toolkit_series_i': 7}, 'FaceAll': {'prompt_toolkit_series_i': 2}, 428 | 'FacesUCR': {'prompt_toolkit_series_i': 6}, 'Fish': {'prompt_toolkit_series_i': 4}, 429 | 'FordA': {'prompt_toolkit_series_i': 3}, 'FordB': {'prompt_toolkit_series_i': 6}, 430 | 'FreezerRegularTrain': {'prompt_toolkit_series_i': 2}, 431 | 'FreezerSmallTrain': {'prompt_toolkit_series_i': 3}, 432 | 'GesturePebbleZ1': {'prompt_toolkit_series_i': 4}, 433 | 'GesturePebbleZ2': {'prompt_toolkit_series_i': 4}, 'GunPoint': {'prompt_toolkit_series_i': 1}, 434 | 'GunPointAgeSpan': {'prompt_toolkit_series_i': 2}, 435 | 'GunPointMaleVersusFemale': {'prompt_toolkit_series_i': 3}, 436 | 'GunPointOldVersusYoung': {'prompt_toolkit_series_i': 4}, 'Ham': {'prompt_toolkit_series_i': 3}, 437 | 'HandOutlines': {'prompt_toolkit_series_i': 8}, 'Haptics': {'prompt_toolkit_series_i': 7}, 438 | 'Herring': {'prompt_toolkit_series_i': 2}, 'HouseTwenty': {'prompt_toolkit_series_i': 7}, 439 | 'InlineSkate': {'prompt_toolkit_series_i': 4}, 440 | 'InsectEPGRegularTrain': {'prompt_toolkit_series_i': 1}, 441 | 'InsectEPGSmallTrain': {'prompt_toolkit_series_i': 5}, 442 | 'InsectWingbeatSound': {'prompt_toolkit_series_i': 6}, 443 | 'ItalyPowerDemand': {'prompt_toolkit_series_i': 1}, 444 | 'LargeKitchenAppliances': {'prompt_toolkit_series_i': 5}, 445 | 'Lightning2': {'prompt_toolkit_series_i': 7}, 'Mallat': {'prompt_toolkit_series_i': 1}, 446 | 'Meat': {'prompt_toolkit_series_i': 1}, 'MedicalImages': {'prompt_toolkit_series_i': 4}, 447 | 'MelbournePedestrian': {'prompt_toolkit_series_i': 6}, 448 | 'MiddlePhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 1}, 449 | 'MiddlePhalanxOutlineCorrect': {'prompt_toolkit_series_i': 6}, 450 | 'MiddlePhalanxTW': {'prompt_toolkit_series_i': 1}, 451 | 'MixedShapesRegularTrain': {'prompt_toolkit_series_i': 6}, 452 | 'MixedShapesSmallTrain': {'prompt_toolkit_series_i': 2}, 453 | 'MoteStrain': {'prompt_toolkit_series_i': 7}, 454 | 'NonInvasiveFetalECGThorax1': {'prompt_toolkit_series_i': 3}, 455 | 'NonInvasiveFetalECGThorax2': {'prompt_toolkit_series_i': 3}, 456 | 'OSULeaf': {'prompt_toolkit_series_i': 3}, 457 | 'PhalangesOutlinesCorrect': {'prompt_toolkit_series_i': 5}, 458 | 'Phoneme': {'prompt_toolkit_series_i': 3}, 'PLAID': {'prompt_toolkit_series_i': 1}, 459 | 'Plane': {'prompt_toolkit_series_i': 4}, 'PowerCons': {'prompt_toolkit_series_i': 8}, 460 | 'ProximalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 8}, 461 | 'ProximalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 3}, 462 | 'ProximalPhalanxTW': {'prompt_toolkit_series_i': 3}, 463 | 'RefrigerationDevices': {'prompt_toolkit_series_i': 7}, 464 | 'ScreenType': {'prompt_toolkit_series_i': 6}, 'SemgHandGenderCh2': {'prompt_toolkit_series_i': 4}, 465 | 'SemgHandMovementCh2': {'prompt_toolkit_series_i': 5}, 466 | 'SemgHandSubjectCh2': {'prompt_toolkit_series_i': 7}, 467 | 'ShapeletSim': {'prompt_toolkit_series_i': 3}, 468 | 'SmallKitchenAppliances': {'prompt_toolkit_series_i': 1}, 469 | 'SmoothSubspace': {'prompt_toolkit_series_i': 4}, 470 | 'SonyAIBORobotSurface1': {'prompt_toolkit_series_i': 2}, 471 | 'SonyAIBORobotSurface2': {'prompt_toolkit_series_i': 7}, 472 | 'StarLightCurves': {'prompt_toolkit_series_i': 1}, 'Strawberry': {'prompt_toolkit_series_i': 4}, 473 | 'SwedishLeaf': {'prompt_toolkit_series_i': 6}, 'Symbols': {'prompt_toolkit_series_i': 6}, 474 | 'SyntheticControl': {'prompt_toolkit_series_i': 1}, 475 | 'ToeSegmentation1': {'prompt_toolkit_series_i': 1}, 476 | 'ToeSegmentation2': {'prompt_toolkit_series_i': 7}, 'Trace': {'prompt_toolkit_series_i': 6}, 477 | 'TwoLeadECG': {'prompt_toolkit_series_i': 1}, 'TwoPatterns': {'prompt_toolkit_series_i': 2}, 478 | 'UMD': {'prompt_toolkit_series_i': 4}, 'UWaveGestureLibraryAll': {'prompt_toolkit_series_i': 7}, 479 | 'UWaveGestureLibraryX': {'prompt_toolkit_series_i': 2}, 480 | 'UWaveGestureLibraryY': {'prompt_toolkit_series_i': 3}, 481 | 'UWaveGestureLibraryZ': {'prompt_toolkit_series_i': 5}, 'Wafer': {'prompt_toolkit_series_i': 1}, 482 | 'Wine': {'prompt_toolkit_series_i': 1}, 'WordSynonyms': {'prompt_toolkit_series_i': 8}, 483 | 'Worms': {'prompt_toolkit_series_i': 4}, 'WormsTwoClass': {'prompt_toolkit_series_i': 4}, 484 | 'Yoga': {'prompt_toolkit_series_i': 5}} 485 | 486 | ucr_hyp_prompt_40 = {'AllGestureWiimoteX': {'prompt_toolkit_series_i': 1}, 487 | 'AllGestureWiimoteY': {'prompt_toolkit_series_i': 5}, 488 | 'AllGestureWiimoteZ': {'prompt_toolkit_series_i': 6}, 'ArrowHead': {'prompt_toolkit_series_i': 1}, 489 | 'BME': {'prompt_toolkit_series_i': 4}, 'Car': {'prompt_toolkit_series_i': 3}, 490 | 'CBF': {'prompt_toolkit_series_i': 1}, 'Chinatown': {'prompt_toolkit_series_i': 1}, 491 | 'CinCECGTorso': {'prompt_toolkit_series_i': 7}, 492 | 'ChlorineConcentration': {'prompt_toolkit_series_i': 1}, 493 | 'Computers': {'prompt_toolkit_series_i': 8}, 'CricketX': {'prompt_toolkit_series_i': 6}, 494 | 'CricketY': {'prompt_toolkit_series_i': 8}, 'CricketZ': {'prompt_toolkit_series_i': 1}, 495 | 'Crop': {'prompt_toolkit_series_i': 8}, 'DiatomSizeReduction': {'prompt_toolkit_series_i': 7}, 496 | 'DistalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 8}, 497 | 'DistalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 1}, 498 | 'DistalPhalanxTW': {'prompt_toolkit_series_i': 5}, 499 | 'DodgerLoopGame': {'prompt_toolkit_series_i': 5}, 500 | 'DodgerLoopWeekend': {'prompt_toolkit_series_i': 1}, 'Earthquakes': {'prompt_toolkit_series_i': 1}, 501 | 'ECG200': {'prompt_toolkit_series_i': 5}, 'ECG5000': {'prompt_toolkit_series_i': 3}, 502 | 'ECGFiveDays': {'prompt_toolkit_series_i': 2}, 'ElectricDevices': {'prompt_toolkit_series_i': 7}, 503 | 'EOGHorizontalSignal': {'prompt_toolkit_series_i': 4}, 504 | 'EOGVerticalSignal': {'prompt_toolkit_series_i': 4}, 505 | 'EthanolLevel': {'prompt_toolkit_series_i': 4}, 'FaceAll': {'prompt_toolkit_series_i': 8}, 506 | 'FacesUCR': {'prompt_toolkit_series_i': 8}, 'Fish': {'prompt_toolkit_series_i': 3}, 507 | 'FordA': {'prompt_toolkit_series_i': 3}, 'FordB': {'prompt_toolkit_series_i': 2}, 508 | 'FreezerRegularTrain': {'prompt_toolkit_series_i': 8}, 509 | 'FreezerSmallTrain': {'prompt_toolkit_series_i': 4}, 510 | 'GesturePebbleZ1': {'prompt_toolkit_series_i': 5}, 511 | 'GesturePebbleZ2': {'prompt_toolkit_series_i': 5}, 'GunPoint': {'prompt_toolkit_series_i': 6}, 512 | 'GunPointAgeSpan': {'prompt_toolkit_series_i': 8}, 513 | 'GunPointMaleVersusFemale': {'prompt_toolkit_series_i': 6}, 514 | 'GunPointOldVersusYoung': {'prompt_toolkit_series_i': 3}, 'Ham': {'prompt_toolkit_series_i': 7}, 515 | 'HandOutlines': {'prompt_toolkit_series_i': 6}, 'Haptics': {'prompt_toolkit_series_i': 6}, 516 | 'Herring': {'prompt_toolkit_series_i': 1}, 'HouseTwenty': {'prompt_toolkit_series_i': 1}, 517 | 'InlineSkate': {'prompt_toolkit_series_i': 7}, 518 | 'InsectEPGRegularTrain': {'prompt_toolkit_series_i': 2}, 519 | 'InsectEPGSmallTrain': {'prompt_toolkit_series_i': 8}, 520 | 'InsectWingbeatSound': {'prompt_toolkit_series_i': 3}, 521 | 'ItalyPowerDemand': {'prompt_toolkit_series_i': 7}, 522 | 'LargeKitchenAppliances': {'prompt_toolkit_series_i': 5}, 523 | 'Lightning2': {'prompt_toolkit_series_i': 3}, 'Mallat': {'prompt_toolkit_series_i': 2}, 524 | 'Meat': {'prompt_toolkit_series_i': 4}, 'MedicalImages': {'prompt_toolkit_series_i': 8}, 525 | 'MelbournePedestrian': {'prompt_toolkit_series_i': 1}, 526 | 'MiddlePhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 4}, 527 | 'MiddlePhalanxOutlineCorrect': {'prompt_toolkit_series_i': 2}, 528 | 'MiddlePhalanxTW': {'prompt_toolkit_series_i': 8}, 529 | 'MixedShapesRegularTrain': {'prompt_toolkit_series_i': 4}, 530 | 'MixedShapesSmallTrain': {'prompt_toolkit_series_i': 7}, 531 | 'MoteStrain': {'prompt_toolkit_series_i': 5}, 532 | 'NonInvasiveFetalECGThorax1': {'prompt_toolkit_series_i': 1}, 533 | 'NonInvasiveFetalECGThorax2': {'prompt_toolkit_series_i': 5}, 534 | 'OSULeaf': {'prompt_toolkit_series_i': 3}, 535 | 'PhalangesOutlinesCorrect': {'prompt_toolkit_series_i': 8}, 536 | 'Phoneme': {'prompt_toolkit_series_i': 1}, 'PLAID': {'prompt_toolkit_series_i': 5}, 537 | 'Plane': {'prompt_toolkit_series_i': 4}, 'PowerCons': {'prompt_toolkit_series_i': 2}, 538 | 'ProximalPhalanxOutlineAgeGroup': {'prompt_toolkit_series_i': 2}, 539 | 'ProximalPhalanxOutlineCorrect': {'prompt_toolkit_series_i': 7}, 540 | 'ProximalPhalanxTW': {'prompt_toolkit_series_i': 5}, 541 | 'RefrigerationDevices': {'prompt_toolkit_series_i': 1}, 542 | 'ScreenType': {'prompt_toolkit_series_i': 1}, 'SemgHandGenderCh2': {'prompt_toolkit_series_i': 6}, 543 | 'SemgHandMovementCh2': {'prompt_toolkit_series_i': 8}, 544 | 'SemgHandSubjectCh2': {'prompt_toolkit_series_i': 5}, 545 | 'ShapeletSim': {'prompt_toolkit_series_i': 1}, 546 | 'SmallKitchenAppliances': {'prompt_toolkit_series_i': 8}, 547 | 'SmoothSubspace': {'prompt_toolkit_series_i': 2}, 548 | 'SonyAIBORobotSurface1': {'prompt_toolkit_series_i': 4}, 549 | 'SonyAIBORobotSurface2': {'prompt_toolkit_series_i': 1}, 550 | 'StarLightCurves': {'prompt_toolkit_series_i': 6}, 'Strawberry': {'prompt_toolkit_series_i': 3}, 551 | 'SwedishLeaf': {'prompt_toolkit_series_i': 5}, 'Symbols': {'prompt_toolkit_series_i': 4}, 552 | 'SyntheticControl': {'prompt_toolkit_series_i': 1}, 553 | 'ToeSegmentation1': {'prompt_toolkit_series_i': 7}, 554 | 'ToeSegmentation2': {'prompt_toolkit_series_i': 6}, 'Trace': {'prompt_toolkit_series_i': 1}, 555 | 'TwoLeadECG': {'prompt_toolkit_series_i': 7}, 'TwoPatterns': {'prompt_toolkit_series_i': 2}, 556 | 'UMD': {'prompt_toolkit_series_i': 2}, 'UWaveGestureLibraryAll': {'prompt_toolkit_series_i': 8}, 557 | 'UWaveGestureLibraryX': {'prompt_toolkit_series_i': 7}, 558 | 'UWaveGestureLibraryY': {'prompt_toolkit_series_i': 4}, 559 | 'UWaveGestureLibraryZ': {'prompt_toolkit_series_i': 6}, 'Wafer': {'prompt_toolkit_series_i': 2}, 560 | 'Wine': {'prompt_toolkit_series_i': 6}, 'WordSynonyms': {'prompt_toolkit_series_i': 4}, 561 | 'Worms': {'prompt_toolkit_series_i': 2}, 'WormsTwoClass': {'prompt_toolkit_series_i': 7}, 562 | 'Yoga': {'prompt_toolkit_series_i': 6}} 563 | -------------------------------------------------------------------------------- /diffshape_ssc/semi_backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SqueezeChannels(nn.Module): 5 | def __init__(self): 6 | super(SqueezeChannels, self).__init__() 7 | 8 | def forward(self, x): 9 | return x.squeeze(2) 10 | 11 | 12 | class FCN(nn.Module): 13 | def __init__(self, num_classes, input_size=1): 14 | super(FCN, self).__init__() 15 | 16 | self.num_classes = num_classes 17 | self.conv_block1 = nn.Sequential( 18 | nn.Conv1d(in_channels=input_size, out_channels=128, kernel_size=8, padding='same'), 19 | nn.BatchNorm1d(128), 20 | nn.ReLU() 21 | ) 22 | 23 | self.conv_block2 = nn.Sequential( 24 | nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, padding='same'), 25 | nn.BatchNorm1d(256), 26 | nn.ReLU() 27 | ) 28 | 29 | self.conv_block3 = nn.Sequential( 30 | nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same'), 31 | nn.BatchNorm1d(128), 32 | nn.ReLU() 33 | ) 34 | 35 | self.network = nn.Sequential( 36 | self.conv_block1, 37 | self.conv_block2, 38 | self.conv_block3, 39 | nn.AdaptiveAvgPool1d(1), 40 | SqueezeChannels(), 41 | ) 42 | 43 | def forward(self, x): 44 | return self.network(x) 45 | 46 | 47 | class Classifier(nn.Module): 48 | def __init__(self, input_dims, output_dims) -> None: 49 | super(Classifier, self).__init__() 50 | 51 | self.dense = nn.Linear(input_dims, output_dims) 52 | self.softmax = nn.Softmax(dim=1) 53 | 54 | def forward(self, x): 55 | return self.softmax(self.dense(x)) 56 | 57 | 58 | class ProjectionHead(nn.Module): 59 | def __init__(self, input_dim, embedding_dim=384, output_dim=128) -> None: 60 | super(ProjectionHead, self).__init__() 61 | self.projection_head = nn.Sequential( 62 | nn.Linear(input_dim, embedding_dim), 63 | nn.BatchNorm1d(embedding_dim), 64 | nn.ReLU(inplace=True), 65 | nn.Linear(embedding_dim, output_dim), 66 | ) 67 | 68 | def forward(self, x): 69 | return self.projection_head(x) 70 | -------------------------------------------------------------------------------- /diffshape_ssc/semi_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn.metrics import accuracy_score 8 | from transformers import T5Tokenizer, T5Model 9 | from semi_backbone import FCN, Classifier 10 | import torch.utils.data as data 11 | import random 12 | import torch.nn as nn 13 | 14 | DEFAULT_T5_NAME = 't5-base' 15 | MODEL_NAME = DEFAULT_T5_NAME 16 | 17 | prompt_toolkit_series = ['This time series is ', 'This time series can be described as ', 18 | 'The key attributes of this time series are ', 19 | 'The nature of this time series is depicted by ', 20 | 'Here, the time series is defined by ', 21 | 'Describing this time series, we find ', 'Examining this time series reveals ', 22 | 'The features exhibited in this time series are '] 23 | prompt_toolkit_end = ['.'] 24 | 25 | 26 | def set_seed(args): 27 | random.seed(args.random_seed) 28 | np.random.seed(args.random_seed) 29 | torch.manual_seed(args.random_seed) 30 | torch.cuda.manual_seed(args.random_seed) 31 | torch.cuda.manual_seed_all(args.random_seed) 32 | 33 | 34 | def build_loss(args): 35 | if args.loss == 'cross_entropy': 36 | return nn.CrossEntropyLoss() 37 | elif args.loss == 'reconstruction': 38 | return nn.MSELoss() 39 | 40 | 41 | def lan_shapelet_contrastive_loss(embd_batch, text_embd_batch, labels, device, 42 | temperature=0.07, base_temperature=0.07): 43 | anchor_dot_contrast = torch.div( 44 | torch.matmul(embd_batch, text_embd_batch.T), 45 | temperature) 46 | 47 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 48 | logits = anchor_dot_contrast - logits_max.detach() 49 | labels = labels.contiguous().view(-1, 1) 50 | mask = torch.eq(labels, labels.T).float().to(device) 51 | 52 | logits_mask = torch.scatter( 53 | torch.ones_like(logits.detach()), 54 | 1, 55 | torch.arange(embd_batch.shape[0]).view(-1, 1).to(device), 56 | 0 57 | ) 58 | mask = mask * logits_mask 59 | 60 | exp_logits = torch.exp(logits) * logits_mask 61 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 62 | 63 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12) 64 | num_anchor = 0 65 | for s in mask.sum(1): 66 | if s != 0: 67 | num_anchor = num_anchor + 1 68 | 69 | loss = - (temperature / base_temperature) * mean_log_prob_pos 70 | loss = loss.sum(0) / (num_anchor + 1e-12) 71 | 72 | return loss 73 | 74 | 75 | def get_all_text_labels(ucr_datasets_dict, dataset_name, num_labels, device, prompt_toolkit_series_i=None): 76 | text_labels = torch.empty(size=(num_labels, 768)) 77 | for _label in range(num_labels): 78 | text_labels[_label] = get_ont_text_label(ucr_datasets_dict=ucr_datasets_dict, dataset_name=dataset_name, 79 | num_label=_label, device=device, 80 | prompt_toolkit_series_i=prompt_toolkit_series_i) 81 | 82 | return text_labels.to(device) 83 | 84 | 85 | def get_ont_text_label(ucr_datasets_dict, dataset_name, num_label, device, prompt_toolkit_series_i=None): 86 | text_label = ucr_datasets_dict[dataset_name][str(num_label)] 87 | 88 | if prompt_toolkit_series_i is not None: 89 | text_label = prompt_toolkit_series[prompt_toolkit_series_i] + text_label + prompt_toolkit_end[0] 90 | 91 | # loading model and tokenizer 92 | tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) 93 | model = T5Model.from_pretrained(MODEL_NAME) 94 | inputs = tokenizer(text_label, return_tensors='pt', padding=True) 95 | 96 | output = model.encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], return_dict=True) 97 | pooled_sentence = output.last_hidden_state.detach() # shape is [batch_size, seq_len, hidden_size] 98 | 99 | return torch.mean(pooled_sentence, dim=1)[0] 100 | 101 | 102 | def get_each_sample_distance_shapelet(generator_shapelet, raw_shapelet, topk=1): 103 | if len(generator_shapelet.shape) < 3: 104 | emb1 = torch.unsqueeze(generator_shapelet, 1) # n*1*d 105 | else: 106 | emb1 = generator_shapelet 107 | emb2 = torch.unsqueeze(raw_shapelet, 0) # 1*n*d 108 | w = ((emb1 - emb2) ** 2).mean(2) # n*n*d -> n*n 109 | w = torch.exp(-w / 2) 110 | 111 | topk, indices = torch.topk(w, topk) 112 | 113 | indices = torch.squeeze(indices) 114 | 115 | indices_raw_shapelets = raw_shapelet[indices] 116 | 117 | return topk.reshape(-1), indices_raw_shapelets 118 | 119 | 120 | def get_similarity_shapelet(generator_shapelet): 121 | generator_shapelet = torch.squeeze(generator_shapelet) 122 | 123 | emb1 = torch.unsqueeze(generator_shapelet, 1) # n*1*d 124 | emb2 = torch.unsqueeze(generator_shapelet, 0) # 1*n*d 125 | 126 | w = ((emb1 - emb2) ** 2).mean(2) # n*n*d -> n*n 127 | w = torch.exp(-w / 2) 128 | 129 | return torch.norm(w) 130 | 131 | 132 | def build_dataset(args): 133 | sum_dataset, sum_target, num_classes = load_data(args.dataroot, args.dataset) 134 | 135 | sum_target = transfer_labels(sum_target) 136 | return sum_dataset, sum_target, num_classes 137 | 138 | 139 | def build_model(args): 140 | if args.backbone == 'fcn': 141 | model = FCN(args.num_classes, args.input_size) 142 | 143 | if args.classifier == 'linear': 144 | classifier = Classifier(args.classifier_input, args.num_classes) 145 | 146 | return model, classifier 147 | 148 | 149 | def load_data(dataroot, dataset): 150 | train = pd.read_csv(os.path.join(dataroot, dataset, dataset + '_TRAIN.tsv'), sep='\t', header=None) 151 | train_x = train.iloc[:, 1:] 152 | train_target = train.iloc[:, 0] 153 | 154 | test = pd.read_csv(os.path.join(dataroot, dataset, dataset + '_TEST.tsv'), sep='\t', header=None) 155 | test_x = test.iloc[:, 1:] 156 | test_target = test.iloc[:, 0] 157 | 158 | sum_dataset = pd.concat([train_x, test_x]).to_numpy(dtype=np.float32) 159 | sum_target = pd.concat([train_target, test_target]).to_numpy(dtype=np.float32) 160 | 161 | num_classes = len(np.unique(sum_target)) 162 | 163 | return sum_dataset, sum_target, num_classes 164 | 165 | 166 | def transfer_labels(labels): 167 | indicies = np.unique(labels) 168 | num_samples = labels.shape[0] 169 | 170 | for i in range(num_samples): 171 | new_label = np.argwhere(labels[i] == indicies)[0][0] 172 | labels[i] = new_label 173 | 174 | return labels 175 | 176 | 177 | def get_all_datasets(data, target): 178 | return k_fold(data, target) 179 | 180 | 181 | def k_fold(data, target): 182 | skf = StratifiedKFold(5, shuffle=True) 183 | train_sets = [] 184 | train_targets = [] 185 | 186 | val_sets = [] 187 | val_targets = [] 188 | 189 | test_sets = [] 190 | test_targets = [] 191 | 192 | for raw_index, test_index in skf.split(data, target): 193 | raw_set = data[raw_index] 194 | raw_target = target[raw_index] 195 | 196 | train_index, val_index = next(StratifiedKFold(4, shuffle=True).split(raw_set, raw_target)) 197 | train_sets.append(raw_set[train_index]) 198 | train_targets.append(raw_target[train_index]) 199 | 200 | val_sets.append(raw_set[val_index]) 201 | val_targets.append(raw_target[val_index]) 202 | 203 | test_sets.append(data[test_index]) 204 | test_targets.append(target[test_index]) 205 | 206 | return train_sets, train_targets, val_sets, val_targets, test_sets, test_targets 207 | 208 | 209 | def normalize_per_series(data): 210 | std_ = data.std(axis=1, keepdims=True) 211 | std_[std_ == 0] = 1.0 212 | return (data - data.mean(axis=1, keepdims=True)) / std_ 213 | 214 | 215 | def fill_nan_value(train_set, val_set, test_set): 216 | ind = np.where(np.isnan(train_set)) 217 | col_mean = np.nanmean(train_set, axis=0) 218 | col_mean[np.isnan(col_mean)] = 1e-6 219 | 220 | train_set[ind] = np.take(col_mean, ind[1]) 221 | 222 | ind_val = np.where(np.isnan(val_set)) 223 | val_set[ind_val] = np.take(col_mean, ind_val[1]) 224 | 225 | ind_test = np.where(np.isnan(test_set)) 226 | test_set[ind_test] = np.take(col_mean, ind_test[1]) 227 | return train_set, val_set, test_set 228 | 229 | 230 | def shuffler(x_train, y_train): 231 | indexes = np.array(list(range(x_train.shape[0]))) 232 | np.random.shuffle(indexes) 233 | x_train = x_train[indexes] 234 | y_train = y_train[indexes] 235 | return x_train, y_train 236 | 237 | 238 | class UCRDataset(data.Dataset): 239 | def __init__(self, dataset, target): 240 | self.dataset = dataset 241 | if len(self.dataset.shape) == 2: 242 | self.dataset = torch.unsqueeze(self.dataset, 1) 243 | self.target = target 244 | 245 | def __getitem__(self, index): 246 | return self.dataset[index], self.target[index] 247 | 248 | def __len__(self): 249 | return len(self.target) 250 | 251 | 252 | def evaluate_model_acc(val_loader, model, fcn_model, fcn_classifier): 253 | target_true = [] 254 | target_pred = [] 255 | 256 | num_val_samples = 0 257 | for data, target in val_loader: 258 | with torch.no_grad(): 259 | predicted = model(torch.unsqueeze(data, 2)) ## torch.unsqueeze(x, 2) 260 | 261 | fcn_cls_emb = fcn_model(torch.squeeze(predicted, 2)) 262 | val_pred = fcn_classifier(fcn_cls_emb) 263 | 264 | target_true.append(target.cpu().numpy()) 265 | target_pred.append(torch.argmax(val_pred.data, axis=1).cpu().numpy()) 266 | num_val_samples = num_val_samples + len(target) 267 | 268 | target_true = np.concatenate(target_true) 269 | target_pred = np.concatenate(target_pred) 270 | 271 | return accuracy_score(target_true, target_pred) 272 | 273 | 274 | def evaluate(val_loader, model, classifier, loss): 275 | target_true = [] 276 | target_pred = [] 277 | 278 | val_loss = 0 279 | sum_len = 0 280 | for data, target in val_loader: 281 | with torch.no_grad(): 282 | val_pred = model(data) 283 | val_pred = classifier(val_pred) 284 | val_loss += loss(val_pred, target).item() 285 | target_true.append(target.cpu().numpy()) 286 | target_pred.append(torch.argmax(val_pred.data, axis=1).cpu().numpy()) 287 | sum_len += len(target) 288 | 289 | return val_loss / sum_len, accuracy_score(target_true, target_pred) 290 | 291 | 292 | class ProjectionHead(nn.Module): 293 | def __init__(self, input_dim, embedding_dim=64, output_dim=32) -> None: 294 | super(ProjectionHead, self).__init__() 295 | self.projection_head = nn.Sequential( 296 | nn.Linear(input_dim, embedding_dim), 297 | nn.BatchNorm1d(embedding_dim), 298 | nn.ReLU(inplace=True), 299 | nn.Linear(embedding_dim, output_dim), 300 | ) 301 | 302 | def forward(self, x): 303 | return self.projection_head(x) 304 | 305 | 306 | def sup_contrastive_loss(embd_batch, labels, device, 307 | temperature=0.07, base_temperature=0.07): 308 | anchor_dot_contrast = torch.div( 309 | torch.matmul(embd_batch, embd_batch.T), 310 | temperature) 311 | 312 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 313 | logits = anchor_dot_contrast - logits_max.detach() 314 | labels = labels.contiguous().view(-1, 1) 315 | mask = torch.eq(labels, labels.T).float().to(device) 316 | 317 | logits_mask = torch.scatter( 318 | torch.ones_like(logits.detach()), 319 | 1, 320 | torch.arange(embd_batch.shape[0]).view(-1, 1).to(device), 321 | 0 322 | ) 323 | mask = mask * logits_mask 324 | # compute log_prob 325 | exp_logits = torch.exp(logits) * logits_mask 326 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 327 | 328 | # compute mean of log-likelihood over positive 329 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12) 330 | num_anchor = 0 331 | for s in mask.sum(1): 332 | if s != 0: 333 | num_anchor = num_anchor + 1 334 | # loss 335 | loss = - (temperature / base_temperature) * mean_log_prob_pos 336 | loss = loss.sum(0) / (num_anchor + 1e-12) 337 | 338 | return loss 339 | 340 | 341 | def get_pesudo_via_high_confidence_softlabels(y_label, pseudo_label_soft, mask_label, num_real_class, device, 342 | p_cutoff=0.95): 343 | all_end_label = torch.argmax(pseudo_label_soft, 1) 344 | pseudo_label_hard = torch.argmax(pseudo_label_soft, 1) 345 | 346 | class_counter = Counter(y_label[mask_label]) 347 | for i in range(num_real_class): 348 | class_counter[i] = 0 349 | 350 | for i in range(len(mask_label)): 351 | if mask_label[i] is False: ## unlabeled data 352 | class_counter[pseudo_label_hard[i]] += 1 353 | else: 354 | all_end_label[i] = y_label[i] 355 | 356 | classwise_acc = torch.zeros((num_real_class,)).to(device) 357 | for i in range(num_real_class): 358 | classwise_acc[i] = class_counter[i] / max(class_counter.values()) 359 | 360 | pseudo_label = torch.softmax(pseudo_label_soft, dim=-1) 361 | max_probs, max_idx = torch.max(pseudo_label, dim=-1) 362 | cpl_mask = max_probs.ge(p_cutoff * (classwise_acc[max_idx] / (2. - classwise_acc[max_idx]))) 363 | 364 | end_mask_labeled = mask_label.copy() 365 | for i in range(len(end_mask_labeled)): 366 | if end_mask_labeled[i] is False: 367 | if cpl_mask[i]: 368 | end_mask_labeled[i] = True 369 | 370 | return end_mask_labeled, all_end_label 371 | --------------------------------------------------------------------------------